"""Predictor class."""

from __future__ import annotations

from logging import warning

import numpy as np
import pandas as pd
from numpy.typing import NDArray
from torch.nn.functional import softmax

from vito_crop_classification.data_process import transform_df
from vito_crop_classification.filters.base import BaseFilter
from vito_crop_classification.model import Model


class Predictor:
    """Predictor class."""

    def __init__(
        self,
        model: Model,
        filterer: BaseFilter,
        batch_size: int = 512,
    ):
        """Predictor class."""
        self.model = model
        self.filterer = filterer
        self._batch_size = batch_size
        self.model.eval()

    def __call__(  # noqa: PLR0915
        self,
        data: pd.DataFrame,
        transform: bool = True,
        ignore_classes: list[str] = [],
        allow_ts_cut: bool = False,
    ) -> pd.DataFrame:
        """Make predictions for the provided data.

        Parameters
        ----------
        data : pd.DataFrame
            Data to predict on.
        transform : bool, optional
            If True, data is transformed before being fed to the predictor, by default True
        ignore_classes : list[str], optional
            Define classes to ignore while predicting
        allow_ts_cut : bool, optional
            Allow cutting of time series to model length, by default False

        Returns
        -------
        pd.DataFrame
            DataFrame containing model predictions over data. Columns contain in order:
             - prediction : The predicted class
             - probability : The corresponding probability of the predicted class
             - probabilities : The probability vector of all the classes
             - embedding : The encoded (hidden) embedding of each sample
        """
        # Transform the received data, if requested
        if transform:
            data, _ = transform_df(
                data,
                scale_cfg=self.model.get_scale_cfg(),
                scale_dynamically=False,
                drop_classes=None,
                inference=True,
            )

        # cut ts if necessary
        ts_cols = [c for c in data.columns if "ts_" in c]
        data_len = len(data[ts_cols[0]].iloc[0])
        seq_len = self.model.enc.get_seq_len()
        if seq_len < data_len:
            if not allow_ts_cut:
                msg = f" Data length {data_len}ts is larger than model length {seq_len}ts. "
                msg += "Please set allow_ts_cut=True to allow cutting of time series."
                raise ValueError(msg)
            msg = f" Cutting time series to model length {seq_len}ts from original {data_len}ts !! "
            msg += "If this is not intended, please check your model configuration and the data you are using!"
            warning(msg)
            start = (data_len - seq_len) // 2
            end = start + seq_len
            for ts in ts_cols:
                data[ts] = data[ts].apply(lambda x: x[start:end])

        # Create predictions, and keep intermediary results
        embs: NDArray[np.float64] = np.asarray([], dtype=np.float64)
        prob_embs: NDArray[np.float64] = np.asarray([], dtype=np.float64)
        preds_id: NDArray[np.str_] = np.asarray([], dtype=np.str_)
        preds_name: NDArray[np.str_] = np.asarray([], dtype=np.str_)
        probs: NDArray[np.float64] = np.asarray([], dtype=np.float64)
        id2name = self.model.get_class_mapping()
        segments, i_s = np.array_split(data, len(data) // self._batch_size + 1), 0
        ignore_mask = np.in1d(self.model.get_class_ids(), ignore_classes)
        print("Predicting: ", end="", flush=True)
        for i, segment in enumerate(segments):
            # Append the hidden embedding
            df_seg = pd.DataFrame(segment, columns=data.columns)
            batch_embs = self.model.enc(df_seg)
            batch_embs_np = batch_embs.cpu().detach().numpy()
            embs = batch_embs_np if len(embs) == 0 else np.concatenate([embs, batch_embs_np])

            # Append the total class probabilities
            clf_logits = self.model.clf(batch_embs)
            clf_probs = softmax(clf_logits, dim=1).cpu().detach().numpy()
            clf_probs[:, ignore_mask] = 0.0
            prob_embs = clf_probs if len(prob_embs) == 0 else np.concatenate([prob_embs, clf_probs])

            # Compute the best predictions and probabilities
            batch_pred_ids, batch_probs = self.filterer.filter_proba(
                segment, clf_logits, ignore_mask
            )
            batch_pred_name = np.asarray([id2name[x] for x in batch_pred_ids])

            # check for NaN inputs and mask prediction
            nan_mask = []
            for r in segment.values:
                check_nan = [np.any(np.isnan(x)) for x in r if type(x) == float]
                nan_mask.append(np.any(check_nan))
            batch_pred_ids[nan_mask] = "NaN"
            batch_pred_name[nan_mask] = "NaN"

            # concatenate predictions and probabilities
            preds_id = np.concatenate([preds_id, batch_pred_ids])
            preds_name = np.concatenate([preds_name, batch_pred_name])
            probs = np.concatenate([probs, batch_probs])

            # Display progress (every 10%)
            if (i * 10 / len(segments)) > i_s:
                print(".", end="", flush=True)
                i_s += 1

        return pd.DataFrame(
            zip(preds_id, preds_name, probs, prob_embs, embs),
            columns=[
                "prediction_id",
                "prediction_name",
                "probability",
                "probabilities",
                "embedding",
            ],
            index=data.index,
        )
