"""Transformer encoder."""

from __future__ import annotations

from math import log
from typing import Any, Callable

import pandas as pd
import torch
from torch import nn

from vito_crop_classification.model.encoders.base import BaseEncoder


class TransformerEncoder(BaseEncoder):
    """Transformer encoder."""

    def __init__(
        self,
        enc_tag: str,
        processing_f: Callable[[pd.DataFrame], torch.FloatTensor],
        n_ts: int,
        seq_len: int,
        output_size: int,
        n_heads: int,
        n_layers: int,
        dropout: float = 0.1,
        **kwargs,
    ):
        """
        Model encoder.

        Parameters
        ----------
        enc_tag : str
            Name of the encoder, used to save it
        processing_f : Callable[[pd.DataFrame], torch.FloatTensor]
            Processing function that puts a dataframe into an embedding of the desirable shape
        n_ts : int
            Number of time series (inputs) provided
        seq_len : int
            Length of the input time series
        output_size : int
            Size of the output and hidden embeddings
        n_heads : int
            Number of heads used in one encoder layer
        n_layers : int
            Number of encoding layers used
        dropout : float
            Dropout to use during training
        """
        super().__init__(
            enc_tag=enc_tag,
            processing_f=processing_f,
            n_ts=n_ts,
            seq_len=seq_len,
            output_size=output_size,
            **_filter_kwargs(kwargs),
        )
        self._n_heads = n_heads
        self._n_layers = n_layers
        self._dropout = dropout

        self._model = CustomTransformer(
            n_ts=self._n_ts,
            hid_size=self._output_size,
            max_seq_len=self._seq_len,
            n_heads=self._n_heads,
            n_layers=self._n_layers,
            dropout=self._dropout,
        )

    def get_output_size(self) -> int:
        """Get the output size of the encoder."""
        return self._output_size * self._seq_len

    def forward_process(self, inp: torch.FloatTensor) -> torch.FloatTensor:
        """Forward the input that has been processed in advance."""
        assert self._model is not None
        return self._model(inp)

    def get_metadata(self) -> dict[str, Any]:
        """Get the model's metadata."""
        metadata = super().get_metadata()
        metadata["n_heads"] = self._n_heads
        metadata["n_layers"] = self._n_layers
        metadata["dropout"] = self._dropout
        return metadata


class PositionalEncoder(nn.Module):
    """
    Positional encoder.

    The authors of the original transformer paper describe very succinctly what
    the positional encoding layer does and why it is needed:
        "Since our model contains no recurrence and no convolution, in order for the
        model to make use of the order of the sequence, we must inject some
        information about the relative or absolute position of the tokens in the
        sequence." (Vaswani et al, 2017)

    Adapted from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
    """

    def __init__(
        self,
        dropout: float = 0.1,
        max_seq_len: int = 1028,
        d_model: int = 512,
        batch_first: bool = True,
    ) -> None:
        """
        Initialise the positional encoder.

        Parameters
        ----------
        dropout : float
            The dropout rate
        max_seq_len : int
            The maximum length of the input sequences
        d_model : int
            The dimension of the output of sub-layers in the model
        batch_first : bool
            Whether the batch is the first dimension
        """
        super().__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(p=dropout)
        self.batch_first = batch_first
        self.x_dim = 1 if batch_first else 0

        # Copy-pasted from PyTorch tutorial
        position = torch.arange(max_seq_len).unsqueeze(1)  # Array from 0..max_seq_len
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-log(10000.0) / d_model))
        pe = torch.zeros(max_seq_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        pe = pe.transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply the positional encoding."""
        x = x + self.pe[: x.size(self.x_dim)]
        return self.dropout(x)


def _filter_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
    """Filter the kwargs."""
    for key in ("enc_tag", "input_size", "output_size", "processing_f"):
        if key in kwargs:
            del kwargs[key]
    return kwargs


class CustomTransformer(nn.Module):
    """Custom transformer with a classification token."""

    def __init__(
        self,
        n_ts: int,
        hid_size: int,
        max_seq_len: int,
        n_heads: int,
        n_layers: int,
        dropout: float,
    ) -> None:
        super().__init__()
        self.inp_layer = nn.Linear(
            in_features=n_ts,
            out_features=hid_size,
        )
        self.pos_enc = PositionalEncoder(
            dropout=dropout,
            max_seq_len=max_seq_len,
            d_model=hid_size,
            batch_first=True,
        )
        enc_layer = nn.TransformerEncoderLayer(
            d_model=hid_size,
            dim_feedforward=4 * hid_size,
            nhead=n_heads,
            dropout=dropout,
            batch_first=True,
        )
        self.enc = nn.TransformerEncoder(
            encoder_layer=enc_layer,
            num_layers=n_layers,
        )
        self.bn = nn.BatchNorm1d(num_features=max_seq_len * hid_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward through the transformer.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor to encode
            Shape: (batch_size, seq_len, n_ts)

        Returns
        -------
        torch.Tensor
            The encoded tensor
            Shape: (batch_size, hid_size)
        """
        # Embed the inputs
        x = self.inp_layer(x)
        x = self.pos_enc(x)

        # Encode
        x = self.enc(x)

        # Flatten the individual tokens
        x = x.flatten(1)

        # Return batch normated output
        return self.bn(x)
