#
# Copyright (c) <2023> Side Effects Software Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# NAME:	        ml/pix2pix/model.py ( Python )
#
# COMMENTS:     Defines classes pix2pix model, includng generator and
#               discriminator.
#
#               Note that scripts in this module are not part of a stable API,
#               and are subject to change at any time.

import os

import torch
import torch.nn as nn

from collections import OrderedDict
from torch.optim import lr_scheduler

from pdgml.pix2pix.ssim import ssim

def get_scheduler(optimizer, hyper_params):
    """Initializes the desired learning rate scheduler based on hyper-parameters.

    Parameters:
        optimizer (nn.optim.Optimizer) -- optimizer to apply LR policy
        hyper_params (HyperParams) -- session hyper-parameters

    Returns:
        correct scheduler object
    """
    if hyper_params.lr_policy == "linear":

        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + 1 - hyper_params.niter) / float(
                hyper_params.niter_decay + 1
            )
            return lr_l

        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif hyper_params.lr_policy == "step":
        scheduler = lr_scheduler.StepLR(
            optimizer, step_size=hyper_params.lr_decay_iters, gamma=0.1
        )
    elif hyper_params.lr_policy == "plateau":
        scheduler = lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.2, threshold=0.01, patience=5
        )
    elif hyper_params.lr_policy == "cosine":
        scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=hyper_params.niter, eta_min=0
        )
    else:
        return NotImplementedError(
            "learning rate policy [%s] is not implemented", hyper_params.lr_policy
        )
    return scheduler


def get_activation(act_type):
    """Returns the desired activation type.

    Parameters:
        act_type (string) -- name of desired activation type

    Returns:
        nn.Module: specifically activation module
    """
    if act_type == "relu":
        return nn.ReLU(True)
    elif act_type == "sigmoid":
        return nn.Sigmoid()
    elif act_type == "tanh":
        return nn.Tanh()
    elif act_type == "logsoft":
        return nn.LogSoftmax()
    elif act_type == "rrelu":
        return nn.RReLU()
    elif act_type == "leaky_relu":
        return nn.LeakyReLU(0.2, True)


def init_weights(net, init_type="normal", init_gain=0.02):
    """Initializes the network weights using desired method

    Parameters:
        net (nn.Module) -- network to be initialized
        init_type (string) -- init method
        init_gain (float) -- scaling factor
    """

    def init_func(m):
        """
        Defines initialization function to be called on all submodules of net

        Parameters:
            m (nn.Module) -- submodule of net
        """
        classname = m.__class__.__name__

        if hasattr(m, "weight") and (
            classname.find("Conv") != -1 or classname.find("Linear") != -1
        ):
            if init_type == "normal":
                nn.init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == "xavier":
                nn.init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == "kaiming":
                nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
            elif init_type == "orthogonal":
                nn.init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError(
                    "initialization method [%s] is not implemented" % init_type
                )

            if hasattr(m, "bias") and m.bias is not None:  # Init bias parameters
                nn.init.constant_(m.bias.data, 0.0)

        elif (
            classname.find("BatchNorm2d") != -1
        ):  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            nn.init.normal_(m.weight.data, 1.0, init_gain)
            nn.init.constant_(m.bias.data, 0.0)

    print("initialize network with %s" % init_type)
    net.apply(init_func)  # apply the initialization function <init_func>


class GANLoss(nn.Module):
    """Define different GAN objectives.

    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """

    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        """Initialize the GANLoss class.

        Parameters:
            gan_mode (str) -- the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) -- label for a real image
            target_fake_label (bool) -- label of a fake image

        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
        # These buffers are persistent.  So by default, we have real_label being a tensor of 1s
        # and fake_label being a tensor of 0s.  Later on, in the __call__ function, we
        # construct the target tensor as being 1s or 0s based on whether it's "real" or not
        # These are expanded to the shape of the needed prediction size, as needed.
        self.register_buffer("real_label", torch.tensor(target_real_label))
        self.register_buffer("fake_label", torch.tensor(target_fake_label))
        self.gan_mode = gan_mode

        if gan_mode == "lsgan":
            self.loss = nn.MSELoss()
        elif gan_mode == "vanilla":
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == "nll":
            self.loss = nn.NLLLoss()
        elif gan_mode == "l1":
            self.loss = nn.L1Loss()
        elif gan_mode in ["wgangp"]:
            self.loss = None
        else:
            raise NotImplementedError("gan mode %s not implemented" % gan_mode)

    def get_target_tensor(self, prediction, target_is_real):
        """Create label tensors with the same size as the input.

        Parameters:
            prediction (tensor) -- the prediction from a discriminator
            target_is_real (bool) -- if the ground truth label is for real images or fake images

        Returns:
            A label tensor filled with ground truth label, and with the size of the input
        """
        target_tensor = self.real_label if target_is_real else self.fake_label
        return target_tensor.expand_as(prediction)

    def __call__(self, prediction, target_is_real):
        """Calculate loss given Discriminator's output and ground truth labels.

        Parameters:
            prediction (tensor) -- typically the prediction output from a discriminator
            target_is_real (bool) -- if the ground truth label is for real images or fake images

        Returns:
            the calculated loss.
        """
        if self.gan_mode in ["lsgan", "vanilla"]:
            target_tensor = self.get_target_tensor(prediction, target_is_real)
            return self.loss(prediction, target_tensor)
        elif self.gan_mode == "wgangp":
            return -prediction.mean() if target_is_real else prediction.mean()


class UnetGenerator(nn.Module):
    """Constructor for U-Net generator"""

    def __init__(
        self,
        input_nc,  # (int) -- number of channels in input images
        output_nc,  # (int) -- number of channels in output images
        num_downs,  # (int) -- number of downsamplings in U-Net
        #             E.g. if |num_downs| == 7, image of size 128x128
        #                  will become of size 1x1 at the bottleneck
        ngf=64,  # (int) -- number of filters in the last conv layer
        use_dropout=False,  # (bool) -- whether or not to use dropout
        up_activation="relu",  # (string) -- name of activation type in up conv layers
        down_activation="leaky_relu",  # (string) -- name of activation type in down conv layers
        filter_size=4,  # (int) -- filter size
    ):
        """
        Generator is contructed recursively with pairs of down convolution and up convolution layers.
        Overall network is built from the innermost pair (bottleneck) to the outermost pair.
        """
        super(UnetGenerator, self).__init__()
        unet_block = None

        # why -4 ? The innermost (ie. the bottom of the U) layers are quite deep
        # because every time we convolve and downsample, generally the number of
        # channels/filters grows.  So what we are saying here is that the bottom
        # num_downs - 4 pairs of the U have the number of channels fixed at
        # ngf*8, which at the default settings means 512.
        num_layers_with_fixed_depth = num_downs - 4
        for i in range(num_layers_with_fixed_depth):
            is_initial_block = unet_block == None
            unet_block = UnetSkipConnectionBlock(
                ngf * 8,
                ngf * 8,
                ngf * 8,
                submodule=unet_block,
                use_dropout=False if is_initial_block else use_dropout,
                innermost=is_initial_block,
                up_activation=up_activation,
                down_activation=down_activation,
                filter_size=filter_size,
            )
        num_downs -= num_layers_with_fixed_depth

        for i in range(num_downs, 1, -1):
            inner_nc = ngf * (2 ** (i - 1))
            in_out_nc = ngf * (2 ** (i - 2))

            is_initial_block = unet_block == None
            unet_block = UnetSkipConnectionBlock(
                in_out_nc,
                inner_nc,
                in_out_nc,
                submodule=unet_block,
                innermost=is_initial_block,
                up_activation=up_activation,
                down_activation=down_activation,
                filter_size=filter_size,
            )

        # Outermost (could also be innermost if num_downs == 1)
        is_initial_block = unet_block == None
        self.model = UnetSkipConnectionBlock(
            output_nc,
            ngf,
            input_nc,
            submodule=unet_block,
            innermost=is_initial_block,
            outermost=True,
            up_activation=up_activation,
            down_activation=down_activation,
            filter_size=filter_size,
        )

    def forward(self, input):
        """Standard forward"""
        return self.model(input)


# class Interpolate(nn.Module):
#     def __init__(self, scale_factor, mode):
#         super(Interpolate, self).__init__()
#         self.interp = nn.functional.interpolate
#         self.scale_factor = scale_factor
#         self.mode = mode

#     def forward(self, x):
#         x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode)
#         return x


class UnetSkipConnectionBlock(nn.Module):
    """Defines the Unet submodule with skip connection.
       +--------------------identity---------------------------+
       |                                                       |
    ---+-- downsampling -- |submodule| -- upsampling --[concat channels]---
    """

    # "inner" is the output of the downsampling and the input of the upsampling
    # "outer" is the output of the upsampling
    # "innermost" is the bottom of the U net and "outermost" is the top of
    # the U net.
    def __init__(
        self,
        outer_nc,  # (int) -- the number of filters in the outer conv layer
        inner_nc,  # (int) -- the number of filters in the inner conv layer
        input_nc,  # (int) -- the number of channels in input image
        submodule=None,  # (UnetSkipConnectionBlock) -- previously defined submodules through recursive process
        outermost=False,  # (bool) -- whether this module is the outermost (direct from input and out to output)
        innermost=False,  # (bool) -- whether this module is the innermost (two layers adjacent to bottleneck)
        use_dropout=False,  # (bool) -- whetehr to use dropout
        up_activation="relu",  # (string) name of up_activation
        down_activation="leaky_relu",  # (string) name of down_activation
        filter_size=4,  # (int) size of ilters
    ):
        super(UnetSkipConnectionBlock, self).__init__()

        kw = filter_size
        pw = int((filter_size - 2) / 2)

        downconv = nn.Conv2d(
            input_nc, inner_nc, kernel_size=kw, stride=2, padding=pw, bias=False
        )
        downnorm = nn.BatchNorm2d(inner_nc)
        upnorm = nn.BatchNorm2d(outer_nc)
        uprelu = get_activation(up_activation)
        downrelu = get_activation(down_activation)

        if innermost and outermost:
            # this is only true if there is only 1 layer
            upconv = nn.ConvTranspose2d(
                inner_nc, outer_nc, kernel_size=kw, stride=2, padding=pw
            )
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + up

        elif outermost:
            # ie. at the top of the U net
            upconv = nn.ConvTranspose2d(
                inner_nc * 2, outer_nc, kernel_size=kw, stride=2, padding=pw
            )
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up

        elif innermost:
            # ie. at the bottom of the U net
            upconv = nn.ConvTranspose2d(
                inner_nc, outer_nc, kernel_size=kw, stride=2, padding=pw, bias=False
            )
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up

        else:
            # it's very important that the number of accepted input layers is inner_nc*2.  The times two
            # is important because it allows BOTH the original image and the generated image to
            # co-exist and so doubles the number of filters to learn.  Later on in SKIP CONNECTION IMPLEMENTATION
            # we rely on this double up to be able to concatenate the two tensors to implement skip
            # connections.
            upconv = nn.ConvTranspose2d(
                inner_nc * 2, outer_nc, kernel_size=kw, stride=2, padding=pw, bias=False
            )
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            model = down + [submodule] + up

            if use_dropout:
                model = model + [nn.Dropout(0.5)]

        # upsampl = Interpolate(2, 'nearest')
        # pad = nn.ReflectionPad2d(1)
        # if inner_outer:
        #     upconv = nn.Conv2d(inner_nc,outer_nc,kernel_size=3,stride=1,padding=0)
        #     down = [downconv]
        #     up = [uprelu, upsampl, pad, upconv, nn.Tanh()]
        #     model = down + up
        # elif outermost:
        #     upconv = nn.Conv2d(inner_nc * 2,outer_nc,kernel_size=3,stride=1,padding=0)
        #     down = [downconv]
        #     up = [uprelu, upsampl, pad, upconv, nn.Tanh()]
        #     model = down + [submodule] + up
        # elif innermost:
        #     upconv = nn.Conv2d(inner_nc,outer_nc,kernel_size=3,stride=1,padding=0,bias=use_bias)
        #     down = [downrelu, downconv]
        #     up = [uprelu, upsampl, pad, upconv, upnorm]
        #     model = down + up
        # else:
        #     upconv = nn.Conv2d(inner_nc * 2,outer_nc,kernel_size=3,stride=1,padding=0,bias=use_bias)
        #     down = [downrelu, downconv, downnorm]
        #     up = [uprelu, upsampl, pad, upconv, upnorm]

        #     if use_dropout:
        #         model = down + [submodule] + up + [nn.Dropout(0.5)]
        #     else:
        #         model = down + [submodule] + up

        self.outermost = outermost
        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:  # add skip connections
            # SKIP CONNECTION IMPLEMENTATION
            return torch.cat([x, self.model(x)], 1)


class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(
        self, input_nc, ndf=64, n_layers=3, down_activation="leaky_relu", filter_size=4
    ):
        """Construct a PatchGAN discriminator

        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()

        kw = filter_size
        pw = int((filter_size - 2) / 2)
        downrelu = get_activation(down_activation)

        # An example with say 256 x 256 images (the 6 is because of concatentation of realA,fakeB or
        # realA,realB pairs
        # 6x256x256 -> 64x128x128
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=pw),
            downrelu,
        ]

        # 64x128x128 -> 128x64x64 -> 256x32x32
        for i in range(1, n_layers):  # gradually increase the number of filters
            sequence += [
                nn.Conv2d(
                    ndf * min(2 ** (i - 1), 8),
                    ndf * min(2 ** i, 8),
                    kernel_size=kw,
                    stride=2,
                    padding=pw,
                    bias=False,
                ),
                nn.BatchNorm2d(ndf * min(2 ** i, 8)),
                downrelu,
            ]

        # 256x32x32 -> 512x31x31 -> 1x30x30
        sequence += [
            nn.Conv2d(
                ndf * min(2 ** (n_layers - 1), 8),
                ndf * min(2 ** n_layers, 8),
                kernel_size=kw,
                stride=1,
                padding=pw,
                bias=False,
            ),
            nn.BatchNorm2d(ndf * min(2 ** n_layers, 8)),
            downrelu,
            nn.Conv2d(
                ndf * min(2 ** n_layers, 8), 1, kernel_size=kw, stride=1, padding=pw
            ),  # output 1 channel prediction map
        ]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)


def create_model(hyper_params, options):
    """Create a model given the option. This is the main interface between this package and 'train.py'/'test.py'.

    Parameters:
        hyper_params (HyperParams) -- experiment hyper-parameters
        options (Options) -- session train and test options

    Returns:
        Pix2PixModel object for the current train/test run
    """
    instance = Pix2PixModel(hyper_params, options)
    print("model [%s] was created" % type(instance).__name__)
    return instance


class Pix2PixModel:
    """This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.
    pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
    """

    def __init__(self, hyper_params, options):
        """Initialize the pix2pix class.

        Parameters:
            hyper_params (HyperParams) -- contains all the experiment hyper-parameters
            options (Options) -- contains all config options
        """
        self.hyper_params = hyper_params
        self.opt = options
        self.isTrain = options.phase == "train"
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        """specify the loss values and image visuals to display"""
        self.loss_names = [
            "G_GAN",
            "G_L1",
            "D_real",
            "D_fake",
            "prediction_D",
            "prediction_D_fake",
            "SSIM",
        ]
        self.visual_names = ["real_A", "fake_B", "real_B"]

        """define and init generator"""
        self.netG = UnetGenerator(
            options.channels,
            options.channels,
            hyper_params.n_layers_G,
            64,
            use_dropout=False,
            up_activation=hyper_params.up_activation,
            down_activation=hyper_params.down_activation,
            filter_size=hyper_params.filter_size,
        ).to(self.device)
        init_weights(self.netG, hyper_params.init_type, hyper_params.init_gain)
        self.model_names = ["G"]

        if self.isTrain:

            """define and init discriminator"""
            self.netD = NLayerDiscriminator(
                options.channels * 2,
                64,
                hyper_params.n_layers_D,
                down_activation=hyper_params.down_activation,
                filter_size=hyper_params.filter_size,
            ).to(self.device)
            init_weights(self.netD, hyper_params.init_type, hyper_params.init_gain)
            self.model_names.append("D")

            """define loss functions"""
            # This instantiates the GANLoss object, and the .to is a nn.Module
            # function that puts the object onto the requested device.
            self.criterionGAN = GANLoss(hyper_params.gan_mode).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()
            self.sigmoid = nn.Sigmoid()

            """initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>."""
            if hyper_params.optimizer == "adam":
                self.optimizer_G = torch.optim.Adam(
                    self.netG.parameters(),
                    lr=hyper_params.lr,
                    betas=(hyper_params.beta1, 0.999),
                )
                self.optimizer_D = torch.optim.Adam(
                    self.netD.parameters(),
                    lr=hyper_params.lr,
                    betas=(hyper_params.beta1, 0.999),
                )
            elif hyper_params.optimizer == "sgd":
                self.optimizer_G = torch.optim.SGD(
                    self.netG.parameters(),
                    lr=hyper_params.lr,
                    momentum=hyper_params.momentum,
                )
                self.optimizer_D = torch.optim.SGD(
                    self.netD.parameters(),
                    lr=hyper_params.lr,
                    momentum=hyper_params.momentum,
                )
            elif hyper_params.optimizer == "asgd":
                self.optimizer_G = torch.optim.ASGD(
                    self.netG.parameters(), lr=hyper_params.lr
                )
                self.optimizer_D = torch.optim.ASGD(
                    self.netD.parameters(), lr=hyper_params.lr
                )
            elif hyper_params.optimizer == "adadelta":
                self.optimizer_G = torch.optim.Adadelta(
                    self.netG.parameters(), lr=hyper_params.lr, rho=hyper_params.rho
                )
                self.optimizer_D = torch.optim.Adadelta(
                    self.netD.parameters(), lr=hyper_params.lr, rho=hyper_params.rho
                )

            self.optimizers = [self.optimizer_G, self.optimizer_D]

    def setup(self, hyper_params, epoch):
        """Load and print networks; create schedulers

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        if self.isTrain:
            self.schedulers = [
                get_scheduler(optimizer, hyper_params) for optimizer in self.optimizers
            ]
        else:
            self.load_networks(epoch)

        self.print_networks(False)

    def eval(self):
        """Make models eval mode during test time"""
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, "net" + name)
                net.eval()

    def test(self):
        """Forward function used in test time.

        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
        It also calls <compute_visuals> to produce additional visualization results
        """
        with torch.no_grad():
            self.forward()

    def get_image_paths(self):
        """ Return image paths that are used to load current data"""
        if self.image_paths != None:
            return self.image_paths

    def update_learning_rate(self):
        """Update learning rates for all the networks; called at the end of every epoch"""
        for scheduler in self.schedulers:
            scheduler.step()
        lr = self.optimizers[0].param_groups[0]["lr"]
        print("learning rate = %.7f" % lr)

    def get_current_visuals(self):
        """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
        visual_ret = OrderedDict()
        for name in self.visual_names:
            if isinstance(name, str):
                visual_ret[name] = getattr(self, name)
        return visual_ret

    def get_current_losses(self):
        """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
        errors_ret = OrderedDict()
        for name in self.loss_names:
            if isinstance(name, str):
                errors_ret[name] = float(
                    getattr(self, "loss_" + name)
                )  # float(...) works for both scalar tensor and float number
        return errors_ret

    def save_networks(self, epoch):
        """Save all the networks to the disk.

        Parameters:
            epoch (int) -- current epoch; used in the file name '%s.net.%s.pth' % (epoch, name)
        """
        save_dir = os.path.join(
            self.opt.checkpoints_dir, self.opt.name);
        os.makedirs(save_dir, exist_ok=True)

        model_paths = []
        for name in self.model_names:
            if isinstance(name, str):
                save_filename = os.path.join(
                    save_dir, "{}.net.{}.pth".format(epoch, name))
                model_paths.append(save_filename)
                net = getattr(self, "net" + name)

                if torch.cuda.is_available():
                    torch.save(net.cpu().state_dict(), save_filename)
                    net.to(self.device)
                else:
                    torch.save(net.cpu().state_dict(), save_filename)
        return model_paths

    def load_networks(self, epoch):
        """Load all the networks from the disk.

        Parameters:
            epoch (int) -- current epoch; used in the file name '%s.net.%s.pth' % (epoch, name)
        """
        for name in self.model_names:
            if isinstance(name, str):
                load_path = os.path.join(
                    self.opt.checkpoints_dir, 
                    self.opt.name,
                    "{}.net.{}.pth".format(epoch, name))
                net = getattr(self, "net" + name)
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                print("loading the model from %s" % load_path)

                # if you are using PyTorch newer than 0.4 (e.g., built from
                # GitHub source), you can remove str() on self.device
                state_dict = torch.load(load_path, map_location=str(self.device))
                if hasattr(state_dict, "_metadata"):
                    del state_dict._metadata

                net.load_state_dict(state_dict)

    def print_networks(self, verbose):
        """Print the total number of parameters in the network and (if verbose) network architecture

        Parameters:
            verbose (bool) -- if verbose: print the network architecture
        """
        print("---------- Networks initialized -------------")
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, "net" + name)
                num_params = 0
                for param in net.parameters():
                    num_params += param.numel()
                if verbose:
                    print(net)
                print(
                    "[Network %s] Total number of parameters : %.3f M"
                    % (name, num_params / 1e6)
                )
        print("-----------------------------------------------")

    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=False for all the networks to avoid unnecessary computations
        Parameters:
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap images in domain A and domain B.
        """
        AtoB = self.opt.direction == "AtoB"
        self.real_A = input["A" if AtoB else "B"].to(self.device)
        self.real_B = input["B" if AtoB else "A"].to(self.device)
        self.image_paths = input["AB_paths"]

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG(self.real_A)  # G(A)

    def backward_D(self):
        """Calculate D loss on fake images"""
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)  # feed both A and G(A) to D
        # detach surgically removes all the tracked computations up to fake_B from the
        # compute graph of the generator.  If we don't do this, then when we backprop,
        # the generator weights will be affected.  no_grad is too blunt a hammer, it would
        # prevent the construction of the compute graph in the first place.
        pred_fake = self.netD(fake_AB.detach())  # prevent backprop to G weights
        self.loss_D_fake = self.criterionGAN(pred_fake, False)  # compute loss

        """Calculate D loss on real images"""
        real_AB = torch.cat((self.real_A, self.real_B), 1)  # feed both A and B to D
        pred_real = self.netD(real_AB)  # get D prediction on real
        self.loss_D_real = self.criterionGAN(pred_real, True)  # compute loss

        """Ratio of D being correct, NOT backpropagated"""
        # the .to(dtype=torch.float) is just a simple type conversion from ints
        correct_pred_fake = (
            torch.sum(self.sigmoid(pred_fake) <= 0.5).to(dtype=torch.float)
            / pred_fake.nelement()
        )
        correct_pred_true = (
            torch.sum(self.sigmoid(pred_real) >= 0.5).to(dtype=torch.float)
            / pred_real.nelement()
        )
        self.loss_prediction_D_fake = correct_pred_fake.item()
        self.loss_prediction_D = (
            correct_pred_fake.item() + correct_pred_true.item()
        ) / 2.0

        """Combine loss and calculate gradients"""
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        """Compute loss based on how well G fools D
        Note: D has gradients disabled in this step"""
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        # our sample is by definition a fake, but we need the discriminator to
        # to think it is real.  So if pred_fake is all ones, the discriminator
        # thinks this is real, and so we did a good job and have no loss.
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        """Compute L1 between G(A) and B"""
        self.loss_G_L1 = (
            self.criterionL1(self.fake_B, self.real_B) * self.hyper_params.lambda_L1
        )

        """Compute SSIM during training, NOT backpropagated"""
        # the 0.5 is to recenter to [0,1]
        self.loss_SSIM = ssim(
            (self.fake_B * 0.5) + 0.5, (self.real_B * 0.5) + 0.5
        ).item()

        """Combine loss and calculate gradients"""
        # we are not just averaging here because the loss_G_L1 is already weighted
        # by the lambda_L1 parameter above
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()  # compute fake images: G(A)

        """Update D"""
        self.set_requires_grad(self.netD, True)  # enable backprop for D
        self.optimizer_D.zero_grad()  # set D's gradients to zero
        self.backward_D()  # calculate gradients for D
        self.optimizer_D.step()  # update D's weights

        """update G"""
        self.set_requires_grad(
            self.netD, False
        )  # disable backprop in D when optimizing G
        self.optimizer_G.zero_grad()  # set G's gradients to zero
        self.backward_G()  # calculate graidents for G
        self.optimizer_G.step()  # udpate G's weights
