#!/usr/bin/env python3

import io
import datetime

import numpy
import pyproj
import shapely.geometry
import rasterio.io

import flask
from flask import Blueprint, request, jsonify, send_file
from flask_httpauth import HTTPBasicAuth

from werkzeug.exceptions import BadRequest, HTTPException

from .dataretrieval import get_datacube
from .algorithm import bbox_from_lonlat, predict, segment, polygonize, isolate_center_segment


blueprint = Blueprint('api', __name__, static_folder='static')

auth = HTTPBasicAuth()


def _unrestricted_users():
    return {'admin': '4Dm1n!',
            'vito':  'V1t0!',
            'demo':  'demo123',
            'niab':  'hUgBfL2jR7'}


def _get_date_param(name, default=None):

    val = request.args.get(name, default)

    if val is None:
        raise BadRequest('Missing parameter [{}]'.format(name))

    try:
        val = datetime.datetime.strptime(val, '%Y-%m-%d')
    except ValueError:
        raise BadRequest('Invalid parameter value [{}={}]'.format(name, val))

    return val.strftime('%Y-%m-%d')


def _get_float_param(name, default=None):
    
    val = request.args.get(name, default)

    if val is None:
        raise BadRequest('Missing parameter [{}]'.format(name))

    try:
        val = float(val)
    except ValueError:
        raise BadRequest('Invalid parameter value [{}={}]'.format(name, val))

    return val


def _to_memory_image(data, transform, crs):

    if len(data.shape) == 2:
        data = data.reshape((1, data.shape[0], data.shape[1]))

    prof = {'driver':    'GTiff',
            'dtype':     str(data.dtype),
            'count':     data.shape[0],
            'height':    data.shape[1],
            'width':     data.shape[2],
            'crs':       crs,
            'transform': transform}

    with rasterio.io.MemoryFile() as memfile:
        with memfile.open(**prof) as im:
            im.write(data)
        
        return memfile.read()


def _polygons_to_lonlat(polygons, crs_from):

    crs_to = 'epsg:4326'

    tf = pyproj.Transformer.from_crs(crs_from,
                                     crs_to,
                                     always_xy=True)

    out = []

    for p in polygons:

        xy = tf.transform(*p.exterior.xy)
        xy = numpy.array(xy).T

        out.append(shapely.geometry.Polygon(xy))

    return shapely.geometry.GeometryCollection(out)


@auth.verify_password
def verify_password(username, password):
    users = _unrestricted_users()

    if username in users:
        return password == users[username]

    return False


@blueprint.errorhandler(Exception)
def handle_error(e):
    code = 500
    if isinstance(e, HTTPException):
        code = e.code
    return jsonify(error=str(e)), code


@blueprint.route('/openapi.yaml')
def openapi_yaml():
    return blueprint.send_static_file('openapi.yaml')


@blueprint.route('/v1.0/geometry/', methods=['GET'], strict_slashes=False)
@auth.login_required
def v1_0_geometry():

    lat = _get_float_param('lat')
    lon = _get_float_param('lon')
    start = _get_date_param('start')
    end = _get_date_param('end')

    # Retrieve input data

    bbox, crs = bbox_from_lonlat(lon, lat)

    data = get_datacube(start, end, bbox, crs, maxcc=75)

    b04 = data['B04']
    b08 = data['B08']
    scl = data['SCL']

    data_tf = data['profile']['transform']

    # Generate field prediction image

    pred, pred_tf = predict(b04, b08, scl, data_tf)

    # Generate field segmentation image

    seg, seg_tf = segment(pred, pred_tf)

    seg = isolate_center_segment(seg)

    # Convert to polygons

    poly = polygonize(seg, seg_tf)

    poly = _polygons_to_lonlat(poly, crs)

    if poly.is_empty:
        poly = shapely.geometry.Polygon()
    else:
        poly = poly.geoms[0]

    return jsonify(shapely.geometry.mapping(poly))


@blueprint.route('/v1.0/geometries/', methods=['GET'], strict_slashes=False)
@auth.login_required
def v1_0_geometries():

    lat = _get_float_param('lat')
    lon = _get_float_param('lon')
    start = _get_date_param('start')
    end = _get_date_param('end')

    # Retrieve input data

    bbox, crs = bbox_from_lonlat(lon, lat)

    data = get_datacube(start, end, bbox, crs, maxcc=75)

    b04 = data['B04']
    b08 = data['B08']
    scl = data['SCL']

    data_tf = data['profile']['transform']

    # Generate field prediction image

    pred, pred_tf = predict(b04, b08, scl, data_tf)

    # Generate field segmentation image

    seg, seg_tf = segment(pred, pred_tf)

    # Convert to polygons

    polys = polygonize(seg, seg_tf)

    polys = _polygons_to_lonlat(polys, crs)

    return jsonify(shapely.geometry.mapping(polys))


@blueprint.route('/v1.0/prediction/', methods=['GET'], strict_slashes=False)
@auth.login_required
def v1_0_prediction():

    lat = _get_float_param('lat')
    lon = _get_float_param('lon')
    start = _get_date_param('start')
    end = _get_date_param('end')

    # Retrieve input data

    bbox, crs = bbox_from_lonlat(lon, lat)

    data = get_datacube(start, end, bbox, crs, maxcc=75)

    b04 = data['B04']
    b08 = data['B08']
    scl = data['SCL']

    data_tf = data['profile']['transform']

    # Generate field segmentation image

    pred, pred_tf = predict(b04, b08, scl, data_tf)

    image = _to_memory_image(pred, pred_tf, crs)

    # The arguments for send_file have changed in Flask v2.2.
    # For now, have constrained the Flask dependency to flask>=1.1.1, < 2.2' in setup.py
    # but we can make it compatible with newer versions with a different call, below:
    # TODO: It seems Flask 1.1.1 was also no longer working, as-is. Should update setup.py if we no longer support Flask 1.1.
    if flask.__version__ < "2.2":
        return send_file(io.BytesIO(image),
                         attachment_filename='prediction.tiff',
                         mimetype='image/tiff')
    else:
        return send_file(io.BytesIO(image),
                         download_name='prediction.tiff',
                         as_attachment=True,
                         mimetype='image/tiff')

@blueprint.route('/v1.0/segmentation/', methods=['GET'], strict_slashes=False)
@auth.login_required
def v1_0_segmentation():

    lat = _get_float_param('lat')
    lon = _get_float_param('lon')
    start = _get_date_param('start')
    end = _get_date_param('end')

    # Retrieve input data

    bbox, crs = bbox_from_lonlat(lon, lat)

    data = get_datacube(start, end, bbox, crs, maxcc=75)

    b04 = data['B04']
    b08 = data['B08']
    scl = data['SCL']

    data_tf = data['profile']['transform']

    # Generate field prediction image

    pred, pred_tf = predict(b04, b08, scl, data_tf)

    # Generate field segmentation image

    seg, seg_tf = segment(pred, pred_tf)

    image = _to_memory_image(seg, seg_tf, crs)

    # The arguments for send_file have changed in Flask v2.2.
    # For now, have constrained the Flask dependency to flask>=1.1.1, < 2.2' in setup.py
    # but we can make it compatible with newer versions with a different call, below:
    # TODO: It seems Flask 1.1.1 was also no longer working, as-is. Should update setup.py if we no longer support Flask 1.1.
    if flask.__version__ < "2.2":
        return send_file(io.BytesIO(image),
                         attachment_filename='segmentation.tiff',
                         mimetype='image/tiff')
    else:
        return send_file(io.BytesIO(image),
                         download_name='segmentation.tiff',
                         as_attachment=True,
                         mimetype='image/tiff')
