#!/usr/bin/env python3
#

import sys,os
import numpy

def mexit(rank,msg):
   if rank == 0:
      sys.exit(msg)
   else:
      sys.exit()

def replicate(n,data):
   pdata=[]
   for i in numpy.arange(n):
      pdata.append(data)
   return pdata

# =======================================================================

if __name__ == "__main__":

   from mpi4py import MPI

   comm=MPI.COMM_WORLD                       # start multiprocess administration

   rank=comm.rank                            # identify this process
   nproc=comm.size                           # total number of processes

   import argparse

   pid=os.getpid()                           # make unique id for this process
   host=os.uname()[1]
   unq="%s_%d_" % (host,pid)

   if rank == 0:                             # user interface only in main process
      parser = argparse.ArgumentParser(\
                         description="Compute NEP for specified X-ray file and accompanying noise file",\
                         epilog="Execution can be distributed over multiple CPU's, using 'mpiexec'")
      parser.add_argument('-d','--diag',action='store_true',required=False,\
                           help='Show diagnostices output and plots')
      parser.add_argument('-lf','--logfreq',action='store_true',required=False,\
                           help='Logarithmic scale for any frequency axis')
      parser.add_argument('-c','--cutoff',default=30000,required=False,type=int,\
                           help='Cuttoff frequency in Hz for computation of NEP')
      parser.add_argument('-x','--xexclude',default=0.05,required=False,type=float,\
                           help='Fraction of pulses, sorted on intergrated maximum, to exclude. '+\
                                'Aim is to exclude double pulses')
      parser.add_argument('-n','--ninclude',default=0.2,required=False,type=float,\
                           help='Fraction of lowest noise records to include.'+\
                           'aim to to exclude data record with drifts')
      parser.add_argument('-a','--all',action='store_true',required=False,\
                           help='process all files in current or given directory')
      parser.add_argument('-ch','--channel',type=str,required=False,default=None,\
                           help='Channel number to use. Default is first available channel. "all" for all channels')
      parser.add_argument('-f','--frequency',type=str,nargs='+',required=False,default=[None],\
                           help='Frequency numbers (pixels) to use. Default is first available frequency. "all" for all frequencies')
      parser.add_argument('dir',help='Directory of x-ray and noise data HDF5 file(s), or x-ray file')
      parser.add_argument('--pdf',action='store_true',required=False,\
                           help='Output pdf files of plots')
      parser.add_argument('--noise',action='store_true',required=False,\
                           help='Ask for noise file (only applicable when "--all" is not set)')
      parser.add_argument('-sum','--summary',action='store_true',required=False,\
                           help='Write ASCII summary file of results') 
      parser.add_argument('-ip','--ignoreposition',action='store_false',required=False,\
                           help='If set, also use pulses outside trigger position')
      parser.add_argument('-rot','--rotate',action='store_true',required=False,\
                           help='rotate for minimum Q, based on average phase of max 100 records') 
      parser.add_argument('--dtype',action='store_true',required=False,\
                           help='convert records to I-Q when datatype is "ampl-phase"')
      parser.add_argument('--prlen',type=int,default=None,required=False,\
                           help='Set processing record length (factor of 2). Must be smaller or equal real length')
      parser.add_argument('--prclip',choices=('start','end'),required=False,default='end',\
                           help='clip record at start or end when length is larger than processing length')
      parser.add_argument('--ascii',action='store_true',required=False,\
                           help='write ASCII output files for averaged pulse and pulse/noise spectra and NEP')
 
      args=parser.parse_args()
      pargs=replicate(nproc,args)
   else:
      if ( len(sys.argv) > 1 ) and ( sys.argv[1] == '-h'):
         sys.exit()
      pargs=None

   args=comm.scatter(pargs,root=0)                 # get input parameters into all processes

   prlen=None   
   if args.prlen is not None:
      prlen=2**int(numpy.log(args.prlen)/numpy.log(2.))
      print("Process record length changed to: ",prlen)

   import matplotlib
   if args.pdf: 
      matplotlib.use('Agg')
   import matplotlib.pyplot as plt
   from matplotlib.backends.backend_pdf import PdfPages

   from tesfdmtools.methods.HDF_NEPutils import NEPs,NEPfiles
   from tesfdmtools.hdf.HMUX import HMUX

   if not os.path.exists(args.dir):
       mexit(rank,'Error, directory %s does not exist.' % args.dir)

   channel=None
   allchannel=False
   if args.channel is not None:
      if args.channel.isdigit():
         channel=int(args.channel)
      elif args.channel == 'all':
         allchannel=True
      else:
           mexit('Error, illegal channel number "%s"' % args.channel)

   freq=None
   allfreq=False
   if args.frequency[0] is not None:
      fmt='Error, illegal frequency specification: "'+len(args.frequency)*'%s '+'"'
      if args.frequency[0] == 'all':
         if len(args.frequency) > 1:
            mexit(rank,fmt % tuple(args.frequency))
         else:
            allfreq=True
      else:
         freq=[]
         for fff in args.frequency:
            if fff.isdigit():
               freq.append(int(fff))
            else:
               mexit(rank,fmt % tuple(args.frequency))
         freq=numpy.array(sorted(freq),dtype=int)
   else:
      freq=None

   pdf=None
   fontsize=9
   matplotlib.rcParams.update({'font.size': fontsize})

   nfile=None

   if args.all:                                      # one or more files ?
   
      from glob import glob
      ffs=glob(args.dir+'/*.h5')
      ffs.sort()

   else:

       if os.path.isfile(args.dir):
          ffs=[args.dir]
       else:
          ffs=[None]         

       if args.noise:                                # ask for noise file ?
          nfile=True

   for i,xfile in enumerate(ffs):

      if rank == 0:
         hdfnams=NEPfiles(xfile=xfile,nfile=nfile,fpath=args.dir,diag=args.diag,names=True)
         if args.pdf:
            pdfnam='HDF_NEP_%s.pdf' % os.path.basename(hdfnams[0]).split('.')[0]
         nhdfs=replicate(nproc,hdfnams)
      else:
         nhdfs=None
      xhdfnam,nhdfnam=comm.scatter(nhdfs,root=0)
      xhdf=HMUX(filename=xhdfnam)
      nhdf=HMUX(filename=nhdfnam)

      if channel is None:
         channel= xhdf.channels[0]

      if rank == 0:                                     # get frequencies data for individual file in main process  
         conf_tab=xhdf.channel[channel].freq_conf
         cpixels=conf_tab['pixel_index']
         fpixels=xhdf.channel[channel].freqs
         apixels=numpy.intersect1d(cpixels,fpixels)	# take only frequencies which occur both in the table as in the channel 
         print("Available pixels(frequencies): ",apixels)
         if allfreq:
            freq=apixels
         if freq is None:
            freq=numpy.array([apixels[0]],dtype=int)
         else:
            freq=numpy.intersect1d(freq,apixels)
         pxindx=numpy.zeros(freq.size,dtype=int)	# indices of selected frequencies in pixel configuration table
         for k,ff in enumerate(freq):
            for l,pp in enumerate(cpixels):
               if ff == pp:
                  pxindx[k]=l
         if i == 0:
            neps=numpy.zeros((len(ffs),len(freq)),dtype=float)
            ampl=numpy.zeros((len(ffs),len(freq)),dtype=float)
            hz=xhdf.channel[channel].freq_conf['freq']

         npix=len(freq)

         tdata={}                                      # list of frequencies to spread over processes
         for iproc in numpy.arange(nproc):             # initialize
            tdata[iproc]=[]
         iproc=0
         for iplt,fff in enumerate(sorted(freq)):      # spread frequencies to be processed over different processes
            if iproc in tdata:
               tdata[iproc].append([iplt,fff])
            else:
               tdata[iproc]=[[iplt,fff]]
            iproc=iproc+1
            if iproc >= nproc:
               iproc=0
         pdata=[]
         for pp in tdata:
            pdata.append(tdata[pp])       
      else:
         pdata=None
      freqs=comm.scatter(pdata,root=0)                 # list of applicable frequencies for this process           

      print(("process: %s  freqs:" % unq),freqs)

      cneps={}                                         # storage for data from this process
      pdfs={}
      fffs={}
      slines={}
      sdata=None
      for k,freq in freqs:                             # process frequencies for this process 
         pdffreq="freq_%3.3d_HDF_NEP_%s.pdf" % (freq,os.path.basename(xhdfnam).split('.')[0])
         if args.pdf:
            pdf=PdfPages(pdffreq)
         else:
            pdf=None
         pdfs[k]=pdffreq
         fffs[k]=freq

         cneps[k],slines[k]=NEPs(xrays=xhdf,noise=nhdf,diag=args.diag,cutfreq=args.cutoff,\
                                 xexclude=args.xexclude,ninclude=args.ninclude,pdf=pdf,\
                                 channel=channel,freqs=[freq],logfreq=args.logfreq,sumret=args.summary,\
                                 pulsepos=args.ignoreposition,rot=args.rotate,ascii=args.ascii,\
                                 dtype=args.dtype,prlen=prlen,prclip=args.prclip)
         if args.pdf:
            pdf.close()
         sdata={'cneps':cneps,'pdfs':pdfs,'fffs':fffs,'slines':slines} # store data from this process

      gdata=comm.gather(sdata,root=0)                  # get data from all processes
        
      if rank == 0:                                    # and process in main proces

         freq=numpy.zeros(npix)
         pdfs=numpy.empty(npix,dtype=object)
         slin=numpy.empty(npix,dtype=object)
         for data in gdata:
            if data is not None:
               for k in data['fffs']:
                  neps[i,k]=data['cneps'][k]
                  freq[k]=data['fffs'][k]
                  pdfs[k]=data['pdfs'][k]
                  slin[k]=data['slines'][k]

         if args.pdf:
            cmd='pdfunite '                               # combine all pdf's
            for k in numpy.arange(npix):
               cmd=cmd+pdfs[k]+' '
            cmd=cmd+pdfnam
            xx=os.system(cmd)
            if xx == 0:
               for pdff in pdfs:
                  if os.path.exists(pdff):
                     os.unlink(pdff)
            else:
               print("Error executing 'pdfunite' to combine pdf's ")

         if args.summary:                             # write summary file
            sname='NEP_'+os.path.basename(xhdfnam).split('.')[0]+'.txt'
            sfile=open(sname,'w')
            sfile.write('# xray_file: %s\n' % (os.path.basename(xhdfnam).split('.')[0]))
            sfile.write('#noise_file: %s\n#\n' % (os.path.basename(nhdfnam).split('.')[0]))
            sfile.write('# i  px  freq(khz)     V_bias       gbwp   NEP(fwhm) (eV)  risetime(ms) falltime(ms)\n#\n')
            for k in numpy.arange(npix):
               sfile.write(slin[k])
            sfile.close()

         ampl[i,:]=xhdf.channel[channel].freq_conf['ampl'][pxindx]
         fff="%s: "+len(neps[i,:])*" %5.2f"
         ppp=tuple([os.path.basename(xhdfnam)]+list(neps[i,:]))
         print(fff % ppp)

   if args.all: 

      if rank == 0:                                   # combine data from all files in main process

         if args.pdf:
            pdf=PdfPages('HDF_NEPoverview_%s.pdf' % (os.path.basename(args.dir)))
         else:
            pdf=None

         np=len(freq)
         ncol=1
         if np > 2:
            ncol=2
         nrow=(np+ncol-1)//ncol
         grid=(nrow,ncol)
         fig = plt.figure()
         plots={}
         xp=numpy.arange(len(ffs))
         for i in numpy.arange(np):
            plots[i] = plt.subplot2grid(grid, ((i//ncol),(i % ncol)) ) 
            indx,=numpy.where(neps[:,i] > 0.0)
            plots[i].plot(ampl[indx,i],neps[indx,i],'bo')
            plots[i].set_xlabel("Ampl.")
            plots[i].set_ylabel("NEP (eV)")
            plots[i].set_title('%d: %8.2f (kHz)' % (freq[i],hz[i]))
         if args.pdf:
            psfile="NEP_%s.ps" % args.dir.replace('/','_')
            fpage = plt.gcf()
            fpage.set_size_inches(8.27, 11.69)
            plt.tight_layout()
            plt.savefig(pdf,format='pdf',papertype='a4')
            pdf.close()
            plt.close('all')
         else:
            plt.tight_layout()
            plt.show()
            plt.close('all')
 
   print("End of process: %s" % unq)
   sys.exit()            



  
