"""
A DaCHS custom renderer to give a custom, json-based interface to
the relational registry data for WIRR's web interface.
"""

import datetime
import json
import re

from twisted.internet import threads
from twisted.web import resource
from twisted.web import server

from gavo import api
from gavo import base
from gavo.base import sqlmunge
from gavo.utils import pgsphere

# give up our queries after this many seconds; 5 seconds may
# be too stingy in case we hit non-index-supported queries over
# rr.table_columns, but let's see
QUERY_TIMEOUT = 5

def escapeForSQL(aString):
	"""returns aString with SQL LIKE metacharacters escaped.
	"""
	return aString.replace("\\", "\\\\"
		).replace("%", "\\%"
		).replace("_", "\\_")


def makeTextCondition(constraint, sqlPars, additionalTables):
	# TODO: add subjects here?
	_, _, operand = constraint

	# hack a pattern for shortName (should we tokenise this and
	# have individual queries per word?)
	#operandPat = "%{}%".format(
	#	escapeForSQL(operand))
	return (
	"to_tsvector('english', res_description) @@ plainto_tsquery(%%(%s)s)"
	" or to_tsvector('english', res_title) @@ plainto_tsquery(%%(%s)s)"
	" or to_tsvector('english', creator_seq) @@ plainto_tsquery(%%(%s)s)")%(
# matching against short_names really messes up the query plan (and
# is not terribly useful in the form).  Hence, we let it be for now.
# Figuring out why the query plan is ruined would be good, though.
#	" or short_name ILIKE %%(%s)s")%(
		sqlmunge.getSQLKey("fulltext", operand, sqlPars),
		sqlmunge.getSQLKey("fulltext", operand, sqlPars),
		sqlmunge.getSQLKey("fulltext", operand, sqlPars))
#		sqlmunge.getSQLKey("opPat", operandPat, sqlPars))


def makeStringIsConditionMaker(colName, extraTable=None, normalize=False,
		fixedOperator=None):
	def makeStringIsCondition(constraint, sqlPars, additionalTables):
		consType, operator, operand = constraint

		if fixedOperator is not None:
			operator = fixedOperator
		if extraTable is not None:
			additionalTables.add(extraTable)

		if normalize:
			operand = operand.lower().strip()

		return "%s %s %%(%s)s"%(
			colName,
			{"=": "=", "!=": "!="}.get(operator, "="),
			sqlmunge.getSQLKey(colName, operand, sqlPars))

	return makeStringIsCondition


def makeStringLikeConditionMaker(colName, extraTable=None):
	def makeStringLikeCondition(constraint, sqlPars, additionalTables):
		consType, operator, operand = constraint

		if extraTable is not None:
			additionalTables.add(extraTable)

		operator = {
			"!=": "NOT LIKE",
			"=": "LIKE",}[operator]

		return "%s %s %%(%s)s"%(
			colName,
			operator,
			sqlmunge.getSQLKey(colName, operand, sqlPars))

	return makeStringLikeCondition


def makeHashContainsConditionMaker(colName, extraTable=None):
	def makeHashContainsCondition(constraint, sqlPars, additionalTables):
		consType, operator, operand = constraint
		if extraTable is not None:
			additionalTables.add(extraTable)
		return "1 %s ivo_hashlist_has(%s, %%(%s)s)"%(
			{"=": "=", "!=": "!="}.get(operator, "="),
			colName,
			sqlmunge.getSQLKey(colName, operand, sqlPars))

	return makeHashContainsCondition


def makeStringContainsConditionMaker(colName, extraTable=None):
	def makeStringContainsCondition(constraint, sqlPars, additionalTables):
		consType, operator, operand = constraint
		if extraTable is not None:
			additionalTables.add(extraTable)

		return "1 %s ivo_hasword(%s, %%(%s)s)"%(
			{"=": "=", "!=": "!="}.get(operator, "="),
			colName,
			sqlmunge.getSQLKey(colName, operand, sqlPars))

	return makeStringContainsCondition


def makeTimestampConditionMaker(colName, extraTable=None):
	def makeTimestampCondition(constraint, sqlPars, additionalTables):
		consType, operator, operand = constraint
		if extraTable is not None:
			additionalTables.add(extraTable)
		now = datetime.datetime.now()

		if operator=='ndays':
			limits = now-datetime.timedelta(days=int(operand)), now
		elif operator=='date':
			limits = api.parseDefaultDatetime(operand), now
		else:
			raise ValueError("Invalid operator %s"%repr(operator))
		
		return "%s BETWEEN %%(%s)s AND %%(%s)s"%(
			colName,
			sqlmunge.getSQLKey(colName, limits[0], sqlPars),
			sqlmunge.getSQLKey(colName, limits[1], sqlPars))

	return makeTimestampCondition


def makeStringHasConditionMaker(colName, extraTable=None, normalize=False):
	def makeStringIsLikeCondition(constraint, sqlPars, additionalTables):
		consType, operator, operand = constraint
		operand = "%{}%".format(
			escapeForSQL(operand))

		if extraTable is not None:
			additionalTables.add(extraTable)
		if normalize:
			operand = operand.lower()

		return "%s %s %%(%s)s"%(
			colName,
			{"=": "ILIKE", "!=": "NOT ILIKE"}.get(operator, "="),
			sqlmunge.getSQLKey(colName, operand, sqlPars))

	return makeStringIsLikeCondition


def makeUCDCondition(constraint, sqlPars, additionalTables):
	consType, operator, operand = constraint
	operand = "%s"%operand
	operand = operand.lower()

	return ("EXISTS("
		" SELECT name"
		" FROM rr.table_column AS tcol"
		" WHERE a.ivoid=tcol.ivoid"
		" AND ucd %s %%(%s)s)")%(
		{"=": "ILIKE", "!=": "NOT ILIKE"}.get(operator, "="),
		sqlmunge.getSQLKey("ucd", operand, sqlPars))


def makeDetailsCondition(constraint, sqlPars, additionalTables):
	"""returns a constraint for rr.res_detail.

	constraint is the usual consType, operator, operand triple, but
	operator now is a whitespace-concatenated combination of key and
	something like an operator (= and ILIKE is allowed right now).
	"""
	consType, operator, operand = constraint
	parts = operator.split()
	if len(parts)==2 and parts[1] in ["ILIKE", "="]:
		# note: operator validated in condition!
		detailXpath, operator = parts
		if parts[1] in "ILIKE":
			operand = "%%%s%%"%operand
	else:
		detailXpath, operator = operator, "="

	return ("EXISTS("
		" SELECT detail_value"
		" FROM rr.res_detail AS de"
		" WHERE a.ivoid=de.ivoid"
		"   AND detail_xpath=%%(%s)s AND detail_value %s %%(%s)s)")%(
		sqlmunge.getSQLKey("detail_xpath", detailXpath, sqlPars),
		operator,
		sqlmunge.getSQLKey("detail_value", operand, sqlPars))


def makeRelationshipCondition(constraint, sqlPars, additionalTables):
	"""constrains resources to those related to one with an IVOID.
	"""
	consType, operator, operand = constraint
	return ("ivoid in (select distinct related_id"
		" from rr.relationship as rinner where rinner.ivoid=%%(%s)s)"%(
			sqlmunge.getSQLKey("targetid", operand, sqlPars)))


def makeAltRelationshipCondition(constraint, sqlPars, additionalTables):
	"""constrains resources to those related some altIdentifier
	"""
	_, _, operand = constraint
	operand = escapeForSQL(operand)+'%'
	return ("ivoid in (select ivoid"
		" from rr.relationship where related_alt_identifier LIKE %%(%s)s)"%(
			sqlmunge.getSQLKey("altid", operand, sqlPars)))


def makeSpatialCondition(constraint, sqlPars, additionalTables):
	"""builds a spatial constraint.

	For now, I'm accepting whatever we accept for SCS.
	"""
	from gavo.protocols import scs
	ra, dec = scs.parseHumanSpoint(constraint[2])
	additionalTables.add("rr.stc_spatial")
	return "%%(%s)s <@ coverage"%(
		sqlmunge.getSQLKey("point",
			pgsphere.SPoint.fromDegrees(ra, dec), sqlPars))


def makeTemporalCondition(constraint, sqlPars, additionalTables):
	"""builds a temporal constraint.

	We're trying to interpret the literal as a julian year, and MJD, or
	an ISO/DALI string as plausible.
	"""
	literal = constraint[2]
	try:
		parsed = float(literal)
		if 1000<parsed<2100:
			# assume it's a julian year
			parsed = api.dateTimeToMJD(api.jYearToDateTime(parsed))
		# else it's interpreted as an MJD
	except ValueError:
		parsed = api.dateTimeToMJD(api.parseISODT(literal))

	additionalTables.add("rr.stc_temporal")
	return "%%(%s)s between time_start and time_end"%(
		sqlmunge.getSQLKey("time", parsed, sqlPars))


def makeSpectralCondition(constraint, sqlPars, additionalTables):
	"""builds a spectral constraint.

	constraint[2], the user input, here is a list consisting of a value
	and a unit name.  This is reflected in the template, which has
	two elements for operandN.
	"""
	value, unit = constraint[2]
	if not value:
		return

	value = eval(base.getSpecExpr(unit, "J").format(
		float(value))) # don't drop the float or you'll get code injection
	
	additionalTables.add("rr.stc_spectral")
	return "%%(%s)s between spectral_start and spectral_end"%(
		sqlmunge.getSQLKey("spectral", value, sqlPars))


_CONSTRAINTDEF = {
	'capid': makeStringLikeConditionMaker("standard_id"),
	'title': makeStringContainsConditionMaker("res_title"),
	'textfields': makeTextCondition,
	'colucd': makeUCDCondition,
	'subject': makeStringContainsConditionMaker("res_subject", "rr.res_subject"),
	'description': makeStringContainsConditionMaker("res_description"),
	'coldesc': makeStringContainsConditionMaker("column_description",
		extraTable="rr.table_column"),
	'creator': makeStringContainsConditionMaker("creator_seq"),
	'waveband': makeHashContainsConditionMaker("waveband"),
	'updated': makeTimestampConditionMaker("updated"),
	'resdetail': makeDetailsCondition,
	'ivoid': makeStringHasConditionMaker("ivoid", normalize=True),
	'relation': makeRelationshipCondition,
	'alt_relation': makeAltRelationshipCondition,
	'accurl': makeStringHasConditionMaker("access_url"),
	'restype': makeStringIsConditionMaker("res_type"),
	'uat': makeStringIsConditionMaker("uat_concept",
		extraTable="rr.subject_uat"),
	'spatial': makeSpatialCondition,
	'temporal': makeTemporalCondition,
	'spectral': makeSpectralCondition,
}


def _listOrItem(arg):
	"""helps resortArguments to pass on list-valued operands if required.
	"""
	return arg[0] if len(arg)==1 else arg


def resortArguments(queryArgs):
	"""returns a list of (consType, operator, operand) constraints from
	"sequenced" request arguments (see above).
	"""
	kwPat = re.compile("field([0-9]+)$")
	constraints = []

	for key in queryArgs:
		mat = kwPat.match(key)
		if mat:
			index = mat.group(1)
			try:
				constraints.append((
					queryArgs["field"+index][0],
					queryArgs["operator"+index][0],
					_listOrItem(queryArgs["operand"+index])))
			except (KeyError, IndexError):
				pass

	return constraints


class _JSONQuery(resource.Resource):
	"""A resource returning Json for the database query given in the
	query class attribute (and potentially some arguments).

	TODO: we should do some more sensible error handling.
	"""
	def _realRender(self, request):
		self.moreData = {}
		try:
			res = self._runQuery(request.strargs)
			request.setHeader("content-type", "text/json;charset=utf-8")
			res = {"status": "ok", "content": res}
			res.update(self.moreData)
			request.write(json.dumps(res).encode("utf-8"))

		except base.QueryCanceledError as msg:
			request.write(json.dumps({"status": "error", "message":
				"Oops: this query runs so long that it's clear we got it"
				" wrong.  Please complain to the operators, sending along"
				" the URL of this page.  Thanks!"}).encode("utf-8"))

		except Exception as msg:
			api.ui.notifyError(f"RR query error: {msg}")
			request.write(json.dumps({"status": "error", "message": str(msg)}
				).encode("utf-8"))

		request.finish()
	
	def render(self, request):
		threads.deferToThread(
			self._realRender, request)
		return server.NOT_DONE_YET


class ParameterQuery(_JSONQuery):
	"""the central query builder.

	This exposes json results for a set of constraints defined via
	_CONSTRAINTDEF.

	There's a little UI hack in here: when offset is 0 (which from the
	JS means the initial query), this will also return the total number
	of matches (at that moment) in a top-level n_total key.

	See _JSONQuery for the format of the json response, and the query below
	for the fields returned.
	"""

	def _runQuery(self, args):
		constraints = resortArguments(args)
		sqlPars, fragments, additionalTables = {}, [], set()
		for constraint in constraints:
			if constraint[0] not in _CONSTRAINTDEF:
				raise api.ReportableError(
					f"Unknown constraint type '{constraint[0]}'")
			fragments.append(
				_CONSTRAINTDEF[constraint[0]](
					constraint, sqlPars, additionalTables))

		fragments = [f for f in fragments if f]

		if not fragments:
			raise api.ValidationError("No conditions", "query")

		limit, offset = self._getQueryMeta(args)

		joinExpr = ""
		if additionalTables:
			for t in additionalTables:
				if "res_detail" in t:
					joinExpr += "\n JOIN %s AS c ON (b.ivoid=c.ivoid)" % (t)
				else:
					joinExpr += "\n NATURAL LEFT OUTER JOIN %s" % (t)

		infoColumns = (
			" a.ivoid, a.res_title, a.res_description, a.source_value,"
			" a.reference_url")
		queryCore = (
			"\n  FROM ("
			"\n    SELECT array_agg(relationship_type) AS relations, res.*"
			"\n       FROM rr.resource AS res"
			"\n       NATURAL LEFT OUTER JOIN rr.relationship"
			"\n       GROUP BY res.ivoid) AS a"
			"\n  NATURAL LEFT OUTER JOIN rr.capability AS b"
			"\n  NATURAL LEFT OUTER JOIN rr.interface"
			"\n  %(joinExpr)s"
			"\n  WHERE "
			"\n  %(conditions)s ")%{
				"conditions": sqlmunge.joinOperatorExpr("\n  AND", fragments),
				"joinExpr": joinExpr}

		self.moreData["basicQuery"] = f"SELECT {infoColumns}\n"+queryCore

		query = ("SELECT %(infoColumns)s, array_agg(standard_id) AS capids,"
			"  array_agg(access_url) AS urls, array_agg(intf_type) AS intftype,"
			"  to_char(updated, 'YYYY-MM-DD') AS lastupdate,"
			"  relations"
			"  %(queryCore)s"
			"  GROUP BY %(infoColumns)s, a.updated, a.relations"
			"  ORDER BY a.res_title"
			"  LIMIT %(limit)d"
			"  OFFSET %(offset)d")%{
				"infoColumns": infoColumns,
				"queryCore": queryCore,
				"limit": limit,
				"offset": offset}

		with api.getTableConn() as conn:
			if offset==0:
					self.moreData["n_total"] = list(conn.query("SELECT count(*)"
						" FROM (SELECT 1 "
						+queryCore
						+" GROUP BY a.ivoid) AS q", sqlPars, timeout=QUERY_TIMEOUT))[0][0]
			return list(conn.queryToDicts(query, sqlPars, timeout=QUERY_TIMEOUT))

	def _getQueryMeta(self, args):
		"""returns limit and offset (or good defaults) from args.

		Both are integers.
		"""
		limit, offset = 300, 0
		try:
			limit = int(args["MAXREC"][0])
		except (ValueError, KeyError, IndexError):
			# any botched input: Keep default
			pass
		try:
			offset = int(args["OFFSET"][0])
		except (ValueError, KeyError, IndexError):
			# any botched input: Keep default
			pass
		return limit, offset


class CannedJSONQuery(_JSONQuery):
	def _runQuery(self, args):
		queryArgs = dict((key, value[0])
			for key, value in args.items())
		with api.getTableConn() as conn:
			res = list(conn.queryToDicts(
				self.query, queryArgs))
			return res


class ResponsibilityQuery(_JSONQuery):
	def _runQuery(self, args):
		queryArgs = dict((key, value[0])
			for key, value in args.items())
		query = (
			"SELECT initcap(base_role) AS baserole,"
			"    role_name, email, telephone"
			"  FROM rr.res_role"
			"  WHERE ivoid=%(ivoid)s")

		with api.getTableConn() as conn:
			res = list(conn.queryToDicts(query, queryArgs))
			return res


class MainPage(api.ServiceBasedPage):
	checkedRenderer = False

	def getChild(self, name, request):
		if name==b'parameterQuery':
			return ParameterQuery()
		elif name==b'responsibilityQuery':
			return ResponsibilityQuery()
		else:
			return resource.NoResource()
