"""Main model class."""

from __future__ import annotations

import json
from pathlib import Path
from typing import Any, Iterator
from skimage.filters import sobel
from skimage.future import graph
from skimage import segmentation

import cv2
import numpy as np
import torch
from torch.nn.parameter import Parameter

from vito_lot_delineation.models.base_model import BaseModel
from vito_lot_delineation.models.semantic.utils import parse_model


class SemanticModel(BaseModel):
    """Model class for a segmentation model."""

    def __init__(
        self,
        model_tag: str,
        config_file: dict[str, Any],
        mdl_f: Path | None = None,
        threshold: float = 0.5,
        **kwargs: dict[str, Any],
    ) -> None:
        """
        Initialize configuration file.

        Parameters
        ----------
        model_tag : str
            Name of the model
        config_file : dict[str, Any]
            Model configuration, specifying model creation
        mdl_f : Path | None
            Folder where the model gets stored
        threshold : float
            Prediction threshold, by Default 0.5
        """
        super().__init__(
            model_tag=model_tag,
            mdl_f=mdl_f,
            config_file=config_file,
            threshold=threshold,
        )
        self.model: torch.nn.Module = parse_model(config_file["model"]).to(self.device)
        self.n_params: int = _count_parameters(list(self.parameters()))

    def __str__(self) -> str:
        """Representation of the model."""
        return f"{self.__class__.__name__} : {json.dumps(self.cfg, indent=2)}"

    def __repr__(self) -> str:
        """Representation of the model."""
        return str(self)

    def __call__(self, x: torch.Tensor) -> torch.IntTensor:
        """
        Make predictions on the provided batch of images.

        Parameters
        ----------
        x : torch.Tensor
            Input batch of shape (batch, channels, time, width, height)

        Returns
        -------
        torch.Tensor
            Field prediction of shape (batch, width, height) where:
                0: No fields
                1: Index of first field
                2: Index of second field
                ...
                N: Index of Nth field
        """
        result = self.forward_process(x)
        return self.post_process(result)

    def forward_process(self, x: torch.Tensor) -> torch.Tensor:
        """Obtain the raw model output over a batch of images.

        Parameters
        ----------
        x : torch.Tensor
            Input batch of shape (batch, channels, time, width, height)

        Returns
        -------
        torch.Tensor
            Raw model output of shape (batch, channels, width, height)
        """
        return self.model(x)

    def post_process(
        self,
        output: torch.Tensor,
        dilate_by: int = 1,
        rm_smaller_than: int = 10,
    ) -> torch.IntTensor:
        """Transform model output into instance segmentation output.

        Parameters
        ----------
        output : torch.Tensor
            Raw model output
        dilate_by : int
            Expand instance segmentation predictions to compensate over watersheds loss,
            defines dliation iterations. By default 1.
        rm_smaller_than : int
            Remove small fields from predictions
        """
        output = output.squeeze(dim=1).detach().cpu().numpy()

        # transform to instance
        batched_labelled = np.array([_apply_felzenswalb(x, rm_smaller_than) for x in output])

        # expand predictions
        batched_labelled = [
            cv2.dilate(x.astype(np.float32), np.ones((3, 3)), iterations=dilate_by)
            for x in batched_labelled
        ]

        # remove small fields
        for labelled in batched_labelled:
            # Get the ID counts of non-zero labels
            ids, counts = np.unique(labelled, return_counts=True)
            ids, counts = ids[1:], counts[1:]

            # Calculate the removal mask
            mask = np.where(
                np.isin(labelled, ids[counts < rm_smaller_than]), True, False
            )

            # Remove the small fields
            labelled[mask] = 0.0

        return torch.IntTensor(np.array(batched_labelled)).to(self.device)

    def train(self) -> None:
        """Set model to train mode."""
        self.model.train()

    def eval(self) -> None:  # noqa: A003
        """Set model to eval mode."""
        self.model.eval()

    def to(self, device: str) -> None:
        """Put the model on the requested device."""
        self.device = device
        self.model.to(self.device)

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        """Get parameters from encoder and decoder."""
        return self.model.parameters(recurse=recurse)

    def save(self) -> None:
        """Save model."""
        # Save model additional variables
        with open(self.model_folder / "config_file.json", "w") as f:
            json.dump(self.cfg, f, indent=2)
        with open(self.model_folder / "mdl_metadata.json", "w") as f:
            json.dump(
                {
                    "n_params": self.n_params,
                    "threshold": self._threshold,
                },
                f,
                indent=2,
            )

        # Save model weights
        model_dir = self.model_folder / "modules"
        model_dir.mkdir(exist_ok=True, parents=True)
        torch.save(self.model, model_dir / "weights.pth")

    @classmethod
    def load(cls, mdl_f: Path) -> SemanticModel:
        """Load model from model folder."""
        # get metadata
        with open(mdl_f / "mdl_metadata.json") as f:
            metadata = json.load(f)

        # load config_file name
        with open(mdl_f / "config_file.json") as f:
            cfg_file = json.load(f)

        # load model
        mdl = cls(
            model_tag=mdl_f.name,
            config_file=cfg_file,
            mdl_f=mdl_f.parent,
            **metadata,
        )
        device = "cuda" if torch.cuda.is_available() else "cpu"
        mdl.model = torch.load(
            mdl_f / "modules" / "weights.pth",
            map_location=torch.device(device),
        )
        mdl.eval()
        return mdl

def _apply_felzenswalb(
    semantic_output: torch.Tensor, rm_smaller_than: int = 10
):
    """Apply felzenswalb segmentation to the semantic output.
    Code by Kasper from VITO"""
    # Calculate the edges using sobel
    edges = sobel(semantic_output)

    # Perform felzenszwalb segmentation
    segments = np.array(
        segmentation.felzenszwalb(
            semantic_output, scale=1, channel_axis=None, sigma=0.0, min_size=rm_smaller_than,
        )
    ).astype(int)

    # Perform the rag boundary analysis and merge the segments
    if np.max(segments) == 0:
        return segments

    g = graph.rag_boundary(segments, edges)
    mergedSegments = graph.cut_threshold(segments.astype(int), g, 0.15, in_place=False)

    # Add 1 to mergedsegments otherwise the zero ID will be ignored
    mergedSegments += 1

    # Felzenswalb does delineate fields in zones with no prob of a field
    # putting zero probabilty to nan does give other artefacts
    # therefore as post-processing the id with the highest occurance should
    # be removed
    unique_seg, counts_seg = np.unique(mergedSegments, return_counts=True)
    idx_max = np.argmax(counts_seg)
    id_remove = unique_seg[idx_max]
    mergedSegments[mergedSegments == id_remove] = 0

    return mergedSegments

def _count_parameters(params: list[Parameter] | Parameter) -> int:
    """Count the parameters."""
    if isinstance(params, list):
        return sum(_count_parameters(p) for p in params)
    return len(params.flatten())


if __name__ == "__main__":
    from vito_lot_delineation.data import DelineationDataset

    # Load model
    model = SemanticModel.load(Path("data/models/20230524T102647-ResUnet3D_ts16_bigger"))

    # Dataset
    data_test = DelineationDataset(
        split="testing",
        data_dir=Path("data/data"),
        bands=model.cfg["input"]["bands"],
        n_ts=model.cfg["input"]["n_ts"],

    )

    # Predict
    sample = data_test.get_batch(batch_size=2)
    semantic = model.forward_process(sample["input"].to(model.device))
    instance = model(sample["input"].to(model.device))
