"""
Computation of the RUWE quality measure (see Lindegren et al,
GAIA-C3-TN-LU-LL-124-01).  A bit of that code was lifted from code
written by Jan Ribizki.  In particular, the two .npy files where
provided by him at https://keeper.mpdl.mpg.de/f/be3ee4b8264a484b806d/?dl=1
(Nov. 2018; changing to the txt files provided by Lindegren should be
straightforward).
"""

import math
import os
import subprocess

import numpy
from scipy.spatial import cKDTree


class RUWEComputer(object):
	"""A facade for the data and indices to compute a RUWE value by
	magnitudes, astrometric_n_good_obs_al, and astrometric_chi2_al.

	This pulls in a moderate amount of external data pulled from the directory
	passed during construction.
	
	Just call the instance with the parameters defined in __call__ to compute
	a RUWE.
	"""
	def __init__(self, src_dir):
		self.u0_g = numpy.load(os.path.join(src_dir, "table_u0_g.npy"))
		self.u0_gc = numpy.load(os.path.join(src_dir, "table_u0_g_col.npy"))
		self.gc_index = cKDTree(numpy.c_[self.u0_gc['g_mag'],self.u0_gc['bp_rp']])
		self.g_index = cKDTree(numpy.c_[self.u0_g['g_mag']])
	
	def __call__(self,
			astrometric_chi2_al,
			astrometric_n_good_obs_al,
			phot_g_mean_mag,
			phot_bp_mean_mag,
			phot_rp_mean_mag):
		u = math.sqrt(astrometric_chi2_al/float(astrometric_n_good_obs_al-5))

		# deal with nan instead of None
		if phot_bp_mean_mag!=phot_bp_mean_mag: phot_bp_mean_mag = None
		if phot_rp_mean_mag!=phot_rp_mean_mag: phot_rp_mean_mag = None

		if phot_bp_mean_mag is None or phot_rp_mean_mag is None:
			return u/self.u0_g['u0'][
				self.g_index.query([phot_g_mean_mag])[1]]

		else:
			bp_rp = phot_bp_mean_mag-phot_rp_mean_mag
			return u/self.u0_gc['u0'][
				self.gc_index.query([phot_g_mean_mag, bp_rp])[1]]


def get_records_tap():
	"""returns DR2 records obtained through TAP for a small "local" subset.
	"""
	from gavo import api
	from astropy import table
	with api.getTableConn() as conn:
		sourceids = table.Table([[r[0]
			for r in conn.query("select source_id from gaia.dr2light")]],
			names=('source_id',))

	from pyvo import dal
	svc = dal.TAPService("http://gaia.ari.uni-heidelberg.de/tap")
	for rec in svc.run_sync("""
			select top 3000 source_id,
				astrometric_chi2_al,
				astrometric_n_good_obs_al,
				phot_g_mean_mag,
				phot_bp_mean_mag,
				phot_rp_mean_mag
			from gaiadr2.gaia_source
			where source_id in (SELECT source_id FROM TAP_UPLOAD.fromdb)""",
			uploads={"fromdb": sourceids}):
		yield dict(rec)


def floatOrNone(s):
	if s==r"\N":
		return None
	return float(s)


def get_records_dump():
	"""returns DR2 records obtained from mintaka.
	"""
	columns = [
		"source_id",
		"astrometric_chi2_al",
		"astrometric_n_good_obs_al",
		"phot_g_mean_mag",
		"phot_bp_mean_mag",
		"phot_rp_mean_mag"]

	src = subprocess.Popen(["ssh", "msdemlei@mintaka.ari.uni-heidelberg.de",
		"psql gaia -c '\\copy gaiadr2.gaia_source(%s) to stdout"
		" with (format text)'"%(",".join(columns))],
		stdout=subprocess.PIPE)

	for line in src.stdout:
		parts = line.split()
		parts = [int(parts[0])]+[floatOrNone(s) for s in parts[1:]]
		yield dict(zip(columns, parts))


def get_default_computer():
	"""returns a RUWEComputer with the resource in gaia's resdir.
	"""
	from gavo import api
	rd = api.getRD("gaia/q2")
	return RUWEComputer(os.path.join(rd.resdir, "data2"))


def compute_RUWES():
	"""writes pairs of source_id, RUWE to stdout.
	"""
	get_RUWE = get_default_computer()
#	for rec in get_records_tap():
	for rec in get_records_dump():
		source_id = rec.pop("source_id")
		try:
			print("{} {}".format(source_id,
				get_RUWE(**rec)))
		except:
			import traceback; traceback.print_exc()
			print("{} NULL".format(source_id))


if __name__=="__main__":
	compute_RUWES()
