from tensorflow.keras.layers import (Input, concatenate,
                                     Conv2D, Dropout, LeakyReLU,
                                     Activation, Conv3D,
                                     )
from tensorflow.keras.layers import (BatchNormalization,
                                     Conv2DTranspose,
                                     Reshape)
from cropsar_px.utils.pconv import PConv3D
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras import Model
from tensorflow.keras.regularizers import l2
import tensorflow_addons as tfa
import tensorflow as tf


DEFAULT_MODEL_INPUTS = {
    'S1': 2,
    'S2': 1,
}


def tilted_loss(q, y, f):
    e = (y-f)
    return tf.keras.backend.mean(
        tf.keras.backend.maximum(q*e, (q-1)*e), axis=-1)


def lr_scheduler(epoch, lr):
    '''
    This function keeps the initial learning rate for the first X epochs
    and decreases it exponentially after that.
    '''
    if epoch < 4:
        return lr
    else:
        return lr * tf.math.exp(-0.1)


class CropsarPixelModel:

    def __init__(self, modelinputs=DEFAULT_MODEL_INPUTS,
                 windowsize=32, tslength=37):

        if windowsize != 32:
            raise ValueError(('Windowsize should 32.'))

        self.windowsize = windowsize
        self.tslength = tslength
        self.dropoutfraction = 0.5
        self.modelinputs = modelinputs

        optimizer = SGD(0.01)

        # Define model inputs
        inputs = {}
        for sensor in modelinputs.keys():
            inputs[sensor] = Input(
                shape=(self.tslength, 32, 32,
                       self.modelinputs[sensor]),
                name=sensor + '_input')
        self.inputs = inputs

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator(
            windowsize=self.windowsize)
        self.discriminator.compile(loss='binary_crossentropy',
                                   optimizer=optimizer,
                                   metrics=['accuracy'],
                                   loss_weights=[0.5])

        # Build the generator
        self.generator = self.build_generator()

        # Build the "frozen discriminator"
        inputs_discriminator = self.discriminator.inputs
        outputs_discriminator = self.discriminator.outputs
        self.frozen_discriminator = Model(
            inputs_discriminator, outputs_discriminator,
            name="frozen_discriminator")
        self.frozen_discriminator.trainable = False

        # By conditioning on inputs, generate fake cloudfree images
        q50 = self.generator(list(inputs.values()))

        # For now, only feed q50 to discriminator
        valid = self.frozen_discriminator([inputs['S1'],
                                           inputs['S2'],
                                           q50])

        self.combined = Model(inputs=list(inputs.values()),
                              outputs=[valid, q50])
        self.combined.compile(loss=['binary_crossentropy',
                                    lambda y, f: tilted_loss(0.5, y, f)],
                              loss_weights=[1, 100],
                              optimizer=optimizer,
                              metrics={
                                  'frozen_discriminator': 'accuracy',
                                  'generator_1': ['accuracy', 'mse']  # Q50
        })

    def build_discriminator(self, windowsize):

        def instancenorm_relu(inputs):
            inputs = tfa.layers.InstanceNormalization(
                axis=4,
                beta_initializer="random_uniform",
                gamma_initializer="random_uniform")(inputs)
            inputs = LeakyReLU(alpha=0.2)(inputs)

            return inputs

        def residual_block_3d(layer_input, filters, kernelsize, stride, init,
                              l2_weight_regularizer=1e-5):

            x = instancenorm_relu(layer_input)
            x = Conv3D(filters, kernel_size=kernelsize,
                       strides=stride, padding='same',
                       activation=None,
                       kernel_initializer=init,
                       kernel_regularizer=l2(
                           l2_weight_regularizer))(x)
            x = instancenorm_relu(x)
            x = Conv3D(filters, kernel_size=kernelsize,
                       strides=1, padding='same',
                       activation=None,
                       kernel_initializer=init,
                       kernel_regularizer=l2(
                           l2_weight_regularizer))(x)
            # Shortcut Connection (Identity Mapping)
            s = Conv3D(filters, kernel_size=1,
                       strides=stride, padding='same',
                       activation=None,
                       kernel_initializer=init,
                       kernel_regularizer=l2(
                           l2_weight_regularizer))(layer_input)
            # Addition
            x = x + s

            return x

        init = RandomNormal(mean=0.0, stddev=0.02)

        # Set the input: the discriminator is NOT
        # fully convolutional!
        s1_input = Input(
            shape=(self.tslength, windowsize, windowsize, 2),
            name='s1_input')
        s2_input = Input(
            shape=(self.tslength, windowsize, windowsize, 1),
            name='s2_input')

        # Target image, supposed to be cloudfree, either real or fake
        target = Input(shape=(windowsize, windowsize, 1),
                       name='target_image')

        # Repeat target in temporal dimension to be able
        # to concatenate the inputs
        exptarget = tf.repeat(tf.expand_dims(target, 1), self.tslength, axis=1)

        # Concatenate inputs and target
        concatenated = concatenate([s1_input, s2_input, exptarget])

        # Encode to patches
        filters = 64
        defaultkernelsize = 4
        defaultstride = 2

        x = Conv3D(filters, defaultkernelsize,
                   padding="same", strides=1,
                   kernel_initializer=init,
                   kernel_regularizer=l2(1e-5))(concatenated)
        x = instancenorm_relu(x)
        x = Conv3D(filters, defaultkernelsize,
                   padding="same", strides=1,
                   kernel_initializer=init,
                   kernel_regularizer=l2(1e-5))(x)
        s = Conv3D(filters, 1, padding="same")(concatenated)
        enc1 = x + s

        # Encoding step 2-4
        enc2 = residual_block_3d(enc1, filters * 2,
                                 (6, defaultkernelsize, defaultkernelsize),
                                 (3, defaultstride, defaultstride), init)

        enc3 = residual_block_3d(enc2, filters * 4,
                                 (6, defaultkernelsize, defaultkernelsize),
                                 (3, defaultstride, defaultstride), init)

        enc4 = residual_block_3d(enc3, filters * 8,
                                 (6, defaultkernelsize, defaultkernelsize),
                                 (3, defaultstride, defaultstride), init)

        enc5 = residual_block_3d(enc4, filters * 8,
                                 defaultkernelsize, defaultstride,
                                 init)

        patch_out = tf.keras.backend.squeeze(
            Activation('sigmoid')(
                Conv3D(filters=1,
                       kernel_size=1,
                       strides=1,
                       padding="same",
                       kernel_initializer=init)(enc5)),
            axis=1)

        discriminator = Model(inputs=(
            s1_input, s2_input, target),
            outputs=patch_out,
            name='discriminator')

        return discriminator

    def build_generator(self):

        defaultkernelsize = 4
        defaultstride = 2
        init = RandomNormal(mean=0.0, stddev=0.02)

        def instancenorm_relu(inputs):
            inputs = tfa.layers.InstanceNormalization(
                axis=4,
                beta_initializer="random_uniform",
                gamma_initializer="random_uniform")(inputs)
            inputs = LeakyReLU(alpha=0.2)(inputs)

            return inputs

        def residual_block_3d(layer_input, filters, kernelsize, stride, init,
                              l2_weight_regularizer=1e-5):

            x = instancenorm_relu(layer_input)
            x = Conv3D(filters, kernel_size=kernelsize,
                       strides=stride, padding='same',
                       activation=None,
                       kernel_initializer=init,
                       kernel_regularizer=l2(
                           l2_weight_regularizer))(x)
            x = instancenorm_relu(x)
            x = Conv3D(filters, kernel_size=kernelsize,
                       strides=1, padding='same',
                       activation=None,
                       kernel_initializer=init,
                       kernel_regularizer=l2(
                           l2_weight_regularizer))(x)
            # Shortcut Connection (Identity Mapping)
            s = Conv3D(filters, kernel_size=1,
                       strides=stride, padding='same',
                       activation=None,
                       kernel_initializer=init,
                       kernel_regularizer=l2(
                           l2_weight_regularizer))(layer_input)
            # Addition
            x = x + s

            return x

        def decoder_block_3d(layer_input, skip_input, filters,
                             kernelsize, stride,
                             init, output_padding=None):
            """ Decoder Block """
            x = tf.keras.layers.UpSampling3D(size=stride)(layer_input)

            if output_padding:
                x = tf.pad(x, [[0, 0],
                               [output_padding[0], 0],
                               [output_padding[1], 0],
                               [output_padding[2], 0],
                               [0, 0]])

            if skip_input:
                if type(skip_input) is not list:
                    skip_input = [skip_input]
                x = concatenate([x] + skip_input)

            x = residual_block_3d(x, filters, kernelsize,
                                  1, init)

            return x

        def _encoder_3D(enc_input):

            filters = 64

            enc_input = tf.pad(enc_input,
                               [
                                   [0, 0],  # Batch dim
                                   [2, 1],  # Temporal dim
                                   [0, 0], [0, 0],  # Spatial dims
                                   [0, 0]  # Channels
                               ])

            # Encoding step 1
            x = Conv3D(filters, defaultkernelsize,
                       padding="same", strides=1,
                       kernel_initializer=init,
                       kernel_regularizer=l2(1e-5))(enc_input)
            x = instancenorm_relu(x)
            x = Conv3D(filters, defaultkernelsize,
                       padding="same", strides=1,
                       kernel_initializer=init,
                       kernel_regularizer=l2(1e-5))(x)
            s = Conv3D(filters, 1, padding="same")(enc_input)
            enc1 = x + s

            # Encoding step 2-4
            enc2 = residual_block_3d(enc1, filters * 2, defaultkernelsize,
                                     defaultstride, init)

            enc3 = residual_block_3d(enc2, filters * 4, defaultkernelsize,
                                     defaultstride, init)

            # Bridge
            enc_output = residual_block_3d(enc3, filters * 4,
                                           defaultkernelsize, defaultstride,
                                           init)

            return {'enc1': enc1, 'enc2': enc2, 'enc3': enc3,
                    'enc_output': enc_output}

        # Inputs
        inputs = self.inputs
        layers = {}
        for sensor in inputs.keys():
            layers[sensor] = {}

        # --------------------------------------
        # Encoding step
        # --------------------------------------

        # Encoders
        for sensor in inputs.keys():
            layers[sensor]['encoded'] = _encoder_3D(
                inputs[sensor]
            )

        # --------------------------------------
        # Concatenation of the encoded features
        # --------------------------------------
        encoded = [
            layers[sensor]['encoded']['enc_output']
            for sensor in inputs.keys()
        ]

        if len(encoded) > 1:
            concatenated = concatenate(encoded)
        else:
            concatenated = encoded[0]

        # --------------------------------------
        # Spatial decoding step
        # --------------------------------------

        dec3 = decoder_block_3d(concatenated,
                                skip_input=[layers[sensor]['encoded']['enc3']
                                            for sensor in layers.keys()],
                                filters=256,
                                kernelsize=defaultkernelsize,
                                stride=defaultstride,
                                init=init)

        dec2 = decoder_block_3d(dec3,
                                skip_input=[layers[sensor]['encoded']['enc2']
                                            for sensor in layers.keys()],
                                filters=128,
                                kernelsize=defaultkernelsize,
                                stride=defaultstride,
                                init=init)
        dec1 = decoder_block_3d(dec2,
                                skip_input=[layers[sensor]['encoded']['enc1']
                                            for sensor in layers.keys()],
                                filters=64,
                                kernelsize=defaultkernelsize,
                                stride=defaultstride,
                                init=init)

        dec_t1 = residual_block_3d(dec1,
                                   filters=128,
                                   kernelsize=(defaultkernelsize, 1, 1),
                                   stride=(defaultstride, 1, 1),
                                   init=init)

        dec_t2 = residual_block_3d(dec_t1,
                                   filters=256,
                                   kernelsize=(defaultkernelsize, 1, 1),
                                   stride=(defaultstride, 1, 1),
                                   init=init)
        dec_t3 = residual_block_3d(dec_t2,
                                   filters=256,
                                   kernelsize=(defaultkernelsize, 1, 1),
                                   stride=(defaultstride, 1, 1),
                                   init=init)
        dec_t4 = residual_block_3d(dec_t3,
                                   filters=256,
                                   kernelsize=(defaultkernelsize, 1, 1),
                                   stride=(3, 1, 1),
                                   init=init)

        # OUTPUT LAYER
        output_q50 = tf.keras.backend.squeeze(
            Activation('tanh')(
                Conv3D(filters=1,
                       kernel_size=(defaultkernelsize, 1, 1),
                       strides=(defaultstride, 1, 1),
                       padding="same",
                       kernel_initializer=init,
                       name='q50')(dec_t4)),
            axis=1)

        # Define the final generator model
        generator = Model(
            inputs=list(self.inputs.values()),
            outputs=output_q50,
            name='generator')

        return generator


if __name__ == '__main__':

    cropsarmodel = CropsarPixelModel(modelinputs={'S1': 2, 'S2': 1})
    print(cropsarmodel.generator.summary())
