#!/usr/bin/env python3
#
'''
.. module:: PCA_analysis
   :synopsis: PCA analysis for selected eventlist file
.. moduleauthor:: Cor de Vries <c.p.de.vries@sron.nl>
'''

import sys,os
import numpy

import matplotlib
matplotlib.use('GTK3Agg')
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

#====================================================================

def saveplot(pdf):
   fpage = plt.gcf()
   fpage.set_size_inches(8.27,10.75)
   plt.tight_layout()
   plt.subplots_adjust(top=0.94)
   plt.savefig(pdf,format='pdf',papertype='a4',orientation='portrait')

   plt.close('all')

#====================================================================
if __name__ == "__main__":

   import argparse

   parser = argparse.ArgumentParser(\
                      description="Principal Component Analysis (PCA) of selected pixel event records")
   parser.add_argument('-t','--type',type=str,required=False,default='freq',choices=['freq','time'],\
                        help='Analysis domain type: "freq"=frequency domain (default) or "time"=time domain')
   parser.add_argument('-f','--file',type=str,required=False,default=None,\
                        help='input eventlist (.lst) file. (output file of HDF_Eventpar)')
   parser.add_argument('--bsec',type=float,required=False,default=0.05,\
                       help='Section (start+end) of record to use for background (default 0.05)')
   parser.add_argument('-c','--comp',type=int,required=False,default=10,\
                       help='Max number of PCA components to consider')
   parser.add_argument('-p','--nplot',type=int,required=False,default=4,\
                       help='Max number of PCA components to plot')
   parser.add_argument('-o','--output',action='store_true',required=False,\
                       help='Write (ASCII) output file of eigenvector projections for first 5 components')   
   parser.add_argument('-n','--nrecs',type=int,required=False,default=9999,\
                       help='Max number of data records to use')
    
   import datetime

   from tesfdmtools.utils.widgets.filechooser import getfile
   from tesfdmtools.hdf.HMUX import HMUX
   from tesfdmtools.utils.reventlist import reventlist
   from tesfdmtools.methods.HDF_Eventutils import makespectrum
   from tesfdmtools.methods.spectrumfit import spectrumfit
   from tesfdmtools.utils.fiterrplot import fiterrplot
   from tesfdmtools.methods.PCA import PCA
   from tesfdmtools.utils.tesfdm_defaults import Tesfdm_Defaults

   args=parser.parse_args()
   defs=Tesfdm_Defaults()                               # last locations for files

   blfrac=args.bsec				        # baseline fraction on either side
   ncomponents=args.comp				# max. number of PCA components
   np=args.nplot					# number of components to plot
   rmax=args.nrecs					# for tests max number of record to process

   tdomain=False
   if args.type == 'time':
      tdomain=True

   if args.file is None:
      elistname=getfile(path=defs.get_filepath('lst'),pattern='eventlist*.lst')
      if elistname == '':
         sys.exit()
      defs.set_filepath(elistname)
   else:
      elistname=args.file

   if not os.path.exists(elistname):
      sys.exit("Error, file does not exist: ",elistname)

   indx,optfit,base,peak,hdfpars,rtimes,ftimes,fenergies=reventlist(elistname)
   hdfnam=hdfpars[0]
   hdfid=os.path.basename(hdfnam).split('.')[0]		# HDF ident
   channel=hdfpars[1]
   freq=hdfpars[2]
   hdf=HMUX(filename=hdfnam)

   rec0=hdf.channel[channel].freq[freq][indx[0]][:,0]
   samplerate=float(hdf.channel[channel].freq[freq][indx[0]].attrs['sample_rate'])
   recl=rec0.size
   xx=numpy.arange(recl)
   if tdomain:
      frecl=recl
      fax=xx  						# record axis in sample number
      xlabel='Sample #'
   else:
      frecl=recl//2
      fax=xx[0:frecl]/float(frecl)*samplerate  		# record axis in Hz
      xlabel='Frequency (Hz)'
   bl=int(recl*blfrac)
   blxx=numpy.concatenate((xx[0:bl],xx[-bl:]))		# baseline index
   nrecs=min([indx.size,rmax])
   data=numpy.zeros((nrecs,frecl),dtype=float)		# data matrix

   for j,i in enumerate(indx):				# go through all records

      if j == rmax:
         break
      idata=hdf.channel[channel].freq[freq][i][:,0]     # i-data
      bc=numpy.polyfit(blxx,idata[blxx],1)		# baseline fit
      idata=idata-bc[0]*xx-bc[1]			# baseline subtracted
      if tdomain:
         data[j,:]=idata				# time domain
      else:
         pdata=numpy.absolute(numpy.fft.fft(idata))[0:frecl] # go to frequency domain
         data[j,:]=pdata					# store in matrix
      if ( j % 1000 ) == 0:
         print("Read record: ",i)

   now=datetime.datetime.now()
   print("Start PCA analysis of %d records: %s" % (nrecs,now))
   pca = PCA(n_components = ncomponents) 
   output = pca.fit_transform(data)
   now=datetime.datetime.now()
   print("End of PCA analysis: %s" % now )  

   pdf = PdfPages('PCA_analysis_'+hdfid+'.pdf')   
   
   evecs = pca.components_				# eigen vectors
   explvar = pca.explained_variance_ratio_		# variance ratio

# output file

   if args.output:
      oname='PCA_components_'+hdfid+'.lst'
      ofile=open(oname,'w')
      ofile.write("#File: %s\n" % hdfpars[0])
      ofile.write("#Channel: %s\n" % hdfpars[1])
      ofile.write("#Freq: %s\n" % hdfpars[2])
      ncols=numpy.arange(5)				# number of components for output
      head="#"
      for col in ncols:
         head=head+"%10s%1d    " % ("comp",col)
      ofile.write(head+"\n")
      for i in numpy.arange(indx.size):
         line=""
         for j in ncols:
            line=line+"%15.5e" % output[i,j]
         ofile.write(line+'\n')
      ofile.close()

# plot weights of main eigenvectors

   f, ax = plt.subplots(2)
   ax[0].set_title(hdfid+('_ch%2.2d_px%2.2d' % (channel,freq)))
   ax[0].set_xlabel('Eigenvector number')
   ax[0].set_ylabel('Variance ratio')
   ax[0].plot((numpy.arange(explvar.size)+1),explvar,'bo-')
   ax[0].set_yscale('log') 
   
# plot correlation between two main eigenvectors

   ax[1].plot(output[:,0],output[:,1],'b.')
   ax[1].set_xlabel('Eigenvector 1')
   ax[1].set_ylabel('Eigenvector 2')

   saveplot(pdf)

# plot main eigenvectors
       
   f, ax = plt.subplots(np, sharex=True)
   for ip in numpy.arange(np):
      ss=1.0
      if evecs[ip,1] < 0:
         ss=-1.0
      ax[ip].plot(fax[1:],ss*evecs[ip,1:])
      ax[ip].set_ylabel('Vect. %d' % (ip+1))
      if not tdomain:
         ax[ip].set_xscale('log')
   ax[np-1].set_xlabel(xlabel)
   ax[0].set_title(hdfid+('_ch%2.2d_px%2.2d' % (channel,freq)))
   saveplot(pdf)

# plot spectrum for first two (main) eigenvector components

   f, ax = plt.subplots(2)
   for i in numpy.arange(2): 
      spectrum,bins=makespectrum(output[:,i],sbins=1000)
      ax[i].plot(bins,spectrum,'b-')
      ax[i].set_xlabel('bin# component %d' % ((i+1)))
      ax[i].set_ylabel('counts')
   saveplot(pdf)

# try to make spectrum

   spectrum,bins=makespectrum(output[:,0],sbins=1000)
   jm=numpy.argmax(spectrum)
   if jm > (spectrum.size//2):
      spectrum=spectrum[::-1]
      bins=bins[::-1]
   pars,fitspectrum,energies,fitted,lratio=spectrumfit(spectrum,bins)
   eoutput=pars[10]+pars[11]*output[:,0]     # convert eigenvector-1 projections to energy
   aindx,=numpy.where( ( eoutput > 5.86 ) & ( eoutput < 5.92 ) )  # limit energy range
   f, ax = plt.subplots(2)
   yerr=numpy.sqrt(fitted)
   ptxt='FWHM: %4.2f eV' % (pars[3]*1000.0)
   fiterrplot(energies,fitspectrum,yerr,fitted,xtitle='Energy(keV)',ytitle='Counts',\
               ptitle='',ptxt=ptxt,pltax=ax[0])
   ax[1].plot(eoutput[aindx],output[aindx,1],'b.')
   ax[1].set_xlabel('Eigenvector 1 (keV)')
   ax[1].set_ylabel('Eigenvector 2')

   saveplot(pdf)
 
   pdf.close()
   
   
   

