# -*- coding: utf-8 -*-
# Uncomment the import only for coding support
from openeo.udf import XarrayDataCube
from typing import Dict
import xarray as xr
import numpy as np
import threading
import functools
from pathlib import Path
import sys
import os

if os.environ.get('IGNORE_KRISTOF_DEPENDENCIES', 'NO') == 'NO':
    sys.path.insert(0, r'/data/users/Public/kristofvt/python/dep')

_threadlocal = threading.local()

DEFAULTMODELDIR = Path('/data/EEA_HRL_VLCC/data/inference/crop_type/models/')  # NOQA
DEFAULTMODELTAG = 'model_oldDs_noSAR_noMETEO'
DEFAULT_HRL_DEP = '/data/users/Public/kristofvt/python/hrl_dep/'


def _setup_logging():
    global logger
    from loguru import logger


def get_preproc_settings():
    from cropclass.features import divide_10000, divide_100, to_db
    return {
        'S2':
            {
                'dtype': np.float32,
                'pre_func': divide_10000,
                'interpolate': True
            },
        'S1':
            {
                'dtype': np.float32,
                'pre_func': to_db,
                'interpolate': True
            },
        'METEO':
            {
                'dtype': np.float32,
                'pre_func': divide_100,
                'interpolate': False  # Should not be needed
            },
    }


def add_dem_features(features, dem_data):
    from satio.utils import dem_attrs
    from satio.features import Features

    dem_data[dem_data < -10000] = np.nan

    slope, _ = dem_attrs(dem_data)

    dem_arr = np.array([dem_data, slope])
    dem_feats = Features(dem_arr, ['DEM-alt-20m', 'DEM-slo-20m'])

    return features.merge(dem_feats)


@functools.lru_cache(maxsize=25)
def load_croptype_model(modeldir: str, modeltag: str):

    from vito_crop_classification.model import Model

    # load a model from its tag
    model = Model.load(mdl_f=Path(modeldir) / modeltag)

    # Get classes that will be mapped
    class_names = model.get_class_names()

    return model, class_names


def _make_prediction(features, modeldir, modeltag, worldcovermask=True):
    from sklearn import preprocessing
    from vito_crop_classification import run_prediction

    model, classes = load_croptype_model(modeldir, modeltag)

    # Fit encoder
    logger.info('Getting label encoder ...')
    le = preprocessing.LabelEncoder()
    le.fit(classes)

    # TODO: remove
    logger.info(f'Input columns before prediction: {features.df.columns}')

    # initialize predictor over new data
    logger.info('Running Predictor.')
    preds = run_prediction(
        df=features.df,
        model=model,
        patch_smoothing=True,
        patch_shape=(features.data.shape[1], features.data.shape[2]),
        transform=True
    )

    prediction = preds['prediction_name'].values
    probability = preds['probability'].values

    idx_valid = ~(prediction == 'NaN')

    prediction[idx_valid] = le.transform(prediction[idx_valid])
    prediction[prediction == 'NaN'] = 255
    prediction = prediction.astype(np.uint8)

    # Rounding to the closest integer, obtaining a value between 0 and 100 encoded in bytes
    probability = ((probability * 100.0) + 0.5).astype(np.uint8)

    if worldcovermask:
        logger.info('Applying WorldCover mask.')
        worldcoverlabel = features.df['WORLDCOVER-label-10m']
        prediction[(preds['probability'].isnull()) |
                   (~worldcoverlabel.isin([30, 40]))] = 255

    return prediction.reshape((1, *features.data.shape[1:3])), probability.reshape((1, *features.data.shape[1:3]))


def compute_features(inarr: xr.DataArray,
                     startdate: str,
                     enddate: str,
                     n_tsteps: int = None,
                     worldcovermask: bool = False,
                     **kwargs) -> xr.DataArray:
    from cropclass.features import FeaturesComputer
    from cropclass.features.settings import get_feature_settings
    from satio.features import Features

    # Fix to see if the logger is defined or not, as the compute_features function can be called alone
    try:
        logger
    except NameError:
        _setup_logging()

    # Make sure we have a controlled dimension order
    inarr = inarr.transpose('bands', 't', 'x', 'y')

    # Initialize features computer
    fc = FeaturesComputer(
        get_feature_settings(n_tsteps),
        startdate,
        enddate,
        preprocessing_settings=get_preproc_settings()
    )

    # Identify the S1, S2 and METEO bands
    bands = list(inarr.bands.values)
    S2_bands = [b for b in bands if 'B' in b]
    S1_bands = [b for b in bands if 'V' in b]
    METEO_bands = [b for b in bands if 'temperature-mean' in b]  # Only the temperature-mean band is useful now

    # Construct sensor-specific arrays.
    # TODO: add other sensors
    sensor_arrays = {'S2': inarr.sel(bands=S2_bands)}

    # Add S1 features, if exists
    if (len(S1_bands)) > 0:
        sensor_arrays['S1'] = inarr.sel(bands=S1_bands)

    # Add Meteo features, if exists
    if len(METEO_bands) > 0:
        sensor_arrays['METEO'] = inarr.sel(bands=METEO_bands)

    # Compute features
    features = fc.get_features(sensor_arrays)

    # Add DEM
    features = add_dem_features(features, inarr.sel(
        bands='DEM').max(dim='t').values)

    # Add worldcover
    if worldcovermask:
        worldcover_feats = Features(inarr.sel(bands='MAP').max(
            dim='t').values, ['WORLDCOVER-label-10m'])
        features = features.merge(worldcover_feats)

    return features


def classify(
        inarr: xr.DataArray,
        startdate: str,
        enddate: str,
        modeltag: str = DEFAULTMODELTAG,
        modeldir: Path = DEFAULTMODELDIR,
        n_tsteps: int = None,
        worldcovermask: bool = False,
        custom_dependency_path: str = DEFAULT_HRL_DEP) -> xr.DataArray:
    """Main inference classification UDF for OpenEO workflow.
    This function takes an input xarray fetched through OpenEO
    and returns a result xarray back to OpenEO for saving.

    Args:
        inarr (xr.DataArray): preprocessed 
            input arrays for the different sensors
        startdate (str): start date for the feature computation
            (yyyy-mm-dd)
        enddate (str): end date for the feature computation
            (yyyy-mm-dd)
        n_tsteps (int, optional): Number of t_steps to extract
            from timeseries
        worldcovermask (bool, optional): Whether or not to apply
            mask based on worldcover 30 and 40 values (grass and crop).
            Pixels outside this mask will get value 255 in
            output product. Defaults to False.
        custom_dependency_path (str, optional: optional path to be
            added in front of the python path

    Returns:
        xr.DataArray: resulting classification xarray.
    """
    if custom_dependency_path is not None:
        sys.path.insert(0, str(custom_dependency_path))
    import cropclass
    import vito_crop_classification
    from loguru import logger
    logger.info(f'Using cropclass library from: {cropclass.__file__}')
    logger.info(('Using vito_crop_classification '
                 f'library from: {vito_crop_classification.__file__}'))

    # Store the original dimension order for later
    orig_dims = list(inarr.dims)
    orig_dims.remove('t')  # Time dimension will be removed after the inference

    features = compute_features(inarr, startdate, enddate, n_tsteps, worldcovermask)

    # Make the crop type prediction
    logger.info(f'Using model: {modeltag}')
    prediction, probability = _make_prediction(
        features,
        modeldir,
        modeltag=modeltag,
        worldcovermask=worldcovermask
    )

    # Zip the prediction and probability arrays together
    predicted_data = np.array([prediction.squeeze(),
                               probability.squeeze()])

    # Finally transform result to DataArray
    result_da = xr.DataArray(predicted_data, coords=[
        ['cropclass_label', 'probability'],
        inarr.coords["x"],
        inarr.coords["y"]], dims=["bands", "x", "y"])

    # result_da = stacked_array.assign_coords(
    #     {"feature": ft_labels_final}).rename(feature='bands')

    # And make sure we revert back to original dimension order
    result_da = result_da.transpose(*orig_dims)

    return result_da


def apply_datacube(cube: XarrayDataCube, context: Dict) -> XarrayDataCube:

    _setup_logging()

    # Extract xarray.DataArray from the cube
    inarr = cube.get_array()

    # Run the two-stage classification
    predictions = classify(inarr, **context)

    # Wrap result in an OpenEO datacube
    return XarrayDataCube(predictions)


def load_cropclass_udf() -> str:
    import os
    with open(os.path.realpath(__file__), 'r+') as f:
        return f.read()
