#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Aug  9 12:46:16 2022

@author: bertelsl
"""
import os
import glob
import netCDF4
import json
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings

#================================================================================================================
class cAnalyse_separability(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fIndir = Info['fIndir']
        self.fOutdir = Info['fOutdir']
        self.intra_exclude = Info['intra_exclude']
        
#================================================================================================================
    def start_processing(self):        
#================================================================================================================

        # only one matrix file
        afMX_file = glob.glob(os.path.join(self.fIndir, '*.nc'))

        self.aUnique_CropTypes = []
        self.aUnique_CropTypes.sort()
        
        columns = ['crop', 'nEntries']
        
        for crop in self.aUnique_CropTypes:
            columns.append('{}_min'.format(crop))
            columns.append('{}_Q10'.format(crop))
            columns.append('{}_Q25'.format(crop))
            columns.append('{}_Q50'.format(crop))
        
        self.dfCropInfo = pd.DataFrame(columns=columns) 
        self.dfCropInfo['crop'] = self.aUnique_CropTypes
        self.dfCropInfo.set_index('crop', inplace=True)
        
        # Calculate stats on all data
        fOut_stats_in = os.path.join(self.fOutdir, 'stats_in.csv')
        
        if not os.path.isfile(fOut_stats_in):
            self.calculate_stats(afMX_file, fOut_stats_in)
            
        # Calculate stats on excluded data
        fOut_stats_ex = os.path.join(self.fOutdir, 'stats_ex.csv')

        if not os.path.isfile(fOut_stats_ex):
    
            if self.intra_exclude :
                self.exclude_intra_outliers()
                
            self.exclude_inter_outliers()           
            self.calculate_stats(afMX_file, fOut_stats_ex)            

#================================================================================================================
    def calculate_stats(self, fMX_file, fOut_stats):        
#================================================================================================================
 
        print(' * Calculating the statistics')

        xMX = xr.load_dataset(fMX_file)             

        aUnique_CropTypes = np.unique(xMX['labelsY'].values)
        aUnique_CropTypes.sort()
        
        #loop over the classxOwnes and claculate the class statistics
        for crop in self.aUnique_CropTypes:

            xOwn = xMX.where((xMX.labelsY == crop), drop=True)
            xOther = xMX.where((xMX.labelsY != crop), drop=True)
 

                
#-----
        # for forclass in self.parameters['selected_classes']:
        #     iOwn = dfValidData[dfValidData['class'] == forclass].index
        #     iOther = dfValidData[dfValidData['class'] != forclass].index
            
        #     #ini the arrays to hold the RMSE results
        #     nOwn = iOwn.shape[0]
        #     aRMSE_min = np.zeros((nOwn), dtype=float)
        #     aRMSE_median = np.zeros((nOwn), dtype=float)
            
        #     nPixels = nOwn**2 - nOwn

        #     # run over all spectra (trainign points) in the OWN class
        #     for iX in iOwn:
        #         # Get indices to valid in-class spectra (training points)
        #         with warnings.catch_warnings():
        #             warnings.simplefilter("ignore")
        #             iRMSE_min     = np.less(self.rule_MX[iOwn, iX], np.nanmin(self.rule_MX[iOther, iX]))
        #             iRMSE_median  = np.less(self.rule_MX[iOwn, iX], np.nanmedian(self.rule_MX[iOther, iX]))
            
        #         #Count the valid in-class spectra:
        #         aRMSE_min[iRMSE_min] += 1
        #         aRMSE_median[iRMSE_median] +=1                
            
        #     #now calculate the statistics for the whole class
        #     with warnings.catch_warnings():
        #         warnings.simplefilter("ignore")
        #         self.aClassInfo.at[forclass, 'rmse_min_perc'] = aRMSE_min.sum() / nPixels * 100
        #         self.aClassInfo.at[forclass, 'rmse_median_perc'] = aRMSE_median.sum() / nPixels * 100       
#-----
                
                
            print('\r   - processing: {} v.s. {}'.format(crop, other_crop), 
                  end='                                                                                                                         ')


            aRMSE_min = np.zeros((nOwn), dtype=float)
            aRMSE_Q10 = np.zeros((nOwn), dtype=float)
            aRMSE_Q25 = np.zeros((nOwn), dtype=float)
            aRMSE_Q50 = np.zeros((nOwn), dtype=float)
            
            nPixels = nOwn**2 - nOwn

            # run over all spectra (trainign points) in the OWN class
            for iOwn in range(nOwn):
                # Get indices to valid in-class spectra (training points)
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    # iRMSE_min     = np.less(self.rule_MX[iOwn, iX], np.nanmin(self.rule_MX[iOther, iX]))
                    # iRMSE_median  = np.less(self.rule_MX[iOwn, iX], np.nanmedian(self.rule_MX[iOther, iX]))
                    iRMSE_min  = np.less(aOwn[iOwn], np.nanmin(aOther[iOwn]))
                    iRMSE_Q10 = np.less(aOwn[iOwn], np.percentile(aOther[iOwn], 10))
                    iRMSE_Q25 = np.less(aOwn[iOwn], np.percentile(aOther[iOwn], 25))
                    iRMSE_Q50 = np.less(aOwn[iOwn], np.percentile(aOther[iOwn], 50))
  
                #Count the valid in-class spectra:
                aRMSE_min[iRMSE_min] += 1
                aRMSE_Q10[iRMSE_Q10] +=1                
                aRMSE_Q25[iRMSE_Q25] +=1                
                aRMSE_Q50[iRMSE_Q50] +=1                
            
            #now calculate the statistics for the whole class
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                self.dfCropInfo.at[crop, '{}_min'.format(other_crop)] = np.round(aRMSE_min.sum() / nPixels * 100)
                self.dfCropInfo.at[crop, '{}_Q10'.format(other_crop)] = np.round(aRMSE_Q10.sum() / nPixels * 100)
                self.dfCropInfo.at[crop, '{}_Q25'.format(other_crop)] = np.round(aRMSE_Q25.sum() / nPixels * 100)
                self.dfCropInfo.at[crop, '{}_Q50'.format(other_crop)] = np.round(aRMSE_Q50.sum() / nPixels * 100)
 

    self.dfCropInfo.reset_index(level='crop', inplace=True)
    self.dfCropInfo.to_csv(fOut_stats, index=False)
        
    print('\r * Calculating the statistics: Done!', end='                                                                                                            ')
        
#================================================================================================================
    def exclude_intra_outliers(self):        
#================================================================================================================

        THRESHOLD = 0.5
        OUTLIER_WEIGHT_THRESHOLD = 0.5
        
        nIntra_excluded = 0
        aID = []
        aDelta = []
        
        for crop in self.aUnique_CropTypes:
        
            xCrop = self.aMX.where((self.aMX.labelsX == crop) & (self.aMX.labelsY  == crop), drop=True)
            aCrop = xCrop['MX'].values
            aIDs = xCrop['ids'].values
            aLat = xCrop['lat'].values
            aLon = xCrop['lon'].values
            nOwn = len(aIDs)
            
            Class_Median_RMSE = np.nanmedian(aLon)
            Class_Median_delta = Class_Median_RMSE * THRESHOLD
            
            # Handle all in class indices:
            for iX in range(aCrop.shape[0]):
                #check the training point RMSE against class median threshold --> INTRA outlier check
                 if ((np.nanmedian(aCrop[:, iX]) < (Class_Median_RMSE - Class_Median_delta)) | \
                     (np.nanmedian(aCrop[:, iX]) > (Class_Median_RMSE + Class_Median_delta))):

                    aID.append(aIDs[iX])
                    aDelta.append(np.nanmedian(aCrop[:, iX]) - (Class_Median_RMSE * self.parameters['intra_rmse_threshold']))

        ##now we weight the importance of the found outliers against the "to report outlier threshold"
        #calculate weight factor for each outlier by taken the DELTA into account
        aDeltaMax = max(aDelta)
        aWeight = [x/float(aDeltaMax) for x in aDelta]
        #get the outlier we want to report
        aKeep = [i for i, j in enumerate(aWeight) if j > self.parameters['outlier_weight_threshold']]
        aValidLocations = [aID[i] for i in aKeep]
        self.ID_intra_outliers = aValidLocations
                    
        #run loop to get final selected INTRA outliers and perform changes on master dataframes
        for element in aValidLocations:
            self.dfValidData.at[element, 'intra_excluded'] = True
            nIntra_excluded += 1
            forClass = self.dfValidData.loc[element, 'class']
            self.aClassInfo.at[forClass, 'nExcluded'] = self.aClassInfo.loc[forClass, 'nExcluded'] + 1

        print('**** Overall ' + str(nIntra_excluded) + ' outliers (pure training points) were removed ...')      
        if len(self.ID_intra_outliers) > 0:
            print('**** location_ids: ' + ', '.join([str(x) for x in self.ID_intra_outliers]))




#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fIndir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/4_HANTS_matrices_test/',
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/5_HANTS_analysis_test/',
        'intra_exclude': True,
        'periods': 577  # 2017/09/01 till 2019/03/31
        }

    oAnalyse_separability = cAnalyse_separability(Info)
    oAnalyse_separability.start_processing()