#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Aug 22 12:42:52 2022

@author: bertelsl

Algorithm:
    
    - selec the best features

        'fIndir':r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/3_satio_LUCAS_metrics/', 
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/6_satio_bestbands/',
        
Version: 31/08/2022

"""

import os
import glob
import netCDF4
import json
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings


#================================================================================================================
class cSelect_best_features(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fIndir = Info['fIndir']
        self.fOutdir = Info['fOutdir']
        self.overwrite = Info['overwrite']
        self.BestFeatures = {}
        
        if not os.path.isdir(self.fOutdir):
            os.mkdir(self.fOutdir)
        
#================================================================================================================
    def start_processing(self):        
#================================================================================================================
        
        fOut = os.path.join(self.fOutdir, 'best_features.csv')
        
        if os.path.isfile(fOut):
            if self.overwrite:
                os.remove(fOut)
            else:
                return
        
       #get the class names for this scenario for the best band comparison
        aClasses, aFeatures = self.get_unique_croptypes()
        nClasses = len(aClasses)
        nFeatures = len(aFeatures)

        #ini the arrays for the statistics of each feature (metrics band)
        nCombinations = int(((nClasses * nClasses) - nClasses) / 2)
        aStatsAna       = np.zeros((nFeatures, nCombinations), dtype=np.float32)
        aStatsAna[:,:]  = np.nan
        aStatsResult    = np.zeros((nFeatures), dtype=np.float32)      
        iFeature = -1   #makes sure we have the metrics band order correct in the statistics
        
        #For each feature (metric band) calculate the class overlap
        for feature in aFeatures:

            print('\r * Processing feature: {}'.format(feature), end='                                                                                                                                                          ')
            iFeature += 1
            iD = 0
    
            # Handle all class combinations
            for iCs in range(nClasses-1):
               
                status, dfRef = self.read_metrics_data(aClasses[iCs])
                
                if not status:
                    continue

                aClass_Ref = dfRef[feature].values
                 
                for iCe in range(iCs+1, nClasses):

                    status, dfTar = self.read_metrics_data(aClasses[iCe])
                
                    if not status:
                        continue

                    aClass_Tar = dfTar[feature].values

                    #run the analysis for this reference-target-class combo and this specific feature (metric band)
                    aMin  = [min(aClass_Ref), min(aClass_Tar)]
                    aMax  = [max(aClass_Ref), max(aClass_Tar)]
                    Min_v = max(aMin)
                    Max_v = min(aMax)
                    nRef  = np.sum((aClass_Ref < Min_v) | (aClass_Ref > Max_v))
                    nTar  = np.sum((aClass_Tar < Min_v) | (aClass_Tar > Max_v))
                    #save the result of this class-combo separability for this feature (metric band)
                    aStatsAna[iFeature, iD] = float(nRef + nTar) / (aClass_Ref.shape[0] + aClass_Tar.shape[0])
                    iD += 1
            # save the statistics for the whole feature (metric band) by calculating median of all class-combos (separability indicator)  
            aStatsResult[iFeature] = np.mean(aStatsAna[iFeature, :])
        
        # analyze the separability indicators by sorting them highest to lowest 
        iBestFeatures   = aStatsResult.argsort()[::-1]
        # get the best band names in the right order
        
        aBestFeatures    = []
        aScore = []
        
        for iX in iBestFeatures:
            aBestFeatures.append(aFeatures[iX])
            aScore.append(aStatsResult[iX])
            
        #write to dictionary
        print('**** BestBands: ' + ', '.join(aBestFeatures))
        
        df = pd.DataFrame(aBestFeatures, columns=['best_features'])
        df['score'] = aScore
        
        df.to_csv(fOut, index=False)

#================================================================================================================
    def read_metrics_data(self, crop):        
#================================================================================================================

        fS1_Metrics = os.path.join(self.fIndir, 'S1_{}.csv'.format(crop))
        fS2_Metrics = os.path.join(self.fIndir, 'S2_{}.csv'.format(crop))

        if not os.path.isfile(fS1_Metrics) or not os.path.isfile(fS2_Metrics):
            return False, None
        
        dfS1 = pd.read_csv(fS1_Metrics, sep=';')
        dfS2 = pd.read_csv(fS2_Metrics, sep=';')
        
        aIDs = []
        aS1_drop = []
        aS2_drop = []

        for ID in dfS1['IDs'].values:
            if ID in dfS2['IDs'].values:
                aIDs.append(ID)
                
        for ID in dfS1['IDs'].values:
            if ID not in aIDs:
                aS1_drop.append(ID)

        for ID in dfS2['IDs'].values:
            if ID not in aIDs:
                aS2_drop.append(ID)

        dfS1.set_index('IDs', inplace=True)
        dfS2.set_index('IDs', inplace=True)
        dfS1.drop(aS1_drop, axis=0, inplace=True)
        dfS2.drop(aS2_drop, axis=0, inplace=True)

        dfS1.drop(['lat', 'lon'], axis=1, inplace=True)
        dfS2.drop(['lat', 'lon'], axis=1, inplace=True)
            
        df = pd.concat([dfS1, dfS2], axis=1)
        
        if len(df) == 0:
            return False, None
        else:
            return True, df

#================================================================================================================
    def get_unique_croptypes(self):        
#================================================================================================================
   
        afMetrics = glob.glob(os.path.join(self.fIndir, '*.csv'))
        afMetrics.sort()

        df = pd.read_csv(afMetrics[0], sep=';') 
        aFeatures = list(df.columns)
        
        aFeatures.remove('lat')
        aFeatures.remove('lon')
        aFeatures.remove('IDs')        
        
        df = pd.read_csv(afMetrics[-1], sep=';')    
        aFeatures = aFeatures + list(df.columns)

        aFeatures.remove('lat')
        aFeatures.remove('lon')
        aFeatures.remove('IDs')
        
        aUnique_CropTypes = []

        for fMetric in afMetrics:          
            crop = os.path.basename(fMetric).split('.')[0]
            crop = crop.split('S1_')[-1]
            crop = crop.split('S2_')[-1]
            
            if os.path.isfile(os.path.join(self.fIndir, 'S1_{}.csv'.format(crop))) and \
                os.path.isfile(os.path.join(self.fIndir, 'S2_{}.csv'.format(crop))):
                    aUnique_CropTypes.append(crop)

        return np.unique(aUnique_CropTypes), aFeatures

#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fIndir':r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/05_LUCAS_metrics/', 
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/08_best_features/',
        'overwrite': True
        }

    oSelect_best_features = cSelect_best_features(Info)
    oSelect_best_features.start_processing()