"""Dataset class for crop classification."""

from __future__ import annotations

import warnings
from random import randint
from random import shuffle as shuffle_idx
from typing import Any, Callable

import numpy as np
import pandas as pd
import torch
from numpy.typing import NDArray
from torch.utils.data import DataLoader, Dataset

from vito_crop_classification.training.augmentor import Augmentor


class BaseDataset(Dataset):
    """Base crop classification dataset."""

    def __init__(
        self,
        data: pd.DataFrame,
        process_f: Callable[[pd.DataFrame], torch.FloatTensor],
        classes: NDArray[np.str_],
        collate_fn: Callable[[...], Any],
        balance: bool = False,
        balance_by: list[str] = ["target_id"],
        metadata: list[str] | None = ["country"],
        augment: bool = False,
        seq_len: int = 24,
    ) -> None:
        """
        Pytorch dataset class used to handle data during training.

        Parameters
        ----------
        data : pd.DataFrame
            Dataframe handled
        process_f : Callable[[pd.DataFrame], torch.FloatTensor]
            Encoder processing function to transform data into the correct input to the model
        classes : NDArray[np.str_]
            List of classes to use for training
        collate_fn : Callable[[...], Any]
            Collate function for the DataLoader
        balance : bool, optional
            If True, balance classes by upsampling, by default False
        balance_by : list[int], optional
            list of columns over which balance the dataloader
        metadata : list[str], optional
            List of metadata columns to include in the dataset, by default ["country"]
        augment : bool, optional
            If True, apply data augmentation, by default False
        seq_len : int, optional
            Sequence length of the time series, by default 24
        """
        super().__init__()
        assert len(classes.shape) == 1
        self._augment = augment
        self._classes_str = classes
        self._classes_mapping = {cls: i for i, cls in enumerate(classes)}
        data_processed = process_f(data)
        self._collate_fn = collate_fn
        self._seq_len = seq_len
        self._augmentor = Augmentor(seq_len=self._seq_len)
        self._metadata = data[metadata].values if metadata else None
        self._metadata_labels = metadata

        if balance:
            self._datasets, self._classes, self._metadata = split_by_key(
                data=data_processed,
                metadata=self._metadata,
                keys=data[balance_by].values,
                target=np.stack(data["target_id"].to_list()),
                classes_mapping=self._classes_mapping,
            )

        else:
            self._datasets = [data_processed]
            self._metadata = [self._metadata]
            self._classes = [
                torch.FloatTensor(
                    [
                        [i == self._classes_mapping[x] for i in range(len(self._classes_mapping))]
                        for x in data["target_id"]
                    ]
                )
            ]

        # Initialise randomly
        assert all(len(d) > 0 for d in self._datasets), "Empty datasets not allowed!"
        self._indices = [[] for _ in self._datasets]
        for idx in range(len(self._datasets)):
            self.shuffle(idx)
        self._d_idx = 0

    def __len__(self) -> int:
        """Get length of dataset."""
        return sum(len(classes) for classes in self._classes)

    def shuffle(self, idx: int) -> None:
        """Shuffle the dataset found under the provided index."""
        self._indices[idx] = list(range(len(self._datasets[idx])))
        shuffle_idx(self._indices[idx])

    def __getitem__(self, _: int) -> dict[str, torch.Tensor | NDArray[list[torch.Tensor]]]:
        """Return datapoint and label at location index in the dataset."""
        raise NotImplementedError

    def get_dataloader(self, batch_size: int = 512) -> DataLoader:
        """Get the dataset's data loader."""
        return DataLoader(
            dataset=self,
            batch_size=batch_size,
            collate_fn=self._collate_fn,
        )


class ClassificationDataset(BaseDataset):
    """Classification based dataset."""

    def __init__(
        self,
        data: pd.DataFrame,
        process_f: Callable[[pd.DataFrame], torch.FloatTensor],
        classes: NDArray[np.str_],
        balance: bool = False,
        balance_by: list[str] = ["target_id"],
        augment: bool = False,
        seq_len: int = 24,
    ) -> None:
        """
        Pytorch dataset class used to handle data during training.

        Parameters
        ----------
        data : pd.DataFrame
            Dataframe handled
        process_f : Callable[[pd.DataFrame], torch.FloatTensor]
            Encoder processing function to transform data into the correct input to the model
        classes : NDArray[np.str_]
            List of classes to use for training
        balance : bool, optional
            If True, balance classes by upsampling, by default False
        augment : bool, optional
            If True, apply data augmentation, by default False
        seq_len : int, optional
            Sequence length of the time series, by default 24
        """
        super().__init__(
            data=data,
            process_f=process_f,
            classes=classes,
            balance=balance,
            balance_by=balance_by,
            augment=augment,
            collate_fn=classification_collate_fn,
            seq_len=seq_len,
        )

    def __getitem__(self, _: int) -> dict[str, torch.Tensor | NDArray[list[torch.Tensor]]]:
        """Return datapoint and label at location index in the dataset."""
        # Get the next sample of the selected dataset
        res = {}
        s_idx = self._indices[self._d_idx].pop(0)
        sample, label = (
            self._datasets[self._d_idx][s_idx],
            self._classes[self._d_idx][s_idx],
        )
        if self._metadata_labels:
            metadata = self._metadata[self._d_idx][s_idx]
            res = dict(zip(self._metadata_labels, metadata))

        # data augmentation
        if self._augment:
            rnd_id = randint(0, len(self._datasets[self._d_idx]) - 1)  # noqa: S311
            sample2 = self._datasets[self._d_idx][rnd_id]
            sample = self._augmentor.augment(sample, sample2)
        elif isinstance(sample, np.ndarray):
            sample = [
                self._augmentor.apply_h_cut(s1=s, seq_len=self._seq_len, center=True)[0]
                if len(s.shape) != 1
                else s
                for s in sample
            ]
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", FutureWarning)
                sample = np.asarray(sample, dtype=object)
        else:
            sample = self._augmentor.apply_h_cut(s1=sample, seq_len=self._seq_len, center=True)[0]

        # Shuffle if we went through the complete dataset
        if len(self._indices[self._d_idx]) == 0:
            self.shuffle(self._d_idx)

        # reset d_idx
        self._d_idx += 1
        if self._d_idx >= len(self._datasets):
            self._d_idx = 0

        res.update({"input": sample, "target": label})
        return res


class SimilarityDataset(BaseDataset):
    """Similarity based classification dataset."""

    def __init__(
        self,
        data: pd.DataFrame,
        process_f: Callable[[pd.DataFrame], torch.FloatTensor],
        classes: NDArray[np.str_],
        representatives: dict[str, torch.FloatTensor],
        balance: bool = False,
        balance_by: list[str] = ["target_id"],
        augment: bool = False,
        seq_len: int = 24,
    ) -> None:
        """
        Pytorch dataset class used to handle data during training.

        Parameters
        ----------
        data : pd.DataFrame
            Dataframe handled
        process_f : Callable[[pd.DataFrame], torch.FloatTensor]
            Encoder processing function to transform data into the correct input to the model
        classes : NDArray[np.str_]
            List of classes to use for training
        representatives : dict[str, torch.FloatTensor]
            Class representatives, provided as a class -> Tensor dictionary
        balance : bool, optional
            If True, balance classes by upsampling, by default False
        augment : bool, optional
            If True, apply data augmentation, by default False
        seq_len : int, optional
            Sequence length of the time series, by default 24
        """
        super().__init__(
            data=data,
            process_f=process_f,
            classes=classes,
            balance=balance,
            balance_by=balance_by,
            augment=augment,
            collate_fn=similarity_collate_fn,
            seq_len=seq_len,
        )
        self._rep_vec = torch.stack([representatives[cls] for cls in self._classes_str])

    def set_representatives(self, representatives: dict[str, torch.FloatTensor]) -> None:
        """Set new representatives."""
        self._rep_vec = torch.stack([representatives[cls] for cls in self._classes_str])

    def __getitem__(self, _: int) -> dict[str, torch.Tensor | np.ndarray]:
        """Return datapoint and its representative."""
        # Get the next sample of the selected dataset
        res = {}
        s_idx = self._indices[self._d_idx].pop(0)
        sample, label = (
            self._datasets[self._d_idx][s_idx],
            self._classes[self._d_idx][s_idx],
        )
        if self._metadata_labels:
            metadata = self._metadata[self._d_idx][s_idx]
            res = dict(zip(self._metadata_labels, metadata))

        # data augmentation
        if self._augment:
            rnd_id = randint(0, len(self._datasets[self._d_idx]) - 1)  # noqa: S311
            sample2 = self._datasets[self._d_idx][rnd_id]
            sample = self._augmentor.augment(sample, sample2)
        elif isinstance(sample, np.ndarray):
            sample = [
                self._augmentor.apply_h_cut(s1=s, seq_len=self._seq_len, center=True)[0]
                if len(s.shape) != 1
                else s
                for s in sample
            ]
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", FutureWarning)
                sample = np.asarray(sample, dtype=object)
        else:
            sample = self._augmentor.apply_h_cut(s1=sample, seq_len=self._seq_len, center=True)[0]

        # Shuffle if we went through the complete dataset
        if len(self._indices[self._d_idx]) == 0:
            self.shuffle(self._d_idx)

        # Sample representative
        rep = self._rep_vec[label == 1][0]

        # reset d_idx
        self._d_idx += 1
        if self._d_idx >= len(self._datasets):
            self._d_idx = 0

        res.update({"input": sample, "representative": rep, "target": label})
        return res


def split_by_key(
    data: torch.FloatTensor,
    metadata: NDArray[Any] | None,
    keys: NDArray[Any],
    target: NDArray[np.str_],
    classes_mapping: dict[str, int],
) -> tuple[list[torch.FloatTensor], list[torch.FloatTensor], list[np.ndarray]]:
    """Split the data by the targets."""
    datasets, targets, metadatas = [], [], []
    for k in np.unique(keys.astype(np.str_), axis=0):
        mask = (keys == k).all(axis=1) if len(keys.shape) > 1 else keys == k
        datasets.append(data[mask])
        targets.append(
            torch.FloatTensor(
                [
                    [i == classes_mapping[x] for i in range(len(classes_mapping))]
                    for x in target[mask]
                ]
            )
        )
        if metadata is not None:
            metadatas.append(metadata[mask])
        else:
            metadatas.append([None])
    return datasets, targets, metadatas


def classification_collate_fn(
    batch: list[dict[str, torch.Tensor | NDArray[list[torch.Tensor]]]],
) -> dict[str, torch.Tensor | NDArray[list[torch.Tensor]]]:
    """Collate function compatible with multi-dimensional input."""
    # If sample is a numpy array, it means that it is the product of the concatenator
    if isinstance(batch[0]["input"], np.ndarray):
        samples = np.asarray([b["input"] for b in batch])
    else:
        samples = torch.vstack([b["input"].unsqueeze(0) for b in batch])
    targets = torch.vstack([b["target"] for b in batch])

    # If metadata is present, add it to the batch
    res = {key: [b[key] for b in batch] for key in batch[0] if key not in ["input", "target"]}
    res.update({"inputs": samples, "targets": targets})
    return res


def similarity_collate_fn(
    batch: list[dict[str, torch.Tensor | NDArray[list[torch.Tensor]]]],
) -> dict[str, torch.Tensor | NDArray[list[torch.Tensor]]]:
    """Collate function compatible with multi-dimensional input."""
    result = classification_collate_fn(batch)
    result["representatives"] = torch.stack([b["representative"] for b in batch])
    return result


def get_dataset(
    train_type: str,
    data: pd.DataFrame,
    process_f: Callable[[pd.DataFrame], torch.FloatTensor],
    classes: NDArray[np.str_],
    representatives: dict[str, torch.FloatTensor] | None = None,
    balance: bool = False,
    seq_len: int = 24,
    balance_by: list[str] = ["target_id"],
    augment: bool = False,
) -> BaseDataset:
    """Get the right dataset."""
    if train_type == "classification":
        return ClassificationDataset(
            data=data,
            process_f=process_f,
            classes=classes,
            balance=balance,
            balance_by=balance_by,
            augment=augment,
            seq_len=seq_len,
        )
    if train_type == "similarity":
        assert representatives is not None
        return SimilarityDataset(
            data=data,
            process_f=process_f,
            classes=classes,
            representatives=representatives,
            balance=balance,
            balance_by=balance_by,
            augment=augment,
            seq_len=seq_len,
        )
    raise Exception(f"Training type '{train_type}' not supported!")


if __name__ == "__main__":
    from vito_crop_classification.data_io import load_datasets

    my_result = load_datasets(dataset="24ts-DEFAULT-nofilters", dataset_ratio=0.01)
    my_df_train = my_result.get("df_train")
    my_classes = np.asarray(sorted(set(my_df_train["target_id"])))

    def my_process_f(df: pd.DataFrame) -> torch.FloatTensor:
        """Process function."""
        return torch.FloatTensor(np.vstack(df["ts_ndvi"].to_list()))

    print("\nCreating classification dataset and data loader:")
    my_ds = ClassificationDataset(
        data=my_df_train,
        process_f=my_process_f,
        balance=True,
        classes=my_classes,
    )
    print(f" - Generated a dataset of size {len(my_ds)}")
    my_loader = my_ds.get_dataloader(batch_size=512)

    # Show the contents of one batch
    batch = next(iter(my_loader))
    print(f" - First sample of shape {batch['inputs'].shape}")
    print(f" - First targets of shape {batch['targets'].shape}")
    print(" - Class distribution:")
    for i in range(batch["targets"].size(1)):
        print(f"   - {i}: {int(sum(batch['targets'][:, i]))}")

    print("\nCreating similarity dataset and data loader:")
    my_ds = SimilarityDataset(
        data=my_df_train,
        process_f=my_process_f,
        balance=True,
        classes=my_classes,
        representatives={c: torch.rand(64) for c in my_classes},
    )
    print(f" - Generated a dataset of size {len(my_ds)}")
    my_loader = my_ds.get_dataloader(batch_size=512)

    # Show the contents of one batch
    batch = next(iter(my_loader))
    print(f" - First sample of shape {batch['inputs'].shape}")
    print(f" - First representative of shape {batch['representatives'].shape}")
