"""Utilisation functions."""

from __future__ import annotations

import numpy as np
import torch
from numpy.typing import NDArray
from scipy import ndimage

from vito_cropsar.constants import PRECISION_FLOAT_NP, PRECISION_INT_NP
from vito_cropsar.vito_logger import LogLevel, bh_logger


def apply_mask(
    x: NDArray[PRECISION_FLOAT_NP] | torch.Tensor,
    mask: NDArray[PRECISION_INT_NP] | torch.Tensor,
) -> NDArray[PRECISION_FLOAT_NP] | torch.Tensor:
    """
    Apply a mask to a sample.

    Parameters
    ----------
    x : NDArray[PRECISION_FLOAT_NP] | torch.Tensor
        Sample to apply the mask to
        Shape: (time, channels, width, height) or (batch, time, channels, width, height)
    mask : NDArray[PRECISION_INT_NP] | torch.Tensor
        Mask to apply
        Shape: (time, width, height) or (batch, time, width, height)

    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        Masked sample
        Shape: (time, channels, width, height) or (batch, time, channels, width, height)
    """
    assert (
        x[:, 0].shape == mask.shape
    ), f"Invalid shapes: x={x.shape}, mask={mask.shape}"

    # Shape the mask
    if len(x.shape) == 4:  # noqa: PLR2004
        mask_ = mask[:, None, :, :]
        if isinstance(x, torch.Tensor):
            mask_ = mask_.repeat(1, x.shape[1], 1, 1)
        else:
            mask_ = np.repeat(mask_, repeats=x.shape[1], axis=1)
    elif len(x.shape) == 5:  # noqa: PLR2004
        mask_ = mask[:, :, None, :, :]
        if isinstance(x, torch.Tensor):
            mask_ = mask_.repeat(1, 1, x.shape[2], 1, 1)
        else:
            mask_ = np.repeat(mask_, repeats=x.shape[1], axis=1)
    else:
        raise ValueError(f"Invalid shape for x: {x.shape}")

    # Apply the mask
    x[mask_ == 0] = torch.nan if isinstance(x, torch.Tensor) else np.nan
    return x


def process_s1(
    x: dict[str, NDArray[PRECISION_INT_NP]],
    bands: list[str],
    speckle: bool = True,
) -> NDArray[PRECISION_FLOAT_NP]:
    """Process s1 data.

    steps:
        - go from dn to power level
        - apply speckle filters while removing nans
        - go from power level to db
        - stack samples, transpose, and convert to float32

    Parameters
    ----------
    x : dict[str, NDArray[PRECISION_INT_NP]]
        Dictionary containing raw s1 data
    bands : list[str]
        List of s1 bands to use
    speckle : bool
        Whether to apply speckle filters

    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        Processed s1 data
    """
    s1 = []
    for b in bands:
        # go from dn to power level
        s1_ = s1_dn_to_db(x[b])

        # apply speckle filters while removing nans
        if speckle:
            s1_ = s1_db_to_power(s1_)
            nan_mask = np.isnan(s1_)
            s1_[nan_mask] = 0.0
            s1_ = speckle_filter(s1_)
            s1_[nan_mask] = np.nan
            s1_ = s1_power_to_db(s1_)

        # go from power level to db
        s1.append(s1_)
    s1 = np.transpose(np.stack(s1), [3, 0, 1, 2])  # n_ts, n_channels, width, height
    s1 = s1.astype(PRECISION_FLOAT_NP)
    return s1


def process_s2(
    x: dict[str, NDArray[PRECISION_INT_NP]],
    bands: list[str],
) -> NDArray[PRECISION_FLOAT_NP]:
    """Process s2 data.

    steps:
        - go from dn to db
        - stack samples, transpose, and convert to float32

    Parameters
    ----------
    x : dict[str, NDArray[PRECISION_INT_NP]]
        Dictionary containing raw s2 data
    bands : list[str]
        List of s2 bands to use

    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        Processed s2 data
    """
    s2 = []
    for b in bands:
        if b == "s2_ndvi":
            nir = s2_dn_to_db(x["s2_b08"], band="s2_b08")
            red = s2_dn_to_db(x["s2_b04"], band="s2_b04")
            ndvi = (nir - red) / (nir + red)
            s2.append(ndvi)
        else:
            s2.append(s2_dn_to_db(x[b], band=b))

    s2 = np.stack(s2)
    s2 = np.transpose(s2, [3, 0, 1, 2])  # n_ts, n_channels, width, height
    s2 = s2.astype(PRECISION_FLOAT_NP)
    return s2


def s1_dn_to_db(
    x: NDArray[np.uint16],
) -> NDArray[PRECISION_FLOAT_NP]:
    """
    Transform array of Sentinel-1 UINT16 digital numbers to physical dB values in Float32.

    Parameters
    ----------
    x : NDArray[np.uint16]
        scaled input array

    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        Output array of physical decibel values.
    """
    x_ = x.astype(PRECISION_FLOAT_NP)
    x_[(x == 65535) | (x == 0)] = np.nan  # noqa: PLR2004
    return 20 * np.log10(x_) - 83


def s1_db_to_power(
    x: NDArray[PRECISION_FLOAT_NP],
) -> NDArray[PRECISION_FLOAT_NP]:
    """
    Transform array of Sentinel-1 physical dB values to power level.

    Parameters
    ----------
    x : NDArray[np.uint16]
        Physical dB array

    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        Output array of power level.
    """
    return 10 ** (x / 10)


def s1_power_to_db(
    x: NDArray[PRECISION_FLOAT_NP],
) -> NDArray[PRECISION_FLOAT_NP]:
    """
    Transform array of Sentinel-1 power level values to physical dB.

    Parameters
    ----------
    x : NDArray[np.uint16]
        Power level array

    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        Output array of physical dB.
    """
    x_ = x.copy()
    return 10 * np.log10(x_)


def s2_dn_to_db(
    x: NDArray[np.uint16],
    band: str,
) -> NDArray[PRECISION_FLOAT_NP]:
    """
    Transform array of Sentinel-2 UINT16 digital numbers to physical dB values in Float32.

    Parameters
    ----------
    x : NDArray[np.uint16]
        scaled input array
    band : str
        band name of the data to transform

    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        Output array of physical decibel values.
    """
    x_ = x.astype(PRECISION_FLOAT_NP)
    if band in {"s2_fapar", "s2_fcover"}:
        x_[x == 255] = np.nan  # noqa: PLR2004
        return x_ * 0.005
    if band in {"s2_b02", "s2_b03", "s2_b04", "s2_b08"}:
        x_[x == 65535] = np.nan  # noqa: PLR2004
        return x_ * 0.0001
    raise ValueError(f"Unknown band {band}")


def _gamma_map(
    img: NDArray[PRECISION_FLOAT_NP],
    win_size: int,
    enl: int,
    ndv: float = 0.0,
) -> NDArray[PRECISION_FLOAT_NP]:
    """Apply gamma MAP filter to an s1 image.

    NOTE: Codebase from VITO

    Parameters
    ----------
    img : NDArray[PRECISION_FLOAT_NP]
        S1 image
    win_size : int
        Filter size
    enl : int
        Filter enl
    ndv : float
        Non-defined-value, by default 0.0

    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        Filtered image
    """
    mask_ndv = img == ndv
    img[mask_ndv] = 0.0
    img_mean = ndimage.uniform_filter(img, size=win_size)
    img_mean2 = ndimage.uniform_filter(pow(img, 2), size=win_size)
    img_mean[mask_ndv] = 0.0
    img_mean2[mask_ndv] = 0.0

    var_z = img_mean2 - pow(img_mean, 2)
    sig_v2 = 1.0 / enl
    out = img_mean

    with np.errstate(divide="ignore", invalid="ignore"):
        fact1 = var_z / pow(img_mean, 2)
        fact1[np.isnan(fact1)] = 0
        mask = (fact1 > sig_v2) & (
            (var_z - pow(img_mean, 2) * sig_v2) > 0.0  # noqa: PLR2004
        )

        if mask.any():
            n = (pow(img_mean, 2) * (1.0 + sig_v2)) / (
                var_z - pow(img_mean, 2) * sig_v2
            )
            phalf = (img_mean * ((enl + 1.0) - n)) / (2 * n)
            q = enl * img_mean * img / n
            out[mask] = -phalf[mask] + np.sqrt(pow(phalf[mask], 2) + q[mask])

    out[img == 0.0] = ndv  # noqa: PLR2004
    return out


def speckle_filter(
    stack: NDArray[PRECISION_FLOAT_NP],
    win_size: int = 7,
    enl: int = 3,
) -> NDArray[PRECISION_FLOAT_NP]:
    """Apply speckle filter to stack of s1 images.

    NOTE: Codebase from VITO

    Parameters
    ----------
    stack : NDArray[PRECISION_FLOAT_NP]
        S1 stack.
        NOTE: S1 must be expressed in power level!
        NOTE: NaN values should be set to 0.0!
    win_size : int, optional
        Window size, by default 7
    enl : int, optional
        Kernel enl, by default 3

    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        Filtered S1 stack
    """
    assert (
        -0.5 < np.mean(stack) < 1.5  # noqa: PLR2004
    ), "S1 should be express in power level."
    assert ~np.isnan(stack).any(), "NaN values should be set to 0.0"

    # initialization
    rows, cols, time = stack.shape
    image_sum = image_num = np.zeros((rows, cols))
    image_fil = np.zeros((rows, cols, time))
    filtim = np.zeros((rows, cols, time))

    # iterate over time and apply filter
    for t in range(time):
        rcs = _gamma_map(img=stack[:, :, t], win_size=win_size, enl=enl)
        with np.errstate(divide="ignore", invalid="ignore"):
            ratio = stack[:, :, t] / rcs
            ratio[np.isnan(ratio)] = 0
        image_sum = image_sum + ratio
        image_num = image_num + (ratio > 0)
        image_fil[:, :, t] = rcs

    # TODO: Can yield inf values --> something to worry about?
    with np.errstate(invalid="ignore"):
        for t in range(time):
            im = stack[:, :, t]
            filtim1 = image_fil[:, :, t] * image_sum / image_num
            filtim1[np.isnan(filtim1)] = 0
            fillmask = (filtim1 == 0) & (im > 0)
            filtim1[fillmask] = im[fillmask]
            mask = im > 0
            filtim1[mask == 0] = im[mask == 0]
            filtim[:, :, t] = filtim1

    # Put odd values to mean value of full stack  TODO: Temporary solution
    is_inf = np.isinf(filtim)
    if is_inf.any():
        bh_logger(
            "Inf values in filtered S1 stack. Setting to mean value.",
            lvl=LogLevel.WARNING,
        )
        filtim[is_inf] = np.nanmean(filtim[~is_inf])
    return filtim


def format_cloud(x: NDArray[PRECISION_FLOAT_NP]) -> NDArray[PRECISION_INT_NP]:
    """Format the cloud data."""
    # Ensure no NaN values (NaN means masked out, so value 0)
    x[np.isnan(x)] = 0

    # n_ts, width, height
    x = x.transpose(2, 0, 1)

    # Return as minimal precision, since binary
    return x.astype(PRECISION_INT_NP)
