"""Run the training script."""

from __future__ import annotations

from typing import Any

from vito_crop_classification.data_io.loaders import load_datasets_from_cfg
from vito_crop_classification.training import run_training


# TODO: Add more parameters as you like
#  - Parameters that can change over training loops (e.g. different model configurations)
#  - Don't add static parameters (e.g. steps_val)
def run_base_training(
    mdl_tag: str,
    mdl_cfg: dict[str, Any],
    dataset: dict | str = "hybrid_dataset",
    dataset_cfg: str | None = None,
) -> None:
    """Fully train a model, specified by its configuration, and using default training parameters.

    Parameters
    ----------
    mdl_tag : str
        Name of the model
    mdl_cfg : dict[str,Any]
        Configuration file specifying the model to use
    dataset : dict | str, optional
        Results dictionary containing splits and scale_cfg or name of the dataset folder
    dataset_cfg : str | None, optional
        name of the configuration file for the dataset.
    """
    assert type(dataset) in [dict, str], (
        "Dataset must be either a dict or string, if df is a Dataframe,"
        "transform it first into the right format via the load_datasets function"
    )

    if type(dataset) == str:
        assert dataset_cfg, "If Dataset is a string, then dataset_cfg must be specified"
        dataset, _ = load_datasets_from_cfg(data_f=dataset, data_cfg=dataset_cfg)

    # Ensure all datasets are filled
    df_train = dataset.get("df_train", None)
    assert df_train is not None
    df_val = dataset.get("df_val", None)
    assert df_val is not None
    df_test = dataset.get("df_test", None)
    assert df_test is not None

    # Run the training loop
    run_training(
        mdl_tag=mdl_tag,
        cfg=mdl_cfg,
        scale_cfg=dataset.get("scale_cfg"),
        feat_cfg=dataset.get("feat_cfg"),
        data_train=df_train,
        data_val=df_val,
        data_test=df_test,
        steps_max=int(1e9),
        steps_val=500,
        es_metric="loss",
        es_tolerance=6,  # Two full drops in learning rate
        es_improvement=0.0,
        augmentation=True,
        write_results=True,
        balance_by=["target_id"],
    )


if __name__ == "__main__":
    from vito_crop_classification.configs import get_cfg

    # Load in the datasets
    run_base_training(
        mdl_tag="transformer_clf",
        mdl_cfg=get_cfg("transformer_clf"),
        dataset="36ts-compDekad",
        dataset_cfg="training_splits",
    )
