"""Base filter."""

from __future__ import annotations

import hashlib
from pathlib import Path
from warnings import warn

import numpy as np
import pandas as pd
import torch
from numpy.typing import NDArray
from torch.nn.functional import softmax

from vito_crop_classification.constants import get_filters_folder


class BaseFilter:
    """Base filter."""

    def __init__(
        self,
        classes: NDArray[np.str_],
        columns: list[str] | None = None,
        filter_f: Path | None = None,
    ):
        """
        Initialize BaseFilter.

        Parameters
        ----------
        classes : NDArray[np.str_]
            Classes used for training
        columns : list[str] | None
            List of filter names, ['Nothing'] by default
        filter_f : Path | None
            Folder to store the filters in, get_filters_folder() by default
        """
        self._classes = classes
        self._columns = columns if columns else ["Nothing"]
        self._filter_f = filter_f if filter_f else get_filters_folder()
        self._filter = self.load()

    def __str__(self) -> str:
        """String representation of the classifier."""
        return f"{self.__class__.__name__}(n_classes={len(self._classes)})"

    def __repr__(self) -> str:
        """String representation of the classifier."""
        return str(self)

    def __call__(self, data: pd.DataFrame, preds: torch.Tensor) -> NDArray[np.str_]:
        """
        Filter out the best performing class.

        Parameters
        ----------
        data: pd.DataFrame
            Input data that provides additional data to filter the predictions with
        preds : torch.Tensor
            Raw predictions to filter
        """
        return self.filter_proba(data=data, preds=preds)[0]

    def filter_proba(
        self,
        data: pd.DataFrame,
        preds: torch.Tensor,
        ignore_mask: np.ndarray[bool] = [],
    ) -> tuple[NDArray[np.str_], NDArray[np.float64]]:
        """
        Filter and return best class with corresponding probability.

        Parameters
        ----------
        data: pd.DataFrame
            Input data that provides additional data to filter the predictions with
        preds : torch.Tensor
            Raw predictions to filter
        ignore_mask : np.ndarray[bool] | None, optional
            mask used to ignore classes
        """
        assert len(data) == len(preds)
        preds = self.forward_process(data, preds)
        preds = softmax(preds, dim=1).cpu().detach().numpy()
        preds[:, ignore_mask] = 0.0
        best_idx = np.argmax(preds, axis=1)
        best = np.asarray([self._classes[i] for i in best_idx])
        probs = np.asarray([p[i] for i, p in zip(best_idx, preds)])
        return best, probs

    def forward_process(self, df: pd.DataFrame, preds: torch.Tensor) -> torch.Tensor:
        """Apply filter over the predictions."""
        return preds

    def train(
        self,
        df: pd.DataFrame,
    ) -> None:
        """
        Train the filter on the provided dataset.

        Parameters
        ----------
        df : pd.DataFrame
            Dataset to extract patterns from
        """
        assert "".join(sorted(df.target_id.unique())) == "".join(
            self._classes
        ), "classes and df targets disagree"
        self._filter = self.create_new()

    def save(self) -> None:
        """Placeholder for the save filter function."""
        self._filter.to_csv(self._get_filter_path())

    def load(self) -> pd.DataFrame:
        """Placeholder for the load filter function."""
        try:
            return pd.read_csv(
                self._get_filter_path(),
                index_col=0,
            )
        except FileNotFoundError:
            warn(f"No filter found for class '{self.__class__.__name__}'!")
            return self.create_new()

    def create_new(self) -> pd.DataFrame:
        """Create a new filter."""
        return pd.DataFrame(
            np.ones((len(self._classes), len(self._columns))),
            index=self._classes,
            columns=self._columns,
        )

    def get_filter(self) -> pd.DataFrame:
        """Return the filter derived over the dataset."""
        return self._filter

    def set_filter(self, f: pd.DataFrame) -> None:
        """Use the provided filter as the new filter."""
        self._filter = f

    def get_classes(self) -> NDArray[np.str_]:
        """Get the filter's classes."""
        return self._classes

    def explain(self) -> None:
        """Explain filter verbally."""
        print(f"Explanation -> {self}")
        print(" ---")
        print(" No active filtering is applied when using this filter.")

    def _get_filter_path(self) -> Path:
        """Get saving path."""
        columns = "".join(self._classes)
        hash_key = hashlib.sha256(columns.encode("utf-8")).hexdigest()
        return self._filter_f / f"{self.__class__.__name__}_{hash_key}.csv"


if __name__ == "__main__":
    from random import randint

    from vito_crop_classification.data_io import load_datasets

    my_result = load_datasets(dataset="hybrid_dataset", dataset_ratio=0.01)
    my_df_train = my_result["df_train"]
    my_df_val = my_result["df_val"]
    my_df_train["target_id"] = [f"{randint(1,3)}" for _ in range(len(my_df_train))]

    # Create filter
    print("\nCreating filter..")
    my_filter = BaseFilter(classes=np.asarray(["1", "2", "3"]))
    my_filter.train(my_df_train)
    my_filter.save()
    print(" - Filter:", my_filter)

    # Load in a new filter
    print("\nLoading in old filter..")
    my_filter = BaseFilter(classes=np.asarray(["1", "2", "3"]))
    print(" - Filter:", my_filter)

    # Run inference on the filter
    print("\nRunning the filter..")
    my_filter(my_df_val, torch.rand(len(my_df_val), 3))
    my_filter.explain()
