"""Run the predictive inference script."""

from __future__ import annotations

import pandas as pd

from vito_crop_classification.inference import run_inference
from vito_crop_classification.model import Model


def main(
    df: pd.DataFrame,
    model: Model,
    patch_smoothing: bool = False,
    patch_shape: tuple[int, int] | None = None,
    transform: bool = True,
    ignore_classes: list[str] = [],
) -> pd.DataFrame:
    """
    Run the inference pipeline.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame to make predictions on
    model : Model, optional
        Model to use for inference
    patch_smoothing : bool
        Smooth on a patch level
    patch_shape : tuple[int, int], optional
        Shape of the patch, should be provided if patch_smoothing=True
    transform : bool, optional
        Run dataframe transformation
    ignore_classes : list[str], optional
        Define classes to ignore while predicting

    Returns
    -------
    prediction : pd.DataFrame
        Prediction on the provided dataframe

    Note
    ----
     - Either model or mdl_f should be provided
    """
    return run_inference(
        df=df,
        model=model,
        patch_smoothing=patch_smoothing,
        patch_shape=patch_shape,
        transform=transform,
        ignore_classes=ignore_classes,
    )


if __name__ == "__main__":
    import matplotlib.pyplot as plt
    import numpy as np
    import seaborn as sns

    from vito_crop_classification.constants import get_data_folder, get_models_folder

    my_model = Model.load(mdl_f=get_models_folder() / "20221102T184737-transf_optical_dem")
    my_patch = pd.read_parquet(get_data_folder() / "inference-patches/testpatch_x166_y126.parquet")

    # Make the prediction
    my_preds = main(
        df=my_patch,
        model=my_model,
        patch_smoothing=True,
        patch_shape=(166, 126),
        ignore_classes=["1-5-0-0"],
    )

    plt.figure()
    plt.title(f"{my_model.tag} - Predictions")
    my_names = list(my_model.get_class_names())
    my_arr = np.asarray([my_names.index(x) for x in my_preds.prediction_name]).reshape((166, 126))
    sns.heatmap(
        my_arr,
        xticklabels=False,
        yticklabels=False,
        vmin=0,
        vmax=len(my_names),
    )
    plt.tight_layout()
    plt.show()

    plt.figure()
    plt.title(f"{my_model.tag} - Probability")
    my_names = list(my_model.get_class_names())
    my_arr = np.asarray(my_preds.probability).reshape((166, 126))
    sns.heatmap(
        my_arr,
        xticklabels=False,
        yticklabels=False,
        vmin=0,
        vmax=1,
    )
    plt.tight_layout()
    plt.show()
