"""
A regression test for the TAP client/server.

We use the unittest framework, but the tests themselves are, admittedly,
no unit tests; also, the idea is to primarily exercise the remote
server.

To run this, you need the GAVO votable library.
"""

import datetime
import os
import re
import sys
import time
import unittest
import urllib.request, urllib.parse, urllib.error
from io import BytesIO

from gavo import votable
from gavo.votable import tapquery


ENDPOINT_URL = os.environ.get("ENDPOINT_URL",
	"http://dc.zah.uni-heidelberg.de/tap")

class ADQLJobTest(unittest.TestCase):
	"""an abstract base for tests running a single query.

	Define the query run in the query class attribute.
	"""
	def setUp(self):
		params = {}
		if getattr(self, "parameters", None):
			params = self.parameters
		self.job = votable.ADQLTAPJob(ENDPOINT_URL, self.query, userParams=params)
	
	def tearDown(self):
		self.job.delete()
	
	def assertRaisesWithMessage(self, exClass, msgPat, callable, *args, **kwargs):
		try:
			callable(*args, **kwargs)
		except exClass as ex:
			if not re.match(msgPat, str(ex)):
				raise self.failureException(
					"Expected %r, got %r as exception message"%(msgPat, str(ex)))
		except:
			raise
		else:
			raise self.failureException("%s not raised"%exClass)


class AllInOneTest(ADQLJobTest):
	"""a wild mixture of exercise for GAVO's TAP server.
	"""
# XXX TODO: split out some pieces.
	query = "SELECT * FROM taptest.unmain"
	
	def test(self):
		tapquery.debug = True
		j = self.job
		self.assertEqual(j.phase, 'PENDING')
		self.assertEqual(j.executionDuration, 7200.)
		j.executionDuration = 12
		self.assertEqual(j.executionDuration, 12.)
		self.assertTrue(j.destruction>datetime.datetime.utcnow())
		newDestruction = datetime.datetime.utcnow()+datetime.timedelta(hours=2)
		newDestruction = datetime.datetime(*newDestruction.timetuple()[:6])
		j.destruction = newDestruction
		self.assertTrue(isinstance(j.quote, datetime.datetime))
		self.assertEqual(j.destruction, newDestruction)
		self.assertEqual(j.getErrorFromServer(), "")  # anything more sensible?
		self.assertEqual(j.parameters["QUERY"], "SELECT * FROM taptest.unmain")
		j.setParameter("QUERY", "SELECT * FROM taptest.main")
		self.assertEqual(j.parameters["QUERY"], "SELECT * FROM taptest.main")
		self.assertEqual(j.owner, "NULL")
		j.start()
		j.waitForPhases(set([tapquery.COMPLETED, tapquery.ERROR]), giveUpAfter=20)
		j.raiseIfError()
		self.assertEqual(j.phase, "COMPLETED")
		self.assertTrue(
			b"instr.bandpass" in urllib.request.urlopen(j.allResults[0].href).read())


class TestWithBadUpload(ADQLJobTest):
	query = "SELECT * from TAP_UPLOAD.upup"

	def testEmpty(self):
		j = self.job
		j.addUpload("upup", BytesIO(b"<VOTABLE/>"))
		self.assertEqual(
			urllib.request.urlopen(j.parameters["UPLOAD"].split(",")[-1]).read(),
			b"<VOTABLE/>")
		j.start()
		j.waitForPhases(set([tapquery.COMPLETED, tapquery.ERROR]), giveUpAfter=20)
		self.assertEqual(j.phase, "ERROR")
		self.assertEqual(j.getErrorFromServer(),
			"While ingesting upload upup: Cannot parse VOTable"
			" (or no table contained)")

	def testRubbish(self):
		j = self.job
		j.addUpload("upup", BytesIO(b"loremipsum"*30))
		j.start()
		j.waitForPhases(set([tapquery.COMPLETED, tapquery.ERROR]), giveUpAfter=20)
		self.assertEqual(j.phase, "ERROR")
		self.assertTrue("syntax error: line 1, column 0", j.getErrorFromServer())
	

class Maxrec0Test(ADQLJobTest):
	query = 'SELECT * FROM taptest.main'
	parameters = {
		'MAXREC': 0}
	
	def test(self):
		self.job.run()
		data, metadata = votable.load(self.job.openResult())
		self.assertEqual(metadata.getFields()[0].name, "ra")
		self.assertEqual(len(data), 0)
		self.assertEqual(metadata.infos["QUERY_STATUS"][0].value, "OK")
		self.assertEqual(metadata.infos["QUERY_STATUS"][1].value, "OVERFLOW")


class Maxrec10Test(ADQLJobTest):
	query = 'SELECT * FROM taptest.main'
	parameters = {
		'MAXREC': 10}
	
	def test(self):
		self.job.run()
		data, metadata = votable.load(self.job.openResult())
		self.assertEqual(metadata.getFields()[0].name, "ra")
		self.assertEqual(len(data), 10)


class FormatsTest(ADQLJobTest):
	query = 'SELECT TOP 1 * FROM taptest.main'

	def testVOTMIME1(self):
		# Incredibly, the service must keep our whim of VOTable mime...
		self.job.setParameter("FORMAT", "application/x-votable+xml")
		self.job.run()
		info = self.job.openResult().info()
		self.assertEqual(info["content-type"], "application/x-votable+xml")

	def testVOTMIME2(self):
		# Incredibly, the service must keep our whim of VOTable mime...
		self.job.setParameter("FORMAT", "text/xml")
		self.job.run()
		info = self.job.openResult().info()
		self.assertEqual(info["content-type"], "text/xml")

	def testCSV(self):
		self.job.setParameter("FORMAT", "csv")
		self.job.run()
		res = self.job.openResult().read()
		self.assertTrue(b"ra,de,spectral" in res)
	
	def testPlainCSV(self):
		self.job.setParameter("FORMAT", "text/csv")
		self.job.run()
		res = self.job.openResult().read()
		self.assertFalse(b"ra,de,spectral" in res)


class TAPSchemaTest(ADQLJobTest):
	query = 'SELECT TOP 1 * FROM TAP_SCHEMA.tables'

	def testRuns(self):
		self.job.setParameter("FORMAT", "votable")
		self.job.run()
		res = self.job.openResult().read()
		# For want of a better place: ensure RESOURCE has type="results"
		self.assertTrue(b'RESOURCE type="results"' in res)
		data, metadata = votable.load(BytesIO(res))
		# TAP_SCHEMA.tables has seven columns.
		self.assertEqual(len(list(metadata)), 8)


class TestWithUpload(ADQLJobTest):
	pass


class InlineUploadTest(TestWithUpload):
	query = (
		'SELECT TOP 20 *'
		'   FROM taptest.main AS m '
		'   JOIN TAP_UPLOAD.upload1 AS u1'
		'   ON (m.ra<u1.alphaFloat) ORDER BY de,cmag')

	def test(self):
		with open("test1.vot", "rb") as f:
			self.job.addUpload("upload1", f)
		self.job.run()
		data, metadata = votable.load(self.job.openResult())
		self.assertEqual(data[0][3], '043555.2+163033')
		self.assertEqual(metadata[3].ucd, "meta.id;meta.main")


class UploadLimitTestSync(unittest.TestCase):
	def test(self):
		job = votable.ADQLSyncJob(
			ENDPOINT_URL,
			'SELECT TOP 20 *'
			'   FROM taptest.main AS m '
			'   JOIN TAP_UPLOAD.upload1 AS u1'
			'   ON (m.ra<u1.alphaFloat) ORDER BY de')
		job.addUpload("upload1", BytesIO(b"loremipsumjumpsoverthe"
			b"quickbrownfox"*1000000))
		try:
			job.run()
		except tapquery.WrongStatus as ex:
			self.assertEqual(str(ex), "Expected status 200, got status 413")
			self.assertEqual(
				job.getErrorFromServer(),
				'Your upload is too large')
			return
		except Exception as ex:
			raise AssertionError("%s raised instead of WrongStatus"%ex)
		raise AssertionError("WrongStatus not raised")


class MissingUploadTest(TestWithUpload):
	query = (
		'SELECT TOP 20 *'
		'   FROM taptest.main AS m '
		'   JOIN TAP_UPLOAD.upload1 AS u1'
		'   ON (m.ra<u1.alphaFloat) ORDER BY de')

	def test(self):
		self.assertRaisesWithMessage(tapquery.RemoteError,
			"Field query: Could not locate table 'public.upload1'",
			self.job.run)


class URLUploadTest(TestWithUpload):
	query = (
		'SELECT TOP 5 glimpse FROM TAP_UPLOAD.glimpse2 ORDER BY Jmag')
	parameters = {
			"UPLOAD": "glimpse2,http://vo.ari.uni-heidelberg.de/docs"
				"/upload_for_regressiontest.vot"}

	def test(self):
		self.job.run()
		data, metadata = votable.load(self.job.openResult())
		self.assertEqual(len(data), 5)
		self.assertEqual(data[0][0], 'G014.9997+00.0027')
		self.assertEqual(metadata[0].ucd, "meta.id;meta.main")


class AbortionTest(unittest.TestCase):
	def _assertJobAbortable(self, j, ensureExecution=None):
		try:
			j.start()
			j.waitForPhases(set([tapquery.EXECUTING]),
				giveUpAfter=50, pollInterval=0.2)
			if ensureExecution is not None:
				ensureExecution(j)
			j.raiseIfError()
			j.abort()
			j.waitForPhases(set([tapquery.ABORTED]), giveUpAfter=4)
			self.assertEqual(j.phase, tapquery.ABORTED)
		finally:
			j.delete()

	def _waitForTaprunnerUp(self, j):
		# in the gavo DC, you can tell taprunner has started to process a
		# job by checking startTime; this blocks until startTime is non-NULL.
		for i in range(40):
			if urllib.request.urlopen(j.makeJobURL("/startTime")).read().strip()!="NULL":
				break
			time.sleep(0.2)

	def testAbortRunner(self):
		# This posts a magic query that merely lets the runner hang and kills it
		# rather than the postgres worker.
		self._assertJobAbortable(votable.ADQLTAPJob(ENDPOINT_URL,
			'JUST HANG around'), self._waitForTaprunnerUp)

	def testAbortQueryInfant(self):
		# in GAVO DC, this will kill the job before taprunner is
		# actually ready.
		self._assertJobAbortable(
			votable.ADQLTAPJob(ENDPOINT_URL,
				'SELECT TOP 1000000000 * FROM taptest.main'
				' CROSS JOIN taptest.main AS a CROSS JOIN taptest.main as b'))

	def testAbortQueryMature(self):
		# this waits for taprunner to be properly up before killing it.
		self._assertJobAbortable(
			votable.ADQLTAPJob(ENDPOINT_URL,
				'SELECT TOP 1000000000 * FROM taptest.main'
				' CROSS JOIN taptest.main AS a CROSS JOIN taptest.main as b'),
				self._waitForTaprunnerUp)

	def testAbortPending(self):
		j = votable.ADQLTAPJob(ENDPOINT_URL, "SELECT * FROM taptest.main")
		try:
			j.abort()
			self.assertEqual(j.phase, tapquery.ABORTED)
		finally:
			j.delete()


class VOSITest(unittest.TestCase):
	def setUp(self):
		self.srv = tapquery.ADQLEndpoint(ENDPOINT_URL)

	def testAvailability(self):
		self.assertTrue(self.srv.available==True)
	
	def testCapabilities(self):
		caps = self.srv.capabilities
		standardIds = [c.get("standardID") for c in caps]
		self.assertTrue('ivo://ivoa.net/std/TAP' in standardIds)
		self.assertTrue('ivo://ivoa.net/std/VOSI#availability' in standardIds)
		self.assertTrue(caps[0]["interfaces"][0]["accessURL"].startswith("http"))
	
	def testTables(self):
		tables = self.srv.tables
		self.assertTrue(isinstance(tables[0], votable.V.TABLE))
		self.assertTrue(isinstance(tables[0].getFields()[0], votable.V.FIELD))

	def testCapabilitiesRequest(self):
		stuff = urllib.request.urlopen(
			ENDPOINT_URL+"/sync?REQUEST=getCapabilities").read()
		self.assertTrue(b'<interface role="std"' in stuff)


"""
Do we want this behaviour?
class ImmediateRunTest(unittest.TestCase):
# Test for posting phase=RUN with the rest
	f = urllib.urlopen(ENDPOINT_URL+"/async", urllib.urlencode({
		"LANG": "ADQL",
		"REQUEST": "doQuery",
		"QUERY": "SELECT TOP 10 * FROM TAP_SCHEMA.tables",
		"PHASE": "run"}))
"""

def main():
	methodPrefix = "test"
	if len(sys.argv)==3:
		methodPrefix = sys.argv[2]
	if 2<=len(sys.argv)<=3:
		testClass = globals()[sys.argv[1]]
		suite = unittest.makeSuite(testClass, methodPrefix)
		runner = unittest.TextTestRunner()
		runner.run(suite)
	else:
		unittest.main()


if __name__=="__main__":
	main()
