from openeo_udf.api.udf_data import UdfData
from openeo_udf.api.structured_data import StructuredData
from tensorflow.keras.models import load_model
import xarray as xr
import _pickle as pickle
from openeo.rest.conversions import timeseries_json_to_pandas
import os
import pandas as pd
import numpy as np
import sys
sys.path.append('/data/users/Public/bontek/Nextland/Croptype_classification/parcel-1.0.0-py3-none-any.whl')
from parcel.feature.classification import preprocess_ts
import nextland_services as nx
from pathlib import Path


def load_generator_model():
    # Keras/tensorflow models are not guaranteed to be threadsafe,
    # but by loading and storing the model once per thread we should
    # be able to safely eliminate loading at model predict time
    import threading
    _threadlocal = threading.local()
    generator_model = getattr(_threadlocal, 'generator_model', None)
    if generator_model is None:
        import io
        import pkgutil
        import h5py
        from tensorflow.keras.models import load_model

        # Load tensorflow model from in-memory HDF5 resource
        path = 'resources/models/SoilEssentials_croptype_detector.h5'
        data = pkgutil.get_data('nextland_services', path)

        with h5py.File(io.BytesIO(data), mode='r') as h5:
            generator_model = load_model(h5)

        # Store per thread
        _threadlocal.generator_model = generator_model

    return generator_model

def udf_croptypeclassification(udf_data:UdfData):
    ## constants
    columns_order = ['VH', 'VV', 'angle', "B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B09", "B11",
                     "B12", "B8A", 'SCL']

    INVERSE_LUT = {
        1: 'Barley',
        2: 'Wheat',
        3: 'Beans',
        4: 'Grass',
        5: 'Flax',
        6: 'Sugar beet',
        7: 'Potato',
        8: 'Rye',
        9: 'Rapeseed',
        10: 'Oats',
        11: 'Peas'
    }


    ## load the TS
    ts_dict = udf_data.get_structured_data_list()[0].data
    if not ts_dict:  # workaround of ts_dict is empty
        return
    TS_df = timeseries_json_to_pandas(ts_dict)
    TS_df.index = pd.to_datetime(TS_df.index).date
    #detect the years for which data is available for croptype prediction
    years = list(set([item.year for item in TS_df.index.to_list()]))
    years.sort()
    years = years[1:]
    amount_fields = next(iter(ts_dict.values()))
    unique_ids_fields = [int(p) for p in range(len(amount_fields))]

    # Load the model
    model = load_generator_model()

    lst_predictions_df = []

    for year in years:

        #### PART1: PUT INPUT TS INTO XR DATASET

        merged_ds = None
        ## loop over every field
        for i in range(len(amount_fields)):
            if len(amount_fields) > 1:
                TS_df_field = TS_df.loc[:, TS_df.columns.get_level_values(0) == i]
            else:
                TS_df_field  =TS_df
            TS_df_field.columns = columns_order
            s1_data = TS_df_field.iloc[:, 0:3]
            s2_data = TS_df_field.iloc[:, 3:]

            # prepare S1 data
            s1_data.index = pd.to_datetime(s1_data.index).tz_localize(None)
            s1_data.index = pd.MultiIndex.from_product([s1_data.index.tolist(), [i]],
                                                       names=['time', 'CODE_OBJ'])
            s1_data = s1_data[['VV', 'VH']].rename(columns={'VV': 'VV_ASC', 'VH': 'VH_ASC'})
            s1_data = np.power(10, s1_data / 10.)
            s1_asc_ds = xr.Dataset.from_dataframe(s1_data)

            # Create fake empty DES dataset
            s1_des_ds = s1_asc_ds.copy(deep=True)
            s1_des_ds = s1_des_ds.rename_vars({'VV_ASC': 'VV_DES', 'VH_ASC': 'VH_DES'})
            s1_des_ds['VV_DES'].values = s1_des_ds['VV_DES'].values * np.nan
            s1_des_ds['VH_DES'].values = s1_des_ds['VH_DES'].values * np.nan

            # prepare S2 data
            s2_data.index = pd.to_datetime(s2_data.index).tz_localize(None)
            s2_data.index = pd.MultiIndex.from_product([s2_data.index.tolist(), [i]],
                                                       names=['time', 'CODE_OBJ'])
            s2_data = s2_data * 0.0001
            s2_ds = xr.Dataset.from_dataframe(s2_data)

            # Merge datasets
            current_merged_ds = xr.merge([s2_ds, s1_asc_ds, s1_des_ds])

            if merged_ds is not None:
                merged_ds = merged_ds.merge(current_merged_ds)
            else:
                merged_ds = current_merged_ds

        # Setup dataframes to hold the data
        start_date = str(int(year) - 1) + '-10-01'
        end_date = str(year) + '-08-15'

        merged_ds = merged_ds.sel(time=slice(start_date, end_date))

        #### PART2: DO THE ACTUAL CROPTYPE CLASSFICATION ON THE MERGED XARRAY DATASETS

        netcdffile = merged_ds

        array_dict = preprocess_ts.to_numpy_arrays(netcdffile,
                                                   identifier='CODE_OBJ',
                                                   resample_freq='5D',
                                                   combine_asc_des=True)


        # Make predictions
        predictions = model.predict([array_dict['S1'], array_dict['S2']])
        outputclasses = np.argmax(predictions, axis=1)
        confidence = np.max(predictions, axis=1)

        # Load encoder
        #encoder = pickle.load(open(encoderpath, 'rb'))
        outputclasses = outputclasses + 1 #encoder.inverse_transform(outputclasses)
        outputclasses = [INVERSE_LUT[x] for x in outputclasses]

        # Put in UDF
        new_data = pd.DataFrame(index= unique_ids_fields)
        new_data[f'{year}_CT'] = outputclasses
        new_data[f'{year}_CONF'] = confidence

        lst_predictions_df.append(new_data)

    df_croptype_predictions = pd.concat(lst_predictions_df, axis = 1)


    udf_data.set_structured_data_list([StructuredData(description= 'croptype_classification', data = df_croptype_predictions.to_dict(), type = "dict")])

    return udf_data

