"""Utils functions for data formatting."""

from __future__ import annotations

from pathlib import Path

import cv2
import numpy as np
import torch
from scipy.ndimage import label

NAN_VALUE = 65535


def list_npz(folder: Path) -> list[Path]:
    """List all npz files paths in folder."""
    return list(folder.glob("**/*.npz"))


def clean_instance(instance: np.ndarray, min_size=10) -> torch.Tensor:
    """Remove from original instance too small or scattered fields."""
    canvas = torch.zeros_like(instance)
    i = 1

    # iterate over all instance fields
    for ix in np.unique(instance)[1:]:
        field = instance == ix
        labelled = label(field)[0]

        # iterate over field partitions
        for lbl in np.unique(labelled)[1:]:
            field_partition = labelled == lbl

            # keep only partitions which are bigger than min_size
            if (field_partition).sum() > min_size:

                # draw partition
                canvas[field_partition > 0] = i

                # fill holes in partition
                contours, hier = cv2.findContours(
                    field_partition.astype(np.uint8),
                    mode=cv2.RETR_TREE,
                    method=cv2.CHAIN_APPROX_SIMPLE,
                )
                for c, h in zip(contours, hier[0]):
                    if h[-1] != -1 and cv2.contourArea(c) <= min_size:
                        hole_mask = np.zeros((128, 128, 3))
                        hole_mask = cv2.fillPoly(hole_mask, pts=[c], color=(1, 1, 1))[
                            :, :, 0
                        ]
                        canvas[hole_mask > 0] = i
                i += 1

    return canvas


def compute_extent(target: torch.Tensor, min_size: int = 10) -> torch.Tensor:
    """Translate target from instance segmentation to semantic."""
    masks = [target == x for x in torch.unique(target) if x != 0]
    extent = torch.zeros_like(target)
    borders = torch.zeros_like(target)

    for mask in masks:
        # detect and remove edges
        canny = cv2.Canny(
            mask.numpy().astype(np.uint8), 0, 0, apertureSize=3, L2gradient=True
        )
        mask[canny > 0] = False

        # add field and borders to canvas
        extent[mask] = 1
        borders[canny > 0] = 1
    extent[borders > 0] = 0

    # remove artifacts
    labelled = label(extent)[0]
    for lbl in np.unique(labelled):
        partition = labelled == lbl
        if np.sum(partition) <= min_size:
            extent[partition] = 0

    return extent


def scale(bands: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    """Normalize S2 bands with expert rules."""
    for b in bands:
        if b in ["s2_b02", "s2_b03", "s2_b04"]:  # RGB
            bands[b] = bands[b].clip(0, 0.3) / 0.3
        else:
            bands[b] = bands[b].clip(0, 1)
    return bands


def transform(bands: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    """Transform bands from uint8 to float32."""
    for b in bands:
        bands[b] = bands[b].astype(np.float32)
        bands[b][bands[b] == NAN_VALUE] = np.nan
        bands[b] *= 0.0001
    return bands


def rerange(bands: dict[str, np.ndarray]) -> np.ndarray:
    """Range bands between -1 and 1.

    Assumes that bands are between 0 and 1.
    """
    for k in ["s2_b02", "s2_b03", "s2_b04", "s2_b08"]:
        bands[k] *= 2
        bands[k] -= 1
    bands["s2_ndvi"] = bands["s2_ndvi"].clip(-1, 1)
    return bands


def fill_and_interpolate(bands: np.ndarray) -> dict[str, np.ndarray]:
    """Fill and interpolate timeseries of samples over entire set of available bands.

    Parameters
    ----------
    bands : np.ndarray
        Raw sample

    Returns
    -------
    dict[str, np.ndarray]
        final interpolated sample
    """
    final = {}
    for sensor in bands:
        batch = bands[sensor]

        # STEP 1: remove sparse tiles
        cleaned, valid_ix = _rm_sparse_tiles(batch)

        # STEP 2: create patches and interpolate
        interpolated = _interpolate_valid_tiles(cleaned, valid_ix)

        # STEP 3: linear interpolation with close samples
        final[sensor] = _fill_empty_tiles(interpolated, valid_ix)

    return final


def compute_ndvi(final: dict[str, np.ndarray]) -> None:
    """Compute NDVI index."""
    nir = final["s2_b08"]
    red = final["s2_b04"]
    final["s2_ndvi"] = (nir - red) / (nir + red + 1e-10)
    return final


def _fill_empty_tiles(
    interpolated: list[np.ndarray], valid_ix: list[int]
) -> np.ndarray:
    """Fill empty tiles via linear interpolation.

    Parameters
    ----------
    valid_ix : list[int]
        Non-empty tiles.
    interpolated : list[np.ndarray]
        Time series of images

    return : np.ndarray
        Filled time series
    """
    filled = []
    for i, original in enumerate(interpolated):
        if i in valid_ix:
            filled.append(original)
        else:
            right_ix = next((val for val in valid_ix if val > i), -1)
            left_ix = next((val for val in valid_ix[::-1] if val < i), -1)
            if right_ix != -1 and left_ix == -1:  # I am the first one
                filled.append(interpolated[right_ix])
            elif right_ix == -1 and left_ix != -1:  # I am the last one
                filled.append(interpolated[left_ix])
            else:  # I am in the middle
                d_r = right_ix - i
                d_l = i - left_ix
                filled.append(
                    (interpolated[right_ix] * d_l + interpolated[left_ix] * d_r)
                    / (d_r + d_l)
                )
    return np.array(filled)


def _interpolate_valid_tiles(
    cleaned: list[np.ndarray], valid_ix: list[int]
) -> np.ndarray:
    """Fill gaps in valid tiles via interpolation.

    Parameters
    ----------
    cleaned : list[np.ndarray]
        Time series of images
    valid_ix : list[int]
        Valid tiles in the time series

    Returns
    -------
    list[np.ndarray]
        Time series with interpolated valid tiles
    """
    interpolated = []
    for i, original in enumerate(cleaned):
        # create patch
        if i in valid_ix:
            nanmask = np.isnan(original)
            tmp = cleaned.copy()[valid_ix]
            tmp[:, ~nanmask] = np.nan
            patch = np.nanmean(tmp, axis=0)

            # # process islands
            # labelled, _ = label(nanmask)
            # for lbl in np.unique(labelled)[1:]:
            #     island = np.zeros_like(labelled, dtype=np.bool_)
            #     island[labelled == lbl] = True

            #     # get borders
            #     borders = binary_dilation(island, iterations=20)
            #     borders[island] = False

            #     # median correction
            #     median_brd = np.nanmedian(original[borders])
            #     median_ptc = np.nanmedian(patch[island])
            #     patch[island] = (patch[island] - median_ptc) + median_brd

            original_ = original.copy()
            original_[nanmask] = patch[nanmask]
            interpolated.append(original_)
        else:
            interpolated.append(original)
    return np.array(interpolated)


def _rm_sparse_tiles(batch: list[np.ndarray]) -> tuple[np.ndarray, list[int]]:
    """Remove sparse tiles. Hence, with less than 50% of clean pixels.

    Parameters
    ----------
    batch : list[np.ndarray]
        Time series of images.

    Returns
    -------
    tuple[np.ndarray, list[int]]
        Tuple containing:
        - cleaned time series and
        - list of valid indices (indices of clean tiles in original time series)
    """
    valid_ix, cleaned = [], []
    for i, img in enumerate(batch):
        if np.isnan(img).sum() >= (128 * 128 * 0.5):
            cleaned.append(batch[i] * np.nan)
        else:
            cleaned.append(batch[i])
            valid_ix.append(i)
    return np.array(cleaned), valid_ix
