"""Custom dataset for lot delineation."""

from __future__ import annotations

from logging import warning
from math import ceil
from pathlib import Path
from random import shuffle
from typing import Any, Iterator

import cv2
import numpy as np
import torch
from joblib import Parallel, cpu_count, delayed
from scipy.ndimage import distance_transform_cdt
from torch.utils.data import Dataset

from vito_lot_delineation.data.augmentor import Augmentor
from vito_lot_delineation.data.utils import load_file, load_folder


class DelineationDataset(Dataset):
    """Custom dataset for lot delineation."""

    def __init__(
        self,
        split: str,
        data_dir: Path | None = None,
        augment: bool = False,
        return_watersheds: bool = True,
        return_borders: bool = True,
        return_distance: bool = True,
        bands: list[str] = ["s2_ndvi"],
        n_ts: int = 4,
    ) -> None:
        """
        Initialise the dataset.

        Parameters
        ----------
        split : str
            Name of the dataset split
        data_dir : Path | None
            Path to dataset directory, if not specified, default data folder is used
        augment : bool
            If True, apply data augmentation
        return_watersheds : bool
            Return watershed weights for loss computation
        return_borders : bool
            Return border weights for loss computation
        return_distance : bool
            Return distance weights for loss computation
        bands: list[str]
            List of bands to extract from the data image
        n_ts: int
            Number of time stamps
        """
        super().__init__()
        if augment and split != "training":
            warning(f"Augmentation will be applied to a non-training split ({split})!")
            warning("If this is not intended, set augment=False on this split.")

        self._idx = 0
        self._paths = load_folder(split, data_dir)
        self._is_train = split == "training"
        self._augment = augment
        self._augmentor = Augmentor(n_ts=n_ts)
        self._return_watersheds = return_watersheds
        self._return_borders = return_borders
        self._return_distance = return_distance
        self._bands = bands
        self._n_ts = n_ts
        self.reset()

    def reset(self) -> None:
        """Reset the dataset."""
        self._idx = 0
        if self._is_train:
            shuffle(self._paths)

    def __len__(self) -> int:
        """Get length of dataset."""
        return len(self._paths)

    def __getitem__(self, i: int) -> dict[str, torch.Tensor]:
        """Return datapoint and label at location index in the dataset."""
        return _generate(
            {
                "path": self._paths[i],
                "augment": self._augment,
                "augmentor": self._augmentor,
                "return_watersheds": self._return_watersheds,
                "return_borders": self._return_borders,
                "return_distance": self._return_distance,
                "bands": self._bands,
                "n_ts": self._n_ts,
            }
        )

    def get_iterator(
        self, batch_size: int, n: int, reset_first: bool = False
    ) -> Iterator[dict[str, torch.Tensor]]:
        """Create a batch generating iterator for a specified number of batches."""
        # Reset if required
        if reset_first:
            self.reset()

        # Return generated batches
        for _ in range(n):
            yield self.get_batch(batch_size)

    def get_batch(
        self, batch_size: int, num_workers: int = 0
    ) -> dict[str, torch.Tensor]:
        """Get a batch of samples."""
        configs = [self._get_config() for _ in range(batch_size)]
        if (num_workers == 0) or (batch_size <= num_workers):
            samples = _generate_samples(configs)
        else:
            num_workers = (
                (cpu_count() // 2)
                if (num_workers < 0)
                else min(num_workers, cpu_count() // 2)
            )
            n = ceil(batch_size / num_workers)
            samples_ = Parallel(n_jobs=num_workers)(
                delayed(_generate_samples)(configs[i : i + n])
                for i in range(0, batch_size, n)
            )
            samples = [x for y in samples_ for x in y]

        return {k: torch.stack([x[k] for x in samples]) for k in samples[0]}

    def _get_config(self) -> dict[str, Any]:
        """Get the next sample configuration."""
        path = self._paths[self._idx]
        self._idx += 1
        if self._idx >= len(self._paths):
            self.reset()
        return {
            "path": path,
            "augment": self._augment,
            "augmentor": self._augmentor,
            "return_watersheds": self._return_watersheds,
            "return_borders": self._return_borders,
            "return_distance": self._return_distance,
            "bands": self._bands,
            "n_ts": self._n_ts,
        }


def _generate_samples(configs: list[dict[str, Any]]) -> list[dict[str, torch.Tensor]]:
    """Generate a batch of samples, following the provided configuration files."""
    return [_generate(cfg) for cfg in configs]


def _generate(cfg: dict[str, Any]) -> dict[str, torch.Tensor]:
    """Generate a sample following the provided configuration."""
    inp, instance, extent = load_file(cfg)
    result = {"input": inp, "instance": instance, "extent": extent}
    augmentor: Augmentor = cfg["augmentor"]

    result = (
        augmentor.augment(result)
        if cfg["augment"]
        else augmentor.tscut(result, center=True)
    )

    if cfg["return_watersheds"]:
        result.update(_compute_watersheds(result["extent"]))

    if cfg["return_borders"]:
        result.update(_compute_borders(result["extent"]))

    if cfg["return_distance"]:
        result.update(_compute_distance(result["extent"]))

    return result


def _compute_watersheds(extent: torch.Tensor) -> dict[str, torch.Tensor]:
    """Compute watersheds importance given extention map (semantic segmentation target).

    Parameters
    ----------
    extent : torch.Tensor
        Semantic map

    Returns
    -------
    dict[str, torch.Tensor]
        watersheds weights to add to a one_like weights mask
    """
    # generate
    blur = cv2.filter2D(extent.float().numpy(), -1, np.ones((5, 5)))
    watersheds = blur.copy()
    watersheds[extent > 0] = 0
    watersheds[watersheds == 0] = 0

    # normalize
    watersheds /= watersheds.max() + 1e-12
    return {"watersheds": torch.Tensor(watersheds)}


def _compute_borders(extent: torch.Tensor) -> dict[str, torch.Tensor]:
    """Compute binary borders gt given extention map (semantic segmentation target).

    Parameters
    ----------
    extent : torch.Tensor
        Semantic map

    Returns
    -------
    dict[str, torch.Tensor]
        binary mask for borders
    """
    # generate
    dilation = cv2.dilate(extent.float().numpy(), np.ones((3, 3)))
    borders = dilation - extent.numpy()

    # normalize
    return {"borders": torch.Tensor(borders)}


def _compute_distance(extent: torch.Tensor) -> dict[str, torch.Tensor]:
    """
    Compute distance importance given extention map (semantic segmentation target).

    Parameters
    ----------
    extent : torch.Tensor
        Field extent of the target

    Returns
    -------
    dict[str, torch.Tensor]
        Distance weighting of each field
    """
    distance = distance_transform_cdt(extent > 0)
    distance = torch.Tensor(distance).float()
    distance /= distance.max() + 1e-12
    return {"distance": distance}


if __name__ == "__main__":
    import matplotlib.pyplot as plt
    from tqdm import tqdm

    my_dataset = DelineationDataset(
        data_dir=Path("data/data"),
        split="validation",
        augment=True,
    )

    # Show one example
    sample = my_dataset[10]
    print("Keys:", sample.keys())

    # Plot the sample
    image = sample["input"]
    instance = sample["instance"]
    extent = sample["extent"]
    watersheds = sample["watersheds"]
    distance = sample["distance"]
    fig, ax = plt.subplots(1, 5, figsize=(20, 5))
    ax[0].set_title("Instance")
    ax[0].imshow(instance)
    ax[1].set_title("Extent")
    ax[1].imshow(extent)
    ax[2].set_title("watersheds")
    ax[2].imshow(watersheds)
    ax[3].set_title("distance")
    ax[3].imshow(distance)
    ax[4].set_title("watersheds + distance")
    ax[4].imshow(torch.ones_like(watersheds) + watersheds + distance)
    plt.show()

    # Stress test
    for _ in tqdm(
        range(10), desc="Generating sequential.."
    ):  # [00:02<00:00,  4.73it/s]
        _ = my_dataset.get_batch(batch_size=64, num_workers=0)
    for _ in tqdm(
        range(10), desc="Generating parallel (4).."
    ):  # [00:06<00:00,  1.61it/s]
        _ = my_dataset.get_batch(batch_size=64, num_workers=4)
    for _ in tqdm(
        range(10), desc="Generating parallel (-1).."
    ):  # [00:08<00:00,  1.12it/s]
        _ = my_dataset.get_batch(batch_size=64, num_workers=-1)
