"""Utilisation functions shared across the models."""

from __future__ import annotations

import json
from pathlib import Path
from typing import Any

import numpy as np
import torch
from numpy.typing import NDArray

from vito_cropsar.constants import PRECISION_FLOAT_NP


class JSONEncoder(json.JSONEncoder):
    """Custom JSON encoder."""

    def default(self, obj: Any) -> Any:
        """JSON encoder."""
        if isinstance(obj, Path):
            return str(obj)
        return super().default(obj)


def interpolate_time(
    arr: torch.Tensor | NDArray[PRECISION_FLOAT_NP],
) -> torch.Tensor | NDArray[PRECISION_FLOAT_NP]:
    """
    Interpolation function.

    Parameters
    ----------
    arr : torch.Tensor | NDArray[PRECISION_FLOAT_NP]
        Array to interpolate
        Shape: (time, channels, width, height)

    Returns
    -------
    torch.Tensor | NDArray[PRECISION_FLOAT_NP]
        Interpolated array
        Shape: (time, channels, width, height)
    """
    if isinstance(arr, np.ndarray):
        return _interpolate(arr)
    if isinstance(arr, torch.Tensor):
        return torch.Tensor(_interpolate(arr.cpu().detach().numpy()))
    raise Exception(f"Type {type(arr)} not supported!")


def _interpolate(arr: NDArray[PRECISION_FLOAT_NP]) -> NDArray[PRECISION_FLOAT_NP]:
    """Interpolation function for numpy arrays."""
    _ts, ch, w, h = arr.shape
    result = np.zeros_like(arr)
    for c_idx in range(ch):  # Iterate over the channels
        for x in range(w):  # Iterate over the width
            for y in range(h):  # Iterate over the height
                t = arr[:, c_idx, x, y]
                if not np.isnan(t).all():
                    result[:, c_idx, x, y] = np.interp(
                        np.arange(len(t)),
                        np.arange(len(t))[~np.isnan(t)],
                        t[~np.isnan(t)],
                    )
    return result


if __name__ == "__main__":
    from time import time

    my_arr = np.zeros((3, 1, 2, 2))
    my_arr[:] = np.nan
    my_arr[0] = 0
    my_arr[-1] = 1
    print("Input array:")
    print(my_arr)

    my_result = _interpolate(my_arr)
    print("\nResult:")
    print(my_result)

    print("\nInterpolating shape '(10, 3, 128, 128)' 10 times in.. ", end="")
    start = time()
    for _ in range(10):
        my_arr = np.zeros((10, 3, 128, 128))
        my_arr[:] = np.nan
        my_arr[0] = 0
        my_arr[-1] = 1
        _ = _interpolate(my_arr)
    print(f"{time()-start:.2f} seconds ({(time()-start)/10:.2f} for frame)")
