"""Custom CnnTransformer trainer."""

from __future__ import annotations

import json

from vito_cropsar.models.CnnTransformer.config import MultiTrainerConfig, TrainerConfig
from vito_cropsar.models.CnnTransformer.main import InpaintingModel
from vito_cropsar.models.CnnTransformer.trainers import (
    CnnTransformerTrainerAutoencoder,
    CnnTransformerTrainerBase,
    CnnTransformerTrainerDecoder,
    CnnTransformerTrainerFinal,
    CnnTransformerTrainerTransformer,
)
from vito_cropsar.models.utils import JSONEncoder


class MultiTrainer:
    """
    Custom CnnTransformer mutli-staged trainer.

    This class trains the CnnTransformer in a multi-staged fashion:
        1. Train the encoder and decoder (autoencoder)
            Trainable: input block, encoder, decoder, output block
            Frozen: /
            Not used: transformer
        2. Train the transformer (time)
            Trainable: transformer, decoder, output block
            Frozen: input block, encoder
            Not used: /
        3. Train the image sharpening (spatial)
            Trainable: output block
            Frozen: input block, encoder, transformer, decoder
            Not used: /
    """

    def __init__(
        self,
        model: InpaintingModel,
        cfg: MultiTrainerConfig,
    ) -> None:
        """
        Initialise the trainer.

        Parameters
        ----------
        model : InpaintingBase
            Model to train
        cfg : TrainerConfig
            Trainer configuration
        """
        self._model = model
        self.cfg = cfg

        # Save the configuration under the model's folder
        with open(self._model.model_folder / "config_trainer.json", "w") as f:
            json.dump(
                self.cfg.dict(),
                f,
                cls=JSONEncoder,
                indent=4,
            )

    def train(self, global_step: int = 0) -> InpaintingModel:
        """Train the network over the different stages."""
        self.log("")
        self.log(f"Model: {self._model}")
        self.log("")

        # Loop over the different trainers
        shared_state = {}
        for tag, cfg in self.cfg.trainers.items():
            trainer = _get_trainer(tag=tag)
            global_step, self._model = trainer(
                model=self._model,
                cfg=self._update_cfg(cfg, global_step=global_step),
                shared_state=shared_state,
            ).train(global_step)
        return self._model

    def log(self, msg: str) -> None:
        """Log the message using the model's logger, if exists."""
        if self._model.logger is not None:
            self._model.logger(msg)

    def _update_cfg(self, cfg: TrainerConfig, global_step: int) -> TrainerConfig:
        """Update the config with the optional fields in the root's config."""
        new_elements = {k: v for k, v in dict(self.cfg).items() if k != "trainers"}
        new_cfg = {k: v for k, v in dict(cfg).items() if k not in new_elements}
        new_cfg.update(new_elements)
        if ("steps_max" in new_cfg) and (new_cfg["steps_max"] is not None):
            new_cfg["steps_max"] += global_step  # Relative step
        return TrainerConfig(**new_cfg)


def _get_trainer(tag: str) -> CnnTransformerTrainerBase:
    """Get trainer based on tag."""
    if tag == "autoencoder":
        return CnnTransformerTrainerAutoencoder
    if tag == "transformer":
        return CnnTransformerTrainerTransformer
    if tag == "decoder":
        return CnnTransformerTrainerDecoder
    if tag == "final":
        return CnnTransformerTrainerFinal
    raise ValueError(f"Unknown trainer tag: {tag}")
