"""Custom U-Net 2D dataset."""

from __future__ import annotations

import traceback
from copy import deepcopy
from multiprocessing import cpu_count
from pathlib import Path

import torch
from torch.utils.data import DataLoader, Dataset

from vito_cropsar.constants import get_data_folder
from vito_cropsar.data import Scaler, apply_mask, list_tiles, load_sample
from vito_cropsar.models import AugmentationConfig
from vito_cropsar.models.shared.augmentation import Augmentor
from vito_cropsar.vito_logger import LogLevel, bh_logger


class CropSarDataset(Dataset):
    """Custom U-Net 2D dataset."""

    def __init__(
        self,
        split: str,
        scaler: Scaler,
        data_f: str | None = None,
        size: int | float = 1.0,
        n_ts: int = 32,
        resolution: int = 128,
        cache_tag: str | None = None,
        align: bool = True,
        smooth_s1: bool = True,
        augm_cfg: AugmentationConfig = AugmentationConfig(),  # noqa: B008
    ) -> None:
        """
        Initialise the dataset.

        Parameters
        ----------
        split : str
            Dataset split to extract tiles from
        scaler : Scaler
            Data scaler to normalise the data
        data_f : str | None
            Data folder location, by default ./data/data
        size : int | float
            Size of the dataset (int: number of tiles, float: percentage of total tiles)
        n_ts : int
            Number of time steps to use
            Note: Can be used as an augmentation technique
        resolution : int
            Resolution of the tiles to use (square)
            Note: Can be used as an augmentation technique
        cache_tag : str | None
            Cache tag to use (default: None, no caching)
        align : bool
            Whether to align the time series (jitter removal in S2)
        smooth_s1 : bool
            Whether to smooth the S1 time series
        augm_cfg : AugmentationConfig
            Augmentation configuration containing the probabilistic properties
        """
        super().__init__()
        assert scaler is not None, "Scaler must be provided"
        assert split in {
            "training",
            "testing",
            "validation",
        }, "Split must be one of 'training', 'testing' or 'validation'"

        # Get the tiles
        self.data_f = Path(data_f) if data_f is not None else get_data_folder()
        self.split = split
        self.tiles = list_tiles(
            data_f=self.data_f,
            split=self.split,
            cache_tag=cache_tag,
        )
        self.cache_tag = cache_tag
        self.align = align
        self.smooth_s1 = smooth_s1
        if size <= 1.0:  # noqa: PLR2004
            if size < 1.0:  # noqa: PLR2004
                bh_logger(
                    f"Only using fraction ({size}) of the dataset",
                    LogLevel.WARNING,
                )
            assert (
                int(len(self.tiles) * size) > 0
            ), f"Requested size ({size}) is too small"
            size = int(len(self.tiles) * size)
        elif size < len(self.tiles):
            bh_logger(
                f"Only using {size}/{len(self.tiles)} tiles from the dataset",
                LogLevel.WARNING,
            )
        else:
            raise Exception(
                f"Not enough tiles in the dataset ({len(self.tiles)} < {size})"
            )
        self.tiles = self.tiles[:size]

        # Continue initialisation
        self._scaler = scaler
        self._augmentor = Augmentor(
            p_cfg=augm_cfg,
            n_ts=n_ts,
            resolution=resolution,
            sample_s1=self._scaler.sample_s1,
            sample_mask=(split == "training"),  # Only sample mask in training
            freeze=(split != "training"),  # Only augment training data
            data_f=self.data_f,
        )

    def __len__(self) -> int:
        """Length of the dataset (number of tiles)."""
        return len(self.tiles)

    def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
        """Return datapoint and label at location index in the dataset."""
        tile = self.tiles[idx]
        try:
            # Load and unfold the sample
            sample = self._prepare_sample(tile=tile)

            # Perform augmentation
            sample = self._augmentor(
                {k: sample[k] for k in ("s1", "s2", "mask", "target")}
            )

            # Fail-proof, at least one sample
            sample = _ensure_mask(sample)

            # Return the result
            result = {k: sample[k] for k in ("s1", "s2", "mask", "target")}
            result["tile"] = tile  # Keep tile information
            return result
        except Exception as e:  # noqa: BLE001
            # Wrap the error to print which tile caused the error
            bh_logger(f"Failed to load tile '{tile}': {e}", LogLevel.ERROR)
            traceback.print_exception(type(e), e, e.__traceback__)
            raise Exception(f"Failed to load tile '{tile}'") from e

    def _prepare_sample(self, tile: str) -> dict[str, torch.Tensor | None]:
        """Prepare the sample, specified by the tile ID."""
        # Get the next tile
        sample = load_sample(
            data_f=self.data_f,
            split=self.split,
            tile=tile,
            cache_tag=self.cache_tag,
            smooth_s1=self.smooth_s1,
            align=self.align,
        )
        s1 = sample["s1"]
        s2 = sample["s2"]
        target = deepcopy(sample["s2"])

        # Extract the right bands
        if sample["bands_s1"] != self._scaler.bands_s1:
            s1 = s1[:, [sample["bands_s1"].index(b) for b in self._scaler.bands_s1]]
        if sample["bands_s2"] != self._scaler.bands_s2:
            _bands = [sample["bands_s2"].index(b) for b in self._scaler.bands_s2]
            s2, target = s2[:, _bands], target[:, _bands]

        # Apply the fictive mask if not in training mode
        mask = sample["mask"]
        if self.split != "training":
            assert "mask_" in sample, "Mask must be provided for validation/testing"
            mask &= sample["mask_"]
            s2 = apply_mask(s2, mask=mask)

        # Normalise the input
        s1, s2 = self._scaler(s1=s1, s2=s2, safe=False)
        _, target = self._scaler(s2=target, safe=False)

        # Return the sample
        return {
            "s1": torch.tensor(s1),
            "s2": torch.tensor(s2),
            "mask": torch.tensor(mask),
            "mask_": torch.tensor(sample["mask_"]) if "mask_" in sample else None,
            "target": torch.tensor(target),
        }

    def get_dataloader(
        self,
        batch_size: int,
        n_loaders: int | None = None,
    ) -> DataLoader:
        """
        Get the dataset's data loader.

        Parameters
        ----------
        batch_size : int
            Batch size
        n_loaders : int
            Number of workers to use to prepare the data
            Note: multiprocessing.cpu_count() is used if None

        Note
        ----
        The following characteristics only occur during training:
            - The dataset is shuffled
            - Incomplete batches are dropped

        Returns
        -------
        DataLoader
            Dataset's data loader
        """
        return DataLoader(
            dataset=self,
            batch_size=batch_size,
            num_workers=cpu_count() if n_loaders is None else n_loaders,
            shuffle=(self.split == "training"),  # Shuffle the dataset
            drop_last=(self.split == "training"),  # Ignores incomplete batches
            persistent_workers=(n_loaders != 0),
        )


def _ensure_mask(sample: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    """
    Ensure that at least something of the input is masked out.

    Parameters
    ----------
    sample : dict[str, torch.Tensor]
        Sample to ensure the mask for

    Returns
    -------
    dict[str, torch.Tensor]
        Sample with ensured mask
    """
    if sample["s2"].isnan().sum() == sample["target"].isnan().sum():
        is_valid = [i for i, v in enumerate(sample["s2"]) if (~v.isnan()).any()]
        assert is_valid, "Faulty sample encountered!"
        sample["s2"][is_valid[len(is_valid) // 2]] = torch.nan
        sample["mask"][is_valid[len(is_valid) // 2]] = 0
    return sample


if __name__ == "__main__":
    import matplotlib.pyplot as plt

    # Create the scaler
    bands_s1 = ["s1_asc_vv", "s1_des_vv", "s1_asc_vh", "s1_des_vh"]
    bands_s2 = ["s2_fapar", "s2_b02", "s2_b03", "s2_b04"]
    scaler = Scaler.load(bands_s1=bands_s1, bands_s2=bands_s2)

    # Generate the dataset
    my_dataset = CropSarDataset(
        split="testing",
        scaler=scaler,
        data_f=get_data_folder(),
        cache_tag="fapar_rgb",
    )
    print(f" - Generated a dataset of size {len(my_dataset)}")

    # Create a dataloader
    my_loader = my_dataset.get_dataloader(batch_size=3, n_loaders=4)

    batch = next(iter(my_loader))
    print(f" - First s1 of shape {batch['s1'].shape}")
    print(f" - First s2 of shape {batch['s2'].shape}")
    print(f" - First targets of shape {batch['target'].shape}")
    print(f" - First masks of shape {batch['mask'].shape}")
    print(f" - First tiles of size {len(batch['tile'])}")

    # Plot out the result
    n_ts = batch["s1"].shape[1]
    _, axs = plt.subplots(4, n_ts, figsize=(n_ts, 5))
    plt.suptitle(batch["tile"][0])
    plt.setp(axs, xticks=[], yticks=[])
    axs[0, 0].set_ylabel("target (index: 0)")
    axs[1, 0].set_ylabel("s2 (index: 0)")
    axs[2, 0].set_ylabel("mask")
    axs[3, 0].set_ylabel("s1 (index: 0)")
    for my_ts in range(n_ts):
        axs[0, my_ts].imshow(batch["target"][0, my_ts, 0], vmin=-1.0, vmax=1.0)
        axs[1, my_ts].imshow(batch["s2"][0, my_ts, 0], vmin=-1.0, vmax=1.0)
        axs[2, my_ts].imshow(batch["mask"][0, my_ts], vmin=0, vmax=1.0)
        axs[3, my_ts].imshow(batch["s1"][0, my_ts, 0], vmin=-1.0, vmax=1.0)
    plt.tight_layout()
    plt.savefig("example_dataset.png")
    plt.close()
