"""Sync data to and from AWS's S3."""

from __future__ import annotations

import os

import boto3
from botocore.exceptions import (
    ClientError,
    NoCredentialsError,
    UnauthorizedSSOTokenError,
)
from tqdm import tqdm

from vito_cropsar.constants import SPLITS, get_data_folder


def push_to_s3(
    cache_tag: str | None = None,
    overwrite: bool = False,
    splits: list[str] | None = None,
) -> None:
    """
    Push data to S3.

    Parameters
    ----------
    cache_tag : str
        Cache postfix tag, raw datasets if None
    overwrite : bool
        Whether to remove remote data folders before pushing
    splits : list[str]
        The splits to pull, by default all
        Options: training, validation, testing
    """
    splits = splits or SPLITS
    assert all(x in SPLITS for x in splits), "Invalid split(s) provided"

    bucket = _get_bucket()
    for split in splits:
        upload_directory(
            dataset=split if cache_tag is None else f"{split}_{cache_tag}",
            bucket=bucket,
            overwrite=overwrite,
        )


def pull_from_s3(
    cache_tag: str | None = None,
    overwrite: bool = False,
    splits: list[str] | None = None,
) -> None:
    """
    Pull data from S3.

    Parameters
    ----------
    cache_tag : str
        Cache postfix tag, raw datasets if None
    overwrite : bool
        Whether to remove local data folders before pulling
    splits : list[str]
        The splits to pull, by default all
        Options: training, validation, testing
    """
    splits = splits or SPLITS
    assert all(x in SPLITS for x in splits), "Invalid split(s) provided"

    bucket = _get_bucket()
    for split in splits:
        download_directory(
            dataset=split if cache_tag is None else f"{split}_{cache_tag}",
            bucket=bucket,
            overwrite=overwrite,
        )


def upload_directory(
    dataset: str,
    bucket: boto3.resources.factory.s3.Bucket | None = None,
    overwrite: bool = False,
) -> None:
    """
    Upload a local directory to S3.

    Parameters
    ----------
    dataset : str
        The dataset folder ("training", "training_fapar", ...) to upload
    overwrite : bool
        Whether to overwrite existing remote npz files
    """
    # Load bucket if not provided
    if bucket is None:
        bucket = _get_bucket()

    # Check the difference between local and remote files
    files_remote = _get_remote_files(dataset, bucket=bucket)
    files_local = _get_local_files(dataset)

    # Remove remote files that are not in the local
    for file in tqdm(
        files_remote - files_local,
        desc=f"Removing out-of-sync files for {dataset}..",
    ):
        obj = bucket.Object(f"cropsar/{dataset}/{file}")
        obj.delete()

    # Upload all local files not yet on remote
    for file in tqdm(
        files_local if overwrite else (files_local - files_remote),
        desc=f"Uploading {dataset}..",
    ):
        with open(get_data_folder() / dataset / file, "rb") as f:
            bucket.upload_fileobj(f, f"cropsar/{dataset}/{file}")


def download_directory(
    dataset: str,
    bucket: boto3.resources.factory.s3.Bucket | None = None,
    overwrite: bool = False,
) -> None:
    """
    Download a directory from S3 to local.

    Parameters
    ----------
    dataset : str
        The dataset folder ("training", "training_fapar", ...) to download
    overwrite : bool
        Whether to overwrite existing npz files
    """
    # Load bucket if not provided
    if bucket is None:
        bucket = _get_bucket()

    # Ensure the folder exists
    (get_data_folder() / dataset).mkdir(parents=True, exist_ok=True)

    # Check the difference between local and remote files
    files_remote = _get_remote_files(dataset, bucket=bucket)
    files_local = _get_local_files(dataset)

    # Remove files that are not in the remote
    for file in tqdm(
        files_local - files_remote,
        desc=f"Removing out-of-sync files for {dataset}..",
    ):
        (get_data_folder() / dataset / file).unlink(missing_ok=True)

    # Pull all files that are remote but not local
    for file in tqdm(
        files_remote if overwrite else (files_remote - files_local),
        desc=f"Downloading {dataset}..",
    ):
        with open(get_data_folder() / dataset / file, "wb") as f:
            bucket.download_fileobj(f"cropsar/{dataset}/{file}", f)


def _get_remote_files(
    dataset: str,
    bucket: boto3.resources.factory.s3.Bucket,
) -> set[str]:
    """Get all the existing remote files."""
    try:
        return {
            obj.key.split("/")[-1]
            for obj in bucket.objects.filter(Prefix=f"cropsar/{dataset}")
            if not obj.key.endswith("/")
        }
    except ClientError:
        return set()


def _get_local_files(dataset: str) -> set[str]:
    """Get all the local files of the requested dataset."""
    return {file.name for file in (get_data_folder() / dataset).glob("*.npz")}


def _get_bucket() -> boto3.resources.factory.s3.Bucket:
    """Load in the S3 bucket."""
    try:
        # Try without session token
        s3_resource = boto3.resource(
            "s3",
            region_name="eu-west-1",
        )
        _ = list(s3_resource.buckets.all())  # Check connection
        bucket = s3_resource.Bucket("vito-data")
    except (UnauthorizedSSOTokenError, KeyError, NoCredentialsError):
        try:
            # Try with session token (environment variable)
            s3_resource = boto3.resource(
                "s3",
                region_name="eu-west-1",
                aws_session_token=os.environ["AWS_SESSION_TOKEN"],
            )
            _ = list(s3_resource.buckets.all())  # Check connection
            bucket = s3_resource.Bucket("vito-data")
        except (UnauthorizedSSOTokenError, KeyError, NoCredentialsError) as e:
            raise Exception("Could not load S3 bucket.") from e
    return bucket


if __name__ == "__main__":
    # push_to_s3(overwrite=True)
    download_directory(
        dataset="testing",
        overwrite=True,
    )
