"""Generate masks."""

from __future__ import annotations

from pathlib import Path
from random import randint, random

import numpy as np
from numpy.typing import NDArray
from tqdm import tqdm

from vito_cropsar.constants import (
    PRECISION_FLOAT_NP,
    PRECISION_INT_NP,
    get_data_folder,
)
from vito_cropsar.data.io import list_tiles, load_masks
from vito_cropsar.data.utils import format_cloud
from vito_cropsar.vito_logger import bh_logger


def extract_masks(
    ratio_min: float = 0.2,
    ratio_max: float = 0.95,
    data_f: Path | None = None,
    max_masks: int = 2000,
) -> None:
    """
    Extract all possible masks from the training data, used to generate masks later on.

    Note: 1 means cloud-free (visible), 0 means cloud (masked out)
    Note: Full cloud obstructed our cloud free images are added in post (while sampling for masks)

    Parameters
    ----------
    ratio_min : float
        Minimal 'cloud' ratio within one tile's time-step
    ratio_max : float
        Maximal 'cloud' ratio within one tile's time-step
    data_f : Path | None
        Path of the data folder where the masks are stored (get_data_folder() by default)
    max_masks : int
        Maximum number of masks to extract
    """
    data_f = get_data_folder() if data_f is None else data_f

    # Generate the masks
    masks = []
    tiles = list_tiles(data_f=data_f, split="training")
    for tile in tqdm(tiles, desc="Extracting.."):
        # Check if enough masks sampled
        if len(masks) >= max_masks:
            tqdm.write(f" - Total of {len(masks)} new masks (maximum reached)")
            tqdm.write(f" - Saving {len(masks)} new masks..")
            break

        # Load the sample and extract the masks
        sample = np.load(data_f / "training" / f"{tile}.npz", allow_pickle=True)
        sample_masks = format_cloud(sample["s2_mask"])
        masks += [
            mask
            for mask in sample_masks
            if ratio_min <= mask.sum() / len(mask.flatten()) <= ratio_max
        ]

    # Save the masks
    np.save(data_f / "masks.npy", np.stack(masks))


def inject_equidistant_mask(
    npz: dict[str, NDArray[PRECISION_INT_NP | PRECISION_FLOAT_NP]],
    tile_path: Path,
    overwrite: bool = False,
) -> None:
    """
    Inject new masks into the npz files that are equidistantly distributed.

    Parameters
    ----------
    npz : dict[str, NDArray[PRECISION_INT_NP | PRECISION_FLOAT_NP]]
        Dictionary of the npz file
    overwrite : bool
        Whether to overwrite existing masks
    ratio_mask : float
        The ratio of currently visible tiles that will be masked out
        Note: Possibly only partially masked out
    ratio_visible : float
        Ratio of non-masked tiles (one time step) before considering the tile "visible"
    """
    # Skip if mask already exists
    if (not overwrite) and ("mask_" in npz):
        bh_logger(
            f" - Mask already exists for tile {tile_path.parent.name}/{tile_path.name}, skipping.."
        )
        return

    # Generate the extra mask
    mask_ = npz["mask"].copy()
    obscure_equidistant(mask=mask_)
    npz["mask_"] = mask_


def obscure_equidistant(
    mask: NDArray[PRECISION_INT_NP],
    ratio_mask: float = 0.33,
    ratio_visible: float = 0.5,
    data_f: Path | None = None,
) -> None:
    """
    Create a mask with equidistantly distributed artificial obscurences.

    Note
    ----
     - The input mask might get modified!
     - The mask will build further on the original mask

    Parameters
    ----------
    mask : NDArray[PRECISION_INT_NP]
        The original mask (real clouds)
        Shape: (time, width, height)
    ratio_mask : float
        The ratio of currently visible tiles that will be masked out
        Note: Possibly only partially masked out
    ratio_visible : float
        Ratio of non-masked tiles (one time step) before considering the tile "visible"
    data_f : Path | None
        Path of the data folder where the masks are stored (get_data_folder() by default)
    """
    # Get a list of visible tiles
    modifiables = np.where(
        ((mask == 1).sum((1, 2)) / np.prod(mask.shape[1:])) >= ratio_visible
    )[0]

    # Obscure ratio_mask of the visible tiles
    obscured = [-1, mask.shape[0]]
    for _ in range(round(len(modifiables) * ratio_mask)):
        modifiables, obscured = _obscure_furthest(
            options=modifiables,
            used=obscured,
            mask=mask,
            data_f=data_f,
        )


def enforce_big_gap(
    mask: NDArray[PRECISION_INT_NP],
    ratio: float = 0.5,
    ratio_visible: float = 0.5,
) -> None:
    """Add a complete obscurence in the worst possible position."""
    # Don't always add a gap
    if random() < ratio:
        return

    # Get a list of visible tiles
    modifiables = np.where(
        ((mask == 1).sum((1, 2)) / np.prod(mask.shape[1:])) >= ratio_visible
    )[0]
    if len(modifiables) <= 3:  # noqa: PLR2004 to prevent distructive modification
        return

    # Calculate the distance between neighbouring visible tiles
    distances = [
        (modifiables[i + 1] if (i + 1) < len(modifiables) else mask.shape[0])
        - (-1 if i == 0 else modifiables[i - 1])
        for i in range(len(modifiables))
    ]

    # Drop the worst possible position
    mask[modifiables[np.argmax(distances)]] = 0


def clean_cloud_mask(
    npz: dict[str, NDArray[PRECISION_INT_NP | PRECISION_FLOAT_NP]]
) -> None:
    """
    Clean single stack of masks according to nans in s2.

    If a pixel in s2 contains a nan value anywhere among the channels, the mask should be set to 0.
    """
    npz["mask"][np.isnan(npz["s2"]).any(axis=1)] = 0
    npz["s2"][np.repeat(npz["mask"][:, None], npz["s2"].shape[1], axis=1) == 0] = np.nan


def _sample_single_mask(
    shape: tuple[int, int],
    p_complete: float = 0.5,
    data_f: Path | None = None,
) -> NDArray[PRECISION_INT_NP]:
    """
    Sample a single tile's mask.

    Note: 1 means cloud-free (visible), 0 means cloud (masked out)

    Parameters
    ----------
    shape : tuple[int, int]
        Shape of the mask to sample in (width, height)
        Note: Can be at most (256, 256) due to the available masks
    p_complete : float
        Probability that the tile is completely masked out (complete cloud)
    data_f : Path | None
        Path of the data folder where the masks are stored (get_data_folder() by default)

    Returns
    -------
    NDArray[PRECISION_INT_NP]
        Single tile mask, sampled from all masks available
        Shape: (width, height)
        Values: 1 means cloud-free (visible), 0 means cloud (masked out)
    """
    assert p_complete <= 1.0, "p_complete must be <= 1.0"  # noqa: PLR2004

    # Sample mask
    masks = load_masks(data_f)
    if random() <= p_complete:
        return np.zeros(shape, dtype=PRECISION_INT_NP)
    sample = masks[randint(0, len(masks) - 1)]
    o_w = randint(0, sample.shape[0] - shape[0])
    o_h = randint(0, sample.shape[1] - shape[1])
    return sample[o_w : shape[0] + o_w, o_h : shape[1] + o_h].astype(PRECISION_INT_NP)


# TODO: There should be a better way to do this
def _obscure_furthest(
    options: list[int],
    used: list[int],
    mask: NDArray[PRECISION_INT_NP],
    data_f: Path | None = None,
) -> tuple[list[int], list[int]]:
    """Obscure the option located the furthest from the used options."""
    # Select the furthest option
    scores = []
    for option in options:
        scores.append(sum([(1 / (option - u)) ** 2 for u in used]))
    if len(scores) == 0:
        return [], used
    furthest = options[np.argmin(scores)]

    # Mask it out
    n_visible = np.sum(mask[furthest])
    mask_ = _sample_single_mask(mask.shape[1:], data_f=data_f)
    while np.sum(mask[furthest] & mask_) >= n_visible:
        mask_ = _sample_single_mask(mask.shape[1:], data_f=data_f)
    mask[furthest] = mask_

    # Return the results
    used.append(furthest)
    return [o for o in options if o != furthest], used


if __name__ == "__main__":
    extract_masks()
    # npz = np.load("data/data/testing/31UDS_4608_4864_5376_5632_2019-11-03.npz")
    # clean_cloud_mask(npz)
