"""Augmentation class."""

from __future__ import annotations

import warnings
from random import randint, random

import numpy as np
import torch
from tsaug import Convolve


class Augmentor:
    """Augmentor class."""

    def __init__(
        self,
        p_merging: float = 0.2,
        p_v_shifting: float = 0.5,
        p_convolving: float = 0.0,
        p_timewarping: float = 0.0,
        max_h_shift: int = 1,
        seq_len: int = 24,
    ) -> None:
        """Initialize Augmentor.

        Parameters
        ----------
        p_merging : float
            Merging probability
        p_v_shifting : float
            Vertical shifting probability
        p_convolving : float
            Convolving probability
        p_timewarping : float
            TimeWarping probability
        max_h_shift : int
            Maximum horizontal shift
        seq_len : int
            Length of the time series
        """
        self.p_merging = p_merging
        self.p_v_shifting = p_v_shifting
        self.p_convolving = p_convolving
        self.p_timewarping = p_timewarping
        self.max_h_shift = max_h_shift
        self.seq_len = seq_len

    def augment(
        self,
        s1: torch.Tensor | np.ndarray,
        s2: torch.Tensor | np.ndarray | None = None,
    ) -> torch.Tensor | np.ndarray:
        """
        Apply all augmentations available according to initialized probabilities.

        Augmentation order:
            - merging (if s2 is not None)
            - shifting
            - noising
            - timewarping
        """
        if type(s1) == np.ndarray:  # concatenation
            augs = [self.augment(ts1, ts2) for ts1, ts2 in zip(s1, s2)]
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", FutureWarning)
                augs = np.asarray(augs, dtype=object)
            return augs
        if len(s1.shape) == 1:
            return s1

        # cut ts to seq_len
        s1, s2 = self.apply_h_cut(s1=s1, s2=s2, seq_len=self.seq_len)

        # Perform the augmentations
        if (s2 is not None) and (random() <= self.p_merging):  # noqa: S311
            s1 = self.apply_merging(s1=s1, s2=s2)
        if random() <= self.p_v_shifting:  # noqa: S311
            s1 = self.apply_shifting(s1=s1)
        if random() <= self.p_convolving:  # noqa: S311
            s1 = self.apply_convolving(s1=s1)
        if random() <= self.p_timewarping:  # noqa: S311
            s1 = self.apply_timewarping(s1=s1)
        return s1

    def _apply_identity(self, s1: torch.Tensor, **kwargs):
        """Apply no augmentation."""
        return s1

    def apply_h_cut(
        self,
        s1: torch.Tensor,
        s2: torch.Tensor | None = None,
        seq_len: int = 24,
        center: bool = False,
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Cut the time series according to seq_len.

        Parameters
        ----------
        s1 : torch.Tensor
            First sample
        s2 : torch.Tensor, optional
            Second sample, by default None
        seq_len : int, optional
            Length of the output time series, by default 24
        center : bool, optional
            Whether to cut the time series in the center or not, by default False


        Returns
        -------
        torch.Tensor
        Cutted samples
        """
        # get a random starting point
        start = (s1.shape[0] - seq_len) // 2
        if not center and s1.shape[0] >= seq_len + 4:
            start = (
                start - randint(0, self.max_h_shift)  # noqa: S311
                if random() < 0.5  # noqa: S311, PLR2004
                else start + randint(0, self.max_h_shift)  # noqa: S311
            )
        return (
            (s1[start : start + seq_len], s2[start : start + seq_len])
            if s2 is not None
            else (s1[start : start + seq_len], None)
        )

    def apply_merging(
        self,
        s1: torch.Tensor,
        s2: torch.Tensor,
        max_ratio: float = 0.1,
    ) -> torch.Tensor:
        """
        Merge 2 timeseries according to a weight.

        Parameters
        ----------
        s1 : torch.Tensor
            First sample
        s2 : torch.Tensor
            Second sample
        max_ratio : float, optional
            Max weight to assign to the second sample. First sample weight is assigned as 1 - rtio, by default 0.25

        Returns
        -------
        torch.Tensor
            Merged sample
        """
        ratio = random() * max_ratio  # noqa: S311
        s1 = ((1 - ratio) * s1) + (ratio * s2)
        return s1

    def apply_shifting(
        self,
        s1: torch.Tensor,
        max_shift: float = 0.2,
        **kwargs,
    ) -> torch.Tensor:
        """
        Shift a timeseries according to a max_shift.

        All shifts are random in intensitya and in direction.
        Either all time series are shifted up or down.

        Parameters
        ----------
        s1 : torch.Tensor
            First sample
        max_shift : float, optional
            Maximum shift applied to timeseries.

        Returns
        -------
        torch.Tensor
            Shifted sample
        """
        side = 1 if random() > 0.5 else -1  # noqa: S311, PLR2004
        shift = random() * max_shift * side  # noqa: S311
        s1 += shift
        return s1

    def apply_convolving(
        self,
        s1: torch.Tensor,
        window: str = "hamming",
        size: int = 5,
        **kwargs,
    ) -> torch.Tensor:
        """
        Apply convolving to all the time series of a sample.

        Parameters
        ----------
        s1 : torch.Tensor
            First sample
        window : str, optional
            type of convolution kernel to use, by default "hamming".
        size : int, optional
            size if window used to convolve, by default 5

        Returns
        -------
        torch.Tensor
            Convolved sample
        """
        seed = randint(0, 1000)  # noqa: S311
        augmentor = Convolve(window=window, size=size, seed=seed)

        if s1.size(1) == 1:
            s1 = np.expand_dims(augmentor.augment(s1.flatten().numpy()), 1)
        else:
            s1 = augmentor.augment(s1.numpy())
        return torch.Tensor(s1)

    def apply_timewarping(
        self,
        s1: torch.Tensor,
        n_changes: int = 3,
        max_speed: float = 1.05,
        **kwargs,
    ) -> torch.Tensor:
        """
        Apply timewarping to all the time series of a sample.

        Parameters
        ----------
        s1 : torch.Tensor
            First sample
        n_changes : int, optional
            Number of timewarps to apply, by default 3
        max_speed : float, optional
            Speed modulation of the timewarp. Higher numbers gives stronger timewarps, by default 3.0

        Returns
        -------
        torch.Tensor
            Timewarped sample
        """
        shift, n_ts = randint(-1, 1), s1.shape[0]  # noqa: S311
        s1[max(0, shift) : min(n_ts + shift, n_ts)] = s1[
            max(0, -shift) : min(n_ts - shift, n_ts)
        ].clone()
        return s1


if __name__ == "__main__":
    ts1 = torch.rand(size=(2, 18))
    ts2 = torch.rand(size=(2, 18))
    augmentor = Augmentor()
    augmentor.augment(ts1, ts2)
    augmentor.augment(ts1, ts2)
