"""Utilisation functions."""

from __future__ import annotations

import random
from pathlib import Path
from typing import Any

import numpy as np
import torch

from vito_lot_delineation.data.constants import get_data_folder


def load_folder(split: str, data_f: Path | None = None) -> list[str]:
    """
    Load all the data paths under a specified folder.

    Parameters
    ----------
    split : str
        The folder (split) to load, any of "training", "testing", or "validation"
    data_f : Path | None
        The data folder to load from

    Returns
    -------
    paths : list[str]
        List of data sample paths
    """
    assert split in {
        "training",
        "testing",
        "testing_spain",
        "validation",
        "streamlit",
    }, f"Folder '{split}' not recognised"
    data_f = data_f or get_data_folder()
    assert (
        data_f / split
    ).is_dir(), f"Folder (Split) '{split}' not found under path '{data_f}'"
    lst = sorted((data_f / split).glob("*.npz"))
    random.Random(42).shuffle(lst)  # Fixed shuffling
    return lst


def load_file(
    cfg: dict[str:Any],
) -> tuple[torch.Tensor, torch.IntTensor, torch.IntTensor]:
    """
    Load in the data residing under the provided path.

    Parameters
    ----------
    cfg : dict[str:Any]
        File configuration

    Returns
    -------
    input : torch.Tensor
        Input tensor of shape (width, height, channels, time-steps) in torch.float32 format
    instance : torch.IntTensor
        Instance tensor of shape (width, height) in torch.int32 format where each field has a unique ID (0 if no field)
    extent : torch.IntTensor
        Extent tensor of shape (width, height) in torch.in32 format where each field has ID 1 and background has ID 0
    """
    # Load in the data, format:
    # band -> shape(width, height, channels, time-steps)
    array = np.load(cfg["path"])

    return (
        torch.stack([torch.Tensor(array[band]) for band in cfg["bands"]]),
        torch.IntTensor(array["instance"]).to(torch.int32),
        torch.IntTensor(array["extent"]).to(torch.float32),
    )


if __name__ == "__main__":
    my_paths = load_folder(split="testing", data_f=Path("data/data/"))
    my_input, my_instance, my_extent = load_file(
        {
            "path": my_paths[0],
            "bands": ["s2_ndvi", "s2_b02"],
        }
    )
    print(f"Input shape: {my_input.shape}")
    print(f"Instance shape: {my_instance.shape}")
    print(f"Extent shape: {my_extent.shape}")
