#
# Copyright (c) <2023> Side Effects Software Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# NAME:	        ml/pix2pix/train.py ( Python )
#
# COMMENTS:     Defines high-level methods to kick off training of a pix2pix
#               model.
#
#               Note that scripts in this module are not part of a stable API,
#               and are subject to change at any time.
import os
import subprocess
import sys
import time

import torch

from pdgml.pix2pix.options import HyperParams, Options
from pdgml.pix2pix.aligneddataset import create_dataset
from pdgml.pix2pix.model import create_model
from pdgml.pix2pix.utils import print_current_losses, get_data_from_file, serialize_options_and_save, create_tensorboard_graphs
from pdgml.pix2pix.computemetrics import loss_measures
from pdgml.pix2pix import testmodel

def train(options, hyper_params, work_item=None):
    # Write options and hyper-parameters to the disk
    options_filename = serialize_options_and_save(
        [options, hyper_params],
        options.results_dir,
        options.name)

    # create a dataset given options
    dataset = create_dataset(options)

    # get the number of images in the dataset.
    dataset_size = len(dataset)

    if work_item:
        work_item.addMessage('Training dataset size = {}'.format(dataset_size))

    # create path to directory for model checkpoints -- create the directory
    # if it does not exist
    checkpoints_dir = os.path.join(options.checkpoints_dir, options.name)
    model_log_path = os.path.join(checkpoints_dir, 'losses.txt')
    os.makedirs(checkpoints_dir, exist_ok=True)

    # if use_tensorboard flag is set to True, import tensorboardX and create
    # the SummaryWriter object
    if options.use_tensorboard: 
        from tensorboardX import SummaryWriter
        tensorboard_path = os.path.join(options.tensorboard_dir, options.name)
        writer = SummaryWriter(tensorboard_path)

        if work_item:
            work_item.addMessage(
                "Writing tensorboard files to {}".format(tensorboard_path))

    # Keeps track of testing progress to ensure testing data is properly
    # obtained
    num_epochs_written = 0

    # create a model given opt.model and other options
    model = create_model(hyper_params, options)

    # regular setup: load and print networks; create schedulers
    model.setup(hyper_params, 0)

    # the total number of training iterations
    total_iters = 0

    # Storage for final graphs of training data and validation data
    train_ssim_cache = [[], []]
    val_ssim_cache = [[], []]

    # Determine the path to the loss log
    test_log_path = os.path.join(
        options.results_dir, options.name, 'losses.txt')

    # outer loop for different epochs; we save the model by <epoch_count>,
    # <epoch_count>+<save_latest_freq>
    for epoch in range(1, hyper_params.niter + hyper_params.niter_decay + 1):
        if work_item:
            sub_item = work_item.batchItems[epoch-1]
            sub_item.startSubItem()
        else:
            sub_item = None
    
        # timer for entire epoch
        epoch_start_time = time.time()
    
        # timer for data loading per iteration
        iter_data_time = time.time()
    
        # the number of training iterations in current epoch,
        # reset to 0 every epoch
        epoch_iter = 0

        for i, data in enumerate(dataset):
            # timer for computation per iteration
            iter_start_time = time.time()
            if total_iters % options.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            total_iters += options.batch_size
            epoch_iter += options.batch_size

            # unpack data from dataset and apply preprocessing
            model.set_input(data)

            # calculate loss functions, get gradients, update network weight
            model.optimize_parameters()
            
            iter_data_time = time.time()
            
            # if use_tensorboard flag is set to True, print out the generator
            # graph to tensorboard
            if options.use_tensorboard:
                if epoch == 1 and i == 0:
                    if sub_item:
                        sub_item.addMessage('Creating Graphs for Generator...')
                    create_tensorboard_graphs(
                        writer, options, hyper_params, data)

            # # print training losses and save logging information to the disk
            if total_iters % options.print_freq == 0:
                losses = model.get_current_losses()
                t_comp = (time.time() - iter_start_time) / options.batch_size
                print_current_losses(
                    model_log_path,
                    epoch,
                    epoch_iter,
                    losses,
                    t_comp,
                    t_data)

                # if use_tensorboard flag is set to True, write out the scalar
                # values for training losses to tensorboard
                if options.use_tensorboard:
                    for name, val in losses.items():
                        writer.add_scalar('Train/' + name, val, total_iters)

                # Store training losses in a cache to plot at the end of training
                train_ssim_cache[0].append(float(total_iters / len(dataset)))
                train_ssim_cache[1].append(losses['SSIM'])

        # cache our model every <save_epoch_freq> epochs
        if epoch % options.save_epoch_freq == 0 or \
                epoch == (hyper_params.niter + hyper_params.niter_decay):
            if sub_item:
                sub_item.addMessage(
                    'Saving model at the end of epoch {}, iters {}'.format(
                        epoch, total_iters))
            model_paths = model.save_networks(epoch)

            if sub_item:
                for model_path in model_paths:
                    sub_item.addOutputFile(
                        model_path, "file/model/pytorch", own=True)
        
            if os.path.exists(test_log_path):
                # Extract validation data from the disk
                epochs, data = get_data_from_file(
                    test_log_path,
                    len(loss_measures),
                    num_epochs_written)
                
                if len(epochs) > 0:
                    for i in range(len(epochs)):
                        for j in range(len(loss_measures)):
                            # if use_tensorboard flag is set to True, write out
                            # the scalar values for validation losses to
                            # tensorboard
                            if options.use_tensorboard: 
                                writer.add_scalar(
                                    'Validation/' + loss_measures[j][0],
                                    data[j][i],
                                    epochs[i])

                            # Store validation SSIM losses in a cache to plot at
                            # the end of training
                            if loss_measures[j][0] == 'SSIM': 
                                val_ssim_cache[0].append(epochs[i]) 
                                val_ssim_cache[1].append(data[j][i])
                    num_epochs_written += len(epochs)
            
            # Command to run the validation script. Run test_command
            # synchronously to prevent race between training and testing
            # async is possible given a large dataset
            test_command = [
                sys.executable,
                '-E',
                testmodel.__file__,
                '--path',
                options_filename,
                '--epoch',
                str(epoch)]
            p = subprocess.call(test_command)

            if options.save_onnx:
                onnx_options = Options.init_from_dict(vars(options))
                onnx_options.num_workers = 1
                onnx_options.batch_size = 1
                onnx_options.phase = "test"
                onnx_options.max_dataset_size = float("inf")
                save_onnx_model(onnx_options, hyper_params, epoch, sub_item)
            
        model.update_learning_rate()

        if sub_item:
            sub_item.addMessage(
                'End of epoch {}/{} -- Time Taken: {} sec'.format(
                    epoch,
                    hyper_params.niter + hyper_params.niter_decay,
                    time.time() - epoch_start_time))
            sub_item.cookSubItem()

    # Run the code above that reads and processes validation data to
    # complete the last iteration
    if os.path.exists(test_log_path):
        epochs, data = get_data_from_file(
            test_log_path,
            len(loss_measures),
            num_epochs_written)

        if len(epochs) > 0:
            for i in range(len(epochs)):
                for j in range(len(loss_measures)):
                    if options.use_tensorboard:
                        writer.add_scalar(
                            'Validation/' + loss_measures[j][0],
                            data[j][i],
                            epochs[i])
                    if loss_measures[j][0] == 'SSIM':
                        val_ssim_cache[0].append(epochs[i])
                        val_ssim_cache[1].append(data[j][i])
            num_epochs_written += len(epochs)

    # Save SSIM plots if configured to do so
    if options.save_plots:
        save_ssim_plot(
            options,
            hyper_params,
            train_ssim_cache,
            val_ssim_cache,
            work_item)

def save_ssim_plot(options, hyper_params, train_ssim, val_ssim, work_item):
    import matplotlib
    import matplotlib.pyplot
    matplotlib.use('Agg')

    matplotlib.pyplot.plot(
        train_ssim[0],
        train_ssim[1],
        '-',
        label='Train SSIM')
    matplotlib.pyplot.plot(
        val_ssim[0],
        val_ssim[1],
        '.-',
        label='Validation SSIM')

    matplotlib.pyplot.legend(loc='best')
    matplotlib.pyplot.ylim(0.6, 1.0)
    matplotlib.pyplot.title("G={} D={} Filter={} LR={}".format(
        hyper_params.n_layers_G,
        hyper_params.n_layers_D,
        hyper_params.filter_size,
        hyper_params.lr))

    plot_dir = os.path.join(options.plots_dir, options.name)
    plot_file = os.path.join(plot_dir, 'SSIM.png')
    os.makedirs(plot_dir, exist_ok=True)

    matplotlib.pyplot.savefig(plot_file)

    if work_item:
        work_item.addOutputFile(plot_file, own=True)

def save_onnx_model(options, hyper_params, epoch, work_item):
    model = create_model(hyper_params, options)
    model.setup(hyper_params, epoch)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    final_model = model.netG.to(device)
    final_model.eval()

    dummy_input = torch.FloatTensor(
        options.channels, options.crop_size, options.crop_size)
    dummy_input = dummy_input.unsqueeze(0)

    onnx_dir = os.path.join(options.onnx_dir, options.name)
    onnx_file = os.path.join(onnx_dir, "{}.onnx".format(epoch))
    os.makedirs(onnx_dir, exist_ok=True)

    torch.onnx.export(
        final_model,
        dummy_input.to(device),
        onnx_file,
        export_params = True,
        opset_version = 10,
        do_constant_folding = True,
        input_names = ['modelInput'],
        output_names = ['modelOutput'],
        dynamic_axes = {
            'modelInput' : {0 : 'batch_size'},
            'modelOutput' : {0 : 'batch_size'}
        }
    )

    if work_item:
        work_item.addOutputFile(onnx_file, "file/model/onnx", own=True)
