#!/usr/bin/env python3
'''
.. module:: TDoptfilter
   :synopsis: Optimum filter fit of TES pulses in the time domain  
.. moduleauthor:: Cor de Vries <c.p.de.vries@sron.nl>
'''
import numpy
import matplotlib
import matplotlib.pyplot as plt
from tesfdmtools.methods.HDF_Eventutils import getrisetime,getphase
from tesfdmtools.methods.IQrotate import IQphase,IQphrot
from tesfdmtools.methods.shiftarr import shiftarr
from tesfdmtools.utils.cu import cu
[docs]def optfilter(hdf,channel,freq,indx,base1,base2,noisspec,debug=False,prlen=None,\
              
rotate=False,risetime=False,bsec=0.05,absolute=False,\
              
nppos=11,shiftpulse=False,flip=False,**kwargs):
   '''
   perform optimal filtering fit of individual pulses (for time domain)
   Args:
      *       `hdf` = HDF5 input file object
      *   `channel` = channel number being processed
      *      `freq` = frequency number (pixel) being processed
      *      `indx` = index of selected events, to be processed
      *     `base1` = baseline level at start of record
      *     `base2` = baseline level at end of record
      *  `noisspec` = noise spectrum
   Kwargs:
      *     `prlen`  = record length to use (10% pre-max and 90% postmax) 
      *    `rotate`  = rotate pulses to minimum Q [default: False]
      *  `risetime`  = if True, compute rise time of pulses [default: False]
      *      `bsec`  = section of record to take for background [default: 0.05] (only used for 'rotate' or 'absolute')
      *  `absolute`  = use sqrt(I^2+Q^2) signal [default: False]
      *     `nppos`  = number of pulse positions to consider for fitting
      * `shiftpulse` = if True, shift pulses to same position for pulse maximum
      *      `flip`  = flip data in record
   Returns:
      *    `ifit` = fitted pulse heigth parameters
      *  `rtimes` = computed rise times of fitted pulses
   '''
   print("Optimal filter (time domain), channel=",channel,' freq=',freq)  
   aphase=0.0
   if type(rotate) is bool:
      if rotate:
         aphase=getphase(hdf,channel,freq,indx,debug=debug)
   else:
      if type(rotate) is float:
         aphase=rotate 
   sirecord=hdf.channel[channel].freq[freq][indx[0]][:,0]        # read first record to get record parameters
   if shiftpulse:
      print('Shift pulse maxima to identical position')
   pulsepos=numpy.argmin(sirecord)
   samplerate=float(cu(hdf.channel[channel].freq[freq].attrs['sample_rate'])) # get sample rate
   if prlen is None:
      c1=0
      c2=sirecord.size                                                # max samplingfrequency
      bl=int(bsec*(c2-c1))                                         # section for background level
      prlen=sirecord.size
   else:
      mm=numpy.argmin(sirecord)                                        # frequency cutoff
      c1=max([0,int(mm-0.10*prlen)])
      c2=c1+prlen
      if c2 > sirecord.size:
         c2=sirecord.size
         c1=max([0,int(c2-prlen)])
      bl=int(bsec*(c2-c1))                                         # section for background level
      print("record size: ",sirecord.size,"   pulse peak at: ",mm)
      print("      use record from samples ",c1,"  to: ",c2)
      print("   background outside samples ",(c1+bl),"  to: ",(c2-bl))
   bl=int(bsec*(c2-c1))                                                 # section for background level
   xax=numpy.arange(prlen,dtype=float)                        # make x-axsis for record
   iax=numpy.arange(prlen,dtype=int)
   bax=numpy.concatenate((iax[:bl],iax[-bl:]))        # baseline section
   fpulses=numpy.zeros((indx.size,(c2-c1)),dtype=float)                # initialize array to store pulse 
   avpulse=numpy.zeros(prlen,dtype=float)                       # average pulse profile
   if risetime:
      rtimes=numpy.empty(indx.size,dtype=float)                 # array of risetimes for events
      ftimes=numpy.empty(indx.size,dtype=float)                 # array of fall times for events
   else:
      rtimes=None
      ftimes=None
    
   for i,irec in enumerate(indx):                                # go through all record in selection list
      record=hdf.channel[channel].freq[freq][irec]
      
      if absolute:
         irecord=numpy.sqrt(record[:,0].astype(float)**2+record[:,1].astype(float)**2)
      else:
         if flip or ( record[0,0] < 0 ):                                        # wrong rotation, rotate by pi
            irecord=-record[:,0]
         else:
            irecord=record[:,0]
      if aphase != 0.0:
         irecord,qrecord=IQphrot(record[:,0],record[:,1],aphase)
                  
      if shiftpulse:
         pp=irecord.argmin()
         irecord=shiftarr(irecord,(pulsepos-pp))
         
      irecord=irecord[c1:c2]
      pp=numpy.polyfit(bax.astype(float),irecord[bax].astype(float),1)
      cirecord=irecord-(pp[0]*xax+pp[1])                # subtract baseline
#      f,ax=plt.subplots(1)
#      ax.plot(xax[c1:c2],cirecord)
#      plt.show()
#      plt.close('all')                                 
      avpulse=avpulse+irecord                                    # accumulate average pulse 
      if risetime:
         rt,ft=getrisetime(xax,cirecord,debug=debug)             # rise time in seconds   
         rtimes[i]=rt/samplerate
         ftimes[i]=ft/samplerate  
      fpulses[i,:]=cirecord
      if ( i % 1000 ) == 0:
         print("process record: ",irec)
   print("number of pulses processed: ",indx.size) 
   avpulse=avpulse/float(indx.size)                             # average pulse
   bindx=numpy.concatenate((numpy.arange(100,dtype=int),numpy.arange((avpulse.size-100),avpulse.size,dtype=int)))
   abfit=numpy.polyfit(xax[bindx],avpulse[bindx],1)
   avpulse=avpulse-(abfit[0]*xax+abfit[1])
   minav=numpy.argmin(avpulse)
   avbline=abfit[0]*minav+abfit[1]                              # average baseline
   
   apulse=fpulses.sum(axis=0)/float(indx.size)                  # compute average pulse
   axax=numpy.arange(apulse.size,dtype=int)                     # 
   bxax=numpy.concatenate((axax[:bl],axax[-bl:]))
   pp=numpy.polyfit(bxax.astype(float),apulse[bxax],1)
   apulse=-(apulse-(pp[0]*axax+pp[1]))		   # subtract baseline
#   print("apulse unfiltered surface: ",apulse.sum())
  
   afft=numpy.fft.fft(apulse) 		           # fft of template
   nfft=numpy.absolute(afft[0])			   # magnitude of normalization
#   mfft=numpy.absolute(afft[afft.size//2])         # magnitude of highest frequency
   mfft=noisspec[-1]                               # extension to highest frequency for noise
   nnspec=numpy.concatenate((noisspec,[mfft],noisspec[-1:0:-1])) # add negative frequencies to noisespectrum
   nnspec[0]=nfft				   # fill zero frequency to prevent division by zero
   weight=numpy.real(numpy.fft.ifft(afft/(nnspec**2)))  # filter template, using noise to get weights
#   print("apulse filtered surface: ",apulse.sum())
   
   if debug:
   
      f, ax = plt.subplots(1)
      ax.plot(axax,weight,'b-')
      ax.set_xlabel('record bin')
      ax.set_ylabel('template weight')
      ax.set_title('noise-filtered template')
      plt.show()
      plt.close('all')
   
# shift template to obtain fits as function of pulse position
   hna=nppos//2
   weight[0:hna]=0.0
   weight[-hna:]=0.0
   nas=numpy.arange(nppos)
   weights=numpy.zeros((nppos,weight.size),dtype=float)
   for i in nas:                                                # store for 'na' different positions
      ish=i-hna
      weights[i]=numpy.roll(weight,ish)
   norm=numpy.sum(apulse*weight)                                # average pulse weight normalization
   ifit=numpy.zeros(indx.size,dtype=float)        # initialize array to store optimal filter fit parameters
   ishft=numpy.zeros(indx.size,dtype=float)
   ees=numpy.zeros(nppos)
   for i in numpy.arange(indx.size):
      for j in nas:
         ees[j]=numpy.sum(fpulses[i]*weights[j])/norm       # compute for the different positions
#      print "ees: "+nppos*"%10.7f" % tuple(ees-ees.min())
      pp=numpy.polyfit(nas.astype(float),ees,2)             # fit polynomial to different shift results
      mee=pp[2]-(pp[1]**2/(4.0*pp[0]))                      # compute maximum of polynome
      ifit[i]=-mee                                          # store maximum of fitted polynomial
      ishft[i]=-pp[1]/(2.0*pp[0])                           # record fitshift for debug output
   if debug:
     f,ax = plt.subplots(1)
     ax.plot(ifit,ishft,'b.')
     ax.set_xlabel('optimal fit')
     ax.set_ylabel('opt. fit shift')
     ax.set_title('optimal fit shift')
     plt.show()
     plt.close('all')
     if risetime:
        np=4
     else:
        np=3
     f, ax = plt.subplots(np)
     nax=numpy.arange(indx.size)
     ax[1].plot(nax,ifit,'b.')
     ax[1].set_xlabel('event number')
     ax[1].set_ylabel('Pulse intensity')
     ax[0].set_title('Optimal filter fits')
     ax[0].hist(ifit,bins=500,histtype='stepfilled')
     ax[0].set_xlabel('Energy (arb. units)')
     ax[0].set_ylabel('Counts')
     ax[2].set_xlabel('record bin')
     ax[2].set_ylabel('TD filter')
     ax[2].plot(axax,apulse,'b-')
     ax[2].plot(axax[:bl],apulse[:bl],'r-',axax[-bl:],apulse[-bl:],'r-')
     if risetime:
        ax[3].hist(rtimes*1000,bins=512,align='mid')
        ax[3].set_xlabel("Rise time (ms)")
        ax[3].set_ylabel("N")
        ax[3].set_title("Pulse rise times")
     plt.show()
     plt.close('all') 
   return (ifit,rtimes,ftimes,avpulse,avbline)