"""
Tests for dealing with SQL arrays in ADQL.

(and perhaps elsewhere).
"""

import math

from gavo.helpers import testhelpers

from gavo import base

import adqltest
import tresc


class _ArrayTestbed(tresc.RDDataResource):
	"""A table that contains arrays.
	"""
	rdName = "data/ufuncex.rd"
	dataId = "import_arr"

_arrayTestbed = _ArrayTestbed()


class AdditionTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testBasic(self):
		res = self.querier.queryADQL(
			"SELECT flux,"
			" flux+flux as fluxtwice,"
			" plc+plc as plctwice"
			" FROM test.arrsample")
		self.assertEqual(
			len(res.rows[0]["fluxtwice"]),
			len(res.rows[0]["flux"]))
		self.assertEqual(len(res.rows[0]["plctwice"]), 3)
		self.assertAlmostEqual(
			res.rows[0]["fluxtwice"][0]/2, res.rows[0]["flux"][0])
		self.assertAlmostEqual(
			res.rows[0]["plctwice"][2], -3.2491233)

	def testUnequalLength(self):
		res = self.querier.queryADQL(
			"SELECT flux+CAST(plc AS real[]) as s FROM test.arrsample")
		self.assertEqual(len(res.rows[0]["s"]), 6)
		self.assertTrue(math.isnan(res.rows[0]["s"][-1]))
		self.assertTrue(math.isnan(res.rows[0]["s"][-2]))
		self.assertTrue(math.isnan(res.rows[0]["s"][-3]))
		self.assertAlmostEqual(res.rows[0]["s"][-4], -0.86078703)
		self.assertEqual(res.tableDef.getColumnByName("s").type, "real[]")
	
	def testAddedMetadata(self):
		res = self.querier.queryADQL(
			"SELECT flux, flux+flux as ff, cast(flux as double precision[])+plc as fp"
			" FROM test.arrsample")
	
		col = res.tableDef.getColumnByName("flux")
		self.assertEqual(col.ucd, "phot.mag")
		self.assertEqual(col.unit, "Jy")
		self.assertEqual(col.type, "real[]")
		self.assertEqual(col.description, "Some flux values in an array")

		col = res.tableDef.getColumnByName("ff")
		self.assertEqual(col.ucd, "phot.mag")
		self.assertEqual(col.unit, "Jy")
		self.assertEqual(col.type, "real[]")
		self.assertEqual(col.description, "This field has traces of: Some flux values in an array; Some flux values in an array -- *TAINTED*: the value was operated on in a way that unit and ucd may be severely wrong")

		col = res.tableDef.getColumnByName("fp")
		self.assertEqual(col.ucd, "")
		self.assertEqual(col.unit, "")
		self.assertEqual(col.type, "double precision[]")
		self.assertEqual(col.description, "This field has traces of: Some flux values in an array; Possibly a cartesian position -- *TAINTED*: the value was operated on in a way that unit and ucd may be severely wrong")


class SubtractionTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testBasic(self):
		res = self.querier.queryADQL(
			"SELECT flux,"
			" flux-flux as allnull,"
			" flux-plc as indiff"
			" FROM test.arrsample")
		self.assertEqual(
			len(res.rows[0]["allnull"]),
			len(res.rows[0]["indiff"]))
		self.assertEqual(res.rows[0]["allnull"][5], 0)
		self.assertAlmostEqual(res.rows[0]["indiff"][2], 2.388336286310653)
		self.assertTrue(math.isnan(res.rows[0]["indiff"][5]))


class MultiplicationTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testBasic(self):
		res = self.querier.queryADQL(
			"SELECT flux[6]*flux[6] as fstest,"
			" flux*flux as fluxsquare,"
			" flux[3]*plc[3] as mixtest,"
			" flux*plc as mixed"
			" FROM test.arrsample")
		r = res.rows[0]

		self.assertEqual(
			len(r["fluxsquare"]),
			len(r["mixed"]))
		self.assertEqual(r["fluxsquare"][5], r["fstest"])
		self.assertAlmostEqual(r["mixed"][2], r["mixtest"])
		self.assertTrue(math.isnan(r["mixed"][5]))


class DivisionTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testBasic(self):
		res = self.querier.queryADQL(
			"SELECT flux[6]/flux[6] as fstest,"
			" flux/flux as allones,"
			" flux[3]/plc[3] as mixtest,"
			" flux/plc as mixed"
			" FROM test.arrsample")
		r = res.rows[0]

		self.assertEqual(
			len(r["allones"]),
			len(r["mixed"]))
		self.assertEqual(r["allones"][5], r["fstest"])
		self.assertAlmostEqual(r["mixed"][2], r["mixtest"])
		self.assertTrue(math.isnan(r["mixed"][5]))

		self.assertEqual(res.tableDef.getColumnByName("allones").type,
			"real[]")
		self.assertEqual(res.tableDef.getColumnByName("mixed").type,
			"double precision[]")


class ScalarProdTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testBasic(self):
		res = self.querier.queryADQL(
			"SELECT"
			" arr_dot(flux, flux) as fluxprod,"
			" arr_dot(flux, plc) as mixedprod"
			" FROM test.arrsample")
		self.assertAlmostEqual(res.rows[0]["fluxprod"], 1.832107663154602)
		self.assertTrue(math.isnan(res.rows[0]["mixedprod"]))

		col = res.tableDef.getColumnByName("fluxprod")
		self.assertEqual(col.type, "double precision")
		self.assertEqual(col.ucd, "")
		self.assertEqual(col.unit, "")

		col = res.tableDef.getColumnByName("mixedprod")
		self.assertEqual(col.type, "double precision")
		self.assertEqual(col.ucd, "")
		self.assertEqual(col.unit, "")


class ScalarMulTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testWithLiteral(self):
		res = self.querier.queryADQL(
			"SELECT"
			" 2*flux-flux*2 as realzeros,"
			" 5*plc as dub"
			" FROM test.arrsample")

		col = res.tableDef.getColumnByName("realzeros")
		self.assertEqual(col.type, "real[]")
		self.assertEqual(col.ucd, "phot.mag")
		self.assertEqual(col.unit, "Jy")

		self.assertEqual(res.rows[0]["realzeros"], [0., 0., 0., 0., 0., 0.])

		col = res.tableDef.getColumnByName("dub")
		self.assertEqual(col.type, "double precision[]")
		self.assertEqual(col.ucd, "pos.cartesian")
		self.assertEqual(col.unit, "m")
		self.assertEqual(col.description, "Possibly a cartesian position -- *TAINTED*: the value was operated on in a way that unit and ucd may be severely wrong")

		self.assertAlmostEqual(res.rows[0]["dub"][2], -8.122808264515303)

	def testWithExpression(self):
		res = self.querier.queryADQL(
			"SELECT"
			" plc[1]*flux as fromleft,"
			" flux*plc[1] as fromright"
			" FROM test.arrsample")

		col = res.tableDef.getColumnByName("fromleft")
		self.assertEqual(col.type, "double precision[]")
		self.assertEqual(col.ucd, "")
		self.assertEqual(col.unit, "m*Jy")

		self.assertEqual([a-b for a, b in zip(
			res.rows[0]["fromleft"],
			res.rows[0]["fromright"])],
			[0., 0., 0., 0., 0., 0.])


class ScalarDivTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testWithLiteral(self):
		res = self.querier.queryADQL(
			"SELECT flux/2 as half FROM test.arrsample")

		col = res.tableDef.getColumnByName("half")
		self.assertEqual(col.type, "real[]")
		self.assertEqual(col.ucd, "phot.mag")
		self.assertEqual(col.unit, "Jy")
		self.assertEqual(res.rows[0]["half"][0], 0.06718212)

	def testWithExpression(self):
		res = self.querier.queryADQL(
			"SELECT flux/plc[1] as k FROM test.arrsample")

		col = res.tableDef.getColumnByName("k")
		self.assertEqual(col.type, "double precision[]")
		self.assertEqual(col.ucd, "")
		self.assertEqual(col.unit, "Jy/(m)")
		self.assertEqual(res.rows[0]["k"][-1], 0.7412795)


class DerivedAggFunctionTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testMessage(self):
		self.assertRaisesWithMsg(base.ValidationError,
			"Field query: Could not parse your query: Expected \")\", found ','  (at char 18), (line:1, col:19)",
			self.querier.queryADQL,
			("SELECT arr_min(plc, plc) FROM test.arrsample",))

	def testMin(self):
		res = self.querier.queryADQL(
			"SELECT"
			" arr_min(plc) as minplc,"
			" arr_min(flux) as minflux"
			" FROM test.arrsample")

		self.assertAlmostEqual(res.rows[0]["minplc"], -1.6245616529030604)
		self.assertAlmostEqual(res.rows[1]["minplc"], -0.7674541696434232)
		self.assertAlmostEqual(res.rows[0]["minflux"], 0.13436425)

		col = res.tableDef.getColumnByName("minplc")
		self.assertEqual(col.type, "double precision")
		self.assertEqual(col.ucd, "stat.min;pos.cartesian")
		self.assertEqual(col.unit, "m")

		col = res.tableDef.getColumnByName("minflux")
		self.assertEqual(col.type, "real")
		self.assertEqual(col.ucd, "stat.min;phot.mag")
		self.assertEqual(col.unit, "Jy")

	def testOnNonArray(self):
		self.assertRaisesWithMsg(
			base.ValidationError,
			"Field query: Scalar array function called on non-array?",
			self.querier.queryADQL,
			("SELECT arr_min(plc[0]) as minplc FROM test.arrsample",))

	def testMax(self):
		res = self.querier.queryADQL(
			"SELECT"
			" arr_max(plc) as maxplc,"
			" arr_max(flux) as maxflux"
			" FROM test.arrsample")

		self.assertAlmostEqual(res.rows[1]["maxplc"], 0.6789216057608836)
		self.assertAlmostEqual(res.rows[0]["maxflux"], 0.84743375)

		col = res.tableDef.getColumnByName("maxplc")
		self.assertEqual(col.ucd, "stat.max;pos.cartesian")
		col = res.tableDef.getColumnByName("maxflux")
		self.assertEqual(col.ucd, "stat.max;phot.mag")

	def testSum(self):
		res = self.querier.queryADQL(
			"SELECT"
			" ARR_sum(plc) as splc,"
			" arr_SUM(flux) as sflux"
			" FROM test.arrsample")

		self.assertAlmostEqual(res.rows[1]["splc"], 0.3352440988313101)
		self.assertAlmostEqual(res.rows[0]["sflux"], 2.9455676)

		col = res.tableDef.getColumnByName("splc")
		self.assertEqual(col.ucd, "pos.cartesian;arith.sum")
		col = res.tableDef.getColumnByName("sflux")
		self.assertEqual(col.ucd, "phot.mag;arith.sum")

	def testAvg(self):
		res = self.querier.queryADQL(
			"SELECT"
			" arr_avg(plc) as mplc,"
			" arr_avg(flux) as mflux"
			" FROM test.arrsample")

		self.assertAlmostEqual(res.rows[1]["mplc"], 0.11174803294377005)
		self.assertAlmostEqual(res.rows[0]["mflux"], 0.49092796)

		col = res.tableDef.getColumnByName("mplc")
		self.assertEqual(col.ucd, "pos.cartesian;stat.mean")
		col = res.tableDef.getColumnByName("mflux")
		self.assertEqual(col.ucd, "phot.mag;stat.mean")

	def testStddev(self):
		res = self.querier.queryADQL(
			"SELECT"
			" power(arr_stddev(plc), 2) as varce, plc,"
			" arr_stddev(flux) as fluxerr, flux"
			" FROM test.arrsample")

		self.assertAlmostEqual(res.rows[0]["fluxerr"], 0.27786547)
		self.assertAlmostEqual(res.rows[1]["varce"], 0.596022120)

		col = res.tableDef.getColumnByName("varce")
		self.assertEqual(col.ucd, "")
		self.assertEqual(col.description, "Possibly a cartesian position -- *TAINTED*: the value was operated on in a way that unit and ucd may be severely wrong")
		col = res.tableDef.getColumnByName("fluxerr")
		self.assertEqual(col.ucd, "stat.stdev;phot.mag")


class ArrayAggregationTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testBasicSum(self):
		res = self.querier.queryADQL(
			"SELECT"
			" sum(flux) allflux,"
			" sum(plc) as allplc"
			" FROM test.arrsample")

		self.assertEqual(len(res.rows[0]["allflux"]), 6)
		self.assertEqual(len(res.rows[0]["allplc"]), 3)

		self.assertAlmostEqual(res.rows[0]["allflux"][0], 1.0903986)
		self.assertAlmostEqual(res.rows[0]["allplc"][-1], -1.2007849901)

		self.assertEqual(res.tableDef.getColumnByName("allflux").ucd, "phot.mag")

	def testBasicMin(self):
		res = self.querier.queryADQL(
			"SELECT"
			" min(flux) minflux,"
			" min(plc) as minplc"
			" FROM test.arrsample")

		self.assertEqual(len(res.rows[0]["minflux"]), 6)
		self.assertEqual(len(res.rows[0]["minplc"]), 3)

		self.assertAlmostEqual(res.rows[0]["minflux"][-1], 0.44949105)
		self.assertAlmostEqual(res.rows[0]["minplc"][0], 0.606371890)

		self.assertEqual(
			res.tableDef.getColumnByName("minflux").ucd,
			"stat.min;phot.mag")
	
	def testMixedMin(self):
		res = self.querier.queryADQL(
			"SELECT min(arr) AS mixedmin FROM ("
			" SELECT flux AS arr FROM test.arrsample"
			" UNION"
			" SELECT plc AS arr FROM test.arrsample) as q")

		self.assertEqual(len(res.rows[0]["mixedmin"]), 6)

		self.assertAlmostEqual(res.rows[0]["mixedmin"][-1], 0.44949105)
		self.assertAlmostEqual(res.rows[0]["mixedmin"][0], 0.1343642473)

		# lazyness alert: the UCD should probably be "", but I won't dig
		# now why it's Not
		self.assertEqual(
			res.tableDef.getColumnByName("mixedmin").ucd,
			None)
		self.assertEqual(
			res.tableDef.getColumnByName("mixedmin").unit,
			"")

	def testBasicMax(self):
		res = self.querier.queryADQL(
			"SELECT"
			" max(flux) maxflux,"
			" max(plc) as maxplc"
			" FROM test.arrsample")

		self.assertEqual(len(res.rows[0]["maxflux"]), 6)
		self.assertEqual(len(res.rows[0]["maxplc"]), 3)

		self.assertAlmostEqual(res.rows[0]["maxflux"][-1], 0.73596996)
		self.assertAlmostEqual(res.rows[0]["maxplc"][0], 0.678921605)

		self.assertEqual(
			res.tableDef.getColumnByName("maxflux").ucd,
			"stat.max;phot.mag")

	def testBasicAvg(self):
		res = self.querier.queryADQL(
			"SELECT"
			" avg(flux) avgflux,"
			" avg(plc) as avgplc"
			" FROM test.arrsample")

		self.assertEqual(len(res.rows[0]["avgflux"]), 6)
		self.assertEqual(len(res.rows[0]["avgplc"]), 3)

		self.assertAlmostEqual(res.rows[0]["avgflux"][-1],
			list(self.querier.connection.query(
				"select avg(flux[6]) from test.arrsample"))[0][0])
		self.assertAlmostEqual(res.rows[0]["avgplc"][0],
			list(self.querier.connection.query(
				"select avg(plc[1]) from test.arrsample"))[0][0])

		self.assertEqual(
			res.tableDef.getColumnByName("avgflux").ucd,
			"phot.mag;stat.mean")

	def testBasicStddev(self):
		res = self.querier.queryADQL(
			"SELECT"
			" stddev(flux[4]) as fluxtest,"
			" stddev(flux) fluxerr,"
			" stddev(plc[1]) as plctest,"
			" stddev(plc) as plcerr"
			" FROM test.arrsample")

		r = res.rows[0]
		self.assertAlmostEqual(r["fluxerr"][3], r["fluxtest"])
		self.assertAlmostEqual(r["plcerr"][0], r["plctest"])
		
		self.assertEqual(
			res.tableDef.getColumnByName("fluxtest").ucd,
			"stat.stdev;phot.mag")


class ArrMapTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testBasic(self):
		res = self.querier.queryADQL(
			"SELECT arr_map(sin(x)+1, flux) as x, sin(flux[3])+1 as xt,"
			"  arr_map(round(x*100), plc) as p, round(plc[1]*100) as pt"
			" FROM test.arrsample")

		self.assertAlmostEqual(res.rows[0]["x"][2], res.rows[0]["xt"])
		self.assertAlmostEqual(res.rows[0]["p"][0], res.rows[0]["pt"])

		self.assertEqual(
			res.tableDef.getColumnByName("x").ucd,
			"")
		self.assertEqual(
			res.tableDef.getColumnByName("x").type,
			"double precision[]")


if __name__=="__main__":
	testhelpers.main(AdditionTest)
