"""Augmentation class."""

from __future__ import annotations

from random import randint, random, sample

import torch


class Augmentor:
    """Apply augmentations to the received data."""

    def __init__(
        self,
        p_vflip: float = 0.0,
        p_hflip: float = 0.0,
        p_rotate: float = 0.0,
        p_tcut: float = 0.5,
        p_roll: float = 0.0,
        p_vshift: float = 0.0,
        p_mask_ts: float = 0.0,
        p_circles: float = 0.0,
        n_ts: int = 16,
    ) -> None:
        """
        Randomised augmentation.

        Parameters
        ----------
        p_vflip : float
            Probability of vertical flip, default 0.5
        p_hflip : float
            Probability of horizontal flip, default 0.5
        p_rotate : float
            Probability of rotation, default 1.0
        p_tcut : float
            Probability of time cut, default 0.5
        p_roll : float
            Probability of roll, default 0.5
        p_vshift : float
            Probability of vertical shift, default 0.5
        p_mask_ts : float
            Probability of masking a random time stamp, default 0.5
        p_circles : float
            Probability of applying random circles, default 0.5
        n_ts : int
            Number of time stamps in a time series, default 16
        """
        self.p_vflip = p_vflip
        self.p_hflip = p_hflip
        self.p_rotate = p_rotate
        self.p_tcut = p_tcut
        self.p_roll = p_roll
        self.p_vshift = p_vshift
        self.p_mask_ts = p_mask_ts
        self.p_circles = p_circles
        self.n_ts = n_ts

    def augment(
        self,
        data: dict[str, torch.Tensor],
    ) -> dict[str, torch.Tensor]:
        """Apply all augmentations available according to initialized probabilities to all tensors in the input data dictionary."""
        if random() <= self.p_vflip:
            data = self.vflip(data)
        if random() <= self.p_hflip:
            data = self.hflip(data)
        if random() <= self.p_rotate:
            data = self.rotate(data)
        if random() <= self.p_roll:
            data = self.roll(data)
        if random() <= self.p_vshift:
            data = self.vshift(data)
        if random() <= self.p_mask_ts:
            data = self.mask_ts(data)
        if random() <= self.p_circles:
            data = self.circles(data)
        data = self.tscut(data, center=random() <= self.p_tcut)
        return data

    def vflip(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """Apply vertical flip."""
        for k in data:
            data[k] = data[k].flip(dims=[-2])
        return data

    def hflip(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """Apply horizontal flip."""
        for k in data:
            data[k] = data[k].flip(dims=[-1])
        return data

    def rotate(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """Apply rotation."""
        rotate_idx = randint(0, 3)
        for k in data:
            data[k] = data[k].rot90(k=rotate_idx, dims=[-2, -1])
        return data

    def tscut(
        self, data: dict[str, torch.Tensor], center: bool = True
    ) -> dict[str, torch.Tensor]:
        """Apply random cut over the time dimention.

        NOTE: Time dimension is assumed to be the third to last dimention.
        """
        ts_len = data["input"].shape[-3]  # B, T, W, H
        start = (
            (ts_len - self.n_ts) // 2
            if center
            else randint(0, (ts_len - self.n_ts) // 2)
        )
        end = start + self.n_ts
        for k in data:
            if data[k].dim() >= 3:  # if time dimention is present # noqa: PLR2004
                data[k] = data[k][..., start:end, :, :]
        return data

    def vshift(self, data: dict[str, torch.Tensor], max_shift: float = 0.2):
        """Apply vertical signal shift."""
        shift = random() * max_shift
        shift = -shift if random() < 0.5 else shift  # noqa: PLR2004
        data["input"] = data["input"] + shift
        return data

    def roll(self, data: dict[str, torch.Tensor], max_shift: int = 40):
        """Apply random roll."""
        h_shift = int(random() * max_shift)
        v_shift = int(random() * max_shift)
        h_shift = -h_shift if random() < 0.5 else h_shift  # noqa: PLR2004
        v_shift = -v_shift if random() < 0.5 else v_shift  # noqa: PLR2004

        for k in data:
            canvas = torch.roll(data[k], shifts=(v_shift, h_shift), dims=(-2, -1))
            data[k] = canvas

        return data

    def mask_ts(self, data: dict[str, torch.Tensor], max_mask: float = 0.2):
        """Mask random timestamp from input."""
        n_mask = max(1, round(random() * max_mask * data["input"].shape[1]))
        indexes = sample(range(data["input"].shape[1]), n_mask)
        data["input"][:, indexes] = 0
        return data

    def circles(
        self,
        data: dict[str, torch.Tensor],
        max_circles: int = 5,
        max_radius: int = 80,
        max_shift: float = 0.2,
    ):
        """Apply signal shift via random circles."""
        w, h = data["extent"].shape
        for _ in range(max(1, int(random() * max_circles))):
            # prepare random circle
            center = (random() * w, random() * h)
            radius = max(5, random() * max_radius)
            ts = int(random() * data["input"].shape[1])
            shift = random() * max_shift
            shift = -shift if random() < 0.5 else shift  # noqa: PLR2004

            # apply circle
            y, x = torch.meshgrid(torch.arange(w), torch.arange(h))
            dist = torch.sqrt((x - center[0]) ** 2 + (y - center[1]) ** 2)
            data["input"][:, ts, dist <= radius] += shift
        return data
