# -*- coding: utf-8 -*-
# Uncomment the import only for coding support
from openeo.udf import XarrayDataCube
from typing import Dict
import xarray as xr
import numpy as np
import threading
import functools
from pathlib import Path
import sys
from typing import List

_threadlocal = threading.local()

DEFAULTMODELDIR = Path('/vitodata/EEA_HRL_VLCC/data/inference/crop_type/models/')  # NOQA
DEFAULTMODELTAG = '20221102T184737-transformer_optical_dem'
DEFAULT_HRL_DEP = '/data/users/Public/couchard/hrl_dep/'

SENTINEL1_AS_SHORTS = "rescale_s1"


def _setup_logging():
    global logger
    from loguru import logger


def get_preproc_settings():
    from cropclass.features import divide_10000, divide_100, to_db
    return {
        'S2':
            {
                'dtype': np.float32,
                'pre_func': divide_10000,
                'interpolate': True
            },
        'S1':
            {
                'dtype': np.float32,
                'interpolate': True,
                'post_func': to_db
            },
        'METEO':
            {
                'dtype': np.float32,
                'pre_func': divide_100,
                'interpolate': False  # Should not be needed
            },
    }


def add_dem_features(features, dem_data):
    from satio.utils import dem_attrs
    from satio.features import Features

    dem_data[dem_data < -10000] = np.nan

    slope, _ = dem_attrs(dem_data)

    dem_arr = np.array([dem_data, slope])
    dem_feats = Features(dem_arr, ['DEM-alt-20m', 'DEM-slo-20m'])

    return features.merge(dem_feats)


@functools.lru_cache(maxsize=25)
def load_croptype_model(modeldir: str, modeltag: str):

    from vito_crop_classification.model import Model
    from cropclass.utils import download_and_unpack

    if modeldir.startswith('http'):
        modeldir, modeltag = download_and_unpack(modeldir, modeltag)

    # load a model from its tag
    model = Model.load(mdl_f=Path(modeldir) / modeltag)

    # Get classes that will be mapped
    class_names = model.get_class_names()

    return model, class_names


def _make_prediction(features, modeldir, modeltag, worldcovermask=True,
                     ignore_classes=None, all_probabilities=False):
    from sklearn import preprocessing
    from vito_crop_classification import run_prediction

    model, classes = load_croptype_model(modeldir, modeltag)

    df = features.df
    nan_indices = df.isna().any(axis=1)

    # initialize predictor over new data
    logger.info('Running Predictor.')
    import torch
    with torch.no_grad():
        pred_columns = ['prediction_id', 'prediction_name', 'probability',
                        'probabilities', 'embedding']
        df.loc[~nan_indices, pred_columns] = run_prediction(
            df=df[~nan_indices],
            model=model,
            patch_smoothing=False,
            transform=True,
            ignore_classes=ignore_classes
        ).set_index(nan_indices[~nan_indices].index)

    # Set NaN strings where we don't have prediction for later
    # translation to raster values
    df.loc[nan_indices, ['prediction_id', 'prediction_name']] = 'NaN'
    df.loc[nan_indices, ['prediction_id', 'probability']] = 0
    df.loc[nan_indices, ['prediction_id', 'probabilities']] = 0  # TODO: CHECK!

    # Get prediction and probability arrays
    prediction = df['prediction_id'].values
    probability = df['probability'].values
    probabilities = df['probabilities'].values

    # Rounding to the closest integer, obtaining a value between 0 and 100 encoded in bytes
    probability = ((probability * 100.0) + 0.5).astype(np.uint8)
    probabilities = ((probabilities * 100.0) + 0.5).astype(np.uint8)

    if worldcovermask:
        logger.info('Applying WorldCover mask.')
        worldcoverlabel = features.df['WORLDCOVER-label-10m']
        prediction[(df['probability'].isnull()) |
                   (~worldcoverlabel.isin([30, 40]))] = "NoCrop"

    if all_probabilities:
        return (prediction.reshape((1, *features.data.shape[1:3])),
                # TODO change
                probabilities.reshape((0, *features.data.shape[1:3])))

    return (prediction.reshape((1, *features.data.shape[1:3])),
            probability.reshape((1, *features.data.shape[1:3])))


def _compute_features(inarr: xr.DataArray,
                      startdate: str,
                      enddate: str,
                      n_tsteps: int = None,
                      worldcovermask: bool = False,
                      **kwargs) -> xr.DataArray:
    from cropclass.features import FeaturesComputer
    from cropclass.features.settings import get_feature_settings
    from satio.features import Features

    # Fix to see if the logger is defined or not, as the compute_features
    # function can be called alone
    try:
        logger
    except NameError:
        _setup_logging()

    # Make sure we have a controlled dimension order
    inarr = inarr.transpose('bands', 't', 'x', 'y')

    # Initialize features computer
    fc = FeaturesComputer(
        get_feature_settings(n_tsteps),
        startdate,
        enddate,
        preprocessing_settings=get_preproc_settings()
    )

    # Identify the S1, S2 and METEO bands
    bands = list(inarr.bands.values)
    S2_bands = [b for b in bands if 'B' in b]
    S1_bands = [b for b in bands if 'V' in b]

    # Only the temperature-mean band is useful now
    METEO_bands = [b for b in bands if 'temperature_mean' in b]

    # Construct sensor-specific arrays.
    sensor_arrays = {'S2': inarr.sel(bands=S2_bands)}

    # Add S1 features, if exists
    if (len(S1_bands)) > 0:
        if kwargs.get(SENTINEL1_AS_SHORTS,False):
            sensor_arrays['S1'] = inarr.sel(bands=S1_bands)
        else:
            sensor_arrays['S1'] = 10**(((inarr.sel(bands=S1_bands)/1000.0)-30.0)/10.0)

    # Add Meteo features, if exists
    if len(METEO_bands) > 0:
        sensor_arrays['METEO'] = inarr.sel(bands=METEO_bands)

    # Compute features
    features = fc.get_features(sensor_arrays)

    # Add DEM
    features = add_dem_features(features, inarr.sel(
        bands='DEM').max(dim='t').values)

    # Add worldcover
    if worldcovermask:
        worldcover_feats = Features(inarr.sel(bands='MAP').max(
            dim='t').values, ['WORLDCOVER-label-10m'])
        features = features.merge(worldcover_feats)

    return features


def classify(
        inarr: xr.DataArray,
        startdate: str,
        enddate: str,
        modeltag: str = DEFAULTMODELTAG,
        modeldir: Path = DEFAULTMODELDIR,
        n_tsteps: int = None,
        worldcovermask: bool = False,
        ignore_classes: List[str] = None,
        custom_dependency_path: str = DEFAULT_HRL_DEP,**processing_opts) -> xr.DataArray:
    """Main inference classification UDF for OpenEO workflow.
    This function takes an input xarray fetched through OpenEO
    and returns a result xarray back to OpenEO for saving.

    Args:
        inarr (xr.DataArray): preprocessed 
            input arrays for the different sensors
        startdate (str): start date for the feature computation
            (yyyy-mm-dd)
        enddate (str): end date for the feature computation
            (yyyy-mm-dd)
        n_tsteps (int, optional): Number of t_steps to extract
            from timeseries
        worldcovermask (bool, optional): Whether or not to apply
            mask based on worldcover 30 and 40 values (grass and crop).
            Pixels outside this mask will get value 255 in
            output product. Defaults to False.
        ignore_classes (str, optional): pass list of classes to ignore
        custom_dependency_path (str, optional: optional path to be
            added in front of the python path

    Returns:
        xr.DataArray: resulting classification xarray.
    """

    if custom_dependency_path is not None:
        sys.path.insert(0, str(custom_dependency_path))
    import cropclass
    import vito_crop_classification
    from loguru import logger
    logger.info(f'Using cropclass library from: {cropclass.__file__}')
    logger.info(('Using vito_crop_classification '
                 f'library from: {vito_crop_classification.__file__}'))

    # Store the original dimension order for later
    orig_dims = list(inarr.dims)
    orig_dims.remove('t')  # Time dimension will be removed after the inference

    features = _compute_features(
        inarr, startdate, enddate, n_tsteps, worldcovermask,**processing_opts)

    # Make the crop type prediction
    logger.info(f'Using model: {modeltag}')
    prediction, probability = _make_prediction(
        features,
        modeldir,
        modeltag=modeltag,
        worldcovermask=worldcovermask,
        ignore_classes=ignore_classes
    )

    # Zip the prediction and probability arrays together
    predicted_data = np.array([prediction.squeeze(),
                               probability.squeeze()])

    # Finally transform result to DataArray
    result_da = xr.DataArray(predicted_data, coords=[
        ['croptype', 'probability'],
        inarr.coords["x"],
        inarr.coords["y"]], dims=["bands", "x", "y"])

    # result_da = stacked_array.assign_coords(
    #     {"feature": ft_labels_final}).rename(feature='bands')

    # And make sure we revert back to original dimension order
    result_da = result_da.transpose(*orig_dims)

    return result_da


def apply_datacube(cube: XarrayDataCube, context: Dict) -> XarrayDataCube:
    sys.path.insert(0, 'tmp/venv_static')
    sys.path.insert(0, 'tmp/venv')

    from cropclass.postprocessing.layer_format import translate_layer

    _setup_logging()

    # Extract xarray.DataArray from the cube
    inarr = cube.get_array()

    # Run the two-stage classification
    predictions = classify(inarr, **context)

    translated_predictions = translate_layer(predictions)

    # Wrap result in an OpenEO datacube
    return XarrayDataCube(translated_predictions)


def load_cropclass_udf() -> str:
    import os
    with open(os.path.realpath(__file__), 'r+') as f:
        return f.read()
