# Convert the dump of the continuous DR3 XP spectra to calibrated,
# sampled ones.
#
# I've abandoned this because the spectra produced by GaiaXP are so
# terrible -- they've got the polynomials all over them and horrendously
# correlated arrors.  See dr3_to_mcsampled.py for what I try instead.
#
# This is an extra program because the whole procedure is really slow
# with python code.  If I want this to terminate in time, I need to
# parallelise, and that's not trivial within DaCHS.
#
# The main program here gobbles in ~CHUNK_SIZE bytes from the input, and hands
# that on to the workers.  And it gets back the records from there, writing
# a quickly importable dump of source_id, flux, flux_error
#
# This is expected to run in the resdir (i.e., the parent of data3).


import gzip
import multiprocessing as mp
import time

import numpy
import pandas

from gaiaxpy import core
from gaiaxpy.calibrator import calibrator

N_WORKERS = 8

# In operation, make CHUNK_SIZE reasonably large (1e8, perhaps)
CHUNK_SIZE = 100000000

# That's a local choice; upstream does something completely different.
# If you change this, adapt s3#spectralPoints, too.  It's in nm, by the way.
SPECTRAL_POINTS = numpy.array([400.+10*i for i in range(41)])

def parseArray(arrlit):
	return numpy.fromiter(
		(float(l) for l in arrlit[1:-1].split(',')), 'f')


def _toTypes(vars):
	for arr_col in [
			"bp_chi_squared",
			"bp_coefficients",
			"bp_coefficient_errors",
			"bp_coefficient_correlations",
			"rp_chi_squared",
			"rp_coefficients",
			"rp_coefficient_errors",
			"rp_coefficient_correlations"]:
		vars[arr_col] = parseArray(vars[arr_col])

	for scalar_col in [
			"bp_standard_deviation",
			"rp_standard_deviation",]:
		vars[scalar_col] = float(vars[scalar_col])

	for int_col in [
			"source_id",
			"bp_n_parameters",
			"rp_n_parameters",]:
		vars[int_col] = int(vars[int_col])

	for band in ["rp", "bp"]:
		vars[f"{band}_coefficient_correlations"
			] = core.array_to_symmetric_matrix(
			vars[f"{band}_n_parameters"],
			vars[f"{band}_coefficient_correlations"])

	return vars


def _addSampled(buffer):
	calib = calibrator.calibrate(
		pandas.DataFrame.from_records(buffer),
		sampling=SPECTRAL_POINTS,
		truncation=True,
		save_file=False)

	for rec, sampled in zip(buffer, calib[0].itertuples()):
		assert rec["source_id"]==sampled[1]
		rec["flux"], rec["flux_error"] = sampled[2:]


FIELD_NAMES = "source_id,solution_id,bp_basis_function_id,bp_degrees_of_freedom,bp_n_parameters,bp_n_measurements,bp_n_rejected_measurements,bp_standard_deviation,bp_chi_squared,bp_coefficients,bp_coefficient_errors,bp_coefficient_correlations,bp_n_relevant_bases,bp_relative_shrinking,rp_basis_function_id,rp_degrees_of_freedom,rp_n_parameters,rp_n_measurements,rp_n_rejected_measurements,rp_standard_deviation,rp_chi_squared,rp_coefficients,rp_coefficient_errors,rp_coefficient_correlations,rp_n_relevant_bases,rp_relative_shrinking".split(",")

def _parseToRecords(chunk):
	for row in chunk.split("\n"):
		if row:
			yield dict(zip(FIELD_NAMES, row.split()))


class Worker:
	# the controller-side representation of a processing job
	def __init__(self, id):
		self.id = id
		self.inQueue = mp.Queue()
		self.outQueue = mp.Queue()
		self._process = mp.Process(target=self._run,
			args=(self.id, self.inQueue, self.outQueue))
		self._process.start()
		self.chunksNotReturned = 0
		self.joined = False
	
	def submit(self, data):
		self.inQueue.put(data)
		self.chunksNotReturned += 1

	@staticmethod
	def _run(id, inQueue, outQueue):
		while True:
			chunk = inQueue.get()
			if not chunk:
				# we're done; exit
				break

			buffer = []
			for row in _parseToRecords(chunk):
				buffer.append(_toTypes(row))
			_addSampled(buffer)

			serialised = []
			for row in buffer:
				source_id = row["source_id"]
				flux = ",".join(f"{v:.7g}" for v in row["flux"])
				flux_error = ",".join(f"{v:.7g}" for v in row["flux_error"])
				serialised.append(f"{source_id}\t{flux}\t{flux_error}")

			outQueue.put("\n".join(serialised))
		print(f"Process {id} exiting.")

	def writeResult(self, destF):
		"""will write a result to destF if one is there.

		This is a no-op if the worker has no new results.
		"""
		if not self.outQueue.empty():
			print(f"Worker {self.id} writing result")
			destF.write(self.outQueue.get()+"\n")
			self.chunksNotReturned -= 1

	def join(self):
		self._process.join(1)
		if self._process.exitcode is not None:
			self.joined = True


def main():
	workers = [Worker(str(i)) for i in range(N_WORKERS)]


	with gzip.open("data3/xp_continuous_dump.txt.gz",
			mode="rt", encoding="ascii") as inF,\
		open("data3/xp_sampled_computed.txt", "w", encoding="ascii") as outF:

		chunk = inF.read(CHUNK_SIZE)
		lineSep = -1
		while chunk:
			rest = inF.read(CHUNK_SIZE)
			if len(rest)<CHUNK_SIZE:
				# file exhausted, let the last job do its thing and exit
				chunk = chunk+rest
				rest = ""

			else:
				# normal operation: complete last line of chunk, hand off packet.
				lineSep = rest.index("\n")
				chunk = chunk+rest[:lineSep]

			# Wait for the next worker to be finish and submit the next chunk to
			# them.
			while True:
				# Let workers write any results coming in
				for w in workers:
					w.writeResult(outF)

				for w in workers:
					if w.chunksNotReturned==0:
						print(f"Submitting to worker {w.id}")
						w.submit(chunk)
						chunk, rest = rest[lineSep+1:], ""
						break
				else: # no free worker.  Retry in a bit
					time.sleep(2)
					continue
				break

		# tell workers to exit
		print("Asking Workers to exit")
		for w in workers:
			w.submit("")

		# when exiting the main loop, make sure all workers have finished
		for retry in range(100):
			notJoined = []
			for w in workers:
				if not w.joined:
					w.writeResult(outF)
					w.join()
					notJoined.append(w)
			if not notJoined:
				break
			print("Waiting for "+", ".join(w.id for w in notJoined))
		else:
			raise Exception("Hung workers?")


def main_sync():
	import sys, gzip
	with gzip.open(sys.argv[1], "rt") as f:
		_addSampled([_toTypes(list(_parseToRecords(f.readline()))[0])])


if __name__=="__main__":
	main()
