"""Loading functions for encoders."""

from __future__ import annotations

from copy import deepcopy
from functools import partial
from pathlib import Path
from typing import Any

import torch

from vito_crop_classification.model.encoders.base import BaseEncoder
from vito_crop_classification.model.encoders.dense import DenseEncoder
from vito_crop_classification.model.encoders.identity import IdentityEncoder
from vito_crop_classification.model.encoders.rnn import RnnEncoder
from vito_crop_classification.model.encoders.transformer import TransformerEncoder
from vito_crop_classification.model.encoders.utils import get_extract_function


def load_encoder(
    mdl_f: Path,
    enc_type: str,
    enc_tag: str,
    device: str | None = None,
) -> BaseEncoder:
    """Load the specified pre-trained encoder."""
    device = device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu")
    if enc_type == "DenseEncoder":
        return DenseEncoder.load(mdl_f, enc_tag=enc_tag, device=device)
    elif enc_type == "RnnEncoder":
        return RnnEncoder.load(mdl_f, enc_tag=enc_tag, device=device)
    elif enc_type == "IdentityEncoder":
        return IdentityEncoder.load(mdl_f, enc_tag=enc_tag, device=device)
    elif enc_type == "TransformerEncoder":
        return TransformerEncoder.load(mdl_f, enc_tag=enc_tag, device=device)
    else:
        raise Exception(f"Cannot parse encoder type: '{enc_type}'!")


def parse_encoder(cfg_enc: dict[str, Any]) -> BaseEncoder | None:
    """Parse simple encoder."""
    from vito_crop_classification.model.encoders.concatenate import ConcatenatedEncoder

    enc_params = deepcopy(cfg_enc["params"])
    enc_type = cfg_enc["type"]

    # initialize encoder in case of simple encoder
    if enc_type != "ConcatenatedEncoder":
        extract_f = get_extract_function(enc_params["processing_f"]["type"])
        extract_f_params = deepcopy(enc_params["processing_f"]["params"])
        enc_params["n_ts"] = len(extract_f_params["cols"])
        enc_params["processing_f"] = partial(extract_f, **extract_f_params)

        if enc_type == "IdentityEncoder":
            return IdentityEncoder(**enc_params)
        elif enc_type == "DenseEncoder":
            return DenseEncoder(**enc_params)
        elif enc_type == "RnnEncoder":
            return RnnEncoder(**enc_params)
        elif enc_type == "TransformerEncoder":
            return TransformerEncoder(**enc_params)

    # iterate over encoders in case of complex encoder
    elif enc_type == "ConcatenatedEncoder":
        list_encoders = []
        for encoder_dict in enc_params["list_encoders"]:
            list_encoders.append(parse_encoder(encoder_dict))
        enc_params["list_encoders"] = list_encoders
        return ConcatenatedEncoder(**enc_params)

    return None
