#
# Copyright (c) <2023> Side Effects Software Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# NAME:	        ml/pix2pix/testmodel.py ( Python )
#
# COMMENTS:     Tests a trained model as a separate task.
#
#               Note that scripts in this module are not part of a stable API,
#               and are subject to change at any time.

import os
import argparse

from tqdm import tqdm

from pdgml.pix2pix.aligneddataset import create_dataset
from pdgml.pix2pix.computemetrics import compute
from pdgml.pix2pix.model import create_model
from pdgml.pix2pix.options import Options, HyperParams
from pdgml.pix2pix.utils import save_images, save_npys

def append_losses_to_file(file_name, loss_averages, epoch):
    """Opens file and appends new validation metrics

    Parameters
    ----------
        file_name (string): path of file
        loss_averages (dict<string, float>): contains averages for each loss
        epoch (int): epoch number

    Return
    ------
        loss_fcns (list<string>): represents loss function labels in order

    """

    loss_fcns = []

    f = open(file_name, "a+")
    f.write("{0},".format(epoch))

    for name, value in loss_averages.items():
        loss_fcns.append(name)
        f.write("{0},".format(value))

    f.write("\n")

    return loss_fcns


def deserialize_as_dict(filename):
    """Opens options file on disk and parses to dict() object

    Parameters
    ----------
        filename (string): path to options file on disk

    Returns
    -------
        obj_dict (dict<string><string>): dictionary of options and their respective values

    """

    obj_dict = {}
    with open(filename, "r") as file:
        for line in file:
            (key, val) = line.split()
            obj_dict[key] = val

    return obj_dict


def test(options, hyper_params, epoch):
    """Runs the test procedure: loads model, inferences on testset, computes validation metrics

    Parameters
    ----------
        options: Standard Options object
        hyper_params: HyperParams object
        epoch (int): current epoch in train process

    """

    options.num_workers = 1  # test code only supports num_threads = 1
    options.batch_size = 1  # test code only supports batch_size = 1
    options.phase = "test"  # set mode to test for dataset and model creation
    options.max_dataset_size = float("inf")  # use the full size of test/val set
    dataset = create_dataset(options)  # create a dataset
    model = create_model(
        hyper_params, options
    )  # create a model given opt.model and other options
    model.setup(
        hyper_params, epoch
    )  # regular setup: load and print networks; create schedulers

    # Test with eval mode. This only affects layers like batchnorm and dropout.
    # Batchnorm and dropout are used in the original pix2pix. You can experiment it with and without eval() mode.
    model.eval()

    # Define the results image directory
    imgs_dir = os.path.join(
        options.results_dir, options.name, "%s.%s/" % (options.phase, epoch)
    )
    print("Producing output images in {0}".format(imgs_dir))

    for i, data in enumerate(tqdm(dataset)):
        model.set_input(data)  # unpack data from data loader
        model.test()  # run inference

        visuals = model.get_current_visuals()  # get image results
        img_path = model.get_image_paths()  # get image paths

        if options.use_exr is True:
            save_exrs(imgs_dir, visuals, img_path)
        elif options.use_npy:
            save_npys(imgs_dir, visuals, img_path)
        else:
            save_images(imgs_dir, visuals, img_path)

        # computing validation metrics
    save_name = os.path.join(
        options.results_dir, options.name, "losses.txt"
    )  # path to save files
    loss_averages = compute(
        imgs_dir, options.channels, use_exr=options.use_exr, use_npy=options.use_npy
    )  # computes average metric values over validation set

    append_losses_to_file(
        save_name, loss_averages, epoch
    )  # appends values to existing file on disk


# Intended to run when called asynchronously from a separate process
if __name__ == "__main__":

    # command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--path",
        type=str,
        default="./results/opt.txt",
        help="path to temporary options file with info from train.py",
    )
    parser.add_argument("--epoch", type=int, help="current training epoch value")
    args = parser.parse_args()

    # fetch and create options for current training session
    opt_dict = deserialize_as_dict(args.path)
    options = Options.init_from_dict(opt_dict)
    hyper_params = HyperParams.init_from_dict(opt_dict)

    # test/validation
    test(options, hyper_params, args.epoch)
