"""Inference scrips."""

from __future__ import annotations

from functools import lru_cache
from pathlib import Path
from typing import Any

import pandas as pd

from vito_crop_classification.filters import BaseFilter
from vito_crop_classification.inference.patch import smooth_patch
from vito_crop_classification.inference.predictor import Predictor
from vito_crop_classification.model import Model


@lru_cache(maxsize=1)
def load_predictor(
    model: Model | None = None,
    mdl_f: Path | None = None,
    filter_cls: Any = BaseFilter,
) -> Predictor:
    """
    Load in a model used for predictions.

    Parameters
    ----------
    model : Model, optional
        Model to use for inference
    mdl_f : Path, optional
        Folder of the model to use for inference
    filter_cls: Any
        Filter class used to create the filter applied during inference

    Returns
    -------
    predictor : Predictor
        Predictor class to run inference with

    Note
    ----
     - Either model or mdl_f should be provided
    """
    assert (model is None) != (mdl_f is None), "Either model or mdl_f should be provided!"
    model = model if (mdl_f is None) else Model.load(mdl_f=mdl_f)
    return Predictor(
        model=model,
        filterer=filter_cls(classes=model.get_class_ids(), filter_f=model.model_folder / "filters"),
    )


def main(
    df: pd.DataFrame,
    model: Model | None = None,
    mdl_f: Path | None = None,
    patch_smoothing: bool = False,
    patch_shape: tuple[int, int] | None = None,
    transform: bool = True,
) -> pd.DataFrame:
    """
    Run the inference pipeline.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame to make predictions on
    model : Model, optional
        Model to use for inference
    mdl_f : Path, optional
        Folder of the model to use for inference
    patch_smoothing : bool
        Smooth on a patch level
    patch_shape : tuple[int, int], optional
        Shape of the patch, should be provided if patch_smoothing=True

    Returns
    -------
    prediction : pd.DataFrame
        Prediction on the provided dataframe

    Note
    ----
     - Either model or mdl_f should be provided
    """
    # Make pixel-level predictions
    predictor = load_predictor(model=model, mdl_f=mdl_f)
    pred_df = predictor(df, transform=transform)

    # Apply patch smoothing if requested
    if patch_smoothing:
        smooth_patch(
            df=pred_df,
            shape=patch_shape,
            class_ids=predictor.model.get_class_ids(),
            class_names=predictor.model.get_class_names(),
        )

    # Return the final result
    return pred_df
