"""Training scrips."""

from __future__ import annotations

from datetime import datetime
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd

from vito_crop_classification.constants import get_models_folder
from vito_crop_classification.data_io.splits import save_splits
from vito_crop_classification.evaluation import evaluate
from vito_crop_classification.filters.base import BaseFilter
from vito_crop_classification.inference import Predictor
from vito_crop_classification.model import Model
from vito_crop_classification.training.trainer import Trainer


def main(
    mdl_tag: str,
    cfg: dict[str, Any],
    scale_cfg: dict[str, tuple[float, float]],
    data_train: pd.DataFrame,
    data_val: pd.DataFrame,
    data_test: pd.DataFrame,
    feat_cfg: dict[str, Any] | None = None,
    mdl_f: Path | None = None,
    batch_size: int = 512,
    steps_max: int = int(1e9),  # To infinity and beyond  # noqa: B008
    steps_val: int = 500,
    es_tolerance: int = 5,
    es_improvement: float = 0.0,
    es_metric: str = "f1_class",
    augmentation: bool = True,
    write_results: bool = False,
    balance_by: list[str] = ["target_id"],
) -> None:
    """
    Script to train an encoder and classification pair.

    Parameters
    ----------
    mdl_tag : str
        Name of the model
    cfg : dict[str,Any]
        Configuration file specifying the model to use
    scale_cfg : dict[str, tuple[float, float]]
        Scaling configuration used to transform the provided data
    data_train : pd.DataFrame
        Dataset used to train the model
    data_val : pd.DataFrame
        Dataset used to validate the model during training
    data_test : pd.DataFrame
        Dataset used to evaluate the model
    feat_cfg : dict[str, Any] | None
        Feature configuration used to generate the data
    mdl_f : Path | None
        Place where the model gets stored
    batch_size : int
        Batch size used during training
    steps_max : int, optional
        Maximum number of training steps the model performs
    steps_val : int, optional
        Number of training steps the model performs before validating
    es_tolerance : int
        Early stopping tolerance, by default 5
    es_improvement : float
        Early stopping improvement before best model gets replaced, by default 0.0
    es_metric : int
        Metric used to monitor early stopping, f1_weighed by default
        Options: loss, f1_sample, f1_class
    augmentation : bool, optional
        Decide either to use augmentation during training or not, by default True
    write_results : bool, optional
        Whether to write a results.parquet file of the test set to the model folder
        NOTE: This file is large, by default False
        NOTE: the streamlit application uses the results data to analyse a model!
    balance_by : list[str], optional
        Balance training by the specified list of columns, by default ["target_id"]
    """
    # Check and update attributes
    mdl_f = mdl_f or get_models_folder()

    # Compose the model
    class_ids, class_names = zip(*sorted(set(zip(data_train.target_id, data_train.target_name))))
    model = Model(
        model_tag=f"{datetime.now().strftime('%Y%m%dT%H%M%S')}-{mdl_tag}",
        config_file=cfg,
        class_ids=np.asarray(class_ids),
        class_names=np.asarray(class_names),
        scale_cfg=scale_cfg,
        mdl_f=mdl_f,
    )
    model.build_model()

    # save model and splits
    model.save()
    save_splits(
        df_train=data_train,
        df_val=data_val,
        df_test=data_test,
        data_f=model.model_folder,
        data_cfg="splits",
        scale_cfg=scale_cfg,
        feat_cfg=feat_cfg,
    )

    # TODO: Make filter part of config file?
    filterer = BaseFilter(classes=class_ids, filter_f=model.model_folder / "filters")

    # Train the model
    trainer = Trainer(
        model,
        train_type=cfg["train_type"],
        loss_type=cfg["loss"],
        es_tolerance=es_tolerance,
        es_improvement=es_improvement,
        es_metric=es_metric,
    )
    best_model = trainer.train(
        data_train=data_train,
        data_val=data_val,
        steps_max=steps_max,
        steps_val=steps_val,
        batch_size=batch_size,
        augmentation=augmentation,
        balance_by=balance_by,
    )

    # Evaluate the model's performance
    predictor = Predictor(model=best_model, filterer=filterer)
    preds = predictor(data_test, transform=False, allow_ts_cut=True)
    evaluate(
        y_pred=preds["prediction_id"].to_numpy(),
        y_true=data_test["target_id"].to_numpy(),
        class_ids=best_model.get_class_ids(),
        class_names=best_model.get_class_names(),
        save_dir=best_model.model_folder / "evaluation",
        logger=best_model.logger,
        title=best_model.tag,
    )

    # Write away the test results
    if write_results:
        save_dir = best_model.model_folder / "predictions"
        save_dir.mkdir(parents=True, exist_ok=True)
        results = pd.concat([data_test, preds], axis=1)
        results.to_parquet(save_dir / "results.parquet")
