"""Augmentation class."""

from __future__ import annotations

import random
import warnings

import numpy as np
import torch
from tsaug import Convolve, TimeWarp


class Augmentor:
    """Augmentor class."""

    def __init__(
        self,
        pr_merging: float = 0.2,
        pr_shifting: float = 0.8,
        pr_convolving: float = 0.0,
        pr_timewarping: float = 0.0,
    ) -> None:
        """Initialize Augmentor.

        NOTE: Sum of all probabilities must be <= 1.0!

        Parameters
        ----------
        pr_merging : float, optional
            Merging probability, by default 0.3
        pr_shifting : float, optional
            Shifting probability, by default 0.7
        pr_convolving : float, optional
            Convolving probability, by default 0.0
        pr_timewarping : float, optional
            TimeWarping probability, by default 0.0
        """
        self._aug_functions = [
            self._apply_identity,
            self.apply_merging,
            self.apply_shifting,
            self.apply_convolving,
            self.apply_timewarping,
        ]
        self._pr_weights = [
            pr_merging,
            pr_shifting,
            pr_convolving,
            pr_timewarping,
        ]
        assert sum(self._pr_weights) == 1.0, "Sum of augmentation probabilities must be <= 1"
        self._pr_weights = [1.0 - sum(self._pr_weights)] + self._pr_weights

    def augment(
        self, s1: torch.Tensor | np.ndarray, s2: torch.Tensor | np.ndarray | None = None
    ) -> torch.Tensor | np.ndarray:
        """
        Apply all augmentations available according to initialized probabilities.

        Augmentation order:
            - merging (if s2 is not None)
            - shifting
            - noising
            - timewarping
        """
        if type(s1) == np.ndarray:  # concatenation
            augs = [self.augment(ts1, ts2) for ts1, ts2 in zip(s1, s2)]
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", FutureWarning)
                augs = np.asarray(augs, dtype=object)
            return augs
        else:
            if len(s1.shape) == 1:
                return s1
            aug_i = random.choices(
                np.arange(len(self._pr_weights)),
                weights=self._pr_weights,
            )[0]
            s1 = self._aug_functions[aug_i](s1=s1, s2=s2)
        return s1

    def _apply_identity(self, s1: torch.Tensor, **kwargs):
        """Apply no augmentation."""
        return s1

    def apply_merging(
        self,
        s1: torch.Tensor,
        s2: torch.Tensor,
        max_ratio: float = 0.1,
    ) -> torch.Tensor:
        """
        Merge 2 timeseries according to a weight.

        Parameters
        ----------
        s1 : torch.Tensor
            First sample
        s2 : torch.Tensor
            Second sample
        max_ratio : float, optional
            Max weight to assign to the second sample. First sample weight is assigned as 1 - rtio, by default 0.25

        Returns
        -------
        torch.Tensor
            Merged sample
        """
        ratio = random.random() * max_ratio
        s1 = ((1 - ratio) * s1) + (ratio * s2)
        return s1

    def apply_shifting(self, s1: torch.Tensor, max_shift: float = 0.2, **kwargs) -> torch.Tensor:
        """
        Shift a timeseries according to a max_shift.

        All shifts are random in intensitya and in direction.
        Either all time series are shifted up or down.

        Parameters
        ----------
        s1 : torch.Tensor
            First sample
        max_shift : float, optional
            Maximum shift applied to timeseries.

        Returns
        -------
        torch.Tensor
            Shifted sample
        """
        side = 1 if random.random() > 0.5 else -1
        shift = random.random() * max_shift * side
        s1 += shift
        return s1

    def apply_convolving(
        self, s1: torch.Tensor, window: str = "hamming", size: int = 5, **kwargs
    ) -> torch.Tensor:
        """
        Apply convolving to all the time series of a sample.

        Parameters
        ----------
        s1 : torch.Tensor
            First sample
        window : str, optional
            type of convolution kernel to use, by default "hamming".
        size : int, optional
            size if window used to convolve, by default 5

        Returns
        -------
        torch.Tensor
            Convolved sample
        """
        seed = random.randint(0, 1000)
        augmentor = Convolve(window=window, size=size, seed=seed)

        if s1.size(1) == 1:
            s1 = np.expand_dims(augmentor.augment(s1.flatten().numpy()), 1)
        else:
            s1 = augmentor.augment(s1.numpy())
        return torch.Tensor(s1)

    def apply_timewarping(
        self, s1: torch.Tensor, n_changes: int = 1, max_speed: float = 1.1, **kwargs
    ) -> torch.Tensor:
        """
        Apply timewarping to all the time series of a sample.

        Parameters
        ----------
        s1 : torch.Tensor
            First sample
        n_changes : int, optional
            Number of timewarps to apply, by default 3
        max_speed : float, optional
            Speed modulation of the timewarp. Higher numbers gives stronger timewarps, by default 3.0

        Returns
        -------
        torch.Tensor
            Timewarped sample
        """
        seed = random.randint(0, 1000)
        speed = random.random() * (max_speed - 1) + 1 + 1e-19
        augmentor = TimeWarp(n_changes, speed, seed=seed)

        if s1.size(1) == 1:
            s1 = np.expand_dims(augmentor.augment(s1.flatten().numpy()), 1)
        else:
            s1 = augmentor.augment(s1.numpy())
        return torch.Tensor(s1)


if __name__ == "__main__":
    ts1 = torch.rand(size=(2, 18))
    ts2 = torch.rand(size=(2, 18))
    augmentor = Augmentor()
    augmentor.augment(ts1, ts2)
