# -*- coding: utf-8 -*-
# Uncomment the import only for coding support
from openeo.udf import XarrayDataCube
from typing import Dict
import xarray as xr
import numpy as np
import threading
import functools


_threadlocal = threading.local()


def _load_cropland_model(path: str):

    from cropclass.models import CroptypeModel

    # Keras/tensorflow models are not guaranteed to be threadsafe,
    # but by loading and storing the model once per thread we should
    # be able to safely eliminate loading at model predict time
    cropland_model = getattr(_threadlocal, 'cropland_model', None)

    if cropland_model is None:
        cropland_model = CroptypeModel.from_config(path)

    # Store per thread
    _threadlocal.cropland_model = cropland_model

    return cropland_model


def _load_croptype_model(path: str):

    from cropclass.models import CroptypeModel

    # Keras/tensorflow models are not guaranteed to be threadsafe,
    # but by loading and storing the model once per thread we should
    # be able to safely eliminate loading at model predict time
    croptype_model = getattr(_threadlocal, 'croptype_model', None)

    if croptype_model is None:
        croptype_model = CroptypeModel.from_config(path)

    # Store per thread
    _threadlocal.croptype_model = croptype_model

    return croptype_model


@functools.lru_cache(maxsize=25)
def load_model(configfile: str, modelname: str):

    if modelname == 'cropland_model':
        return _load_cropland_model(configfile)

    elif modelname == 'croptype_model':
        return _load_croptype_model(configfile)

    else:
        raise ValueError(('`modelname` should be one of '
                          '[cropland_model, croptype_model]'))


def add_sensor_name(features, s):
    featnames = features.names
    featnames = [s + '-' + f for f in featnames]
    features.names = featnames

    return features


def classify(inarr: xr.DataArray, startdate, enddate,
             **kwargs) -> xr.DataArray:

    # OpenEO needs access to worldcereal, satio and cropclass
    import sys
    sys.path.append(r'/data/users/Public/kristofvt/python/dep')
    from cropclass.classifier import CropclassClassifier
    from cropclass.openeo.fp import OpenEOS2FeaturesProcessor, OpenEOS1FeaturesProcessor
    from cropclass.features.settings_old import get_croptype_tsteps_parameters
    from worldcereal.utils import is_real_feature

    # Get processing parameters
    parameters = get_croptype_tsteps_parameters()
    settings = parameters['settings']
    features_meta = parameters['features_meta']
    ignore_def_feat = parameters['ignore_def_feat']
    for coll in settings.keys():
        settings[coll]['composite']['start'] = startdate
        settings[coll]['composite']['end'] = enddate

    # Store the original dimension order for later
    orig_dims = list(inarr.dims)
    orig_dims.remove('t')  # Time dimension will be removed

    # Make sure we have a controlled dimension order
    inarr = inarr.transpose('bands', 't', 'x', 'y')

    # Identify the S1 and S2 bands
    bands = list(inarr.bands.values)
    S2_bands = [b for b in bands if 'B' in b] + ['SCL']
    S1_bands = [b for b in bands if 'V' in b]

    # Split according to S1/S2
    S2data = inarr.sel(bands=S2_bands).assign_attrs(sensor='S2')
    S1data = inarr.sel(bands=S1_bands).assign_attrs(sensor='S1')

    # Setup SAR features processor
    fp_s1 = OpenEOS1FeaturesProcessor(
        S1data,
        settings=settings['SAR'],
        features_meta=features_meta['SAR'],
        ignore_def_features=ignore_def_feat['SAR'])

    # Preprocess and compute features
    features_s1 = fp_s1.compute_features()
    features_s1 = add_sensor_name(features_s1, 'SAR')

    # Setup optical features processor
    fp_s2 = OpenEOS2FeaturesProcessor(
        S2data,
        settings=settings['OPTICAL'],
        features_meta=features_meta['OPTICAL'],
        ignore_def_features=ignore_def_feat['OPTICAL'])

    # Preprocess and compute features
    features_s2 = fp_s2.compute_features()
    features_s2 = add_sensor_name(features_s2, 'OPTICAL')

    # Merge the features
    features = features_s2.merge(features_s1)
    features_s1 = None
    features_s2 = None

    # Select real features
    all_fts = features.names
    real_fts = [ft for ft in all_fts if is_real_feature(ft)]
    features = features.select(real_fts)

    # Construct the feature names the models expect
    bands10m = ["B02", "B03", "B04", "B08"]
    ft_labels = []
    for label in features.names:
        parts = label.split('-')
        ft_labels.append(
            (f'{parts[0]}-{parts[1]}-{parts[2]}' + '-10m' if parts[1]
             in bands10m else f'{parts[0]}-{parts[1]}-{parts[2]}' + '-20m'))
    features.names = ft_labels

    # --------------------------------------------------------------------------
    # STAGE 1: CREATE CROPLAND MASK
    # --------------------------------------------------------------------------

    # Define the cropland model and encoder
    cropland_model = '/data/users/Public/kristofvt/NEXTLAND/models/cropland_detector_WorldCerealPixelLSTM_11Bands_v050/config.json'  # NOQA
    cropland_encoder = '/data/users/Public/kristofvt/NEXTLAND/models/cropland_detector_WorldCerealPixelLSTM_11Bands_v050/label_encoder.p'  # NOQA

    # Load the cropland model
    cropland_model = load_model(cropland_model, 'cropland_model')

    # Create a WorldCerealClassifier from the model
    croplandclassifier = CropclassClassifier(
        cropland_model,
        filtersettings={'kernelsize': 5, 'conf_threshold': 0.8},
        encoder=cropland_encoder)

    # Run cropland model
    prediction, confidence = croplandclassifier.predict(features,
                                                        nodatavalue=255)

    # Transform landcover in a cropland mask
    mask = np.zeros_like(prediction)
    mask[prediction == 11] = 1

    # --------------------------------------------------------------------------
    # STAGE 2: IDENTIFY CROPS WITHIN MASK
    # --------------------------------------------------------------------------

    # Define the croptype model and encoder
    croptype_model = '/data/users/Public/kristofvt/NEXTLAND/models/croptype_detector_WorldCerealPixelLSTM_11Bands_v050/config.json'  # NOQA
    croptype_encoder = '/data/users/Public/kristofvt/NEXTLAND/models/croptype_detector_WorldCerealPixelLSTM_11Bands_v050/label_encoder.p'  # NOQA

    # Load the croptype model
    croptype_model = load_model(croptype_model, 'croptype_model')

    # Create a WorldCerealClassifier from the model
    croptypeclassifier = CropclassClassifier(
        croptype_model,
        filtersettings={'kernelsize': 5, 'conf_threshold': 0.8},
        maskdata=mask,
        encoder=croptype_encoder)

    # Run crop type model, put masked regions to nodata
    prediction, confidence = croptypeclassifier.predict(
        features,
        nodatavalue=0)

    # Finally transform result to DataArray
    prediction = prediction.reshape(
        (1, *prediction.shape))  # Add bands dimension
    result_da = xr.DataArray(prediction, coords=[
        ['cropclass_label'],
        inarr.coords["x"],
        inarr.coords["y"]], dims=["bands", "x", "y"])

    # features = croplandclassifier._scale(features)
    # features.data = features.data * 200

    # result_da = xr.DataArray(features.data, coords=[
    #     features.names,
    #     inarr.coords["x"],
    #     inarr.coords["y"]], dims=["bands", "x", "y"])

    # And make sure we revert back to original dimension order
    result_da = result_da.transpose(*orig_dims)

    return result_da


def apply_datacube(cube: XarrayDataCube, context: Dict) -> XarrayDataCube:

    # Extract xarray.DataArray from the cube
    inarr = cube.get_array()

    # Run the two-stage classification
    predictions = classify(inarr, **context)

    # Wrap result in an OpenEO datacube
    return XarrayDataCube(predictions)


def load_cropclass_udf() -> str:
    import os
    with open(os.path.realpath(__file__), 'r+') as f:
        return f.read()
