"""Custom learning rate scheduler."""

from __future__ import annotations

from typing import Literal

from torch.optim import Optimizer

from vito_cropsar.vito_logger import Logger


class LearningRateScheduler:
    """
    Custom learning rate scheduler.

    This scheduler implements the following:
     - Warmup (linearly increasing the learning rate from 0 to lr_max)
     - Cyclic learning rate (linearly increasing the learning rate from lr_min to lr_max and back)
     - Plateau learning rate (reducing the learning rate when the validation value does not improve)
    """

    def __init__(
        self,
        optimizer: Optimizer,
        warmup_steps: int,
        cycle_steps: int,
        lr_min: float = 1e-4,
        lr_max: float = 1e-3,
        lr_floor: float = 1e-8,
        val_mode: Literal["min", "max"] = "min",
        plateau_improvement_ratio: float = 0.0,
        plateau_factor: float = 0.5,
        plateau_patience: int = 3,
        logger: Logger | None = None,
        verbosity: int = 1,
    ) -> None:
        """
        Initialize the scheduler.

        Parameters
        ----------
        optimizer : torch.optim.Optimizer
            The optimizer to use
        warmup_steps : int
            The number of steps to warmup the learning rate
        cycle_steps : int
            The number of steps in a cycle (between two validations)
        lr_min : float
            The initial minimum learning rate
        lr_max : float
            The initial maximum learning rate
        lr_floor : float
            The minimum learning rate to reduce to
        val_mode : Literal["min", "max"]
            Whether to minimize or maximize the validation value
        plateau_improvement : float
            The minimum improvement in the validation value to be considered an improvement
        plateau_factor : float
            The factor to reduce the learning rates (lr_min, lr_max) by when the plateau patience is reached
        plateau_patience : int
            The number of steps to wait before reducing the learning rates (lr_min, lr_max) by the plateau factor
            Note: If patience is met the learning rate is reduced (value of 3 reduces on the 3rd non-improvement)
        logger : Logger
            Custom logging class to log results and save to file
            Note: Verbose always off if not provided
        verbosity : int
            Verbosity level:
             - 0: No logging
             - 1: Only log when there is a change in learning rate
             - 2: Log every step
        """
        # Check the inputs
        assert (
            lr_max > lr_min
        ), f"lr_max ({lr_max}) must be larger than lr_min ({lr_min})"
        assert cycle_steps > 0, "cycle_steps must be larger than 0"

        # Set the parameters
        self.optimizer = optimizer
        self.step_warmup, self.step_cycle = 0, 0
        self.lr = 0.0
        self.lr_min = lr_min
        self.lr_max = lr_max
        self.lr_floor = lr_floor
        self.warmup_steps = warmup_steps
        self.cycle_steps = cycle_steps
        self.val_mode = val_mode
        self.val_best = float("inf") if (self.val_mode == "min") else float("-inf")
        self.plateau_patience = 0
        self.plateau_improvement_ratio = plateau_improvement_ratio
        self.plateau_factor = plateau_factor
        self.plateau_max = plateau_patience
        self.logger = logger
        self.verbosity = verbosity

        # Apply the initial learning rate
        self._apply_lr()

    def __str__(self) -> str:
        """Representation of the learning rate scheduler."""
        return f"LearningRateScheduler(lr={self.lr}, warmup={self.warmup_steps}, cycle={self.step_cycle})"

    def __repr__(self) -> str:
        """Representation of the learning rate scheduler."""
        return str(self)

    def __call__(self, val: float | None) -> None:
        """Take a step with the scheduler (use the self.step() method)."""
        self.step(val)

    def step(self, val: float | None = None) -> None:
        """
        Take a step with the scheduler.

        Parameters
        ----------
        val : float
            Optional validation value
            Note: If provided a "cycle update" is performed, otherwise (None) a "batch update"
        """
        self._log(
            "Learning rate scheduler:",
            lvl=2 if (val is None) or (self.step_warmup < self.warmup_steps) else 1,
        )

        # Check if still in warmup, no learning rate update for validation during warmup
        if (self.step_warmup < self.warmup_steps) and (val is not None):
            self.val_best = (
                min(val, self.val_best)
                if self.val_mode == "min"
                else max(val, self.val_best)
            )
            return
        if self.step_warmup < self.warmup_steps:
            self._warmup_step()

        # Perform a cycle update if a validation value is provided
        elif val is not None:
            self._cycle_step(val)

        # Otherwise perform a batch update
        else:
            self._batch_step()

    def get_lr(self) -> float:
        """Get the current learning rate."""
        return self.lr

    def _warmup_step(self) -> None:
        """Perform a warmup step."""
        self.step_warmup += 1
        self.lr = (
            self.lr_max * self.step_warmup / self.warmup_steps
        )  # linear from ~zero to lr_max
        self._log(f" - Warmup step {self.step_warmup}/{self.warmup_steps}", lvl=2)
        self._apply_lr()

    def _cycle_step(self, val: float) -> None:
        """Perform a cycle step."""
        # Reset the steps made within a cycle
        self.step_cycle = 0

        # Reduce learning rates on plateau
        if self._improved(val):
            self.plateau_patience = 0
            self._log(
                f" - Validation improved ({self.val_best:.5f} -> {val:.5f}), resetting plateau patience",
                lvl=1,
            )
            self.val_best = val
        else:
            self.plateau_patience += 1
            self._log(
                f" - Validation did not improve ({self.val_best:.5f} -> {val:.5f}), increasing plateau patience ({self.plateau_patience}/{self.plateau_max})",
                lvl=1,
            )
            if self.plateau_patience >= self.plateau_max:
                self.lr_min = max(self.lr_floor, self.lr_min * self.plateau_factor)
                self.lr_max = max(self.lr_floor, self.lr_max * self.plateau_factor)
                self.plateau_patience = 0
                self._log(
                    f" - Plateau patience reached, reducing learning rates by {self.plateau_factor}",
                    lvl=1,
                )
                self._log(
                    f" - New learning rates: lr_min={self.lr_min} - lr_max={self.lr_max}",
                    lvl=1,
                )

        # Reset the learning rate
        self.lr = self.lr_max
        self._apply_lr()

    def _improved(self, val: float) -> bool:
        """Check if the validation value improved."""
        return (
            (self.val_mode == "min")
            and ((1 - self.plateau_improvement_ratio) * self.val_best > val)  # r% less
            or (self.val_mode == "max")
            and ((1 + self.plateau_improvement_ratio) * self.val_best < val)  # r% more
        )

    def _batch_step(self) -> None:
        """Perform a batch step."""
        self.step_cycle += 1

        # Get the relative step
        _step = abs(self.cycle_steps / 2 - (self.step_cycle % self.cycle_steps)) / (
            self.cycle_steps / 2
        )
        self.lr = self.lr_min + (self.lr_max - self.lr_min) * _step
        self._log(f" - Batch step {self.step_cycle}/{self.cycle_steps}", lvl=2)
        self._apply_lr()

    def _apply_lr(self) -> None:
        """Apply the learning rate to the optimizer."""
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = self.lr
        self._log(f" - Updated learning rate to {self.lr}", lvl=2)

    def _log(self, msg: str, lvl: int) -> None:
        """Log if the logger is provided."""
        if (self.logger is not None) and (lvl <= self.verbosity):
            self.logger(msg)
