from typing import Dict, List

from loguru import logger
import numpy as np
from pathlib import Path
from satio.features import Features
from worldcereal.fp import (L2AFeaturesProcessor,
                            L8ThermalFeaturesProcessor,
                            SARFeaturesProcessor,
                            WorldCerealSARFeaturesProcessor,
                            WorldCerealOpticalFeaturesProcessor,
                            WorldCerealThermalFeaturesProcessor,
                            WorldCerealAgERA5FeaturesProcessor,
                            AgERA5FeaturesProcessor)
from cropclass.classifier import CropclassClassifier
from worldcereal import SUPPORTED_SEASONS, TBASE
from worldcereal.processors import ClassificationProcessor

from cropclass.models import CroptypeModel


class CropClassProcessor(ClassificationProcessor):

    supported_sensors = ['OPTICAL',
                         'SAR',
                         'METEO',
                         'DEM',
                         'L8']

    def __init__(self,
                 output_folder,
                 collections: Dict,
                 class_model,
                 mask_model=None,
                 class_encoder=None,
                 mask_encoder=None,
                 season: str = 'annual',
                 settings: Dict = None,
                 rsi_meta: Dict = None,
                 features_meta: Dict = None,
                 ignore_def_feat: Dict = None,
                 gdd_normalization: bool = False,
                 fps: Dict = None,
                 aez: int = None,
                 start_date='20190101',
                 end_date='20200101',
                 save_confidence=False,
                 save_meta=False,
                 save_features=False,
                 filtersettings: Dict = None,
                 featresolution: int = 10,
                 avg_segm: bool = True,
                 segm_feat: List = None,
                 **kwargs):

        logger.debug("Initializing ClassificationProcessor")

        self.collections = {k: v for k, v in collections.items()
                            if v is not None}
        self.aez = aez
        self.output = Path(output_folder)
        self.class_model = class_model
        self.mask_model = mask_model
        self.class_encoder = class_encoder
        self.mask_encoder = mask_encoder
        self.filtersettings = filtersettings
        self.season = season
        self.settings = settings
        self.rsi_meta = rsi_meta
        self.features_meta = features_meta
        self.ignore_def_feat = ignore_def_feat
        self.gdd_normalization = gdd_normalization
        self._start_date = start_date
        self._end_date = end_date
        self._save_features = save_features
        self._save_confidence = save_confidence
        self._save_meta = save_meta
        self.featresolution = featresolution
        self.avg_segm = avg_segm
        self.segm_feat = segm_feat or ['SAR-VH-std-20m',
                                       'OPTICAL-B12-p10-20m',
                                       'OPTICAL-ndmi-p10-20m',
                                       'OPTICAL-evi-p10-10m',
                                       'OPTICAL-anir-std-20m',
                                       'OPTICAL-B12-std-20m']

        self._check_sources()

        # Configure the featuresprocessors
        fps = fps or {}
        self._fps = {
            'OPTICAL': fps.get('OPTICAL', WorldCerealOpticalFeaturesProcessor),
            'SAR': fps.get('SAR', WorldCerealSARFeaturesProcessor),
            'METEO': fps.get('METEO', WorldCerealAgERA5FeaturesProcessor),
            'DEM': fps.get('DEM', Features.from_dem),
            'L8': fps.get('L8', WorldCerealThermalFeaturesProcessor)
        }

        # Now make sure start_date and end_date are set in the settings
        # of all collections
        for coll in self.settings.keys():
            self.settings[coll]['composite']['start'] = start_date
            self.settings[coll]['composite']['end'] = end_date

        if 'sen2agri_temp_feat' in self.features_meta.get('OPTICAL', {}):
            self.features_meta['OPTICAL'][
                'sen2agri_temp_feat'][
                    'parameters']['time_start'] = start_date

        # Add optional GDD normalization to the settings as well
        if self.gdd_normalization:
            for coll in self.settings.keys():
                self.settings[coll]['normalize_gdd'] = dict(
                    tbase=TBASE[season],
                    season=season
                )

        logger.debug("ClassificationProcessor initialized")

    def classify(self, features, output_folder, tile, bounds,
                 epsg, block_id):
        # Now do the predictions based on all provided models
        # and write result to disk

        try:

            if self.mask_model is not None:
                mask_model = CroptypeModel.from_config(self.mask_model)

                # Create a WorldCerealClassifier from the model
                classifier = CropclassClassifier(
                    mask_model,
                    filtersettings=self.filtersettings,
                    encoder=self.mask_encoder)

                # Get prediction and confidence
                prediction, confidence = classifier.predict(
                    features, nodatavalue=255)

                mask = np.zeros_like(prediction)
                mask[prediction == 11] = 1

            else:
                mask = None

            # Load the model
            model = CroptypeModel.from_config(self.class_model)

            # Create a WorldCerealClassifier from the model
            self.classifier = CropclassClassifier(
                model,
                filtersettings=self.filtersettings,
                maskdata=mask,
                encoder=self.class_encoder)

            nodata = 0

            # Get prediction and confidence
            prediction, confidence = self.classifier.predict(
                features,
                nodatavalue=nodata)

            self.save(prediction, tile, bounds, epsg,
                      self.aez, output_folder, block_id=block_id,
                      product='croptype', tag='classification',
                      nodata=0)

            if self._save_confidence:
                self.save(confidence, tile, bounds, epsg,
                          self.aez, output_folder, block_id=block_id,
                          product='croptype', tag='confidence',
                          nodata=nodata)

            if self._save_features:
                self.save_features(features, output_folder, 'croptype',
                                   tile, block_id, bounds, epsg,
                                   self.aez)

            if self._save_meta:
                self.save_meta(features, output_folder, 'croptype',
                               tile, block_id, bounds, epsg,
                               self.aez)

        except Exception as e:
            logger.error(f"Error predicting for {tile} - {bounds} "
                         f"- {block_id}: {e}")
            raise
