#!/usr/bin/env python3
#
'''
.. module:: HDF_average
   :synopsis: Average pulse data, or noise data. Show average raw data or spectrum 
.. moduleauthor:: Cor de Vries <c.p.de.vries@sron.nl>

'''

import matplotlib
if __name__ == '__main__':
   matplotlib.use('GTK3Agg')
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from tesfdmtools.utils.widgets.printbutton import printbutton

import numpy

def doublepeak(record,trat=0.5,dist=100):
   '''
   Determine if pulse record contains double peak or not

   Args:
      * `record` = array with the record pulse data
      
   Kwargs:
      * `trat` = threshold in terms of fraction of the peak value
      * `dist` = minimum distance for double peak at threshold value

   Returns:
      * True of False whether it is a double peak
   '''
   pmin=numpy.min(record)
   pbas=numpy.mean(record[0:10])
   thresh=pbas-trat*(pbas-pmin)
   indx,=numpy.where(record < thresh)
   dd=indx[1:]-indx[0:-1]
#   print("dpeak: b,m,t,d",pbas,pmin,thresh,dd.max())
   if ( dd.max() < dist ):
      return False
   else:
      return True

def avplot(pltx,x,y1,y2,xtitle='',y1title='',y2title='',title='',log=False,noq=False):
   '''
   Plot the averaged data

   Args:
      * `pltx` = plot instance
      *   `x`  = X-axis for the plots
      *  `y1`  = Data for the I plot
      *  `y2`  = data for the Q plot

   Kwargs:
      *  `xtitle` = title for the X-axis
      * `y1title` = title for the top plot Y-axis
      * `y2title` = title for the top plot Y-axis
      *   `title` = general title for plot
      *     `log` = if True, plot in log scale 
      *     `noq` = if True, do not plot Q-signals  
   '''

   pltx.plot(x,y1,'b-')
   pltx.set_ylabel(y1title)
   pltx.set_xlabel(xtitle)
   pltx.set_title(title)

   if not noq:
      pltq=pltx.twinx()
      pltq.plot(x,y2,'r-')
      pltq.set_ylabel(y2title,color='r')
      for tl in pltq.get_yticklabels():
               tl.set_color('r')

      pltx.set_zorder(2)
      pltx.patch.set_visible(False)
      pltq.set_zorder(1)
      pltq.patch.set_visible(True)   

   if log:
      pltx.set_xscale('log')
      pltx.set_yscale('log')
      if not noq:
         pltq.set_yscale('log')
 
#===================================================================================

if __name__ == "__main__":

   import sys,os
   import argparse

   from tesfdmtools.utils.cu import cu
   from tesfdmtools.utils.widgets.filechooser import getfile
   from tesfdmtools.utils.tesfdm_defaults import Tesfdm_Defaults
   from tesfdmtools.hdf.HMUX import HMUX

   parser = argparse.ArgumentParser(\
                      description="Show averaged pulse or noise data record or spectrum")
   parser.add_argument('-ch','--channel',type=int,required=False,default=None,\
                       help='Channel number to use. Default is first available channel.')
   parser.add_argument('-f','--frequency',type=str,nargs='+',required=False,default=[None],\
                       help="Frequency numbers (pixels) to use. Default is first available frequency. 'all' for all frequencies ")
   parser.add_argument('-fr','--frange',nargs=2,type=float,required=False,default=[None,None],\
                       help="Range for frequencies to plot in spectrum")
   parser.add_argument('--nfrac',type=float,required=False,default=0.25,\
                       help='Fraction of lowest intensity data to take for noise spectra')
   parser.add_argument('--pthresh',type=int,required=False,default=500,\
                       help='Threshold for detecting pulses')
   parser.add_argument('--noq',action='store_true',required=False,\
                       help='Do not plot Q-signals') 
   parser.add_argument('-v','--volts',action='store_true',required=False,\
                       help='Convert output values to Volts')
   parser.add_argument('-a','--ampphase',action='store_true',required=False,\
                       help='Output amplitude and phase for average pulse')
   parser.add_argument('-o','--output',action='store_true',required=False,\
                       help='Write ASCII output files of results') 
   parser.add_argument('--file',type=str,required=False,default=None,\
                       help='HDF5 input file to use. When not provided, user can select the file interactively')

   args=parser.parse_args()
   defs=Tesfdm_Defaults()

   if args.file is None:
      infile=getfile(pattern='*.h5',path=defs.get_filepath('h5')) # select input file
   else:
      infile=args.file
   defs.set_filepath(infile)

   xray=False
   if os.path.basename(infile).find('xray') > 0:		# x-ray or noise file ?
      xray=True

   hdf=HMUX(filename=infile)
   fnam=os.path.basename(infile).rsplit('.',1)[0]
   print("\nFile: ",fnam)
   pdffile='Av_'+fnam+'.pdf'
   pdf=PdfPages(pdffile)					# pdf output plot file
   fpage = plt.gcf()
   fpage.set_size_inches(8.27, 10.75)

   if args.channel is None:
      chan=int(hdf.channels[0])  
   else:
      chan=args.channel

   conf_tab=hdf.channel[chan].freq_conf
   cpixels=conf_tab['pixel_index']
   fpixels=hdf.channel[chan].freqs
   apixels=numpy.intersect1d(cpixels,fpixels)   # take only frequencies which occur both in the table as in the channel
   print("Available pixels(frequencies): ",apixels)
   if apixels.size == 0:
      sys.exit("pixel configuration table and available frequencies do not match")

   if args.frequency[0] is None:
      frequencies=numpy.array([apixels[0]],dtype=int)
   else:
      fmt='Error, illegal frequency specification: "'+len(args.frequency)*'%s '+'"'
      if args.frequency[0] == 'all':
         frequencies=numpy.array(sorted(apixels),dtype=int)
      else:
         frequencies=[]
         for fff in args.frequency:
            if fff.isdigit():
               frr=int(fff)
               findx,=numpy.where(apixels == frr)
               if len(findx) == 0:
                  print("Selected frequency (pixel) %d does not exist in file" % frr)
               frequencies.append(frr)
            else:
               print(fmt % tuple(args.frequency))
         frequencies=numpy.array(sorted(frequencies),dtype=int)

   nevents=int(hdf.channel[chan].attrs['Nevents'])
   print("Number of records: ",nevents)

   below=False
   np=min([4,frequencies.size])
   if frequencies.size == 1:
      below=True
   ppage=False

   for jj,freq in enumerate(frequencies):			# go through all frequencies

      print("Processing pixel (frequency): ",freq," .......")

      datatype=cu(hdf.channel[chan].attrs['data_type'])
      print("     datatype is: '%s'" % datatype)
      if datatype == 'ampl-phase':
         ampphas=True
      else:
         ampphas=False

      jp=jj % 4
      if ppage and (jp == 0):
         plt.suptitle(fnam,size=12)
         plt.tight_layout()
         plt.subplots_adjust(top=0.94)
         plt.savefig(pdf,format='pdf',papertype='a4')
         plt.close('all')
         ppage=False
         fpage = plt.gcf()
         fpage.set_size_inches(8.27, 10.75)
      if below:
         plt1=plt.subplot2grid((2,1),(0,0))
         plt2=plt.subplot2grid((2,1),(1,0))
      else:
         plt1=plt.subplot2grid((np,2),(jp,0))
         plt2=plt.subplot2grid((np,2),(jp,1))  

      ptt='pixel(freq)= %d' % freq 

      if not xray:							# for noise file
         ptps=numpy.zeros(nevents,dtype=float)     		
         for i,event in enumerate(hdf.channel[chan].freq[freq]):	# select lowest 'nfrac' of records
            ptps[i]=numpy.ptp(event[:,0])
         indx=numpy.argsort(ptps)
         urecs=max([1,int(args.nfrac*float(nevents))])
         valid=numpy.sort(indx[0:urecs])
      else:							# for x-ray file
         valid=numpy.zeros(nevents,dtype=int)
         i=0
         for j,event in enumerate(hdf.channel[chan].freq[freq]):
            if numpy.ptp(event[:,0]) > args.pthresh:		# select single pulses only
               if not doublepeak(event[:,0]):
                  valid[i]=j
                  i=i+1
         valid=valid[0:i]

   
      for j,i in enumerate(valid):			      # go through all valid records

         event=hdf.channel[chan].freq[freq][i] 

         if j == 0:					      # initialize  
            avIdata=numpy.array(event[:,0],dtype=float)
            avQdata=numpy.array(event[:,1],dtype=float)
            srate=int(event.attrs['sample_rate'])
            vscale=float(event.attrs['scale_volt_DEMUX'])
            if ampphas:
               pscale=float(event.attrs['scale_phase_rad'])
            else:
               pscale=vscale
            if jj == 0:
               print("Sample rate: ",srate)
         else:						      # add remaining records
            avIdata=avIdata+event[:,0]
            avQdata=avQdata+event[:,1]

      if args.volts:
         yunit=' (V)'
      else:
         vscale=1.0
         pscale=1.0
         yunit=''

      avIdata=avIdata/float(valid.size)*vscale			       # average
      avQdata=avQdata/float(valid.size)*pscale
      x=numpy.arange(avIdata.size,dtype=float)

      if args.ampphase:

         if ampphas:
 
            amplitude=avIdata
            phase=avQdata
            avIdata=amplitude*numpy.cos(phase)
            avQdata=amplitude*numpy.sin(phase)

         else:

            amplitude=numpy.sqrt(avIdata**2+avQdata**2)
            phase=numpy.arctan2(avIdata,avQdata)

         avplot(plt1,x,amplitude,phase,xtitle='sample#',y1title='Amplitude'+yunit,\
                y2title='Phase (rad)',title=ptt,noq=args.noq)  # plot result

      else:

         if ampphas:

            ii=avIdata*numpy.cos(avQdata)
            qq=avIdata*numpy.sin(avQdata)

            avIdata=ii
            avQdata=qq

         avplot(plt1,x,avIdata,avQdata,xtitle='sample#',y1title='average I-data'+yunit,\
                y2title='average Q-data'+yunit,title=ptt,noq=args.noq)  # plot result
  
      Ifft=numpy.absolute(numpy.fft.fft(avIdata))	       # compute spectrum
      Qfft=numpy.absolute(numpy.fft.fft(avQdata))
      fax=x/float(x.size)*srate
      n2=Ifft.size//2					       # spectrum is symmetric; only plot 1st half

      if ( args.frange[0] is not None ) and ( args.frange[1] is not None ):
         findx,=numpy.where( ( fax > args.frange[0] ) & ( fax < args.frange[1] ) )
         fax=fax[findx]
         Ifft=Ifft[findx]
         Qfft=Qfft[findx]
      else:
         fax=fax[1:n2]
         Ifft=Ifft[1:n2]
         Qfft=Qfft[1:n2]

      avplot(plt2,fax,Ifft,Qfft,xtitle='frequency (Hz)',y1title='average I-data'+yunit,\
             y2title='average Q-data'+yunit,log=True,title=ptt,noq=args.noq)      # plot spectrum result

      ppage=True

      if args.output:					      # write ASCII output files

         oid='_ch%2.2d_px%2.2d' % (chan,freq)
         if args.ampphase:
            iqfile='AverageAP_'+fnam+oid+'.txt'		      # IQ data output file
         else:
            iqfile='AverageIQ_'+fnam+oid+'.txt'		      # IQ data output file
         spfile='AverageSp_'+fnam+oid+'.txt'		      # Spectrum data output file
         cmd='#cmd:'
         for sw in sys.argv:
            if sw == '--file':
               break
            cmd=cmd+' '+sw
         fil='#file: '+infile
         
         ofile=open(iqfile,'w')				      # write IQ data
         ofile.write(cmd+'\n')
         ofile.write(fil+'\n')
         stime=x.astype(float)/float(srate)*1000.0           # time axis (msec)
         units=['msec','ADC','radians']
         if args.volts:
            units[1]='Volts'
         if not args.ampphase:
            units[2]=units[1]
         ofile.write('#Data unit: 0:[%s] 1:[%s] 2:[%s]\n' % tuple(units))
         if args.ampphase:
            ofile.write("#%9.9s%20.15s%20.15s\n" % ('time','Ampl','Phase'))
            for j,xx in enumerate(stime):
               ofile.write("%10.4f%20.6e%20.6e\n" % (xx,amplitude[j],phase[j]))
         else:
            ofile.write("#%9.9s%20.15s%20.15s\n" % ('time','I','Q'))
            for j,xx in enumerate(stime):
               ofile.write("%10.4f%20.6e%20.6e\n" % (xx,avIdata[j],avQdata[j]))
         ofile.close()

         ofile=open(spfile,'w')				      # write spectral data
         ofile.write(cmd+'\n')
         ofile.write(fil+'\n')
         if args.volts:
            ofile.write('#Data unit: Volts\n')
         ofile.write("#%19.15s%20.15s%20.15s\n" % ('f (Hz)','Ifft','Qfft'))
         for j,xx in enumerate(fax):
            ofile.write("%20.6e%20.6e%20.6e\n" % (xx,Ifft[j],Qfft[j]))
         ofile.close()

   if ppage:
      plt.suptitle(fnam,size=12)
      plt.tight_layout()
      plt.subplots_adjust(top=0.94)
      plt.savefig(pdf,format='pdf',papertype='a4')
      plt.close('all')
    
   pdf.close()


   

   
      


   
