# Coryn Bailer-Jones wants compute robust distance estimates
# again.  For this, he needs CSV files in a special format.
# This script produces those.

import csv
import sys

import psycopg2

# this assumes port 9564 on localhost is forwarded like this to
# mintaka:
# ssh -L 9564:/var/run/postgresql/.s.PGSQL.5432 mintaka

DSN = {"database": "gaia", "host": "localhost", "port": "9564"}

QUERY = """
SELECT
	source_id, phot_g_mean_mag,
	bp_rp,
	parallax, parallax_error,
	nu_eff_used_in_astrometry,
	pseudocolour,
	ecl_lat
FROM gaiaedr3.gaia_source
WHERE source_id BETWEEN {{min_id}} AND {{max_id}}
AND parallax IS NOT NULL""".format(
	**globals())

HPX_SHIFTER = 2**35*4**7

class StatusDisplay:
	"""A context manager for updating a one-line display.
	"""
	def __init__(self, dest_f=sys.stdout):
		self.dest_f = dest_f
		self.clearer = "\r\n"

	def update(self, new_content):
		self.dest_f.write(self.clearer+new_content)
		self.dest_f.flush()
		self.clearer = "\r"+(" "*len(new_content))+"\r"

	def __enter__(self):
		self.dest_f.write(self.clearer)
		self.dest_f.flush()
		return self

	def __exit__(self, *args):
		self.dest_f.write("\r\n")
		self.dest_f.flush()


def dump_one(conn, hpx_index):
	cursor = conn.cursor()
	min_id = hpx_index*HPX_SHIFTER
	max_id = (hpx_index+1)*HPX_SHIFTER
	cursor.execute(QUERY.format(**locals()))

	with open("{:d}.csv".format(hpx_index), "w") as dest_file:
		dest_file.write(",".join(d[0] for d in cursor.description))
		dest_file.write("\r\n")
		dest = csv.writer(dest_file)
		dest.writerows(cursor)


def main():
	if len(sys.argv)==1:
		healpixes = range(12288)
	else:
		healpixes = [int(a) for a in sys.argv[1:]]
	total = len(healpixes)

	conn = psycopg2.connect(**DSN)

	with StatusDisplay() as disp:
		for index, hpx in enumerate(healpixes):
			disp.update("{index}/{total}".format(**locals()))
			dump_one(conn, hpx)
		disp.update("{total}/{total}".format(**locals()))


if __name__=="__main__":
	main()
