# -*- coding: utf-8 -*-
# Uncomment the import only for coding support
from openeo_udf.api.udf_data import UdfData
import pandas as pd


def udf_timesat(data: UdfData):
    """
    This UDF calls timesat on the time series in data.get_structured_data_list[0].
    """

    import numpy as np
    import sys
    import traceback
    import datetime
    from dateutil.parser import parse

    ### Load the user context

    user_context = data.user_context
    valid_metrics = ['FAPAR','LAI','FCOVER', 'NDVI']
    metric = user_context.get('metric')
    # the metrics will be used to determine the physical range of it
    if metric not in valid_metrics:
        metric = 'FAPAR'

    class TimeSatWrapper:

        ###############################################################
        ###              LOAD                                       ###
        ###############################################################

        TimesatVersion = 'TIMESAT_py38_numpy1184/TIMESAT4.1.6/' #'TIMESAT4.1.2a/'
        sys.path.append(r'/data/users/Public/zhanzhangcai/TIMESAT/' + TimesatVersion)
        import timesat

        # Note:
        # vpp   :
        #     Two seasons are stored per year      HRVPP format/scalling factor
        # ! 1) season start date                YYDOY
        # ! 2) season start value               *10000
        # ! 3) season start derivative          *10000
        # ! 4) season end date                  YYDOY
        # ! 5) season end values                *10000
        # ! 6) season end derivative            *10000
        # ! 7) season length                    -
        # ! 8) basevalue                        *10000
        # ! 9) time for peak                    YYDOY
        # ! 10) value at peak                   *10000
        # ! 11) seasonal amplitude              *10000
        # ! 12) large integral                  *10
        # ! 13) small integral                  *10

        def __init__(self):

            ###############################################################
            ###              TIMESAT PARAMETERS                         ###
            ###############################################################

            #self.p_ylu = [0., 12]
            self.p_ignoreday = 366
            self.p_a = [4., 6, 1., 0., 0., 0, 0., 0., 0.]
            self.p_printflag = 99  # only row information will be print
            self.p_nodata = -9999.  # whatever you choose, keep it 5 letters long when str(int(round(...)))
            self.p_nenvi = 1
            # dbl# p_wfactnum                      ! adaptation strength
            self.p_wfactnum = 1
            # int# p_startmethod                   ! method for season start
            self.p_startmethod = 1
            # dbl# p_startcutoff(255,2)            ! parameter for season start
            # p_startcutoff = [start of season threshold, end of season threshold]
            self.p_startcutoff = [0.25, 0.25]

            # dbl# p_low_percentile                ! parameter for define base level
            self.p_low_percentile = 0.05

            # int# p_hrvppformat                     ! 1: output as HRVPP format (YYDOY)
            #                                       ! 0: output as TIMESAT format (sequential number, no scalling)
            self.p_hrvppformat = 1

            # int# p_continuityflag                ! 0: no continuity conditions (calval)
            #                                     ! 1: new run (produce) (janval,jande,decval,decde will update on the third year)
            #                                     ! 2: re-run  (janval,jande,decval,decde will update on the second year)
            self.p_continuityflag = 0

            # encoding in the results
            # the resulting array of phenologicalpar layout is:
            # [col][row][iyear*p_nresultvars*p_nseasons+iseason*p_nresultvars+ivariable]
            # where ivariable is for example p_sos/p_pos/p_eos index
            # where col is used for each polygon timeseries and row is fixed to size 1
            self.p_nseasons = 2  # how many seasons does timesat return per year
            self.p_nresultvars = 13  # number of variables for a single season
            self.p_sos = 0  # index of start of season variable
            self.p_pos = 8  # index of peak of season variable
            self.p_eos = 3  # index of end of season variable

            #
            #       ! close to 0: more seasons will be detect when the amplitude of time series is close to p_ylu(2)-p_ylu(1)
            #       ! close to 1: less seasons will be detect when the amplitude of time series is close to p_ylu(2)-p_ylu(1)
            self.p_seapar = 0.5

        ###############################################################
        ###              MAIN FUNC                                  ###
        ###############################################################

        # expects that dates are sorted chronologically and values belong to the right dates
        # dates: list of YYjjj strings (two digit year and three digit day of year)
        # values:
        def run(self, dates, values, metric):

            try:
                if metric == 'LAI':
                    p_ylu = [0,12]
                else:
                    p_ylu = [0,1]
                # number of years the data spans
                yr = int(dates[-1][:4]) - int(dates[0][:4]) + 1

                step = values.shape[-1]

                aQA = np.full(values.shape, 5, order='F', dtype='float64')

                outindex = np.arange(1, yr * 365, 1)  # output all days
                outindex_num = len(outindex)  # number of output images

                ret = self.timesat.tsfprocess(
                    yr, values, aQA, dates, outindex,
                    self.p_ignoreday, p_ylu, self.p_a, self.p_printflag, self.p_nodata, self.p_nenvi,
                    self.p_wfactnum,
                    self.p_startmethod, self.p_startcutoff, self.p_low_percentile, self.p_hrvppformat,
                    self.p_seapar,aQA.shape[0], aQA.shape[1], step, outindex_num)

                phenologicalpar = ret[0]

                # see the explanation at the parameters about the layout of the results
                start_year = int(dates[0][:4])
                nodata = str(int(round(self.p_nodata))).zfill(5)
                tsresult = {}
                for iyr in range(yr):
                    iyear = datetime.datetime(start_year + iyr, 1, 1)
                    itsseasons = []
                    for its in phenologicalpar:
                        iseasons = []
                        for isn in range(self.p_nseasons):
                            # caught incosistencies so far:
                            # * dates returned as floating point -> round+int and put zeroes in front of single digit years 1-9 (two-digit year representation!)
                            # * timesat is incosistent, for invalid results it sometimes return nodata, sometimes ust 0.
                            isos = str(int(round(its[
                                                     0, iyr * self.p_nresultvars * self.p_nseasons + isn * self.p_nresultvars + self.p_sos]))).zfill(
                                5)
                            ipos = str(int(round(its[
                                                     0, iyr * self.p_nresultvars * self.p_nseasons + isn * self.p_nresultvars + self.p_pos]))).zfill(
                                5)
                            ieos = str(int(round(its[
                                                     0, iyr * self.p_nresultvars * self.p_nseasons + isn * self.p_nresultvars + self.p_eos]))).zfill(
                                5)
                            iseasons += [
                                (datetime.datetime.strptime(isos,
                                                            "%y%j")).date() if isos != "00000" and isos != nodata else None,
                                (datetime.datetime.strptime(ipos,
                                                            "%y%j")).date() if ipos != "00000" and ipos != nodata else None,
                                (datetime.datetime.strptime(ieos,
                                                            "%y%j")).date() if ieos != "00000" and ieos != nodata else None
                            ]
                        itsseasons += [iseasons]
                    tsresult[str(iyear.year)] = itsseasons

                return tsresult
            except Exception as e:
                print(e)
                raise type(e)('INNER EXCEPTION: ' + str(e) + ' -> ' + str(traceback.format_exc()))

    # init the wrapper
    tsw = TimeSatWrapper()

    # preprocessing the time series
    #  * filtering [[],[],[],...] entries
    #  * sorting by key
    #  * parsing the data:
    #     * replacing timestamps in key with two-digit year + dayofyear (2016-02-10T00:00:00 -> 16041)
    #     * replacing None with p_nodata in the value arrays
    if len(data.get_structured_data_list()[0].data) > 0:
        inputs = list(map(
            lambda item:
            [parse(item[0]).strftime('%Y%j')] +
            [float(i[0]) if i[0] is not None else float(tsw.p_nodata) for i in item[1]]
            ,
            sorted(
                filter(
                    lambda i: len(i[1]) == sum([len(j) for j in i[1]]),
                    data.get_structured_data_list()[0].data.items()
                )
            )
        ))
    else:
        return


    def flatten(L):
        for l in L:
            if isinstance(l, list):
                yield from flatten(l)
            else:
                yield l

    # extract data in the format the wrapper expects
    dates = [i[0] for i in inputs]
    values = np.array([[i[1:]] for i in inputs], order='F', dtype='float64').transpose()

    # run timesat and place the results back to data
    result = tsw.run(dates, values, metric)


    #now convert the dictionary to an interpretable dataframe for the user with indication on which phenometrics are detected at which moment
    columns_pheno = ['SOS_S1', 'POS_S1', 'EOS_S1', 'SOS_S2', 'POS_S2', 'EOS_S2']
    nr_fields = len(result.get(next(iter(result))))
    #make per year a dataframe with the phenometrics stored per field
    lst_df_phen_year = []
    for year in result.keys():
        df_phen_year = pd.DataFrame(pd.DataFrame(result.get(year)).stack(dropna = False)).T

        ## make a multindex columns that is able to identicate the phenometric asossiacted with each field
        field_column = [[['Field_{}'.format(s) ]* len(columns_pheno)]for s in range(nr_fields)]
        ## flatten the field column:
        field_column = list(flatten(field_column))
        multi_index_column_array = [field_column, columns_pheno*nr_fields]
        column_index = pd.MultiIndex.from_tuples(list(zip(*multi_index_column_array)), names = ['field', 'phenometric'])
        df_phen_year.columns = column_index

        df_phen_year.index = [year]
        df_phen_year.index.name = 'year'
        lst_df_phen_year.append(df_phen_year)

    df_phen_final = pd.concat(lst_df_phen_year)


    data.get_structured_data_list()[0].data = df_phen_final.to_dict()

    return data