"""Dataset class for crop classification."""

from __future__ import annotations

import random
from random import shuffle as shuffle_idx
from typing import Any, Callable

import numpy as np
import pandas as pd
import torch
from numpy.typing import NDArray
from torch.utils.data import DataLoader, Dataset

from vito_crop_classification.training.augmentor import Augmentor


class BaseDataset(Dataset):  # type : ignore[type-arg]
    """Base crop classification dataset."""

    def __init__(
        self,
        data: pd.DataFrame,
        process_f: Callable[[pd.DataFrame], torch.FloatTensor],
        classes: NDArray[np.str_],
        collate_fn: Callable[[...], Any],
        balance: bool = False,
        augment: bool = False,
    ) -> None:
        """
        Pytorch dataset class used to handle data during training.

        Parameters
        ----------
        data : pd.DataFrame
            Dataframe handled
        process_f : Callable[[pd.DataFrame], torch.FloatTensor]
            Encoder processing function to transform data into the correct input to the model
        classes : NDArray[np.str_]
            List of classes to use for training
        collate_fn : Callable[[...], Any]
            Collate function for the DataLoader
        balance : bool, optional
            If True, balance classes by upsampling, by default False
        augment : bool, optional
            If True, apply data augmentation, by default False
        """
        super().__init__()
        assert len(classes.shape) == 1
        self._augment = augment
        self._classes_str = classes
        self._classes_mapping = {l: i for i, l in enumerate(classes)}
        data_processed = process_f(data)
        self._collate_fn = collate_fn
        self._augmentor = Augmentor()

        if balance:
            self._datasets, self._classes = split_by_target(
                data=data_processed,
                target=np.stack(data["target_id"].to_list()),
                classes_mapping=self._classes_mapping,
            )

        else:
            self._datasets = [data_processed]
            self._classes = [
                torch.FloatTensor(
                    [
                        [i == self._classes_mapping[x] for i in range(len(self._classes_mapping))]
                        for x in data["target_id"]
                    ]
                )
            ]

        # Initialise randomly
        assert all(len(d) > 0 for d in self._datasets), "Empty datasets not allowed!"
        self._indices = [[] for _ in self._datasets]
        for idx in range(len(self._datasets)):
            self.shuffle(idx)
        self._d_idx = 0

    def __len__(self) -> int:
        """Get length of dataset."""
        return sum(len(classes) for classes in self._classes)

    def shuffle(self, idx: int) -> None:
        """Shuffle the dataset found under the provided index."""
        self._indices[idx] = list(range(len(self._datasets[idx])))
        shuffle_idx(self._indices[idx])

    def __getitem__(self, _: int) -> dict[str, torch.Tensor | NDArray[list[torch.Tensor]]]:
        """Return datapoint and label at location index in the dataset."""
        raise NotImplementedError

    def get_dataloader(self, batch_size: int = 512) -> DataLoader:
        """Get the dataset's data loader."""
        return DataLoader(
            dataset=self,
            batch_size=batch_size,
            collate_fn=self._collate_fn,
        )


class ClassificationDataset(BaseDataset):
    """Classification based dataset."""

    def __init__(
        self,
        data: pd.DataFrame,
        process_f: Callable[[pd.DataFrame], torch.FloatTensor],
        classes: NDArray[np.str_],
        balance: bool = False,
        augment: bool = False,
    ) -> None:
        """
        Pytorch dataset class used to handle data during training.

        Parameters
        ----------
        data : pd.DataFrame
            Dataframe handled
        process_f : Callable[[pd.DataFrame], torch.FloatTensor]
            Encoder processing function to transform data into the correct input to the model
        classes : NDArray[np.str_]
            List of classes to use for training
        balance : bool, optional
            If True, balance classes by upsampling, by default False
        augment : bool, optional
            If True, apply data augmentation, by default False
        """
        super(ClassificationDataset, self).__init__(
            data=data,
            process_f=process_f,
            classes=classes,
            balance=balance,
            augment=augment,
            collate_fn=classification_collate_fn,
        )

    def __getitem__(self, _: int) -> dict[str, torch.Tensor | NDArray[list[torch.Tensor]]]:
        """Return datapoint and label at location index in the dataset."""
        # Get the next sample of the selected dataset
        s_idx = self._indices[self._d_idx].pop(0)
        sample, label = self._datasets[self._d_idx][s_idx], self._classes[self._d_idx][s_idx]

        # data augmentation
        if self._augment:
            rnd_id = random.randint(0, len(self._datasets[self._d_idx]) - 1)
            sample2 = self._datasets[self._d_idx][rnd_id]
            sample = self._augmentor.augment(sample, sample2)

        # Shuffle if we went through the complete dataset
        if len(self._indices[self._d_idx]) == 0:
            self.shuffle(self._d_idx)

        # reset d_idx
        self._d_idx += 1
        if self._d_idx >= len(self._datasets):
            self._d_idx = 0

        return {
            "input": sample,
            "target": label,
        }


class SimilarityDataset(BaseDataset):
    """Similarity based classification dataset."""

    def __init__(
        self,
        data: pd.DataFrame,
        process_f: Callable[[pd.DataFrame], torch.FloatTensor],
        classes: NDArray[np.str_],
        representatives: dict[str, torch.FloatTensor],
        balance: bool = False,
        augment: bool = False,
    ) -> None:
        """
        Pytorch dataset class used to handle data during training.

        Parameters
        ----------
        data : pd.DataFrame
            Dataframe handled
        process_f : Callable[[pd.DataFrame], torch.FloatTensor]
            Encoder processing function to transform data into the correct input to the model
        classes : NDArray[np.str_]
            List of classes to use for training
        representatives : dict[str, torch.FloatTensor]
            Class representatives, provided as a class -> Tensor dictionary
        balance : bool, optional
            If True, balance classes by upsampling, by default False
        augment : bool, optional
            If True, apply data augmentation, by default False
        """
        super(SimilarityDataset, self).__init__(
            data=data,
            process_f=process_f,
            classes=classes,
            balance=balance,
            augment=augment,
            collate_fn=similarity_collate_fn,
        )
        self._rep_vec = torch.stack([representatives[cls] for cls in self._classes_str])

    def set_representatives(self, representatives: dict[str, torch.FloatTensor]) -> None:
        """Set new representatives."""
        self._rep_vec = torch.stack([representatives[cls] for cls in self._classes_str])

    def __getitem__(self, _: int) -> dict[str, torch.Tensor | NDArray[list[torch.Tensor]]]:
        """Return datapoint and its representative."""
        # Get the next sample of the selected dataset
        s_idx = self._indices[self._d_idx].pop(0)
        sample, label = self._datasets[self._d_idx][s_idx], self._classes[self._d_idx][s_idx]

        # data augmentation
        if self._augment:
            rnd_id = random.randint(0, len(self._datasets[self._d_idx]) - 1)
            sample2 = self._datasets[self._d_idx][rnd_id]
            sample = self._augmentor.augment(sample, sample2)

        # Shuffle if we went through the complete dataset
        if len(self._indices[self._d_idx]) == 0:
            self.shuffle(self._d_idx)

        # Sample representative
        rep = self._rep_vec[label == 1][0]

        # reset d_idx
        self._d_idx += 1
        if self._d_idx >= len(self._datasets):
            self._d_idx = 0

        return {
            "input": sample,
            "representative": rep,
            "target": label,
        }


def split_by_target(
    data: torch.FloatTensor,
    target: NDArray[np.str_],
    classes_mapping: dict[str, int],
) -> tuple[list[torch.FloatTensor], list[torch.FloatTensor]]:
    """Split the data by the targets."""
    all_targets = sorted(classes_mapping.keys())
    datasets, targets = [], []
    for t in all_targets:
        datasets.append(data[target == t])
        targets.append(
            torch.FloatTensor(
                [
                    [i == classes_mapping[x] for i in range(len(classes_mapping))]
                    for x in target[target == t]
                ]
            )
        )
    return datasets, targets


def classification_collate_fn(
    batch: list[dict[str, torch.Tensor | NDArray[list[torch.Tensor]]]],
) -> dict[str, torch.Tensor | NDArray[list[torch.Tensor]]]:
    """Collate function compatible with multi-dimensional input."""
    # If sample is a numpy array, it means that it is the product of the concatenator
    if isinstance(batch[0]["input"], np.ndarray):
        samples = np.asarray([b["input"] for b in batch])
    else:
        samples = torch.vstack([b["input"].unsqueeze(0) for b in batch])
    targets = torch.vstack([b["target"] for b in batch])
    return {
        "inputs": samples,
        "targets": targets,
    }


def similarity_collate_fn(
    batch: list[dict[str, torch.Tensor | NDArray[list[torch.Tensor]]]],
) -> dict[str, torch.Tensor | NDArray[list[torch.Tensor]]]:
    """Collate function compatible with multi-dimensional input."""
    result = classification_collate_fn(batch)
    result["representatives"] = torch.stack([b["representative"] for b in batch])
    return result


def get_dataset(
    train_type: str,
    data: pd.DataFrame,
    process_f: Callable[[pd.DataFrame], torch.FloatTensor],
    classes: NDArray[np.str_],
    representatives: dict[str, torch.FloatTensor] | None = None,
    balance: bool = False,
    augment: bool = False,
) -> BaseDataset:
    """Get the right dataset."""
    if train_type == "classification":
        return ClassificationDataset(
            data=data,
            process_f=process_f,
            classes=classes,
            balance=balance,
            augment=augment,
        )
    elif train_type == "similarity":
        assert representatives is not None
        return SimilarityDataset(
            data=data,
            process_f=process_f,
            classes=classes,
            representatives=representatives,
            balance=balance,
            augment=augment,
        )
    else:
        raise Exception(f"Training type '{train_type}' not supported!")


if __name__ == "__main__":
    from vito_crop_classification.data import load_datasets

    my_result = load_datasets(dataset_ratio=0.01)
    my_df_train = my_result.get("df_train")
    my_classes = np.asarray(sorted(set(my_df_train["target_id"])))

    def my_process_f(df: pd.DataFrame) -> torch.FloatTensor:
        """Example processing function."""
        return torch.FloatTensor(np.vstack(df["ts_ndvi"].to_list()))

    print("\nCreating classification dataset and data loader:")
    my_ds = ClassificationDataset(
        data=my_df_train,
        process_f=my_process_f,
        balance=True,
        classes=my_classes,
    )
    print(f" - Generated a dataset of size {len(my_ds)}")
    my_loader = my_ds.get_dataloader(batch_size=512)

    # Show the contents of one batch
    batch = next(iter(my_loader))
    print(f" - First sample of shape {batch['inputs'].shape}")
    print(f" - First targets of shape {batch['targets'].shape}")
    print(" - Class distribution:")
    for i in range(batch["targets"].size(1)):
        print(f"   - {i}: {int(sum(batch['targets'][:, i]))}")

    print("\nCreating similarity dataset and data loader:")
    my_ds = SimilarityDataset(
        data=my_df_train,
        process_f=my_process_f,
        balance=True,
        classes=my_classes,
        representatives={c: torch.rand(64) for c in my_classes},
    )
    print(f" - Generated a dataset of size {len(my_ds)}")
    my_loader = my_ds.get_dataloader(batch_size=512)

    # Show the contents of one batch
    batch = next(iter(my_loader))
    print(f" - First sample of shape {batch['inputs'].shape}")
    print(f" - First representative of shape {batch['representatives'].shape}")
