"""Data scaler class."""

from __future__ import annotations

import json
import re
from pathlib import Path
from typing import Any

import numpy as np
from numpy.typing import NDArray

from vito_cropsar.constants import PRECISION_FLOAT_NP
from vito_cropsar.vito_logger import LogLevel, bh_logger


class Scaler:
    """Data scaler class."""

    def __init__(self, cfg: dict[str, Any], warnings: bool = True) -> None:
        """
        Initialise the scaler.

        Parameters
        ----------
        cfg : dict[str, Any]
            Configuration dictionary
             - bands_s1 : list[str]
                List of Sentinel-1 bands to scale
             - bands_s2 : list[str]
                List of Sentinel-2 bands to scale
             - clip : bool
                Whether to clip within the scaled range (default: True)
             - v_min : float
                Minimum scaling range (default: -1.0)
             - v_max : float
                Maximum  scaling range (default: 1.0)
             - sample_s1 : bool
                Select only one band between ascending and descending in vv and vh (default: True)
            - ranges : dict[str, tuple[float, float]]
                Dictionary of ranges for each band (default: {})
        warnings : bool
            Whether to print warnings
        """
        self.bands_s1 = sorted(
            cfg["bands_s1"], key=lambda x: tuple(reversed(x.split("_")))
        )
        self.bands_s2 = sorted(
            cfg["bands_s2"], key=lambda x: tuple(reversed(x.split("_")))
        )
        self.bands = self.bands_s1 + self.bands_s2
        self.clip = parse_config(
            "clip",
            cfg=cfg,
            default=True,
            warnings=warnings,
        )
        self.v_min = parse_config(
            "v_min",
            cfg=cfg,
            default=-1.0,
            warnings=warnings,
        )
        self.v_max = parse_config(
            "v_max",
            cfg=cfg,
            default=1.0,
            warnings=warnings,
        )
        self.sample_s1 = parse_config(
            "sample_s1",
            cfg=cfg,
            default=True,
            warnings=warnings,
        )
        self.ranges: dict[str, tuple[float, float]] = parse_config(
            "ranges",
            cfg=cfg,
            default={},
            warnings=warnings,
        )

    def __call__(
        self,
        s1: NDArray[PRECISION_FLOAT_NP] | None = None,
        s2: NDArray[PRECISION_FLOAT_NP] | None = None,
        reverse: bool = False,
        inference: bool = False,
        safe: bool = True,
    ) -> tuple[NDArray[PRECISION_FLOAT_NP], NDArray[PRECISION_FLOAT_NP]]:
        """
        Scale the specified sample.

        Parameters
        ----------
        s1 : NDArray[PRECISION_FLOAT_NP]
            Sentinel-1 data, of shape (time, channels, width, height)
        s2 : NDArray[PRECISION_FLOAT_NP]
            Sentinel-2 data, of shape (time, channels, width, height)
        reverse : bool
            Whether to reverse the scaling operation
        inference : bool
            Whether to scale the data for inference
            Note: This will merge S1 ascending and descending bands if `sample_s1` is True
        safe : bool
            Whether to perform the scaling in a safe way
            Note: If safe=False, it's possible that the parameters fed to the scaler will change

        Note
        ----
        This method can manipulate the sample in-place.

        Returns
        -------
        tuple[NDArray[PRECISION_FLOAT_NP], NDArray[PRECISION_FLOAT_NP]
            Scaled Sentinel-1 and Sentinel-2 data
        """
        assert self.ranges != {}, "Scaler not trained yet"

        # Merge S1 if necessary
        bands_s1 = (
            self.get_s1_merged() if inference and self.sample_s1 else self.bands_s1
        )

        # Scale Sentinel-1 data
        if s1 is not None:
            s1_ = s1.copy() if safe else s1
            assert (
                len(bands_s1) == s1_.shape[1]
            ), f"Scaler bands ({len(bands_s1)}) != s1 bands ({s1_.shape[1]})"
            for idx, band in enumerate(bands_s1):
                s1_[:, idx] = self._scale(
                    x=s1_[:, idx], v_range=self.ranges[band], reverse=reverse
                )
        else:
            s1_ = None

        # Scale Sentinel-2 data
        if s2 is not None:
            s2_ = s2.copy() if safe else s2
            assert (
                len(self.bands_s2) == s2_.shape[1]
            ), f"Scaler bands ({len(self.bands_s2)}) != s1 bands ({s2_.shape[1]})"
            for idx, band in enumerate(self.bands_s2):
                s2_[:, idx] = self._scale(
                    x=s2_[:, idx], v_range=self.ranges[band], reverse=reverse
                )
        else:
            s2_ = None

        return s1_, s2_

    def _scale(
        self,
        x: NDArray[PRECISION_FLOAT_NP],
        v_range: tuple[float, float],
        reverse: bool,
    ) -> NDArray[PRECISION_FLOAT_NP]:
        """
        Scale the provided array to the specified range.

        Note
        ----
         - Ugly function for a reason! Faster than creating vectors and doing matrix multiplication

        Parameters
        ----------
        x : NDArray[PRECISION_FLOAT_NP]
            Array to scale
        v_range : tuple[float, float]
            Range to scale to/from
        reverse : bool
            Whether to reverse the scaling operation

        Returns
        -------
        NDArray[PRECISION_FLOAT_NP]
            Scaled array
        """
        if reverse:
            x -= self.v_min
            x /= self.v_max - self.v_min
            x *= v_range[1] - v_range[0]
            x += v_range[0]
        else:
            x -= v_range[0]
            x /= v_range[1] - v_range[0]
            x *= self.v_max - self.v_min
            x += self.v_min
            if self.clip:
                x = np.clip(x, self.v_min, self.v_max)

        return x

    def __repr__(self) -> str:
        """Return the representation of the scaler."""
        return f"{self.__class__.__name__}({', '.join(sorted(self.bands))})"

    def __str__(self) -> str:
        """Return the string representation of the scaler."""
        return self.__repr__()

    def train(
        self,
        data: dict[str, list[NDArray[PRECISION_FLOAT_NP]]],
        perc: float = 0.001,
        *args,
        **kwargs,
    ) -> None:
        """
        Train the scaler on the provided sample.

        Parameters
        ----------
        data : dict[str, list[NDArray[PRECISION_FLOAT_NP]]]
            Training data containing band values for each of the specified bands
        perc : float
            Percentage of data to ignore when calculating the ranges
        """
        # Import training specific packages
        from tqdm import tqdm

        # Check if all the required data is provided
        assert set(data.keys()) == set(
            self.bands
        ), f"Not all bands are present in the provided data (received: {set(data.keys())}, needed: {set(self.bands)})"

        # Calculate the ranges
        for k in tqdm(self.bands, desc="Calculating ranges"):
            v = np.concatenate(data[k])
            self.ranges[k] = (
                np.quantile(v, q=perc),
                np.quantile(v, q=1 - perc),
            )

    def show_ranges(self) -> None:
        """Show the ranges of the scaler."""
        for k in self.bands:
            print(f"{k}: [{self.ranges[k][0]}..{self.ranges[k][1]}]")

    def get_s1_merged(self) -> list[str]:
        """Get the S1 bands if ascending and descending are merged."""
        return [re.sub("asc_", "", band) for band in self.bands_s1 if "asc_" in band]

    def get_metadata(self) -> dict[str, Any]:
        """Return the metadata of the scaler."""
        return {
            "bands_s1": self.bands_s1,
            "bands_s2": self.bands_s2,
            "bands": self.bands,
            "clip": self.clip,
            "v_min": self.v_min,
            "v_max": self.v_max,
            "sample_s1": self.sample_s1,
            "ranges": self.ranges,
        }

    def save(self, path: Path) -> None:
        """Save the scaler's metadata file to the specified path."""
        path.mkdir(parents=True, exist_ok=True)
        with open(path / "scaler_metadata.json", "w") as f:
            json.dump(self.get_metadata(), f, indent=2)

    @classmethod
    def load(
        cls,
        path: Path | None = None,
        bands_s1: list[str] | None = None,
        bands_s2: list[str] | None = None,
        sample_s1: bool | None = None,
    ) -> Scaler:
        """Load the scaler from the specified path.

        Parameters
        ----------
        path : Path
            Path to where the scaler metadata file is stored
        bands_s1 : list[str]
            Sentinel-1 bands to use in loaded scaler (all by default)
        bands_s2 : list[str]
            Sentinel-2 bands to use in loaded scaler (all by default)
        sample_s1 : bool
            Select only one band between ascending and descending in vv and vh (True by default)

        Returns
        -------
        Scaler
            Loaded scaler
        """
        path = path or Path(__file__).parent / "configs"
        assert path.is_dir(), f"Path {path} is not a directory"

        # Load the metadata
        with open(path / "scaler_metadata.json") as f:
            metadata = json.load(f)

        # Set the bands, if specified
        if bands_s1:
            metadata["bands_s1"] = bands_s1
        if bands_s2:
            metadata["bands_s2"] = bands_s2
        if sample_s1:
            metadata["sample_s1"] = sample_s1

        # Create the scaler
        scaler = cls(cfg=metadata, warnings=False)
        return scaler


def parse_config(
    k: str,
    cfg: dict[str, Any],
    default: Any = "_NONE",
    warnings: bool = True,
) -> Any:
    """Parse the configuration."""
    if k in cfg:
        return cfg[k]
    if default != "_NONE":
        if warnings:
            bh_logger(
                f"Value for '{k}' not found in configuration, using default value: '{default}'",
                LogLevel.WARNING,
            )
        return default
    raise KeyError(f"Configuration key {k} not found.")


if __name__ == "__main__":
    scaler = Scaler.load()
    scaler.show_ranges()
