#
# Copyright (c) <2023> Side Effects Software Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# NAME:	        ml/pix2pix/testmodel.py ( Python )
#
# COMMENTS:     Utility functions
#
#               Note that scripts in this module are not part of a stable API,
#               and are subject to change at any time.

import ntpath
import numpy as np
import os
import torch

from PIL import Image

from pdgml.pix2pix.model import UnetGenerator

# Use to validate image formats
IMG_EXTENSIONS = [
    ".jpg",
    ".JPG",
    ".jpeg",
    ".JPEG",
    ".png",
    ".PNG",
    ".ppm",
    ".PPM",
    ".bmp",
    ".BMP",
    ".exr",
    ".npy",
]

def is_image_file(filename):
    """Checks for valid image formats

    Parameter
    ---------
        filename (string): name of image file

    Return
    ------
        return (bool): true if file extension exists in IMG_EXTENSIONS

    """
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def make_dataset(dataset_path, max_dataset_size=float("inf")):
    """Generates a dataset of valid image paths

    Parameters
    ----------
        dataset_path (string): oath to dataset directory
        max_dataset_size (int): maximum number of images in the dataset

    Returns
    -------
        images (list<string>): list of paths to images of valid image formats
    """
    images = []
    assert os.path.isdir(dataset_path), "%s is not a valid directory" % dataset_path

    for root, _, fnames in sorted(
        os.walk(dataset_path)
    ):  # Searches for all viable files in dataset_path (including within subdirectories)
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)

    return images[: min(max_dataset_size, len(images))]

def save_images(output_path, visuals, image_path, aspect_ratio=1.0):
    """Save images to the disk

    Parameters
    ----------
        output_path (string): path to directory of output folder
        visuals (dict<string><tensor>): dictionary mapping image names (real_A, real_B, fake_B) to image tensors
        image_path (string): full path to original image in the dataset
        aspect_ratio (float): scale the aspect ratio of the image

    """
    os.makedirs(output_path, exist_ok=True)

    short_path = ntpath.basename(
        image_path[0]
    )  # Gets the image file name (eg. 'name.jpg')
    name = os.path.splitext(short_path)[
        0
    ]  # Removes the extension from the file name (eg. 'name')

    for label, im_data in visuals.items():  # Get the tensor data of each image
        image_numpy = tensor2im(im_data)  # Converts tensor to numpy array
        image_name = "%s_%s.png" % (name, label)
        save_path = os.path.join(output_path, image_name)
        h, w, _ = image_numpy.shape
        if aspect_ratio > 1.0:
            image_numpy = imresize(
                im, (h, int(w * aspect_ratio)), interp="bicubic"
            )  # increase the width
        if aspect_ratio < 1.0:
            image_numpy = imresize(
                im, (int(h / aspect_ratio), w), interp="bicubic"
            )  # increase the height

        image_pil = Image.fromarray(image_numpy)
        image_pil.save(save_path)


# def save_to_exr(img, path):
# """ Helper function to save numpy array as an exr

# Parameters
# ----------
# img (numpy array): numpy array of image data (height, width, channels)
# path (string): full path including the file name to save the exr

# """
# size_x, size_y = img.shape[1], img.shape[0] # height and width of image

# if img.shape[-1] == 3: # if image has three channels
# R, G, B = [img[:,:,channel].tostring() for channel in [0, 1, 2]] # Convert numpy array data of each channel to string
# header = OpenEXR.Header(size_x, size_y) # an EXR header is a dict that contains the defining information of an EXR file - initialized with width and height
# header['channels'] = {'R': Imath.Channel(Imath.PixelType(OpenEXR.HALF)),
# 'G': Imath.Channel(Imath.PixelType(OpenEXR.HALF)),
# 'B': Imath.Channel(Imath.PixelType(OpenEXR.HALF))} # channels key points to another dictionary with the keys as the channels, and values of class Imath.Channel - change the format of the pixeltype to HALF (float16)

# out = OpenEXR.OutputFile(path, header) # Create EXR file with "path" as the file name, and properties defined by header object
# out.writePixels({'R': R, 'G': G, 'B': B}) # Write the RGB data
# elif img.shape[-1] == 1: # if image has one channel
# header = OpenEXR.Header(size_x, size_y)
# header['channels'] = {'R': Imath.Channel(Imath.PixelType(OpenEXR.HALF))}
# out = OpenEXR.OutputFile(path, header)
# out.writePixels({'R': img})


# def save_exrs(output_path, visuals, image_path, aspect_ratio=1.0):
# """ Save images to the disk as exrs

# Parameters
# ----------
# output_path (string): path to directory of output folder
# visuals (dict<string><tensor>): dictionary mapping image names (real_A, real_B, fake_B) to image tensors
# image_path (string): full path to original image in the dataset
# aspect_ratio (float): scale the aspect ratio of the image

# """
# if not os.path.exists(output_path):
# os.makedirs(output_path)

# short_path = ntpath.basename(image_path[0])
# name = os.path.splitext(short_path)[0]

# for label, im_data in visuals.items():
# image_numpy = tensor2im(im_data, use_exr = True, no_normal = (label == 'real_A'))
# image_name = '%s_%s.exr' % (name, label)
# save_path = os.path.join(output_path, image_name)
# h, w, _ = image_numpy.shape
# if aspect_ratio > 1.0:
# image_numpy = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') # increase the width
# if aspect_ratio < 1.0:
# image_numpy = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') # increase the height
# save_to_exr(image_numpy, save_path)

def save_npys(output_path, visuals, image_path, aspect_ratio=1.0):
    """Save images to the disk as exrs

    Parameters
    ----------
        output_path (string): path to directory of output folder
        visuals (dict<string><tensor>): dictionary mapping image names (real_A, real_B, fake_B) to image tensors
        image_path (string): full path to original image in the dataset
        aspect_ratio (float): scale the aspect ratio of the image

    """
    os.makedirs(output_path, exist_ok=True)

    short_path = ntpath.basename(image_path[0])
    name = os.path.splitext(short_path)[0]

    for label, im_data in visuals.items():
        image_numpy = tensor2im(im_data, use_exr=True, no_normal=(label == "real_A"))
        image_name = "%s_%s.npy" % (name, label)
        save_path = os.path.join(output_path, image_name)
        h, w, _ = image_numpy.shape
        if aspect_ratio > 1.0:
            image_numpy = imresize(
                im, (h, int(w * aspect_ratio)), interp="bicubic"
            )  # increase the width
        if aspect_ratio < 1.0:
            image_numpy = imresize(
                im, (int(h / aspect_ratio), w), interp="bicubic"
            )  # increase the height
        np.save(save_path, image_numpy)

def unnormalize_exr(imgs):
    """Unormalize np arrays to save as exr

    Parameters
    ----------
        imgs (numpy array): numpy array of shape (height, width, channel)

    Returns
    -------
        imgs (numpy array): numpy array with each channel respectively unormalized
    """
    MIN = np.array([48.625, 0.0, 0.0])
    RANGE = np.array([917.375, 389.25, 28.34375])
    imgs = imgs * RANGE + MIN

    return imgs

def tensor2im(input_image, imtype=np.uint8, use_exr=False, no_normal=False):
    """Converts a tensor array into a numpy image array.

    Parameters
    ----------
        input_image (tensor): the input image tensor array
        imtype (type): the desired type of the converted numpy array
        use_exr (bool): if false, prepare the numpy array to be saved as a png, else prepare to save as exr
        no_normal (bool): only applies when use_exr is true, since we normalize each colour channel uniquely. If this true, then do not unormalize

    Returns
    -------
        return (numpy array): numpy array representation of tensor
    """

    if not isinstance(input_image, np.ndarray):  # if input_image is not a numpy array
        if isinstance(input_image, torch.Tensor):  # if input_image is a tensor
            image_tensor = input_image.data  # get the data from a variable
        else:
            return input_image
        image_numpy = (
            image_tensor[0].cpu().float().numpy()
        )  # convert it into a numpy array
        if use_exr == False:  # process to save as a png
            if image_numpy.shape[0] == 1:  # grayscale to RGB
                image_numpy = np.tile(image_numpy, (3, 1, 1))
            image_numpy = (
                (np.transpose(image_numpy, (1, 2, 0)) + 1) * 0.5 * 255.0
            )  # post-processing: tranpose and scaling
        else:  # process to save as an exr
            image_numpy = (
                np.transpose(image_numpy, (1, 2, 0)) + 1
            ) * 0.5  # Scale to [0-1]
            # if no_normal == False:
            #     img = unnormalize_exr(img)
            return image_numpy.astype(np.float16)  # exr uses float16 format
    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
    return image_numpy.astype(imtype)

def exr_to_np(path, channels=3):
    """Convert an exr file to an numpy array

    Parameters
    ----------
        path (string): path to exr file
        channels (int): number of channels in the exr image

    Returns
    -------
        img (numpy array): 3D numpy array in format (height,width,channels)
    """
    file = OpenEXR.InputFile(path)  # Create an InputFile object to read an EXR file
    dw = file.header()[
        "dataWindow"
    ]  # dataWindow key in header points to a Imath Box2i that begins with an Imath point at the upper left coordinate and ends with an Imath point at the bottom right
    size = (
        dw.max.x - dw.min.x + 1,
        dw.max.y - dw.min.y + 1,
    )  # Extract the width and height from the Imath Box2i
    HALF = Imath.PixelType(Imath.PixelType.HALF)  # Equivalent to float16
    if channels == 3:
        R, G, B = [
            np.fromstring(c, dtype=np.float16).reshape(size[1], size[0])
            for c in file.channels("RGB", HALF)
        ]  # channels contains a list of string data for each channel - read each and convert to a numpy array of type float16
        img = np.array([R, G, B]).transpose(
            (1, 2, 0)
        )  # Shape is now (height, width, channels)
    elif channels == 1:
        img = np.fromstring(file.channel("R", HALF), dtype=np.float16).reshape(
            size[1], size[0], 1
        )  # Shape is now (height, width, channels)
    return img

def exr_to_tensor(path, channels=3):
    """Convert an exr file to a tensor

    Parameters
    ----------
        path (string): path to exr file
        channels (int): number of channels in the exr image

    Returns
    -------
        return (tensor): 3D tensor in format (channels,height,width)

    """
    np_arr = exr_to_np(path, channels=channels)  # Convert exr to numpy array
    return torch.from_numpy(np_arr.transpose((2, 0, 1))).type(torch.FloatTensor)

def print_current_losses(log_name, epoch, iters, losses, t_comp, t_data):
    """Print current losses in console, and log them to a file in the disk

    Parameters:
        epoch (int): current epoch
        iters (int): current training iteration during this epoch (reset to 0 at the end of every epoch)
        losses (OrderedDict<string><float>): training losses stored in the format of (name, float) pairs - same format as |losses| of plot_current_losses
        t_comp (float): computational time per data point (normalized by batch_size)
        t_data (float): data loading time per data point (normalized by batch_size)
    """
    message = "(epoch: %d, iters: %d, time: %.3f, data: %.3f) " % (
        epoch,
        iters,
        t_comp,
        t_data,
    )
    for k, v in losses.items():
        message += "%s: %.3f " % (k, v)

    print(message)  # print the message
    with open(log_name, "a") as log_file:
        log_file.write("%s\n" % message)  # save the message

def get_data_from_file(file_name, num_loss, num_epochs_written):
    """Read most recent unwritten loss data from a log file

    Parameters
    ----------
        file_name (string): full path to log file
        num_loss (int): number of loss functions
        num_epochs_written (int): number of epochs already saved in main process

    Returns
    -------
        tuple(list<int>,list<list<float>>):
            epochs (list<int>): list of epoch values for each row of unwritten data
            data (list<list<float>>): list of same length as epochs, with each index containing list of loss data
    """

    epochs = []
    data = [[] for x in range(num_loss)]

    with open(file_name, "r") as file:
        file = file.readlines()
        row_count = len(file)

        if row_count > num_epochs_written:
            diff = row_count - num_epochs_written
            unwritten = list(file[-diff:])
            for row in unwritten:
                row = row.split(",")
                epochs.append(int(row[0]))
                row = row[1:-1]

                assert len(row) == num_loss
                for i in range(num_loss):
                    data[i].append(float(row[i]))

    return (epochs, data)

def serialize_options_and_save(options, dir_path, model_name):
    """Save options and hyperparams objects to a file in disk

    Parameters
    ----------
        options: list containing an Options object and Hyperparams object
        dir_path (string): path to directory in which to place the file

    Returns
    -------
        filename: full path to the new file
    """
    os.makedirs(dir_path, exist_ok=True)

    filename = os.path.join(dir_path, "{}.txt".format(model_name))
    obj_dict = {}
    for option in options:
        obj_dict.update(vars(option))

    with open(filename, "w+") as file:
        for key, val in obj_dict.items():
            file.write("{0} {1}\n".format(key, val))

    return filename

def create_tensorboard_graphs(w, options, hyper_params, data):
    """Display model as a graph on tensorboard

    Parameters
    ----------
        w: tensorboard SummaryWriter object
        hyper_params: HyperParams object
        data: sample image data (dummy input required to determine graph)

    """
    net = UnetGenerator(
        options.channels,
        options.channels,
        hyper_params.n_layers_G,
        64,
        up_activation=hyper_params.up_activation,
        down_activation=hyper_params.down_activation,
        filter_size=hyper_params.filter_size,
    )
    w.add_graph(net, data["B"], False)
