HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ML_Model.h
Go to the documentation of this file.
1 /*
2  * PROPRIETARY INFORMATION. This software is proprietary to
3  * Side Effects Software Inc., and is not to be reproduced,
4  * transmitted, or disclosed in any way without written permission.
5  *
6  * COMMENTS: Wrapper for the ONNX inference engine
7  *
8  */
9 
10 #pragma once
11 
12 #include "ML_API.h"
13 #include "ML_Types.h"
14 
15 #include <UT/UT_Array.h>
16 #include <UT/UT_NonCopyable.h>
17 #include <UT/UT_SharedPtr.h>
18 #include <UT/UT_StringHolder.h>
19 
20 class UT_WorkBuffer;
21 
23 {
24 public:
25  ML_Model();
26  ~ML_Model();
28 
29  class SessionInfo; // Our environment
31 
32  /// Initializer for the ML_Model class
33  /// \param model_filepath contains the path to the ONNX model
34  /// \param provider determines which execution provider to use
35  /// \param errors is meant to hold any error strings that may be generated from
36  /// an error occurring during initialization
37  bool initializeModel(const UT_StringRef &model_filepath,
38  ML_ExecutionProvider provider,
39  UT_WorkBuffer &errors,
40  UT_WorkBuffer &warnings,
41  UT_WorkBuffer &messages);
42 
43  bool run(const UT_Array<UT_Array<float>> &inputs,
44  const UT_Array<Shape> &input_shapes,
45  UT_Array<UT_Array<float>> &outputs,
46  const UT_Array<Shape> &output_shapes,
47  UT_WorkBuffer &error_message);
48 
49  void getNames(UT_StringArray &input_names,
50  UT_StringArray &output_names) const;
51 
52  void getShapes(UT_Array<Shape> &input_shapes,
53  UT_Array<Shape> &output_shapes) const;
54 
55  /// Gets the product of all non-dynamic axes of a tensor shape.
56  /// Places a bool in a variable to determine if dynamic axes were found
57  /// Any dimensions being zero will return 0
58  /// \returns 1 if all axes are dynamic
59  static exint tensorElementsSize(const UT_Array<exint> &tensor_dimensions,
60  bool &has_dynamic_axes);
61 
62  /// This is a function for acquiring the shape of a tensor from the parameters.
63  /// \param tensor_shape the array to fill with the shape.
64  static bool mat3ToShape(Shape &tensor_shape, const UT_Matrix3D &shape_vector);
65 
66  /// Parses the output data for nodes
67  /// \param maxtuplesize -1 for unlimited
68  static bool parseOutputData(const UT_StringHolder &srcpattern, int maxtuplesize,
70 
71  /// Returns true if the specified provider is supported by the current
72  /// platform and system configuration
73  static bool supportsExecutionProvider(ML_ExecutionProvider provider);
74 
75  /// Returns a cached list of support execution providers for the current
76  /// system configuration
77  static const UT_Array<ML_ExecutionProvider> &
78  supportedExecutionProviders();
79 
80  /// Returns a human-readable name for an exection provider
81  static void executionProviderName(
82  ML_ExecutionProvider provider,
85 private:
86  /// Places the information about the model into a UT_WorkBuffer
87  void info(UT_WorkBuffer &model_info) const;
88 
89  /// Places the string representing the shape for input "input_index" into a workbuffer
90  void inputShapeString(int input_index, UT_WorkBuffer &the_string) const;
91 
92  /// Places the string representing the shape for output "output_index" into a workbuffer
93  void outputShapeString(int output_index, UT_WorkBuffer &the_string) const;
94 
95  /// Checks to see if the size and shape of the inputs and outputs
96  /// are compatible with the model and should be able to run
97  bool sizeAndShapeErrorChecking(const UT_Array<UT_Array<float>> &inputs,
98  const UT_Array<Shape> &specified_input_shapes,
99  UT_Array<UT_Array<float>> &outputs,
100  const UT_Array<Shape> &specified_output_shapes,
101  UT_WorkBuffer &error_message);
102 
104 
105 };
GLuint GLsizei const GLuint const GLintptr const GLsizeiptr * sizes
Definition: glcorearb.h:2621
ML_ExecutionProvider
Definition: ML_Types.h:17
GLuint GLsizei const GLchar * label
Definition: glcorearb.h:2545
int64 exint
Definition: SYS_Types.h:125
< returns > If no error
Definition: snippets.dox:2
std::shared_ptr< T > UT_SharedPtr
Wrapper around std::shared_ptr.
Definition: UT_SharedPtr.h:36
#define ML_API
Definition: ML_API.h:10
#define UT_NON_COPYABLE(CLASS)
Define deleted copy constructor and assignment operator inside a class.
GLuint const GLchar * name
Definition: glcorearb.h:786