# Script to crop empty data from the borders of ppakm31 images
# (actually, it ought to be relatively generic modulo non-data detection.
#
# To use it, run it on a set of images.  It will then output a command
# that actually does the cropping.
#
# In the current context, if your really need to re-crop the original
# images, run
#
# python3 bin/cropfits.py data/PPAK_M31_F1_*
#
# see if the output makes sense and then run
#
# python3 bin/cropfits.py data/PPAK_M31_F1_* | sh
#
# Repeat for F2-F5.


import sys

import numpy

from gavo import api


def cropOne(fName, cropTo):
	hdus = api.pyfits.open(fName)
	assert len(hdus)==2
	hdr = hdus[0].header
	hdr["EXTEND"] = 'T'
	if "DATAMAX" in hdr:
		raise Exception("Looks like I've already run here.  Bailing out.")

	newHDUs = api.pyfits.HDUList([
		api.cutoutFITS(hdus[0], *cropTo),
		api.cutoutFITS(hdus[1], *cropTo)])

	hdus.close()

	# while we're dealing with all the data, also crop away pixels that
	# obviously make no sense:
	newHDUs[0].data[newHDUs[0].data>1e-13] = numpy.nan

	vals = newHDUs[0].data[numpy.isfinite(newHDUs[0].data)].ravel()
	newHDUs[0].header["DATAMIN"] = min(vals)
	newHDUs[0].header["DATAMAX"] = max(vals)

	api.addHistoryCard(hdr,
		"ppakm31: cropped all-NaN lines, NaNed fluxes >1e-13.",
		"all-NaN")

	newHDUs.update_extend()
	with api.safeReplaced(fName) as destF:
		newHDUs.writeto(destF)


def cropdetect(imgs):
	"""returns a common crop rectangle for the FITSes in imgs.
	"""
	ax1lower, ax1upper = 1000000, -100000
	ax2lower, ax2upper = 1000000, -100000

	for fName in imgs:
		hdus = api.pyfits.open(fName)
		assert len(hdus)==2
		hdr = hdus[0].header

		if hdr.get("NAXIS3", 0)>1:
			# I'm doing this a bit backward because nan+anything=nan and I
			# want to be reasonably sure I'm only throwing away all-nan data.
			arr = numpy.sum(numpy.nan_to_num(hdus[0].data), 0)
			arr[arr==0] = numpy.nan

		else:
			arr = hdus[0].data
		
		for lower in range(arr.shape[0]):
			if not numpy.all(numpy.isnan(arr[lower,])):
				break
		for upper in range(arr.shape[0]-1, -1, -1):
			if not numpy.all(numpy.isnan(arr[upper,])):
				break
		ax2lower = min(ax2lower, lower)
		ax2upper = max(ax2upper, upper+2)

		for lower in range(arr.shape[1]):
			if not numpy.all(numpy.isnan(arr[:,lower])):
				break
		for upper in range(arr.shape[1]-1, -1, -1):
			if not numpy.all(numpy.isnan(arr[:,upper])):
				break
		ax1lower = min(ax1lower, lower)
		ax1upper = max(ax1upper, upper+2)

	allimgs = " ".join('"{}"'.format(i) for i in imgs)
	prog = sys.argv[0]

	print("python3 {prog} -r 1-{ax1lower}-{ax1upper}+"
		"2-{ax2lower}-{ax2upper} {allimgs}".format(**locals()))


def parse_command_line():
	import argparse
	p = argparse.ArgumentParser(description="Initial processing")
	p.add_argument("-r", dest="crop_rect", action="store", type=str,
		default=None)
	p.add_argument("images", type=str, nargs='+')
	return p.parse_args()


def main():
	args = parse_command_line()

	if args.crop_rect:
		crop_rect = [[int(s) for s in spec.split("-")]
			for spec in args.crop_rect.split("+")]
		for fName in args.images:
			cropOne(fName, crop_rect)
	
	else:
		cropdetect(args.images)


if __name__=="__main__":
	main()
