#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Aug  2 06:21:11 2022

@author: bertelsl
"""
import os
import netCDF4
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

#================================================================================================================
class cHRLseparability(object):
#================================================================================================================
    def __init__(self, fOutdir):        
#================================================================================================================
        
        # odict_keys(['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B11', 'B12', 'B8A', 'CODE_OBJ', 'LABEL', 'SCL', 'VH', 'VV', 'angle', 'time'])
        
        CROPTYPES_LPIS = ['Unknown',
                                        'Maize',
                                        'Common wheat',
                                        'Durum wheat', 
                                        'Barley',
                                        'Rye',
                                        'Triticale',
                                        'Spelt',
                                        'Oats',
                                        'Rice',
                                        'Other cereals',
                                        'Vegetables and herbs',
                                        'Peas and beans',
                                        'Other pulses',  #  
                                        'Potatoes',
                                        'Beet',  ##
                                        'Other root crops',  ##
                                        'Sunflower',
                                        'Soybean',
                                        'Rapeseed',
                                        'Flax, cotton and hemp',
                                        'Other industrial crops',  #
                                        'Grapes',
                                        'Olives',
                                        'Pome and stone fruit',
                                        'Citrus fruit',
                                        'Berries',
                                        'Nuts',
                                        'Other fruit', 
                                        'Other permanent crops', #
                                        'Grass',
                                        'Alfalfa',
                                        'Other fodder crops', #
                                        'Forest',
                                        'Other', 
                                        'Unclassified']

        CROPTYPES_LUCAS = ['Apple fruit',
                                            'Barley',
                                            'Broadleaved woodland',
                                            'Cherry fruit',
                                            'Clovers',
                                            'Common wheat',
                                            'Cotton',
                                            'Dry pulses',
                                            'Durum wheat',
                                            'Floriculture and ornamental plants',
                                            'Grassland with sparse tree/shrub cover',
                                            'Grassland without tree/shrub cover',
                                            'Lucerne',
                                            'Maize',
                                            'Mixed cereals for fodder',
                                            'Nuts trees',
                                            'Oats',
                                            'Olive groves',
                                            'Oranges',
                                            'Other cereals',
                                            'Other citrus fruit',
                                            'Other coniferous woodland',
                                            'Other fibre and oleaginous crops',
                                            'Other fresh vegetables',
                                            'Other fruit trees and berries',
                                            'Other leguminous and mixtures for fodder',
                                            'Other mixed woodland',
                                            'Other non-permanent industrial crops',
                                            'Pear fruit',
                                            'Permanent industrial crops',
                                            'Pine dominated coniferous woodland',
                                            'Pine dominated mixed woodland',
                                            'Potatoes',
                                            'Rape and turnip rape',
                                            'Rice',
                                            'Rye',
                                            'Soya',
                                            'Spruce dominated coniferous woodland',
                                            'Spruce dominated mixed woodland',
                                            'Strawberries',
                                            'Sugar beet',
                                            'Sunflower',
                                            'Temporary grasslands',
                                            'Tobacco',
                                            'Tomatoes',
                                            'Triticale',
                                            'Vineyards']

        self.fOutdir = fOutdir
        
        self.aB04_HANTS = None
        self.aB8A_HANTS = None
        
#================================================================================================================
    def read_merged_info(self, fMerged):        
#================================================================================================================

        print(' * Reading the required info')
        
        file2read = netCDF4.Dataset(fMerged,'r')
        
        self.labels = file2read.variables['LABEL'][:]
        self.codes = file2read.variables['CODE_OBJ'][:]
        self.unique_labels = set(self.labels)

#================================================================================================================
    def get_master_mask(self, fS2):        
#================================================================================================================

        print(' * Calculating the madHANTS master mask')
        
        fMaster_mask = os.path.join(self.fOutdir, 'master_mask.csv')
        
        if os.path.isfile(fMaster_mask):
            self.master_mask = pd.read_csv(fMaster_mask)
            return
        
        file2read = netCDF4.Dataset(fS2,'r')
        
        B02 = file2read.variables['B02'][:]
        aBlue = np.ma.getdata(B02)
        aBlue[np.isnan(aBlue)] = 0
        
        B11 = file2read.variables['B11'][:]
        aSWIR = np.ma.getdata(B11)
        aSWIR[np.isnan(aSWIR)] = 0
        
        nDays = np.shape(aBlue)[0]
        
        blue_HANTS = self.HANTS_light(nDays, aBlue.T)
        swir_HANTS = self.HANTS_light(nDays, aSWIR.T)
        
        ### TEST
        # y1 =aBlue[:, 0]
        # y2 = blue_HANTS[0, :]
        # x = np.arange(nDays)
        
        # plt.plot(x,y1)
        # plt.plot(x,y2)
        ### TEST
        
        diff_blue = np.abs(blue_HANTS.T - aBlue)
        diff_swir = np.abs(swir_HANTS.T - aSWIR)
        
        ma_diff_blue = np.ma.array(diff_blue, mask=(aBlue == 0))
        ma_diff_swir = np.ma.array(diff_swir, mask=(aSWIR == 0))
        
        MAD_blue = np.ma.median(ma_diff_blue, axis=1, keepdims=True).filled(0)
        MAD_swir = np.ma.median(ma_diff_swir, axis=1, keepdims=True).filled(0)
        
        # set numpy error warning for divide to avoid messages for water pixel
        np.seterr(divide='ignore', invalid='ignore')
          
        # calculate score value for each data point
        score_blue = diff_blue / MAD_blue
        score_swir = diff_swir / MAD_swir
        
        # create mask for both channels via comparison of score to threshold
        threshold = 3.5  # is nearly 3.5 standard deviations
         
        mask_blue = score_blue >= threshold
        mask_swir = score_swir >= threshold
        # create master mask by taking all outliers from blue and swir into account
        self.master_mask = mask_swir | mask_blue
        
        dfOut = pd.DataFrame(self.master_mask)
        dfOut.to_csv(fMaster_mask, index=False)

        ### TEST
        # aBlue[self.master_mask == True] = 0
        # y3 =aBlue[:, 0]
        
        # plt.plot(x,y3)       
        # plt.show()
        # a=1
        ### TEST
        
#================================================================================================================
    def get_S2_HANTS(self, fS2, BXX):        
#================================================================================================================

        print(' * Calculating the HANTS interpolated/ smoothed curve for {}'.format(BXX))
        
        fBXX = os.path.join(self.fOutdir, '{}_madHANTS.csv'.format(BXX))
        
        if os.path.isfile(fBXX):
            aBXX_HANTS = pd.read_csv(fBXX)
            return aBXX_HANTS
        
        file2read = netCDF4.Dataset(fS2,'r')
        
        BXX = file2read.variables[BXX][:]
        aBXX = np.ma.getdata(BXX)
        aBXX[np.isnan(aBXX)] = 0
        
        nDays = np.shape(aBXX)[0]
        
        # apply the master mask on all data in the current line via fancy slicing (mask has to be transposed)
        aBXX[self.master_mask == True] = 0
        
        aBXX_HANTS = self.HANTS_light(nDays, aBXX.T)

        dfOut = pd.DataFrame(aBXX_HANTS)
        dfOut.to_csv(fBXX, index=False)
        
        return aBXX_HANTS

#================================================================================================================
    def get_VIs(self, aVIs, fS2):        
#================================================================================================================

        print(' * Calculating the S2 vegetation indices')

        self.aB04_HANTS = self.get_S2_HANTS(fS2, 'B04')
        self.aB8A_HANTS = self.get_S2_HANTS(fS2, 'B8A')
        
        for VI in aVIs:
            if VI == 'NDVI':
                fNDVI = os.path.join(self.fOutdir, 'NDVI.csv')
                
                if os.path.isfile(fNDVI):
                    self.aNDVI = pd.read_csv(fNDVI)
                else:
                    self.aNDVI  = (self.aB8A_HANTS - self.aB04_HANTS) / (self.aB8A_HANTS + self.aB04_HANTS)
                    
                    dfOut = pd.DataFrame(self.aNDVI)
                    dfOut['code_obj'] = self.codes
                    dfOut['label'] = self.labels
                    dfOut.to_csv(fNDVI, index=False)                    

#================================================================================================================
    def JMdist(self, v1, v2):        
#================================================================================================================
        
        meanA = np.mean(v1)
        meanB = np.mean(v2)
        meanDif = meanA - meanB
        
        covA = np.cov(v1)
        covB = np.cov(v2)
        
        p = (covA + covB) / 2
    
      # calculate the Bhattacharryya distance:
      # bh.distance <- 0.125 *t ( mean.difference ) * p^ ( -1 ) * mean.difference +
      #   0.5 * log (det ( p ) / sqrt (det ( cv.Matrix.1 ) * det ( cv.Matrix.2 )))
    
        BHdist = 0.125 * meanDif * np.power(p, -1) * meanDif + 0.5 * np.log(p / np.sqrt(covA * covB))
        
        # calculate the jeffries-matsushita distance:
        # jm.distance <- 2 * ( 1 - exp ( -bh.distance ) )
        
        JMdist = 2 * (1 - np.exp(-1 * BHdist))
    
        return JMdist

#================================================================================================================
    def calculate_intra_crop_variability(self, aVIs, prefix='', threshold=None):        
#================================================================================================================
        
        print(' * Calculating the intra crop variability')
        
        for VI in aVIs:
            
            if VI == 'NDVI':
                dfVI = pd.DataFrame(self.aNDVI)
                
            elif VI == 'NDVI_drop':
                fIn = os.path.join(self.fOutdir, 'NDVI_drop_{}.csv'.format(str(threshold).replace('.', '')))
                dfVI = pd.read_csv(fIn)
            
            # nDays = np.shape(self.aNDVI)[1]
            nDays = dfVI.shape[1] - 2
            
            # dfVI['code_obj'] = self.codes
            # dfVI['label'] = self.labels
            
            for label in self.unique_labels:
                
                fName = '{}_{}_{}'.format(prefix, VI, label)
                fOut = os.path.join(self.fOutdir, fName+'.csv')
                
                if os.path.isfile(fOut):
                    continue
                
                print(' ** Processing: {}'.format(fName))
    
                dfCrop = dfVI[dfVI['label'] == label]
                nCrop = len(dfCrop)
                
                aMX = np.zeros(shape=(nCrop, nCrop))
                aMX[:] = np.nan
                
                for iX in range(0, nCrop-1):
                    for iY in range (iX+1, nCrop):
                        
                        v1 = np.array(dfCrop.iloc[iX, 0:nDays].values.tolist())
                        v2 = np.array(dfCrop.iloc[iY, 0:nDays].values.tolist())
                        
                        JMdist = self.JMdist(v1, v2)
                
                        aMX[iY, iX] = JMdist
                        aMX[iX, iY] = JMdist

                dfOut = pd.DataFrame(aMX, columns = dfCrop['code_obj'])
                dfOut['mean'] = np.nanmean(aMX, axis=0)
                dfOut['code_obj'] = dfCrop['code_obj'].values                
                dfOut.to_csv(fOut, index=False)
                
                fOut = fOut.replace('.csv', '.png')
                nnn, bins, patches = plt.hist(aMX, 20, range=[0, 2], facecolor='blue')
                plt.xlim([0, 2])
                plt.title('{} - {} samples'.format(fName, nCrop))
                plt.savefig(fOut)
                plt.close()
                # plt.show()
                # A = 1

#================================================================================================================
    def remove_intra_crop_outliers(self, VI, prefix='', threshold=0.5):        
#================================================================================================================

        if VI == 'NDVI':
 
            
            fNDVI = os.path.join(self.fOutdir, 'NDVI.csv')
            aDroplist = []
            
            fOut = fNDVI.replace('NDVI.csv', 'NDVI_drop_{}.csv'.format(str(threshold).replace('.', '')))
            
            if os.path.isfile(fOut):
                return
            
            if os.path.isfile(fNDVI):
                dfNDVI = pd.read_csv(fNDVI)
            else:
                print('*** ERROR: file not found: {}'.format(fNDVI))

            for label in self.unique_labels:
                
                fName = '{}_{}_{}'.format(prefix, VI, label)
                fIn = os.path.join(self.fOutdir, fName+'.csv')        
                
                df = pd.read_csv(fIn)
                code_objs = df[df['mean'] >= threshold]['code_obj']
                aDroplist = aDroplist + (code_objs.tolist())
                
            for code_obj in aDroplist:
                dfNDVI.drop(dfNDVI[dfNDVI.code_obj == code_obj].index, inplace = True)
            
            dfNDVI.to_csv(fOut, index=False)     

#================================================================================================================
    def makediag3d(self, M):
#================================================================================================================
        # Computing diagonal for each row of a 2d array. See: http://stackoverflow.com/q/27214027/2459096
        # helper function for HANTS algorithm
        b = np.zeros((M.shape[0], M.shape[1] * M.shape[1]))
        b[:, ::M.shape[1] + 1] = M
        return b.reshape(M.shape[0], M.shape[1], M.shape[1])
    
#================================================================================================================
    def get_starter_matrix(self, base_period_len, sample_count, frequencies_considered_count):
#================================================================================================================
        # get first matrix with harmonisation factors
        # helper function for HANTS algorithm
        nr = min(2 * frequencies_considered_count + 1,
                      sample_count)  # number of 2*+1 frequencies, or number of input images
        mat = np.zeros(shape=(nr, sample_count))
           
        mat[0, :] = 1
        ang = 2 * np.pi * np.arange(base_period_len) / base_period_len
        cs = np.cos(ang)
        sn = np.sin(ang)
        # create some standard sinus and cosinus functions and put in matrix
        i = np.arange(1, frequencies_considered_count + 1)
        ts = np.arange(sample_count)
        for column in range(sample_count):
            index = np.mod(i * ts[column], base_period_len)
            # index looks like 000, 123, 246, etc, until it wraps around (for len(i)==3)
            mat[2 * i - 1, column] = cs.take(index)
            mat[2 * i, column] = sn.take(index)
        return mat

#================================================================================================================
    def HANTS_light(self, sample_count, inputs, frequencies_considered_count=3, outliers_to_reject='Hi',
              exclude_low=0., exclude_high=255, fit_error_tolerance=5, delta=0.1):
#================================================================================================================
        """
        Function to apply the Harmonic analysis of time series applied to arrays
        
        This version gives only back the harmonized time series
        
        sample_count    = nr. of images (total number of actual samples of the time series)
        base_period_len    = length of the base period, measured in virtual samples
                (days, dekads, months, etc.)
        frequencies_considered_count    = number of frequencies to be considered above the zero frequency
        inputs     = array of input sample values (e.g. NDVI values)
        ts    = array of size sample_count of time sample indicators
                (indicates virtual sample number relative to the base period);
                numbers in array ts maybe greater than base_period_len
                If no aux file is used (no time samples), we assume ts(i)= i,
                where i=1, ..., sample_count
        outliers_to_reject  = 2-character string indicating rejection of high or low outliers
                select from 'Hi', 'Lo' or 'None'
        low   = valid range minimum
        high  = valid range maximum (values outside the valid range are rejeced
                right away)
        fit_error_tolerance   = fit error tolerance (points deviating more than fit_error_tolerance from curve
                fit are rejected)
        dod   = degree of overdeterminedness (iteration stops if number of
                points reaches the minimum required for curve fitting, plus
                dod). This is a safety measure
        delta = small positive number (e.g. 0.1) to suppress high amplitudes
        """
        # define some parameters
        base_period_len = sample_count  #
    
        # check which setting to set for outlier filtering
        if outliers_to_reject == 'Hi':
            sHiLo = -1
        elif outliers_to_reject == 'Lo':
            sHiLo = 1
        else:
            sHiLo = 0
    
        nr = min(2 * frequencies_considered_count + 1,
                 sample_count)  # number of 2*+1 frequencies, or number of input images
    
        # create empty arrays to fill
        outputs = np.zeros(shape=(inputs.shape[0], sample_count))
        
        #get starter matrix
        mat = self.get_starter_matrix(base_period_len, sample_count, frequencies_considered_count)
    
        # repeat the mat array over the number of arrays in inputs
        # and create arrays with ones with shape inputs where high and low values are set to 0
        mat = np.tile(mat[None].T, (1, inputs.shape[0])).T
        p = np.ones_like(inputs)
        p[(exclude_low >= inputs) | (inputs > exclude_high)] = 0
        nout = np.sum(p == 0, axis=-1)  # count the outliers for each timeseries
    
        # prepare for while loop
        ready = np.zeros((inputs.shape[0]), dtype=bool)  # all timeseries set to false
    
        dod = 1  # (2*frequencies_considered_count-1)  # Um, no it isn't :/
        noutmax = sample_count - nr - dod
        
        # NOW we have to deal with pixel where a gap is
        # since we have filled the gap with -1 in the whole line we only tell
        # the algorithmus that this whole line is valid 
        p[p.sum(axis=1)==0] = 1
        
        # and set the nout value of gap lines to noutmax -> then this line is ready after the
        # first processing
        nout[nout==sample_count] = noutmax 
        
        ## here comes now the real calculations!    
        for _ in range(sample_count):
            if ready.all():
                break
            # print '--------*-*-*-*',it.value, '*-*-*-*--------'
            # multiply outliers with timeseries
            za = np.einsum('ijk,ik->ij', mat, p * inputs)
    
            # multiply mat with the multiplication of multiply diagonal of p with transpose of mat
            diag = self.makediag3d(p)
            A = np.einsum('ajk,aki->aji', mat, np.einsum('aij,jka->ajk', diag, mat.T))
            # add delta to suppress high amplitudes but not for [0,0]
            A = A + np.tile(np.diag(np.ones(nr))[None].T, (1, inputs.shape[0])).T * delta
            A[:, 0, 0] = A[:, 0, 0] - delta
    
            # solve linear matrix equation and define reconstructed timeseries
            zr = np.linalg.solve(A, za)
            outputs = np.einsum('ijk,kj->ki', mat.T, zr)
    
            # calculate error and sort err by index
            err = p * (sHiLo * (outputs - inputs))
            rankVec = np.argsort(err, axis=1, )
    
            # select maximum error and compute new ready status
            maxerr = np.diag(err.take(rankVec[:, sample_count - 1], axis=-1))
            ready = (maxerr <= fit_error_tolerance) | (nout == noutmax)
    
            # if ready is still false
            if not ready.all():
                j = rankVec.take(sample_count - 1, axis=-1)
    
                p.T[j.T, np.indices(j.shape)] = p.T[j.T, np.indices(j.shape)] * ready.astype(
                    int)  #*check
                nout += 1
        return outputs


#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    fMerged = r'/data/EEA_HRL_VLCC/data/ref/crop_type/2018_BE_LPIS-Flanders_POLY_110/2018_BE_LPIS-Flanders_POLY_110_2017-09-01_2018-08-30_InputsOutputs_crops.nc'
    fS1 = r'/data/EEA_HRL_VLCC/data/ref/crop_type/2018_BE_LPIS-Flanders_POLY_110/TS/S1_2017-09-01_2018-08-30_2018_BE_LPIS-Flanders_POLY_110.nc'
    fS2 = r'/data/EEA_HRL_VLCC/data/ref/crop_type/2018_BE_LPIS-Flanders_POLY_110/TS/S2_2017-09-01_2018-08-30_2018_BE_LPIS-Flanders_POLY_110.nc'
    # fS2 = r'/home/bertelsl/Public/PyCo/HRL_VLCC/data/S2_2017-09-01_2018-08-30_2018_AT_LPIS_POLY_110.nc'

    fOutdir = r'/data/EEA_HRL_VLCC/user/luc/intra_seperability/'
    prefix = '_'.join(os.path.basename(fMerged).split('_')[0:3])

    oHRLseparability = cHRLseparability(fOutdir)
    oHRLseparability.read_merged_info(fMerged)
    
    oHRLseparability.get_master_mask(fS2)
    oHRLseparability.get_VIs(['NDVI'], fS2)

    oHRLseparability.calculate_intra_crop_variability(['NDVI'], prefix=prefix)
    oHRLseparability.remove_intra_crop_outliers('NDVI', prefix=prefix, threshold=0.5)
    oHRLseparability.calculate_intra_crop_variability(['NDVI_drop'], prefix=prefix, threshold=0.5)