"""Evaluation scripts."""

from __future__ import annotations

from pathlib import Path
from time import time

import numpy as np
import torch
from numpy.typing import NDArray

from vito_cropsar.constants import (
    PRECISION_FLOAT_NP,
    S2_SCALES,
    get_data_folder,
)
from vito_cropsar.data import list_tiles
from vito_cropsar.evaluation.metric_artifact import get_artifacts_score
from vito_cropsar.evaluation.metric_scores import calculate_results
from vito_cropsar.evaluation.metric_ssim import get_ssim_score
from vito_cropsar.evaluation.plot_pixel_evolution import plot_pixel_evolution
from vito_cropsar.evaluation.plot_series import plot_series
from vito_cropsar.evaluation.utils import prepare_sample, write_result
from vito_cropsar.models import InpaintingBase


@torch.no_grad()
def main(
    model: InpaintingBase,
    n_samples: int | None = None,
    data_f: Path | None = None,
    cache_tag: str | None = None,
) -> None:
    """
    Make predictions using the provided model.

    Parameters
    ----------
    model : InpaintingBase
        Predictive inpainting model
    n_samples : int | None
        Number of samples to evaluate (all if None)
    data_f : Path | None
        Data folder location, use default location if None
    cache_tag : str | None
        Cache tag to use during evaluation
    """
    model.logger("")
    model.logger(f"Start evaluating model '{model}':")
    model.logger("")

    # get data folder
    data_f = data_f or get_data_folder()

    # Put model in evaluation mode
    model.eval()

    # Prepare for evaluation
    eval_f = model.model_folder / "evaluation"
    eval_f.mkdir(exist_ok=True, parents=True)
    series_f = eval_f / "series"
    series_f.mkdir(exist_ok=True, parents=True)
    pixel_f = eval_f / "pixel"
    pixel_f.mkdir(exist_ok=True, parents=True)
    tiles = list_tiles(data_f=data_f, split="testing", cache_tag=cache_tag)

    # Filter out tiles if n_samples is set
    if n_samples is not None:
        tiles = tiles[:n_samples]

    # Evaluate the model
    results, start = {}, time()
    for tile_idx, tile in enumerate(tiles):
        t_tile = float("inf") if tile_idx == 0 else (time() - start) / tile_idx
        model.logger(
            f"Evaluate tile: {tile} ({100*tile_idx/len(tiles):.2f}%, time remaining: {t_tile*(len(tiles)-tile_idx):.2f}s)"
        )

        # Load in the next sample to evaluate
        model.logger(" - Preparing the sample..")
        sample = prepare_sample(
            tile,
            data_f=data_f,
            bands_s1=model.scaler.bands_s1,
            bands_s2=model.scaler.bands_s2,
            n_ts=model.n_ts,
            resolution=model.resolution,
            sample_s1=model.scaler.sample_s1,
            align=model.align_input,
            smooth_s1=model.smooth_s1,
            cache_tag=cache_tag,
        )

        # Make the predictions on batched inputs
        model.logger(" - Predicting..")
        pred = model(s1=sample.s1, s2=sample.s2)

        # Write away the results
        model.logger(" - Calculating the metric scores..")
        results[tile] = _calculate_metrics(
            bands_s2=model.scaler.bands_s2,
            target=sample.target,
            pred=pred,
            mask=sample.mask,
            mask_original=sample.mask_original,
        )
        model.logger(f"   - MAE masked: {results[tile]['global']['mae_masked']:.3f}")
        model.logger(f"   - artifact: {results[tile]['global']['artifact']:.3f}")
        model.logger(f"   - SSIM: {results[tile]['global']['ssim']:.3f}")

        # Plot the result
        model.logger(" - Plotting the time-series..")
        plot_series(
            target=sample.target,
            inp=sample.s2,
            pred=pred,
            mask=sample.mask,
            s2_bands=model.scaler.bands_s2,
            save_f=series_f / f"{tile}.png",
        )

        # Plot pixel-level evolutions
        model.logger(" - Plotting the pixel evolutions..")
        plot_pixel_evolution(
            target=sample.target,
            inp=sample.s2,
            pred=pred,
            s2_bands=model.scaler.bands_s2,
            save_f=pixel_f / f"{tile}.png",
            band="s2_fapar"
            if "s2_fapar" in model.scaler.bands_s2
            else model.scaler.bands_s2[0],
        )

        model.logger("")

    # Create report
    write_result(results, write_f=eval_f / "report.json", logger=model.logger)


def _calculate_metrics(
    bands_s2: list[str],
    target: NDArray[PRECISION_FLOAT_NP],
    pred: NDArray[PRECISION_FLOAT_NP],
    mask: NDArray[PRECISION_FLOAT_NP],
    mask_original: NDArray[PRECISION_FLOAT_NP],
) -> dict[str, float]:
    """Calculate the metrics."""
    rescaled_pred = np.stack(
        [_resize(pred[:, i], band) for i, band in enumerate(bands_s2)],
        axis=1,
    )
    rescaled_target = np.stack(
        [_resize(target[:, i], band) for i, band in enumerate(bands_s2)],
        axis=1,
    )

    # compute metrics for each band
    results = {}
    for i, band in enumerate(bands_s2):
        results[band] = calculate_results(
            target=rescaled_target[:, i : i + 1],
            pred=rescaled_pred[:, i : i + 1],
            mask=mask,
        )
        results[band]["artifact"] = get_artifacts_score(
            target=rescaled_target[:, i : i + 1],
            pred=rescaled_pred[:, i : i + 1],
            mask=mask_original,
            mask_agg=mask,
        )

    # Compute SSIM for the bands
    ssim = get_ssim_score(
        target=rescaled_target,
        pred=rescaled_pred,
        mask=mask_original,
    )
    for i, band in enumerate(bands_s2):
        results[band]["ssim"] = ssim[i]

    # average all metrics for each band
    metrics = list(results[bands_s2[0]])
    results["global"] = {metric: 0.0 for metric in metrics}
    for metric in metrics:
        results["global"][metric] = sum(
            results[band][metric] for band in bands_s2
        ) / len(bands_s2)

    return results


def _resize(x: NDArray[PRECISION_FLOAT_NP], band: str) -> NDArray[PRECISION_FLOAT_NP]:
    """Resize the input to a [0..1] range."""
    v_min, v_max = S2_SCALES[band]
    x = np.clip(a=x, a_min=v_min, a_max=v_max)
    return (x - v_min) / (v_max - v_min)


if __name__ == "__main__":
    from vito_cropsar.constants import get_models_folder
    from vito_cropsar.models import InpaintingCnnTransformer

    # Create the evaluation
    main(
        model=InpaintingCnnTransformer.load(
            mdl_f=get_models_folder() / "20230601T104756_cnn_transformer",
        ),
        cache_tag="fapar_rgb",
    )
