"""Run the training script."""

from __future__ import annotations

from typing import Any

from vito_crop_classification.data import load_datasets
from vito_crop_classification.training import run_training


# TODO: Add more parameters as you like
#  - Parameters that can change over training loops (e.g. different model configurations)
#  - Don't add static parameters (e.g. steps_val)
def main(
    mdl_tag: str,
    mdl_cfg: dict[str, Any],
    dataset: str = "hybrid_dataset",
    r_test: float = 0.1,  # ratio of test set over the entire dataset
    r_val: float = 0.1,  # ratio of validation set over the training dataset
    dataset_ratio: float = 1.0,  # percentage of dataset to load from the dataset
    random_state: int = 42,  # random state used to split data
    steps_max: int = int(1e9),  # To infinity and beyond
    steps_val: int = 500,
    es_tolerance: int = 5,
    es_improvement: float = 0.0,
    es_metric: str = "f1_class",
    augmentation: bool = True,
    run_analysis: bool = True,
) -> None:
    """Fully train a model, specified by its configuration."""
    result = load_datasets(
        dataset=dataset,
        r_test=r_test,
        r_val=r_val,
        dataset_ratio=dataset_ratio,
        random_state=random_state,
    )

    # Ensure all datasets are filled
    df_train = result.get("df_train", None)
    assert df_train is not None
    df_val = result.get("df_val", None)
    assert df_val is not None
    df_test = result.get("df_test", None)
    assert df_test is not None

    # Run the training loop
    run_training(
        mdl_tag=mdl_tag,
        cfg=mdl_cfg,
        scale_cfg=result.get("scale_cfg"),
        data_train=df_train,
        data_val=df_val,
        data_test=df_test,
        steps_max=steps_max,
        steps_val=steps_val,
        es_metric=es_metric,
        es_tolerance=es_tolerance,
        es_improvement=es_improvement,
        augmentation=augmentation,
        run_analysis=run_analysis,
    )


if __name__ == "__main__":
    from vito_crop_classification.configs import get_transformer_cfg

    # Train the transformer model
    main(
        mdl_tag="transformer_optical_hybrid",
        mdl_cfg=get_transformer_cfg(),
    )
