"""Dense encoder."""

from __future__ import annotations

from typing import Any, Callable

import pandas as pd
import torch
import torch.nn as nn

from vito_crop_classification.model.encoders.base import BaseEncoder


class DenseEncoder(BaseEncoder):
    """Dense encoder."""

    def __init__(
        self,
        enc_tag: str,
        processing_f: Callable[[pd.DataFrame], torch.FloatTensor],
        n_ts: int,
        seq_len: int,
        hidden_dims: list[int],
        output_size: int,
        dropout: float = 0.1,
        use_batch_norm: bool = False,
        **kwargs,
    ):
        """
        Dense encoder.

        Parameters
        ----------
        enc_tag : str
            Name of the encoder, used to save it
        processing_f : Callable[[pd.DataFrame], torch.FloatTensor]
            Processing function that puts a dataframe into an embedding of the desirable shape
        n_ts : int
            Number of time series observed
        seq_len : int
            Length of one time series
        hidden_dims : list[int]
            List of hidden dimentions
        output_size : int
            Size of the output embedding
        dropout : float
            Dropout to use during training
        use_batch_norm : bool
            Either to use batch norm or not
        """
        super(DenseEncoder, self).__init__(
            enc_tag=enc_tag,
            n_ts=n_ts,
            seq_len=seq_len,
            output_size=output_size,
            processing_f=processing_f,
            **_filter_kwargs(kwargs),
        )
        self._hidden_dims = hidden_dims
        self._dropout = dropout
        self._use_batch_norm = use_batch_norm

        self._model = _create_model(
            inp_size=self._n_ts * self._seq_len,
            out_size=self._output_size,
            hidden_dims=self._hidden_dims,
            dropout=self._dropout,
            use_batch_norm=self._use_batch_norm,
        )

    def get_metadata(self) -> dict[str, Any]:
        """Get the model's metadata."""
        metadata = super().get_metadata()
        metadata["dropout"] = self._dropout
        metadata["use_batch_norm"] = self._use_batch_norm
        metadata["hidden_dims"] = self._hidden_dims
        return metadata


def _filter_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
    """Filter the kwargs."""
    for key in ("enc_tag", "input_size", "output_size", "processing_f"):
        if key in kwargs:
            del kwargs[key]
    return kwargs


def _create_model(
    inp_size: int, out_size: int, hidden_dims: list[int], dropout: float, use_batch_norm: bool
) -> torch.nn.Module:
    """Create a dense model."""
    inp_dims = [inp_size] + hidden_dims
    out_dims = hidden_dims + [out_size]
    mdl = nn.Sequential(nn.Flatten())
    for inp, out in zip(inp_dims, out_dims):
        mdl.append(nn.Linear(in_features=inp, out_features=out))
        mdl.append(nn.ReLU())
        mdl.append(nn.Dropout(dropout))
        if use_batch_norm:
            mdl.append(nn.BatchNorm1d(out))
    return mdl
