#
# Copyright (c) <2023> Side Effects Software Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# NAME:	        ml/pix2pix/options.py ( Python )
#
# COMMENTS:     Defines classes for holding and saving training options.
#
#               Note that scripts in this module are not part of a stable API,
#               and are subject to change at any time.

class HyperParams:
    """
    Class contains all of the hyper-parameters that describe the network
    architecture and the parameter learning/optimization process
    """

    """
    Default constructor with preset values
    """

    def __init__(
        self,
        n_layers_G=8,
        n_layers_D=3,
        init_type="normal",
        init_gain=0.02,
        niter=100,
        niter_decay=100,
        optimizer="adam",
        momentum=0.0,
        rho=0.9,
        lambda_L1=100.0,
        beta1=0.5,
        lr=0.0002,
        gan_mode="vanilla",
        lr_policy="linear",
        lr_decay_iters=50,
        up_activation="relu",
        down_activation="leaky_relu",
        filter_size=4,
    ):
        self.n_layers_G = n_layers_G  # (int)       -- number of downsample-upsample pairs in U-Net Generator (default 8)
        self.n_layers_D = n_layers_D  # (int)       -- number of layers in patchGAN discriminator (default 3)
        self.init_type = init_type  # (string)    -- "normal" | "xavier" | "kaiming" | "orthogonal" (default "normal")
        self.init_gain = init_gain  # (float)     -- scaling factor for normal, xavier, and orthogonal (default 0.02)
        self.niter = niter  # (int)       -- number of epochs at initial learning rate (default 100)
        self.niter_decay = niter_decay  # (int)       -- number of epochs to linearly decay learning rate to zero (default 100)
        self.optimizer = optimizer  # (string)    -- "adam" | "sgd" | "asgd" | "adadelta" (default "adam")
        self.momentum = momentum  # (float)     -- momentum for sgd (default 0.0)
        self.rho = rho  # (float)     -- rho value for Adadelta (default 0.9)
        self.beta1 = beta1  # (float)     -- beta1 for Adam (default 0.5)
        self.lambda_L1 = lambda_L1  # (float)     -- weight for L1 loss in generator backprop (default 100.0)
        self.lr = lr  # (float)     -- learning rate (default 0.0002)
        self.gan_mode = gan_mode  # (string)    -- "lsgan" (MSELoss) | "vanilla" (BCEWithLogitsLoss) | "nll" (NLLLoss) | "l1" (L1Loss) | "wgangp" (none) -- (default "vanilla")
        self.lr_policy = lr_policy  # (string)    -- "linear" | "step" | "plateau" | "cosine" (default "linear")
        self.lr_decay_iters = lr_decay_iters  # (int)       -- step_size parameter for stepLR optimization scheduler only (default 50)
        self.up_activation = up_activation  # (string)    -- "relu" | "sigmoid" | "tanh" | "logsoft" | "rrelu" | "leaky_relu" (default "relu")
        self.down_activation = down_activation  # (string)    -- "leaky_relu" | "sigmoid" | "tanh" | "logsoft" | "rrelu" (default "leaky_relu")
        self.filter_size = (
            filter_size  # (int)       -- choose a multiple of 2 (default 4)
        )

    """
    Construction from dict( string -> string ) object, used at test-time
    """

    @classmethod
    def init_from_dict(self, dict):
        self.n_layers_G = int(dict["n_layers_G"])
        self.n_layers_D = int(dict["n_layers_D"])
        self.init_type = dict["init_type"]
        self.init_gain = float(dict["init_gain"])
        self.niter = int(dict["niter"])
        self.niter_decay = int(dict["niter_decay"])
        self.optimizer = dict["optimizer"]
        self.momentum = float(dict["momentum"])
        self.rho = float(dict["rho"])
        self.beta1 = float(dict["beta1"])
        self.lambda_L1 = float(dict["lambda_L1"])
        self.lr = float(dict["lr"])
        self.gan_mode = dict["gan_mode"]
        self.lr_policy = dict["lr_policy"]
        self.lr_decay_iters = int(dict["lr_decay_iters"])
        self.up_activation = dict["up_activation"]
        self.down_activation = dict["down_activation"]
        self.filter_size = int(dict["filter_size"])
        return self


class Options:
    """
    Class contains all of the session opitions that describe directories,
    datasets, and training resources
    """

    """
    Default constructor with preset values
    """

    def __init__(
        self,
        dataroot,
        name,
        max_dataset_size,
        checkpoints_dir="./checkpoints",
        results_dir="./results",
        tensorboard_dir="./tensorboard",
        plots_dir="./plots",
        onnx_dir="./onnx",
        batch_size=1,
        print_freq=100,
        save_epoch_freq=5,
        num_workers=0,
        phase="train",
        crop_size=256,
        use_exr=False,
        channels=3,
        direction="BtoA",
        use_tensorboard=True,
        use_npy=False,
        save_plots=True,
        save_onnx=True,
    ):
        self.dataroot = dataroot  # (string) -- path to dataset root dir, assumes "train" and "test" subdirectories
        self.name = (
            name  # (string) -- name of experiment to be used in output dir paths
        )
        self.max_dataset_size = max_dataset_size  # (int) -- the max number of samples from dataset that is loaded
        self.checkpoints_dir = checkpoints_dir  # (string) -- path to output dir for model checkpoints, will be created if does not exist (default "./checkpoints")
        self.results_dir = results_dir  # (string) -- path to output dir for test results, will be created if does not exist (default "./results")
        self.tensorboard_dir = tensorboard_dir  # (string) -- path to tensorboard dir for logging, will be created if does not exist (default "./tensorboard")
        self.plots_dir = plots_dir # (string) -- path to matplotlib output showing training progress (default "./plots")
        self.onnx_dir = onnx_dir # (string) -- path to the directory to save ONXX models (default "./onnx")
        self.batch_size = (
            batch_size  # (int) -- batch size to be used at training time (default 1)
        )
        self.print_freq = print_freq  # (int) -- number of image iterations between each print log (default 100)
        self.save_epoch_freq = save_epoch_freq  # (int) -- number of epochs between each checkpoint save and validation (default 5)
        self.num_workers = (
            num_workers  # (int) -- number of CPU threads used by dataloader (default 4)
        )
        self.phase = phase  # (string) -- "train" | "test" (default "train")
        self.crop_size = crop_size  # (int) -- will crop image to square with side length crop_size, does not apply for EXR images (default 256)
        self.direction = direction  # (string) -- direction of mapping in input image pair AB: "AtoB" | "BtoA" (default "BtoA")
        self.use_exr = use_exr  # (bool) -- flag for EXR image option (default False)
        self.channels = (
            channels  # (int) -- number of color channels in input images (default 3)
        )
        self.use_tensorboard = use_tensorboard  # (bool)-- flag for the usage of tensorboardX to view results dynamically during train
        self.use_npy = use_npy
        self.save_plots = save_plots # (bool) -- flag that determines whether or not SSIM plots should be saved
        self.save_onnx = save_onnx # (bool) -- flag that determines whether or not the model is converted to ONNX

    """
    Construction from dict( string -> string ) object, used at test-time
    """

    @classmethod
    def init_from_dict(self, dict):
        self.dataroot = dict["dataroot"]
        self.name = dict["name"]
        self.max_dataset_size = float(dict["max_dataset_size"])
        self.checkpoints_dir = dict["checkpoints_dir"]
        self.results_dir = dict["results_dir"]
        self.tensorboard_dir = dict["tensorboard_dir"]
        self.plots_dir = dict["plots_dir"]
        self.onnx_dir = dict["onnx_dir"]
        self.batch_size = int(dict["batch_size"])
        self.print_freq = int(dict["print_freq"])
        self.save_epoch_freq = int(dict["save_epoch_freq"])
        self.num_workers = int(dict["num_workers"])
        self.phase = dict["phase"]
        self.crop_size = int(dict["crop_size"])
        self.use_exr = dict["use_exr"] == "True"
        self.use_npy = dict["use_npy"] == "True"
        self.direction = dict["direction"]
        self.channels = int(dict["channels"])
        self.use_tensorboard = dict["use_tensorboard"] == "True"
        self.save_plots = dict["save_plots"] == "True"

        return self
