"""Concatenated Filter class."""

from __future__ import annotations

from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd
import torch

from vito_crop_classification.filters.altitude import AltitudeFilter
from vito_crop_classification.filters.base import BaseFilter
from vito_crop_classification.filters.group import GroupFilter
from vito_crop_classification.filters.slope import SlopeFilter


class ConcatenatedFilter(BaseFilter):
    """Concatenated filter."""

    def __init__(
        self,
        list_filters: list[BaseFilter],
        filter_f: Path | None = None,
    ) -> None:
        """Initialize concatenated filter.

        Parameters
        ----------
        list_filters : list[BaseFilter]
            List of filters to concatenate.
        """
        classes = list_filters[0].get_classes()
        assert all((classes == f.get_classes()).all() for f in list_filters)
        self._list_filters = list_filters
        super(ConcatenatedFilter, self).__init__(
            classes=list_filters[0]._classes,
            filter_f=filter_f,
        )

    def forward_process(self, data: pd.DataFrame, preds: torch.Tensor) -> torch.Tensor:
        """Apply multiple filters over the predictions."""
        preds_recover = preds.clone()
        for f in self._list_filters:
            preds = f.forward_process(data, preds)

        # if filters have been too harsh, recover original predictions
        recover_idx = (preds.sum(1) == 0).nonzero().flatten()
        preds[recover_idx] = preds_recover[recover_idx]
        return preds

    def save(self) -> None:
        """Loading not supported for this filter."""
        for f in self._list_filters:
            f.save()

    def load(self) -> None:
        """Loading not supported for this filter."""
        for f in self._list_filters:
            f.set_filter(f.load())

    def get_filter(self) -> dict[str, Any]:
        """Return all the filters used for concatenation."""
        return {str(f): f.get_filter() for f in self._list_filters}

    def explain(self) -> None:
        """Explain filter verbally."""
        print(f"Explanation -> {self} (with {len(self._list_filters)} sub-filters)")
        for i, f in enumerate(self._list_filters):
            print(f"\nFilter {f} ({i + 1}/{len(self._list_filters)})")
            f.explain()


if __name__ == "__main__":
    from random import randint

    from vito_crop_classification.data import load_datasets

    my_result = load_datasets(dataset="hybrid_dataset", dataset_ratio=0.01)
    my_df_train = my_result["df_train"]
    my_df_val = my_result["df_val"]
    my_df_train["target_id"] = [f"{randint(1,3)}" for _ in range(len(my_df_train))]

    # Create filter
    print("\nCreating filter..")
    my_classes = np.asarray(sorted(set(my_df_train["target_id"])))
    groupFilter = GroupFilter(
        classes=my_classes,
        groups=np.asarray(sorted(set(my_df_train["sc_group"]))),
    )
    slopeFilter = SlopeFilter(classes=my_classes)
    altitudeFilter = AltitudeFilter(classes=my_classes)
    my_filter = ConcatenatedFilter(list_filters=[groupFilter, slopeFilter, altitudeFilter])
    print(" - Filter:", my_filter)

    # Run inference on the filter
    print("\nRunning the filter..")
    my_filter(my_df_val, torch.rand(len(my_df_val), 3))
    my_filter.explain()
