"""Model creation functionality."""

from __future__ import annotations

from typing import Callable

import torch
from torch import nn


def parse_activation(  # noqa: PLR0911
    activation: str,
) -> Callable[[torch.Tensor], torch.Tensor]:
    """Parse the activation function."""
    if activation == "relu":
        return nn.ReLU()
    if activation == "leaky_relu":
        return nn.LeakyReLU()
    if activation == "elu":
        return nn.ELU()
    if activation == "gelu":
        return nn.GELU()
    if activation == "selu":
        return nn.SELU()
    if activation == "silu":
        return nn.SiLU()
    if activation == "tanh":
        return nn.Tanh()
    if activation == "sigmoid":
        return nn.Sigmoid()
    if activation == "none":
        return nn.Identity()
    raise ValueError(f"Unknown activation function: {activation}")
