"""SMP model class."""

from __future__ import annotations

from typing import Any

import segmentation_models_pytorch as smp
import torch
from torch import nn


class SMPModel(nn.Module):
    """Custom ResUNet3d module."""

    def __init__(
        self,
        smp_config: dict[str, Any],
    ) -> None:
        """
        Initialise the ResUNet3d module.

        Parameters
        ----------
        smp_config : dict[str, Any]
            Configuration of a semantic_segmentation_pytorch model
        """
        super().__init__()
        self.model = _parse_smp_model(smp_config)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward process."""
        # flatten the time dimension (B, CH, TS, W, H) -> (B, CH * TS, W, H)
        x = x.flatten(start_dim=1, end_dim=2)
        return self.model(x)


def _parse_smp_model(model_cfg: dict[str, Any]) -> torch.nn.Module:
    """Parse an smp model from configuration file."""
    model_class: torch.nn.Module = getattr(smp, model_cfg["class"])
    return model_class(**model_cfg["params"])
