"""Base trainer."""

from __future__ import annotations

import traceback
from typing import Any

from torch.optim import Adam

from vito_cropsar.models.base_config import TrainerBaseConfig
from vito_cropsar.models.base_model import InpaintingBase
from vito_cropsar.models.shared import TensorBoardLogger


class TrainerBase:
    """Base trainer class."""

    def __init__(
        self,
        model: InpaintingBase,
        cfg: TrainerBaseConfig,
    ) -> None:
        """
        Initialise the trainer.

        Parameters
        ----------
        model : InpaintingBase
            Model to train
        cfg : TrainerBaseConfig
            Trainer configuration
        """
        self._model = model
        self.cfg = cfg
        self._optimizer = Adam(self._model.parameters(), lr=self.cfg.lr)

        # initialize bands
        self._tensorboard = TensorBoardLogger(
            mdl_f=self._model.model_folder,
            n_plots=self.cfg.n_plots,
            scaler=self._model.scaler,
            fill_nan=self._model.fill_nan,
        )

    def __getitem__(self, key: str) -> Any:
        """Get the attribute."""
        return getattr(self, key)

    def __getattribute__(self, attr: str) -> Any:
        """Get the attribute."""
        try:
            return super().__getattribute__(attr)
        except AttributeError as e:
            if hasattr(self.cfg, attr):
                return getattr(self.cfg, attr)
            raise e

    def train() -> None:
        """Training script."""
        raise NotImplementedError

    def recover(self) -> InpaintingBase:
        """Recover the best version of the model."""
        return self._model.__class__.load(self._model.model_folder)

    def log(self, msg: str) -> None:
        """Log the message using the model's logger, if exists."""
        if self._model.logger is None:
            print(msg)
        else:
            self._model.logger(msg)

    def log_exception(self, e: Exception) -> None:
        """Log the message using the model's logger, if exists."""
        msg_e = f"Error during forward pass: {e}"
        msg_t = "".join(
            traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)
        )
        if self._model.logger is None:
            print(msg_e)
            print(msg_t)
        else:
            self._model.logger(msg_e)
            self._model.logger(msg_t)
