"""Inference pipeline."""

from __future__ import annotations

from pathlib import Path

import torch

from vito_lot_delineation.models import BaseModel, load_model


def main(model: BaseModel, batch: torch.Tensor) -> dict[str, torch.Tensor] | None:
    """Run inference over a batch of samples.

    Parameters
    ----------
    model : BaseModel
        Model to use to run inference
    batch : torch.Tensor
        Batch of images, shape must be (B, CH, T, W, H)

    Returns
    -------
    torch.Tensor
        Model predictions.
    """
    res = {"semantic": None, "instance": None}
    batch_size = 8
    model.eval()
    for inp in _iterate(batch, batch_size):
        # model forward step
        inp = inp.to(model.device)
        semantic = model.forward_process(inp)
        instance = model.post_process(semantic)

        # process semantic
        if isinstance(semantic, list):
            semantic = model.combine_heads(semantic)
        semantic = semantic.squeeze().detach()

        # save results to dict if requested
        res["semantic"] = (
            torch.concatenate((res["semantic"], semantic))
            if res["semantic"] is not None
            else semantic
        )
        res["instance"] = (
            torch.concatenate((res["instance"], instance))
            if res["instance"] is not None
            else instance
        )

    return res


def _iterate(
    seq: torch.Tensor | list[Path], batch_size: int
) -> tuple[torch.Tensor | list[Path]]:
    """Create a tuple of batches to iterate over."""
    return (seq[pos : pos + batch_size] for pos in range(0, len(seq), batch_size))


if __name__ == "__main__":
    model = load_model(Path("data/models/20230608T161609-test"))
    print(f"model: {model}")

    # Use batch of tensors
    from vito_lot_delineation.data import DelineationDataset

    dataset = DelineationDataset(
        split="testing",
        data_dir=Path(__file__).parent.parent.parent.parent / "data/data",
        augment=False,
        return_distance=False,
        return_watersheds=False,
        bands=model.cfg["input"]["bands"],
        n_ts=model.cfg["input"]["n_ts"],
    )
    batch = dataset.get_batch(20)["input"]
    print(f"\nBatch shape: {batch.shape}")

    # Run inference
    preds = main(
        model=model,
        batch=batch,
    )
    print(f"\nSemantic Shape: {preds['semantic'].shape}")
    print(f"Instance Shape: {preds['instance'].shape}")
