"""CnnTransformer configuration."""

from typing import Dict, List

from pydantic import BaseModel, Field

from vito_cropsar.models.base_config import (
    AugmentationConfig,
    LossConfig,
    ModelBaseConfig,
    TrainerBaseConfig,
)


class ModelConfig(ModelBaseConfig):
    """CnnTransformer configuration."""

    enc_channels_p: int = Field(
        32,
        description="Number of input processing channels (spatial encoder)",
        example=32,
    )
    enc_channels_resnet: List[int] = Field(  # noqa: UP006
        [96, 192, 384],
        description="Number of channels in each spatial layer (spatial encoder)",
        example=[64, 128, 512],
    )
    enc_dropout: float = Field(
        0.1,
        description="Dropout rate (spatial encoder)",
        example=0.1,
    )
    enc_activation: str = Field(
        "relu",
        description="Activation function to use (spatial encoder)",
        example="relu",
    )
    temp_layers: int = Field(
        4,
        description="Number of layers (temporal layer)",
        example=4,
    )
    temp_heads: int = Field(
        8,
        description="Number of heads (temporal layer)",
        example=8,
    )
    temp_dropout: float = Field(
        0.1,
        description="Dropout rate (temporal layer)",
        example=0.1,
    )
    temp_activation: str = Field(
        "gelu",
        description="Activation function to use (temporal layer)",
        example="relu",
    )
    dec_channels_p: int = Field(
        32,
        description="Number of output processing channels (spatial decoder)",
        example=32,
    )
    dec_channels_resnet: List[int] = Field(  # noqa: UP006
        [384, 192, 96],
        description="Number of channels in each spatial layer (spatial decoder)",
        example=[512, 256, 128],
    )
    dec_dropout: float = Field(
        0.0,
        description="Dropout rate (spatial decoder)",
        example=0.0,
    )
    dec_activation: str = Field(
        "leaky_relu",
        description="Activation function to use (spatial decoder)",
        example="relu",
    )
    sha_activation: str = Field(
        "leaky_relu",
        description="Activation function to use in the sharpening UNet",
        example="relu",
    )
    edge_repr: bool = Field(
        True,
        description="Whether to use an edge representative as assistance input in the decoder",
        example=False,
    )


class TrainerConfig(TrainerBaseConfig):
    """Placeholder for trainer configuration of the ResUNet3d model."""


class MultiTrainerConfig(BaseModel):
    """Config containing multiple trainer configs."""

    trainers: Dict[str, TrainerConfig] = Field(  # noqa: UP006
        {
            "autoencoder": TrainerConfig(
                loss_cfg=LossConfig(  # Only change the non-default values
                    weights={
                        "mae_masked": 0.1,  # Secondary objective
                        "mae_unmasked": 1.0,
                        "ssim": 1.0,
                        "mae_time_regularization": 0.0,
                        "ssim_time_regularization": 0.0,
                    },
                ),
                augm_cfg=AugmentationConfig(  # Only change the non-default values
                    p_gap=0.0,  # Time is irrelevant for this network
                    r_cloud=0.5,
                    nrt_max=0,  # Time is irrelevant for this network
                ),
                es_improvement_ratio=0.001,
            ),
            "transformer": TrainerConfig(
                loss_cfg=LossConfig(  # Only change the non-default values
                    weights={
                        "mae_masked": 1.0,
                        "mae_unmasked": 1.0,
                        "ssim": 0.5,  # Info on structure still important to propagate
                        "mae_time_regularization": 0.01,  # Secondary objective
                        "ssim_time_regularization": 0.01,  # Secondary objective
                    },
                ),
                es_improvement_ratio=0.0005,
            ),
            # "decoder": TrainerConfig(
            #     loss_cfg=LossConfig(  # Only change the non-default values
            #         weights={
            #             "mae_masked": 1.0,
            #             "mae_unmasked": 1.0,
            #             "ssim": 1.0,
            #             "mae_time_regularization": 0.05,  # Secondary objective
            #             "ssim_time_regularization": 0.05,  # Secondary objective
            #         },
            #     ),
            #     lr=5e-5,
            #     es_improvement_ratio=0.0,
            # ),
            "final": TrainerConfig(
                loss_cfg=LossConfig(  # Only change the non-default values
                    weights={
                        "mae_masked": 1.0,
                        "mae_unmasked": 1.0,
                        "ssim": 1.0,
                        "mae_time_regularization": 0.01,  # Secondary objective
                        "ssim_time_regularization": 0.01,  # Secondary objective
                    },
                ),
                lr=1e-5,
                es_improvement_ratio=0.0,
            ),
        },
        description="Trainer configurations",
        example={"example": TrainerConfig()},
    )

    class Config:
        """Configuration of the LossConfig class."""

        extra = "allow"
        arbitrary_types_allowed = True


if __name__ == "__main__":
    from vito_cropsar.data import ScalerNone

    cfg = ModelConfig(
        tag="test",
        scaler=ScalerNone(
            {
                "bands_s1": ["s1_vv", "s1_vh"],
                "bands_s2": ["s2_fapar"],
            }
        ),
        sample_s1=0,
        optional="test",
    )
    print(cfg)
