from tensorflow_addons.layers import GroupNormalization, InstanceNormalization
from tensorflow.keras.layers import Layer, LayerNormalization, BatchNormalization
import tensorflow as tf
import numpy as np
# from icecream import ic
from tensorflow.keras import constraints
from tensorflow.keras import initializers
from tensorflow.keras import regularizers

from cropsar_px.utils.clouds import generate_cloud


class MaskedBatchNormalization(BatchNormalization):
    def _calculate_mean_and_var(self, inputs, reduction_axes, keep_dims):
        mask = tf.cast(tf.not_equal(inputs, tf.constant(float(0))),
                       tf.float32)
        return tf.nn.weighted_moments(inputs, reduction_axes, mask,
                                      keepdims=keep_dims)


class MaskedGroupNormalization(GroupNormalization):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, inputs, normmask):

        input_shape = tf.keras.backend.int_shape(inputs)
        tensor_input_shape = tf.shape(inputs)

        reshaped_inputs, group_shape = self._reshape_into_groups(
            inputs, input_shape, tensor_input_shape
        )

        float_mask = tf.cast(normmask, dtype=tf.dtypes.float32)
        reshaped_mask, _ = self._reshape_into_groups(
            float_mask, input_shape, tensor_input_shape
        )
        normalized_inputs = self._apply_masked_normalization(
            reshaped_inputs, input_shape, reshaped_mask)

        is_instance_norm = (input_shape[self.axis] // self.groups) == 1
        if not is_instance_norm:
            outputs = tf.reshape(normalized_inputs, tensor_input_shape)
        else:
            outputs = normalized_inputs

        return tf.cast(outputs, tf.float32)

    def _apply_masked_normalization(self, reshaped_inputs, input_shape,
                                    reshaped_mask):

        group_shape = tf.keras.backend.int_shape(reshaped_inputs)
        group_reduction_axes = list(range(1, len(group_shape)))
        is_instance_norm = (input_shape[self.axis] // self.groups) == 1
        if not is_instance_norm:
            axis = -2 if self.axis == -1 else self.axis - 1
        else:
            axis = -1 if self.axis == -1 else self.axis - 1
        group_reduction_axes.pop(axis)

        # use weighted moments to compute mean/variance
        # only using unmasked pixels
        mean, variance = tf.nn.weighted_moments(
            reshaped_inputs, group_reduction_axes, reshaped_mask, keepdims=True
        )

        gamma, beta = self._get_reshaped_weights(input_shape)
        normalized_inputs = tf.nn.batch_normalization(
            reshaped_inputs,
            mean=mean,
            variance=variance,
            scale=gamma,
            offset=beta,
            variance_epsilon=self.epsilon,
        )
        return normalized_inputs


class MaskedInstanceNormalization(MaskedGroupNormalization):

    def __init__(self, **kwargs):
        if "groups" in kwargs:
            logging.warning("The given value for groups will be overwritten.")

        kwargs["groups"] = -1
        super().__init__(**kwargs)


class MaskedLayerNormalization(Layer):
    '''Copy of LayernOrmalization with mask added'''

    def __init__(self,
                 axis=-1,
                 epsilon=1e-3,
                 center=True,
                 scale=True,
                 beta_initializer='zeros',
                 gamma_initializer='ones',
                 beta_regularizer=None,
                 gamma_regularizer=None,
                 beta_constraint=None,
                 gamma_constraint=None,
                 **kwargs):
        super().__init__(**kwargs)
        if isinstance(axis, (list, tuple)):
            self.axis = axis[:]
        elif isinstance(axis, int):
            self.axis = axis
        else:
            raise TypeError('Expected an int or a list/tuple of ints for the '
                            'argument \'axis\', but received: %r' % axis)

        self.epsilon = epsilon
        self.center = center
        self.scale = scale
        self.beta_initializer = initializers.get(beta_initializer)
        self.gamma_initializer = initializers.get(gamma_initializer)
        self.beta_regularizer = regularizers.get(beta_regularizer)
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
        self.beta_constraint = constraints.get(beta_constraint)
        self.gamma_constraint = constraints.get(gamma_constraint)

        self.supports_masking = True

        # Indicates whether a faster fused implementation can be used. This will be
        # set to True or False in build()"
        self._fused = None

    def _fused_can_be_used(self, ndims):
        """Returns false if fused implementation cannot be used.
        Check if the axis is contiguous and can be collapsed into the last axis.
        The self.axis is assumed to have no duplicates.
        """
        axis = sorted(self.axis)
        can_use_fused = False

        if axis[-1] == ndims - 1 and axis[-1] - axis[0] == len(axis) - 1:
            can_use_fused = True

        # fused_batch_norm will silently raise epsilon to be at least 1.001e-5, so
        # we cannot used the fused version if epsilon is below that value. Also, the
        # variable dtype must be float32, as fused_batch_norm only supports float32
        # variables.
        if self.epsilon < 1.001e-5 or self.dtype != 'float32':
            can_use_fused = False

        return can_use_fused

    def build(self, input_shape):
        ndims = len(input_shape)
        if ndims is None:
            raise ValueError(
                'Input shape %s has undefined rank.' % input_shape)

        # Convert axis to list and resolve negatives
        if isinstance(self.axis, int):
            self.axis = [self.axis]
        elif isinstance(self.axis, tuple):
            self.axis = list(self.axis)
        for idx, x in enumerate(self.axis):
            if x < 0:
                self.axis[idx] = ndims + x

        # Validate axes
        for x in self.axis:
            if x < 0 or x >= ndims:
                raise ValueError(
                    f'Invalid axis. Expected 0 <= axis < inputs.rank (with '
                    f'inputs.rank={ndims}). Received: layer.axis={self.axis}')
        if len(self.axis) != len(set(self.axis)):
            raise ValueError('Duplicate axis: {}'.format(tuple(self.axis)))

        param_shape = [input_shape[dim] for dim in self.axis]
        if self.scale:
            self.gamma = self.add_weight(
                name='gamma',
                shape=param_shape,
                initializer=self.gamma_initializer,
                regularizer=self.gamma_regularizer,
                constraint=self.gamma_constraint,
                trainable=True,
                experimental_autocast=False)
        else:
            self.gamma = None

        if self.center:
            self.beta = self.add_weight(
                name='beta',
                shape=param_shape,
                initializer=self.beta_initializer,
                regularizer=self.beta_regularizer,
                constraint=self.beta_constraint,
                trainable=True,
                experimental_autocast=False)
        else:
            self.beta = None

        self._fused = self._fused_can_be_used(ndims)

        self.built = True

    def call(self, inputs, mask):
        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.shape
        ndims = len(input_shape)

        # Broadcasting only necessary for norm when the axis is not just
        # the last dimension
        broadcast_shape = [1] * ndims
        for dim in self.axis:
            broadcast_shape[dim] = input_shape.dims[dim].value

        def _broadcast(v):
            if (v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]):
                return tf.reshape(v, broadcast_shape)
            return v

        if not self._fused:
            input_dtype = inputs.dtype
            if input_dtype in ('float16', 'bfloat16') and self.dtype == 'float32':
                # If mixed precision is used, cast inputs to float32 so that this is at
                # least as numerically stable as the fused version.
                inputs = tf.cast(inputs, 'float32')

            # Calculate the moments on the last axis (layer activations).
            # Calculate the number of valid pixels to check for exceptions
            n_valid_pixels = tf.reduce_sum(mask, axis=self.axis, keepdims=True)
            n_valid_pixels_reduced = tf.reduce_sum(n_valid_pixels)
            zero_valid_pixel = bool(tf.equal(
                n_valid_pixels_reduced, tf.zeros_like(n_valid_pixels_reduced)).numpy())
            if zero_valid_pixel:
                # 1. If no pixels are valid, the mean defaults to one and the variance to zero
                mean = tf.zeros_like(tf.reduce_sum(
                    mask, axis=self.axis, keepdims=True))
                variance = tf.ones_like(mean)
            else:
                # Calculate mean only on valid pixels
                mean = tf.reduce_sum(
                    inputs*mask, axis=self.axis, keepdims=True) / n_valid_pixels
                one_valid_pixel = bool(
                    tf.equal(n_valid_pixels_reduced, tf.ones_like(n_valid_pixels_reduced)).numpy())
                if one_valid_pixel:
                    # 2. If only one pixel is valid, the vairance defaults to zero
                    variance = tf.ones_like(mean)
                else:
                    # Calculate variance only on valid pixels
                    # The variance should be calculated on (n_valid_pixels - 1), however this is
                    # not consistent with the results of tf.nn.moments, which calculates the
                    # variance on n_valid_pixels
                    variance = tf.reduce_sum(
                        ((inputs - mean) ** 2) * mask, axis=self.axis, keepdims=True) / n_valid_pixels

            scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

            # Compute layer normalization using the batch_normalization function.
            outputs = tf.nn.batch_normalization(
                inputs,
                mean,
                variance,
                offset=offset,
                scale=scale,
                variance_epsilon=self.epsilon)
            outputs = tf.cast(outputs, input_dtype)
        else:
            # Collapse dims before self.axis, and dims in self.axis
            pre_dim, in_dim = (1, 1)
            axis = sorted(self.axis)
            tensor_shape = tf.shape(inputs)
            for dim in range(0, ndims):
                dim_tensor = tensor_shape[dim]
                if dim < axis[0]:
                    pre_dim = pre_dim * dim_tensor
                else:
                    assert dim in axis
                    in_dim = in_dim * dim_tensor

            squeezed_shape = [1, pre_dim, in_dim, 1]
            # This fused operation requires reshaped inputs to be NCHW.
            data_format = 'NCHW'

            inputs = tf.reshape(inputs, squeezed_shape)

            # self.gamma and self.beta have the wrong shape for fused_batch_norm, so
            # we cannot pass them as the scale and offset parameters. Therefore, we
            # create two constant tensors in correct shapes for fused_batch_norm and
            # later construct a separate calculation on the scale and offset.
            scale = tf.ones([pre_dim], dtype=self.dtype)
            offset = tf.zeros([pre_dim], dtype=self.dtype)

            # Compute layer normalization using the fused_batch_norm function.
            outputs, _, _ = tf.compat.v1.nn.fused_batch_norm(
                inputs,
                scale=scale,
                offset=offset,
                epsilon=self.epsilon,
                data_format=data_format)

            outputs = tf.reshape(outputs, tensor_shape)

            scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

            if scale is not None:
                outputs = outputs * tf.cast(scale, outputs.dtype)
            if offset is not None:
                outputs = outputs + tf.cast(offset, outputs.dtype)

        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        return outputs

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        config = {
            'axis': self.axis,
            'epsilon': self.epsilon,
            'center': self.center,
            'scale': self.scale,
            'beta_initializer': initializers.serialize(self.beta_initializer),
            'gamma_initializer': initializers.serialize(self.gamma_initializer),
            'beta_regularizer': regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
            'beta_constraint': constraints.serialize(self.beta_constraint),
            'gamma_constraint': constraints.serialize(self.gamma_constraint)
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))


class MaskedNormalizationTest:

    def __init__(self, b, t, y, x, c):
        self.data, self.mask = self._generate_raw_data(b, t, y, x, c)
        ic(self.data.shape, self.mask.shape)

    def _generate_raw_data(self, b, t, y, x, c):
        from numpy import random
        from cropsar_px.utils.clouds import generate_cloud
        assert x == y, 'The images need to be squares: x = y.'
        raw_data = random.rand(b, t, y, x, c)  # everything between 0 and 1
        mask_data = np.zeros((b, t, y, x, c))
        for t_i in range(t):
            for c_i in range(c):
                cloud_mask = generate_cloud(
                    windowsize=y)
                mask_data[:, t_i, :, :, c_i] = cloud_mask
        return tf.convert_to_tensor(raw_data), tf.convert_to_tensor(mask_data, dtype=tf.float32)

    def run_instance_norm(self):
        norm = InstanceNormalization()
        masked_norm = MaskedInstanceNormalization()
        norm_data = norm(self.data)
        masked_norm_data = masked_norm(self.data, self.mask)
        ic(self.data, self.mask, norm_data, masked_norm_data)

    def run_layer_norm(self):
        norm = LayerNormalization([1, 2, 3])
        masked_norm = MaskedLayerNormalization([1, 2, 3])
        norm_data = norm(self.data)
        masked_norm_data = masked_norm(self.data, self.mask)
        ic(self.data, self.mask, norm_data, masked_norm_data)

    def test_concat(self):
        from tensorflow.keras.layers import concatenate
        norm = InstanceNormalization(axis=4)
        masked_norm = MaskedInstanceNormalization(axis=4)
        norm_data = norm(self.data)
        masked_norm_data = masked_norm(
            self.data, tf.cast(self.mask, tf.float32))
        ones = tf.ones(masked_norm_data.shape)
        concat_norm_data = tf.concat([norm_data, ones], axis=-1)
        concat_masked_norm_data = concatenate(
            [masked_norm_data, ones], axis=-1)


if __name__ == '__main__':
    test = MaskedNormalizationTest(4, 37, 32, 32, 1)
    test.test_concat()
