# A base core for computing the various APFS versions.
# To make this work, inherit from BaseCore and fill out:
#
# * inputTableXML -- INPUT_TABLE_XML_TEMPLATE with objectDescription
#     and objectDefault put in
# * sourceCatName -- the name of the catalog the astrometry is pulled from
# * sourceCatId -- a DaCHS cross id of the table the astrometry
#     is pulled from
# * query(identifier) -- a function returning a tuple of apfswrap arguments
#     for identifier (as determined by resolve)
# * catalogueEpoch -- JD of the epoch of the astrometry.  Defaults to
#     jdn(H2000.0)
#
# The apfswrap binary takes the following arguments:
# alpha, delta, mualpha, mudelta (in mas/yr), radial velocity (in km/s),
#   parallax (in arcsec), and optionally
# period (julian year), eccentricty, tperiastron (julian year)
#   majoraxis, incl, nodelength, peraparg (all in degrees)

import datetime
import io
import os
import re
import subprocess

from gavo import api
from gavo import base
from gavo import utils
from gavo.protocols import scs
from gavo.protocols import simbadinterface #noflake: for registration


INPUT_TABLE_XML_TEMPLATE = """
<inputTable>
	<inputKey tablehead="Object" name="object"
		type="text" required="True"
		multiplicity="single"
		description="{objectDescription}">
		<property key="defaultForForm">{objectDefault}</property>
	</inputKey>
	<inputKey tablehead="Start date"
		name="startDate" type="date" required="True"
		multiplicity="single"
		description="Start date of generated ephemeris"/>
	<inputKey tablehead="End date"
		multiplicity="single"
		name="endDate" type="date" required="True"
		description="End date of generated ephemeris"/>
	<inputKey tablehead="Interval of generation (hrs)"
		multiplicity="single"
		name="hrInterval" type="integer" required="True"
		description="Number of hours between two apparent positions">
		<values min="1" default="24"/>
	</inputKey>
</inputTable>
"""


def ensureWithinSOFA(dt):
	"""SOFA has time limits; catch them here for legible error messages.
	"""
	if not 1960<=dt.year<=2059:
		raise base.ValidationError("Can only compute ephemeris"
			" between 1960 and 2059", "startDate")
	return dt


def expandDates(startTime, endTime, hrInterval, matchLimit=100000):
	"""yields datetime.datetime instances between startDate and
	endDate, separated hrInterval hours.

	This uses //procs#expandDates
	"""
	stampTime = datetime.datetime.combine(
		startTime,
		datetime.time(0))
	endTime = datetime.datetime.combine(
		endTime,
		datetime.time(23, 59, 00))

	hrInterval = max(1, hrInterval)
	interval = datetime.timedelta(hours=hrInterval)

	while stampTime<=endTime:
		matchLimit -= 1
		if matchLimit<0:
			break
		yield ensureWithinSOFA(stampTime)
		stampTime = stampTime+interval


def getInput(inputTable):
	"""returns the stdin to the apfs computer.

	That's yr, month, day, hour between the beginning and the end as defined
	inputTable.
	"""
	return api.bytify("\n".join(
			"%d %d %d %f"%(dt.year, dt.month, dt.day, dt.hour)
			for dt in expandDates(
					inputTable.args["startDate"],
					inputTable.args["endDate"],
					inputTable.args["hrInterval"]))+"\n")


class BaseCore(api.Core):
	"""Abstract APFS core: compute ephemeris from a variety of sources.

	See module docstring for how to make this work.
	"""
	useLegacy = True
	catalogueEpoch = 2451545.0

	def initialize(self):
		self.outputTable = api.OutputTableDef.fromTableDef(
			self.rd.getById("apfsOutput"))
		mod, _ = api.loadPythonModule(self.rd.getAbsPath("res/apfs"))
		self.capplc = mod.capplc

	def resolve(self, inputTable):
		try:
			ob = inputTable.args["object"]
			mat = re.match("^[0-9]+$", ob)
			if mat:
				inputTable.args["star"] = ob
				return
			mat = re.match("([0-9.]+),([0-9.+-]+)$", ob)
			if mat:
				ra, dec = [float(v) for v in mat.groups()]
			else:
				ra, dec = base.caches.getSesame("simbad").getPositionFor(ob)
		
			cat = api.resolveCrossId(self.sourceCatId)
			idCol = cat.getColumnByUCD("meta.id;meta.main")
			inputTable.args["star"] = scs.findNClosest(
				ra, dec,
				cat,
				1, [idCol.name])[0][0]
		except (KeyError, ValueError, AttributeError):
			raise api.ui.logOldExc(
				api.ValidationError("%s is neither an %s number nor"
				" a ra,dec position nor an object known by simbad."%(
					repr(ob),
					self.sourceCatName),
				colName="object"))

	def runFortran(self, service, inputTable, astrometry):
		"""uses the old Fortran machinery to compute the apparent places.

		This should no longer be used in production and should be
		considered breakable any time.

		In particular, it will yield bad results when sources not on J2000
		are used.
		"""
		astrometry = [str(v or 0) for v in astrometry]
		computer = base.getBinaryName(
			service.rd.getAbsPath("bin/apfswrap"))
		pipe = subprocess.Popen([computer]+astrometry,
			stdin=subprocess.PIPE, stdout=subprocess.PIPE, close_fds=True,
				cwd=os.path.dirname(computer))

		data, errmsg = pipe.communicate(getInput(inputTable))
		if pipe.returncode:
			raise base.ValidationError("The backend computing program failed"
				" (exit code %s).  Messages may be available as"
				" hints."%pipe.returncode,
				"startTime",
				hint=errmsg)
		
		return api.makeData(service.rd.getById("apfswrap"),
			forceSource=io.BytesIO(data)).getPrimaryTable()

	# see "apfswrap binary" above for the units of the tuples
	# returned by query (this isn't a dict for historical reasons;
	# and we should probably change the query signature one of these
	# days)
	astrometryLabels = [
		"raj2000", "dej2000", "pmra", "pmdec", "parallax", "rv",
		"period", "ecc", "tper", "axis",
		"incl", "node", "omega", "objectname"]

	resultLabels = ["raCio", "raEqu", "dec"]

	def runModern(self, service, inputTable, astrometry):
		"""computes the apparent places using python code.
		"""
		rows = []
		for dt in expandDates(
				inputTable.args["startDate"],
				inputTable.args["endDate"],
				inputTable.args["hrInterval"]):

			basic_motion = dict(zip(self.astrometryLabels[:6], astrometry))
			orbit = None
			if len(astrometry)>6 and astrometry[7] is not None:
				orbit = dict(zip(self.astrometryLabels[6:13], astrometry[6:13]))

			row = dict(zip(self.resultLabels,
				self.capplc(
					basic_motion,
					self.catalogueEpoch,
					api.dateTimeToJdn(dt),
					orbit)))
			row["isodate"] = dt
			row["arg_hour"] = (dt.hour/24.+dt.minute/24./60+dt.second/86400.)*24
			rows.append(row)

		return api.makeData(
			service.rd.getById("buildresult"),
			forceSource=rows).getPrimaryTable()

	def run(self, service, inputTable, queryMeta):
		self.resolve(inputTable)
		inputTable.setMeta("forStar", str(inputTable.args["star"]))
		if inputTable.args.get("comname"):
			inputTable.setMeta("commonName", inputTable.args["comname"])
		try:
			astrometry = self.query(inputTable.args["star"], inputTable)
		except IndexError:
			# While this might hide other problems, the assumption
			# this is because the object doesn't exist is probably safe.
			raise api.ValidationError("No associated object found", "object")

		# the first couple of pieces of astrometry go to the input table
		# so the custom renderer can pull them from there
		inputTable.args.update(
			dict(zip(["alpha", "delta"], astrometry)))

		if self.useLegacy:
			return self.runFortran(service, inputTable, astrometry)
		else:
			return self.runModern(service, inputTable, astrometry)
