"""
Pull LAMOST spectra.

This needs python3-aiphttp.

It expects the spectrum catalogue
http://dr6.lamost.org/v2/catdl?name=dr6_med_v2_MRS.fits.gz in
data/MRS-catalogue.fits.gz

This uses coroutines to mitigate the connection overhead; upstream didn't
want to use rsync...
"""

import asyncio
import gzip
import os
from io import BytesIO

import aiohttp

from gavo import api
from gavo.utils import ostricks

N_PARALLEL = 1
RD = api.getRD("lamost6/q")


async def download_one(dldesc, session):
	res, obsid = dldesc
	if res=='m':
		dest_path = RD.getAbsPath("data/MRS/{}/{}.fits".format(obsid[5:], obsid))
		src_url = "http://dr6.lamost.org/v2/medspectrum/fits/"+obsid
	elif res=='l':
		dest_path = RD.getAbsPath("data/LRS/{}/{}.fits".format(obsid[5:], obsid))
		src_url = "http://dr6.lamost.org/v2/spectrum/fits/"+obsid
	else:
		assert False

	if os.path.exists(dest_path):
		return

	async with session.get(src_url) as response:
		spec_data = await response.read()

	os.makedirs(os.path.dirname(dest_path), exist_ok=True)
	with open(dest_path, "wb") as f:
		f.write(gzip.open(BytesIO(spec_data)).read())


async def worker(to_download, total):
	with ostricks.StatusDisplay() as d:
		async with aiohttp.ClientSession() as session:
			while to_download:
				dldesc = to_download.pop()
				d.update("{}/{}".format(len(to_download), total))
				try:
					await download_one(dldesc, session)
				except Exception as msg:
					print("Skipping {} for this round: {}".format(dldesc[1], msg))


async def main():
	with api.pyfits.open("data/MRS-catalogue.fits.gz") as hdus:
		# there are multiple catalogue rows per obsid, and each file
		# contains multiple spectra.  Hence, uniquify on obsid.
		to_download = [('m', str(oid)) for oid in set(hdus[1].data["obsid"])]
	with api.pyfits.open("data/LRS-catalogue.fits.gz") as hdus:
		to_download.extend(("l", str(oid)) for oid in hdus[1].data["obsid"])

	total = len(to_download)
	tasks = []
	for i in range(N_PARALLEL):
		tasks.append(
			asyncio.create_task(worker(to_download, total)))

	await asyncio.gather(*tasks)
		

if __name__=="__main__":
	asyncio.run(main())
