"""Custom learning rate scheduler."""

from __future__ import annotations

from typing import Any

import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import ReduceLROnPlateau

from vito_crop_classification.vito_logger import Logger


class LearningRateScheduler:
    """Custom learning rate scheduler."""

    def __init__(
        self,
        tag: str | None,
        optimizer: Optimizer,
        mode: str = "min",
        logger: Logger | None = None,
    ) -> None:
        """
        Initialise the custom learning rate scheduler.

        Parameters
        ----------
        tag : str, optional
            Type of learning rate scheduler to use
            Options: None, reduce_on_plateau
        optimizer : Optimizer
            Optimizer around which the learning rate scheduler is wrapped
        mode : str
            Optimisation mode to use
            Options: min, max
        logger : Logger, optional
            Logging function used to log results with
        """
        assert tag in {None, "reduce_on_plateau"}
        assert mode in {"min", "max"}

        self._mode = mode
        self._tag = tag
        self._optimizer = optimizer
        self._scheduler = get_lr_scheduler(
            scheduler=self._tag,
            optimizer=self._optimizer,
            mode=self._mode,
        )
        self._logger = logger

    def __str__(self) -> str:
        """String representation of the learning rate scheduler."""
        return (
            f"LearningRateScheduler(scheduler={self._tag}, mode={self._mode}, lr={self.get_lr()})"
        )

    def __repr__(self) -> str:
        """String representation of the learning rate scheduler."""
        return str(self)

    def __call__(self, metric: float | torch.float64) -> None:
        """Take a step with the optimizer."""
        if self._scheduler is None:
            return

        # Perform the scheduling
        if isinstance(self._scheduler, ReduceLROnPlateau):
            self._scheduler.step(metric)
        self.log_status()

    def get_lr(self) -> float:
        """Get the current learning rate of the optimizer."""
        return self._optimizer.param_groups[0].get("lr", None)

    def log_status(self) -> None:
        """Log the provided message, if logger exists."""
        if (self._logger is not None) and (self._scheduler is not None):
            self._logger(f"Learning rate scheduler: {self._tag}")
            self._logger(f" - Current learning rate: {self.get_lr()}")
            self._logger(
                f" - Patience: {self._scheduler.num_bad_epochs}/{self._scheduler.patience}"
            )


def get_lr_scheduler(scheduler: str | None, optimizer: Any, mode: str) -> Any:
    """Load in the proper learning rate scheduler."""
    if scheduler is None:
        return None
    elif scheduler == "reduce_on_plateau":
        return ReduceLROnPlateau(
            optimizer=optimizer,
            mode=mode,
            patience=1,  # TODO: Investigate, patience seems to be two? Or logging runs behind..
        )
    else:
        raise Exception(f"Learning rate scheduling '{scheduler}' not supported!")
