#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Aug  8 09:50:46 2022

@author: bertelsl
"""
import os
import glob
import netCDF4
import json
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy

#================================================================================================================
class cCalculate_metrics(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fIndir = Info['fIndir']
        self.fOutdir = Info['fOutdir']
        # self.periods = Info['periods']
        # self.newT = pd.date_range("2017-09-01", periods=Info['periods'])
        self.start_day = Info['start_day']
        self.end_day = Info['end_day']
        self.calculate_VIs = Info['calculate_VIs']
        self.calculate_tstep = Info['calcuate_tstep']
        self.overwrite = Info['overwrite']

        if not os.path.isdir(self.fOutdir):
            os.mkdir(self.fOutdir)
            
#================================================================================================================
    def start_processing(self):        
#================================================================================================================
        
        aCropfiles = glob.glob(os.path.join(self.fIndir, '*.nc'))
        
        for fCropfile in aCropfiles:
            print('\r Calculating metrics for {}'.format(os.path.basename(fCropfile)),
                  end='                                                                                                                                ')
            
            basename = os.path.basename(fCropfile).split('.')[0]
            fOut = os.path.join(self.fOutdir, '{}.csv'.format(basename))
            
            if os.path.isfile(fOut):
                if self.overwrite:
                    os.remove(fOut)
                else:
                    continue

            data = xr.load_dataset(fCropfile)
            
            lat = data['lat'].values
            lon = data['lon'].values
            self.aB02 = data['B02'].values[:, self.start_day:self.end_day]
            self.aB03 = data['B03'].values[:, self.start_day:self.end_day]
            self.aB04 = data['B04'].values[:, self.start_day:self.end_day]
            self.aB05 = data['B05'].values[:, self.start_day:self.end_day]
            self.aB06 = data['B06'].values[:, self.start_day:self.end_day]
            self.aB07 = data['B07'].values[:, self.start_day:self.end_day]
            self.aB08 = data['B08'].values[:, self.start_day:self.end_day]
            self.aB8A = data['B8A'].values[:, self.start_day:self.end_day]
            self.aB11 = data['B11'].values[:, self.start_day:self.end_day]
            self.aB12 = data['B12'].values[:, self.start_day:self.end_day]
            self.aVH = data['VH'].values[:, self.start_day:self.end_day]
            self.aVV = data['VV'].values[:, self.start_day:self.end_day]
            
            if self.calculate_VIs:
                aVIs = ['NDVI', 'NDVI2', 'NDRE', 'SIPI', 'LCI', 'HUE', 'VALUE', 'VHVVR', 'RVI']
                aVI_Metrics = ['Mean', 'Std', 'Median', 'Q10', 'Q90', 'QR', 'Q10_dot', 'aQ90_dot']
                aVI_Columns = []
                
                self.aVI_Metrics_out = np.zeros((self.aB02.shape[0], len(aVIs)*len(aVI_Metrics)), dtype=np.float32)
                self.iMetrics = 0
                
                for VI in aVIs:
                    for Metric in aVI_Metrics:
                        aVI_Columns.append('{}_{}'.format(VI, Metric))
    
                    if VI == 'VALUE':
                        continue
                    elif VI == 'HUE':
                        VI = 'HUE_VALUE'
                        aHUE, aVALUE = self.calculate_VI(VI)
                        
                        self.calculate_VI_metrics(aHUE)
                        self.calculate_VI_metrics(aVALUE)
                    else:
                        aVI = self.calculate_VI(VI)
                        self.calculate_VI_metrics(aVI)
    
                dfVI = pd.DataFrame(self.aVI_Metrics_out, columns=aVI_Columns)
    
            if self.calculate_tstep:
                
                self.iMetrics = 0
                self.nSteps = 6    
                aBands = ['B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B11', 'B12', 'VH', 'VV']
                aTs_metrics = ['Q10', 'Q50', 'Q90', 'QR', 'Std']               
                
                self.aTs_Metrics_out = np.zeros((self.aB02.shape[0], len(aBands) * (len(aTs_metrics) + self.nSteps)), dtype=np.float32)
                self.iMetrics = 0
                self.nSteps = 6    

                for iX in range(self.nSteps):
                    aTs_metrics.append('tstep_{}'.format(iX+1))
                    
                aTs_Columns = []
                for band in aBands:
                    for ts_metric in aTs_metrics:
                        aTs_Columns.append('{}_{}'.format(band, ts_metric))

                for band in aBands:
                    self.calculate_Ts_metrics(band)

                a=1

                dfTs = pd.DataFrame(self.aTs_Metrics_out, columns=aTs_Columns)
    
            if self.calculate_VIs and self.calculate_tstep:
                df = pd.concat([dfVI, dfTs], axis=1)
                df['lat'] = lat
                df['lon'] = lon
                df.to_csv(fOut, sep=';', index=False)
                
            elif self.calculate_VIs:
                dfVI['lat'] = lat
                dfVI['lon'] = lon
                dfVI.to_csv(fOut, sep=';', index=False)
            
            elif self.calculate_tstep:
                dfTs['lat'] = lat
                dfTs['lon'] = lon
                dfTs.to_csv(fOut, sep=';', index=False)            
    
    
        print('\rDone!',end='                                                                                                                               ')
        print('')

#================================================================================================================
    def tsteps(self, x, n_steps=6):
 #================================================================================================================

        return scipy.signal.resample(x, n_steps, axis=1)

#================================================================================================================
    def calculate_Ts_metrics(self, aBand):        
#================================================================================================================
    
                if aBand == 'B02':
                    aBand = self.aB02
                elif aBand == 'B03':
                    aBand = self.aB03
                elif aBand == 'B04':
                    aBand = self.aB04
                elif aBand == 'B05':
                    aBand = self.aB05
                elif aBand == 'B06':
                    aBand = self.aB06
                elif aBand == 'B07':
                    aBand = self.aB07
                elif aBand == 'B08':
                    aBand = self.aB08
                elif aBand == 'B8A':
                    aBand = self.aB8A
                elif aBand == 'B11':
                    aBand = self.aB11
                elif aBand == 'B12':
                    aBand = self.aB12
                elif aBand == 'VH':
                    aBand = self.aVH
                elif aBand == 'VV':
                    aBand = self.aVV
                
                Q10 = np.percentile(aBand, 10, axis = 1)
                Q50 = np.percentile(aBand, 50, axis = 1)
                Q90 = np.percentile(aBand, 90, axis = 1)
                QR = np.ptp(np.vstack((Q10, Q90)), axis=0)
                Std = np.std(aBand, axis = 1)
                
                ts = self.tsteps(aBand, n_steps=self.nSteps) 
                
                self.aTs_Metrics_out [:, self.iMetrics] = Q10
                self.iMetrics += 1
                self.aTs_Metrics_out [:, self.iMetrics] = Q50
                self.iMetrics += 1
                self.aTs_Metrics_out [:, self.iMetrics] = Q90
                self.iMetrics += 1
                self.aTs_Metrics_out [:, self.iMetrics] = QR
                self.iMetrics += 1
                self.aTs_Metrics_out [:, self.iMetrics] = Std
                self.iMetrics += 1                
                
                for iX in range(self.nSteps):
                    self.aTs_Metrics_out[:, self.iMetrics] = ts[:, iX]
                    self.iMetrics += 1                    
              
#================================================================================================================
    def calculate_VI_metrics(self, aMetric):        
#================================================================================================================
    
        aMean = np.mean(aMetric, axis = 1)
        aStd = np.std(aMetric, axis = 1)
        aMedian = np.median(aMetric, axis = 1)
        aQ10 = np.percentile(aMetric, 10, axis = 1)
        aQ90 = np.percentile(aMetric, 90, axis = 1)
        aQR = np.ptp(np.vstack((aQ10, aQ90)), axis=0)
        
        aQ10_dot, aQ90_dot = self.calculate_DOM(aMetric)
  
        self.aVI_Metrics_out[:, self.iMetrics] = aMean
        self.iMetrics += 1
        self.aVI_Metrics_out[:, self.iMetrics] = aStd
        self.iMetrics += 1
        self.aVI_Metrics_out[:, self.iMetrics] = aMedian
        self.iMetrics += 1
        self.aVI_Metrics_out[:, self.iMetrics] = aQ10
        self.iMetrics += 1
        self.aVI_Metrics_out[:, self.iMetrics] = aQ90
        self.iMetrics += 1
        self.aVI_Metrics_out[:, self.iMetrics] = aQR
        self.iMetrics += 1
        self.aVI_Metrics_out[:, self.iMetrics] = aQ10_dot
        self.iMetrics += 1
        self.aVI_Metrics_out[:, self.iMetrics] = aQ90_dot
        self.iMetrics += 1        
                
#================================================================================================================
    def calculate_VI(self, VIname):        
#================================================================================================================
        
        if VIname == 'NDVI':
            return (self.aB08 - self.aB04) / (self.aB08 + self.aB04)
        
        elif VIname == 'NDVI2':
            return (self.aB12 - self.aB08) / (self.aB12 + self.aB08)
         
        elif VIname == 'NDRE':
            return (self.aB07 - self.aB05) / (self.aB07 + self.aB05)           
        
        elif VIname == 'SIPI':
            return (self.aB08 - self.aB02) / (self.aB08 - self.aB04)
                
        elif VIname == 'LCI':
            return (self.aB08 - self.aB05) / (self.aB08 + self.aB04)
        
        elif VIname == 'NDRE':
            return (self.aB07 - self.aB05) / (self.aB07 + self.aB05)

        elif VIname == 'HUE_VALUE':
            return self.calculateHUE_VALUE()

        elif VIname == 'VHVVR':
            return self.aVH - self.aVV

        elif VIname == 'RVI':
            return  (4 * np.power(10, self.aVH / 10)) / (np.power(10, self.aVV / 10) + np.power(10, self.aVH / 10))      

 #================================================================================================================   
    def calculate_DOM(self, aMetric):
 #================================================================================================================   
        """ calculates the Day of Timeseries in which the q10 and q90
            of the harmonized time series of the VI o falls """
            
        # fig = plt.figure()
        # x=np.arange(577)
        # y1=aMetric[0, :]
        # # y2=aMetric[7, :]
        # plt.plot(x,y1)
        # # plt.plot(x,y2)
        # plt.show()
            
        #calculate standard parameters for reference year in harmonized time series
        aQ10 = np.percentile(aMetric, 10, axis = 1).astype(np.float32)
        aQ90 = np.percentile(aMetric, 90, axis = 1).astype(np.float32)

        #ini output arrays
        aQ10_dot = np.zeros_like(aQ10, dtype=np.float32)
        aQ90_dot = np.zeros_like(aQ10, dtype=np.float32)
    
        #loop over all entries
        for iL in range(aMetric.shape[0]):
            # calculate day of timeseries
            aQ10_dot[iL] = np.argmin(np.abs(aMetric[iL, :] - aQ10[iL]))
            aQ90_dot[iL] = np.argmin(np.abs(aMetric[iL, :] - aQ90[iL]))

        return aQ10_dot, aQ90_dot
    
#================================================================================================================
    def calculateHUE_VALUE(self):
#================================================================================================================
        """ Calculate HUE and VALUE on 2D or 3D array
            plus scales the results to fit into int16 array """
        # ini array with zeros for output
        aHue = np.zeros_like(self.aB04, dtype=np.float64)
                
        # calculate Value and Value range
        aValue = np.fmax(np.fmax(self.aB12, self.aB08), self.aB04)
        aDivValue = aValue - np.fmin(np.fmin(self.aB12, self.aB08), self.aB04)
        # create mask for right positions of Hue calculation
        iS = np.logical_and(aValue == self.aB12, aDivValue != 0. )
        iN = np.logical_and(aValue == self.aB08,  aDivValue != 0. )
        iR = np.logical_and(aValue == self.aB04,  aDivValue != 0. )
        # calculate Hue
        aHue[iS] = np.mod((60. * (self.aB08[iS] - self.aB04[iS])/ aDivValue[iS] + 360.), 360.) 
        aHue[iN] = 60. * (self.aB04[iN] - self.aB12[iN]) / aDivValue[iN] + 120.
        aHue[iR] = 60. * (self.aB12[iR]- self.aB08[iR]) / aDivValue[iR] + 240.
        
        # scale the output in order to bring them in Int16
        aHue = np.round(aHue * 50, decimals=0).astype(np.int16)
        aValue = np.round(aValue * 2000, decimals=0).astype(np.int16)

        return aHue, aValue
        
#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fIndir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/2_HANTS_crops/',
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/3_HANTS_Ts_metrics/',
        # 'periods': 577, # 2017/09/01 till 2019/03/31
        'start_day': 181, # 2018/03/01
        'end_day': 365,  # 2018/08/31
        'calculate_VIs': False,
        'calcuate_tstep':True,
        'overwrite':True
        }

    oCalculate_metrics = cCalculate_metrics(Info)
    oCalculate_metrics.start_processing()
