#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Sep  7 07:40:48 2022

@author: bertelsl
"""
import os
import glob
import netCDF4
import json
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
from numpy import mean
from numpy import std
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_validate
from sklearn import datasets, linear_model
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.ensemble import RandomForestClassifier

#================================================================================================================
class cGroup_cross_validate(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fIndir = Info['fIndir']
        self.fOutdir = Info['fOutdir']

#================================================================================================================
    def start_processing(self):        
#================================================================================================================

        afCV = glob.glob(self.fIndir + '*.csv')
        afCV.sort()
        aCrops = []
        
        for fCV in afCV:
            tmp = os.path.basename(fCV).split('.')[0]
            
            aCrops.append(tmp.split(' - ')[0])
            
        aCrops.append(tmp.split(' - ')[1])
        aUnique_CropTypes = np.unique(aCrops)

        nCropTypes = len(aUnique_CropTypes)

        aScores_accuracy_mean = np.zeros((nCropTypes, nCropTypes), dtype=np.float32)
        aScores_accuracy_std = np.zeros((nCropTypes, nCropTypes), dtype=np.float32)
        aScores_f1_mean = np.zeros((nCropTypes, nCropTypes), dtype=np.float32)
        aScores_f1_std = np.zeros((nCropTypes, nCropTypes), dtype=np.float32)
        
        for fCV in afCV:
            df = pd.read_csv(fCV)
        
            aScores_accuracy_mean[df.iRow, df.iColumn] = df.scores_accuracy_mean * 100
            aScores_accuracy_std[df.iRow, df.iColumn] = df.scores_accuracy_std
            aScores_f1_mean[df.iRow, df.iColumn] = df.scores_f1_mean * 100
            aScores_f1_std[df.iRow, df.iColumn] = df.scores_f1_std
        
            aScores_accuracy_mean[df.iColumn, df.iRow] = df.scores_accuracy_mean * 100
            aScores_accuracy_std[df.iColumn, df.iRow] = df.scores_accuracy_std
            aScores_f1_mean[df.iColumn, df.iRow] = df.scores_f1_mean * 100
            aScores_f1_std[df.iColumn, df.iRow] = df.scores_f1_std
            
        fOut1 = os.path.join(self.fOutdir, 'Scores_accuracy_mean_.csv')
        fOut2 = os.path.join(self.fOutdir, 'Scores_accuracy_std_.csv')
        fOut3 = os.path.join(self.fOutdir, 'Scores_f1_mean_.csv')
        fOut4 = os.path.join(self.fOutdir, 'Scores_f1_std_.csv')            

        aScores_acc = aScores_accuracy_mean[aScores_accuracy_mean != 0]
        aScores_f1 = aScores_f1_mean[aScores_f1_mean != 0]
        
        df1 = pd.DataFrame(aScores_accuracy_mean, columns=aUnique_CropTypes)
        df1['crop types'] = aUnique_CropTypes

        df2 = pd.DataFrame(aScores_accuracy_std, columns=aUnique_CropTypes)
        df2['crop types'] = aUnique_CropTypes
        
        df3 = pd.DataFrame(aScores_f1_mean, columns=aUnique_CropTypes)
        df3['crop types'] = aUnique_CropTypes
        
        df4 = pd.DataFrame(aScores_f1_std, columns=aUnique_CropTypes)
        df4['crop types'] = aUnique_CropTypes

        df1.to_csv(fOut1, sep=';')
        df2.to_csv(fOut2, sep=';')
        df3.to_csv(fOut3, sep=';')
        df4.to_csv(fOut4, sep=';')
        
        print('mean score accuracy = {}'.format(np.mean(aScores_acc)))
        print('mean score accuracy = {}'.format(np.mean(aScores_f1)))

#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fIndir': r'/data/EEA_HRL_VLCC/user/luc/data/10_cross_validate_tmp2/',
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/10_cross_validate_bf032/',
        'overwrite': True
        }

    oGroup_cross_validate = cGroup_cross_validate(Info)
    oGroup_cross_validate.start_processing()