#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Aug 16 10:39:50 2022

@author: bertelsl

Algorithm:
    
    - calculate the matrices for all the crops, use the best bands on request and use the score threshold to select the best bands:
        
        'fIndir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/05_LUCAS_metrics/',
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/06_LUCAS_matrices_bestbands_032/',
        'fBest_features': '/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/08_best_features/best_features.csv',
        'Best_features_threshold': 0.32, # None for not using the best bands
        
Version: 31/08/2022

"""

import os
import glob
import netCDF4
import json
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings

#================================================================================================================
class cCalculate_matrices(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fIndir = Info['fIndir']
        self.fOutdir = Info['fOutdir']
        self.overwrite = Info['overwrite']
        self.fBest_features = Info['fBest_features']
        self.Best_features_threshold = Info['Best_features_threshold']
        
        if not os.path.isdir(self.fOutdir):
            os.mkdir(self.fOutdir)
        
#================================================================================================================
    def start_processing(self, dist):        
#================================================================================================================
  
        self.aUnique_CropTypes = self.get_unique_croptypes()

        if self.Best_features_threshold != None:
            self.get_best_bands()
            
        self.read_metrics_data()
        
        self.dia_count = 0
                        
        for crop in self.aUnique_CropTypes:
             
            fOut_MX = os.path.join(self.fOutdir, '{}.nc'.format(crop))
            
            if os.path.isfile(fOut_MX):
                if self.overwrite:
                    os.remove(fOut_MX)
                else:
                    continue       

            if crop in self.aValid_CropTypes:
                print(' * Calculating the matrix for {}'.format(crop))
                
                self.calculate_MX(crop, fOut_MX)
         
        print('')
        print(' * Done!')

#================================================================================================================
    def get_best_bands(self):        
#================================================================================================================
        
        df = pd.read_csv(self.fBest_features)
        dfbb = df[df['score'] < self.Best_features_threshold]
        
        self.drop_features = dfbb['best_features']

#================================================================================================================
    def get_unique_croptypes(self):        
#================================================================================================================
   
        afMetrics = glob.glob(os.path.join(self.fIndir, '*.csv'))

        CropTypes = []

        for fMetric in afMetrics:        
            
            basename = os.path.basename(fMetric).split('.')[0]
            basename = basename.split('S1_')[-1]
            basename = basename.split('S2_')[-1]
            
            CropTypes.append(basename)

        return np.unique(CropTypes)
    
#================================================================================================================
    def read_metrics_data(self):        
#================================================================================================================
    
        init = True
  
        self.aValid_CropTypes = []     
        aALL_IDs = []
            
        for crop in self.aUnique_CropTypes:
                   
            fS1_Metrics = os.path.join(self.fIndir, 'S1_{}.csv'.format(crop))
            fS2_Metrics = os.path.join(self.fIndir, 'S2_{}.csv'.format(crop))
    
            if not os.path.isfile(fS1_Metrics) or not os.path.isfile(fS2_Metrics):
                continue
            
            print('\r * Reading the metrics data for {}'.format(crop), \
                  end='                                                                                                                                                               ')

            dfS1 = pd.read_csv(fS1_Metrics, sep=';')
            dfS2 = pd.read_csv(fS2_Metrics, sep=';')
            
            aIDs = []
            aS1_drop = []
            aS2_drop = []
    
            for ID in dfS1['IDs'].values:
                if ID in dfS2['IDs'].values:
                    aIDs.append(ID)
                    
            for ID in dfS1['IDs'].values:
                if ID not in aIDs:
                    aS1_drop.append(ID)
    
            for ID in dfS2['IDs'].values:
                if ID not in aIDs:
                    aS2_drop.append(ID)
    
            dfS1.set_index('IDs', inplace=True)
            dfS2.set_index('IDs', inplace=True)
            dfS1.drop(aS1_drop, axis=0, inplace=True)
            dfS2.drop(aS2_drop, axis=0, inplace=True)
            
            # lat = dfS1['lat'].values
            # lon=dfS1['lon'].values
            
            # dfS1.drop(['lat', 'lon'], axis=1, inplace=True)
            # dfS2.drop(['lat', 'lon'], axis=1, inplace=True)
            
            self.aValid_CropTypes = self.aValid_CropTypes + [crop] * len(dfS1)
            aALL_IDs = aALL_IDs + list(dfS1.index.values)
            df = pd.concat([dfS1, dfS2], axis=1)
            
            if self.Best_features_threshold != None:
                df = df.drop(self.drop_features, axis=1)
            
            # normalize the data
            df_nume = (df - df.min(axis=0))
            df_deno = (df.max(axis=0) - df.min(axis=0))
            df_div = df_nume.div(df_deno, axis=1)
            aNorm = df_div.to_numpy()
            
            if init:
                init = False
                aData = aNorm
                # aLat = lat
                # aLon = lon
                aMetrics = df.columns
                self.nFeatures = len(aMetrics)
            else:
                aData = np.concatenate((aData, aNorm), axis=0)
                # aLat = np.append(aLat, lat)
                # aLon = np.append(aLon, lon)
                
        print('')
        
        self.xData = xr.Dataset(
            data_vars=dict(
                METRICS=(["crops", "metrics"], aData),   
                # lat = (["crops"], aLat),
                # lon = (["crops"], aLon),                
                IDs = (["crops"], aALL_IDs),
            ),
            coords=dict(
                crops=(self.aValid_CropTypes),
                metrics = (aMetrics),
            ),
            attrs=dict(description="metrics data"),
        )

        self.nData = aData.shape[0]
        self.aData = aData

#================================================================================================================
    def calculate_MX(self, crop, fOut_MX):        
#================================================================================================================

        xCrop = self.xData.where((self.xData.crops == crop), drop=True)
        aCrop = xCrop['METRICS'].values
        # aLat = xCrop['lat'].values
        # aLon = xCrop['lon'].values
        aAll_IDs = self.xData['IDs'].values
        aOwn_IDs = xCrop['IDs'].values
        
        aMX = np.zeros((aCrop.shape[0], self.nData), dtype=np.float32)

        for iR in range(aCrop.shape[0]):
            
            print('\r  - Progress: {} / {}'.format(iR+1, aCrop.shape[0]), end='                                                                                    ')

            for iC in range(iR+1, self.nData):

                RMSE = np.sqrt((np.power((aCrop[iR] - self.aData[iC]), 2)).sum() / self.nFeatures)     

                aMX[iR, iC] = RMSE
                
                if iC < aCrop.shape[0]:
                    aMX[iC, iR] = RMSE  
     
        for iR in range(aCrop.shape[0]):
            aMX[iR, self.dia_count] = np.nan
            self.dia_count += 1


        new_data = xr.Dataset(
                data_vars=dict(
                    MX=(["IDsX", "IDsY"], aMX),                    
                    # lat=(["IDsX"], aLat),
                    # lon=(["IDsX"], aLon),      
                    yCroptypes=(["IDsY"], self.aValid_CropTypes),      
                ),
                coords=dict(
                    IDsY = aAll_IDs,
                    IDsX = aOwn_IDs,
                ),
                attrs=dict(description="matrix"),
            )    

        new_data.to_netcdf(fOut_MX)         

        print('')

#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fIndir': r'/data/EEA_HRL_VLCC/user/luc/data/05_LPIS_LUCAS_metrics/', 
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/06_LPIS_LUCAS_matrices/', 
        'fBest_features': r'/data/EEA_HRL_VLCC/user/luc/data/best_features.csv',
        'Best_features_threshold': None, #0.32, # None for not using the best bands
        'overwrite': False
        }

    oCalculate_matrices = cCalculate_matrices(Info)
    oCalculate_matrices.start_processing('RMSE')
    