"""Model class."""

from __future__ import annotations

import json
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd
import torch
from numpy.typing import NDArray
from torch.nn.parameter import Parameter

from vito_crop_classification.constants import get_models_folder
from vito_crop_classification.model.classifiers import (
    BaseClassifier,
    load_classifier,
    parse_classifier,
)
from vito_crop_classification.model.encoders import BaseEncoder, ConcatenatedEncoder, load_encoder
from vito_crop_classification.model.encoders.loaders import parse_encoder
from vito_crop_classification.vito_logger import Logger


class Model:
    """Model class."""

    def __init__(
        self,
        model_tag: str,
        config_file: dict[str, Any],
        class_ids: NDArray[np.str_],
        class_names: NDArray[np.str_],
        scale_cfg: dict[str, tuple[float, float]],
        mdl_f: Path | None = None,
        device: str | None = None,
    ) -> None:
        """
        Initialize configuration file.

        Parameters
        ----------
        model_tag : str
            Name of the model
        config_file : dict[str,Any]
            Model configuration, specifying model creation
        class_ids : NDArray[np.str_]
            List of class IDs used for classification
        class_names : NDArray[np.str_]
            List of class names used for classification
        scale_cfg : dict[str, tuple[float, float]]
            Scaling configuration applied on the transformed DataFrame
        mdl_f : Path | None
            Folder where the model gets stored
        device : str | None
            Device to run the model on
        """
        self.tag: str = model_tag
        self.device: str = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.enc: BaseEncoder | None = None
        self.clf: BaseClassifier | None = None
        self.model_folder: Path = (mdl_f or get_models_folder()) / self.tag
        self._config_file: dict[str, Any] = config_file
        self._class_ids: NDArray[np.str_] = class_ids
        self._class_names: NDArray[np.str_] = class_names
        self._scale_cfg = scale_cfg

        # create model folder
        self.model_folder.mkdir(exist_ok=True, parents=True)

        # Create model specific logger
        self.logger = Logger(log_f=self.model_folder)

    def __str__(self) -> str:
        """Representation of the model."""
        return f"Model(enc={self.enc}, clf={self.clf})"

    def __repr__(self) -> str:
        """Representation of the model."""
        return str(self)

    def __call__(self, df: pd.DataFrame) -> torch.Tensor:
        """Forward process over dataframe."""
        df_t = self.enc.preprocess_df(df)
        return self.forward_process(df_t)

    def forward_process(self, inp: torch.FloatTensor | NDArray) -> torch.Tensor:
        """Forward process over torch tensor."""
        enc = self.enc.forward_process(inp)
        return self.clf.forward_process(enc)

    def train(self) -> None:
        """Set enc and clf to train mode."""
        self._check_initialised()
        self.enc.train(self.device)
        self.clf.train(self.device)

    def eval(self) -> None:
        """Set enc and clf to eval mode."""
        self._check_initialised()
        self.enc.eval(self.device)
        self.clf.eval(self.device)

    def get_scale_cfg(self) -> dict[str, tuple[float, float]]:
        """Get the model specific scaling configuration."""
        return self._scale_cfg

    def parameters(self) -> list[Parameter]:
        """Get parameters from encoder and decoder."""
        return list(self.enc.parameters()) + list(self.clf.parameters())

    def build_model(self) -> None:
        """Build model encoder and classifier from configuration file."""
        self.logger(f"Building model '{self.tag}'..")

        # parse encoder
        self.enc = parse_encoder(self._config_file["encoder"])

        # parse classifier
        assert self._config_file["classifier"]["n_classes"] == len(self._class_ids), (
            f"Configuration has dimension {self._config_file['classifier']['n_classes']}, "
            f"where {len(self._class_ids)} classes are provided!"
        )
        self.clf = parse_classifier(
            cfg_clf=self._config_file["classifier"],
            input_size=self.enc.get_output_size(),
            classes=self._class_ids,
        )
        self.save()

    def get_class_ids(self) -> NDArray[np.str_]:
        """Get the class IDs on which the model is trained on."""
        return self._class_ids

    def get_class_names(self) -> NDArray[np.str_]:
        """Get the class names on which the model is trained on."""
        return self._class_names

    def get_class_mapping(self, id2name: bool = True) -> dict[str, str]:
        """Get the class ID to name mapping."""
        assert self._class_ids.shape == self._class_names.shape
        return {
            (k if id2name else v): (v if id2name else k)
            for k, v in zip(self._class_ids, self._class_names)
        }

    def save(self) -> None:
        """Save model."""
        self._check_initialised()

        # Save modules
        self.enc.save(self.model_folder)
        self.clf.save(self.model_folder)

        # Save model additional variables
        with open(self.model_folder / "config_file.json", "w") as f:
            json.dump(self._config_file, f, indent=2)
        with open(self.model_folder / "mdl_metadata.json", "w") as f:
            json.dump(
                {
                    "class_ids": list(self._class_ids),
                    "class_names": list(self._class_names),
                    "scale_cfg": self._scale_cfg,
                    "n_params": _count_parameters(self.parameters()),
                },
                f,
                indent=2,
            )

    @classmethod
    def load(cls, mdl_f: Path, device: str | None = None) -> Model:
        """Load model from model folder."""
        # get metadata
        with open(mdl_f / "mdl_metadata.json") as f:
            metadata = json.load(f)

        # load config_file name
        with open(mdl_f / "config_file.json") as f:
            cfg_file = json.load(f)

        # load model
        device = device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu")
        if cfg_file["encoder"]["type"] == "ConcatenatedEncoder":
            enc = ConcatenatedEncoder.load(
                mdl_f, enc_tag=cfg_file["encoder"]["params"]["enc_tag"], device=device
            )
        else:
            enc = load_encoder(
                mdl_f,
                enc_type=cfg_file["encoder"]["type"],
                enc_tag=cfg_file["encoder"]["params"]["enc_tag"],
                device=device,
            )

        # load classifier
        clf_tag = list((mdl_f / "modules").glob("clf_*"))[0].name
        clf = load_classifier(
            mdl_f,
            clf_type=cfg_file["classifier"]["type"],
            clf_tag=clf_tag,
            device=device,
        )

        # build model
        model = cls(
            model_tag=mdl_f.name,
            config_file=cfg_file,
            class_ids=np.asarray(
                metadata["class_ids"],
            ),
            class_names=np.asarray(
                metadata["class_names"],
            ),
            scale_cfg=metadata["scale_cfg"],
            mdl_f=mdl_f.parent,
        )
        model.enc = enc
        model.clf = clf
        model.eval()
        return model

    def _check_initialised(self) -> None:
        """Check if the model is properly initialised."""
        assert self.enc is not None, "Encoder not found!"
        assert self.clf is not None, "Classifier not found!"


def _count_parameters(params: list[Parameter] | Parameter) -> int:
    """Count the parameters."""
    if isinstance(params, list):
        return sum(_count_parameters(p) for p in params)
    return len(params.flatten())


if __name__ == "__main__":
    my_cfg: dict[str, Any] = {
        "encoder": {
            "type": "CnnEncoder",
            "params": {
                "enc_tag": "enc_cnn_test",
                "processing_f": {
                    "type": "extract_ts",
                    "params": {"cols": ["ts_ndvi", "ts_ndwi", "ts_ndre2"]},
                },
                "n_ts": 3,
                "seq_len": 18,
                "output_size": 64,
                "hidden_dims": [16, 32, 64],
                "kernel_dims": [3, 5, 3, 5],
                "dropout": 0.1,
            },
        },
        "classifier": {
            "type": "DenseClassifier",
            "n_classes": 19,
            "params": {"clf_tag": "clf_dense_test"},
        },
        "train_type": "classification",
        "loss": "cross_entropy",
    }
    my_mdl = Model(
        model_tag="test_model",
        config_file=my_cfg,
        class_ids=np.asarray([f"{i}" for i in range(my_cfg["classifier"]["n_classes"])]),
        class_names=np.asarray([f"{i}" for i in range(my_cfg["classifier"]["n_classes"])]),
        scale_cfg={},
    )
    my_mdl.build_model()
    my_mdl.save()
