#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Aug  9 12:46:16 2022

@author: bertelsl
"""
import os
import glob
import netCDF4
import json
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings

        
INTER_RMSE_THRESHOLD = 1
INTRA_RMSE_THRESHOLD = 0.5
OUTLIER_WEIGHT_THRESHOLD = 0.5
        
#================================================================================================================
class cAnalyse_separability(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fIndir = Info['fIndir']
        self.fOutdir = Info['fOutdir']
        self.intra_exclude = Info['intra_exclude']
        self.overwrite = Info['overwrite']
        
        if not os.path.isdir(self.fOutdir):
            os.mkdir(self.fOutdir)
            
#================================================================================================================
    def start_processing(self):        
#================================================================================================================

        self.afMX_files = glob.glob(os.path.join(self.fIndir, '*.nc'))
        self.aUnique_CropTypes = []

        for fMX_file in self.afMX_files:
            crop = os.path.basename(fMX_file).split('_RMSE')[0]  
            self.aUnique_CropTypes.append(crop.split('.')[0])
        
        self.aUnique_CropTypes.sort()
        
        # Calculate stats on all data
        fOut_stats_in = os.path.join(self.fOutdir, 'stats_in.csv')
        
        if not os.path.isfile(fOut_stats_in):
            self.calculate_stats(fOut_stats_in)
            
        # Calculate stats on excluded data
        fOut_stats_ex = os.path.join(self.fOutdir, 'stats_ex.csv')

        if self.intra_exclude:
            self.exclude_intra_outliers()
            
        self.exclude_inter_outliers()           

        if  not os.path.isfile(fOut_stats_ex):
            self.calculate_stats(fOut_stats_ex, exclude=True)            

#================================================================================================================
    def calculate_stats(self, fOut_stats, exclude=False):        
#================================================================================================================

        def hist_intersect(hist1, hist2):
            area = 0.
            
            for iX in range(len(hist1)):
                area += min(hist1[iX], hist2[iX])
                
            return area
        
        print(' * Calculating the statistics')

        nEntries = len(self.afMX_files)
        anEntries = []
        aStats = np.zeros((nEntries, nEntries), dtype=np.float32)

        # Loop over all matrices
        for iR in range(nEntries):
            fMX_file = os.path.join(self.fIndir, '{}.nc'.format(self.aUnique_CropTypes[iR]))
            xMX1 = xr.load_dataset(fMX_file)             
            
            if exclude:
                xMX1 = self.exclude(xMX1)
            
            aMX1 = xMX1['MX'].values
            
            anEntries.append(aMX1.shape[0])

            # Loop over all other matrices
            for iC in range(iR+1, nEntries):
                fMX_file = os.path.join(self.fIndir, '{}.nc'.format(self.aUnique_CropTypes[iC]))
                xMX2 = xr.load_dataset(fMX_file)
                
                if exclude:
                    self.exclude(xMX2)                       
                
                aMX2 = xMX2['MX'].values
                
                print('\r   - processing: {} v.s. {}'.format(self.aUnique_CropTypes[iR], self.aUnique_CropTypes[iC]), 
                      end='                                                                                                                         ')

                aMX1 = aMX1[~np.isnan(aMX1)]
                aMX2 = aMX2[~np.isnan(aMX2)]

                hist1, bin_edges1 = np.histogram(aMX1, bins=100)
                hist2, bin_edges2 = np.histogram(aMX2, bins=100)
                
                hist1_norm = hist1 / len(aMX1)
                hist2_norm = hist2 / len(aMX2)
                
                # fig, ax = plt.subplots()
                # ax.bar(bin_edges1[:-1], hist1, width=np.diff(bin_edges1), edgecolor="black", align="edge")
                # plt.show()
                
                # fig, ax = plt.subplots()
                # ax.bar(bin_edges2[:-1], hist2, width=np.diff(bin_edges2), edgecolor="black", align="edge")
                # plt.show()
                
                diff = hist_intersect(hist1_norm, hist2_norm)

                aStats[iR, iC] = diff
                aStats[iC, iR] = diff

        df = pd.DataFrame(aStats, columns=self.aUnique_CropTypes)
        df['nEnries'] = anEntries
        df['crop'] = self.aUnique_CropTypes
        df.to_csv(fOut_stats, index=False)

#================================================================================================================
    def exclude(self, xMX, level=None):        
#================================================================================================================

        aLabels = list(xMX['labelsX'].values)
        aNewLabels = list(xMX['labelsX'].values)
        cond = []
        
        if level == 'intra_only':      
            for ID in aLabels:
                if ID in  self.aIntraRemoveIDs:
                    cond.append(False)
                    aNewLabels.remove(ID)
                else:
                    cond.append(True)
        
        else:     
            for ID in aLabels:
                if (ID in self.aInterRemoveIDs) or (ID in  self.aIntraRemoveIDs):
                    cond.append(False)
                    aNewLabels.remove(ID)
                else:
                    cond.append(True)
    
        # Trick to remove multiple enries at once, prepare 'cond' array of booleans
        dsX = xr.DataArray(cond, dims=('labelsX'), coords={'labelsX':aLabels}, name='labelsX')
        dsY = xr.DataArray(cond, dims=('labelsY'), coords={'labelsY':aLabels}, name='labelsY')
        xMX = xMX.where(dsX, drop=True)
        xMX = xMX.where(dsY, drop=True)
  
        return xMX
        
#================================================================================================================
    def exclude_intra_outliers(self):        
#================================================================================================================

        fOut_intra_outliers = os.path.join(self.fOutdir, 'intra_outliers.csv')
        
        if os.path.isfile(fOut_intra_outliers) and not self.overwrite:
            df = pd.read_csv(fOut_intra_outliers)
            self.aIntraRemoveIDs = list(df['aIntraRemoveIDs'])
            return

        aID = []
        aDelta = []
        
        for fMX_file in self.afMX_files:

            crop = os.path.basename(fMX_file).split('.')[0]
            
            print('\r * Calculating intra crop outliers for {}'.format(crop), end= '                                                                                                                  ')
            
            xCrop = xr.load_dataset(fMX_file)         
            aCrop = xCrop['MX'].values
            aIDs = xCrop['labelsY'].values
            
            Class_Median_RMSE = np.nanmedian(aCrop)
            Class_Median_delta = Class_Median_RMSE * INTRA_RMSE_THRESHOLD
            
            # Handle all in class indices:
            for iX in range(aCrop.shape[0]):
                #check the training point RMSE against class median threshold --> INTRA outlier check
                 if ((np.nanmedian(aCrop[:, iX]) < (Class_Median_RMSE - Class_Median_delta)) | \
                     (np.nanmedian(aCrop[:, iX]) > (Class_Median_RMSE + Class_Median_delta))):

                    aID.append(aIDs[iX])
                    aDelta.append(np.nanmedian(aCrop[:, iX]) - (Class_Median_RMSE)) #  * INTRA_RMSE_THRESHOLD))

        ##now we weight the importance of the found outliers against the "to report outlier threshold"
        #calculate weight factor for each outlier by taken the DELTA into account
        DeltaMax = max(aDelta)
        aWeight = [x/float(DeltaMax) for x in aDelta]
        
        #get the outlier we want to report
        aIntraRemove = [i for i, j in enumerate(aWeight) if j > OUTLIER_WEIGHT_THRESHOLD]   
        self.aIntraRemoveIDs = [aID[i] for i in aIntraRemove]             
        
        df = pd.DataFrame(self.aIntraRemoveIDs, columns = ['aIntraRemoveIDs'])
        df.to_csv(fOut_intra_outliers, index=False)

#================================================================================================================
    def exclude_inter_outliers(self):        
#================================================================================================================

        fOut_inter_outliers = os.path.join(self.fOutdir, 'inter_outliers.csv')
        
        if os.path.isfile(fOut_inter_outliers) and not self.overwrite:
            self.aInterRemoveIDs = pd.read_csv(fOut_inter_outliers)
            return
        
        aRemoveIDs = []
        aDelta = []
        
        for fMX_own in self.afMX_files:

            crop = os.path.basename(fMX_own).split('.')[0]  
            xCrop = xr.load_dataset(fMX_own)      
            aOwn = xCrop['MX'].values
            aLabels = xCrop['labelsX'].values
            
            if self.intra_exclude:
                xCrop = self.exclude(xCrop, level='intra_only')
            
            Class_Median_own = np.nanmedian(aOwn)
            
            for fMX_other in self.afMX_files:
                
                other_crop = os.path.basename(fMX_other).split('.')[0]
                
                if other_crop == crop:
                    continue
                
                print('\r  * Calculating inter crop outliers between {} and {}'.format(crop, other_crop), 
                      end = '                                                                                                                                ')

                xOther = xr.load_dataset(fMX_other)
                
                if self.intra_exclude:
                    xOther = self.exclude(xOther, level='intra_only')
            
                aOther = xOther['MX'].values   
            
                Class_Median_other = np.nanmedian(aOther)
                
                # Handle all in class indices:
                for iX in range(aOwn.shape[0]):
                    #check the training point RMSE against class median threshold --> INTRA outlier check
                     if np.nanmedian(aOwn[:, iX]) > (Class_Median_other * INTER_RMSE_THRESHOLD):
                        aRemoveIDs.append(aLabels[iX])
                        aDelta.append(np.nanmedian(aOwn[:, iX]) - Class_Median_other)     
            
        ##now we weight the importance of the found outliers against the "to report outlier threshold"
        #calculate weight factor for each outlier by taken the DELTA into account
        DeltaMax = max(aDelta)
        aWeight = [x/float(DeltaMax) for x in aDelta]
        
        #get the outlier we want to report
        aIntraRemove = [i for i, j in enumerate(aWeight) if j > OUTLIER_WEIGHT_THRESHOLD]   
        self.aInterRemoveIDs = [aRemoveIDs[i] for i in aIntraRemove]       

        df = pd.DataFrame(self.aInterRemoveIDs, columns = ['aIntraRemoveIDs'])
        df.to_csv(fOut_inter_outliers, index=False)
                   
#================================================================================================================
    def update_xarray(self, xMX):        
#================================================================================================================
        
        aLabelsX = self.xMX['labelsX'].values
        aLabelsY = self.xMX['labelsY'].values
        
        current_label = aLabelsY[0]
        count = 1
        aNewLabels = []
         
        for label in aLabelsY:
            if label != current_label:
                count = 1
                current_label = label
            aNewLabels.append('{}_{}'.format(label, count))
            count+=1
            
        self.xMX['labelsY'] = aNewLabels      

        for ID in self.aIntraRemoveIDs:
            
            print('\r * Updating arrays for ID: {}'.format(ID), end='                                                                ')
            
            self.xMX = self.xMX.where(self.xMX.labelsX != ID, drop=True)
            self.xMX = self.xMX.where(self.xMX.labelsY != ID, drop=True)        
     
            aNewLabels.remove(ID)
             
        print('')
        
        # remove sequence number from the Y labels
        aLabels = []
        for label in aNewLabels:
            aLabels.append('_'.join(label.split('_')[0:-1]))

        self.xMX['labelsY'] = aLabels

#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fIndir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/4_HANTS_Ts_matrices_bb012/',
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/5_HANTS_Ts_analysis/',
        'intra_exclude': True,
        'overwrite': False
        }

    oAnalyse_separability = cAnalyse_separability(Info)
    oAnalyse_separability.start_processing()
    
    
    