import json
import logging
import re
from typing import Dict, Tuple

import geopandas as gpd
import numpy as np
import pyproj
import requests
import xarray as xr
from openeo.udf.xarraydatacube import XarrayDataCube
from shapely.geometry import shape
from shapely.geometry.polygon import Polygon, Point
from shapely.ops import transform

_log = logging.getLogger("yieldpotentialmap_udf")

BAND_NAME = 'relative_diff'


def get_mean_and_median(data: xr.DataArray, date: np.datetime64, band: str, mask: list) -> Tuple:
    """
    Retieve the mean and median value for date in a xarray DataArray. The pixels that contain a value, provided by the mask,
    will be excluded from the calculation
    :param data: xarray DataArray containing the data
    :param band: Name of the band that contains the value to check
    :param date: Date for which to calculate the median value
    :param mask: List of values to mask
    :return: Tuple containing the mean and median values
    """
    pixels = data.sel(t=date).sel(bands=band).astype('float32')
    pixels = mask_values(data=pixels, values=mask, drop=True)
    return pixels.mean(skipna=True).values.tolist(), pixels.median(skipna=True).values.tolist()


def filter_dates(data: xr.DataArray, band: str, threshold: float, exclude_months: list, mask: list) -> dict:
    """
    Filter the data in the given xarray DataArray based on the time dimension. For each date in the array, we will
    check if the following conditions:
        * If the month of the date should not be excluded
        * If the median value of the date is above a certain threshold
    If one of the above conditions fail, the date is removed from the DataArray.
    :param data: xarray DataArray to check
    :param band: Name of the band that contains the value to check
    :param threshold: Threshold for the median value for a given date
    :param exclude_months: Months that should be removed from the DataArray
    :param mask: List of value to mask when calculating the median value
    :return: A dictionary containing the dates as keys and the median for each date as the value
    """
    result = {}
    for date in list(data.t.values):
        _log.debug(f'Checking if {date} needs to be filtered')

        month = date.astype('datetime64[M]').astype(int) % 12 + 1
        if month in list(exclude_months):
            _log.debug(f'Date {date} is in excluded mont;hs')
            continue

        mean, median = get_mean_and_median(data=data, date=date, band=band, mask=mask)
        if (threshold and mean < threshold) or np.isnan(mean):
            _log.debug(f'Date {date} is below threshold {threshold}')
            continue
        result[str(date)] = median
    return result


def categorize_map(cube: xr.DataArray) -> xr.DataArray:
    """
    Map the data in xarray DataArray to categories. The NaN values are automatically filtered in the result.
    :param cube: xarray DataArray to categorize
    :return:
    """
    classification_values = [-3e10, 92.499999, 97.499999, 102.499999, 107.499999,
                             3e10, np.nan]
    data = cube.values
    categorized = np.digitize(data, classification_values)
    categorized = categorized.astype('float32')
    categorized[categorized == len(classification_values)] = np.nan
    cube.values = categorized
    return cube


def generate_map(cube: xr.DataArray, band: str, filtered_results: dict, mask: list,
                 raw: bool = True) -> xr.DataArray:
    """
    Generate the yield potential map based on a set of input parameters
    :param cube: xarray DateCube containing the pixel values
    :param band: Name of the band that contains the value to check
    :param filtered_results: Dictionary containing the dates that should be used and their respecitive median values
    :param mask_values: List of values to mask when generating the map
    :param raw: Flag indicating if the raw values should be returned (True) or if the map should contain categories (False)
    :return: xarray DataArray representing the yield potential map
    """
    map = generate_relative_diff(cube=cube, band=band, filtered_results=filtered_results, mask=mask)
    if not raw:
        map = categorize_map(cube=map)
    return map


def mask_values(data: xr.DataArray, values: list, drop: bool) -> xr.DataArray:
    """
    Mask a list of values in a given xarray DataArray
    :param data: xarray DataArray to mask
    :param values: Values that should be masked
    :param drop: Flag indicating if the values should be removed from the array
    :return:
    """
    for value in values:
        data = data.where(data != value, drop=drop)
    return data


def generate_relative_diff(cube: xr.DataArray, band: str, filtered_results: dict, mask: list) -> xr.DataArray:
    """
    For each valid date, calculate the relative differences of each pixel to the median value
    :param cube: xarray DataArray containing the pixel values
    :param band: Name of the band that contains the value to check
    :param filtered_results: Dictionary containing the dates that should be used and their respecitive median values
    :param mask: List of values to mask when calculating the differences
    :return:
    """
    dates = [np.datetime64(d) for d in filtered_results.keys()]
    year = dates[0].astype('datetime64[Y]').astype(int) + 1970
    medians = xr.DataArray(list(filtered_results.values()), dims=('t'), coords={'t': dates})
    data = cube.where(cube.t.isin(dates), drop=True).sel(bands=band).astype('float32')
    data = mask_values(data=data, values=mask, drop=False)

    relative_diff = ((data - medians) / medians) * 100
    relative_diff = relative_diff.mean(dim=['t'], skipna=True)
    relative_diff = relative_diff + 100

    map = cube.isel(t=0).isel(bands=0).copy()
    map.values = relative_diff
    map = map.expand_dims(dim='bands', axis=0).assign_coords(bands=[BAND_NAME])
    map = map.expand_dims(dim='t', axis=0).assign_coords(t=[np.datetime64(f'{year}-01-01T00:00:00')])
    return map


def get_field_bounds(cube: xr.DataArray, band: str, mask: list, date: str) -> object:
    """
    Given a xarray DataArray, calculate the bounds of the data values
    :param cube: xarray DataArray containing the pixel values
    :param band: Name of the band that contains the value to check
    :param mask: List of values to mask out the field
    :param date: Date for which there should be data available
    :return:
    """
    day = cube.sel(t=np.datetime64(date)).sel(bands=band).astype('float32')
    field = mask_values(data=day, values=mask, drop=False)
    field_df = field.to_dataframe(name='fapar').reset_index()
    field_df = gpd.GeoDataFrame(
        field_df.fapar, geometry=gpd.points_from_xy(field_df.x, field_df.y))
    field_df = field_df[~np.isnan(field_df['fapar'])]
    return field_df.unary_union.convex_hull


def feature_in_field(feature: object, field: Polygon) -> bool:
    """
    Check if a feature is located in a given field polygon. This done by checking the area of the intersection. If the
    feature covers the field for +90%, it is considered that the feature is in the field
    :param feature: Feature to check
    :param field: Field to check
    :return:
    """
    feature_geom = shape(feature['geometry'])
    intersection = field.intersection(feature_geom)
    return intersection.area / field.area >= 0.9


def reproject_field(field: Polygon, from_proj: str, to_proj: str) -> Polygon:
    """
    Reproject a field from a source projection system to a destination projection system
    :param field: Field to reproject
    :param from_proj: Source projection system
    :param to_proj: Destination projection system
    :return:
    """
    from_crs = pyproj.CRS(from_proj)
    to_crs = pyproj.CRS(to_proj)
    project = pyproj.Transformer.from_crs(from_crs, to_crs, always_xy=True).transform
    return transform(project, field)


def check_field_flanders(year: int, field: Polygon, field_epsg: str, croptype_blacklist: list = []) -> tuple:
    """
    Check the field conditions in Flanders. Based on the DLV data, the following conditions are checked
        1. Is the field located in Flanders?
        2. Has the field multiple matches in Flanders? If so, the field is considered to be split
        3. Has the field a croptype that has been blacklisted?

    :param year: Year for which to check the conditions
    :param field: Field to check
    :param field_epsg: The projection system of the field
    :param croptype_blacklist: List of croptypes that should be blacklisted
    :return: Tuple that contains a boolean, indicating if the field satities all conditions, and a corresponding message
    """
    _log.debug('Checking field in Flanders')

    field_latlon = reproject_field(field=field, from_proj=field_epsg, to_proj='EPSG:4326')
    bounds = field_latlon.bounds
    url = 'https://geoservices.landbouwvlaanderen.be/PUBLIC/wfs?service=WFS&version=1.1.0&request=GetFeature&typeName=PUBLIC:AGRIPRC{}&styles=&srsName=EPSG:4326&bbox={},{},{},{},EPSG:4326&srs=EPSG:4326&outputformat=JSON'.format(
        str(year), str(bounds[0]), str(bounds[1]), str(bounds[2]), str(bounds[3]))

    response = requests.get(url)
    if response.status_code != 200:
        return False, f'Could not find execute request to DLV: {response.status_code} - {response.text}'

    features = list(filter(lambda x: feature_in_field(feature=x, field=field_latlon), response.json()['features']))
    if len(features) == 0:
        return False, f'No features were found in the DLV response'
    if len(features) > 1:
        return False, f'Found {len(features)} matches in the DLV response, the field was probably split up in {year}'

    if features[0]['properties']['MAINCROPGROUP'] in croptype_blacklist:
        return False, f'Croptype {features[0]["properties"]["MAINCROPGROUP"]} is blacklisted'

    return True, ''


def check_field_walloon(year: int, field: Polygon, field_epsg: str, croptype_blacklist: list = []) -> tuple:
    """
    Check the field conditions in Walloon. Based on the Walloon GeoServices data, the following conditions are checked
        1. Is the field located in Walloon?
        2. Has the field multiple matches in Walloon? If so, the field is considered to be split
        3. Has the field a croptype that has been blacklisted?

    :param year: Year for which to check the conditions
    :param field: Field to check
    :param field_epsg: The projection system of the field
    :param croptype_blacklist: List of croptypes that should be blacklisted
    :return: Tuple that contains a boolean, indicating if the field satities all conditions, and a corresponding message
    """
    _log.debug('Checking field in Walloon')

    field_lambert = reproject_field(field=field, from_proj=field_epsg, to_proj='EPSG:31370')
    bounds = field_lambert.bounds
    url = 'https://geoservices.wallonie.be/arcgis/services/WFS/SIGEC_PARC_AGRI_ANON/MapServer/WFSServer?service=WFS&request=GetFeature&version=2.0.0&typename=SIGEC_PARC_AGRI_ANON:Parcelles_agricoles___{}_&srsname=EPSG:31370&outputFormat=geojson&count=1000&bbox={},{},{},{},EPSG:31370'.format(
        str(year), str(bounds[0]), str(bounds[1]), str(bounds[2]), str(bounds[3]))

    response = requests.get(url)
    if response.status_code != 200:
        return False, f'Could not find execute request to Walloon GeoServices: {response.status_code} - {response.text}'

    # Intervention required as the Walloon response is not a valid JSON
    response_json = json.loads(re.sub('\"SURF_HA\": \d+,\d+,\n', '', response.text))

    features = list(filter(lambda x: feature_in_field(feature=x, field=field_lambert), response_json['features']))
    if len(features) == 0:
        return False, f'No features were found in the Walloon response'
    if len(features) > 1:
        return False, f'Found {len(features)} matches in the Walloon response, the field was probably split up in {year}'

    if features[0]['properties']['CULT_COD'] in croptype_blacklist:
        return False, f'Croptype {features[0]["properties"]["CULT_COD"]} is blacklisted'

    return True, ''


def check_year_valid(cube: xr.DataArray, bounds: list, exclude_croptypes: list) -> None:
    """
    Check if the current year of the yield potential map is a valid year. A year is considered valid if the following
    conditons are met:
        1. The field is not split in multiple parts for the given year
        2. The croptype, grown for the given year, is not blacklisted

    This is done by iterating through the different regions (Flanders and Walloon). If the above conditions are valid
    for one of these regions, the year is considered valid. If all regions fail, meaning the year is not valid,
    an exception is raised.

    :param cube: xarray DataArray containing the pixel values
    :param bounds: Bounds of the field to check
    :param exclude_croptypes: List of croptypes to be blacklisted
    :return:
    """
    year = cube.t.values[0].astype('datetime64[Y]').astype(int) + 1970
    messages = list()

    for check in [check_field_flanders, check_field_walloon]:
        _log.debug(f'Checking if year {year} is valid')
        valid, message = check(year=year, field=bounds, field_epsg='EPSG:32631', croptype_blacklist=exclude_croptypes)
        messages.append(message)

        if valid:
            return

    raise Exception(f'Field is not valid: {", ".join(messages)}')


def create_map_df(cube: xr.DataArray) -> gpd.GeoDataFrame:
    """
    Transform a xarray DataArray containing the yield potential map to a geopandas dataframe
    :param cube: xarray DataArray to transform
    :return: Geopandas Dataframe that contains a polygon for each pixel and its assigned value
    """
    buffer = 10

    map = cube.isel(t=0).sel(bands=BAND_NAME).to_dataframe('categories').reset_index().drop(
        columns=['t', 'bands']).dropna()

    map['geometry'] = map.apply(lambda x: Polygon(
        [Point(x.x, x.y), Point(x.x + buffer, x.y), Point(x.x + buffer, x.y + buffer), Point(x.x, x.y + buffer),
         Point(x.x, x.y)]), axis=1)

    map = map.drop(columns=['x', 'y']).set_geometry('geometry')
    return map


def polygonize_map(cube: XarrayDataCube) -> XarrayDataCube:
    """
    Polygonize the raster to create different zones delineated through polygons
    :param cube: xarray DataCube containing the final map to poligonize
    :return:
    """
    data = create_map_df(cube.get_array())
    result = xr.DataArray.from_series(data['categories']).reset_index('index', drop=True)
    result = result.expand_dims(geometry=data['geometry'])
    return XarrayDataCube(result)


def apply_datacube(cube: XarrayDataCube, context: Dict) -> XarrayDataCube:
    """
    OpenEO UDF callback that will generate the yield map
    :param cube: XarrayDataCube containing the inputs coming from OpenEO
    :param context: Additional context of the UDF
    :return:
    """
    _log.debug(f'Executing UDF for yield potental map')
    data = cube.get_array()

    # Extract the context information
    band = context.get('band', None)
    threshold = context.get('threshold', None)
    raw = context.get('raw', True)
    exclude_months = context.get('exclude_months', [])
    exclude_croptypes = context.get('exclude_croptypes', [])
    mask_values = [context.get('mask_value', 999.0)]
    year_check = context.get('year_check', False)
    polygonize = context.get('polygonize', False)

    # Filter dates that are above a certain threshold and are not located in specific months (e.g. green manure)
    filtered_dates = filter_dates(data=data, band=band, threshold=threshold, exclude_months=exclude_months,
                                  mask=mask_values)

    if len(filtered_dates.keys()) == 0:
        raise Exception(f'Could not find any valid dates to process.')

    # Check if the year can be processed - parcel is not split or does not contain any of the blacklisted crop types
    if year_check:
        bounds = get_field_bounds(cube=data, band=band, mask=mask_values, date=list(filtered_dates.keys())[0])
        check_year_valid(cube=data, bounds=bounds, exclude_croptypes=exclude_croptypes)

    # Create the actual yield potential map
    map = generate_map(cube=data, band=band, filtered_results=filtered_dates, mask=mask_values, raw=raw)
    result = XarrayDataCube(map.astype('float32'))

    # Check if we need to polygonize the map
    if polygonize:
        return polygonize_map(cube=result)
    else:
        return result


def execute_datefilter_udf(cube: XarrayDataCube, context: Dict) -> XarrayDataCube:
    """
    Local function used for testing purposes
    :param cube: XarrayDataCube containing the inputs coming from OpenEO
    :param context: Additional context of the UDF
    :return:
    """
    return apply_datacube(cube=cube, context=context)
