"""Custom TensorBoard logger."""

from __future__ import annotations

import warnings
from functools import partial
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from numpy.typing import NDArray

from vito_cropsar.constants import PRECISION_FLOAT_NP
from vito_cropsar.data import Scaler
from vito_cropsar.evaluation.plot_pixel_evolution import plot_pixel_evolution
from vito_cropsar.evaluation.plot_series import plot_series


class TensorBoardLogger:
    """Custom TensorBoard logger."""

    def __init__(
        self,
        mdl_f: Path,
        n_plots: int,
        fill_nan: int,
        scaler: Scaler | None = None,
        logs_f: str = "logs",
    ) -> None:
        """
        Initialize TensorBoard logger.

        Parameters
        ----------
        mdl_f : Path
            Model folder in which the TensorBoard logs are stored
        n_plots : int
            Number of plots to show in TensorBoard
        fill_nan : int
            Value to fill NaNs with
        scaler : Scaler
            Scaler used to scale the data before plotting, defaults to None
        logs_f : str
            Folder in which the logs are stored (used in TensorBoard for naming)
        """
        from torch.utils.tensorboard.writer import SummaryWriter

        self._mdl_f = mdl_f
        self._writer = SummaryWriter(log_dir=str(self._mdl_f / logs_f))
        self.n_plots = n_plots
        self.fill_nan = fill_nan
        self.scaler = scaler

    def log_metrics(
        self,
        split: str,
        step: int,
        loss: float,
        **kwargs,
    ) -> None:
        """Write away the evaluation to TensorBoard."""
        assert split in {"train", "val"}
        self._writer.add_scalar(f"loss/{split}", loss, global_step=step)  # type: ignore[no-untyped-call]
        for k, v in kwargs.items():
            self._writer.add_scalar(f"{k}/{split}", v, global_step=step)  # type: ignore[no-untyped-call]

    def log_lr(
        self,
        step: int,
        lr: float,
    ) -> None:
        """Log the current learning rate of the scheduler."""
        assert isinstance(
            lr, float
        ), f"Learning rate must be a float (got {type(lr)} instead)."
        self._writer.add_scalar("lr", lr, global_step=step)  # type: ignore[no-untyped-call]

    def log_plots(
        self,
        step: int,
        preds: torch.Tensor,
        targets: torch.Tensor,
        inputs: torch.Tensor,
        masks: torch.Tensor,
        postfix: str = "",
    ) -> plt.Figure:
        """Plot a single stack of predictions."""
        save_f = self._mdl_f / "logs_im"
        save_f.mkdir(exist_ok=True, parents=True)

        # convert to numpy
        masks = masks.numpy()
        preds = preds.numpy()
        targets = targets.numpy()
        inputs = inputs.numpy()

        # plotting functions
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            for i, (prd, trg, inp, msk) in enumerate(
                zip(preds, targets, inputs, masks)
            ):
                # skip if we have enough plots
                if i >= self.n_plots:
                    continue

                # scale back values to original ranges
                if self.scaler is not None:
                    inp_ = inp.copy()
                    inp_[
                        np.repeat(msk[:, None, :, :], inp_.shape[1], axis=1) == 0
                    ] = torch.nan
                    inp_ = self.scaler(s1=None, s2=inp_, reverse=True, safe=False)[1]
                    prd_ = self.scaler(s1=None, s2=prd, reverse=True)[1]
                    trg_ = self.scaler(s1=None, s2=trg, reverse=True)[1]

                # plot and log
                self._log_series(
                    i,
                    step=step,
                    prd=prd_,
                    trg=trg_,
                    inp=inp_,
                    msk=msk,
                    save_f=save_f,
                    postfix=postfix,
                )
                self._log_pixel_evol(
                    i,
                    step=step,
                    prd=prd_,
                    trg=trg_,
                    inp=inp_,
                    save_f=save_f,
                    postfix=postfix,
                )

    def _log_series(
        self,
        i: int,
        step: int,
        prd: NDArray[PRECISION_FLOAT_NP],
        trg: NDArray[PRECISION_FLOAT_NP],
        inp: NDArray[PRECISION_FLOAT_NP],
        msk: NDArray[PRECISION_FLOAT_NP],
        save_f: Path,
        postfix: str = "",
    ) -> None:
        """Log predictions over a single stack.

        Parameters
        ----------
        i : int
            identifier of the stack
        step : int
            current training step
        prd : NDArray[PRECISION_FLOAT_NP]
            predicted stack
        trg : NDArray[PRECISION_FLOAT_NP]
            target stack
        inp : NDArray[PRECISION_FLOAT_NP]
            input stack
        msk : NDArray[PRECISION_FLOAT_NP]
            mask stack
        save_f : Path
            folder in which to save the plots
        postfix : str
            postfix to add to the plot name
        """
        writer_function = partial(
            self._writer.add_figure,
            tag=f"plot_series/stack_{i}",
            global_step=step,
        )
        plot_series(
            target=trg,
            inp=inp,
            pred=prd,
            mask=msk,
            s2_bands=self.scaler.bands_s2,
            save_f=save_f / f"series_{i}_step_{step:05d}{postfix}.png",
            writer_function=writer_function,
        )

    def _log_pixel_evol(
        self,
        i: int,
        step: int,
        prd: NDArray[PRECISION_FLOAT_NP],
        trg: NDArray[PRECISION_FLOAT_NP],
        inp: NDArray[PRECISION_FLOAT_NP],
        save_f: Path,
        postfix: str = "",
    ) -> None:
        """Log pixel evolution over time.

        Parameters
        ----------
        i : int
            identifier of the stack
        step : int
            current training step
        prd : NDArray[PRECISION_FLOAT_NP]
            predicted stack
        trg : NDArray[PRECISION_FLOAT_NP]
            target stack
        inp : NDArray[PRECISION_FLOAT_NP]
            input stack
        save_f : Path
            folder in which to save the plots
        postfix : str
            postfix to add to the plot name
        """
        writer_function = partial(
            self._writer.add_figure,
            tag=f"plot_pixel/stack_{i}",
            global_step=step,
        )
        plot_pixel_evolution(
            target=trg,
            inp=inp,
            pred=prd,
            s2_bands=self.scaler.bands_s2,
            save_f=save_f / f"pixel_{i}_step_{step:05d}{postfix}.png",
            writer_function=writer_function,
            band="s2_fapar"
            if "s2_fapar" in self.scaler.bands_s2
            else self.scaler.bands_s2[0],
        )
