from satio.collections import AgERA5Collection
from cropclass.geoloader import GDALWarpLoader


class AgERA5YearlyCollection(AgERA5Collection):

    sensor = 'AgERA5'

    def __init__(self, *args, **kwargs):
        """
        collection csv/dataframe should have the paths
        in column 'path'
        e.g.
        date      |    path
        20200101   .../ewoc-agera5-yearly/2020
        """
        super().__init__(*args, **kwargs)
        self._loader = GDALWarpLoader()

    def filter_dates(self, start_date, end_date):
        start_year = int(start_date[:4])
        end_year = int(end_date[:4]) + 1

        df = self.df[(self.df.date >= f'{start_year}0101')
                     & (self.df.date < f'{end_year}0101')]
        return self._clone(df=df, start_date=start_date, end_date=end_date)

    def get_band_filenames(self, band, resolution=None):
        filenames = self.df.apply(
            lambda x:
            f"{x.path}/AgERA5_{band}_"
            f'{x.date.year}.tif',
            axis=1)
        return filenames.values.tolist()

    def load_timeseries(self,
                        *bands,
                        resolution=100,
                        resampling='cubic',
                        **kwargs):
        return super().load_timeseries(*bands,
                                       resolution=resolution,
                                       resampling=resampling,
                                       **kwargs)
