"""Code mostly copied from https://github.com/MathiasGruber/PConv-Keras
"""

from math import ceil, floor
from numpy import ones
from numpy.random import rand
import functools
import six
from tensorflow.python.ops import nn_ops
from tensorflow.python.keras.initializers.initializers_v2 import GlorotUniform
from tensorflow import (multiply, convert_to_tensor, cast,
                        reduce_sum, equal, int32)
from tensorflow.keras.layers import Conv3D, InputSpec, LeakyReLU, Input
import tensorflow_addons as tfa
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Conv3D, InputSpec
from tensorflow.python.keras.utils import conv_utils
from tensorflow.keras import backend as K
from cropsar_px.models.layers.norm import MaskedInstanceNormalization

# from memory_profiler import memory_usage


# To use both Keras and Tensorflow layers in the same model do:
# import tensorflow as tf
# import keras
# from keras import backend as K

# tf_sess = tf.Session()
# K.set_session(tf_sess)

'''
NOTE

we'll need to remember that the PConv's as pushed to master are in
fact probably broken. To support variable size inputs I had to make
some changes and use more of the internal convolution operators,
but how it now works with padded zeroes is a bit obscure. In any
case, the PConvSimple first checks for invalid values and if none
are found, it takes the normal Conv pathway, meaning that also
padded zeroes during convolutions are not treated as masked values.
The Pconv with explicit mask does not have this simplified pathway
so they way it works now is that the padded zeroes will become part
of the mask as well. Therefore, a test on a completely valid image
that checks identical outputs from PConv and PConvSimple will fail,
because the former will still take the masked pathway in the padded
regions, while the latter will take the shortcut to ordinary
convolutions.
'''


class PConv3D(Conv3D):
    def __init__(self, *args, padding='same', **kwargs):
        super().__init__(*args, padding=padding, **kwargs)
        self.input_spec = [InputSpec(ndim=5), InputSpec(ndim=5)]

    def build(self, input_shape):
        """Adapted from original _Conv() layer of Keras
        param input_shape: list of dimensions for [img, mask]
        """

        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1

        if input_shape[0][channel_axis] is None:
            raise ValueError(
                'The channel dimension of the inputs should be defined. Found `None`.')

        self.input_dim = input_shape[0][channel_axis]

        # Image kernel
        kernel_shape = self.kernel_size + (self.input_dim, self.filters)
        self.kernel = self.add_weight(shape=kernel_shape,
                                      initializer=self.kernel_initializer,
                                      name='img_kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)

        # Mask kernel
        self.kernel_mask = K.constant(
            1, shape=self.kernel_size + (self.input_dim, self.filters))

        # Window size - used for normalization
        self.window_size = self.kernel_size[0] * \
            self.kernel_size[1] * self.kernel_size[2]

        if self.use_bias:
            self.bias = self.add_weight(shape=(self.filters,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None

        # Convert Keras formats to TF native formats.
        if self.padding == 'causal':
            tf_padding = 'VALID'  # Causal padding handled in `call`.
        elif isinstance(self.padding, six.string_types):
            tf_padding = self.padding.upper()
        else:
            tf_padding = self.padding
        tf_dilations = list(self.dilation_rate)
        tf_strides = list(self.strides)

        tf_op_name = self.__class__.__name__
        if tf_op_name == 'Conv1D':
            tf_op_name = 'conv1d'  # Backwards compat.

        self.convolution_op = functools.partial(
            nn_ops.convolution_v2,
            strides=tf_strides,
            padding=tf_padding,
            dilations=tf_dilations,
            data_format=self._tf_data_format,
            name=tf_op_name)

        self.built = True

    def call(self, inputs):
        '''
        We will be using the Keras conv2d method, and essentially we have
        to do here is multiply the mask with the input X, before we apply the
        convolutions. For the mask itself, we apply convolutions with all weights
        set to 1.
        Subsequently, we clip mask values to between 0 and 1
        '''

        # Perform convolutions on masks and inputs
        mask_output = self.convolution_op(inputs[1], self.kernel_mask)
        img_output = self.convolution_op(multiply(inputs[0], inputs[1]),
                                          self.kernel)

        # Calculate the mask ratio on each pixel in the output mask
        mask_ratio = self.window_size / (mask_output + 1e-8)

        # Clip output to be between 0 and 1
        mask_output = K.clip(mask_output, 0, 1)

        # Remove ratio values where there are holes
        mask_ratio = mask_ratio * mask_output

        # Normalize image output
        img_output = img_output * mask_ratio

        # Apply bias only to the image (if chosen to do so)
        if self.use_bias:
            img_output = K.bias_add(
                img_output,
                self.bias,
                data_format=self.data_format)

        # Apply activations on the image
        if self.activation is not None:
            img_output = self.activation(img_output)

        return [img_output, mask_output]

    def compute_output_shape(self, input_shape):
        new_shape = super().compute_output_shape(input_shape[0])
        return [new_shape, new_shape]


def tf_count(t, val):
    elements_equal_to_value = equal(t, val)
    as_ints = cast(elements_equal_to_value, int32)
    count = reduce_sum(as_ints)
    return count


class PConvSimple3DXENIA(Conv3D):
    def __init__(self, *args, padding='same', n_channels=3, mono=False, null_value=0, **kwargs):
        super().__init__(*args, **kwargs)
        self.input_spec = InputSpec(ndim=5)
        self.padding = padding
        self.null_value = null_value

    def build(self, input_shape):
        """Adapted from original _Conv() layer of Keras
        """

        if self.data_format == 'channels_first':
            channel_axis = 1
            time_axis = 2
            height_axis = 3
            width_axis = 4
        else:
            channel_axis = 4
            time_axis = 1
            height_axis = 2
            width_axis = 3

        if input_shape[channel_axis] is None:
            raise ValueError(
                'The channel dimension of the inputs should be defined. Found `None`.')

        self.input_dim = input_shape[channel_axis]

        # Image kernel
        kernel_shape = self.kernel_size + (self.input_dim, self.filters)
        self.kernel = self.add_weight(shape=kernel_shape,
                                      initializer=self.kernel_initializer,
                                      name='img_kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)

        # Mask kernel
        self.kernel_mask = K.constant(
            1, shape=self.kernel_size + (self.input_dim, self.filters))

        # Compute flexible padding
        if self.padding == 'same':
            # Get strides for the t, x and y dimensions
            if hasattr(self.strides, '__len__'):
                if len(self.strides) == 3:
                    stride_t, stride_h, stride_w = self.strides
            elif type(self.strides) is int:
                stride_t, stride_h, stride_w = self.strides, self.strides, self.strides
            else:
                raise ValueError(
                    'The strides should be either an int, a tuple of length 3. Found {}.'.format(self.strides))

            # Get kernel_size for the t, x and y dimensions
            if hasattr(self.kernel_size, '__len__'):
                if len(self.kernel_size) == 3:
                    kernel_size_t, kernel_size_h, kernel_size_w = self.kernel_size
            elif type(self.kernel_size) is int:
                kernel_size_t, kernel_size_h, kernel_size_w = self.kernel_size, self.kernel_size, self.kernel_size
            else:
                raise ValueError(
                    'The kernel_size should be either an int, a tuple of length 3. Found {}.'.format(self.kernel_size))

            # Function that returns (padding1, padding2) for a given dimension
            def create_same_padding(stride, dimension_size, kernel_size):
                # Decided to take the ceiling of the dimension_size/stride according to
                # https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks
                padding = (stride*ceil(dimension_size/stride) -
                           dimension_size + kernel_size - stride)
                return (floor(padding/2), ceil(padding/2))

            self.pconv_padding = (
                create_same_padding(
                    stride_t, input_shape[time_axis], kernel_size_t),
                create_same_padding(
                    stride_h, input_shape[height_axis], kernel_size_h),
                create_same_padding(
                    stride_w, input_shape[width_axis], kernel_size_w)
            )

        elif self.padding == 'valid':
            self.pconv_padding = ((0, 0), (0, 0), (0, 0))
        else:
            raise ValueError(
                'The padding value should be valid or same. Found {}.'.format(self.padding))

        # Window size - used for normalization
        self.window_size = self.kernel_size[0] * \
            self.kernel_size[1] * self.kernel_size[2]

        if self.use_bias:
            self.bias = self.add_weight(shape=(self.filters,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        self.built = True

    def call(self, input):
        '''
        We will be using the Keras conv3d method.
        If there are no missing values (null_value) in the input,
        we proceed to a regular convolution.
        Otherwise, we infer a mask based on null_values, and essentially we have
        to do here is multiply the mask with the input X, before we apply the
        convolutions. For the mask itself, we apply convolutions with all weights
        set to 1.
        Subsequently, we clip mask values to between 0 and 1
        '''

        # If the image is not masked, proceed to a regular 3d convolution
        if tf_count(input, self.null_value) == 0:
            img_output = K.conv3d(
                K.spatial_3d_padding(
                    input, self.pconv_padding, self.data_format),
                self.kernel,
                strides=self.strides,
                padding='valid',
                data_format=self.data_format,
                dilation_rate=self.dilation_rate
            )

        # If the image is masked, infer mask and proceed to partial convolution
        else:

            # Inferring mask from image
            # Only values equal to the null_value are considered invalid
            masks = tf.cast(tf.not_equal(input, tf.constant(
                float(self.null_value))), tf.float32)

            # Padding done explicitly so that padding becomes part of the masked partial convolution
            masks = K.spatial_3d_padding(
                masks, self.pconv_padding, self.data_format)
            images = K.spatial_3d_padding(
                input, self.pconv_padding, self.data_format)

            # Apply convolutions to mask
            mask_output = K.conv3d(
                masks, self.kernel_mask,
                strides=self.strides,
                padding='valid',
                data_format=self.data_format,
                dilation_rate=self.dilation_rate
            )

            # Apply convolutions to image
            img_output = K.conv3d(
                multiply(images, masks),
                self.kernel,
                strides=self.strides,
                padding='valid',
                data_format=self.data_format,
                dilation_rate=self.dilation_rate
            )

            # Calculate the mask ratio on each pixel in the output mask
            mask_ratio = self.window_size / (mask_output + 1e-8)

            # Clip output to be between 0 and 1
            mask_output = K.clip(mask_output, 0, 1)

            # Remove ratio values where there are holes
            mask_ratio = mask_ratio * mask_output

            # Normalize image output
            img_output = img_output * mask_ratio

        # Apply bias only to the image (if chosen to do so)
        if self.use_bias:
            img_output = K.bias_add(
                img_output,
                self.bias,
                data_format=self.data_format)

        # Apply activations on the image
        if self.activation is not None:
            img_output = self.activation(img_output)

        return img_output

    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_last':
            space = input_shape[1:-1]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding='same',
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            new_shape = (input_shape[0],) + \
                tuple(new_space) + (self.filters,)
            return new_shape
        if self.data_format == 'channels_first':
            space = input_shape[2:]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding='same',
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            new_shape = (input_shape, self.filters) + tuple(new_space)
            return new_shape


class PConvSimple3D(Conv3D):
    def __init__(self, *args, padding='same',
                 null_value=0, **kwargs):
        super().__init__(*args, padding=padding, **kwargs)
        self.null_value = null_value

        self.compute_output_shape

    def call(self, inputs):
        '''
        We will be using the Keras conv3d method.
        If there are no missing values (null_value) in the input,
        we proceed to a regular convolution.
        Otherwise, we infer a mask based on null_values, and essentially we have
        to do here is multiply the mask with the input X, before we apply the
        convolutions. For the mask itself, we apply convolutions with all weights
        set to 1.
        Subsequently, we clip mask values to between 0 and 1
        '''

        # If the image is not masked, proceed to a regular 3d convolution
        if tf_count(inputs, self.null_value) == 0:
            return super().call(inputs)

        # If the image is masked, infer mask and proceed to partial convolution
        else:

            # Inferring mask from image
            # Only values equal to the null_value are considered invalid
            masks = tf.cast(tf.not_equal(inputs, tf.constant(
                float(self.null_value))), tf.float32)

            # Mask kernel
            kernel_mask = K.constant(
                1, shape=self.kernel_size +
                (inputs.shape[self._get_channel_axis()],  self.filters))

            # Perform convolutions on masks and inputs
            mask_output = self.convolution_op(masks, kernel_mask)
            img_output = self.convolution_op(multiply(inputs, masks),
                                              self.kernel)

            # Window size - used for normalization
            window_size = self.kernel_size[0] * \
                self.kernel_size[1] * self.kernel_size[2]

            # Calculate the mask ratio on each pixel in the output mask
            mask_ratio = window_size / (mask_output + 1e-8)

            # Clip output to be between 0 and 1
            mask_output = K.clip(mask_output, 0, 1)

            # Remove ratio values where there are holes
            mask_ratio = mask_ratio * mask_output

            # Normalize image output
            img_output = img_output * mask_ratio

            # Apply bias only to the image (if chosen to do so)
            if self.use_bias:
                img_output = K.bias_add(
                    img_output,
                    self.bias,
                    data_format=self.data_format)

            # Apply activations on the image
            if self.activation is not None:
                img_output = self.activation(img_output)

            return img_output


class PConvSimple2DXENIA(Conv2D):
    def __init__(self, *args, padding='same', null_value=0, **kwargs):
        super().__init__(*args, **kwargs)
        self.input_spec = InputSpec(ndim=4)
        self.padding = padding
        self.null_value = null_value

    def build(self, input_shape):
        """Adapted from original _Conv() layer of Keras
        """

        if self.data_format == 'channels_first':
            channel_axis = 1
            height_axis = 2
            width_axis = 3
        else:
            channel_axis = 3
            height_axis = 1
            width_axis = 2

        if input_shape[channel_axis] is None:
            raise ValueError(
                'The channel dimension of the inputs should be defined. Found `None`.')

        self.input_dim = input_shape[channel_axis]

        # Image kernel
        kernel_shape = self.kernel_size + (self.input_dim, self.filters)
        self.kernel = self.add_weight(shape=kernel_shape,
                                      initializer=self.kernel_initializer,
                                      name='img_kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)

        # Mask kernel
        self.kernel_mask = K.ones(
            shape=self.kernel_size + (self.input_dim, self.filters))

        # Compute flexible padding
        if self.padding == 'same':
            # Get strides for the t, x and y dimensions
            if hasattr(self.strides, '__len__'):
                if len(self.strides) == 2:
                    stride_h, stride_w = self.strides
            elif type(self.strides) is int:
                stride_h, stride_w = self.strides, self.strides
            else:
                raise ValueError(
                    'The strides should be either an int, a tuple of length 2. Found {}.'.format(self.strides))

            # Get kernel_size for the t, x and y dimensions
            if hasattr(self.kernel_size, '__len__'):
                if len(self.kernel_size) == 2:
                    kernel_size_h, kernel_size_w = self.kernel_size
            elif type(self.kernel_size) is int:
                kernel_size_h, kernel_size_w = self.kernel_size, self.kernel_size
            else:
                raise ValueError(
                    'The kernel_size should be either an int, a tuple of length 2. Found {}.'.format(self.kernel_size))

            # Function that returns (padding1, padding2) for a given dimension
            def create_same_padding(stride, dimension_size, kernel_size):
                # Decided to take the ceiling of the dimension_size/stride according to
                # https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks
                padding = (stride*ceil(dimension_size/stride) -
                           dimension_size + kernel_size - stride)
                return (floor(padding/2), ceil(padding/2))

            self.pconv_padding = (
                create_same_padding(
                    stride_h, input_shape[height_axis], kernel_size_h),
                create_same_padding(
                    stride_w, input_shape[width_axis], kernel_size_w)
            )

        elif self.padding == 'valid':
            self.pconv_padding = ((0, 0), (0, 0))
        else:
            raise ValueError(
                'The padding value should be valid or same. Found {}.'.format(self.padding))

        # Window size - used for normalization
        self.window_size = self.kernel_size[0] * self.kernel_size[1]

        if self.use_bias:
            self.bias = self.add_weight(shape=(self.filters,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        self.built = True

    def call(self, input):
        '''
        We will be using the Keras conv3d method.
        If there are no missing values (null_value) in the input,
        we proceed to a regular convolution.
        Otherwise, we infer a mask based on null_values, and essentially we have
        to do here is multiply the mask with the input X, before we apply the
        convolutions. For the mask itself, we apply convolutions with all weights
        set to 1.
        Subsequently, we clip mask values to between 0 and 1
        '''

        # If the image is not masked, proceed to a regular 3d convolution
        if tf_count(input, self.null_value) == 0:
            img_output = K.conv2d(
                K.spatial_2d_padding(
                    input, self.pconv_padding, self.data_format),
                self.kernel,
                strides=self.strides,
                padding='valid',
                data_format=self.data_format,
                dilation_rate=self.dilation_rate
            )

        # If the image is masked, infer mask and proceed to partial convolution
        else:

            # Inferring mask from image
            # Only values equal to the null_value are considered invalid
            masks = tf.cast(tf.not_equal(input, tf.constant(
                float(self.null_value))), tf.float32)

            # Padding done explicitly so that padding becomes part of the masked partial convolution
            masks = K.spatial_2d_padding(
                masks, self.pconv_padding, self.data_format)
            images = K.spatial_2d_padding(
                input, self.pconv_padding, self.data_format)

            # Apply convolutions to mask
            mask_output = K.conv2d(
                masks, self.kernel_mask,
                strides=self.strides,
                padding='valid',
                data_format=self.data_format,
                dilation_rate=self.dilation_rate
            )

            # Apply convolutions to image
            img_output = K.conv2d(
                multiply(images, masks),
                self.kernel,
                strides=self.strides,
                padding='valid',
                data_format=self.data_format,
                dilation_rate=self.dilation_rate
            )

            # Calculate the mask ratio on each pixel in the output mask
            mask_ratio = self.window_size / (mask_output + 1e-8)

            # Clip output to be between 0 and 1
            mask_output = K.clip(mask_output, 0, 1)

            # Remove ratio values where there are holes
            mask_ratio = mask_ratio * mask_output

            # Normalize image output
            img_output = img_output * mask_ratio

        # Apply bias only to the image (if chosen to do so)
        if self.use_bias:
            img_output = K.bias_add(
                img_output,
                self.bias,
                data_format=self.data_format)

        # Apply activations on the image
        if self.activation is not None:
            img_output = self.activation(img_output)

        return img_output

    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_last':
            space = input_shape[1:-1]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding='same',
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            new_shape = (input_shape[0],) + \
                tuple(new_space) + (self.filters,)
            return new_shape
        if self.data_format == 'channels_first':
            space = input_shape[2:]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding='same',
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            new_shape = (input_shape, self.filters) + tuple(new_space)
            return new_shape


class PConvSimple2D(Conv2D):
    def __init__(self, *args, padding='same',
                 null_value=0, **kwargs):
        super().__init__(*args, padding=padding, **kwargs)
        self.null_value = null_value

    def call(self, inputs):
        '''
        We will be using the Keras conv2d method.
        If there are no missing values (null_value) in the input,
        we proceed to a regular convolution.
        Otherwise, we infer a mask based on null_values, and essentially we have
        to do here is multiply the mask with the input X, before we apply the
        convolutions. For the mask itself, we apply convolutions with all weights
        set to 1.
        Subsequently, we clip mask values to between 0 and 1
        '''

        # If the image is not masked, proceed to a regular 3d convolution
        if tf_count(inputs, self.null_value) == 0:
            return super().call(inputs)

        # If the image is masked, infer mask and proceed to partial convolution
        else:

            # Inferring mask from image
            # Only values equal to the null_value are considered invalid
            masks = tf.cast(tf.not_equal(inputs, tf.constant(
                float(self.null_value))), tf.float32)

            # Mask kernel
            kernel_mask = K.constant(
                1, shape=self.kernel_size +
                (inputs.shape[self._get_channel_axis()],  self.filters))

            # Perform convolutions on masks and inputs
            mask_output = self.convolution_op(masks, kernel_mask)
            img_output = self.convolution_op(multiply(inputs, masks),
                                              self.kernel)

            # Window size - used for normalization
            window_size = self.kernel_size[0] * self.kernel_size[1]

            # Calculate the mask ratio on each pixel in the output mask
            mask_ratio = window_size / (mask_output + 1e-8)

            # Clip output to be between 0 and 1
            mask_output = K.clip(mask_output, 0, 1)

            # Remove ratio values where there are holes
            mask_ratio = mask_ratio * mask_output

            # Normalize image output
            img_output = img_output * mask_ratio

            # Apply bias only to the image (if chosen to do so)
            if self.use_bias:
                img_output = K.bias_add(
                    img_output,
                    self.bias,
                    data_format=self.data_format)

            # Apply activations on the image
            if self.activation is not None:
                img_output = self.activation(img_output)

            return img_output


class PConvTest():

    def __init__(self, b, t, y, x, c, mask_limit=0.5):
        self.images, self.masks, self.reduced_images = self._create_inputs(
            b, t, y, x, c, mask_limit)

    def _create_inputs(self, b, t, y, x, c, mask_limit):
        input_data = rand(b, t, y, x, c)
        # apply mask
        # the zeroes should be the mask in question
        # mask_data = input_data.copy()
        # mask_data[mask_data < mask_limit] = 0
        # mask_data[mask_data >= mask_limit] = 1
        import numpy as np
        mask_data = np.ones_like(input_data)
        mask_data[:, 12:22, 12:22, 12:22, :] = 0
        input_data[mask_data == 0] = 0
        images = convert_to_tensor(input_data)
        masks = convert_to_tensor(mask_data)
        reduced_image = multiply(images, masks)
        return images, masks, reduced_image

    def visualize_regular_layer_effect(self, filters, kernel_size, strides, padding, is_printed=True):
        pconv_layer = PConv3D(filters, kernel_size,
                              strides=strides, padding=padding)
        output_images, output_masks = pconv_layer([self.images, self.masks])
        if is_printed:
            print('Output data', output_images)
            print('Mask output', output_masks)

    def visualize_reduced_layer_effect(self, filters, kernel_size, strides, padding, is_printed=True):
        pconv_layer = PConvSimple3D(
            filters, kernel_size, strides=strides, padding=padding)
        output_reduced_images = pconv_layer(self.reduced_images)
        if is_printed:
            print('Reduced data', output_reduced_images)

    def test_padding(self, kernel_size, strides, padding, filters=1):
        # Visualize the shape of inputs vs outputs
        # for different parameters
        pconv_layer = PConvSimple3D(
            filters, kernel_size, strides=strides, padding=padding)
        output_reduced_images = pconv_layer(self.reduced_images)
        res = {
            'input': self.reduced_images.shape,
            'output': output_reduced_images.shape,
        }
        is_correct = True
        if type(strides) is int:
            strides = [strides, strides, strides]
        if type(kernel_size) is int:
            kernel_size = [kernel_size, kernel_size, kernel_size]
        elif type(kernel_size) is tuple:
            kernel_size = list(kernel_size)
        if padding == 'same':
            for i in [1, 2, 3]:
                if res['output'][i] != ceil(res['input'][i]/strides[i-1]):
                    is_correct = False
        elif padding == 'valid':
            for i in [1, 2, 3]:
                oo = res['input'][i] + 1 - floor(kernel_size[i-1])
                nope = ceil(oo/strides[i-1])
                if res['output'][i] != nope:
                    is_correct = False
        return is_correct

    def compare_convolutions(self, filters, kernel_size, strides, padding):
        initializer = GlorotUniform(seed=1)
        conv_layer = Conv3D(filters, kernel_size, strides=strides,
                            padding=padding, use_bias=False, kernel_initializer=initializer)
        pconv_layer = PConvSimple3D(filters, kernel_size, strides=strides,
                                    padding=padding, use_bias=False, kernel_initializer=initializer)

        images = self.images
        images = tf.where(self.masks == 0, 1, images)

        output = conv_layer(images)
        output2, = pconv_layer(images)
        res = bool(tf_count(equal(output, output2), False).numpy())
        return res

    def compare_layers(self, filters, kernel_size, strides, padding):
        # ! Do not initialize layers with use_bias = False and a static kernel_initializer
        # outside of very specific testing scenarios the way I did it here
        # My aim was to initialize two similar layers with the same weights to see
        # if the operations performed yielded different results
        initializer = GlorotUniform(seed=1)
        pconv_layer = PConv3D(filters, kernel_size, strides=strides,
                              padding=padding, use_bias=False, kernel_initializer=initializer)
        output, _ = pconv_layer([self.images, self.masks])
        pconv_layer2 = PConvSimple3D(filters, kernel_size, strides=strides,
                                     padding=padding, use_bias=False, kernel_initializer=initializer)
        output2 = pconv_layer2(self.reduced_images)
        res = bool(tf_count(equal(output, output2), False).numpy())
        return res

    def _regular_layer(self):
        self.visualize_regular_layer_effect(
            1, (3, 3, 3), 1, padding='same', is_printed=False)

    def _reduced_layer(self):
        self.visualize_reduced_layer_effect(
            1, (3, 3, 3), 1, padding='same', is_printed=False)

    def test_memory_usage(self):
        from memory_profiler import memory_usage
        # Test the memory usage of both layers.
        print('Regular Layer')
        mem_usage = memory_usage(self._regular_layer)
        print('Memory usage (in chunks of .1 seconds): %s' % mem_usage)
        print('Maximum memory usage: %s' % max(mem_usage))
        print('\nReduced Layer')
        mem_usage = memory_usage(self._reduced_layer)
        print('Memory usage (in chunks of .1 seconds): %s' % mem_usage)
        print('Maximum memory usage: %s' % max(mem_usage))

    def test_encoding_steps(self):
        def instancenorm_relu(inputs):
            # TODO: do not touch zero values?!

            masks = tf.cast(tf.not_equal(inputs, tf.constant(float(0))),
                            tf.float32)

            adj_inputs = MaskedInstanceNormalization(
                beta_initializer="random_uniform",
                gamma_initializer="random_uniform")(inputs, masks)

            adj_inputs = adj_inputs * masks

            return LeakyReLU(alpha=0.2)(adj_inputs)

        def residual_block_3d(layer_input, filters, kernelsize, stride, init,
                              l2_weight_regularizer=1e-5):

            x = instancenorm_relu(layer_input)
            x = PConvSimple3D(filters, kernel_size=kernelsize,
                              strides=stride, padding='same')(x)
            x = instancenorm_relu(x)
            x = PConvSimple3D(filters, kernel_size=kernelsize,
                              strides=1, padding='same')(x)
            # Shortcut Connection (Identity Mapping)
            s = PConvSimple3D(filters, kernel_size=1,
                              strides=stride, padding='valid')(layer_input)
            # Addition
            x = x + s

            return x

        filters = 64
        defaultkernelsize = 4
        defaultstride = 2

        input_data = self.images
        # input_data = Input(shape=(32, 32, 32, 1))
        print(f'Padded input: {input_data.shape}')

        # Encoding step 1
        x = PConvSimple3D(filters, defaultkernelsize,
                          padding="same", strides=1)(input_data)
        x = PConvSimple3D(filters, defaultkernelsize,
                          padding="same", strides=1)(x)
        s = PConvSimple3D(filters, 1, padding="same")(input_data)
        enc1 = x + s

        print(f'Output shape after first encoding: {enc1.shape}')

        # Encoding step 2
        enc2 = residual_block_3d(enc1, filters * 2, defaultkernelsize,
                                 defaultstride, None)
        print(f'Output shape after second encoding: {enc2.shape}')


if __name__ == "__main__":
    # test = PConvTest(1, 37, 32, 32, 2, mask_limit=0.6)
    test = PConvTest(1, 32, 32, 32, 1, mask_limit=1)
    # Select the tests you want to run
    # test.visualize_regular_layer_effect(1, (3, 3, 3), 1, padding='same')
    # test.visualize_reduced_layer_effect(1, (3, 3, 3), 1, padding='same')
    # print(test.test_padding((4, 4, 4), 2, padding='same'))
    # print(test.test_padding((2, 2, 2), 1, padding='same'))
    # print(test.test_padding((3, 3, 3), 1, padding='valid'))
    print(test.compare_layers(1, (3, 3, 3), 2, padding='same'))
    # test.test_memory_usage()
    # print(test.compare_convolutions(1, (3, 3, 3), 1, 'same'))
    test.test_encoding_steps()
