# -*- coding: utf-8 -*-
# Uncomment the import only for coding support
from openeo.udf import XarrayDataCube
from typing import Dict
import xarray as xr
import numpy as np
import threading
import functools
import pandas as pd
from pathlib import Path
import sys
from openeo.udf.debug import inspect
import os

_threadlocal = threading.local()

def _setup_logging():
    global logger
    from loguru import logger

def take_best_acquisition(x, n_steps=6):
    '''Method to compute tsteps features.
    '''

    def _take_best(group: xr.DataArray):
        best_image = group.isel(t=group.notnull().sum(
            dim=('x', 'y')).argmax())
        return best_image

    da = xr.DataArray(x, dims=('t', 'x', 'y'))
    da_resampled = da.groupby_bins('t', bins=n_steps).map(_take_best)

    return da_resampled.values

def get_preproc_settings(n_tsteps):
    return {
        'S2':
            {
                "tsteps": {
                    "function": take_best_acquisition,
                    "parameters": {
                        'n_steps': n_tsteps,
                    },
                    "names": [f'ts{i}' for i in range(n_tsteps)]
                },
            }
    }



@functools.lru_cache(maxsize=25)
def load_delineation_model(modeldir: str, modeltag: str, cache_dir: str):
    import os
    from vito_lot_delineation import SemanticModel
    from fielddelineation.utils import download_and_unpack

    if modeldir.startswith('http'):
        modeldir, modeltag = download_and_unpack(modeldir, modeltag, cache_dir,
                                                 format=os.path.splitext(modeltag)[1].replace('.',  ''))

    # load a model from its tag
    model = SemanticModel.load(mdl_f=Path(modeldir) / modeltag.replace('.zip', ''))

    return model


def _apply_delineation(features, modeldir, modeltag, cache_dir):
    from vito_lot_delineation import run_prediction
    from vito_lot_delineation.inference import preprocess_raw

    logger.info('******* Start actual delineation ********')
    # load the pre-trained model
    model = load_delineation_model(modeldir, modeltag, cache_dir)
    logger.info(f'Delineation model loaded: {modeltag}')

    # preprocess satellite dataset before fetching into a batch
    logger.info('Preprocessing of features and transforming it to Torch sensor')
    sample = preprocess_raw(sample=features, # input satellite features
                            cfg=model.cfg["input"]) # configuration directory
    # create batch for feeding into model
    batch = sample.unsqueeze(0).repeat(1, 1, 1, 1, 1)
    logger.info(f'Created batch sample of shape: {batch.shape}')

    logger.info(f'Apply model prediction')
    # run now actually the prediction
    preds = run_prediction(
        model=model, # model to use for prediction
        batch=batch # batch of images over which want to predict
    )
    logger.info(f'Model prediction succeeded')
    # keep only the instance segmentation
    preds_instance = preds['instance'].numpy()

    # for the moment reduce outcome from 3D to 2D
    if len(preds_instance.shape) > 2:
        logger.info('Model output is 3D --> convert to 2D')
        preds_instance = preds_instance[0, :, :]

    return preds_instance

def _compute_features(inarr: xr.DataArray,
                      startdate: str,
                      enddate: str,
                      n_tsteps: int,
                      **kwargs) -> xr.DataArray:
    from cropclass.features import FeaturesComputer

    # Fix to see if the logger is defined or not, as the compute_features
    # function can be called alone
    try:
        logger
    except NameError:
        _setup_logging()

    # Make sure we have a controlled dimension order
    inarr = inarr.transpose('bands', 't', 'x', 'y')
    preproc_settings = get_preproc_settings(n_tsteps)

    # Initialize features computer
    fc = FeaturesComputer(
        preproc_settings,
        startdate,
        enddate)

    # Identify the S2 bands
    bands = list(inarr.bands.values)
    S2_bands = [b for b in bands if 'B' in b]

    # Construct sensor-specific arrays.
    sensor_arrays = {'S2': inarr.sel(bands=S2_bands)}

    # Compute features
    features = fc.get_features(sensor_arrays)
    logger.info(f'Features computed of shape: {features.shape}')

    # write out the features per band into a dictionary
    features_dict = {}
    for datasource in S2_bands:
        # Ensure that naming of bands is compliant with inference workflow
        datasource_rename = 's2_' + datasource.lower()
        band_name = [b for b in features.names if
                     datasource.upper().replace('_', '-') in b]
        features_dict[datasource_rename] = features.select(band_name).data
        banddata = features_dict[datasource_rename]
        banddata[np.isnan(banddata)] = 65535
        #ensure that all the ndarrays match with the model input patch size
        # this is only needed when locally debugging the workflow on a test patch
        if kwargs.get('run_local'):
            banddata = banddata[:, 0:kwargs.get('sample_size'), 0:kwargs.get('sample_size')]
        features_dict[datasource_rename] = banddata
        features_dict[datasource_rename] = features_dict[datasource_rename].astype(np.uint16)

    return features_dict


def delineate(
        inarr: xr.DataArray,
        start_month: int,
        end_month: int,
        year: int,
        **processing_opts) -> xr.DataArray:
    """Main inference delineation UDF for OpenEO workflow.
    This function takes an input xarray fetched through OpenEO
    and returns a result xarray back to OpenEO for saving.

    Args:
        inarr (xr.DataArray): preprocessed
            input arrays for the different sensors
        start_month (int): start month for the feature computation
        end_month (int): end month for the feature computation
        year (int): Year in which mapped season ends.

    Returns:
        xr.DataArray: resulting delineation xarray.
    """

    import vito_lot_delineation
    import cropclass
    from cropclass.features.settings import compute_n_tsteps
    from loguru import logger
    import datetime
    logger.info(f'Using cropclass library from: {cropclass.__file__}')
    logger.info(f'Using parcel delineation library from: {vito_lot_delineation.__file__}')


    # Infer exact processing dates from months and year
    startdate, enddate = (datetime.datetime(year, start_month, 1).date(),
                          datetime.datetime(year, end_month, 31).date())
    logger.info(f'Inferred processing date: {startdate} - {enddate}')

    # Infer required amount of tsteps
    n_tsteps = compute_n_tsteps(start_month, end_month)
    logger.info(f'Inferred n_tsteps: {n_tsteps}')

    # Store the original dimension order for later
    orig_dims = list(inarr.dims)
    orig_dims.remove('t')  # Time dimension will be removed after the inference

    # Mask spurious reflectance values
    arrays = {}
    for datasource in list(inarr.bands.data):
        band_data = inarr.sel(bands=datasource)

        # Need to set spurious values to NaN
        # it seems that clouds are marked with value 2.1
        band_data.values[(band_data.values * processing_opts.get('scale_factor')) > 2] = np.nan

        # Resample to daily for further processing
        daily_index = pd.date_range(startdate, enddate, freq='1D')
        band_data = band_data.reindex(t=daily_index)

        arrays[datasource] = band_data

    logger.info('Reindexed Xarray to daily scale and removed spurious values')


    # create new xarray based on rescaled time axis
    inarr_resampled = xr.concat([arrays.get(item) for item in arrays], dim='bands')

    features = _compute_features(
        inarr_resampled, startdate, enddate, n_tsteps, **processing_opts)

    # Apply the field delineation detection (Radix implementation)
    logger.info(f'Using model: {processing_opts.get("model_tag")}')
    prediction = _apply_delineation(
        features,
        processing_opts.get("model_dir"),
        modeltag=processing_opts.get("model_tag"),
        cache_dir=processing_opts.get("cache_dir")
    )

    logger.info('Transforming model prediction into Xarray')

    # first put in 3D again
    prediction = prediction.reshape(1, prediction.shape[0], prediction.shape[1])

    ## force the data to be in the same format as the prediction when in debug mode
    if processing_opts.get('run_local'):
        result_xr = xr.DataArray(prediction, coords=[
            np.array(['field_ids']),
            inarr.coords["x"][0: prediction.shape[-2]],
            inarr.coords["y"][0: prediction.shape[-1]]]
                                 , dims=["bands", "x", "y"])
    else:
        result_xr = xr.DataArray(prediction, coords=[
            ['field_ids'],
            inarr.coords["x"],
            inarr.coords["y"]], dims=["bands", "x", "y"])

    # And make sure we revert back to original dimension order
    result_xr = result_xr.transpose(*orig_dims)

    return result_xr

def apply_datacube(cube: XarrayDataCube, context: Dict) -> XarrayDataCube:
    os.environ['TORCH_HOME'] = context.get("cache_dir")
    os.environ['XDG_CACHE_HOME'] = context.get("cache_dir")
    sys.path.insert(0, 'tmp/venv_static')
    sys.path.insert(0, 'tmp/venv_static_del')
    sys.path.insert(0, 'tmp/venv_model')
    sys.path.insert(0, 'tmp/venv')

    _setup_logging()

    # Extract xarray.DataArray from the cube
    inarr = cube.get_array()
    logger.info(f'Input array opened with shape: {inarr.shape}')

    # # Run the delineation workflow
    predictions = delineate(inarr, **context)

    # Wrap result in an OpenEO datacube
    return XarrayDataCube(predictions)


def load_delineation_udf() -> str:
    import os
    with open(os.path.realpath(__file__), 'r+') as f:
        return f.read()
