"""
Write a dump of the non-null columns of the main table in nice, handlable
partitions.
"""

import os
import urllib

from astropy import table
import pyvo

MAX_ROWS = 5000000

ENDPOINT = "http://dc.g-vo.org/tap"

# the divisor here is max(source_id)/10000 -- but it's probably wiser
# to just use 2e14 or so.
ID_DIVISOR = 691752899328320

QUERY = """
select round(source_id/%d) as bin, count(*) as ct
from gdr2mock.generated_data
group by bin
"""%ID_DIVISOR


def get_bin_sizes():
	"""returns a ordered sequence of (bin_center, num_objects) rows.
	"""
	try:
		with open("partitions.vot", "rb") as f:
			tbl = table.Table.read(f)
	except IOError:
		# Fetch from source; takes about 1 hour
		print("Fetching partitions from source; this will take a while"
			" (provide partitions.vot to avoid re-querying)")
		svc = pyvo.dal.TAPService(ENDPOINT)
		res = svc.run_async(QUERY, maxrec=2000000)
		tbl = res.table
		with open("partitions.vot", "wb") as f:
			tbl.write(output=f, format="votable")
	
	res = [(row["bin"], row["ct"]) for row in tbl]
	res.sort()
	return res


def get_partition_limits(bin_sizes):
	"""returns a list of limits of source_id ranges exhausting the whole
	catalog.

	bin_sizes is what get_bin_sizes returns (and it must be sorted by
	bin center).
	"""
	limits, cur_count = [0], 0
	for bin_center, bin_count in bin_sizes:
		if cur_count+bin_count>MAX_ROWS:
			limits.append(int(bin_center*ID_DIVISOR-ID_DIVISOR/2))
			cur_count = 0
		cur_count += bin_count
	limits.append(6917528993281343490+1)
	return limits


def get_data_for(svc, low, high):
	"""writes a FITS table for the catalog between source_id low and high.
	"""
	dest_name = "dump-{:09d}-{:09d}.count".format(low>>36, high>>36)
	if os.path.exists(dest_name):
		return

	cols = ("source_id, ra, dec, ra_error, dec_error, l, b, pmra,"
		" pmdec, pmra_error, pmdec_error, parallax, parallax_error,"
		" nobs, phot_g_mean_mag, phot_rp_mean_mag,"
		" phot_bp_mean_mag, phot_g_mean_flux, phot_bp_mean_flux,"
		" phot_rp_mean_flux, phot_g_mean_flux_error, phot_bp_mean_flux_error,"
		" phot_rp_mean_flux_error, "
		" bp_rp, bp_g, g_rp, radial_velocity, e_bp_min_rp_val, lum_val,"
		"	feh, teff_val, a_g_val, mass, age, logg, a0, radius_val,"
		" index_parsec, random_index")

	print("SELECT {cols} from gdr2mock.main where"
			 " source_id between {low} and {high}".format(**locals()))
	job = svc.submit_job(
		 "SELECT {cols} from gdr2mock.main where"
			 " source_id between {low} and {high}".format(**locals()),
			 maxrec=20000000, format="fits")
	try:
		job.run()
		job.wait()


		with open(dest_name, "wb") as dest:
			src = urllib.urlopen(job.result_uri)
			while True:
				stuff = src.read(10000000)
				if not stuff:
					break
				dest.write(stuff)
	finally:
		job.delete()


def write_fitses(limits):
	"""writes FITS binary files for the partitions defined by limits.
	"""
	svc = pyvo.dal.TAPService(ENDPOINT)
	for ct, (low, high) in enumerate(zip(limits[:-1], limits[1:])):
		get_data_for(svc, low, high-1)
		print("{}/{}".format(ct+1, len(limits)-1))


if __name__=="__main__":
	limits = get_partition_limits(get_bin_sizes())
	print(limits)
	write_fitses(limits)
