"""CNN encoder."""

from __future__ import annotations

from typing import Any, Callable

import pandas as pd
import torch
import torch.nn as nn

from vito_crop_classification.model.encoders.base import BaseEncoder


class CnnEncoder(BaseEncoder):
    """CNN encoder."""

    def __init__(
        self,
        enc_tag: str,
        processing_f: Callable[[pd.DataFrame], torch.FloatTensor],
        n_ts: int,
        seq_len: int,
        hidden_dims: list[int],
        kernel_dims: list[int],
        output_size: int,
        dropout: float = 0.1,
        **kwargs,
    ):
        """
        Initialize CNN encoder.

        Parameters
        ----------
        enc_tag : str
            Encoder name
        processing_f : Callable[[pd.DataFrame], torch.FloatTensor]
            Processing function that puts a dataframe into an embedding of the desirable shape
        n_ts : int
            Number of time series observed
        seq_len : int
            Length of one time series
        hidden_dims : list[int]
            List of hidden dimensions
        kernel_dims : list[int]
            List of kernel dims
        output_size : int, optional
            Size of the output embedding, by default 128
        dropout : float
            Dropout to use during training
        """
        super().__init__(
            enc_tag=enc_tag,
            n_ts=n_ts,
            seq_len=seq_len,
            output_size=output_size,
            processing_f=processing_f,
            **_filter_kwargs(kwargs),
        )
        self._dropout = dropout
        self._hidden_dims = hidden_dims
        self._kernel_dims = kernel_dims

        self._model = _create_model(
            n_ts=self._n_ts,
            out_size=self._output_size,
            hidden_dims=self._hidden_dims,
            kernel_dims=self._kernel_dims,
            dropout=self._dropout,
        )

    def get_output_size(self) -> int:
        """Get the output size of the encoder."""
        return self._output_size * self._seq_len

    def forward_process(self, inp: torch.FloatTensor) -> torch.FloatTensor:
        """Forward the input that has been processed in advance."""
        assert self._model is not None
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(-1)  # N, L -> N, L, 1
        inp = inp.permute((0, 2, 1))  # N, L, F -> N, F, L
        return self._model(inp)

    def get_metadata(self) -> dict[str, Any]:
        """Get the model's metadata."""
        metadata = super().get_metadata()
        metadata["output_size"] = self._output_size
        metadata["hidden_dims"] = self._hidden_dims
        metadata["kernel_dims"] = self._kernel_dims
        metadata["dropout"] = self._dropout
        return metadata


def _filter_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
    """Filter the kwargs."""
    for key in ("enc_tag", "input_size", "output_size", "processing_f"):
        if key in kwargs:
            del kwargs[key]
    return kwargs


def _create_model(
    n_ts: int,
    out_size: int,
    hidden_dims: list[int],
    kernel_dims: list[int],
    dropout: float,
) -> nn.Module:
    """Create an LSTM model."""
    inp_dims = [n_ts] + hidden_dims
    out_dims = hidden_dims + [out_size]
    mdl = nn.Sequential()
    for i, (inp, out, ks) in enumerate(zip(inp_dims, out_dims, kernel_dims)):
        assert ks % 2 == 1, "Kernel size must be odd"
        padding = 0 if ks == 1 else ks // 2
        mdl.append(nn.Conv1d(inp, out, kernel_size=ks, padding=padding))
        mdl.append(nn.ReLU())
        if i + 1 != len(out_dims) and dropout > 0.0:
            mdl.append(nn.Dropout(dropout))
    mdl.append(nn.Flatten())
    return mdl
