"""Code mostly copied from https://github.com/MathiasGruber/PConv-Keras
"""

from tensorflow.python.keras.utils import conv_utils
from tensorflow.keras import backend as K
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.layers import Conv3D, InputSpec, LeakyReLU
from tensorflow import multiply, convert_to_tensor, cast, reduce_sum, equal, int32
from tensorflow.python.keras.initializers.initializers_v2 import GlorotUniform
from numpy.random import rand
from math import ceil, floor


# To use both Keras and Tensorflow layers in the same model do:
# import tensorflow as tf
# import keras
# from keras import backend as K

# tf_sess = tf.Session()
# K.set_session(tf_sess)

class PConv3D(Conv3D):
    def __init__(self, *args, n_channels=3, mono=False, **kwargs):
        super().__init__(*args, **kwargs)
        self.input_spec = [InputSpec(ndim=5), InputSpec(ndim=5)]

    def build(self, input_shape):
        """Adapted from original _Conv() layer of Keras        
        param input_shape: list of dimensions for [img, mask]
        """

        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1

        if input_shape[0][channel_axis] is None:
            raise ValueError(
                'The channel dimension of the inputs should be defined. Found `None`.')

        self.input_dim = input_shape[0][channel_axis]

        # Image kernel
        kernel_shape = self.kernel_size + (self.input_dim, self.filters)
        self.kernel = self.add_weight(shape=kernel_shape,
                                      initializer=self.kernel_initializer,
                                      name='img_kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)

        # Mask kernel
        self.kernel_mask = K.ones(
            shape=self.kernel_size + (self.input_dim, self.filters))

        # Calculate padding size to achieve zero-padding
        self.pconv_padding = (
            (int((self.kernel_size[0]-1)/2), int((self.kernel_size[0]-1)/2)),
            (int((self.kernel_size[0]-1)/2), int((self.kernel_size[0]-1)/2)),
            (int((self.kernel_size[0]-1)/2), int((self.kernel_size[0]-1)/2))
        )

        # Window size - used for normalization
        self.window_size = self.kernel_size[0] * self.kernel_size[1]

        if self.use_bias:
            self.bias = self.add_weight(shape=(self.filters,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        self.built = True

    def call(self, inputs):
        '''
        We will be using the Keras conv2d method, and essentially we have
        to do here is multiply the mask with the input X, before we apply the
        convolutions. For the mask itself, we apply convolutions with all weights
        set to 1.
        Subsequently, we clip mask values to between 0 and 1
        '''

        # Padding done explicitly so that padding becomes part of the masked partial convolution
        images = K.spatial_3d_padding(
            inputs[0], self.pconv_padding, self.data_format)
        masks = K.spatial_3d_padding(
            inputs[1], self.pconv_padding, self.data_format)

        # Apply convolutions to mask
        mask_output = K.conv3d(
            masks, self.kernel_mask,
            strides=self.strides,
            padding='valid',
            data_format=self.data_format,
            dilation_rate=self.dilation_rate
        )

        # Apply convolutions to image
        img_output = K.conv3d(
            multiply(images, masks),
            self.kernel,
            strides=self.strides,
            padding='valid',
            data_format=self.data_format,
            dilation_rate=self.dilation_rate
        )

        # Calculate the mask ratio on each pixel in the output mask
        mask_ratio = self.window_size / (mask_output + 1e-8)

        # Clip output to be between 0 and 1
        mask_output = K.clip(mask_output, 0, 1)

        # Remove ratio values where there are holes
        mask_ratio = mask_ratio * mask_output

        # Normalize image output
        img_output = img_output * mask_ratio

        # Apply bias only to the image (if chosen to do so)
        if self.use_bias:
            img_output = K.bias_add(
                img_output,
                self.bias,
                data_format=self.data_format)

        # Apply activations on the image
        if self.activation is not None:
            img_output = self.activation(img_output)

        return [img_output, mask_output]

    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_last':
            space = input_shape[0][1:-1]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding='same',
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            new_shape = (input_shape[0][0],) + \
                tuple(new_space) + (self.filters,)
            return [new_shape, new_shape]
        if self.data_format == 'channels_first':
            space = input_shape[2:]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding='same',
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            new_shape = (input_shape[0], self.filters) + tuple(new_space)
            return [new_shape, new_shape]


def tf_count(t, val):
    elements_equal_to_value = equal(t, val)
    as_ints = cast(elements_equal_to_value, int32)
    count = reduce_sum(as_ints)
    return count


class PConvSimple3D(Conv3D):
    def __init__(self, *args, padding='same', n_channels=3, mono=False, null_value=0, **kwargs):
        super().__init__(*args, **kwargs)
        self.input_spec = InputSpec(ndim=5)
        self.padding = padding
        self.null_value = null_value

    def build(self, input_shape):
        """Adapted from original _Conv() layer of Keras
        """

        if self.data_format == 'channels_first':
            channel_axis = 1
            time_axis = 2
            height_axis = 3
            width_axis = 4
        else:
            channel_axis = 4
            time_axis = 1
            height_axis = 2
            width_axis = 3

        if input_shape[channel_axis] is None:
            raise ValueError(
                'The channel dimension of the inputs should be defined. Found `None`.')

        self.input_dim = input_shape[channel_axis]

        # Image kernel
        kernel_shape = self.kernel_size + (self.input_dim, self.filters)
        self.kernel = self.add_weight(shape=kernel_shape,
                                      initializer=self.kernel_initializer,
                                      name='img_kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)

        # Mask kernel
        self.kernel_mask = K.ones(
            shape=self.kernel_size + (self.input_dim, self.filters))

        # Compute flexible padding
        if self.padding == 'same':
            # Get strides for the t, x and y dimensions
            if hasattr(self.strides, '__len__'):
                if len(self.strides) == 3:
                    stride_t, stride_h, stride_w = self.strides
            elif type(self.strides) is int:
                stride_t, stride_h, stride_w = self.strides, self.strides, self.strides
            else:
                raise ValueError(
                    'The strides should be either an int, a tuple of length 3. Found {}.'.format(self.strides))

            # Get kernel_size for the t, x and y dimensions
            if hasattr(self.kernel_size, '__len__'):
                if len(self.kernel_size) == 3:
                    kernel_size_t, kernel_size_h, kernel_size_w = self.kernel_size
            elif type(self.kernel_size) is int:
                kernel_size_t, kernel_size_h, kernel_size_w = self.kernel_size, self.kernel_size, self.kernel_size
            else:
                raise ValueError(
                    'The kernel_size should be either an int, a tuple of length 3. Found {}.'.format(self.kernel_size))

            # Function that returns (padding1, padding2) for a given dimension
            def create_same_padding(stride, dimension_size, kernel_size):
                # Decided to take the ceiling of the dimension_size/stride according to
                # https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks
                padding = (stride*ceil(dimension_size/stride) -
                           dimension_size + kernel_size - stride)
                return (floor(padding/2), ceil(padding/2))

            self.pconv_padding = (
                create_same_padding(
                    stride_t, input_shape[time_axis], kernel_size_t),
                create_same_padding(
                    stride_h, input_shape[height_axis], kernel_size_h),
                create_same_padding(
                    stride_w, input_shape[width_axis], kernel_size_w)
            )

        elif self.padding == 'valid':
            self.pconv_padding = ((0, 0), (0, 0), (0, 0))
        else:
            raise ValueError(
                'The padding value should be valid or same. Found {}.'.format(self.padding))

        # Window size - used for normalization
        self.window_size = self.kernel_size[0] * self.kernel_size[1]

        if self.use_bias:
            self.bias = self.add_weight(shape=(self.filters,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        self.built = True

    def call(self, input):
        '''
        We will be using the Keras conv3d method.
        If there are no missing values (null_value) in the input,
        we proceed to a regular convolution.
        Otherwise, we infer a mask based on null_values, and essentially we have
        to do here is multiply the mask with the input X, before we apply the
        convolutions. For the mask itself, we apply convolutions with all weights
        set to 1.
        Subsequently, we clip mask values to between 0 and 1
        '''

        # If the image is not masked, proceed to a regular 3d convolution
        if tf_count(input, self.null_value) == 0:
            img_output = K.conv3d(
                K.spatial_3d_padding(
                    input, self.pconv_padding, self.data_format),
                self.kernel,
                strides=self.strides,
                padding='valid',
                data_format=self.data_format,
                dilation_rate=self.dilation_rate
            )

        # If the image is masked, infer mask and proceed to partial convolution
        else:

            # Inferring mask from image
            # Only values equal to the null_value are considered invalid
            masks_array = input.numpy().copy()
            masks_array[masks_array != self.null_value] = 1

            # Padding done explicitly so that padding becomes part of the masked partial convolution
            masks = K.spatial_3d_padding(convert_to_tensor(
                masks_array), self.pconv_padding, self.data_format)
            images = K.spatial_3d_padding(
                input, self.pconv_padding, self.data_format)

            # Apply convolutions to mask
            mask_output = K.conv3d(
                masks, self.kernel_mask,
                strides=self.strides,
                padding='valid',
                data_format=self.data_format,
                dilation_rate=self.dilation_rate
            )

            # Apply convolutions to image
            img_output = K.conv3d(
                multiply(images, masks),
                self.kernel,
                strides=self.strides,
                padding='valid',
                data_format=self.data_format,
                dilation_rate=self.dilation_rate
            )

            # Calculate the mask ratio on each pixel in the output mask
            mask_ratio = self.window_size / (mask_output + 1e-8)

            # Clip output to be between 0 and 1
            mask_output = K.clip(mask_output, 0, 1)

            # Remove ratio values where there are holes
            mask_ratio = mask_ratio * mask_output

            # Normalize image output
            img_output = img_output * mask_ratio

            # Apply bias only to the image (if chosen to do so)
            if self.use_bias:
                img_output = K.bias_add(
                    img_output,
                    self.bias,
                    data_format=self.data_format)

            # Apply activations on the image
            if self.activation is not None:
                img_output = self.activation(img_output)

        return img_output

    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_last':
            space = input_shape[1:-1]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding='same',
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            new_shape = (input_shape[0],) + \
                tuple(new_space) + (self.filters,)
            return [new_shape, new_shape]
        if self.data_format == 'channels_first':
            space = input_shape[2:]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding='same',
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            new_shape = (input_shape, self.filters) + tuple(new_space)
            return new_shape


class PConvTest():

    def __init__(self, b, t, y, x, c, mask_limit=0.5):
        self.images, self.masks, self.reduced_images = self._create_inputs(
            b, t, y, x, c, mask_limit)

    def _create_inputs(self, b, t, y, x, c, mask_limit):
        input_data = rand(b, t, y, x, c) * 10
        # apply mask
        # the zeroes should be the mask in question
        mask_data = input_data.copy()
        mask_data[mask_data < mask_limit] = 0
        mask_data[mask_data >= mask_limit] = 1
        images = convert_to_tensor(input_data)
        masks = convert_to_tensor(mask_data)
        reduced_image = multiply(images, masks)
        return images, masks, reduced_image

    def visualize_regular_layer_effect(self, filters, kernel_size, strides, padding, is_printed=True):
        pconv_layer = PConv3D(filters, kernel_size,
                              strides=strides, padding=padding)
        output_images, output_masks = pconv_layer([self.images, self.masks])
        if is_printed:
            print('Output data', output_images)
            print('Mask output', output_masks)

    def visualize_reduced_layer_effect(self, filters, kernel_size, strides, padding, is_printed=True):
        pconv_layer = PConvSimple3D(
            filters, kernel_size, strides=strides, padding=padding)
        output_reduced_images = pconv_layer(self.reduced_images)
        if is_printed:
            print('Reduced data', output_reduced_images)

    def test_padding(self, kernel_size, strides, padding, filters=1):
        # Visualize the shape of inputs vs outputs
        # for different parameters
        pconv_layer = PConvSimple3D(
            filters, kernel_size, strides=strides, padding=padding)
        output_reduced_images = pconv_layer(self.reduced_images)
        res = {
            'input': self.reduced_images.shape,
            'output': output_reduced_images.shape,
        }
        is_correct = True
        if type(strides) is int:
            strides = [strides, strides, strides]
        if type(kernel_size) is int:
            kernel_size = [kernel_size, kernel_size, kernel_size]
        elif type(kernel_size) is tuple:
            kernel_size = list(kernel_size)
        if padding == 'same':
            for i in [1, 2, 3]:
                if res['output'][i] != ceil(res['input'][i]/strides[i-1]):
                    is_correct = False
        elif padding == 'valid':
            for i in [1, 2, 3]:
                oo = res['input'][i] + 1 - floor(kernel_size[i-1])
                nope = ceil(oo/strides[i-1])
                if res['output'][i] != nope:
                    is_correct = False
        return is_correct

    def compare_layers(self, filters, kernel_size, strides, padding):
        # ! Do not initialize layers with use_bias = False and a static kernel_initializer
        # outside of very specific testing scenarios the way I did it here
        # My aim was to initialize two similar layers with the same weights to see
        # if the operations performed yielded different results
        initializer = GlorotUniform(seed=1)
        pconv_layer = PConv3D(filters, kernel_size, strides=strides,
                              padding=padding, use_bias=False, kernel_initializer=initializer)
        output, _ = pconv_layer([self.images, self.masks])
        pconv_layer2 = PConvSimple3D(filters, kernel_size, strides=strides,
                                     padding=padding, use_bias=False, kernel_initializer=initializer)
        output2 = pconv_layer2(self.reduced_images)
        res = bool(tf_count(equal(output, output2), False).numpy())
        return res

    def _regular_layer(self):
        self.visualize_regular_layer_effect(
            1, (3, 3, 3), 1, padding='same', is_printed=False)

    def _reduced_layer(self):
        self.visualize_reduced_layer_effect(
            1, (3, 3, 3), 1, padding='same', is_printed=False)

    def test_memory_usage(self):
        from memory_profiler import memory_usage
        # Test the memory usage of both layers.
        print('Regular Layer')
        mem_usage = memory_usage(self._regular_layer)
        print('Memory usage (in chunks of .1 seconds): %s' % mem_usage)
        print('Maximum memory usage: %s' % max(mem_usage))
        print('\nReduced Layer')
        mem_usage = memory_usage(self._reduced_layer)
        print('Memory usage (in chunks of .1 seconds): %s' % mem_usage)
        print('Maximum memory usage: %s' % max(mem_usage))

    def test_encoding_steps(self):
        def instancenorm_relu(inputs):
            # TODO: do not touch zero values?!
            inputs = tfa.layers.InstanceNormalization(
                axis=4,
                beta_initializer="random_uniform",
                gamma_initializer="random_uniform")(inputs)
            inputs = LeakyReLU(alpha=0.2)(inputs)

            return inputs

        def residual_block_3d(layer_input, filters, kernelsize, stride, init,
                              l2_weight_regularizer=1e-5):

            x = instancenorm_relu(layer_input)
            x = PConvSimple3D(filters, kernel_size=kernelsize,
                              strides=stride, padding='same')(x)
            x = instancenorm_relu(x)
            x = PConvSimple3D(filters, kernel_size=kernelsize,
                              strides=1, padding='same')(x)
            # Shortcut Connection (Identity Mapping)
            s = PConvSimple3D(filters, kernel_size=1,
                              strides=stride, padding='valid')(layer_input)
            # Addition
            x = x + s

            return x

        filters = 64
        defaultkernelsize = 4
        defaultstride = 2

        input_data = tf.pad(self.images,
                            [
                                [0, 0],  # Batch dim
                                [2, 1],  # Temporal dim
                                [0, 0], [0, 0],  # Spatial dims
                                [0, 0]  # Channels
                            ])
        print(f'Padded input: {input_data.shape}')

        # Encoding step 1
        x = PConvSimple3D(filters, defaultkernelsize,
                          padding="same", strides=1)(input_data)
        x = PConvSimple3D(filters, defaultkernelsize,
                          padding="same", strides=1)(x)
        s = PConvSimple3D(filters, 1, padding="same")(input_data)
        enc1 = x + s

        print(f'Output shape after first encoding: {enc1.shape}')

        # Encoding step 2
        enc2 = residual_block_3d(enc1, filters * 2, defaultkernelsize,
                                 defaultstride, None)
        print(f'Output shape after second encoding: {enc2.shape}')


if __name__ == "__main__":
    test = PConvTest(1, 37, 32, 32, 2, mask_limit=0.6)
    # test = PConvTest(1, 2, 9, 9, 1, mask_limit=0.7)
    # Select the tests you want to run
    # test.visualize_regular_layer_effect(1, (3, 3, 3), 1, padding='same')
    # test.visualize_reduced_layer_effect(1, (3, 3, 3), 1, padding='same')
    # print(test.test_padding((4, 4, 4), 2, padding='same'))
    # print(test.test_padding((2, 2, 2), 1, padding='same'))
    # print(test.test_padding((3, 3, 3), 1, padding='valid'))
    # print(test.compare_layers(1, (3, 3, 3), 2, padding='same'))
    # test.test_memory_usage()
    test.test_encoding_steps()
