#!/usr/bin/env python3
#
'''
.. module:: HDF_Bfit
   :synopsis: Script for optimal fitting of baseline only records  
.. moduleauthor:: Cor de Vries <c.p.de.vries@sron.nl>

Suited for multi-processor execution using 'mpiexec'.

'''
import numpy

def mexit(rank,msg):
   if rank == 0:
      sys.exit(msg)
   else:
      sys.exit()

def replicate(n,data):
   pdata=[]
   for i in numpy.arange(n):
      pdata.append(data)
   return pdata

# =============================================================

if __name__ == "__main__":

   import sys,os
   import argparse

   parser = argparse.ArgumentParser(\
                      description="Process baseline HDF5 records for optimal fitting",\
                      epilog="Execution can be distributed over multiple CPU's, using 'mpiexec'")
   parser.add_argument('-ch','--channel',type=int,required=False,default=None,\
                       help='Channel number to use. Default is first available channel.')
   parser.add_argument('-f','--frequency',type=str,nargs='+',required=False,default=[None],\
                       help="Frequency numbers (pixels) to use. Default is first available frequency. 'all' for all frequencies ")
   parser.add_argument('-n','--nrecs',type=str,required=False,default='all',\
                        help='Number of records to display. When single is not selected, average of records is shown. Default is all records')
   parser.add_argument('-sum','--summary',action='store_true',required=False,\
                       help='Print summary of results to file')
   parser.add_argument('--pdf',action='store_true',required=False,\
                       help='Only pdf output')
   parser.add_argument('--debug',action='store_true',required=False,\
                       help='show some debug info + plots on screen')
   parser.add_argument('--threshold','-t',type=int,default=2000,required=False,\
                       help='Amplitude threshold for event detection (default 2000)') 
   parser.add_argument('--bsec',type=float,required=False,default=0.05,\
                       help='Section (start+end) of record to use for background (default 0.05)')
   parser.add_argument('--bpo',action='store_true',required=False,\
                       help='select only pretrigger section for baseline value')
   parser.add_argument('--absolute',action='store_true',required=False,\
                       help='Use sqrt(I^2+Q^2) for pulse record')
   parser.add_argument('-ect','--ectfilter',action='store_true',required=False,\
                       help='Perform filtering to exclude electric crosstalk positive pulses') 
   parser.add_argument('--noise',action='store_true',required=False,\
                       help='Let the script ask for seperate noise file')
   parser.add_argument('--brange',nargs=2,type=int,default=[None,None],\
                       help='Range of baseline values to select. Default: no selection')
   parser.add_argument('-nf','--noisefile',type=str,required=False,default=None,\
                       help='Define seperate noise file(name)')
   parser.add_argument('-np','--nopulseposition',action='store_true',required=False,\
                       help='do not select pulses based on pulse position')
   parser.add_argument('-fc','--freqcutoff',type=int,required=False,default=None,\
                       help='Maximum frequency cutoff for optimal filtering')
   parser.add_argument('-pl','--prlen',type=int,required=False,default=None,\
                       help='Length (power of 2) of record to process for optimal filtering (default entire record)')
   parser.add_argument('-rot','--rotate',action='store_true',required=False,\
                       help='Rotate events for minimum Q-signal (based on 100 events phase average)')
   parser.add_argument('-ph','--phase',type=float,required=False,default=None,\
                       help='Rotate events with given fixed phase (degrees)')
   parser.add_argument('dir',help='Directory or file of x-ray data HDF5 file(s)')

   from mpi4py import MPI

   comm=MPI.COMM_WORLD			     # start multiprocess administration

   rank=comm.rank			     # identify this process
   nproc=comm.size			     # total number of processes

   if rank == 0:			     # use parser only in process 0, and scatter results
      args=parser.parse_args()
      pargs=replicate(nproc,args)
   else:
      if ( len(sys.argv) > 1 ) and ( sys.argv[1] == '-h'):
         sys.exit()
      pargs=None

   args=comm.scatter(pargs,root=0)

   freqcutoff=args.freqcutoff

   import matplotlib
   if args.pdf: 
      matplotlib.use('Agg')
   import matplotlib.pyplot as plt
   from matplotlib.backends.backend_pdf import PdfPages

   from tesfdmtools.utils.widgets.filechooser import getfile
   from tesfdmtools.hdf.HMUX import HMUX
   from tesfdmtools.methods.fitcor import fitcor
   from tesfdmtools.utils.fiterrplot import fiterrplot

   from tesfdmtools.methods.HDF_Eventutils import \
                                 eventpars,plotepar,selectevs,noisefile,noisespec,\
                                 makeplot,subsel,makespectrum

   from tesfdmtools.methods.Bfit_Utils import optfilt,nfilt,fitgauss

   if os.path.isfile(args.dir):
      cfilename=args.dir
   else:
      if rank == 0:			     # only select filename in process 0
         if not os.path.exists(args.dir):
            print('Error, directory %s does not exist.' % args.dir)
         cfilename=getfile(pattern='*.h5',path=args.dir)
         pcfilename=replicate(nproc,cfilename)
      else:
         pcfilename=None

      cfilename=comm.scatter(pcfilename,root=0)   # scatter selected filename to all processes

   if args.prlen is not None:                                       # optimal filter procesing length
      proclen=2**int(numpy.log(args.prlen)/numpy.log(2))        # must be power of 2
   else:
      proclen=None
   
   hdf=HMUX(filename=cfilename)
   fnam=os.path.basename(cfilename).split('.')[0]

   channel=args.channel
   if channel is None:
      channel=int(hdf.channels[0])

   conf_tab=hdf.channel[channel].freq_conf
   cpixels=conf_tab['pixel_index']
   fpixels=hdf.channel[channel].freqs
   apixels=numpy.intersect1d(cpixels,fpixels)	# take only frequencies which occur both in the table as in the channel
   if rank == 0:
      print("Available pixels(frequencies): ",apixels)
   if apixels.size == 0:
      mexit(rank,"pixel configuration table and available frequencies do not match")

   if args.frequency[0] is None:
      frequencies=[apixels[0]]
   else:
      fmt='Error, illegal frequency specification: "'+len(args.frequency)*'%s '+'"'
      if args.frequency[0] == 'all':
         if len(args.frequency) > 1:
            mexit(rank,fmt % tuple(args.frequency))
         else:
            frequencies=numpy.array(sorted(apixels),dtype=int)
      else:
         frequencies=[]
         for fff in args.frequency:
            if fff.isdigit():
               frr=int(fff)
               findx,=numpy.where(apixels == frr)
               if len(findx) == 0:
                  mexit(rank,"Selected frequency (pixel) %d does not exist in file" % frr)
               frequencies.append(frr)
            else:
               mexit(rank,fmt % tuple(args.frequency))
         frequencies=numpy.array(sorted(frequencies),dtype=int)

   nrow=4
   ncol=1
   fontsize=10
   if len(frequencies) > 4:
      ncol=2
      fontsize=7
   else:
      nrow=len(frequencies)

   pid=os.getpid()			# make unique id for this process
   host=os.uname()[1]
   unq="%s_%d_" % (host,pid)

   matplotlib.rcParams.update({'font.size': fontsize})

   show='X'
   if args.pdf:
      show='pdf'

   rotate=args.rotate
   rotp=0.0
   if args.phase is not None:
      rotate=args.phase/180.0*numpy.pi
      rotp=rotate

   ptitle='%s' % fnam
   pdfplt={}
   fitplt={}
   titplt={}

   nfile=None

   if rank == 0:			     # identify info for the separate processes

      if args.summary:
         sfile=open('Bfit_'+fnam+'.txt','w')
         sfile.write('file: %s\n' % fnam)
         sfile.write('pars:')
         pttt=''
         for pp in sys.argv[1:-1]:
            pttt=pttt+' '+pp
         pttt=pttt+'\n'
         sfile.write(pttt)
         sfile.write('\n  i  px freq(khz)    V_bias      gbwp  fit(err)(eV)   records\n\n')

      tdata={}				     # list of iplt's and frequencies to spread over processes
      for iproc in numpy.arange(nproc):	     # initialize
         tdata[iproc]=[]
      iproc=0
      for iplt,freq in enumerate(sorted(frequencies)):	# spread frequencies to be processed over different processes
         if iproc in tdata:
            tdata[iproc].append([iplt,freq])
         else:
            tdata[iproc]=[[iplt,freq]]
         iproc=iproc+1
         if iproc >= nproc:
            iproc=0

      pdata=[]
      for pp in tdata:
         pdata.append(tdata[pp]) 

   else:

      pdata=None


   avapos=(5897.867+5887.743)/2.0               # average alpha line position (eV)
   
   frequencies=comm.scatter(pdata,root=0)	# now scatter info to all processes
   sfiletxt={}

# process events into spectra

   for iplt,freq in frequencies:

      print("process: ",unq," (iplt,freq): ",iplt,freq)

      pdfnam='freq%3.3d_Bfit_%s.pdf' % (iplt,fnam)                                   # unique plot file for each frequency
      pdf = PdfPages(pdfnam)
      pdfplt[iplt]=pdfnam
      
      pindx,=numpy.where(conf_tab['pixel_index'] == freq)
      channel,freq,blines1,blines2,positions,peaks,surf,ptp,rclen=\
             eventpars(hdf,channel=channel,freq=freq,nrecs=args.nrecs,rotation=rotp,\
                       threshold=args.threshold,bsec=args.bsec,bpt=args.bpo)        # get basic pulse parameters
      bline=blines1+positions/float(rclen)*(blines2-blines1)
      indx=selectevs(hdf,channel,freq,bline,positions,peaks,surf,rclen,\
                     show=show,debug=args.debug,nopos=args.nopulseposition,pdf=pdf) # select proper events
      if ( args.brange[0] is not None ) and ( args.brange[1] is not None ):	    # any additional baseline selection?
         indx=subsel(indx,bline,args.brange)
      if args.ectfilter:
         ectpars=filterect(hdf,channel,freq,indx,blines1,blines2)                   # filter electric crosstalk positive pulses
      else:
         ectpars=None
      if show is not None:
         plotepar(hdf,channel,freq,bline[indx],positions[indx],peaks[indx],surf[indx],\
                  show=show,pdf=pdf,indx=indx,ectpar=ectpars)                       # show parameters of selected events
      if args.ectfilter:
         indx=indx[ectpars[0]]						            # use electric crosstalk filter if applicable
      if indx.size < 20:
         print("Insufficient events for frequency: ",freq)
         fitplt[iplt]=None
         titplt[iplt]="ch=%d, px=%d" % (channel,freq)                               # plot titles
         continue
      if args.noise:                                                                # seperate noise file ?
         nfile=True                                                                 # ask for file
      else:
         nfile=args.noisefile                                                       # noise file can be given
      nhdf,nptp=noisefile(hdf,channel,freq,nfile=nfile)
      if nhdf is None:  
         noisespc,fax=noisespec(hdf,channel,freq,ptp,debug=args.debug,\
                                absolute=args.absolute,prlen=proclen)               # compute noise spectrum from xray file
      else:
         noisespc,fax=noisespec(nhdf,channel,freq,nptp,debug=args.debug,\
                                absolute=args.absolute,prlen=proclen)               # compute noise spectrum from noise file
      avpulsepos=numpy.mean(positions[indx])                                        # average pulseposition


      ofilter,norm=optfilt(hdf,channel,freq,indx,blines1,blines2,noisespc,\
                     freqcutoff=freqcutoff,debug=args.debug,rotate=rotate,\
                     bsec=args.bsec,absolute=args.absolute,prlen=proclen,\
                     avpulsepos=avpulsepos)           				    # compute optimal filter

      if nhdf is None:
         optfits,indx=nfilt(hdf,channel,freq,ptp,ofilter,norm,debug=args.debug,absolute=args.absolute,\
                       rotate=rotate,freqcutoff=freqcutoff,prlen=proclen,\
                       threshold=args.threshold)	    			    # perform optimal fits on noise records
      else:
         optfits,indx=nfilt(nhdf,channel,freq,nptp,ofilter,norm,debug=args.debug,absolute=args.absolute,\
                       rotate=rotate,freqcutoff=freqcutoff,prlen=proclen,\
                       threshold=args.threshold) 	   			    # perform optimal fits on noise records

      
      spectrum,bins=makespectrum(optfits*avapos,sbins=500,debug=args.debug)         # make spectrum
      fitplt[iplt]={}
      fwhm,fpars=fitgauss(fitplt,iplt,hdf,channel,freq,spectrum,bins,\
                             debug=args.debug)                                      # fit gaussian function
      nrecs=numpy.sum(spectrum)							    # number of records used
      fwhmtxt="%7.2f(%4.2f)" % (fwhm,fpars[2])                                      # FWHM + 1-sigma error

      fffx,=numpy.where( cpixels == freq )
      freqindx=fffx[0]
      khz=conf_tab['freq'][freqindx]

# store information of this process
 
      if args.summary:
         sfiletxt[freq]='%3d %3d %9.2f %9.2f %9.4f %13.13s %8d\n' % \
                         (iplt,freq,khz,conf_tab['ampl'][pindx[0]],conf_tab['gbwp'][pindx[0]],\
                          fwhmtxt,nrecs)

      pdf.close()

# store information of this process

   sdata={'sfiletxt':sfiletxt,'pdfnam':pdfplt,'fitplt':fitplt,'frequencies':frequencies}

# gather information from all processes

   pdata=comm.gather(sdata,root=0)

   if rank != 0:	# stop all processes except main process
      print("Stop process %s" % unq)
      sys.exit()

# for the main process start assembling of the data

   print("finishing up in process %s...................." % unq)
   pdfs={}
   sfiletxt={}
   fitplt={}
   fff=[]
   for data in pdata:
      for iplt in data['pdfnam']:
         pdfs[iplt]=data['pdfnam'][iplt]
      for ff in data['sfiletxt']:
         sfiletxt[ff]=data['sfiletxt'][ff]
      for iplt in data['fitplt']:
         fitplt[iplt]=data['fitplt'][iplt]
      for iplt,frq in data['frequencies']:
         fff.append(frq) 

   frequencies=numpy.array(sorted(fff),dtype=int)

   pdfmain='main_Bfit_%s.pdf' % fnam
   pdf = PdfPages(pdfmain)

# line fit plots

   for jplt in numpy.arange(len(frequencies)):

      iplt=jplt % 8
      if iplt == 0:
         plots={}
         plt.suptitle(ptitle,size=12)
         
      plots[iplt] = plt.subplot2grid((nrow,ncol), ((iplt//ncol),(iplt % ncol)) )
      fp=fitplt[jplt]
      if fp is not None:
         fiterrplot(fp['xax'],fp['spectrum'],fp['yerr'],fp['gauss'],\
                 xtitle="Energy (eV)",ytitle="Counts",ptitle=fp['ptitle'],ptxt=fp['ptxt'],pltax=plots[iplt])
      else:
         plots[iplt].set_title(titplt[jplt])

      if ( iplt == 7 ) or ( jplt == (len(frequencies)-1) ):
         makeplot(pdf,ps=args.pdf)

# (ascii) summary file

   if args.summary:
      for ff in frequencies:
         if ff in sfiletxt:
            sfile.write(sfiletxt[ff])
         else:
            print("No summary info found for frequency: ",ff)
      if 'total' in sfiletxt:
         sfile.write(sfiletxt['total'])
         sfile.write(sfiletxt['totct'])  
      sfile.write('\n')
      sfile.close()
  
   pdf.close()

# now combine pdf's and delete duplicates

   print("combine pdf's................")

   cmd='pdfunite %s' % pdfmain
   for iplt in numpy.arange(len(frequencies)):
      if os.path.exists(pdfs[iplt]):
         cmd=cmd+' '+pdfs[iplt] 
   pdfnam='Bfit_%s.pdf' % fnam
   cmd=cmd+' '+pdfnam
   xx=os.system(cmd)
   if xx == 0:
      os.unlink(pdfmain)
      for pnn in pdfs:
         if os.path.exists(pdfs[pnn]):
            os.unlink(pdfs[pnn])
   else:
      sys.exit("Error executing 'pdfunite' to combine pdf's ") 

   print("Finished!")  

   


