"""ResUNet3d configuration."""

from typing import List

from pydantic import Field

from vito_cropsar.models.base_config import ModelBaseConfig, TrainerBaseConfig


class ModelConfig(ModelBaseConfig):
    """ResUNet3d configuration."""

    channels: List[int] = Field(  # noqa: UP006
        [64, 128, 256, 512],
        description=(
            "Number of channels per layer in the UNet3d."
            "Note: The first channels is used for input and output mapping."
            "Note: The last channel is used for the bottleneck, with channel//4 as bottleneck"
        ),
        example=[64, 128, 256],
    )
    depth_spatial: int = Field(
        3,  # 128x128 to 16x16 (patches of 8x8)
        description="Number of spatial convolution layers in the UNet3d",
        example=3,
    )
    depth_temporal: int = Field(
        4,  # 32 to 2 (complete unit representation of full stack)
        description="Number of temporal convolution layers in the UNet3d",
        example=5,
    )
    enc_channels_p: int = Field(
        16,
        description="Number of input processing channels (spatial encoder)",
        example=32,
    )
    enc_dropout: float = Field(
        0.1,
        description="Dropout rate (spatial encoder)",
        example=0.1,
    )
    enc_activation: str = Field(
        "relu",
        description="Activation function to use (spatial encoder)",
        example="relu",
    )
    bn_fraction: float = Field(
        0.25,
        description="Number of channels to use in the bottleneck, expressed as a fraction from channels[-1]",
        example=0.25,
    )
    bn_dropout: float = Field(
        0.1,
        description="Dropout rate (bottleneck)",
        example=0.1,
    )
    bn_activation: str = Field(
        "relu",
        description="Activation function to use (bottleneck)",
        example="relu",
    )
    dec_channels_p: int = Field(
        32,
        description="Number of input processing channels (spatial decoder)",
        example=32,
    )
    dec_dropout: float = Field(
        0.1,
        description="Dropout rate (spatial decoder)",
        example=0.1,
    )
    dec_activation: str = Field(
        "relu",
        description="Activation function to use (spatial decoder)",
        example="relu",
    )


class TrainerConfig(TrainerBaseConfig):
    """Placeholder for trainer configuration of the ResUNet3d model."""


if __name__ == "__main__":
    from vito_cropsar.data import ScalerNone

    cfg = ModelConfig(
        tag="test",
        scaler=ScalerNone(
            {
                "bands_s1": ["s1_vv", "s1_vh"],
                "bands_s2": ["s2_fapar"],
            }
        ),
        sample_s1=0,
        optional="test",
    )
    print(cfg)
