"""RNN encoder."""

from __future__ import annotations

from enum import Enum
from typing import Any, Callable

import pandas as pd
import torch
import torch.nn as nn

from vito_crop_classification.model.encoders.base import BaseEncoder


class RnnType(str, Enum):
    """Different RNN types."""

    RNN: str = "rnn"
    LSTM: str = "lstm"
    GRU: str = "gru"

    def __str__(self) -> str:
        """String representation of RNN enum."""
        return self.value


class RnnEncoder(BaseEncoder):
    """RNN encoder."""

    def __init__(
        self,
        enc_tag: str,
        processing_f: Callable[[pd.DataFrame], torch.FloatTensor],
        cell_type: RnnType | str,
        n_ts: int,
        seq_len: int,
        hidden_dims: list[int],
        output_size: int,
        bidirectional: bool = True,
        dropout: float = 0.1,
        return_sequence: bool = False,  # TODO: Use transformer 'Aggregate' approach!
        **kwargs,
    ):
        """
        Initialize RNN encoder.

        Parameters
        ----------
        enc_tag : str
            Encoder name
        processing_f : Callable[[pd.DataFrame], torch.FloatTensor]
            Processing function that puts a dataframe into an embedding of the desirable shape
        cell_type : RnnType
            Type of RNN cell to use (RNN, GRU, or LSTM)
        n_ts : int
            Number of time series observed
        seq_len : int
            Length of one time series
        hidden_dims : list[int]
            List of hidden dimentions
        output_size : int, optional
            Size of the output embedding, by default 128
        bidirectional : bool, optional
            Defines if the LSTM is bi-directional or one-directional, by default True
        dropout : float
            Dropout to use during training
        return_sequence : bool
            Decide to either return sequence or only last cell output from the model
        """
        super().__init__(
            enc_tag=enc_tag,
            n_ts=n_ts,
            seq_len=seq_len,
            output_size=output_size,
            processing_f=processing_f,
            **_filter_kwargs(kwargs),
        )
        self._cell_type = cell_type
        self._dropout = dropout
        self._hidden_dims = hidden_dims
        self._bidirectional = bidirectional
        self._return_sequence = return_sequence

        self._model = _create_model(
            cell_type=self._cell_type,
            n_ts=self._n_ts,
            out_size=self._output_size,
            bidirectional=self._bidirectional,
            hidden_dims=self._hidden_dims,
            dropout=self._dropout,
        )

        # double the output size in case of bi-directional LSTM
        if self._return_sequence:
            self._output_size *= (
                self._seq_len
            )  # TODO: This is bad.. (conflicts with cfg during load!)
        if self._bidirectional:
            self._output_size *= 2

    def forward_process(self, inp: torch.FloatTensor) -> torch.FloatTensor:
        """Forward the input that has been processed in advance."""
        assert self._model is not None
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(-1)  # N, L -> N, L, 1

        # Predict
        if self._return_sequence:
            out, _ = self._model_forward(inp)
            return out.flatten(1)
        else:
            if self._cell_type == RnnType.LSTM:
                _, (h_n, _) = self._model_forward(inp)
            else:
                _, h_n = self._model_forward(inp)

            # Return the right token
            if self._bidirectional:
                return torch.cat([h_n[-2, :, :], h_n[-1, :, :]], dim=1)  # type: ignore
            else:
                return h_n[-1, :, :]

    def get_metadata(self) -> dict[str, Any]:
        """Get the model's metadata."""
        metadata = super().get_metadata()
        metadata["cell_type"] = self._cell_type
        metadata["output_size"] = self._output_size
        metadata["hidden_dims"] = self._hidden_dims
        metadata["bidirectional"] = self._bidirectional
        metadata["return_sequence"] = self._return_sequence
        return metadata

    def _model_forward(self, x: torch.FloatTensor) -> tuple[Any]:
        """Forward step for model."""
        for layer in self._model[:-1]:
            if type(layer) == nn.Dropout:
                x = layer(x)
            else:
                x = layer(x)[0]
                if self._bidirectional:  # average bidirectional signal
                    x = (x[:, :, x.size(2) // 2 :] + x[:, :, : x.size(2) // 2]) / 2
        return self._model[-1](x)


def _filter_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
    """Filter the kwargs."""
    for key in ("enc_tag", "input_size", "output_size", "processing_f"):
        if key in kwargs:
            del kwargs[key]
    return kwargs


def _create_model(
    cell_type: RnnType | str,
    n_ts: int,
    out_size: int,
    hidden_dims: list[int],
    dropout: float,
    bidirectional: bool,
) -> nn.Module:
    """Create an LSTM model."""
    inp_dims = [n_ts] + hidden_dims
    out_dims = hidden_dims + [out_size]
    mdl = nn.Sequential()
    for i, (inp, out) in enumerate(zip(inp_dims, out_dims)):
        mdl.append(
            _get_cell(cell_type)(
                input_size=inp,
                hidden_size=out,
                bidirectional=bidirectional,
                batch_first=True,
                num_layers=1,
            )
        )
        if i + 1 != len(out_dims) and dropout > 0.0:
            mdl.append(nn.Dropout(dropout))
    return mdl


def _get_cell(cell_type: RnnType | str) -> nn.RNNBase:
    """Get the right cell type."""
    if cell_type == RnnType.RNN:
        return nn.RNN  # type: ignore
    if cell_type == RnnType.LSTM:
        return nn.LSTM  # type: ignore
    elif cell_type == RnnType.GRU:
        return nn.GRU  # type: ignore
    else:
        raise Exception(f"Cell type '{cell_type}' not supported!")
