import os
from loguru import logger as log
import numpy as np
import torch


from fielddelineation.utils.raster import read_window_raster
from fielddelineation.utils.delineation import _apply_delineation
from fielddelineation.utils.geom import get_shape_target
from cropclass.features.settings import compute_n_tsteps
from fielddelineation.utils.padding import calc_padding



def window_generation(xdim: int, ydim: int, windowsize: int, stride: int,
                      force_match_grid: bool = True):
    """
    Function that will generate a list of (unique) windows that could be processed along the defined
    dimensions
    :param xdim: the size of the x-dimension
    :param ydim: the size of the y-dimension
    :param windowsize: the size that each window should have (windowsize x windowsize).
    :param stride: the overlap that the windows may have
    :param force_match_grid: define if the windows should cover the full xdim, ydim extent even this causes an overlap
    of some of the windows. If this is set to False together with a stide of zero, there will be no overlap between
    the windows. Hence, it might happen that the windowlist does not fully cover the dimensions.
    :return: list of windows that could be processed
    """
    #
    # force_match_grid: determines that wen a window would fall outside the xdim, ydim, a reshufelling of that
    # window should be forced such that is nicely ends at the bounds of the image or not. If set to false this last
    # window will not be created and as such the last pixels in the grid are not covered
    #
    # Get the windows
    windowlist = []
    for xStart in range(0, xdim, windowsize - 2 * stride):
        for yStart in range(0, ydim, windowsize - 2 * stride):
            # We need to check if we're at the end of the master image
            # We have to make sure we have a full subtile
            # so we need to expand such tile and the resulting overlap
            # with previous subtile is not an issue
            if xStart + windowsize > xdim:
                if force_match_grid or stride > 0:
                    xStart = xdim - windowsize
                    xEnd = xdim
                else:
                    continue

            else:
                xEnd = xStart + windowsize
            if yStart + windowsize > ydim:
                if force_match_grid or stride > 0:
                    yStart = ydim - windowsize
                    yEnd = ydim
                else:
                    continue
            else:
                yEnd = yStart + windowsize

            windowlist.append(((xStart, xEnd), (yStart, yEnd)))

    return windowlist


def _apply_prediction(ds, window, windowsize,
                      start_month, end_month, year, pred_context,
                      bands=['B02', 'B03', 'B04', 'B08']):
    """"""""""""""""""""""""""""""""
    # Model input data preparation
    """"""""""""""""""""""""""""""""
    # Added it due to conflict on packages in the environment
    if pred_context.get('cache_dir') is None:
        os.environ['TORCH_HOME'] = os.getcwd()
        os.environ['XDG_CACHE_HOME'] = os.getcwd()

    os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
    
    # Infer required amount of tsteps
    n_tsteps = compute_n_tsteps(start_month, end_month)
    
    # now check if padding should be done to get the
    # proper dimensions
    n_ts_add_before, n_ts_add_after = calc_padding(ds,
                                                   start_month,
                                                   end_month,
                                                   n_tsteps,
                                                   year)
    
    if ds is not None:
        pred_array = {}
        for band in bands:
            # Open the dataset only for the requested window and band
            ds_band = ds.filter_by_attrs(long_name=band)
            ds_band_window = ds_band[band][:, window[0][0]: window[0][1],
                                           window[1][0]: window[1][1]]

            band_data = ds_band_window.values
            # Ensure it is a float before setting wrong values to nan
            band_data = band_data.astype(float)

            # Need to set spurious values to NaN
            # it seems that clouds are marked with value 2.1
            band_data[band_data > 20000] = np.nan

            band_data[np.isnan(band_data)] = 65535  # Nodata value
            band_data = band_data.astype(np.uint16)
            pred_array[f"s2_{band.lower()}"] = band_data
            if n_ts_add_before > 0 or n_ts_add_after > 0:
                ext_array_before = np.ones((n_ts_add_before,
                                            band_data.shape[1],
                                            band_data.shape[-1]))*65535
                ext_array_after = np.ones((n_ts_add_after,
                                           band_data.shape[1],
                                           band_data.shape[-1]))*65535
                band_data = np.concatenate((ext_array_before,
                                            band_data,
                                            ext_array_after))
    else:
        # TODO this section is deprecated
        data_dir = pred_context.get('data_dir')
        name_window_data = ''.join((f'{pred_context.get("tile")}_',
                                    f'{window[0][0]}_{window[0][1]}_',
                                    f'{window[1][0]}_{window[1][1]}.npz'))
        if not os.path.exists(os.path.join(data_dir,
                                           name_window_data)):
            return window, None
        else:

            npz_array = np.load(os.path.join(data_dir,
                                             name_window_data))
            pred_array = {}
            for band in npz_array:
                pred_array.update({band: npz_array[band]})

    """"""""""""""""""""""""""""""""
    # Model prediction
    """"""""""""""""""""""""""""""""

    # Apply the field delineation detection (Radix implementation)
    log.info(f'Using model: {pred_context.get("model_tag")}')
    prediction_sem, mdl = _apply_delineation(
        pred_array,
        pred_context.get("model_dir"),
        modeltag=pred_context.get("model_tag"),
        cache_dir=pred_context.get("cache_dir"),
        return_model=True
    )

    # #TODO fix to deal with moving from 3D to 2D
    # # model --> please re-check later once final model
    # if '3D' in pred_context.get("model_tag") and len(prediction_sem.shape) == 2:
    #     prediction_inst = mdl.post_process(
    #         output=torch.tensor(prediction_sem.reshape(1, 
    #                                                    prediction_sem.shape[0], 
    #                                                    prediction_sem.shape[1])).cpu())
    # else:
    #     prediction_inst = mdl.post_process(
    #         output=torch.tensor(prediction_sem).cpu())
    # prediction_inst = prediction_inst.reshape(windowsize, windowsize)

    # Derive instance segmentation for evaluation purposes

    # dict_pred = {
    #     'semantic': torch.tensor(prediction_sem),
    #     'instance': prediction_inst
    # }

    log.success('Model prediction finished')

    return window, (prediction_sem)


def stitch_pred(model_context, pred):
    from fielddelineation.utils.delineation import load_delineation_model
    mdl = load_delineation_model(model_context.get('model_dir'),
                                 model_context.get('model_tag'),
                                 model_context.get('cache_dir'))
    if type(pred) != torch.Tensor:
        pred_torch = torch.Tensor(pred)
    else:
        pred_torch = pred
    
    # if multihead model, explicitely tell that stiching is done, 
    # so that multi head combining is not done again

    if 'MultiHead' in model_context.get('model_tag'):
        pred_stiched = mdl.post_process(pred_torch.unsqueeze(0), 
                                        stitching=True).numpy()
    else:
        pred_stiched = mdl.post_process(pred_torch.unsqueeze(0)).numpy()
    log.info('Stitching succeeded')

    return pred_stiched
