"""Algorithm to forward arbitrary input shapes through a CropSAR InpaintingModel."""

from typing import Callable, Literal

import numpy as np
from numpy.typing import NDArray

from vito_cropsar.constants import PRECISION_FLOAT_NP
from vito_cropsar.models import InpaintingBase


def main(
    s2: NDArray[PRECISION_FLOAT_NP],
    s1: NDArray[PRECISION_FLOAT_NP],
    model: InpaintingBase,
    t_overlap: int = 4,
    xy_overlap: int = 20,
    transition: Literal["hard", "soft"] = "soft",
) -> NDArray[PRECISION_FLOAT_NP]:
    """
    Do a model prediction on a timestack with arbitrary size in time & space.

    This algorithm stitches together multiple predictions of the model to cover the whole input
    stack. Also applies a necessary overlap for proper stitching.

    Parameters
    ----------
    s2 : NDArray[PRECISION_FLOAT_NP]
        Sentinel-2 input stack
    s1 : NDArray[PRECISION_FLOAT_NP]
        Sentinel-1 input stack
    model : InpaintingBase
        Model to use for prediction
    t_overlap : int
        Time overlap between chunks
    xy_overlap : int
        Spatial overlap between chunks
    transition : Literal["hard", "soft"]
        Transition function to use for stitching


    Returns
    -------
    NDArray[PRECISION_FLOAT_NP]
        Predicted and stitched output stack

    Notes
    -----
     - We assume the model accepts square matrices (in space dimension)
     - not yet supported: images with smaller size than the model expects (would need padding strategy for that)
        s2: (T, C, X, Y)
        s1: (T, C, X, Y)
    """
    t_res_model = model.n_ts
    xy_res_model = model.resolution
    n_channels_out_model = model.n_channels_out

    # Check input stack size
    assert (
        t_res_model <= s2.shape[0]
    ), "error: time dimension of input stack must be >= model"
    assert (
        xy_res_model <= s2.shape[2]
    ), "error: first spatial dimension of input stack must be >= model"
    assert (
        xy_res_model <= s2.shape[3]
    ), "error: second spatial dimension of input stack must be >= model"

    t_weight_func = _get_weight_func(transition)
    x_weight_func = _get_weight_func(transition)
    y_weight_func = _get_weight_func(transition)

    t_res_source, _, x_res_source, y_res_source = s2.shape
    t_chunk_count = int(np.ceil((t_res_source - t_overlap) / (t_res_model - t_overlap)))
    x_chunk_count = int(
        np.ceil((x_res_source - xy_overlap) / (xy_res_model - xy_overlap))
    )
    y_chunk_count = int(
        np.ceil((y_res_source - xy_overlap) / (xy_res_model - xy_overlap))
    )

    t_overlap_avg = (
        ((t_chunk_count * t_res_model) - t_res_source) / (t_chunk_count - 1)
        if t_chunk_count > 1
        else 0
    )
    x_overlap_avg = (
        ((x_chunk_count * xy_res_model) - x_res_source) / (x_chunk_count - 1)
        if x_chunk_count > 1
        else 0
    )
    y_overlap_avg = (
        ((y_chunk_count * xy_res_model) - y_res_source) / (y_chunk_count - 1)
        if y_chunk_count > 1
        else 0
    )

    t_starts = np.floor(
        np.arange(t_chunk_count) * (t_res_model - t_overlap_avg)
    ).astype(np.int16)
    x_starts = np.floor(
        np.arange(x_chunk_count) * (xy_res_model - x_overlap_avg)
    ).astype(np.int16)
    y_starts = np.floor(
        np.arange(y_chunk_count) * (xy_res_model - y_overlap_avg)
    ).astype(np.int16)

    t_overlaps = np.pad(t_res_model - np.diff(t_starts), (1, 1))
    x_overlaps = np.pad(xy_res_model - np.diff(x_starts), (1, 1))
    y_overlaps = np.pad(xy_res_model - np.diff(y_starts), (1, 1))

    output_nvdi = np.zeros(
        (t_res_source, n_channels_out_model, x_res_source, y_res_source)
    )
    output_weights = np.zeros(
        (t_res_source, n_channels_out_model, x_res_source, y_res_source)
    )

    for t_start, t_overlap_low, t_overlap_high in zip(
        t_starts,
        t_overlaps[:-1],
        t_overlaps[1:],
    ):
        t_weight = _compose_weight(
            t_overlap_low, t_overlap_high, t_res_model, t_weight_func
        )

        for x_start, x_overlap_low, x_overlap_high in zip(
            x_starts,
            x_overlaps[:-1],
            x_overlaps[1:],
        ):
            x_weight = _compose_weight(
                x_overlap_low, x_overlap_high, xy_res_model, x_weight_func
            )

            for y_start, y_overlap_low, y_overlap_high in zip(
                y_starts,
                y_overlaps[:-1],
                y_overlaps[1:],
            ):
                y_weight = _compose_weight(
                    y_overlap_low, y_overlap_high, xy_res_model, y_weight_func
                )

                s2_chunk = s2[
                    t_start : t_start + t_res_model,
                    :,
                    x_start : x_start + xy_res_model,
                    y_start : y_start + xy_res_model,
                ]
                s1_chunk = s1[
                    t_start : t_start + t_res_model,
                    :,
                    x_start : x_start + xy_res_model,
                    y_start : y_start + xy_res_model,
                ]

                output_chunk = model(s2=s2_chunk, s1=s1_chunk)

                weight_t = np.broadcast_to(
                    t_weight.reshape(-1, 1, 1, 1), output_chunk.shape
                )
                weight_x = np.broadcast_to(
                    x_weight.reshape(1, 1, -1, 1), output_chunk.shape
                )
                weight_y = np.broadcast_to(
                    y_weight.reshape(1, 1, 1, -1), output_chunk.shape
                )

                # How the weights of time & space dimensions are combined:
                # output_weight = weight_t * weight_x * weight_y
                # output_weight = np.prod(np.stack(weight_t, weight_x, weight_y), axis=0) # product of weights
                output_weight = np.minimum(weight_x, weight_y) * weight_t
                # output_weight = np.maximum(weight_x, weight_y) * weight_t
                # output_weight = np.sqrt(weight_x**2 + weight_y**2) * weight_t
                # output_weight = (1-np.sqrt((1-weight_x)**2 + (1-weight_y)**2)) * weight_t
                # output_weight = weight_t

                output_nvdi[
                    t_start : t_start + t_res_model,
                    :,
                    x_start : x_start + xy_res_model,
                    y_start : y_start + xy_res_model,
                ] = output_nvdi[
                    t_start : t_start + t_res_model,
                    :,
                    x_start : x_start + xy_res_model,
                    y_start : y_start + xy_res_model,
                ] + (
                    output_chunk * output_weight
                )
                output_weights[
                    t_start : t_start + t_res_model,
                    :,
                    x_start : x_start + xy_res_model,
                    y_start : y_start + xy_res_model,
                ] = (
                    output_weights[
                        t_start : t_start + t_res_model,
                        :,
                        x_start : x_start + xy_res_model,
                        y_start : y_start + xy_res_model,
                    ]
                    + output_weight
                )

    return output_nvdi / output_weights


def _get_weight_func(
    transition: Literal["hard", "soft"] = "soft"
) -> Callable[[NDArray[PRECISION_FLOAT_NP]], NDArray[PRECISION_FLOAT_NP]]:
    """Retrieve the function to apply in the overlap regions."""
    return {
        "hard": lambda x: (x >= 0.5).astype(PRECISION_FLOAT_NP),  # noqa: PLR2004
        "soft": lambda x: x,
    }[transition]


def _compose_weight(
    overlap_low: int,
    overlap_high: int,
    res_model: int,
    weight_func: Callable[[NDArray[PRECISION_FLOAT_NP]], NDArray[PRECISION_FLOAT_NP]],
) -> NDArray[PRECISION_FLOAT_NP]:
    """Compose a weight vector, used to weight patches in overlap regions."""
    overlap_low_trunc = min(overlap_low, int(res_model // 2))
    overlap_high_trunc = min(overlap_high, int(res_model // 2))
    weight = np.concatenate(
        (
            weight_func((np.arange(overlap_low) + 0.5) / overlap_low)[
                :overlap_low_trunc
            ],
            weight_func(np.ones(res_model - overlap_low_trunc - overlap_high_trunc)),
            (1 - weight_func((np.arange(overlap_high) + 0.5) / (overlap_high)))[
                -overlap_high_trunc:
            ],
        )
    )

    return weight
    return weight
