#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 31 08:30:17 2022

@author: bertelsl
"""


import os
import glob
import netCDF4
import json
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
from numpy import mean
from numpy import std
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_validate
from sklearn import datasets, linear_model
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.ensemble import RandomForestClassifier

#================================================================================================================
class cCross_validate(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fIndir = Info['fIndir']
        self.fOutdir = Info['fOutdir']
        self.fTmpdir = Info['fTmpdir']
        self.overwrite = Info['overwrite']
        self.fBest_features = Info['fBest_features']
        self.Best_features_threshold = Info['Best_features_threshold'] 
        self.fMan_exclude_metrices = Info['fMan_exclude_metrices']
        self.exclude_features = Info['exclude_features']
        self.exclude_outliers = Info['exclude_outliers']
        self.man_exclude_metrices = Info['man_exclude_metrices']
        
        if self.exclude_outliers:
            aIntra_outliers = pd.read_csv(Info['fIntra_outliers'])['aIntraRemoveIDs'].values
            aInter_outliers = pd.read_csv(Info['fInter_outliers'])['aInterRemoveIDs'].values
            
            self.aOutliers = np.array(list(aIntra_outliers) + list(aInter_outliers))
                                 
        if self.exclude_features:
            df = pd.read_csv(self.fBest_features)
            dfbb = df[df['score'] < self.Best_features_threshold]
            
            self.drop_features = dfbb['best_features']
            
        if self.man_exclude_metrices:
            df = pd.read_csv(self.fMan_exclude_metrices)
            
            self.man_drop_features = df['drop_features']
           
        if not os.path.isdir(self.fOutdir):
            os.mkdir(self.fOutdir)
            
        if not os.path.isdir(self.fTmpdir):
            os.mkdir(self.fTmpdir)
            
#================================================================================================================
    def start_processing(self):        
#================================================================================================================

        aUnique_CropTypes, aFeatures = self.get_unique_croptypes()
        nCropTypes = len(aUnique_CropTypes)

        aScores_accuracy_mean = np.zeros((nCropTypes, nCropTypes), dtype=np.float32)
        aScores_accuracy_std = np.zeros((nCropTypes, nCropTypes), dtype=np.float32)
        aScores_f1_mean = np.zeros((nCropTypes, nCropTypes), dtype=np.float32)
        aScores_f1_std = np.zeros((nCropTypes, nCropTypes), dtype=np.float32)
        
        anEntries = []
        
        for own_crop in aUnique_CropTypes:
            
            status1, dfRef = self.read_metrics_data(own_crop)
            
            if not status1:
                anEntries.append(0)
                continue
            
            anEntries.append(len(dfRef))
            
            iRow = list(aUnique_CropTypes).index(own_crop)
            not_found = True
            
            for other_crop in aUnique_CropTypes:

                if (other_crop != own_crop) and not_found:
                    continue
                else:
                    not_found = False
            
                if (other_crop == own_crop):
                    continue
   
                status2, dfTar= self.read_metrics_data(other_crop)

                if not status1 or not status2:
                    continue
 
                fTmp = os.path.join(self.fTmpdir, '{} - {}.csv'.format(own_crop, other_crop))
                
                if os.path.isfile(fTmp):
                    continue
    
                iColumn = list(aUnique_CropTypes).index(other_crop)
    
                print('\r Processing: {} - {}'.format(own_crop, other_crop), \
                  end='                                                                                                                  ')
       
                nRef = len(dfRef)
                nTar = len(dfTar)
                
                if (nRef < 10) or (nTar < 10):
                    continue
                
                df = pd.concat([dfRef, dfTar], axis=0)
                
                aX = df.values
                aX[np.isnan(aX)] = 0
                ay = np.array(list(np.full((nRef), 0)) + list(np.full((nTar), 1)))

                # define the model
                model = RandomForestClassifier()
                # Run 5-fold cross validation and compute accuary and F1
                scores_accuracy = cross_val_score(model, aX, ay, cv=5, scoring=('accuracy'))
                scores_f1 = cross_val_score(model, aX, ay, cv=5, scoring=('f1'))
                
                # report performance
                # print("scoring accuracy: %0.2f accuracy with a standard deviation of %0.2f" % (scores_accuracy.mean(), scores_accuracy.std()))
                # print("scoring f1: %0.2f accuracy with a standard deviation of %0.2f" % (scores_f1.mean(), scores_f1.std()))
                
                aScores_accuracy_mean[iRow, iColumn] = scores_accuracy.mean() * 100
                aScores_accuracy_std[iRow, iColumn] = scores_accuracy.std()
                aScores_f1_mean[iRow, iColumn] = scores_f1.mean() * 100
                aScores_f1_std[iRow, iColumn] = scores_f1.std()
                 
                aScores_accuracy_mean[iColumn, iRow] = scores_accuracy.mean() * 100
                aScores_accuracy_std[iColumn, iRow] = scores_accuracy.std()
                aScores_f1_mean[iColumn, iRow] = scores_f1.mean() * 100
                aScores_f1_std[iColumn, iRow] = scores_f1.std()
                               
                ddd = {'iRow': iRow, 'iColumn': iColumn, 
                                   'scores_accuracy_mean':scores_accuracy.mean(),
                                   'scores_accuracy_std': scores_accuracy.std(),
                                   'scores_f1_mean': scores_f1.mean(),
                                   'scores_f1_std': scores_f1.std(),
                                   'nEntries': len(dfRef)}
                
                df = pd.DataFrame(data = ddd, index=[0])             
                df.to_csv(fTmp, index=False)

        fOut1 = os.path.join(self.fOutdir, 'Scores_accuracy_mean.csv')
        fOut2 = os.path.join(self.fOutdir, 'Scores_accuracy_std.csv')
        fOut3 = os.path.join(self.fOutdir, 'Scores_f1_mean.csv')
        fOut4 = os.path.join(self.fOutdir, 'Scores_f1_std.csv')
        
        aScores_acc = aScores_accuracy_mean[aScores_accuracy_mean != 0]
        aScores_f1 = aScores_f1_mean[aScores_f1_mean != 0]
        
        df1 = pd.DataFrame(aScores_accuracy_mean, columns=aUnique_CropTypes)
        df1['crop types'] = aUnique_CropTypes
        df1['nEntries'] = anEntries

        df2 = pd.DataFrame(aScores_accuracy_std, columns=aUnique_CropTypes)
        df2['crop types'] = aUnique_CropTypes
        df2['nEntries'] = anEntries
        
        df3 = pd.DataFrame(aScores_f1_mean, columns=aUnique_CropTypes)
        df3['crop types'] = aUnique_CropTypes
        df3['nEntries'] = anEntries
        
        df4 = pd.DataFrame(aScores_f1_std, columns=aUnique_CropTypes)
        df4['crop types'] = aUnique_CropTypes
        df4['nEntries'] = anEntries

        df1.to_csv(fOut1, sep=';')
        df2.to_csv(fOut2, sep=';')
        df3.to_csv(fOut3, sep=';')
        df4.to_csv(fOut4, sep=';')
        
        print('mean score accuracy = {}'.format(np.mean(aScores_acc)))
        print('mean score accuracy = {}'.format(np.mean(aScores_f1)))

#================================================================================================================
    def get_unique_croptypes(self):        
#================================================================================================================
   
        afMetrics = glob.glob(os.path.join(self.fIndir, '*.csv'))
        afMetrics.sort()

        df = pd.read_csv(afMetrics[0], sep=';') 
        aFeatures = list(df.columns)
        
        # aFeatures.remove('lat')
        # aFeatures.remove('lon')
        aFeatures.remove('IDs')        
        
        df = pd.read_csv(afMetrics[-1], sep=';')    
        aFeatures = aFeatures + list(df.columns)

        # aFeatures.remove('lat')
        # aFeatures.remove('lon')
        aFeatures.remove('IDs')
        
        aUnique_CropTypes = []

        for fMetric in afMetrics:          
            crop = os.path.basename(fMetric).split('.')[0]
            crop = crop.split('S1_')[-1]
            crop = crop.split('S2_')[-1]
            
            if os.path.isfile(os.path.join(self.fIndir, 'S1_{}.csv'.format(crop))) and \
                os.path.isfile(os.path.join(self.fIndir, 'S2_{}.csv'.format(crop))):
                    aUnique_CropTypes.append(crop)

        return np.unique(aUnique_CropTypes), aFeatures


#================================================================================================================
    def read_metrics_data(self, crop):        
#================================================================================================================

        fS1_Metrics = os.path.join(self.fIndir, 'S1_{}.csv'.format(crop))
        fS2_Metrics = os.path.join(self.fIndir, 'S2_{}.csv'.format(crop))

        if not os.path.isfile(fS1_Metrics) or not os.path.isfile(fS2_Metrics):
            return False, None
        
        dfS1 = pd.read_csv(fS1_Metrics, sep=';')
        dfS2 = pd.read_csv(fS2_Metrics, sep=';')
        
        aIDs = []
        aS1_drop = []
        aS2_drop = []

        ''' drop IDs in S1 and S2 which are not found in both'''
        for ID in dfS1['IDs'].values:
            if ID in dfS2['IDs'].values:
                aIDs.append(ID)
                
        for ID in dfS1['IDs'].values:
            if ID not in aIDs:
                aS1_drop.append(ID)

        for ID in dfS2['IDs'].values:
            if ID not in aIDs:
                aS2_drop.append(ID)

        dfS1.set_index('IDs', inplace=True)
        dfS2.set_index('IDs', inplace=True)
        dfS1.drop(aS1_drop, axis=0, inplace=True)
        dfS2.drop(aS2_drop, axis=0, inplace=True)

        df = pd.concat([dfS1, dfS2], axis=1)
        
        if self.exclude_features:
            df = df.drop(self.drop_features, axis=1)
            
        if self.man_exclude_metrices:
            df = df.drop(self.man_drop_features, axis=1)
            
        if self.exclude_outliers:
            for ID in self.aOutliers:
                if ID in df.index:
                    df.drop(ID, axis=0, inplace=True)        
                
        if len(df) == 0:
            return False, None
        else:
            return True, df

#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fIndir': r'/data/EEA_HRL_VLCC/user/luc/data/05_LPIS_LUCAS_metrics/', 
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/10_cross_validate_s2_only/',
        'fTmpdir': r'/data/EEA_HRL_VLCC/user/luc/data/10_cross_validate_tmp4/',
        'fIntra_outliers': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/07_outliers/intra_outliers.csv',
        'fInter_outliers': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/07_outliers/inter_outliers.csv',
        'fBest_features': '/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/08_best_features/best_features.csv',
        'Best_features_threshold': 0.30, # None for not using the best bands
        'fMan_exclude_metrices': r'/data/EEA_HRL_VLCC/user/luc/xlsx/use_only_metrices.csv',
        'exclude_features': False,
        'exclude_outliers': False,
        'man_exclude_metrices': True,
        'overwrite': True
        }

    oCross_validate = cCross_validate(Info)
    oCross_validate.start_processing()