#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Aug  9 12:46:16 2022

@author: bertelsl
"""
import os
import glob
import netCDF4
import json
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings

        
INTER_RMSE_THRESHOLD = 1
INTRA_RMSE_THRESHOLD = 0.5
OUTLIER_WEIGHT_THRESHOLD = 0.5
        

#================================================================================================================
class cAnalyse_separability(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fIndir = Info['fIndir']
        self.fOutdir = Info['fOutdir']
        self.intra_exclude = Info['intra_exclude']
        
#================================================================================================================
    def start_processing(self):        
#================================================================================================================

        self.afMX_files = glob.glob(os.path.join(self.fIndir, '*.nc'))
        self.afMX_files.sort()
        
        self.aUnique_CropTypes = []

        for fMX_file in self.afMX_files:
            crop = os.path.basename(fMX_file).split('_RMSE')[0]  
            self.aUnique_CropTypes.append(crop)
        
        self.aUnique_CropTypes.sort()
        columns = ['crop', 'nEntries']
        
        for crop in self.aUnique_CropTypes:
            columns.append('{}_min'.format(crop))
            columns.append('{}_Q10'.format(crop))
            columns.append('{}_Q25'.format(crop))
            columns.append('{}_Q50'.format(crop))
        
        self.dfCropInfo = pd.DataFrame(columns=columns) 
        self.dfCropInfo['crop'] = self.aUnique_CropTypes
        self.dfCropInfo.set_index('crop', inplace=True)
        
        # Calculate stats on all data
        fOut_stats_in = os.path.join(self.fOutdir, 'stats_in.csv')
        
        if not os.path.isfile(fOut_stats_in):
            self.calculate_stats(fOut_stats_in)
            
        # Calculate stats on excluded data
        fOut_stats_ex = os.path.join(self.fOutdir, 'stats_ex.csv')

        if  not os.path.isfile(fOut_stats_ex):
    
            if self.intra_exclude:
                self.exclude_intra_outliers()
                
            self.exclude_inter_outliers()           
            
        self.calculate_stats(fOut_stats_ex, exclude=True)            

#================================================================================================================
    def calculate_stats(self, fOut_stats, exclude=False):        
#================================================================================================================
 
        print(' * Calculating the statistics')

        # if os.path.isfile(fOut_stats):
        #     return

        aNown = []

        for fMX_file in self.afMX_files:
            
            crop = os.path.basename(fMX_file).split('_RMSE')[0]
            
            self.xMX = xr.load_dataset(fMX_file)             
            
            if exclude:
                self.exclude()
                
            xOwn = self.xMX.where((self.xMX.labelsY  == crop), drop=True)
            aOwn = xOwn['MX'].values
            nOwn = aOwn.shape[0]
            aNown.append(nOwn)
            
            #loop over the classxOwnes and claculate the class statistics
            for other_crop in self.aUnique_CropTypes:
                
                if crop == other_crop:
                    continue
                
                print('\r   - processing: {} v.s. {}'.format(crop, other_crop), 
                      end='                                                                                                                         ')

                xOther = self.xMX.where((self.xMX.labelsY == other_crop), drop=True)
                aOther = xOther['MX'].values

                aRMSE_min = np.zeros((nOwn), dtype=float)
                aRMSE_Q10 = np.zeros((nOwn), dtype=float)
                aRMSE_Q25 = np.zeros((nOwn), dtype=float)
                aRMSE_Q50 = np.zeros((nOwn), dtype=float)
                
                nPixels = nOwn**2 - nOwn
    
                # run over all spectra (trainign points) in the OWN class
                for iOwn in range(nOwn):
                    # Get indices to valid in-class spectra (training points)
                    with warnings.catch_warnings():
                        warnings.simplefilter("ignore")
                        # iRMSE_min     = np.less(self.rule_MX[iOwn, iX], np.nanmin(self.rule_MX[iOther, iX]))
                        # iRMSE_median  = np.less(self.rule_MX[iOwn, iX], np.nanmedian(self.rule_MX[iOther, iX]))
                        iRMSE_min  = np.less(aOwn[iOwn], np.nanmin(aOther[iOwn]))
                        iRMSE_Q10 = np.less(aOwn[iOwn], np.percentile(aOther[iOwn], 10))
                        iRMSE_Q25 = np.less(aOwn[iOwn], np.percentile(aOther[iOwn], 25))
                        iRMSE_Q50 = np.less(aOwn[iOwn], np.percentile(aOther[iOwn], 50))
      
                    #Count the valid in-class spectra:
                    if len(aRMSE_min) != len(iRMSE_min):
                        a=1
                    aRMSE_min[iRMSE_min] += 1
                    aRMSE_Q10[iRMSE_Q10] +=1                
                    aRMSE_Q25[iRMSE_Q25] +=1                
                    aRMSE_Q50[iRMSE_Q50] +=1                
                
                #now calculate the statistics for the whole class
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    self.dfCropInfo.at[crop, '{}_min'.format(other_crop)] = np.round(aRMSE_min.sum() / nPixels * 100)
                    self.dfCropInfo.at[crop, '{}_Q10'.format(other_crop)] = np.round(aRMSE_Q10.sum() / nPixels * 100)
                    self.dfCropInfo.at[crop, '{}_Q25'.format(other_crop)] = np.round(aRMSE_Q25.sum() / nPixels * 100)
                    self.dfCropInfo.at[crop, '{}_Q50'.format(other_crop)] = np.round(aRMSE_Q50.sum() / nPixels * 100)
 
        self.dfCropInfo['nEntries'] = aNown
        self.dfCropInfo.reset_index(level='crop', inplace=True)
        self.dfCropInfo.to_csv(fOut_stats, index=False)
            
        print('\r * Calculating the statistics: Done!', end='                                                                                                            ')
 
#================================================================================================================
    def exclude(self):        
#================================================================================================================

        # add sequence number to the Y labels
        aLabelsY = self.xMX['labelsY'].values
        
        current_label = aLabelsY[0]
        count = 1
        aNewLabels = []
        
        for label in aLabelsY:
            if label != current_label:
                count = 1
                current_label = label
            aNewLabels.append('{}_{}'.format(label, count))
            count+=1
            
        self.xMX['labelsY'] = aNewLabels
        labelsX = self.xMX['labelsX'].values
        labelsY = self.xMX['labelsY'].values
        condX = []
        condY = []
        
        ###
        if self.intra_exclude:   
            for ID in self.aIntraRemoveIDs:
                if ID in labelsX:
                    condX.append(False)
                    aNewLabels.remove(ID)
                else:
                    condX.append(True)
                    
                if ID in labelsY:
                    condY.append(False)
                else:
                    condY.append(True)
                
        for ID in self.aInterRemoveIDs:
            if ID in labelsX:
                condX.append(False)
                aNewLabels.remove(ID)
            else:
                condX.append(True)
                
            if ID in labelsY:
                condY.append(False)
            else:
                condY.append(True)
    
        dsX = xr.DataArray(condX, dims=('y'), coords={'y':['a', 'b', 'c']}, name='y')
        dsY = xr.DataArray(condY, dims=('y'), coords={'y':['a', 'b', 'c']}, name='y')
        self.xMX.where(dsX, drop=True)
        self.xMY.where(dsY, drop=True)         
        
        aLabels = []
        for label in aNewLabels:
            aLabels.append('_'.join(label.split('_')[0:-1]))
        
        self.xMX['labelsY'] = aLabels
        
        a=1
        ###



        # drop the reduested entries in X and Y direction
        # if self.intra_exclude:
        #     count = 1
        #     for IntraRemoveID in self.aIntraRemoveIDs:
        #         print ('\r   - Excluding intra crop ID: {}  {}/{}'.format(IntraRemoveID, count, len(self.aIntraRemoveIDs), end=''))
        #         count += 1
        #         self.xMX = self.xMX.where(self.xMX.labelsX != IntraRemoveID, drop=True)
        #         self.xMX = self.xMX.where(self.xMX.labelsY != IntraRemoveID, drop=True)           
        #         aNewLabels.remove(IntraRemoveID)
        # print('')
        
        # count = 1
        # for InterRemoveID in self.aInterRemoveIDs:
        #     print ('\r   - Excluding intra crop ID: {}  {}/{}'.format(IntraRemoveID, count, len(self.aInterRemoveIDs), end=''))
        #     count += 1            

        #     self.xMX = self.xMX.where(self.xMX.labelsX != InterRemoveID, drop=True)
        #     self.xMX = self.xMX.where(self.xMX.labelsY != InterRemoveID, drop=True)         
            
        #     if InterRemoveID in aNewLabels:
        #         aNewLabels.remove(InterRemoveID)
        #     else:
        #         print('Not found: {}'.format(InterRemoveID))
            
        # print('')
       
        # # remove sequence number from the Y labels
        # aLabels = []
        # for label in aNewLabels:
        #     aLabels.append('_'.join(label.split('_')[0:-1]))
        
        # self.xMX['labelsY'] = aLabels
        
#================================================================================================================
    def exclude_intra_outliers(self):        
#================================================================================================================

        fOut_intra_outliers = os.path.join(self.fOutdir, 'intra_outliers.csv')
        
        if os.path.isfile(fOut_intra_outliers):
            df = pd.read_csv(fOut_intra_outliers)
            self.aIntraRemoveIDs = list(df['aIntraRemoveIDs'])
            return
        
        nIntra_excluded = 0
        aID = []
        aDelta = []
        
        for fMX_file in self.afMX_files:

            crop = os.path.basename(fMX_file).split('_RMSE')[0]
            
            print('\r * Calculating intra crop outliers for {}'.format(crop), end= '                                                                                                                  ')
            
            xMX = xr.load_dataset(fMX_file)         
            xCrop = xMX.where((xMX.labelsY  == crop), drop=True)
            aCrop = xCrop['MX'].values
            aIDs = xCrop['ids'].values
            
            Class_Median_RMSE = np.nanmedian(aCrop)
            Class_Median_delta = Class_Median_RMSE * INTRA_RMSE_THRESHOLD
            
            # Handle all in class indices:
            for iX in range(aCrop.shape[0]):
                #check the training point RMSE against class median threshold --> INTRA outlier check
                 if ((np.nanmedian(aCrop[:, iX]) < (Class_Median_RMSE - Class_Median_delta)) | \
                     (np.nanmedian(aCrop[:, iX]) > (Class_Median_RMSE + Class_Median_delta))):

                    aID.append(aIDs[iX])
                    aDelta.append(np.nanmedian(aCrop[:, iX]) - (Class_Median_RMSE)) #  * INTRA_RMSE_THRESHOLD))

        ##now we weight the importance of the found outliers against the "to report outlier threshold"
        #calculate weight factor for each outlier by taken the DELTA into account
        DeltaMax = max(aDelta)
        aWeight = [x/float(DeltaMax) for x in aDelta]
        
        #get the outlier we want to report
        aIntraRemove = [i for i, j in enumerate(aWeight) if j > OUTLIER_WEIGHT_THRESHOLD]   
        self.aIntraRemoveIDs = [aID[i] for i in aIntraRemove]             
        
        df = pd.DataFrame(self.aIntraRemoveIDs, columns = ['aIntraRemoveIDs'])
        df.to_csv(fOut_intra_outliers, index=False)

#================================================================================================================
    def exclude_inter_outliers(self):        
#================================================================================================================

        fOut_inter_outliers = os.path.join(self.fOutdir, 'inter_outliers.csv')
        
        if os.path.isfile(fOut_inter_outliers):
            self.aInterRemoveIDs = pd.read_csv(fOut_inter_outliers)
            return
        
        aID = []
        aDelta = []
        
        for fMX_file in self.afMX_files:

            crop = os.path.basename(fMX_file).split('_RMSE')[0]
             
            print('  * Calculating inter crop outliers for {}'.format(crop))
            
            self.xMX = xr.load_dataset(fMX_file)      
            
            if self.intra_exclude:
                self.update_xarray()
            
            self.xCrop = self.xMX.where((self.xMX.labelsY  == crop), drop=True)
            self.xOther = self.xMX.where((self.xMX.labelsY  != crop), drop=True)
            
            aIDs = self.xCrop['ids'].values
 
            aCrop = self.xCrop['MX'].values
            aOther = self.xOther['MX'].values   
            
            Class_Median_RMSE = np.nanmedian(aCrop)
            
            # Handle all in class indices:
            for iX in range(aCrop.shape[0]):
                #check the training point RMSE against class median threshold --> INTRA outlier check
                 if np.nanmedian(aCrop[:, iX]) > (Class_Median_RMSE * INTER_RMSE_THRESHOLD):
                    aID.append(aIDs[iX])
                    aDelta.append(np.nanmedian(aCrop[:, iX]) - (Class_Median_RMSE)) #  * INTRA_RMSE_THRESHOLD))           
            
        ##now we weight the importance of the found outliers against the "to report outlier threshold"
        #calculate weight factor for each outlier by taken the DELTA into account
        DeltaMax = max(aDelta)
        aWeight = [x/float(DeltaMax) for x in aDelta]
        
        #get the outlier we want to report
        aIntraRemove = [i for i, j in enumerate(aWeight) if j > OUTLIER_WEIGHT_THRESHOLD]   
        self.aInterRemoveIDs = [aID[i] for i in aIntraRemove]       

        df = pd.DataFrame(self.aInterRemoveIDs, columns = ['aIntraRemoveIDs'])
        df.to_csv(fOut_inter_outliers, index=False)
                   
#================================================================================================================
    def update_xarray(self):        
#================================================================================================================
        
        aLabelsX = self.xMX['labelsX'].values
        aLabelsY = self.xMX['labelsY'].values
        
        current_label = aLabelsY[0]
        count = 1
        aNewLabels = []
         
        for label in aLabelsY:
            if label != current_label:
                count = 1
                current_label = label
            aNewLabels.append('{}_{}'.format(label, count))
            count+=1
            
        self.xMX['labelsY'] = aNewLabels      

        for ID in self.aIntraRemoveIDs:
            
            print('\r * Updating arrays for ID: {}'.format(ID), end='                                                                ')
            
            self.xMX = self.xMX.where(self.xMX.labelsX != ID, drop=True)
            self.xMX = self.xMX.where(self.xMX.labelsY != ID, drop=True)        
     
            aNewLabels.remove(ID)
             
        print('')
        
        # remove sequence number from the Y labels
        aLabels = []
        for label in aNewLabels:
            aLabels.append('_'.join(label.split('_')[0:-1]))

        self.xMX['labelsY'] = aLabels

#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fIndir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/4_HANTS_matrices/',
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/5_HANTS_analysis/',
        'intra_exclude': True,
        'periods': 577  # 2017/09/01 till 2019/03/31
        }

    oAnalyse_separability = cAnalyse_separability(Info)
    oAnalyse_separability.start_processing()
    
    
    