"""Concatenated encoder class."""

from __future__ import annotations

import json
import warnings
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd
import torch
import torch.nn as nn

from vito_crop_classification.model.encoders.base import BaseEncoder
from vito_crop_classification.model.encoders.loaders import load_encoder


class ConcatenatedEncoder(BaseEncoder):
    """Concatenated encoder."""

    def __init__(
        self,
        enc_tag: str,
        list_encoders: list[BaseEncoder],
        hidden_dims: list[int],
        output_size: int | None = None,
        use_batch_norm: bool = False,
        dropout: float = 0.1,
        **kwargs,
    ) -> None:
        """
        Initialize the encoder concatenation model.

        Parameters
        ----------
        enc_tag : str
            Name of the encoder, used to save it
        list_encoders : list[BaseEncoder]
            List of encoders to concatenate
        output_size : int
            Desired size of the output embedding, None if concatenated output is OK
        hidden_dims : list[int]
            List of hidden dimentions
        use_batch_norm : bool,
            Define either to use batch normalization or not
        dropout : float
            Dropout to use during training
        """
        agg_inp_size = sum([e.get_output_size() for e in list_encoders])
        super(ConcatenatedEncoder, self).__init__(
            enc_tag=enc_tag,
            n_ts=len(list_encoders),
            seq_len=agg_inp_size,
            output_size=output_size if output_size else agg_inp_size,
            processing_f=lambda x: x,
            **_filter_kwargs(kwargs),
        )
        self._list_encoders = list_encoders
        self._hidden_dims = hidden_dims
        self._use_batch_norm = use_batch_norm
        self._dropout = dropout

        self._model = _create_model(
            inp_size=agg_inp_size,
            out_size=self._output_size,
            hidden_dims=self._hidden_dims,
            dropout=self._dropout,
            use_batch_norm=self._use_batch_norm,
        )

    def forward_process(self, inp: np.ndarray[list[torch.FloatTensor]]) -> torch.FloatTensor:
        """Forward the input that has been processed in advance."""
        result = []
        for i, enc in enumerate(self._list_encoders):
            segment = torch.stack(tuple(inp[:, i])).to(self._device)
            result.append(enc.forward_process(segment))
        return self._model(torch.concat(result, dim=1))

    def preprocess_df(self, df: pd.DataFrame) -> np.ndarray[list[torch.Tensor]]:
        """Preprocess the provided dataframe."""
        results = []
        for enc in self._list_encoders:
            if len(results) == 0:
                results = [[r] for r in enc.preprocess_df(df)]
            else:
                for i, r in enumerate(enc.preprocess_df(df)):
                    results[i].append(r)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", FutureWarning)
            results = np.asarray(results, dtype=object)
        return results

    def parameters(self) -> list[nn.parameter.Parameter]:
        """Get load parameters."""
        assert self._model is not None
        params = []
        for enc in self._list_encoders:
            params += enc.parameters()
        params += self._model.parameters()
        return params

    def get_metadata(self) -> dict[str, Any]:
        """Get the model's metadata."""
        metadata = super().get_metadata()
        # TODO: Can't we derive this from the configuration?
        metadata["list_encoders_types"] = [enc.get_type() for enc in self._list_encoders]
        metadata["list_encoders_tags"] = [enc.get_tag() for enc in self._list_encoders]
        metadata["hidden_dims"] = self._hidden_dims
        metadata["dropout"] = self._dropout
        metadata["use_batch_norm"] = self._use_batch_norm
        return metadata

    def train(self, device: str | None = None) -> None:
        """Put the model to training, if exists."""
        assert self._model is not None
        self._model.train()
        for i, _ in enumerate(self._list_encoders):
            self._list_encoders[i].train()
        self.to(self._device)

    def eval(self, device: str | None = None) -> None:
        """Put the model to training, if exists."""
        assert self._model is not None
        self._model.eval()
        for i, _ in enumerate(self._list_encoders):
            self._list_encoders[i].eval()
        self.to(self._device)

    def to(self, device: str | None = None) -> None:
        """Put the model on the requested device."""
        if device is not None:
            self._device = device
        self._model.to(self._device)
        for i, _ in enumerate(self._list_encoders):
            self._list_encoders[i]._model.to(self._device)

    @classmethod
    def load(cls, mdl_f: Path, enc_tag: str, device: str | None = None) -> ConcatenatedEncoder:
        """Load in the corresponding encoder."""
        device = ("cuda" if torch.cuda.is_available() else "cpu") if (device is None) else device

        # Load metadata
        load_path = mdl_f / "modules" / enc_tag
        with open(load_path / "enc_metadata.json", "r") as f:
            metadata = json.load(f)

        # Load sub-encoders
        list_encoders = []
        for enc_type, enc_tag in zip(
            metadata["list_encoders_types"], metadata["list_encoders_tags"]
        ):
            if "concat" in enc_tag:
                ConcatenatedEncoder.load(mdl_f, enc_tag=enc_tag)
            else:
                list_encoders.append(
                    load_encoder(mdl_f, enc_type=enc_type, enc_tag=enc_tag, device=device)
                )

        # Initialize Concatenate encoder
        concatenator = cls(
            enc_tag=enc_tag,
            list_encoders=list_encoders,
            output_size=metadata["output_size"],
            hidden_dims=metadata["hidden_dims"],
            dropout=metadata["dropout"],
            use_batch_norm=metadata["use_batch_norm"],
        )

        # Load weights
        concatenator._model = torch.load(load_path / "enc_weights", map_location=torch.device(device))  # type: ignore[no-untyped-call]
        return concatenator

    def save(self, mdl_f: Path) -> None:
        """Save the concatenation encoder."""
        # save all sub-encoders
        for enc in self._list_encoders:
            enc.save(mdl_f)

        # save metadata
        save_path = mdl_f / "modules" / self._tag
        save_path.mkdir(parents=True, exist_ok=True)
        with open(save_path / "enc_metadata.json", "w") as f:
            json.dump(self.get_metadata(), f, indent=2)

        # Save the model weights
        torch.save(self._model, save_path / "enc_weights")


def _filter_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
    """Filter the kwargs."""
    for key in ("enc_tag", "input_size", "output_size", "processing_f"):
        if key in kwargs:
            del kwargs[key]
    return kwargs


def _create_model(
    inp_size: int, out_size: int, hidden_dims: list[int], dropout: float, use_batch_norm: bool
) -> torch.nn.Module:
    """Create a dense model."""
    if inp_size == out_size:
        return nn.Sequential(nn.Identity())
    else:
        inp_dims = [inp_size] + hidden_dims
        out_dims = hidden_dims + [out_size]
        mdl = nn.Sequential()
        for inp, out in zip(inp_dims, out_dims):
            mdl.append(nn.Linear(in_features=inp, out_features=out))
            mdl.append(nn.ReLU())
            mdl.append(nn.Dropout(dropout))
            if use_batch_norm:
                mdl.append(nn.BatchNorm1d(out))
        return mdl
