import os
import numpy as np
import geopandas as gpd
from shapely import geometry
from skimage import segmentation
from skimage.filters import sobel
from skimage.future import graph

from rasterio.windows import Window
from rasterio.features import shapes
from rasterio import Affine
from math import ceil
import numpy
import gc
import rasterio
import pandas as pd
from itertools import chain
from functools import reduce
import operator
import glob
import uuid

'''
This is a legacy script that used to combine the delineation results (in vector representation) of parts of a sentinel tile.
It is kept in case of reviving full tile delineation with openeo.
'''

def assign_boundary(geom,total_bounds):
    #total_bounds order: minx, miny, maxx, maxy
    # TODO: define 0.5 as some sort of epsilon tolerance
    bnd=0
    geom_bounds=geom.bounds
    if geom_bounds[0]<=total_bounds[0]+0.5: bnd+=1
    if geom_bounds[1]<=total_bounds[1]+0.5: bnd+=2
    if geom_bounds[2]>=total_bounds[2]-0.5: bnd+=4
    if geom_bounds[3]>=total_bounds[3]-0.5: bnd+=8
    return bnd

def filter_small_areas(geodf, minArea_Ha, keep_bnd=False):
    ma=minArea_Ha*10000
    return geodf[(geodf['geometry'].area >= ma) | (keep_bnd and geodf['boundary']!=0)]

# for horizontal: left=top, right=bottom
def merge_polygons(leftdf,rightdf,is_vert):

    def split_geopandas_frame(df,eval_func):
        istrue=eval_func(df)
        dftrue=df[istrue]
        dffalse=df[~istrue]
        return (dftrue,dffalse)
      
    # collects the row indices of the touching geodata frames in l & r    
    # c is a pair of sets {left indices}{right indices}
    # for horizontal: left=top, right=bottom
    def build_merge_groups(l,r,lr,node,c=None,level=0):
        if level==0: c=(set(),set())
        result=c[lr]
        if lr==0: input=l
        else: input=r
        if not node in result:
            result.add(node)
            nodes=input.pop(node)
            for inode in nodes:
                build_merge_groups(l, r, (lr+1)%2, inode, c, level+1)
        return c

    # obvious no-op
    if leftdf is None or rightdf is None: return leftdf,rightdf
    if (leftdf.shape[0]==0) or (rightdf.shape[0]==0): return leftdf,rightdf

    # set constants based on vertical or horizontal edge to be sewed
    if is_vert:
        bndidx=(4,1)
        xtr=0.5
        ytr=0.0
    else:
        bndidx=(2,8)
        xtr=0.0
        ytr=-0.5
    
    # separate involved polygons
    (left,lcommon)= split_geopandas_frame(leftdf, lambda d: d['boundary']&bndidx[0]>0)
    (right,rcommon)=split_geopandas_frame(rightdf,lambda d: d['boundary']&bndidx[1]>0)
    left=left.copy()
    right=right.copy()
    
    # move left slightly bit closer, so overlap is guaranteed        
    left.geometry=left.geometry.translate(xtr,ytr,0.)
    
    # build index connections
    r={}
    for index, row in right.iterrows():
        nbs = left[left.geometry.overlaps(row['geometry'])].index.tolist()
        r[index]=nbs
    l={}
    for index, row in left.iterrows():
        nbs = right[right.geometry.overlaps(row['geometry'])].index.tolist()
        l[index]=nbs
        
    # move back before joining
    left.geometry=left.geometry.translate(-xtr,-ytr,0.)

    # match the indices left - right
    groups=[]
    while len(l)>0: groups.append(build_merge_groups(l, r, 0, list(l.keys())[0]))
    for inode in r.keys(): groups.append((set(),set([inode]))) 
        
    # merge the polygons
    left['group']=-1
    right['group']=-1
    for igrp in range(len(groups)):
        bnd=reduce(operator.or_,left.loc[groups[igrp][0],'boundary'],0) | reduce(operator.or_,right.loc[groups[igrp][1],'boundary'],0)
        bnd&=~bndidx[0]
        bnd&=~bndidx[1]
        for lindex in groups[igrp][0]:
            left.loc[lindex,'group']=igrp
            left.loc[lindex,'boundary']=bnd
        for rindex in groups[igrp][1]:
            right.loc[rindex,'group']=igrp
            right.loc[rindex,'boundary']=bnd
    mergedboundary=pd.concat([left,right],sort=False).dissolve('group',as_index=False).drop('group', axis=1)
    
    # return left and right
    return pd.concat([lcommon,mergedboundary],sort=False).reset_index(drop=True), rcommon.reset_index(drop=True)


###################################################
# Segmentation
###################################################
def raster_segmentation(image,profile,window,outFilePolygonized,keep_tmp_files):

    # window size to split to
    bs=1098 # 1024 # 1098
    # Set minimum area
    minArea = 0.25  # ha
    # pre-define nblocks

    # reduce number of blocks being computed
    # ONLY FOR DEBUGGING!!! keep nblocks=None and steps=(1,1) for normal run
    nblocks=None
    steps=(1,1)

    # Calculate the edges using sobel
    print('  Calculating edges ...')
    edges=sobel(image)

    # allocating image-global data
    if keep_tmp_files:
        segments=numpy.zeros((window.width,window.height),dtype=numpy.int32)
        mergedSegments=numpy.zeros((window.width,window.height),dtype=numpy.int32)
    ijblk=0
    
    if nblocks is None: nblocks=(int(ceil(window.width/bs)),int(ceil(window.height/bs)))
    polys = [[None for j in range(nblocks[1])] for i in range(nblocks[0])]
    
    print('  Calculating window-by-window ...')
    for i in range(0,nblocks[0],steps[0]):
        for j in range(0,nblocks[1],steps[1]):
#         for i in range(2):
#             for j in range(2):
            # get window
            ijblk+=1
            ijwin=Window(i*bs, j*bs, bs if i*bs+bs<=window.width else window.width-i*bs, bs if j*bs+bs<=window.height else window.height-j*bs)
            ijband=image[ijwin.row_off:ijwin.row_off+ijwin.height,ijwin.col_off:ijwin.col_off+ijwin.width]
            # only do work when block has data
            if ijband.min()!=int(profile['nodata']) or ijband.max()!=int(profile['nodata']):
                print("    processing win {}/{} {}".format(ijblk,nblocks[0]*nblocks[1],ijwin),end=' ')
                ijedge=edges[ijwin.row_off:ijwin.row_off+ijwin.height,ijwin.col_off:ijwin.col_off+ijwin.width]
                # Perform felzenszwalb segmentation
                ijsegment = np.array(segmentation.felzenszwalb(ijband, scale=1, sigma=0, min_size=30, multichannel=False)).astype(numpy.int32)
                # Perform the rag boundary analysis and merge the segments
                ijgraph = graph.rag_boundary(ijsegment, ijedge)
                # merging segments
                ijmergedsegment = graph.cut_threshold(ijsegment, ijgraph, 0.15, in_place=False)
                # We currently take 0.3 as the binary threshold to distinguish between segments of fields and other segments.
                # This could definitely be improved and made more objective.
                # NOTE: new implementation uses scaled data, so threshold needs to be scaled as well!
                ijmergedsegment[ijband < 0.3 * 250] = 0
                ijmergedsegment[ijband==int(profile['nodata'])] = 0
                ijmergedsegment[ijmergedsegment < 0] = 0
                ijmask=ijmergedsegment!=0
                # vectorize with shape
                results=list(
                    {'properties': { 'boundary': 0 }, 'geometry': s}
                    for (s, v) in shapes(ijmergedsegment,mask=ijmask,transform=profile['transform']*Affine.translation(i*bs, j*bs))
                )
                ijpolys=gpd.GeoDataFrame.from_features(results,crs=profile['crs'])
                if ijpolys.shape[0]>0:
                    # mark the boundary polygons
                    ijbounds=ijpolys.total_bounds
                    ijpolys['boundary']=ijpolys['geometry'].apply(lambda x: assign_boundary(x,ijbounds))
                    # removing non-boundary small areas
                    ijpolys=filter_small_areas(ijpolys,minArea,keep_bnd=True)
                polys[i][j]=ijpolys
                print("number of polys: {}".format(str(polys[i][j].shape[0])))
                if keep_tmp_files:
                    segments[ijwin.row_off:ijwin.row_off+ijwin.height,ijwin.col_off:ijwin.col_off+ijwin.width]=ijsegment
                    mergedSegments[ijwin.row_off:ijwin.row_off+ijwin.height,ijwin.col_off:ijwin.col_off+ijwin.width]=ijmergedsegment
                    
    if keep_tmp_files:        
        with rasterio.open(outFilePolygonized + '_image.tif','w',**profile) as dstimg: dstimg.write(image,1)
        profile.update(dtype=numpy.float64)
        with rasterio.open(outFilePolygonized + '_edges.tif','w',**profile) as dstedge: dstedge.write(edges,1)
        profile.update(dtype=numpy.int32)
        with rasterio.open(outFilePolygonized + '_segments.tif','w',**profile) as dstseg: dstseg.write(segments.astype(numpy.int32),1)
        profile.update(dtype=numpy.int32)
        with rasterio.open(outFilePolygonized + '_mergedSegments.tif','w',**profile) as dstmergseg: dstmergseg.write(mergedSegments.astype(numpy.int32),1)

    # explicit garabage collecting what is possible
    edges=None
    if keep_tmp_files:
        segments=None
        mergedSegments=None
    gc.collect()

    # save boundaries tile by tile if keep tmp
    if keep_tmp_files:
        for i in range(nblocks[0]):
            for j in range(nblocks[1]):
                polys[i][j][polys[i][j]['boundary']>0].to_file(outFilePolygonized + '_bnd_blocks_{}_{}.shp'.format(str(i),str(j)))

    # sew the polygons on the internal boundaries
    print("  Stitching internal boundaries")
    for i in range(nblocks[0]-1):
        for j in range(nblocks[1]):
            polys[i+0][j],polys[i+1][j]=merge_polygons(polys[i+0][j],polys[i+1][j],True)
    for i in range(nblocks[0]):
        for j in range(nblocks[1]-1):
            polys[i][j+0],polys[i][j+1]=merge_polygons(polys[i][j+0],polys[i][j+1],False)
        
    # join the polygons 
    allpolys=pd.concat(list(filter(lambda p: p is not None, chain.from_iterable(polys))))
    allpolys.drop('boundary', axis=1, inplace=True)
    
    # filter out small sizes (what remains on boundaries after bnd sewing)
    allpolys=filter_small_areas(allpolys,minArea) 
    
    # reindex for brewity
    allpolys=allpolys.filter(['geometry'])
    allpolys.reset_index(drop=True,inplace=True)
    
    # For now not returning the dataframe because rest of the code is very file based and needs some restructuring
    polys=None
    #allpolys=None
    gc.collect()
    return allpolys
    

###### USED UP TO HERE!!!!

###################################################
# Post processing shapefile
###################################################
def polygon_postproc(tile,imgBounds,outFilePolygonizedGpd):

    # Filter the result
    print('  Post-processing geopandas dataframe ...')

    # Get a bounding box of the entire file
    pointList = [[imgBounds[0], imgBounds[1]], [imgBounds[0], imgBounds[3]], [imgBounds[2], imgBounds[3]], [imgBounds[2], imgBounds[1]]]
    # Create polygon bounding box
    boundingPoly = geometry.Polygon([[p[0], p[1]] for p in pointList])
    # Convert to a GPD series
    boundingPolyGPD = gpd.GeoSeries(boundingPoly)
    # Set the correct CRS
    boundingPolyGPD.crs = outFilePolygonizedGpd.crs
    # Take an inward buffer of 4500m (there is 10k overlap between the tiles, but we need a bit overlap to be there)
    print('    Buffering bounding box (4.5km inward) ...')
    boundingPolyGPDBuffered = boundingPolyGPD.buffer(-4500)
    boundsBuffered = boundingPolyGPDBuffered.total_bounds

    # Generate a unique polygon ID
    # TODO: reindex?
    print('    Generating unique CODE_OBJ ...')
    #codeObj = random.sample(range(1000000, 9999999), len(outFilePolygonizedGpd))
    #codeObj = [tile + "_" + str(x) for x in codeObj]
    codeObj = [ tile + "_" + uuid.uuid4().hex.upper() for iCodeObj in range(len(outFilePolygonizedGpd)) ]
    outFilePolygonizedGpd['CODE_OBJ'] = codeObj

    # Subset the segments based on the buffered bounding box
    print('    Subsetting on buffered bounding box ...')
    outFilePolygonizedGpd = outFilePolygonizedGpd.cx[boundsBuffered[0]:boundsBuffered[2],
                            boundsBuffered[1]:boundsBuffered[3]]

    # # Now we throw away features that are not within Belgium
    # print('Subsetting over Belgium (this can take a while) ...')
    # belgiumShp = gpd.read_file(r'S:\01_DATA\BelgiumCropMap\BelgiumAdm\BEL_adm0.shp').to_crs({'init': 'epsg:32631'})
    # inBelgium = outFilePolygonizedGpd.intersects(belgiumShp.unary_union)
    # outFilePolygonizedGpd = outFilePolygonizedGpd[inBelgium]

    # Drop the DN column
    outFilePolygonizedGpd = outFilePolygonizedGpd[['CODE_OBJ', 'geometry']]

    # Write to final shapefile
    outFilePolygonizedGpd.reset_index(drop=True,inplace=True)
    return outFilePolygonizedGpd


###################################################
# Post processing shapefile
###################################################

def polygon_buffer(outFilePolygonizedGpd):

    outFilePolygonizedGpd['geometry'] = (
        outFilePolygonizedGpd.geometry.buffer(-10, cap_style=1,
                                              join_style=2,
                                              resolution=4))
    outFilePolygonizedGpd = outFilePolygonizedGpd[(~outFilePolygonizedGpd['geometry'].isna()) & (~outFilePolygonizedGpd['geometry'].is_empty)]
    return outFilePolygonizedGpd


###################################################
# Post processing shapefile
###################################################

def polygon_epsg4326(outFilePolygonizedGpd):
    return outFilePolygonizedGpd.to_crs(epsg=4326)

###################################################
# Merge tiles
###################################################

def merge_tiles(outdir,filepattern='[0-9][0-9][A-Z][A-Z][A-Z]_segmentedImage_felzenszwalbRAG_015threshold_025haSelected_4500mCropped.shp'):
    # TODO try out using assign_boundary (won't work out of the box because polygons hang out from the 4500m cropped region)
    # probably this works:
    # 1.: get the two bbox-es and  find their intersection rectangle
    # 2.: select what is inside on both sides
    # 3.: do the join 
    #
    # TODO: other way around (because now bounding boxes intersection is used and that select too many polys for example in a 2x2 tile case when adding the last tile <- first three is L-shaped):  
    #       * create an outer rectangular ring of 1000m (if 4500m cropped -> has 1km overlap)
    #       * select all that intersects with that
    #       * move the rest from the inside tot he collector dataframe (lshp)
    # TODO: still generatesoverlapping polygons -> needs a post sweep to join
    print('Merging individual tiles ...')
    
    tiles=glob.glob(os.path.join(outdir,filepattern))
    print('Reading first shp: '+tiles[0])
    lshp = gpd.read_file(tiles[0])
    lbox = geometry.box(*lshp.total_bounds)
    for tile in tiles[1:]:
        print('Appending '+tile)
        rshp = gpd.read_file(tile)
        rbox = geometry.box(*rshp.total_bounds)
        # find common box
        ints = lbox.intersection(rbox)
        # if there is meaningful common box
        cint=None
        if ints.area>0.:
            # extract possible overlapping polygons from left and right
            # TODO: if per tile building would build a boundary information and preselect on that, 
            #       that would speed up things and the internal polygons could be cancat'd immediately
            #       use the -1m boundary line segment with intersect
            print("  Intersecting")
            lint = lshp[lshp.intersects(ints)]
            rint = rshp[rshp.intersects(ints)]
            print("  Reducing")
            lshp.drop(lint.index, inplace=True)
            rshp.drop(rint.index, inplace=True)
            print("  Joining")
            mint = gpd.sjoin(lint, rint, op='intersects', how='inner')
            lint=lint.drop(mint.index)
            rint=rint.drop(mint.index_right)
            # it does happen that multiple  polygons  overlap
            mint.drop('index_right', axis=1, inplace=True)
            mint=mint.dissolve(['CODE_OBJ_left'], as_index=False).dissolve(['CODE_OBJ_right'], as_index=False)
            mint.drop('CODE_OBJ_right', axis=1, inplace=True)
            mint=mint.rename({'CODE_OBJ_left':'CODE_OBJ'},axis=1)
            # combine intersected geometries together
            cint=lint.append([mint,rint])
        # shunt all from right into left
        if cint is not None:
            print("  Adding merged boundary geometries")
            lshp=lshp.append(cint)
        print("  Adding the rest from 'right'")
        lshp=lshp.append(rshp)
        lshp.reset_index(drop=True,inplace=True)
    print("Writing results to disk")
    lshp.to_file(os.path.join(outdir, 'parcel_delineation_merged.shp'))


###################################################
# Main routine
###################################################
def process_tile(rootdir, outdir, tile, overwrite, keep_tmp_files):

    os.environ['GDAL_DATA'] = '/usr/share/gdal'

    print('-' * 75)
    print('PROCESSING TILE: {}'.format(tile))

    # Define the input file
    inFile = os.path.join(rootdir, tile + '_segmentedImage.tif')

    # Read the image in memory
    try:
        print('  Reading image ...')
        with rasterio.open(inFile) as src: 
            image=src.read(1)
            window=Window(0,0,src.width,src.height)
            profile=src.profile
            imgBounds=src.bounds[:]
    except Exception as e:
        raise Exception("Error opening input image",e)
        return

    ###################################################
    # Vectorization
    ###################################################

    outFilePolygonized = os.path.join(outdir, tile + '_segmentedImage_felzenszwalbRAG_015threshold_025haSelected_4500mCropped.shp')
    if not os.path.exists(outFilePolygonized) or overwrite:
        # compute the polygons
        outGPDFPolygonized=raster_segmentation(image,profile,window,outFilePolygonized,keep_tmp_files)
        # do the 4.5km cropping
        outGPDFPolygonized=polygon_postproc(tile,imgBounds,outGPDFPolygonized)
        print('Writing to new SHP file ...')
        outGPDFPolygonized.to_file(outFilePolygonized)
    else:
        print('Output polygonized file exists -> Skipping')
        outGPDFPolygonized=None

    # ugly!!!
    image=None;

    ###################################################
    # Post processing shapefile
    ###################################################

    outFilePostProc = os.path.splitext(outFilePolygonized)[0] + '_Buffer10m_WGS84.shp'
    if not os.path.exists(outFilePostProc) or overwrite:
        
        # favor in-memory
        if outGPDFPolygonized is None: outGPDFPolygonized=gpd.read_file(outFilePolygonized)
        outGPDFPostProc=outGPDFPolygonized
                
        # do the 10m buffer
        outGPDFPostProc=polygon_buffer(outGPDFPostProc)

        # change crs
        outGPDFPostProc=polygon_epsg4326(outGPDFPostProc)
        
        # Write to final shapefile
        print('Writing to new SHP file ...')
    # reindex for brewity
        outGPDFPostProc=outGPDFPostProc.filter(['geometry','CODE_OBJ'])
        outGPDFPostProc.reset_index(drop=True,inplace=True)
        outGPDFPostProc.to_file(outFilePostProc)

    else:
        print('Output post-processed SHP file exists -> Skipping')


