#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Aug 20 05:11:35 2022

@author: bertelsl
"""

import os
import glob
import netCDF4
import json
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings

#================================================================================================================
class cAnalyse_separability(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fIndir = Info['fIndir']
        self.fOutdir = Info['fOutdir']
        self.overwrite = Info['overwrite']
        self.fBestBands = Info['fBestBands']
        self.BestBand_threshold = Info['BestBand_threshold']
        
        if not os.path.isdir(self.fOutdir):
            os.mkdir(self.fOutdir)
            
#================================================================================================================
    def start_processing(self):        
#================================================================================================================

        aUnique_CropTypes = self.get_unique_croptypes()

        if self.BestBand_threshold != None:
            self.get_best_bands()

        for crop in aUnique_CropTypes:

            fOut_MX = os.path.join(self.fOutdir, '{}.nc'.format(crop))
            
            if os.path.isfile(fOut_MX):
                if self.overwrite:
                    os.remove(fOut_MX)
                else:
                    continue
                
            if self.read_metrics_data(crop):
                self.calculate_MX(crop, fOut_MX)
            
        print(' * Done!')

#================================================================================================================
    def get_best_bands(self):        
#================================================================================================================
        
        df = pd.read_csv(self.fBestBands)
        dfbb = df[df['score'] < self.BestBand_threshold]
        
        self.drop_bands = dfbb['best_bands']

#================================================================================================================
    def get_unique_croptypes(self):        
#================================================================================================================
   
        afMetrics = glob.glob(os.path.join(self.fIndir, '*.csv'))

        CropTypes = []

        for fMetric in afMetrics:        
            
            basename = os.path.basename(fMetric).split('.')[0]
            basename = basename.split('S1_')[-1]
            basename = basename.split('S2_')[-1]
            
            CropTypes.append(basename)

        return np.unique(CropTypes)

#================================================================================================================
    def read_metrics_data(self, crop):        
#================================================================================================================
   
        print(' * Reading the metrics data for {}'.format(crop))
   
        fS1_Metrics = os.path.join(self.fIndir, 'S1_{}.csv'.format(crop))
        fS2_Metrics = os.path.join(self.fIndir, 'S2_{}.csv'.format(crop))

        if not os.path.isfile(fS1_Metrics) or not os.path.isfile(fS2_Metrics):
            return False

        self.aCropTypes = []

        dfS1 = pd.read_csv(fS1_Metrics, sep=';')
        dfS2 = pd.read_csv(fS2_Metrics, sep=';')
        
        aIDs = []
        aS1_drop = []
        aS2_drop = []

        for ID in dfS1['IDs'].values:
            if ID in dfS2['IDs'].values:
                aIDs.append(ID)
                
        for ID in dfS1['IDs'].values:
            if ID not in aIDs:
                aS1_drop.append(ID)

        for ID in dfS2['IDs'].values:
            if ID not in aIDs:
                aS2_drop.append(ID)

        dfS1.set_index('IDs', inplace=True)
        dfS2.set_index('IDs', inplace=True)
        dfS1.drop(aS1_drop, axis=0, inplace=True)
        dfS2.drop(aS2_drop, axis=0, inplace=True)
        
        lat = dfS1['lat'].values
        lon=dfS1['lon'].values
        dfS1.drop(['lat', 'lon'], axis=1, inplace=True)
        dfS2.drop(['lat', 'lon'], axis=1, inplace=True)
        
        df = pd.concat([dfS1, dfS2], axis=1)
        
        if self.BestBand_threshold != None:
            df = df.drop(self.drop_bands, axis=1)
        
        # normalize the data
        df_nume = (df - df.min(axis=0))
        df_deno = (df.max(axis=0) - df.min(axis=0))
        df_div = df_nume.div(df_deno, axis=1)
        
        self.aCropMetrics = df_div.to_numpy()

        self.aLat = lat
        self.aLon = lon
        
        return True

#================================================================================================================
    def calculate_MX(self, crop, fOut_MX):        
#================================================================================================================
        
        print(' * Calculating the matrix for {}'.format(crop))

        nEntries = self.aCropMetrics.shape[0]
        nFeatures = self.aCropMetrics.shape[1]
        aIds = []
        aMX = np.zeros((nEntries, nEntries), dtype=np.float32)

        for iR in range(nEntries):
            
            aMX[iR, iR] = np.nan
            aIds.append('{}_{}'.format(crop, iR+1))
            
            print('\r  - Progress: {} / {}'.format(iR+1, nEntries), end='                                                                                    ')

            for iC in range(iR+1, nEntries):

                RMSE = np.sqrt((np.power((self.aCropMetrics[iR] - self.aCropMetrics[iC]), 2)).sum() / nFeatures)     

                aMX[iR, iC] = RMSE
                aMX[iC, iR] = RMSE  
     
        new_data = xr.Dataset(
                data_vars=dict(
                    MX=(["labelsX", "labelsY"], aMX),             
                ),
                coords=dict(
                    labelsY = aIds,
                    labelsX = aIds,
                    lat=(["labelsX"], self.aLat),
                    lon=(["labelsX"], self.aLon),
                ),
                attrs=dict(description="equivalence matrix"),
            )        
        
        new_data.to_netcdf(fOut_MX)         

        print('')
      
#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fIndir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/3_satio_LUCAS_metrics/',
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/4_satio_LUCAS_matrices/',
        'fBestBands': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/6_HANTS_Ts_bestbands/best_bands.csv',
        'BestBand_threshold': None, # 0.12, # None for not using the best bands
        'overwrite': False
        }

    oAnalyse_separability = cAnalyse_separability(Info)
    oAnalyse_separability.start_processing()