"""Evaluation utilisation functions."""


from __future__ import annotations

import itertools
import json
from pathlib import Path
from typing import List

import numpy as np
from numpy.typing import NDArray
from pydantic import BaseModel, Field

from vito_cropsar.constants import PRECISION_FLOAT_NP, PRECISION_INT_NP, get_data_folder
from vito_cropsar.data import apply_mask, load_sample
from vito_cropsar.vito_logger import Logger


class Sample(BaseModel):
    """Preprocessed sample."""

    target: NDArray[PRECISION_FLOAT_NP] = Field(
        ...,
        description="Target image (unmasked)",
    )
    target_mask: NDArray[PRECISION_FLOAT_NP] = Field(
        ...,
        description="Target image (masked)",
    )
    mask: NDArray[PRECISION_INT_NP] = Field(
        ...,
        description="Mask applied on the input S2 data (actual mask with artificial obscurene)",
    )
    mask_original: NDArray[PRECISION_INT_NP] = Field(
        ...,
        description="Original mask of the input S2 data (actual mask)",
    )
    s1: NDArray[PRECISION_FLOAT_NP] = Field(
        ...,
        description="Sentinel-1 input data",
    )
    s2: NDArray[PRECISION_FLOAT_NP] = Field(
        ...,
        description="Sentinel-2 input data",
    )
    bands_s1: List[str] = Field(  # noqa: UP006
        ...,
        description="Sentinel-1 bands used",
    )
    bands_s2: List[str] = Field(  # noqa: UP006
        ...,
        description="Sentinel-2 bands used",
    )

    class Config:
        """Configure the sample class."""

        arbitrary_types_allowed = True


def prepare_sample(
    tile: str,
    bands_s1: list[str],
    bands_s2: list[str],
    n_ts: int,
    resolution: int,
    sample_s1: bool,
    data_f: str | None = None,
    align: bool = True,
    smooth_s1: bool = True,
    cache_tag: str | None = None,
) -> Sample:
    """
    Prepare the requested sample.

    Parameters
    ----------
    tile : str
        Tile to load (must be from the testing split)
    bands_s1 : list[str]
        Sentinel-1 bands to use
    bands_s2 : list[str]
        Sentinel-2 bands to use
    n_ts : int
        Number of time series to plot
        Note: This must overlap with how the model was trained
        Note: For consistency, the center cut will be taken
    resolution : int
        Resolution (width/height) of each patch (square)
        Note: This must overlap with how the model was trained
        Note: For consistency, the center cut will be taken
    sample_s1 : bool
        Whether to use the most present ascending or descending S1 band
        Note: This must overlap with how the model was trained
        Note: This will halve the number of S1 bands
    data_f : str
        Folder from which to get data from
    align : bool
        Whether to align the time series (jitter removal in S2)
    smooth_s1 : bool
        Whether to smooth the S1 time series
    cache_tag : str | None
        Cache postfix to check for cached samples

    Returns
    -------
    dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]]
        Metrics dictionary
    """
    # get data folder
    data_f = data_f or get_data_folder()

    # get sample
    sample = load_sample(
        split="testing",
        tile=tile,
        data_f=data_f,
        align=align,
        smooth_s1=smooth_s1,
        cache_tag=cache_tag,
    )

    # Extract the right bands
    sample = _extract_bands(sample, bands_s1=bands_s1, bands_s2=bands_s2)

    # Cut the time series
    sample = _cut_ts(sample, n_ts=n_ts)

    # Cut the resolution
    sample = _cut_resolution(sample, resolution=resolution)

    # Sample the S1 band
    sample = _sample_s1(sample, sample_s1=sample_s1)

    # Extract the masks
    mask = sample["mask"].astype(np.bool_)
    mask_ = sample["mask_"].astype(np.bool_)

    # Manipulate the input (merged --> 0 if any mask has cloud, 1 if both visible)
    s2 = sample["s2"].copy()
    mask_merged = mask & mask_  # type: ignore[operator]
    s2 = apply_mask(s2, mask_merged)

    # Define the targets
    target = sample["s2"].copy()
    target = apply_mask(target, mask=mask)  # No info, so nothing to evaluate
    target_mask = sample["s2"].copy()
    target_mask = apply_mask(
        target_mask, mask=~mask | ~mask_
    )  # NaN if no info or visible in artificial input

    # Forward the required attributes, incl. the repeated masks
    return Sample(
        target=target,
        target_mask=target_mask,
        mask=mask_merged,
        mask_original=mask,
        s1=sample["s1"],
        s2=s2,
        bands_s1=sample["bands_s1"],
        bands_s2=sample["bands_s2"],
    )


def _extract_bands(
    sample: dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]],
    bands_s1: list[str],
    bands_s2: list[str],
) -> dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]]:
    """
    Extract the requested bands.

    Parameters
    ----------
    sample : dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]]
        Sample to extract from
    bands_s1 : list[str]
        Sentinel-1 bands to use
    bands_s2 : list[str]
        Sentinel-2 bands to use

    Returns
    -------
    dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]]
        Sample with extracted bands
    """
    sample["s1"] = sample["s1"][
        :, [i for i, b in enumerate(sample["bands_s1"]) if b in bands_s1]
    ]
    sample["s2"] = sample["s2"][
        :, [i for i, b in enumerate(sample["bands_s2"]) if b in bands_s2]
    ]
    sample["bands_s1"] = bands_s1
    sample["bands_s2"] = bands_s2
    return sample


def _cut_ts(
    sample: dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]], n_ts: int
) -> dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]]:
    """
    Cut the time series.

    Parameters
    ----------
    sample : dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]]
        Sample to cut
    n_ts : int
        Number of time series to plot
        Note: This must overlap with how the model was trained
        Note: For consistency, the center cut will be taken

    Returns
    -------
    dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]]
        Sample with cut time series
    """
    n_ts_total = sample["s2"].shape[0]
    if n_ts_total > n_ts:
        # Cut the time series
        n_ts_cut = (n_ts_total - n_ts) // 2
        sample["s2"] = sample["s2"][n_ts_cut:-n_ts_cut]
        sample["s1"] = sample["s1"][n_ts_cut:-n_ts_cut]
        sample["mask"] = sample["mask"][n_ts_cut:-n_ts_cut]
        sample["mask_"] = sample["mask_"][n_ts_cut:-n_ts_cut]
    return sample


def _cut_resolution(
    sample: dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]], resolution: int
) -> dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]]:
    """
    Cut the resolution.

    Parameters
    ----------
    sample : dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]]
        Sample to cut
    resolution : int
        Resolution (width/height) of each patch (square)
        Note: This must overlap with how the model was trained
        Note: For consistency, the center cut will be taken

    Returns
    -------
    dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]]
        Sample with cut resolution
    """
    n_res_total = sample["s2"].shape[-1]
    if n_res_total > resolution:
        # Cut the resolution
        n_res_cut = (n_res_total - resolution) // 2
        sample["s2"] = sample["s2"][..., n_res_cut:-n_res_cut, n_res_cut:-n_res_cut]
        sample["s1"] = sample["s1"][..., n_res_cut:-n_res_cut, n_res_cut:-n_res_cut]
        sample["mask"] = sample["mask"][..., n_res_cut:-n_res_cut, n_res_cut:-n_res_cut]
        sample["mask_"] = sample["mask_"][
            ..., n_res_cut:-n_res_cut, n_res_cut:-n_res_cut
        ]
    return sample


def _sample_s1(
    sample: dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]],
    sample_s1: bool,
) -> dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]]:
    """
    Sample the S1 band.

    Parameters
    ----------
    sample : dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]]
        Sample to sample
    sample_s1 : bool
        Whether to use the most present ascending or descending S1 band
        Note: This must overlap with how the model was trained
        Note: This will halve the number of S1 bands

    Returns
    -------
    dict[str, NDArray[PRECISION_FLOAT_NP | PRECISION_INT_NP]]
        Sample with sampled S1 band
    """
    if sample_s1:
        s1, result = sample["s1"], []
        result.extend(
            (
                s1[:, i]
                if np.sum(np.isnan(s1[:, i])) <= np.sum(np.isnan(s1[:, i + 1]))
                else s1[:, i + 1]
            )[:, None]
            for i in range(0, s1.shape[1], 2)
        )
        sample["s1"] = np.concatenate(result, axis=1)
    return sample


def write_result(
    result: dict[str, dict[str, float]],
    write_f: Path,
    logger: Logger | None = None,
) -> None:
    """
    Write away the result.

    Parameters
    ----------
    result : dict[str, dict[str, float]]
        Generated result for each evaluation tile (key)
    write_f : Path
        File (+path) to write result to
    logger : Logger
        Log the evaluation, if logger is provided
    """
    assert write_f.suffix == ".json"

    # gather keys
    tiles = list(result.keys())
    bands = list(result[tiles[0]])
    metrics = list(result[tiles[0]][bands[0]])

    # register summary
    result["summary"] = {band: {metric: 0.0 for metric in metrics} for band in bands}
    for band, metric in itertools.product(bands, metrics):
        scores = [result[tile][band][metric] for tile in tiles]
        result["summary"][band][metric] = sum(scores) / len(scores)

    # Print out evaluation results
    if logger is not None:
        logger("")
        logger("Evaluation summary:")
        for band in bands:
            logger(f" - {band}:")
            for metric, score in result["summary"][band].items():
                logger(f"   - {metric}: {score:.5f}")
        logger("")

    # Write away the result
    with open(write_f, "w") as f:
        json.dump(result, fp=f, indent=2)


if __name__ == "__main__":
    sample = prepare_sample(
        tile="31UDS_4096_4352_7424_7680_2020-08-11",
        bands_s1=["s1_asc_vv", "s1_dsc_vv", "s1_asc_vh", "s1_dsc_vh"],
        bands_s2=["s2_fapar"],
        n_ts=32,
        resolution=128,
        sample_s1=True,
    )
    print("bands_s1", sample.bands_s1)
