"""Utilities to load Agera5 aggregated data stored in
/vitodata/vegteam/auxdata/meteo/agera5/aggregated/v1
Data is aggregated every 5 days with a rolling window 
of [5, 10, 30, 90, 180, 365] days.

The loader will identify the closest date available and return the data
for the aggregated interval selected.
"""
import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import rasterio
from dateutil.parser import parse
from rio_tiler.reader import point as rio_tiler_point

from vitors_landcover import parallelize


def agera_var_path_aggregated(var_name,
                              date_id='20200101',
                              interval_days=10):
    if isinstance(date_id, datetime.datetime):
        date_id = date_id.strftime('%Y%m%d')
    aggregated_path = Path(
        f'/vitodata/vegteam/auxdata/meteo/agera5/aggregated/v1/{interval_days:03d}D')
    year = date_id[:4]
    agera_sub_path = aggregated_path / year / date_id

    return agera_sub_path / f"AgERA5_{var_name}_{date_id}_{interval_days:03d}D.tif"


class Agera5Dataset:

    supported_vars = [
        'dewpoint-temperature',
        'solar-radiation-flux',
        'precipitation-flux',
        'temperature-min',
        'temperature-max',
        'temperature-mean']

    supported_intervals = [5, 10, 30, 90,
                           180, 365]

    def __init__(self):
        self._dates = self._get_dates()

    def _check_var(self, var_name):
        if var_name not in self.supported_vars:
            raise ValueError(f'var_name {var_name} not supported. '
                             f'Supported values: {self.supported_vars}')

    def _check_days(self, interval_days):
        if interval_days not in self.supported_intervals:
            raise ValueError(f'interval_days {interval_days} not supported. '
                             f'Supported values: {self.supported_intervals}')

    def _get_dates(self):
        agera_dates = sorted(list(
            Path("/vitodata/vegteam/auxdata/meteo/"
                 "agera5/aggregated/v1/010D/2019").iterdir()))
        agera_dates = [d.name.replace('2019', y)
                       for y in ['2018', '2019', '2020',
                                 '2021', '2022', '2023']
                       for d in agera_dates]
        agera_dates += ['20240101']
        agera_dates = [parse(d) for d in sorted(agera_dates)]
        return agera_dates

    def _find_closest_date(self, date):
        if isinstance(date, str):
            date = parse(date)
        return min(self._dates, key=lambda d: abs(d - date))

    def _get_tif_path_closest(self, var_name, date, interval_days):
        date_id = self._find_closest_date(date)
        return agera_var_path_aggregated(var_name, date_id, interval_days)

    def var_point(self, var_name, date, lon, lat, interval_days=10):

        self._check_var(var_name)
        self._check_days(interval_days)

        tif_path = self._get_tif_path_closest(var_name, date, interval_days)
        with rasterio.open(str(tif_path)) as src:
            scale = src.scales[0]
            nodata = src.nodata
            value = rio_tiler_point(src,
                                    coordinates=(lon, lat)).data[0]
            if value == nodata:
                return np.nan
            else:
                return value * scale

    def all_vars_point(self, date, lon, lat):
        values = [self.var_point(var_name,
                                 date,
                                 lon,
                                 lat,
                                 interval)
                  for var_name in self.supported_vars
                  for interval in self.supported_intervals]
        return pd.Series(values, index=self._get_index())

    def _get_index(self):
        return [f'{var_name}-{interval:03d}D'
                for var_name in self.supported_vars
                for interval in self.supported_intervals]

    def _build_vars_df(self, lon, lat):
        vals = parallelize(lambda d: self.all_vars_point(d, lon, lat),
                           self._dates, max_workers=10)
        return pd.DataFrame(vals,
                            index=self._dates,
                            columns=self._get_index())
