import abc
from typing import List
from pathlib import Path
import shutil
import json
import tempfile
from retry import retry
import os
import numpy as np
import pandas as pd
from sklearn.metrics import (accuracy_score, f1_score,
                             precision_score,
                             recall_score)

from loguru import logger
import tensorflow as tf
from tensorflow.keras.layers import (Dropout, BatchNormalization,
                                     Dense, concatenate, LSTM, Flatten)
from tensorflow.keras import backend as K
from tensorflow.keras import Input, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import plot_model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.losses import binary_crossentropy


from cropclass.utils import (get_sensor_config,
                             get_sensor_config_timeseries)

SUPPORTED_MODELS = ['CroptypePixelLSTM', 'CroptypeCatBoostModel']

TRIES = 5
BACKOFF = 10
DELAY = 5


def DiceBCELoss(y_true, y_pred, smooth=1):

    BCE = binary_crossentropy(y_true, y_pred)
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    dice_loss = (1 - (2. * intersection + smooth) /
                 (K.sum(K.square(y_true), -1) +
                 K.sum(K.square(y_pred), -1) + smooth))
    Dice_BCE = BCE + dice_loss

    return Dice_BCE


class CroptypeModel(object, metaclass=abc.ABCMeta):
    '''
    Abstract base class for WorldCereal model implementations
    '''

    @abc.abstractmethod
    def train(self):
        pass

    @abc.abstractmethod
    def predict(self):
        pass

    @abc.abstractmethod
    def save(self):
        pass

    @abc.abstractclassmethod
    def load(self):
        pass

    @abc.abstractmethod
    def summary(self):
        pass

    @abc.abstractmethod
    def evaluate(self):
        pass

    @staticmethod
    @retry(exceptions=TimeoutError, tries=TRIES, delay=DELAY,
           backoff=BACKOFF, logger=logger)
    def from_config(configfile):
        logger.info(('Loading Croptype model '
                     f'from config: {configfile} ...'))
        configfile = str(configfile)
        if not configfile.endswith('json'):
            raise ValueError('Configfile should be json.')
        if configfile.startswith("https"):
            from urllib.request import urlopen
            response = urlopen(configfile)
            config = json.loads(response.read())
        else:
            with open(configfile) as f:
                config = json.load(f)

        modelclass = config['settings']['modelclass']
        if modelclass not in SUPPORTED_MODELS:
            raise ValueError((f'Model class `{modelclass}` not known. '
                              f'Should be one of: {SUPPORTED_MODELS}'))

        basedir = config['paths'].get(
            'basedir', tempfile.mkdtemp())

        # Load the model
        modelfile = config['paths']['modelfile']
        if modelfile is None:
            raise ValueError('Config file has no path to a model')
        model = eval(modelclass).load(modelfile)

        # Get other parameters
        feature_names = config['feature_names']
        parameters = config['parameters']

        return eval(modelclass)(
            model=model,
            basedir=basedir,
            feature_names=feature_names,
            parameters=parameters,
            exist_ok=True,
            config=config
        )


class CroptypeBaseModel(CroptypeModel):

    def __init__(self,
                 model=None,
                 modeltype=None,
                 feature_names=None,
                 requires_scaling=False,
                 parameters=None,
                 basedir=None,
                 overwrite=False,
                 exist_ok=False,
                 parentmodel=None,
                 config=None):
        self.model = model
        self.modeltype = modeltype
        self.modelclass = type(self).__name__
        self.basedir = basedir
        self.feature_names = feature_names
        self.requires_scaling = requires_scaling
        self.parameters = parameters or {}
        self.parentmodel = parentmodel
        self.config = config
        self.impute_missing = True

        if modeltype is None:
            raise ValueError('`modeltype` cannot be None')

        if feature_names is None:
            raise ValueError('`feature_names` cannot be None')

        if basedir is None:
            raise ValueError('`basedir` cannot be None')

        if Path(basedir).is_dir():
            if exist_ok:
                pass
            elif not overwrite:
                raise ValueError((f'Basedir `{basedir}` is '
                                  'not empty. Please delete '
                                  'or use `overwrite=True`'))
            else:
                shutil.rmtree(basedir)

        Path(basedir).mkdir(parents=True, exist_ok=True)

        if self.config is None:
            self.create_config()

    @ classmethod
    def load(cls, file):
        raise NotImplementedError('No model loader available.')

    def save(self, file):
        raise NotImplementedError(('Cannot save model directly '
                                   'from base class.'))

    def predict(self, *args, **kwargs):
        raise NotImplementedError('No prediction method available.')

    def evaluate(self, inputs, outputs, original_labels=None,
                 outdir=None, pattern='', encoder=None):
        from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
        import matplotlib.pyplot as plt

        if self.model is None:
            raise ValueError('No model initialized yet.')
        predictions, _ = self.predict(inputs)

        outdir = outdir or self.basedir

        if encoder is not None:
            predictions = encoder.inverse_transform(predictions)

        # Make sure predictions are now 1D
        predictions = predictions.squeeze()

        # Make absolute confusion matrix
        cm = confusion_matrix(outputs, predictions,
                              labels=np.unique(outputs))
        disp = ConfusionMatrixDisplay(cm, display_labels=np.unique(outputs))
        _, ax = plt.subplots(figsize=(10, 10))
        disp.plot(ax=ax, cmap=plt.cm.Blues, colorbar=False)
        plt.tight_layout()
        plt.savefig(str(Path(outdir) / f'{pattern}CM_abs.png'))
        plt.close()

        # Make relative confusion matrix
        cm = confusion_matrix(outputs, predictions, normalize='true',
                              labels=np.unique(outputs))
        disp = ConfusionMatrixDisplay(cm, display_labels=np.unique(outputs))
        _, ax = plt.subplots(figsize=(10, 10))
        disp.plot(ax=ax, cmap=plt.cm.Blues, values_format='.1f',
                  colorbar=False)
        plt.tight_layout()
        plt.savefig(str(Path(outdir) / f'{pattern}CM_norm.png'))
        plt.close()

        # Compute evaluation metrics
        metrics = {}
        if len(np.unique(outputs)) == 2:
            metrics['OA'] = np.round(accuracy_score(
                outputs, predictions), 3)
            metrics['F1'] = np.round(f1_score(
                outputs, predictions), 3)
            metrics['Precision'] = np.round(precision_score(
                outputs, predictions), 3)
            metrics['Recall'] = np.round(recall_score(outputs, predictions), 3)
        else:
            metrics['OA'] = np.round(accuracy_score(
                outputs, predictions), 3)
            metrics['F1'] = np.round(f1_score(
                outputs, predictions, average='macro'), 3)
            metrics['Precision'] = np.round(precision_score(
                outputs, predictions, average='macro'), 3)
            metrics['Recall'] = np.round(recall_score(
                outputs, predictions, average='macro'), 3)

        # Write metrics to disk
        with open(str(Path(outdir) / f'{pattern}metrics.txt'), 'w') as f:
            f.write('Test results:\n')
            for key in metrics.keys():
                f.write(f'{key}: {metrics[key]}\n')
                logger.info(f'{key} = {metrics[key]}')

        cm = confusion_matrix(outputs, predictions)
        outputlabels = list(np.unique(outputs).astype(int))
        predictlabels = list(np.unique(predictions).astype(int))
        outputlabels.extend(predictlabels)
        outputlabels = list(dict.fromkeys(outputlabels))
        outputlabels.sort()
        cm_df = pd.DataFrame(data=cm, index=outputlabels, columns=outputlabels)
        outfile = Path(outdir) / f'{pattern}confusion_matrix.txt'
        cm_df.to_csv(outfile)
        if original_labels is not None:
            datadict = {'ori': original_labels.astype(int),
                        'pred': predictions.astype(int)}
            data = pd.DataFrame.from_dict(datadict)
            count = data.groupby(['ori', 'pred']).size()
            result = count.to_frame(name='count').reset_index()
            outfile = (Path(outdir) /
                       f'{pattern}confusion_matrix_original_labels.txt')
            result.to_csv(outfile, index=False)

        return metrics

    def create_config(self):

        config = {}
        config['parameters'] = self.parameters
        config['settings'] = dict(
            modeltype=self.modeltype,
            modelclass=self.modelclass,
            requires_scaling=self.requires_scaling
        )
        config['feature_names'] = self.feature_names
        config['paths'] = dict(
            basedir=str(self.basedir),
            modelfile=None,
            modelweights=None,
            parentmodel=self.parentmodel
        )
        self.config = config
        self.save_config()

    def save_config(self):
        configpath = Path(self.basedir) / 'config.json'
        with open(configpath, 'w') as f:
            json.dump(self.config, f, indent=4)

    def transfer(self, basedir):
        '''
        Method to transfer the model to a new directory
        where it can be used to retrain the model on other data
        '''
        if self.model is None:
            raise ValueError('No model loaded to transfer.')
        logger.info(f'Transferring model to: {basedir}')
        return self.__class__(basedir=basedir,
                              feature_names=self.feature_names,
                              parameters=self.parameters,
                              model=self.model,
                              parentmodel=self.config['paths']['modelfile'])


class CroptypeKerasModel(CroptypeBaseModel):

    def train(self,
              calibrationx=None,
              calibrationy=None,
              validationdata=None,
              steps_per_epoch=None,
              validation_steps=None,
              epochs=100,
              modelfile=None,
              weightsfile=None,
              learning_rate=None,
              earlystopping=True,
              reducelronplateau=True,
              tensorboard=False,
              csvlogger=False,
              customcallbacks: List = None,
              **kwargs
              ):

        if self.model is None:
            raise Exception('No model loaded yet')

        if validationdata is None:
            earlystopping = False

        callbacks = []

        if modelfile is not None:
            checkpointermodel = tf.keras.callbacks.ModelCheckpoint(
                filepath=modelfile, save_best_only=True
            )
            callbacks.append(checkpointermodel)

            # Also save the network architecture to a text file
            with open(Path(modelfile).parent /
                      (str(Path(modelfile).stem) +
                       '_architecture.log'), 'w') as f:
                f.write("--- " + str(Path(modelfile).stem) + " ---" + "\n\n\n")
                self.model.summary(
                    print_fn=lambda x: f.write(x + "\n"))

            # Plot the model
            try:
                plotfile = (Path(modelfile).parent /
                            (str(Path(modelfile).stem) +
                             '_architecture.png'))
                plot_model(
                    self.model,
                    to_file=plotfile,
                    show_shapes=True,
                    dpi=96,
                )
            except ImportError:
                logger.warning('Could not plot model!')

        if weightsfile is not None:
            checkpointerweights = tf.keras.callbacks.ModelCheckpoint(
                filepath=weightsfile,
                save_best_only=True,
                save_weights_only=True
            )
            callbacks.append(checkpointerweights)

        if earlystopping:
            earlystoppingcallback = tf.keras.callbacks.EarlyStopping(
                monitor='val_loss', min_delta=0.001, patience=5,
                restore_best_weights=True, verbose=1
            )
            callbacks.append(earlystoppingcallback)

        if reducelronplateau:
            reducelrcallback = tf.keras.callbacks.ReduceLROnPlateau(
                monitor="val_loss", factor=0.1, patience=3,
                verbose=1,  mode="auto", min_delta=0.0001)
            callbacks.append(reducelrcallback)

        if tensorboard:
            log_dir = Path(modelfile).parent / 'tensorboardlogs'
            tensorboardcallback = tf.keras.callbacks.TensorBoard(
                log_dir=log_dir
            )
            callbacks.append(tensorboardcallback)

        if csvlogger:
            log_file = Path(modelfile).parent / 'kerastraininglog.csv'

            # Delete log file if it exists
            if os.path.exists(log_file):
                os.remove(log_file)
            csvloggercallback = tf.keras.callbacks.CSVLogger(log_file,
                                                             separator=",",
                                                             append=True)
            callbacks.append(csvloggercallback)

        if customcallbacks is not None:
            callbacks += customcallbacks

        if learning_rate is not None:
            logger.info(f'Adjusting learning rate to: {learning_rate}')
            K.set_value(self.model.optimizer.learning_rate, learning_rate)

        logger.info('-'*30)
        logger.info('Starting model training ...')
        logger.info('-'*30)

        self.model.fit(
            x=calibrationx,
            y=calibrationy,
            steps_per_epoch=steps_per_epoch,
            epochs=epochs,
            callbacks=callbacks,
            validation_data=validationdata,
            validation_steps=validation_steps,
            **kwargs
        )

    @ classmethod
    @retry(exceptions=TimeoutError, tries=TRIES, delay=DELAY,
           backoff=BACKOFF, logger=logger)
    def load(cls, file: str):
        if file.startswith("https"):
            import h5py
            from urllib.request import urlopen
            from io import BytesIO
            file = h5py.File(BytesIO(urlopen(file).read()), 'r')

        return tf.keras.models.load_model(
            file,
            custom_objects={'DiceBCELoss': DiceBCELoss}
        )

    def save(self, file):
        if self.model is None:
            raise Exception('No model loaded yet')

        bn = Path(file).stem.replace('_weights', '').replace('.h5', '')
        bn_weights = str(Path(file).parent / (bn + '_weights.h5'))
        logger.info(f'Saving model weights to: {bn_weights}')
        self.model.save_weights(bn_weights)

        bn_model = Path(file).parent / (bn + '.h5')
        logger.info(f'Saving model to: {bn_model}')
        self.model.save(bn_model)

        # Update config
        self.config['paths']['modelfile'] = str(bn_model)
        self.config['paths']['modelweights'] = str(bn_weights)
        self.save_config()

    def summary(self, **kwargs):
        if self.model is None:
            raise ValueError('No model associated with this object')
        else:
            self.model.summary(**kwargs)

    def retrain(self, **kwargs):
        raise NotImplementedError(
            'Method should be implemented for specific model!')


class CroptypePixelLSTM(CroptypeKerasModel):
    def __init__(self,
                 model=None,
                 feature_names=None,
                 parameters: dict = {},
                 outchannels: int = 1,
                 **kwargs
                 ):

        if feature_names is None:
            raise ValueError('`feature_names` cannot be None')

        _req_params = []

        for param in _req_params:
            if param not in parameters:
                raise ValueError((f'Parameter `{param}` is '
                                  'compulsory for this model '
                                  'but was not found.'))

        self.dropoutfraction = 0.15
        self.outchannels = outchannels
        self.sensorconfigts = get_sensor_config_timeseries(feature_names)
        self.sensorconfigaux = get_sensor_config(feature_names)

        parameters['sensorconfigts'] = self.sensorconfigts
        parameters['sensorconfigaux'] = self.sensorconfigaux

        super().__init__(modeltype='pixel',
                         parameters=parameters,
                         requires_scaling=True,
                         feature_names=feature_names,
                         **kwargs)

        if model is None:
            self.model = self.lstmmodel()
        else:
            self.model = model

        self.loss = ('binary_crossentropy' if self.outchannels == 1
                     else 'categorical_crossentropy')

        self.model.compile(optimizer=Adam(0.0002, 0.5),
                           loss=self.loss,
                           weighted_metrics=['accuracy'])

    def lstmmodel(self):
        '''
        This method outlines the main lstmmodel model
        architecture
        '''

        # Get a local copy of this because we'll modify it
        sensorconfigts = get_sensor_config_timeseries(self.feature_names)
        sensorconfigaux = get_sensor_config(self.feature_names)

        def _time_encoder(enc_input, sensor, units=None):

            units = units or self.estimate_units(enc_input.shape[2])

            # Masking slows down training terribly
            # and does not seem to impact the results
            # masked_input = Masking()(enc_input)

            regularizer = 0.001

            enc_1 = LSTM(
                units,
                return_sequences=True,
                kernel_regularizer=l2(regularizer),
                recurrent_regularizer=l2(regularizer),
                bias_regularizer=l2(regularizer),
                name=f'{sensor}_lstm_initial'
            )(enc_input)
            enc_output = Flatten()(BatchNormalization()(LSTM(
                units * 2,
                return_sequences=True,
                kernel_regularizer=l2(regularizer),
                recurrent_regularizer=l2(regularizer),
                bias_regularizer=l2(regularizer),
                name=f'{sensor}_lstm_1'
            )(enc_1)))

            return enc_output

        def _encoder(enc_input, sensor):

            units = 8**(int(enc_input.shape[1]/5) + 1)

            enc1 = BatchNormalization()(
                Dropout(self.dropoutfraction)(Dense(
                    units,
                    activation='relu',
                    name=f'{sensor}_dense_initial'
                )(enc_input)))
            enc_output = BatchNormalization()(
                Dropout(self.dropoutfraction)(Dense(
                    units*2,
                    activation='relu'
                )(enc1)))

            return enc_output

        # --------------------------------------
        # Inputs
        # --------------------------------------

        # Inputs
        modelinput = Input(
            shape=(len(self.feature_names),),
            name='full_input'
        )

        # The tf.gather operator selects the
        # sensor-specific channels from the full input
        for sensor in sensorconfigts.keys():
            sensorinputs = []
            for sensorchannel in sensorconfigts[sensor]['channels']:
                sensorinputs.append(
                    tf.gather(
                        params=modelinput,
                        indices=sensorconfigts[sensor][
                            sensorchannel]['positions'],
                        axis=1,
                    ),
                )
            # Stack the different channels in a new channel dimension
            sensorconfigts[sensor]['input'] = tf.stack(
                sensorinputs, axis=2, name=f'{sensor}_ts_input')

        for sensor in sensorconfigaux.keys():
            sensorconfigaux[sensor]['input'] = tf.gather(
                params=modelinput,
                indices=sensorconfigaux[sensor]['positions'],
                axis=1,
                name=f'{sensor}_aux_input'
            )

        # --------------------------------------
        # Temporal encoding step
        # --------------------------------------

        # Encoders
        for sensor in sensorconfigts.keys():
            sensorconfigts[sensor]['encoded'] = _time_encoder(
                sensorconfigts[sensor]['input'],
                sensor=sensor
            )

        # Concatenate temporal features
        time_encoded = [
            sensorconfigts[sensor]['encoded']
            for sensor in sensorconfigts.keys()]
        if len(time_encoded) > 1:
            time_concatenated = concatenate(time_encoded)
        else:
            time_concatenated = time_encoded[0]

        # time_encoded_comb = BatchNormalization()(
        #     LSTM(
        #         64,
        #         return_sequences=False,
        #         kernel_regularizer=l2(0.001),
        #         recurrent_regularizer=l2(0.001),
        #         bias_regularizer=l2(0.001),
        #         name=f'{sensor}_lstm_combined'
        #     )(time_concatenated))

        # --------------------------------------
        # Non temporal encoding step
        # --------------------------------------
        for sensor in sensorconfigaux.keys():
            sensorconfigaux[sensor]['encoded'] = _encoder(
                sensorconfigaux[sensor]['input'],
                sensor=sensor
            )

        # --------------------------------------
        # Concatenation of all encoded features
        # --------------------------------------
        encoded = [
            # sensorconfigts[sensor]['encoded']
            # for sensor in sensorconfigts.keys()
            time_concatenated
        ] + [
            sensorconfigaux[sensor]['encoded']
            for sensor in sensorconfigaux.keys()
        ]

        if len(encoded) > 1:
            concatenated = concatenate(encoded)
        else:
            concatenated = encoded[0]

        # --------------------------------------
        # Classification head
        # --------------------------------------
        # Final dense layer for classification
        clf1 = Dropout(self.dropoutfraction)(
            Dense(512, activation='relu')(concatenated))
        clf2 = Dropout(self.dropoutfraction)(
            Dense(256, activation='relu')(clf1))
        clf3 = Dropout(self.dropoutfraction)(
            Dense(128, activation='relu')(clf2))
        if self.outchannels == 1:
            output = Dense(1, activation='sigmoid')(clf3)
        else:
            output = Dense(self.outchannels, activation='softmax')(clf3)

        return Model(inputs=modelinput, outputs=output)

    def predict(self, inputs, **kwargs):
        if self.model is None:
            raise Exception('No model loaded yet')
        predictions = self.model.predict(inputs,
                                         verbose=1)
        confidence = np.copy(predictions)

        if predictions.ndim == 2:
            # One-hot-encoded
            # needs to be transformed
            predictions = np.argmax(predictions, axis=1)

        # We define confidence as highest at 0 and 1
        # and lowest at 0.5 (undecided)
        confidence = np.max(confidence, axis=1)

        return predictions, confidence

    @staticmethod
    def estimate_units(inputchannels):
        if inputchannels < 5:
            return 32
        if inputchannels < 12:
            return 64
        elif inputchannels < 25:
            return 128
        elif inputchannels < 50:
            return 256
        elif inputchannels < 100:
            return 512
        else:
            return 512


class CroptypeCatBoostModel(CroptypeModel):
    def __init__(self, gpu=False, model=None,
                 iterations=8000, depth=8,
                 random_seed=1234, classes_count=None,
                 learning_rate=0.05, early_stopping_rounds=20,
                 **kwargs):
        from catboost import CatBoostClassifier

        if gpu:
            task_type = "GPU"
            devices = '0'
        else:
            task_type = "CPU"
            devices = None

        if model is None:
            model = CatBoostClassifier(
                iterations=iterations, depth=depth,
                random_seed=random_seed,
                learning_rate=learning_rate,
                early_stopping_rounds=early_stopping_rounds,
                task_type=task_type,
                classes_count=classes_count,
                devices=devices,
                l2_leaf_reg=3
            )
        super().__init__(model=model, modeltype='pixel',
                         **kwargs)

        self.impute_missing = False  # CatBoost can handle NaN

    def train(self, inputs, outputs=None, cat_features=None, **kwargs):

        if inputs.shape[1] != len(self.feature_names):
            raise ValueError(('Model was initialized for '
                              f'{len(self.feature_names)} '
                              'features but got '
                              f'{inputs.shape[1]} for '
                              'fitting.'))

        if self.model is None:
            raise ValueError('No model initialized yet.')

        if 'init_model' in kwargs:
            logger.info('Continuing training from previous model!')

        self.model.fit(inputs, outputs,
                       cat_features=cat_features, **kwargs)

    def grid_search(self, grid, X):
        from catboost import CatBoostClassifier
        model = CatBoostClassifier(early_stopping_rounds=20,
                                   eval_metric='F1')
        logger.info(f'Starting grid search for parameter grid: {grid}')
        results = model.grid_search(grid, X, verbose=False)
        return results

    def save(self, modelfile):
        modelfile = str(modelfile)
        if self.model is None:
            raise ValueError('No model initialized yet.')
        if not modelfile.endswith('cbm'):
            modelfile += '.cbm'
        logger.info(f'Saving model to: {modelfile}')
        self.model.save_model(modelfile)

        # Update config
        self.config['paths']['modelfile'] = modelfile
        self.save_config()

    @ classmethod
    @retry(exceptions=TimeoutError, tries=TRIES, delay=DELAY,
           backoff=BACKOFF, logger=logger)
    def load(cls, modelfile):
        from catboost import CatBoostClassifier
        logger.info(f'Restoring model from: {modelfile}')
        if modelfile.startswith("https"):
            import urllib
            modelfile, _ = urllib.request.urlretrieve(modelfile)

        model = CatBoostClassifier()

        return model.load_model(modelfile)

    def summary(self):
        if self.model is None:
            raise ValueError('No model initialized yet.')
        logger.info(self.model.get_params())

    def transfer(self, basedir, **kwargs):
        '''Override parent method because the model
        will need to be re-initialized instead of
        transferred.
        '''
        if self.model is None:
            raise ValueError('No model loaded to transfer.')
        logger.info(f'Transferring and resetting model to: {basedir}')
        return self.__class__(basedir=basedir,
                              feature_names=self.feature_names,
                              parameters=self.parameters,
                              parentmodel=self.config['paths']['modelfile'],
                              **kwargs)

    def retrain(self, init_model, inputs, outputs=None,
                cat_features=None, **kwargs):
        if self.model is None:
            raise ValueError('No model loaded yet.')

        self.train(inputs, outputs=outputs, cat_features=cat_features,
                   init_model=init_model, **kwargs)

    @staticmethod
    def convolve_probs(probs, kernel):
        """
        Perform 2d convolution of kernel to array of probabilities
        along the labels axis
        """
        import scipy

        filtered_probs = np.zeros(probs.shape)
        for i in range(probs.shape[-1]):
            filtered_probs[..., i] = scipy.signal.convolve2d(
                probs[..., i],
                kernel,
                mode='same',
                boundary='symm')

        return filtered_probs

    @staticmethod
    def get_gaussian_kernel(kernlen=7, std=1):
        """Returns a 2D Gaussian kernel array."""
        import scipy

        gkern1d = scipy.signal.gaussian(kernlen, std=std).reshape(kernlen, 1)
        gkern2d = np.outer(gkern1d, gkern1d)
        gkern2d = gkern2d/gkern2d.sum()
        return gkern2d

    def predict(self, inputs):
        if self.model is None:
            raise ValueError('No model initialized yet.')

        if type(inputs) == np.ndarray:
            inputs = pd.DataFrame(data=inputs,
                                  columns=self.feature_names)

        # Make sure categorical features are categorical
        for ft in self.model.get_cat_feature_indices():
            inputs.iloc[:, ft] = inputs.iloc[:, ft].astype(int)

        predictions = self.model.predict(inputs)
        confidence = np.max(self.model.predict_proba(inputs),
                            axis=1)

        return predictions, confidence


if __name__ == '__main__':
    import tempfile

    features = []

    for band in ['VV', 'VH']:
        features += [f'SIGMA0-{band}-ts{x}-20m' for x in range(18)]

    for band in ['B02', 'B03', 'B04', 'B05',
                 'B06', 'B07', 'B08', 'B11', 'B12']:
        features += [f'L2A-{band}-ts{x}-20m' for x in range(18)]

    features += ['DEM-alt-20m', 'DEM-slo-20m']

    model = CroptypePixelLSTM(feature_names=features, outchannels=25,
                              basedir=tempfile.mkdtemp(), overwrite=True)
    model.summary()
