"""Data preprocessing for inference."""

from __future__ import annotations

from typing import Any

import numpy as np
import torch

from vito_lot_delineation.data import Augmentor
from vito_lot_delineation.data_prepr.utils import (
    compute_ndvi,
    fill_and_interpolate,
    rerange,
    scale,
    transform,
)


def preprocess_raw(sample: dict[np.ndarray], cfg: dict[str, Any]) -> torch.Tensor:
    """Transform raw sample into torch.Tensor ready to be fed to a model.

    Parameters
    ----------
    sample : dict[np.ndarray]
        Raw sample
    cfg : dict[str, Any]
        Configuration dictionary used to preprocess sample

    Returns
    -------
    torch.Tensor
        Preprocessed sample
    """
    # preprocess sample
    bands = transform({k: sample[k] for k in sample if k not in ["parcelids", "CT"]})
    bands = compute_ndvi(bands)
    bands = scale(bands)
    bands = fill_and_interpolate(bands)
    bands = rerange(bands)
    bands = [
        np.nan_to_num(bands[k]) for k in cfg["bands"]
    ]  # dict of all bands -> list of input bands
    bands = torch.tensor(np.stack(bands))

    # extract model input
    augmentor: Augmentor = Augmentor(n_ts=cfg["n_ts"])
    res = augmentor.tscut({"input": bands}, center=True)["input"]
    return res


if __name__ == "__main__":
    data = np.load(
        "data/data_raw/E464N414_2021-1-1_2021-12-31/E464N414_0_128_0_128.npz"
    )

    x = preprocess_raw(sample=data, cfg={"bands": ["s2_b02", "s2_b03"], "n_ts": 4})
    print(x.shape)
