"""Model and trainer configuration."""

# Doesn't work with Pydantic in Python3.8
# from __future__ import annotations

import json
from pathlib import Path
from typing import Dict, Optional, Union

from pydantic import BaseModel, Field, root_validator, validator

from vito_cropsar.constants import get_data_folder, get_models_folder
from vito_cropsar.data import Scaler


class ModelBaseConfig(BaseModel):
    """Model configuration."""

    tag: str = Field(
        ...,
        description="Model to use",
        example="cnn_transformer",
    )
    folder: Path = Field(
        get_models_folder(),
        description=(
            "Folder where the model gets stored"
            "Note: This folder is not the one where your binaries are stored, this is 'folder / tag' instead"
        ),
        example=Path.cwd(),
    )
    scaler: Scaler = Field(
        ...,
        description="Data scaler applied to the model's input",
        example=None,  # Not possible to show an example here
    )
    n_ts: int = Field(
        32,
        gt=0,
        le=48,
        description="Temporal resolution of the data (number of time steps)",
        example=32,
    )
    resolution: int = Field(
        128,
        gt=0,
        le=256,
        description="Spatial resolution of the data",
        example=128,
    )
    n_channels_in: int = Field(
        6,
        description="Number of input channels",
        example=3,
    )
    n_channels_out: int = Field(
        4,
        description="Number of output channels (to predict)",
        example=1,
    )
    fill_nan: int = Field(
        0,
        description="Value to fill NaNs with",
        example=0,
    )
    align_input: bool = Field(
        False,
        description="Whether to align the input data to the output data (default False since it might generate blurry results)",
        example=False,
    )
    smooth_s1: bool = Field(
        True,
        description="Whether to smooth the S1 data (Speckle filters)",
        example=True,
    )

    @validator("scaler")
    def check_scaler(cls, v: Scaler) -> Scaler:  # noqa: N805
        """Check if the scaler is valid."""
        assert isinstance(v, Scaler), "Scaler must be of type Scaler"
        return v

    @property
    def model_folder(self) -> Path:
        """Return the model folder."""
        return self.folder / self.tag

    class Config:
        """Configuration of the ModelBaseConfig class."""

        extra = "allow"
        arbitrary_types_allowed = True

    def json(self, *args, **kwargs):
        """Convert the object to a JSON string."""
        x = json.loads(super().json(*args, **kwargs, exclude={"scaler"}))
        x["scaler"] = {
            "bands_s1": self.scaler.bands_s1,
            "bands_s2": self.scaler.bands_s2,
            "sample_s1": self.scaler.sample_s1,
        }
        return json.dumps(x)

    def dict(self, *args, **kwargs):
        """Convert the object to a dictionary."""
        return json.loads(self.json(*args, **kwargs))


class LossConfig(BaseModel):
    """Loss configuration."""

    weights: Dict[str, float] = Field(  # noqa: UP006
        {
            "mae_masked": 1.0,
            "mae_unmasked": 1.0,
            "ssim": 0.5,
            "mae_time_regularization": 0.2,
            "ssim_time_regularization": 0.1,
        },
        description="Weighted contribution of each loss in the final backprop calculation (only computed if corresponding use_* is True)",
        example={"mae_masked": 0.42},
    )
    use_mae: bool = Field(
        True,
        description="Whether to use MAE loss",
        example=True,
    )
    use_ssim: bool = Field(
        True,
        description="Whether to use SSIM loss",
        example=True,
    )
    use_mae_time_reg: bool = Field(
        True,
        description="Whether to use MAE time regularization loss",
        example=True,
    )
    use_ssim_time_reg: bool = Field(
        True,
        description="Whether to use SSIM time regularization loss",
        example=True,
    )

    @validator("weights")
    def check_weigths(
        cls, v: Dict[str, float]  # noqa: UP006, N805
    ) -> Dict[str, float]:  # noqa: UP006
        """Check all the weights."""
        assert all(w >= 0 for w in v.values()), "All weights must be >= 0"
        return v

    class Config:
        """Configuration of the LossConfig class."""

        extra = "allow"
        arbitrary_types_allowed = True


class AugmentationConfig(BaseModel):
    """Augmentation configuration."""

    p_flip_vertical: float = Field(
        0.5,
        description="Probability to apply vertical flipping",
        example=0.5,
        ge=0,
        le=1,
    )
    p_flip_horizontal: float = Field(
        0.5,
        description="Probability to apply horizontal flipping",
        example=0.5,
        ge=0,
        le=1,
    )
    p_rotate: float = Field(
        1.0,
        description="Probability to apply random rotation",
        example=1.0,
        ge=0,
        le=1,
    )
    p_gap: float = Field(
        0.5,
        description="Probability to drop the worst located time step in the input",
        example=0.5,
        ge=0,
        le=1,
    )
    r_cloud: float = Field(
        0.33,
        description="Ratio of the visible inputs that should be artificially obscured",
        example=0.33,
        ge=0,
        le=1,
    )
    r_visible: float = Field(
        0.5,
        description="Ratio of the visible pixels in one time step before considering this step as 'visible'",
        example=0.5,
        ge=0,
        le=1,
    )
    nrt_max: int = Field(
        3,
        description="Maximum number of time steps that can get removed from the input to mimic NRT",
        example=3,
        ge=0,
    )

    class Config:
        """Configuration of the AugmentationConfig class."""

        extra = "allow"
        arbitrary_types_allowed = True


class TrainerBaseConfig(BaseModel):
    """Trainer configuration."""

    lr: float = Field(
        1e-4,
        gt=0,
        description="Initial learning rate",
        example=1e-4,
    )
    es_tolerance: int = Field(
        6,
        ge=0,
        description="Early stopping tolerance",
        example=6,
    )
    es_improvement_ratio: float = Field(
        0.0005,
        ge=0,
        description="Early stopping improvement (ratio) before best model gets replaced",
        example=0.001,
    )
    cache_tag: Optional[str] = Field(  # noqa: UP007
        None,
        description="Cache tag to use (default: None, no caching)",
        example="test",
    )
    data_f: Path = Field(
        get_data_folder(),
        description="Folder where the data is stored",
        example=Path.cwd(),
    )
    steps_val: int = Field(
        300,
        gt=0,
        description="Number of training steps inbetween validation cycles",
        example=50,
    )
    steps_max: Optional[int] = Field(  # noqa: UP007
        None,
        description="Maximum number of training steps (default: None, no limit)",
        example=1000,
    )
    batch_size: int = Field(
        10,
        gt=0,
        description="Batch size used during training",
        example=8,
    )
    size_train: Union[int, float] = Field(  # noqa: UP007
        1.0,
        gt=0,
        description="Size of the training set (int: number of samples, float: percentage of samples)",
        example=1.0,
    )
    size_val: Union[int, float] = Field(  # noqa: UP007
        1.0,
        gt=0,
        description="Size of the validation set (int: number of samples, float: percentage of samples)",
        example=1.0,
    )
    loss_cfg: LossConfig = Field(
        LossConfig(),
        description="Loss configuration",
        example=LossConfig(),
    )
    augm_cfg: AugmentationConfig = Field(
        AugmentationConfig(),
        description="Augmentation configuration",
        example=AugmentationConfig(),
    )
    n_plots: int = Field(
        3,
        description="Number of plots to generate during training (slow!)",
        example=5,
    )
    n_data_loaders: Optional[int] = Field(  # noqa: UP007
        None,
        description="Number of data loaders to use (CPU-count if None)",
        example=4,
    )

    @root_validator
    def validate_n_plots(cls, values):  # noqa: N805
        """Validate n_plots."""
        if values["n_plots"] > values["batch_size"]:
            raise ValueError("n_plots must be <= batch_size")
        return values

    class Config:
        """Configuration of the TrainerBaseConfig class."""

        extra = "allow"
        arbitrary_types_allowed = True

    def json(self, *args, **kwargs):
        """Convert the object to a JSON string."""
        x = json.loads(super().json(*args, **kwargs))
        return json.dumps(x)

    def dict(self, *args, **kwargs):
        """Convert the object to a dictionary."""
        return json.loads(self.json(*args, **kwargs))


if __name__ == "__main__":
    from pprint import pprint

    from vito_cropsar.data import ScalerNone

    cfg = ModelBaseConfig(
        tag="test",
        scaler=ScalerNone(
            {
                "bands_s1": ["s1_vv", "s1_vh"],
                "bands_s2": ["s2_fapar"],
                "sample_s1": True,
            }
        ),
        optional="test",
    )
    pprint(cfg.json())  # noqa: T203
