#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Aug 29 09:36:28 2022

@author: bertelsl

Algorithm:
    
    - perform Random Forest cross validation on the metrics, use the best features to filter the metrics:

        'fIndir':r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/05_LUCAS_metrics/', 
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/10_cross_validate_030/',
        'fBestBands': '/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/6_satio_bestbands/best_bands.csv',
        'BestBand_threshold': 0.30, # None for not using the best bands
        
Version: 31/08/2022

"""

import os
import glob
import netCDF4
import json
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
from numpy import mean
from numpy import std
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_validate
from sklearn import datasets, linear_model
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.ensemble import RandomForestClassifier

#================================================================================================================
class cCross_validate(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fIndir = Info['fIndir']
        self.fOutdir = Info['fOutdir']
        self.overwrite = Info['overwrite']
        self.fBestBands = Info['fBestBands']
        self.BestBand_threshold = Info['BestBand_threshold'] 
        
        if not os.path.isdir(self.fOutdir):
            os.mkdir(self.fOutdir)
        
#================================================================================================================
    def start_processing(self):        
#================================================================================================================

        aUnique_CropTypes, aFeatures = self.get_unique_croptypes()
        nCropTypes = len(aUnique_CropTypes)

        if self.BestBand_threshold != None:
            self.get_best_bands()
        
        aScores_accuracy_mean = np.zeros((nCropTypes, nCropTypes), dtype=np.float32)
        aScores_accuracy_std = np.zeros((nCropTypes, nCropTypes), dtype=np.float32)
        aScores_f1_mean = np.zeros((nCropTypes, nCropTypes), dtype=np.float32)
        aScores_f1_std = np.zeros((nCropTypes, nCropTypes), dtype=np.float32)
        
        for own_crop in aUnique_CropTypes:
            
            status1, dfRef = self.read_metrics_data(own_crop)
            
            iRow = list(aUnique_CropTypes).index(own_crop)
            not_found = True
            
            for other_crop in aUnique_CropTypes:

                if (other_crop != own_crop) and not_found:
                    continue
                else:
                    not_found = False
            
                if (other_crop == own_crop):
                    continue
   
                status2, dfTar= self.read_metrics_data(other_crop)

                if not status1 or not status2:
                    continue
 
                iColumn = list(aUnique_CropTypes).index(other_crop)
    
                print('\r Processing: {} - {}'.format(own_crop, other_crop), \
                  end='                                                                                                                  ')
       
                nRef = len(dfRef)
                nTar = len(dfTar)
                
                if (nRef < 10) or (nTar < 10):
                    continue
                
                df = pd.concat([dfRef, dfTar], axis=0)
                
                aX = df.values
                ay = np.array(list(np.full((nRef), 0)) + list(np.full((nTar), 1)))
        
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    # define the model
                    model = RandomForestClassifier()
                    # Run 5-fold cross validation and compute accuary and F1
                    scores_accuracy = cross_val_score(model, aX, ay, cv=5, scoring=('accuracy'))
                    scores_f1 = cross_val_score(model, aX, ay, cv=5, scoring=('f1'))
                    
                    # report performance
                    # print("scoring accuracy: %0.2f accuracy with a standard deviation of %0.2f" % (scores_accuracy.mean(), scores_accuracy.std()))
                    # print("scoring f1: %0.2f accuracy with a standard deviation of %0.2f" % (scores_f1.mean(), scores_f1.std()))
                    
                    aScores_accuracy_mean[iRow, iColumn] = scores_accuracy.mean()
                    aScores_accuracy_std[iRow, iColumn] = scores_accuracy.std()
                    aScores_f1_mean[iRow, iColumn] = scores_f1.mean()
                    aScores_f1_std[iRow, iColumn] = scores_f1.std()

        fOut1 = os.path.join(self.fOutdir, 'Scores_accuracy_mean.csv')
        fOut2 = os.path.join(self.fOutdir, 'Scores_accuracy_std.csv')
        fOut3 = os.path.join(self.fOutdir, 'Scores_f1_mean.csv')
        fOut4 = os.path.join(self.fOutdir, 'Scores_f1_std.csv')
        
        df1 = pd.DataFrame(aScores_accuracy_mean, columns=aUnique_CropTypes)
        df1['crop types'] = aUnique_CropTypes

        df2 = pd.DataFrame(aScores_accuracy_std, columns=aUnique_CropTypes)
        df2['crop types'] = aUnique_CropTypes
        
        df3 = pd.DataFrame(aScores_f1_mean, columns=aUnique_CropTypes)
        df3['crop types'] = aUnique_CropTypes
        
        df4 = pd.DataFrame(aScores_f1_std, columns=aUnique_CropTypes)
        df4['crop types'] = aUnique_CropTypes

        df1.to_csv(fOut1)
        df2.to_csv(fOut2)
        df3.to_csv(fOut3)
        df4.to_csv(fOut4)

#================================================================================================================
    def get_best_bands(self):        
#================================================================================================================
        
        df = pd.read_csv(self.fBestBands)
        dfbb = df[df['score'] < self.BestBand_threshold]
        
        self.drop_bands = dfbb['best_bands']

#================================================================================================================
    def get_unique_croptypes(self):        
#================================================================================================================
   
        afMetrics = glob.glob(os.path.join(self.fIndir, '*.csv'))
        afMetrics.sort()

        df = pd.read_csv(afMetrics[0], sep=';') 
        aFeatures = list(df.columns)
        
        aFeatures.remove('lat')
        aFeatures.remove('lon')
        aFeatures.remove('IDs')        
        
        df = pd.read_csv(afMetrics[-1], sep=';')    
        aFeatures = aFeatures + list(df.columns)

        aFeatures.remove('lat')
        aFeatures.remove('lon')
        aFeatures.remove('IDs')
        
        aUnique_CropTypes = []

        for fMetric in afMetrics:          
            crop = os.path.basename(fMetric).split('.')[0]
            crop = crop.split('S1_')[-1]
            crop = crop.split('S2_')[-1]
            
            if os.path.isfile(os.path.join(self.fIndir, 'S1_{}.csv'.format(crop))) and \
                os.path.isfile(os.path.join(self.fIndir, 'S2_{}.csv'.format(crop))):
                    aUnique_CropTypes.append(crop)

        return np.unique(aUnique_CropTypes), aFeatures


#================================================================================================================
    def read_metrics_data(self, crop):        
#================================================================================================================

        fS1_Metrics = os.path.join(self.fIndir, 'S1_{}.csv'.format(crop))
        fS2_Metrics = os.path.join(self.fIndir, 'S2_{}.csv'.format(crop))

        if not os.path.isfile(fS1_Metrics) or not os.path.isfile(fS2_Metrics):
            return False, None
        
        dfS1 = pd.read_csv(fS1_Metrics, sep=';')
        dfS2 = pd.read_csv(fS2_Metrics, sep=';')
        
        aIDs = []
        aS1_drop = []
        aS2_drop = []

        for ID in dfS1['IDs'].values:
            if ID in dfS2['IDs'].values:
                aIDs.append(ID)
                
        for ID in dfS1['IDs'].values:
            if ID not in aIDs:
                aS1_drop.append(ID)

        for ID in dfS2['IDs'].values:
            if ID not in aIDs:
                aS2_drop.append(ID)

        dfS1.set_index('IDs', inplace=True)
        dfS2.set_index('IDs', inplace=True)
        dfS1.drop(aS1_drop, axis=0, inplace=True)
        dfS2.drop(aS2_drop, axis=0, inplace=True)

        dfS1.drop(['lat', 'lon'], axis=1, inplace=True)
        dfS2.drop(['lat', 'lon'], axis=1, inplace=True)
            
        df = pd.concat([dfS1, dfS2], axis=1)
        
        if self.BestBand_threshold != None:
            df = df.drop(self.drop_bands, axis=1)
                
        if len(df) == 0:
            return False, None
        else:
            return True, df

#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fIndir':r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/05_LUCAS_metrics/', 
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/10_cross_validate_030/',
        'fBestBands': '/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/6_satio_bestbands/best_bands.csv',
        'BestBand_threshold': 0.30, # None for not using the best bands
        'overwrite': True
        }

    oCross_validate = cCross_validate(Info)
    oCross_validate.start_processing()