# Copyright (c) <2023> Side Effects Software Inc.
#  All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. The name of Side Effects Software may not be used to endorse or
#     promote products derived from this software without specific prior
#     written permission.
#
#  THIS SOFTWARE IS PROVIDED BY SIDE EFFECTS SOFTWARE "AS IS" AND ANY EXPRESS
#  OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
#  OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN
#  NO EVENT SHALL SIDE EFFECTS SOFTWARE BE LIABLE FOR ANY DIRECT, INDIRECT,
#  INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
#  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
#  OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
#  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
#  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
#  EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

""" Training script for an ML deformer based on random poses
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import numpy

import os

class DynamicDisplacementModel(nn.Module):
    """This class defines the network structure.

    hidden_layer_count must be at least 1.
    """

    def __init__(self, input_dimension, target_dimension, hidden_layer_count, hidden_layer_width):
        super().__init__()
        self.input_layer = nn.Linear(input_dimension, hidden_layer_width)

        # functions that make up the hidden layers:
        # activation functions alternating with linear functions
        self.hidden_functions = []

        # activation functions (nonlinear) in hidden layers
        self.activation_functions = []

        # linear functions in hidden layers
        self.linear_functions = []

        # Hidden layers start with activation layer
        activation = nn.Tanh()            
        self.activation_functions.append(activation)
        self.hidden_functions.append(activation)

        for i in range(hidden_layer_count):
            # Extend hidden layers by repeatedly appending
            # a linear function followed by an activation function

            l_layer = nn.Linear(hidden_layer_width, hidden_layer_width)
            self.linear_functions.append(l_layer)
            self.hidden_functions.append(l_layer)

            activation = nn.Tanh()            
            self.activation_functions.append(activation)
            self.hidden_functions.append(activation)

        self.hidden_layers = nn.Sequential(*self.hidden_functions) 

        self.output_layer = nn.Linear(hidden_layer_width, target_dimension)

        self.net = nn.Sequential(
            self.input_layer,
            self.hidden_layers,
            self.output_layer
        )

    def forward(self, x):
        return self.net(x)

class MyDataSet(Dataset):
    """Utility for iterating through the training data"""

    def __init__(self, inputs, targets, begin, end):
        super().__init__()
        self._inputs = inputs
        self._targets = targets
        self._begin = begin
        self._end = end

    def __len__(self):
        return self._end - self._begin

    def __getitem__(self, idx):
        input = self._inputs[self._begin + idx]
        target = self._targets[self._begin + idx]
        return input, target

def regularization_term(model, regularization, my_requires_grad, my_device):
    """Evaluate a loss term that aims to keep the linear (nonbias) parameters small.

    This is useful for preventing overfitting.
    """
    zero = torch.tensor(0.0, requires_grad=my_requires_grad, device=my_device)

    t = zero

    for name, parameter in model.named_parameters():
        if 'bias' not in name:
            t = t + torch.norm(parameter, p=2)

    return torch.tensor(regularization, requires_grad=my_requires_grad) * t

def train_single_epoch(my_device, log_file, model, criterion, regularization, optimizer, training_loader):
    """Do a training pose for a simple epoch.

    This goes through all of the training poses, excluding the poses that have been set aside for validation.
    """
    model.train(True)
    my_requires_grad = True

    for i, data in enumerate(training_loader):
        batch_inputs, batch_targets = data
        batch_inputs = batch_inputs.to(my_device)
        batch_targets = batch_targets.to(my_device)

        # zero the parameter gradients
        optimizer.zero_grad()

        batch_outputs = model(batch_inputs)

        batch_loss = criterion(batch_outputs,batch_targets) + regularization_term(model, regularization, my_requires_grad, my_device)

        batch_loss.backward()
        optimizer.step()

def evaluated_loss(my_device, model, criterion, regularization, loader):
    """Evaluate the loss without computing gradients.

    This allows periodic checking of the loss function for both the training set and the validation set.
    """
    model.train(False)
    my_requires_grad = False

    zero = torch.tensor(0.0, requires_grad=my_requires_grad, device=my_device)
    one = torch.tensor(1.0, requires_grad=my_requires_grad, device=my_device)
    loss = zero
    n = zero

    for i, data in enumerate(loader):
        batch_inputs, batch_targets = data
        batch_inputs = batch_inputs.to(my_device)
        batch_targets = batch_targets.to(my_device)
        batch_outputs = model(batch_inputs)

        batch_size = batch_inputs.size( 0 )

        batch_loss = criterion(batch_outputs,batch_targets)

        loss = loss + batch_loss
        n = n + one

    return (loss / n) + regularization_term(model, regularization, my_requires_grad, my_device)

def export_to_onnx(model, dummy_input, export_path, my_device):
    """
    Export the model to ONNX format so that it can be read in and used by the ONNX SOP
    """
    cpu_device = torch.device('cpu')
    
    model = model.to(cpu_device)
    model.eval()

    dummy_input = dummy_input.to(cpu_device)

    torch.onnx.export(model, dummy_input, export_path)
    
    model = model.to(my_device)

def train_using_inputs_and_targets(
    my_device,
    inputs, targets,
    log_file, 
    input_dimension,
    target_dimension,
    observation_count,
    onnx_path,
    /, *,
    hidden_layer_count = 2,
    hidden_layer_width = -1,
    regularization = 0,
    epoch_max = 1_000_000_000
):
    """
    Perform a complete supervised training given input and target tensors

    The training stops as soons as the loss on the validation set stops decreasing after a specified number of iterations
    """

    log_file.write(f"train_using_inputs_and_targets has started\n")

    if observation_count <= 0:
        log_file.write(f"The number of observations is zero\n")
        return

    # If the hidden layer width is not specified by the user of the script,
    # take the maximum of the input_dimension and the target dimension
    if hidden_layer_width == -1:
        hidden_layer_width = max(input_dimension, target_dimension)

    model = DynamicDisplacementModel(input_dimension, target_dimension, hidden_layer_count, hidden_layer_width)
    model = model.to(my_device)

    # It is assumed that inputs, targets already have
    # a random ordering

    # Use only the initial subset_count entries of
    # the observations for training and validation
    #subset_count = min(observation_count, 16)
    subset_count = observation_count

    # Choose ratio of training samples over validation samples
    validate_one_per = 16

    validation_count = max(1, subset_count // validate_one_per)
    training_count = subset_count - validation_count

    if validation_count <= 0:
        log_file.write(f"Not enough observations available for validation set\n")
        return

    log_file.write(f"Use only the first {subset_count=} observations \n")
    log_file.write(f"size of training set {training_count=}\n")
    log_file.write(f"Size of validation set {validation_count=}\n")

    training_begin = 0
    training_end = training_begin + training_count

    validation_begin = training_end
    validation_end = validation_begin + validation_count

    training_data = MyDataSet(inputs, targets, training_begin, training_end)
    validation_data = MyDataSet(inputs, targets, validation_begin, validation_end)

    max_batch_size = 256

    training_loader = DataLoader(training_data, batch_size = min(max_batch_size, training_count), shuffle = True)
    validation_loader = DataLoader(validation_data, batch_size = max_batch_size, shuffle = True)

    criterion = nn.MSELoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999))

    model_save_count = 0
    last_saved_validation_loss = -1

    for epoch in range(epoch_max):
        train_single_epoch(my_device, log_file, model, criterion, regularization, optimizer, training_loader)

        #TODO: Put below code in a function and call only once
        # in so many iterations

        if ( epoch % 16 == 0 ) or ( epoch == epoch_max - 1 ):
            training_loss = evaluated_loss(my_device, model, criterion, regularization, training_loader)
            validation_loss = evaluated_loss(my_device, model, criterion, regularization, validation_loader)
            
            log_file.write("{0: <11}".format(epoch))
            log_file.write(f"training loss {training_loss.item():<24}")
            log_file.write(f"validation loss {validation_loss.item():<24}\n")
            #TODO: print regularization

            save_current_model = False

            if ( epoch > 0 ) and ( ( epoch % 1024 == 0 ) or ( epoch == epoch_max - 1 ) ):
                if ( model_save_count >= 1 ) and ( validation_loss.item() >= last_saved_validation_loss ):
                    log_file.write(f"Don't save the current model because current validation loss {validation_loss.item()} is not less than last saved validation loss {last_saved_validation_loss}\n Terminate training loop.\n")

                    return
                else:
                    if model_save_count == 0:
                        log_file.write(f"Save the current model (first time)\n")
                    else:
                        log_file.write(f"Save the current model because current validation loss {validation_loss.item()} is strictly less than last saved validation loss {last_saved_validation_loss}\n")

                    save_current_model = True

            #FIXME: Temporary:
            #if epoch % 1024 == 0:
            #    save_current_model = True

            if save_current_model:
                log_file.write(f"ONNX file path: {onnx_path}\n")

                dummy_input = torch.zeros_like(inputs[0])

                export_to_onnx(model, dummy_input, onnx_path, my_device)

                model_save_count = model_save_count + 1
                last_saved_validation_loss = validation_loss.item()

        log_file.flush()

    log_file.write(f"train_using_inputs_and_targets has completed\n")

def train_and_save_model(
    input_data_path, target_data_path,
    onnx_path,
    log_path,
    /, *,
    hp_hidden_layer_count = 2,
    hp_hidden_layer_width = -1,
    hp_regularization = 0.0,
    hp_epoch_max = 1_000_000_000,
):
    """
    Given file paths to the input and target data,
    construct corresponding tensors and perform a training.
    """
    torch.manual_seed(2012)

    with open(log_path, 'w') as log_file:
        log_file.write(f"{input_data_path=}\n")
        log_file.write(f"{target_data_path=}\n")

        log_file.write(f"{onnx_path=}\n")
        log_file.write(f"{log_path=}\n")
 
        if torch.cuda.is_available():
            my_device = torch.device('cuda')
        elif torch.backends.mps.is_available():
            my_device = torch.device('mps')
        else:
            my_device = torch.device('cpu')

        log_file.write(f"{my_device=}\n")
          
        with open(input_data_path, 'rb') as inputs_file:
            input_observation_count_bytes = inputs_file.read(4)
            input_observation_count = int.from_bytes(input_observation_count_bytes, byteorder='little')
            input_dimension_bytes = inputs_file.read(4)
            input_dimension = int.from_bytes(input_dimension_bytes, byteorder='little')
            inputs_raw = inputs_file.read()

        with open(target_data_path, 'rb') as targets_file:
            target_observation_count_bytes = targets_file.read(4)
            target_observation_count = int.from_bytes(target_observation_count_bytes, byteorder='little')
            target_dimension_bytes = targets_file.read(4)
            target_dimension = int.from_bytes(target_dimension_bytes, byteorder='little')
            targets_raw = targets_file.read()

        if input_dimension <= 0:
            message = f"Each observation has an input dimension of {input_dimension}"
            log_file.write(message)
            return (False, message)

        if target_dimension <= 0:
            message = f"Each observation has an target dimension of {target_dimension}"
            log_file.write(message)
            return (False, message)

        if input_observation_count != target_observation_count:
            message = f"Observation counts for input file and target file do not match: input has {input_observation_count} observations, but target has {target_observation_count} observations"
            log_file.write(message)
            return (False, message)

        observation_count = input_observation_count

        log_file.write(f"{observation_count=}\n")
        log_file.write(f"{input_dimension=}\n")
        log_file.write(f"{target_dimension=}\n")

        my_dtype = numpy.float32
        my_dtype_bytes = 4;

        input_data_expected_bytes = observation_count * input_dimension * my_dtype_bytes
        if len(inputs_raw) != input_data_expected_bytes:
            message = f"Raw input data is {len(inputs_raw)} bytes ({len(inputs_raw) // (input_dimension * my_dtype_bytes)} inputs), expected {input_data_expected_bytes} bytes ({observation_count} inputs)\n"
            log_file.write(message)
            return (False, message)

        target_data_expected_bytes = observation_count * target_dimension * my_dtype_bytes
        if len(targets_raw) != target_data_expected_bytes:
            message = f"Raw target data is {len(targets_raw)} bytes ({len(targets_raw) // (target_dimension * my_dtype_bytes)} inputs), expected {target_data_expected_bytes} bytes ({observation_count} inputs)\n"
            log_file.write(message)
            return (False, message)

        inputs = torch.tensor(numpy.frombuffer(inputs_raw, my_dtype), device=my_device)
        inputs = torch.reshape(inputs, (observation_count, input_dimension))

        for i in range(observation_count):      
            if torch.isnan(inputs[i]).any():
                message = f"Input for observation #{i} contains NAN\n"
                log_file.write(message)
                return (False, message)

        targets = torch.tensor(numpy.frombuffer(targets_raw, my_dtype), device=my_device)
        targets = torch.reshape(targets, (observation_count, target_dimension))

        for i in range(observation_count):      
            if torch.isnan(targets).any():
                message = f"Target for observation #{i} contains NAN\n"
                log_file.write(message)
                return (False, message)

        log_file.write(f"{hp_hidden_layer_count=}\n")
        log_file.write(f"{hp_regularization=}\n")
        log_file.write(f"{hp_epoch_max=}\n")

        train_using_inputs_and_targets(
            my_device,
            inputs,
            targets,
            log_file,
            input_dimension,
            target_dimension,
            observation_count,
            onnx_path,
            hidden_layer_count=hp_hidden_layer_count,
            hidden_layer_width=hp_hidden_layer_width,
            regularization=hp_regularization,
            epoch_max=hp_epoch_max
        )

    return (True, "")

# Make this file usable as a stand-alone script invoked outside of the hip file
if __name__ == "__main__":
    # These must be set manually to correspond to the current sizes:
    input_data_path = "Preprocessed/inputs.raw"
    target_data_path = "Preprocessed/targets.raw"
    onnx_path = 'Models/model_standalone.onnx'
    log_path = 'Logs/log_standalone.txt'

    train_and_save_model(
        input_data_path, target_data_path,
        onnx_path, log_path
    )

    with open(log_path, 'r') as log_file:
        read_data = log_file.read()
        print(read_data)
  
    
