"""Identity encoder."""

from __future__ import annotations

from typing import Any, Callable

import pandas as pd
import torch
import torch.nn as nn

from vito_crop_classification.model.encoders.base import BaseEncoder


class IdentityEncoder(BaseEncoder):
    """Identity encoder."""

    def __init__(
        self,
        enc_tag: str,
        feature_size: int,
        processing_f: Callable[[pd.DataFrame], torch.FloatTensor],
        **kwargs,
    ):
        """
        Identity encoder.

        Parameters
        ----------
        enc_tag : str
            Name of the encoder, used to save it
        feature_size : int
            Number of the features extracted by the processing function
        """
        super(IdentityEncoder, self).__init__(
            enc_tag=enc_tag,
            n_ts=feature_size,
            seq_len=1,
            output_size=feature_size,
            processing_f=processing_f,
            **_filter_kwargs(kwargs),
        )
        self._model = _create_model()

    def get_metadata(self) -> dict[str, Any]:
        """Get the model's metadata."""
        metadata = super().get_metadata()
        metadata["feature_size"] = self._output_size
        return metadata


def _filter_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
    """Filter the kwargs."""
    for key in ("enc_tag", "input_size", "output_size", "processing_f", "n_ts", "seq_len"):
        if key in kwargs:
            del kwargs[key]
    return kwargs


def _create_model() -> torch.nn.Module:
    """Create a dense model."""
    return nn.Sequential(nn.Identity())
