"""Train and evaluate a model of choice."""

from __future__ import annotations

from vito_cropsar.evaluation import evaluate
from vito_cropsar.models import InpaintingBase, TrainerBase


def main(
    model: InpaintingBase,
    trainer: TrainerBase,
    cache_tag: str | None = None,
) -> None:
    """
    Train and evaluate the provided model.

    Parameters
    ----------
    model : InpaintingBase
        Model to train and evaluate
    trainer : TrainerBase
        Trainer used to train the model
    """
    # Train the model
    model = trainer.train()

    # Evaluate the model
    evaluate(model, cache_tag=cache_tag)


if __name__ == "__main__":
    from datetime import datetime

    from vito_cropsar.data import Scaler
    from vito_cropsar.models import configuration_cnn_transformer

    BANDS_S1 = ["s1_asc_vv", "s1_des_vv", "s1_asc_vh", "s1_des_vh"]
    BANDS_S2 = ["s2_fapar", "s2_b02", "s2_b03", "s2_b04"]
    CACHE_TAG = "fapar_rgb"
    TAG = "cnn_transformer_multi_repr2"
    CONFIG = configuration_cnn_transformer

    # Create the model and trainer
    trainer_cfg = CONFIG["config_trainer"](
        cache_tag=CACHE_TAG,
    )
    model_cfg = CONFIG["config_model"](
        tag=f"{datetime.now().strftime('%Y%m%dT%H%M%S')}_{TAG}",
        scaler=Scaler.load(
            bands_s1=BANDS_S1,
            bands_s2=BANDS_S2,
            sample_s1=True,
        ),
    )

    # Create the model and trainer, or load previously trained model instead
    # my_model = InpaintingCnnTransformer.load(
    #     mdl_f=get_models_folder() / "20230610T142128_cnn_transformer_multi",
    # )
    my_model = CONFIG["model"](cfg=model_cfg)
    my_trainer = CONFIG["trainer"](model=my_model, cfg=trainer_cfg)

    # Train and evaluate the model
    try:
        main(model=my_model, trainer=my_trainer, cache_tag=CACHE_TAG)
    except Exception as exc:  # noqa: BLE001
        my_model.logger(exc)
