"""
UDF that will convert the raster based delineation predictions to a vector format
"""

from openeo.udf import XarrayDataCube
from typing import Dict
import os
import threading
import sys


_threadlocal = threading.local()

def _setup_logging():
    global logger
    from loguru import logger

def vectorize(field_pred, **processing_opts):
    import numpy as np
    import xarray as xr
    from skimage import segmentation
    from skimage.filters import sobel
    from skimage.future import graph


    logger.info(f'Computing segmentation for shape {field_pred.shape}')

    # Store the original dimension order for later
    orig_dims = list(field_pred.dims)

    # Extract xarray.DataArray from the cube and drop the band dimension
    inarray= field_pred
    inimage = inarray.values#(inarray.values*inarray.values)/255
    inimage -= np.min(inimage)
    inimage = inimage * 249./np.max(inimage)
    image = np.clip((inimage-0.3*250)*2, 0., 249.)

    # compute edges
    edges = sobel(image)


    segment = np.array(segmentation.felzenszwalb(image, scale=1, sigma=0.,
                                                 min_size=30, multichannel=False)).astype(np.int32)
    # Perform the rag boundary analysis and merge the segments
    bgraph = graph.rag_boundary(segment, edges)
    # merging segments
    mergedsegment = graph.cut_threshold(segment, bgraph, 0.15, in_place=False)
    # segments start from 0, therefore the 0th has to be moved
    mergedsegment[mergedsegment == 0] = np.max(mergedsegment)+1
    # We currently take 0.3 as the binary threshold to distinguish between segments of fields and other segments.
    # This could definitely be improved and made more objective.
    # NOTE: new implementation uses scaled data, so threshold needs to be scaled as well!
    mergedsegment[image == 0] = 0
    #mergedsegment[image < 0.3 * 250] = 0
    mergedsegment[mergedsegment < 0] = 0
    #mergedsegment[mergedsegment > 0] = 200

    logger.info(f'Defined in total {len(np.unique(mergedsegment))-1} fields')

    vectorized_segment = mergedsegment.reshape(mergedsegment.shape[0], mergedsegment.shape[1])

    ## force the data to be in the same format as the prediction when in debug mode
    if processing_opts.get('run_local'):
        result_xr = xr.DataArray(vectorized_segment, coords=[
            field_pred.coords["x"][0: vectorized_segment.shape[-2]],
            field_pred.coords["y"][0: vectorized_segment.shape[-1]]]
                                 , dims=["x", "y"])
    else:
        result_xr = xr.DataArray(vectorized_segment, coords=[
            field_pred.coords["x"],
            field_pred.coords["y"]], dims=["x", "y"])

    # And make sure we revert back to original dimension order
    result_xr = result_xr.transpose(*orig_dims)

    return result_xr



def apply_datacube(cube: XarrayDataCube, context: Dict) -> XarrayDataCube:
    print(f'CUBE is {cube}')
    print(f'Context is {context}')
    os.environ['TORCH_HOME'] = context.get("cache_dir")
    os.environ['XDG_CACHE_HOME'] = context.get("cache_dir")
    sys.path.insert(0, 'tmp/venv_static')
    sys.path.insert(0, 'tmp/venv_static_del')
    sys.path.insert(0, 'tmp/venv_model')

    _setup_logging()

    # Extract xarray.DataArray from the cube
    inarr = cube.get_array().squeeze('bands', drop=True)
    logger.info(f'Input array opened with shape: {inarr.shape}')

    # # Run the delineation workflow
    vectorized = vectorize(inarr, **context)

    # Wrap result in an OpenEO datacube
    return XarrayDataCube(vectorized)



def load_vectorization_udf() -> str:
    import os
    with open(os.path.realpath(__file__), 'r+') as f:
        return f.read()