"""Train a model."""

from __future__ import annotations

from datetime import datetime
from pathlib import Path
from typing import Any

from vito_lot_delineation.configs import get_cfg
from vito_lot_delineation.evaluation import run_evaluation
from vito_lot_delineation.models import parse_model, parse_trainer


def main(
    mdl_tag: str,
    cfg: dict[str, Any],
    mdl_f: Path | None = None,
) -> None:
    """
    Fully train a model, specified by its configuration.

    Parameters
    ----------
    mdl_tag : str
        Name of the model
    cfg : dict[str,Any]
        Configuration file specifying the model to use
    mdl_f : Path | None
        Place where the model gets stored
    """
    # Create the model
    model = parse_model(cfg["model"]["architecture"])
    model = model(
        model_tag=f"{datetime.now().strftime('%Y%m%dT%H%M%S')}-{mdl_tag}",
        config_file=cfg,
        mdl_f=mdl_f,
    )
    model.save()

    # train the model
    trainer = parse_trainer(cfg["model"]["architecture"])
    trainer = trainer(
        model=model,
        loss_cfg=cfg["train"]["losses"],
        input_cfg=cfg["input"],
        es_improvement=0.0001,
        es_tolerance=10,
        es_metric="loss",
    )
    best_model = trainer.train(
        steps_val=250,
        batch_size=16,
        use_augmentation=True,
        data_dir=Path(__file__).parent.parent.parent / "data/data",
    )

    # evaluate model
    run_evaluation(
        best_model, data_dir=Path(__file__).parent.parent.parent / "data/data"
    )


if __name__ == "__main__":

    cfg = get_cfg("MultiHeadResUnet3D")
    main(mdl_tag="MultiHeadResUnet3D_ts24_newB", cfg=cfg)
