"""
Make plots for paper.  The input files were written by S. Roeser, basically
as

ra dec <thing>

Where thing was:

nucac3.gz -- number of objects with valid PM in the 0.25x0.25 in UCAC3
muedeltaucac.gz -- pm delta in UCAC3
muealphacac.gz -- pm alpha in UCAC3
muealpps1fig1.gz -- pm alpha in intermediate system (cannot be recovered here)
muedelpps1fig1.gz -- pm delta in intermediate system (cannot be recovered here)
muealfinalfig1.gz -- pm alpha in ppmxl
muedelfinalfig1.gz -- pm delta in ppmxl
muedeltfaintfig8.gz -- pm delta in ppmxl, as for fig. 8
"""


from __future__ import with_statement

import gzip
import os
import sys

import matplotlib
from scipy import ndimage
from matplotlib import cm
from matplotlib import pyplot
import numpy
from numpy import ma


# dataName, maskName, limits, outputName
FIGURES = [
	("muedeltaucac.gz", "nucac3.gz", (-1200,1200), "muedeltaucac.eps"),
	("muealphaucac.gz", "nucac3.gz", (-1200,1200), "muealphaucac.eps"),
	("muealpps1fig1.gz", None, (-1200,1200), None),
	("muedelps1fig2.gz", None, (-1200,1200), None),
	("muealfinalfig5.gz", None, (-1200,1200), None),
	("muedelfinalfig6.gz", None, (-1200,1200), None),
	("muedeltfaintfig8.gz", "nfaintfig8.gz", (-1200,1200), None),
]


def isNewer(name1, name2):
	if os.path.exists(name1):
		return os.path.getmtime(name1)>os.path.getmtime(name2)
	return False


def readFromName(fName):
	cacheName = fName+".npz"
	if isNewer(cacheName, fName):
		return numpy.load(cacheName)["arr_0"]
	arr = numpy.loadtxt(fName)
	pixels = arr[:,-1]
	height = 720
	assert len(pixels)%height==0
	width = len(pixels)/height
	pixels = pixels.reshape((height, width), order='F')
	numpy.savez(fName, pixels)
	return pixels


def makeArray(dataName, maskName=None, limits=None):
	"""returns a masked array for values from dataName.

	(really, the load magic is in readFromName)

	limits is none or a two-tuple; overflows will we min/max, *not*
	masked (Roeser wanted it this way).
	"""
	pixels = readFromName(dataName)
	if limits:
		lower, upper = limits
		pixels[pixels<lower] = lower
		pixels[pixels>upper] = upper
	if maskName:
		transpixels = readFromName(maskName)
		mask = numpy.ones(transpixels.shape)
		mask[transpixels<3] = 0
		mask = ndimage.binary_opening(mask)
		mask = ndimage.binary_closing(mask)
		pixels = ma.masked_array(pixels, mask=ma.masked_values(mask, False).mask)
	return pixels


def annotateSkyPlot(ax):
	ax.xaxis.set_major_locator(pyplot.FixedLocator(numpy.arange(0, 361, 60)))
	ax.set_xlabel('Right ascension [deg]')
	ax.yaxis.set_major_locator(pyplot.FixedLocator(numpy.arange(-90, 91, 45)))
	ax.set_ylabel('Declination [deg]')
	ax.text(0, -0.15, "One pixel is $0.25^\\circ\\times\,0.25^\\circ$", ha="left",
		transform=ax.transAxes, size=8)



def makeSkyPlot(dataName, maskName, limits=None):
	pixels = makeArray(dataName, maskName, limits)
	fig = pyplot.figure()
	ax = fig.add_axes([0.1,0.1,0.8,0.9])
	im = ax.imshow(pixels, origin="lower", extent=(0,360,-90,90),
		cmap=cm.jet_r)
	cbar = fig.colorbar(im, orientation='horizontal', aspect=25)
	cbar.set_label("Proper motion [mas/100yr]")
	annotateSkyPlot(ax)
	return fig


def makeFromFigDef(figDef, saveFig):
	dataName, maskName, limits, destName = figDef
	if destName is None:
		destName = os.path.splitext(dataName)[0]+".eps"
	fig = makeSkyPlot(dataName, maskName, limits)
	if saveFig:
		fig.savefig(destName, orientation="portrait", format="eps")
	return fig


def main():
	pyplot.rc("font", size=9)
	if len(sys.argv)==1:
		for figDef in FIGURES:
			makeFromFigDef(figDef, True)
	else:
		fig = makeFromFigDef(FIGURES[int(sys.argv[1])], False)
		pyplot.show()


#  l1 = matplotlib.lines.Line2D([0, 1], [0, 1], transform=fig.transFigure, figure=fig)
# fig.axes
# ax.set_image_extent


if __name__=="__main__":
	main()
