"""Custom trainer class."""

from __future__ import annotations

from typing import Any

from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard.writer import SummaryWriter

from vito_lot_delineation.models import BaseModel


class BaseTrainer:
    """Custom trainer class."""

    def __init__(
        self,
        model: BaseModel,
        es_tolerance: int = 5,
        es_improvement: float = 0.05,
        es_metric: str = "loss",
        lr_tolerance: int = 3,
    ) -> None:
        """
        Initialise the trainer.

        Parameters
        ----------
        model : Model
            Model to train
        es_tolerance : int
            Early stopping tolerance, by default 5
        es_improvement : float
            Early stopping improvement before best model gets replaced, by default 0.0
        es_metric : int
            Metric used to monitor early stopping, f1_weighed by default
            Options: loss
        """
        self.model = model
        self._optimizer = Adam(self.model.parameters(), lr=1e-3)
        self._es_tol = es_tolerance
        self._es_improv = es_improvement
        self._es_metric = es_metric
        self._writer = SummaryWriter(log_dir=str(self.model.model_folder / "logs"))
        self._scheduler = ReduceLROnPlateau(
            optimizer=self._optimizer,
            mode="min" if (self._es_metric == "loss") else "max",
            patience=lr_tolerance,
            verbose=True,
        )

    def train(
        self,
        steps_max: float = 1e9,
        steps_val: int = 50,
        batch_size: int = 32,
        *args: Any,
        **kwargs: Any,
    ) -> BaseModel:
        """
        Training script.

        Parameters
        ----------
        steps_max : int, optional
            Maximum number of steps used to train a model, by default 1e9
        steps_val : int, optional
            Number of steps used to evaluate a model, by default 50
        batch_size : int, optional
            Batch size used during training, by default 32

        Returns
        -------
        Model
            Best model obtained during training.
        """
        raise NotImplementedError
