# -*- coding: utf-8 -*-
# Uncomment the import only for coding support
from openeo.udf import XarrayDataCube
from typing import Dict
import xarray as xr
import numpy as np
import threading
import functools
import pandas as pd
from pathlib import Path
import sys
import os
from openeo.udf.debug import inspect

_threadlocal = threading.local()


# def _setup_logging():
#     global logger
#     from loguru import logger


def take_best_acquisition(x, n_steps=6):
    '''Method to compute tsteps features.
    '''

    def _take_best(group: xr.DataArray):
        best_image = group.isel(t=group.notnull().sum(
            dim=('x', 'y')).argmax())
        return best_image

    da = xr.DataArray(x, dims=('t', 'x', 'y'))
    da_resampled = da.groupby_bins('t', bins=n_steps).map(_take_best)

    return da_resampled.values


def get_preproc_settings(n_tsteps):
    return {
        'S2':
            {
                "tsteps": {
                    "function": take_best_acquisition,
                    "parameters": {
                        'n_steps': n_tsteps,
                    },
                    "names": [f'ts{i}' for i in range(n_tsteps)]
                },
            }
    }


def _compute_features(inarr: xr.DataArray,
                      startdate: str,
                      enddate: str,
                      n_tsteps: int,
                      **kwargs) -> xr.DataArray:
    from cropclass.features import FeaturesComputer

    # Fix to see if the logger is defined or not, as the compute_features
    # function can be called alone
    # try:
    #     logger
    # except NameError:
    #     _setup_logging()

    # Make sure we have a controlled dimension order
    inarr = inarr.transpose('bands', 't', 'x', 'y')
    preproc_settings = get_preproc_settings(n_tsteps)

    # Initialize features computer
    fc = FeaturesComputer(
        preproc_settings,
        startdate,
        enddate)

    # Identify the S2 bands
    bands = list(inarr.bands.values)
    S2_bands = [b for b in bands if 'B' in b]

    # Construct sensor-specific arrays.
    sensor_arrays = {'S2': inarr.sel(bands=S2_bands)}

    # Compute features
    features = fc.get_features(sensor_arrays)
    inspect(
        message=f'Computed features of shape: {features.shape}',
        code='_compute_features',
        level='debug'
    )

    # write out the features per band into a dictionary
    features_dict = {}
    for datasource in S2_bands:
        # Ensure that naming of bands is compliant with inference workflow
        datasource_rename = 's2_' + datasource.lower()
        band_name = [b for b in features.names if
                     datasource.upper().replace('_', '-') in b]
        features_dict[datasource_rename] = features.select(band_name).data
        banddata = features_dict[datasource_rename]
        banddata[np.isnan(banddata)] = 65535
        # ensure that all the ndarrays match with the model input patch size
        # this is only needed when locally debugging the workflow on a test patch
        if kwargs.get('run_local'):
            banddata = banddata[:, 0:kwargs.get(
                'sample_size'), 0:kwargs.get('sample_size')]
        features_dict[datasource_rename] = banddata
        features_dict[datasource_rename] = features_dict[datasource_rename].astype(
            np.uint16)

    return features_dict


def delineate(
        inarr: xr.DataArray,
        start_month: int,
        end_month: int,
        year: int,
        **processing_opts) -> xr.DataArray:
    """Main inference delineation UDF for OpenEO workflow.
    This function takes an input xarray fetched through OpenEO
    and returns a result xarray back to OpenEO for saving.

    Args:
        inarr (xr.DataArray): preprocessed
            input arrays for the different sensors
        start_month (int): start month for the feature computation
        end_month (int): end month for the feature computation
        year (int): Year in which mapped season ends.

    Returns:
        xr.DataArray: resulting delineation xarray.
    """

    import vito_lot_delineation
    import cropclass
    from cropclass.features.settings import compute_n_tsteps
    # from loguru import logger
    import datetime
    from fielddelineation.utils.padding import calc_padding

    inspect(message= f'Using cropclass library from: {cropclass.__file__}',
            code='delineate', level='info')
    inspect(message= f'Using parcel delineation library from: {vito_lot_delineation.__file__}',
            code='delineate', level='info')
    from fielddelineation.utils.delineation import _apply_delineation

    # Infer exact processing dates from months and year
    startdate, enddate = (datetime.datetime(year, start_month, 1).date(),
                          datetime.datetime(year, end_month, 31).date())
    inspect(
        message=f'Inferred processing date: {startdate} - {enddate}',
        code='delineate',
        level='info'
    )

    # Infer required amount of tsteps
    n_tsteps = compute_n_tsteps(start_month, end_month)
    inspect(
        message=f'Inferred n_tsteps: {n_tsteps}',
        code='delineate',
        level='info'
    )

    # check if n_tsteps aligns with input data
    # check if padding should be done to get the
    # proper dimensions
    n_ts_add_before, n_ts_add_after = calc_padding(inarr,
                                                   start_month,
                                                   end_month,
                                                   n_tsteps,
                                                   year)

    # Store the original dimension order for later
    orig_dims = list(inarr.dims)
    # get index positions of the time dimensions
    idx_t = orig_dims.index('t')
    orig_dims.remove('t')  # Time dimension will be removed after the inference

    windowsize = processing_opts.get(
        'sample_size')

    # Mask spurious reflectance values
    arrays = {}
    for datasource in list(inarr.bands.data):
        band_data = inarr.sel(bands=datasource)

        # Ensure it is a float before setting wrong values to nan
        band_data = band_data.astype(float)

        # Need to set spurious values to NaN
        # it seems that clouds are marked with value 2.1
        band_data.values[band_data.values > 20000] = np.nan

        band_data.values[np.isnan(band_data.values)] = 65535  # Nodata value
        band_data = band_data.astype(np.uint16)

        # set to daily scale if not already
        # temporal aggregated data
        if not processing_opts.get('temp_aggr'):
            # Resample to daily for further processing
            daily_index = pd.date_range(startdate, enddate, freq='1D')
            band_data = band_data.reindex(t=daily_index)
            arrays[datasource] = band_data

        else:
            # if no feature computation should be done the naming should set
            # correcty such that the model can deal with it. Otherwise this
            # should be done after the feature computation
            datasource_name = 's2_' + datasource.lower()

            band_data = band_data.values

            # extend array if mismatch in time dimension that was requested
            if n_ts_add_before > 0 or n_ts_add_after > 0:
                if idx_t == 0:
                    ext_array_before = np.ones((n_ts_add_before,
                                                band_data.shape[1],
                                                band_data.shape[-1]))*65535
                    ext_array_after = np.ones((n_ts_add_after,
                                               band_data.shape[1],
                                               band_data.shape[-1]))*65535
                elif idx_t == 1:
                    ext_array_before = np.ones((band_data.shape[0],
                                                n_ts_add_before,
                                                band_data.shape[-1]))*65535
                    ext_array_after = np.ones((band_data.shape[0],
                                               n_ts_add_after,
                                               band_data.shape[-1]))*65535
                else:
                    ext_array_before = np.ones((band_data.shape[0],
                                                band_data.shape[1],
                                                n_ts_add_before))*65535
                band_data = np.concatenate((ext_array_before,
                                            band_data,
                                            ext_array_after))
                
            # ensure that all the ndarrays match 
            # with the model input patch size
            # this is only needed when 
            # locally debugging the workflow on a test patch
            if processing_opts.get('run_local'):
                band_data = band_data[:, 0:windowsize, 0:windowsize]

            arrays[datasource_name] = band_data

    inspect(message= 'Reindexed Xarray to daily scale and removed spurious values',
            code='delineate',
            level='debug')

    # create new xarray based on rescaled time axis
    if not processing_opts.get('temp_aggr'):
        inarr_resampled = xr.concat([arrays.get(item)
                                    for item in arrays], dim='bands')

        features = _compute_features(
            inarr_resampled, startdate, enddate, n_tsteps, **processing_opts)
    else:
        features = arrays

    # Apply the field delineation detection (Radix implementation)
    inspect(message=f'Using model: {processing_opts.get("model_tag")}',
            code='delineate',
            level='info')
    prediction_sem = _apply_delineation(
        features,
        processing_opts.get("model_dir"),
        modeltag=processing_opts.get("model_tag"),
        cache_dir=processing_opts.get("cache_dir"))

    # first put in 3D again
    prediction_sem = np.array([prediction_sem.squeeze().T])

    # force the data to be in the same format
    # as the prediction when in debug mode
    if processing_opts.get('run_local'):
        result_xr = xr.DataArray(prediction_sem, coords=[
            np.array(['field_ids']),
            inarr.coords["x"][0: prediction_sem.shape[-2]],
            inarr.coords["y"][0: prediction_sem.shape[-1]]],
            dims=["bands", "x", "y"])
    else:
        result_xr = xr.DataArray(prediction_sem, coords=[
            ['field_ids'],
            inarr.coords["x"],
            inarr.coords["y"]], dims=["bands", "x", "y"])

    # And make sure we revert back to original dimension order
    result_xr = result_xr.transpose(*orig_dims)

    return result_xr


def apply_datacube(cube: XarrayDataCube, context: Dict) -> XarrayDataCube:
    if context.get('cache_dir') is None:
        os.environ['TORCH_HOME'] = os.getcwd()
        os.environ['XDG_CACHE_HOME'] = os.getcwd()
    else:
        os.environ['TORCH_HOME'] = context.get("cache_dir")
        os.environ['XDG_CACHE_HOME'] = context.get("cache_dir")
    sys.path.insert(0, 'tmp/venv_model')
    sys.path.insert(0, 'tmp/venv_static')
    sys.path.insert(0, 'tmp/venv_static_del')

    # _setup_logging()

    # Extract xarray.DataArray from the cube
    inarr = cube.get_array()
    inspect(message=f'Input array opened with shape: {inarr.shape}',
            code='delineate',
            level='info')

    # # Run the delineation workflow
    predictions = delineate(inarr, **context)

    # Wrap result in an OpenEO datacube
    return XarrayDataCube(predictions)


def load_delineation_udf() -> str:
    import os
    with open(os.path.realpath(__file__), 'r+') as f:
        return f.read()
