#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Aug 29 13:28:36 2022

@author: bertelsl

Algorithm:
    
    - plot the matrix histograms for the different crop pairs, use the outliers to filter the matrices:

        'fIndir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/06_LUCAS_matrices_features_032/',
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/09_LUCAS_MX_plots_filtered',
        'fIntra_outliers': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/07_outliers/intra_outliers.csv',
        'fInter_outliers': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/07_outliers/inter_outliers.csv',
        
Version: 31/08/2022

"""

import os
import glob
import netCDF4
import json
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings

#================================================================================================================
class cMX_histogram(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fIndir = Info['fIndir']
        self.fOutdir = Info['fOutdir']
        self.overwrite = Info['overwrite']
        
        aIntra_outliers = pd.read_csv(Info['fIntra_outliers'])['aIntraRemoveIDs'].values
        aInter_outliers = pd.read_csv(Info['fInter_outliers'])['aInterRemoveIDs'].values
        
        self.aOutliers = np.array(list(aIntra_outliers) + list(aInter_outliers))
        
        if not os.path.isdir(self.fOutdir):
            os.mkdir(self.fOutdir)

#================================================================================================================
    def start_processing(self, dist):
#================================================================================================================
      
        afMX = glob.glob(self.fIndir + '*.nc')
        
        aCropTypes = []
        
        for fMX in afMX:
            aCropTypes.append(os.path.basename(fMX).split('.')[0])
        
        nCropTypes = len(aCropTypes)
        
        aOverlap = np.zeros((nCropTypes, nCropTypes), dtype=np.float32)
         
        for own_crop in aCropTypes:
            
            fMX = os.path.join(self.fIndir, '{}.nc'.format(own_crop))
       
            xCrop = xr.open_dataset(fMX)
            print(own_crop)
            xCrop = self.exclude_outliers(xCrop)
            
            iRow = aCropTypes.index(own_crop)
            xOwn = xCrop.where(xCrop.yCroptypes == own_crop, drop=True)
            aOwn = xOwn['MX'].values
           
            if aOwn.shape[1] == 0:
                continue

            not_found = True
            
            for other_crop in aCropTypes:

                if (other_crop != own_crop) and not_found:
                    continue
                else:
                    not_found = False
            
                if (other_crop == own_crop):
                    continue
   
                iColumn = aCropTypes.index(other_crop)
                
                fOut = os.path.join(self.fOutdir, '{} - {}.png'.format(own_crop, other_crop))
                
                if os.path.isfile(fOut):
                    if self.overwrite:
                        os.remove(fOut)
                    else:
                        continue    
                    
                print('\r Processing: {} - {}'.format(own_crop, other_crop), \
                  end='                                                                                                                  ')
 
                xOther = xCrop.where(xCrop.yCroptypes == other_crop, drop=True)
                aOther = xOther['MX'].values
                
                if aOther.shape[1] == 0:
                    continue
                
                overlap = self.hist_intersect(aOwn, aOther, own_crop, other_crop, fOut)
                
                aOverlap[iRow, iColumn] = overlap
                    
        fOut = os.path.join(self.fOutdir, '_histogram_overlap.csv')
        df = pd.DataFrame(aOverlap, columns=aCropTypes)
        df['crop types'] = aCropTypes
        df.to_csv(fOut)

#================================================================================================================
    def hist_intersect(self, aOwn, aOther, crop1, crop2, fOut):
#================================================================================================================

        total_overlap = 0.
    
        aOwn = aOwn[~np.isnan(aOwn)]
        aOther = aOther[~np.isnan(aOther)]
        
        hist1, bin_edges1 = np.histogram(aOwn, bins=200, range=[0 , 1])
        hist2, bin_edges2 = np.histogram(aOther, bins=200, range=[0 , 1])

        for iX in range(len(hist1)):
            total_overlap += min(hist1[iX], hist2[iX])

        A = total_overlap / np.sum(hist1) * 100
        B = total_overlap / np.sum(hist2) * 100
        
        max_overlap = np.max([A, B])

        fig, ax = plt.subplots()

        _ = ax.bar(bin_edges1[:-1], hist1, width=np.diff(bin_edges1), label=crop1)
        _ = ax.bar(bin_edges2[:-1], hist2, width=np.diff(bin_edges2), label=crop2, alpha=0.7)

        # Add some text for labels, title and custom x-axis tick labels, etc.
        ax.set_ylabel('Scores')
        ax.set_title('overlap = {:.1f} %'.format(max_overlap))
        # ax.set_xticks(x, labels)
        ax.legend()
        fig.tight_layout()
        
        plt.savefig(fOut)
        plt.close()
        # plt.show()
        
        return max_overlap

#================================================================================================================
    def exclude_outliers(self, xMX, level=None):        
#================================================================================================================

        aLabelsX = list(xMX['IDsX'].values)
        aLabelsY = list(xMX['IDsY'].values)
        condX = []
        condY = []

        for ID in aLabelsX:
            if ID in  self.aOutliers:
                condX.append(False)
            else:
                condX.append(True)   
        
        for ID in aLabelsY:               
            if ID in  self.aOutliers:
                condY.append(False)
            else:
                condY.append(True)   

        # Trick to remove multiple enries at once, prepare 'cond' array of booleans
        dsX = xr.DataArray(condX, dims=('IDsX'), coords={'IDsX':aLabelsX}, name='IDsX')
        dsY = xr.DataArray(condY, dims=('IDsY'), coords={'IDsY':aLabelsY}, name='IDsY')
        xMX = xMX.where(dsX, drop=True)
        xMX = xMX.where(dsY, drop=True)
        
        aDsX = xMX['IDsX'].values
        aDsY = xMX['IDsY'].values
        aMX = xMX['MX'].values
        aLat = xMX['lat'].values[:, 0]
        aLon = xMX['lon'].values[:, 0]
        aCropTypes = xMX['yCroptypes'].values[:, 0]
        
        new_data = xr.Dataset(
                data_vars=dict(
                    MX=(["IDsX", "IDsY"], aMX),                    
                    lat=(["IDsX"], aLat),
                    lon=(["IDsX"], aLon),      
                    yCroptypes=(["IDsY"], aCropTypes),      
                ),
                coords=dict(
                    IDsY = aDsY,
                    IDsX = aDsX,
                ),
                attrs=dict(description="matrix"),
            )    
  
        return new_data
    
#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fIndir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/06_LUCAS_matrices_features_032/',
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/09_LUCAS_MX_plots_filtered',
        'fIntra_outliers': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/07_outliers/intra_outliers.csv',
        'fInter_outliers': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/07_outliers/inter_outliers.csv',
        'overwrite': True
        }

    oMX_histogram = cMX_histogram(Info)
    oMX_histogram.start_processing('RMSE')
    