"""Custom augmentations."""

from __future__ import annotations

from pathlib import Path
from random import randint, random

import torch

from vito_cropsar.constants import PRECISION_INT, get_data_folder
from vito_cropsar.data import apply_mask, obscure_equidistant
from vito_cropsar.data.masks import enforce_big_gap
from vito_cropsar.models import AugmentationConfig


class Augmentor:
    """Apply augmentations to the received data."""

    def __init__(
        self,
        p_cfg: AugmentationConfig,
        n_ts: int,
        resolution: int,
        sample_s1: bool,
        sample_mask: bool = True,
        freeze: bool = False,
        data_f: Path | None = None,
    ) -> None:
        """
        Randomised augmentation.

        Parameters
        ----------
        p_cfg : AugmentationConfig
            Augmentation configuration containing the probabilistic properties
        n_ts : int
            Number of time steps to use
        resolution : int
            Resolution of the tiles to use (square)
        sample_s1 : bool
            Whether to sample the S1 (ascending/descending) data or not
            Note: This will halve the number of S1 bands ([s1_asc_vv, s1_des_vv] --> [s1_vv])
        sample_mask : bool
            Apply sampled masking on the S2 input data
            Note: Advised to only do this during training (to get consistent evaluation results)
        freeze : bool
            Whether to always use the same augmentation or not
        data_f : Path | None
            Data folder used to pull masks from
        """
        self.p_cfg = p_cfg
        self.n_ts = n_ts
        self.resolution = resolution
        self.sample_s1 = sample_s1
        self.sample_mask = sample_mask
        self.freeze = freeze
        self.data_f = data_f or get_data_folder()

    def __call__(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """Perform the augmentations."""
        # Shape formatting and selection
        data = self._cut_ts(data)  # S2 and target are out of sync
        data = self._cut_resolution(data)

        # Select the S1 band
        data = self._sample_s1(data)

        # Harmless augmentations
        data = self._flip_horizontal(data)
        data = self._flip_vertical(data)
        data = self._rot90(data)

        # Input data augmentations, potentially harmful to the reconstructive capabilities
        data = self._add_artificial_clouds(data)
        data = self._mimic_nrt(data)

        return data

    def _cut_ts(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """Cut the time steps."""
        # Check the input sizes
        sizes = [v.shape[0] for v in data.values()]
        assert len(set(sizes)) == 1, "All time series must have the same length"
        assert self.n_ts <= sizes[0]

        # Cut the time series using a random offset
        diff = sizes[0] - self.n_ts
        offset = diff // 2 if self.freeze else randint(0, diff)
        return {k: v[offset : offset + self.n_ts] for k, v in data.items()}

    def _cut_resolution(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """Cut the resolution."""
        # Check the input sizes
        sizes = [v.shape[-1] for v in data.values()]
        assert len(set(sizes)) == 1, "All time series must have the same resolution"
        assert self.resolution <= sizes[0]

        # Cut the time series using a random offset
        diff = sizes[0] - self.resolution
        offset_x = diff // 2 if self.freeze else randint(0, diff)
        offset_y = diff // 2 if self.freeze else randint(0, diff)
        return {
            k: v[
                ...,
                offset_x : offset_x + self.resolution,
                offset_y : offset_y + self.resolution,
            ]
            for k, v in data.items()
        }

    def _add_artificial_clouds(
        self, data: dict[str, torch.Tensor]
    ) -> dict[str, torch.Tensor]:
        """
        Add extra (artificial) cloud obscurence to the S2 data.

        Parameters
        ----------
        data : dict[str, torch.Tensor]
            The data to obscure

        Returns
        -------
        dict[str, torch.Tensor]
            The data with the extra cloud obscurence
        """
        if not self.sample_mask:
            return data

        # Add extra cloud obscurence
        mask = data["mask"].numpy()
        if self.p_cfg.r_cloud > 0:
            obscure_equidistant(
                mask=mask,
                ratio_mask=self.p_cfg.r_cloud,
                ratio_visible=self.p_cfg.r_visible,
                data_f=self.data_f,
            )
        if self.p_cfg.p_gap > 0:
            enforce_big_gap(
                mask=mask,
                ratio=self.p_cfg.p_gap,
                ratio_visible=self.p_cfg.r_visible,
            )

        # Transform mask to tensor and apply to the input
        mask_t = torch.tensor(mask, dtype=PRECISION_INT)
        data["s2"] = apply_mask(data["s2"], mask=mask_t)
        data["mask"] = mask_t
        return data

    def _mimic_nrt(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """
        Mimic the NRT behaviour by cutting out the last parts of S1 and S2.

        Parameters
        ----------
        data : dict[str, torch.Tensor]
            The data to obscure

        Returns
        -------
        dict[str, torch.Tensor]
            The data with the extra cloud obscurence
        """
        if not self.sample_mask:
            return data

        # Select the number of time steps to cut
        n_ts = randint(0, self.p_cfg.nrt_max)
        for i in range(n_ts):
            data["s1"][-1 - i] = torch.nan
            data["s2"][-1 - i] = torch.nan
            data["mask"][-1 - i] = 0
        return data

    def _sample_s1(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """
        Sample the S1 bands (ascending vs descending).

        Notes
        -----
         - This will halve the number of S1 bands ([s1_asc_vv, s1_des_vv] --> [s1_vv])
         - This assumes that ascending and descending bands are always in pairs

        Parameters
        ----------
        data : dict[str, torch.Tensor]
            The data to sample from

        Returns
        -------
        dict[str, torch.Tensor]
            The data with the sampled S1 bands and untouched other data
        """
        if not self.sample_s1:
            return data

        # Sample the S1 bands
        s1 = data["s1"]
        s1_result = []
        for i in range(0, s1.shape[1], 2):
            s1_asc = s1[:, i]
            s1_des = s1[:, i + 1]

            # Check the band frequency
            valid_asc = sum(~v.isnan().any() for v in s1_asc)
            valid_des = sum(~v.isnan().any() for v in s1_des)

            # Choose band with significant majority, sample one otherwise
            if ((valid_asc / 2) > valid_des) or (valid_des == 0):
                s1_result.append(s1_asc[:, None])
            elif ((valid_des / 2) > valid_asc) or (valid_asc == 0):
                s1_result.append(s1_des[:, None])
            else:
                if self.freeze:  # Fixed sampling
                    x = s1_asc if valid_asc > valid_des else s1_des
                else:  # Random sampling
                    x = s1_asc if random() < 0.5 else s1_des  # noqa: PLR2004
                s1_result.append(x[:, None])

        # Stack the sampled S1 bands and return the data
        data["s1"] = torch.concat(s1_result, dim=1)
        return data

    def _flip_vertical(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """Vertical flipping method."""
        if (not self.freeze) and (random() <= self.p_cfg.p_flip_vertical):
            return {k: torch.flip(v, dims=(-2,)) for k, v in data.items()}
        return data

    def _flip_horizontal(
        self, data: dict[str, torch.Tensor]
    ) -> dict[str, torch.Tensor]:
        """Horizontal flipping method."""
        if (not self.freeze) and (random() <= self.p_cfg.p_flip_horizontal):
            return {k: torch.flip(v, dims=(-1,)) for k, v in data.items()}
        return data

    def _rot90(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """Channel shuffling method."""
        if (not self.freeze) and (random() <= self.p_cfg.p_rotate):
            rot = randint(0, 3)
            return {k: torch.rot90(v, k=rot, dims=(-2, -1)) for k, v in data.items()}
        return data
