# -*- coding: utf-8 -*-
# Uncomment the import only for coding support
from openeo.udf import XarrayDataCube
from typing import Dict
import xarray as xr
import numpy as np
import threading
import functools


_threadlocal = threading.local()


def _load_cropland_model(path: str):

    from cropclass.models import CroptypeModel

    # Keras/tensorflow models are not guaranteed to be threadsafe,
    # but by loading and storing the model once per thread we should
    # be able to safely eliminate loading at model predict time
    cropland_model = getattr(_threadlocal, 'cropland_model', None)

    if cropland_model is None:
        cropland_model = CroptypeModel.from_config(path)

    # Store per thread
    _threadlocal.cropland_model = cropland_model

    return cropland_model


def _load_croptype_model(path: str):

    from cropclass.models import CroptypeModel

    # Keras/tensorflow models are not guaranteed to be threadsafe,
    # but by loading and storing the model once per thread we should
    # be able to safely eliminate loading at model predict time
    croptype_model = getattr(_threadlocal, 'croptype_model', None)

    if croptype_model is None:
        croptype_model = CroptypeModel.from_config(path)

    # Store per thread
    _threadlocal.croptype_model = croptype_model

    return croptype_model


@functools.lru_cache(maxsize=25)
def load_model(configfile: str, modelname: str):

    if modelname == 'cropland_model':
        return _load_cropland_model(configfile)

    elif modelname == 'croptype_model':
        return _load_croptype_model(configfile)

    else:
        raise ValueError(('`modelname` should be one of '
                          '[cropland_model, croptype_model]'))


def tsteps(x, n_steps=18, axis=1):
    import scipy.signal
    return scipy.signal.resample(x, n_steps, axis=axis)


def classify(inarr: xr.DataArray, **kwargs) -> xr.DataArray:

    # OpenEO needs access to worldcereal, satio and cropclass
    import sys
    # sys.path.append(r'/data/users/Public/kristofvt/python/dep')

    from cropclass.classifier import CropclassClassifier
    from satio.features import Features

    # Identify the S1 and S2 bands
    bands = list(inarr.bands.values)
    S2_bands = [b for b in bands if 'B' in b]
    S1_bands = [b for b in bands if 'V' in b]

    '''
    Rescale to what the classifiers expect:
    S2: DN / 10000. (to reflectance)
    S1: 10 * log10(DN) (to dB)
    '''
    inarr.loc[{'bands': S1_bands}] = 10. * np.log10(
        inarr.sel(bands=S1_bands))
    inarr.loc[{'bands': S2_bands}] = inarr.sel(
        bands=S2_bands) / 10000.

    # Stack bands and tsteps into one "feature" dimension
    stacked_array = inarr.stack(
        feature=("bands", "t")).transpose('feature', 'x', 'y')

    # Construct the feature names the models expect
    bands10m = ["B02", "B03", "B04", "B08"]
    ft_labels = []
    for label in stacked_array.feature.values:
        ft_labels.append(
            (f'{label[0]}-{label[1]}' + '-10m' if label[0]
             in bands10m else f'{label[0]}-{label[1]}' + '-20m'))
    ft_labels_final = []
    for label in ft_labels:
        ft_labels_final.append((
            f'L2A-{label}' if 'B' in label else f'SIGMA0-{label}'
        ))

    # Create satio Features from the DataArray
    inputfeatures = Features(data=stacked_array.values,
                             names=ft_labels_final)

    # --------------------------------------------------------------------------
    # STAGE 1: CREATE CROPLAND MASK
    # --------------------------------------------------------------------------

    # Define the cropland model and encoder
    cropland_model = '/data/users/Public/kristofvt/NEXTLAND/models/cropland_detector_WorldCerealPixelLSTM_11Bands_v050/config.json'  # NOQA
    cropland_encoder = '/data/users/Public/kristofvt/NEXTLAND/models/cropland_detector_WorldCerealPixelLSTM_11Bands_v050/label_encoder.p'  # NOQA

    # Load the cropland model
    cropland_model = load_model(cropland_model, 'cropland_model')

    # Create a WorldCerealClassifier from the model
    croplandclassifier = CropclassClassifier(
        cropland_model,
        filtersettings={'kernelsize': 7, 'conf_threshold': 0.8},
        encoder=cropland_encoder)

    # Run cropland model
    prediction, confidence = croplandclassifier.predict(inputfeatures,
                                                        nodatavalue=255)

    # Transform landcover in a cropland mask
    mask = np.zeros_like(prediction)
    mask[prediction == 11] = 1

    # --------------------------------------------------------------------------
    # STAGE 2: IDENTIFY CROPS WITHIN MASK
    # --------------------------------------------------------------------------

    # Define the croptype model and encoder
    croptype_model = '/data/users/Public/kristofvt/NEXTLAND/models/croptype_detector_WorldCerealPixelLSTM_11Bands_v050/config.json'  # NOQA
    croptype_encoder = '/data/users/Public/kristofvt/NEXTLAND/models/croptype_detector_WorldCerealPixelLSTM_11Bands_v050/label_encoder.p'  # NOQA

    # Load the croptype model
    croptype_model = load_model(croptype_model, 'croptype_model')

    # Create a WorldCerealClassifier from the model
    croptypeclassifier = CropclassClassifier(
        croptype_model,
        filtersettings={'kernelsize': 7, 'conf_threshold': 0.8},
        maskdata=mask,
        encoder=croptype_encoder)

    # Run crop type model, put masked regions to nodata
    prediction, confidence = croptypeclassifier.predict(
        inputfeatures,
        nodatavalue=0)

    # Finally transform result to DataArray
    result_da = xr.DataArray(prediction, coords=[
        inarr.coords["x"],
        inarr.coords["y"]], dims=["x", "y"])

    return result_da


def apply_datacube(cube: XarrayDataCube, context: Dict) -> XarrayDataCube:

    # Extract xarray.DataArray from the cube
    inarr = cube.get_array()

    # Run the two-stage classification
    predictions = classify(inarr, **context)

    # Wrap result in an OpenEO datacube
    return XarrayDataCube(predictions)


def load_cropclass_classify_udf() -> str:
    import os
    with open(os.path.realpath(__file__), 'r+') as f:
        return f.read()
