#!/usr/bin/env python3
#
'''
.. module:: mhfit
   :synopsis: Manual holzerfit on eventlist optimal fits  
.. moduleauthor:: Cor de Vries <c.p.de.vries@sron.nl>

'''

import matplotlib
if __name__ == "__main__":
   matplotlib.use('GTK3Agg')
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages


from matplotlib.widgets import Cursor

from tesfdmtools.utils.reventlist import reventlist
from tesfdmtools.methods.HDF_Eventutils import makespectrum,limdist
from tesfdmtools.utils.widgets.filechooser import getfile
from tesfdmtools.utils.widgets.printbutton import printbutton
from tesfdmtools.utils.tesfdm_defaults import Tesfdm_Defaults
from tesfdmtools.methods.holzer import holzerfit,holzer
from tesfdmtools.utils.fiterrplot import fiterrplot

import numpy


def onclick(event):
   '''
   Get two positions from mouse-clicks in plot window
   '''
   global fig,ax,nclicks,xx,yy,cid,msg
   
   if nclicks == 0:
      xx[0]=event.xdata
      yy[0]=event.ydata
      nclicks=1
      print(msg)
   else:
      xx[1]=event.xdata
      yy[1]=event.ydata
      nclicks=0
      if xx[0] > xx[1]:
         xxx=xx[0]
         xx[0]=xx[1]
         xx[1]=xxx  
      fig.canvas.mpl_disconnect(cid)
      fig.canvas.manager.window.destroy()
  
#===========================================================

if __name__ == '__main__':

   import os,sys
   import argparse
   
   from tesfdmtools.methods.mxsspec import findpeaks,fitlines,mbasecorr

   parser = argparse.ArgumentParser(\
                      description="Manual holzerfit on eventlist (generated be HDF_Eventpar) optimal fit list. ") 
   parser.add_argument('-n','--nbins',type=int,required=False,default=5000,\
                       help='Number of bins for the optimum fit spectrum')
   parser.add_argument('-t','--tsel',action='store_true',required=False,\
                       help='Select events on fall time when available')
   parser.add_argument('--xfrac',type=float,required=False,default=0.05,\
                       help='Fraction of events at low and high end to exclude as outliers')
   parser.add_argument('-s','--show',action='store_true',required=False,\
                       help='When set, only show the spectrum, do not attempt any fits')
   parser.add_argument('--log',action='store_true',required=False,\
                       help='When set, use log scale for spectrumplot with "--show"')
   parser.add_argument('--mxs',action='store_true',required=False,\
                       help='Adapt parameters to process an mxs+mn55 input spectrum')
   parser.add_argument('--alsi',action='store_true',required=False,\
                       help='Include processing of the Al and Si lines in the MXS spectrum')
   parser.add_argument('--pdf',action='store_true',required=False,\
                       help='When "--mxs" is set, output MXS fit results to pdf file')
   parser.add_argument('-notb','--notimebounds',action='store_true',required=False,\
                       help='Do not automatically select bounds on the fall time selection window')
   parser.add_argument('--debug',action='store_true',required=False,\
                       help='Extra output for debugging purposes')

   args=parser.parse_args() 
   defs=Tesfdm_Defaults() 
   
   if args.mxs:
      if args.xfrac > 0.04:
         args.xfrac=0.001
      args.show=True
      if args.nbins < 10000:
         args.nbins=40000

   elist=getfile(path=defs.get_filepath('lst'),pattern='eventlist_*.lst')
   if elist == '':
      sys.exit()
   defs.set_filepath(elist)

   fnam=os.path.basename(elist).rsplit('.',1)[0]
   ftit=fnam.replace('eventlist_','')
   
   indx,optfit,basel,peaks,hdffile,rtimes,ftimes,energies=reventlist(elist)
   
   print("Number of events read: ",optfit.size)

   xx=numpy.zeros(2,dtype=float)
   yy=numpy.zeros(2,dtype=float)
   nclicks=0

   if ( ftimes is not None ) and args.tsel:

      indx,=numpy.where( ftimes > 0.0 )
      tmin,tmax=limdist(ftimes[indx],bins=50)

      fig, ax = plt.subplots(1)
      printbutton(fig,orientation='landscape')
      ax.plot(ftimes[indx],optfit[indx],'b.')
      if not args.notimebounds:
         rr=list(ax.axis())
         rr[0]=tmin
         rr[1]=tmax
         ax.axis(rr)   
      ax.set_xlabel('Fall time (usec)')
      ax.set_ylabel('Opt. fit')
      ax.set_title(fnam)
      msg="Select high bound of fall time interval"
      print(" Select low bound of fall time interval")  
      cursor=Cursor(ax, useblit=True, color='red', linewidth=1 )	# show cursor with crosshair
      cid = fig.canvas.mpl_connect('button_press_event', onclick)	# activate 'onclick' routine
      plt.show()
      plt.close('all')
      print("Selected fall time interval: ",xx)
      tindx,=numpy.where( ( ftimes > xx[0] ) & ( ftimes < xx[1] ) )
   
      spectrum,bins=makespectrum(optfit[tindx],sbins=args.nbins,efrac=args.xfrac)

   else:

      spectrum,bins=makespectrum(optfit,sbins=args.nbins,efrac=args.xfrac)
  
   if args.mxs:

      fontsize=10
      if not args.debug:
         fontsize=7
      matplotlib.rcParams.update({'font.size': fontsize})

      if args.pdf:      
         pdfnam='mhfit-mxs_'+ftit+'.pdf'
         pdf=PdfPages(pdfnam)
      
      if args.debug:
         pltax=None
      else:     
         pltax={}		# plot placements for overview pdf plot
         plt.suptitle(ftit,size=12)
         pltax['mn']=plt.subplot2grid((4,2),(0,0))	# mn line fit plot
         pltax['cu']=plt.subplot2grid((4,2),(0,1))      # cu line fit plot
         pltax['cr']=plt.subplot2grid((4,2),(1,0))      # cr line fit plot
         pltax['pecal']=plt.subplot2grid((4,2),(1,1))	# prelimenary energy scale plot
         pltax['spectrum']=plt.subplot2grid((4,1),(3,0))  # total spectrum plot
         pltax['tilts']=plt.subplot2grid((4,2),(2,0))   # baseline tilts fit
         pltax['ecal']=plt.subplot2grid((4,2),(2,1))    # final energy calibration
            
      escale,lindx,lenergy,lxx,lfit=findpeaks(bins,spectrum,fitas=args.alsi,debug=args.debug,pltax=pltax)
      
      if escale is not None:
            
         cspectrum,tpar=mbasecorr(optfit,bins,escale,basel,pltax=pltax,\
                                  debug=args.debug)  # correct spectrum for baseline dependence

         if args.debug:                             
            f,ax=plt.subplots(1)
            ax.plot(bins,spectrum,'b-',bins,cspectrum,'r-')
            ax.set_xlabel('Optimal fit')
            ax.set_ylabel('Counts')
            ax.set_title('Original spectrum vs. corrected')
            plt.show()
            plt.close('all')
         
#         lfit={'mn':[True,True],'cu':[True,True],'cr':[True,True]} # the lines to fit 'element':[alpha,beta]
         for el in lfit:  # when there is an Alpha line, always fit the beta line
            if lfit[el][0]:
               lfit[el][1]=True
         print("lfit: ",lfit)
         escale,afitpars=fitlines(escale,cspectrum,debug=args.debug,\
                                  lfit=lfit,lindx=lindx,lenergy=lenergy,fixbgain=True,\
                                  pltaxs=pltax,optbins=bins)
         if args.debug:
            for element in afitpars:
               print('lratio: ',element,afitpars[element][-1])
                         
# convert optimal fit results into energies for all events

         optee=numpy.interp(optfit,bins,escale)
         
# check spectrum

      if args.debug:
         
         nspectrum,nebins=numpy.histogram(optee,bins=args.nbins)
         nebins=(nebins[:-1]+nebins[1:])/2.0
         f,ax=plt.subplots(1)
         ax.plot(nebins,nspectrum,'b-')
         ax.set_xlabel('Energy (eV)')
         ax.set_ylabel('Counts')
         ax.set_title('Events spectrum')
         plt.show()
         plt.close('all')

      if pltax is not None:		# output all plots to pdf
         fpage=plt.gcf()
         fpage.set_size_inches(8.27, 10.75)
         plt.tight_layout()
         plt.subplots_adjust(top=0.94)
         if args.pdf:
            plt.savefig(pdf,format='pdf',papertype='a4')
         else:
            plt.show()
         plt.close('all')

      if args.pdf:      
         pdf.close()
     
      sys.exit()
      
   if args.show:		# just show the spectrum; no processing
 
      xlabel='Optimal fit parameter'  
      if energies is not None:
         spectrum,bins=makespectrum(energies,sbins=args.nbins,efrac=args.xfrac)
         xlabel='energy (eV)'
      f,ax=plt.subplots(1)
      printbutton(f,orientation='landscape')  
      ax.plot(bins,spectrum,'b-')
      if args.log:
         if int(matplotlib.__version__[0]) >= 3:
            ax.set_yscale('symlog',linthresh=1.0)
         else:
            ax.set_yscale('symlog',linthreshy=1.0)         
      ax.set_xlabel(xlabel)
      ax.set_ylabel('Counts')
      ax.set_title('Optimal fit spectrum')
      plt.show()
      plt.close('all')
      
      sys.exit()

   fig, ax = plt.subplots(1)
   printbutton(fig,orientation='landscape')  

   ax.plot(bins,spectrum)					# plot optimal fit spectrum
   ax.set_xlabel('opt. fit')
   ax.set_ylabel('Counts')
   ax.set_title(ftit)

   apos=5898.882						# main positions of Mnk55 peaks
   bpos=6486.31

   msg="Select highest Alpha peak"
   cursor=Cursor(ax, useblit=True, color='red', linewidth=1 )	# show cursor with crosshair
   cid = fig.canvas.mpl_connect('button_press_event', onclick)	# activate 'onclick' routine

   print("Select Beta line in plot")

   plt.show()
   plt.close('all')
 
   print("Selected line positions: ",xx)

   pee=apos+(bins-xx[0])/(xx[1]-xx[0])*(bpos-apos)		# energy calibration based on selected
								# alpha and beta positions

   aindx,=numpy.where( ( pee > 5850.0 ) & ( pee < 5930.0 ) )	# select window on alpha lines

   fig, ax = plt.subplots(1)
   printbutton(fig,orientation='landscape') 

   ax.plot(pee[aindx],spectrum[aindx])				# plot with energy axis
   ax.set_xlabel('Energy (eV)')
   ax.set_ylabel('Counts')
   ax.set_title(ftit)
   cursor=Cursor(ax, useblit=True, color='red', linewidth=1 )	# show cursor with crosshair
   cid = fig.canvas.mpl_connect('button_press_event', onclick)	# activate 'onclick' routine
   print("Select fit range")
   msg="Select 2nd fit boundary"
   plt.show()
   plt.close('all')

   indx,=numpy.where( ( pee > xx[0] ) & ( pee < xx[1] ) )	# select fit window

   bb=holzerfit(pee[indx]/1000.0,spectrum[indx],line='Alpha')	# perform holzer fit
   fit=holzer(bb,pee[aindx]/1000.0)
   yerr=numpy.zeros(spectrum.size)
   yerr[:]=1E9							# infinite error outside fit window
   yrr=numpy.sqrt(spectrum[indx])				# 1 sigma error inside wift window
   yerr[indx]=numpy.where(yrr == 0,1.0,yrr)

   print("\n    FWHM: %5.2f (eV)\n" % (bb[3]*1000.0))

   fig, ax = plt.subplots(1) 
   printbutton(fig)

   tt=['Pos1:','Pos2:','Counts:','FWHM:','C-stat:','Degr. of frdm:']
   fig.caption=['%7s %7.4f' % (tt[0],bb[0]),
                '%7s %7.4f' % (tt[1],bb[1]),
                '%7s %7d' % (tt[2],bb[2]),
                '%7s %5.2f (eV)' % (tt[3],bb[3]*1000.0),
                '%7s %7.1f' % (tt[4],bb[4]),
                '%s %d' % (tt[5],bb[5]) ]

   fiterrplot(pee[aindx],spectrum[aindx],yerr[aindx],fit,\
              xtitle='Energy (eV)',ytitle='Counts',ptitle=ftit,ptxt='',pltax=ax) # show result
   plt.show()
   plt.close('all')
     
