#!/usr/bin/env python3

import math
import threading

import affine
import pyproj
import numpy
import scipy
import scipy.signal
import scipy.ndimage
import skimage.transform
import skimage.filters
import skimage.segmentation
import shapely.geometry


_threadlocal = threading.local()


def _get_tf_model():

    # Keras/tensorflow models are not guaranteed to be threadsafe,
    # but by loading and storing the model once per thread we should
    # be able to safely eliminate loading at model predict time

    tf_model = getattr(_threadlocal, 'tf_model', None)

    if tf_model is None:

        import io
        import pkgutil
        import h5py
        import tensorflow.keras.models

        # Load tensorflow model from in-memory HDF5 resource

        path = 'resources/fielddetectormodel.h5'

        data = pkgutil.get_data('fielddelineation', path)

        with h5py.File(io.BytesIO(data), mode='r') as h5:
            tf_model = tensorflow.keras.models.load_model(h5)

        # Store per thread

        _threadlocal.tf_model = tf_model

    return tf_model


def _utm_epsg(lon, lat):

    # Use longitude to determine zone 'band'

    zone = (math.floor((lon + 180.0) / 6.0) % 60) + 1

    # Use latitude to determine north/south

    if lat >= 0.0:
        epsg = 32600 + zone
    else:
        epsg = 32700 + zone

    return epsg


def _to_utm(lon, lat):

    epsg = _utm_epsg(lon, lat)

    crs_from = 'epsg:4326'
    crs_to   = 'epsg:' + str(epsg)

    tf = pyproj.Transformer.from_crs(crs_from,
                                     crs_to,
                                     always_xy=True)

    x, y = tf.transform(lon, lat)
    
    return (x, y), crs_to


def _gaussian_kernel(window):

    kernel = scipy.signal.windows.gaussian(window, std=window/6, sym=True)

    kernel = numpy.outer(kernel, kernel)
    kernel = kernel / kernel.sum()

    return kernel


def _proximity_mask(a, window, values, threshold):

    kernel = _gaussian_kernel(window)
    kernel = numpy.repeat(kernel[numpy.newaxis, ...], a.shape[0], axis=0)

    mask = numpy.isin(a, values)
    mask = scipy.signal.fftconvolve(mask, kernel, mode='same', axes=[1, 2])
    mask = mask > threshold

    return mask


def bbox_from_lonlat(lon, lat):

    # This generates a bounding box with (lon, lat) at the center.
    #
    # The box is 3280m x 3280m, which at 10m resolution should
    # accomodate a 128x128 pixel window + 100 pixel border on all
    # sides.
    #
    # Also returned is the CRS of the bounding box coordinates.

    # Round to the nearest UTM coordinate

    (x, y), crs = _to_utm(lon, lat)

    x, y = numpy.round((x, y)).astype(int)

    # Construct the bounding box

    m = (128//2 + 100) * 10

    return (x-m, y-m, x+m, y+m), crs


def predict(b04, b08, scl, transform, seed=0):

    # This function generates an image where each pixel value gives a 
    # indication of the likelihood [0-1] that that pixel is part of an 
    # agricultural parcel.  Parcel borders are intended to stand out 
    # with clear lines of low values.
    #
    # b04, b08 and scl are assumed to include a 100 pixel border for
    # cloud masking.  This border will be ignored for the rest of the
    # processing and will not be part of the final prediction image!
    # 
    # Image resolution should be 10m.

    rand = numpy.random.RandomState(seed=seed)

    model = _get_tf_model()

    # Create nodata/cloudmask

    mask1 = _proximity_mask(scl,  17, [0, 1, 3, 8, 9, 10, 11], 0.057)
    mask2 = _proximity_mask(scl, 201,       [3, 8, 9, 10, 11], 0.025)

    cloud_mask = mask1 | mask2

    # Remove border used for generating cloudmask

    cloud_mask = cloud_mask[:, 100:-100, 100:-100]

    b04 = b04[:, 100:-100, 100:-100]
    b08 = b08[:, 100:-100, 100:-100]
    scl = scl[:, 100:-100, 100:-100]

    # Adjust transform for removed border

    x0, y0 = transform * (0, 0)
    x1, y1 = transform * (100, 100)

    transform = affine.Affine.translation(x1-x0, y1-y0) * transform

    # Generate watermask

    water_mask = (scl == 6)

    # Find the 8 images with the most reliable data

    unreliable_counts = numpy.sum(cloud_mask | water_mask, axis=(1,2))

    sorted_indices = numpy.argsort(unreliable_counts)
    sorted_indices = sorted_indices[0:8]

    # Keep only those images

    cloud_mask = cloud_mask[sorted_indices]

    b04 = b04[sorted_indices]
    b08 = b08[sorted_indices]

    # Calculate NDVI

    ndvi = (b08 - b04) / (b08 + b04)

    # Remove nodata/cloudy pixels

    ndvi[cloud_mask] = 0

    # Make 3 predictions based on different input images

    predictions = []

    indices = numpy.arange(ndvi.shape[0])

    h = ndvi.shape[1]
    w = ndvi.shape[2]

    for i in range(3):

        # Choose 3 random images out of 8

        random_indices = rand.choice(indices, size=3, replace=False)

        # Convert to model data layout

        inputs = ndvi[random_indices, :, :]
        inputs = inputs.transpose([1, 2, 0])
        inputs = inputs.reshape((1, h*w, 3))
        
        # Run tensorflow model

        output = model.predict(inputs)
        output = output.reshape((h, w))

        predictions.append(output)

    # Final prediction is the median

    final = numpy.median(predictions, axis=0)
    
    return final, transform


def segment(prediction, transform, upscale=4):

    # Upscaling leads to less coarse segment edges

    if upscale > 1:
        prediction = skimage.transform.pyramid_expand(prediction, upscale)

        transform = transform * affine.Affine.scale(1/upscale)

    # Define maximum number of pixels for 'small objects'

    small = (2 * upscale) ** 2

    # This function uses watershed segmentation to partition a field 
    # delineation prediction image into field segments and returns an
    # image where each pixel value is either 0 or a field segment ID.
    
    # Simply negate the prediction image to get the 'elevation', so
    # field borders with low values become ridges (watersheds).

    elevation = -prediction

    # Use local maxima in the prediction image to find locations that
    # are probably in fields.  These mark the basins for the watershed
    # segmentation.

    maxima = skimage.morphology.h_maxima(prediction, 0.1).astype(bool)

    # Give each local maximum (field/basin) a distinct ID.

    labels, _ = scipy.ndimage.label(maxima)

    # Mark pixels that are almost certainly not part of a field.

    exclude = prediction < 0.1
    exclude = skimage.morphology.remove_small_objects(exclude, small)

    # Now combine these markers into one image, where 0 indicates that
    # a pixel should be 'flooded' and other values mark basin/fields.  
    # Value 1 is reserved for special 'exclude' basins that will be
    # set to 0 in the segmentation image later.

    markers = numpy.zeros_like(labels)
    markers[maxima] = 2 + labels[maxima]
    markers[exclude] = 1

    # Apply watershed segmentation, enabling 'watershed_line' will
    # generate 0-value pixels on the borders between fields.

    segmentation = skimage.segmentation.watershed(elevation, markers,
                                                  watershed_line=True)

    mask = skimage.morphology.remove_small_objects(segmentation, small)
    mask = mask.astype(bool)

    segmentation[~mask] = 0

    # Finally subtract 1, this will set the 'exclude' segments to 0,
    # just like the small ojects and watershed_line pixels, and make
    # 1 the lowest field-segment id.

    segmentation[segmentation >= 1] -= 1

    return segmentation, transform


def polygonize(segmentation, transform, close_edges=False, simplify_tolerance=0):

    # Returns a MultiPolygon containing polygons for all field segments
    # in a segmentation image.
    # 
    # Note that Polygons intersecting with the edge of the image are not
    # closed and by default not included.  Set 'close_edges' to True to
    # automatically close these polygons and include them in the output.
    # 
    # Optionally simplify polygons if 'simplify_tolerance' > 0,
    # (expressed in CRS units, so normally meter at this point).

    # Find contours for all non-zero pixels

    contours = skimage.measure.find_contours(segmentation, 0.5)
    polygons = []

    # Optionally remove polygons that are not closed (typically at the
    # edge of the image).  If not removed these will be closed 
    # automatically by shapely by simply connecting the first and
    # last vertex and will be included in the result.

    if not close_edges:
        contours = [c for c in contours if (c[0,:] == c[-1,:]).all()]

    # Contours are 'center of pixel', but coordinate transforms are
    # relative to the upper-left of the image, so translate by half
    # a pixel

    transform = transform * affine.Affine.translation(0.5, 0.5)

    for contour in contours:

        # Convert pixel to map coordinates

        contour = contour[:,[1,0]]
        contour = numpy.array(transform * contour.T).T

        # Convert to a polygon object

        polygon = shapely.geometry.Polygon(contour)
        polygon = polygon.simplify(simplify_tolerance)

        polygons.append(polygon)

    return shapely.geometry.GeometryCollection(polygons)


def isolate_center_segment(segmentation):

    # Returns a new segmentation image that contains only the 
    # field segment at the center of the image.

    dtype = segmentation.dtype

    y = segmentation.shape[0] // 2
    x = segmentation.shape[1] // 2

    if segmentation[x,y] != 0:
        segmentation = (segmentation == segmentation[x,y])
    else:
        segmentation = numpy.zeros_like(segmentation)
    
    return segmentation.astype(dtype)


def find_field_polygons(datacube):

    # Takes a datacube and return a MultiPolygon containing
    # all field polygons.
    
    b04 = datacube['B04']
    b08 = datacube['B08']
    scl = datacube['SCL']

    tf = datacube['profile']['transform']

    pred, tf = predict(b04, b08, scl, tf)
    seg, tf = segment(pred, tf, 4)
    poly = polygonize(seg, tf)

    return poly


def find_field_polygon_at_center(datacube):

    # Takes a datacube and return a MultiPolygon containing
    # just the center field polygon.

    b04 = datacube['B04']
    b08 = datacube['B08']
    scl = datacube['SCL']

    tf = datacube['profile']['transform']

    pred, tf = predict(b04, b08, scl, tf)
    seg, tf = segment(pred, tf, 4)
    seg = isolate_center_segment(seg)
    poly = polygonize(seg, tf)

    return poly
