from math import ceil

import numpy as np
import rasterio
from rasterio.crs import CRS
from rasterio.enums import Resampling
from rasterio.vrt import WarpedVRT
from rio_tiler.reader import point
from skimage.transform import resize
import importlib.resources as pkg_resources

METEO_PATHS = {}


def buffer_bounds(bounds, buffer):
    bounds = np.array(bounds)
    bounds += np.array([-buffer, -buffer, buffer, buffer])
    return bounds.tolist()


for year in range(2018, 2024):
    with pkg_resources.path(
        "evotrain.meteo.annual_agera_embeddings_v1",
        f"meteo_biomes_{year}.tif",
    ) as resource_path:
        METEO_PATHS[year] = resource_path


def _read_warped_tif(
    fn, bounds, epsg, out_shape=None, resampling=Resampling.cubic_spline
):
    with rasterio.open(fn) as src:
        with WarpedVRT(
            src,
            crs=CRS.from_epsg(epsg),
            bounds=bounds,
            resampling=resampling,
            dtype="float32",
        ) as vrt:
            meteo = vrt.read(
                window=vrt.window(*bounds),
                out_shape=out_shape,
            )

        meteo[meteo == src.nodata] = np.nan
        meteo *= src.scales[0]

    return meteo


def shape_from_bounds(bounds, res):
    height = int(ceil((bounds[3] - bounds[1]) / res))
    width = int(ceil((bounds[2] - bounds[0]) / res))
    return height, width


def read_warped_lowres_tif(
    fn,
    bounds,
    epsg,
    resolution=10,
    bounds_buffer=None,
    resampling=Resampling.cubic_spline,
    order=3,
):
    buf_bounds = buffer_bounds(bounds, bounds_buffer)
    out_shape_low_res = None

    data = _read_warped_tif(
        fn,
        buf_bounds,
        epsg,
        out_shape=out_shape_low_res,
        resampling=resampling,
    )

    out_shape = shape_from_bounds(buf_bounds, resolution)
    high_res_data = np.zeros((data.shape[0], *out_shape))
    for i in range(data.shape[0]):
        high_res_data[i] = resize(
            data[i],
            out_shape,
            order=order,
            anti_aliasing=True,
            preserve_range=True,
        )

    buffered_pixels = bounds_buffer // resolution

    high_res_data = high_res_data[
        :, buffered_pixels:-buffered_pixels, buffered_pixels:-buffered_pixels
    ]

    return high_res_data


def load_meteo_embeddings(
    bounds, epsg, year, resolution=10, bounds_buffer=3000, order=3
):
    return read_warped_lowres_tif(
        METEO_PATHS[year],
        bounds,
        epsg,
        resolution=resolution,
        bounds_buffer=bounds_buffer,
        order=order,
    )


def load_meteo_point(lon, lat, year):
    with rasterio.open(METEO_PATHS[year]) as src:
        value = point(src, (lon, lat))
    return value.data
