'''
Created on Jan 17, 2021

@author: banyait
'''
import sys
import glob
import os
from openeo.rest.conversions import _load_DataArray_from_NetCDF
from pathlib import Path
import json
from matplotlib import pyplot
import datetime
from shapely.geometry.base import geom_from_wkt
import contextily as ctx
import pyproj
from shapely.ops import transform
import matplotlib
import shapely.geometry
from functools import partial


if __name__ == '__main__':

    print('Start: '+str(datetime.datetime.now()))
    
    # where to look
    datadir=sys.argv[1]
    outdir=sys.argv[2]
     
    # statistical results: 
    # dict(
    #   numpy.datetime64 per month,
    #   list of ratios non-nodata/sum pixels rounded to integer percentage
    # )
    stats={}
     
    # multypoligon geojson with the extents of each sample
    features=[]
     
    # gather info
    files=glob.glob(str(Path(datadir,'sample_*')))
    tileID=files[0].split('_')[1]
    for ifile in files:
        data=_load_DataArray_from_NetCDF(ifile)
        tc=list(zip(
            data.t.to_pandas().dt.normalize().apply(lambda x: x.replace(day=1)).to_numpy(),
            #data.count(dim=('x','y','bands')).values
            data.where(data<255).count(dim=('x','y','bands')).values
        ))
        npts=float(data.shape[-1]*data.shape[-2])
        for i in tc:
            l=stats.get(i[0],[])
            l.append(int(100.*float(i[1])/npts))
            stats[i[0]]=l
        bb=((float(data.x.min()),float(data.y.min())),(float(data.x.max()),float(data.y.max())))
        features+=[
            {
                "type": "Feature",
                "geometry": {
                    "type": "Polygon", 
                    "coordinates": [[
                        [bb[0][0],bb[0][1]],
                        [bb[1][0],bb[0][1]],
                        [bb[1][0],bb[1][1]],
                        [bb[0][0],bb[1][1]],
                        [bb[0][0],bb[0][1]]
                    ]]
                },
                "properties": {
                    "numlayers": len(data.t.values),
                }                
            }
        ]
          
    # save the geojson
    geoms={
        "type": "FeatureCollection",
        "features": features
    }
    with open(Path(outdir,'bboxes.json'),'w') as f:
        json.dump(geoms, f, indent=2)
  
    # plot the histograms
    s=list(sorted(stats.items(),key=lambda i: i[0]))
    fig = pyplot.figure(figsize=(3,2*(len(s)+1)),dpi=200) 
    gs = pyplot.GridSpec(len(s)+1,1,wspace=0.,hspace=0.25,top=1.-0.01,bottom=0.01,left=0.165,right=1.-0.006875) 
  
    ax0= pyplot.subplot(gs[0,0])
    ax0.bar([i[0] for i in s], [len(i[1]) for i in s], width=200/len(s)+1)
    ax0.text(-0.17,0.5, 'Monthly distribution', va="center", ha="center", rotation=90,  transform=ax0.transAxes)
    ax0.tick_params(axis='x', rotation=30, labelsize=5)
    ax0.tick_params(axis='y', labelsize=5)
    #ax0.ticklabel_format(axis='both', style='plain', scilimits=(0,1000000000), useOffset=False)
      
    for i in range(len(s)):
        ax= pyplot.subplot(gs[i+1,0])
        ax.hist(s[i][1], bins=[i for i in range(0,101)])
        #ax.xticks([0,10,20,30,40,50,60,70,80,90,100])
        ax.set_xlim(0, 100)
        ax.text(-0.17,0.5, datetime.datetime.fromtimestamp(s[i][0].item()/10**9).strftime('%Y/%m'), va="center", ha="center", rotation=90,  transform=ax.transAxes)
        ax.ticklabel_format(axis='both', style='plain', scilimits=(0,1000000000), useOffset=False)
        ax.tick_params(axis='both', labelsize=5)
          
    pyplot.savefig(Path(outdir,'distribution.png'))
    #pyplot.show()
      
    # print textual info
    with open(Path(outdir,'resultsummary.txt'),'w') as f:
        f.write('Found {} files\n'.format(str(len(files))))
        f.write('Month numsamplelayerss mincountpercent maxcountpercent\n')
        for i in s: f.write("  {} {:>8} {:>3} {:>3}\n".format(datetime.datetime.fromtimestamp(i[0].item()/10**9).strftime('%Y/%m'),str(len(i[1])),str(min(i[1])),str(max(i[1]))))
        f.write('Average number of layers in a sample: {}\n'.format(str(float(sum([len(i[1]) for i in s]))/float(len(files)))))

    # plot the inspection view
    #with open(Path(outdir,'bboxes.json')) as f: geoms=json.load(f)
    with open(Path(os.path.dirname(os.path.realpath(__file__)),"S2A_OPER_GIP_TILPAR_MPC__20151209T095117_V20150622T000000_21000101T000000_B00.json")) as f:
        tileinfo=json.load(f)[tileID]
        epsg="EPSG:"+tileinfo[0]

    #p3857 = pyproj.Transformer.from_proj(pyproj.Proj(init=epsg), pyproj.Proj(init='EPSG:3857')).transform

    p3857 = partial(pyproj.transform,pyproj.Proj(init=epsg),pyproj.Proj(init="EPSG:3857")) 

    tile=transform(p3857, geom_from_wkt(tileinfo[2]).buffer(0))
    extt=transform(p3857, geom_from_wkt(tileinfo[2]).buffer(15000)).bounds

    pyplot.clf()
    pyplot.cla()
    pyplot.close()
    pyplot.figure(figsize=(10,10))

    ax = pyplot.gca()
    ax.set_xlim([extt[0], extt[2]])
    ax.set_ylim([extt[1], extt[3]])

    pyplot.plot(*tile.exterior.xy, c="white")
    cm=matplotlib.cm.get_cmap('RdYlGn')
    
    img,ext=ctx.bounds2img(*extt, zoom=11, source=ctx.providers.Esri.WorldImagery)
    grayimg=0.2989*img[:,:,0]+0.5870*img[:,:,1]+0.1140*img[:,:,2]
    grayimg=grayimg-grayimg.min()
    grayimg=grayimg/grayimg.max()
    pyplot.imshow(grayimg,cmap='gray',vmin=-1,vmax=1,extent=ext)

    i=0
    for ismp in geoms["features"]:
        i=i+1
        if i%1000==0: print(str(datetime.datetime.now())+": "+str(i)+"/"+str(len(geoms["features"])))
        icol=matplotlib.colors.to_hex(cm(float(ismp['properties']['numlayers'])/15.))
        igeo=transform(p3857, shapely.geometry.Polygon(*ismp["geometry"]["coordinates"]).buffer(0)).exterior.xy
        pyplot.fill('x','y',icol,data={'x':igeo[0][:-1],'y':igeo[1][:-1]},alpha=0.2)
    
    pyplot.axis('off')
    pyplot.tight_layout()
#    pyplot.show()
    pyplot.savefig('numlayers.png',dpi=400)
    
    print("Finished: "+str(datetime.datetime.now()))
        
        
        
        
        
        