#!/usr/bin/env python3
#
'''
.. module:: HDF_Crosstalk
   :synopsis: Pipeline script for crosstalk analysis  
.. moduleauthor:: Cor de Vries <c.p.de.vries@sron.nl>

Suited for multi-processor execution using 'mpiexec'.

'''
import sys,os
import numpy

from tesfdmtools.hdf.HMUX import HMUX
from tesfdmtools.utils.widgets.filechooser import getfile
from tesfdmtools.methods.IQrotate import IQphase,IQphrot

def mexit(rank,msg):
   if rank == 0:
      sys.exit(msg)
   else:
      sys.exit()

def replicate(n,data):
   pdata=[]                                                                               
   for i in numpy.arange(n):                                                              
      pdata.append(data)                                                                  
   return pdata  

# =======================================================================

if __name__ == "__main__":

   import argparse


   parser = argparse.ArgumentParser(\
                      description="Show crosstalk signals on other pixels for given pixel(frequency)",\
                      epilog="Execution can be distributed over multiple CPU's, using 'mpiexec'")
   parser.add_argument('-d','--diag',action='store_true',required=False,\
                        help='Show diagnostices output and plots')
   parser.add_argument('-ch','--channel',type=int,required=False,default=None,\
                        help='Channel number to use. Default is first available channel.')
   parser.add_argument('-f','--frequency',type=str,nargs='+',required=False,default=[None],\
                       help="Frequency numbers (pixels) to use. Default is first available frequency. 'all' for all frequencies ")
   parser.add_argument('-n','--nrecs',type=str,required=False,default='all',\
                        help='Number of records to display. When single is not selected, average of records is shown. Default is all records')
   parser.add_argument('-t','--threshold',type=int,required=False,default=1000,\
                        help='Threshold for real X-ray pulse detection') 
   parser.add_argument('-x','--xrange',type=int,nargs=2,required=False,default=(0,0),\
                        help='Range of x-values to plot')
   parser.add_argument('-s','--single',action='store_true',required=False,\
                        help='show single records')
   parser.add_argument('-q','--qdata',action='store_true',required=False,\
                        help='also show Q-signals')
   parser.add_argument('-pp','--ppos',action='store_true',required=False,\
                        help='assume all pulses are on same position (triggered data)')
   parser.add_argument('-pr','--prange',type=int,nargs=2,required=False,default=[-200,2000],\
                        help='range in bins to use, centered around pulse peak')
   parser.add_argument('-rot','--rotate',action='store_true',required=False,\
                        help='rotate pulses for minimum Q signal')
   parser.add_argument('--pdf',action='store_true',required=False,\
                        help='Output pdf files of plots') 
   parser.add_argument('-sum','--summary',action='store_true',required=False,\
                        help='print ASCII summary of results')
   parser.add_argument('--txtout',action='store_true',default=None,\
                        help='Write text files with crosstalk data for each pixel')
   parser.add_argument('dir',help='Directory or file of x-ray data HDF5 file(s). Query for file when directory.')

   from mpi4py import MPI

   comm=MPI.COMM_WORLD                       # start multiprocess administration

   rank=comm.rank                            # identify this process
   nproc=comm.size                           # total number of processes

   if rank == 0:                             # use parser only in process 0, and scatter results
      args=parser.parse_args()
      pargs=replicate(nproc,args)
   else:
      if ( len(sys.argv) > 1 ) and ( sys.argv[1] == '-h'):
         sys.exit()
      pargs=None

   args=comm.scatter(pargs,root=0)

   import matplotlib
   if args.pdf:
      matplotlib.use('Agg')
   import matplotlib.pyplot as plt
   from matplotlib.backends.backend_pdf import PdfPages

   from tesfdmtools.methods.HDF_Cutils import surface,crossdata,cprint,cplot,pixelmap,crec

   if os.path.isfile(args.dir):
      cfilename=args.dir
   else:
      if rank == 0:
         if not os.path.exists(args.dir):
            print('Error, directory %s does not exist.' % args.dir)
         cfilename=getfile(pattern='*.h5',path=args.dir)
         pcfilename=replicate(nproc,cfilename)
      else:
         pcfilesname=None

      cfilename=comm.scatter(pcfilename,root=0)   # scatter selected filename to all processes
   
   cfile=HMUX(filename=cfilename)
   fnam=os.path.basename(cfilename).split('.')[0]
   runversion=fnam.replace('Run','').lstrip('_-').replace('-','_').split('_')[0]

   if args.diag and rank == 0:
      print("available channels: ",cfile.channels)
      for ichan in cfile.channels:
         conf_tab=cfile.channel[ichan].freq_conf
         print("channel: ",ichan,"  frequencies: ",cfile.channel[ichan].freqs)
         print("              kHz: ",conf_tab['freq'])

   channel=args.channel
   if channel is None:
      channel=cfile.channels[0]

   conf_tab=cfile.channel[channel].freq_conf
   cpixels=conf_tab['pixel_index']
   fpixels=cfile.channel[channel].freqs
   apixels=numpy.intersect1d(cpixels,fpixels)   # take only frequencies which occur both in the table as in the channel
   if rank == 0:
      print("Available pixels(frequencies): ",apixels)
   if apixels.size == 0:
      mexit(rank,"pixel configuration table and available frequencies do not match")

   if args.frequency[0] is None:
      frequencies=[apixels[0]]
   else:
      fmt='Error, illegal frequency specification: "'+len(args.frequency)*'%s '+'"'
      if args.frequency[0] == 'all':
         if len(args.frequency) > 1:
            mexit(rank,fmt % tuple(args.frequency))
         else:
            frequencies=numpy.array(sorted(apixels),dtype=int)
      else:
         frequencies=[]
         for fff in args.frequency:
            if fff.isdigit():
               frr=int(fff)
               findx,=numpy.where(apixels == frr)
               if len(findx) == 0:
                  mexit(rank,"Selected frequency (pixel) %d does not exist in file" % frr)
               frequencies.append(frr)
            else:
               mexit(rank,fmt % tuple(args.frequency))
         frequencies=numpy.array(sorted(frequencies),dtype=int)

   pid=os.getpid()                      # make unique id for this process
   host=os.uname()[1]
   unq="%s_%d_" % (host,pid)

   crossmatrix=numpy.zeros((len(frequencies),len(conf_tab['pixel_index'])),dtype=float)

   sumtext={}
   if rank == 0:
      print("Pixels (frequencies) selected for processing: ",frequencies)
 
      if args.summary and ( not args.single):
         sumfile=open("Crosstalk_%s.txt" % fnam,'w')
         ll=' file: %s\n' % fnam
         sumfile.write(ll)
         ll='         primary pixel   crosstalk pixels\n'
         sumfile.write(ll)
      else:
         sumfile=None

#      if args.pdf:
#         pdf = PdfPages('Crosstalk_%s.pdf' % fnam)
#      else:
#         pdf=None

      tdata={}
      for iproc in numpy.arange(nproc):      # initialize
         tdata[iproc]=[]
      iproc=0
      for iplt,freq in enumerate(sorted(frequencies)):  # spread frequencies to be processed over different processes
         if iproc in tdata:
            tdata[iproc].append([iplt,freq])
         else:
            tdata[iproc]=[[iplt,freq]]
         iproc=iproc+1
         if iproc >= nproc:
            iproc=0

      pdata=[]
      for pp in tdata:
         pdata.append(tdata[pp])       
   else:
      pdata=None

   frequencies=comm.scatter(pdata,root=0)       # now scatter info to all processes  
   pdfplt={} 

   for jf,frequency in frequencies:

      print("process: ",unq," (iplt,freq): ",jf,frequency)

      pdfnam='freq%3.3d_Crosstalk_%s.pdf' % (jf,fnam)                                   # unique plot file for each frequency
      pdf = PdfPages(pdfnam)
      pdfplt[jf]=pdfnam

      print("Processing frequency(pixel): ",frequency)
      crec=0   
      qdata=None
      if args.single:   # show only single records, no further analysis   
         for nn in numpy.arange(int(args.nrecs)):
            tff,freqs,recs,qrecs,nproc,apulse=crossdata(cfile,channel,frequency,'1',args.threshold,\
                     ppos=args.ppos,prange=args.prange,rotate=args.rotate,debug=args.diag)
            xax=numpy.arange(recs[0,:].size)
            if args.qdata:
               qdata=qrecs
            pnm='_px%d_%d' % (frequency,nn)
            cplot(tff,recs,xax,'sample nr.','ADC',freqs,ps=pdf,filename=fnam,num=pnm,xrange=args.xrange,\
                  qdata=qdata,qtitle='Q-ADC',ptitle=fnam)
      else:   # show averages and analyze data
         tff,freqs,recs,qrecs,nproc,apulse=crossdata(cfile,channel,frequency,args.nrecs,args.threshold,\
                  ppos=args.ppos,prange=args.prange,rotate=args.rotate,debug=args.diag)
         xax=numpy.arange(recs[0,:].size)
         psurf=numpy.zeros(apulse[:,0].size,dtype=float)
         for i in numpy.arange(psurf.size):
            psurf[i],bline=surface(xax,apulse[i,:],range=args.xrange)
            if args.diag:
               print('surface opix: ',i,psurf[i])
         npix=len(tff)
         dlabel=numpy.empty(npix,dtype=object)
         dbckgr=numpy.zeros((npix,xax.size))
         isurfaces=numpy.zeros(npix,dtype=float)
         ievs=numpy.zeros((npix-1),dtype=float)
         for i in numpy.arange(npix):
            surf,bline=surface(xax,recs[i,:],range=args.xrange)
            if i > 0:
               dlabel[i]='sf: %10.1f\nev: %5.2f (eV)' % (surf,(surf/psurf[i-1]*6000.0))
               ievs[i-1]=surf/psurf[i-1]*6000.0
            else:
               dlabel[i]='sf: %10.1f' % surf
            dbckgr[i,:]=bline
            isurfaces[i]=surf
         qlabel=None
         qbckgr=None
         if args.qdata:
            qdata=qrecs
            qlabel=numpy.empty(npix,dtype=object)
            qbckgr=numpy.zeros((npix,xax.size))
            qsurfaces=numpy.zeros(npix,dtype=float)
            for i in numpy.arange(npix):
               surf,bline=surface(xax,qrecs[i,:],range=args.xrange)
               qlabel[i]='sf: %10.1f' % surf
               qsurfaces[i]=surf
               qbckgr[i,:]=bline
         pt="%s   n=%d" % (fnam,nproc)
         pnm='_px%d' % frequency
         cplot(tff,recs,xax,'sample nr.','ADC',freqs,ps=pdf,filename=fnam,xrange=args.xrange,qdata=qdata,qtitle='Q-ADC',\
               dlabel=dlabel,qlabel=qlabel,dbckgr=dbckgr,qbckgr=qbckgr,ptitle=pt,num=pnm)

         if args.txtout:
            txtname='Crosstalk_%s_px_%d.txt' % (fnam,frequency)
            cprint(txtname,tff,freqs,xax,recs,qrecs,xrange=args.xrange,filename=fnam)

         for ji,jp in enumerate(sorted(tff)):
            jxx,=numpy.where( tff == jp )
            jx=jxx[0]
            if jx > 0:
               crossmatrix[jf,ji]=ievs[jx-1]
        
         if args.summary and ( not args.single):
            nfreq=len(freqs)
            fff=numpy.zeros(nfreq,dtype=float)
            for i in numpy.arange(nfreq):
               fff[i]=freqs[tff[i]]
            fmt=nfreq*'%12.2f'
            ifmt=nfreq*'%12d'
            efmt='            '+(nfreq-1)*'%12.3f'
            stxt=''
            stxt=stxt+('  nevents:%12d\n' % nproc)
            stxt=stxt+(('    pixel:'+ifmt+'\n') % tuple(tff))
            stxt=stxt+(('frequency:'+fmt+'   (kHz)\n') % tuple(fff))
            stxt=stxt+(('I-surface:'+fmt+'\n') % tuple(isurfaces))
            if args.qdata:
               stxt=stxt+(('Q-surface:'+fmt+'\n') % tuple(qsurfaces))
            stxt=stxt+(('   Energy:'+efmt+'   (eV)\n') % tuple(ievs))
            sumtext[jf]=stxt

      if args.pdf:
         pdf.close()

# store information of this process  

   sdata={'pdfnam':pdfplt,'sumtext':sumtext,'frequencies':frequencies,'crossmatrix':crossmatrix} 

# gather information from all processes

   pdata=comm.gather(sdata,root=0)

   if rank != 0:        # stop all processes except main process
      print("Stop process %s" % unq)
      sys.exit()

   if args.single:
#      if args.pdf:
#         pdf.close() 
      sys.exit() 

   pixels=numpy.array(sorted(tff))

# for the main process start assembling of the data

   print("finishing up in process %s...................." % unq)
   pdfs={}
   sumtexts={}
   fff=[]
   cmatrix={}
   for jd,data in enumerate(pdata):
      for iplt in data['pdfnam']:
         pdfs[iplt]=data['pdfnam'][iplt]
      for jf in data['sumtext']:
         sumtexts[jf]=data['sumtext'][jf]
      cmatrix[jd]=data['crossmatrix']
      for iplt,frq in data['frequencies']:
         fff.append(frq)
 
   frequencies=numpy.array(sorted(fff),dtype=int)
   spixels=numpy.array(sorted(frequencies))

# merge crosstalk-matrices

   for jn,jf in enumerate(cmatrix):
      if jn == 0:
         crossmatrix=cmatrix[jn]
      else:
         for i in numpy.arange(crossmatrix[0,:].size):
            for j in numpy.arange(crossmatrix[:,0].size):
               if cmatrix[jf][j,i] != 0.0:
                  crossmatrix[j,i]=cmatrix[jf][j,i] 

# print summary file

   if sumfile is not None:

      for jf in numpy.arange(len(frequencies)):
         sumfile.write(sumtexts[jf])
      sumfile.write('\n\n\nCrosstalk matrix (eV)\n')
      ifmt=nfreq*'%9d'
      ffmt=nfreq*'%9.3f'
      sumfile.write(('         '+ifmt+'\n') % tuple(sorted(tff)))
      for ip,jp in enumerate(spixels):
         pp='%9d' % jp
         ll=((pp+ffmt+'\n') % tuple(crossmatrix[ip,:])).replace('0.000','     ')
         sumfile.write(ll)

      sumfile.close()

# frequency distance matrix

   frdist=numpy.zeros_like(crossmatrix)
   sff=sorted(tff)
   for j in numpy.arange(len(tff)):
      for i in numpy.arange(len(frequencies)):
         if i != j:
            frdist[i,j]=freqs[sff[i]]-freqs[sff[j]]
  
# plot correlation matrix + pixel configuration data


   if args.pdf:
      cpdfnam='Cmatrix_Crosstalk_%s.pdf' % fnam                                   # unique plot file for each frequency
      pdf = PdfPages(cpdfnam)
   
   pr=2

   cplot = plt.subplot2grid((pr,1), (0,0) )
   ccc=numpy.fabs(crossmatrix)
   cmax=20.0
   ccc=numpy.where(ccc > cmax,cmax,ccc)
   pixelmap(cplot,ccc,pixels,title='Crosstalk matrix',ypixelax=numpy.array(frequencies))

   plt.suptitle(fnam,size=12)
      
   if args.pdf:
      fpage = plt.gcf()
      fpage.set_size_inches(8.27, 10.75)
      plt.tight_layout()
      plt.subplots_adjust(top=0.94)
      plt.savefig(pdf,format='pdf',papertype='a4')
      pdf.close()
   else:
      plt.tight_layout()
      plt.show()
      plt.close('all') 

 # now combine pdf's and delete duplicates

   print("combine pdf's................")

   cmd='pdfunite '
   for iplt in numpy.arange(len(frequencies)):
      if os.path.exists(pdfs[iplt]):
         cmd=cmd+' '+pdfs[iplt] 
   cmd=cmd+' '+cpdfnam
   pdfnam='Crosstalk_%s.pdf' % fnam
   cmd=cmd+' '+pdfnam
#
#   print(cmd)
   xx=os.system(cmd)
   if xx == 0:
      for pnn in pdfs:
         if os.path.exists(pdfs[pnn]):
            os.unlink(pdfs[pnn])
      os.unlink(cpdfnam)
   else:
      sys.exit("Error executing 'pdfunite' to combine pdf's ") 

   print("Finished!")       
   


    


            

   
