"""
Computation of apparent and intermediate places.

This is based on Helmut Lenhard's apfscompute.f and for now blindly follows
it prescriptions; this is to be sure we are not inadvertently "update"
his recipes without known what we do.
"""

import datetime
from math import sin, cos, tan, sqrt, pi, atan, atan2

from astropy.coordinates.erfa_astrom import erfa
from jplephem.spk import SPK
import numpy
from numpy import linalg

from gavo import api
from gavo.stc import spherc
from gavo.utils import DEG, mathtricks

km_AU = 1.495978701e8


RD = api.getRD("apfs/res/apfs_new")
JPLEPH = SPK.open(RD.getAbsPath("data/de430.bsp"))

def normalize(v):
		return v/linalg.norm(v)


def light_deflection(earth_radvec, object_pos, heliocentric_distance):
		"""returns the barycentric position for an object at object_vector
		(barycentric) when observed earth_radvec from heliocentric_distance.
		"""
		xmc2 = 9.8731e-09
		edotp = numpy.dot(earth_radvec, object_pos)

		return object_pos+2*xmc2/heliocentric_distance*(
				earth_radvec-edotp*object_pos)/(1+edotp)


def aberration(earth_veloc, object_pos):
		"""returns an aberrated radius vector for object_pos observed from
		a body moving at earth_veloc.
		"""
		# convert to units of c
		earth_veloc = earth_veloc*0.00577551833
		veloc = linalg.norm(earth_veloc)

		pdotv = numpy.dot(earth_veloc, object_pos)
		invbeta = sqrt(1-veloc**2)
		x1 = 1+pdotv/(1+invbeta)

		return (invbeta*object_pos+x1*earth_veloc)/(1+pdotv)


def orbitDelta(orbit, basicMotion, toEpoch):
	"""returns Δalpha, Δdelta in gegrees for the apparent positions of a star
	on a given orbit at toEpoch (which is JD).

	orbit is a dictionary with the keys

	* period -- years
	* tper -- t perihel as julian year
	* ecc -- eccentricy
	* axis -- major axis [arcsec]
	* incl -- inclination [deg]
	* node -- longitude of ascending node [deg] (gom)
	* omega -- argument of perastron [deg] (om)

	basicMotion is a dictionary as in capplc.
	"""
	circfreq = 2*pi/orbit["period"]  # pmu in Helmut
	arg = circfreq*(api.jdnToJYear(toEpoch)-orbit["tper"])  # pmea in Helmut

	# internally, all angles are in rad, except for the axis, which remains
	# in arcsec
	node = orbit["node"]*DEG
	omega = orbit["omega"]*DEG
	incl = orbit["incl"]*DEG
	axis_sin = orbit["axis"]/sin(incl)
	raj2000 = basicMotion["raj2000"]*DEG
	dej2000 = basicMotion["dej2000"]*DEG

	# I frankly have not researched what this code does; it's just taken
	# from Helmut's FORTRAN code.
	ge = atan2(sin(arg), cos(arg)-orbit["ecc"])

	for niter in range(80):
		ea = ge
		ge = arg + orbit["ecc"]*sin(ea)
		if abs(ge-ea)<1e-10:
			break
	else:
		raise Exception("No convergence")

	radius = axis_sin*(1-orbit["ecc"]*cos(ge))
	truano= 2*atan(sqrt((1+orbit["ecc"])/(1-orbit["ecc"])
		)*tan(ge*0.5))
	theta = atan2(sin(truano+omega)*cos(incl), cos(truano+omega))+node
	rho = radius*cos(truano+omega)/cos(theta-node)

	# The corrections come out of the math for equinox J2000 and are in the
	# tangential plane ("with cos(delta)").  Fix this up.
	deltaRA = rho*sin(theta)/3.6e6*DEG/cos(dej2000)
	deltaDe = rho*cos(theta)/3.6e6*DEG

	precMat = spherc.getPrecMatrix(
		# our orbit elements are given such that our deltas are for J2000
		datetime.datetime(2000, 1, 1),
		api.jdnToDateTime(toEpoch),
		spherc.prec_IAU1976)

	posVec = mathtricks.spherToCart(raj2000, dej2000)
	precessedPos = numpy.dot(precMat, posVec)
	raNew, decNew = mathtricks.cartToSpher(precessedPos)

	# This repeats a bit of code we have elsewhere; essentially, it's
	# applying proper motions with cartesian vectors.  But the
	# equivalent astropy functions are hidden rather deeply, and I didn't
	# want extra moving parts.
	dirRA = [
		-cos(dej2000)*sin(raj2000),
		cos(dej2000)*cos(raj2000),
		0]
	dirDec = [
		-sin(dej2000)*cos(raj2000),
		-sin(dej2000)*sin(raj2000),
		cos(dej2000)]

	motion = (numpy.dot(precMat, dirRA)*deltaRA
		+ numpy.dot(precMat, dirDec)*deltaDe)
	dirRANew = [
		-cos(decNew)*sin(raNew),
		cos(decNew)*cos(raNew),
		0]
	dra = numpy.dot(motion, dirRANew)/cos(decNew)**2

	dirDecNew = [
		-sin(decNew)*cos(raNew),
		-sin(decNew)*sin(raNew),
		cos(decNew)]
	ddec = numpy.dot(motion, dirDecNew)
	return dra/DEG, ddec/DEG


def capplc(basicMotion, fromEpoch, toEpoch, orbit=None):
		"""returns racio, raint, dec, pmra, pmdec, parallax, rv for toEpoch.

		As in Helmut's original code, positions are in degrees, proper
		motions in mas/yr, the parallax in arcsec, and the rv in km/s.

		toEpoch is a JD.

		Orbit, if given, must be a dict as defined in orbitDelta.

		basicMotion is a dictionary with the following keys:

		* raj2000 -- ICRS RA in degrees for J2000
		* dej2000 -- ICRS Declination in degrees for J2000
		* pmra -- proper motion in RA (tangential plane), mas/yr
		* pmdec -- proper motion in Dec, mas/yr
		* parallax -- in arcsec
		* rv -- in km/s
		"""
		# internally, we want all angles in rad; as per hallowed
		# tradition, the time unit is the julian century.
		try:
			raj2000 = basicMotion["raj2000"]*DEG
			dej2000 = basicMotion["dej2000"]*DEG
			pmra = basicMotion["pmra"]/3.6e6*DEG/cos(dej2000)*100
			pmdec = basicMotion["pmdec"]/3.6e6*DEG*100
			parallax = basicMotion["parallax"]/3.6e3*DEG
		except TypeError:
			raise api.ValidationError("This object does not have a full"
				" five-parameter solution (ra, dec, pmra, pmdec, parallax)."
				" I will not try to guess parameters.  Choose another star.",
				"object")
		rv = basicMotion["rv"] or 0

		# "radial proper motion"
		radpm = rv*21.09495
		vaupar = parallax*radpm

		unit_vec0 = mathtricks.spherToCart(raj2000, dej2000) # q in Helmut's
		motion = numpy.array([ # m in Helmut's
				-cos(dej2000)*sin(raj2000)*pmra
						-sin(dej2000)*cos(raj2000)*pmdec
						+cos(dej2000)*cos(raj2000)*vaupar,
				cos(dej2000)*cos(raj2000)*pmra
					 -sin(dej2000)*sin(raj2000)*pmdec
					 +cos(dej2000)*sin(raj2000)*vaupar,
				cos(dej2000)*pmdec + sin(dej2000)*vaupar])

		# earth/moon barycentre
		eb_pos, eb_vel = JPLEPH[0,3].compute_and_differentiate(toEpoch)
		# earth relative to embary
		er_pos, er_vel = JPLEPH[3,399].compute_and_differentiate(toEpoch)

		eb, ev = (eb_pos+er_pos)/km_AU, (eb_vel+er_vel)/km_AU
		sb = JPLEPH[0,10].compute(toEpoch)/km_AU
		# time unit is the julian century
		dt = (toEpoch-fromEpoch)/36525

		propagated = normalize(unit_vec0+dt*motion-parallax*eb)
		heliocentric_earth = eb-sb
		heliocentric_distance = linalg.norm(heliocentric_earth)

		deflected = light_deflection(
				heliocentric_earth/heliocentric_distance,
				propagated,
				heliocentric_distance)

		aberrated = aberration(ev, deflected)

		# equinox method: precession
		prec_matrix = erfa.pnm06a(toEpoch, 0)
		pos_equ = numpy.matmul(prec_matrix, aberrated)
		ra_equ, dec_equ = mathtricks.cartToSpher(pos_equ)

		# CIO method: intermediate position
		mat = erfa.c2i06a(toEpoch, 0)
		pos_cio = numpy.matmul(mat, aberrated)
		ra_cio, dec_cio = mathtricks.cartToSpher(pos_cio)

		ra_equ, ra_cio = ra_equ/DEG, ra_cio/DEG
		dec_equ, dec_cio = dec_equ/DEG, dec_cio/DEG
		assert dec_cio/dec_equ-1<1e-10

		if orbit:
			# unit version vs. apfs.fk6orbits
			orbit["axis"] = orbit["axis"]*3.6e6
			dra, dde = orbitDelta(orbit, basicMotion, toEpoch)
			ra_equ, ra_cio = ra_equ+dra, ra_cio+dra
			dec_equ, dec_cio = dec_equ+dde, dec_cio+dde

		return (ra_cio, ra_equ, dec_cio)


if __name__=="__main__":
		# Test case: FK6 905 for 2023-01-30 (jd 2451544.5)
		# Compute using:
		# echo 2023 1 30 0 | ./apfswrap-i386 0.93495080 -17.33599078 25.84 -8.55 0.01431 -5.0

		print("Sollte: 0.9315561918 1.2248792111 -17.2113204949")
		print(capplc(
				0.93495080, -17.33599078, 25.84, -8.55, 14.31/1e3, -5.0, 2459974.5))
