# -*- coding: utf-8 -*-
# Uncomment the import only for coding support
from openeo_udf.api.datacube import DataCube
from typing import Dict
import functools
import xarray
import numpy
from xarray.core.dataarray import DataArray
import pandas

from cropsar_px.model import load_generator_model

NDVI = 'ndvi'
S2id = 'S2ndvi'
VHid = 'VH'
VVid = 'VV'


@functools.lru_cache(maxsize=25)
def load_datafusion_model():
    return load_generator_model()


class Scaler():
    def minmaxscaler(self, data, source):
        ranges = {}
        ranges[NDVI] = [-0.08, 1]
        ranges[VVid] = [-20, -2]
        ranges[VHid] = [-33, -8]
        # Scale between -1 and 1
        datarescaled = 2 * (data - ranges[source][0]) / (ranges[source][1] - ranges[source][0]) - 1
        return datarescaled

    def minmaxunscaler(self, data, source):
        ranges = {}
        ranges[NDVI] = [-0.08, 1]
        ranges[VVid] = [-20, -2]
        ranges[VHid] = [-33, -8]
        # Unscale
        dataunscaled = 0.5 * (data + 1) * (ranges[source][1] - ranges[source][0]) + ranges[source][0]
        return dataunscaled


def process(
        inarr: xarray.DataArray,
        startdate: str,
        enddate: str,
        gan_window_half: str = "90D",
        acquisition_steps: str = "5D",
        gan_window_size: int = 256,
        gan_steps: str = "5D",
        gan_samples: int = 37  # this is 2*gan_window_half/gan_steps+1
) -> xarray.DataArray:
    # rescale
    inarr.loc[{'bands': VHid}] = 10. * numpy.log10(inarr.sel(bands=VHid))
    inarr.loc[{'bands': VVid}] = 10. * numpy.log10(inarr.sel(bands=VVid))

    # compute windows
    xsize, ysize = inarr.x.shape[0], inarr.y.shape[0]
    windowlist = [((0, gan_window_size), (0, gan_window_size))]

    # init scaler
    sc = Scaler()

    # load the model
    model = load_datafusion_model()

    # compute acquisition dates
    acquisition_dates = pandas.date_range(
        pandas.to_datetime(startdate),
        pandas.to_datetime(enddate),
        freq=acquisition_steps
    )

    # result buffer
    shape = [len(acquisition_dates), 1, 1, 1]
    shape[inarr.dims.index('x')] = xsize
    shape[inarr.dims.index('y')] = ysize
    predictions = DataArray(numpy.full(shape, numpy.nan, dtype=numpy.float32), dims=inarr.dims,
                            coords={'bands': ['predictions'], 't': acquisition_dates})

    # run processing
    for idate in acquisition_dates:
        for iwin in windowlist:
            data = inarr.isel({
                'x': slice(iwin[0][0], iwin[0][1]),
                'y': slice(iwin[1][0], iwin[1][1]),

            }).sel(t=slice(idate - pandas.to_timedelta(gan_window_half), idate + pandas.to_timedelta(gan_window_half)))
            ires = process_window(data, model, sc, gan_window_size, gan_steps, gan_samples).astype(numpy.float32)
            predictions.loc[{'t': idate, 'x': range(iwin[0][0], iwin[0][1]), 'y': range(iwin[1][0], iwin[1][1])}] = ires

    return predictions


def process_window(
        inarr: xarray.DataArray,
        model,
        scaler: Scaler,
        windowsize: int,
        gan_steps: str,
        gan_samples: int,
        nodata: float = 0.
):
    inarr = inarr.ffill(dim='t').resample(t='1D').ffill().resample(t=gan_steps).ffill()

    # older tensorflows expect exact number of samples in every dimension
    if len(inarr.t) > gan_samples:
        trimfront = int((len(inarr.t) - gan_samples) / 2)
        trimback = trimfront + (0 if (len(inarr.t) - gan_samples) % 2 == 0 else 1)
        inarr = inarr.sel(t=inarr.t[trimfront:-trimback])
    if len(inarr.t) < gan_samples:
        trimfront = int((gan_samples - len(inarr.t)) / 2)
        trimback = trimfront + (0 if (gan_samples - len(inarr.t)) % 2 == 0 else 1)
        front = pandas.date_range(end=inarr.t.values.min() - pandas.to_timedelta(gan_steps), periods=trimfront,
                                  freq=gan_steps).values.astype(inarr.t.dtype)
        back = pandas.date_range(start=inarr.t.values.max() + pandas.to_timedelta(gan_steps), periods=trimback,
                                 freq=gan_steps).values.astype(inarr.t.dtype)
        inarr = inarr.reindex({'t': numpy.concatenate((front, inarr.t.values, back))})

    # grow it to 5 dimensions
    inarr = inarr.expand_dims(dim=['d0', 'd5'], axis=[0, 5])

    # select bands
    S2 = inarr.sel(bands=S2id)
    VH = inarr.sel(bands=VHid)
    VV = inarr.sel(bands=VVid)

    # Scale S1
    VV = scaler.minmaxscaler(VV, VVid)
    VH = scaler.minmaxscaler(VH, VHid)

    # Concatenate s1 data
    s1_backscatter = xarray.concat((VV, VH), dim='d5')

    # Scale NDVI
    s2_ndvi = scaler.minmaxscaler(S2, NDVI)

    # Remove any nan values
    # Passing in numpy arrays because reduces RAM usage (newer tensorflows copy out from xarray into a numpy array) and backwards compatibility goes further back in time
    s2_ndvi = s2_ndvi.fillna(nodata).values
    s1_backscatter = s1_backscatter.fillna(nodata).values

    # Run neural network
    predictions = model.predict((s1_backscatter, s2_ndvi))

    # Unscale
    predictions = scaler.minmaxunscaler(predictions, NDVI)

    return predictions.reshape((windowsize, windowsize))


def apply_datacube(cube: DataCube, context: Dict) -> DataCube:
    # extract xarray
    inarr = cube.get_array()
    # get predictions
    predictions = process(inarr, **context)
    # wrap predictions in an OpenEO datacube
    return DataCube(predictions)


def load_cropsar_px_udf()->str:
    import os
    with open(os.path.realpath(__file__), 'r+') as f:
        return f.read()
