"""Post-processing functions for the Vito Lot Delineation model."""

import numpy as np
import torch
from skimage import segmentation
from skimage.filters import sobel
from skimage import graph


def apply_felzenswalb(semantic_output: torch.Tensor, rm_smaller_than: int = 10):
    """
    Apply felzenswalb segmentation to the semantic output.

    Code by Kasper from VITO
    """
    # Calculate the edges using sobel
    edges = sobel(semantic_output)

    # Perform felzenszwalb segmentation
    segments = np.array(
        segmentation.felzenszwalb(
            semantic_output,
            scale=1,
            channel_axis=None,
            sigma=0.0,
            min_size=rm_smaller_than,
        )
    ).astype(int)

    # Perform the rag boundary analysis and merge the segments
    if np.max(segments) == 0:
        return segments

    g = graph.rag_boundary(segments, edges)
    mergedsegments = graph.cut_threshold(segments.astype(int), g, 0.15, in_place=False)
    
    # recover background
    mergedsegments += 1
    ids, counts = np.unique(mergedsegments[semantic_output == 0], return_counts=True)
    mergedsegments[semantic_output == 0] = 0
    mergedsegments[mergedsegments == ids[np.argmax(counts)]] = 0

    return mergedsegments