from tensorflow.keras.layers import (Input, concatenate,
                                     Conv2D, Dropout, LeakyReLU,
                                     Activation, Conv3D,
                                     )
from tensorflow.keras.layers import (BatchNormalization,
                                     Conv2DTranspose)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras import Model
from tensorflow.keras.regularizers import l2
import tensorflow as tf
import threading

_threadlocal = threading.local()


DEFAULT_MODEL_INPUTS = {
    'S1': 2,
    'S2': 1,
}


def load_generator_model():
    # Keras/tensorflow models are not guaranteed to be threadsafe,
    # but by loading and storing the model once per thread we should
    # be able to safely eliminate loading at model predict time
    generator_model = getattr(_threadlocal, 'tf_model', None)
    if generator_model is None:
        import io
        import pkgutil
        import h5py
        from tensorflow.keras.models import load_model

        # Load tensorflow model from in-memory HDF5 resource
        path = 'resources/cropsar_px_generator.h5'
        data = pkgutil.get_data('cropsar_px', path)

        with h5py.File(io.BytesIO(data), mode='r') as h5:
            generator_model = load_model(h5)

        # Store per thread
        _threadlocal.generator_model = generator_model

    return generator_model

class CropsarPixelModel:
    def __init__(self, modelinputs=DEFAULT_MODEL_INPUTS,
                 windowsize=32, tslength=37):
        if windowsize != 32:
            raise ValueError(('Windowsize should 32.'))

        self.windowsize = windowsize
        self.tslength = tslength
        self.dropoutfraction = 0.5
        self.modelinputs = modelinputs

        optimizer = Adam(0.001, 0.5)

        # Define model inputs
        inputs = {}
        for sensor in modelinputs.keys():
            inputs[sensor] = Input(
                shape=(self.tslength, None, None,
                       self.modelinputs[sensor]),
                name=sensor + '_input')
        self.inputs = inputs

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator(
            windowsize=self.windowsize)
        self.discriminator.compile(loss='binary_crossentropy',
                                   optimizer=optimizer,
                                   metrics=['accuracy'],
                                   loss_weights=[0.5])

        # Build the generator
        self.generator = self.build_generator()

        # Build the "frozen discriminator"
        inputs_discriminator = self.discriminator.inputs
        outputs_discriminator = self.discriminator.outputs
        self.frozen_discriminator = Model(
            inputs_discriminator, outputs_discriminator,
            name="frozen_discriminator")
        self.frozen_discriminator.trainable = False

        # By conditioning on inputs, generate fake cloudfree images
        cloudfreeimage_fake = self.generator(list(inputs.values()))

        valid = self.frozen_discriminator([inputs['S2'],
                                           cloudfreeimage_fake])

        self.combined = Model(inputs=list(inputs.values()),
                              outputs=[valid, cloudfreeimage_fake])
        self.combined.compile(loss=['binary_crossentropy', 'mae'],
                              loss_weights=[1, 100], optimizer=optimizer,
                              metrics=['accuracy', 'mse'])

    def build_discriminator(self, windowsize):

        init = RandomNormal(mean=0.0, stddev=0.02)

        # Set the input: the discriminator is NOT
        # fully convolutional!
        s2_input = Input(
            shape=(self.tslength, windowsize, windowsize, 1),
            name='s2_input')

        # Target image, supposed to be cloudfree, either real or fake
        target = Input(shape=(windowsize, windowsize, 1),
                       name='target_image')

        # Use conv3d to remove the temporal dimension
        s2_conv3d_1 = LeakyReLU(alpha=0.2)(BatchNormalization()(
            Conv3D(filters=64,
                   kernel_size=(7, 1, 1),
                   strides=(5, 1, 1),
                   padding="same")(s2_input)))
        s2_conv3d_2 = LeakyReLU(alpha=0.2)(BatchNormalization()(
            Conv3D(filters=64,
                   kernel_size=(5, 1, 1),
                   strides=(5, 1, 1),
                   padding="same")(s2_conv3d_1)))
        s2_conv3d_3 = tf.keras.backend.squeeze(
            LeakyReLU(alpha=0.2)(BatchNormalization()(
                Conv3D(filters=64,
                       kernel_size=(7, 1, 1),
                       strides=(2, 1, 1),
                       padding="same")(s2_conv3d_2))), axis=1)

        # Concatenate the inputs and the target image
        concatenated = concatenate(
            [s2_conv3d_3, target])

        # Encode to patches
        enc1 = LeakyReLU(alpha=0.2)(Conv2D(
            64, (4, 4), strides=2, padding='same',
            kernel_initializer=init)(concatenated))
        enc2 = LeakyReLU(alpha=0.2)(BatchNormalization()(
            Conv2D(128, (4, 4), strides=2, padding='same',
                   kernel_initializer=init)(enc1)))
        enc3 = LeakyReLU(alpha=0.2)(BatchNormalization()(
            Conv2D(256, (4, 4), strides=2, padding='same',
                   kernel_initializer=init)(enc2)))
        zero_pad1 = tf.keras.layers.ZeroPadding2D()(enc3)
        enc4 = LeakyReLU(alpha=0.2)(BatchNormalization()(
            Conv2D(512, (4, 4), strides=1, padding='valid',
                   kernel_initializer=init, use_bias=False,)(zero_pad1)))
        zero_pad2 = tf.keras.layers.ZeroPadding2D()(enc4)
        patch_out = Activation('sigmoid')(
            Conv2D(1, (4, 4), strides=1, padding='valid',
                   kernel_initializer=init)(zero_pad2))

        discriminator = Model(inputs=(
            s2_input, target),
            outputs=patch_out,
            name='discriminator')

        return discriminator

    def build_generator(self):

        defaultkernelsize = 4
        defaultstride = 2
        init = RandomNormal(mean=0.0, stddev=0.02)

        def conv2d(layer_input, filters, kernelsize, stride, init,
                   batchnormalization=True, l2_weight_regularizer=1e-5):
            c = Conv2D(filters, kernel_size=(kernelsize, kernelsize),
                       strides=stride, padding='same',
                       activation=None,
                       kernel_initializer=init,
                       kernel_regularizer=l2(
                l2_weight_regularizer))(layer_input)
            if batchnormalization:
                c = BatchNormalization()(c)
            c = LeakyReLU(alpha=0.2)(c)
            return c

        def deconv2d(layer_input, skip_input, filters, kernelsize, stride,
                     init, dropout, batchnormalization=True,
                     l2_weight_regularizer=1e-5):
            d = Conv2DTranspose(filters, kernel_size=kernelsize,
                                strides=stride, padding='same',
                                activation=None,
                                kernel_initializer=init,
                                kernel_regularizer=l2(
                                    l2_weight_regularizer))(layer_input)
            if batchnormalization:
                d = BatchNormalization()(d)
            if dropout:
                d = Dropout(dropout)(d)
            d = Activation('relu')(d)
            if skip_input:
                if type(skip_input) is not list:
                    skip_input = [skip_input]
                d = concatenate([d] + skip_input)
            return d

        def _time_encoder(enc_input):

            filters = 8**(int(enc_input.shape[3]/5) + 1)

            input_reshaped = Reshape((enc_input.shape[1],
                                      self.windowsize,
                                      self.windowsize,
                                      enc_input.shape[3]))(enc_input)

            enc1 = ConvLSTM2D(
                filters=filters,
                kernel_size=3,
                padding='same',
                strides=(1, 1),
                return_sequences=True
            )(input_reshaped)
            enc_output = BatchNormalization()(ConvLSTM2D(
                filters=filters,
                kernel_size=3,
                padding='same',
                strides=(1, 1),
                return_sequences=False
            )(enc1))

            return enc_output

        def _time_encoder_conv3D(enc_input,
                                 l2_weight_regularizer=1e-5):

            filters = 32

            # First layer no batch normalisation
            enc1 = LeakyReLU(alpha=0.2)(
                Conv3D(filters=filters,
                       kernel_size=(7, 1, 1),
                       strides=(5, 1, 1),
                       padding="same",
                       kernel_regularizer=l2(
                           l2_weight_regularizer))(enc_input))
            enc2 = LeakyReLU(alpha=0.2)(BatchNormalization()(
                Conv3D(filters=filters*2,
                       kernel_size=(5, 1, 1),
                       strides=(5, 1, 1),
                       padding="same",
                       kernel_regularizer=l2(
                           l2_weight_regularizer))(enc1)))
            enc3 = tf.keras.backend.squeeze(
                LeakyReLU(alpha=0.2)(BatchNormalization()(
                    Conv3D(filters=filters*4,
                           kernel_size=(5, 1, 1),
                           strides=(5, 1, 1),
                           padding="same",
                           kernel_regularizer=l2(
                               l2_weight_regularizer))(enc2))),
                axis=1)

            return enc3

        def _encoder(enc_input):

            filters = 128

            enc1 = conv2d(enc_input,
                          filters=filters,
                          kernelsize=defaultkernelsize,
                          stride=defaultstride,
                          init=init)
            enc2 = conv2d(enc1,
                          filters=filters * 2,
                          kernelsize=defaultkernelsize,
                          stride=defaultstride,
                          init=init)
            enc3 = conv2d(enc2,
                          filters=filters * 2,
                          kernelsize=defaultkernelsize,
                          stride=defaultstride,
                          init=init)
            enc4 = conv2d(enc3,
                          filters=filters * 4,
                          kernelsize=defaultkernelsize,
                          stride=defaultstride,
                          init=init)

            # Bottleneck
            enc_output = conv2d(enc4,
                                filters=filters * 4,
                                kernelsize=defaultkernelsize,
                                stride=defaultstride,
                                init=init)

            return {'enc1': enc1, 'enc2': enc2, 'enc3': enc3,
                    'enc4': enc4, 'enc_output': enc_output}

        # Inputs
        inputs = self.inputs
        layers = {}
        for sensor in inputs.keys():
            layers[sensor] = {}

        # --------------------------------------
        # Temporal encoding step
        # --------------------------------------

        # Encoders
        for sensor in inputs.keys():
            layers[sensor]['time_encoded'] = _time_encoder_conv3D(
                inputs[sensor]
            )

        # --------------------------------------
        # Spatial encoding step
        # --------------------------------------

        # Encoders
        for sensor in inputs.keys():
            layers[sensor]['encoded'] = _encoder(
                layers[sensor]['time_encoded']
            )

        # --------------------------------------
        # Concatenation of the encoded features
        # --------------------------------------
        encoded = [
            layers[sensor]['encoded']['enc_output']
            for sensor in inputs.keys()
        ]

        if len(encoded) > 1:
            concatenated = concatenate(encoded)
        else:
            concatenated = encoded[0]

        # --------------------------------------
        # Spatial decoding step
        # --------------------------------------

        dec4 = deconv2d(concatenated,
                        skip_input=[layers[sensor]['encoded']['enc4']
                                    for sensor in layers.keys()],
                        filters=256,
                        kernelsize=defaultkernelsize,
                        stride=defaultstride,
                        init=init, dropout=self.dropoutfraction)
        dec3 = deconv2d(dec4,
                        skip_input=[layers[sensor]['encoded']['enc3']
                                    for sensor in layers.keys()],
                        filters=256,
                        kernelsize=defaultkernelsize,
                        stride=defaultstride,
                        init=init, dropout=self.dropoutfraction)
        dec2 = deconv2d(dec3,
                        skip_input=[layers[sensor]['encoded']['enc2']
                                    for sensor in layers.keys()],
                        filters=128,
                        kernelsize=defaultkernelsize,
                        stride=defaultstride,
                        init=init, dropout=0)
        dec1 = deconv2d(dec2,
                        skip_input=[layers[sensor]['encoded']['enc1']
                                    for sensor in layers.keys()],
                        filters=64,
                        kernelsize=defaultkernelsize,
                        stride=defaultstride,
                        init=init, dropout=0)

        # OUTPUT LAYER
        output = Activation('tanh')(Conv2DTranspose(
            1, (4, 4), strides=2, padding='same',
            kernel_initializer=init)(dec1))

        # Define the final generator model
        generator = Model(
            inputs=list(self.inputs.values()),
            outputs=output,
            name='generator')

        return generator


if __name__ == '__main__':

    cropsarmodel = CropsarPixelModel(modelinputs={'S1': 2, 'S2': 1})
    print(cropsarmodel.generator.summary())
