#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Aug  4 06:55:24 2022

@author: bertelsl


The LUCAS dataset (S1, S2) was colleted for the period 2017/09/01 till 2019/03/31. and contains 
following entries:

    Dimensions:                (feature: 712, t: 528)
    Coordinates:
      * t                      (t) datetime64[ns] 2017-09-01 ... 2019-03-31
        lat                    (feature) float64 58.54 58.62 58.62 ... 57.38 57.35
        lon                    (feature) float64 14.37 14.42 14.35 ... 12.74 12.74
        feature_names          (feature) object 'feature_0' ... 'feature_711'
    Dimensions without coordinates: feature
    Data variables:
        VH                     (feature, t) float64 -15.87 -23.53 nan ... -15.33 nan
        VV                     (feature, t) float64 -9.958 -10.49 nan ... -8.22 nan
        local_incidence_angle  (feature, t) float64 15.59 16.27 nan ... 15.89 nan
        B01                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B02                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B03                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B04                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B05                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B06                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B07                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B08                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B09                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B11                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B12                    (feature, t) float64 nan nan nan nan ... nan nan nan
        B8A                    (feature, t) float64 nan nan nan nan ... nan nan nan
        SCL                    (feature, t) float64 nan nan nan nan ... nan nan nan
    Attributes:
        Conventions:  CF-1.8
        source:       Aggregated timeseries generated by openEO GeoPySpark backend.
        
"""
import os
import glob
import netCDF4
import json
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

#================================================================================================================
class cLUCASpreprocessing(object):
#================================================================================================================
    def __init__(self, Info):        
#================================================================================================================
        
        self.fBasedir = Info['fBasedir']
        self.fSplitdir= Info['fSplitdir']
        self.fOutdir = Info['fOutdir']
        self.periods = Info['periods']
        self.newT = pd.date_range("2017-09-01", periods=Info['periods'])
        
        if not os.path.isdir(self.fOutdir):
            os.mkdir(self.fOutdir)
 
#================================================================================================================
    def start_processing(self):        
#================================================================================================================
        
        aInDirs = glob.glob(self.fBasedir+'Lucas *')
        aSplit = glob.glob(os.path.join(self.fSplitdir, '*.json'))
        Nr = 1
                  
        for fSubdir in aInDirs:

            basename = os.path.basename(fSubdir).split('.')[0]
            self.fOut = os.path.join(self.fOutdir, basename+'.nc')    
        
            print('Processing: {} - {}'.format(Nr, basename))
            Nr += 1

            # check if outputfile is alreay there
            # if os.path.isfile(self.fOut):
            #     continue 

            fTimeseries = os.path.join(fSubdir, 'timeseries.nc')
            fInfo = glob.glob(os.path.join(fSubdir, 'group_*.json'))
            
            if len(fInfo) == 0:
                fInfo = [s for s in aSplit if (basename.split(' ')[1] in s)][0]
            else:
                fInfo = fInfo[0]
            
            if not os.path.isfile(fTimeseries):
                print('*** ERROR: file not found: {}'.format(fTimeseries))
                
            if not os.path.isfile(fInfo):
                print('*** ERROR: file not found: {}'.format(fInfo))
            
            data = json.load(open(fInfo))
            aFeatures = data["features"]
        
            aLabels = []
        
            for feature in aFeatures:
                aLabels.append(feature['properties']['LC1_LABEL'])

            self.extract_LUCAS_data(fTimeseries, aLabels, basename)

#================================================================================================================
    def extract_LUCAS_data(self, fTimeseries, aLabels, basename):        
#================================================================================================================

        print(' * Extracting LUCAS data')
        
        # get available dates for the actual data
        file2read = netCDF4.Dataset(fTimeseries,'r')
        self.at_date = file2read.variables['t'][:]
        self.at_date = np.ma.getdata(self.at_date)
        
        # read xarray
        self.data = xr.load_dataset(fTimeseries)
        
        # get coordinates
        lat = self.data['lat'].values
        lon = self.data['lon'].values
        t = self.data['t'].values

        # get madHANTS master mask
        self.get_master_mask(fTimeseries, basename)
        
        # # get S2 data
        self.aB02 = self.preprocess_band('B02', 0.0001)
        self.aB03 = self.preprocess_band('B03', 0.0001)
        self.aB04 = self.preprocess_band('B04', 0.0001)
        self.aB05 = self.preprocess_band('B05', 0.0001)
        self.aB06 = self.preprocess_band('B06', 0.0001)
        self.aB07 = self.preprocess_band('B07', 0.0001)
        self.aB08 = self.preprocess_band('B08', 0.0001)
        self.aB8A = self.preprocess_band('B8A', 0.0001)
        self.aB11 = self.preprocess_band('B11', 0.0001)
        self.aB12 = self.preprocess_band('B12', 0.0001)
        self.aVH = self.preprocess_band('VH', 0.0001) 
        self.aVV = self.preprocess_band('VV', 0.0001)

        ### TEST
        # for iX in range(41, 42):
        #     aB02 = self.data['B02'].values * 0.0001    
        #     aB04 = self.data['B04'].values * 0.0001    
        #     aB08 = self.data['B08'].values * 0.0001
       
        #     # aNDVI = (aB08 - aB04) / (aB08 + aB04)
    
        #     B02 = np.zeros(shape=(np.shape(aB02)[0], self.periods))
        #     B02[:, self.at_date] = aB02
        #     B02[np.isnan(B02)] = 0
            
        #     B04 = np.zeros(shape=(np.shape(aB04)[0], self.periods))
        #     B04[:, self.at_date] = aB04
        #     B04[np.isnan(B04)] = 0
    
        #     B08 = np.zeros(shape=(np.shape(aB08)[0], self.periods))
        #     B08[:, self.at_date] = aB08
        #     B08[np.isnan(B08)] = 0
            
        #     y1 = B02[iX, :]
        #     y2 = self.aB02[iX, :]
    
        #     x = np.arange(self.periods)
            
        #     fig = plt.figure()
        #     plt.plot(x,y1)
        #     plt.plot(x,y2)
    
        #     plt.title('entry: {}'.format(iX))
        #     plt.xlabel('days')
        #     plt.ylabel('reflectance')
        #     # plt.legend()
        #     plt.show()

        ### TEST

        # create the new dataset
        new_data = xr.Dataset(
                data_vars=dict(
                    B02=(["labels", "date"], self.aB02),
                    B03=(["labels", "date"], self.aB03),
                    B04=(["labels", "date"], self.aB04),
                    B05=(["labels", "date"], self.aB05),
                    B06=(["labels", "date"], self.aB06),
                    B07=(["labels", "date"], self.aB07),
                    B08=(["labels", "date"], self.aB08),
                    B8A=(["labels", "date"], self.aB8A),
                    B11=(["labels", "date"], self.aB11),
                    B12=(["labels", "date"], self.aB12),
                    VH=(["labels", "date"], self.aVH),
                    VV=(["labels", "date"], self.aVV),
                    lat=(["labels"], lat),
                    lon=(["labels"], lon),                    
                ),
                coords=dict(
                    date=(self.newT),
                    labels = (aLabels)
                ),
                attrs=dict(description="JFdist applied on full date range data."),
            )

        new_data.to_netcdf(self.fOut)

        '''
                Dimensions:  (date: 577, feature: 712, labels: 712)
                Coordinates:
                  * date     (date) datetime64[ns] 2017-09-01 2017-09-02 ... 2019-03-31
                    lon      (feature) float64 14.37 14.42 14.35 14.34 ... 12.94 12.74 12.74
                    lat      (feature) float64 58.54 58.62 58.62 58.51 ... 57.36 57.38 57.35
                  * labels   (labels) <U40 'Grassland without tree/shrub cover' ... 'Grassland with sparse tree/shrub cover'
                Dimensions without coordinates: feature
                Data variables:
                    B02      (feature, date) float64 0.05427 0.05376 ... -0.001047 -0.001759
                    B03      (feature, date) float64 0.07723 0.07689 0.07655 ... 0.02968 0.02928
                    B04      (feature, date) float64 0.1039 0.1034 0.1029 ... 0.03758 0.03729
                    B8A      (feature, date) float64 0.1994 0.1998 0.2002 ... 0.2472 0.2473
                    B11      (feature, date) float64 0.2218 0.2216 0.2215 ... 0.1293 0.1291
                Attributes:
                    description:  JFdist applied on full date range data.
        '''

#================================================================================================================
    def preprocess_band(self, InBand, scaling):        
#================================================================================================================

        ''' 
        The dataset has no entries for all dates 2017/09/01 till 2019/03/31. Fill the missing dates with nan.
        '''   
        print('    - preprocessing {}'.format(InBand))
        
        aInBand = self.data[InBand].values * scaling       
        aBand = np.zeros(shape=(np.shape(aInBand)[0], self.periods))
        aBand[:, self.at_date] = aInBand
        aBand = np.ma.getdata(aBand)
        aBand[np.isnan(aBand)] = 0
    
        # apply the master mask on all data in the current line via fancy slicing (mask has to be transposed)
        aBand[self.master_mask.T == True] = 0
        
        #TEST
        # aBand[aBand == 0] = 0.02
        #TEST
        
        aHants = self.HANTS_light(self.periods, aBand, frequencies_considered_count=2)

        return aHants

#================================================================================================================
    def get_master_mask(self, fS2, basename):        
#================================================================================================================
        
        fMaster_mask = os.path.join(self.fOutdir, '{}_master_mask.csv'.format(basename))
        
        if os.path.isfile(fMaster_mask):
            self.master_mask = pd.read_csv(fMaster_mask)
            return
        
        file2read = netCDF4.Dataset(fS2,'r')
        
        B02 = (file2read.variables['B02'][:] * 0.0001).T
        aB02 = np.ma.getdata(B02)
        aB02[np.isnan(aB02)] = 0
        
        print(' * Calculating the madHANTS master mask; Nr. entries: {}'.format(aB02.shape[1]))

        aBlue = np.zeros(shape=(self.periods, np.shape(aB02)[1]))
        aBlue[self.at_date, :] = aB02
        
        B11 = (file2read.variables['B11'][:] * 0.0001).T
        aB11 = np.ma.getdata(B11)
        aB11[np.isnan(aB11)] = 0
        
        aSWIR = np.zeros(shape=(self.periods, np.shape(aB11)[1]))
        aSWIR[self.at_date, :] = aB11
 
        blue_HANTS = self.HANTS_light(self.periods, aBlue.T)
        swir_HANTS = self.HANTS_light(self.periods, aSWIR.T)
        
        ### TEST
        # y1 =aBlue[:, 37]
        # y2 = blue_HANTS[37, :]
        # x = np.arange(self.periods)
        
        # plt.plot(x,y1)
        # plt.plot(x,y2)
        # # plt.show()
        ### TEST
        
        # calculate difference between HANTS data and real data point
        diff_blue = np.abs(blue_HANTS.T - aBlue)
        diff_swir = np.abs(swir_HANTS.T - aSWIR)
          
        # calculate MAD (we don\t use the normalized MAD here, since we didn't used 
        # the absolute difference to the median - so MAD has not to be scaled to fit standard deviation)
        # NOTE: we have to exclude all -1 points from the MAD calculation
        #create masked array out of diff array to exclude -1 points from MAD        
        ma_diff_blue = np.ma.array(diff_blue, mask=(aBlue == 0))
        ma_diff_swir = np.ma.array(diff_swir, mask=(aSWIR == 0))

        MAD_blue = np.ma.median(ma_diff_blue, axis=0, keepdims=True).filled(0)
        MAD_swir = np.ma.median(ma_diff_swir, axis=0, keepdims=True).filled(0)
        
        # set numpy error warning for divide to avoid messages for water pixel
        np.seterr(divide='ignore', invalid='ignore')
          
        # calculate score value for each data point
        score_blue = diff_blue / MAD_blue
        score_swir = diff_swir / MAD_swir
        
        # create mask for both channels via comparison of score to threshold
        threshold = 3.5  # is nearly 3.5 standard deviations
         
        mask_blue = score_blue >= threshold
        mask_swir = score_swir >= threshold
        # create master mask by taking all outliers from blue and swir into account
        self.master_mask = mask_swir | mask_blue
        
        dfOut = pd.DataFrame(self.master_mask)
        dfOut.to_csv(fMaster_mask, index=False)

        ### TEST
        # aBlue[self.master_mask == True] = 0
        # y3 =aBlue[:, 37]
        
        # plt.plot(x,y3)       
        # plt.show()
        # a=1
        ### TEST
        
#================================================================================================================
    def makediag3d(self, M):
#================================================================================================================
        # Computing diagonal for each row of a 2d array. See: http://stackoverflow.com/q/27214027/2459096
        # helper function for HANTS algorithm
        b = np.zeros((M.shape[0], M.shape[1] * M.shape[1]))
        b[:, ::M.shape[1] + 1] = M
        return b.reshape(M.shape[0], M.shape[1], M.shape[1])
    
#================================================================================================================
    def get_starter_matrix(self, base_period_len, sample_count, frequencies_considered_count):
#================================================================================================================
        # get first matrix with harmonisation factors
        # helper function for HANTS algorithm
        nr = min(2 * frequencies_considered_count + 1,
                      sample_count)  # number of 2*+1 frequencies, or number of input images
        mat = np.zeros(shape=(nr, sample_count))
           
        mat[0, :] = 1
        ang = 2 * np.pi * np.arange(base_period_len) / base_period_len
        cs = np.cos(ang)
        sn = np.sin(ang)
        # create some standard sinus and cosinus functions and put in matrix
        i = np.arange(1, frequencies_considered_count + 1)
        ts = np.arange(sample_count)
        for column in range(sample_count):
            index = np.mod(i * ts[column], base_period_len)
            # index looks like 000, 123, 246, etc, until it wraps around (for len(i)==3)
            mat[2 * i - 1, column] = cs.take(index)
            mat[2 * i, column] = sn.take(index)
        return mat

#================================================================================================================
    def HANTS_light(self, sample_count, inputs, frequencies_considered_count=3, outliers_to_reject='Hi',
              exclude_low=0., exclude_high=255, fit_error_tolerance=5, delta=0.1):
#================================================================================================================
        """
        Function to apply the Harmonic analysis of time series applied to arrays
        
        This version gives only back the harmonized time series
        
        sample_count    = nr. of images (total number of actual samples of the time series)
        base_period_len    = length of the base period, measured in virtual samples
                (days, dekads, months, etc.)
        frequencies_considered_count    = number of frequencies to be considered above the zero frequency
        inputs     = array of input sample values (e.g. NDVI values)
        ts    = array of size sample_count of time sample indicators
                (indicates virtual sample number relative to the base period);
                numbers in array ts maybe greater than base_period_len
                If no aux file is used (no time samples), we assume ts(i)= i,
                where i=1, ..., sample_count
        outliers_to_reject  = 2-character string indicating rejection of high or low outliers
                select from 'Hi', 'Lo' or 'None'
        low   = valid range minimum
        high  = valid range maximum (values outside the valid range are rejeced
                right away)
        fit_error_tolerance   = fit error tolerance (points deviating more than fit_error_tolerance from curve
                fit are rejected)
        dod   = degree of overdeterminedness (iteration stops if number of
                points reaches the minimum required for curve fitting, plus
                dod). This is a safety measure
        delta = small positive number (e.g. 0.1) to suppress high amplitudes
        """
        # define some parameters
        base_period_len = sample_count  #
    
        # check which setting to set for outlier filtering
        if outliers_to_reject == 'Hi':
            sHiLo = -1
        elif outliers_to_reject == 'Lo':
            sHiLo = 1
        else:
            sHiLo = 0
    
        nr = min(2 * frequencies_considered_count + 1,
                 sample_count)  # number of 2*+1 frequencies, or number of input images
    
        # create empty arrays to fill
        outputs = np.zeros(shape=(inputs.shape[0], sample_count))
        
        #get starter matrix
        mat = self.get_starter_matrix(base_period_len, sample_count, frequencies_considered_count)
    
        # repeat the mat array over the number of arrays in inputs
        # and create arrays with ones with shape inputs where high and low values are set to 0
        mat = np.tile(mat[None].T, (1, inputs.shape[0])).T
        p = np.ones_like(inputs)
        p[(exclude_low >= inputs) | (inputs > exclude_high)] = 0
        nout = np.sum(p == 0, axis=-1)  # count the outliers for each timeseries
    
        # prepare for while loop
        ready = np.zeros((inputs.shape[0]), dtype=bool)  # all timeseries set to false
    
        dod = 1  # (2*frequencies_considered_count-1)  # Um, no it isn't :/
        noutmax = sample_count - nr - dod
        
        # NOW we have to deal with pixel where a gap is
        # since we have filled the gap with -1 in the whole line we only tell
        # the algorithmus that this whole line is valid 
        p[p.sum(axis=1)==0] = 1
        
        # and set the nout value of gap lines to noutmax -> then this line is ready after the
        # first processing
        nout[nout==sample_count] = noutmax 
        
        ## here comes now the real calculations!    
        for _ in range(sample_count):
            if ready.all():
                break
            # print '--------*-*-*-*',it.value, '*-*-*-*--------'
            # multiply outliers with timeseries
            za = np.einsum('ijk,ik->ij', mat, p * inputs)
    
            # multiply mat with the multiplication of multiply diagonal of p with transpose of mat
            diag = self.makediag3d(p)
            A = np.einsum('ajk,aki->aji', mat, np.einsum('aij,jka->ajk', diag, mat.T))
            # add delta to suppress high amplitudes but not for [0,0]
            A = A + np.tile(np.diag(np.ones(nr))[None].T, (1, inputs.shape[0])).T * delta
            A[:, 0, 0] = A[:, 0, 0] - delta
    
            # solve linear matrix equation and define reconstructed timeseries
            zr = np.linalg.solve(A, za)
            outputs = np.einsum('ijk,kj->ki', mat.T, zr)
    
            # calculate error and sort err by index
            err = p * (sHiLo * (outputs - inputs))
            rankVec = np.argsort(err, axis=1, )
    
            # select maximum error and compute new ready status
            maxerr = np.diag(err.take(rankVec[:, sample_count - 1], axis=-1))
            ready = (maxerr <= fit_error_tolerance) | (nout == noutmax)
    
            # if ready is still false
            if not ready.all():
                j = rankVec.take(sample_count - 1, axis=-1)
    
                p.T[j.T, np.indices(j.shape)] = p.T[j.T, np.indices(j.shape)] * ready.astype(
                    int)  #*check
                nout += 1
        return outputs

#================================================================================================================
if __name__ == '__main__':
#================================================================================================================

    Info = {
        'fBasedir': r'/data/EEA_HRL_VLCC/data/ref/lucas/LUCAS2018/2018_EU_LUCAS2018_POINT_JD/',
        'fSplitdir': r'/data/EEA_HRL_VLCC/data/ref/lucas/LUCAS2018/2018_EU_LUCAS2018_POINT_JD/lucas_split/',
        'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/1_HANTS_preprocessed_allbands/',
        # 'fOutdir': r'/data/EEA_HRL_VLCC/user/luc/data/LUCAS2018/test/',
        'periods': 577 # 2017/09/01 till 2019/03/31
        }
    
    oLUCASpreprocessing = cLUCASpreprocessing(Info)
    oLUCASpreprocessing.start_processing()
    
    
    
    
    
    