"""Altitude filter."""

from __future__ import annotations

from pathlib import Path

import numpy as np
import pandas as pd
import torch
from numpy.typing import NDArray

from vito_crop_classification.filters.base import BaseFilter


class AltitudeFilter(BaseFilter):
    """Altitude filter."""

    def __init__(
        self,
        classes: NDArray[np.str_],
        filter_f: Path | None = None,
    ) -> None:
        """
        Initialize AltitudeFilter.

        Parameters
        ----------
        classes : NDArray[np.str_]
            Classes used for training
        filter_f : Path | None
            Folder to store the filters in, get_filters_folder() by default
        """
        self._max_quantile = 1.0
        super(AltitudeFilter, self).__init__(
            classes=classes,
            columns=["MaxAltitude"],
            filter_f=filter_f,
        )

    def forward_process(self, data: pd.DataFrame, preds: torch.Tensor) -> torch.Tensor:
        """Apply filter over the predictions."""
        pred_filter = np.array([self._filter["MaxAltitude"] > s for s in data["sc_altitude"]])
        pred_filter = np.array([f if f.any() else ~f for f in pred_filter])
        preds[~pred_filter] = 0.0
        return preds

    def train(
        self,
        df: pd.DataFrame,
        max_quantile: float = 0.99,
    ) -> None:
        """
        Train the filter on the provided dataset.

        Parameters
        ----------
        df : pd.DataFrame
            Dataset to extract patterns from
        max_quantile : float
            Maximum quantile used for cut-off
        """
        super(AltitudeFilter, self).train(df)
        self._max_quantile = max_quantile
        self._filter = pd.DataFrame(
            np.asarray(
                df.groupby("target_id")["sc_altitude"].quantile(self._max_quantile).values
            ).T,
            index=self._classes,
            columns=self._columns,
        )

    def explain(self) -> None:
        """Explain filter verbally."""
        print(f"Explanation -> {self}")
        print(" ---")
        print(f" This filter checks what is the maximum altitude ({self._max_quantile} quantile)")
        print(" at which a certain crop can grow, and discard all crops that can not grow")
        print(" at the specified input altitude.")
        print(" ---")
        print(" The following max-altitudes have been derived from the training set:")
        print("  - <class> -> <maxAltitude>")
        for c, a in zip(self._classes, self._filter["MaxAltitude"]):
            print(f"  - {c} -> {a:.7f}")


if __name__ == "__main__":
    from random import randint

    from vito_crop_classification.data_io import load_datasets

    my_result = load_datasets(dataset="hybrid_dataset", dataset_ratio=0.01)
    my_df_train = my_result["df_train"]
    my_df_val = my_result["df_val"]
    my_df_train["target_id"] = [f"{randint(1,3)}" for _ in range(len(my_df_train))]

    # Create filter
    print("\nCreating filter..")
    my_filter = AltitudeFilter(classes=np.asarray(["1", "2", "3"]))
    my_filter.train(my_df_train, max_quantile=0.9)
    my_filter.save()
    print(" - Filter:", my_filter)

    # Load in a new filter
    print("\nLoading in old filter..")
    my_filter = AltitudeFilter(classes=np.asarray(["1", "2", "3"]))
    print(" - Filter:", my_filter)

    # Run inference on the filter
    print("\nRunning the filter..")
    my_filter(my_df_val, torch.rand(len(my_df_val), 3))
    my_filter.explain()
