# -*- coding: utf-8 -*-
# Uncomment the import only for coding support
import os

from openeo.udf import XarrayDataCube
from typing import Dict, Optional
import functools
import xarray
import numpy
import pandas as pd
from xarray.core.dataarray import DataArray
import pandas
import tensorflow as tf
import threading


WINDOW_SIZE = 64
GAN_WINDOW_HALF = "80D"
ACQUISITION_STEPS = "5D"
GAN_STEPS = "5D"
GAN_SAMPLES = 32  # this is 2*gan_window_half/gan_steps + 1

NDVI = 'ndvi'
S2id = 'S2ndvi'
VHid = 'VH'
VVid = 'VV'

_threadlocal = threading.local()


def load_generator_model(path: str = None):
    from cropsar_px.model import CropsarPixelModel

    # Keras/tensorflow models are not guaranteed to be threadsafe,
    # but by loading and storing the model once per thread we should
    # be able to safely eliminate loading at model predict time
    generator_model = getattr(_threadlocal, 'generator_model', None)
    if generator_model is None:
        generator_model = CropsarPixelModel(modelinputs={'S1': 2, 'S2': 1},
                                            windowsize=32,
                                            tslength=32).generator

        if path is None:
            try:
                import importlib.resources as pkg_resources
            except ImportError:
                import importlib_resources as pkg_resources

            with pkg_resources.path('cropsar_px.resources', 'cropsar_px_generator.h5') as path:
                generator_model.load_weights(path)
        else:
            generator_model.load_weights(path)

        # Store per thread
        _threadlocal.generator_model = generator_model

    return generator_model


@functools.lru_cache(maxsize=25)
def load_datafusion_model():
    return load_generator_model()


class Scaler():
    def minmaxscaler(self, data, source):
        ranges = {}
        ranges[NDVI] = [-0.08, 1]
        ranges[VVid] = [-20, -2]
        ranges[VHid] = [-33, -8]
        # Scale between -1 and 1
        datarescaled = 2 * \
            (data - ranges[source][0]) / \
            (ranges[source][1] - ranges[source][0]) - 1
        return datarescaled

    def minmaxunscaler(self, data, source):
        ranges = {}
        ranges[NDVI] = [-0.08, 1]
        ranges[VVid] = [-20, -2]
        ranges[VHid] = [-33, -8]
        # Unscale
        dataunscaled = 0.5 * \
            (data + 1) * (ranges[source][1] -
                          ranges[source][0]) + ranges[source][0]
        return dataunscaled


def multitemporal_mask(ndvicube):

    print('Running multitemporal masking ...')
    from cropsar_px.utils.masking import flaglocalminima

    timestamps = list(ndvicube.t.values)

    daily_daterange = pd.date_range(
        timestamps[0],
        timestamps[-1] + pd.Timedelta(days=1),
        freq='D').floor('D')
    ndvi_daily = ndvicube.reindex(t=daily_daterange,
                                  method='bfill', tolerance='1D')

    # ndvi_daily.values[:,50,25]

    # Run multitemporal dip detection
    # Need to do it in slices, to avoid memory issues
    step = 256
    for idx in numpy.r_[:ndvi_daily.values.shape[1]:step]:
        for idy in numpy.r_[:ndvi_daily.values.shape[2]:step]:

            ndvi_daily.values[
                :, idx:idx+step, idy:idy+step] = flaglocalminima(
                ndvi_daily.values[:, idx:idx+step, idy:idy+step],
                maxdip=0.01,
                maxdif=0.1,
                maxgap=60,
                maxpasses=5)

    # Subset on the original timestamps
    ndvi_cleaned = ndvi_daily.sel(t=timestamps,
                                  method='ffill',
                                  tolerance='1D')

    return ndvi_cleaned


def process(
        inarr: xarray.DataArray,
        startdate: str,
        enddate: str,
        gan_window_half: str = GAN_WINDOW_HALF,
        acquisition_steps: str = ACQUISITION_STEPS,
        gan_window_size: int = WINDOW_SIZE,
        gan_steps: str = GAN_STEPS,
        gan_samples: int = GAN_SAMPLES,
        model_file: str = None,
        inpaint_only: bool = True,
        output_mask: bool = False,
        nrt_mode: bool = False,
        drop_dates: Optional[list] = None
) -> xarray.DataArray:
    """
    Apply the CropSAR_px algorithm to the provided input data.

    :param inarr: input data (Sentinel-1 + Sentinel-2)
    :param startdate: requested start date
    :param enddate: requested end date
    :param gan_window_half: half GAN temporal window size
    :param acquisition_steps: acquisition interval in the output
    :param gan_window_size: GAN window size
    :param gan_steps: GAN steps
    :param gan_samples: number of GAN samples, this is 2*gan_window_half/gan_steps + 1
    :param model_file: path to custom GAN model file
    :param inpaint_only: keep actual NDVI acquisitions, only predict areas where there is no data
    :param output_mask: output the Sentinel-2 mask: 0 = no data, 1 = data
    :param nrt_mode: only use prior data for prediction
    :param drop_dates: drop Sentinel-2 acquisitions for provided dates
    """
    if drop_dates is not None:
        # Drop Sentinel-2 acquisitions
        drop_dates = list(map(pandas.to_datetime, drop_dates))
        inarr.loc[dict(bands=S2id, t=[d for d in drop_dates if d in inarr.t])] = numpy.NaN

    # Run multitemporal mask
    inarr.loc[dict(bands=S2id)] = multitemporal_mask(inarr.sel(bands=S2id))

    # compute windows
    xsize, ysize = inarr.x.shape[0], inarr.y.shape[0]
    windowlist = [
        ((ix, ix + gan_window_size), (iy, iy + gan_window_size))
        for ix in range(0, xsize, gan_window_size)
        for iy in range(0, ysize, gan_window_size)
    ]

    # init scaler
    sc = Scaler()

    # load the model
    if model_file is None:
        model = load_datafusion_model()
    else:
        model = load_generator_model(model_file)

    # compute acquisition dates
    acquisition_dates = pandas.date_range(
        pandas.to_datetime(startdate),
        pandas.to_datetime(enddate),
        freq=acquisition_steps
    )

    # result buffer
    shape = [len(acquisition_dates), 2 if output_mask else 1, 1, 1]
    shape[inarr.dims.index('x')] = xsize
    shape[inarr.dims.index('y')] = ysize
    predictions = DataArray(
        numpy.full(shape, numpy.nan, dtype=numpy.float32), dims=inarr.dims,
        coords={'bands': ["NDVI", "mask"] if output_mask else ["NDVI"], 't': acquisition_dates})

    # run processing
    for idate in acquisition_dates:
        for iwin in windowlist:
            data = inarr.isel({
                'x': slice(iwin[0][0], iwin[0][1]),
                'y': slice(iwin[1][0], iwin[1][1]),
            }).sel(t=slice(
                idate - pandas.to_timedelta(gan_window_half),
                idate + pandas.to_timedelta(gan_window_half) if not nrt_mode else idate))
            ires = process_window(
                data, model, sc, idate, gan_window_size,
                gan_steps, gan_samples, inpaint_only=inpaint_only, output_mask=output_mask, drop_dates=drop_dates
            ).astype(numpy.float32)
            predictions.loc[{'t': idate, 'x': range(
                iwin[0][0], iwin[0][1]), 'y': range(
                    iwin[1][0], iwin[1][1])}] = ires

    return predictions


def _process_s2(s2data: DataArray, output_index, gan_steps):
    '''Sentinel-2:
    - Make a resample object to 5-day resolution
    - Take the best image out of each group
    - Finally do the reindexing to the requested 5-day
        index and make sure we propagate values no more
        than 5 days (should there still be NaNs (?))
    '''
    print(f"Output index: {output_index}")
    print(f"S2 data t: {s2data.t}")

    def _take_best(group: xarray.DataArray):
        best_image = group.isel(t=group.notnull().sum(
            dim=('x', 'y', 'batch', 'channel')).argmax())
        return best_image

    s2data_resampled = s2data.resample(
        t=gan_steps
    ).map(
        _take_best
    ).reindex(
        {'t': output_index}, method='ffill', tolerance=gan_steps
    )

    return s2data_resampled


def _process_s1(s1data: DataArray, output_index, gan_steps):
    '''Sentinel-1:
    - First revert to power values
    - then resample to every 5-days and average the obs
        in each window
    - next interpolate any missing values using linear
        interpolation
    - next do the reindexing to the requested 5-day index
        and make sure we propagate the values no more than
        5 days (should there still be NaNs (?))
    - finally re-introduce the decibels
    '''

    # To power values
    s1data = numpy.power(10, s1data / 10.)

    # Resample
    s1data_resampled = s1data.resample(
        t=gan_steps
    ).mean(
        skipna=True
    ).interpolate_na(
        dim='t', method='linear'
    ).reindex(
        {'t': output_index}, method='ffill', tolerance=gan_steps
    )

    # To dB
    s1data_resampled_db = 10 * numpy.log10(s1data_resampled)

    return s1data_resampled_db


def _force_even_tsteps():
    if GAN_SAMPLES % 2 == 0:
        return True
    else:
        return False


def process_window(
        inarr: xarray.DataArray,
        model,
        scaler: Scaler,
        idate: pd.Timestamp,
        windowsize: int,
        gan_steps: str,
        gan_samples: int,
        inpaint_only: bool = True,
        output_mask: bool = False,
        nodata: float = 0.,
        drop_dates: Optional[list] = None
):
    if drop_dates is not None:
        # put manually masked values to infinity, so we can track it when resampling
        inarr.loc[dict(bands=S2id, t=[d for d in drop_dates if d in inarr.t])] = numpy.inf

    # Get the required output temporal index for this stack
    output_index = pd.date_range(idate - pd.to_timedelta(GAN_WINDOW_HALF),
                                 idate + pd.to_timedelta(GAN_WINDOW_HALF),
                                 freq=gan_steps)

    if _force_even_tsteps():
        if len(output_index) % 2 == 1:
            # In case we want an even amount of timesteps
            # we strip the last timestamp in the series
            output_index = output_index[:-1]

    # grow it to 5 dimensions
    inarr = inarr.expand_dims(dim=['batch', 'channel'], axis=[0, 5])

    # Process S2
    S2 = _process_s2(inarr.sel(bands=S2id), output_index, gan_steps)

    # Process S1
    S1 = _process_s1(inarr.sel(bands=[VVid, VHid]), output_index, gan_steps)

    # select bands
    VH = S1.sel(bands=VHid)
    VV = S1.sel(bands=VVid)

    # Get the center acquisition to inpaint
    s2_ndvi_center = S2.values[:, gan_samples // 2, ...].reshape((windowsize, windowsize))

    # Get a mask and the center NDVI
    # mask categories
    #   0: no data
    #   1: data
    #   2: manually masked
    s2_mask_category = (~numpy.isnan(s2_ndvi_center)).astype(int)

    if drop_dates is not None:
        # add to mask
        s2_mask_inf = numpy.isinf(s2_ndvi_center).astype(int)
        s2_mask_category += s2_mask_inf
        # put NaN for manually masked values before feeding into the network
        S2 = xarray.where(S2 == numpy.inf, numpy.NaN, S2)
        s2_ndvi_center[s2_ndvi_center == numpy.inf] = numpy.NaN

    # simplify categories to binary mask
    s2_mask = (s2_mask_category == 1).astype(int)

    # Scale S1
    VV = scaler.minmaxscaler(VV, VVid)
    VH = scaler.minmaxscaler(VH, VHid)

    # Concatenate s1 data
    s1_backscatter = xarray.concat((VV, VH), dim='channel')

    # Scale NDVI
    s2_ndvi = scaler.minmaxscaler(S2, NDVI)

    # Remove any nan values
    # Passing in numpy arrays because reduces RAM usage
    # (newer tensorflows copy out from xarray into a numpy array)
    # and backwards compatibility goes further back in time
    s2_ndvi = s2_ndvi.fillna(nodata).values
    s1_backscatter = s1_backscatter.fillna(nodata).values

    # Run neural network
    predictions = model.predict((s1_backscatter, s2_ndvi))

    # Unscale
    predictions = scaler.minmaxunscaler(predictions, NDVI)
    pred_reshaped = predictions.reshape((windowsize, windowsize))

    if inpaint_only:
        # Only predict masked regions
        # We want to avoid crisp borders
        # so first dilate the inverted mask
        s2_mask_nan = s2_mask.astype(float)
        s2_mask_inv_dilated = _dilate_mask(1 - s2_mask)

        # Put mask values to NaN
        s2_mask_nan[s2_mask_nan == 0] = numpy.nan
        s2_mask_inv_dilated[s2_mask_inv_dilated == 0] = numpy.nan

        # Stack original and predicted pixels
        # masked pixels become NaN
        stacked = numpy.stack(
            [s2_mask_nan * s2_ndvi_center,
             s2_mask_inv_dilated * pred_reshaped],
            axis=-1)

        # By taking a nanmean, we take the mean
        # of the original and predicted values
        # in overlap regions of mask borders
        completed = numpy.nanmean(stacked, axis=-1)

        # completed = s2_mask * s2_ndvi_center + \
        #     (1 - s2_mask) * pred_reshaped
    else:
        # Return prediction for all pixels
        completed = pred_reshaped

    if output_mask:
        return numpy.stack([completed, s2_mask_category])
    else:
        return completed


def _dilate_mask(mask, dilate_r=5):
    from skimage.morphology import selem, binary_dilation

    dilate_disk = selem.disk(dilate_r)
    dilated_mask = binary_dilation(mask, dilate_disk)

    return dilated_mask.astype(float)


def apply_datacube(cube: XarrayDataCube, context: Dict) -> XarrayDataCube:
    # extract xarray
    inarr = cube.get_array()
    # get predictions
    predictions = process(inarr, **context)
    # wrap predictions in an OpenEO datacube
    return XarrayDataCube(predictions)


def load_cropsar_px_udf() -> str:
    import os
    with open(os.path.realpath(__file__), 'r+') as f:
        return f.read()
