"""Model class and training."""

import json
from pathlib import Path

from vito_lot_delineation.models.base_model import BaseModel
from vito_lot_delineation.models.base_trainer import BaseTrainer
from vito_lot_delineation.models.EnanchedResUnet3D import (
    EnanchedResUnet3D,
    EnanchedResUnet3DTrainer,
)
from vito_lot_delineation.models.MultiHeadResUnet3D import (
    MultiHeadResUnet3D,
    MultiHeadResUnet3DTrainer,
)
from vito_lot_delineation.models.ResUnet3D import ResUnet3D, ResUnet3DTrainer
from vito_lot_delineation.models.SMPModel import SMPModel, SMPModelTrainer

__all__ = [
    "BaseModel",
    "ResUnet3D",
    "SMPModel",
    "EnanchedResUnet3D",
    "load_model",
    "BaseTrainer",
    "ResUnet3DTrainer",
    "SMPModelTrainer",
    "EnanchedResUnet3DTrainer",
]


def parse_model(model_architecture: dict) -> BaseModel:
    """Parse model configuration and return model instance."""
    if model_architecture == "ResUnet3D":
        return ResUnet3D
    if model_architecture == "EnanchedResUnet3D":
        return EnanchedResUnet3D
    if model_architecture == "MultiHeadResUnet3D":
        return MultiHeadResUnet3D
    if model_architecture.startswith("smp"):
        return SMPModel
    raise ValueError(f"Unknown model architecture: {model_architecture}")


def parse_trainer(model_architecture: dict) -> BaseTrainer:
    """Parse model configuration and return model instance."""
    if model_architecture == "ResUnet3D":
        return ResUnet3DTrainer
    if model_architecture == "EnanchedResUnet3D":
        return EnanchedResUnet3DTrainer
    if model_architecture == "MultiHeadResUnet3D":
        return MultiHeadResUnet3DTrainer
    if model_architecture.startswith("smp"):
        return SMPModelTrainer
    raise ValueError(f"Unknown model architecture: {model_architecture}")


def load_model(model_path: Path) -> BaseModel:
    """Load in any type of model."""
    config_path = model_path / "config_file.json"
    with open(config_path) as f:
        config_file = json.load(f)
    model_architecture = config_file["model"]["architecture"]
    if model_architecture == "ResUnet3D":
        model = ResUnet3D.load(model_path)
    elif model_architecture == "EnanchedResUnet3D":
        model = EnanchedResUnet3D.load(model_path)
    elif model_architecture == "MultiHeadResUnet3D":
        model = MultiHeadResUnet3D.load(model_path)
    elif model_architecture.startswith("smp"):
        model = SMPModel.load(model_path)
    return model
