#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Aug 26 08:53:34 2022

@author: bertelsl
"""



import os
import glob
import netCDF4
import json
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings

#================================================================================================================
class cAnalyse_separability(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fIndir = Info['fIndir']
        self.fOutdir = Info['fOutdir']
        self.intra_exclude = Info['intra_exclude']
        
#================================================================================================================
    def start_processing(self, dist):        
#================================================================================================================
  
        if not ((dist == 'RMSE') | (dist == 'JMdist')):
            print('ERROR: no algoritm for {} distance'.format(dist))
            
        print(' * Processing {} distance'.format(dist))
        
        self.read_metrics_data()
        
        self.dia_count = 0
        
        for crop in self.aUnique_CropTypes:
            
            fOut_MX = os.path.join(self.fOutdir, '{}_{}_MX.nc'.format(crop, dist))
            
            if os.path.isfile(fOut_MX):
                continue
            
            self.calculate_MX(dist, crop, fOut_MX)

#================================================================================================================
    def read_metrics_data(self):        
#================================================================================================================
   
        afMetrics = glob.glob(os.path.join(self.fIndir, '*.csv'))

        Init = True
        self.aCropTypes = []
        self.aUnique_CropTypes = []

        self.dfCropInfo = pd.DataFrame(columns=['crop', 'nEntries'])

        for fMetric in afMetrics:
            
            crop = os.path.basename(fMetric).split('.')[0]
            self.aUnique_CropTypes.append(crop)

            print('\r * Reading the metrics data for {}'.format(crop), end='                                                                                                             ')
            
            df = pd.read_csv(fMetric, sep=';')
            lat = df['lat'].values
            lon=df['lon'].values
            df = df.drop(['lat', 'lon'], axis=1)
            
            self.aCropTypes = self.aCropTypes + [crop] * len(df)
            self.dfCropInfo = self.dfCropInfo.append({'crop':crop, 'nEntries':len(df)}, ignore_index=True)
            
            # normalize the data
            df_nume = (df - df.min(axis=0))
            df_deno = (df.max(axis=0) - df.min(axis=0))
            df_div = df_nume.div(df_deno, axis=1)
            aNorm = df_div.to_numpy()
            
            if Init:
                Init = False
                aData = aNorm
                aLat = lat
                aLon = lon
                aMetrics = df.columns
                self.nFeatures = len(aMetrics)
            else:
                aData = np.concatenate((aData, aNorm), axis=0)
                aLat = np.append(aLat, lat)
                aLon = np.append(aLon, lon)
                
        self.xData = xr.Dataset(
            data_vars=dict(
                METRICS=(["crops", "metrics"], aData),   
                lat = (["crops"], aLat),
                lon = (["crops"], aLon),                
            ),
            coords=dict(
                crops=(self.aCropTypes),
                metrics = (aMetrics),
            ),
            attrs=dict(description="JFdist applied on full date range data."),
        )
        
        self.nTotalCropEntries  = aData.shape[0]
        self.aData = aData
        
        print('\r * Calculating the metrics: Done!', end='                                                                                                 ')
        print('')

#================================================================================================================
    def calculate_MX(self, dist, crop, fOut_MX):        
#================================================================================================================
        # if crop != 'Cotton':
        #     return
        
        print(' * Calculating the {} matrix for {}'.format(dist, crop))

        xCrop = self.xData.where((self.xData.crops == crop), drop=True)
        aCrop = xCrop['METRICS'].values
        aLat = xCrop['lat'].values
        aLon = xCrop['lon'].values

        aMX = np.zeros((aCrop.shape[0], self.aData.shape[0]), dtype=np.float32)
        aIds = []

        for iR in range(aCrop.shape[0]):
            aIds.append('{}_{}'.format(crop, iR+1))
            
            print('\r  - Progress: {} / {}'.format(iR+1, aCrop.shape[0]), end='                                                                                    ')

            for iC in range(iR+1, self.aData.shape[0]):

                RMSE = np.sqrt((np.power((aCrop[iR] - self.aData[iC]), 2)).sum() / self.nFeatures)     

                aMX[iR, iC] = RMSE
                
                if iC < aCrop.shape[0]:
                    aMX[iC, iR] = RMSE  
     
        for iR in range(aCrop.shape[0]):
            aMX[iR, self.dia_count] = np.nan
            self.dia_count += 1
     
        new_data = xr.Dataset(
                data_vars=dict(
                    MX=(["labelsX", "labelsY"], aMX),             
                ),
                coords=dict(
                    labelsY = (self.aCropTypes),
                    labelsX = aIds,
                    ids=(["labels"], aIds), 
                    lat=(["labelsX"], aLat),
                    lon=(["labelsX"], aLon),
                ),
                attrs=dict(description="matrix"),
            )        
        
        new_data.to_netcdf(fOut_MX)         

        print('')
        
#================================================================================================================
    def JMdist(self, v1, v2):        
#================================================================================================================
        ''' NOT used, too slow'''
    
        meanA = np.mean(v1)
        meanB = np.mean(v2)
        meanDif = meanA - meanB
        
        covA = np.cov(v1)
        covB = np.cov(v2)
        
        p = (covA + covB) / 2
    
      # calculate the Bhattacharryya distance:
      # bh.distance <- 0.125 *t ( mean.difference ) * p^ ( -1 ) * mean.difference +
      #   0.5 * log (det ( p ) / sqrt (det ( cv.Matrix.1 ) * det ( cv.Matrix.2 )))
    
        BHdist = 0.125 * meanDif * np.power(p, -1) * meanDif + 0.5 * np.log(p / np.sqrt(covA * covB))
        
        # calculate the jeffries-matsushita distance:
        # jm.distance <- 2 * ( 1 - exp ( -bh.distance ) )
        
        JMdist = 2 * (1 - np.exp(-1 * BHdist))
    
        return JMdist



#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fIndir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/3_HANTS_metrics/',
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/4_HANTS_matrices/',
        'intra_exclude': True,
        'periods': 577  # 2017/09/01 till 2019/03/31
        }

    oAnalyse_separability = cAnalyse_separability(Info)
    oAnalyse_separability.start_processing('RMSE')
    