#
# Copyright (c) <2023> Side Effects Software Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# NAME:	        ml/pix2pix/exr.py ( Python )
#
# COMMENTS:     EXR loading interface
#
#               Note that scripts in this module are not part of a stable API,
#               and are subject to change at any time.

import Imath
import numpy as np
import OpenEXR
import torch

def exr_to_np(path, channels=3):
    file = OpenEXR.InputFile(path)
    dw = file.header()["dataWindow"]
    size = (dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1)
    HALF = Imath.PixelType(Imath.PixelType.HALF)
    if channels == 3:
        R, G, B = [
            np.fromstring(c, dtype=np.float16).reshape(size[1], size[0])
            for c in file.channels("RGB", HALF)
        ]
        img = np.array([R, G, B]).transpose((1, 2, 0))
    elif channels == 1:
        img = np.fromstring(file.channel("A", HALF), dtype=np.float16).reshape(
            size[1], size[0], 1
        )
    return img


def exr_to_tensor(path, channels=3):
    np_arr = exr_to_np(path, channels=channels)
    if channels == 3:
        return torch.from_numpy(np_arr.transpose((2, 0, 1))).type(torch.FloatTensor)
    elif channels == 1:
        return torch.from_numpy(np_arr).type(torch.FloatTensor)


def save_to_exr(img, path):
    img = img.astype(np.float16)
    size_x, size_y = img.shape[1], img.shape[0]

    if img.shape[-1] == 3:
        R, G, B = [img[:, :, channel].tostring() for channel in [0, 1, 2]]
        header = OpenEXR.Header(size_x, size_y)
        header["channels"] = {
            "R": Imath.Channel(Imath.PixelType(OpenEXR.HALF)),
            "G": Imath.Channel(Imath.PixelType(OpenEXR.HALF)),
            "B": Imath.Channel(Imath.PixelType(OpenEXR.HALF)),
        }

        out = OpenEXR.OutputFile(path, header)
        out.writePixels({"R": R, "G": G, "B": B})
    elif img.shape[-1] == 1:
        header = OpenEXR.Header(size_x, size_y)
        header["channels"] = {"R": Imath.Channel(Imath.PixelType(OpenEXR.HALF))}
        out = OpenEXR.OutputFile(path, header)
        out.writePixels({"R": img})
