# -*- coding: utf-8 -*-
# Uncomment the import only for coding support
import os

from openeo.udf import XarrayDataCube
from typing import Dict
import functools
import xarray
import numpy
from xarray.core.dataarray import DataArray
import pandas
import tensorflow as tf
import tensorflow_addons as tfa
import threading


WINDOW_SIZE = 64
GAN_WINDOW_HALF = "90D"
ACQUISITION_STEPS = "5D"
GAN_STEPS = "5D"
GAN_SAMPLES = 37  # this is 2*gan_window_half/gan_steps + 1

NDVI = 'ndvi'
S2id = 'S2ndvi'
VHid = 'VH'
VVid = 'VV'

_threadlocal = threading.local()


def tilted_loss(q, y, f):
    e = (y-f)
    return tf.keras.backend.mean(
        tf.keras.backend.maximum(q*e, (q-1)*e), axis=-1)


def load_generator_model(path: str = None):
    # Keras/tensorflow models are not guaranteed to be threadsafe,
    # but by loading and storing the model once per thread we should
    # be able to safely eliminate loading at model predict time
    generator_model = getattr(_threadlocal, 'generator_model', None)
    if generator_model is None:
        import io
        import pkgutil
        import h5py
        from tensorflow.keras.models import load_model

        # Load custom objects for model loading
        custom_objects = {
            'tilted_loss': tilted_loss,
            'InstanceNormalization': tfa.layers.InstanceNormalization
        }

        if path is None:
            # Load tensorflow model from in-memory HDF5 resource
            path = 'resources/cropsar_px_generator.h5'
            data = pkgutil.get_data('cropsar_px', path)

            with h5py.File(io.BytesIO(data), mode='r') as h5:
                generator_model = load_model(h5, custom_objects=custom_objects)
        else:
            generator_model = load_model(path, custom_objects=custom_objects)

        # Store per thread
        _threadlocal.generator_model = generator_model

    return generator_model


@functools.lru_cache(maxsize=25)
def load_datafusion_model():
    return load_generator_model()


class Scaler():
    def minmaxscaler(self, data, source):
        ranges = {}
        ranges[NDVI] = [-0.08, 1]
        ranges[VVid] = [-20, -2]
        ranges[VHid] = [-33, -8]
        # Scale between -1 and 1
        datarescaled = 2 * \
            (data - ranges[source][0]) / \
            (ranges[source][1] - ranges[source][0]) - 1
        return datarescaled

    def minmaxunscaler(self, data, source):
        ranges = {}
        ranges[NDVI] = [-0.08, 1]
        ranges[VVid] = [-20, -2]
        ranges[VHid] = [-33, -8]
        # Unscale
        dataunscaled = 0.5 * \
            (data + 1) * (ranges[source][1] -
                          ranges[source][0]) + ranges[source][0]
        return dataunscaled


def process(inarr: xarray.DataArray,
            startdate: str,
            enddate: str,
            gan_window_half: str = GAN_WINDOW_HALF,
            acquisition_steps: str = ACQUISITION_STEPS,
            gan_window_size: int = WINDOW_SIZE,
            gan_steps: str = GAN_STEPS,
            gan_samples: int = GAN_SAMPLES,  # this is 2*gan_window_half/gan_steps+1
            percentiles: bool = False,
            model_file: str = None
            ) -> xarray.DataArray:
    # rescale
    inarr.loc[{'bands': VHid}] = 10. * numpy.log10(inarr.sel(bands=VHid))
    inarr.loc[{'bands': VVid}] = 10. * numpy.log10(inarr.sel(bands=VVid))

    # compute windows
    xsize, ysize = inarr.x.shape[0], inarr.y.shape[0]
    windowlist = [
        ((ix, ix + gan_window_size), (iy, iy + gan_window_size))
        for ix in range(0, xsize, gan_window_size) for iy in range(0, ysize, gan_window_size)
    ]

    # init scaler
    sc = Scaler()

    # load the model
    if model_file is None:
        model = load_datafusion_model()
    else:
        model = load_generator_model(model_file)

    # compute acquisition dates
    acquisition_dates = pandas.date_range(
        pandas.to_datetime(startdate),
        pandas.to_datetime(enddate),
        freq=acquisition_steps
    )

    # result buffer
    shape = [len(acquisition_dates), 3 if percentiles else 1, 1, 1]
    shape[inarr.dims.index('x')] = xsize
    shape[inarr.dims.index('y')] = ysize
    predictions = DataArray(
        numpy.full(shape, numpy.nan, dtype=numpy.float32), dims=inarr.dims,
        coords={'bands': ["q10", "q50", "q90"] if percentiles else ["predictions"], 't': acquisition_dates})

    # run processing
    for idate in acquisition_dates:
        for iwin in windowlist:
            data = inarr.isel({
                'x': slice(iwin[0][0], iwin[0][1]),
                'y': slice(iwin[1][0], iwin[1][1]),

            }).sel(t=slice(idate - pandas.to_timedelta(gan_window_half),
                           idate + pandas.to_timedelta(gan_window_half)))
            ires = process_window(
                data, model, sc, gan_window_size,
                gan_steps, gan_samples, percentiles=percentiles).astype(numpy.float32)
            predictions.loc[{'t': idate, 'x': range(
                iwin[0][0], iwin[0][1]), 'y': range(
                    iwin[1][0], iwin[1][1])}] = ires

    return predictions


def process_window(
        inarr: xarray.DataArray,
        model,
        scaler: Scaler,
        windowsize: int,
        gan_steps: str,
        gan_samples: int,
        nodata: float = 0.,
        percentiles: bool = False
):
    inarr = inarr.ffill(dim='t').resample(
        t='1D').ffill().resample(t=gan_steps).ffill()

    # older tensorflows expect exact number of samples in every dimension
    if len(inarr.t) > gan_samples:
        trimfront = int((len(inarr.t) - gan_samples) / 2)
        trimback = trimfront + \
            (0 if (len(inarr.t) - gan_samples) % 2 == 0 else 1)
        inarr = inarr.sel(t=inarr.t[trimfront:-trimback])
    if len(inarr.t) < gan_samples:
        trimfront = int((gan_samples - len(inarr.t)) / 2)
        trimback = trimfront + \
            (0 if (gan_samples - len(inarr.t)) % 2 == 0 else 1)
        front = pandas.date_range(
            end=inarr.t.values.min() - pandas.to_timedelta(gan_steps),
            periods=trimfront,
            freq=gan_steps).values.astype(inarr.t.dtype)
        back = pandas.date_range(
            start=inarr.t.values.max() + pandas.to_timedelta(gan_steps),
            periods=trimback,
            freq=gan_steps).values.astype(inarr.t.dtype)
        inarr = inarr.reindex(
            {'t': numpy.concatenate((front, inarr.t.values, back))})

    # grow it to 5 dimensions
    inarr = inarr.expand_dims(dim=['d0', 'd5'], axis=[0, 5])

    # select bands
    S2 = inarr.sel(bands=S2id)
    VH = inarr.sel(bands=VHid)
    VV = inarr.sel(bands=VVid)

    # Scale S1
    VV = scaler.minmaxscaler(VV, VVid)
    VH = scaler.minmaxscaler(VH, VHid)

    # Concatenate s1 data
    s1_backscatter = xarray.concat((VV, VH), dim='d5')

    # Scale NDVI
    s2_ndvi = scaler.minmaxscaler(S2, NDVI)

    # Remove any nan values
    # Passing in numpy arrays because reduces RAM usage
    # (newer tensorflows copy out from xarray into a numpy array)
    # and backwards compatibility goes further back in time
    s2_ndvi = s2_ndvi.fillna(nodata).values
    s1_backscatter = s1_backscatter.fillna(nodata).values

    # Run neural network
    predictions = model.predict((s1_backscatter, s2_ndvi))

    # Unscale
    if percentiles:
        predictions = list(map(lambda x: scaler.minmaxscaler(x, NDVI), predictions))
        return numpy.asarray(list(map(lambda x: x.reshape((windowsize, windowsize)), predictions)))
    else:
        predictions = scaler.minmaxunscaler(predictions[1], NDVI)  # only take Q50
        return predictions.reshape((windowsize, windowsize))


def apply_datacube(cube: XarrayDataCube, context: Dict) -> XarrayDataCube:
    # extract xarray
    inarr = cube.get_array()
    # get predictions
    predictions = process(inarr, **context)
    # wrap predictions in an OpenEO datacube
    return XarrayDataCube(predictions)


def load_cropsar_px_udf() -> str:
    import os
    with open(os.path.realpath(__file__), 'r+') as f:
        return f.read()
