#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Aug 20 07:41:12 2022

@author: bertelsl
"""

import os
import glob
import netCDF4
import json
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings

#================================================================================================================
class cMX_histogram(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fIndir = Info['fIndir']
        self.fOutdir = Info['fOutdir']
        self.overwrite = Info['overwrite']
        
        if not os.path.isdir(self.fOutdir):
            os.mkdir(self.fOutdir)

#================================================================================================================
    def start_processing(self, dist):        
#================================================================================================================
      
        afMX = glob.glob(self.fIndir + '*.nc')
        
        aCropTypes = []
        
        for fMX in afMX:
            aCropTypes.append(os.path.basename(fMX).split('.')[0])
        
        nCropTypes = len(aCropTypes)
        
        aOverlap = np.zeros((nCropTypes, nCropTypes), dtype=np.float32)
         
        for own_crop in aCropTypes:
       
            xCrop = xr.open_dataset(fMX)
            
            iRow = aCropTypes.index(own_crop)
        
            xOxn = xCrop.where(xCrop.labelsY == own_crop, drop=True)
            aOwn = xOxn['MX'].values
            
            not_found = True
            
            for other_crop in aCropTypes:

                if (other_crop != own_crop) and not_found:
                    continue
                else:
                    not_found = False
            
                if (other_crop == own_crop):
                    continue
   
                iColumn = aCropTypes.index(other_crop)
                
                fOut = os.path.join(self.fOutdir, '{} - {}.png'.format(own_crop, other_crop))
                
                if os.path.isfile(fOut):
                    if self.overwrite:
                        os.remove(fOut)
                    else:
                        continue    
                    
                print('\r Processing: {} - {}'.format(own_crop, other_crop), \
                  end='                                                                                                                  ')
 
                xOther = xCrop.where(xCrop.labelsY == other_crop, drop=True)
                aOther = xOther['MX'].values

                overlap = self.hist_intersect(aOwn, aOther, own_crop, other_crop, fOut)
                
                aOverlap[iRow, iColumn] = overlap
                    
        fOut = os.path.join(self.fOutdir, '_histogram_overlap.csv')
        df = pd.DataFrame(aOverlap, columns=aCropTypes)
        df['crop types'] = aCropTypes
        df.to_csv(fOut)

#================================================================================================================
    def hist_intersect(self, aOwn, aOther, crop1, crop2, fOut):
#================================================================================================================

        overlap = 0.
    
        aOwn = aOwn[~np.isnan(aOwn)]
        aOther = aOther[~np.isnan(aOther)]
        
        hist1, bin_edges1 = np.histogram(aOwn, bins=200, range=[0 , 1])
        hist2, bin_edges2 = np.histogram(aOther, bins=200, range=[0 , 1])
        
        hist1_norm = hist1 / len(aOwn)
        hist2_norm = hist2 / len(aOther)

        for iX in range(len(hist1_norm)):
            overlap += min(hist1_norm[iX], hist2_norm[iX])

        fig, ax = plt.subplots()

        _ = ax.bar(bin_edges1[:-1], hist1_norm, width=np.diff(bin_edges1), label=crop1)
        _ = ax.bar(bin_edges2[:-1], hist2_norm, width=np.diff(bin_edges2), label=crop2, alpha=0.7)

        # Add some text for labels, title and custom x-axis tick labels, etc.
        ax.set_ylabel('Scores')
        ax.set_title('overlap = {:.2f}'.format(overlap))
        # ax.set_xticks(x, labels)
        ax.legend()
        fig.tight_layout()
        
        plt.savefig(fOut)
        plt.close()
        # plt.show()
        
        return overlap

#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fIndir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/4_satio_LUCAS_matrices/',
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/5_satio_LUCAS_MX_plots2/',
        'fIntra_outliers': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/7_satio_analysis/intra_outliers.csv',
        'fInter_outliers': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/7_satio_analysis/inter_outliers.csv',
        'fBest_bands': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/6_satio_bestbands/',
        'overwrite': True
        }

    oMX_histogram = cMX_histogram(Info)
    oMX_histogram.start_processing('RMSE')
    