from loguru import logger
import numpy as np
import pickle
from typing import Dict
import skimage

from satio.features import Features

from cropclass.models import CroptypeModel
from cropclass.utils.scalers import minmaxscaler


def get_default_filtersettings():
    return {'kernelsize': 0,
            'conf_threshold': 0.9}


def mask(data, mask, valid=1, maskedvalue=255):
    data[mask != valid] = maskedvalue
    return data


def majority_filter(prediction, kernel_size,
                    confidence=None, conf_thr=None,
                    no_data_value=255):
    '''
    :param prediction: prediction can only be zero, one or nodata
        (see no_data_value)
    :param kernel_size: determines the size of the spatial window
        that is considered during the filtering operation. Must
        be an odd number.
    :param confidence: (optional) pixel-based confidence scores
        of the prediction. Should be values between 0 and 1.
        No data value of this input should be the same as the one
        used for prediction.
    :param conf_thr: (optional) pixels having confidence lower
        than this threshold will be updated during this process.
        Also, these pixels will not be taken into account for
        determining the majority in a window.
    :param no_data_value: (optional) No data value in both
        prediction and confidence, which will be ignored during
        the entire process.
    '''

    if kernel_size % 2 == 0:
        raise ValueError('Kernel size for majority filtering should be an'
                         ' an odd number!')

    valid = prediction != no_data_value

    # Convolution kernel
    k = np.ones((kernel_size, kernel_size), dtype=int)

    filteredprediction = skimage.filters.rank.majority(
        prediction, k, mask=valid)

    # determine which cells need to be updated:
    # if confidence is low
    update_mask = ((confidence < conf_thr) &
                   (prediction != no_data_value))

    # produce final result
    newprediction = np.where(update_mask, filteredprediction, prediction)

    return newprediction


class Classifier(object):

    def __init__(self, croptypemodel: CroptypeModel,
                 filtersettings: Dict = None,
                 maskdata=None, encoder=None):

        self.model = croptypemodel
        self.modeltype = croptypemodel.modeltype
        self.modelclass = croptypemodel.modelclass
        self.feature_names = croptypemodel.feature_names
        self.requires_scaling = croptypemodel.requires_scaling
        self.filtersettings = filtersettings or get_default_filtersettings()
        self.maskdata = maskdata
        self.encoder = encoder

        self._check_model()

    def _check_model(self):
        if not isinstance(self.model, CroptypeModel):
            raise ValueError(('Associated model should be '
                              'instance of WorldCerealModel '
                              f'but got: `{type(self.model)}`'))
        if self.modeltype is None:
            raise ValueError(('Associated model is of unknown '
                              'type. Should be `pixel` or `patch`.'))

    def predict(self, features: Features, fillnodata=0,
                nodatavalue=0):

        # Select the features on which the model was trained
        inputfeatures = features.select(self.feature_names)

        # Scale the input features if the model requires it
        if self.requires_scaling:
            inputfeatures = self._scale(inputfeatures)

        # Get rid of any remaining NaN values
        if fillnodata is not None and self.model.impute_missing:
            inputfeatures.data[np.isnan(inputfeatures.data)] = fillnodata

        # do prediction
        if self.modeltype == 'pixel':
            prediction, confidence = self._predict_pixel_based(inputfeatures)
        elif self.modeltype == 'patch':
            prediction, confidence = self._predict_patch_based(inputfeatures)
        else:
            raise ValueError(('Unknown modeltype: '
                              f'{self.modeltype}'))

        # Set confidence of 0 to 0.01 to not interfere with nodata
        confidence[confidence == 0] = 0.01

        # If we have an encoder, decode predictions
        if self.encoder is not None:
            logger.info('Decoding predictions ...')
            encoder = pickle.load(open(self.encoder, 'rb'))
            origshape = prediction.shape
            prediction = encoder.inverse_transform(
                prediction.ravel()).reshape(origshape)

        # mask prediction and confidence if necessary
        if self.maskdata is not None:
            prediction = mask(prediction, self.maskdata,
                              maskedvalue=nodatavalue)
            confidence = mask(confidence, self.maskdata,
                              maskedvalue=nodatavalue)

        # perform majority filtering
        if self.filtersettings['kernelsize'] > 0:
            # apply majority filter on prediction
            prediction = majority_filter(
                prediction.astype(np.uint16),
                self.filtersettings['kernelsize'],
                confidence=confidence,
                conf_thr=self.filtersettings['conf_threshold'],
                no_data_value=nodatavalue
            )

        # Convert prediction to uint16 and confidence to uint8
        confidence[confidence != nodatavalue] *= 100

        return prediction.astype(np.uint16), confidence.astype(np.uint8)

    def _predict_pixel_based(self, features):
        logger.debug('Start pixel-based prediction ...')
        orig_shape = features.data.shape[1:3]
        inputs = features.data.transpose(
            (1, 2, 0)).reshape((-1, len(self.feature_names)))
        prediction, confidence = self.model.predict(inputs,
                                                    orig_shape=orig_shape)
        prediction = prediction.reshape(orig_shape)
        confidence = confidence.reshape(orig_shape)

        return prediction, confidence

    def _predict_patch_based(self, features):
        '''
        First implementation of patch-based classifier.
        Should be improved
        '''
        logger.debug('Start patch-based prediction ...')
        logger.info(f"Running classification ...")
        windowsize = self.model.parameters['windowsize']
        xdim = features.data.shape[1]
        ydim = features.data.shape[2]

        if xdim == ydim == windowsize:
            # Features are already in correct spatial shape
            # we can directly make prediction
            prediction, confidence = self.model.predict(
                features.data.transpose((1, 2, 0)).reshape(
                    (1,
                     windowsize * windowsize,
                     -1)))
            prediction = prediction.squeeze().reshape((windowsize, windowsize))
            confidence = confidence.squeeze().reshape((windowsize, windowsize))
        else:
            # Slide through the block with overlap and make predictions

            prediction = np.empty((xdim, ydim))
            confidence = np.empty((xdim, ydim))

            for xStart in range(0, xdim, windowsize):
                for yStart in range(0, ydim, windowsize):
                    # We need to check if we're at the end of the master image
                    # We have to make sure we have a full subtile
                    # so we need to expand such tile and the resulting overlap
                    # with previous subtile is not an issue
                    if xStart + windowsize > xdim:
                        xStart = xdim - windowsize
                        xEnd = xdim
                    else:
                        xEnd = xStart + windowsize
                    if yStart + windowsize > ydim:
                        yStart = ydim - windowsize
                        yEnd = ydim
                    else:
                        yEnd = yStart + windowsize

                    features_patch = features.data[:,
                                                   xStart:xEnd,
                                                   yStart:yEnd]
                    patchprediction, patchconfidence = self.model.predict(
                        features_patch.transpose((1, 2, 0)).reshape(
                            (1, windowsize * windowsize, -1)))

                    patchprediction = patchprediction.squeeze().reshape(
                        (windowsize, windowsize))
                    patchconfidence = patchconfidence.squeeze().reshape(
                        (windowsize, windowsize))

                    prediction[xStart:xEnd, yStart:yEnd] = patchprediction
                    confidence[xStart:xEnd, yStart:yEnd] = patchconfidence

        return prediction, confidence

    def _scale(self, features):

        scaled_features = []
        logger.info('Scaling input features ...')

        # Scale the data
        for ft in features.names:
            ftdata = features.select([ft]).data
            ftscaled = minmaxscaler(ftdata,
                                    ft_name=ft,
                                    clamp=(-0.1, 1.1),
                                    nodata=0)
            scaled_features.append(Features(data=ftscaled,
                                            names=[ft]))

        logger.info('Scaling done.')
        return Features.from_features(*scaled_features)
