from tensorflow.keras.layers import (Input, concatenate,
                                     Conv2D, Dropout, LeakyReLU,
                                     Activation, Conv3D, TimeDistributed,
                                     )
from tensorflow.keras.layers import (BatchNormalization,
                                     Conv2DTranspose)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras import Model
from cropsar_px.utils.pconv import PConvSimple2D, PConvSimple3D
from cropsar_px.models.layers.norm import (MaskedInstanceNormalization,
                                           MaskedBatchNormalization,
                                           MaskedGroupNormalization)
from tensorflow.keras.regularizers import l2
import tensorflow_addons as tfa
import tensorflow as tf
import threading
import numpy as np

_threadlocal = threading.local()


DEFAULT_MODEL_INPUTS = {
    'S1': 2,
    'S2': 1,
}

ACTIVATIONS = {
    'elu': tf.keras.layers.ELU(),
    'relu': Activation('relu'),
    'tanh': Activation('tanh'),
    'leakyrelu': LeakyReLU(alpha=0.2),
    'sigmoid': tf.keras.activations.sigmoid
}


def load_generator_model():
    # Keras/tensorflow models are not guaranteed to be threadsafe,
    # but by loading and storing the model once per thread we should
    # be able to safely eliminate loading at model predict time
    generator_model = getattr(_threadlocal, 'tf_model', None)
    if generator_model is None:
        import io
        import pkgutil
        import h5py
        from tensorflow.keras.models import load_model

        # Load tensorflow model from in-memory HDF5 resource
        path = 'resources/cropsar_px_generator.h5'
        data = pkgutil.get_data('cropsar_px', path)

        with h5py.File(io.BytesIO(data), mode='r') as h5:
            generator_model = load_model(h5)

        # Store per thread
        _threadlocal.generator_model = generator_model

    return generator_model


def smooth_positive_labels(y):
    '''Helper function to introduce some noise
    around the hard 1 labels
    '''
    return y - 0.15 + (np.random.random(y.shape) * 0.25)


def smooth_negative_labels(y):
    '''Helper function to introduce some noise
    around the hard 0 labels
    '''
    return y + np.random.random(y.shape) * 0.15


def noisy_labels(y, p_flip):
    '''Helper function to flip some labels
    to improve robustness of network
    (NOT USED AT THE MOMENT)
    '''
    # determine the number of labels to flip
    n_select = int(p_flip * y.shape[0])
    # choose labels to flip
    flip_ix = np.random.choice([i for i in range(y.shape[0])], size=n_select)
    # invert the labels
    y[flip_ix] = 1 - y[flip_ix]
    return y


class CropsarPixelModel:
    def __init__(self, modelinputs=DEFAULT_MODEL_INPUTS,
                 windowsize=32, tslength=32):
        if windowsize != 32:
            raise ValueError(('Windowsize should be 32.'))

        self.windowsize = windowsize
        self.tslength = tslength
        self.modelinputs = modelinputs

        # Define model inputs
        inputs = {}
        for sensor in modelinputs.keys():
            inputs[sensor] = Input(
                shape=(self.tslength, None, None,
                       self.modelinputs[sensor]),
                name=sensor + '_input')
        self.inputs = inputs

        # Build the discriminator
        self.discriminator = self.build_discriminator(
            windowsize=self.windowsize)
        self.disc_optimizer = Adam(learning_rate=1e-4,
                                   beta_1=0.5,
                                   beta_2=0.999)

        # Build the generator
        self.generator = self.build_generator()
        self.gen_optimizer = Adam(learning_rate=1e-4,
                                  beta_1=0.5,
                                  beta_2=0.999)

    def build_discriminator(self, windowsize):

        init = RandomNormal(mean=0.0, stddev=0.02)

        s1_input = Input(
            shape=(windowsize, windowsize, 2),
            name='s1_input')
        s2_input = Input(
            shape=(windowsize, windowsize, 1),
            name='s2_input')

        # Target image, supposed to be cloudfree, either real or fake
        target = Input(shape=(windowsize, windowsize, 1),
                       name='target_image')

        # Concatenate temporal inputs
        concatenated = concatenate([s1_input, s2_input, target])

        filters = 64
        defaultkernelsize = 4
        defaultstride = 2

        # Encoding step 1-4
        enc1 = self.residual_block(concatenated, filters,
                                   defaultkernelsize,
                                   defaultstride, init)

        enc2 = self.residual_block(enc1, filters * 2,
                                   defaultkernelsize,
                                   defaultstride, init)

        enc3 = self.residual_block(enc2, filters * 4,
                                   defaultkernelsize,
                                   defaultstride, init)

        enc4 = self.residual_block(enc3, filters * 8,
                                   defaultkernelsize,
                                   defaultstride, init)

        patch_out = Conv2D(1, 1, strides=1, padding='valid',
                           kernel_initializer=init)(enc4)

        discriminator = Model(inputs=(
            s1_input, s2_input, target),
            outputs=patch_out,
            name='discriminator')

        return discriminator

    def build_generator(self):

        defaultkernelsize = 4
        defaultstride = 2
        init = RandomNormal(mean=0.0, stddev=0.02)

        def time_encoder(enc_input,
                         l2_weight_regularizer=1e-5):

            filters = 32

            # Use partial conv3d to remove the temporal dimension
            conv3d_1 = self.residual_block(
                enc_input, filters=filters,
                kernelsize=(7, 1, 1), conv3d=True,
                stride=(5, 1, 1), init=init,
                l2_weight_regularizer=l2_weight_regularizer)

            conv3d_2 = self.residual_block(
                conv3d_1, filters=filters * 2,
                kernelsize=(5, 1, 1), conv3d=True,
                stride=(5, 1, 1), init=init,
                l2_weight_regularizer=l2_weight_regularizer)
            conv3d_3 = tf.keras.backend.squeeze(
                self.residual_block(
                    conv3d_2, filters=filters * 4,
                    kernelsize=(5, 1, 1), conv3d=True,
                    stride=(5, 1, 1), init=init,
                    l2_weight_regularizer=l2_weight_regularizer), axis=1)

            return conv3d_3

        def _encoder_3D(enc_input):

            filters = 32

            enc0 = self.residual_block(enc_input, filters, defaultkernelsize,
                                       1, init, conv3d=True)

            enc1 = self.residual_block(enc0, filters * 2, defaultkernelsize,
                                       defaultstride, init, conv3d=True)

            enc2 = self.residual_block(enc1, filters * 4, defaultkernelsize,
                                       defaultstride, init, conv3d=True)

            enc3 = self.residual_block(enc2, filters * 8, defaultkernelsize,
                                       defaultstride, init, conv3d=True)

            enc4 = self.residual_block(enc3, filters * 8, defaultkernelsize,
                                       defaultstride, init, conv3d=True)

            # Bridge
            enc_output = self.residual_block(enc4, filters * 8,
                                             defaultkernelsize, defaultstride,
                                             init, conv3d=True)

            return {'enc0': enc0, 'enc1': enc1, 'enc2': enc2, 'enc3': enc3,
                    'enc4': enc4,
                    'enc_output': enc_output}

        # Inputs
        inputs = self.inputs
        layers = {}
        for sensor in inputs.keys():
            layers[sensor] = {}

        # --------------------------------------
        # Encoding step
        # --------------------------------------

        # Encoders
        for sensor in inputs.keys():
            layers[sensor]['encoded'] = _encoder_3D(
                inputs[sensor]
            )

        # --------------------------------------
        # Concatenation of the encoded features
        # --------------------------------------
        encoded = [
            layers[sensor]['encoded']['enc_output']
            for sensor in inputs.keys()
        ]

        if len(encoded) > 1:
            concatenated = concatenate(encoded)
        else:
            concatenated = encoded[0]

        # --------------------------------------
        # Spatial decoding step
        # --------------------------------------

        dec4 = self.decoder_block(concatenated,
                                  skip_input=[layers[sensor]['encoded']['enc4']
                                              for sensor in layers.keys()],
                                  filters=512,
                                  kernelsize=defaultkernelsize,
                                  stride=defaultstride,
                                  init=init,
                                  conv3d=True,
                                  activation='leakyrelu')

        dec3 = self.decoder_block(dec4,
                                  skip_input=[layers[sensor]['encoded']['enc3']
                                              for sensor in layers.keys()],
                                  filters=256,
                                  kernelsize=defaultkernelsize,
                                  stride=defaultstride,
                                  init=init,
                                  conv3d=True,
                                  activation='leakyrelu')

        dec2 = self.decoder_block(dec3,
                                  skip_input=[layers[sensor]['encoded']['enc2']
                                              for sensor in layers.keys()],
                                  filters=128,
                                  kernelsize=defaultkernelsize,
                                  stride=defaultstride,
                                  init=init,
                                  conv3d=True,
                                  activation='leakyrelu')
        dec1 = self.decoder_block(dec2,
                                  skip_input=[layers[sensor]['encoded']['enc1']
                                              for sensor in layers.keys()],
                                  filters=64,
                                  kernelsize=defaultkernelsize,
                                  stride=defaultstride,
                                  init=init,
                                  conv3d=True,
                                  activation='leakyrelu')

        dec0 = self.decoder_block(dec1,
                                  skip_input=[layers[sensor]['encoded']['enc0']
                                              for sensor in layers.keys()],
                                  filters=32,
                                  kernelsize=defaultkernelsize,
                                  stride=defaultstride,
                                  init=init,
                                  conv3d=True,
                                  activation='leakyrelu')

        dec_t1 = self.conv3d(dec0,
                             filters=64,
                             kernelsize=(defaultkernelsize, 1, 1),
                             stride=(defaultstride, 1, 1),
                             init=init,
                             activation='relu')

        dec_t2 = self.conv3d(dec_t1,
                             filters=128,
                             kernelsize=(defaultkernelsize, 1, 1),
                             stride=(defaultstride, 1, 1),
                             init=init,
                             activation='relu')
        dec_t3 = self.conv3d(dec_t2,
                             filters=256,
                             kernelsize=(defaultkernelsize, 1, 1),
                             stride=(defaultstride, 1, 1),
                             init=init,
                             activation='relu')
        dec_t4 = self.conv3d(dec_t3,
                             filters=256,
                             kernelsize=(defaultkernelsize, 1, 1),
                             stride=(defaultstride, 1, 1),
                             init=init,
                             activation='relu')

        # OUTPUT LAYER
        output = tf.keras.backend.squeeze(
            Activation('tanh')(
                Conv3D(filters=1,
                       kernel_size=(defaultkernelsize, 1, 1),
                       strides=(defaultstride, 1, 1),
                       padding="same",
                       kernel_initializer=init,
                       name='output')(dec_t4)),
            axis=1)

        # Define the final generator model
        generator = Model(
            inputs=list(self.inputs.values()),
            outputs=output,
            name='generator')

        return generator

    def residual_block(self, layer_input, filters, kernelsize, stride, init,
                       activation='relu', conv3d=False, timedistributed=False,
                       l2_weight_regularizer=1e-5):

        if not conv3d and timedistributed:
            raise ValueError(
                'Cannot do timedistributed Conv2D on a Conv2D task')

        if not conv3d:
            convolution = PConvSimple2D
            # convolution = Conv2D
        else:
            if not timedistributed:
                convolution = PConvSimple3D
                # convolution = Conv3D
            else:
                convolution = PConvSimple2D
                # convolution = Conv2D

        # First convolution
        x = convolution(filters, kernel_size=kernelsize,
                        strides=stride, padding='same',
                        activation=None,
                        kernel_initializer=init,
                        kernel_regularizer=l2(
                            l2_weight_regularizer))
        if timedistributed:
            x = TimeDistributed(x)
        x = x(layer_input)

        # Batchnormalization and activation
        x = self.batchnorm(x)
        x = ACTIVATIONS[activation](x)

        # Second convolution
        x2 = convolution(filters, kernel_size=kernelsize,
                         strides=1, padding='same',
                         activation=None,
                         kernel_initializer=init,
                         kernel_regularizer=l2(
                             l2_weight_regularizer))
        if timedistributed:
            x2 = TimeDistributed(x2)
        x2 = x2(x)

        # Shortcut Connection (Identity Mapping)
        s = convolution(filters, kernel_size=1,
                        strides=stride, padding='valid',
                        activation=None,
                        kernel_initializer=init,
                        kernel_regularizer=l2(
                            l2_weight_regularizer))
        if timedistributed:
            s = TimeDistributed(s)
        s = s(layer_input)

        # Addition
        x = x + s

        # Batchnormalization and activation
        x = self.batchnorm(x)
        x = ACTIVATIONS[activation](x)

        return x

    def decoder_block(self, layer_input, skip_input, filters,
                      kernelsize, stride, init, activation='relu',
                      conv3d=False, l2_weight_regularizer=1e-5):
        """ Decoder Block """

        if conv3d:
            x = tf.keras.layers.UpSampling3D(size=stride)(layer_input)
        else:
            x = tf.keras.layers.UpSampling2D(size=stride)(layer_input)

        if skip_input:
            if type(skip_input) is not list:
                skip_input = [skip_input]
            x = concatenate([x] + skip_input)

        x = self.residual_block(x, filters, kernelsize, 1, init,
                                activation=activation, conv3d=conv3d,
                                l2_weight_regularizer=l2_weight_regularizer)

        return x

    def conv3d(self, layer_input, filters, kernelsize, stride, init,
               batchnormalization=True, l2_weight_regularizer=1e-5,
               activation='relu'):
        c = Conv3D(filters, kernel_size=kernelsize,
                   strides=stride, padding='same',
                   activation=None,
                   kernel_initializer=init,
                   kernel_regularizer=l2(
                       l2_weight_regularizer))(layer_input)
        if batchnormalization:
            c = self.batchnorm(c)
        if activation is not None:
            c = Activation(activation)(c)
        return c

    def batchnorm(self, inputs):

        # Values that are equal to mask
        # will not be touched

        # adj_inputs = MaskedBatchNormalization()(inputs)
        adj_inputs = BatchNormalization()(inputs)
        masks = tf.cast(tf.not_equal(inputs, tf.constant(float(0))),
                        tf.float32)
        adj_inputs = adj_inputs * masks

        return adj_inputs

    def groupnorm(self, inputs):
        from tensorflow_addons.layers import GroupNormalization

        # Values that are equal to mask
        # will not be touched
        masks = tf.cast(tf.not_equal(inputs, tf.constant(float(0))),
                        tf.float32)

        nrgroups = min(int(inputs.shape[-1] / 2), 32)
        # adj_inputs = MaskedGroupNormalization(groups=nrgroups)(inputs, masks)
        adj_inputs = GroupNormalization(groups=nrgroups)(inputs)
        masks = tf.cast(tf.not_equal(inputs, tf.constant(float(0))),
                        tf.float32)
        adj_inputs = adj_inputs * masks

        return adj_inputs

    def instancenorm(self, inputs):

        # WARNING: instance normalization seems to break
        # training! Loss does not go down.

        # Values that are equal to mask
        # will not be touched
        masks = tf.cast(tf.not_equal(inputs, tf.constant(float(0))),
                        tf.float32)

        adj_inputs = MaskedInstanceNormalization(
            beta_initializer="random_uniform",
            gamma_initializer="random_uniform")(inputs, masks)

        adj_inputs = adj_inputs * masks

        return adj_inputs

    def discriminator_loss(self, disc_real_output, disc_generated_output):
        loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

        ones = smooth_positive_labels(np.ones_like(disc_real_output))
        zeros = smooth_negative_labels(np.zeros_like(disc_generated_output))
        noisy_ones = noisy_labels(ones, 0.1)
        noisy_zeros = noisy_labels(zeros, 0.1)

        real_loss = loss_object(noisy_ones, disc_real_output)
        generated_loss = loss_object(noisy_zeros, disc_generated_output)

        total_disc_loss = tf.add(.5 * real_loss, .5 * generated_loss)

        return total_disc_loss

    def generator_loss(self, disc_generated_output,
                       gen_output, target, masks):
        loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

        LAMBDA = 100  # Weight on L1 loss

        gan_loss = loss_object(tf.ones_like(disc_generated_output),
                               disc_generated_output)

        # # Mean absolute error
        # n_inpainted_pixels = tf.reduce_sum(
        #     tf.cast(tf.equal(masks, tf.constant(0., tf.float32)), tf.float32))
        # n_original_pixels = tf.reduce_sum(
        #     tf.cast(tf.equal(masks, tf.constant(1., tf.float32)), tf.float32))

        # l1_loss_inpainted = tf.reduce_sum(
        #     tf.abs(target * (1. - masks) -
        #            gen_output * (1. - masks))) / n_inpainted_pixels
        # l1_loss_original = tf.reduce_sum(
        #     tf.abs(target * masks -
        #            gen_output * masks)) / n_original_pixels

        # # 80-20 weight on inpainted/original pixels
        # if not np.isfinite(l1_loss_inpainted.numpy()):
        #     l1_loss = l1_loss_original
        # elif not np.isfinite(l1_loss_original.numpy()):
        #     l1_loss = l1_loss_inpainted
        # else:
        #     l1_loss = 0.8 * l1_loss_inpainted + 0.2 * l1_loss_original

        l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

        total_gen_loss = gan_loss + (LAMBDA * l1_loss)

        return total_gen_loss, gan_loss, l1_loss

    def _run_step(self, inputs, outputs, masks, training=False):

        predicted = self.generator(inputs, training=training)

        completed = predicted * (1. - masks) + outputs * masks

        disc_inputs = []
        for x in inputs:
            disc_inputs.append(x[:, self.tslength // 2, ...])

        real_output = self.discriminator([*disc_inputs, outputs],
                                         training=training)
        fake_output = self.discriminator([*disc_inputs, completed],
                                         training=training)

        gen_total_loss, gen_gan_loss, gen_l1_loss = self.generator_loss(
            fake_output, predicted, outputs, masks)

        disc_loss = self.discriminator_loss(real_output, fake_output)

        losses = dict(
            disc_loss=disc_loss,
            gen_gan_loss=gen_gan_loss,
            gen_reconstruction_loss=gen_l1_loss,
            gen_total_loss=gen_total_loss
        )

        return losses

    # @tf.function
    def train_step(self, inputs, outputs, masks):

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            losses = self._run_step(inputs, outputs, masks, training=True)

        gradients_of_generator = gen_tape.gradient(
            losses['gen_total_loss'], self.generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(
            losses['disc_loss'], self.discriminator.trainable_variables)

        self.disc_optimizer.apply_gradients(
            zip(gradients_of_discriminator,
                self.discriminator.trainable_variables))

        self.gen_optimizer.apply_gradients(
            zip(gradients_of_generator,
                self.generator.trainable_variables))

        return losses

    def test_step(self, inputs, outputs, masks):
        '''
        Note: The training=True is intentional here since you want the batch
        statistics, while running the model on the test dataset.
        If you use training=False, you get the accumulated statistics
        learned from the training dataset (which you don't want).
        (https://www.tensorflow.org/tutorials/generative/pix2pix?hl=en)
        '''

        # Mind what happens when we apply Dropout!
        losses = self._run_step(inputs, outputs, masks, training=True)

        return losses


if __name__ == '__main__':

    cropsarmodel = CropsarPixelModel(modelinputs={'S1': 2, 'S2': 1})
    print(cropsarmodel.generator.summary(line_length=150))
