#
# Copyright (c) <2023> Side Effects Software Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# NAME:	        ml/pix2pix/ssim.py ( Python )
#
# COMMENTS:     Quantifies differences in images using SSIM
#
#               Original source from:
#                   https://github.com/Po-Hsun-Su/pytorch-ssim
#
#               Note that scripts in this module are not part of a stable API,
#               and are subject to change at any time.

import numpy as np
import torch
import torch.nn.functional as F

from math import exp

def create_window(window_size, channel, sigma):
    """Create a window of values distributed along a gaussian curve

    Parameters
    ----------
        window_size (int): heigth and width of 2D tensor
        channel (int): number of chanels of input image
        sigma (float): standard deviation of gaussian distribution

    Return
    ------
        window (4D tensor<channel><1><window_size><window_size>): 2D tensor duplicated for each channel

    """
    gauss = torch.Tensor(
        [
            exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2))
            for x in range(window_size)
        ]
    )
    gauss = (
        gauss / gauss.sum()
    )  # gauss = gauss/(sqrt(2*math.pi)*sigma) (proper mathematical definition)
    _1D_window = gauss.unsqueeze(
        1
    )  # Generate a tensor (dimension: (window_size)) of values distributed along a gaussian curve

    _2D_window = (
        _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    )  # Generate a 2D tensor represented in 4D (dimension: (1,1,window_size,window_size)
    window = _2D_window.expand(
        channel, 1, window_size, window_size
    ).contiguous()  # Duplicate the 2D window along the first dimension
    return window


def ssim(img1, img2, window_size=11):
    """Structural similarity measurement

    Parameters
    ----------
        img1 (4D tensor <batch><channel><height><width>): input image to compare to img2
        img2 (4D tensor <batch><channel><height><width>): input image to compare to img1
        window_size (int): size of kernel window

    Returns
    -------
        ssim_map.mean() (1Dtensor): Single value between 0-1 representing SSIM

    """

    (_, channel, _, _) = img1.size()  # Extract the number of channels from the image
    window = create_window(
        window_size, channel, 1.5
    )  # Create a filter to convolve over the images

    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    mu1 = F.conv2d(
        img1, window, padding=window_size // 2, groups=channel
    )  # Convolve to extract the local means
    mu2 = F.conv2d(
        img2, window, padding=window_size // 2, groups=channel
    )  # Convolve to extract the local means

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = (
        F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    )  # convolve to extract local standard deviations
    sigma2_sq = (
        F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    )  # convolve to extract local standard deviations
    sigma12 = (
        F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
        - mu1_mu2
    )  # convolve to extract local cross variance

    C1 = 0.01 ** 2  # Regularization Constant
    C2 = 0.03 ** 2  # Regularization Constant

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
        (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
    )  # Calculate the localized SSIM

    return ssim_map.mean()
