"""Cleaning functions for faulty samples."""

from __future__ import annotations

from pathlib import Path
from typing import Any

import numpy as np
from numpy.typing import NDArray

from vito_cropsar.constants import PRECISION_FLOAT_NP, PRECISION_INT_NP
from vito_cropsar.vito_logger import bh_logger


def check_faulty_sample(
    npz: dict[str, NDArray[PRECISION_INT_NP | PRECISION_FLOAT_NP]],
    tile_path: Path,
    writer: Any = None,
) -> bool:
    """Remove faulty samples from the dataset.

    Parameters
    ----------
    npz : dict[str, NDArray[PRECISION_INT_NP | PRECISION_FLOAT_NP]]
        Dictionary containing the sample data to check
    tile : str
        Name of the tile
    split : str
        Name of the split

    Returns
    -------
    bool
        True if the sample is faulty, False otherwise
    """
    writer = writer if writer is not None else bh_logger

    tile_name = tile_path.name
    tile_parent = tile_path.parent.name

    # Check if the sample is faulty
    remove_s1 = is_faulty_s1(npz, max_s1_missing=0.66)
    remove_s2 = is_faulty_s2(npz, vis_window=32, vis_threshold=0.8, min_visible=2)

    # Remove faulty samples
    if remove_s1 or remove_s2:
        new_path = Path(f"{tile_path.parent.__str__()}_removed")
        new_path.mkdir(parents=True, exist_ok=True)
        if remove_s1:
            writer(
                f"{tile_parent}/{tile_name} has more than 66% of nan pixels in a single stack in s1"
            )
        if remove_s2:
            writer(
                f"{tile_parent}/{tile_name} has less than 2 visible images in a 32 temporal window"
            )
        writer(
            f"Moving {tile_parent}/{tile_name} into {tile_parent}_removed/{tile_name}\n"
        )
        tile_path.rename(new_path / tile_name)
        return True
    return False


def is_faulty_s1(
    npz: dict[str, NDArray[PRECISION_INT_NP | PRECISION_FLOAT_NP]],
    max_s1_missing: float = 0.66,
) -> bool:
    """Check if samples is faulty in s1.

    Faulty samples have more than 66% of nan pixels in both s1 and s2 stacks.

    Parameters
    ----------
    npz : dict[str, NDArray[PRECISION_INT_NP | PRECISION_FLOAT_NP]]
        Numpy dictionary containing tile
    max_s1_missing : float, optional
        Max nan missing values in s1 expressed in percentage over the
        entire stack, by default 0.66 -> 2/3 of one pixel stack is nan


    Returns
    -------
    bool
        True if the sample is faulty, False otherwise
    """
    vv_bands = [npz["bands_s1"].index(k) for k in npz["bands_s1"] if "vv" in k]
    s1 = npz["s1"][:, vv_bands]
    nancounts_s1 = np.isnan(s1).sum(0).max((1, 2))
    nanratio_s1 = nancounts_s1 / len(s1)
    return bool(~(nanratio_s1 < max_s1_missing).any())


def is_faulty_s2(
    npz: dict[str, NDArray[PRECISION_INT_NP | PRECISION_FLOAT_NP]],
    vis_window: int = 32,
    vis_threshold: float = 0.8,
    min_visible=2,
) -> bool:
    """Check if samples is faulty in s2.

    Faulty samples have less than 2 visible images in a 32 temporal window in s2.


    Parameters
    ----------
    npz : dict[str, NDArray[PRECISION_INT_NP | PRECISION_FLOAT_NP]]
        Numpy dictionary containing tile
    vis_window : int, optional
        Size of the temporal window to consider, by default 32
    vis_threshold : float, optional
        Percentage of visible pixels in an image to consider it visible, by default 0.8
    min_visible : int, optional
        Minimum number of visible images in the temporal window to consider the stack valid, by default 2


    Returns
    -------
    bool
        True if the stack is faulty, False otherwise
    """
    # read one of the s2 bands (all bands should be almost identical in terms of visibility)
    s2_visible = 1 - (np.isnan(npz["s2"][:, 0]).sum((1, 2))) / (
        npz["s2"][:, 0].shape[1] ** 2
    )

    # check if there are at least 2 visible images in all the 32 temporal windows
    for i in range(len(s2_visible) - vis_window):
        count = (s2_visible[i : i + vis_window] > vis_threshold).sum()
        if count < min_visible:
            return True
    return False


if __name__ == "__main__":
    npz = np.load("data/data/testing/31UDS_4608_4864_5376_5632_2019-11-03.npz")
    check_faulty_sample(npz=npz)
