# -*- coding: utf-8 -*-

"""
Created on Fri Jul 15 09:44:54 2022

@author: bertelsl
# """
import netCDF4
import numpy as np
import matplotlib.pyplot as plt

def makediag3d(M):
    # Computing diagonal for each row of a 2d array. See: http://stackoverflow.com/q/27214027/2459096
    # helper function for HANTS algorithm
    b = np.zeros((M.shape[0], M.shape[1] * M.shape[1]))
    b[:, ::M.shape[1] + 1] = M
    return b.reshape(M.shape[0], M.shape[1], M.shape[1])

def get_starter_matrix(base_period_len, sample_count, frequencies_considered_count):
    # get first matrix with harmonisation factors
    # helper function for HANTS algorithm
    nr = min(2 * frequencies_considered_count + 1,
                  sample_count)  # number of 2*+1 frequencies, or number of input images
    mat = np.zeros(shape=(nr, sample_count))
       
    mat[0, :] = 1
    ang = 2 * np.pi * np.arange(base_period_len) / base_period_len
    cs = np.cos(ang)
    sn = np.sin(ang)
    # create some standard sinus and cosinus functions and put in matrix
    i = np.arange(1, frequencies_considered_count + 1)
    ts = np.arange(sample_count)
    for column in range(sample_count):
        index = np.mod(i * ts[column], base_period_len)
        # index looks like 000, 123, 246, etc, until it wraps around (for len(i)==3)
        mat[2 * i - 1, column] = cs.take(index)
        mat[2 * i, column] = sn.take(index)
    return mat

def HANTS_light(sample_count, inputs, frequencies_considered_count=3, outliers_to_reject='Hi',
          exclude_low=0., exclude_high=255, fit_error_tolerance=5, delta=0.1):
    """
    Function to apply the Harmonic analysis of time series applied to arrays
    
    This version gives only back the harmonized time series
    
    sample_count    = nr. of images (total number of actual samples of the time series)
    base_period_len    = length of the base period, measured in virtual samples
            (days, dekads, months, etc.)
    frequencies_considered_count    = number of frequencies to be considered above the zero frequency
    inputs     = array of input sample values (e.g. NDVI values)
    ts    = array of size sample_count of time sample indicators
            (indicates virtual sample number relative to the base period);
            numbers in array ts maybe greater than base_period_len
            If no aux file is used (no time samples), we assume ts(i)= i,
            where i=1, ..., sample_count
    outliers_to_reject  = 2-character string indicating rejection of high or low outliers
            select from 'Hi', 'Lo' or 'None'
    low   = valid range minimum
    high  = valid range maximum (values outside the valid range are rejeced
            right away)
    fit_error_tolerance   = fit error tolerance (points deviating more than fit_error_tolerance from curve
            fit are rejected)
    dod   = degree of overdeterminedness (iteration stops if number of
            points reaches the minimum required for curve fitting, plus
            dod). This is a safety measure
    delta = small positive number (e.g. 0.1) to suppress high amplitudes
    """
    # define some parameters
    base_period_len = sample_count  #

    # check which setting to set for outlier filtering
    if outliers_to_reject == 'Hi':
        sHiLo = -1
    elif outliers_to_reject == 'Lo':
        sHiLo = 1
    else:
        sHiLo = 0

    nr = min(2 * frequencies_considered_count + 1,
             sample_count)  # number of 2*+1 frequencies, or number of input images

    # create empty arrays to fill
    outputs = np.zeros(shape=(inputs.shape[0], sample_count))
    
    #get starter matrix
    mat = get_starter_matrix(base_period_len, sample_count, frequencies_considered_count)

    # repeat the mat array over the number of arrays in inputs
    # and create arrays with ones with shape inputs where high and low values are set to 0
    mat = np.tile(mat[None].T, (1, inputs.shape[0])).T
    p = np.ones_like(inputs)
    p[(exclude_low >= inputs) | (inputs > exclude_high)] = 0
    nout = np.sum(p == 0, axis=-1)  # count the outliers for each timeseries

    # prepare for while loop
    ready = np.zeros((inputs.shape[0]), dtype=bool)  # all timeseries set to false

    dod = 1  # (2*frequencies_considered_count-1)  # Um, no it isn't :/
    noutmax = sample_count - nr - dod
    
    # NOW we have to deal with pixel where a gap is
    # since we have filled the gap with -1 in the whole line we only tell
    # the algorithmus that this whole line is valid 
    p[p.sum(axis=1)==0] = 1
    
    # and set the nout value of gap lines to noutmax -> then this line is ready after the
    # first processing
    nout[nout==sample_count] = noutmax 
    
    ## here comes now the real calculations!    
    for _ in range(sample_count):
        if ready.all():
            break
        # print '--------*-*-*-*',it.value, '*-*-*-*--------'
        # multiply outliers with timeseries
        za = np.einsum('ijk,ik->ij', mat, p * inputs)

        # multiply mat with the multiplication of multiply diagonal of p with transpose of mat
        diag = makediag3d(p)
        A = np.einsum('ajk,aki->aji', mat, np.einsum('aij,jka->ajk', diag, mat.T))
        # add delta to suppress high amplitudes but not for [0,0]
        A = A + np.tile(np.diag(np.ones(nr))[None].T, (1, inputs.shape[0])).T * delta
        A[:, 0, 0] = A[:, 0, 0] - delta

        # solve linear matrix equation and define reconstructed timeseries
        zr = np.linalg.solve(A, za)
        outputs = np.einsum('ijk,kj->ki', mat.T, zr)

        # calculate error and sort err by index
        err = p * (sHiLo * (outputs - inputs))
        rankVec = np.argsort(err, axis=1, )

        # select maximum error and compute new ready status
        maxerr = np.diag(err.take(rankVec[:, sample_count - 1], axis=-1))
        ready = (maxerr <= fit_error_tolerance) | (nout == noutmax)

        # if ready is still false
        if not ready.all():
            j = rankVec.take(sample_count - 1, axis=-1)

            p.T[j.T, np.indices(j.shape)] = p.T[j.T, np.indices(j.shape)] * ready.astype(
                int)  #*check
            nout += 1
    return outputs
############### END of Calculation of Harmonics of time series #######################
    
fIn = r'/home/bertelsl/Public/PyCo/HRL_VLCC/data/S2_2017-09-01_2018-08-30_2018_AT_LPIS_POLY_110.nc'
# fIn = r'/home/bertelsl/Public/PyCo/HRL_VLCC/data/2018_crops.nc'

file2read = netCDF4.Dataset(fIn,'r')

# code = file2read.variables['LABEL'][:]


B02 = file2read.variables['B02'][:]
aBlue = np.ma.getdata(B02)
aBlue[np.isnan(aBlue)] = 0

B11 = file2read.variables['B11'][:]
aSWIR = np.ma.getdata(B11)
aSWIR[np.isnan(aSWIR)] = 0

aOut = HANTS_light(364, aBlue.T)

entry = 0

y1 =aBlue[:, entry]
y2 = aOut[entry, :]
x = np.arange(364)

plt.plot(x,y1)
plt.plot(x,y2)

# plt.show()

blue_HANTS = HANTS_light(364, aBlue.T)
swir_HANTS = HANTS_light(364, aSWIR.T)

diff_blue = np.abs(blue_HANTS.T - aBlue)
diff_swir = np.abs(swir_HANTS.T - aSWIR)

ma_diff_blue = np.ma.array(diff_blue, mask=(aBlue == 0))
ma_diff_swir = np.ma.array(diff_swir, mask=(aSWIR == 0))

MAD_blue = np.ma.median(ma_diff_blue, axis=1, keepdims=True).filled(0)
MAD_swir = np.ma.median(ma_diff_swir, axis=1, keepdims=True).filled(0)

# set numpy error warning for divide to avoid messages for water pixel
np.seterr(divide='ignore', invalid='ignore')
  
# calculate score value for each data point
score_blue = diff_blue / MAD_blue
score_swir = diff_swir / MAD_swir

# create mask for both channels via comparison of score to threshold
threshold = 3.5  # is nearly 3.5 standard deviations
 
mask_blue = score_blue >= threshold
mask_swir = score_swir >= threshold
# create master mask by taking all outliers from blue and swir into account
master_mask = mask_swir | mask_blue


# apply the mask on all data in the current line via fancy slicing (mask has to be transposed)
aBlue[master_mask == True] = 0
aSWIR[master_mask == True] = 0

blue_HANTS2 = HANTS_light(364, aBlue.T)

y3 =blue_HANTS2[entry, :]
plt.plot(x,y3)

plt.show()

