"""Shared utilisation functions."""

from __future__ import annotations

from time import time
from typing import Any, Callable

import torch

TIME_CACHE = {}


def get_mask_errors(
    t: torch.Tensor,
    p: torch.Tensor,
    m: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Calculate the unmasked and masked errors for the provided sample.

    Parameters
    ----------
    t : torch.Tensor
        Target tensor
        Shape: (time, bands, width, height)
    p : torch.Tensor
        Predicted tensor
        Shape: (time, bands, width, height)
    m : torch.Tensor
        Mask tensor
        Shape: (time, width, height)

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        Unmasked and masked errors
    """
    # Calculate the absolute error
    e = torch.abs(t - p)
    m = m[:, None].repeat(1, e.shape[1], 1, 1)

    # Get the error of the unmasekd area (target exists and not masked out)
    e_unm = e[~torch.isnan(t) & (m == 1)]

    # Get the error of the masked area (target exists but masked out)
    e_m = e[~torch.isnan(t) & (m == 0)]

    return e_unm, e_m


def interpolate_s1_batch(x: torch.Tensor) -> torch.Tensor:
    """
    Linearly interpolate a batch of S1 data.

    Note: This script removes time steps with NaNs

    Parameters
    ----------
    x : torch.Tensor
        Sentinel-1 data
        Shape: (batch, time, channels_s1, width, height)

    Returns
    -------
    torch.Tensor
        Interpolated Sentinel-1 data
        Shape: (batch, time, channels_s1, width, height)
    """
    return torch.stack([_interpolate_s1_stack(x[i]) for i in range(x.shape[0])])


def _interpolate_s1_stack(x: torch.Tensor) -> torch.Tensor:
    """
    Linearly interpolate S1 bands.

    Note: This script removes time steps with NaNs

    Parameters
    ----------
    x : torch.Tensor
        Sentinel-1 data
        Shape: (time, channels_s1, width, height)

    Returns
    -------
    torch.Tensor
        Interpolated Sentinel-1 data
        Shape: (time, channels_s1, width, height)
    """
    return torch.stack(
        [_interpolate_s1_channel(x[:, i]) for i in range(x.shape[1])], dim=1
    )


def _interpolate_s1_channel(x: torch.Tensor) -> torch.Tensor:
    """
    Linearly interpolate a S1 channel.

    Note: This script removes time steps with NaNs

    Parameters
    ----------
    x : torch.Tensor
        Sentinel-1 data
        Shape: (time, width, height)

    Returns
    -------
    torch.Tensor
        Interpolated Sentinel-1 data
        Shape: (time, width, height)
    """
    # Put the time steps to NaN if NaN occurs
    is_nan, is_not_nan = [], []
    for i in range(x.shape[0]):
        if x[i].isnan().any():
            x[i] = torch.nan
            is_nan.append(i)
        else:
            is_not_nan.append(i)

    # If no non-NaN time steps are found, replace the full-NaN pixels with 0 and try again
    if not is_not_nan:
        x[:, x.isnan().all(dim=0)] = 0
        return _interpolate_s1_channel(x)

    # Interpolate the NaN values by the closest non-NaN value
    for i in is_nan:
        a, b = _get_bounds(i, lst=is_not_nan)
        if a is None:
            x[i] = x[b].clone()
        elif b is None:
            x[i] = x[a].clone()
        else:
            x[i] = abs(i - b) * x[a] / (b - a) + abs(i - a) * x[b] / (b - a)

    return x


def _get_bounds(i: int, lst: list[int]) -> tuple[int | None, int | None]:
    """Get the closest values in the list to the provided integer."""
    if i < lst[0]:
        return None, lst[0]
    if i > lst[-1]:
        return lst[-1], None
    for j in range(len(lst) - 1):
        if lst[j] <= i <= lst[j + 1]:
            return lst[j], lst[j + 1]
    raise ValueError("This should not happen!")


def print_mem(x: str) -> None:
    """Print current memory usage."""
    mem = torch.cuda.memory_reserved(0) / 1024**2
    total = torch.cuda.get_device_properties(0).total_memory / 1024**2
    print(f" - {x}: {int(mem)}/{int(total)} MiB ({mem/total*100:.2f}%)")


def time_func(func: Callable[[Any], Any]) -> None:
    """Time a function."""

    def wrapper(*args, **kwargs):
        """Inner wrapper for measuring the time."""
        # Time the function's execution
        start = time()
        result = func(*args, **kwargs)
        end = time()

        # Write away the time
        if func.__name__ not in TIME_CACHE:
            TIME_CACHE[func.__name__] = 0.0
        TIME_CACHE[func.__name__] += end - start
        return result

    return wrapper


def print_time(n: int | None = None, top_n: int | None = None) -> None:
    """
    Print the time cache.

    Parameters
    ----------
    n : int | None
        Number of samples the timing went over (default: None)
    top_n : int | None
        Number of top functions to print (default: None)
    """
    print("\n\n\nTime report:")
    print(TIME_CACHE)
    times = sorted(TIME_CACHE.items(), key=lambda x: x[1], reverse=True)
    for key, value in times[: top_n or len(times)]:
        print(f" - {key}: {value/(n or 1):.5f}s{'/sample' if n else ''}")
    print("\n\n\n")
