#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 24 11:31:41 2022

@author: bertelsl

Algorithm:
    
    - calculate metrics for the different crop types and save to csv-files:
        
        S1: aBands = ['VH', 'VV', 'VHVVR', 'RVI']
              aMetrics = ['Q10', 'Q50', 'Q90', 'QR', 'Std'] + 6 steps
              
        S2: aBands = ['B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B11', 'B12', 
                      'NDVI', 'NDMI', 'NDGI', 'NDRE1', 'NDRE2', 'NDRE5']
              aMetrics = ['Q10', 'Q50', 'Q90', 'QR', 'Std']  + 6 steps
        
        'fIndir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/03_LUCAS_crops/',
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/05_LUCAS_metrics/', 
        
Version: 31/08/2022

"""
import os
import glob
import netCDF4
import json
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy
from datetime import datetime

#================================================================================================================
class cCalculate_metrics(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fIndir = Info['fIndir']
        self.fOutdir = Info['fOutdir']
        self.nSteps = Info['nSteps']
        self.overwrite = Info['overwrite']

        if not os.path.isdir(self.fOutdir):
            os.mkdir(self.fOutdir)
            
#================================================================================================================
    def start_processing(self):        
#================================================================================================================
        
        self.calculate_S1_metrics()
        self.calculate_S2_metrics()
        
#================================================================================================================
    def calculate_S1_metrics(self):        
#================================================================================================================
        
        afS1 = glob.glob(self.fIndir + 'S1_*.nc')
        
        for fS1 in afS1:
            
            basename = os.path.basename(fS1).split('.')[0]
            fOut = os.path.join(self.fOutdir, '{}.csv'.format(basename))
            
            if os.path.isfile(fOut):
                if self.overwrite:
                    os.remove(fOut)
                else:
                    continue
                
            print('\r Calculating S1 metrices for {}'.format(basename), \
                  end='                                                                                                                                                              ')
                
            xData = xr.load_dataset(fS1)

            self.aVH = xData['VH'].values
            self.aVV = xData['VV'].values
            
            aBands = ['VH', 'VV', 'VHVVR', 'RVI']
            aMetrics = ['Q10', 'Q50', 'Q90', 'QR', 'Std']    

            for iX in range(self.nSteps):
                aMetrics.append('tstep_{}'.format(iX+1))
                
            aMetrics_columns = []
            for band in aBands:
                for metric in aMetrics:
                    aMetrics_columns.append('{}_{}'.format(band, metric))

            self.aMetrics = np.zeros((self.aVH.shape[1], len(aMetrics_columns)), dtype=np.float32)
            self.iMetrics = 0
            
            for band in aBands:
                self.metrics(band)

            df = pd.DataFrame(self.aMetrics, columns=aMetrics_columns)
            df['IDs'] = xData['IDs'].values
            df.to_csv(fOut, sep=';', index=False)

#================================================================================================================
    def calculate_S2_metrics(self):        
#================================================================================================================
        
        afS2 = glob.glob(self.fIndir + 'S2_*.nc')
        
        for fS2 in afS2:
            
            basename = os.path.basename(fS2).split('.')[0]
            fOut = os.path.join(self.fOutdir, '{}.csv'.format(basename))
            
            if os.path.isfile(fOut):
                if self.overwrite:
                    os.remove(fOut)
                else:
                    continue
                
            print('\r Calculating S2 metrices for {}'.format(basename), \
                  end='                                                                                                                                                       ')
            
            xData = xr.load_dataset(fS2)

            self.aB02 = xData['B02'].values
            self.aB03 = xData['B03'].values
            self.aB04 = xData['B04'].values
            self.aB05 = xData['B05'].values
            self.aB06 = xData['B06'].values
            self.aB07 = xData['B07'].values
            self.aB08 = xData['B08'].values
            self.aB8A = xData['B8A'].values
            self.aB11 = xData['B11'].values
            self.aB12 = xData['B12'].values
            
            aBands = ['B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B11', 'B12', 
                      'NDVI', 'NDMI', 'NDGI', 'NDRE1', 'NDRE2', 'NDRE5']
            aMetrics = ['Q10', 'Q50', 'Q90', 'QR', 'Std']    

            for iX in range(self.nSteps):
                aMetrics.append('tstep_{}'.format(iX+1))
                
            aMetrics_columns = []
            for band in aBands:
                for metric in aMetrics:
                    aMetrics_columns.append('{}_{}'.format(band, metric))    

            self.aMetrics = np.zeros((self.aB02.shape[1], len(aMetrics_columns)), dtype=np.float32)
            self.iMetrics = 0
            
            for band in aBands:
                self.metrics(band)

            df = pd.DataFrame(self.aMetrics, columns=aMetrics_columns)
            df['IDs'] = xData['IDs'].values
            
            df.to_csv(fOut, sep=';', index=False)    
                     
#================================================================================================================
    def metrics(self, band):        
#================================================================================================================
                  
        if band == 'B02':
            aBand = self.aB02
        elif band == 'B03':
            aBand = self.aB03
        elif band == 'B04':
            aBand =  self.aB04
        elif band == 'B05':
            aBand =  self.aB05
        elif band == 'B06':
            aBand =  self.aB06
        elif band == 'B07':
            aBand =  self.aB07
        elif band == 'B08':
            aBand =  self.aB08
        elif band == 'B8A':
            aBand =  self.aB8A
        elif band == 'B11':
            aBand =  self.aB11
        elif band == 'B12':
            aBand =  self.aB12
        elif band == 'NDVI':
                aBand =  (self.aB08 - self.aB04) / (self.aB08 + self.aB04)
        elif band == 'NDMI':
            aBand =  (self.aB08 - self.aB11) / (self.aB08 + self.aB11)
        elif band == 'NDGI':
            aBand =  (self.aB03 - self.aB04) / (self.aB03 + self.aB04)
        elif band == 'ANIR':
            aBand =  self.ANIR(self.aB04, self.aB08, self.aB11)
        elif band == 'NDRE1':
            aBand =  (self.aB08 - self.aB05) / (self.aB08 + self.aB05)
        elif band == 'NDRE2':
            aBand =  (self.aB08 - self.aB06) / (self.aB08 + self.aB06)
        elif band == 'NDRE5':
            aBand =  (self.aB07 - self.aB05) / (self.aB07 + self.aB05)
        elif band == 'VH':
            aBand = self.aVH
        elif band == 'VV':
            aBand = self.aVV
        elif band == 'VHVVR':
            aBand =   self.aVH - self.aVV
        elif band == 'RVI':
            aBand = (4 * np.power(10,self.aVH / 10)) / (np.power(10, self.aVV / 10) + np.power(10, self.aVH / 10))
            
      
        Q10 = np.nanpercentile(aBand, 10, axis = 0)
        Q50 = np.nanpercentile(aBand, 50, axis = 0)
        Q90 = np.nanpercentile(aBand, 90, axis = 0)
        QR = np.ptp(np.vstack((Q10, Q90)), axis=0)
        Std = np.nanstd(aBand, axis = 0)
        
        ts = self.tsteps(aBand, n_steps=self.nSteps) 
        
        self.aMetrics [:, self.iMetrics] = Q10
        self.iMetrics += 1
        self.aMetrics [:, self.iMetrics] = Q50
        self.iMetrics += 1
        self.aMetrics [:, self.iMetrics] = Q90
        self.iMetrics += 1
        self.aMetrics [:, self.iMetrics] = QR
        self.iMetrics += 1
        self.aMetrics [:, self.iMetrics] = Std
        self.iMetrics += 1                
        
        for iX in range(self.nSteps):
            self.aMetrics[:, self.iMetrics] = ts[iX, :]
            self.iMetrics += 1                    
        
        
#================================================================================================================
    def anir(B04, B08, B11):
#================================================================================================================

        WL_B04 = 0.665
        WL_B08 = 0.842
        WL_B11 = 1.610
        
        a = np.sqrt(np.square(WL_B08 - WL_B04) + np.square(B08 - B04))
        b = np.sqrt(np.square(WL_B11 - WL_B08) + np.square(B11 - B08))
        c = np.sqrt(np.square(WL_B11 - WL_B04) + np.square(B11 - B04))
        # calculate angle with NIR as reference (ANIR)
        site_length = (np.square(a) + np.square(b) - np.square(c)) / (2 * a * b)
        site_length[site_length < -1] = -1
        site_length[site_length > 1] = 1
        return 1. / np.pi * np.arccos(site_length)
        
#================================================================================================================
    def tsteps(self, xx, n_steps=36):
#================================================================================================================

        x = xx.copy()
        
        # x[np.isnan(x)] = 0
        # return scipy.signal.resample(x, n_steps, axis=0)

        idx = np.round(np.linspace(0, x.shape[0] - 1, n_steps, axis=0)).astype(int)
        return x[idx, ...]
    
#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fIndir': r'/data/EEA_HRL_VLCC/user/luc/data/04_LPIS_LUCAS_merged_crops/',
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/05_LPIS_LUCAS_metrics/', 
        'nSteps': 6,
        'overwrite':False
        }

    oCalculate_metrics = cCalculate_metrics(Info)
    oCalculate_metrics.start_processing()        
        