"""Predictor class."""

from __future__ import annotations

import numpy as np
import pandas as pd
import torch
from numpy.typing import NDArray
from tqdm import tqdm

from vito_crop_classification.data import transform_df
from vito_crop_classification.filters import BaseFilter
from vito_crop_classification.model import Model


class Predictor:
    """Predictor class."""

    def __init__(
        self,
        model: Model,
        filterer: BaseFilter,
        batch_size: int = 512,
    ):
        """Predictor class."""
        self.model = model
        self.filterer = filterer
        self._batch_size = batch_size
        self.model.eval()

    def __call__(
        self,
        data: pd.DataFrame,
        transform: bool = True,
        scale: bool = True,
        level: int = 3,
    ) -> pd.DataFrame:
        """Make predictions for the provided data.

        Parameters
        ----------
        data : pd.DataFrame
            Data to predict on.
        transform : bool, optional
            If True, data is transformed before being fed to the predictor, by default True
        scale : bool, optional
            If True, data is scaled while transformed, by default True

        Returns
        -------
        pd.DataFrame
            DataFrame containing model predictions over data. Columns contain in order:
             - prediction : The predicted class
             - probability : The corresponding probability of the predicted class
             - probabilities : The probability vector of all the classes
             - embedding : The encoded (hidden) embedding of each sample
        """
        # Transform the received data, if requested
        if transform:
            data, _ = transform_df(
                data, scale_cfg=self.model.get_scale_cfg(), scale_dynamically=False
            )

        # Create predictions, and keep intermediary results
        embs: NDArray[np.float64] = np.asarray([], dtype=np.float64)
        prob_embs: NDArray[np.float64] = np.asarray([], dtype=np.float64)
        preds_id: NDArray[np.str_] = np.asarray([], dtype=np.str_)
        preds_name: NDArray[np.str_] = np.asarray([], dtype=np.str_)
        probs: NDArray[np.float64] = np.asarray([], dtype=np.float64)
        id2name = self.model.get_class_mapping()
        for segment in tqdm(
            np.array_split(data, len(data) // self._batch_size + 1),
            desc="Predicting..",
        ):
            # Append the hidden embedding
            df_seg = pd.DataFrame(segment, columns=data.columns)
            batch_embs = self.model.enc(df_seg)
            batch_embs_np = batch_embs.cpu().detach().numpy()
            embs = batch_embs_np if len(embs) == 0 else np.concatenate([embs, batch_embs_np])

            # Append the total class probabilities
            clf_logits = self.model.clf(batch_embs)
            clf_probs = torch.softmax(clf_logits, dim=1).cpu().detach().numpy()
            prob_embs = clf_probs if len(prob_embs) == 0 else np.concatenate([prob_embs, clf_probs])

            # Compute the best predictions and probabilities
            batch_pred_ids, batch_probs = self.filterer.filter_proba(segment, clf_logits)
            batch_pred_name = np.asarray([id2name[x] for x in batch_pred_ids])

            # check for NaN inputs and mask prediction
            nan_mask = []
            for r in segment.values:
                check_nan = [np.any(np.isnan(x)) for x in r if type(x) != str]
                nan_mask.append(np.any(check_nan))
            batch_pred_ids[nan_mask] = "NaN"
            batch_pred_name[nan_mask] = "NaN"

            # concatenate predictions and probabilities
            preds_id = np.concatenate([preds_id, batch_pred_ids])
            preds_name = np.concatenate([preds_name, batch_pred_name])
            probs = np.concatenate([probs, batch_probs])

        return pd.DataFrame(
            zip(preds_id, preds_name, probs, prob_embs, embs),
            columns=[
                "prediction_id",
                "prediction_name",
                "probability",
                "probabilities",
                "embedding",
            ],
            index=data.index,
        )
