"""Base encoder."""

from __future__ import annotations

import json
import pickle
from pathlib import Path
from typing import Any, Callable

import pandas as pd
import torch
from torch import nn


class BaseEncoder:
    """Base encoder."""

    def __init__(
        self,
        enc_tag: str,
        processing_f: Callable[[pd.DataFrame], torch.FloatTensor],
        n_ts: int,
        seq_len: int,
        output_size: int,
        **kwargs,
    ):
        """
        Initialize base encoder.

        Parameters
        ----------
        enc_tag : str
            Name of the encoder, used to save it
        n_ts : int
            Number of time series observed
        seq_len : int
            Length of one time series
        output_size : int
            Size of the output embedding
        processing_f : Callable[[pd.DataFrame], torch.FloatTensor]
            Processing function that puts a dataframe into an embedding of the desirable shape
        """
        self._tag = enc_tag
        self._n_ts = n_ts
        self._seq_len = seq_len
        self._output_size = output_size
        self._process_f = processing_f
        self._device = "cuda" if torch.cuda.is_available() else "cpu"
        self._model: torch.nn.Module | None = None

    def __str__(self) -> str:
        """Representation of the encoder."""
        return (
            f"{self.__class__.__name__}("
            f"n_ts={self._n_ts}, "
            f"seq_len={self._seq_len}, "
            f"out_dim={self._output_size})"
        )

    def __repr__(self) -> str:
        """Representation of the encoder."""
        return str(self)

    def __call__(self, df: pd.DataFrame) -> torch.FloatTensor:
        """Encode the provided dataframe."""
        assert self._model is not None
        return self.forward_process(self.preprocess_df(df))

    def forward_process(self, inp: torch.FloatTensor) -> torch.FloatTensor:
        """Forward the input that has been processed in advance."""
        assert self._model is not None
        return self._model(inp.to(self._device))

    def preprocess_df(self, df: pd.DataFrame) -> torch.FloatTensor:
        """Preprocess the provided dataframe."""
        return self._process_f(df)

    def parameters(self) -> list[nn.parameter.Parameter]:
        """Return model parameters to train."""
        assert self._model is not None
        return list(self._model.parameters())

    def train(self, device: str | None = None) -> None:
        """Put the model to training, if exists."""
        assert self._model is not None
        self._model.train()
        self.to(device)

    def eval(self, device: str | None = None) -> None:
        """Put the model to training, if exists."""
        assert self._model is not None
        self._model.eval()
        self.to(device)

    def to(self, device: str | None = None) -> None:
        """Put the model on the requested device."""
        if device is not None:
            self._device = device
        self._model.to(self._device)

    def get_type(self) -> str:
        """Get the model's tag."""
        return self.__class__.__name__

    def get_tag(self) -> str:
        """Get the model's tag."""
        return self._tag

    def get_n_ts(self) -> int:
        """Get the number of time series."""
        return self._n_ts

    def get_seq_len(self) -> int:
        """Get the length of the sequences."""
        return self._seq_len

    def get_output_size(self) -> int:
        """Get the output size of the encoder."""
        return self._output_size

    def get_metadata(self) -> dict[str, Any]:
        """Get the model's metadata."""
        return {
            "n_ts": self._n_ts,
            "seq_len": self._seq_len,
            "output_size": self._output_size,
        }

    @classmethod
    def load(cls, mdl_f: Path, enc_tag: str, device: str | None = None) -> BaseEncoder:
        """Load in the corresponding encoder."""
        # Load the metadata
        load_path = mdl_f / "modules" / enc_tag
        with open(load_path / "enc_metadata.json") as f:
            metadata = json.load(f)
        with open(load_path / "enc_processing.pkl", "rb") as f2:
            processing_f = pickle.load(f2)  # noqa: S301
        enc = cls(
            enc_tag=enc_tag,
            processing_f=processing_f,
            **metadata,
        )
        device = ("cuda" if torch.cuda.is_available() else "cpu") if (device is None) else device
        enc._model = torch.load(load_path / "enc_weights", map_location=torch.device(device))  # type: ignore[no-untyped-call]
        return enc

    def save(self, mdl_f: Path) -> None:
        """Save the encoder."""
        # Save the metadata
        save_path = mdl_f / "modules" / self._tag
        save_path.mkdir(parents=True, exist_ok=True)
        with open(save_path / "enc_metadata.json", "w") as f:
            json.dump(self.get_metadata(), f, indent=2)
        with open(save_path / "enc_processing.pkl", "wb") as f2:
            pickle.dump(self._process_f, f2)
        torch.save(self._model, save_path / "enc_weights")
