#!/usr/bin/env python3

import pkgutil
import configparser
import io
import os
import tempfile
import abc
import threading
import logging

import pandas as pd

from cropsar.training import train_RNNmodel, train_RNNfullmodel
from cropsar.preprocessing.utils import ts_to_dict, minmaxunscaler
from cropsar.preprocessing.prepare_datastack_rnn import getFullDataStacks
from cropsar.preprocessing.prepare_datastack_rnnfull import getFullDataStack_RNNfull


log = logging.getLogger(__name__)


#-----------------------------------------------------------------------------
# High-level interface that every CropSAR implementation should support
#-----------------------------------------------------------------------------


class CropSARModel(object, metaclass=abc.ABCMeta):
    '''
    Abstract base class for CropSAR model implementations
    '''

    @abc.abstractmethod
    def get_margin_in_days(self):
        '''
        Returns the recommended number of additional days that should be
        included at the start and end of the CropSAR input time series to
        get a reliable result.

        :return: number of days as an integer
        '''
        raise NotImplementedError('Should have implemented this')

    @abc.abstractmethod
    def get_timeseries(self, S1VV, S1VH, S1IncidenceAngle, S2fapar,
                       identifier, startDate, endDate,
                       S2layername='FAPAR', S1var='sigma', useAfter=True):
        '''
        Function that takes S1 and S2 input data and returns the corresponding
        CropSAR time series

        :param S1VV: pandas daily time series of S1 VV backscatter in dB
        :param S1VH: pandas daily time series of S1 VH backscatter in dB
        :param S1IncidenceAngle: pandas daily time series of S1 incidence angle in degrees
        :param S2fapar: pandas daily time series of S2 fapar
        :param identifier: string that identifies the object to which the time series belong (e.g. fieldID)
        :param startDate: string representing desired start date of returned time series (format: yyyy-mm-dd)
        :param endDate: string representing desired end date of returned time series (format: yyyy-mm-dd)
        :param S2layername: string representing the S2 layername (FAPAR, FAPAR_8BAND, or FCOVER), currently only used in the scalers\
        :param S1var: S1 variable used (gamma, sigma)
        :param useAfter: use data available after endDate
        :return: three pandas time series CropSAR predictions, corresponding to q10 (lower confidence), q50 (actual prediction) and q90 (upper confidence)
        '''
        raise NotImplementedError('Should have implemented this')


#-----------------------------------------------------------------------------
# Utility functions
#-----------------------------------------------------------------------------


def _load_model_weights(model, weights):

    # If the specified path is accessible as a file, load it directly.

    if os.path.isfile(weights):
        log.info('Loading model weights [{}]'.format(weights))
        model.load_weights(weights)
    else:

        # Otherwise, try to resolve it relative to the cropsar module

        filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), weights)

        if os.path.isfile(filename):
            log.info('Loading model weights [{}]'.format(filename))
            model.load_weights(filename)
        else:

            # If the path is still not accessible as a regular file, assume
            # that it's pointing to a file inside a package and extract it
            # to a temp file before loading

            data = pkgutil.get_data('cropsar', weights)

            fd, filename = tempfile.mkstemp(suffix='.h5', prefix='weights')
            try:
                log.info('Extracting weights [{}] as [{}]'.format(weights, filename))

                with os.fdopen(fd, 'wb') as f:
                    f.write(data)

                    log.info('Loading model weights [{}]'.format(filename))
                model.load_weights(filename)

            finally:
                log.info('Removing [{}]'.format(filename))

                os.remove(filename)


#-----------------------------------------------------------------------------
# Recurrent Neural Network implementation of CropSAR
#-----------------------------------------------------------------------------


class RNNModel(CropSARModel):
    def __init__(self, weights=None):
        self._model = self._get_model(weights)

    def get_margin_in_days(self):
        return 150

    def get_timeseries(self, S1VV, S1VH, S1IncidenceAngle, S2fapar,
                       identifier, startDate,
                       endDate, S2layername='FAPAR', S1var='sigma', useAfter=True):

        log.info('Assuming backscatter type: {}'.format(S1var))
        log.info('Evaluating RNN model for field {}, {} to {}...'.format(identifier, startDate, endDate))

        # Convert input time series to dictionary for internal use
        inputData = ts_to_dict(S1VV, S1VH, S1IncidenceAngle,
                               S2fapar, identifier, startDate, endDate, S2layername=S2layername,
                               margin_in_days=self.get_margin_in_days())

        # Read model settings
        modelSettings = configparser.ConfigParser()
        modelSettings.read_file(io.StringIO(pkgutil.get_data("cropsar", "resources/models/RNNModelSettings.ini").decode('utf-8')))

        S1smoothing = int(modelSettings["DataStackSettings"]['s1smoothing']) if modelSettings["DataStackSettings"]['s1smoothing'] is not None else None
        outputResolution = int(modelSettings["DataStackSettings"]['outputResolution'])

        # Calculate the DataStack that will be input to the CropSAR model
        inputDataStack, assimilatedFapar, index = getFullDataStacks(identifier, inputData, startDate, endDate, 150,
                                                                   outputResolution=outputResolution,
                                                                   S1smoothing=S1smoothing,
                                                                   useAfter=useAfter,
                                                                   S2layername=S2layername,
                                                                   S1var=S1var)

        # Make the CropSAR prediction
        predictions = self._model.predict([inputDataStack[:, :, 0:2], inputDataStack[:, :, 2:3]])

        # Create the time series and dont forget to unscale
        q10 = pd.Series(index=index, data=minmaxunscaler(predictions[0].ravel(),
                                                         's2_' + S2layername.lower(), clip=True))
        q50 = pd.Series(index=index, data=minmaxunscaler(predictions[1].ravel(),
                                                         's2_' + S2layername.lower(), clip=True))
        q90 = pd.Series(index=index, data=minmaxunscaler(predictions[2].ravel(),
                                                         's2_' + S2layername.lower(), clip=True))

        log.info('Done evaluating model')

        return q10, q50, q90

    def _get_model(self, weights=None):

        log.info('Loading RNN model...')

        modelSettings = configparser.ConfigParser()
        modelSettings.read_file(io.StringIO(pkgutil.get_data("cropsar", "resources/models/RNNModelSettings.ini").decode('utf-8')))

        model = train_RNNmodel.create_model(nodes=int(modelSettings['ModelParameters']['nodes']),
                                            dropoutFraction=float(modelSettings['ModelParameters']['dropoutFraction']))


        # If no weights specified, use the default file from resources
        if weights is None:
            weights = 'resources/models/RNNModelWeights.h5'

        _load_model_weights(model, weights)
        
        # TODO: Since tensorflow 2.2 the internal _make_predict_function
        #       function is gone, so we'll have to solve this some other
        #       way, there are some useful suggestions on these pages:
        #
        #        https://github.com/keras-team/keras/pull/13116
        #        https://github.com/keras-team/keras/issues/5640
        #
        #       specifically this seems a good candidate solution:
        #
        #        https://github.com/keras-team/keras/issues/5640#issuecomment-519134602
        #

        if '_make_predict_function' in dir(model):
            model._make_predict_function()

        log.info('Done loading model')

        return model

#-----------------------------------------------------------------------------
# Full Recurrent neural network implementation of CropSAR
#-----------------------------------------------------------------------------

class RNNfullModel(CropSARModel):
    def __init__(self, weights=None):
        self._model = self._get_model(weights)
        self._seed = None

    def _set_takout_seed(self, seed):
        self._seed = seed

    def get_margin_in_days(self):
        return 90

    def get_timeseries(self, S1VV, S1VH, S1IncidenceAngle, S2fapar,
                       identifier, startDate, endDate,
                       S2layername='FAPAR', S1var='sigma', useAfter=True):

        log.info('Assuming backscatter type: {}'.format(S1var))
        log.info('Evaluating RNNfull model for field {}, {} to {}...'.format(identifier, startDate, endDate))

        # Convert input time series to dictionary for internal use
        inputData = ts_to_dict(S1VV, S1VH, S1IncidenceAngle,
                               S2fapar, identifier, startDate, endDate, S2layername=S2layername,
                               margin_in_days=self.get_margin_in_days())

        # Get the original index
        origindex = inputData['S1']['VV'].index

        # Calculate the DataStack that will be input to the CropSAR model
        inputDataStack, assimilatedFapar = getFullDataStack_RNNfull(identifier, inputData,
                                                                    S2layername=S2layername,
                                                                    useAfter=useAfter,
                                                                    endDate=endDate,
                                                                    S1var=S1var)

        # Make the CropSAR prediction
        predictions = self._model.predict([inputDataStack[:, :, 0:2], inputDataStack[:, :, 2:3]])

        # Construct output datetimeindex
        index = pd.date_range(start=startDate, end=endDate, freq='1D')

        # Make sure that q10 <= q50 <= q90
        predictions[0][predictions[0] > predictions[1]] = predictions[1][predictions[0] > predictions[1]]
        predictions[2][predictions[2] < predictions[1]] = predictions[1][predictions[2] < predictions[1]]

        q10 = pd.Series(index=origindex,
                        data=minmaxunscaler(predictions[0].ravel(),
                                            's2_' + S2layername.lower(), clip=True)).reindex(index)
        q50 = pd.Series(index=origindex,
                        data=minmaxunscaler(predictions[1].ravel(),
                                            's2_' + S2layername.lower(), clip=True)).reindex(index)
        q90 = pd.Series(index=origindex,
                        data=minmaxunscaler(predictions[2].ravel(),
                                            's2_' + S2layername.lower(), clip=True)).reindex(index)

        log.info('Done evaluating model')

        return q10, q50, q90

    def _get_model(self, weights=None):

        log.info('Loading RNNfull model...')

        modelSettings = configparser.ConfigParser()
        modelSettings.read_file(io.StringIO(pkgutil.get_data("cropsar", "resources/models/RNNfullModelSettings.ini").decode('utf-8')))

        model = train_RNNfullmodel.create_model(nodes=int(modelSettings['ModelParameters']['nodes']),
                                         dropoutFraction=float(modelSettings['ModelParameters']['dropoutFraction']))

        # If no weights specified, use the default file from resources
        if weights is None:
            weights = 'resources/models/RNNfullModelWeights.h5'

        _load_model_weights(model, weights)

        # TODO: Since tensorflow 2.2 the internal _make_predict_function
        #       function is gone, so we'll have to solve this some other
        #       way, there are some useful suggestions on these pages:
        #
        #        https://github.com/keras-team/keras/pull/13116
        #        https://github.com/keras-team/keras/issues/5640
        #
        #       specifically this seems a good candidate solution:
        #
        #        https://github.com/keras-team/keras/issues/5640#issuecomment-519134602
        #

        if '_make_predict_function' in dir(model):
            model._make_predict_function()

        log.info('Done loading model')

        return model


#-----------------------------------------------------------------------------
# Generic model construction functions
#-----------------------------------------------------------------------------


def get_available_model_types() -> [str]:
    '''
    Returns a list of types that can be passed to `get_model()`.

    :return: the list of types
    '''

    return ['RNNfull', 'RNN']


_get_model_lock = threading.Lock()
_get_model_cache = {}


def get_model(type='RNNfull') -> CropSARModel:
    '''
    Loads a CropSAR model of the specified type.

    :param type: type of model to load, should be one of the types
                 returned by `get_available_model_types()`
    :return: a model object
    '''

    with _get_model_lock:

        m = _get_model_cache.get(type)

        if m is None:
            if type == 'RNN':
                m = RNNModel()
            elif type == 'RNNfull':
                m = RNNfullModel()

        if m is None:
            raise ValueError('Invalid CropSAR model type: {}'.format(type))

        _get_model_cache[type] = m

        return m


# XXX: exists for backwards-compatibility, maybe remove later...
def get_model_margin_in_days(model):
    return model.get_margin_in_days()


# XXX: exists for backwards-compatibility, maybe remove later...
def get_ts_cropsar(model, *args, **kwargs):
    return model.get_timeseries(*args, **kwargs)
