"""
Tests for ADQL parsing and reasoning about query results.
"""

#c Copyright 2008-2019, the GAVO project
#c
#c This program is free software, covered by the GNU GPL.  See the
#c COPYING file in the source distribution.


import os
import re
import unittest
from pprint import pprint

import pyparsing

from gavo.helpers import testhelpers

from gavo import adql
from gavo import base
from gavo import stc
from gavo import rsc
from gavo import rscdef
from gavo import utils
from gavo.adql import annotations
from gavo.adql import morphpg
from gavo.adql import nodes
from gavo.adql import tree
from gavo.protocols import adqlglue
from gavo.protocols import tap
from gavo.stc import tapstc
from gavo.utils import pgsphere

import tresc

MS = base.makeStruct

class Error(Exception):
	pass


# The resources below are used elsewhere (e.g., taptest).
class _ADQLQuerier(testhelpers.TestResource):
	resources = [("conn", tresc.dbConnection)]

	def make(self, deps):
		return base.UnmanagedQuerier(deps["conn"])
	
	def cleanup(self, deps):
		deps["conn"].rollback()

adqlQuerier = _ADQLQuerier()


class _ADQLTestTable(testhelpers.TestResource):
	resources = [("conn", tresc.dbConnection)]

	def make(self, deps):
		self.rd = testhelpers.getTestRD()
		ds = rsc.makeData(self.rd.getById("ADQLTest"),
				connection=deps["conn"])
		tap.publishToTAP(self.rd, deps["conn"])
		return ds
	
	def clean(self, ds):
		conn = ds.tables.values()[0].connection
		conn.rollback()
		ds.dropTables(rsc.parseNonValidating)
		conn.commit()
adqlTestTable = _ADQLTestTable()


class _GeometryTable(testhelpers.TestResource):
	resources = [("conn", tresc.dbConnection)]

	def _makeRow(self, row_id=None, a_moc=None, a_point=None):
		return locals()

	def _iterRows(self):
		yield self._makeRow("moc-6-1", 
			a_moc = pgsphere.SMoc.fromASCII("6/450"),
			a_point = pgsphere.SPoint.fromDegrees(23, 42))
		yield self._makeRow("moc-4-1", 
			a_moc = pgsphere.SMoc.fromASCII("4/28"),
			a_point = pgsphere.SPoint.fromDegrees(25, 44))

	def make(self, deps):
		self.rd = testhelpers.getTestRD()
		try:
			ds = rsc.makeData(self.rd.getById("import_adqlgeo"),
				connection=deps["conn"],
				forceSource=self._iterRows())
			tap.publishToTAP(self.rd, deps["conn"])
		except base.SourceParseError:
			ds = "SMoc is not available, tests will fail"
		return ds
	
	def clean(self, ds):
		if isinstance(getattr(ds, "original", ds), basestring):
			# it's a fake thing because we didn't have SMoc available
			return

		conn = ds.tables.values()[0].connection
		conn.rollback()
		ds.dropTables(rsc.parseNonValidating)
		conn.commit()
geomTestTable = _GeometryTable()


class MatchLimitTest(testhelpers.VerboseTest):
	__metaclass__ = testhelpers.SamplesBasedAutoTest

	def _runTest(self, sample):
		topClause, maxrec, hardLimit, expectedInQuery, expectedLimit = sample
		tree = parseWithArtificialTable(
			"SELECT %s * FROM spatial"%topClause)
		overflowLimit = adqlglue._updateMatchLimits(tree, maxrec, hardLimit)
		self.assertEqual(tree.setLimit, expectedInQuery)
		self.assertEqual(overflowLimit, expectedLimit)
	
	samples = [
		("", 2000, 200000, 2000, 2000),
		("TOP 10", 2000, 200000, 10, 2000),
		("TOP 10000", 2000, 200000, 2001, 2001),
		("TOP 30000", None, 200000, 20001, 20001),
		("TOP 30000", None, None, 20001, 20001),
# 05
		("TOP 10000", 99999999999, None, 10000, 20000000),
		("TOP 10000", 99999999999, None, 10000, 20000000),
		("", 99999999999, None, 20000000, 20000000),
		("", 99999999999, 50000, 50000, 50000),
		("TOP 1", 1, 50000, 1, 2),
	]


class _SymbolsParseTestBase(testhelpers.VerboseTest):
	def setUp(self):
		self.symbols, _ = adql.getRawGrammar()

	def _assertParses(self, symbol, literal):
		try:
			(self.symbols[symbol]+pyparsing.StringEnd()).parseString(literal)
		except adql.ParseException:
			raise AssertionError("%s doesn't parse %s but should."%(symbol,
				repr(literal)))

	def _assertDoesntParse(self, symbol, literal):
		try:
			(self.symbols[symbol]+pyparsing.StringEnd()).parseString(literal)
		except (adql.ParseException, adql.ParseSyntaxException):
			pass
		else:
			raise AssertionError("%s parses %s but shouldn't."%(symbol,
				repr(literal)))


class _GoodExamplesBase(_SymbolsParseTestBase):
	__metaclass__ = testhelpers.SamplesBasedAutoTest

	def _runTest(self, sample):
		self._assertParses(*sample)


class _BadExamplesBase(_SymbolsParseTestBase):
	__metaclass__ = testhelpers.SamplesBasedAutoTest

	def _runTest(self, sample):
		self._assertDoesntParse(*sample)


class MiscBadSymbolsTest(_BadExamplesBase):

	samples = [
		("dateValueExpression", "TIMESTAMP('1992-03-01', a)"),
		("dateValueExpression", "TIMESTAMP(b, a)"),
		("castSpecification", "cast(char)"),
		("castSpecification", "cast(x as y)"),
		("castSpecification", "cast(x as national bank)"),
# 5
		("castSpecification", "cast(x as char(*))"),
		("numericValueExpression", "BITWISE_NOT()"),
		("numericValueExpression", "BITWISE_AND(a)"),
		("numericValueExpression", "BITWISE_AND(a, b, c)"),
		("joinedTable", "foo natural join (t1, t2)"),
# 10
		("joinedTable", "(t1 join t2) using (foo)"),
		("joinedTable", "(t1 join t2) on (t1.foo=t2.bar)"),
		("possiblyAliasedTable", "t1 tablesample(a)"),
	]


class MiscGoodSymbolsTest(_GoodExamplesBase):
	samples = [
		("searchCondition", "5+9<'b' || 'foo'"),
		("delimitedIdentifier", '"a"'),
		("delimitedIdentifier", '"a""b"'),
		("comparisonPredicate", '"ja ja"<"Umph"'),
		("comparisonPredicate", "a<b"),
# 5
		("comparisonPredicate", "'a'<'b'"),
		("comparisonPredicate", "'a'<'b' || 'foo'"),
		("comparisonPredicate", "5+9<'b' || 'foo'"),
		("stringValueExpression", "'abc'"),
		("stringValueExpression", "'abc' || 'def'"),
# 10
		("stringValueExpression", "'abc' || 'def' || '78%%'"),
		("dateValueExpression", "TIMESTAMP('1992-03-01' || 'T12:33')"),
		("dateValueExpression", "TIMESTAMP(b)"),
		("castSpecification", "CAST(x+23.0 AS INTEGER)"),
		("castSpecification", "CAST('230' AS BIGINT)"),
# 15
		("castSpecification", "CAST(230 AS NATIONAL   CHAR ( 10 ))"),
		("castSpecification", 'CAST("My stupid col" || \'x\' AS CHAR(1230))'),
		("castSpecification", "CAST(230  AS CHAR)"),
		("castSpecification", "CAST(230+PI() AS   REAL)"),
		("castSpecification", "CAST(SQRT(honk) AS DOUBLE PRECISION)"),
# 20
		("castSpecification", "CAST('2017-02-30' AS TIMESTAMP)"),
		("castSpecification", "CAST(NULL AS TIMESTAMP)"),
		("numericValueExpression", "BitWISE_NOT(x)"),
		("numericValueExpression", "BITWISE_and(x, 2)"),
		("numericValueExpression", "BITWISE_OR(x, y+2)"),
# 25
		("numericValueExpression", "BITWISE_XOR(5, 2)"),
		("withSpecification", "WITH foobar as (select a,b,c from x),"
			" knatter as (select cos(d)+13 as foo from y)"),
		("valueExpressionPrimary", "arr[15]"),
		("valueExpressionPrimary", "arr[ROUND(x/10)+3]"),
		("derivedColumn", "98x"),
#30
		("derivedColumn", "(A+B)X"),
		("possiblyAliasedTable", '"gnott" as g tablesample (0.1)'),
		("possiblyAliasedTable", '"gnott" tablesample(1e-7)'),
		("setGeneratingFunction", "generate_series ( 3 , 4 )"),
	]


class GoodBooleanTermsTest(_GoodExamplesBase):
	samples = [
		("searchCondition", "z BETWEEN 8 AND 9"),
		("searchCondition", "z BETWEEN 'a' AND 'b'"),
		("searchCondition", "z BEtWEEN x+8 AnD x*8"),
		("searchCondition", "z NOT BETWEEN x+8 AND x*8"),
		("searchCondition", "z iN (a)"),
		("searchCondition", "z NoT In (a)"),
		("searchCondition", "z NOT IN (a, 4, 'xy')"),
		("searchCondition", "z IN (select x from foo)"),
		("searchCondition", "u LIKE '%'"),
		("searchCondition", "u NoT LiKE '%'"),
		("searchCondition", "u ILIKE '%'"),
		("searchCondition", "u Not ILIKE '%'"),
		("searchCondition", "u || 'foo' NOT LIKE '%'"),
		("searchCondition", "u NOT LIKE '%' || 'xy'"),
		("searchCondition", "k IS NULL"),
		("searchCondition", "k IS NOT NULL"),
	]


class BadGeometriesTest(_BadExamplesBase):
	samples = [
		("point", "POINT(x,y,z)"),
		("circle", "circle('ICRS', x)"),
		("circle", "circle(5, y)"),
		("geometryExpression", "circle('ICRS', x)"),
		("polygon", "POLYGON(2, 3)"),
#5
		("polygon", "POLYGON(2, 4, 3)"),
		("polygon", "POLYGON('', 2, 4, 3)"),
		("polygon", "POLYGON(POINT(2, 4), POINT(3, 5), 3, 6)"),
		("region", "REGION(23, 'CIRCLE ICRS 2 3 4)"),
		("booleanTerm", "Point('fk5',2,3)"),
#10
		("booleanTerm", "CIRCLE('fk5', 2, 3)=x"),
		("booleanTerm", "POLYGON('fk5', 2, 3, 3, 0, 23, 0, 45)=x"),
		("booleanTerm", "CENTROID(3)=x"),
		("booleanTerm", "CENTROID(COUNT(*))=x"),
		("nonPredicateGeometryFunction", "DISTANCE()"),
#15
		("nonPredicateGeometryFunction", "DISTANCE(a, b, c)"),
		("nonPredicateGeometryFunction", "DISTANCE(POINT(a,b), c, d)"),
		("nonPredicateGeometryFunction", "DISTANCE(POINT(a,b), POINT(c, d), e)"),
		("nonPredicateGeometryFunction", "DISTANCE(a, CENTROID(CIRCLE(b, c, d)))"),
	]


class GoodGeometriesTest(_GoodExamplesBase):
	samples = [
		("point", "pOint('ICRS', x,y)"),
		("point", "point(NULL, x,y)"),
		("point", "POINT(x,y)"),
		("circle", "circle('ICRS', x,y, r)"),
		("circle", "CIRCLE(NULL, 1,2, 4)"),
#5
		("circle", "CIRCLE(1,2, 4)"),
		("circle", "CIRCLE('', c, r)"),
		("circle", "CIRCLE(c, 5)"),
		("polygon", "POLYGON(NULL, 1, 2, 4, 3, 4, 4)"),
		("polygon", "POLYGON('', 1, 2, 4, 3, 4, 4)"),
#10
		("polygon", "POLYGON(1, 2, 4, 3, 4, 4)"),
		("polygon", "POLYGON(a, b, c)"),
		("polygon", "POLYGON(POINT(2,3), b, c)"),
		("polygon", "POLYGON(a, b, POINT(2,3))"),
		("box", "BOX(1, 2, 0.2, 0.1)"),
#15
		("box", "BOX('GALACTIC', 1, 2, 0.2, 0.1)"),
		("region", "REGION('CIRCLE ICRS 2 3 4)')"),
		("geometryExpression", "CIRCLE('ICRS', 1,2, 4)"),
		("predicateGeometryFunction", 
			"Contains(pOint('ICRS', x,y),CIRCLE('ICRS', 1,2, 4))"),
		("booleanTerm", "Point(NULL, 2, 3)=x"),
#20
		("booleanTerm", "Point('fk5', 2, 3)=x"),
		("booleanTerm", "CIRCLE('fk5', 2, 3, 3)=x"),
		("booleanTerm", "box('fk5', 2, 3, 3, 0)=x"),
		("booleanTerm", "POLYGON('fk5', 2, 3, 3, 0, 23, 0, 45, 34)=x"),
		("booleanTerm", "REGION('mainfranken')=x"),
#25
		("booleanTerm", "CENTROID(CIRCLE('fk4', 2, 3, 3))=x"),
		("nonPredicateGeometryFunction", "DISTANCE(a, b)"),
		("nonPredicateGeometryFunction", "DISTANCE(POINT(a, b), POINT(c,d))"),
		("nonPredicateGeometryFunction", "DISTANCE(a, POINT(c, d))"),
		("nonPredicateGeometryFunction", 
			"DISTANCE(ivo_apply_pm(a, b, 0.1, -0.2, 20), c)"),
	]


class _ADQLParsesTest(testhelpers.VerboseTest):
	"""an abstract base for tests checking whether ADQL expressions parse.
	"""
	def setUp(self):
		_, self.grammar = adql.getRawGrammar()
		testhelpers.VerboseTest.setUp(self)

	def _assertGoodADQL(self, statement):
		try:
			self.grammar.parseString(statement)
		except (adql.ParseException, adql.ParseSyntaxException):
			raise AssertionError("%s doesn't parse but should."%statement)
		except RuntimeError:
			raise Error("%s causes an infinite recursion"%statement)

	def _assertBadADQL(self, statement):
			try:
				self.assertRaisesVerbose(
					(adql.ParseException,adql.ParseSyntaxException), 
					self.grammar.parseString, (statement,), 
					"Parses but shouldn't: %s"%statement)
			except RuntimeError:
				raise Error("%s causes an infinite recursion"%statement)


class NakedParseTest(_ADQLParsesTest):
	"""tests for plain parsing (without tree building).
	"""
	def _assertParse(self, correctStatements):
		for stmt in correctStatements:
			self._assertGoodADQL(stmt)

	def _assertDontParse(self, badStatements):
		for stmt in badStatements:
			self._assertBadADQL(stmt)

	def testPlainSelects(self):
		"""tests for non-errors on some elementary select expressions parse.
		"""
		self._assertParse([
				"SELECT x FROM y",
				"SELECT x FROM y WHERE z=0",
				"SELECT x, v FROM y WHERE z=0 AND v>2",
				"SELECT 89 FROM X",
				"SELECT 89 FROM X AS Y",
				"SELECT 89 FROM X Y",
			])

	def testDelimited(self):
		self._assertParse([
			'SELECT "f-bar", "c""ho" FROM "nons-ak" WHERE "ja ja"<"Umph"'])

	def testSimpleSyntaxErrors(self):
		"""tests for rejection of gross syntactic errors.
		"""
		self._assertDontParse([
				"W00T",
				"SELECT A",
				"SELECT A FROM",
				"SELECT A FROM B WHERE",
				"SELECT FROM",
				"SELECT 89! FROM z",
			])

	def testCaseInsensitivity(self):
		"""tests for case being ignored in SQL keywords.
		"""
		self._assertParse([
				"select z as U From n",
				"seLect z AS U FROM n",
			])

	def testJoins(self):
		"""tests for JOIN syntax.
		"""
		self._assertParse([
			"select x from t1, t2",
			"select x from t1, t2, t3",
			"select x from t1, t2, t3 WHERE t1.x=t2.y",
			"select x from t1 JOIN t2",
			"select x from t1 NATURAL JOIN t2",
			"select x from t1 LEFT OUTER JOIN t2",
			"select x from t1 RIGHT OUTER JOIN t2",
			"select x from t1 FULL OUTER JOIN t2",
			"select x from t1 FULL OUTER JOIN t2 ON (x=y)",
			"select x from t1 FULL OUTER JOIN t2 USING (x,y)",
			"select x from t1 INNER JOIN (t2 JOIN t3)",
			"select x from (t1 JOIN t4) FULL OUTER JOIN (t2 JOIN t3)",
			"select x from t1 NATURAL JOIN t2, t3",
		])

	def testBadJoins(self):
		"""tests for syntax error detection in JOINs.
		"""
		self._assertDontParse([
			"select x from t1 JOIN",
			"select x from JOIN t1",
			"select x from t1 join JOIN t1",
			"select x from t1 NATURAL JOIN t2, t3 OUTER",
			"select x from t1 NATURAL JOIN t2, t3 ON",
			"select x from t1, t2, t3 ON",
		])

	def testDetritus(self):
		"""tests for ORDER BY and friends.
		"""
		self._assertParse([
			"select x from t1 order by z",
			"select x from t1 order by z desc",
			"select x from t1 order by z desc, x asc",
			"select x from t1 group by z",
			"select x from t1 group by z, s",
			"select x from t1 having x=z AND 7<u",
		])

	def testBadDetritus(self):
		"""tests for syntax errors in ORDER BY and friends.
		"""
		self._assertDontParse([
			"select x from t1 having y",
		])

	def testBadBooleanTerms(self):
		p = "select x from y where "
		self._assertDontParse([
			p+"z BETWEEN",
			p+"z BETWEEN AND",
			p+"z BETWEEN AND 5",
			p+"z 7 BETWEEN 5 AND ",
			p+"x IN",
			p+"x IN 5",
			p+"x IN (23, 3,)",
			p+"x Is None",
		])
	
	def testsBadFunctions(self):
		"""tests for rejection of bad function calls.
		"""
		p = "select x from y where "
		self._assertDontParse([
			p+"ABS()<3",
			p+"ABS(y,z)<3",
			p+"ATAN2(x)<3",
			p+"PI==3",
		])
	
	def testFunkyIds(self):
		"""tests for parsing quoted identifiers.
		"""
		p = "select x from y where "
		self._assertParse([
			p+'"some weird column">0',
			p+'"some even ""weirder"" column">0',
			p+'"SELECT">0',
		])

	def testMiscGood(self):
		"""tests for parsing of various legal statements.
		"""
		self._assertParse([
			"select a, b from (select * from x) AS q",
		])

	def testMiscBad(self):
		"""tests for rejection of various bad statements.
		"""
		self._assertDontParse([
			"select a, b from (select * from x) q r",
			"select a, b from (select * from x)",
			"select x.y.z.a.b from a",
			"select x from a.b.c.d",
		])

	def testStringExpressionSelect(self):
		self._assertParse([
			"select m || 'ab' from q",])


class FunctionsParseTest(_ADQLParsesTest):
	__metaclass__ = testhelpers.SamplesBasedAutoTest

	def _runTest(self, sample):
		self._assertGoodADQL("select x from y where "+sample)

	samples = [
		"ABS(-3)<3",
		"ABS(-3.0)<3",
		"ABS(-3.0E4)<3",
		"ABS(-3.0e-4)<3",
		"ABS(x)<3",
		"ATAN2(-3.0e-4, 4.5)=x",
		"RAND(4)=x",
		"RAND()=x",
		"ROUND(23)=x",
		"ROUND(23,2)=x",
		"ROUND(PI(),2)=3.14",
		"POWER(x,10)=3.14",
		"POWER(10,x)=3.14",
		"1=CROSSMATCH(ra1, dec1, ra2, dec2, 0.001)"
	]


class SetExpressionsTest(_ADQLParsesTest):
	__metaclass__ = testhelpers.SamplesBasedAutoTest

	def _runTest(self, sample):
		self._assertGoodADQL(sample)
	
	samples = [
		"select x from t1 union select x from t2",
		"select x from t1 intersect select x from t2",
		"select x from t1 except select x from t2",
		"select x from t1 where x>2 union select x from t2",
		"select * from t1 union select x from t2 intersect select x from t3"
			" except select x from t4",
# 5
		"select * from (select * from t1 except select * from t2) as q union"
			" select * from  t3",
		"select * from t1 union select foo from (select * from t2 except select * from t1) as q",
	]


class AsTreeTest(testhelpers.VerboseTest):
	"""tests for asTree()
	"""
	def testSimple(self):
		t = adql.parseToTree("SELECT * FROM t WHERE 1=CONTAINS("
			"CIRCLE('ICRS', 4, 4, 2), POINT('', ra, dec))").asTree()
		self.assertEqual(t[1][1][1][1][0], 'possiblyAliasedTable')
		self.assertEqual(t[1][1][3][0], 'whereClause')
		self.assertEqual(t[1][1][3][1][2][1][0], 'circle')


class _TreeParseTestBase(testhelpers.VerboseTest):
	def setUp(self):
		self.grammar = adql.getGrammar()

	def assertAttrs(self, c, **assertions):
		for k,v in assertions.items():
			self.assertEqual(getattr(c, k), v, 
				"%s, %s!=%s"%(k, repr(getattr(c, k)), repr(v)))


class TreeParseTest(_TreeParseTestBase):
	def testSelectList(self):
		for q, e in [
			("select a from z", ["a"]),
			("select x.a from z", ["a"]),
			("select x.a, b from z", ["a", "b"]),
			('select "one weird name", b from z', 
				[utils.QuotedName('one weird name'), "b"]),
		]:
			tree = self.grammar.parseString(q)[0]
			res = [c.name for c in tree.getSelectFields()]
			self.assertEqual(res, e, 
				"Select list from %s: expected %s, got %s"%(q, e, res))

	def testSourceTables(self):
		for q, e in [
			("select * from z", ["z"]),
			("select * from z.x", ["z.x"]),
			("select * from z.x.y", ["z.x.y"]),
			("select * from z.x.y, a", ["z.x.y", "a"]),
			("select * from (select * from z) as q, a", ["q", "a"]),
		]:
			res = list(self.grammar.parseString(q)[0].getAllNames())
			self.assertEqual(res, e, 
				"Source tables from %s: expected %s, got %s"%(q, e, res))

	def testSourceTablesJoin(self):
		for q, e in [
			("select * from z join x", ["z", "x"]),
			("select * from (select * from a,b, (select * from c,d) as q) as r join"
				"(select * from x,y) as p", ["r", "p"]),
		]:
			res = list(self.grammar.parseString(q)[0].getAllNames())
			self.assertEqual(res, e, 
				"Source tables from %s: expected %s, got %s"%(q, e, res))

	def testContributingTables(self):
		q = ("select * from (select * from urks.a,b,"
			" (select * from c,monk.d) as q) as r join"
			" (select * from x,y) as p")
		self.grammar.parseString(q)[0].getContributingNames()
		self.assertEqual(self.grammar.parseString(q)[0].getContributingNames(),
			set(['c', 'b', 'urks.a', 'q', 'p', 'r', 'y', 'x', 'monk.d']))

	def testAliasedColumn(self):
		q = "select foo+2 as fp2 from x"
		res = self.grammar.parseString(q)[0]
		field = list(res.getSelectFields())[0]
		self.assertEqual(field.name, "fp2")
	
	def testTainting(self):
		for q, (exName, exTaint) in [
			("select x from z", ("x", False)),
			("select x as u from z", ("u", False)),
			("select x+2 from z", (None, True)),
			('select x+2 as "99 Monkeys" from z', (utils.QuotedName("99 Monkeys"), 
				True)),
			('select x+2 as " ""cute"" Monkeys" from z', 
				(utils.QuotedName(' "cute" Monkeys'), True)),
		]:
			res = list(self.grammar.parseString(q)[0].getSelectFields())[0]
			self.assertEqual(res.tainted, exTaint, "Field taintedness wrong in %s"%
				q)
			if exName:
				self.assertEqual(res.name, exName)

	def testValueExpressionColl(self):
		t = adql.parseToTree("select x from z where 5+9>'gaga'||'bla'")
		compPred = t.whereClause.children[1]
		self.assertEqual(compPred.op1.type, "numericValueExpression")
		self.assertEqual(compPred.opr, ">")
		self.assertEqual(compPred.op2.type, "stringValueExpression")

	def testQualifiedStar(self):
		t = adql.parseToTree("select t1.*, s1.t2.* from t1, s1.t2, s2.t3")
		self.assertEqual(t.selectList.selectFields[0].type, "qualifiedStar")
		self.assertEqual(t.selectList.selectFields[0].sourceTable.qName,
			"t1")
		self.assertEqual(t.selectList.selectFields[1].sourceTable.qName,
			"s1.t2")

	def testBadSystem(self):
		self.assertRaises(adql.ParseSyntaxException, 
			self.grammar.parseString, "select point('QUARK', 1, 2) from spatial")

	def testQuotedTableName(self):
		t = adql.parseToTree('select "abc-g".* from "abc-g" JOIN "select"')
		self.assertEqual(t.selectList.selectFields[0].sourceTable.name, "abc-g")
		self.assertEqual(t.selectList.selectFields[0].sourceTable.qName, '"abc-g"')

	def testQuotedSchemaName(self):
		t = adql.parseToTree('select * from "Murks Schema"."Murks Tabelle"')
		table = t.fromClause.tableReference
		self.assertEqual(table.tableName.name,
			utils.QuotedName("Murks Tabelle"))
		self.assertEqual(table.tableName.schema,
			utils.QuotedName("Murks Schema"))
	
	def testSetLimitInherited(self):
		t = adql.parseToTree('select top 3 * from t1 union'
			' select top 4 * from t2 except select * from t3')
		self.assertEqual(t.setLimit, 4)
	
	def testSetLimitDeep(self):
		t = adql.parseToTree(
			'select top 7 * from t1 union'
			' (select top 4 * from t2 except select * from t3)'
			' except (select top 30 x from t4 except select top 3 y from t5)')
		self.assertEqual(t.setLimit, 30)

	def testHexadecimal(self):
		t = adql.parseToTree(
			"select x-0xaf, -0x1fffffff from t1")
		sels = list(t.getSelectClauses())[0].selectList.selectFields
		self.assertEqual(sels[0].flatten(), 'x - 175')
		self.assertEqual(sels[1].flatten(), '- 536870911')


class CircleTreeParseTest(_TreeParseTestBase):
	def _getParsed(self, circleLiteral):
		t = adql.parseToTree('select %s from t1'%circleLiteral)
		return t.children[0].selectList.selectFields[0].expr	

	def testColrefNosys(self):
		c = self._getParsed("CIRCLE(x, r)")
		self.assertAttrs(c, cooSys="", x=None, y=None)
		self.assertAttrs(c.radius, type="columnReference", name="r")
		self.assertAttrs(c.center, type="columnReference", name="x")

	def testColrefWithsys(self):
		c = self._getParsed("CIRCLE('ICRS', \"center\", \"radius\")")
		self.assertAttrs(c, cooSys="ICRS", x=None, y=None)
		self.assertAttrs(c.radius, type="columnReference", name="radius")
		self.assertAttrs(c.center, type="columnReference", name="center")

	def testSplitNosys(self):
		c = self._getParsed("CIRCLE(a, d, r)")
		self.assertAttrs(c, cooSys="", center=None)
		self.assertAttrs(c.radius, type="columnReference", name="r")
		self.assertAttrs(c.x, type="columnReference", name="a")
		self.assertAttrs(c.y, type="columnReference", name="d")

	def testLiteralPoint(self):
		c = self._getParsed("CIRCLE(NULL, POINT(a, d) ,r)")
		self.assertAttrs(c, cooSys="UNKNOWN", center=None)
		self.assertAttrs(c.radius, type="columnReference", name="r")
		self.assertAttrs(c.x, type="columnReference", name="a")
		self.assertAttrs(c.y, type="columnReference", name="d")

	def testLiteralPointNosys(self):
		c = self._getParsed("CIRCLE(POINT('ICRS', a, d) ,r)")
		self.assertAttrs(c, cooSys="", center=None)
		self.assertAttrs(c.radius, type="columnReference", name="r")
		self.assertAttrs(c.x, type="columnReference", name="a")
		self.assertAttrs(c.y, type="columnReference", name="d")

	def testSplitWithExpression(self):
		c = self._getParsed("CIRCLE(ra+2, dec ,r-1)")
		self.assertAttrs(c, cooSys="", center=None)
		self.assertAttrs(c.radius, type="valueExpression")
		self.assertAttrs(c.x, type="valueExpression")
		self.assertAttrs(c.y, type="columnReference", name="dec")
		self.assertEqual(c.radius.children[1], '-')

	def testGeoUDF(self):
		c = self._getParsed("CIRCLE(ivo_apply_pm(12, 13, 1e-7, -1e-7, 19) ,r)")
		self.assertEqual(c.x.children[1], '+')
		self.assertEqual(c.x.children[2].children[0].children[0], '19')
		self.assertEqual(c.radius.name, "r")
		self.assertEqual(c.flatten(),
			"CIRCLE(POINT(12 + 19 * 1e-7 / COS(RADIANS(13)), 13 + - 1e-7 * 19),r)")


class PolygonTreeParseTest(_TreeParseTestBase):
	def _getParsed(self, polygonLiteral):
		t = adql.parseToTree('select %s from t1'%polygonLiteral)
		return t.children[0].selectList.selectFields[0].expr	

	def testColrefNosys(self):
		p = self._getParsed("POLYGON(a, b, c, d)")
		self.assertAttrs(p, type="polygon", cooSys="", coos=None)
		self.assertEqual(len(p.points), 4)
		self.assertAttrs(p.points[0], type="columnReference", name="a")
		self.assertAttrs(p.points[-1], type="columnReference", name="d")

	def testColrefWithsys(self):
		p = self._getParsed("polygon('ICRS', \"p 1\", \"p 2\", p3)")
		self.assertAttrs(p, cooSys="ICRS", coos=None)
		self.assertEqual(len(p.points), 3)
		self.assertAttrs(p.points[0], type="columnReference", name="p 1")
		self.assertAttrs(p.points[1], type="columnReference", name="p 2")
		self.assertAttrs(p.points[2], type="columnReference", name="p3")

	def testSplitNosys(self):
		p = self._getParsed("polygon(x1, y1, x2, y2, x3, y3)")
		self.assertAttrs(p, cooSys="", points=None)
		self.assertEqual(len(p.coos), 3)
		self.assertAttrs(p.coos[0][0], type="columnReference", name="x1")
		self.assertAttrs(p.coos[-1][1], type="columnReference", name="y3")

	def testLiteralPoints(self):
		p = self._getParsed("POLYGON(NULL, POINT(x1, y1),"
			" POINT('ICRS', x2, y2), POINT(x3,y3))")
		self.assertAttrs(p, cooSys="UNKNOWN", points=None)
		self.assertEqual(len(p.coos), 3)
		self.assertAttrs(p.coos[0][0], type="columnReference", name="x1")
		self.assertAttrs(p.coos[-1][1], type="columnReference", name="y3")

	def testLiteralPointNosys(self):
		p = self._getParsed("POLYGON(NULL, POINT(x1, y1),"
			" p, POINT(x3,y3))")
		self.assertAttrs(p, cooSys="UNKNOWN", coos=None)
		self.assertEqual(len(p.points), 3)
		self.assertAttrs(p.points[0].x, type="columnReference", name="x1")
		self.assertAttrs(p.points[-1].y, type="columnReference", name="y3")

	def testSplitWithExpression(self):
		p = self._getParsed(
			"polygon(ra+1, dec+1 ,ra+1, dec-1, ra-1, dec+1, ra-1, dec-1)")
		self.assertAttrs(p, cooSys="", points=None)
		self.assertAttrs(p.coos[0][0], type="valueExpression")
		self.assertEqual(p.coos[-1][1].children[1], "-")

	def testGeoUDF(self):
		p = self._getParsed("polygon(ivo_apply_pm(12, 13, 1e-7, -1e-7, 19),"
			"ivo_apply_pm(12, 13, -1e-7, 1e-7, -19),"
			"ivo_apply_pm(12, 13, 1e-7, -1e-7, -19),"
			"ivo_apply_pm(12, 13, 1e-7, 1e-7, 19))")
		self.assertEqual(len(p.coos), 4)
		self.assertEqual(p.flatten(),
			"POLYGON(12 + 19 * 1e-7 / COS(RADIANS(13)), 13 + - 1e-7 * 19, 12 + - 19 * - 1e-7 / COS(RADIANS(13)), 13 + 1e-7 * - 19, 12 + - 19 * 1e-7 / COS(RADIANS(13)), 13 + - 1e-7 * - 19, 12 + 19 * 1e-7 / COS(RADIANS(13)), 13 + 1e-7 * 19)")

class DistanceParseTest(_TreeParseTestBase):
	def _getParsed(self, distanceLiteral):
		t = adql.parseToTree('select %s from t1'%distanceLiteral)
		return t.children[0].selectList.selectFields[0].expr	

	def testSplitArgs(self):
		df = self._getParsed("distance(ra1, dec1, ra2, dec2)")
		self.assertAttrs(df, funName="DISTANCE", pointArguments=False)
		self.assertEqual(len(df.args), 4)
		self.assertAttrs(df.args[0], type="columnReference", name="ra1")
		self.assertAttrs(df.args[-1], type="columnReference", name="dec2")

	def testPointArgs(self):
		df = self._getParsed("distance(p1, p2)")
		self.assertAttrs(df, funName="DISTANCE", pointArguments=True)
		self.assertEqual(len(df.args), 2)
		self.assertAttrs(df.args[0], type="columnReference", name="p1")
		self.assertAttrs(df.args[-1], type="columnReference", name="p2")

	def testPointArgsLiterals(self):
		df = self._getParsed("distance(point(ra1, dec1), point(ra2, dec2))")
		self.assertAttrs(df, funName="DISTANCE", pointArguments=False)
		self.assertEqual(len(df.args), 4)
		self.assertAttrs(df.args[0], type="columnReference", name="ra1")
		self.assertAttrs(df.args[-1], type="columnReference", name="dec2")

	def testPointArgsOneLiteral(self):
		df = self._getParsed("distance(point(ra1, dec1), p2)")
		self.assertAttrs(df, funName="DISTANCE", pointArguments=True)
		self.assertEqual(len(df.args), 2)
		self.assertAttrs(df.args[0], type="point")
		self.assertAttrs(df.args[0].x, type="columnReference", name="ra1")
		self.assertAttrs(df.args[1], type="columnReference", name="p2")


class ParseErrorTest(testhelpers.VerboseTest):
	"""tests for sensible error messages.
	"""
	__metaclass__ = testhelpers.SamplesBasedAutoTest

	def _runTest(self, sample):
		query, msgFragment = sample
		try:
			res = adql.getGrammar().parseString(query, parseAll=True)
		except (adql.ParseException, adql.ParseSyntaxException) as ex:
			msg = unicode(ex)
			self.failUnless(msgFragment in msg,
				"'%s' does not contain '%s'"%(msg, msgFragment))
		else:
			self.fail("'%s' parses but should not"%query)

	samples = [
		("", 'Expected "SELECT" (at char 0)'),
		("select mag from %s", 'Expected table reference (at char 16)'),
		("SELECT TOP foo FROM x", 'Expected unsigned integer (at char 11)'),
		("SELECT FROM x", 'Expected select list (at char 7)'),
		("SELECT x, FROM y", 'Expected select list item (at char 10)'),
#5
		("SELECT * FROM distinct", 'Expected table reference (at char 14)'),
		("SELECT DISTINCT FROM y", 'Expected select list (at char 16)'),
		("SELECT *", 'Expected "FROM" (at char 8)'),
		("SELECT * FROM y WHERE", 'Expected boolean expression (at char 21)'),
		("SELECT * FROM y WHERE y u 2", 
			'Expected boolean expression (at char 24)'),
# 10
		("SELECT * FROM y WHERE y < 2 AND", 
			'Expected boolean expression (at char 31)'),
		("SELECT * FROM y WHERE y < 2 OR", 
			'Expected boolean expression (at char 30)'),
		("SELECT * FROM y WHERE y IS 3", 'Expected "NULL" (at char 27)'),
		("SELECT * FROM y WHERE CONTAINS(a,b)", 
			'Expected boolean expression (at char 35)'),
		("SELECT * FROM y WHERE 1=CONTAINS(POINT('ICRS',x,'sy')"
			" ,CIRCLE('ICRS',x,y,z))", 
			'Expected numeric expression (at char 48)'),
# 15
		("SELECT * FROM (SELECT * FROM x)", 
			'Expected table reference (at char 31)'),
		("SELECT * FROM x WHERE EXISTS z", 'Expected subquery (at char 29)'),
		("SELECT POINT('junk', 3,4) FROM z",
			"xpected numeric expression (at char 13)"),
		("SELECT * from a join b on foo",
			"Expected boolean expression (at char 29"),
		("SELECT * from a OFFSET 20 join b on foo",
			"Expected end of text (at char 26)"),
# 20
		("SELECT * from a natural join b OFFSET banana",
			"Expected unsigned integer (at char 38)"),
	]


class JoinTypeTest(testhelpers.VerboseTest):
	__metaclass__ = testhelpers.SamplesBasedAutoTest
	sym = adql.getSymbols()["joinedTable"]

	def _collectJoinTypes(self, joinedNode):
		res = []
		if hasattr(joinedNode.leftOperand, "leftOperand"):
			res.extend(self._collectJoinTypes(joinedNode.leftOperand))
		res.append(joinedNode.getJoinType())
		if hasattr(joinedNode.rightOperand, "leftOperand"):
			res.extend(self._collectJoinTypes(joinedNode.rightOperand))
		return res

	def _runTest(self, sample):
		query, joinType = sample
		self.assertEqual(
			self._collectJoinTypes(self.sym.parseString(query)[0]), joinType)
	
	samples = [
		("a CROSS JOIN b", ["CROSS"]),
		("a join b", ["NATURAL"]),
		("a join b using (x)", ["USING"]),
		("a CROSS JOIN b CROSS JOIN c", ["CROSS", "CROSS"]),
		("a CROSS JOIN b join c", ["CROSS", "NATURAL"]),
#5
		("a join b cross join c", ["NATURAL", "CROSS"]),
		("a join b on (x=y) cross join c", ["CROSS", "CROSS"]),
		("a join b using (x,y) join c", ["USING", "NATURAL"]),
		("a join b using (x,y) join c using (z,v)", ["USING", "USING"]),
		("(a join b using (x,y)) join c using (z,v)", ["USING", "USING"]),
# 10
		("(a join b) cross join (c join d)", ["NATURAL", "CROSS", "NATURAL"]),
	]


spatialFields = [
	MS(rscdef.Column, name="dist", ucd="phys.distance", unit="m"),
	MS(rscdef.Column, name="width", ucd="phys.dim", unit="m"),
	MS(rscdef.Column, name="height", ucd="phys.dim", unit="km"),
	MS(rscdef.Column, name="ra1", ucd="pos.eq.ra", unit="deg"),
	MS(rscdef.Column, name="ra2", ucd="pos.eq.ra", unit="rad"),
	MS(rscdef.Column, name="gibtnet", ucd="invalid", unit="junk", hidden=True),]
spatial2Fields = [
	MS(rscdef.Column, name="ra1", ucd="pos.eq.ra;meta.main", unit="deg"),
	MS(rscdef.Column, name="dec", ucd="pos.eq.dec;meta.main", unit="deg"),
	MS(rscdef.Column, name="dist", ucd="phys.distance", unit="m"),
	MS(rscdef.Column, name="t", ucd="time.epoch", unit="h")]
miscFields = [
	MS(rscdef.Column, name="mass", ucd="phys.mass", unit="kg"),
	MS(rscdef.Column, name="mag", ucd="phot.mag", unit="mag"),
	MS(rscdef.Column, name="speed", ucd="phys.veloc", unit="km/s")]
quotedFields = [
	MS(rscdef.Column, name=utils.QuotedName("left-right"), ucd="mess", 
		unit="bg"),
	MS(rscdef.Column, name=utils.QuotedName('inch"ing'), ucd="imperial.mess",
		unit="fin"),
	MS(rscdef.Column, name=utils.QuotedName('plAin'), ucd="boring.stuff",
		unit="pc"),
	MS(rscdef.Column, name=utils.QuotedName('alllower'), ucd="simple.case",
		unit="km"),]
crazyFields = [
	MS(rscdef.Column, name="ct", type="integer"),
	MS(rscdef.Column, name="wot", type="bigint", 
		values=MS(rscdef.Values, nullLiteral="-1")),
	MS(rscdef.Column, name="wotb", type="bytea", 
		values=MS(rscdef.Values, nullLiteral="254")),
	MS(rscdef.Column, name="mass", ucd="event;using.incense"),
	MS(rscdef.Column, name="name", type="unicode"),
	MS(rscdef.Column, name="version", type="text"),
	MS(rscdef.Column, name="flag", type="char"),
	MS(rscdef.Column, name="vals", type="real[]", ucd="some.value", unit="yr")]
geoFields = [
	MS(rscdef.Column, name="pt", type="spoint"),
	MS(rscdef.Column , name="dt", type="timestamp", ucd="time;obs"),
]

def _addSpatialSTC(sf, sf2, geo):
	ast1 = stc.parseQSTCS('Position ICRS "ra1" "dec" Size "width" "height"')
	ast2 = stc.parseQSTCS('Position FK4 SPHER3 "ra2" "dec" "dist"')
	# XXX TODO: get utypes from ASTs
	sf[0].stc, sf[0].stcUtype = ast2, None
	sf[1].stc, sf[1].stcUtype = ast1, None
	sf[2].stc, sf[2].stcUtype = ast1, None
	sf[3].stc, sf[3].stcUtype = ast1, None
	sf[4].stc, sf[4].stcUtype = ast2, None
	sf2[0].stc, sf2[0].stcUtype = ast1, None
	sf2[1].stc, sf2[0].stcUtype = ast1, None
	sf2[2].stc, sf2[0].stcUtype = ast2, None
	ast3 = stc.parseQSTCS('Time TT BARYCENTER "dt" Position GALACTIC [pt]')
	geo[0].stc = ast3
	geo[1].stc = ast3
_addSpatialSTC(spatialFields, spatial2Fields, geoFields)


class _MTH(object):
	@classmethod
	def getTableDefForTable(cls, tableName):
		return {
			'spatial': spatialFields,
			'spatial2': spatial2Fields,
			'misc': miscFields,
			'quoted': quotedFields,
			'crazy': crazyFields,
			'geo': geoFields}[tableName]


class _SampleFieldInfoGetter(adqlglue.DaCHSFieldInfoGetter):
	def __init__(self, *args):
		adqlglue.DaCHSFieldInfoGetter.__init__(self)
		self.mth = _MTH

_sampleFieldInfoGetter = _SampleFieldInfoGetter()


def parseWithArtificialTable(query):
	parsedTree = adql.getGrammar().parseString(query)[0]
	ctx = adql.annotate(parsedTree, _sampleFieldInfoGetter)
	return parsedTree


class TypecalcTest(testhelpers.VerboseTest):
	__metaclass__ = testhelpers.SamplesBasedAutoTest

	def _runTest(self, sample):
		inTypes, result = sample
		self.assertEqual(adql.getSubsumingType(inTypes), result)
	
	samples = [
		(["double precision", "integer", "bigint"], 'double precision'),
		(["date", "timestamp", "timestamp"], 'timestamp'),
		(["date", "boolean", "smallint"], 'text'),
		(["box", "raw"], 'raw'),
		(["date", "time"], 'timestamp'),
# 5
		(["char(3)", "integer"], "text"),
		(["double precision", "char(3)"], "text"),
		(["integer[3]", "bigint"], "bigint[]"),
		(["integer", "smallint", "double precision[]"], "double precision[]"),
		(["integer[][]", "smallint", "double precision[]"], "double precision[]"),
# 10
		# I would give you the next is plain wrong, but I'm relying on postgres
		# to reject such nonsence in the first place.
		(["double precision[340]", "char(3)"], "text"),
		(["boolean", "boolean"], "boolean"),
		(["boolean", "smallint"], "smallint"),
		(["sbox", "spoint"], "text"),
		(["sbox", "spoly"], "spoly"),
		(["sbox", "whacko"], "raw"),
	]


class ColumnTest(testhelpers.VerboseTest):
	def setUp(self):
		testhelpers.VerboseTest.setUp(self)
		self.fieldInfoGetter = _sampleFieldInfoGetter
		self.grammar = adql.getGrammar()

	def _getColSeqAndCtx(self, query):
		t = self.grammar.parseString(query)[0]
		ctx = adql.annotate(t, self.fieldInfoGetter)
		return t.fieldInfos.seq, ctx

	def _getColSeq(self, query):
		return self._getColSeqAndCtx(query)[0]

	def _assertColumns(self, resultColumns, assertProperties):
		self.assertEqual(len(resultColumns), len(assertProperties))
		for index, ((name, col), (type, unit, ucd, taint)) in enumerate(zip(
				resultColumns, assertProperties)):
			if type is not None:
				self.assertEqual(col.type, type, "Type %d: %r != %r"%
					(index, col.type, type))
			if unit is not None:
				self.assertEqual(col.unit, unit, "Unit %d: %r != %r"%
					(index, col.unit, unit))
			if ucd is not None:
				self.assertEqual(col.ucd, ucd, "UCD %d: %r != %r"%
					(index, col.ucd, ucd))
			if taint is not None:
				self.assertEqual(col.tainted, taint, "Taint %d: should be %s"%
					(index, taint))


class SelectClauseTest(ColumnTest):
	def testConstantSelect(self):
		cols = self._getColSeq("select 1, 'const' from spatial")
		self._assertColumns(cols, [
			("smallint", "", "", False),
			("text", "", "", False),])

	def testConstantExprSelect(self):
		cols = self._getColSeq("select 1+0.1, 'const'||'ab' from spatial")
		self._assertColumns(cols, [
			("double precision", "", "", True),
			("text", "", "", True),])

	def testConstantSelectWithAs(self):
		cols = self._getColSeq("select 1+0.1 as x from spatial")
		self._assertColumns(cols, [
			("double precision", "", "", True),])

	def testSimpleColumn(self):
		cols = self._getColSeq("select mass from misc")
		self._assertColumns(cols, [
			("real", "kg", "phys.mass", False),])

	def testBadRefRaises(self):
		self.assertRaises(adql.ColumnNotFound, self._getColSeq, 
			"select x, foo.* from spatial, misc")

	def testQualifiedStarSingle(self):
		cols = self._getColSeq("select misc.* from misc")
		self._assertColumns(cols, [
			("real", "kg", "phys.mass", False),
			("real", "mag", "phot.mag", False),
			("real", "km/s", "phys.veloc", False),])

	def testQualifiedStar(self):
		cols = self._getColSeq("select misc.* from spatial, misc")
		self._assertColumns(cols, [
			("real", "kg", "phys.mass", False),
			("real", "mag", "phot.mag", False),
			("real", "km/s", "phys.veloc", False),])

	def testMixedQualifiedStar(self):
		cols = self._getColSeq("select misc.*, dist, round(mass/10)"
			" from spatial, misc")
		self._assertColumns(cols, [
			("real", "kg", "phys.mass", False),
			("real", "mag", "phot.mag", False),
			("real", "km/s", "phys.veloc", False),
			("real", "m", "phys.distance", False),
			("double precision", "kg", "phys.mass", True),])

	def testAliasedStar(self):
		cols = self._getColSeq("select misc.* from spatial join misc as foo"
			" on (spatial.dist=foo.mass)")
		self.assertEqual(len(cols), 3)

	def testFancyRounding(self):
		cols = self._getColSeq("select round(dist, 2) from spatial")
		self._assertColumns(cols, [
			("double precision", "m", "phys.distance", True)])


class ColResTest(ColumnTest):
	"""tests for resolution of output columns from various expressions.
	"""
	def testSimpleSelect(self):
		cols = self._getColSeq("select width, height from spatial")
		self.assertEqual(cols[0][0], 'width')
		self.assertEqual(cols[1][0], 'height')
		wInfo = cols[0][1]
		self.assertEqual(wInfo.unit, "m")
		self.assertEqual(wInfo.ucd, "phys.dim")
		self.assert_(wInfo.userData[0] is spatialFields[1])

	def testIgnoreCase(self):
		cols = self._getColSeq("select Width, hEiGHT from spatial")
		self._assertColumns(cols, [
			("real", "m", "phys.dim", False),
			("real", "km", "phys.dim", False),])

	def testStarSelect(self):
		cols = self._getColSeq("select * from spatial")
		self._assertColumns(cols, [
			("real", "m", "phys.distance", False),
			("real", "m", "phys.dim", False),
			("real", "km", "phys.dim", False),
			("real", "deg", "pos.eq.ra", False),
			("real", "rad", "pos.eq.ra", False), ])

	def testStarSelectJoined(self):
		cols = self._getColSeq("select * from spatial, misc")
		self._assertColumns(cols, [
			("real", "m", "phys.distance", False),
			("real", "m", "phys.dim", False),
			("real", "km", "phys.dim", False),
			("real", "deg", "pos.eq.ra", False),
			("real", "rad", "pos.eq.ra", False),
			("real", "kg", "phys.mass", False),
			("real", "mag", "phot.mag", False),
			("real", "km/s", "phys.veloc", False)])

	def testDimlessSelect(self):
		cols = self._getColSeq("select 3+4 from spatial")
		self.assertEqual(cols[0][1].type, "smallint")
		self.assertEqual(cols[0][1].unit, "")
		self.assertEqual(cols[0][1].ucd, "")

	def testSimpleScalarExpression(self):
		cols = self._getColSeq("select 2+width, 2*height, height*2"
			" from spatial")
		self._assertColumns(cols, [
			("real", "m", "", True),
			("real", "km", "phys.dim", True),
			("real", "km", "phys.dim", True),])
		self.assert_(cols[1][1].userData[0] is spatialFields[2])

	def testFieldOperandExpression(self):
		cols = self._getColSeq("select width*height, width/speed, "
			"3*mag*height, mag+height, height+height from spatial, misc")
		self._assertColumns(cols, [
			("real", "m*km", "", True),
			("real", "m/(km/s)", "", True),
			("real", "mag*km", "", True),
			("real", "", "", True),
			("real", "km", "phys.dim", True)])

	def testMiscOperands(self):
		cols = self._getColSeq("select -3*mag from misc")
		self._assertColumns(cols, [
			("real", "mag", "phot.mag", True)])

	def testSetFunctions(self):
		cols = self._getColSeq("select AVG(mag), mAx(mag), max(2*mag),"
			" Min(Mag), sum(mag), count(mag), avg(3), count(*) from misc")
		self._assertColumns(cols, [
			("double precision", "mag", "phot.mag;stat.mean", False),
			("real", "mag", "phot.mag;stat.max", False),
			("real", "mag", "phot.mag;stat.max", True),
			("real", "mag", "phot.mag;stat.min", False),
			("real", "mag", "phot.mag", False),
			("integer", "", "meta.number;phot.mag", False),
			("double precision", "", None, False),
			("integer", "", "meta.number", False)])

	def testNumericFunctions(self):
		cols = self._getColSeq("select acos(ra2), degrees(ra2), RadianS(ra1),"
			" PI(), ABS(width), Ceiling(Width), Truncate(height*2)"
			" from spatial")
		self._assertColumns(cols, [
			("double precision", "rad", "", True),
			("double precision", "deg", "pos.eq.ra", True),
			("double precision", "rad", "pos.eq.ra", True),
			("double precision", "", "", True),
			("double precision", "m", "phys.dim", True),
			("double precision", "m", "phys.dim", True),
			("double precision", "km", "phys.dim", True)])

	def testAggFunctions(self):
		cols = self._getColSeq("select max(ra1), min(ra1) from spatial")
		self._assertColumns(cols, [
			("real", "deg", "pos.eq.ra;stat.max", False),
			("real", "deg", "pos.eq.ra;stat.min", False)])

	def testPoint(self):
		cols = self._getColSeq("select point('ICRS', ra1, ra2) from spatial")
		self._assertColumns(cols, [
			("spoint", 'deg,rad', '', False)])
		self.assert_(cols[0][1].userData[0] is spatialFields[3])

	def testDistance(self):
		cols = self._getColSeq("select distance(point('galactic', 2, 3),"
			" point('ICRS', ra1, ra2)) from spatial")
		self._assertColumns(cols, [
			("double precision", 'deg', 'pos.angDistance', False)])

	def testCentroid(self):
		cols = self._getColSeq("select centroid(circle('galactic', ra1, ra2, 0.5))"
			" from spatial")
		self._assertColumns(cols, [
			("spoint", '', '', False)])

	def testParenExprs(self):
		cols = self._getColSeq("select (width+width)*height from spatial")
		self._assertColumns(cols, [
			("real", "m*km", "", True)])

	def testSubquery(self):
		cols = self._getColSeq("select q.p from (select ra2 as p from"
			" spatial) as q")
		self._assertColumns(cols, [
			("real", 'rad', 'pos.eq.ra', False)])

	def testSubqueryStar(self):
		cols = self._getColSeq("select p, speed, q.*"
			" from (select speed, mag as p from misc) as q")
		self._assertColumns(cols, [
				("real", "mag", "phot.mag", False),
				("real", "km/s", "phys.veloc", False),
				("real", "km/s", "phys.veloc", False),
				("real", "mag", "phot.mag", False)])

	def testJoin(self):
		cols = self._getColSeq("select dist, speed, 2*mass*height"
			" from spatial join misc on (mass>height)")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False),
			("real", 'km/s', 'phys.veloc', False),
			("real", 'kg*km', '', True),])

	def testWhereResolutionPlain(self):
		cols = self._getColSeq("select dist from spatial where exists"
			" (select * from misc where dist=misc.mass)")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False)])

	def testWhereResolutionWithAlias(self):
		cols = self._getColSeq("select dist from spatial as q where exists"
			" (select * from misc where q.dist=misc.mass)")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False)])

	def testErrorReporting(self):
		self.assertRaises(adql.ColumnNotFound, self._getColSeq,
			"select gnurks from spatial")

	def testExpressionWithUnicode(self):
		cols = self._getColSeq("select crazy.name||geo.pt from crazy, geo")
		self._assertColumns(cols, [
			("unicode", '', '', True)])

	def testIdenticalNames(self):
		cols = self._getColSeq("SELECT u.ra1 FROM spatial AS mine"
  		" LEFT OUTER JOIN spatial2 as u"
  		" ON (1=CONTAINS(POINT('', mine.ra1, mine.ra2),"
  		"   CIRCLE('', u.ra1, u.dec, 1)))")
		self._assertColumns(cols, [
			("real", 'deg', 'pos.eq.ra;meta.main', False)])

	def testAliasedColumn(self):
		cols = self._getColSeq("SELECT foo, ra1 FROM ("
			"SELECT ra1 as foO, ra1 FROM spatial) as q")
		self._assertColumns(cols, [
			("real", 'deg', 'pos.eq.ra', False),
			("real", 'deg', 'pos.eq.ra', False)])


class ExprColTest(ColumnTest):
	def testCharConcat(self):
		cols = self._getColSeq("select flag||'ab' as cat from crazy")
		self._assertColumns(cols, [
			("text", '', "", True),])

	def testTextConcat(self):
		cols = self._getColSeq("select version||'ab' as cat from crazy")
		self._assertColumns(cols, [
			("text", '', "", True),])

	def testUnicodeConcat(self):
		cols = self._getColSeq("select name||'ab' as cat from crazy")
		self._assertColumns(cols, [
			("unicode", '', "", True),])

	def testUCDColSimple(self):
		cols = self._getColSeq("select UCDCOL('phys.mass') from misc")
		self._assertColumns(cols, [
			("real", "kg", "phys.mass", False)])

	def testUCDColPattern(self):
		cols = self._getColSeq("select UCDCOL('phys.mass'), UCDCOL('phys.dist*')"
			" from misc join spatial on (dist=mass)")
		self._assertColumns(cols, [
			("real", "kg", "phys.mass", False),
			("real", "m", "phys.distance", False)
			])

	def testUCDColFails(self):
		self.assertRaisesWithMsg(base.NotFoundError,
			"column matching ucd 'phys.mass' could not be located in from clause",
			self._getColSeq,
			("select UCDCOL('phys.mass') from spatial",))
	
	def testInUnit(self):
		cols = self._getColSeq("select IN_UNIT(height, 'm') from spatial")
		self._assertColumns(cols, [
			("real", "m", "phys.dim", False),])
	
	def testInUnitFailsIncompatibleUnit(self):
		self.assertRaisesWithMsg(adql.Error,
			"in_unit error: km and rad do not have the same SI base",
			self._getColSeq,
			("select IN_UNIT(height, 'rad') from spatial",))

	def testAggregateUDF(self):
		cols = self._getColSeq("select gavo_histogram(mass, 0, 100, 10) as h"
			" from misc")
		self._assertColumns(cols, [
			("integer[]", "", "stat.histogram;phys.mass", False)])

	def testArrayElement(self):
		cols = self._getColSeq("select vals[5] as h"
			" from crazy")
		self._assertColumns(cols, [
			("real", "yr", "some.value", True)])
	
	def testNonArraySubscript(self):
		self.assertRaisesWithMsg(adql.Error,
			"Cannot subscript a non-array in width [ 0 ]",
			self._getColSeq,
			("select width[0] from spatial",))

	def testNoNumberSubscript(self):
		self.assertRaisesWithMsg(base.ParseException,
			'Expected "FROM" (at char 8), (line:1, col:9)',
			self._getColSeq,
			("select 5[0] from spatial",))


class DelimitedColResTest(ColumnTest):
	"""tests for column resolution with delimited identifiers.
	"""
	def testCaseSensitive(self):
		self.assertRaises(adql.ColumnNotFound, self._getColSeq,
			'select "Inch""ing" from quoted')

	def testMixedCase(self):
		cols = self._getColSeq('select "plAin" from quoted')
		self.assertEqual(cols[0][0], utils.QuotedName("plAin"))

	def testNoFoldToRegular(self):
		self.assertRaises(adql.ColumnNotFound, self._getColSeq,
			'select plain from quoted')

	def testDelimitedMatchesRegular(self):
		cols = self._getColSeq('select "mass" from misc')
		self.assertEqual(cols[0][0], "mass")

	def testConstantSelectWithAs(self):
		cols = self._getColSeq('select 1+0.1 as "x" from spatial')
		self.assertEqual(cols[0][0], "x")

	def testRegularMatchesDelmitied(self):
		cols = self._getColSeq('select alllower from quoted')
		self.assertEqual(cols[0][0], "alllower")

	def testSimpleStar(self):
		cols = self._getColSeq("select * from quoted")
		self._assertColumns(cols, [
			("real", 'bg', "mess", False),
			("real", 'fin', "imperial.mess", False),
			("real", 'pc', "boring.stuff", False),
			("real", 'km', "simple.case", False),])
	
	def testSimpleJoin(self):
		cols = self._getColSeq('select "inch""ing", "mass" from misc join'
			' quoted on ("left-right"=speed)')
		self._assertColumns(cols, [
			("real", 'fin', "imperial.mess", False),
			("real", 'kg', 'phys.mass', False)])

	def testPlainAndSubselect(self):
		cols = self._getColSeq('select "inch""ing", alllower from ('
			'select TOP 5 * from quoted where alllower<"inch""ing") as q')
		self._assertColumns(cols, [
			("real", 'fin', "imperial.mess", False),
			("real", 'km', "simple.case", False),])
	
	def testQuotedExpressions(self):
		cols = self._getColSeq('select 4*alllower*"inch""ing" from quoted')
		self._assertColumns(cols, [
			("real", 'km*fin', None, True)])

	def testReferencingRegulars(self):
		cols = self._getColSeq('select "ra1" from spatial')
		self._assertColumns(cols, [
			("real", 'deg', "pos.eq.ra", False)])


class JoinColResTest(ColumnTest):
	def testJoin(self):
		cols = self._getColSeq("select dist, speed, 2*mass*height"
			" from spatial join misc on (mass>height)")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False),
			("real", 'km/s', 'phys.veloc', False),
			("real", 'kg*km', '', True),])

	def testJoinStar(self):
		cols = self._getColSeq("select * from spatial as q join misc as p on"
			" (1=contains(point('ICRS', q.dist, q.width), circle('ICRS',"
			" p.mass, p.mag, 0.02)))")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False),
			("real", 'm', 'phys.dim', False),
			("real", 'km', 'phys.dim', False),
			("real", 'deg', 'pos.eq.ra', False),
			("real", 'rad', 'pos.eq.ra', False),
			("real", 'kg', 'phys.mass', False),
			("real", 'mag', 'phot.mag', False),
			("real", 'km/s', 'phys.veloc', False),
			])

	def testSubqueryJoin(self):
		cols = self._getColSeq("SELECT * FROM ("
  		"SELECT ALL q.mass, spatial.ra1 FROM ("
    	"  SELECT TOP 100 mass, mag FROM misc"
      "    WHERE speed BETWEEN 0 AND 1) AS q JOIN"
    	"  spatial ON (mass=width)) AS f")
		self._assertColumns(cols, [
			("real", 'kg', 'phys.mass', False),
			("real", 'deg', 'pos.eq.ra', False)])

	def testAutoJoin(self):
		cols = self._getColSeq("SELECT * FROM misc JOIN"
			" (SELECT TOP 3 * FROM crazy) AS q ON (mag=q.ct)")
		physMass = cols[0]
		self.assertEqual(physMass[0], "mass")
		self.assertEqual(physMass[1].ucd, "phys.mass")
		crazyMass = cols[6]
		self.assertEqual(crazyMass[0], "mass")
		self.assertEqual(crazyMass[1].ucd, "event;using.incense")

	def testSelfUsingJoin(self):
		cols = self._getColSeq("SELECT * FROM "
    	" misc JOIN misc AS u USING (mass)")
		self._assertColumns(cols, [
			("real", 'kg', 'phys.mass', False),
			("real", 'mag', 'phot.mag', False),
			("real", 'km/s', 'phys.veloc', False),
			("real", 'mag', 'phot.mag', False),
			("real", 'km/s', 'phys.veloc', False) ])

	def testExReferenceBad(self):
		self.assertRaises(adql.TableNotFound, self._getColSeq,
			"select foo.dist from spatial join misc on (mass>height)")

	def testExReference(self):
		cols = self._getColSeq("select a.dist, b.dist"
			" from spatial as a join spatial as b on (a.dist>b.dist)")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False),
			("real", 'm', 'phys.distance', False)])

	def testExReferenceMixed(self):
		cols = self._getColSeq("select spatial.dist, b.speed"
			" from spatial as a join misc as b on (a.dist>b.speed)")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False),
			("real", 'km/s', 'phys.veloc', False)])
	
	def testNaturalJoin(self):
		cols = self._getColSeq("SELECT * FROM"
			" spatial JOIN spatial2 WHERE dist<2")
		self._assertColumns(cols, [
			("real", "m", "phys.distance", False),
			("real", "m", "phys.dim", False),
			("real", "km", "phys.dim", False),
			("real", "deg", "pos.eq.ra", False),
			("real", "rad", "pos.eq.ra", False),
			("real", "deg", "pos.eq.dec;meta.main", False),
			("real", "h", "time.epoch", False)])

	def testNaturalJoinSubquery(self):
		cols = self._getColSeq("SELECT dist, width FROM"
			" spatial JOIN spatial2 WHERE dist IN (SELECT spatial2.dist FROM spatial2)")
		self._assertColumns(cols, [
			("real", "m", "phys.distance", False),
			("real", "m", "phys.dim", False)])

	def testUsingJoin1(self):
		cols = self._getColSeq("SELECT * FROM"
			" spatial JOIN spatial2 USING (ra1)")
		self._assertColumns(cols, [
			("real", "m", "phys.distance", False),
			("real", "m", "phys.dim", False),
			("real", "km", "phys.dim", False),
			("real", "deg", "pos.eq.ra", False),
			("real", "rad", "pos.eq.ra", False),
			("real", "deg", "pos.eq.dec;meta.main", False),
			("real", "m", "phys.distance", False),
			("real", "h", "time.epoch", False)])

	def testUsingJoin2(self):
		cols = self._getColSeq("SELECT * FROM"
			" spatial JOIN spatial2 USING (ra1, dist)")
		self._assertColumns(cols, [
			("real", "m", "phys.distance", False),
			("real", "m", "phys.dim", False),
			("real", "km", "phys.dim", False),
			("real", "deg", "pos.eq.ra", False),
			("real", "rad", "pos.eq.ra", False),
			("real", "deg", "pos.eq.dec;meta.main", False),
			("real", "h", "time.epoch", False)])

	def testUsingJoin3(self):
		cols = self._getColSeq("SELECT ra1, dec, mass FROM"
			" spatial JOIN spatial2 USING (ra1, dist) JOIN misc ON (dist=mass)")
		self._assertColumns(cols, [
			("real", "deg", "pos.eq.ra", False),
			("real", "deg", "pos.eq.dec;meta.main", False),
			("real", "kg", "phys.mass", False),])

	def testUsingJoin4(self):
		cols = self._getColSeq("SELECT ra1, dec, mass FROM"
			" (SELECT * FROM spatial) as q JOIN spatial2"
			" USING (ra1, dist) JOIN misc ON (dist=mass)")
		self._assertColumns(cols, [
			("real", "deg", "pos.eq.ra", False),
			("real", "deg", "pos.eq.dec;meta.main", False),
			("real", "kg", "phys.mass", False),])
	
	def testCommaAll(self):
		cols = self._getColSeq("SELECT * from spatial, spatial, misc")
		self.assertEqual([c[1].userData[0].name for c in cols], [
			'dist', 'width', 'height', 'ra1', 'ra2', 'dist', 'width', 
			'height', 'ra1', 'ra2', 'mass', 'mag', 'speed'])

	def testHaving1(self):
		cols = self._getColSeq(
			"SELECT ct FROM crazy "
			"JOIN ("
			"  SELECT height FROM spatial"
			"  JOIN spatial2 ON (ra2=dist)"
			"  GROUP BY height"
			"  HAVING (height>avg(dist))) AS q "
			"ON (wot=height)")
		self._assertColumns(cols, [
			('integer', '', '', False)])


class SetColResTest(ColumnTest):
	def testSimple(self):
		cols = self._getColSeq("select dist, height from spatial"
			" union select dist, height from spatial")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False),
			("real", 'km', 'phys.dim', False)])

	def testStars(self):
		cols = self._getColSeq("select 0 as i, misc.* from misc"
			" union select 1 as i, misc.* from misc")
		self._assertColumns(cols, [
			("smallint", '', '', False),
			("real", 'kg', 'phys.mass', False),
			("real", 'mag', 'phot.mag', False),
			("real", 'km/s', 'phys.veloc', False)])

	def testLengthFailure(self):
		self.assertRaisesWithMsg(adql.IncompatibleTables,
			"Operands in set operation have differing result tuple lengths.",
			self._getColSeq,
			("select dist, width, height from spatial"
			" union select dist, height from spatial",))

	def testName(self):
		self.assertRaisesWithMsg(adql.IncompatibleTables,
			"Operands if set operation have differing names.  First differing name: width vs. dist",
			self._getColSeq,
			("select width, height from spatial"
			" union select dist, height from spatial",))

	def testAliasing(self):
		cols = self._getColSeq("select dist, height from spatial"
			" union select dist, ra1 as height from spatial2"
			" intersect select mass as dist, mag as height from misc"
			' except select "left-right" as dist, "plAin" as height from quoted')
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False),
			("real", 'km', 'phys.dim', False)])

	def testNested(self):
		cols = self._getColSeq("select dist, height from spatial"
			" union select ra1 as dist, dec as height from ("
			"   select ra1, dec, dist from spatial2"
			"   except select mag as ra1, mass as dec, speed as dist from misc) as q"
			"  where dist>2")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False),
			("real", 'km', 'phys.dim', False)])

	def testSetGeneratingFunction(self):
		cols = self._getColSeq("select * from generate_series(1, 4)")
		self._assertColumns(cols, [
			('integer', None, None, False)])
		self.assertEqual(cols[0][0], "generate_series")

	def testSetGeneratingFunctionAlias(self):
		cols = self._getColSeq("select * from generate_series(1, 4) as q")
		self._assertColumns(cols, [
			('integer', None, None, False)])
		self.assertEqual(cols[0][0], "q")

	def testSetGeneratingFunctionJoin(self):
		cols = self._getColSeq("select mass, q from generate_series(1, 4) as q"
			" join misc on (q=speed)")
		self._assertColumns(cols, [
			('real', 'kg', 'phys.mass', False),
			('integer', None, None, False)])
		self.assertEqual(cols[1][0], "q")



class CastColResTest(ColumnTest):
	def testSimple(self):
		cols = self._getColSeq("select dist,"
			" cast(dist as char(10)) as d_str,"
			" cast(dist as CHAR) as d_chr,"
			" cast(dist as natIonal char(13)) as d_uni,"
			" cast(dist as national char) as d_uni,"
			" cast(dist as integer) as d_int,"
			" cast(dist as bigint) as d_long,"
			" cast(ra1 as smallint) as ra_short,"
			" cast(dist as real) as d_float,"
			" cast(dist as double precision) as d_double,"
			" cast(dist as timestamp) as d_ts,"
			" cast(NULL as bigint) as n_long"
			" from spatial")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False),
			("text", 'm', 'phys.distance', True),
			("char", 'm', 'phys.distance', True),
			("unicode", 'm', 'phys.distance', True),
			("unicode", 'm', 'phys.distance', True),
			("integer", 'm', 'phys.distance', True),
			("bigint", 'm', 'phys.distance', True),
			("smallint", 'deg', 'pos.eq.ra', True),
			("real", 'm', 'phys.distance', True),
			("double precision", 'm', 'phys.distance', True),
			("timestamp", 'm', 'phys.distance', True),
			("bigint", '', '', True),
			])


class WithColResTest(ColumnTest):
	def testSimple(self):
		cols = self._getColSeq("with hollow as (select * from spatial)"
			" select dist, ra2 from hollow")
		self._assertColumns(cols, [
			("real", "m", "phys.distance", False),
			("real", "rad", "pos.eq.ra", False),])

	def testWithSetOperations(self):
		cols = self._getColSeq("with hollow as"
			" (select * from spatial where dist>5"
			"   union select * from spatial where ra1<3)"
			" select dist, ra2 from hollow")
		self._assertColumns(cols, [
			("real", "m", "phys.distance", False),
			("real", "rad", "pos.eq.ra", False),])

	def testMultipleWith(self):
		cols = self._getColSeq("with hollow as (select dist, ra2 from spatial),"
			"   filled as (select mass as dist, mag from misc)"
			" select ra2, mag from hollow natural join filled")
		self._assertColumns(cols, [
			("real", "rad", "pos.eq.ra", False),
			("real", "mag", "phot.mag", False)])


class _UploadTDWithOID(testhelpers.TestResource):
	def make(self, ignored):
		from gavo import votable
		from gavo.formats import votableread
		from cStringIO import StringIO

		rows =votable.parse(StringIO(
			"""<VOTABLE><RESOURCE><TABLE>
				<FIELD name="oid" datatype="float"/>
				<DATA><TABLEDATA><TR><TD>1</TD></TR></TABLEDATA></DATA>
				</TABLE></RESOURCE></VOTABLE>""")).next()
		return votableread.makeTableDefForVOTable(
			"foo", rows.tableDefinition, votableread.AutoQuotedNameMaker())

_uploadTDWithOID = _UploadTDWithOID()


class UploadColResTest(ColumnTest):
	resources = [("nastyTD", _uploadTDWithOID)]

	def setUp(self):
		ColumnTest.setUp(self)
		self.fieldInfoGetter = adqlglue.DaCHSFieldInfoGetter(tdsForUploads=[
			testhelpers.getTestTable("adql")])
	
	def testNormalResolution(self):
		cols = self._getColSeq("select alpha, rv from TAP_UPLOAD.adql")
		self._assertColumns(cols, [
			("real", 'deg', 'pos.eq.ra;meta.main', False),
			("double precision", 'km/s', 'phys.veloc;pos.heliocentric', False),])

	def testFailedResolutionCol(self):
		self.assertRaises(base.NotFoundError, self._getColSeq,
			'select alp, rv from TAP_UPLOAD.adql')
	
	def testFailedResolutionTable(self):
		self.assertRaises(base.NotFoundError, self._getColSeq,
			'select alpha, rv from TAP_UPLOAD.junk')

	def testPGForbiddenNames(self):
		self.fieldInfoGetter = adqlglue.DaCHSFieldInfoGetter(
			tdsForUploads=[self.nastyTD])
		cols = self._getColSeq(
			"select q.*, q.oid from (select oid from tap_upload.foo) as q")
		self.assertEqual(cols[0][0], "oid_")
		self.assertEqual(cols[1][0], "oid_")


class STCTest(ColumnTest):
	"""tests for working STC inference in ADQL expressions.
	"""
	def testSimple(self):
		cs = self._getColSeq("select ra1, ra2 from spatial")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, 'ICRS')
		self.assertEqual(cs[1][1].stc.astroSystem.spaceFrame.refFrame, 'FK4')

	def testBroken(self):
		cs = self._getColSeq("select ra1+ra2 from spatial")
		self.failUnless(hasattr(cs[0][1].stc, "broken"))

	def testOKPoint(self):
		cs, ctx = self._getColSeqAndCtx(
			"select point('ICRS', ra1, 2) from spatial")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, 'ICRS')
		self.assertEqual(ctx.errors, [])

	def testEmptyCoosysInherits(self):
		cs, ctx = self._getColSeqAndCtx(
			"select point('', ra2, 2) from spatial")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, 'FK4')
		self.assertEqual(ctx.errors, [])

	def testEmptyCoosysBecomesNone(self):
		cs, ctx = self._getColSeqAndCtx(
			"select point('', mag, 2) from misc")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, None)
		self.assertEqual(ctx.errors, [])

	def testMissingCoosysBecomesNone(self):
		cs, ctx = self._getColSeqAndCtx(
			"select point('', mag, 2) from misc")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, None)
		self.assertEqual(ctx.errors, [])

	def testNULLCoosysInherits(self):
		cs, ctx = self._getColSeqAndCtx(
			"select point(NULL, ra1, 2) from spatial")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, 'ICRS')
		self.assertEqual(ctx.errors, [])

	def testMissingCoosysInherits(self):
		cs, ctx = self._getColSeqAndCtx(
			"select point(ra2, 2) from spatial")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, 'FK4')
		self.assertEqual(ctx.errors, [])

	def testMissingCoosysInheritsCircle(self):
		cs, ctx = self._getColSeqAndCtx(
			"select circle(ra2, dist, 2) from spatial")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, 'FK4')
		self.assertEqual(ctx.errors, [])

	def testPointBadCoo(self):
		cs, ctx = self._getColSeqAndCtx(
			"select point('ICRS', ra2, 2) from spatial")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, 'ICRS')
		self.assertEqual(ctx.errors, ['When constructing point:'
			' Argument 1 has incompatible STC'])

	def testPointFunctionsSelect(self):
		cs, ctx = self._getColSeqAndCtx(
			"select coordsys(p), coord1(p), coord2(p) from"
			"	(select point('FK5', ra1, width) as p from spatial) as q")
		self._assertColumns(cs, [
			("text", '', 'meta.ref;pos.frame', False),
			("double precision", 'deg', None, False),
			("double precision", 'deg', None, False)])

	def testBadSTCSRegion(self):
		self.assertRaisesWithMsg(adql.RegionError, 
			"Invalid argument to REGION: 'Time TT'.",
			self._getColSeqAndCtx, (
				"select * from spatial where 1=intersects("
				"region('Time TT'), circle('icrs', 1, 1, 0.1))",))

	def testRegionExpressionRaises(self):
		self.assertRaisesWithMsg(adql.RegionError, 
			"Invalid argument to REGION: ''Position'||alphaName||deltaName'.",
			self._getColSeqAndCtx, (
				"select * from spatial where 1=intersects("
				"region('Position' || alphaName || deltaName),"
				" circle('icrs', 1, 1, 0.1))",))

	def testSTCSRegion(self):
		cs, ctx = self._getColSeqAndCtx(
				"select region('Circle FK4 10 10 1')"
				" from spatial")
		self.assertEqual(cs[0][1].unit, "deg")
	
	def testPolygonInheritsGeo(self):
		cs, ctx = self._getColSeqAndCtx(
				"select polygon(pt, point(1, 2), point(3,4))"
				" from geo")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, 
			'GALACTIC_II')

	def testPolygonInheritsSplit(self):
		cs, ctx = self._getColSeqAndCtx(
				"select polygon(ra2, dist, ra1, height, 2, 3)"
				" from spatial")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, 
			'FK4')
	
	def timestampMetaInference(self):
		cs, ctx = self._getColSeqAndCtx(
			"select timestamp('2017-12-01'), TIMESTAMP(dt)"
			" from geo")
		self.assertEqual(cs[0][1].ucd, '')
		self.assertEqual(cs[0][1].unit, '')
		self.assertEqual(cs[0][1].type, 'timestamp')
		self.assertEqual(cs[0][1].stc, None)
		self.assertEqual(cs[1][1].ucd, 'time;obs')
		self.assertEqual(cs[1][1].type, 'timestamp')
		self.assertEqual(cs[1][1].stc.astroSystem.timeFrame.timeScale, "TT")


class FunctionNodeTest(unittest.TestCase):
	"""tests for nodes.FunctionMixin and friends.
	"""
	def setUp(self):
		self.grammar = adql.getGrammar()
	
	def testPlainArgparse(self):
		t = self.grammar.parseString("select POINT('ICRS', width,height)"
			" from spatial")[0]
		p = t.selectList.selectFields[0].expr
		self.assertEqual(p.cooSys, "ICRS")
		self.assertEqual(nodes.flatten(p.x), "width")
		self.assertEqual(nodes.flatten(p.y), "height")

	def testExprArgparse(self):
		t = self.grammar.parseString("select POINT('ICRS', "
			"5*width+height*LOG(width),height)"
			" from spatial")[0]
		p = t.selectList.selectFields[0].expr
		self.assertEqual(p.cooSys, "ICRS")
		self.assertEqual(nodes.flatten(p.x), "5 * width + height * LOG(width)")
		self.assertEqual(nodes.flatten(p.y), "height")


class ComplexExpressionTest(unittest.TestCase):
	"""quite random tests for correct processing of complex-ish search expressions.
	"""
	def testOne(self):
		t = adql.getGrammar().parseString("select top 5 * from"
			" lsw.plates where dateobs between 'J2416642 ' and 'J2416643'")[0]
		self.assertEqual(t.whereClause.children[1].name, "dateobs")
		self.assertEqual(adql.flatten(t.whereClause.children[-1]), "'J2416643'")


class NameSuggestingTest(testhelpers.VerboseTest):
	__metaclass__ = testhelpers.SamplesBasedAutoTest

	def _runTest(self, sample):
		query, name = sample
		t = adql.getGrammar().parseString(query)[0]
		self.assertEqual(t.suggestAName(), name)

	samples = [
		("select * from plaintable", "plaintable"),
		('select * from "plaintable"', "_plaintable"),
		('select * from "useless & table"" name"', "_uselesstablename"),
		('select * from "3columns"', "_3columns"),
		('select * from t1 join t2', "t1_t2"),
# 5
		('select * from (select * from gnug) as booger', "booger"),
		('select * from (select * from gnug) as booger join boog', 
			"booger_boog"),]


class _FlatteningTest(testhelpers.VerboseTest):
	def _assertFlattensTo(self, rawADQL, flattenedADQL):
		self.assertEqual(adql.flatten(adql.parseToTree(rawADQL)),
			flattenedADQL)


class MiscFlatteningTest(_FlatteningTest):
	"""tests for flattening of plain ADQL trees.
	"""
	def testCircle(self):
		self._assertFlattensTo("select alphaFloat, deltaFloat from ppmx.data"
				" where contains(point('ICRS', alphaFloat, deltaFloat), "
				" circle('ICRS', 23, 24, 0.2))=1",
			"SELECT alphaFloat, deltaFloat FROM ppmx.data WHERE"
				" CONTAINS(POINT(alphaFloat, deltaFloat),"
				" CIRCLE(ICRS,23,24,0.2)) = 1")

	def testFunctions(self):
		self._assertFlattensTo(
			"select round(x,2)as a, truncate(x,-2) as b from foo",
			"SELECT ROUND(x, 2) AS a, TRUNCATE(x, - 2) AS b FROM foo")

	def testJoin(self):
		self._assertFlattensTo(
			"SELECT ra1, dec, mass FROM\n"
			" (SELECT * FROM spatial) as q LEFT OUTER JOIN spatial2\n"
			" USING (ra1, dist) JOIN misc ON (dist=mass)",
			"SELECT ra1, dec, mass FROM (SELECT * FROM spatial) AS q"
			" LEFT OUTER JOIN spatial2 USING ( ra1 , dist ) JOIN misc"
			" ON ( dist = mass )")

	def testCommaJoin(self):
		self._assertFlattensTo(
			"SELECT ra1, dec, mass FROM\n spatial, spatial2, misc",
			"SELECT ra1, dec, mass FROM spatial , spatial2 , misc")

	def testSubJoin(self):
		self._assertFlattensTo(
			"SELECT ra1, dec, mass FROM\n"
			" (spatial join spatial2 using (ra1)), misc",
			"SELECT ra1, dec, mass FROM"
			" (spatial JOIN spatial2 USING ( ra1 )) , misc")

	def testConcat(self):
		self._assertFlattensTo(
			"select 'ivo://' ||  name || '%' as pat from crazy",
			"SELECT 'ivo://' || name || '%' AS pat FROM crazy")
	
	def testAliasExpr(self):
		self._assertFlattensTo(
			"select a+b/(8+x) as num from crazy",
			"SELECT a + b / ( 8 + x ) AS num FROM crazy")


class CommentTest(_FlatteningTest):
	def testTopComment(self):
		self._assertFlattensTo("-- opening remarks;\n"
		"-- quite a few of them, actually.\nselect * from foo",
			"SELECT * FROM foo")
	
	def testEmbeddedComments(self):
		self._assertFlattensTo("select -- comment\n"
			"bar, --comment\n"
			"quux --comment\n"
			"from -- comment\n"
			"foo --comment",
			"SELECT bar, quux FROM foo")

	def testStringJoining(self):
		self._assertFlattensTo("select * from bar where a='qua' -- cmt\n'tsch'",
			"SELECT * FROM bar WHERE a = 'quatsch'")
	
	def testLeadingWhitespaceCleanup(self):
		self._assertFlattensTo("select * from--comment\n   bar",
			"SELECT * FROM bar")
	
	def testEquivalentToWhitespace(self):
		self._assertFlattensTo("select * from--comment\nbar",
			"SELECT * FROM bar")


class Q3CMorphTest(testhelpers.VerboseTest):
	"""tests the Q3C morphing of queries.
	"""
	def setUp(self):
		self.grammar = adql.getGrammar()
	
	def testCircleIn(self):
		s, t = morphpg.morphPG(
			adql.parseToTree("select alphaFloat, deltaFloat from ppmx.data"
			" where contains(point('ICRS', alphaFloat, deltaFloat), "
				" circle('ICRS', 23, 24, 0.2))=1"))
		self.assertEqual(adql.flatten(t),
			"SELECT alphaFloat, deltaFloat FROM ppmx.data WHERE"
				" q3c_join(23, 24, alphaFloat, deltaFloat, 0.2)")
	
	def testCircleOut(self):
		s, t = morphpg.morphPG(
			adql.parseToTree("select alphaFloat, deltaFloat from ppmx.data"
			" where 0=contains(point('ICRS', alphaFloat, deltaFloat),"
				" circle('ICRS', 23, 24, 0.2))"))
		self.assertEqual(adql.flatten(t),
			"SELECT alphaFloat, deltaFloat FROM ppmx.data WHERE"
				" NOT q3c_join(23, 24, alphaFloat, deltaFloat, 0.2)")

	def testConstantsFirst(self):
		s, t = morphpg.morphPG(
			adql.parseToTree("select alphaFloat, deltaFloat from ppmx.data"
			" where 0=contains(point('ICRS', 23, 24),"
				" circle('ICRS', alphaFloat, deltaFloat, 0.2))"))
		self.assertEqual(adql.flatten(t),
			"SELECT alphaFloat, deltaFloat FROM ppmx.data WHERE"
				" NOT q3c_join(23, 24, alphaFloat, deltaFloat, 0.2)")

	def _parseAnnotating(self, query):
		return adql.parseAnnotating(query, _sampleFieldInfoGetter)[1]

	def testCircleAnnotated(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating("SELECT TOP 10 * FROM spatial"
			" WHERE 1=CONTAINS(POINT('ICRS', ra1, ra2),"
			"  CIRCLE('ICRS', 10, 10, 0.5))"))
		self.assertEqual(adql.flatten(t),
			"SELECT spatial.dist, spatial.width, spatial.height, spatial.ra1, spatial.ra2 FROM spatial WHERE q3c_join(10, 10, ra1, ra2, 0.5) LIMIT 10")

	def testMogrifiedIntersect(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating("SELECT TOP 10 * FROM spatial"
			" WHERE 1=INTERSECTS(CIRCLE('ICRS', 10, 10, 0.5),"
				"POINT('ICRS', ra1, ra2))"))
		self.assertEqual(adql.flatten(t),
			"SELECT spatial.dist, spatial.width, spatial.height, spatial.ra1, spatial.ra2 FROM spatial WHERE q3c_join(10, 10, ra1, ra2, 0.5) LIMIT 10")

	def testDistanceTranslatedCrossmatch(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating("select ra1 from spatial as a join spatial as b"
			" where distance(a.ra1, a.ra2, b.ra1, b.ra2)<0.001"))
		morphed = adql.flatten(t)
		self.assertEqual(morphed, 
			"SELECT ra1 FROM spatial AS a JOIN spatial AS b"
			"  WHERE  q3c_join(a.ra1, a.ra2, b.ra1, b.ra2, 0.001)")

	def testDistanceTranslatedInvertedCrossmatch(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating("select * from spatial as a join spatial as b"
			" where distance(a.ra1, a.ra2, b.ra1, b.ra2)>=0.001"))
		morphed = adql.flatten(t)
		self.assertEqual(morphed, "SELECT dist, width, height, ra1, ra2"
			" FROM spatial AS a JOIN spatial AS b  WHERE"
			" NOT  q3c_join(a.ra1, a.ra2, b.ra1, b.ra2, 0.001)")

	def testDistanceTranslatedSelect(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating(
				"select distance(4+ra1, ra2*2, ra2, dist) AS d from spatial"))
		self.assertEqual(adql.flatten(t), 
			"SELECT q3c_dist(4 + ra1, ra2 * 2, ra2, dist) AS d FROM spatial")

	def testDistanceWithCoord(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating(
				"select dist from spatial where"
				"  2>distance(coord1(ivo_healpix_center(5, 100)),"
				"	coord2(ivo_healpix_center(5, 100)), ra1, ra2)"))
		self.assertEqual(adql.flatten(t), 
			"SELECT dist FROM spatial WHERE  q3c_join(DEGREES(long(center_of_healpix_nest(5, 100))), DEGREES(lat(center_of_healpix_nest(5, 100))), ra1, ra2, 2)")

	def testCircleWithCoord(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating(
				"select dist from spatial where"
				"  1=contains(point(ra1, ra2), "
				" CIRCLE(coord1(ivo_healpix_center(5, 100)),"
				"	coord2(ivo_healpix_center(5, 100)), 2))"))
		self.assertEqual(adql.flatten(t), 
			"SELECT dist FROM spatial WHERE q3c_join(DEGREES(long(center_of_healpix_nest(5, 100))), DEGREES(lat(center_of_healpix_nest(5, 100))), ra1, ra2, 2)")

	def testPolygonConst(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating(
				"select polygon(coord1(p1), coord2(p1), coord1(p2), coord2(p2)," 
					"2,3) from (SELECT POINT(ra1, ra2) as p1, ivo_apply_pm(ra1, ra2,"
					" 1e-7, 2e-7, 10) as p2 from spatial) as q"))
		self.assertEqualIgnoringAliases(adql.flatten(t), 
			"SELECT (SELECT spoly(q.p) FROM (VALUES (0, spoint(RADIANS(DEGREES(long(p1))), RADIANS(DEGREES(lat(p1))))), (1, spoint(RADIANS(DEGREES(long(p2))), RADIANS(DEGREES(lat(p2))))), (2, spoint(RADIANS(2), RADIANS(3))) ORDER BY column1) as q(ind,p)) ASWHATEVER FROM (SELECT spoint(RADIANS(ra1), RADIANS(ra2)) AS p1, spoint(RADIANS(ra1 + 10 * 1e-7 / COS(RADIANS(ra2))), RADIANS(ra2 + 2e-7 * 10)) AS p2 FROM spatial) AS q")


class PQMorphTest(testhelpers.VerboseTest):
	"""tests for morphing to non-geometry ADQL syntax to postgres.
	"""
	resources = [("nastyTD", _uploadTDWithOID)]

	def _testMorph(self, stIn, stOut, fieldInfoGetter=None):
		tree = adql.parseToTree(stIn)
		if fieldInfoGetter:
			ctx = adql.annotate(tree, fieldInfoGetter)
		status, t = adql.morphPG(tree)
		flattened = nodes.flatten(t)
		self.assertEqualIgnoringAliases(nodes.flatten(t), stOut)

	def testSyntax(self):
		self._testMorph("select distinct top 10 x, y from foo offset 3", 
			'SELECT DISTINCT x, y FROM foo  LIMIT 10 OFFSET 3')

	def testWhitespace(self):
		self._testMorph("select\t distinct top\n\r\n    10 x, y from foo", 
			'SELECT DISTINCT x, y FROM foo LIMIT 10')
	
	def testGroupby(self):
		self._testMorph("select count(*), inc from ("
			" select round(x) as inc from foo) as q group by inc",
			"SELECT COUNT ( * ) ASWHATEVER, inc FROM"
			" (SELECT ROUND(x) AS inc FROM foo) AS q"
			" GROUP BY inc")

	def testTwoArgRound(self):
		self._testMorph(
			"select round(x, 2) as a, truncate(x, -2) as b from foo",
			'SELECT ROUND((x)*10^(2)) / 10^(2) AS a, TRUNC((x)*'
				'10^(- 2)) / 10^(- 2) AS b FROM foo')
	
	def testExprArgs(self):
		self._testMorph(
			"select truncate(round((x*2)+y, 4)) from foo",
			'SELECT TRUNC(ROUND((( x * 2 ) + y)*10^(4)) / 10^(4)) ASWHATEVER FROM foo')

	def testPointFunctionWithFieldInfo(self):
		t = adql.parseToTree("select coordsys(q.p) from "
			"(select point('ICRS', ra1, ra2) as p from spatial) as q")
		ctx = adql.annotate(t, _sampleFieldInfoGetter)
		self.assertEqual(ctx.errors[0], 
			'When constructing point: Argument 2 has incompatible STC')
		status, t = adql.morphPG(t)
		self.assertEqualIgnoringAliases(nodes.flatten(t), 
			"SELECT 'ICRS' ASWHATEVER FROM (SELECT spoint"
			"(RADIANS(ra1), RADIANS(ra2)) AS p FROM spatial) AS q")

	def testStringReplacedNumerics(self):
		self._testMorph("select square(x+x) from foo",
			"SELECT (x + x)^2 ASWHATEVER FROM foo")

	def testNumerics(self):
		self._testMorph("select log10(x), log(x), rand(), rand(5), "
			" TRUNCATE(x), TRUNCATE(x,3) from foo", 
			'SELECT LOG(x) ASWHATEVER, LN(x) ASWHATEVER, random() ASWHATEVER,'
				' random() ASWHATEVER, TRUNC('
				'x) ASWHATEVER, TRUNC((x)*10^(3)) / 10^(3) ASWHATEVER FROM foo')

	def testHarmless(self):
		self._testMorph("select delta*2, alpha*mag, alpha+delta"
			" from something where mag<-10",
			'SELECT delta * 2 ASWHATEVER, alpha * mag ASWHATEVER, alpha + delta ASWHATEVER FROM something'
			' WHERE mag < - 10')

	def testUnaryLogic(self):
		self._testMorph("select x from something where y not in (1,2)",
			'SELECT x FROM something WHERE y NOT IN ( 1 , 2 )')

	def testOrder(self):
		self._testMorph("select top 100 * from spatial where dist>10"
			" order by dist, height", 
			'SELECT spatial.dist, spatial.width, spatial.height, spatial.ra1, spatial.ra2 FROM spatial WHERE dist > 10 ORDER BY dist , height LIMIT 100',
			_sampleFieldInfoGetter)

	def testUploadKilled(self):
		self._testMorph("select * from TAP_UPLOAD.abc",
			"SELECT * FROM abc")

	def testAliasedUploadKilled(self):
		self._testMorph("select * from TAP_UPLOAD.abc as o",
			"SELECT * FROM abc AS o")

	def testUploadColRef(self):
		self._testMorph("select TAP_UPLOAD.abc.c from TAP_UPLOAD.abc",
			"SELECT abc.c FROM abc")
	
	def testUploadColRefInGeom(self):
		self._testMorph("select POINT('', TAP_UPLOAD.abc.b, TAP_UPLOAD.abc.c)"
			" from TAP_UPLOAD.abc",
			"SELECT spoint(RADIANS(abc.b), RADIANS(abc.c)) ASWHATEVER FROM abc")

	def testUploadColRefInGeomContains(self):
		self._testMorph("SELECT TAP_UPLOAD.user_table.ra FROM"
			" TAP_UPLOAD.user_table WHERE (1=CONTAINS(POINT('ICRS',"
			" usnob.data.raj2000, usnob.data.dej2000), CIRCLE('ICRS',"
			" TAP_UPLOAD.user_table.ra2000, a.dec2000, 0.016666666666666666)))",
			'SELECT user_table.ra FROM user_table WHERE ( q3c_join('
			"user_table.ra2000, a.dec2000, usnob.data.raj2000,"
			" usnob.data.dej2000, 0.016666666666666666) )")

	def testSTCSSingle(self):
		self._testMorph(
			"select * from foo where 1=CONTAINS(REGION('Position ICRS 1 2'), x)",
			"SELECT * FROM foo WHERE"
			" ((spoint '(0.0174532925,0.0349065850)') @ (x))")

	def testSTCSExpr(self):
		self._testMorph(
			"select * from foo where 1=CONTAINS("
				"REGION('Union ICRS (Position 1 2 Intersection"
				" (circle  1 2 3 box 1 2 3 4 circle 30 40 2))'),"
				" REGION('circle GALACTIC 1 2 3'))",
			 "SELECT * FROM foo WHERE ((spoint '(0.0174532925,0.0349065850)' @ ((scircle '< (0.0174532925, 0.0349065850), 0.0523598776 >')+strans(1.346356097441,-1.097319001837,0.574770524729)))) OR (((scircle '< (0.0174532925, 0.0349065850), 0.0523598776 >' @ ((scircle '< (0.0174532925, 0.0349065850), 0.0523598776 >')+strans(1.346356097441,-1.097319001837,0.574770524729)))) AND ((spoly '{(-0.0087266463,0.0000000000),(-0.0087266463,0.0698131701),(0.0436332313,0.0698131701),(0.0436332313,0.0000000000)}' @ ((scircle '< (0.0174532925, 0.0349065850), 0.0523598776 >')+strans(1.346356097441,-1.097319001837,0.574770524729)))) AND ((scircle '< (0.5235987756, 0.6981317008), 0.0349065850 >' @ ((scircle '< (0.0174532925, 0.0349065850), 0.0523598776 >')+strans(1.346356097441,-1.097319001837,0.574770524729)))))")
# TODO: Have a long, close look at this

	def testSTCSNotRegion(self):
		self._testMorph(
			"select * from foo where 1=INTERSECTS(REGION('NOT (circle  1 2 3)'), x)",
			"SELECT * FROM foo WHERE NOT ((scircle '< (0.0174532925, 0.0349065850), 0.0523598776 >' && (x)))")

	def testIsNotNull(self):
		self._testMorph(
			"select * from foo where x is not null",
			"SELECT * FROM foo WHERE x IS NOT NULL")

	def testIsNull(self):
		self._testMorph(
			"select * from foo where x is null",
			"SELECT * FROM foo WHERE x IS NULL")

	def testMultiJoin(self):
		self._testMorph(
			"select * from spatial natural join spatial2 join misc on (dist=speed)",
			"SELECT spatial.dist, spatial.width, spatial.height, spatial.ra1, spatial.ra2, spatial2.ra1, spatial2.dec, spatial2.dist, spatial2.t, misc.mass, misc.mag, misc.speed FROM spatial NATURAL JOIN spatial2  JOIN misc ON ( dist = speed )",
			_sampleFieldInfoGetter)

	def testMoveAndUnit(self):
		self._testMorph("select ivo_apply_pm("
			"in_unit(ra1, 'arcmin'), in_unit(dec, 'deg'),"
			"in_unit(ra1/t, 'mas/s'), in_unit(dec/(t+1), 'uarcsec/min'), 80) as gnack"
			" from spatial natural join spatial2 where"
			" 5<distance(point(ra1, dec),"
			" ivo_apply_pm(ra1, ra2, in_unit(dec/t, 'deg/yr'),"
			"   in_unit(ra1/t, 'mas/h'), 30))",
			"SELECT spoint(RADIANS((ra1 * 60.0)"
				" + 80 * ((ra1 / t) * 1000.0) / COS(RADIANS((dec * 1)))),"
				" RADIANS((dec * 1) + ((dec / ( t + 1 )) * 60000000.0) * 80)) AS gnack"
				" FROM spatial NATURAL JOIN spatial2  WHERE NOT "
				" q3c_join(ra1 + 30 * ((dec / t) * 8766.0) / COS(RADIANS(ra2)),"
				" ra2 + ((ra1 / t) * 3600000.0) * 30, ra1, dec, 5)",
			_sampleFieldInfoGetter)

	def testQualifiedStar(self):
		self._testMorph(
			"select spatial.*, misc.* from spatial natural join spatial2"
				" join misc on (dist=speed)",
			"SELECT spatial.dist, spatial.width, spatial.height, spatial.ra1, spatial.ra2, misc.mass, misc.mag, misc.speed FROM spatial NATURAL JOIN spatial2  JOIN misc ON ( dist = speed )",
			_sampleFieldInfoGetter)
	
	def testStarWithAlias(self):
		self._testMorph("select * from spatial as b",
			"SELECT b.dist, b.width, b.height, b.ra1, b.ra2 FROM spatial AS b",
			_sampleFieldInfoGetter)

	def testStarWithJoin(self):
		self._testMorph("select * from spatial join spatial2 on (width=dec)",
			"SELECT spatial.dist, spatial.width, spatial.height,"
			" spatial.ra1, spatial.ra2, spatial2.ra1, spatial2.dec,"
			" spatial2.dist, spatial2.t FROM spatial JOIN spatial2 ON ( width = dec )",
			_sampleFieldInfoGetter)

	def testStarWithSubquery(self):
		tree = adql.parseToTree("select * from spatial join "
			" (select ra1+dec, dist-2 as foo, dec from spatial2 offset 0) as q"
			" ON ( width = dec )")
		adql.annotate(tree, _sampleFieldInfoGetter)
		status, t = adql.morphPG(tree)
		flattened = nodes.flatten(t)
		self.assertTrue(re.match(r'SELECT spatial.dist, spatial.width, spatial.height, spatial.ra1, spatial.ra2, q.([a-z]*), q.foo, q.dec FROM spatial JOIN \(SELECT ra1 \+ dec AS \1, dist - 2 AS foo, dec FROM spatial2  OFFSET 0\) AS q ON \( width = dec \)$', flattened))

	def testSetLimitIntegrated(self):
		self._testMorph("select top 3 * from x union (select top 40 a from y"
			" except select * from z)", 
			"SELECT * FROM x UNION ( SELECT a FROM y EXCEPT SELECT * FROM z ) LIMIT 40")
	
	def testDeepSetLimitProtected(self):
		self._testMorph("select * from (select TOP 30 * from x) as q union"
			" select TOP 4 * from u", 
			"SELECT * FROM (SELECT * FROM x LIMIT 30) AS q UNION SELECT * FROM u LIMIT 4")

	def testUCDCOL(self):
		# non postgres-specific, but we need the annotation
		self._testMorph("select UCDCOL('pos.eq.dec*'), UCDCOL('phys.dim')"
			" from spatial natural join spatial2",
			"SELECT dec, width FROM spatial NATURAL JOIN spatial2 ",
			_sampleFieldInfoGetter)

	def testInUnit(self):
		# non postgres-specific, but we need the annotation
		self._testMorph("select in_unit("
				"in_unit(ra1, 'deg')/in_unit(height, 'pc')+"
				"in_unit(ra2, 'deg')/in_unit(width, 'pc'), 'rad/pc') as fantasy"
				" from spatial",
			"SELECT (((ra1 * 1) / (height * 3.24075574424e-14) + (ra2 * 57.2957795131) / (width * 3.24075574424e-17)) * 0.0174532925199) AS fantasy FROM spatial",
			_sampleFieldInfoGetter)

	def testTimestamp(self):
		self._testMorph("select TIMESTAMP('2007-02-03' || 'T12:33:44') from geo"
			" where dt>TIMESTAMP('1997-03-02')",
			"SELECT ('2007-02-03' || 'T12:33:44')::TIMESTAMP ASWHATEVER FROM geo WHERE dt > ('1997-03-02')::TIMESTAMP")

	def testBitwise(self):
		self._testMorph("select BITWISE_AND(mass, mag) as one,"
			" BITWISE_OR(ROUND(mass), power(mag, 2)) as two,"
			" BITWISE_XOR(BITWISE_NOT(mass)+3, power(mag, 2)/2) as three"
			" from misc",
			"SELECT (mass)&(mag) AS one, (ROUND(mass))|(POWER(mag, 2)) AS two, (~(mass) + 3)#(POWER(mag, 2) / 2) AS three FROM misc")

	def testOidInUpload(self):
		self._testMorph(
			"select q.*, q.oid from (select oid from tap_upload.foo) as q",
			"SELECT q.oid_, q.oid_ FROM (SELECT oid_ FROM foo) AS q",
			fieldInfoGetter=adqlglue.DaCHSFieldInfoGetter(
			tdsForUploads=[self.nastyTD]))

	def testArray(self):
		self._testMorph("select vals[3] as x from crazy where vals[round(mass)]=3",
			"SELECT vals [ 3 ] AS x FROM crazy WHERE vals [ ROUND(mass) ] = 3",
			_sampleFieldInfoGetter)

	def testEmbeddedQuotes(self):
		self._testMorph("select count(*) as x from crazy where name='O''Toole'",
			"SELECT COUNT ( * ) AS x FROM crazy WHERE name = 'O''Toole'",
			_sampleFieldInfoGetter)

	def testSampling(self):
		self._testMorph("select * from foo as a tablesample( 0.01 )",
			"SELECT * FROM foo AS a TABLESAMPLE SYSTEM (0.01)")

	def testSetGenerating(self):
		self._testMorph("select * from generate_series(low, high) as foo",
			"SELECT * FROM generate_series ( low , high ) AS foo")


class PGSMorphTest(testhelpers.VerboseTest):
	"""tests for some pgSphere morphing.
	"""
	__metaclass__ = testhelpers.SamplesBasedAutoTest

	def _runTest(self, sample):
		query, morphed = sample
		tree = adql.parseToTree(query)
		#pprint(tree.asTree())
		status, t = adql.morphPG(tree)
		self.assertEqualIgnoringAliases(nodes.flatten(t), morphed)

	samples = [
		("select AREA(circle('ICRS', COORD1(p1), coord2(p1), 2)),"
				" DISTANCE(p1,p2), centroid(circle('ICRS', coord1(p1), coord2(p1),"
				" 3)) from (select point('ICRS', ra1, dec1) as p1,"
				"   point('ICRS', ra2, dec2) as p2 from foo) as q", 
			'SELECT 3282.806350011744*AREA(scircle(spoint(RADIANS(DEGREES(long(p1))), RADIANS(DEGREES(lat(p1)))), RADIANS(2))) ASWHATEVER, DEGREES((p1) <-> (p2)) ASWHATEVER, @@(scircle(spoint(RADIANS(DEGREES(long(p1))), RADIANS(DEGREES(lat(p1)))), RADIANS(3))) ASWHATEVER FROM (SELECT spoint(RADIANS(ra1), RADIANS(dec1)) AS p1, spoint(RADIANS(ra2), RADIANS(dec2)) AS p2 FROM foo) AS q'),
		("select coord1(p) from foo", 'SELECT DEGREES(long(p)) ASWHATEVER FROM foo'),
		("select coord2(p) from foo", 'SELECT DEGREES(lat(p)) ASWHATEVER FROM foo'),
		# Ahem -- the following could resolve the coordsys, but intra-query 
		# communication is through field infos; the trees here are not annotated,
		# though.  See above, testPointFunctinWithFieldInfo
		("select coordsys(q.p) from (select point('ICRS', x, y)"
			" as p from foo) as q", 
			"SELECT 'UNKNOWN' ASWHATEVER FROM (SELECT spoint(RADIANS(x), RADIANS(y)) AS p FROM foo) AS q"),
		("select alpha from foo where"
				" Intersects(circle('ICRS', alpha, delta,"
				" margin*margin), polygon('ICRS', 1, 12, 3, 4, 5, 6, 7, 8))=0",
				"SELECT alpha FROM foo WHERE NOT ((scircle(spoint(RADIANS(alpha), RADIANS(delta)), RADIANS(margin * margin))) && ((SELECT spoly(q.p) FROM (VALUES (0, spoint(RADIANS(1), RADIANS(12))), (1, spoint(RADIANS(3), RADIANS(4))), (2, spoint(RADIANS(5), RADIANS(6))), (3, spoint(RADIANS(7), RADIANS(8))) ORDER BY column1) as q(ind,p))))"),

# 5
		("select alpha from foo where"
				" contains(circle('ICRS', alpha, delta,"
				" margin*margin), box('ICRS', lf, up, ri, lo))=0",
			"SELECT alpha FROM foo WHERE NOT ((scircle(spoint(RADIANS(alpha), RADIANS(delta)), RADIANS(margin * margin))) @ ((SELECT spoly(q.p) FROM (VALUES (0, spoint(RADIANS(lf)-RADIANS(ri)/2, RADIANS(up)-RADIANS(lo)/2)), (1, spoint(RADIANS(lf)-RADIANS(ri)/2, RADIANS(up)+RADIANS(lo)/2)), (2, spoint(RADIANS(lf)+RADIANS(ri)/2, RADIANS(up)+RADIANS(lo)/2)), (3, spoint(RADIANS(lf)+RADIANS(ri)/2, RADIANS(up)-RADIANS(lo)/2)) ORDER BY column1) as q(ind,p))))"),
		("select point('ICRS', cos(a)*sin(b), cos(a)*sin(b)),"
				" circle('ICRS', raj2000, dej2000, 25-mag*mag) from foo",
			'SELECT spoint(RADIANS(COS(a) * SIN(b)), RADIANS(COS(a) * SIN(b))) ASWHATEVER, scircle(spoint(RADIANS(raj2000), RADIANS(dej2000)), RADIANS(25 - mag * mag)) ASWHATEVER FROM foo'),
		("select POiNT('ICRS', 1, 2), CIRCLE('ICRS', 2, 3, 4),"
				" bOx('ICRS', 2 ,3, 4, 5), polygon('ICRS', 2, 3, 4, 5, 6, 7)"
				" from foo",
			'SELECT spoint(RADIANS(1), RADIANS(2)) ASWHATEVER,'
			' scircle(spoint(RADIANS(2), RADIANS(3)), RADIANS(4)) ASWHATEVER,'
			' (SELECT spoly(q.p) FROM (VALUES (0, spoint(RADIANS(2)-RADIANS(4)/2, RADIANS(3)-RADIANS(5)/2)), (1, spoint(RADIANS(2)-RADIANS(4)/2, RADIANS(3)+RADIANS(5)/2)), (2, spoint(RADIANS(2)+RADIANS(4)/2, RADIANS(3)+RADIANS(5)/2)), (3, spoint(RADIANS(2)+RADIANS(4)/2, RADIANS(3)-RADIANS(5)/2)) ORDER BY column1) as q(ind,p)) ASWHATEVER,'
			' (SELECT spoly(q.p) FROM (VALUES (0, spoint(RADIANS(2), RADIANS(3))), (1, spoint(RADIANS(4), RADIANS(5))), (2, spoint(RADIANS(6), RADIANS(7))) ORDER BY column1) as q(ind,p)) ASWHATEVER FROM foo'),
		("select Box('ICRS',alphaFloat,deltaFloat,pmra*100,pmde*100)"
			"	from ppmx.data where pmra!=0 and pmde!=0", 
			"SELECT (SELECT spoly(q.p) FROM (VALUES (0, spoint(RADIANS(alphaFloat)-RADIANS(pmra * 100)/2, RADIANS(deltaFloat)-RADIANS(pmde * 100)/2)), (1, spoint(RADIANS(alphaFloat)-RADIANS(pmra * 100)/2, RADIANS(deltaFloat)+RADIANS(pmde * 100)/2)), (2, spoint(RADIANS(alphaFloat)+RADIANS(pmra * 100)/2, RADIANS(deltaFloat)+RADIANS(pmde * 100)/2)), (3, spoint(RADIANS(alphaFloat)+RADIANS(pmra * 100)/2, RADIANS(deltaFloat)-RADIANS(pmde * 100)/2)) ORDER BY column1) as q(ind,p)) ASWHATEVER FROM ppmx.data WHERE pmra != 0 AND pmde != 0"),
		("select * from data where 1=contains(point('fk4', 1,2),"
			" circle('Galactic',2,3,4))",
			"SELECT * FROM data WHERE (((spoint(RADIANS(1), RADIANS(2)))-strans(1.565186433367,-0.004859055280,-1.576368104353)+strans(1.346356097441,-1.097319001837,0.574770524729)) @ (scircle(spoint(RADIANS(2), RADIANS(3)), RADIANS(4))))"),
# 10
		("select * from data where 1=contains(point('UNKNOWN', ra,de),"
			" circle('Galactic',2,3,4))", 
			"SELECT * FROM data WHERE q3c_join(2, 3, ra, de, 4)"),
		("select * from data where 1=intersects(coverage,"
			"circle('icrs', 10, 10, 1))",
			"SELECT * FROM data WHERE ((coverage) && (scircle(spoint(RADIANS(10), RADIANS(10)), RADIANS(1))))"),
		("select * from data where 1=intersects(\"coVerage\","
			"circle('icrs', 10, 10, 1))",
			"SELECT * FROM data WHERE ((\"coVerage\") && (scircle(spoint(RADIANS(10), RADIANS(10)), RADIANS(1))))"),
		("select contains(coverage, circle('', 10, 10, 1)) from data",
			"SELECT CONTAINS(coverage, scircle(spoint(RADIANS(10), RADIANS(10)), RADIANS(1))) ASWHATEVER FROM data"),
		("select * from (select point(ra, de) as p from data) as q"
			" where 1=intersects(circle(p, 0.1), circle(1,2,3))",
			"SELECT * FROM (SELECT spoint(RADIANS(ra), RADIANS(de)) AS p FROM data) AS q WHERE ((scircle(p, RADIANS(0.1))) && (scircle(spoint(RADIANS(1), RADIANS(2)), RADIANS(3))))"),
# 15
			]


class PGSNoMorphTest(testhelpers.VerboseTest):
	def testPolygonNoCentroid(self):
		tree = adql.parseToTree(
			"select centroid(polygon('ICRS', 12, 13, 14, 15, 15, 17)) from foo")
		self.assertRaisesWithMsg(adql.MorphError,
			"Can only compute centroids of circles and points yet."
			"  Complain to make us implement other geometries faster.",
			adql.morphPG,
			(tree,))

	def testReferencedBoxNoCentroid(self):
		tree = parseWithArtificialTable(
			"select centroid(b) from (select"
				" box('', 1, 1, 2, 2) as b from spatial) as q")
		self.assertRaisesWithMsg(adql.MorphError,
			"Can only compute centroids of circles and points yet."
			"  Complain to make us implement other geometries faster.",
			adql.morphPG,
			(tree,))


class GlueTest(testhelpers.VerboseTest):
# Tests for some aspects of adqlglue
	def testAutoNull(self):
		td = adqlglue._getTableDescForOutput(
			parseWithArtificialTable("select * from crazy"))
		self.assertEqual(td.getColumnByName("ct").values.nullLiteral, "-2147483648")

	def testSpecifiedNull(self):
		td = adqlglue._getTableDescForOutput(
			parseWithArtificialTable("select * from crazy"))
		self.assertEqual(td.getColumnByName("wot").values.nullLiteral, "-1")

	def testSpecifiedNullOverridden(self):
		td = adqlglue._getTableDescForOutput(
			parseWithArtificialTable("select 2+wot from crazy"))
		self.assertEqual(td.columns[0].values.nullLiteral, '-9223372036854775808')

	def testPureByteaNotPromoted(self):
		td = adqlglue._getTableDescForOutput(
			parseWithArtificialTable("select wotb from crazy"))
		self.assertEqual(td.columns[0].values.nullLiteral, '254')
		self.assertEqual(td.columns[0].type, 'bytea')

	def testByteaInMultiplication(self):
# This probably behaviour that doesn't work with postgres anyway. 
# Fix the whole unsignedByte mess by not linking it to bytea.
		td = adqlglue._getTableDescForOutput(
			parseWithArtificialTable("select 2*wotb from crazy"))
		self.assertEqual(td.columns[0].type, 'smallint')
		self.assertEqual(td.columns[0].values.nullLiteral, "-32768")


class QueryTest(testhelpers.VerboseTest):
	"""performs some actual queries to test the whole thing.
	"""
	resources = [("ds",  adqlTestTable), ("querier", adqlQuerier),
		("geomds", geomTestTable)]

	def setUp(self):
		testhelpers.VerboseTest.setUp(self)
		self.tableName = self.ds.tables["adql"].tableDef.getQName()

	def _assertFieldProperties(self, dataField, expected):
		for label, value in expected:
			self.assertEqual(getattr(dataField, label, None), value, 
				"Data field %s:"
				" Expected %s for %s, found %s"%(dataField.name, repr(value), 
					label, repr(getattr(dataField, label, None))))

	def runQuery(self, query, **kwargs):
		return adqlglue.query(self.querier, query, **kwargs)

	def testPlainSelect(self):
		res = self.runQuery(
			"select alpha, delta from %s where mag<0"%
			self.tableName)
		self.assertEqual(res.tableDef.id, self.tableName.split(".")[-1])
		self.assertEqual(len(res.rows), 1)
		self.assertEqual(len(res.rows[0]), 2)
		self.assertEqual(res.rows[0]["alpha"], 290.125)
		raField, deField = res.tableDef.columns
		self._assertFieldProperties(raField, [("ucd", 'pos.eq.ra;meta.main'),
			("description", 'A sample RA'), ("unit", 'deg'), 
			("tablehead", "Raw RA")])
		self._assertFieldProperties(deField, [("ucd", 'pos.eq.dec;meta.main'),
			("description", 'A sample Dec'), ("unit", 'deg'), 
			("tablehead", None)])

	def testStarSelect(self):
		res = self.runQuery("select * from %s where mag<0"%
			self.tableName)
		self.assertEqual(len(res.rows), 1)
		self.assertEqual(len(res.rows[0]), 5)
		fields = res.tableDef.columns
		self._assertFieldProperties(fields[0], [("ucd", 'pos.eq.ra;meta.main'),
			("description", 'A sample RA'), ("unit", 'deg'), 
			("tablehead", "Raw RA")])
		self._assertFieldProperties(fields[1], [("ucd", 'pos.eq.dec;meta.main'),
			("description", 'A sample Dec'), ("unit", 'deg'), 
			("tablehead", None)])
		self._assertFieldProperties(fields[3], [
			("ucd", 'phys.veloc;pos.heliocentric'),
			("description", 'A sample radial velocity'), ("unit", 'km/s')])
		self._assertFieldProperties(fields[4], [
			("ucd", ''), ("description", ''), ("unit", '')])

	def testQualifiedStarSelect(self):
		res = self.runQuery("select %s.* from %s join %s as q1"
			" using (mag) where q1.mag<0"%(
			self.tableName, self.tableName, self.tableName))
		self.assertEqual(res.tableDef.id, "adql_q1")
		self.assertEqual(len(res.rows), 1)
		self.assertEqual(len(res.rows[0]), 5)
		fields = res.tableDef.columns
		self._assertFieldProperties(fields[0], [("ucd", 'pos.eq.ra;meta.main'),
			("description", 'A sample RA'), ("unit", 'deg'), 
			("tablehead", "Raw RA")])

	def testNoCase(self):
		# will just raise an Exception if things are broken.
		self.runQuery("select ALPHA, DeLtA, MaG from %s"%self.tableName)
	
	def testDelimitedMapping(self):
		res = self.runQuery('select "alpha" from "test"."adql"')
		self.assertEqual(
			str(res.tableDef.columns[0].name),
			'"alpha"')
		self.assertEqual(
			res.tableDef.id,
			"_adql")

	def testDelimitedBadTableFails(self):
		self.assertRaisesWithMsg(base.DBError,
		'relation "test.Adql" does not exist\nLINE 1: SELECT "alpha" FROM test."Adql" LIMIT 20000\n                            ^\n',
		self.runQuery,
		('select "alpha" from test."Adql"',))

	def testTainting(self):
		res = self.runQuery("select delta*2, alpha*mag, alpha+delta"
			" from %s where mag<-10"% self.tableName)
		f1, f2, f3 = res.tableDef.columns
		self._assertFieldProperties(f1, [("ucd", 'pos.eq.dec;meta.main'),
			("description", 'A sample Dec -- *TAINTED*: the value was operated'
				' on in a way that unit and ucd may be severely wrong'),
			("unit", 'deg')])
		self._assertFieldProperties(f2, [("ucd", ''),
			("description", 'This field has traces of: A sample RA;'
				' A sample magnitude -- *TAINTED*: the value was operated'
				' on in a way that unit and ucd may be severely wrong'),
			("unit", 'deg*mag')])
		self._assertFieldProperties(f3, [("ucd", ''),
			("description", 'This field has traces of: A sample RA; A sample Dec'
				' -- *TAINTED*: the value was operated on in a way that unit and'
				' ucd may be severely wrong'),
			("unit", 'deg')])

	def testTransformation(self):
		res = self.runQuery("select mag from %s where"
			" 1=contains(point('icrs', alpha, delta),"
			"   circle('galactic', 107,-47, 1))"%self.tableName)
		self.assertEqual(list(res)[0]["mag"], 10.25)
	
	def testGeometryInSelect(self):
		res = self.runQuery(
			"select rv, point('icrs', alpha, delta) as p, mag, alpha, delta,"
			" contains(point(alpha, delta), circle('', 3, 15, 1)) as c,"
			" contains(point('', alpha, delta), circle(3, 15, 2)) as c1,"
			" intersects(point(alpha, delta), circle(3, 15, 2)) as i1,"
			" intersects(circle(alpha, delta, 1.5), circle(3, 15, 1.5)) as i2,"
			" intersects(circle(alpha, delta, 0.5), circle(3, 15, 0.5)) as i3,"
			" circle('icrs', alpha, delta, 10) as ci,"
			" circle(delta, alpha, 5) as cn"
			" from %s where mag>5"%self.tableName)
		rows = list(res)

		expected = pgsphere.SPoint.fromDegrees(2, 14)
		self.assertAlmostEqual(rows[0]["p"].x, expected.x)
		self.assertAlmostEqual(rows[0]["p"].y, expected.y)
		self.assertEqual(rows[0]["rv"], -23.75)
		self.assertEqual(rows[0]["c"], 0)
		self.assertEqual(rows[0]["c1"], 1)
		self.assertEqual(rows[0]["i1"], 1)
		self.assertEqual(rows[0]["i2"], 1)
		self.assertEqual(rows[0]["i3"], 0)
		expected = pgsphere.SCircle.fromDALI([2, 14, 10])
		expected2 = pgsphere.SCircle.fromDALI([14, 2, 5])
		self.assertAlmostEqual(rows[0]["ci"].center.x, expected.center.x)
		self.assertAlmostEqual(rows[0]["ci"].radius, expected.radius)
		self.assertAlmostEqual(rows[0]["cn"].center.y, expected2.center.y)
		self.assertAlmostEqual(rows[0]["cn"].radius, expected2.radius)
		self.assertEqual(res.tableDef.getColumnByName("p").xtype,
			"point")
		self.assertEqual(res.tableDef.getColumnByName("ci").xtype,
			"circle")
		self.assertEqual(res.tableDef.getColumnByName("cn").xtype,
			"circle")

	def testCrossmatch(self):
		res = self.runQuery(
			"select alpha, delta from %s" 
			" WHERE 1=CROSSMATCH(3, 15, alpha, delta, 2)"%self.tableName)
		self.assertEqual(list(res), [{'alpha': 2.0, 'delta': 14.0}])

	def testQuotedIdentifier(self):
		res = self.runQuery(
			'select "rv", rV from %s where delta=89'%self.tableName)
		self.assertEqual(res.rows, [{"rv": 28., "rv_": 28.}])

	def testDistanceDegrees(self):
		res = self.runQuery(
			"select DISTANCE(POINT('ICRS', 22, -3), POINT('ICRS', 183, 50)) as d"
			" from %s"%self.tableName)
		self.assertAlmostEqual(res.rows[0]["d"], 130.31777623681)

	def testInUnitGeometry(self):
		res = self.runQuery(
			"select in_unit(DISTANCE(POINT('ICRS', 22, -3),"
				" POINT('ICRS', 183, 50)), 'rad') as d"
			" from %s"%self.tableName)
		self.assertAlmostEqual(res.rows[0]["d"], 2.27447426920956)

	def testApplyPM(self):
		res = self.runQuery(
			"SELECT ivo_apply_pm(alpha, delta, 0.002, -0.001, -55) as moved"
			" FROM %s where alpha between 20 and 26"%self.tableName)
		mapped = list(base.SerManager(res).getMappedValues())[0]["moved"]
		self.assertAlmostEqual(mapped[0],  24.8866325007715)
		self.assertAlmostEqual(mapped[1], -13.945)
	
	def testStringFunctions(self):
		res = self.runQuery(
			"SELECT UPPer(table_name) as tn, lower(description) as td"
				" from tap_schema.columns where column_name ='delta'")
		self.assertTrue({'tn': 'TEST.ADQL', 'td': 'a sample dec'} in res.rows)

	def testDMAnnotation(self):
		res = self.runQuery("SELECT TOP 1 * FROM %s"%self.tableName)
		ann = res.tableDef.iterAnnotationsOfType("geojson:FeatureCollection"
			).next()
		self.assertEqual(ann["feature"]["geometry"]["type"], "sepcoo")
		self.assertEqual(ann["feature"]["geometry"]["latitude"].value, 
			res.tableDef.getByName("delta"))

	def testMOCVsPoint(self):
		res = self.runQuery("SELECT row_id FROM test.adqlgeo"
			" WHERE 1=CONTAINS(POINT('ICRS', 55.5, 20.7), a_moc)")
		self.assertEqual([r["row_id"] for r in res.rows],
			['moc-6-1', 'moc-4-1'])
		res = self.runQuery("SELECT row_id FROM test.adqlgeo"
			" WHERE 1=CONTAINS(POINT('ICRS', 56, 21), a_moc)")
		self.assertEqual([r["row_id"] for r in res.rows],
			['moc-4-1'])

	def testSpointedCircle(self):
		res = self.runQuery("SELECT row_id FROM test.adqlgeo"
			" WHERE 1=INTERSECTS(CIRCLE(a_point, 1), CIRCLE(23, 42, 1))")
		self.assertEqual([r["row_id"] for r in res.rows],
			['moc-6-1'])
	
	def testSpointedPolygon(self):
		res = self.runQuery("SELECT row_id,"
			" polygon(a_point,POINT(23,42), POINT(23, 41)) as p FROM test.adqlgeo"
			" WHERE 1=CONTAINS(POINT(24, 43), POLYGON(a_point,"
			" POINT(23,42), POINT(23, 41), POINT(24, 41)))")
		self.assertEqual([r["row_id"] for r in res.rows],
			['moc-4-1'])
		self.assertAlmostEqual(r["p"].points[0].x, 0.436332313)
	
	def testCTE(self):
		res = self.runQuery("WITH knall as (SELECT POINT(alpha, delta) as pt"
			" from test.adql)"
			" SELECT * FROM kNall")
		self.assertEqual(len(res.rows), 3)
		self.assertEqual(type(res.rows[0]["pt"]), pgsphere.SPoint)
		self.assertEqual(res.tableDef.columns[0].name, "pt")


class SimpleSTCSTest(testhelpers.VerboseTest):
	def setUp(self):
		self.parse = tapstc.getSimpleSTCSParser()

	def testPosParses(self):
		res = self.parse("Position 10 20 ")
		self.assertEqual(res.pgType, "spoint")
		self.assertAlmostEqual(res.x, 0.174532925199432)
		self.assertEqual(res.cooSys, "UNKNOWN")
	
	def testCircleParses(self):
		res = self.parse(" Circle ICRS 10 20 1e0")
		self.assertEqual(res.pgType, "scircle")
		self.assertEqual(res.cooSys, "ICRS")

	def testBadCircleRaises(self):
		self.assertRaisesWithMsg(stc.STCSParseError,
			'STC-S circles want three numbers.',
			self.parse,
			("Circle 2 1",))

	def testBoxParses(self):
		res = self.parse("box TOPOCENTER SPHERICAL2 -10  20 2.1 5.4")
		self.assertEqual(res.pgType, "spoly")
	
	def testPolyParses(self):
		res = self.parse("PolyGon FK4 TOPOCENTER SPHERICAL2 -10  20 2.1 5.4 1 3")
		self.assertEqual(res.pgType, "spoly")

	def testNotParses(self):
		res = self.parse("NOT  (Box ICRS 1 2 3 4)")
		self.failUnless(isinstance(res, tapstc.GeomExpr))
		self.assertEqual(len(res.operands), 1)
		self.assertAlmostEqual(res.operands[0].points[0].x, -0.00872664626)
		self.assertEqual(res.cooSys, "UNKNOWN")
	
	def testSimpleOpParses(self):
		res = self.parse("UNiON (Box ICRS 1 2 3 4 Circle 1 2 3)")
		self.failUnless(isinstance(res, tapstc.GeomExpr))
		self.assertEqual(res.operator, "UNION")
		self.assertEqual(len(res.operands), 2)
		self.assertEqual(res.operands[0].pgType, "spoly")
		self.assertEqual(res.operands[1].pgType, "scircle")
		self.assertEqual(res.cooSys, "UNKNOWN")
	
	def testComplexOpParses(self):
		res = self.parse("INtersection FK4 ("
			"UNiON BARYCENTER (Box ICRS 1 2 3 4 Circle 1 2 3)"
			" Polygon ICRS GEOCENTER 2 3 4 5 6 7"
			" Circle Fk4 spherical2 3 4 5)")
		self.assertEqual(res.operands[0].operator, "UNION")
		self.assertEqual(res.operands[1].cooSys, "ICRS")
		self.assertEqual(res.cooSys, "FK4")

	def testCartesianRaises(self):
		self.assertRaisesWithMsg(stc.STCSParseError, 
			'Only SPHERICAL2 STC-S supported here',
			self.parse,
			("Position CARTESIAN3 1 2 3",))


class IntersectsFallbackTest(testhelpers.VerboseTest):
# Does INTERSECT fall back to CONTAINS?
	def testArg1(self):
		ctx, tree = adql.parseAnnotating(
			"SELECT pt from geo where intersects(pt, circle('ICRS', 2, 2, 1))=1",
			_sampleFieldInfoGetter)
		funNode = tree.whereClause.children[1].op1
		self.assertEqual(funNode.funName, "CONTAINS")
		self.assertEqual(funNode.args[0].type, "columnReference")
		self.assertEqual(funNode.args[1].type, "circle")

	def testArg2(self):
		ctx, tree = adql.parseAnnotating(
			"SELECT pt from geo where intersects(circle('ICRS', 2, 2, 1), pt)=1",
			_sampleFieldInfoGetter)
		funNode = tree.whereClause.children[1].op1
		self.assertEqual(funNode.funName, "CONTAINS")
		self.assertEqual(funNode.args[0].type, "columnReference")
		self.assertEqual(funNode.args[1].type, "circle")

	def testExpr(self):
		ctx, tree = adql.parseAnnotating(
			"SELECT pt from geo where intersects("
			"point('ICRS', 2, 2), circle('ICRS', 2, 2, 1))=1",
			_sampleFieldInfoGetter)
		funNode = tree.whereClause.children[1].op1
		self.assertEqual(funNode.funName, "CONTAINS")
		self.assertEqual(funNode.args[0].type, "point")
	
	def testNotouch(self):
		ctx, tree = adql.parseAnnotating(
			"SELECT pt from geo where intersects("
			"box('ICRS', 2, 2, 3, 3), circle('ICRS', 2, 2, 1))=1",
			_sampleFieldInfoGetter)
		funNode = tree.whereClause.children[1].op1
		self.assertEqual(funNode.funName, "INTERSECTS")

	def testJoinCond(self):
		ctx, tree = adql.parseAnnotating(
			"SELECT * from geo as a join geo as b on (intersects("
			"circle('ICRS', coord1(b.pt), coord2(b.pt), 1), a.pt)=1)",
			_sampleFieldInfoGetter)
		funNode = tree.fromClause.tableReference.joinSpecification.children[2].op1
		self.assertEqual(funNode.funName, "CONTAINS")
		self.assertEqual(funNode.args[0].fieldInfo.type, "spoint")

	def testJoinCondGeoCol(self):
		ctx, tree = adql.parseAnnotating(
			"SELECT * from geo as a join geo as b on (intersects("
			"circle('ICRS', b.pt, 1), a.pt)=1)",
			_sampleFieldInfoGetter)
		funNode = tree.fromClause.tableReference.joinSpecification.children[2].op1
		self.assertEqual(funNode.funName, "CONTAINS")
		self.assertEqual(funNode.args[0].fieldInfo.type, "spoint")



if __name__=="__main__":
	testhelpers.main(STCTest)
