#!/var/jenkins/workspace/opyspark-integrationtests_master/venv3.5/bin/python3
# -*- coding: utf-8 -*-
import argparse
from osgeo import gdal
from osgeo.gdalconst import *
import numpy
from pandas import DatetimeIndex
import geopandas
import os
import math
# import pandas
# import torch
# import torchvision
# import tensorflow
# import tensorboard
from pprint import pprint
from openeo_udf.api.base import RasterCollectionTile, UdfData, SpatialExtent, FeatureCollectionTile

__license__ = "Apache License, Version 2.0"
__author__     = "Soeren Gebbert"
__copyright__  = "Copyright 2018, Soeren Gebbert"
__maintainer__ = "Soeren Gebbert"
__email__      = "soerengebbert@googlemail.com"

gdal.UseExceptions()
DESCR="""This program reads a list of single- or multi-band GeoTiff files and vector files
and applies a user defined function (UDF) on them. 
The GeoTiff files must be provided as comma separated list, as well as the band names. 
The vector files must be provides as comma separated list of files as well. The UDF
must be accessible on the file system. The computed results are single- or multi-band GeoTiff files 
in case of raster output and geopackage vector files in case of vector output
that are written into a specific output directory. Raster and vector files can be specified together.
However, all provided files must have the same projection.

Examples:

    The following command computes the NDVI on a raster
    image series of three multi-band tiff files. Two bands are provided with the names RED and NIR for
    the UDF. The three resulting single-band GeoTiff files are written to the /tmp directory.
    
        execute_udf -r data/red_nir_1987.tif,data/red_nir_2000.tif,data/red_nir_2002.tif \\
                    -b RED,NIR \\
                    -u src/openeo_udf/functions/raster_collections_ndvi.py

    This command computes the sum of the raster series for each band. A single raster image
    with two bands is written as GeoTiff file to the directory /tmp.
    
        execute_udf -r data/red_nir_1987.tif,data/red_nir_2000.tif,data/red_nir_2002.tif \\
                    -b RED,NIR \\
                    -u src/openeo_udf/functions/raster_collections_reduce_time_sum.py
    
    
    This command reads a feature collection stored in a gepackge file
    and applies the UDF buffer function. The result is a new geopackage
    that contains the buffers written in directory /tmp:
    
        execute_udf -v data/sampling_points.gpkg -u src/openeo_udf/functions/feature_collections_buffer.py 

    This command reads a series of raster GeoTiff files and a feature collection stored in a gepackge file
    and applies the UDF sampling function. The result is a new geopackage
    that contains the sampled attributes written in directory /tmp:
    
        execute_udf -r data/red_nir_1987.tif,data/red_nir_2000.tif,data/red_nir_2002.tif \\
                    -b RED,NIR \\
                    -v data/sampling_points.gpkg \\
                    -u src/openeo_udf/functions/raster_collections_sampling.py 

"""


def main():

    parser = argparse.ArgumentParser(description=DESCR, formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument("-r", "--raster_files", type=str, required=False,
                        help="Comma separated list of raster files. If several raster "
                             "files are provided, then each raster file must have "
                             "the same number of bands.")

    parser.add_argument("-v", "--vector_files", type=str, required=False,
                        help="Comma separated list of vector files. Each vector file will be converted into a "
                             "vector collection tile.")

    parser.add_argument("-t", "--raster_time_stamps", type=str, required=False,
                        help="A comma separated list of time stamps, that must have the same number of entries "
                             "as the list of raster files.")

    parser.add_argument("-b", "--band_names", type=str, required=False,
                        help="A comma separated list of band names.")

    parser.add_argument("-o", "--raster_output_dir", type=str, required=False, default="/tmp",
                        help="The output directory to store the computed results.")

    parser.add_argument("-u", "--path_to_udf", type=str, required=True,
                        help="The UDF file to execute.")

    args = parser.parse_args()

    raster_names = []
    vector_names = []
    raster_time_stamps= []
    band_names = []

    if args.raster_files:
        raster_names = args.raster_files.split(",")
    if args.vector_files:
        vector_names = args.vector_files.split(",")
    if args.raster_time_stamps:
        raster_time_stamps = args.raster_time_stamps.split(",")
    out_dir = args.raster_output_dir
    udf_path = args.path_to_udf
    if args.band_names:
        band_names = args.band_names.split(",")

    if not raster_names and not vector_names:
        raise Exception("Raster files and/or vector files must be specified.")

    code = open(udf_path, "r").read()
    data_type = None
    proj = None

    # Make sure the correct number of raster bands is used
    num_bands = 0
    first = True
    for raster_name in raster_names:
        ds = gdal.Open(raster_name, GA_ReadOnly)

        if first is True:
            num_bands = ds.RasterCount
            # Get the datatype of the array
            b = ds.GetRasterBand(1)
            array = b.ReadAsArray(0, 0, 0, 0)
            data_type = array.dtype

            first = False
        else:
            if num_bands - ds.RasterCount != 0:
                raise Exception("The number of raster bands must be equal for all provided images")

            b = ds.GetRasterBand(1)
            array = b.ReadAsArray(0, 0, 0, 0)

            if data_type != array.dtype:
                raise Exception("The raster images must have the same datatype")

    if num_bands != len(band_names):
        raise Exception("Please provide the same number of band names a bands in the raster image files")

    num_raster = len(raster_names)
    raster_datasets = []
    vector_datasets = []

    for band in range(num_bands):
        first = True
        extent = None
        array3d = None
        projection = None
        tile_count = 0

        for raster_name in raster_names:

            # For code support
            # ds = gdal.Dataset()
            ds = gdal.Open(raster_name, GA_ReadOnly)

            # Check if the data is correct
            if first is True:
                extent = get_extent_from_dataset(ds)
                rows = ds.RasterYSize
                cols = ds.RasterXSize
                projection = ds.GetProjection()
                # Create the 3D numpy array
                array3d = numpy.zeros(shape=[num_raster, rows, cols], dtype=data_type)
                first = False
            else:
                new_extent = get_extent_from_dataset(ds)
                if new_extent.as_polygon() != extent.as_polygon():
                    raise Exception("The extents of the provided raster files must be equal")

                if projection != ds.GetProjection():
                    raise Exception("The provided raster files must have equal projection")

            # For code support
            # b = gdal.Band()
            b = ds.GetRasterBand(band + 1)
            # Read the whole raster array as single 2D tile
            array = b.ReadAsArray()
            array3d[tile_count] = array
            tile_count += 1

        rct = RasterCollectionTile(id=band_names[band], data=array3d, extent=extent)

        raster_datasets.append(rct)

        proj = {"WKT":projection}

    for vector_name in vector_names:
        gdf = geopandas.read_file(filename=vector_name)

        vector_id = os.path.basename(vector_name).split(".")[0]
        fct = FeatureCollectionTile(id=vector_id, data=gdf)
        vector_datasets.append(fct)

    data = UdfData(proj=proj, raster_collection_tiles=raster_datasets, feature_collection_tiles=vector_datasets)

    #################################################################
    ########### RUN THE UDF CODE ####################################
    exec(code)
    #################################################################
    #################################################################

    # Write the output
    fileformat = "GTiff"
    driver = gdal.GetDriverByName(fileformat)

    tiles = data.get_raster_collection_tiles()

    if tiles:
        # Get the number of files that must be created
        tile = tiles[0]
        # tile = RasterCollectionTile()
        # Determine the shape and the datatype for the output
        num_slices, rows, cols = tile.data.shape
        numpy_dtype = tile.data.dtype
        if numpy_dtype == "float64":
            gdal_dtype = gdal.GDT_Float64
        else:
            gdal_dtype = gdal.GDT_Float32

        for slice in range(num_slices):
            output_path = os.path.join(out_dir, tile.id + "_%i.tif"%slice)
            print(output_path)

            dst_ds = driver.Create(output_path, xsize=cols, ysize=rows,
                                   bands=len(tiles), eType=gdal_dtype)
            transform = extent_to_gdal_transform(tile.extent)
            dst_ds.SetGeoTransform(transform)
            dst_ds.SetProjection(data.proj["WKT"])
            tile_count = 1
            for tile in tiles:
                # tile = RasterCollectionTile()
                dst_ds.GetRasterBand(tile_count).WriteArray(tile.data[slice])
                tile_count += 1
            # Write the data
            dst_ds = None

    tiles = data.get_feature_collection_tiles()

    if tiles:
        for tile in tiles:
            output_path = os.path.join(out_dir, tile.id + ".gpkg")
            tile.data.to_file(filename=output_path, driver='GPKG')


def get_extent_from_dataset(ds):
    """
    adfGeoTransform[0] /* top left x */
    adfGeoTransform[1] /* w-e pixel resolution */
    adfGeoTransform[2] /* rotation, 0 if image is "north up" */
    adfGeoTransform[3] /* top left y */
    adfGeoTransform[4] /* rotation, 0 if image is "north up" */
    adfGeoTransform[5] /* n-s pixel resolution */
    """

    transform = ds.GetGeoTransform()
    rows = ds.RasterYSize
    cols = ds.RasterXSize

    west = transform[0]
    ewres = transform[1]
    north = transform[3]
    nsres = transform[5]

    south = north + (rows * nsres)
    east = west + (cols * ewres)

    # print("crs", ds.GetProjection())
    # print("cols", cols)
    # print("north", north)
    # print("south", south)
    # print("west", west)
    # print("east", east)
    # print("ewres", ewres)
    # print("nsres", nsres)

    extent = SpatialExtent(top=north, bottom=south, left=west,
                           right=east, width=abs(ewres), height=abs(nsres))

    return extent


def extent_to_gdal_transform(extent):
    """
    adfGeoTransform[0] /* top left x */
    adfGeoTransform[1] /* w-e pixel resolution */
    adfGeoTransform[2] /* rotation, 0 if image is "north up" */
    adfGeoTransform[3] /* top left y */
    adfGeoTransform[4] /* rotation, 0 if image is "north up" */
    adfGeoTransform[5] /* n-s pixel resolution */
    """

    # extent = SpatialExtent()

    pixel_height = math.copysign(extent.height, extent.bottom - extent.top)
    pixel_width  = math.copysign(extent.width, extent.right - extent.left)

    transform = [extent.top, pixel_width, 0, extent.left, 0, pixel_height]
    return transform



if __name__ == '__main__':
    main()
