#! /usr/bin/env python3
"""
# filter is part of pyGMTSAR. 
# It is migrated from filter.csh by Dunyu Liu since 20230424.
# Originally it was written by Xiaopeng Tong and David Sandwell on Feb 2010.
# Matt Wei added ENVISAT on May 4 2010.
# DTS added phase gradient on May 26 2010.
# EF, DTS, XT - Jan 10 2014, TSX.

# The script is to convolve the real.grd and imag.grd with gaussian filters.
# Form amplitude, phase, phase gradient, and correlation images.
"""

import sys, os, re, configparser
import subprocess, glob, shutil
from gmtsar_lib import * 

def filter():
    run("alias rm 'rm -f'")
    run('gmt set IO_NC4_CHUNK_SIZE classic')
    
    print('set grdimage options')
    
    scale  = '-JX6.5i'
    thresh = '5.e-21'
    
    run('gmt set COLOR_MODEL = hsv')
    run('gmt set PROJ_LENGTH_UNIT = inch')
    
    n = len(sys.argv)
    arg = sys.argv[0:]
    print('FILTER: input arguments are ', arg)
    print(' ')
    
    def Error_Message():
        print( " ")
        print( "Usage: filter master.PRM aligned.PRM filter decimation [rng_dec azi_dec] [compute_phase_gradient]")  
        print( " ")
        print( " Apply gaussian filter to amplitude and phase images.")
        print( " ")
        print( " filter -  wavelength of the filter in meters (0.5 gain)")
        print( " decimation - (1) better resolution, (2) smaller files")
        print( " ")
        print( "Example: filter IMG-HH-ALPSRP055750660-H1.0__A.PRM IMG-HH-ALPSRP049040660-H1.0__A.PRM 300  2")
        print( " ")
        
    print('FILTER - START ... ...')
    
    if n<5 or n>8:
        print('FILTER: Wrong # of input arguments; # should be >=4 or <=7 ... ...')
        Error_Message()
        
    print(' ')
    print('FILTER: define filter and decimation variables ... ...')
    
    #sharedir = 'gmtsar_sharedir.csh'
    sharedir = '/usr/local/GMTSAR/share/gmtsar'
    filter3  = sharedir + '/filters/fill.3x3'
    filter4  = sharedir + '/filters/xdir'
    filter5  = sharedir + '/filters/ydir'
    
    dec = int(sys.argv[4]) # the 4th arg
    az_lks = 4
    
    print('FILTER: dec, az_lks are ', dec, az_lks)
    print(' ')
    # original csh command: set PRF = `grep PRF *.PRM | awk 'NR == 1 {printf("%d", $3)}'`
    # set PRF to be the value of first encounter of PRF in *.PRM files. 
    pwd = os.getcwd()
    all_files = os.listdir(pwd)
    for filename in all_files:
        if os.path.splitext(filename)[1] == '.PRM':
            PRF = float(grep_value(filename, 'PRF', 3))

    if PRF<1000:
        az_lks = 1
    
    print('FILTER: look for range sampling rate ... ...')
    rng_samp_rate = float(grep_value(sys.argv[1], 'rng_samp_rate', 3))
    print('FILTER: rng_samp_rate is ', rng_samp_rate)
    print(' ')
    print('FILTER: set the range spacing in units of image range pixel size ... ...')
    if type(rng_samp_rate) == float:
        if rng_samp_rate > 110000000:
            dec_rng = 4
            filter1 = sharedir + '/filters/gauss15x5'
        elif rng_samp_rate <110000000 and rng_samp_rate > 20000000:
            dec_rng = 2
            filter1 = sharedir + '/filters/gauss15x5'
            
            print('FILTER: special for TOPS mode ... ...')
            if az_lks == 1:
                filter1 = sharedir + '/filters/gauss5x5'
        else:
            dec_rng = 1
            filter1 = sharedir + '/filters/gauss15x3'
    else:
        print('FILTER: undefined rng_sample_rate in the master PRM file ... ...')
        print('FILTER: ERROR.')
    print(' ')
    print('FILTER: filter1 is ', filter1)  
    print('FILTER: determine whether to run phase gradient ... ...')
    
    compute_phase_gradient = 0
    if n == 6: # if 5 arguments were input
        compute_phase_gradient = int(sys.argv[5])
    if n == 8: # if 7 arguments were input
        compute_phase_gradient = int(sys.argv[7])
    
    print('FILTER: set az_lks and dec_rng to 1 for odd decimation ... ...')
    
    if n == 7 or n == 8:
        if int(sys.argv[6])% 2 == 0:
            jud     = 1
        else:
            jud     = 0
            az_lks  = 1
        
        if int(sys.argv[5])% 2 == 0:
            jud     = 1
        else:
            jud     = 0
            dec_rng = 1
        print('FILTER: jud, az_lks, dec_rng are ', jud, az_lks, dec_rng)
    
    print('FILTER: make the custom filter2 and set the decimation ... ...')
    run('make_gaussian_filter '+sys.argv[1]+' '+str(dec_rng)+' '+str(az_lks)+' '+sys.argv[3]+' > ijdec')
    filter2 = 'gauss_'+sys.argv[3]
    
    # assign idec and jdec
    with open('ijdec', 'r') as f:
        line = f.readline() # idjec is list.
    ijdec = line.split()
    idec = int(ijdec[0])*dec    
    jdec = int(ijdec[1])*dec
    
    print('ijdev, idec, jdec, dec', ijdec, idec, jdec, dec)
    if n == 7 or n == 8:
        range_dec = int(sys.argv[5])
        azimuth_dec = int(sys.argv[6])
        idec = azimuth_dec/az_lks
        jdec = range_dec/dec_rng
        print('FILTER: setting range range_dec = ', range_dec, '... ...')
        print('FILTER: setting azimuth_dec = ', azimuth_dec, '... ...')
    
    print(filter2, idec, jdec, az_lks, dec_rng)
    print(' ')
    
    print('FILTER: filter the two amplitude images ... ...')
    print(' ')
    
    print('FILTER: making amplitudes ... ...')
    run('conv '+str(az_lks)+' '+str(dec_rng)+' '+str(filter1)+' '+sys.argv[1]+' amp1_tmp.grd=bf')
    run('conv '+str(idec)+' '+str(jdec)+' '+str(filter2)+' amp1_tmp.grd=bf amp1.grd')
    delete('amp1_tmp.grd')
    
    run('conv '+str(az_lks)+' '+str(dec_rng)+' '+str(filter1)+' '+sys.argv[2]+' amp2_tmp.grd=bf')
    run('conv '+str(idec)+' '+str(jdec)+' '+str(filter2)+' amp2_tmp.grd=bf amp2.grd')
    delete('amp2_tmp.grd')    
    
    print('FILTER: filter the real and imaginary parts of the interferogram ... ...')
    print('FILTER: filtering interferogram ... ...')
    run('conv '+str(az_lks)+' '+str(dec_rng)+' '+str(filter1)+' real.grd=bf real_tmp.grd=bf')
    run('conv '+str(idec)+' '+str(jdec)+' '+str(filter2)+' real_tmp.grd=bf realfilt.grd')
    run('conv '+str(az_lks)+' '+str(dec_rng)+' '+str(filter1)+' imag.grd=bf imag_tmp.grd=bf')
    run('conv '+str(idec)+' '+str(jdec)+' '+str(filter2)+' imag_tmp.grd=bf imagfilt.grd')
    
    print(' ')
    print('FILTER: also compute gradients and filter them the same way ... ...')
    
    if compute_phase_gradient != 0:
        print('FILTER: filtering for phase gradient ... ...')
        run('conv 1 1 '+str(filter4)+' real_tmp.grd xt.grd=bf')
        run('conv 1 1 '+str(filter5)+' real_tmp.grd yt.grd=bf')
        run('conv '+str(idec)+' '+str(jdec)+' '+str(filter2)+' xt.grd=bf xreal.grd')
        run('conv '+str(idec)+' '+str(jdec)+' '+str(filter2)+' yt.grd=bf yreal.grd')
        delete('xt.grd')
        delete('yt.grd')
        
        run('conv 1 1 '+str(filter4)+' imag_tmp.grd xt.grd=bf')
        run('conv 1 1 '+str(filter5)+' imag_tmp.grd yt.grd=bf')
        run('conv '+str(idec)+' '+str(jdec)+' '+str(filter2)+' xt.grd=bf ximag.grd')
        run('conv '+str(idec)+' '+str(jdec)+' '+str(filter2)+' yt.grd=bf yimag.grd')
        delete('xt.grd')
        delete('yt.grd')        
    
    delete('real_tmp.grd')
    delete('imag_tmp.grd')
    
    print(' ')
    print('FILTER: form amplitude image ... ...')
    print('FILTER: making amplitude ... ...')
    
    run('gmt grdmath realfilt.grd imagfilt.grd HYPOT = amp.grd')
    run('gmt grdmath amp.grd 0.5 POW FLIPUD = display_amp.grd')

    output = subprocess.check_output(['gmt', 'grdinfo', '-L2', 'display_amp.grd'])
    stdev_line = [line for line in output.splitlines() if b'stdev' in line][0]
    stdev = float(stdev_line.split()[-3])
    AMAX = 3 * stdev
    print('FILTER: cmd gmt grdinfo -L2 display_amp.grd output', output)
    print('FILTER: stdev_line is ', stdev_line)
    print('FILTER: stdev is ', stdev)
    print('FILTER: AMAX is ', AMAX)
    print(' ')
    run('gmt grd2cpt display_amp.grd -Z -D -L0/'+str(AMAX)+' -Cgray > display_amp.cpt')
    append_new_line('display_amp.cpt','N 255 255 254')
    run('gmt grdimage display_amp.grd -Cdisplay_amp.cpt '+str(scale)+' -Bxaf+lRange -Byaf+lAzimuth -BWSen -X1.3i -Y3i -P -K > display_amp.ps')
    run('''gmt psscale -Rdisplay_amp.grd -J -DJTC+w5i/0.2i+h+ef -Cdisplay_amp.cpt -Bx0+l"Amplitude (histogram equalized)" -O >> display_amp.ps''')
    run('gmt psconvert -Tf -P -A -Z display_amp.ps')
    
    print(' ')
    print('FILTER: form the correlation ... ...')
    print('FILTER: making correction ... ...')
    run('gmt grdmath amp1.grd amp2.grd MUL = tmp.grd')
    run('gmt grdmath tmp.grd '+str(thresh)+' GE 0 NAN = mask.grd')
    run('gmt grdmath amp.grd tmp.grd SQRT DIV mask.grd MUL FLIPUD = tmp2.grd=bf')
    run('conv 1 1 '+str(filter3)+' tmp2.grd=bf corr.grd')
    run("gmt makecpt -T0./.8/0.1 -Cgray -Z -N > corr.cpt")
    append_new_line('corr.cpt','N 255 255 254')
    run('gmt grdimage corr.grd '+str(scale)+' -Ccorr.cpt -Bxaf+lRange -Byaf+lAzimuth -BWSen -X1.3i -Y3i -P -K > corr.ps')
    run("gmt psscale -Rcorr.grd -J -DJTC+w5i/0.2i+h+ef -Ccorr.cpt -Baf+lCorrelation -O >> corr.ps")
    run('gmt psconvert -Tf -P -A -Z corr.ps')
    
    print(' ')
    print('FILTER: form the phase ... ...')
    run('gmt grdmath imagfilt.grd realfilt.grd ATAN2 mask.grd MUL FLIPUD = phase.grd')
    run("gmt makecpt -Crainbow -T-3.15/3.15/0.1 -Z -N > phase.cpt")
    run('gmt grdimage phase.grd '+str(scale)+' -Bxaf+lRange -Byaf+lAzimuth -BWSen -Cphase.cpt -X1.3i -Y3i -P -K > phase.ps')
    run('''gmt psscale -Rphase.grd -J -DJTC+w5i/0.2i+h -Cphase.cpt -B1.57+l"Phase" -By+lrad -O >> phase.ps''')
    run('gmt psconvert -Tf -P -A -Z phase.ps')
    
    print(' ')
    print('FILTER: compute the solid earth tide (Commented so far) ')
    
    print(' ')
    print('FILTER: make the Werner/Goldstain filtered phase ... ...')
    run('phasefilt -imag imagfilt.grd -real realfilt.grd -amp1 amp1.grd -amp2 amp2.grd -psize 32') 
    run('''gmt grdedit filtphase.grd `gmt grdinfo mask.grd -I- --FORMAT_FLOAT_OUT=%.12lg`''') 
    run('gmt grdmath filtphase.grd mask.grd MUL FLIPUD = phasefilt.grd')
    delete('filtphase.grd')
    run('gmt grdimage phasefilt.grd '+str(scale)+' -Bxaf+lRange -Byaf+lAzimuth -BWSen -Cphase.cpt -X1.3i -Y3i -P -K > phasefilt.ps')
    run('''gmt psscale -Rphasefilt.grd -J -DJTC+w5i/0.2i+h -Cphase.cpt -Bxa1.57+l"Phase" -By+lrad -O >> phasefilt.ps''')
    run('gmt psconvert -Tf -P -A -Z phasefilt.ps')
    
    print(' ')
    print('FILTER: form the phase gradients ... ...')
    
    if compute_phase_gradient != 0:
        print('FILTER: making phase gradient ... ...')
        run('gmt grdmath amp.grd 2. POW = amp_pow.grd')
        run('gmt grdmath realfilt.grd ximag.grd MUL imagfilt.grd xreal.grd MUL SUB amp_pow.grd DIV mask.grd MUL FLIPUD = xphase.grd')
        run('gmt grdmath realfilt.grd yimag.grd MUL imagfilt.grd yreal.grd MUL SUB amp_pow.grd DIV mask.grd MUL FLIPUD = yphase.grd') 
    
    file_shuttle('mask.grd','tmp.grd','mv')
    run('gmt grdmath tmp.grd FLIPUD = mask.grd')
    
    print(' ')
    print('FILTER: delete files ... ...')
    delete('real.grd')
    delete('imag.grd')
    delete('real.grd')
    delete('imag.grd')
    delete('ximag.grd')
    delete('yimag.grd')
    delete('xreal.grd')
    delete('yreal.grd')
    
    print("FILTER - END ... ...")

def _main_func(description):
    filter()

if __name__ == "__main__":
    _main_func(__doc__)

