HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
onnxruntime_cxx_api.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 // Summary: The Ort C++ API is a header only wrapper around the Ort C API.
5 //
6 // The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
7 // and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so
8 // all the resources follow RAII and do not leak memory.
9 //
10 // Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
11 // To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them
12 // until you assign an instance that actually holds an underlying object.
13 //
14 // For Ort objects only move assignment between objects is allowed, there are no copy constructors.
15 // Some objects have explicit 'Clone' methods for this purpose.
16 //
17 // ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments
18 // by value or by reference. ConstXXXX types are restricted to const only interfaces.
19 //
20 // UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces.
21 //
22 // The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not
23 // have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code.
24 
25 #pragma once
26 #include "onnxruntime_c_api.h"
27 #include <cstddef>
28 #include <array>
29 #include <memory>
30 #include <stdexcept>
31 #include <string>
32 #include <vector>
33 #include <unordered_map>
34 #include <utility>
35 #include <type_traits>
36 
37 #ifdef ORT_NO_EXCEPTIONS
38 #include <iostream>
39 #endif
40 
41 /** \brief All C++ Onnxruntime APIs are defined inside this namespace
42  *
43  */
44 namespace Ort {
45 
46 /** \brief All C++ methods that can fail will throw an exception of this type
47  *
48  * If <tt>ORT_NO_EXCEPTIONS</tt> is defined, then any error will result in a call to abort()
49  */
50 struct Exception : std::exception {
51  Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
52 
53  OrtErrorCode GetOrtErrorCode() const { return code_; }
54  const char* what() const noexcept override { return message_.c_str(); }
55 
56  private:
57  std::string message_;
58  OrtErrorCode code_;
59 };
60 
61 #ifdef ORT_NO_EXCEPTIONS
62 // The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors.
63 // NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW
64 #ifndef ORT_CXX_API_THROW
65 #define ORT_CXX_API_THROW(string, code) \
66  do { \
67  std::cerr << Ort::Exception(string, code) \
68  .what() \
69  << std::endl; \
70  abort(); \
71  } while (false)
72 #endif
73 #else
74 #define ORT_CXX_API_THROW(string, code) \
75  throw Ort::Exception(string, code)
76 #endif
77 
78 // This is used internally by the C++ API. This class holds the global variable that points to the OrtApi,
79 // it's in a template so that we can define a global variable in a header and make
80 // it transparent to the users of the API.
81 template <typename T>
82 struct Global {
83  static const OrtApi* api_;
84 };
85 
86 // If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
87 template <typename T>
88 #ifdef ORT_API_MANUAL_INIT
89 const OrtApi* Global<T>::api_{};
90 inline void InitApi() { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
91 
92 // Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is
93 // required by C++ APIs.
94 //
95 // Example mycustomop.cc:
96 //
97 // #define ORT_API_MANUAL_INIT
98 // #include <onnxruntime_cxx_api.h>
99 // #undef ORT_API_MANUAL_INIT
100 //
101 // OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) {
102 // Ort::InitApi(api_base->GetApi(ORT_API_VERSION));
103 // // ...
104 // }
105 //
106 inline void InitApi(const OrtApi* api) { Global<void>::api_ = api; }
107 #else
108 #if defined(_MSC_VER) && !defined(__clang__)
109 #pragma warning(push)
110 // "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
111 // Please define ORT_API_MANUAL_INIT if it conerns you.
112 #pragma warning(disable : 26426)
113 #endif
114 const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
115 #if defined(_MSC_VER) && !defined(__clang__)
116 #pragma warning(pop)
117 #endif
118 #endif
119 
120 /// This returns a reference to the OrtApi interface in use
121 inline const OrtApi& GetApi() { return *Global<void>::api_; }
122 
123 /// <summary>
124 /// This is a C++ wrapper for OrtApi::GetAvailableProviders() and
125 /// returns a vector of strings representing the available execution providers.
126 /// </summary>
127 /// <returns>vector of strings</returns>
128 std::vector<std::string> GetAvailableProviders();
129 
130 /** \brief IEEE 754 half-precision floating point data type
131  * \details It is necessary for type dispatching to make use of C++ API
132  * The type is implicitly convertible to/from uint16_t.
133  * The size of the structure should align with uint16_t and one can freely cast
134  * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data.
135  *
136  * Generally, you can feed any of your types as float16/blfoat16 data to create a tensor
137  * on top of it, providing it can form a continuous buffer with 16-bit elements with no padding.
138  * And you can also feed a array of uint16_t elements directly. For example,
139  *
140  * \code{.unparsed}
141  * uint16_t values[] = { 15360, 16384, 16896, 17408, 17664};
142  * constexpr size_t values_length = sizeof(values) / sizeof(values[0]);
143  * std::vector<int64_t> dims = {values_length}; // one dimensional example
144  * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
145  * // Note we are passing bytes count in this api, not number of elements -> sizeof(values)
146  * auto float16_tensor = Ort::Value::CreateTensor(info, values, sizeof(values),
147  * dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
148  * \endcode
149  *
150  * Here is another example, a little bit more elaborate. Let's assume that you use your own float16 type and you want to use
151  * a templated version of the API above so the type is automatically set based on your type. You will need to supply an extra
152  * template specialization.
153  *
154  * \code{.unparsed}
155  * namespace yours { struct half {}; } // assume this is your type, define this:
156  * namespace Ort {
157  * template<>
158  * struct TypeToTensorType<yours::half> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; };
159  * } //namespace Ort
160  *
161  * std::vector<yours::half> values;
162  * std::vector<int64_t> dims = {values.size()}; // one dimensional example
163  * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
164  * // Here we are passing element count -> values.size()
165  * auto float16_tensor = Ort::Value::CreateTensor<yours::half>(info, values.data(), values.size(), dims.data(), dims.size());
166  *
167  * \endcode
168  */
169 struct Float16_t {
170  uint16_t value;
171  constexpr Float16_t() noexcept : value(0) {}
172  constexpr Float16_t(uint16_t v) noexcept : value(v) {}
173  constexpr operator uint16_t() const noexcept { return value; }
174  constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; };
175  constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; };
176 };
177 
178 static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
179 
180 /** \brief bfloat16 (Brain Floating Point) data type
181  * \details It is necessary for type dispatching to make use of C++ API
182  * The type is implicitly convertible to/from uint16_t.
183  * The size of the structure should align with uint16_t and one can freely cast
184  * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data.
185  *
186  * See also code examples for Float16_t above.
187  */
188 struct BFloat16_t {
189  uint16_t value;
190  constexpr BFloat16_t() noexcept : value(0) {}
191  constexpr BFloat16_t(uint16_t v) noexcept : value(v) {}
192  constexpr operator uint16_t() const noexcept { return value; }
193  constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; };
194  constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; };
195 };
196 
197 static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
198 
199 namespace detail {
200 // This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
201 // This can't be done in the C API since C doesn't have function overloading.
202 #define ORT_DEFINE_RELEASE(NAME) \
203  inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
204 
205 ORT_DEFINE_RELEASE(Allocator);
206 ORT_DEFINE_RELEASE(MemoryInfo);
207 ORT_DEFINE_RELEASE(CustomOpDomain);
208 ORT_DEFINE_RELEASE(ThreadingOptions);
209 ORT_DEFINE_RELEASE(Env);
211 ORT_DEFINE_RELEASE(Session);
212 ORT_DEFINE_RELEASE(SessionOptions);
213 ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
214 ORT_DEFINE_RELEASE(SequenceTypeInfo);
215 ORT_DEFINE_RELEASE(MapTypeInfo);
216 ORT_DEFINE_RELEASE(TypeInfo);
218 ORT_DEFINE_RELEASE(ModelMetadata);
219 ORT_DEFINE_RELEASE(IoBinding);
220 ORT_DEFINE_RELEASE(ArenaCfg);
222 ORT_DEFINE_RELEASE(OpAttr);
224 ORT_DEFINE_RELEASE(KernelInfo);
225 
226 #undef ORT_DEFINE_RELEASE
227 
228 /** \brief This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object
229  * has no ownership of the underlying C object.
230  */
231 template <typename T>
232 struct Unowned {
233  using Type = T;
234 };
235 
236 /** \brief Used internally by the C++ API. C++ wrapper types inherit from this.
237  * This is a zero cost abstraction to wrap the C API objects and delete them on destruction.
238  *
239  * All of the C++ classes
240  * a) serve as containers for pointers to objects that are created by the underlying C API.
241  * Their size is just a pointer size, no need to dynamically allocate them. Use them by value.
242  * b) Each of struct XXXX, XXX instances function as smart pointers to the underlying C API objects.
243  * they would release objects owned automatically when going out of scope, they are move-only.
244  * c) ConstXXXX and UnownedXXX structs function as non-owning, copyable containers for the above pointers.
245  * ConstXXXX allow calling const interfaces only. They give access to objects that are owned by somebody else
246  * such as Onnxruntime or instances of XXXX classes.
247  * d) serve convenient interfaces that return C++ objects and further enhance exception and type safety so they can be used
248  * in C++ code.
249  *
250  */
251 
252 /// <summary>
253 /// This is a non-const pointer holder that is move-only. Disposes of the pointer on destruction.
254 /// </summary>
255 template <typename T>
256 struct Base {
257  using contained_type = T;
258 
259  constexpr Base() = default;
260  constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
261  ~Base() { OrtRelease(p_); }
262 
263  Base(const Base&) = delete;
264  Base& operator=(const Base&) = delete;
265 
266  Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
267  Base& operator=(Base&& v) noexcept {
268  OrtRelease(p_);
269  p_ = v.release();
270  return *this;
271  }
272 
273  constexpr operator contained_type*() const noexcept { return p_; }
274 
275  /// \brief Relinquishes ownership of the contained C object pointer
276  /// The underlying object is not destroyed
278  T* p = p_;
279  p_ = nullptr;
280  return p;
281  }
282 
283  protected:
285 };
286 
287 // Undefined. For const types use Base<Unowned<const T>>
288 template <typename T>
289 struct Base<const T>;
290 
291 /// <summary>
292 /// Covers unowned pointers owned by either the ORT
293 /// or some other instance of CPP wrappers.
294 /// Used for ConstXXX and UnownedXXXX types that are copyable.
295 /// Also convenient to wrap raw OrtXX pointers .
296 /// </summary>
297 /// <typeparam name="T"></typeparam>
298 template <typename T>
299 struct Base<Unowned<T>> {
301 
302  constexpr Base() = default;
303  constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
304 
305  ~Base() = default;
306 
307  Base(const Base&) = default;
308  Base& operator=(const Base&) = default;
309 
310  Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
311  Base& operator=(Base&& v) noexcept {
312  p_ = nullptr;
313  std::swap(p_, v.p_);
314  return *this;
315  }
316 
317  constexpr operator contained_type*() const noexcept { return p_; }
318 
319  protected:
321 };
322 
323 // Light functor to release memory with OrtAllocator
326  explicit AllocatedFree(OrtAllocator* allocator)
327  : allocator_(allocator) {}
328  void operator()(void* ptr) const {
329  if (ptr) allocator_->Free(allocator_, ptr);
330  }
331 };
332 
333 } // namespace detail
334 
335 struct AllocatorWithDefaultOptions;
336 struct Env;
337 struct TypeInfo;
338 struct Value;
339 struct ModelMetadata;
340 
341 /** \brief unique_ptr typedef used to own strings allocated by OrtAllocators
342  * and release them at the end of the scope. The lifespan of the given allocator
343  * must eclipse the lifespan of AllocatedStringPtr instance
344  */
345 using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
346 
347 /** \brief The Status that holds ownership of OrtStatus received from C API
348  * Use it to safely destroy OrtStatus* returned from the C API. Use appropriate
349  * constructors to construct an instance of a Status object from exceptions.
350  */
351 struct Status : detail::Base<OrtStatus> {
352  explicit Status(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used
353  explicit Status(OrtStatus* status); ///< Takes ownership of OrtStatus instance returned from the C API. Must be non-null
354  explicit Status(const Exception&); ///< Creates status instance out of exception
355  explicit Status(const std::exception&); ///< Creates status instance out of exception
357  OrtErrorCode GetErrorCode() const;
358 };
359 
360 /** \brief The ThreadingOptions
361  *
362  * The ThreadingOptions used for set global threadpools' options of The Env.
363  */
364 struct ThreadingOptions : detail::Base<OrtThreadingOptions> {
365  /// \brief Wraps OrtApi::CreateThreadingOptions
367 
368  /// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads
369  ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads);
370 
371  /// \brief Wraps OrtApi::SetGlobalInterOpNumThreads
372  ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads);
373 
374  /// \brief Wraps OrtApi::SetGlobalSpinControl
375  ThreadingOptions& SetGlobalSpinControl(int allow_spinning);
376 
377  /// \brief Wraps OrtApi::SetGlobalDenormalAsZero
379 
380  /// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn
382 
383  /// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions
384  ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
385 
386  /// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn
388 };
389 
390 /** \brief The Env (Environment)
391  *
392  * The Env holds the logging state used by all other objects.
393  * <b>Note:</b> One Env must be created before using any other Onnxruntime functionality
394  */
395 struct Env : detail::Base<OrtEnv> {
396  explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used
397 
398  /// \brief Wraps OrtApi::CreateEnv
399  Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
400 
401  /// \brief Wraps OrtApi::CreateEnvWithCustomLogger
402  Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
403 
404  /// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools
405  Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
406 
407  /// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools
408  Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
409  OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
410 
411  /// \brief C Interop Helper
412  explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
413 
414  Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents
415  Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents
416 
417  Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level); ///< Wraps OrtApi::UpdateEnvWithCustomLogLevel
418 
419  Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator
420 };
421 
422 /** \brief Custom Op Domain
423  *
424  */
425 struct CustomOpDomain : detail::Base<OrtCustomOpDomain> {
426  explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used
427 
428  /// \brief Wraps OrtApi::CreateCustomOpDomain
429  explicit CustomOpDomain(const char* domain);
430 
431  // This does not take ownership of the op, simply registers it.
432  void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add
433 };
434 
435 /** \brief RunOptions
436  *
437  */
438 struct RunOptions : detail::Base<OrtRunOptions> {
439  explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used
440  RunOptions(); ///< Wraps OrtApi::CreateRunOptions
441 
442  RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel
443  int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel
444 
445  RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel
446  int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel
447 
448  RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag
449  const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag
450 
451  RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry
452 
453  /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance
454  *
455  * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error
456  * Wraps OrtApi::RunOptionsSetTerminate
457  */
459 
460  /** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating
461  *
462  * Wraps OrtApi::RunOptionsUnsetTerminate
463  */
465 };
466 
467 
468 namespace detail {
469 // Utility function that returns a SessionOption config entry key for a specific custom operator.
470 // Ex: custom_op.[custom_op_name].[config]
471 std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config);
472 } // namespace detail
473 
474 /// <summary>
475 /// Class that represents session configuration entries for one or more custom operators.
476 ///
477 /// Example:
478 /// Ort::CustomOpConfigs op_configs;
479 /// op_configs.AddConfig("my_custom_op", "device_type", "CPU");
480 ///
481 /// Passed to Ort::SessionOptions::RegisterCustomOpsLibrary.
482 /// </summary>
484  CustomOpConfigs() = default;
485  ~CustomOpConfigs() = default;
486  CustomOpConfigs(const CustomOpConfigs&) = default;
487  CustomOpConfigs& operator=(const CustomOpConfigs&) = default;
488  CustomOpConfigs(CustomOpConfigs&& o) = default;
489  CustomOpConfigs& operator=(CustomOpConfigs&& o) = default;
490 
491  /** \brief Adds a session configuration entry/value for a specific custom operator.
492  *
493  * \param custom_op_name The name of the custom operator for which to add a configuration entry.
494  * Must match the name returned by the CustomOp's GetName() method.
495  * \param config_key The name of the configuration entry.
496  * \param config_value The value of the configuration entry.
497  * \return A reference to this object to enable call chaining.
498  */
499  CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value);
500 
501  /** \brief Returns a flattened map of custom operator configuration entries and their values.
502  *
503  * The keys has been flattened to include both the custom operator name and the configuration entry key name.
504  * For example, a prior call to AddConfig("my_op", "key", "value") corresponds to the flattened key/value pair
505  * {"my_op.key", "value"}.
506  *
507  * \return An unordered map of flattened configurations.
508  */
509  const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const;
510 
511  private:
512  std::unordered_map<std::string, std::string> flat_configs_;
513 };
514 
515 /** \brief Options object used when creating a new Session object
516  *
517  * Wraps ::OrtSessionOptions object and methods
518  */
519 
520 struct SessionOptions;
521 
522 namespace detail {
523 // we separate const-only methods because passing const ptr to non-const methods
524 // is only discovered when inline methods are compiled which is counter-intuitive
525 template <typename T>
527  using B = Base<T>;
528  using B::B;
529 
530  SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions
531 
532  std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry
533  bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry
534  std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def);
535 };
536 
537 template <typename T>
540  using B::B;
541 
542  SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads
543  SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads
544  SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel
545 
546  SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena
547  SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena
548 
549  SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath
550 
551  SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling
552  SessionOptionsImpl& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling
553 
554  SessionOptionsImpl& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps
555 
556  SessionOptionsImpl& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern
557  SessionOptionsImpl& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern
558 
559  SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode
560 
561  SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId
562  SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel
563 
564  SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain
565 
566  SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads
567 
568  SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry
569 
570  SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer
571  SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values); ///< Wraps OrtApi::AddExternalInitializers
572 
573  SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA
574  SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2
575  SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
576  SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
577  SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
578  SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
579  SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
580  ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN
582  /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports SNPE and XNNPACK.
584  const std::unordered_map<std::string, std::string>& provider_options = {});
585 
586  SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
587  SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions
588  SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn
589 
590  ///< Registers the custom operator from the specified shared library via OrtApi::RegisterCustomOpsLibrary_V2.
591  ///< The custom operator configurations are optional. If provided, custom operator configs are set via
592  ///< OrtApi::AddSessionConfigEntry.
593  SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {});
594 
595  SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction
596 };
597 } // namespace detail
598 
601 
602 /** \brief Wrapper around ::OrtSessionOptions
603  *
604  */
605 struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> {
606  explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used
607  SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions
608  explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl<OrtSessionOptions>{p} {} ///< Used for interop with the C API
609  UnownedSessionOptions GetUnowned() const { return UnownedSessionOptions{this->p_}; }
610  ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; }
611 };
612 
613 /** \brief Wrapper around ::OrtModelMetadata
614  *
615  */
616 struct ModelMetadata : detail::Base<OrtModelMetadata> {
617  explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used
618  explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} ///< Used for interop with the C API
619 
620  /** \brief Returns a copy of the producer name.
621  *
622  * \param allocator to allocate memory for the copy of the name returned
623  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
624  * The OrtAllocator instances must be valid at the point of memory release.
625  */
626  AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName
627 
628  /** \brief Returns a copy of the graph name.
629  *
630  * \param allocator to allocate memory for the copy of the name returned
631  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
632  * The OrtAllocator instances must be valid at the point of memory release.
633  */
634  AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName
635 
636  /** \brief Returns a copy of the domain name.
637  *
638  * \param allocator to allocate memory for the copy of the name returned
639  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
640  * The OrtAllocator instances must be valid at the point of memory release.
641  */
642  AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain
643 
644  /** \brief Returns a copy of the description.
645  *
646  * \param allocator to allocate memory for the copy of the string returned
647  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
648  * The OrtAllocator instances must be valid at the point of memory release.
649  */
650  AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription
651 
652  /** \brief Returns a copy of the graph description.
653  *
654  * \param allocator to allocate memory for the copy of the string returned
655  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
656  * The OrtAllocator instances must be valid at the point of memory release.
657  */
658  AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription
659 
660  /** \brief Returns a vector of copies of the custom metadata keys.
661  *
662  * \param allocator to allocate memory for the copy of the string returned
663  * \return a instance std::vector of smart pointers that would deallocate the buffers when out of scope.
664  * The OrtAllocator instance must be valid at the point of memory release.
665  */
666  std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys
667 
668  /** \brief Looks up a value by a key in the Custom Metadata map
669  *
670  * \param key zero terminated string key to lookup
671  * \param allocator to allocate memory for the copy of the string returned
672  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
673  * maybe nullptr if key is not found.
674  *
675  * The OrtAllocator instances must be valid at the point of memory release.
676  */
677  AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap
678 
679  int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion
680 };
681 
682 struct IoBinding;
683 
684 namespace detail {
685 
686 // we separate const-only methods because passing const ptr to non-const methods
687 // is only discovered when inline methods are compiled which is counter-intuitive
688 template <typename T>
689 struct ConstSessionImpl : Base<T> {
690  using B = Base<T>;
691  using B::B;
692 
693  size_t GetInputCount() const; ///< Returns the number of model inputs
694  size_t GetOutputCount() const; ///< Returns the number of model outputs
695  size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden
696 
697  /** \brief Returns a copy of input name at the specified index.
698  *
699  * \param index must less than the value returned by GetInputCount()
700  * \param allocator to allocate memory for the copy of the name returned
701  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
702  * The OrtAllocator instances must be valid at the point of memory release.
703  */
705 
706  /** \brief Returns a copy of output name at then specified index.
707  *
708  * \param index must less than the value returned by GetOutputCount()
709  * \param allocator to allocate memory for the copy of the name returned
710  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
711  * The OrtAllocator instances must be valid at the point of memory release.
712  */
714 
715  /** \brief Returns a copy of the overridable initializer name at then specified index.
716  *
717  * \param index must less than the value returned by GetOverridableInitializerCount()
718  * \param allocator to allocate memory for the copy of the name returned
719  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
720  * The OrtAllocator instances must be valid at the point of memory release.
721  */
722  AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName
723 
724  uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
725  ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata
726 
727  TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo
728  TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo
729  TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo
730 };
731 
732 template <typename T>
735  using B::B;
736 
737  /** \brief Run the model returning results in an Ort allocated vector.
738  *
739  * Wraps OrtApi::Run
740  *
741  * The caller provides a list of inputs and a list of the desired outputs to return.
742  *
743  * See the output logs for more information on warnings/errors that occur while processing the model.
744  * Common errors are.. (TODO)
745  *
746  * \param[in] run_options
747  * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names
748  * \param[in] input_values Array of Value objects of length input_count that is the list of input values
749  * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays)
750  * \param[in] output_names Array of C style strings of length output_count that is the list of output names
751  * \param[in] output_count Number of outputs (the size of the output_names array)
752  * \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector)
753  */
754  std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
755  const char* const* output_names, size_t output_count);
756 
757  /** \brief Run the model returning results in user provided outputs
758  * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t)
759  */
760  void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
761  const char* const* output_names, Value* output_values, size_t output_count);
762 
763  void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding
764 
765  /** \brief End profiling and return a copy of the profiling file name.
766  *
767  * \param allocator to allocate memory for the copy of the string returned
768  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
769  * The OrtAllocator instances must be valid at the point of memory release.
770  */
771  AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator); ///< Wraps OrtApi::SessionEndProfiling
772 };
773 
774 } // namespace detail
775 
778 
779 /** \brief Wrapper around ::OrtSession
780  *
781  */
782 struct Session : detail::SessionImpl<OrtSession> {
783  explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used
784  Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession
785  Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
786  OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
787  Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray
788  Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
789  OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer
790 
791  ConstSession GetConst() const { return ConstSession{this->p_}; }
792  UnownedSession GetUnowned() const { return UnownedSession{this->p_}; }
793 };
794 
795 namespace detail {
796 template <typename T>
797 struct MemoryInfoImpl : Base<T> {
798  using B = Base<T>;
799  using B::B;
800 
803  int GetDeviceId() const;
805  OrtMemType GetMemoryType() const;
806 
807  template <typename U>
808  bool operator==(const MemoryInfoImpl<U>& o) const;
809 };
810 } // namespace detail
811 
812 // Const object holder that does not own the underlying object
814 
815 /** \brief Wrapper around ::OrtMemoryInfo
816  *
817  */
818 struct MemoryInfo : detail::MemoryInfoImpl<OrtMemoryInfo> {
820  explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created
821  explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl<OrtMemoryInfo>{p} {} ///< Take ownership of a pointer created by C Api
822  MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
823  ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; }
824 };
825 
826 namespace detail {
827 template <typename T>
829  using B = Base<T>;
830  using B::B;
831 
832  ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType
833  size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount
834 
835  size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount
836 
837  /** \deprecated use GetShape() returning std::vector
838  * [[deprecated]]
839  * This interface is unsafe to use
840  */
841  [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions
842 
843  void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions
844 
845  std::vector<int64_t> GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape
846 };
847 
848 } // namespace detail
849 
851 
852 /** \brief Wrapper around ::OrtTensorTypeAndShapeInfo
853  *
854  */
855 struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl<OrtTensorTypeAndShapeInfo> {
856  explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used
857  explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API
859 };
860 
861 namespace detail {
862 template <typename T>
864  using B = Base<T>;
865  using B::B;
866  TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType
867 };
868 
869 } // namespace detail
870 
872 
873 /** \brief Wrapper around ::OrtSequenceTypeInfo
874  *
875  */
876 struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> {
877  explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used
878  explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {} ///< Used for interop with the C API
879  ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; }
880 };
881 
882 namespace detail {
883 template <typename T>
885  using B = Base<T>;
886  using B::B;
887  ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType
888  TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType
889 };
890 
891 } // namespace detail
892 
894 
895 /** \brief Wrapper around ::OrtMapTypeInfo
896  *
897  */
898 struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> {
899  explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used
900  explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {} ///< Used for interop with the C API
901  ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; }
902 };
903 
904 namespace detail {
905 template <typename T>
907  using B = Base<T>;
908  using B::B;
909 
910  ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo
911  ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo
912  ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo
913 
914  ONNXType GetONNXType() const;
915 };
916 } // namespace detail
917 
918 /// <summary>
919 /// Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value.
920 /// Provides access to const OrtTypeInfo APIs.
921 /// </summary>
923 
924 /// <summary>
925 /// Type information that may contain either TensorTypeAndShapeInfo or
926 /// the information about contained sequence or map depending on the ONNXType.
927 /// </summary>
928 struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> {
929  explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used
930  explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {} ///< C API Interop
931 
932  ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; }
933 };
934 
935 namespace detail {
936 // This structure is used to feed sparse tensor values
937 // information for use with FillSparseTensor<Format>() API
938 // if the data type for the sparse tensor values is numeric
939 // use data.p_data, otherwise, use data.str pointer to feed
940 // values. data.str is an array of const char* that are zero terminated.
941 // number of strings in the array must match shape size.
942 // For fully sparse tensors use shape {0} and set p_data/str
943 // to nullptr.
945  const int64_t* values_shape;
947  union {
948  const void* p_data;
949  const char** str;
950  } data;
951 };
952 
953 // Provides a way to pass shape in a single
954 // argument
955 struct Shape {
956  const int64_t* shape;
957  size_t shape_len;
958 };
959 
960 template <typename T>
961 struct ConstValueImpl : Base<T> {
962  using B = Base<T>;
963  using B::B;
964 
965  /// <summary>
966  /// Obtains a pointer to a user defined data for experimental purposes
967  /// </summary>
968  template <typename R>
969  void GetOpaqueData(const char* domain, const char* type_name, R&) const; ///< Wraps OrtApi::GetOpaqueValue
970 
971  bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc
972  bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None
973 
974  size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
975  Value GetValue(int index, OrtAllocator* allocator) const;
976 
977  /// <summary>
978  /// This API returns a full length of string data contained within either a tensor or a sparse Tensor.
979  /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful
980  /// for allocating necessary memory and calling GetStringTensorContent().
981  /// </summary>
982  /// <returns>total length of UTF-8 encoded bytes contained. No zero terminators counted.</returns>
983  size_t GetStringTensorDataLength() const;
984 
985  /// <summary>
986  /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor
987  /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate.
988  /// The user must also allocate offsets buffer with the number of entries equal to that of the contained
989  /// strings.
990  ///
991  /// Strings are always assumed to be on CPU, no X-device copy.
992  /// </summary>
993  /// <param name="buffer">user allocated buffer</param>
994  /// <param name="buffer_length">length in bytes of the allocated buffer</param>
995  /// <param name="offsets">a pointer to the offsets user allocated buffer</param>
996  /// <param name="offsets_count">count of offsets, must be equal to the number of strings contained.
997  /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo()
998  /// for sparse tensors</param>
999  void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
1000 
1001  /// <summary>
1002  /// Returns a const typed pointer to the tensor contained data.
1003  /// No type checking is performed, the caller must ensure the type matches the tensor type.
1004  /// </summary>
1005  /// <typeparam name="T"></typeparam>
1006  /// <returns>const pointer to data, no copies made</returns>
1007  template <typename R>
1008  const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// <summary>
1009 
1010  /// <summary>
1011  /// Returns a non-typed pointer to a tensor contained data.
1012  /// </summary>
1013  /// <returns>const pointer to data, no copies made</returns>
1014  const void* GetTensorRawData() const;
1015 
1016  /// <summary>
1017  /// The API returns type information for data contained in a tensor. For sparse
1018  /// tensors it returns type information for contained non-zero values.
1019  /// It returns dense shape for sparse tensors.
1020  /// </summary>
1021  /// <returns>TypeInfo</returns>
1022  TypeInfo GetTypeInfo() const;
1023 
1024  /// <summary>
1025  /// The API returns type information for data contained in a tensor. For sparse
1026  /// tensors it returns type information for contained non-zero values.
1027  /// It returns dense shape for sparse tensors.
1028  /// </summary>
1029  /// <returns>TensorTypeAndShapeInfo</returns>
1031 
1032  /// <summary>
1033  /// This API returns information about the memory allocation used to hold data.
1034  /// </summary>
1035  /// <returns>Non owning instance of MemoryInfo</returns>
1037 
1038  /// <summary>
1039  /// The API copies UTF-8 encoded bytes for the requested string element
1040  /// contained within a tensor or a sparse tensor into a provided buffer.
1041  /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate.
1042  /// </summary>
1043  /// <param name="buffer_length"></param>
1044  /// <param name="element_index"></param>
1045  /// <param name="buffer"></param>
1046  void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
1047 
1048  /// <summary>
1049  /// The API returns a byte length of UTF-8 encoded string element
1050  /// contained in either a tensor or a spare tensor values.
1051  /// </summary>
1052  /// <param name="element_index"></param>
1053  /// <returns>byte length for the specified string element</returns>
1054  size_t GetStringTensorElementLength(size_t element_index) const;
1055 
1056 #if !defined(DISABLE_SPARSE_TENSORS)
1057  /// <summary>
1058  /// The API returns the sparse data format this OrtValue holds in a sparse tensor.
1059  /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used
1060  /// the value returned is ORT_SPARSE_UNDEFINED.
1061  /// </summary>
1062  /// <returns>Format enum</returns>
1064 
1065  /// <summary>
1066  /// The API returns type and shape information for stored non-zero values of the
1067  /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer.
1068  /// </summary>
1069  /// <returns>TensorTypeAndShapeInfo values information</returns>
1071 
1072  /// <summary>
1073  /// The API returns type and shape information for the specified indices. Each supported
1074  /// indices have their own enum values even if a give format has more than one kind of indices.
1075  /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer.
1076  /// </summary>
1077  /// <param name="format">enum requested</param>
1078  /// <returns>type and shape information</returns>
1080 
1081  /// <summary>
1082  /// The API retrieves a pointer to the internal indices buffer. The API merely performs
1083  /// a convenience data type casting on the return type pointer. Make sure you are requesting
1084  /// the right type, use GetSparseTensorIndicesTypeShapeInfo();
1085  /// </summary>
1086  /// <typeparam name="T">type to cast to</typeparam>
1087  /// <param name="indices_format">requested indices kind</param>
1088  /// <param name="num_indices">number of indices entries</param>
1089  /// <returns>Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer.</returns>
1090  template <typename R>
1091  const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
1092 
1093  /// <summary>
1094  /// Returns true if the OrtValue contains a sparse tensor
1095  /// </summary>
1096  /// <returns></returns>
1097  bool IsSparseTensor() const;
1098 
1099  /// <summary>
1100  /// The API returns a pointer to an internal buffer of the sparse tensor
1101  /// containing non-zero values. The API merely does casting. Make sure you
1102  /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo()
1103  /// first.
1104  /// </summary>
1105  /// <typeparam name="T">numeric data types only. Use GetStringTensor*() to retrieve strings.</typeparam>
1106  /// <returns>a pointer to the internal values buffer. Do not free this pointer.</returns>
1107  template <typename R>
1108  const R* GetSparseTensorValues() const;
1109 
1110 #endif
1111 };
1112 
1113 template <typename T>
1116  using B::B;
1117 
1118  /// <summary>
1119  /// Returns a non-const typed pointer to an OrtValue/Tensor contained buffer
1120  /// No type checking is performed, the caller must ensure the type matches the tensor type.
1121  /// </summary>
1122  /// <returns>non-const pointer to data, no copies made</returns>
1123  template <typename R>
1125 
1126  /// <summary>
1127  /// Returns a non-typed non-const pointer to a tensor contained data.
1128  /// </summary>
1129  /// <returns>pointer to data, no copies made</returns>
1130  void* GetTensorMutableRawData();
1131 
1132  /// <summary>
1133  // Obtain a reference to an element of data at the location specified
1134  /// by the vector of dims.
1135  /// </summary>
1136  /// <typeparam name="R"></typeparam>
1137  /// <param name="location">[in] expressed by a vecotr of dimensions offsets</param>
1138  /// <returns></returns>
1139  template <typename R>
1140  R& At(const std::vector<int64_t>& location);
1141 
1142  /// <summary>
1143  /// Set all strings at once in a string tensor
1144  /// </summary>
1145  /// <param name="s">[in] An array of strings. Each string in this array must be null terminated.</param>
1146  /// <param name="s_len">[in] Count of strings in s (Must match the size of \p value's tensor shape)</param>
1147  void FillStringTensor(const char* const* s, size_t s_len);
1148 
1149  /// <summary>
1150  /// Set a single string in a string tensor
1151  /// </summary>
1152  /// <param name="s">[in] A null terminated UTF-8 encoded string</param>
1153  /// <param name="index">[in] Index of the string in the tensor to set</param>
1154  void FillStringTensorElement(const char* s, size_t index);
1155 
1156 #if !defined(DISABLE_SPARSE_TENSORS)
1157  /// <summary>
1158  /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor.
1159  /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1160  /// allocated buffers lifespan must eclipse that of the OrtValue.
1161  /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1162  /// </summary>
1163  /// <param name="indices_data">pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors.</param>
1164  /// <param name="indices_num">number of indices entries. Use 0 for fully sparse tensors</param>
1165  void UseCooIndices(int64_t* indices_data, size_t indices_num);
1166 
1167  /// <summary>
1168  /// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor.
1169  /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1170  /// allocated buffers lifespan must eclipse that of the OrtValue.
1171  /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1172  /// </summary>
1173  /// <param name="inner_data">pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors</param>
1174  /// <param name="inner_num">number of csr inner indices or 0 for fully sparse tensors</param>
1175  /// <param name="outer_data">pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors</param>
1176  /// <param name="outer_num">number of csr outer indices or 0 for fully sparse tensors</param>
1177  void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
1178 
1179  /// <summary>
1180  /// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor.
1181  /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1182  /// allocated buffers lifespan must eclipse that of the OrtValue.
1183  /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1184  /// </summary>
1185  /// <param name="indices_shape">indices shape or a {0} for fully sparse</param>
1186  /// <param name="indices_data">user allocated buffer with indices or nullptr for fully spare tensors</param>
1187  void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
1188 
1189  /// <summary>
1190  /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1191  /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located
1192  /// at difference device than the allocator, a X-device copy will be performed if possible.
1193  /// </summary>
1194  /// <param name="data_mem_info">specified buffer memory description</param>
1195  /// <param name="values_param">values buffer information.</param>
1196  /// <param name="indices_data">coo indices buffer or nullptr for fully sparse data</param>
1197  /// <param name="indices_num">number of COO indices or 0 for fully sparse data</param>
1198  void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
1199  const int64_t* indices_data, size_t indices_num);
1200 
1201  /// <summary>
1202  /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1203  /// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located
1204  /// at difference device than the allocator, a X-device copy will be performed if possible.
1205  /// </summary>
1206  /// <param name="data_mem_info">specified buffer memory description</param>
1207  /// <param name="values">values buffer information</param>
1208  /// <param name="inner_indices_data">csr inner indices pointer or nullptr for fully sparse tensors</param>
1209  /// <param name="inner_indices_num">number of csr inner indices or 0 for fully sparse tensors</param>
1210  /// <param name="outer_indices_data">pointer to csr indices data or nullptr for fully sparse tensors</param>
1211  /// <param name="outer_indices_num">number of csr outer indices or 0</param>
1212  void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1214  const int64_t* inner_indices_data, size_t inner_indices_num,
1215  const int64_t* outer_indices_data, size_t outer_indices_num);
1216 
1217  /// <summary>
1218  /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1219  /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located
1220  /// at difference device than the allocator, a X-device copy will be performed if possible.
1221  /// </summary>
1222  /// <param name="data_mem_info">specified buffer memory description</param>
1223  /// <param name="values">values buffer information</param>
1224  /// <param name="indices_shape">indices shape. use {0} for fully sparse tensors</param>
1225  /// <param name="indices_data">pointer to indices data or nullptr for fully sparse tensors</param>
1226  void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
1228  const Shape& indices_shape,
1229  const int32_t* indices_data);
1230 
1231 #endif
1232 };
1233 
1234 } // namespace detail
1235 
1238 
1239 /** \brief Wrapper around ::OrtValue
1240  *
1241  */
1242 struct Value : detail::ValueImpl<OrtValue> {
1246 
1247  explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used
1248  explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API
1249  Value(Value&&) = default;
1250  Value& operator=(Value&&) = default;
1251 
1252  ConstValue GetConst() const { return ConstValue{this->p_}; }
1253  UnownedValue GetUnowned() const { return UnownedValue{this->p_}; }
1254 
1255  /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
1256  * \tparam T The numeric datatype. This API is not suitable for strings.
1257  * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
1258  * \param p_data Pointer to the data buffer.
1259  * \param p_data_element_count The number of elements in the data buffer.
1260  * \param shape Pointer to the tensor shape dimensions.
1261  * \param shape_len The number of tensor shape dimensions.
1262  */
1263  template <typename T>
1264  static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
1265 
1266  /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
1267  * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
1268  * \param p_data Pointer to the data buffer.
1269  * \param p_data_byte_count The number of bytes in the data buffer.
1270  * \param shape Pointer to the tensor shape dimensions.
1271  * \param shape_len The number of tensor shape dimensions.
1272  * \param type The data type.
1273  */
1274  static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1276 
1277  /** \brief Creates a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
1278  * \tparam T The numeric datatype. This API is not suitable for strings.
1279  * \param allocator The allocator to use.
1280  * \param shape Pointer to the tensor shape dimensions.
1281  * \param shape_len The number of tensor shape dimensions.
1282  */
1283  template <typename T>
1284  static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
1285 
1286  /** \brief Creates a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
1287  * \param allocator The allocator to use.
1288  * \param shape Pointer to the tensor shape dimensions.
1289  * \param shape_len The number of tensor shape dimensions.
1290  * \param type The data type.
1291  */
1292  static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
1293 
1294  static Value CreateMap(Value& keys, Value& values); ///< Wraps OrtApi::CreateValue
1295  static Value CreateSequence(std::vector<Value>& values); ///< Wraps OrtApi::CreateValue
1296 
1297  template <typename T>
1298  static Value CreateOpaque(const char* domain, const char* type_name, const T&); ///< Wraps OrtApi::CreateOpaqueValue
1299 
1300 #if !defined(DISABLE_SPARSE_TENSORS)
1301  /// <summary>
1302  /// This is a simple forwarding method to the other overload that helps deducing
1303  /// data type enum value from the type of the buffer.
1304  /// </summary>
1305  /// <typeparam name="T">numeric datatype. This API is not suitable for strings.</typeparam>
1306  /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
1307  /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
1308  /// <param name="dense_shape">a would be dense shape of the tensor</param>
1309  /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
1310  /// <returns></returns>
1311  template <typename T>
1312  static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1313  const Shape& values_shape);
1314 
1315  /// <summary>
1316  /// Creates an OrtValue instance containing SparseTensor. This constructs
1317  /// a sparse tensor that makes use of user allocated buffers. It does not make copies
1318  /// of the user provided data and does not modify it. The lifespan of user provided buffers should
1319  /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain
1320  /// a pointer to non-zero values. To fully populate the sparse tensor call Use<Format>Indices() API below
1321  /// to supply a sparse format specific indices.
1322  /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings
1323  /// can be properly copied into the allocated buffer.
1324  /// </summary>
1325  /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
1326  /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
1327  /// <param name="dense_shape">a would be dense shape of the tensor</param>
1328  /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
1329  /// <param name="type">data type</param>
1330  /// <returns>Ort::Value instance containing SparseTensor</returns>
1331  static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1332  const Shape& values_shape, ONNXTensorElementDataType type);
1333 
1334  /// <summary>
1335  /// This is a simple forwarding method to the below CreateSparseTensor.
1336  /// This helps to specify data type enum in terms of C++ data type.
1337  /// Use CreateSparseTensor<T>
1338  /// </summary>
1339  /// <typeparam name="T">numeric data type only. String data enum must be specified explicitly.</typeparam>
1340  /// <param name="allocator">allocator to use</param>
1341  /// <param name="dense_shape">a would be dense shape of the tensor</param>
1342  /// <returns>Ort::Value</returns>
1343  template <typename T>
1344  static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
1345 
1346  /// <summary>
1347  /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data.
1348  /// The data must be supplied by on of the FillSparseTensor<Format>() methods that take both non-zero values
1349  /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator.
1350  /// Use this API to create OrtValues that contain sparse tensors with all supported data types including
1351  /// strings.
1352  /// </summary>
1353  /// <param name="allocator">allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue</param>
1354  /// <param name="dense_shape">a would be dense shape of the tensor</param>
1355  /// <param name="type">data type</param>
1356  /// <returns>an instance of Ort::Value</returns>
1357  static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
1358 
1359 #endif // !defined(DISABLE_SPARSE_TENSORS)
1360 };
1361 
1362 /// <summary>
1363 /// Represents native memory allocation coming from one of the
1364 /// OrtAllocators registered with OnnxRuntime.
1365 /// Use it to wrap an allocation made by an allocator
1366 /// so it can be automatically released when no longer needed.
1367 /// </summary>
1369  MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
1371  MemoryAllocation(const MemoryAllocation&) = delete;
1372  MemoryAllocation& operator=(const MemoryAllocation&) = delete;
1373  MemoryAllocation(MemoryAllocation&&) noexcept;
1375 
1376  void* get() { return p_; }
1377  size_t size() const { return size_; }
1378 
1379  private:
1380  OrtAllocator* allocator_;
1381  void* p_;
1382  size_t size_;
1383 };
1384 
1385 namespace detail {
1386 template <typename T>
1387 struct AllocatorImpl : Base<T> {
1388  using B = Base<T>;
1389  using B::B;
1390 
1391  void* Alloc(size_t size);
1393  void Free(void* p);
1394  ConstMemoryInfo GetInfo() const;
1395 };
1396 
1397 } // namespace detail
1398 
1399 /** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime
1400  *
1401  */
1402 struct AllocatorWithDefaultOptions : detail::AllocatorImpl<detail::Unowned<OrtAllocator>> {
1403  explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
1405 };
1406 
1407 /** \brief Wrapper around ::OrtAllocator
1408  *
1409  */
1410 struct Allocator : detail::AllocatorImpl<OrtAllocator> {
1411  explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
1412  Allocator(const Session& session, const OrtMemoryInfo*);
1413 };
1414 
1416 
1417 namespace detail {
1418 namespace binding_utils {
1419 // Bring these out of template
1420 std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*);
1421 std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*);
1422 } // namespace binding_utils
1423 
1424 template <typename T>
1426  using B = Base<T>;
1427  using B::B;
1428 
1429  std::vector<std::string> GetOutputNames() const;
1430  std::vector<std::string> GetOutputNames(OrtAllocator*) const;
1431  std::vector<Value> GetOutputValues() const;
1432  std::vector<Value> GetOutputValues(OrtAllocator*) const;
1433 };
1434 
1435 template <typename T>
1438  using B::B;
1439 
1440  void BindInput(const char* name, const Value&);
1441  void BindOutput(const char* name, const Value&);
1442  void BindOutput(const char* name, const OrtMemoryInfo*);
1443  void ClearBoundInputs();
1444  void ClearBoundOutputs();
1445  void SynchronizeInputs();
1446  void SynchronizeOutputs();
1447 };
1448 
1449 } // namespace detail
1450 
1453 
1454 /** \brief Wrapper around ::OrtIoBinding
1455  *
1456  */
1457 struct IoBinding : detail::IoBindingImpl<OrtIoBinding> {
1458  explicit IoBinding(std::nullptr_t) {} ///< Create an empty object for convenience. Sometimes, we want to initialize members later.
1459  explicit IoBinding(Session& session);
1460  ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; }
1461  UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; }
1462 };
1463 
1464 /*! \struct Ort::ArenaCfg
1465  * \brief it is a structure that represents the configuration of an arena based allocator
1466  * \details Please see docs/C_API.md for details
1467  */
1468 struct ArenaCfg : detail::Base<OrtArenaCfg> {
1469  explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used
1470  /**
1471  * Wraps OrtApi::CreateArenaCfg
1472  * \param max_mem - use 0 to allow ORT to choose the default
1473  * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
1474  * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
1475  * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
1476  * See docs/C_API.md for details on what the following parameters mean and how to choose these values
1477  */
1478  ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
1479 };
1480 
1481 //
1482 // Custom OPs (only needed to implement custom OPs)
1483 //
1484 
1485 /// <summary>
1486 /// This struct provides life time management for custom op attribute
1487 /// </summary>
1488 struct OpAttr : detail::Base<OrtOpAttr> {
1489  OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
1490 };
1491 
1492 /// <summary>
1493 /// This class wraps a raw pointer OrtKernelContext* that is being passed
1494 /// to the custom kernel Compute() method. Use it to safely access context
1495 /// attributes, input and output parameters with exception safety guarantees.
1496 /// See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc
1497 /// </summary>
1499  explicit KernelContext(OrtKernelContext* context);
1500  size_t GetInputCount() const;
1501  size_t GetOutputCount() const;
1502  ConstValue GetInput(size_t index) const;
1503  UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
1504  UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
1505  void* GetGPUComputeStream() const;
1506 
1507  private:
1508  OrtKernelContext* ctx_;
1509 };
1510 
1511 struct KernelInfo;
1512 
1513 namespace detail {
1514 namespace attr_utils {
1515 void GetAttr(const OrtKernelInfo* p, const char* name, float&);
1516 void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&);
1517 void GetAttr(const OrtKernelInfo* p, const char* name, std::string&);
1518 void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&);
1519 void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&);
1520 } // namespace attr_utils
1521 
1522 template <typename T>
1523 struct KernelInfoImpl : Base<T> {
1524  using B = Base<T>;
1525  using B::B;
1526 
1527  KernelInfo Copy() const;
1528 
1529  template <typename R> // R is only implemented for float, int64_t, and string
1530  R GetAttribute(const char* name) const {
1531  R val;
1532  attr_utils::GetAttr(this->p_, name, val);
1533  return val;
1534  }
1535 
1536  template <typename R> // R is only implemented for std::vector<float>, std::vector<int64_t>
1537  std::vector<R> GetAttributes(const char* name) const {
1538  std::vector<R> result;
1539  attr_utils::GetAttrs(this->p_, name, result);
1540  return result;
1541  }
1542 
1543  Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const;
1544 
1545  size_t GetInputCount() const;
1546  size_t GetOutputCount() const;
1547 
1548  std::string GetInputName(size_t index) const;
1549  std::string GetOutputName(size_t index) const;
1550 
1551  TypeInfo GetInputTypeInfo(size_t index) const;
1552  TypeInfo GetOutputTypeInfo(size_t index) const;
1553 };
1554 
1555 } // namespace detail
1556 
1558 
1559 /// <summary>
1560 /// This struct owns the OrtKernInfo* pointer when a copy is made.
1561 /// For convenient wrapping of OrtKernelInfo* passed to kernel constructor
1562 /// and query attributes, warp the pointer with Ort::Unowned<KernelInfo> instance
1563 /// so it does not destroy the pointer the kernel does not own.
1564 /// </summary>
1565 struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
1566  explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later
1567  explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance
1568  ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; }
1569 };
1570 
1571 /// <summary>
1572 /// Create and own custom defined operation.
1573 /// </summary>
1574 struct Op : detail::Base<OrtOp> {
1575  explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used
1576 
1577  explicit Op(OrtOp*); ///< Take ownership of the OrtOp
1578 
1579  static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain,
1580  int version, const char** type_constraint_names,
1581  const ONNXTensorElementDataType* type_constraint_values,
1582  size_t type_constraint_count,
1583  const OpAttr* attr_values,
1584  size_t attr_count,
1585  size_t input_count, size_t output_count);
1586 
1587  void Invoke(const OrtKernelContext* context,
1588  const Value* input_values,
1589  size_t input_count,
1590  Value* output_values,
1591  size_t output_count);
1592 
1593  // For easier refactoring
1594  void Invoke(const OrtKernelContext* context,
1595  const OrtValue* const* input_values,
1596  size_t input_count,
1597  OrtValue* const* output_values,
1598  size_t output_count);
1599 };
1600 
1601 /// <summary>
1602 /// This entire structure is deprecated, but we not marking
1603 /// it as a whole yet since we want to preserve for the next release.
1604 /// </summary>
1605 struct CustomOpApi {
1606  CustomOpApi(const OrtApi& api) : api_(api) {}
1607 
1608  /** \deprecated use Ort::Value::GetTensorTypeAndShape()
1609  * [[deprecated]]
1610  * This interface produces a pointer that must be released. Not exception safe.
1611  */
1612  [[deprecated("use Ort::Value::GetTensorTypeAndShape()")]] OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value);
1613 
1614  /** \deprecated use Ort::TensorTypeAndShapeInfo::GetElementCount()
1615  * [[deprecated]]
1616  * This interface is redundant.
1617  */
1618  [[deprecated("use Ort::TensorTypeAndShapeInfo::GetElementCount()")]] size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info);
1619 
1620  /** \deprecated use Ort::TensorTypeAndShapeInfo::GetElementType()
1621  * [[deprecated]]
1622  * This interface is redundant.
1623  */
1624  [[deprecated("use Ort::TensorTypeAndShapeInfo::GetElementType()")]] ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info);
1625 
1626  /** \deprecated use Ort::TensorTypeAndShapeInfo::GetDimensionsCount()
1627  * [[deprecated]]
1628  * This interface is redundant.
1629  */
1630  [[deprecated("use Ort::TensorTypeAndShapeInfo::GetDimensionsCount()")]] size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info);
1631 
1632  /** \deprecated use Ort::TensorTypeAndShapeInfo::GetShape()
1633  * [[deprecated]]
1634  * This interface is redundant.
1635  */
1636  [[deprecated("use Ort::TensorTypeAndShapeInfo::GetShape()")]] void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
1637 
1638  /** \deprecated
1639  * [[deprecated]]
1640  * This interface sets dimensions to TensorTypeAndShapeInfo, but has no effect on the OrtValue.
1641  */
1642  [[deprecated("Do not use")]] void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
1643 
1644  /** \deprecated use Ort::Value::GetTensorMutableData()
1645  * [[deprecated]]
1646  * This interface is redundant.
1647  */
1648  template <typename T>
1649  [[deprecated("use Ort::Value::GetTensorMutableData()")]] T* GetTensorMutableData(_Inout_ OrtValue* value);
1650 
1651  /** \deprecated use Ort::Value::GetTensorData()
1652  * [[deprecated]]
1653  * This interface is redundant.
1654  */
1655  template <typename T>
1656  [[deprecated("use Ort::Value::GetTensorData()")]] const T* GetTensorData(_Inout_ const OrtValue* value);
1657 
1658  /** \deprecated use Ort::Value::GetTensorMemoryInfo()
1659  * [[deprecated]]
1660  * This interface is redundant.
1661  */
1662  [[deprecated("use Ort::Value::GetTensorMemoryInfo()")]] const OrtMemoryInfo* GetTensorMemoryInfo(_In_ const OrtValue* value);
1663 
1664  /** \deprecated use Ort::TensorTypeAndShapeInfo::GetShape()
1665  * [[deprecated]]
1666  * This interface is redundant.
1667  */
1668  [[deprecated("use Ort::TensorTypeAndShapeInfo::GetShape()")]] std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info);
1669 
1670  /** \deprecated use TensorTypeAndShapeInfo instances for automatic ownership.
1671  * [[deprecated]]
1672  * This interface is not exception safe.
1673  */
1674  [[deprecated("use TensorTypeAndShapeInfo")]] void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input);
1675 
1676  /** \deprecated use Ort::KernelContext::GetInputCount
1677  * [[deprecated]]
1678  * This interface is redundant.
1679  */
1680  [[deprecated("use Ort::KernelContext::GetInputCount")]] size_t KernelContext_GetInputCount(const OrtKernelContext* context);
1681 
1682  /** \deprecated use Ort::KernelContext::GetInput
1683  * [[deprecated]]
1684  * This interface is redundant.
1685  */
1686  [[deprecated("use Ort::KernelContext::GetInput")]] const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index);
1687 
1688  /** \deprecated use Ort::KernelContext::GetOutputCount
1689  * [[deprecated]]
1690  * This interface is redundant.
1691  */
1692  [[deprecated("use Ort::KernelContext::GetOutputCount")]] size_t KernelContext_GetOutputCount(const OrtKernelContext* context);
1693 
1694  /** \deprecated use Ort::KernelContext::GetOutput
1695  * [[deprecated]]
1696  * This interface is redundant.
1697  */
1698  [[deprecated("use Ort::KernelContext::GetOutput")]] OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count);
1699 
1700  /** \deprecated use Ort::KernelContext::GetGPUComputeStream
1701  * [[deprecated]]
1702  * This interface is redundant.
1703  */
1704  [[deprecated("use Ort::KernelContext::GetGPUComputeStream")]] void* KernelContext_GetGPUComputeStream(const OrtKernelContext* context);
1705 
1706  /** \deprecated use Ort::ThrowOnError()
1707  * [[deprecated]]
1708  * This interface is redundant.
1709  */
1710  [[deprecated("use Ort::ThrowOnError()")]] void ThrowOnError(OrtStatus* result);
1711 
1712  /** \deprecated use Ort::OpAttr
1713  * [[deprecated]]
1714  * This interface is not exception safe.
1715  */
1716  [[deprecated("use Ort::OpAttr")]] OrtOpAttr* CreateOpAttr(_In_ const char* name,
1717  _In_ const void* data,
1718  _In_ int len,
1720 
1721  /** \deprecated use Ort::OpAttr
1722  * [[deprecated]]
1723  * This interface is not exception safe.
1724  */
1725  [[deprecated("use Ort::OpAttr")]] void ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr* op_attr);
1726 
1727  /** \deprecated use Ort::Op
1728  * [[deprecated]]
1729  * This interface is not exception safe.
1730  */
1731  [[deprecated("use Ort::Op")]] OrtOp* CreateOp(_In_ const OrtKernelInfo* info,
1732  _In_ const char* op_name,
1733  _In_ const char* domain,
1734  _In_ int version,
1735  _In_opt_ const char** type_constraint_names,
1736  _In_opt_ const ONNXTensorElementDataType* type_constraint_values,
1737  _In_opt_ int type_constraint_count,
1738  _In_opt_ const OrtOpAttr* const* attr_values,
1739  _In_opt_ int attr_count,
1740  _In_ int input_count,
1741  _In_ int output_count);
1742 
1743  /** \deprecated use Ort::Op::Invoke
1744  * [[deprecated]]
1745  * This interface is redundant
1746  */
1747  [[deprecated("use Ort::Op::Invoke")]] void InvokeOp(_In_ const OrtKernelContext* context,
1748  _In_ const OrtOp* ort_op,
1749  _In_ const OrtValue* const* input_values,
1750  _In_ int input_count,
1751  _Inout_ OrtValue* const* output_values,
1752  _In_ int output_count);
1753 
1754  /** \deprecated use Ort::Op for automatic lifespan management.
1755  * [[deprecated]]
1756  * This interface is not exception safe.
1757  */
1758  [[deprecated("use Ort::Op")]] void ReleaseOp(_Frees_ptr_opt_ OrtOp* ort_op);
1759 
1760  /** \deprecated use Ort::KernelInfo for automatic lifespan management or for
1761  * querying attributes
1762  * [[deprecated]]
1763  * This interface is redundant
1764  */
1765  template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
1766  [[deprecated("use Ort::KernelInfo::GetAttribute")]] T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name);
1767 
1768  /** \deprecated use Ort::KernelInfo::Copy
1769  * querying attributes
1770  * [[deprecated]]
1771  * This interface is not exception safe
1772  */
1773  [[deprecated("use Ort::KernelInfo::Copy")]] OrtKernelInfo* CopyKernelInfo(_In_ const OrtKernelInfo* info);
1774 
1775  /** \deprecated use Ort::KernelInfo for lifespan management
1776  * querying attributes
1777  * [[deprecated]]
1778  * This interface is not exception safe
1779  */
1780  [[deprecated("use Ort::KernelInfo")]] void ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_copy);
1781 
1782  private:
1783  const OrtApi& api_;
1784 };
1785 
1786 template <typename TOp, typename TKernel>
1790  OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
1791  OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
1792 
1793  OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
1794 
1795  OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
1796  OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
1797  OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); };
1798 
1799  OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
1800  OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
1801 
1802  OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
1803 #if defined(_MSC_VER) && !defined(__clang__)
1804 #pragma warning(push)
1805 #pragma warning(disable : 26409)
1806 #endif
1807  OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
1808 #if defined(_MSC_VER) && !defined(__clang__)
1809 #pragma warning(pop)
1810 #endif
1811  OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
1812  OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
1813 
1814  OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicInputMinArity(); };
1815  OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicInputHomogeneity()); };
1816  OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicOutputMinArity(); };
1817  OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicOutputHomogeneity()); };
1818  }
1819 
1820  // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
1821  const char* GetExecutionProviderType() const { return nullptr; }
1822 
1823  // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
1824  // (inputs and outputs are required by default)
1827  }
1828 
1831  }
1832 
1833  // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault
1834  OrtMemType GetInputMemoryType(size_t /*index*/) const {
1835  return OrtMemTypeDefault;
1836  }
1837 
1838  // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input
1839  // should expect at least 1 argument.
1841  return 1;
1842  }
1843 
1844  // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments
1845  // to a variadic input should be of the same type.
1847  return true;
1848  }
1849 
1850  // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output
1851  // should produce at least 1 output value.
1853  return 1;
1854  }
1855 
1856  // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values
1857  // produced by a variadic output should be of the same type.
1859  return true;
1860  }
1861 
1862  // Declare list of session config entries used by this Custom Op.
1863  // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs().
1864  // This default implementation returns an empty vector of config entries.
1865  std::vector<std::string> GetSessionConfigKeys() const {
1866  return std::vector<std::string>{};
1867  }
1868 
1869  protected:
1870  // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
1871  void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
1872 };
1873 
1874 } // namespace Ort
1875 
1876 #include "onnxruntime_cxx_inline.h"
const char * GetExecutionProviderType() const
OrtMemoryInfoDeviceType GetDeviceType() const
ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo *info)
UnownedSession GetUnowned() const
void * KernelContext_GetGPUComputeStream(const OrtKernelContext *context)
void Invoke(const OrtKernelContext *context, const Value *input_values, size_t input_count, Value *output_values, size_t output_count)
SessionOptionsImpl & SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn)
Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn.
size_t GetElementCount() const
Wraps OrtApi::GetTensorShapeElementCount.
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const
static Value CreateOpaque(const char *domain, const char *type_name, const T &)
Wraps OrtApi::CreateOpaqueValue.
MemoryAllocation & operator=(const MemoryAllocation &)=delete
AllocatorWithDefaultOptions(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
Env & DisableTelemetryEvents()
Wraps OrtApi::EnableTelemetryEvents.
AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of the overridable initializer name at then specified index.
This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object has no...
ThreadingOptions & SetGlobalSpinControl(int allow_spinning)
Wraps OrtApi::SetGlobalSpinControl.
SequenceTypeInfo(OrtSequenceTypeInfo *p)
TypeInfo(std::nullptr_t)
Create an empty TypeInfo object, must be assigned a valid one to be used.
constexpr Base(contained_type *p) noexcept
OrtCustomThreadHandle(* OrtCustomCreateThreadFn)(void *ort_custom_thread_creation_options, OrtThreadWorkerFn ort_thread_worker_fn, void *ort_worker_fn_param)
Ort custom thread creation function.
void *ORT_API_CALL * CreateKernel(_In_ const struct OrtCustomOp *op, _In_ const OrtApi *api, _In_ const OrtKernelInfo *info)
std::string GetErrorMessage() const
size_t GetInputCount() const
Returns the number of model inputs.
SessionOptionsImpl & DisablePerSessionThreads()
Wraps OrtApi::DisablePerSessionThreads.
Env & UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level)
Wraps OrtApi::UpdateEnvWithCustomLogLevel.
#define _In_
void InvokeOp(_In_ const OrtKernelContext *context, _In_ const OrtOp *ort_op, _In_ const OrtValue *const *input_values, _In_ int input_count, _Inout_ OrtValue *const *output_values, _In_ int output_count)
void UseBlockSparseIndices(const Shape &indices_shape, int32_t *indices_data)
Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSp...
RunOptions & AddConfigEntry(const char *config_key, const char *config_value)
Wraps OrtApi::AddRunConfigEntry.
ConstMapTypeInfo GetMapTypeInfo() const
Wraps OrtApi::CastTypeInfoToMapTypeInfo.
void FillSparseTensorCsr(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const int64_t *inner_indices_data, size_t inner_indices_num, const int64_t *outer_indices_data, size_t outer_indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
TypeInfo GetInputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetInputTypeInfo.
size_t GetOutputCount() const
void GetSymbolicDimensions(const char **values, size_t values_count) const
Wraps OrtApi::GetSymbolicDimensions.
Type information that may contain either TensorTypeAndShapeInfo or the information about contained se...
SessionOptionsImpl & EnableMemPattern()
Wraps OrtApi::EnableMemPattern.
TensorTypeAndShapeInfo(std::nullptr_t)
Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used...
#define _Frees_ptr_opt_
SessionOptionsImpl & EnableCpuMemArena()
Wraps OrtApi::EnableCpuMemArena.
static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1)
Value(OrtValue *p)
Used for interop with the C API.
Env(OrtEnv *p)
C Interop Helper.
bool IsTensor() const
Returns true if Value is a tensor, false for other types like map/sequence/etc.
AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of output name at then specified index.
void GetOpaqueData(const char *domain, const char *type_name, R &) const
Obtains a pointer to a user defined data for experimental purposes
const void * GetTensorRawData() const
Returns a non-typed pointer to a tensor contained data.
ConstMemoryInfo GetInfo() const
T * GetTensorMutableData(_Inout_ OrtValue *value)
void GetStringTensorElement(size_t buffer_length, size_t element_index, void *buffer) const
The API copies UTF-8 encoded bytes for the requested string element contained within a tensor or a sp...
Value(std::nullptr_t)
Create an empty Value object, must be assigned a valid one to be used.
SessionOptionsImpl & SetLogId(const char *logid)
Wraps OrtApi::SetSessionLogId.
Custom Op Domain.
void swap(UT::ArraySet< Key, MULTI, MAX_LOAD_FACTOR_256, Clearer, Hash, KeyEqual > &a, UT::ArraySet< Key, MULTI, MAX_LOAD_FACTOR_256, Clearer, Hash, KeyEqual > &b)
Definition: UT_ArraySet.h:1631
CustomOpConfigs()=default
Value GetValue(int index, OrtAllocator *allocator) const
const GLdouble * v
Definition: glcorearb.h:837
Base & operator=(Base &&v) noexcept
uint64_t GetProfilingStartTimeNs() const
Wraps OrtApi::SessionGetProfilingStartTimeNs.
#define ORTCHAR_T
Status(std::nullptr_t)
Create an empty object, must be assigned a valid one to be used.
SessionOptionsImpl & AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN.
UnownedValue GetOutput(size_t index, const int64_t *dim_values, size_t dim_count) const
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
void BindInput(const char *name, const Value &)
ONNXTensorElementDataType GetMapKeyType() const
Wraps OrtApi::GetMapKeyType.
OrtMemoryInfoDeviceType
This mimics OrtDevice type constants so they can be returned in the API.
ConstValue GetInput(size_t index) const
SessionOptionsImpl & AppendExecutionProvider_CANN(const OrtCANNProviderOptions &provider_options)
MemoryAllocation GetAllocation(size_t size)
std::unique_ptr< char, detail::AllocatedFree > AllocatedStringPtr
unique_ptr typedef used to own strings allocated by OrtAllocators and release them at the end of the ...
SessionOptionsImpl & DisableCpuMemArena()
Wraps OrtApi::DisableCpuMemArena.
AllocatedStringPtr GetDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the description.
bool HasConfigEntry(const char *config_key) const
Wraps OrtApi::HasSessionConfigEntry.
void UseCsrIndices(int64_t *inner_data, size_t inner_num, int64_t *outer_data, size_t outer_num)
Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tens...
ONNXTensorElementDataType GetElementType() const
Wraps OrtApi::GetTensorElementType.
Used internally by the C++ API. C++ wrapper types inherit from this. This is a zero cost abstraction ...
ConstMemoryInfo GetConst() const
Take ownership of a pointer created by C Api.
Env & CreateAndRegisterAllocator(const OrtMemoryInfo *mem_info, const OrtArenaCfg *arena_cfg)
Wraps OrtApi::CreateAndRegisterAllocator.
void FillStringTensorElement(const char *s, size_t index)
Set a single string in a string tensor
void ReleaseOp(_Frees_ptr_opt_ OrtOp *ort_op)
std::vector< Value > Run(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, size_t output_count)
Run the model returning results in an Ort allocated vector.
Wrapper around ::OrtModelMetadata.
GLint level
Definition: glcorearb.h:108
void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo *info, _Out_ int64_t *dim_values, size_t dim_values_length)
#define _In_opt_
SessionOptionsImpl & AddInitializer(const char *name, const OrtValue *ort_val)
Wraps OrtApi::AddInitializer.
TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *p)
Used for interop with the C API.
const R * GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t &num_indices) const
The API retrieves a pointer to the internal indices buffer. The API merely performs a convenience dat...
Wrapper around ::OrtMapTypeInfo.
std::string GetConfigEntryOrDefault(const char *config_key, const std::string &def)
OpAttr(const char *name, const void *data, int len, OrtOpAttrType type)
SessionOptionsImpl & AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT.
GLdouble s
Definition: glad.h:3009
SessionOptionsImpl & EnableOrtCustomOps()
Wraps OrtApi::EnableOrtCustomOps.
size_t GetDimensionsCount() const
Wraps OrtApi::GetDimensionsCount.
AllocatedStringPtr LookupCustomMetadataMapAllocated(const char *key, OrtAllocator *allocator) const
Looks up a value by a key in the Custom Metadata map.
This struct provides life time management for custom op attribute
R & At(const std::vector< int64_t > &location)
The C API.
std::string GetAllocatorName() const
TypeInfo(OrtTypeInfo *p)
detail::SequenceTypeInfoImpl< detail::Unowned< const OrtSequenceTypeInfo >> ConstSequenceTypeInfo
SessionOptionsImpl & AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA.
Wrapper around OrtValue.
MapTypeInfo(OrtMapTypeInfo *p)
std::vector< R > GetAttributes(const char *name) const
SessionOptionsImpl & SetIntraOpNumThreads(int intra_op_num_threads)
Wraps OrtApi::SetIntraOpNumThreads.
static Value CreateSequence(std::vector< Value > &values)
Wraps OrtApi::CreateValue.
**But if you need a result
Definition: thread.h:613
it is a structure that represents the configuration of an arena based allocator
SessionOptionsImpl & RegisterCustomOpsUsingFunction(const char *function_name)
Wraps OrtApi::RegisterCustomOpsUsingFunction.
IoBinding(std::nullptr_t)
Create an empty object for convenience. Sometimes, we want to initialize members later.
OrtAllocatorType
The Env (Environment)
static const OrtApi * api_
std::vector< AllocatedStringPtr > GetCustomMetadataMapKeysAllocated(OrtAllocator *allocator) const
Returns a vector of copies of the custom metadata keys.
SessionOptionsImpl & SetOptimizedModelFilePath(const ORTCHAR_T *optimized_model_file)
Wraps OrtApi::SetOptimizedModelFilePath.
Env(std::nullptr_t)
Create an empty Env object, must be assigned a valid one to be used.
ThreadingOptions & SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn)
Wraps OrtApi::SetGlobalCustomJoinThreadFn.
const R * GetSparseTensorValues() const
The API returns a pointer to an internal buffer of the sparse tensor containing non-zero values...
constexpr bool operator==(const Float16_t &rhs) const noexcept
TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const
The API returns type and shape information for the specified indices. Each supported indices have the...
SessionOptionsImpl & SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level)
Wraps OrtApi::SetSessionGraphOptimizationLevel.
const R * GetTensorData() const
Returns a const typed pointer to the tensor contained data. No type checking is performed, the caller must ensure the type matches the tensor type.
SessionOptionsImpl & SetCustomThreadCreationOptions(void *ort_custom_thread_creation_options)
Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions.
void * GetTensorMutableRawData()
Returns a non-typed non-const pointer to a tensor contained data.
union Ort::detail::OrtSparseValuesParam::@93 data
static Op Create(const OrtKernelInfo *info, const char *op_name, const char *domain, int version, const char **type_constraint_names, const ONNXTensorElementDataType *type_constraint_values, size_t type_constraint_count, const OpAttr *attr_values, size_t attr_count, size_t input_count, size_t output_count)
ConstSession GetConst() const
std::string GetConfigEntry(const char *config_key) const
Wraps OrtApi::GetSessionConfigEntry.
std::vector< std::string > GetOutputNames() const
void(ORT_API_CALL * OrtLoggingFunction)(void *param, OrtLoggingLevel severity, const char *category, const char *logid, const char *code_location, const char *message)
SessionOptionsImpl & AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2 &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT.
ModelMetadata(std::nullptr_t)
Create an empty ModelMetadata object, must be assigned a valid one to be used.
#define _Inout_
constexpr bool operator!=(const BFloat16_t &rhs) const noexcept
OrtSparseIndicesFormat
GLuint GLsizei const GLuint const GLintptr * offsets
Definition: glcorearb.h:2621
bool IsSparseTensor() const
Returns true if the OrtValue contains a sparse tensor
std::vector< Value > GetOutputValuesHelper(const OrtIoBinding *binding, OrtAllocator *)
constexpr BFloat16_t() noexcept
void ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo *info_copy)
int GetVariadicInputMinArity() const
ModelMetadata GetModelMetadata() const
Wraps OrtApi::SessionGetModelMetadata.
AllocatedStringPtr GetGraphNameAllocated(OrtAllocator *allocator) const
Used for interop with the C API.
size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo *info)
OrtCustomOpInputOutputCharacteristic
Wrapper around OrtAllocator.
ThreadingOptions & SetGlobalInterOpNumThreads(int inter_op_num_threads)
Wraps OrtApi::SetGlobalInterOpNumThreads.
void FillSparseTensorBlockSparse(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const Shape &indices_shape, const int32_t *indices_data)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
TypeInfo GetOutputTypeInfo(size_t index) const
constexpr Base()=default
std::vector< int64_t > GetShape() const
Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape.
const T * GetTensorData(_Inout_ const OrtValue *value)
const OrtValue * KernelContext_GetInput(const OrtKernelContext *context, _In_ size_t index)
std::vector< int64_t > GetTensorShape(const OrtTensorTypeAndShapeInfo *info)
Wrapper around OrtMemoryInfo.
const OrtMemoryInfo * GetTensorMemoryInfo(_In_ const OrtValue *value)
std::vector< std::string > GetAvailableProviders()
This is a C++ wrapper for OrtApi::GetAvailableProviders() and returns a vector of strings representin...
Op(std::nullptr_t)
Create an empty Operator object, must be assigned a valid one to be used.
size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo *info)
SessionOptions Clone() const
Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions.
detail::MapTypeInfoImpl< detail::Unowned< const OrtMapTypeInfo >> ConstMapTypeInfo
Definition: core.h:760
Wrapper around ::OrtIoBinding.
void GetAttrs(const OrtKernelInfo *p, const char *name, std::vector< float > &)
constexpr BFloat16_t(uint16_t v) noexcept
CustomOpApi(const OrtApi &api)
#define ORT_API_VERSION
The API version defined in this header.
SessionOptionsImpl & SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn)
Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn.
OrtKernelInfo * CopyKernelInfo(_In_ const OrtKernelInfo *info)
The Status that holds ownership of OrtStatus received from C API Use it to safely destroy OrtStatus* ...
OrtOpAttr * CreateOpAttr(_In_ const char *name, _In_ const void *data, _In_ int len, _In_ OrtOpAttrType type)
void GetStringTensorContent(void *buffer, size_t buffer_length, size_t *offsets, size_t offsets_count) const
The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor into...
void GetDimensions(int64_t *values, size_t values_count) const
Wraps OrtApi::GetDimensions.
The default allocator for execution provider.
detail::SessionOptionsImpl< detail::Unowned< OrtSessionOptions >> UnownedSessionOptions
All C++ methods that can fail will throw an exception of this type.
const char * what() const noexceptoverride
OrtMemType GetInputMemoryType(size_t) const
A generic, discriminated value, whose type may be queried dynamically.
Definition: Value.h:44
int64_t GetVersion() const
Wraps OrtApi::ModelMetadataGetVersion.
Base(Base &&v) noexcept
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *input)
void GetAttr(const OrtKernelInfo *p, const char *name, float &)
typename Unowned< T >::Type contained_type
CustomOpDomain(std::nullptr_t)
Create an empty CustomOpDomain object, must be assigned a valid one to be used.
SessionOptionsImpl & AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO.
Wrapper around ::OrtSequenceTypeInfo.
const std::unordered_map< std::string, std::string > & GetFlattenedConfigs() const
Returns a flattened map of custom operator configuration entries and their values.
This class wraps a raw pointer OrtKernelContext* that is being passed to the custom kernel Compute() ...
SessionOptionsImpl & AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2 &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2.
CustomOpConfigs & operator=(const CustomOpConfigs &)=default
constexpr std::enable_if< I< type_count_base< T >::value, int >::type tuple_type_size(){return subtype_count< typename std::tuple_element< I, T >::type >::value+tuple_type_size< T, I+1 >);}template< typename T > struct type_count< T, typename std::enable_if< is_tuple_like< T >::value >::type >{static constexpr int value{tuple_type_size< T, 0 >)};};template< typename T > struct subtype_count{static constexpr int value{is_mutable_container< T >::value?expected_max_vector_size:type_count< T >::value};};template< typename T, typename Enable=void > struct type_count_min{static const int value{0};};template< typename T >struct type_count_min< T, typename std::enable_if<!is_mutable_container< T >::value &&!is_tuple_like< T >::value &&!is_wrapper< T >::value &&!is_complex< T >::value &&!std::is_void< T >::value >::type >{static constexpr int value{type_count< T >::value};};template< typename T > struct type_count_min< T, typename std::enable_if< is_complex< T >::value >::type >{static constexpr int value{1};};template< typename T >struct type_count_min< T, typename std::enable_if< is_wrapper< T >::value &&!is_complex< T >::value &&!is_tuple_like< T >::value >::type >{static constexpr int value{subtype_count_min< typename T::value_type >::value};};template< typename T, std::size_t I >constexpr typename std::enable_if< I==type_count_base< T >::value, int >::type tuple_type_size_min(){return 0;}template< typename T, std::size_t I > constexpr typename std::enable_if< I< type_count_base< T >::value, int >::type tuple_type_size_min(){return subtype_count_min< typename std::tuple_element< I, T >::type >::value+tuple_type_size_min< T, I+1 >);}template< typename T > struct type_count_min< T, typename std::enable_if< is_tuple_like< T >::value >::type >{static constexpr int value{tuple_type_size_min< T, 0 >)};};template< typename T > struct subtype_count_min{static constexpr int value{is_mutable_container< T >::value?((type_count< T >::value< expected_max_vector_size)?type_count< T >::value:0):type_count_min< T >::value};};template< typename T, typename Enable=void > struct expected_count{static const int value{0};};template< typename T >struct expected_count< T, typename std::enable_if<!is_mutable_container< T >::value &&!is_wrapper< T >::value &&!std::is_void< T >::value >::type >{static constexpr int value{1};};template< typename T > struct expected_count< T, typename std::enable_if< is_mutable_container< T >::value >::type >{static constexpr int value{expected_max_vector_size};};template< typename T >struct expected_count< T, typename std::enable_if<!is_mutable_container< T >::value &&is_wrapper< T >::value >::type >{static constexpr int value{expected_count< typename T::value_type >::value};};enum class object_category:int{char_value=1, integral_value=2, unsigned_integral=4, enumeration=6, boolean_value=8, floating_point=10, number_constructible=12, double_constructible=14, integer_constructible=16, string_assignable=23, string_constructible=24, other=45, wrapper_value=50, complex_number=60, tuple_value=70, container_value=80,};template< typename T, typename Enable=void > struct classify_object{static constexpr object_category value{object_category::other};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_integral< T >::value &&!std::is_same< T, char >::value &&std::is_signed< T >::value &&!is_bool< T >::value &&!std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::integral_value};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_integral< T >::value &&std::is_unsigned< T >::value &&!std::is_same< T, char >::value &&!is_bool< T >::value >::type >{static constexpr object_category value{object_category::unsigned_integral};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_same< T, char >::value &&!std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::char_value};};template< typename T > struct classify_object< T, typename std::enable_if< is_bool< T >::value >::type >{static constexpr object_category value{object_category::boolean_value};};template< typename T > struct classify_object< T, typename std::enable_if< std::is_floating_point< T >::value >::type >{static constexpr object_category value{object_category::floating_point};};template< typename T >struct classify_object< T, typename std::enable_if<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&std::is_assignable< T &, std::string >::value >::type >{static constexpr object_category value{object_category::string_assignable};};template< typename T >struct classify_object< T, typename std::enable_if<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&!std::is_assignable< T &, std::string >::value &&(type_count< T >::value==1)&&std::is_constructible< T, std::string >::value >::type >{static constexpr object_category value{object_category::string_constructible};};template< typename T > struct classify_object< T, typename std::enable_if< std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::enumeration};};template< typename T > struct classify_object< T, typename std::enable_if< is_complex< T >::value >::type >{static constexpr object_category value{object_category::complex_number};};template< typename T > struct uncommon_type{using type=typename std::conditional<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&!std::is_assignable< T &, std::string >::value &&!std::is_constructible< T, std::string >::value &&!is_complex< T >::value &&!is_mutable_container< T >::value &&!std::is_enum< T >::value, std::true_type, std::false_type >::type;static constexpr bool value=type::value;};template< typename T >struct classify_object< T, typename std::enable_if<(!is_mutable_container< T >::value &&is_wrapper< T >::value &&!is_tuple_like< T >::value &&uncommon_type< T >::value)>::type >{static constexpr object_category value{object_category::wrapper_value};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&is_direct_constructible< T, double >::value &&is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::number_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&!is_direct_constructible< T, double >::value &&is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::integer_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&is_direct_constructible< T, double >::value &&!is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::double_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< is_tuple_like< T >::value &&((type_count< T >::value >=2 &&!is_wrapper< T >::value)||(uncommon_type< T >::value &&!is_direct_constructible< T, double >::value &&!is_direct_constructible< T, int >::value)||(uncommon_type< T >::value &&type_count< T >::value >=2))>::type >{static constexpr object_category value{object_category::tuple_value};};template< typename T > struct classify_object< T, typename std::enable_if< is_mutable_container< T >::value >::type >{static constexpr object_category value{object_category::container_value};};template< typename T, enable_if_t< classify_object< T >::value==object_category::char_value, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"CHAR";}template< typename T, enable_if_t< classify_object< T >::value==object_category::integral_value||classify_object< T >::value==object_category::integer_constructible, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"INT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::unsigned_integral, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"UINT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::floating_point||classify_object< T >::value==object_category::number_constructible||classify_object< T >::value==object_category::double_constructible, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"FLOAT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::enumeration, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"ENUM";}template< typename T, enable_if_t< classify_object< T >::value==object_category::boolean_value, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"BOOLEAN";}template< typename T, enable_if_t< classify_object< T >::value==object_category::complex_number, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"COMPLEX";}template< typename T, enable_if_t< classify_object< T >::value >=object_category::string_assignable &&classify_object< T >::value<=object_category::other, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"TEXT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value >=2, detail::enabler >=detail::dummy >std::string type_name();template< typename T, enable_if_t< classify_object< T >::value==object_category::container_value||classify_object< T >::value==object_category::wrapper_value, detail::enabler >=detail::dummy >std::string type_name();template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value==1, detail::enabler >=detail::dummy >inline std::string type_name(){return type_name< typename std::decay< typename std::tuple_element< 0, T >::type >::type >);}template< typename T, std::size_t I >inline typename std::enable_if< I==type_count_base< T >::value, std::string >::type tuple_name(){return std::string{};}template< typename T, std::size_t I >inline typename std::enable_if<(I< type_count_base< T >::value), std::string >::type tuple_name(){auto str=std::string{type_name< typename std::decay< typename std::tuple_element< I, T >::type >::type >)}+ ','+tuple_name< T, I+1 >);if(str.back()== ',') str.pop_back();return str;}template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value >=2, detail::enabler > > std::string type_name()
Recursively generate the tuple type name.
Definition: CLI11.h:1729
static Value CreateTensor(const OrtMemoryInfo *info, T *p_data, size_t p_data_element_count, const int64_t *shape, size_t shape_len)
Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
GLint GLint GLsizei GLint GLenum format
Definition: glcorearb.h:108
const OrtApi & GetApi()
This returns a reference to the OrtApi interface in use.
ThreadingOptions & SetGlobalIntraOpNumThreads(int intra_op_num_threads)
Wraps OrtApi::SetGlobalIntraOpNumThreads.
GraphOptimizationLevel
Graph optimization level.
size_t KernelContext_GetOutputCount(const OrtKernelContext *context)
Value GetTensorAttribute(const char *name, OrtAllocator *allocator) const
bool GetVariadicInputHomogeneity() const
void Add(const OrtCustomOp *op)
Wraps CustomOpDomain_Add.
constexpr Float16_t() noexcept
KernelInfo(std::nullptr_t)
Create an empty instance to initialize later.
detail::TypeInfoImpl< detail::Unowned< const OrtTypeInfo >> ConstTypeInfo
Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value. Provides access to const OrtTypeInfo APIs.
AllocatedStringPtr GetDomainAllocated(OrtAllocator *allocator) const
Returns a copy of the domain name.
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const
Represents native memory allocation coming from one of the OrtAllocators registered with OnnxRuntime...
Session(std::nullptr_t)
Create an empty Session object, must be assigned a valid one to be used.
ThreadingOptions & SetGlobalDenormalAsZero()
Wraps OrtApi::SetGlobalDenormalAsZero.
IEEE 754 half-precision floating point data type.
detail::Shape Shape
GLint location
Definition: glcorearb.h:805
SessionOptions(OrtSessionOptions *p)
OpenVINO Provider Options.
Create and own custom defined operation.
SessionOptionsImpl & SetInterOpNumThreads(int inter_op_num_threads)
Wraps OrtApi::SetInterOpNumThreads.
ConstIoBinding GetConst() const
Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT_V...
R * GetTensorMutableData()
Returns a non-const typed pointer to an OrtValue/Tensor contained buffer No type checking is performe...
SessionOptions()
Wraps OrtApi::CreateSessionOptions.
void * GetGPUComputeStream() const
void BindOutput(const char *name, const Value &)
GLuint const GLchar * name
Definition: glcorearb.h:786
RunOptions & SetRunTag(const char *run_tag)
wraps OrtApi::RunOptionsSetRunTag
Allocator(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
SessionOptionsImpl & EnableProfiling(const ORTCHAR_T *profile_file_prefix)
Wraps OrtApi::EnableProfiling.
constexpr bool operator!=(const Float16_t &rhs) const noexcept
size_t GetStringTensorDataLength() const
This API returns a full length of string data contained within either a tensor or a sparse Tensor...
ORT_EXPORT const OrtApiBase *ORT_API_CALL OrtGetApiBase(void) NO_EXCEPTION
The Onnxruntime library's entry point to access the C API.
detail::ConstSessionOptionsImpl< detail::Unowned< const OrtSessionOptions >> ConstSessionOptions
RunOptions & SetTerminate()
Terminates all currently executing Session::Run calls that were made using this RunOptions instance...
OrtAllocatorType GetAllocatorType() const
RunOptions(std::nullptr_t)
Create an empty RunOptions object, must be assigned a valid one to be used.
OrtRunOptions RunOptions
Definition: run_options.h:48
TypeInfo GetSequenceElementType() const
Wraps OrtApi::GetSequenceElementType.
const OrtApi *ORT_API_CALL * GetApi(uint32_t version) NO_EXCEPTION
Get a pointer to the requested version of the OrtApi.
int GetVariadicOutputMinArity() const
static Value CreateSparseTensor(const OrtMemoryInfo *info, T *p_data, const Shape &dense_shape, const Shape &values_shape)
This is a simple forwarding method to the other overload that helps deducing data type enum value fro...
const char *ORT_API_CALL * GetExecutionProviderType(_In_ const struct OrtCustomOp *op)
TypeInfo GetOutputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOutputTypeInfo.
size_t GetStringTensorElementLength(size_t element_index) const
The API returns a byte length of UTF-8 encoded string element contained in either a tensor or a spare...
AllocatedFree(OrtAllocator *allocator)
SessionOptionsImpl & Add(OrtCustomOpDomain *custom_op_domain)
Wraps OrtApi::AddCustomOpDomain.
ORT_DEFINE_RELEASE(Allocator)
contained_type * p_
constexpr Base(contained_type *p) noexcept
TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const
The API returns type and shape information for stored non-zero values of the sparse tensor...
GT_API const UT_StringHolder version
TypeInfo GetTypeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
#define _Out_
bool GetVariadicOutputHomogeneity() const
RunOptions & SetRunLogSeverityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogSeverityLevel.
ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
Wraps OrtApi::CastTypeInfoToTensorInfo.
TensorRT Provider Options.
bool operator==(const MemoryInfoImpl< U > &o) const
_In_ OrtKernelContext * context
TypeInfo GetMapValueType() const
Wraps OrtApi::GetMapValueType.
RunOptions()
Wraps OrtApi::CreateRunOptions.
OrtTensorTypeAndShapeInfo * GetTensorTypeAndShape(_In_ const OrtValue *value)
This struct owns the OrtKernInfo* pointer when a copy is made. For convenient wrapping of OrtKernelIn...
void FillSparseTensorCoo(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values_param, const int64_t *indices_data, size_t indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
std::vector< Value > GetOutputValues() const
GLsizeiptr size
Definition: glcorearb.h:664
OrtErrorCode GetErrorCode() const
struct OrtKernelInfo OrtKernelInfo
void FillStringTensor(const char *const *s, size_t s_len)
Set all strings at once in a string tensor
KernelContext(OrtKernelContext *context)
constexpr bool operator==(const BFloat16_t &rhs) const noexcept
ArenaCfg(std::nullptr_t)
Create an empty ArenaCfg object, must be assigned a valid one to be used.
ThreadingOptions()
Wraps OrtApi::CreateThreadingOptions.
void ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr *op_attr)
GLenum GLsizei GLsizei GLint * values
Definition: glcorearb.h:1602
void operator()(void *ptr) const
void SetDimensions(OrtTensorTypeAndShapeInfo *info, _In_ const int64_t *dim_values, size_t dim_count)
ConstValue GetConst() const
std::string GetOutputName(size_t index) const
The ThreadingOptions.
OrtOpAttrType
static Value CreateMap(Value &keys, Value &values)
Wraps OrtApi::CreateValue.
ModelMetadata(OrtModelMetadata *p)
constexpr Float16_t(uint16_t v) noexcept
void GetSessionConfigs(std::unordered_map< std::string, std::string > &out, ConstSessionOptions options) const
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOverridableInitializerTypeInfo.
This entire structure is deprecated, but we not marking it as a whole yet since we want to preserve f...
MemoryInfo(OrtMemoryInfo *p)
OrtSparseFormat
SessionOptionsImpl & AddConfigEntry(const char *config_key, const char *config_value)
Wraps OrtApi::AddSessionConfigEntry.
Memory allocation interface.
CUDA Provider Options.
OrtOp * CreateOp(_In_ const OrtKernelInfo *info, _In_ const char *op_name, _In_ const char *domain, _In_ int version, _In_opt_ const char **type_constraint_names, _In_opt_ const ONNXTensorElementDataType *type_constraint_values, _In_opt_ int type_constraint_count, _In_opt_ const OrtOpAttr *const *attr_values, _In_opt_ int attr_count, _In_ int input_count, _In_ int output_count)
SessionOptionsImpl & SetExecutionMode(ExecutionMode execution_mode)
Wraps OrtApi::SetSessionExecutionMode.
GLuint index
Definition: glcorearb.h:786
ROCM Provider Options.
AllocatedStringPtr EndProfilingAllocated(OrtAllocator *allocator)
End profiling and return a copy of the profiling file name.
SessionOptionsImpl & AppendExecutionProvider(const std::string &provider_name, const std::unordered_map< std::string, std::string > &provider_options={})
Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports SNPE and XNNPACK.
auto ptr(T p) -> const void *
Definition: format.h:2448
GLuint GLfloat * val
Definition: glcorearb.h:1608
SessionOptionsImpl & DisableProfiling()
Wraps OrtApi::DisableProfiling.
AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the graph description.
ExecutionMode
OrtValue * KernelContext_GetOutput(OrtKernelContext *context, _In_ size_t index, _In_ const int64_t *dim_values, size_t dim_count)
RunOptions & UnsetTerminate()
Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without ...
size_t GetOverridableInitializerCount() const
Returns the number of inputs that have defaults that can be overridden.
const char * GetRunTag() const
Wraps OrtApi::RunOptionsGetRunTag.
int GetRunLogVerbosityLevel() const
Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel.
std::vector< std::string > GetSessionConfigKeys() const
OrtErrorCode GetOrtErrorCode() const
Definition: core.h:1131
size_t GetOutputCount() const
Returns the number of model outputs.
SessionOptionsImpl & AddExternalInitializers(const std::vector< std::string > &names, const std::vector< Value > &ort_values)
Wraps OrtApi::AddExternalInitializers.
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
Base & operator=(Base &&v) noexcept
ONNXTensorElementDataType
#define const
Definition: zconf.h:214
int GetRunLogSeverityLevel() const
Wraps OrtApi::RunOptionsGetRunLogSeverityLevel.
ConstMemoryInfo GetTensorMemoryInfo() const
This API returns information about the memory allocation used to hold data.
AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of input name at the specified index.
Wrapper around ::OrtTensorTypeAndShapeInfo.
MemoryInfo(std::nullptr_t)
No instance is created.
Exception(std::string &&string, OrtErrorCode code)
std::string MakeCustomOpConfigEntryKey(const char *custom_op_name, const char *config)
CustomOpConfigs.
size_t GetCount() const
< Return true if OrtValue contains data and returns false if the OrtValue is a None ...
MapTypeInfo(std::nullptr_t)
Create an empty MapTypeInfo object, must be assigned a valid one to be used.
std::string GetInputName(size_t index) const
ThreadingOptions & SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn)
Wraps OrtApi::SetGlobalCustomCreateThreadFn.
Wrapper around ::OrtSessionOptions.
OrtSparseFormat GetSparseFormat() const
The API returns the sparse data format this OrtValue holds in a sparse tensor. If the sparse tensor w...
OrtErrorCode
void UseCooIndices(int64_t *indices_data, size_t indices_num)
Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tens...
contained_type * release()
Relinquishes ownership of the contained C object pointer The underlying object is not destroyed...
MemoryAllocation(OrtAllocator *allocator, void *p, size_t size)
CustomOpConfigs & AddConfig(const char *custom_op_name, const char *config_key, const char *config_value)
Adds a session configuration entry/value for a specific custom operator.
SessionOptions(std::nullptr_t)
Create an empty SessionOptions object, must be assigned a valid one to be used.
Base & operator=(const Base &)=delete
type
Definition: core.h:1059
void(* OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_handle)
Custom thread join function.
void ThrowOnError(OrtStatus *result)
TypeInfo GetInputTypeInfo(size_t index) const
SessionOptionsImpl & SetLogSeverityLevel(int level)
Wraps OrtApi::SetSessionLogSeverityLevel.
UnownedValue GetUnowned() const
R GetAttribute(const char *name) const
Wrapper around OrtAllocator default instance that is owned by Onnxruntime.
Wrapper around ::OrtSession.
SessionOptionsImpl & DisableMemPattern()
Wraps OrtApi::DisableMemPattern.
SessionOptionsImpl & AppendExecutionProvider_ROCM(const OrtROCMProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM.
~CustomOpConfigs()=default
Options for the CUDA provider that are passed to SessionOptionsAppendExecutionProvider_CUDA_V2. Please note that this struct is similar to OrtCUDAProviderOptions but only to be used internally. Going forward, new cuda provider options are to be supported via this struct and usage of the publicly defined OrtCUDAProviderOptions will be deprecated over time. User can only get the instance of OrtCUDAProviderOptionsV2 via CreateCUDAProviderOptions.
RunOptions & SetRunLogVerbosityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel.
ThreadingOptions & SetGlobalCustomThreadCreationOptions(void *ort_custom_thread_creation_options)
Wraps OrtApi::SetGlobalCustomThreadCreationOptions.
struct OrtKernelContext OrtKernelContext
const char *ORT_API_CALL * GetName(_In_ const struct OrtCustomOp *op)
bfloat16 (Brain Floating Point) data type
ConstKernelInfo GetConst() const
constexpr FMT_INLINE value()
Definition: core.h:1154
OrtLoggingLevel
Logging severity levels.
SequenceTypeInfo(std::nullptr_t)
Create an empty SequenceTypeInfo object, must be assigned a valid one to be used. ...
Class that represents session configuration entries for one or more custom operators.
Value & operator=(Value &&)=default
size_t KernelContext_GetInputCount(const OrtKernelContext *context)
Definition: format.h:895
std::vector< std::string > GetOutputNamesHelper(const OrtIoBinding *binding, OrtAllocator *)
ConstTensorTypeAndShapeInfo GetConst() const
_In_ size_t index
OrtMemType
Memory types for allocated memory, execution provider specific types should be extended in each provi...
SessionOptionsImpl & RegisterCustomOpsLibrary(const ORTCHAR_T *library_name, const CustomOpConfigs &custom_op_configs={})
T KernelInfoGetAttribute(_In_ const OrtKernelInfo *info, _In_ const char *name)
ConstSequenceTypeInfo GetSequenceTypeInfo() const
Wraps OrtApi::CastTypeInfoToSequenceTypeInfo.
UnownedIoBinding GetUnowned() const
MIGraphX Provider Options.