from tensorflow.keras.layers import (Input, concatenate,
                                     Conv2D, Dropout, LeakyReLU,
                                     Activation, Conv3D, Conv2D)
from tensorflow.keras.layers import (BatchNormalization, Flatten)
from cropsar_px.utils.pconv import PConv3D, PConvSimple3D, PConvSimple2D
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.initializers import RandomNormal
import numpy as np
from tensorflow.keras import Model
from tensorflow.keras.regularizers import l2
import tensorflow_addons as tfa
import tensorflow as tf

ACTIVATIONS = {
    'elu': tf.keras.layers.ELU(),
    'relu': Activation('relu'),
    'tanh': Activation('tanh'),
    'leakyrelu': LeakyReLU(alpha=0.2),
    'sigmoid': tf.keras.activations.sigmoid
}


DEFAULT_MODEL_INPUTS = {
    'S1': 2,
    'S2': 1,
}


def gen_conv_2d(x, filters, kernel_size, strides=1, dilation_rate=1,
                padding='same', activation='elu', **kwargs):
    """Define conv for generator.
    Args:
        x: Input.
        ksize: Kernel size.
        Stride: Convolution stride.
        Rate: Rate for or dilated conv.
        name: Name of layers.
        padding: Default to SYMMETRIC.
        activation: Activation function after convolution.
    Returns:
        tf.Tensor: output
    """

    # Perform convolution
    x = Conv2D(filters=filters, kernel_size=kernel_size,
               strides=strides, padding=padding, dilation_rate=dilation_rate,
               **kwargs)(x)

    if activation is not None:
        x, y = tf.split(x, 2, 3)
        x = ACTIVATIONS[activation](x)
        y = ACTIVATIONS['sigmoid'](y)
        x = x * y
    return x


def gen_conv_3d(x, filters, kernel_size, strides=1, dilation_rate=1,
                padding='same', activation='elu', **kwargs):
    """Define conv for generator.
    Args:
        x: Input.
        ksize: Kernel size.
        Stride: Convolution stride.
        Rate: Rate for or dilated conv.
        name: Name of layers.
        padding: Default to SYMMETRIC.
        activation: Activation function after convolution.
    Returns:
        tf.Tensor: output
    """

    # Perform convolution
    x = Conv3D(filters=filters, kernel_size=kernel_size,
               strides=strides, padding=padding, dilation_rate=dilation_rate,
               **kwargs)(x)

    if activation is not None:
        x, y = tf.split(x, 2, 4)
        x = ACTIVATIONS[activation](x)
        y = ACTIVATIONS['sigmoid'](y)
        x = x * y
    return x


def gen_deconv_2d(x, filters, kernel_size=3, padding='same', strides=2,
                  **kwargs):

    x = tf.keras.layers.UpSampling2D(size=strides)(x)
    x = gen_conv_2d(x, filters, kernel_size, 1, padding=padding,
                    **kwargs)

    return x


def gen_deconv_3d(x, filters, kernel_size=3, padding='same', strides=2,
                  **kwargs):

    x = tf.keras.layers.UpSampling3D(size=strides)(x)
    x = gen_conv_3d(x, filters, kernel_size, 1, padding=padding, **kwargs)

    return x


def dis_conv_2d(x, filters, kernel_size=3, strides=2,
                padding='same', final=False, batchnorm=False,
                **kwargs):

    x = tfa.layers.SpectralNormalization(
        Conv2D(filters, kernel_size, strides, padding=padding))(x)

    if batchnorm:
        x = BatchNormalization()(x)

    if not final:
        x = ACTIVATIONS['leakyrelu'](x)
    else:
        x = ACTIVATIONS['sigmoid'](x)

    return x


def resize_3d(x, scale=2, to_shape=None, dynamic=False,
              method='nearest'):

    if method != 'nearest':
        raise NotImplementedError(
            'Only nearest interpolation implemented for 5D tensors')

    # FIRST ON SPATIAL DIMENSIONS
    origshape = x.shape
    newshape = tf.reshape(x, (-1,
                              origshape[2], origshape[3],
                              origshape[4]))

    to_shape_first = to_shape[1:3] if to_shape is not None else None
    x_resize_first = resize(newshape, scale=scale,
                            to_shape=to_shape_first, method=method)
    x_resize_first = tf.reshape(
        x_resize_first,
        [-1] + [origshape.as_list()[1]] + x_resize_first.shape.as_list()[1:3] + [origshape[-1]])

    # THEN ON TEMPORAL DIMENSION
    x_resize_transpose_first = tf.transpose(x_resize_first, [0, 2, 1, 3, 4])
    origshape = x_resize_transpose_first.shape
    newshape = tf.reshape(x_resize_transpose_first,
                          (-1,
                           origshape[2], origshape[3],
                           origshape[4]))
    to_shape_second = [to_shape[0], to_shape[2]
                       ] if to_shape is not None else None
    x_resize_second = resize(
        newshape, scale=(scale, 1), to_shape=to_shape_second,
        method='nearest')
    x_resize_second = tf.reshape(
        x_resize_second,
        [-1] + [origshape.as_list()[1]] + x_resize_second.shape.as_list()[1:3] + [origshape[-1]])

    x_resize = tf.transpose(x_resize_second, [0, 2, 1, 3, 4])

    return x_resize


def resize(x, scale=2, to_shape=None, dynamic=False,
           method='bilinear'):

    if type(scale) != tuple:
        scale = (scale, scale)

    if dynamic:
        xs = tf.cast(tf.shape(x), tf.float32)
        new_xs = [tf.cast(xs[1]*scale[0], tf.int32),
                  tf.cast(xs[2]*scale[1], tf.int32)]
    else:
        xs = x.get_shape().as_list()
        new_xs = [int(xs[1]*scale[0]), int(xs[2]*scale[1])]
    if to_shape is None:
        x = tf.image.resize(x, new_xs, method=method)
    else:
        x = tf.image.resize(x, [to_shape[0], to_shape[1]],
                            method=method)
    return x


def resize_mask_like(mask, x):
    """Resize mask like shape of x.

    Had to change this a bit to make it 3D compatible
    Only works with nearest neighbor!

    Args:
        mask: Original mask.
        x: To shape of x.
    Returns:
        tf.Tensor: resized mask
    """

    assert len(mask.shape) == len(x.shape)

    if len(mask.shape) == 4:
        # Ordinary resize
        mask_resize = resize(
            mask, to_shape=x.get_shape().as_list()[1:3],
            method='nearest')
    elif len(mask.shape) == 5:
        # Need to reshape, resize, un-reshape
        origshape = mask.shape
        newshape = tf.reshape(mask, (-1,
                                     origshape[2], origshape[3],
                                     origshape[4]))
        resized = resize(
            newshape, to_shape=x.get_shape().as_list()[2:4],
            method='nearest')
        mask_resize = tf.reshape(
            resized,
            [-1] + [origshape.as_list()[1]] + x.get_shape().as_list()[2:4] + [origshape[-1]])
        mask_resize_transpose = tf.transpose(mask_resize, [0, 2, 1, 3, 4])
        origshape = mask_resize_transpose.shape
        newshape = tf.reshape(mask_resize_transpose,
                              (-1,
                               origshape[2], origshape[3],
                                  origshape[4]))
        resized = resize(
            newshape, to_shape=[x.get_shape().as_list()[1],
                                x.get_shape().as_list()[3]],
            method='nearest')
        mask_resize = tf.reshape(
            resized,
            [-1] + [origshape.as_list()[1]] + [x.get_shape().as_list()[1]] + [x.get_shape().as_list()[3]] + [origshape[-1]])

        mask_resize = tf.transpose(mask_resize, [0, 2, 1, 3, 4])

        return mask_resize

    else:
        raise ValueError(f'Cannot resize tensor of ndim={len(mask.shape)}')
    return


def contextual_attention(f, b, mask=None, ksize=3, stride=1, rate=1,
                         fuse_k=3, softmax_scale=10., fuse=True):
    """ Contextual attention layer implementation.
    FROM: https://github.com/JiahuiYu/generative_inpainting/blob/3a5324373ba52c68c79587ca183bc10b9e57b783/inpaint_ops.py#L256

    Adapted to cover 3D but not really tested for now.

    """
    # get shapes
    raw_fs = tf.shape(f)
    raw_int_fs = f.get_shape().as_list()
    raw_int_bs = b.get_shape().as_list()

    # extract patches from background with stride and rate
    kernel = 2 * rate
    raw_w = tf.extract_volume_patches(
        b, [1, kernel, kernel, kernel, 1],
        [1, rate*stride, rate*stride, rate*stride, 1],
        'SAME')
    raw_w = tf.reshape(
        raw_w,
        [raw_int_bs[0], -1, kernel, kernel, kernel, raw_int_bs[4]])
    # transpose to b*k*k*k*c*hw
    raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 5, 1])

    # downscaling foreground option: downscaling both foreground and
    # background for matching and use original background for reconstruction.
    f = resize_3d(f, scale=1./rate, method='nearest')
    # https://github.com/tensorflow/tensorflow/issues/11651
    b = resize_3d(b, to_shape=[int(raw_int_bs[1]/rate), int(raw_int_bs[3]/rate),
                               int(raw_int_bs[3]/rate)], method='nearest')
    if mask is not None:
        mask = resize_3d(mask, scale=1./rate,
                         method='nearest')
    fs = tf.shape(f)
    int_fs = f.get_shape().as_list()
    f_groups = tf.split(f, int_fs[0], axis=0)
    # from t(H*W*C) to w(b*k*k*c*h*w)
    bs = tf.shape(b)
    int_bs = b.get_shape().as_list()
    w = tf.extract_volume_patches(
        b, [1, ksize, ksize, ksize, 1], [1, stride, stride, stride, 1], 'SAME')
    w = tf.reshape(w, [int_fs[0], -1, ksize, ksize, ksize, int_fs[4]])
    w = tf.transpose(w, [0, 2, 3, 4, 5, 1])  # transpose to b*k*k*c*hw
    # process mask
    if mask is None:
        mask = tf.zeros([int_fs[0], bs[1], bs[2], bs[3], 1])
    m = tf.extract_volume_patches(
        mask, [1, ksize, ksize, ksize, 1], [1, stride, stride, stride, 1], 'SAME')
    m = tf.reshape(m, [int_fs[0], -1, ksize, ksize, ksize, 1])
    m = tf.transpose(m, [0, 2, 3, 4, 5, 1])  # transpose to b*k*k*c*hw
    # m = m[0]
    mm = tf.cast(tf.equal(tf.reduce_mean(
        m, axis=[1, 2, 3, 4], keepdims=True), 0.), tf.float32)
    mm_groups = tf.split(mm, int_bs[0], axis=0)
    w_groups = tf.split(w, int_bs[0], axis=0)
    raw_w_groups = tf.split(raw_w, int_bs[0], axis=0)
    y = []
    offsets = []
    k = fuse_k
    scale = softmax_scale
    fuse_weight = tf.reshape(tf.eye(k), [k, k, 1, 1])
    for xi, wi, raw_wi, mmi in zip(f_groups, w_groups, raw_w_groups, mm_groups):
        # conv for compare
        wi = wi[0]
        mmi = mmi[0]
        wi_normed = wi / \
            tf.maximum(tf.sqrt(tf.reduce_sum(
                tf.square(wi), axis=[0, 1, 2, 3, 4])), 1e-4)
        yi = tf.nn.conv3d(xi, wi_normed, strides=[
                          1, 1, 1, 1, 1], padding="SAME")

        # conv implementation for fuse scores to encourage large patches
        if fuse:
            yi = tf.reshape(yi, [1, fs[1]*fs[2]*fs[3], bs[1]*bs[2]*bs[3], 1])
            yi = tf.nn.conv2d(yi, fuse_weight, strides=[
                              1, 1, 1, 1], padding='SAME')
            yi = tf.reshape(yi, [1, fs[1], fs[2], fs[3], bs[1], bs[2], bs[3]])
            yi = tf.transpose(yi, [0, 2, 1, 3, 5, 4, 6])
            yi = tf.reshape(yi, [1, fs[1]*fs[2]*fs[3], bs[1]*bs[2]*fs[3], 1])
            yi = tf.nn.conv2d(yi, fuse_weight, strides=[
                              1, 1, 1, 1], padding='SAME')
            yi = tf.reshape(yi, [1, fs[2], fs[1], fs[3], bs[2], bs[1], bs[3]])
            yi = tf.transpose(yi, [0, 2, 1, 3, 5, 4, 6])
            yi = tf.transpose(yi, [0, 3, 2, 1, 6, 4, 5])
            yi = tf.reshape(yi, [1, fs[1]*fs[2]*fs[3], bs[1]*bs[2]*fs[3], 1])
            yi = tf.nn.conv2d(yi, fuse_weight, strides=[
                              1, 1, 1, 1], padding='SAME')
            yi = tf.reshape(yi, [1, fs[3], fs[2], fs[1], bs[3], bs[2], bs[1]])
            yi = tf.transpose(yi, [0, 3, 2, 1, 6, 4, 5])

        yi = tf.reshape(yi, [1, fs[1], fs[2], fs[3], bs[1]*bs[2]*bs[3]])

        # softmax to match
        yi *= mmi  # mask
        yi = tf.nn.softmax(yi*scale, 3)
        yi *= mmi  # mask

        offset = tf.argmax(yi, axis=4, output_type=tf.int32)
        offset = tf.stack([offset // fs[3], offset % fs[3]], axis=-1)
        # deconv for patch pasting
        # 3.1 paste center
        wi_center = raw_wi[0]
        yi = tf.nn.conv3d_transpose(yi, wi_center, tf.concat(
            [[1], raw_fs[1:]], axis=0), strides=[1, rate, rate, rate, 1]) / 8.
        y.append(yi)
        offsets.append(offset)
    y = tf.concat(y, axis=0)

    return y


def smooth_positive_labels(y):
    '''Helper function to introduce some noise
    around the hard 1 labels
    '''
    return y - 0.15 + (np.random.random(y.shape) * 0.25)


def smooth_negative_labels(y):
    '''Helper function to introduce some noise
    around the hard 0 labels
    '''
    return y + np.random.random(y.shape) * 0.15


def noisy_labels(y, p_flip):
    '''Helper function to flip some labels
    to improve robustness of network
    (NOT USED AT THE MOMENT)
    '''
    # determine the number of labels to flip
    n_select = int(p_flip * y.shape[0])
    # choose labels to flip
    flip_ix = np.random.choice([i for i in range(y.shape[0])], size=n_select)
    # invert the labels
    y[flip_ix] = 1 - y[flip_ix]
    return y


class CropsarPixelModel:

    def __init__(self, modelinputs=DEFAULT_MODEL_INPUTS,
                 windowsize=32, tslength=37, batch_size=None):

        if windowsize != 32:
            raise ValueError('Windowsize should 32.')

        self.windowsize = windowsize
        self.batch_size = batch_size
        self.tslength = tslength
        self.modelinputs = modelinputs

        # Define model inputs
        inputs = {}
        for sensor in modelinputs.keys():
            inputs[sensor] = Input(
                shape=(self.tslength, self.windowsize,
                       self.windowsize,
                       self.modelinputs[sensor]),
                batch_size=self.batch_size,
                name=sensor + '_input')
        self.inputs = inputs

        # Build the discriminator
        self.discriminator = self.build_discriminator(
            windowsize=self.windowsize)
        self.disc_optimizer = Adam(learning_rate=1e-4,
                                   beta_1=0.5,
                                   beta_2=0.999)

        # Build the generator
        self.generator = self.build_generator()
        self.gen_optimizer = Adam(learning_rate=1e-4,
                                  beta_1=0.5,
                                  beta_2=0.999)

    def build_discriminator_DeepFill(self, windowsize):
        '''
        Based on DeepFill architecture.
        NOT USED ATM
        '''

        # 2 inputs = image (real/fake) and mask
        img_input = Input(
            shape=(windowsize, windowsize, 1),
            name='disc_img_input')

        mask_input = Input(
            shape=(windowsize, windowsize, 1),
            name='disc_mask_input')

        inputs = concatenate([img_input, mask_input])

        filters = 64
        x = dis_conv_2d(inputs, filters, batchnorm=True, name='conv1')
        x = dis_conv_2d(x, filters*2, batchnorm=True, name='conv2')
        x = dis_conv_2d(x, filters*4, batchnorm=True, name='conv3')
        x = Dropout(0.2)(x)
        x = dis_conv_2d(x, filters*4, batchnorm=True, name='conv4')
        x = Dropout(0.2)(x)
        x = dis_conv_2d(x, filters*4, batchnorm=True, name='conv5')
        x = Dropout(0.2)(x)
        x = dis_conv_2d(x, filters*4, strides=1, name='last')
        x = Dropout(0.2)(x)
        x = Flatten(name='flatten')(x)

        discriminator = Model(inputs=[img_input, mask_input],
                              outputs=x,
                              name='discriminator')

        return discriminator

    def build_generator_DeepFill(self):
        '''
        Based on DeepFill architecture.
        NOT USED ATM
        '''

        # Inputs -> for testing, currently purely S2
        inputs = self.inputs['S2']

        # Pad inputs so temporal dimension is 40 in size
        # which is more convenient than 37.
        inputs_pad = tf.pad(inputs,
                            [
                                [0, 0],  # Batch dim
                                [2, 1],  # Temporal dim
                                [0, 0], [0, 0],  # Spatial dims
                                [0, 0]  # Channels
                            ])

        # Derive the full 3D mask
        # 1 = masked, 0 = valid
        mask = tf.cast(tf.equal(inputs_pad, tf.constant(float(0))), tf.float32)
        ones_x = tf.ones_like(inputs_pad)[:, :, :, :, 0:1]

        x = tf.concat([inputs_pad, ones_x, ones_x * mask], axis=4)

        # Original implementation has the very strange value of 48: https://github.com/JiahuiYu/generative_inpainting/blob/3a5324373ba52c68c79587ca183bc10b9e57b783/inpaint_model.py#L44
        filters = 64

        # Largely following original code, but then in 3D
        # Mind that maybe we go already too deep here. Dilation_rate 16
        # might not make any sense on such small patches.

        # stage1
        x = gen_conv_3d(x, filters, kernel_size=5, strides=1,
                        name='conv1')
        x = gen_conv_3d(x, 2 * filters, kernel_size=3, strides=2,
                        name='conv2_downsample')
        x = gen_conv_3d(x, 2 * filters, kernel_size=3, strides=1,
                        name='conv3')
        x = gen_conv_3d(x, 4 * filters, kernel_size=3, strides=2,
                        name='conv4_downsample')
        x = gen_conv_3d(x, 4 * filters, kernel_size=3, strides=1,
                        name='conv5')
        x = gen_conv_3d(x, 4 * filters, kernel_size=3, strides=1,
                        name='conv6')
        mask_s = resize_mask_like(mask, x)
        x = gen_conv_3d(x, 4 * filters, kernel_size=3, name='conv7')
        x = gen_conv_3d(x, 4 * filters, kernel_size=3, name='conv8')
        x = gen_conv_3d(x, 4 * filters, kernel_size=3, name='conv9')
        x = gen_conv_3d(x, 4 * filters, kernel_size=3, name='conv10')
        x = gen_conv_3d(x, 4 * filters, kernel_size=3, name='conv11')
        x = gen_deconv_3d(x, 2 * filters, name='conv12_upsample')
        x = gen_conv_3d(x, 2 * filters, kernel_size=3, name='conv13')
        x = gen_deconv_3d(x, filters, name='conv14_upsample')
        x = gen_conv_3d(x, filters // 2, kernel_size=3, name='conv15')
        x = gen_conv_3d(x, 1, kernel_size=3, activation=None, name='conv16')
        x = ACTIVATIONS['tanh'](x)
        x_stage1 = x

        # Get rid of temporal dimension
        # for the side track where we need to train
        # on the coarse resolution prediction as well
        x_stage1 = tf.keras.backend.squeeze(
            tf.keras.layers.AveragePooling3D(pool_size=(40, 1, 1))(x_stage1),
            axis=1)

        # stage2, paste result as input
        x = x * mask + inputs_pad*(1.-mask)

        # conv branch
        xnow = x
        x = gen_conv_3d(xnow, filters, 5, 1, name='xconv1')
        x = gen_conv_3d(x, filters, 3, 2, name='xconv2_downsample')
        x = gen_conv_3d(x, 2 * filters, 3, 1, name='xconv3')
        x = gen_conv_3d(x, 2 * filters, 3, 2, name='xconv4_downsample')
        x = gen_conv_3d(x, 4 * filters, 3, 1, name='xconv5')
        x = gen_conv_3d(x, 4 * filters, 3, 1, name='xconv6')
        x = gen_conv_3d(x, 4 * filters, 3, 1, name='xconv7')
        x = gen_conv_3d(x, 4 * filters, 3, 1, name='xconv8')
        x = gen_conv_3d(x, 4 * filters, 3, 1, name='xconv9')
        x_hallu = x

        # attention branch
        x = gen_conv_3d(xnow, filters, 5, 1, name='pmconv1')
        x = gen_conv_3d(x, filters, 3, 2, name='pmconv2_downsample')
        x = gen_conv_3d(x, 2*filters, 3, 1, name='pmconv3')
        x = gen_conv_3d(x, 4*filters, 3, 2, name='pmconv4_downsample')
        x = gen_conv_3d(x, 4*filters, 3, 1, name='pmconv5')
        x = gen_conv_3d(x, 4*filters, 3, 1, name='pmconv6',
                        activation='relu')
        x = contextual_attention(x, x, mask_s, 3, 1, rate=1)
        x = gen_conv_3d(x, 4*filters, 3, 1, name='pmconv9')
        x = gen_conv_3d(x, 4*filters, 3, 1, name='pmconv10')
        pm = x
        x = tf.concat([x_hallu, pm], axis=4)

        '''
        Original implementation is already in 2D. Next Conv layers
        have a stride in temporal dimension so we can finally get rid
        of it.
        '''
        x = gen_conv_3d(x, 4*filters, 3, (2, 1, 1), name='allconv11')
        x = gen_conv_3d(x, 4*filters, 3, (2, 1, 1), name='allconv12')
        x = tf.keras.backend.squeeze(
            gen_conv_3d(x, 4*filters, 3, (3, 1, 1), name='allconv13'),
            axis=1)
        '''
        From here on we're in 2D and can follow original code
        '''
        x = gen_deconv_2d(x, 2*filters, name='allconv13_upsample')
        x = gen_conv_2d(x, 2*filters, 3, 1, name='allconv14')
        x = gen_deconv_2d(x, filters, name='allconv15_upsample')
        x = gen_conv_2d(x, filters//2, 3, 1, name='allconv16')
        x = gen_conv_2d(x, 1, 3, 1, activation=None, name='allconv17')
        x = ACTIVATIONS['tanh'](x)  # If output scaled between [-1, 1]
        x_stage2 = x

        # Define the final generator model
        generator = Model(
            inputs=inputs,
            outputs=[x_stage1, x_stage2],
            name='generator')

        return generator

    def build_discriminator(self, windowsize):
        '''
        Currently used discriminator. Losely based on
        Pix2Pix PatchGAN discriminator
        '''

        def _batchnorm_leakyrelu(x):
            return LeakyReLU(alpha=0.2)(BatchNormalization()(x))

        init = RandomNormal(mean=0.0, stddev=0.02)

        # Set the input: the discriminator is NOT
        # fully convolutional!
        s1_input = Input(
            shape=(self.tslength, windowsize, windowsize, 2),
            name='s1_input')
        s2_input = Input(
            shape=(self.tslength, windowsize, windowsize, 1),
            name='s2_input')

        # Target image, supposed to be cloudfree, either real or fake
        target = Input(shape=(windowsize, windowsize, 1),
                       name='target_image')

        # The mask of center input
        mask = Input(shape=(windowsize, windowsize, 1),
                     name='mask_image')

        # Concatenate the inputs
        inputs = concatenate([s1_input, s2_input])

        # Use partial conv3d to remove the temporal dimension
        conv3d_1 = _batchnorm_leakyrelu(
            PConvSimple3D(filters=64,
                          kernel_size=(7, 1, 1),
                          strides=(5, 1, 1),
                          padding="same")(inputs))
        conv3d_2 = _batchnorm_leakyrelu(
            PConvSimple3D(filters=64,
                          kernel_size=(5, 1, 1),
                          strides=(5, 1, 1),
                          padding="same")(conv3d_1))
        conv3d_3 = tf.keras.backend.squeeze(
            _batchnorm_leakyrelu(
                PConvSimple3D(filters=64,
                              kernel_size=(7, 1, 1),
                              strides=(2, 1, 1),
                              padding="same")(conv3d_2),), axis=1)

        # Concatenate the inputs, the target image and the mask
        concatenated = concatenate([conv3d_3, target, mask])

        # Encode to patches
        enc1 = Dropout(0.3)(_batchnorm_leakyrelu(Conv2D(
            64, (4, 4), strides=2, padding='same',
            kernel_initializer=init)(concatenated)))
        enc2 = Dropout(0.3)(_batchnorm_leakyrelu(
            Conv2D(128, (4, 4), strides=2, padding='same',
                   kernel_initializer=init)(enc1)))
        enc3 = Dropout(0.3)(_batchnorm_leakyrelu(
            Conv2D(256, (4, 4), strides=2, padding='same',
                   kernel_initializer=init)(enc2)))
        zero_pad1 = tf.keras.layers.ZeroPadding2D()(enc3)
        enc4 = Dropout(0.3)(_batchnorm_leakyrelu(
            Conv2D(512, (4, 4), strides=1, padding='valid',
                   kernel_initializer=init, use_bias=False,)(zero_pad1)))
        zero_pad2 = tf.keras.layers.ZeroPadding2D()(enc4)
        patch_out = Dropout(0.3)(Flatten()(
            Conv2D(1, (4, 4), strides=1, padding='valid',
                   kernel_initializer=init)(zero_pad2)))

        # NOTE: output is not sigmoid-activated make
        # sure to use appropriate loss.

        discriminator = Model(inputs=(
            s1_input, s2_input, target, mask),
            outputs=patch_out,
            name='discriminator')

        return discriminator

    def build_generator(self):
        '''
        Currently used generator. Losely based on Pix2Pix UNET,
        but in 3D and using gated convolutions.
        '''

        defaultkernelsize = 4
        defaultstride = 2
        init = RandomNormal(mean=0.0, stddev=0.02)

        def conv2d(layer_input, filters, kernelsize, stride, init,
                   batchnormalization=True, l2_weight_regularizer=1e-5,
                   activation='relu'):
            c = Conv2D(filters, kernel_size=(kernelsize, kernelsize),
                       strides=stride, padding='same',
                       activation=None,
                       kernel_initializer=init,
                       kernel_regularizer=l2(
                l2_weight_regularizer))(layer_input)
            if batchnormalization:
                c = self.batchnorm(c)
            if activation is not None:
                # Used gating mechanism
                c, y = tf.split(c, 2, 3)
                c = Activation(activation)(c)
                y = Activation('sigmoid')(y)
                c = c * y

            return c

        def deconv2d(layer_input, skip_input, filters, kernelsize, stride,
                     init, dropout, batchnormalization=True,
                     l2_weight_regularizer=1e-5):
            d = tf.keras.layers.UpSampling2D(size=stride)(layer_input)
            d = conv2d(d, filters, kernelsize=kernelsize,
                       stride=1, init=init, batchnormalization=False,
                       l2_weight_regularizer=l2_weight_regularizer,
                       activation=None)
            if batchnormalization:
                d = self.batchnorm(d)
            if dropout:
                d = Dropout(dropout)(d)

            # Used gating mechanism
            d, y = tf.split(d, 2, 3)
            d = LeakyReLU(alpha=0.2)(d)
            y = Activation('sigmoid')(y)
            d = d * y

            if skip_input:
                if type(skip_input) is not list:
                    skip_input = [skip_input]
                d = concatenate([d] + skip_input)
            return d

        def _time_encoder_conv3D(enc_input,
                                 l2_weight_regularizer=1e-5):

            filters = 64

            def _gated_leakyrelu(x, batchnorm=False):
                if batchnorm:
                    x = self.batchnorm(x)
                x, y = tf.split(x, 2, 4)
                x = LeakyReLU(alpha=0.2)(x)
                y = Activation('sigmoid')(y)
                x = x * y

                return x

            # First layer no batch normalisation
            enc1 = _gated_leakyrelu(
                Conv3D(filters=filters,
                       kernel_size=(7, 1, 1),
                       strides=(5, 1, 1),
                       padding="same",
                       kernel_regularizer=l2(
                           l2_weight_regularizer))(enc_input))
            enc2 = _gated_leakyrelu(
                Conv3D(filters=filters*2,
                       kernel_size=(5, 1, 1),
                       strides=(5, 1, 1),
                       padding="same",
                       kernel_regularizer=l2(
                           l2_weight_regularizer))(enc1),
                batchnorm=True)
            enc3 = tf.keras.backend.squeeze(
                _gated_leakyrelu(
                    Conv3D(filters=filters*4,
                           kernel_size=(5, 1, 1),
                           strides=(5, 1, 1),
                           padding="same",
                           kernel_regularizer=l2(
                               l2_weight_regularizer))(enc2),
                    batchnorm=True),
                axis=1)

            return enc3

        def _encoder(enc_input):

            filters = 256

            enc1 = conv2d(enc_input,
                          filters=filters,
                          kernelsize=defaultkernelsize,
                          stride=defaultstride,
                          init=init)
            enc2 = conv2d(enc1,
                          filters=filters * 2,
                          kernelsize=defaultkernelsize,
                          stride=defaultstride,
                          init=init)
            enc3 = conv2d(enc2,
                          filters=filters * 2,
                          kernelsize=defaultkernelsize,
                          stride=defaultstride,
                          init=init)
            enc4 = conv2d(enc3,
                          filters=filters * 4,
                          kernelsize=defaultkernelsize,
                          stride=defaultstride,
                          init=init)

            # Bottleneck
            enc_output = conv2d(enc4,
                                filters=filters * 4,
                                kernelsize=defaultkernelsize,
                                stride=defaultstride,
                                init=init)

            return {'enc1': enc1, 'enc2': enc2, 'enc3': enc3,
                    'enc4': enc4, 'enc_output': enc_output}

        # Inputs
        inputs = self.inputs
        layers = {}
        for sensor in inputs.keys():
            layers[sensor] = {}

        # --------------------------------------
        # Temporal encoding step
        # --------------------------------------

        # Encoders
        for sensor in inputs.keys():
            layers[sensor]['time_encoded'] = _time_encoder_conv3D(
                inputs[sensor]
            )

        # --------------------------------------
        # Spatial encoding step
        # --------------------------------------

        # Encoders
        for sensor in inputs.keys():
            layers[sensor]['encoded'] = _encoder(
                layers[sensor]['time_encoded']
            )

        # --------------------------------------
        # Concatenation of the encoded features
        # --------------------------------------
        encoded = [
            layers[sensor]['encoded']['enc_output']
            for sensor in inputs.keys()
        ]

        if len(encoded) > 1:
            concatenated = concatenate(encoded)
        else:
            concatenated = encoded[0]

        # --------------------------------------
        # Spatial decoding step
        # --------------------------------------

        dec4 = deconv2d(concatenated,
                        skip_input=[layers[sensor]['encoded']['enc4']
                                    for sensor in layers.keys()],
                        filters=256,
                        kernelsize=defaultkernelsize,
                        stride=defaultstride,
                        init=init, dropout=0.3)
        dec3 = deconv2d(dec4,
                        skip_input=[layers[sensor]['encoded']['enc3']
                                    for sensor in layers.keys()],
                        filters=256,
                        kernelsize=defaultkernelsize,
                        stride=defaultstride,
                        init=init, dropout=0.3)
        dec2 = deconv2d(dec3,
                        skip_input=[layers[sensor]['encoded']['enc2']
                                    for sensor in layers.keys()],
                        filters=128,
                        kernelsize=defaultkernelsize,
                        stride=defaultstride,
                        init=init, dropout=0.3)
        dec1 = deconv2d(dec2,
                        skip_input=[layers[sensor]['encoded']['enc1']
                                    for sensor in layers.keys()],
                        filters=64,
                        kernelsize=defaultkernelsize,
                        stride=defaultstride,
                        init=init, dropout=0.3,
                        batchnormalization=False)

        # OUTPUT LAYER
        output = tf.keras.layers.UpSampling2D(size=2)(dec1)
        output = Activation('tanh')(Conv2D(
            1, (4, 4), strides=1, padding='same',
            kernel_initializer=init)(output))

        # Define the final generator model
        generator = Model(
            inputs=list(self.inputs.values()),
            outputs=output,
            name='generator')

        return generator

    def batchnorm(self, inputs):

        # Values that are equal to mask
        # will not be touched
        masks = tf.cast(tf.not_equal(inputs, tf.constant(float(0))),
                        tf.float32)

        adj_inputs = BatchNormalization()(inputs)
        adj_inputs = adj_inputs * masks

        return adj_inputs

    def gan_hinge_loss(self, pos, neg):
        """
        https://github.com/JiahuiYu/neuralgym/blob/88292adb524186693a32404c0cfdc790426ea441/neuralgym/ops/gan_ops.py#L50
        """
        hinge_pos = tf.reduce_mean(tf.nn.relu(1-pos))
        hinge_neg = tf.reduce_mean(tf.nn.relu(1+neg))
        d_loss = tf.add(.5 * hinge_pos, .5 * hinge_neg)
        g_loss = -tf.reduce_mean(neg)

        return g_loss, d_loss

    def discriminator_loss(self, disc_real_output, disc_generated_output):
        loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

        ones = smooth_positive_labels(np.ones_like(disc_real_output))
        zeros = smooth_negative_labels(np.zeros_like(disc_generated_output))
        noisy_ones = noisy_labels(ones, 0.1)
        noisy_zeros = noisy_labels(zeros, 0.1)

        real_loss = loss_object(noisy_ones, disc_real_output)
        generated_loss = loss_object(noisy_zeros, disc_generated_output)

        total_disc_loss = tf.add(.5 * real_loss, .5 * generated_loss)

        return total_disc_loss

    def generator_loss(self, disc_generated_output,
                       gen_output, target):
        loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        LAMBDA = 5  # Weight on L1 loss

        gan_loss = loss_object(tf.ones_like(disc_generated_output),
                               disc_generated_output)

        # Mean absolute error
        l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

        total_gen_loss = gan_loss + (LAMBDA * l1_loss)

        return total_gen_loss, gan_loss, l1_loss

    def gan_reconstruction_loss(self, y_pred, y_true):
        return tf.reduce_mean(tf.abs(y_true - y_pred))

    def _run_step(self, inputs, outputs, masks, training=False):

        predicted = self.generator(inputs, training=training)

        completed = predicted * masks + outputs * (1. - masks)

        real_output = self.discriminator([*inputs, outputs, masks],
                                         training=training)
        fake_output = self.discriminator([*inputs, completed, masks],
                                         training=training)

        gen_total_loss, gen_gan_loss, gen_l1_loss = self.generator_loss(
            fake_output, predicted, outputs)

        # gen_hinge_loss, disc_hinge_loss = self.gan_hinge_loss(
        #     real_output, fake_output)

        # gen_reconstruction_loss = self.gan_reconstruction_loss(
        #     predicted, outputs)

        # gen_loss = gen_hinge_loss + LAMBDA * gen_reconstruction_loss
        # disc_loss = disc_hinge_loss

        disc_loss = self.discriminator_loss(real_output, fake_output)

        losses = dict(
            disc_loss=disc_loss,
            gen_hinge_loss=gen_gan_loss,
            gen_reconstruction_loss=gen_l1_loss,
            gen_total_loss=gen_total_loss
        )

        return losses

    # @tf.function
    def train_step(self, inputs, outputs, masks):

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            losses = self._run_step(inputs, outputs, masks, training=True)

        gradients_of_generator = gen_tape.gradient(
            losses['gen_total_loss'], self.generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(
            losses['disc_loss'], self.discriminator.trainable_variables)

        self.disc_optimizer.apply_gradients(
            zip(gradients_of_discriminator,
                self.discriminator.trainable_variables))

        self.gen_optimizer.apply_gradients(
            zip(gradients_of_generator,
                self.generator.trainable_variables))

        return losses

    def test_step(self, inputs, outputs, masks):
        '''
        Note: The training=True is intentional here since you want the batch
        statistics, while running the model on the test dataset.
        If you use training=False, you get the accumulated statistics
        learned from the training dataset (which you don't want).
        (https://www.tensorflow.org/tutorials/generative/pix2pix?hl=en)
        '''

        # Still I put it to False now because Dropout is included
        # which is activated when putting training=True
        losses = self._run_step(inputs, outputs, masks, training=False)

        return losses


if __name__ == '__main__':

    model = CropsarPixelModel()
    model.generator.summary()
    model.discriminator.summary()
