# -*- coding: utf-8 -*-
# Uncomment the import only for coding support
from openeo.udf import XarrayDataCube
from typing import Dict
import xarray as xr
import numpy as np
import threading
import functools
import datetime
from pathlib import Path
import sys
import pandas as pd

from typing import List

_threadlocal = threading.local()

DEFAULTMODELDIR = Path('/vitodata/EEA_HRL_VLCC/data/inference/crop_type/models/')  # NOQA
DEFAULTMODELTAG = '20221102T184737-transformer_optical_dem'


def _setup_logging():
    global logger
    try:
        from loguru import logger
    except ImportError:
        import logging
        logger = logging.getLogger(__name__)


def scaled_s1_to_pwr(scaled_db):
    from cropclass.features import to_pwr
    '''Method to remove the scaling from S1 data
    fetched from openeo preprocessing workflow
    and converting to power
    '''

    result = to_pwr(20 * np.log10(scaled_db) - 83)
    result[~np.isfinite(result)] = np.nan

    return result


def get_preproc_settings():
    from cropclass.features import divide_100, divide_10000, to_db  # NOQA
    return {
        'S2':
            {
                'dtype': np.float32,
                'src_nodata': 65535,  # Nodata value for UINT16 in OpenEO
                'dst_nodata': np.nan,
                'pre_func': divide_10000,
                'interpolate': True
            },
        'S1':
            {
                'dtype': np.float32,
                'src_nodata': 65535,  # Nodata value for UINT16 in OpenEO
                'dst_nodata': np.nan,
                'bands': ['VV', 'VH'],
                'pre_func': scaled_s1_to_pwr,
                'post_func': to_db,
                'interpolate': True
            },
        'METEO':
            {
                'dtype': np.float32,
                'src_nodata': 65535,  # Nodata value for UINT16 in OpenEO
                'dst_nodata': np.nan,
                'pre_func': divide_100,
                'interpolate': False  # Should not be needed
            },
    }


def convert_meteojson_to_xarray(xrarray_frame, meteo_json, startdate, enddate):
    """
    gets all data from meteo_json temoral filter this data to startdata
    and endate. Calculates the spatial average (with nodata values masked.)
    and pushes it into a xr-array.
    returns the data in the xarray
    @param xrarray_frame:  xarray frame which should be have the correct coords
    @param meteo_json: input json with the meteo data from agera5 as extracted
            by preprocessing_agera5 script
    @param startdate: the startdate of the data that is needed
    @param enddate: the enddate of the data that is needed
    @return: xarray
    """

    # aggregate data
    meteo_data = meteo_json.get("data")
    dates = [datetime.datetime.strptime(date, "%Y-%m-%d %H:%M:%S")
             for date in meteo_json.get("coords").get("t").get("data")]
    masked_meteo = np.ma.masked_values(
        meteo_data, meteo_json.get("attrs").get("nodata"))
    aggregate_meteo = masked_meteo.mean(axis=(1, 2, 3), keepdims=False)
    aggregate_meteo.filled(fill_value=meteo_json.get("attrs").get("nodata"))
    # filter based on timings
    enddate = pd.to_datetime(enddate)
    startdate = pd.to_datetime(startdate)
    filter = [date >= startdate and date < enddate for date in dates]
    aggregate_meteo = aggregate_meteo[filter]
    # tile and assign to xarray
    res_agg = aggregate_meteo.reshape((1, len(aggregate_meteo), 1, 1))
    tiled_data = np.tile(
        res_agg, (1, 1, len(xrarray_frame['x']), len(xrarray_frame['y'])))
    xarray = xr.ones_like(xrarray_frame)
    xarray.data = tiled_data
    xarray = xarray.assign_coords(
        {'bands': meteo_json.get("coords").get("bands").get("data")})

    return xarray


def add_dem_features(features, dem_data):
    from satio.utils import dem_attrs
    from satio.features import Features
    from cropclass.features.inference import UnexpectedRangeError

    dem_data = dem_data.astype(np.float32)
    dem_data[dem_data < -10000] = np.nan

    MAX_ALLOWED_ALTITUDE = 8000
    MIN_ALLOWED_ALTITUDE = -50

    # Some checks on valid DEM data
    if np.nanmax(dem_data) > MAX_ALLOWED_ALTITUDE:
        raise UnexpectedRangeError(
            ('Detected unrealistic DEM data: '
             f'{np.nanmax(dem_data)} is higher '
             f'than max allowed {MAX_ALLOWED_ALTITUDE}'))
    if np.nanmin(dem_data) < MIN_ALLOWED_ALTITUDE:
        raise UnexpectedRangeError(
            ('Detected unrealistic DEM data: '
             f'{np.nanmin(dem_data)} is lower '
             f'than min allowed {MIN_ALLOWED_ALTITUDE}'))

    # To compute slope, nodata value is expected to be -9999
    slope, _ = dem_attrs(np.where(np.isfinite(dem_data), dem_data, -9999))
    slope[slope == -9999] = np.nan

    dem_arr = np.array([dem_data, slope])
    dem_feats = Features(dem_arr, ['DEM-alt-20m', 'DEM-slo-20m'])

    return features.merge(dem_feats)


@functools.lru_cache(maxsize=25)
def load_croptype_model(modeldir: str, modeltag: str):

    from vito_crop_classification.model import Model
    from cropclass.utils import download_and_unpack

    if str(modeldir).startswith('http'):
        modeldir, modeltag = download_and_unpack(modeldir, modeltag)

    # load a model from its tag
    model = Model.load(mdl_f=Path(modeldir) / modeltag)

    # Get classes that will be mapped
    class_names = model.get_class_names()

    return model, class_names


def _make_prediction(features, modeldir, modeltag,
                     apply_worldcovermask=True, ignore_classes=[]):
    from vito_crop_classification import run_prediction

    model, classes = load_croptype_model(modeldir, modeltag)

    df = features.df
    valid_indices = df[~df.isna().any(axis=1)].index
    df_filter = df.loc[df.index.isin(valid_indices)]
    df_filter.reset_index(inplace = True)

    pred_columns = ['prediction_id', 'prediction_name', 'probability',
                    'probabilities', 'embedding']

    df = df.reindex(columns=[*df.columns,*pred_columns])

    # initialize predictor over new data
    logger.info('Running Predictor.')
    import torch
    with torch.no_grad():
        df.loc[df.index.isin(valid_indices), pred_columns] = run_prediction(
            df=df_filter,
            model=model,
            patch_smoothing=False,
            transform=True,
            ignore_classes=ignore_classes).set_index(valid_indices)

    # Set NaN strings where we don't have prediction for later
    # translation to raster values
    df.loc[~df.index.isin(valid_indices),
           ['prediction_id', 'prediction_name']] = 'NaN'
    df.loc[~df.index.isin(valid_indices), 'probability'] = 0
    df.loc[~df.index.isin(valid_indices),
           'probabilities'] = (
        df.loc[~df.index.isin(valid_indices), 'probabilities'].apply(
            lambda _: np.zeros(len(classes), dtype=np.float32)))

    # Get prediction and probability arrays
    prediction = df['prediction_id'].values
    probability = df['probability'].values
    probabilities = df['probabilities'].values

    # Rounding to the closest integer, obtaining a value between
    # 0 and 100 encoded in bytes
    probability = ((probability * 100.0) + 0.5).astype(np.uint8)
    probabilities = np.moveaxis(
        (np.stack(probabilities) * 100.0 + 0.5).astype(np.uint8), 0, 1
    )

    if apply_worldcovermask:
        logger.info('Applying WorldCover mask.')
        worldcoverlabel = features.df['WORLDCOVER-label-10m']
        prediction[~(worldcoverlabel.isin([30, 40]))] = "NoCrop"

    return (
        prediction.reshape(1, *features.data.shape[1:3]),
        probability.reshape(1, *features.data.shape[1:3]),
        probabilities.reshape(len(classes), *features.data.shape[1:3]),
        classes
    )


def _compute_features(inarr: xr.DataArray,
                      startdate: str,
                      enddate: str,
                      n_tsteps: int,
                      apply_worldcovermask: bool = False,
                      segmentation: bool = False,
                      segm_settings: Dict = None,
                      outofrange_toleranceratio: np.float32 = 0.001,
                      **kwargs) -> xr.DataArray:
    from cropclass.features import FeaturesComputer
    from cropclass.features.settings import get_feature_settings
    from satio.features import Features

    # Fix to see if the logger is defined or not, as the compute_features
    # function can be called alone
    try:
        logger
    except NameError:
        _setup_logging()

    # Make sure we have a controlled dimension order
    inarr = inarr.transpose('bands', 't', 'x', 'y')

    # Add meteo band if needed
    bands = list(inarr.bands.values)
    if 'METEO_data' in kwargs:
        if kwargs.get('METEO_data') is not None:
            meteo_array = convert_meteojson_to_xarray(
                inarr.sel(bands=[bands[0]]), kwargs.get('METEO_data'),
                startdate, enddate)
            inarr = xr.concat([inarr, meteo_array], dim="bands")

    # Initialize features computer
    fc = FeaturesComputer(
        get_feature_settings(n_tsteps),
        startdate,
        enddate,
        preprocessing_settings=get_preproc_settings(),
        segmentation=segmentation,
        segm_settings=segm_settings,
        outofrange_toleranceratio=outofrange_toleranceratio
    )

    # Identify the S1, S2 and METEO bands
    bands = list(inarr.bands.values)
    S2_bands = [b for b in bands if 'B' in b]
    S1_bands = [b for b in bands if 'V' in b]

    # Only the temperature-mean band is useful now
    METEO_bands = [b for b in bands if 'temperature_mean' in b]

    # Construct sensor-specific arrays.
    sensor_arrays = {'S2': inarr.sel(bands=S2_bands)}

    # Add S1 features, if exists
    # at this point they will still be scaled!
    if (len(S1_bands)) > 0:
        sensor_arrays['S1'] = inarr.sel(bands=S1_bands)

    # Add Meteo features, if exists
    if len(METEO_bands) > 0:
        sensor_arrays['METEO'] = inarr.sel(bands=METEO_bands)

    # Compute features
    features = fc.get_features(sensor_arrays)

    # Add DEM
    features = add_dem_features(features, inarr.sel(
        bands='DEM').max(dim='t').values)

    # Add worldcover
    if apply_worldcovermask:
        worldcover_feats = Features(inarr.sel(bands='MAP').max(
            dim='t').values, ['WORLDCOVER-label-10m'])
        features = features.merge(worldcover_feats)

    return features


def classify(
        inarr: xr.DataArray,
        start_month: int,
        end_month: int,
        year: int,
        modeltag: str = DEFAULTMODELTAG,
        modeldir: Path = DEFAULTMODELDIR,
        apply_worldcovermask: bool = False,
        ignore_classes: List[str] = [],
        all_probabilities: bool = False,
        **processing_opts) -> xr.DataArray:
    """Main inference classification UDF for OpenEO workflow.
    This function takes an input xarray fetched through OpenEO
    and returns a result xarray back to OpenEO for saving.

    Args:
        inarr (xr.DataArray): preprocessed
            input arrays for the different sensors
        start_month (int): start month for the feature computation
        end_month (int): end month for the feature computation
        year (int): Year in which mapped season ends.
        apply_worldcovermask (bool, optional): Whether or not to apply
            mask based on worldcover 30 and 40 values (grass and crop).
            Pixels outside this mask will get value 255 in
            output product. Defaults to False.
        ignore_classes (str, optional): pass list of classes to ignore
        custom_dependency_path (str, optional: optional path to be
            added in front of the python path

    Returns:
        xr.DataArray: resulting classification xarray.
    """
    import cropclass
    from cropclass.utils.seasons import get_processing_dates
    from cropclass.features.settings import compute_n_tsteps
    import vito_crop_classification

    try:
        logger
    except NameError:
        _setup_logging()

    logger.info(f'Using cropclass library from: {cropclass.__file__}')
    logger.info(('Using vito_crop_classification '
                 f'library from: {vito_crop_classification.__file__}'))

    # Infer exact processing dates from months and year
    startdate, enddate = get_processing_dates(start_month, end_month, year)
    logger.info(f'Inferred processing date: {startdate} - {enddate}')

    # Infer required amount of tsteps
    n_tsteps = compute_n_tsteps(start_month, end_month)
    logger.info(f'Inferred n_tsteps: {n_tsteps}')

    # Store the original dimension order for later
    orig_dims = list(inarr.dims)
    orig_dims.remove('t')  # Time dimension will be removed after the inference

    features = _compute_features(inarr, startdate, enddate, n_tsteps,
                                 apply_worldcovermask, **processing_opts)

    # Make the crop type prediction
    logger.info(f'Using model: {modeltag}')
    prediction, probability, cls_probabilities, cls_names = _make_prediction(
        features,
        modeldir,
        modeltag=modeltag,
        apply_worldcovermask=apply_worldcovermask,
        ignore_classes=ignore_classes
    )

    features = None

    if not all_probabilities:
        # Zip the prediction and probability arrays together
        predicted_data = np.array([prediction.squeeze(),
                                   probability.squeeze()])

        # Finally transform result to DataArray
        result_da = xr.DataArray(
            predicted_data,
            coords=[
                ['croptype', 'probability'],
                inarr.coords["x"],
                inarr.coords["y"]
            ],
            dims=["bands", "x", "y"]
        )
    else:
        # Adds the prediciton and winning proba bands, then the class
        # probabilities one by one
        band_list = [prediction.squeeze(), probability.squeeze()]
        for prob_idx in range(cls_probabilities.shape[0]):
            band_list.append(cls_probabilities[prob_idx].squeeze())
        band_coords = ['croptype', 'probability']
        for cls_name in cls_names:
            band_coords.append('probability_' +
                               cls_name.replace(' ', '_').replace(',', ''))

        predicted_data = np.array(band_list)

        result_da = xr.DataArray(
            predicted_data,
            coords=[
                band_coords,
                inarr.coords['x'],
                inarr.coords['y']
            ],
            dims=['bands', 'x', 'y']
        )

    # result_da = stacked_array.assign_coords(
    #     {"feature": ft_labels_final}).rename(feature='bands')

    # And make sure we revert back to original dimension order
    result_da = result_da.transpose(*orig_dims)

    return result_da


def apply_datacube(cube: XarrayDataCube, context: Dict) -> XarrayDataCube:
    sys.path.insert(0, 'tmp/venv_static')
    sys.path.insert(0, 'tmp/cropclasslib')
    sys.path.insert(0, 'tmp/vitocropclassification')

    from cropclass.postprocessing.layer_format import translate_layer

    _setup_logging()

    # Extract xarray.DataArray from the cube
    inarr = cube.get_array()

    # Run the two-stage classification
    predictions = classify(inarr, **context)

    translated_predictions = translate_layer(
        predictions, translation_table=context.get('translation_table', None))

    # Wrap result in an OpenEO datacube
    return XarrayDataCube(translated_predictions)
