#
# Copyright (c) <2023> Side Effects Software Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# NAME:	        ml/pix2pix/aligneddataset.py ( Python )
#
# COMMENTS:     Defines classes for custom datasets and dataloaders
#
#               Note that scripts in this module are not part of a stable API,
#               and are subject to change at any time.

import os
import torchvision.transforms as transforms
import torch
import numpy as np

from torch.utils.data import DataLoader
from PIL import Image

from pdgml.pix2pix.options import Options
from pdgml.pix2pix.utils import make_dataset

class NPYDataset:
    """A dataset class for paired EXR dataset.

    Assumes that the directory '/path/to/data/train' contains NPY pairs in the form of {A,B}.
    During test time, you need to prepare a directory '/path/to/data/test'.

    Important:
    ----------
    Assumes that all NPYs are of the same dimensions and are pre normalized to [0,1].
    No additional preprocessing is applied.
    """

    def __init__(self, opt):
        """NPYDataset constructor: creates dataset of paths to EXRs

        Parameters:
           options (Options) -- contains experiment config
        """

        self.dir_AB = os.path.join(opt.dataroot, opt.phase)
        self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size))
        self.channels = opt.channels

    def __getitem__(self, index):
        """Return a data point and its metadata information."""

        """Converts to Numpy array
           Values on range 0 to 1"""
        AB_path = self.AB_paths[index]
        AB = np.load(AB_path).astype(float)

        """Scales -> -1 to 1"""
        AB = AB - 0.5
        AB = AB / 0.5

        h, w = AB.shape[0], AB.shape[1]
        w_half = int(w / 2)
        A = AB[:, :w_half]
        B = AB[:, w_half:]

        # Here A is a numpy array representing an image.  So for example [256, 256, 3]
        # for a 256 image rgb image.  This reorders the array in the format that torch
        # expects.  So the output is [3,256,256]
        A = torch.from_numpy((A.transpose((2, 0, 1)))).type(torch.FloatTensor)
        B = torch.from_numpy((B.transpose((2, 0, 1)))).type(torch.FloatTensor)

        return {"A": A, "B": B, "AB_paths": AB_path}

    def __len__(self):
        """Return the total number of images in the dataset."""
        return len(self.AB_paths)


class EXRDataset:
    """A dataset class for paired EXR dataset.

    Assumes that the directory '/path/to/data/train' contains EXR pairs in the form of {A,B}.
    During test time, you need to prepare a directory '/path/to/data/test'.

    Important:
    ----------
    Assumes that all EXRs are of the same dimensions and are pre normalized to [0,1].
    No additional preprocessing is applied.
    """

    def __init__(self, opt):
        """EXRDataset constructor: creates dataset of paths to EXRs

        Parameters:
           options (Options) -- contains experiment config
        """

        self.dir_AB = os.path.join(opt.dataroot, opt.phase)
        self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size))
        self.channels = opt.channels

    def __getitem__(self, index):
        """Return a data point and its metadata information."""

        """Opens EXR and converts to Numpy array
           Values on range 0 to 1"""
        AB_path = self.AB_paths[index]
        AB = exr_to_np(AB_path, channels=self.channels).astype(float)

        """Scales -> -1 to 1"""
        AB = AB - 0.5
        AB = AB / 0.5

        h, w = AB.shape[0], AB.shape[1]
        w_half = int(w / 2)
        A = AB[:, :w_half]
        B = AB[:, w_half:]

        # Here A is a numpy array representing an image.  So for example [256, 256, 3]
        # for a 256 image rgb image.  This reorders the array in the format that torch
        # expects.  So the output is [3,256,256]
        A = torch.from_numpy((A.transpose((2, 0, 1)))).type(torch.FloatTensor)
        B = torch.from_numpy((B.transpose((2, 0, 1)))).type(torch.FloatTensor)

        return {"A": A, "B": B, "AB_paths": AB_path}

    def __len__(self):
        """Return the total number of images in the dataset."""
        return len(self.AB_paths)


class AlignedDataset:
    """A dataset class for paired image dataset.

    Assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}.
    During test time, you need to prepare a directory '/path/to/data/test'.

    Assumes that all images are of the same dimensions. No additional resize or crop is applied.
    """

    def __init__(self, options):
        """AlignedDataset constructor: creates dataset of image paths and initializes preprocessing

        Parameters:
           options (Options) -- contains experiment config
        """

        self.dir_AB = os.path.join(options.dataroot, options.phase)
        self.AB_paths = self.AB_paths = sorted(
            make_dataset(self.dir_AB, options.max_dataset_size)
        )

        MEAN = (0.5,) * options.channels
        STD = (0.5,) * options.channels

        transform_list = [
            transforms.Resize((options.crop_size, options.crop_size), Image.BICUBIC),
            transforms.Lambda(lambda img: __adjust(img)),
            transforms.ToTensor(),
            transforms.Normalize(MEAN, STD)
        ]
        self.transform_img = transforms.Compose(transform_list)

    def __getitem__(self, index):
        """Return a data point and its metadata information."""

        AB_path = self.AB_paths[index]  # Get image path at index
        AB = Image.open(AB_path)  # Open image as PIL

        w, h = AB.size
        w2 = int(w / 2)
        A = AB.crop((0, 0, w2, h))  # Split left half as A
        B = AB.crop((w2, 0, w, h))  # Split right half as B

        """call standard transformation function"""
        A = self.transform_img(A)
        B = self.transform_img(B)

        """ Note:
        ----------
        You can return anything in this method.  It will take whatever you return
        and add a dimension in front that is the size of the batch.  So, if you
        return just A for example, it will compose a tensor like [batchsize,A]
        Because we are returning a dictionary here, it will compose something
        like:
            dict(string -> list( value ))
        where: 
            value = Tensor or str
            len(list) == batch_size     
        'A' --> [batchsize,A]
        """
        return {"A": A, "B": B, "AB_paths": AB_path}

    def __len__(self):
        """Return the total number of images in the dataset."""
        return len(self.AB_paths)


class CustomDatasetDataLoader:
    """Wrapper class of Dataset class that performs multi-threaded data loading"""

    def __init__(self, options):
        """Initialize this class

        Step 1: create a dataset instance given the name [dataset_mode]
        Step 2: create a multi-threaded data loader.
        """
        self.options = options
        self.dataset = (
            EXRDataset(options)
            if options.use_exr is True
            else (NPYDataset(options) if options.use_npy else AlignedDataset(options))
        )
        print("dataset [%s] was created" % type(self.dataset).__name__)

        self.dataloader = DataLoader(
            self.dataset,
            batch_size=options.batch_size,
            shuffle=False,
            num_workers=options.num_workers,
        )

    def load_data(self):
        return self

    def __len__(self):
        """Return the number of data in the dataset"""
        return min(len(self.dataset), self.options.max_dataset_size)

    def __iter__(self):
        """Return a batch of data"""
        for i, data in enumerate(self.dataloader):
            if i * self.options.batch_size >= self.options.max_dataset_size:
                break
            yield data


def create_dataset(options):
    """Create a dataset given the option.

    This function wraps the class CustomDatasetDataLoader.
        This is the main interface between this package and 'train.py'/'test.py'

    Example:
        >>> from data import create_dataset
        >>> dataset = create_dataset(opt)
    """
    data_loader = CustomDatasetDataLoader(options)
    dataset = data_loader.load_data()
    return dataset


def _AlignedDataset__adjust(img):
    """Modify the width and height to be multiple of 4.

    Parameters:
        img (PIL image) -- input image

    Returns:
        modified image whose width and height are mulitple of 4.

    the size needs to be a multiple of 4,
    because going through generator network may change img size
    and eventually cause size mismatch error
    """
    ow, oh = img.size
    mult = 4
    if ow % mult == 0 and oh % mult == 0:
        return img
    w = (ow - 1) // mult
    w = (w + 1) * mult
    h = (oh - 1) // mult
    h = (h + 1) * mult

    if ow != w or oh != h:
        __print_size_warning(ow, oh, w, h)

    return img.resize((w, h), Image.BICUBIC)


def __print_size_warning(ow, oh, w, h):
    """Print warning information about image size(only print once)"""
    if not hasattr(__print_size_warning, "has_printed"):
        print(
            "The image size needs to be a multiple of 4. "
            "The loaded image size was (%d, %d), so it was adjusted to "
            "(%d, %d). This adjustment will be done to all images "
            "whose sizes are not multiples of 4" % (ow, oh, w, h)
        )
        __print_size_warning.has_printed = True
