#!/usr/bin/env python3
#
'''
.. module:: Crosstalkmatrix
   :synopsis: Plot crosstalk matrix based on the HDF_Crosstalk output files
.. moduleauthor:: Cor de Vries <c.p.de.vries@sron.nl>


Use:

   Crosstalkmatrix [-h]

This script will plot the crosstalk data in a matrix. 

The script will ask for a sample **HDF_Crosstalk** output file which contains the crosstalk profiles.
The **HDF_Crosstalk** output files are produced when the parameter `--txtout` is provided.
If associated files for other pixels are also present in the same directory, they will also be processed.

'''
import os,sys
import numpy
from glob import glob

from tesfdmtools.utils.widgets.filechooser import getfile

def getnrec(txtfile):
   '''
   Read header information from HDF_Crosstalk files

   Args:
     * `txtfile` = name of input HDF_Crosstalk text file

   Returns:
     * `cnt` = number of samples in the crosstalk profiles
     * `khz` = dictionary of frequencies for each pixel
     * `rot` = always zero
     * `npix` = number of pixels
   '''

   ff=open(txtfile,'r')
   cnt=0
   khz={}
   rot={}
   npix=0
   for line in ff.readlines():
      ll=line.strip(' \n')
      if ll[0] != '#':
         if ll[0:1].isdigit():
            cnt=cnt+1
      else:
         if ll.find('(kHz)') > 0:
            tt=ll.split()
            ipix=int(tt[2])
            khz[ipix]=float(tt[4])
#            rot[ipix]=float(tt[7])
            rot[ipix]=0.0
            npix=npix+1
   ff.close()
   return (cnt,khz,rot,npix)

def getcdata(txtfile,npix,nrec):
   '''
   Get the crosstalk profiles data

   Args:
      * `txtfile` = name of the crosstalk profile file
      *    `npix` = number of pixels in the file
      *    `nrec` = number of samples for the profiles 
   '''

   dd=numpy.zeros((npix,2,nrec),dtype=float)

   pxs=numpy.arange(npix)

   ff=open(txtfile,'r')
   irec=0
   ipx=None
   for line in ff.readlines():
      ll=line.strip(' \n')
      if ll[0] != '#':
         nums=numpy.array(ll.split(),dtype=float)
         for ii in pxs:
            dd[pxindx[pxpos[ii]],0,irec]=nums[2*ii+1]
            dd[pxindx[pxpos[ii]],1,irec]=nums[2*ii+2]
         irec=irec+1
      else:
         if ll.find(" pixel =") > 0:
            pxtxt=ll.split()
            ipx=int(pxtxt[3])
         elif ll.find(" i ") > 0:
            off=3 
            pxtxt=ll.split()
            pxpos=numpy.zeros(npix,dtype=int)   
            for ii in pxs:
               pxpos[ii]=pxtxt[off+ii*2]
            print("pixel ",ipx," columns: ",pxpos)
            pxindx=numpy.zeros((pxpos.max()+1),dtype=int)
            for k,px in enumerate(sorted(pxpos)):
               pxindx[px]=k

   ff.close()

   return ipx,dd


# ====================================================================================

if __name__ == "__main__":

   import argparse

   parser = argparse.ArgumentParser(\
                      description="Plot the crosstalk data as output from script HDF_Crosstalk in a matrix")
   parser.add_argument('-ap','--ampphase',action='store_true',required=False,\
                       help='plot amplitude/phase in place of I and Q')
   parser.add_argument('-f','--file',type=str,required=False,default=None,\
                       help='Crosstalk data input file. This file is made by HDF_Crosstalk when the --txtout parameter is set')

   args=parser.parse_args()

   import matplotlib
   matplotlib.use('GTK3Agg')
   import matplotlib.pyplot as plt
   from matplotlib.backends.backend_pdf import PdfPages

   if args.file is None:
      tfile=getfile(path='.',pattern='Crosstalk_*.txt')	# get an input crosstalk profile file
   else:
      if not os.path.isfile(args.file):
         sys.exit("Error, %s is not a regular file" % (os.path.basename(args.file)) )
      tfile=args.file

   ff=tfile.split("_px_")[0]
   pfiles=sorted(glob(ff+"_px_*.txt"))			# find the other profiles

   ptitle=os.path.basename(ff)

   print("data are in files: ")
   for pf in pfiles:
      print(os.path.basename(pf))

   nfil=len(pfiles)
   print("number of files: ",nfil)
   nrec,khz,rot,npix=getnrec(pfiles[0])
   print("number of pixels: ",npix)
   print("number of samples: ",nrec)

   ispx=numpy.array(sorted(khz.keys()),dtype=int) 
   print("List of available pixels: ",ispx)  

   data=numpy.zeros((npix,npix,2,nrec),dtype=float)

# ipix ==> index of ipix in ordered pixel list

   for pf in pfiles:					# get the actual crosstalk profiles
      ipix,ddd=getcdata(pf,npix,nrec)
      pindx,=numpy.where( ispx == ipix )
      if pindx.size == 1:
         data[pindx[0],:,:,:]=ddd
      else:
         sys.exit("Error, pixel %d for file %s does not exist in pixel table" % (ipix,fff))  

# remove baselines and scale the profiles such that a full pulse has 1.0 intensity sqrt(I^2+Q^2) 
# or show amplitude and phase

   pxs=numpy.arange(npix)
   scales=numpy.zeros(npix,dtype=float)
   ibaselines=numpy.zeros(npix,dtype=float)
   bindx=numpy.concatenate([numpy.arange(10),numpy.arange((nrec-10),nrec)])
   if args.ampphase:
      ilabel='norm. Amplitude'
      qlabel='phase (deg)'
      for jpix in pxs:
         for ipix in pxs:
            cmplx=data[jpix,ipix,0,:]+1j*data[jpix,ipix,1,:]
            data[jpix,ipix,0,:]=numpy.absolute(cmplx)
            data[jpix,ipix,1,:]=numpy.angle(cmplx,deg=True)
            if jpix == ipix:
               ibaselines[ipix]=numpy.mean(data[ipix,ipix,0,bindx])
               top=numpy.min(data[ipix,ipix,0,:])-ibaselines[ipix]
               scales[ipix]=-1.0/top
   else:
      ilabel='I norm. units'
      qlabel='Q norm. units'
      qbaselines=numpy.zeros(npix,dtype=float)
      for ipix in pxs:
         ibasl=numpy.mean(data[ipix,ipix,0,bindx])
         qbasl=numpy.mean(data[ipix,ipix,1,bindx])
         xtop=numpy.argmin(data[ipix,ipix,0,:])
         top=numpy.sqrt((ibasl-data[ipix,ipix,0,xtop])**2+\
                        (qbasl-data[ipix,ipix,1,xtop])**2)
         scales[ipix]=1.0/top
         ibaselines[ipix]=ibasl
         qbaselines[ipix]=qbasl

   xax=numpy.arange(nrec)

   pdf = PdfPages('Crossmat_%s.pdf' % ptitle)
   fig=plt.figure()

   matplotlib.rcParams.update({'font.size': 4})

   quadrant=numpy.zeros((npix,npix),dtype=int)	# set up matrix layout
   qoffs={}
   for jpix in pxs:
      jq=jpix//6
      for ipix in pxs:
         iq=ipix//6
         qq=jq*100+iq
         quadrant[jpix,ipix]=qq
         qoffs[qq]=(jq*6,iq*6)
         
   qpix=npix
   if npix > 6:
      qpix=6      

   for quad in qoffs:				# plot the matrix
      for jpix in pxs:
         for ipix in pxs:
            if quadrant[jpix,ipix] != quad:
               continue 
            idata=(data[jpix,ipix,0,:]-ibaselines[ipix])*scales[ipix]
            if args.ampphase:
               qdata=data[jpix,ipix,1,:]
            else:
               qdata=(data[jpix,ipix,1,:]-qbaselines[ipix])*scales[ipix]
            ax=plt.subplot2grid((qpix, qpix),((jpix-qoffs[quad][0]),(ipix-qoffs[quad][1])))
            if (jpix-qoffs[quad][0]) == (qpix-1) :
               ax.set_xlabel('sample#/100')
            if (ipix-qoffs[quad][1]) == 0:
               ax.set_ylabel(ilabel)
            if ipix == jpix:
               ax.set_title("%d: %7.2f (kHz) r: %5.3f" % (ispx[ipix],khz[ispx[ipix]],rot[ispx[ipix]]))
            else:
               dkhz=khz[ispx[ipix]]-khz[ispx[jpix]]
               ax.set_title(r'%d: %7.2f ($\Delta$kHz)' % (ispx[ipix],dkhz))
            ax.ticklabel_format(useOffset=False)
            ax.plot(xax/100.0,idata,'b-')
            qax=ax.twinx()
            qax.ticklabel_format(useOffset=False)
            if (ipix-qoffs[quad][1]) == (qpix-1) :
               qax.set_ylabel(qlabel,color='r')
            for tl in qax.get_yticklabels():
               tl.set_color('r')
            qax.plot(xax/100.0,qdata,'r-')
            ax.set_zorder(2)
            ax.patch.set_visible(False)
            qax.set_zorder(1)
            qax.patch.set_visible(True) 

      plt.suptitle(ptitle,size=10)    
      fpage = plt.gcf()
      fpage.set_size_inches(10.75,8.27)
      plt.tight_layout()
      plt.subplots_adjust(top=0.94)
      plt.savefig(pdf,format='pdf',papertype='a4',orientation='landscape')
      plt.close('all')

   pdf.close()

  
