#    SP_Ace derives stellar parameters, such as gravity, temperature, and element
#    abundances from optical stellar spectra, assuming Local Thermodynamic
#    Equilibrium (LTE) and 1D stellar atmosphere models.
#
#    Copyright (C) 2016 Corrado Boeche
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.

from tkinter import *
import tkinter, tkinter.filedialog
import tkinter.messagebox
from tkinter import ttk
import os
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
from multiprocessing import Pool
import pdb

#the user can edit these two lines
address_GCOG_lib='/home/corrado/workTux2/EW_library_SPACE2.2/libGCOG/'
space_executable='SPACE_v1.4'

##################################

def chose_file(*args):
    file_entry.delete(0,END)
    root.filename = tkinter.filedialog.askopenfilename(initialdir = work_dir,title = "Select file",filetypes = (("ASCII files","*.asc"),("all files","*.*")))
    file_entry.insert(0,root.filename)

def chose_exe(*args):
    space_entry.delete(0,END)
    root.filename = tkinter.filedialog.askopenfilename(initialdir = "~/",title = "Select file")
    space_entry.insert(0,root.filename)

def chose_dir(*args):
    dir_entry.delete(0,END)
    root.dirname = tkinter.filedialog.askdirectory(initialdir = "~/",title = "Select directory")
    dir_entry.insert(0,root.dirname)

def chose_work_dir(*args):
    work_entry.delete(0,END)
    root.work_dir = tkinter.filedialog.askdirectory(initialdir = work_dir,title = "Select directory")
    work_entry.insert(0,root.work_dir)
    #move to the work directory
    os.chdir(root.work_dir)

def warn_win(text):
    messagebox.showinfo(message=text)
###################################################
def collect_results(params_file, params_list_dict):
    params_file_root = params_file.replace('.par','')
    address_spec = params_list_dict['obs_sp_file']

    error_line='  None\n'
    try:
        file=open(params_file_root + '_TGM_ABD.dat',"r")
        results_header = file.readline()
        line=file.readline()
        file.close()
        filename_line = address_spec + line
        os.remove(params_file_root + '_TGM_ABD.dat')
    except:
        filename_line = address_spec + error_line

    return filename_line, results_header
#######################################
def save_results():
    file = tkinter.filedialog.asksaveasfile(initialdir = work_dir,title = "Save results as...",mode='w')

    if file is None:
        return
    file.writelines( "%s" % 'file_name    ' + results_header)
    file.writelines( "%s" % item for item in results_collection)
    file.close()

############################
class entry_check:
    def __init__(self,testo,colonna,riga,yshift,lf,var_value):
        self.flag=IntVar(value=0)
        self.variab=StringVar(value=var_value)
        #
        self.check = ttk.Checkbutton(lf, text=testo, command=self.naccheck,var=self.flag)

        self.check.grid(column=colonna,row=riga, pady=yshift,sticky=(W))
        #
        self.entry_field = ttk.Entry(lf, width=10, textvariable=self.variab,state=DISABLED)
        self.entry_field.grid(column=colonna+1, row=riga, sticky=(W), padx=5, pady=yshift)

    def naccheck(self):
        if(self.flag.get()==0):
            self.entry_field.configure(state=DISABLED)
        else:
            self.entry_field.configure(state=NORMAL)
            self.entry_field.focus()
############################
class spinbox_check:
    def __init__(self,testo,colonna,riga,yshift,var_value):
        self.flag=IntVar(value=0)
        self.variab=StringVar(value=var_value)
        #
        self.check = ttk.Checkbutton(lf2, text=testo, command=self.naccheck,var=self.flag)

        self.check.grid(column=colonna,row=riga, pady=yshift,sticky=(W))
        #
        self.entry_field = Spinbox(lf2,values=['null','-9.99','NaN'],width=6,textvariable=self.variab,state=DISABLED)
        self.entry_field.grid(column=colonna+1, row=riga, sticky=(W), padx=5, pady=yshift)

    def naccheck(self):
        if(self.flag.get()==0):
            self.entry_field.configure(state=DISABLED)
        else:
            self.entry_field.configure(state=NORMAL)
            self.entry_field.focus()
############################
class select_file_check:
    def __init__(self,testo,colonna,riga,yshift,var_value):
        self.flag=IntVar(value=0)
        self.variab=StringVar(value=var_value)
        #
        self.check = ttk.Checkbutton(lf2, text=testo, command=self.naccheck,var=self.flag)

        self.check.grid(column=colonna,row=riga, pady=yshift,sticky=(W))
        #
        self.file_entry = ttk.Entry(lf2, width=6, textvariable=self.variab,state=DISABLED)
        self.file_entry.grid(column=colonna+1, row=riga, sticky=(EW), padx=5, pady=yshift)
        #
        self.select_button=ttk.Button(lf2, text="select", command=chose_file,state=DISABLED)
        self.select_button.grid(column=colonna+2, row=riga, sticky=(E), padx=5, pady=yshift)

    def naccheck(self):
        if(self.flag.get()==0):
            self.file_entry.configure(state=DISABLED)
            self.select_button.configure(state=DISABLED)
        else:
            self.file_entry.configure(state=NORMAL)
            self.select_button.configure(state=NORMAL)
            self.file_entry.focus()
############################
class select_file_radiobutton:

    def __init__(self):

        self.file_type=StringVar()
        self.spectrum_name=StringVar()
        self.list_name=StringVar()
        self.variab=StringVar()
        self.variab1=StringVar()
        #
        self.file_entry = ttk.Entry(lf1, width=20, textvariable=self.variab,state=DISABLED)
        self.file_entry.grid(column=2, row=1, sticky=(W, E), padx=5, pady=5)
        self.list_entry = ttk.Entry(lf1, width=20, textvariable=self.variab1,state=DISABLED)
        self.list_entry.grid(column=2, row=2, sticky=(W, E), padx=5, pady=5)

        #
        self.file_check = ttk.Radiobutton(lf1, text='spectrum', variable=self.file_type, value='file',command=self.naccheck)
        self.file_check.grid(column=1, row=1, sticky=(W))
        self.list_check = ttk.Radiobutton(lf1, text='list of spectra', variable=self.file_type, value='list',command=self.naccheck)
        self.list_check.grid(column=1, row=2, sticky=(W))
        #
        self.select_button_file=ttk.Button(lf1, text="select", command=self.file_choose,state=DISABLED)
        self.select_button_file.grid(column=3, row=1, sticky=(E), padx=5, pady=5)
        self.select_button_list=ttk.Button(lf1, text="select", command=self.list_choose,state=DISABLED)
        self.select_button_list.grid(column=3, row=2, sticky=(E), padx=5, pady=5)

    def naccheck(self):

        if self.file_type.get()=='file':
            self.file_entry.configure(state=NORMAL)
            self.select_button_file.configure(state=NORMAL)
            self.file_entry.focus()
            try:
                self.list_entry.configure(state=DISABLED)
                self.select_button_list.configure(state=DISABLED)
            except:
                pass
            self.spectrum_name=self.variab
        elif self.file_type.get()=='list':
            self.list_entry.configure(state=NORMAL)
            self.select_button_list.configure(state=NORMAL)
            self.list_entry.focus()
            try:
                self.file_entry.configure(state=DISABLED)
                self.select_button_file.configure(state=DISABLED)
            except:
                pass
            self.list_name=self.variab1

    def file_choose(self):
        self.file_entry.delete(0,END)
        root.filename = tkinter.filedialog.askopenfilename(initialdir = work_dir,title = "Select file",filetypes = (("ASCII files","*.asc"),("all files","*.*")))
        self.file_entry.insert(0,root.filename)

    def list_choose(self):
        self.list_entry.delete(0,END)
        root.listname = tkinter.filedialog.askopenfilename(initialdir = work_dir,title = "Select list",filetypes = (("list files","*list*"),("all files","*.*")))
        self.list_entry.insert(0,root.listname)

#############################
def prepare_lists_params(sp_list):

    space_inputs_dict = {}

    format_str = '{:0' + str(len(str(len(sp_list)))) +  'd}'

    for i,spectrum in enumerate(sp_list):
        space_pars_dict = OrderedDict()

        space_pars_dict['obs_sp_file'] = spectrum.rstrip()
        space_pars_dict['GCOGlib'] = dirname.get()
        space_pars_dict['fwhm'] = fwhm.get()
        space_pars_dict['wave_lims'] = wave.get()
        if abdloop_flag.get():
            space_pars_dict['ABD_loop'] = ''
        if Salaris_MH_flag.get():
            space_pars_dict['Salaris_MH'] = ''
        if alpha_flag.get():
            space_pars_dict['alpha'] =  ''
        if error_flag.get():
            space_pars_dict['error_est'] = ''
        if nonorm_flag.get():
            space_pars_dict['no_norm'] = ''
        if sn_entry_check_button.flag.get():
            space_pars_dict['sn_ratio'] = sn_entry_check_button.variab.get()
        if T_entry_check_button.flag.get():
            space_pars_dict['T_force'] = T_entry_check_button.variab.get()
        if G_entry_check_button.flag.get():
            space_pars_dict['G_force'] = G_entry_check_button.variab.get()
        if Nrad_entry_check_button.flag.get():
            space_pars_dict['norm_rad'] = Nrad_entry_check_button.variab.get()
        if RV_entry_check_button.flag.get():
            space_pars_dict['RV_ini'] = RV_entry_check_button.variab.get()
        if ele_entry_check_button.flag.get():
            space_pars_dict['ele2write'] = ele_entry_check_button.variab.get()
        if null_spinbox_check_button.flag.get():
            space_pars_dict['null_value'] = null_spinbox_check_button.variab.get()
        if llist_rej_select_file_check_button.flag.get():
            space_pars_dict['llist_rej'] = llist_rej_select_file_check_button.variab.get()

        space_params_file = 'space_GUI_' + format_str.format(i) + '.par'
        space_inputs_dict[space_params_file] = space_pars_dict

    return space_inputs_dict
#############################
def run_space():
    global results_header
    global results_collection

    if os.path.isfile('space.par'):
        os.remove('space.par')

    if file_list_radiobutton.file_type.get()=='file':
        sp_list=[root.filename]
    elif file_list_radiobutton.file_type.get()=='list':
        file=open(root.listname,"r")
        sp_list=file.readlines()
        file.close()

    space_inputs_dict = prepare_lists_params(sp_list)

    if batch_check_button.flag.get():
        with Pool(processes=int(batch_check_button.variab.get())) as pool:
            list_process = [pool.apply_async(func_parallel, args=(name_file, params_list_dict)) for name_file, params_list_dict in space_inputs_dict.items()]
            results = [p.get() for p in list_process]
    else:
        results = []
        for name_file, params_list_dict in space_inputs_dict.items():
            result_line = func_parallel(name_file, params_list_dict)
            results.append(result_line)

    results_collection = [ll[0] for ll in results]
    results_header = results[0][1]
################################
def func_parallel(space_params_file, params_list_dict):
    keys_with_apex = ['obs_sp_file','GCOGlib','null_value','llist_rej']
    list = []

    for key, value in params_list_dict.items():
        if len(value)>0:
            if key in keys_with_apex:
                list.append(key + ' ' + "'" + value + "'")
            else:
                list.append(key + ' ' + value)
        else:
            list.append(key)

    file=open(space_params_file,"w")
    file.writelines( "%s\n" % item for item in list )
    file.close()

    sys_msg=os.system(space_executable + ' ' + space_params_file)
#    if not sys_msg==0:
#        warn_win('SP_Ace quits!')
    if sys_msg==0 and not batch_check_button.flag.get():
        open_plot(space_params_file,params_list_dict['obs_sp_file'])

    result_line, result_header = collect_results(space_params_file,params_list_dict)

    return [result_line, result_header]
##############################################
def open_plot(space_params_file, spec_file_name):
    space_model_file = space_params_file.replace('.par','') + '_model.dat'


    wave,spec_input,spec_norm,spec_model,spec_cont,spec_weight,spec_sn=np.genfromtxt(space_model_file,usecols=[0,1,2,3,4,5,6],unpack=True)

    fig, axes = plt.subplots(ncols=1, nrows=3, sharex=True, sharey=False,
                figsize=(12,12), constrained_layout=False)
    fig.subplots_adjust(top=0.9,bottom=0.2,left=0.1,right=0.93,wspace=0.1,hspace=0.03)

    title_txt = spec_file_name
    axes[0].set_title(title_txt)

    axes[0].plot(wave, spec_input, 'r-', label='input spectrum')
    axes[0].plot(wave, spec_cont, 'b-', label='continuum chosen by SP_Ace')
    axes[0].set_ylabel('norm flux')
    axes[0].legend(bbox_to_anchor=(0., 0.,1.,0.15), loc=1, ncol=2,borderaxespad=0.1, fontsize=5)
    axes[0].set_ylim(0.0,1.1)
    axes[0].set_xlim(min(wave)-20., max(wave)+20.)
    axes[0].xaxis.set_tick_params(labelbottom=True, direction='in')

    bool_shade = (spec_weight<0.01)
    axes[1].fill_between(wave, bool_shade, 0, facecolor="gray", alpha=0.8,label='wavelength rejected by SPAce')
    axes[1].plot(wave, spec_norm, 'b-', label='spectrum normalized by SP_Ace')
    axes[1].plot(wave, spec_model, 'g-', label='best matching model')
    axes[1].xaxis.set_tick_params(labelbottom=True, direction='in')
    axes[1].set_ylabel('norm flux')
    axes[1].legend(bbox_to_anchor=(0., 0.,1.,0.15), loc=1, ncol=3,borderaxespad=0.1,fontsize=5)
    axes[1].set_ylim(0.0,1.1)

    axes[2].plot(wave, spec_sn, 'k-')
    axes[2].set_ylabel('S/N pixel')
    axes[2].set_xlabel('$\lambda$ [$\AA$]')
    axes[2].set_ylim(0.0,max(spec_sn)*1.1)


    plt.show()

    #clean the axes
    for ax in axes:
        ax.clear()
    plt.close()
##############################################
root = Tk()
root.geometry("1200x500")
root.title("SP_Ace GUI")
savename=StringVar()

lf0 = ttk.Labelframe(root, text='Main',padding=8)
lf0.grid(column=1, row=0, sticky=(NW),columnspan=3,rowspan=3)
lf1 = ttk.Labelframe(root, text='Necessary keywords',padding=8)
lf1.grid(column=1, row=7, sticky=(NW),columnspan=3,rowspan=5)
lf2 = ttk.Labelframe(root, text='Optional keywords',padding=8)
lf2.grid(column=5, row=7, sticky=(NE),columnspan=4,rowspan=7)
######### tab 1 #########
filename = StringVar()
listname = StringVar()
rej_file = StringVar()
space_exe= StringVar(value=space_executable)
#work_dir = os.getcwd()
work_dir= StringVar(value=os.getcwd())
dirname = StringVar(value=address_GCOG_lib)

#work directory
ttk.Label(lf0, text="work directory:").grid(column=1, row=0, sticky=(W))
work_entry = ttk.Entry(lf0, width=20, textvariable=work_dir)
work_entry.grid(column=2, row=0, sticky=(W, E))
ttk.Button(lf0, text="select", command=chose_work_dir).grid(column=3, row=0, sticky=(E), padx=5,pady=5)

#SP_Ace executable
ttk.Label(lf0, text="SP_Ace executable:").grid(column=1, row=1, sticky=(W))
space_entry = ttk.Entry(lf0, width=20, textvariable=space_exe)
space_entry.grid(column=2, row=1, sticky=(W, E))
ttk.Button(lf0, text="select", command=chose_exe).grid(column=3, row=1, sticky=(E), padx=5,pady=5)

#file choose
file_list_radiobutton=select_file_radiobutton()

#GCOG library choose
ttk.Label(lf1, text="GCOG library:").grid(column=1, row=3, sticky=(W))
dir_entry = ttk.Entry(lf1, width=20, textvariable=dirname)
dir_entry.grid(column=2, row=3, sticky=(W, E))
ttk.Button(lf1, text="select", command=chose_dir).grid(column=3, row=3, sticky=(E), padx=5,pady=5)

#FWHM
ttk.Label(lf1, text="approx FWHM:").grid(column=1, row=4, sticky=(W))
fwhm=StringVar(value='0.4')
fwhm_entry = ttk.Entry(lf1, width=4, textvariable=fwhm)
fwhm_entry.grid(column=2, row=4, sticky=(W), pady=5)

#wave_lims FWHM
ttk.Label(lf1, text="wavelenght limits:").grid(column=1, row=5, sticky=(W))
wave=StringVar(value='4800 6860')
wave_entry = ttk.Entry(lf1, width=20, textvariable=wave)
wave_entry.grid(column=2, row=5, sticky=(W), pady=5)

#batch mode
#batch_flag=BooleanVar()
#check = ttk.Checkbutton(root, text='batch mode', var=batch_flag)
batch_check_button=entry_check('batch_mode (number of jobs)',6,1,1,root,'1')
#batch_flag = BooleanVar(value=batch_check_button.flag.get())
#check.grid(column=5,row=1, pady=5, sticky=(N))

#button to run SPAce
ttk.Button(root, text="Run SP_Ace", command=run_space).grid(column=8, row=1, sticky=(NW))
#button to save results
ttk.Button(root, text="Save results", command=save_results).grid(column=8, row=2, sticky=(NW))
#button to quit
ttk.Button(root, text="  Quit  ", command=root.destroy).grid(column=8, row=3, sticky=(SW))

root.bind('<Return>', run_space)



######### tab 2 #########
#error est
Salaris_MH_flag=BooleanVar(value=True)
check = ttk.Checkbutton(lf2, text='Salaris_MH (by default if libGCOG v1.4)', var=Salaris_MH_flag)
check.grid(column=1,row=0, pady=5, sticky=(W))

#ABD_loop
abdloop_flag=BooleanVar(value=True)
check = ttk.Checkbutton(lf2, text='ABD_loop (by default if libGCOG v1.4)', var=abdloop_flag)
check.grid(column=5,row=0, pady=5, sticky=(W))


#error est
error_flag=BooleanVar()
check = ttk.Checkbutton(lf2, text='error_est', var=error_flag)
check.grid(column=1,row=1, pady=5, sticky=(W))

#alpha
alpha_flag=BooleanVar()
check = ttk.Checkbutton(lf2, text='alpha', var=alpha_flag)
check.grid(column=5,row=1, pady=5, sticky=(W))

#T_force
T_entry_check_button=entry_check('T_force',1,3,5,lf2,'')

#G_force
G_entry_check_button=entry_check('G_force',5,3,5,lf2,'')

#RV_ini
RV_entry_check_button=entry_check('RV_ini',1,4,5,lf2,'0.0')

#sn_flag
sn_entry_check_button=entry_check('sn_ratio',5,4,5,lf2,'')

#norm_rad
Nrad_entry_check_button=entry_check('norm_rad',1,5,5,lf2,'30.0')

#null_value
null_spinbox_check_button=spinbox_check('null value',5,5,5,'null')

#ele2write
ele_entry_check_button=entry_check('ele2write',1,6,5,lf2,'12 14 20 21 22 23 24 26 27 28')

#no_norm
nonorm_flag=BooleanVar()
check = ttk.Checkbutton(lf2, text='no_norm', var=nonorm_flag)
check.grid(column=5,row=6, pady=5, sticky=(W))


#llist_rej
llist_rej_select_file_check_button=select_file_check('llist_rej:',1,8,5,'')

root.mainloop()

