"""Model constants."""

from __future__ import annotations

from functools import lru_cache
from pathlib import Path

import numpy as np
import torch

# Define the precision
PRECISION_FLOAT = torch.float32  # Spectral data
PRECISION_FLOAT_NP = np.float32
PRECISION_INT = torch.int8  # Masks
PRECISION_INT_NP = np.int8

# s2 scales for evaluation
S2_SCALES = {
    "s2_fapar": [0.0, 1.0],
    "s2_fcover": [0.0, 1.0],
    "s2_b02": [0.0, 0.3],
    "s2_b03": [0.0, 0.3],
    "s2_b04": [0.0, 0.3],
    "s2_b08": [0.0, 0.7],
    "s2_ndvi": [-0.08, 0.92],
}

# Supported bands
S1 = ["s1_asc_vv", "s1_des_vv", "s1_asc_vh", "s1_des_vh"]
S2 = ["s2_fapar", "s2_fcover", "s2_b02", "s2_b03", "s2_b04", "s2_b08", "s2_ndvi"]

# Possible dataset splits
SPLITS = ("training", "validation", "testing")

# Extra variables
NAN_INT_VALUE = 65535


@lru_cache
def get_root_folder() -> Path:
    """Load in the root folder."""
    folder = Path(__file__).parents[2] / "data"
    return folder


@lru_cache
def get_data_folder() -> Path:
    """Load in the data folder."""
    folder = get_root_folder() / "data"
    return folder


@lru_cache
def get_models_folder() -> Path:
    """Load in the model folder."""
    folder = get_root_folder() / "models"
    return folder
