#
# Copyright (c) <2023> Side Effects Software Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# NAME:	        ml/pix2pix/computemetrics.py ( Python )
#
# COMMENTS:     Calculation of error and loss measurements
#
#               Note that scripts in this module are not part of a stable API,
#               and are subject to change at any time.

import os
import glob
import numpy as np
import torch

from PIL import Image
from torchvision import models, transforms
from torch.autograd import Variable
from tqdm import tqdm

from pdgml.pix2pix.ssim import ssim
from pdgml.pix2pix.utils import exr_to_tensor

class VGG16Extractor(torch.nn.Module):
    """Feature extraction from ImageNet classification pretrained VGG16
    Used for obtaining high level feature information needed for computing style loss

    Attributes
    ----------
        max_pooling1 (list<torch.nn.Module.feature>): List of specific layers of the vgg16 model
        max_pooling2 (list<torch.nn.Module.feature>): List of specific layers of the vgg16 model
        max_pooling3 (list<torch.nn.Module.feature>): List of specific layers of the vgg16 model

    """

    def __init__(self):
        super().__init__()
        vgg16 = models.vgg16(pretrained=True)
        self.max_pooling1 = vgg16.features[:5]
        self.max_pooling2 = vgg16.features[5:10]
        self.max_pooling3 = vgg16.features[10:17]
        for i in range(1, 4):
            for param in getattr(self, "max_pooling{:d}".format(i)).parameters():
                param.requires_grad = False

    def forward(self, image):
        """Forward pass through the vgg16 model

        Parameter
        ---------
            image (tensor<batch><channel><height><width>): input image

        Return
        ------
            results (list<tensors>): List of tensors consisting of the output tensor of each of the final layers specified in self.max_pooling1, self.max_pooling2, self.max_pooling3

        """
        results = [image]
        for i in range(1, 4):
            func = getattr(self, "max_pooling{:d}".format(i))
            results.append(func(results[-1]))
        return results[1:]


def sort_paths(arr, l, r):
    """Sort image paths by the numerical index contained in file name

    Parameters
    ----------
        arr (list<string>): image paths
        l, r (string): values used to parse out the image index value

    Return
    ------
        return (list<string>): list of sorted image paths

    """
    return sorted(arr, key=lambda x: float(x.split(r)[0].split(l)[1]))


def l1_loss(A, B):
    """Calculate the L1 loss between two images

    Parameters
    ----------
        A, B (tensors): RGB data of the two images being compared

    Return
    ------
        return (float): value of L1 Loss

    """
    loss = torch.nn.L1Loss()
    return loss(A, B).item()


def style_loss(A, B):
    """Calculate the style loss between two images

    Parameters
    ----------
        A, B (list<tensor>): Lists of intermediate model output produced from VGG16 feature extraction

    Return
    ------
        loss (float): value of style loss

    """

    loss = 0.0
    l1 = torch.nn.L1Loss()

    for i in range(len(A)):
        loss += l1(gram_matrix(A[i]), gram_matrix(B[i])).item()

    return loss


def gram_matrix(m):
    """Gram matrix used in style loss computation

    Parameter
    ---------
        m (tensor): tensor representing an intermediate or final output from the VGG16 extraction

    Return
    ------
        m (tensor): modified tensor representing gram matrix of size channel*channel
    """
    batch, channel, h, w = m.size()
    m = m.view(channel, h * w)
    m = torch.mm(m, m.t()) / (channel * h * w)
    return m


def bce_loss(A, B):
    """Calculate the binary cross entropy loss with logits

    Parameters
    ----------
        A, B (tensors): RGB data of the two images being compared

    Return
    ------
        return (float): value of bce Loss

    """
    loss = torch.nn.BCEWithLogitsLoss()
    return loss(A, B).item()


def ssim_loss(image_fake_B, image_real_B):
    """Calculation of SSIM (implementation in ssim.py)

    Parameters
    ----------
        A, B (tensors): RGB data of the two images being compared

    Return
    ------
        return (float): value of SSIM

    """
    return ssim(image_real_B, image_fake_B).item()

def pic_to_tensor(pic_path, use_exr, use_npy, channels):
    """Opens an image from the disc and converts it to tensor

    Parameters
    ----------
        pic_path (string): path to image
        use_exr (bool): bool from options to designate the image as an exr file
        channels (int): number of channels, used in the conversion of exr to tensor

    Returns
    -------
        return (3Dtensor<channel><height><width>): 3D tensor of the image

    """
    if use_exr is True:
        return exr_to_tensor(pic_path, channels=channels)
    elif use_npy:
        return transforms.functional.to_tensor(np.load(pic_path)).float()
    else:
        img_transform = transforms.ToTensor()
        return img_transform(Image.open(pic_path).convert("RGB"))


loss_measures = [
    ("L1Loss", l1_loss),
    ("SSIM", ssim_loss)
    # ("StyleLoss", style_loss),
    # ("BCE", bce_loss),
]


def compute(source_path, channels, use_exr=False, use_npy=False):
    """Computes loss metrics between fake and real images in given directory

    Parameters
    ----------
        source_path (string): path to dataset directory
        channels (int): number of channels in an image
        use_exr (bool): boolean that switches on/off the accomodation for exr files

    Returns
    -------
        loss_totals (dict<string><float>): dict with the average loss values across the dataset for each type of loss

    """

    loss_totals = {}
    for name, fcn in loss_measures:
        loss_totals[name] = 0.0

    fake_path = (
        "_fake_B.exr"
        if use_exr is True
        else ("_fake_B.npy" if use_npy else "_fake_B.png")
    )
    real_path = (
        "_real_B.exr"
        if use_exr is True
        else ("_real_B.npy" if use_npy else "_real_B.png")
    )

    images_fake = [f for f in glob.glob(os.path.join(source_path, "*" + fake_path))]
    images_real = [f for f in glob.glob(os.path.join(source_path, "*" + real_path))]

    images_fake = sort_paths(images_fake, source_path, fake_path)
    images_real = sort_paths(images_real, source_path, real_path)

    assert len(images_real) == len(images_fake)
    num_test_images = len(images_fake)

    # vgg_extract = VGG16Extractor()

    print("Computing losses for test data in {0}".format(source_path))
    for i in tqdm(range(num_test_images)):

        # transform Image objects to Pytorch Tensors
        tensor_real = pic_to_tensor(images_real[i], use_exr, use_npy, channels)
        tensor_fake = pic_to_tensor(images_fake[i], use_exr, use_npy, channels)

        # add batch dimension (batch_size = 1 during test)
        tensor_real.unsqueeze_(0)
        tensor_fake.unsqueeze_(0)

        # extr_real = vgg_extract(tensor_real)
        # extr_fake = vgg_extract(tensor_fake)

        for name, fcn in loss_measures:
            if name == "StyleLoss":
                loss_totals[name] += fcn(extr_fake, extr_real)
            else:
                loss_totals[name] += fcn(tensor_real, tensor_fake)

    # compute average for each loss value
    for name, value in loss_totals.items():
        loss_totals[name] = value / float(num_test_images)

    return loss_totals
