HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
onnxruntime_cxx_api.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 // Summary: The Ort C++ API is a header only wrapper around the Ort C API.
5 //
6 // The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
7 // and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so
8 // all the resources follow RAII and do not leak memory.
9 //
10 // Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
11 // To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them
12 // until you assign an instance that actually holds an underlying object.
13 //
14 // For Ort objects only move assignment between objects is allowed, there are no copy constructors.
15 // Some objects have explicit 'Clone' methods for this purpose.
16 //
17 // ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments
18 // by value or by reference. ConstXXXX types are restricted to const only interfaces.
19 //
20 // UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces.
21 //
22 // The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not
23 // have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code.
24 
25 #pragma once
26 #include "onnxruntime_c_api.h"
27 #include "onnxruntime_float16.h"
28 
29 #include <cstddef>
30 #include <cstdio>
31 #include <array>
32 #include <memory>
33 #include <stdexcept>
34 #include <string>
35 #include <vector>
36 #include <unordered_map>
37 #include <utility>
38 #include <type_traits>
39 
40 #ifdef ORT_NO_EXCEPTIONS
41 #include <iostream>
42 #endif
43 
44 /** \brief All C++ Onnxruntime APIs are defined inside this namespace
45  *
46  */
47 namespace Ort {
48 
49 /** \brief All C++ methods that can fail will throw an exception of this type
50  *
51  * If <tt>ORT_NO_EXCEPTIONS</tt> is defined, then any error will result in a call to abort()
52  */
53 struct Exception : std::exception {
54  Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
55 
56  OrtErrorCode GetOrtErrorCode() const { return code_; }
57  const char* what() const noexcept override { return message_.c_str(); }
58 
59  private:
60  std::string message_;
61  OrtErrorCode code_;
62 };
63 
64 #ifdef ORT_NO_EXCEPTIONS
65 // The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors.
66 // NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW
67 #ifndef ORT_CXX_API_THROW
68 #define ORT_CXX_API_THROW(string, code) \
69  do { \
70  std::cerr << Ort::Exception(string, code) \
71  .what() \
72  << std::endl; \
73  abort(); \
74  } while (false)
75 #endif
76 #else
77 #define ORT_CXX_API_THROW(string, code) \
78  throw Ort::Exception(string, code)
79 #endif
80 
81 // This is used internally by the C++ API. This class holds the global variable that points to the OrtApi,
82 // it's in a template so that we can define a global variable in a header and make
83 // it transparent to the users of the API.
84 template <typename T>
85 struct Global {
86  static const OrtApi* api_;
87 };
88 
89 // If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
90 template <typename T>
91 #ifdef ORT_API_MANUAL_INIT
92 const OrtApi* Global<T>::api_{};
93 inline void InitApi() noexcept { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
94 
95 // Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is
96 // required by C++ APIs.
97 //
98 // Example mycustomop.cc:
99 //
100 // #define ORT_API_MANUAL_INIT
101 // #include <onnxruntime_cxx_api.h>
102 // #undef ORT_API_MANUAL_INIT
103 //
104 // OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) {
105 // Ort::InitApi(api_base->GetApi(ORT_API_VERSION));
106 // // ...
107 // }
108 //
109 inline void InitApi(const OrtApi* api) noexcept { Global<void>::api_ = api; }
110 #else
111 #if defined(_MSC_VER) && !defined(__clang__)
112 #pragma warning(push)
113 // "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
114 // Please define ORT_API_MANUAL_INIT if it conerns you.
115 #pragma warning(disable : 26426)
116 #endif
117 const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
118 #if defined(_MSC_VER) && !defined(__clang__)
119 #pragma warning(pop)
120 #endif
121 #endif
122 
123 /// This returns a reference to the OrtApi interface in use
124 inline const OrtApi& GetApi() noexcept { return *Global<void>::api_; }
125 
126 /// <summary>
127 /// This function returns the onnxruntime version string
128 /// </summary>
129 /// <returns>version string major.minor.rev</returns>
130 std::string GetVersionString();
131 
132 /// <summary>
133 /// This function returns the onnxruntime build information: including git branch,
134 /// git commit id, build type(Debug/Release/RelWithDebInfo) and cmake cpp flags.
135 /// </summary>
136 /// <returns>string</returns>
137 std::string GetBuildInfoString();
138 
139 /// <summary>
140 /// This is a C++ wrapper for OrtApi::GetAvailableProviders() and
141 /// returns a vector of strings representing the available execution providers.
142 /// </summary>
143 /// <returns>vector of strings</returns>
144 std::vector<std::string> GetAvailableProviders();
145 
146 /** \brief IEEE 754 half-precision floating point data type
147  *
148  * \details This struct is used for converting float to float16 and back
149  * so the user could feed inputs and fetch outputs using these type.
150  *
151  * The size of the structure should align with uint16_t and one can freely cast
152  * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data.
153  *
154  * \code{.unparsed}
155  * // This example demonstrates converion from float to float16
156  * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
157  * std::vector<Ort::Float16_t> fp16_values;
158  * fp16_values.reserve(std::size(values));
159  * std::transform(std::begin(values), std::end(values), std::back_inserter(fp16_values),
160  * [](float value) { return Ort::Float16_t(value); });
161  *
162  * \endcode
163  */
165  private:
166  /// <summary>
167  /// Constructor from a 16-bit representation of a float16 value
168  /// No conversion is done here.
169  /// </summary>
170  /// <param name="v">16-bit representation</param>
171  constexpr explicit Float16_t(uint16_t v) noexcept { val = v; }
172 
173  public:
175 
176  /// <summary>
177  /// Default constructor
178  /// </summary>
179  Float16_t() = default;
180 
181  /// <summary>
182  /// Explicit conversion to uint16_t representation of float16.
183  /// </summary>
184  /// <param name="v">uint16_t bit representation of float16</param>
185  /// <returns>new instance of Float16_t</returns>
186  constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); }
187 
188  /// <summary>
189  /// __ctor from float. Float is converted into float16 16-bit representation.
190  /// </summary>
191  /// <param name="v">float value</param>
192  explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
193 
194  /// <summary>
195  /// Converts float16 to float
196  /// </summary>
197  /// <returns>float representation of float16 value</returns>
198  float ToFloat() const noexcept { return Base::ToFloatImpl(); }
199 
200  /// <summary>
201  /// Checks if the value is negative
202  /// </summary>
203  /// <returns>true if negative</returns>
204  using Base::IsNegative;
205 
206  /// <summary>
207  /// Tests if the value is NaN
208  /// </summary>
209  /// <returns>true if NaN</returns>
210  using Base::IsNaN;
211 
212  /// <summary>
213  /// Tests if the value is finite
214  /// </summary>
215  /// <returns>true if finite</returns>
216  using Base::IsFinite;
217 
218  /// <summary>
219  /// Tests if the value represents positive infinity.
220  /// </summary>
221  /// <returns>true if positive infinity</returns>
223 
224  /// <summary>
225  /// Tests if the value represents negative infinity
226  /// </summary>
227  /// <returns>true if negative infinity</returns>
229 
230  /// <summary>
231  /// Tests if the value is either positive or negative infinity.
232  /// </summary>
233  /// <returns>True if absolute value is infinity</returns>
234  using Base::IsInfinity;
235 
236  /// <summary>
237  /// Tests if the value is NaN or zero. Useful for comparisons.
238  /// </summary>
239  /// <returns>True if NaN or zero.</returns>
240  using Base::IsNaNOrZero;
241 
242  /// <summary>
243  /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
244  /// </summary>
245  /// <returns>True if so</returns>
246  using Base::IsNormal;
247 
248  /// <summary>
249  /// Tests if the value is subnormal (denormal).
250  /// </summary>
251  /// <returns>True if so</returns>
252  using Base::IsSubnormal;
253 
254  /// <summary>
255  /// Creates an instance that represents absolute value.
256  /// </summary>
257  /// <returns>Absolute value</returns>
258  using Base::Abs;
259 
260  /// <summary>
261  /// Creates a new instance with the sign flipped.
262  /// </summary>
263  /// <returns>Flipped sign instance</returns>
264  using Base::Negate;
265 
266  /// <summary>
267  /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
268  /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
269  /// and therefore equivalent, if the resulting value is still zero.
270  /// </summary>
271  /// <param name="lhs">first value</param>
272  /// <param name="rhs">second value</param>
273  /// <returns>True if both arguments represent zero</returns>
274  using Base::AreZero;
275 
276  /// <summary>
277  /// User defined conversion operator. Converts Float16_t to float.
278  /// </summary>
279  explicit operator float() const noexcept { return ToFloat(); }
280 
281  using Base::operator==;
282  using Base::operator!=;
283  using Base::operator<;
284 };
285 
286 static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
287 
288 /** \brief bfloat16 (Brain Floating Point) data type
289  *
290  * \details This struct is used for converting float to bfloat16 and back
291  * so the user could feed inputs and fetch outputs using these type.
292  *
293  * The size of the structure should align with uint16_t and one can freely cast
294  * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data.
295  *
296  * \code{.unparsed}
297  * // This example demonstrates converion from float to float16
298  * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
299  * std::vector<Ort::BFloat16_t> bfp16_values;
300  * bfp16_values.reserve(std::size(values));
301  * std::transform(std::begin(values), std::end(values), std::back_inserter(bfp16_values),
302  * [](float value) { return Ort::BFloat16_t(value); });
303  *
304  * \endcode
305  */
307  private:
308  /// <summary>
309  /// Constructor from a uint16_t representation of bfloat16
310  /// used in FromBits() to escape overload resolution issue with
311  /// constructor from float.
312  /// No conversion is done.
313  /// </summary>
314  /// <param name="v">16-bit bfloat16 value</param>
315  constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; }
316 
317  public:
319 
320  BFloat16_t() = default;
321 
322  /// <summary>
323  /// Explicit conversion to uint16_t representation of bfloat16.
324  /// </summary>
325  /// <param name="v">uint16_t bit representation of bfloat16</param>
326  /// <returns>new instance of BFloat16_t</returns>
327  static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); }
328 
329  /// <summary>
330  /// __ctor from float. Float is converted into bfloat16 16-bit representation.
331  /// </summary>
332  /// <param name="v">float value</param>
333  explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
334 
335  /// <summary>
336  /// Converts bfloat16 to float
337  /// </summary>
338  /// <returns>float representation of bfloat16 value</returns>
339  float ToFloat() const noexcept { return Base::ToFloatImpl(); }
340 
341  /// <summary>
342  /// Checks if the value is negative
343  /// </summary>
344  /// <returns>true if negative</returns>
345  using Base::IsNegative;
346 
347  /// <summary>
348  /// Tests if the value is NaN
349  /// </summary>
350  /// <returns>true if NaN</returns>
351  using Base::IsNaN;
352 
353  /// <summary>
354  /// Tests if the value is finite
355  /// </summary>
356  /// <returns>true if finite</returns>
357  using Base::IsFinite;
358 
359  /// <summary>
360  /// Tests if the value represents positive infinity.
361  /// </summary>
362  /// <returns>true if positive infinity</returns>
364 
365  /// <summary>
366  /// Tests if the value represents negative infinity
367  /// </summary>
368  /// <returns>true if negative infinity</returns>
370 
371  /// <summary>
372  /// Tests if the value is either positive or negative infinity.
373  /// </summary>
374  /// <returns>True if absolute value is infinity</returns>
375  using Base::IsInfinity;
376 
377  /// <summary>
378  /// Tests if the value is NaN or zero. Useful for comparisons.
379  /// </summary>
380  /// <returns>True if NaN or zero.</returns>
381  using Base::IsNaNOrZero;
382 
383  /// <summary>
384  /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
385  /// </summary>
386  /// <returns>True if so</returns>
387  using Base::IsNormal;
388 
389  /// <summary>
390  /// Tests if the value is subnormal (denormal).
391  /// </summary>
392  /// <returns>True if so</returns>
393  using Base::IsSubnormal;
394 
395  /// <summary>
396  /// Creates an instance that represents absolute value.
397  /// </summary>
398  /// <returns>Absolute value</returns>
399  using Base::Abs;
400 
401  /// <summary>
402  /// Creates a new instance with the sign flipped.
403  /// </summary>
404  /// <returns>Flipped sign instance</returns>
405  using Base::Negate;
406 
407  /// <summary>
408  /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
409  /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
410  /// and therefore equivalent, if the resulting value is still zero.
411  /// </summary>
412  /// <param name="lhs">first value</param>
413  /// <param name="rhs">second value</param>
414  /// <returns>True if both arguments represent zero</returns>
415  using Base::AreZero;
416 
417  /// <summary>
418  /// User defined conversion operator. Converts BFloat16_t to float.
419  /// </summary>
420  explicit operator float() const noexcept { return ToFloat(); }
421 
422  // We do not have an inherited impl for the below operators
423  // as the internal class implements them a little differently
424  bool operator==(const BFloat16_t& rhs) const noexcept;
425  bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); }
426  bool operator<(const BFloat16_t& rhs) const noexcept;
427 };
428 
429 static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
430 
431 /** \brief float8e4m3fn (Float8 Floating Point) data type
432  * \details It is necessary for type dispatching to make use of C++ API
433  * The type is implicitly convertible to/from uint8_t.
434  * See https://onnx.ai/onnx/technical/float8.html for further details.
435  */
437  uint8_t value;
438  constexpr Float8E4M3FN_t() noexcept : value(0) {}
439  constexpr Float8E4M3FN_t(uint8_t v) noexcept : value(v) {}
440  constexpr operator uint8_t() const noexcept { return value; }
441  // nan values are treated like any other value for operator ==, !=
442  constexpr bool operator==(const Float8E4M3FN_t& rhs) const noexcept { return value == rhs.value; };
443  constexpr bool operator!=(const Float8E4M3FN_t& rhs) const noexcept { return value != rhs.value; };
444 };
445 
446 static_assert(sizeof(Float8E4M3FN_t) == sizeof(uint8_t), "Sizes must match");
447 
448 /** \brief float8e4m3fnuz (Float8 Floating Point) data type
449  * \details It is necessary for type dispatching to make use of C++ API
450  * The type is implicitly convertible to/from uint8_t.
451  * See https://onnx.ai/onnx/technical/float8.html for further details.
452  */
454  uint8_t value;
455  constexpr Float8E4M3FNUZ_t() noexcept : value(0) {}
456  constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept : value(v) {}
457  constexpr operator uint8_t() const noexcept { return value; }
458  // nan values are treated like any other value for operator ==, !=
459  constexpr bool operator==(const Float8E4M3FNUZ_t& rhs) const noexcept { return value == rhs.value; };
460  constexpr bool operator!=(const Float8E4M3FNUZ_t& rhs) const noexcept { return value != rhs.value; };
461 };
462 
463 static_assert(sizeof(Float8E4M3FNUZ_t) == sizeof(uint8_t), "Sizes must match");
464 
465 /** \brief float8e5m2 (Float8 Floating Point) data type
466  * \details It is necessary for type dispatching to make use of C++ API
467  * The type is implicitly convertible to/from uint8_t.
468  * See https://onnx.ai/onnx/technical/float8.html for further details.
469  */
470 struct Float8E5M2_t {
471  uint8_t value;
472  constexpr Float8E5M2_t() noexcept : value(0) {}
473  constexpr Float8E5M2_t(uint8_t v) noexcept : value(v) {}
474  constexpr operator uint8_t() const noexcept { return value; }
475  // nan values are treated like any other value for operator ==, !=
476  constexpr bool operator==(const Float8E5M2_t& rhs) const noexcept { return value == rhs.value; };
477  constexpr bool operator!=(const Float8E5M2_t& rhs) const noexcept { return value != rhs.value; };
478 };
479 
480 static_assert(sizeof(Float8E5M2_t) == sizeof(uint8_t), "Sizes must match");
481 
482 /** \brief float8e5m2fnuz (Float8 Floating Point) data type
483  * \details It is necessary for type dispatching to make use of C++ API
484  * The type is implicitly convertible to/from uint8_t.
485  * See https://onnx.ai/onnx/technical/float8.html for further details.
486  */
488  uint8_t value;
489  constexpr Float8E5M2FNUZ_t() noexcept : value(0) {}
490  constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept : value(v) {}
491  constexpr operator uint8_t() const noexcept { return value; }
492  // nan values are treated like any other value for operator ==, !=
493  constexpr bool operator==(const Float8E5M2FNUZ_t& rhs) const noexcept { return value == rhs.value; };
494  constexpr bool operator!=(const Float8E5M2FNUZ_t& rhs) const noexcept { return value != rhs.value; };
495 };
496 
497 static_assert(sizeof(Float8E5M2FNUZ_t) == sizeof(uint8_t), "Sizes must match");
498 
499 namespace detail {
500 // This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
501 // This can't be done in the C API since C doesn't have function overloading.
502 #define ORT_DEFINE_RELEASE(NAME) \
503  inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
504 
505 ORT_DEFINE_RELEASE(Allocator);
506 ORT_DEFINE_RELEASE(MemoryInfo);
507 ORT_DEFINE_RELEASE(CustomOpDomain);
508 ORT_DEFINE_RELEASE(ThreadingOptions);
509 ORT_DEFINE_RELEASE(Env);
511 ORT_DEFINE_RELEASE(LoraAdapter);
512 ORT_DEFINE_RELEASE(Session);
513 ORT_DEFINE_RELEASE(SessionOptions);
514 ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
515 ORT_DEFINE_RELEASE(SequenceTypeInfo);
516 ORT_DEFINE_RELEASE(MapTypeInfo);
517 ORT_DEFINE_RELEASE(TypeInfo);
519 ORT_DEFINE_RELEASE(ModelMetadata);
520 ORT_DEFINE_RELEASE(IoBinding);
521 ORT_DEFINE_RELEASE(ArenaCfg);
523 ORT_DEFINE_RELEASE(OpAttr);
525 ORT_DEFINE_RELEASE(KernelInfo);
526 
527 #undef ORT_DEFINE_RELEASE
528 
529 /** \brief This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object
530  * has no ownership of the underlying C object.
531  */
532 template <typename T>
533 struct Unowned {
534  using Type = T;
535 };
536 
537 /** \brief Used internally by the C++ API. C++ wrapper types inherit from this.
538  * This is a zero cost abstraction to wrap the C API objects and delete them on destruction.
539  *
540  * All of the C++ classes
541  * a) serve as containers for pointers to objects that are created by the underlying C API.
542  * Their size is just a pointer size, no need to dynamically allocate them. Use them by value.
543  * b) Each of struct XXXX, XXX instances function as smart pointers to the underlying C API objects.
544  * they would release objects owned automatically when going out of scope, they are move-only.
545  * c) ConstXXXX and UnownedXXX structs function as non-owning, copyable containers for the above pointers.
546  * ConstXXXX allow calling const interfaces only. They give access to objects that are owned by somebody else
547  * such as Onnxruntime or instances of XXXX classes.
548  * d) serve convenient interfaces that return C++ objects and further enhance exception and type safety so they can be used
549  * in C++ code.
550  *
551  */
552 
553 /// <summary>
554 /// This is a non-const pointer holder that is move-only. Disposes of the pointer on destruction.
555 /// </summary>
556 template <typename T>
557 struct Base {
558  using contained_type = T;
559 
560  constexpr Base() = default;
561  constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
562  ~Base() { OrtRelease(p_); }
563 
564  Base(const Base&) = delete;
565  Base& operator=(const Base&) = delete;
566 
567  Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
568  Base& operator=(Base&& v) noexcept {
569  OrtRelease(p_);
570  p_ = v.release();
571  return *this;
572  }
573 
574  constexpr operator contained_type*() const noexcept { return p_; }
575 
576  /// \brief Relinquishes ownership of the contained C object pointer
577  /// The underlying object is not destroyed
579  T* p = p_;
580  p_ = nullptr;
581  return p;
582  }
583 
584  protected:
586 };
587 
588 // Undefined. For const types use Base<Unowned<const T>>
589 template <typename T>
590 struct Base<const T>;
591 
592 /// <summary>
593 /// Covers unowned pointers owned by either the ORT
594 /// or some other instance of CPP wrappers.
595 /// Used for ConstXXX and UnownedXXXX types that are copyable.
596 /// Also convenient to wrap raw OrtXX pointers .
597 /// </summary>
598 /// <typeparam name="T"></typeparam>
599 template <typename T>
600 struct Base<Unowned<T>> {
602 
603  constexpr Base() = default;
604  constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
605 
606  ~Base() = default;
607 
608  Base(const Base&) = default;
609  Base& operator=(const Base&) = default;
610 
611  Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
612  Base& operator=(Base&& v) noexcept {
613  p_ = nullptr;
614  std::swap(p_, v.p_);
615  return *this;
616  }
617 
618  constexpr operator contained_type*() const noexcept { return p_; }
619 
620  protected:
622 };
623 
624 // Light functor to release memory with OrtAllocator
626  OrtAllocator* allocator_;
627  explicit AllocatedFree(OrtAllocator* allocator)
628  : allocator_(allocator) {}
629  void operator()(void* ptr) const {
630  if (ptr) allocator_->Free(allocator_, ptr);
631  }
632 };
633 
634 } // namespace detail
635 
636 struct AllocatorWithDefaultOptions;
637 struct Env;
638 struct TypeInfo;
639 struct Value;
640 struct ModelMetadata;
641 
642 /** \brief unique_ptr typedef used to own strings allocated by OrtAllocators
643  * and release them at the end of the scope. The lifespan of the given allocator
644  * must eclipse the lifespan of AllocatedStringPtr instance
645  */
646 using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
647 
648 /** \brief The Status that holds ownership of OrtStatus received from C API
649  * Use it to safely destroy OrtStatus* returned from the C API. Use appropriate
650  * constructors to construct an instance of a Status object from exceptions.
651  */
652 struct Status : detail::Base<OrtStatus> {
653  explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used
654  explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API.
655  explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception
656  explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception
657  Status(const char* message, OrtErrorCode code) noexcept; ///< Creates status instance out of null-terminated string message.
658  std::string GetErrorMessage() const;
659  OrtErrorCode GetErrorCode() const;
660  bool IsOK() const noexcept; ///< Returns true if instance represents an OK (non-error) status.
661 };
662 
663 /** \brief The ThreadingOptions
664  *
665  * The ThreadingOptions used for set global threadpools' options of The Env.
666  */
667 struct ThreadingOptions : detail::Base<OrtThreadingOptions> {
668  /// \brief Wraps OrtApi::CreateThreadingOptions
670 
671  /// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads
672  ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads);
673 
674  /// \brief Wraps OrtApi::SetGlobalInterOpNumThreads
675  ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads);
676 
677  /// \brief Wraps OrtApi::SetGlobalSpinControl
678  ThreadingOptions& SetGlobalSpinControl(int allow_spinning);
679 
680  /// \brief Wraps OrtApi::SetGlobalDenormalAsZero
681  ThreadingOptions& SetGlobalDenormalAsZero();
682 
683  /// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn
684  ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn);
685 
686  /// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions
687  ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
688 
689  /// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn
690  ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);
691 };
692 
693 /** \brief The Env (Environment)
694  *
695  * The Env holds the logging state used by all other objects.
696  * <b>Note:</b> One Env must be created before using any other Onnxruntime functionality
697  */
698 struct Env : detail::Base<OrtEnv> {
699  explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used
700 
701  /// \brief Wraps OrtApi::CreateEnv
702  Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
703 
704  /// \brief Wraps OrtApi::CreateEnvWithCustomLogger
705  Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
706 
707  /// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools
708  Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
709 
710  /// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools
711  Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
712  OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
713 
714  /// \brief C Interop Helper
715  explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
716 
717  Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents
718  Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents
719 
720  Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level); ///< Wraps OrtApi::UpdateEnvWithCustomLogLevel
721 
722  Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator
723 
724  Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2
725 };
726 
727 /** \brief Custom Op Domain
728  *
729  */
730 struct CustomOpDomain : detail::Base<OrtCustomOpDomain> {
731  explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used
732 
733  /// \brief Wraps OrtApi::CreateCustomOpDomain
734  explicit CustomOpDomain(const char* domain);
735 
736  // This does not take ownership of the op, simply registers it.
737  void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add
738 };
739 
740 /// \brief LoraAdapter holds a set of Lora Parameters loaded from a single file
741 struct LoraAdapter : detail::Base<OrtLoraAdapter> {
743  using Base::Base;
744 
745  explicit LoraAdapter(std::nullptr_t) {} ///< Create an empty LoraAdapter object, must be assigned a valid one to be used
746  /// \brief Wraps OrtApi::CreateLoraAdapter
747  ///
748  /// The function attempts to load the adapter from the specified file
749  /// \param adapter_path The path to the Lora adapter
750  /// \param allocator optional pointer to a device allocator. If nullptr, the data stays on CPU. It would still
751  /// be copied to device if required by the model at inference time.
752  static LoraAdapter CreateLoraAdapter(const std::basic_string<ORTCHAR_T>& adapter_path,
753  OrtAllocator* allocator);
754 
755  /// \brief Wraps OrtApi::CreateLoraAdapterFromArray
756  ///
757  /// The function attempts to load the adapter from the specified byte array.
758  /// \param bytes The byte array containing file LoraAdapter format
759  /// \param num_bytes The number of bytes in the byte array
760  /// \param allocator optional pointer to a device allocator. If nullptr, the data stays on CPU. It would still
761  /// be copied to device if required by the model at inference time.
762  static LoraAdapter CreateLoraAdapterFromArray(const void* bytes, size_t num_bytes,
763  OrtAllocator* allocator);
764 };
765 
766 /** \brief RunOptions
767  *
768  */
769 struct RunOptions : detail::Base<OrtRunOptions> {
770  explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used
771  RunOptions(); ///< Wraps OrtApi::CreateRunOptions
772 
773  RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel
774  int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel
775 
776  RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel
777  int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel
778 
779  RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag
780  const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag
781 
782  RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry
783 
784  /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance
785  *
786  * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error
787  * Wraps OrtApi::RunOptionsSetTerminate
788  */
789  RunOptions& SetTerminate();
790 
791  /** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating
792  *
793  * Wraps OrtApi::RunOptionsUnsetTerminate
794  */
795  RunOptions& UnsetTerminate();
796 
797  /** \brief Add the LoraAdapter to the list of active adapters.
798  * The setting does not affect RunWithBinding() calls.
799  *
800  * Wraps OrtApi::RunOptionsAddActiveLoraAdapter
801  * \param adapter The LoraAdapter to be used as the active adapter
802  */
803  RunOptions& AddActiveLoraAdapter(const LoraAdapter& adapter);
804 };
805 
806 namespace detail {
807 // Utility function that returns a SessionOption config entry key for a specific custom operator.
808 // Ex: custom_op.[custom_op_name].[config]
809 std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config);
810 } // namespace detail
811 
812 /// <summary>
813 /// Class that represents session configuration entries for one or more custom operators.
814 ///
815 /// Example:
816 /// Ort::CustomOpConfigs op_configs;
817 /// op_configs.AddConfig("my_custom_op", "device_type", "CPU");
818 ///
819 /// Passed to Ort::SessionOptions::RegisterCustomOpsLibrary.
820 /// </summary>
822  CustomOpConfigs() = default;
823  ~CustomOpConfigs() = default;
824  CustomOpConfigs(const CustomOpConfigs&) = default;
825  CustomOpConfigs& operator=(const CustomOpConfigs&) = default;
826  CustomOpConfigs(CustomOpConfigs&& o) = default;
827  CustomOpConfigs& operator=(CustomOpConfigs&& o) = default;
828 
829  /** \brief Adds a session configuration entry/value for a specific custom operator.
830  *
831  * \param custom_op_name The name of the custom operator for which to add a configuration entry.
832  * Must match the name returned by the CustomOp's GetName() method.
833  * \param config_key The name of the configuration entry.
834  * \param config_value The value of the configuration entry.
835  * \return A reference to this object to enable call chaining.
836  */
837  CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value);
838 
839  /** \brief Returns a flattened map of custom operator configuration entries and their values.
840  *
841  * The keys has been flattened to include both the custom operator name and the configuration entry key name.
842  * For example, a prior call to AddConfig("my_op", "key", "value") corresponds to the flattened key/value pair
843  * {"my_op.key", "value"}.
844  *
845  * \return An unordered map of flattened configurations.
846  */
847  const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const;
848 
849  private:
850  std::unordered_map<std::string, std::string> flat_configs_;
851 };
852 
853 /** \brief Options object used when creating a new Session object
854  *
855  * Wraps ::OrtSessionOptions object and methods
856  */
857 
858 struct SessionOptions;
859 
860 namespace detail {
861 // we separate const-only methods because passing const ptr to non-const methods
862 // is only discovered when inline methods are compiled which is counter-intuitive
863 template <typename T>
865  using B = Base<T>;
866  using B::B;
867 
868  SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions
869 
870  std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry
871  bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry
872  std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def);
873 };
874 
875 template <typename T>
878  using B::B;
879 
880  SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads
881  SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads
882  SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel
883  SessionOptionsImpl& SetDeterministicCompute(bool value); ///< Wraps OrtApi::SetDeterministicCompute
884 
885  SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena
886  SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena
887 
888  SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath
889 
890  SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling
891  SessionOptionsImpl& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling
892 
893  SessionOptionsImpl& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps
894 
895  SessionOptionsImpl& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern
896  SessionOptionsImpl& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern
897 
898  SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode
899 
900  SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId
901  SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel
902 
903  SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain
904 
905  SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads
906 
907  SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry
908 
909  SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer
910  SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values); ///< Wraps OrtApi::AddExternalInitializers
911  SessionOptionsImpl& AddExternalInitializersFromFilesInMemory(const std::vector<std::basic_string<ORTCHAR_T>>& external_initializer_file_names,
912  const std::vector<char*>& external_initializer_file_buffer_array,
913  const std::vector<size_t>& external_initializer_file_lengths); ///< Wraps OrtApi::AddExternalInitializersFromFilesInMemory
914 
915  SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA
916  SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2
917  SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
918  SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
919  ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO_V2
920  SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options = {});
921  SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
922  SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
923  SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
924  ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN
925  SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options);
926  ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl
927  SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options);
928  /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK.
929  SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
930  const std::unordered_map<std::string, std::string>& provider_options = {});
931 
932  SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
933  SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions
934  SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn
935 
936  ///< Registers the custom operator from the specified shared library via OrtApi::RegisterCustomOpsLibrary_V2.
937  ///< The custom operator configurations are optional. If provided, custom operator configs are set via
938  ///< OrtApi::AddSessionConfigEntry.
939  SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {});
940 
941  SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction
942 
943  ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_VitisAI
944  SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map<std::string, std::string>& provider_options = {});
945 };
946 } // namespace detail
947 
950 
951 /** \brief Wrapper around ::OrtSessionOptions
952  *
953  */
954 struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> {
955  explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used
956  SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions
957  explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl<OrtSessionOptions>{p} {} ///< Used for interop with the C API
958  UnownedSessionOptions GetUnowned() const { return UnownedSessionOptions{this->p_}; }
959  ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; }
960 };
961 
962 /** \brief Wrapper around ::OrtModelMetadata
963  *
964  */
965 struct ModelMetadata : detail::Base<OrtModelMetadata> {
966  explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used
967  explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} ///< Used for interop with the C API
968 
969  /** \brief Returns a copy of the producer name.
970  *
971  * \param allocator to allocate memory for the copy of the name returned
972  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
973  * The OrtAllocator instances must be valid at the point of memory release.
974  */
975  AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName
976 
977  /** \brief Returns a copy of the graph name.
978  *
979  * \param allocator to allocate memory for the copy of the name returned
980  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
981  * The OrtAllocator instances must be valid at the point of memory release.
982  */
983  AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName
984 
985  /** \brief Returns a copy of the domain name.
986  *
987  * \param allocator to allocate memory for the copy of the name returned
988  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
989  * The OrtAllocator instances must be valid at the point of memory release.
990  */
991  AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain
992 
993  /** \brief Returns a copy of the description.
994  *
995  * \param allocator to allocate memory for the copy of the string returned
996  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
997  * The OrtAllocator instances must be valid at the point of memory release.
998  */
999  AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription
1000 
1001  /** \brief Returns a copy of the graph description.
1002  *
1003  * \param allocator to allocate memory for the copy of the string returned
1004  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1005  * The OrtAllocator instances must be valid at the point of memory release.
1006  */
1007  AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription
1008 
1009  /** \brief Returns a vector of copies of the custom metadata keys.
1010  *
1011  * \param allocator to allocate memory for the copy of the string returned
1012  * \return a instance std::vector of smart pointers that would deallocate the buffers when out of scope.
1013  * The OrtAllocator instance must be valid at the point of memory release.
1014  */
1015  std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys
1016 
1017  /** \brief Looks up a value by a key in the Custom Metadata map
1018  *
1019  * \param key zero terminated string key to lookup
1020  * \param allocator to allocate memory for the copy of the string returned
1021  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1022  * maybe nullptr if key is not found.
1023  *
1024  * The OrtAllocator instances must be valid at the point of memory release.
1025  */
1026  AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap
1027 
1028  int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion
1029 };
1030 
1031 struct IoBinding;
1032 
1033 namespace detail {
1034 
1035 // we separate const-only methods because passing const ptr to non-const methods
1036 // is only discovered when inline methods are compiled which is counter-intuitive
1037 template <typename T>
1038 struct ConstSessionImpl : Base<T> {
1039  using B = Base<T>;
1040  using B::B;
1041 
1042  size_t GetInputCount() const; ///< Returns the number of model inputs
1043  size_t GetOutputCount() const; ///< Returns the number of model outputs
1044  size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden
1045 
1046  /** \brief Returns a copy of input name at the specified index.
1047  *
1048  * \param index must less than the value returned by GetInputCount()
1049  * \param allocator to allocate memory for the copy of the name returned
1050  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1051  * The OrtAllocator instances must be valid at the point of memory release.
1052  */
1053  AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const;
1054 
1055  /** \brief Returns a copy of output name at then specified index.
1056  *
1057  * \param index must less than the value returned by GetOutputCount()
1058  * \param allocator to allocate memory for the copy of the name returned
1059  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1060  * The OrtAllocator instances must be valid at the point of memory release.
1061  */
1062  AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const;
1063 
1064  /** \brief Returns a copy of the overridable initializer name at then specified index.
1065  *
1066  * \param index must less than the value returned by GetOverridableInitializerCount()
1067  * \param allocator to allocate memory for the copy of the name returned
1068  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1069  * The OrtAllocator instances must be valid at the point of memory release.
1070  */
1071  AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName
1072 
1073  uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
1074  ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata
1075 
1076  TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo
1077  TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo
1078  TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo
1079 };
1080 
1081 template <typename T>
1084  using B::B;
1085 
1086  /** \brief Run the model returning results in an Ort allocated vector.
1087  *
1088  * Wraps OrtApi::Run
1089  *
1090  * The caller provides a list of inputs and a list of the desired outputs to return.
1091  *
1092  * See the output logs for more information on warnings/errors that occur while processing the model.
1093  * Common errors are.. (TODO)
1094  *
1095  * \param[in] run_options
1096  * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names
1097  * \param[in] input_values Array of Value objects of length input_count that is the list of input values
1098  * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays)
1099  * \param[in] output_names Array of C style strings of length output_count that is the list of output names
1100  * \param[in] output_count Number of outputs (the size of the output_names array)
1101  * \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector)
1102  */
1103  std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1104  const char* const* output_names, size_t output_count);
1105 
1106  /** \brief Run the model returning results in user provided outputs
1107  * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t)
1108  */
1109  void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1110  const char* const* output_names, Value* output_values, size_t output_count);
1111 
1112  void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding
1113 
1114  /** \brief Run the model asynchronously in a thread owned by intra op thread pool
1115  *
1116  * Wraps OrtApi::RunAsync
1117  *
1118  * \param[in] run_options
1119  * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
1120  * \param[in] input_values Array of Value objects of length input_count
1121  * \param[in] input_count Number of elements in the input_names and inputs arrays
1122  * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
1123  * \param[out] output_values Array of provided Values to be filled with outputs.
1124  * On calling RunAsync, output_values[i] could either be initialized by a null pointer or a preallocated OrtValue*.
1125  * Later, on invoking the callback, each output_values[i] of null will be filled with an OrtValue* allocated by onnxruntime.
1126  * Then, an OrtValue** pointer will be casted from output_values, and pass to the callback.
1127  * NOTE: it is customer's duty to finally release output_values and each of its member,
1128  * regardless of whether the member (Ort::Value) is allocated by onnxruntime or preallocated by the customer.
1129  * \param[in] output_count Number of elements in the output_names and outputs array
1130  * \param[in] callback Callback function on model run completion
1131  * \param[in] user_data User data that pass back to the callback
1132  */
1133  void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1134  const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data);
1135 
1136  /** \brief End profiling and return a copy of the profiling file name.
1137  *
1138  * \param allocator to allocate memory for the copy of the string returned
1139  * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1140  * The OrtAllocator instances must be valid at the point of memory release.
1141  */
1142  AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator); ///< Wraps OrtApi::SessionEndProfiling
1143 
1144  /** \brief Set DynamicOptions for EPs (Execution Providers)
1145  *
1146  * Wraps OrtApi::SetEpDynamicOptions
1147  *
1148  * Valid options can be found in `include\onnxruntime\core\session\onnxruntime_session_options_config_keys.h`
1149  * Look for `kOrtEpDynamicOptions`
1150  *
1151  * \param[in] keys Array of null terminated UTF8 encoded strings of EP dynamic option keys
1152  * \param[in] values Array of null terminated UTF8 encoded string of EP dynamic option values
1153  * \param[in] kv_len Number of elements in the keys and values arrays
1154  */
1155  void SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len);
1156 };
1157 
1158 } // namespace detail
1159 
1162 
1163 /** \brief Wrapper around ::OrtSession
1164  *
1165  */
1166 struct Session : detail::SessionImpl<OrtSession> {
1167  explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used
1168  Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession
1169  Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
1170  OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
1171  Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray
1172  Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
1173  OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer
1174 
1175  ConstSession GetConst() const { return ConstSession{this->p_}; }
1176  UnownedSession GetUnowned() const { return UnownedSession{this->p_}; }
1177 };
1178 
1179 namespace detail {
1180 template <typename T>
1181 struct MemoryInfoImpl : Base<T> {
1182  using B = Base<T>;
1183  using B::B;
1184 
1185  std::string GetAllocatorName() const;
1186  OrtAllocatorType GetAllocatorType() const;
1187  int GetDeviceId() const;
1188  OrtMemoryInfoDeviceType GetDeviceType() const;
1189  OrtMemType GetMemoryType() const;
1190 
1191  template <typename U>
1192  bool operator==(const MemoryInfoImpl<U>& o) const;
1193 };
1194 } // namespace detail
1195 
1196 // Const object holder that does not own the underlying object
1198 
1199 /** \brief Wrapper around ::OrtMemoryInfo
1200  *
1201  */
1202 struct MemoryInfo : detail::MemoryInfoImpl<OrtMemoryInfo> {
1203  static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1);
1204  explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created
1205  explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl<OrtMemoryInfo>{p} {} ///< Take ownership of a pointer created by C Api
1206  MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
1207  ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; }
1208 };
1209 
1210 namespace detail {
1211 template <typename T>
1213  using B = Base<T>;
1214  using B::B;
1215 
1216  ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType
1217  size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount
1218 
1219  size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount
1220 
1221  /** \deprecated use GetShape() returning std::vector
1222  * [[deprecated]]
1223  * This interface is unsafe to use
1224  */
1225  [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions
1226 
1227  void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions
1228 
1229  std::vector<int64_t> GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape
1230 };
1231 
1232 } // namespace detail
1233 
1235 
1236 /** \brief Wrapper around ::OrtTensorTypeAndShapeInfo
1237  *
1238  */
1239 struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl<OrtTensorTypeAndShapeInfo> {
1240  explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used
1241  explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API
1243 };
1244 
1245 namespace detail {
1246 template <typename T>
1248  using B = Base<T>;
1249  using B::B;
1250  TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType
1251 };
1252 
1253 } // namespace detail
1254 
1256 
1257 /** \brief Wrapper around ::OrtSequenceTypeInfo
1258  *
1259  */
1260 struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> {
1261  explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used
1262  explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {} ///< Used for interop with the C API
1263  ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; }
1264 };
1265 
1266 namespace detail {
1267 template <typename T>
1269  using B = Base<T>;
1270  using B::B;
1271  TypeInfo GetOptionalElementType() const; ///< Wraps OrtApi::CastOptionalTypeToContainedTypeInfo
1272 };
1273 
1274 } // namespace detail
1275 
1276 // This is always owned by the TypeInfo and can only be obtained from it.
1278 
1279 namespace detail {
1280 template <typename T>
1282  using B = Base<T>;
1283  using B::B;
1284  ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType
1285  TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType
1286 };
1287 
1288 } // namespace detail
1289 
1291 
1292 /** \brief Wrapper around ::OrtMapTypeInfo
1293  *
1294  */
1295 struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> {
1296  explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used
1297  explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {} ///< Used for interop with the C API
1298  ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; }
1299 };
1300 
1301 namespace detail {
1302 template <typename T>
1304  using B = Base<T>;
1305  using B::B;
1306 
1307  ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo
1308  ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo
1309  ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo
1310  ConstOptionalTypeInfo GetOptionalTypeInfo() const; ///< wraps OrtApi::CastTypeInfoToOptionalTypeInfo
1311 
1312  ONNXType GetONNXType() const;
1313 };
1314 } // namespace detail
1315 
1316 /// <summary>
1317 /// Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value.
1318 /// Provides access to const OrtTypeInfo APIs.
1319 /// </summary>
1321 
1322 /// <summary>
1323 /// Type information that may contain either TensorTypeAndShapeInfo or
1324 /// the information about contained sequence or map depending on the ONNXType.
1325 /// </summary>
1326 struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> {
1327  explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used
1328  explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {} ///< C API Interop
1329 
1330  ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; }
1331 };
1332 
1333 namespace detail {
1334 // This structure is used to feed sparse tensor values
1335 // information for use with FillSparseTensor<Format>() API
1336 // if the data type for the sparse tensor values is numeric
1337 // use data.p_data, otherwise, use data.str pointer to feed
1338 // values. data.str is an array of const char* that are zero terminated.
1339 // number of strings in the array must match shape size.
1340 // For fully sparse tensors use shape {0} and set p_data/str
1341 // to nullptr.
1343  const int64_t* values_shape;
1345  union {
1346  const void* p_data;
1347  const char** str;
1348  } data;
1349 };
1350 
1351 // Provides a way to pass shape in a single
1352 // argument
1353 struct Shape {
1354  const int64_t* shape;
1355  size_t shape_len;
1356 };
1357 
1358 template <typename T>
1359 struct ConstValueImpl : Base<T> {
1360  using B = Base<T>;
1361  using B::B;
1362 
1363  /// <summary>
1364  /// Obtains a pointer to a user defined data for experimental purposes
1365  /// </summary>
1366  template <typename R>
1367  void GetOpaqueData(const char* domain, const char* type_name, R&) const; ///< Wraps OrtApi::GetOpaqueValue
1368 
1369  bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc
1370  bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None
1371 
1372  size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
1373  Value GetValue(int index, OrtAllocator* allocator) const;
1374 
1375  /// <summary>
1376  /// This API returns a full length of string data contained within either a tensor or a sparse Tensor.
1377  /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful
1378  /// for allocating necessary memory and calling GetStringTensorContent().
1379  /// </summary>
1380  /// <returns>total length of UTF-8 encoded bytes contained. No zero terminators counted.</returns>
1381  size_t GetStringTensorDataLength() const;
1382 
1383  /// <summary>
1384  /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor
1385  /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate.
1386  /// The user must also allocate offsets buffer with the number of entries equal to that of the contained
1387  /// strings.
1388  ///
1389  /// Strings are always assumed to be on CPU, no X-device copy.
1390  /// </summary>
1391  /// <param name="buffer">user allocated buffer</param>
1392  /// <param name="buffer_length">length in bytes of the allocated buffer</param>
1393  /// <param name="offsets">a pointer to the offsets user allocated buffer</param>
1394  /// <param name="offsets_count">count of offsets, must be equal to the number of strings contained.
1395  /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo()
1396  /// for sparse tensors</param>
1397  void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
1398 
1399  /// <summary>
1400  /// Returns a const typed pointer to the tensor contained data.
1401  /// No type checking is performed, the caller must ensure the type matches the tensor type.
1402  /// </summary>
1403  /// <typeparam name="T"></typeparam>
1404  /// <returns>const pointer to data, no copies made</returns>
1405  template <typename R>
1406  const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// <summary>
1407 
1408  /// <summary>
1409  /// Returns a non-typed pointer to a tensor contained data.
1410  /// </summary>
1411  /// <returns>const pointer to data, no copies made</returns>
1412  const void* GetTensorRawData() const;
1413 
1414  /// <summary>
1415  /// The API returns type information for data contained in a tensor. For sparse
1416  /// tensors it returns type information for contained non-zero values.
1417  /// It returns dense shape for sparse tensors.
1418  /// </summary>
1419  /// <returns>TypeInfo</returns>
1420  TypeInfo GetTypeInfo() const;
1421 
1422  /// <summary>
1423  /// The API returns type information for data contained in a tensor. For sparse
1424  /// tensors it returns type information for contained non-zero values.
1425  /// It returns dense shape for sparse tensors.
1426  /// </summary>
1427  /// <returns>TensorTypeAndShapeInfo</returns>
1428  TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
1429 
1430  /// <summary>
1431  /// This API returns information about the memory allocation used to hold data.
1432  /// </summary>
1433  /// <returns>Non owning instance of MemoryInfo</returns>
1434  ConstMemoryInfo GetTensorMemoryInfo() const;
1435 
1436  /// <summary>
1437  /// The API copies UTF-8 encoded bytes for the requested string element
1438  /// contained within a tensor or a sparse tensor into a provided buffer.
1439  /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate.
1440  /// </summary>
1441  /// <param name="buffer_length"></param>
1442  /// <param name="element_index"></param>
1443  /// <param name="buffer"></param>
1444  void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
1445 
1446  /// <summary>
1447  /// Returns string tensor UTF-8 encoded string element.
1448  /// Use of this API is recommended over GetStringTensorElement() that takes void* buffer pointer.
1449  /// </summary>
1450  /// <param name="element_index"></param>
1451  /// <returns>std::string</returns>
1452  std::string GetStringTensorElement(size_t element_index) const;
1453 
1454  /// <summary>
1455  /// The API returns a byte length of UTF-8 encoded string element
1456  /// contained in either a tensor or a spare tensor values.
1457  /// </summary>
1458  /// <param name="element_index"></param>
1459  /// <returns>byte length for the specified string element</returns>
1460  size_t GetStringTensorElementLength(size_t element_index) const;
1461 
1462 #if !defined(DISABLE_SPARSE_TENSORS)
1463  /// <summary>
1464  /// The API returns the sparse data format this OrtValue holds in a sparse tensor.
1465  /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used
1466  /// the value returned is ORT_SPARSE_UNDEFINED.
1467  /// </summary>
1468  /// <returns>Format enum</returns>
1469  OrtSparseFormat GetSparseFormat() const;
1470 
1471  /// <summary>
1472  /// The API returns type and shape information for stored non-zero values of the
1473  /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer.
1474  /// </summary>
1475  /// <returns>TensorTypeAndShapeInfo values information</returns>
1476  TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const;
1477 
1478  /// <summary>
1479  /// The API returns type and shape information for the specified indices. Each supported
1480  /// indices have their own enum values even if a give format has more than one kind of indices.
1481  /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer.
1482  /// </summary>
1483  /// <param name="format">enum requested</param>
1484  /// <returns>type and shape information</returns>
1485  TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const;
1486 
1487  /// <summary>
1488  /// The API retrieves a pointer to the internal indices buffer. The API merely performs
1489  /// a convenience data type casting on the return type pointer. Make sure you are requesting
1490  /// the right type, use GetSparseTensorIndicesTypeShapeInfo();
1491  /// </summary>
1492  /// <typeparam name="T">type to cast to</typeparam>
1493  /// <param name="indices_format">requested indices kind</param>
1494  /// <param name="num_indices">number of indices entries</param>
1495  /// <returns>Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer.</returns>
1496  template <typename R>
1497  const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
1498 
1499  /// <summary>
1500  /// Returns true if the OrtValue contains a sparse tensor
1501  /// </summary>
1502  /// <returns></returns>
1503  bool IsSparseTensor() const;
1504 
1505  /// <summary>
1506  /// The API returns a pointer to an internal buffer of the sparse tensor
1507  /// containing non-zero values. The API merely does casting. Make sure you
1508  /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo()
1509  /// first.
1510  /// </summary>
1511  /// <typeparam name="T">numeric data types only. Use GetStringTensor*() to retrieve strings.</typeparam>
1512  /// <returns>a pointer to the internal values buffer. Do not free this pointer.</returns>
1513  template <typename R>
1514  const R* GetSparseTensorValues() const;
1515 
1516 #endif
1517 };
1518 
1519 template <typename T>
1522  using B::B;
1523 
1524  /// <summary>
1525  /// Returns a non-const typed pointer to an OrtValue/Tensor contained buffer
1526  /// No type checking is performed, the caller must ensure the type matches the tensor type.
1527  /// </summary>
1528  /// <returns>non-const pointer to data, no copies made</returns>
1529  template <typename R>
1530  R* GetTensorMutableData();
1531 
1532  /// <summary>
1533  /// Returns a non-typed non-const pointer to a tensor contained data.
1534  /// </summary>
1535  /// <returns>pointer to data, no copies made</returns>
1536  void* GetTensorMutableRawData();
1537 
1538  /// <summary>
1539  // Obtain a reference to an element of data at the location specified
1540  /// by the vector of dims.
1541  /// </summary>
1542  /// <typeparam name="R"></typeparam>
1543  /// <param name="location">[in] expressed by a vecotr of dimensions offsets</param>
1544  /// <returns></returns>
1545  template <typename R>
1546  R& At(const std::vector<int64_t>& location);
1547 
1548  /// <summary>
1549  /// Set all strings at once in a string tensor
1550  /// </summary>
1551  /// <param name="s">[in] An array of strings. Each string in this array must be null terminated.</param>
1552  /// <param name="s_len">[in] Count of strings in s (Must match the size of \p value's tensor shape)</param>
1553  void FillStringTensor(const char* const* s, size_t s_len);
1554 
1555  /// <summary>
1556  /// Set a single string in a string tensor
1557  /// </summary>
1558  /// <param name="s">[in] A null terminated UTF-8 encoded string</param>
1559  /// <param name="index">[in] Index of the string in the tensor to set</param>
1560  void FillStringTensorElement(const char* s, size_t index);
1561 
1562  /// <summary>
1563  /// Allocate if necessary and obtain a pointer to a UTF-8
1564  /// encoded string element buffer indexed by the flat element index,
1565  /// of the specified length.
1566  ///
1567  /// This API is for advanced usage. It avoids a need to construct
1568  /// an auxiliary array of string pointers, and allows to write data directly
1569  /// (do not zero terminate).
1570  /// </summary>
1571  /// <param name="index"></param>
1572  /// <param name="buffer_length"></param>
1573  /// <returns>a pointer to a writable buffer</returns>
1574  char* GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length);
1575 
1576 #if !defined(DISABLE_SPARSE_TENSORS)
1577  /// <summary>
1578  /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor.
1579  /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1580  /// allocated buffers lifespan must eclipse that of the OrtValue.
1581  /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1582  /// </summary>
1583  /// <param name="indices_data">pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors.</param>
1584  /// <param name="indices_num">number of indices entries. Use 0 for fully sparse tensors</param>
1585  void UseCooIndices(int64_t* indices_data, size_t indices_num);
1586 
1587  /// <summary>
1588  /// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor.
1589  /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1590  /// allocated buffers lifespan must eclipse that of the OrtValue.
1591  /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1592  /// </summary>
1593  /// <param name="inner_data">pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors</param>
1594  /// <param name="inner_num">number of csr inner indices or 0 for fully sparse tensors</param>
1595  /// <param name="outer_data">pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors</param>
1596  /// <param name="outer_num">number of csr outer indices or 0 for fully sparse tensors</param>
1597  void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
1598 
1599  /// <summary>
1600  /// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor.
1601  /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1602  /// allocated buffers lifespan must eclipse that of the OrtValue.
1603  /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1604  /// </summary>
1605  /// <param name="indices_shape">indices shape or a {0} for fully sparse</param>
1606  /// <param name="indices_data">user allocated buffer with indices or nullptr for fully spare tensors</param>
1607  void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
1608 
1609  /// <summary>
1610  /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1611  /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located
1612  /// at difference device than the allocator, a X-device copy will be performed if possible.
1613  /// </summary>
1614  /// <param name="data_mem_info">specified buffer memory description</param>
1615  /// <param name="values_param">values buffer information.</param>
1616  /// <param name="indices_data">coo indices buffer or nullptr for fully sparse data</param>
1617  /// <param name="indices_num">number of COO indices or 0 for fully sparse data</param>
1618  void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
1619  const int64_t* indices_data, size_t indices_num);
1620 
1621  /// <summary>
1622  /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1623  /// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located
1624  /// at difference device than the allocator, a X-device copy will be performed if possible.
1625  /// </summary>
1626  /// <param name="data_mem_info">specified buffer memory description</param>
1627  /// <param name="values">values buffer information</param>
1628  /// <param name="inner_indices_data">csr inner indices pointer or nullptr for fully sparse tensors</param>
1629  /// <param name="inner_indices_num">number of csr inner indices or 0 for fully sparse tensors</param>
1630  /// <param name="outer_indices_data">pointer to csr indices data or nullptr for fully sparse tensors</param>
1631  /// <param name="outer_indices_num">number of csr outer indices or 0</param>
1632  void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1634  const int64_t* inner_indices_data, size_t inner_indices_num,
1635  const int64_t* outer_indices_data, size_t outer_indices_num);
1636 
1637  /// <summary>
1638  /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1639  /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located
1640  /// at difference device than the allocator, a X-device copy will be performed if possible.
1641  /// </summary>
1642  /// <param name="data_mem_info">specified buffer memory description</param>
1643  /// <param name="values">values buffer information</param>
1644  /// <param name="indices_shape">indices shape. use {0} for fully sparse tensors</param>
1645  /// <param name="indices_data">pointer to indices data or nullptr for fully sparse tensors</param>
1646  void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
1648  const Shape& indices_shape,
1649  const int32_t* indices_data);
1650 
1651 #endif
1652 };
1653 
1654 } // namespace detail
1655 
1658 
1659 /** \brief Wrapper around ::OrtValue
1660  *
1661  */
1662 struct Value : detail::ValueImpl<OrtValue> {
1666 
1667  explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used
1668  explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API
1669  Value(Value&&) = default;
1670  Value& operator=(Value&&) = default;
1671 
1672  ConstValue GetConst() const { return ConstValue{this->p_}; }
1673  UnownedValue GetUnowned() const { return UnownedValue{this->p_}; }
1674 
1675  /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
1676  * \tparam T The numeric datatype. This API is not suitable for strings.
1677  * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
1678  * \param p_data Pointer to the data buffer.
1679  * \param p_data_element_count The number of elements in the data buffer.
1680  * \param shape Pointer to the tensor shape dimensions.
1681  * \param shape_len The number of tensor shape dimensions.
1682  */
1683  template <typename T>
1684  static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
1685 
1686  /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
1687  *
1688  * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
1689  * \param p_data Pointer to the data buffer.
1690  * \param p_data_byte_count The number of bytes in the data buffer.
1691  * \param shape Pointer to the tensor shape dimensions.
1692  * \param shape_len The number of tensor shape dimensions.
1693  * \param type The data type.
1694  */
1695  static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1696  ONNXTensorElementDataType type);
1697 
1698  /** \brief Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
1699  * This overload will allocate the buffer for the tensor according to the supplied shape and data type.
1700  * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released.
1701  * The input data would need to be copied into the allocated buffer.
1702  * This API is not suitable for strings.
1703  *
1704  * \tparam T The numeric datatype. This API is not suitable for strings.
1705  * \param allocator The allocator to use.
1706  * \param shape Pointer to the tensor shape dimensions.
1707  * \param shape_len The number of tensor shape dimensions.
1708  */
1709  template <typename T>
1710  static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
1711 
1712  /** \brief Creates an OrtValue with a tensor using the supplied OrtAllocator.
1713  * Wraps OrtApi::CreateTensorAsOrtValue.
1714  * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released.
1715  * The input data would need to be copied into the allocated buffer.
1716  * This API is not suitable for strings.
1717  *
1718  * \param allocator The allocator to use.
1719  * \param shape Pointer to the tensor shape dimensions.
1720  * \param shape_len The number of tensor shape dimensions.
1721  * \param type The data type.
1722  */
1723  static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
1724 
1725  /** \brief Creates an OrtValue with a Map Onnx type representation.
1726  * The API would ref-count the supplied OrtValues and they will be released
1727  * when the returned OrtValue is released. The caller may release keys and values after the call
1728  * returns.
1729  *
1730  * \param keys an OrtValue containing a tensor with primitive data type keys.
1731  * \param values an OrtValue that may contain a tensor. Ort currently supports only primitive data type values.
1732  */
1733  static Value CreateMap(const Value& keys, const Value& values); ///< Wraps OrtApi::CreateValue
1734 
1735  /** \brief Creates an OrtValue with a Sequence Onnx type representation.
1736  * The API would ref-count the supplied OrtValues and they will be released
1737  * when the returned OrtValue is released. The caller may release the values after the call
1738  * returns.
1739  *
1740  * \param values a vector of OrtValues that must have the same Onnx value type.
1741  */
1742  static Value CreateSequence(const std::vector<Value>& values); ///< Wraps OrtApi::CreateValue
1743 
1744  /** \brief Creates an OrtValue wrapping an Opaque type.
1745  * This is used for experimental support of non-tensor types.
1746  *
1747  * \tparam T - the type of the value.
1748  * \param domain - zero terminated utf-8 string. Domain of the type.
1749  * \param type_name - zero terminated utf-8 string. Name of the type.
1750  * \param value - the value to be wrapped.
1751  */
1752  template <typename T>
1753  static Value CreateOpaque(const char* domain, const char* type_name, const T& value); ///< Wraps OrtApi::CreateOpaqueValue
1754 
1755 #if !defined(DISABLE_SPARSE_TENSORS)
1756  /// <summary>
1757  /// This is a simple forwarding method to the other overload that helps deducing
1758  /// data type enum value from the type of the buffer.
1759  /// </summary>
1760  /// <typeparam name="T">numeric datatype. This API is not suitable for strings.</typeparam>
1761  /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
1762  /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
1763  /// <param name="dense_shape">a would be dense shape of the tensor</param>
1764  /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
1765  /// <returns></returns>
1766  template <typename T>
1767  static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1768  const Shape& values_shape);
1769 
1770  /// <summary>
1771  /// Creates an OrtValue instance containing SparseTensor. This constructs
1772  /// a sparse tensor that makes use of user allocated buffers. It does not make copies
1773  /// of the user provided data and does not modify it. The lifespan of user provided buffers should
1774  /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain
1775  /// a pointer to non-zero values. To fully populate the sparse tensor call Use<Format>Indices() API below
1776  /// to supply a sparse format specific indices.
1777  /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings
1778  /// can be properly copied into the allocated buffer.
1779  /// </summary>
1780  /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
1781  /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
1782  /// <param name="dense_shape">a would be dense shape of the tensor</param>
1783  /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
1784  /// <param name="type">data type</param>
1785  /// <returns>Ort::Value instance containing SparseTensor</returns>
1786  static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1787  const Shape& values_shape, ONNXTensorElementDataType type);
1788 
1789  /// <summary>
1790  /// This is a simple forwarding method to the below CreateSparseTensor.
1791  /// This helps to specify data type enum in terms of C++ data type.
1792  /// Use CreateSparseTensor<T>
1793  /// </summary>
1794  /// <typeparam name="T">numeric data type only. String data enum must be specified explicitly.</typeparam>
1795  /// <param name="allocator">allocator to use</param>
1796  /// <param name="dense_shape">a would be dense shape of the tensor</param>
1797  /// <returns>Ort::Value</returns>
1798  template <typename T>
1799  static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
1800 
1801  /// <summary>
1802  /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data.
1803  /// The data must be supplied by on of the FillSparseTensor<Format>() methods that take both non-zero values
1804  /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator.
1805  /// Use this API to create OrtValues that contain sparse tensors with all supported data types including
1806  /// strings.
1807  /// </summary>
1808  /// <param name="allocator">allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue</param>
1809  /// <param name="dense_shape">a would be dense shape of the tensor</param>
1810  /// <param name="type">data type</param>
1811  /// <returns>an instance of Ort::Value</returns>
1812  static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
1813 
1814 #endif // !defined(DISABLE_SPARSE_TENSORS)
1815 };
1816 
1817 /// <summary>
1818 /// Represents native memory allocation coming from one of the
1819 /// OrtAllocators registered with OnnxRuntime.
1820 /// Use it to wrap an allocation made by an allocator
1821 /// so it can be automatically released when no longer needed.
1822 /// </summary>
1824  MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
1825  ~MemoryAllocation();
1826  MemoryAllocation(const MemoryAllocation&) = delete;
1827  MemoryAllocation& operator=(const MemoryAllocation&) = delete;
1828  MemoryAllocation(MemoryAllocation&&) noexcept;
1830 
1831  void* get() { return p_; }
1832  size_t size() const { return size_; }
1833 
1834  private:
1835  OrtAllocator* allocator_;
1836  void* p_;
1837  size_t size_;
1838 };
1839 
1840 namespace detail {
1841 template <typename T>
1842 struct AllocatorImpl : Base<T> {
1843  using B = Base<T>;
1844  using B::B;
1845 
1846  void* Alloc(size_t size);
1847  MemoryAllocation GetAllocation(size_t size);
1848  void Free(void* p);
1849  ConstMemoryInfo GetInfo() const;
1850 };
1851 
1852 } // namespace detail
1853 
1854 /** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime
1855  *
1856  */
1857 struct AllocatorWithDefaultOptions : detail::AllocatorImpl<detail::Unowned<OrtAllocator>> {
1858  explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
1860 };
1861 
1862 /** \brief Wrapper around ::OrtAllocator
1863  *
1864  */
1865 struct Allocator : detail::AllocatorImpl<OrtAllocator> {
1866  explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
1867  Allocator(const Session& session, const OrtMemoryInfo*);
1868 };
1869 
1871 
1872 namespace detail {
1873 namespace binding_utils {
1874 // Bring these out of template
1875 std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*);
1876 std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*);
1877 } // namespace binding_utils
1878 
1879 template <typename T>
1881  using B = Base<T>;
1882  using B::B;
1883 
1884  std::vector<std::string> GetOutputNames() const;
1885  std::vector<std::string> GetOutputNames(OrtAllocator*) const;
1886  std::vector<Value> GetOutputValues() const;
1887  std::vector<Value> GetOutputValues(OrtAllocator*) const;
1888 };
1889 
1890 template <typename T>
1893  using B::B;
1894 
1895  void BindInput(const char* name, const Value&);
1896  void BindOutput(const char* name, const Value&);
1897  void BindOutput(const char* name, const OrtMemoryInfo*);
1898  void ClearBoundInputs();
1899  void ClearBoundOutputs();
1900  void SynchronizeInputs();
1901  void SynchronizeOutputs();
1902 };
1903 
1904 } // namespace detail
1905 
1908 
1909 /** \brief Wrapper around ::OrtIoBinding
1910  *
1911  */
1912 struct IoBinding : detail::IoBindingImpl<OrtIoBinding> {
1913  explicit IoBinding(std::nullptr_t) {} ///< Create an empty object for convenience. Sometimes, we want to initialize members later.
1914  explicit IoBinding(Session& session);
1915  ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; }
1916  UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; }
1917 };
1918 
1919 /*! \struct Ort::ArenaCfg
1920  * \brief it is a structure that represents the configuration of an arena based allocator
1921  * \details Please see docs/C_API.md for details
1922  */
1923 struct ArenaCfg : detail::Base<OrtArenaCfg> {
1924  explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used
1925  /**
1926  * Wraps OrtApi::CreateArenaCfg
1927  * \param max_mem - use 0 to allow ORT to choose the default
1928  * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
1929  * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
1930  * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
1931  * See docs/C_API.md for details on what the following parameters mean and how to choose these values
1932  */
1933  ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
1934 };
1935 
1936 //
1937 // Custom OPs (only needed to implement custom OPs)
1938 //
1939 
1940 /// <summary>
1941 /// This struct provides life time management for custom op attribute
1942 /// </summary>
1943 struct OpAttr : detail::Base<OrtOpAttr> {
1944  OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
1945 };
1946 
1947 /**
1948  * Macro that logs a message using the provided logger. Throws an exception if OrtApi::Logger_LogMessage fails.
1949  * Example: ORT_CXX_LOG(logger, ORT_LOGGING_LEVEL_INFO, "Log a message");
1950  *
1951  * \param logger The Ort::Logger instance to use. Must be a value or reference.
1952  * \param message_severity The logging severity level of the message.
1953  * \param message A null-terminated UTF-8 message to log.
1954  */
1955 #define ORT_CXX_LOG(logger, message_severity, message) \
1956  do { \
1957  if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1958  Ort::ThrowOnError(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
1959  static_cast<const char*>(__FUNCTION__), message)); \
1960  } \
1961  } while (false)
1962 
1963 /**
1964  * Macro that logs a message using the provided logger. Can be used in noexcept code since errors are silently ignored.
1965  * Example: ORT_CXX_LOG_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log a message");
1966  *
1967  * \param logger The Ort::Logger instance to use. Must be a value or reference.
1968  * \param message_severity The logging severity level of the message.
1969  * \param message A null-terminated UTF-8 message to log.
1970  */
1971 #define ORT_CXX_LOG_NOEXCEPT(logger, message_severity, message) \
1972  do { \
1973  if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1974  static_cast<void>(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
1975  static_cast<const char*>(__FUNCTION__), message)); \
1976  } \
1977  } while (false)
1978 
1979 /**
1980  * Macro that logs a printf-like formatted message using the provided logger. Throws an exception if
1981  * OrtApi::Logger_LogMessage fails or if a formatting error occurs.
1982  * Example: ORT_CXX_LOGF(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12);
1983  *
1984  * \param logger The Ort::Logger instance to use. Must be a value or reference.
1985  * \param message_severity The logging severity level of the message.
1986  * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
1987  * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
1988  * \param ... Zero or more variadic arguments referenced by the format string.
1989  */
1990 #define ORT_CXX_LOGF(logger, message_severity, /*format,*/...) \
1991  do { \
1992  if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1993  Ort::ThrowOnError(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
1994  static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
1995  } \
1996  } while (false)
1997 
1998 /**
1999  * Macro that logs a printf-like formatted message using the provided logger. Can be used in noexcept code since errors
2000  * are silently ignored.
2001  * Example: ORT_CXX_LOGF_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12);
2002  *
2003  * \param logger The Ort::Logger instance to use. Must be a value or reference.
2004  * \param message_severity The logging severity level of the message.
2005  * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
2006  * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
2007  * \param ... Zero or more variadic arguments referenced by the format string.
2008  */
2009 #define ORT_CXX_LOGF_NOEXCEPT(logger, message_severity, /*format,*/...) \
2010  do { \
2011  if (message_severity >= logger.GetLoggingSeverityLevel()) { \
2012  static_cast<void>(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
2013  static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
2014  } \
2015  } while (false)
2016 
2017 /// <summary>
2018 /// This class represents an ONNX Runtime logger that can be used to log information with an
2019 /// associated severity level and source code location (file path, line number, function name).
2020 ///
2021 /// A Logger can be obtained from within custom operators by calling Ort::KernelInfo::GetLogger().
2022 /// Instances of Ort::Logger are the size of two pointers and can be passed by value.
2023 ///
2024 /// Use the ORT_CXX_LOG macros to ensure the source code location is set properly from the callsite
2025 /// and to take advantage of a cached logging severity level that can bypass calls to the underlying C API.
2026 /// </summary>
2027 struct Logger {
2028  /**
2029  * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use.
2030  */
2031  Logger() = default;
2032 
2033  /**
2034  * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use.
2035  */
2036  explicit Logger(std::nullptr_t) {}
2037 
2038  /**
2039  * Creates a logger from an ::OrtLogger instance. Caches the logger's current severity level by calling
2040  * OrtApi::Logger_GetLoggingSeverityLevel. Throws an exception if OrtApi::Logger_GetLoggingSeverityLevel fails.
2041  *
2042  * \param logger The ::OrtLogger to wrap.
2043  */
2044  explicit Logger(const OrtLogger* logger);
2045 
2046  ~Logger() = default;
2047 
2048  Logger(const Logger&) = default;
2049  Logger& operator=(const Logger&) = default;
2050 
2051  Logger(Logger&& v) noexcept = default;
2052  Logger& operator=(Logger&& v) noexcept = default;
2053 
2054  /**
2055  * Returns the logger's current severity level from the cached member.
2056  *
2057  * \return The current ::OrtLoggingLevel.
2058  */
2059  OrtLoggingLevel GetLoggingSeverityLevel() const noexcept;
2060 
2061  /**
2062  * Logs the provided message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOG or ORT_CXX_LOG_NOEXCEPT
2063  * macros to properly set the source code location and to use the cached severity level to potentially bypass
2064  * calls to the underlying C API.
2065  *
2066  * \param log_severity_level The message's logging severity level.
2067  * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE.
2068  * \param line_number The file line number in which the message is logged. Usually the value of __LINE__.
2069  * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__.
2070  * \param message The message to log.
2071  * \return A Ort::Status value to indicate error or success.
2072  */
2073  Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
2074  const char* func_name, const char* message) const noexcept;
2075 
2076  /**
2077  * Logs a printf-like formatted message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOGF or ORT_CXX_LOGF_NOEXCEPT
2078  * macros to properly set the source code location and to use the cached severity level to potentially bypass
2079  * calls to the underlying C API. Returns an error status if a formatting error occurs.
2080  *
2081  * \param log_severity_level The message's logging severity level.
2082  * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE.
2083  * \param line_number The file line number in which the message is logged. Usually the value of __LINE__.
2084  * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__.
2085  * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
2086  * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
2087  * \param args Zero or more variadic arguments referenced by the format string.
2088  * \return A Ort::Status value to indicate error or success.
2089  */
2090  template <typename... Args>
2091  Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
2092  const char* func_name, const char* format, Args&&... args) const noexcept;
2093 
2094  private:
2095  const OrtLogger* logger_{};
2096  OrtLoggingLevel cached_severity_level_{};
2097 };
2098 
2099 /// <summary>
2100 /// This class wraps a raw pointer OrtKernelContext* that is being passed
2101 /// to the custom kernel Compute() method. Use it to safely access context
2102 /// attributes, input and output parameters with exception safety guarantees.
2103 /// See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc
2104 /// </summary>
2106  explicit KernelContext(OrtKernelContext* context);
2107  size_t GetInputCount() const;
2108  size_t GetOutputCount() const;
2109  // If input is optional and is not present, the method returns en empty ConstValue
2110  // which can be compared to nullptr.
2111  ConstValue GetInput(size_t index) const;
2112  // If outout is optional and is not present, the method returns en empty UnownedValue
2113  // which can be compared to nullptr.
2114  UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
2115  UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
2116  void* GetGPUComputeStream() const;
2117  Logger GetLogger() const;
2118  OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const;
2119  OrtKernelContext* GetOrtKernelContext() const { return ctx_; }
2120  void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const;
2121 
2122  private:
2123  OrtKernelContext* ctx_;
2124 };
2125 
2126 struct KernelInfo;
2127 
2128 namespace detail {
2129 namespace attr_utils {
2130 void GetAttr(const OrtKernelInfo* p, const char* name, float&);
2131 void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&);
2132 void GetAttr(const OrtKernelInfo* p, const char* name, std::string&);
2133 void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&);
2134 void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&);
2135 } // namespace attr_utils
2136 
2137 template <typename T>
2138 struct KernelInfoImpl : Base<T> {
2139  using B = Base<T>;
2140  using B::B;
2141 
2142  KernelInfo Copy() const;
2143 
2144  template <typename R> // R is only implemented for float, int64_t, and string
2145  R GetAttribute(const char* name) const {
2146  R val;
2147  attr_utils::GetAttr(this->p_, name, val);
2148  return val;
2149  }
2150 
2151  template <typename R> // R is only implemented for std::vector<float>, std::vector<int64_t>
2152  std::vector<R> GetAttributes(const char* name) const {
2153  std::vector<R> result;
2154  attr_utils::GetAttrs(this->p_, name, result);
2155  return result;
2156  }
2157 
2158  Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const;
2159 
2160  size_t GetInputCount() const;
2161  size_t GetOutputCount() const;
2162 
2163  std::string GetInputName(size_t index) const;
2164  std::string GetOutputName(size_t index) const;
2165 
2166  TypeInfo GetInputTypeInfo(size_t index) const;
2167  TypeInfo GetOutputTypeInfo(size_t index) const;
2168 
2169  ConstValue GetTensorConstantInput(size_t index, int* is_constant) const;
2170 
2171  std::string GetNodeName() const;
2172  Logger GetLogger() const;
2173 };
2174 
2175 } // namespace detail
2176 
2178 
2179 /// <summary>
2180 /// This struct owns the OrtKernInfo* pointer when a copy is made.
2181 /// For convenient wrapping of OrtKernelInfo* passed to kernel constructor
2182 /// and query attributes, warp the pointer with Ort::Unowned<KernelInfo> instance
2183 /// so it does not destroy the pointer the kernel does not own.
2184 /// </summary>
2185 struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
2186  explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later
2187  explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance
2188  ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; }
2189 };
2190 
2191 /// <summary>
2192 /// Create and own custom defined operation.
2193 /// </summary>
2194 struct Op : detail::Base<OrtOp> {
2195  explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used
2196 
2197  explicit Op(OrtOp*); ///< Take ownership of the OrtOp
2198 
2199  static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain,
2200  int version, const char** type_constraint_names,
2201  const ONNXTensorElementDataType* type_constraint_values,
2202  size_t type_constraint_count,
2203  const OpAttr* attr_values,
2204  size_t attr_count,
2205  size_t input_count, size_t output_count);
2206 
2207  void Invoke(const OrtKernelContext* context,
2208  const Value* input_values,
2209  size_t input_count,
2210  Value* output_values,
2211  size_t output_count);
2212 
2213  // For easier refactoring
2214  void Invoke(const OrtKernelContext* context,
2215  const OrtValue* const* input_values,
2216  size_t input_count,
2217  OrtValue* const* output_values,
2218  size_t output_count);
2219 };
2220 
2221 /// <summary>
2222 /// Provide access to per-node attributes and input shapes, so one could compute and set output shapes.
2223 /// </summary>
2226  SymbolicInteger(int64_t i) : i_(i), is_int_(true) {};
2227  SymbolicInteger(const char* s) : s_(s), is_int_(false) {};
2228  SymbolicInteger(const SymbolicInteger&) = default;
2229  SymbolicInteger(SymbolicInteger&&) = default;
2230 
2231  SymbolicInteger& operator=(const SymbolicInteger&) = default;
2233 
2234  bool operator==(const SymbolicInteger& dim) const {
2235  if (is_int_ == dim.is_int_) {
2236  if (is_int_) {
2237  return i_ == dim.i_;
2238  } else {
2239  return std::string{s_} == std::string{dim.s_};
2240  }
2241  }
2242  return false;
2243  }
2244 
2245  bool IsInt() const { return is_int_; }
2246  int64_t AsInt() const { return i_; }
2247  const char* AsSym() const { return s_; }
2248 
2249  static constexpr int INVALID_INT_DIM = -2;
2250 
2251  private:
2252  union {
2253  int64_t i_;
2254  const char* s_;
2255  };
2256  bool is_int_;
2257  };
2258 
2259  using Shape = std::vector<SymbolicInteger>;
2260 
2261  ShapeInferContext(const OrtApi* ort_api, OrtShapeInferContext* ctx);
2262 
2263  const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); }
2264 
2265  size_t GetInputCount() const { return input_shapes_.size(); }
2266 
2267  Status SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
2268 
2269  int64_t GetAttrInt(const char* attr_name);
2270 
2271  using Ints = std::vector<int64_t>;
2272  Ints GetAttrInts(const char* attr_name);
2273 
2274  float GetAttrFloat(const char* attr_name);
2275 
2276  using Floats = std::vector<float>;
2277  Floats GetAttrFloats(const char* attr_name);
2278 
2279  std::string GetAttrString(const char* attr_name);
2280 
2281  using Strings = std::vector<std::string>;
2282  Strings GetAttrStrings(const char* attr_name);
2283 
2284  private:
2285  const OrtOpAttr* GetAttrHdl(const char* attr_name) const;
2286  const OrtApi* ort_api_;
2287  OrtShapeInferContext* ctx_;
2288  std::vector<Shape> input_shapes_;
2289 };
2290 
2292 
2293 #define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1
2294 
2295 template <typename TOp, typename TKernel, bool WithStatus = false>
2296 struct CustomOpBase : OrtCustomOp {
2298  OrtCustomOp::version = ORT_API_VERSION;
2299  OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
2300 
2301  OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
2302 
2303  OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
2304  OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
2305  OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); };
2306 
2307  OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
2308  OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
2309 
2310 #if defined(_MSC_VER) && !defined(__clang__)
2311 #pragma warning(push)
2312 #pragma warning(disable : 26409)
2313 #endif
2314  OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
2315 #if defined(_MSC_VER) && !defined(__clang__)
2316 #pragma warning(pop)
2317 #endif
2318  OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
2319  OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
2320 
2321  OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicInputMinArity(); };
2322  OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicInputHomogeneity()); };
2323  OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicOutputMinArity(); };
2324  OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicOutputHomogeneity()); };
2325 #ifdef __cpp_if_constexpr
2326  if constexpr (WithStatus) {
2327 #else
2328  if (WithStatus) {
2329 #endif
2330  OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
2331  return static_cast<const TOp*>(this_)->CreateKernelV2(*api, info, op_kernel);
2332  };
2333  OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
2334  return static_cast<TKernel*>(op_kernel)->ComputeV2(context);
2335  };
2336  } else {
2337  OrtCustomOp::CreateKernelV2 = nullptr;
2338  OrtCustomOp::KernelComputeV2 = nullptr;
2339 
2340  OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
2341  OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
2342  static_cast<TKernel*>(op_kernel)->Compute(context);
2343  };
2344  }
2345 
2346  SetShapeInferFn<TOp>(0);
2347 
2348  OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) {
2349  return static_cast<const TOp*>(this_)->start_ver_;
2350  };
2351 
2352  OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) {
2353  return static_cast<const TOp*>(this_)->end_ver_;
2354  };
2355 
2356  OrtCustomOp::GetMayInplace = nullptr;
2357  OrtCustomOp::ReleaseMayInplace = nullptr;
2358  OrtCustomOp::GetAliasMap = nullptr;
2359  OrtCustomOp::ReleaseAliasMap = nullptr;
2360  }
2361 
2362  // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
2363  const char* GetExecutionProviderType() const { return nullptr; }
2364 
2365  // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
2366  // (inputs and outputs are required by default)
2367  OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
2368  return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
2369  }
2370 
2371  OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
2372  return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
2373  }
2374 
2375  // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault
2376  OrtMemType GetInputMemoryType(size_t /*index*/) const {
2377  return OrtMemTypeDefault;
2378  }
2379 
2380  // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input
2381  // should expect at least 1 argument.
2383  return 1;
2384  }
2385 
2386  // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments
2387  // to a variadic input should be of the same type.
2389  return true;
2390  }
2391 
2392  // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output
2393  // should produce at least 1 output value.
2395  return 1;
2396  }
2397 
2398  // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values
2399  // produced by a variadic output should be of the same type.
2401  return true;
2402  }
2403 
2404  // Declare list of session config entries used by this Custom Op.
2405  // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs().
2406  // This default implementation returns an empty vector of config entries.
2407  std::vector<std::string> GetSessionConfigKeys() const {
2408  return std::vector<std::string>{};
2409  }
2410 
2411  template <typename C>
2412  decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) {
2413  OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
2414  ShapeInferContext ctx(&GetApi(), ort_ctx);
2415  return C::InferOutputShape(ctx);
2416  };
2417  return {};
2418  }
2419 
2420  template <typename C>
2421  void SetShapeInferFn(...) {
2422  OrtCustomOp::InferOutputShapeFn = {};
2423  }
2424 
2425  protected:
2426  // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
2427  void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
2428 
2429  int start_ver_ = 1;
2430  int end_ver_ = MAX_CUSTOM_OP_END_VER;
2431 };
2432 
2433 } // namespace Ort
2434 
2435 #include "onnxruntime_cxx_inline.h"
constexpr Float8E4M3FNUZ_t() noexcept
std::vector< int64_t > Ints
UnownedSession GetUnowned() const
std::string GetBuildInfoString()
This function returns the onnxruntime build information: including git branch, git commit id...
GLuint GLsizei const GLchar * message
Definition: glcorearb.h:2543
AllocatorWithDefaultOptions(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
Logger(std::nullptr_t)
This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object has no...
SequenceTypeInfo(OrtSequenceTypeInfo *p)
TypeInfo(std::nullptr_t)
Create an empty TypeInfo object, must be assigned a valid one to be used.
constexpr Base(contained_type *p) noexcept
bool IsNaN() const noexcept
Tests if the value is NaN
Float16_t Abs() const noexcept
Creates an instance that represents absolute value.
std::string GetErrorMessage() const
std::vector< std::string > Strings
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const
BFloat16_t Negate() const noexcept
Creates a new instance with the sign flipped.
constexpr bool operator!=(const Float8E5M2FNUZ_t &rhs) const noexcept
Type information that may contain either TensorTypeAndShapeInfo or the information about contained se...
TensorTypeAndShapeInfo(std::nullptr_t)
Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used...
Value(OrtValue *p)
Used for interop with the C API.
Env(OrtEnv *p)
C Interop Helper.
std::vector< float > Floats
Value(std::nullptr_t)
Create an empty Value object, must be assigned a valid one to be used.
Custom Op Domain.
GLboolean * data
Definition: glcorearb.h:131
void swap(UT::ArraySet< Key, MULTI, MAX_LOAD_FACTOR_256, Clearer, Hash, KeyEqual > &a, UT::ArraySet< Key, MULTI, MAX_LOAD_FACTOR_256, Clearer, Hash, KeyEqual > &b)
Definition: UT_ArraySet.h:1699
const GLdouble * v
Definition: glcorearb.h:837
Base & operator=(Base &&v) noexcept
bool IsSubnormal() const noexcept
Tests if the value is subnormal (denormal).
GLsizei const GLfloat * value
Definition: glcorearb.h:824
Float16_t()=default
Default constructor
bool GetVariadicInputHomogeneity() const
constexpr Float8E5M2_t(uint8_t v) noexcept
std::unique_ptr< char, detail::AllocatedFree > AllocatedStringPtr
unique_ptr typedef used to own strings allocated by OrtAllocators and release them at the end of the ...
Used internally by the C++ API. C++ wrapper types inherit from this. This is a zero cost abstraction ...
ConstMemoryInfo GetConst() const
Take ownership of a pointer created by C Api.
const Shape & GetInputShape(size_t indice) const
std::vector< SymbolicInteger > Shape
Wrapper around ::OrtModelMetadata.
GLint level
Definition: glcorearb.h:108
BFloat16_t(float v) noexcept
__ctor from float. Float is converted into bfloat16 16-bit representation.
TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *p)
Used for interop with the C API.
Wrapper around ::OrtMapTypeInfo.
constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept
GLdouble s
Definition: glad.h:3009
This struct provides life time management for custom op attribute
float ToFloatImpl() const noexcept
Converts bfloat16 to float
TypeInfo(OrtTypeInfo *p)
detail::SequenceTypeInfoImpl< detail::Unowned< const OrtSequenceTypeInfo >> ConstSequenceTypeInfo
Wrapper around OrtValue.
MapTypeInfo(OrtMapTypeInfo *p)
static bool AreZero(const Float16Impl &lhs, const Float16Impl &rhs) noexcept
IEEE defines that positive and negative zero are equal, this gives us a quick equality check for two ...
bool operator!=(const BFloat16_t &rhs) const noexcept
std::vector< R > GetAttributes(const char *name) const
**But if you need a result
Definition: thread.h:622
it is a structure that represents the configuration of an arena based allocator
Provide access to per-node attributes and input shapes, so one could compute and set output shapes...
OCIOEXPORT void LogMessage(LoggingLevel level, const char *message)
Log a message using the library logging function.
IoBinding(std::nullptr_t)
Create an empty object for convenience. Sometimes, we want to initialize members later.
The Env (Environment)
static const OrtApi * api_
bool IsNegative() const noexcept
Checks if the value is negative
OrtMemType GetInputMemoryType(size_t) const
Env(std::nullptr_t)
Create an empty Env object, must be assigned a valid one to be used.
float8e4m3fnuz (Float8 Floating Point) data type
GLuint buffer
Definition: glcorearb.h:660
constexpr bool operator==(const Float8E4M3FN_t &rhs) const noexcept
ConstSession GetConst() const
float8e4m3fn (Float8 Floating Point) data type
void GetAttrs(const OrtKernelInfo *p, const char *name, std::vector< int64_t > &)
static constexpr uint16_t ToUint16Impl(float v) noexcept
Converts from float to uint16_t float16 representation
ModelMetadata(std::nullptr_t)
Create an empty ModelMetadata object, must be assigned a valid one to be used.
bool IsFinite() const noexcept
Tests if the value is finite
GLuint GLsizei const GLuint const GLintptr * offsets
Definition: glcorearb.h:2621
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
std::vector< Value > GetOutputValuesHelper(const OrtIoBinding *binding, OrtAllocator *)
bool operator==(const BaseDimensions< T > &a, const BaseDimensions< Y > &b)
Definition: Dimensions.h:137
Wrapper around ::OrtAllocator.
bool operator==(const SymbolicInteger &dim) const
constexpr Base()=default
constexpr Float8E5M2FNUZ_t() noexcept
bool IsPositiveInfinity() const noexcept
Tests if the value represents positive infinity.
Wrapper around OrtMemoryInfo.
std::vector< std::string > GetAvailableProviders()
This is a C++ wrapper for OrtApi::GetAvailableProviders() and returns a vector of strings representin...
GLint GLint GLsizei GLint GLenum GLenum type
Definition: glcorearb.h:108
Op(std::nullptr_t)
Create an empty Operator object, must be assigned a valid one to be used.
bool IsNormal() const noexcept
Tests if the value is normal (not zero, subnormal, infinite, or NaN).
bool operator<(const BFloat16_t &rhs) const noexcept
detail::MapTypeInfoImpl< detail::Unowned< const OrtMapTypeInfo >> ConstMapTypeInfo
Wrapper around ::OrtIoBinding.
bool IsFinite() const noexcept
Tests if the value is finite
The Status that holds ownership of OrtStatus received from C API Use it to safely destroy OrtStatus* ...
BFloat16_t()=default
void GetAttr(const OrtKernelInfo *p, const char *name, std::string &)
OrtKernelContext * GetOrtKernelContext() const
float ToFloat() const noexcept
Converts float16 to float
Shared implementation between public and internal classes. CRTP pattern.
constexpr bool operator==(const Float8E5M2_t &rhs) const noexcept
detail::SessionOptionsImpl< detail::Unowned< OrtSessionOptions >> UnownedSessionOptions
constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept
bool IsNegative() const noexcept
Checks if the value is negative
All C++ methods that can fail will throw an exception of this type.
const char * what() const noexceptoverride
A generic, discriminated value, whose type may be queried dynamically.
Definition: Value.h:45
Base(Base &&v) noexcept
typename Unowned< T >::Type contained_type
CustomOpDomain(std::nullptr_t)
Create an empty CustomOpDomain object, must be assigned a valid one to be used.
Wrapper around ::OrtSequenceTypeInfo.
This class wraps a raw pointer OrtKernelContext* that is being passed to the custom kernel Compute() ...
constexpr std::enable_if< I< type_count_base< T >::value, int >::type tuple_type_size(){return subtype_count< typename std::tuple_element< I, T >::type >::value+tuple_type_size< T, I+1 >);}template< typename T > struct type_count< T, typename std::enable_if< is_tuple_like< T >::value >::type >{static constexpr int value{tuple_type_size< T, 0 >)};};template< typename T > struct subtype_count{static constexpr int value{is_mutable_container< T >::value?expected_max_vector_size:type_count< T >::value};};template< typename T, typename Enable=void > struct type_count_min{static const int value{0};};template< typename T >struct type_count_min< T, typename std::enable_if<!is_mutable_container< T >::value &&!is_tuple_like< T >::value &&!is_wrapper< T >::value &&!is_complex< T >::value &&!std::is_void< T >::value >::type >{static constexpr int value{type_count< T >::value};};template< typename T > struct type_count_min< T, typename std::enable_if< is_complex< T >::value >::type >{static constexpr int value{1};};template< typename T >struct type_count_min< T, typename std::enable_if< is_wrapper< T >::value &&!is_complex< T >::value &&!is_tuple_like< T >::value >::type >{static constexpr int value{subtype_count_min< typename T::value_type >::value};};template< typename T, std::size_t I >constexpr typename std::enable_if< I==type_count_base< T >::value, int >::type tuple_type_size_min(){return 0;}template< typename T, std::size_t I > constexpr typename std::enable_if< I< type_count_base< T >::value, int >::type tuple_type_size_min(){return subtype_count_min< typename std::tuple_element< I, T >::type >::value+tuple_type_size_min< T, I+1 >);}template< typename T > struct type_count_min< T, typename std::enable_if< is_tuple_like< T >::value >::type >{static constexpr int value{tuple_type_size_min< T, 0 >)};};template< typename T > struct subtype_count_min{static constexpr int value{is_mutable_container< T >::value?((type_count< T >::value< expected_max_vector_size)?type_count< T >::value:0):type_count_min< T >::value};};template< typename T, typename Enable=void > struct expected_count{static const int value{0};};template< typename T >struct expected_count< T, typename std::enable_if<!is_mutable_container< T >::value &&!is_wrapper< T >::value &&!std::is_void< T >::value >::type >{static constexpr int value{1};};template< typename T > struct expected_count< T, typename std::enable_if< is_mutable_container< T >::value >::type >{static constexpr int value{expected_max_vector_size};};template< typename T >struct expected_count< T, typename std::enable_if<!is_mutable_container< T >::value &&is_wrapper< T >::value >::type >{static constexpr int value{expected_count< typename T::value_type >::value};};enum class object_category:int{char_value=1, integral_value=2, unsigned_integral=4, enumeration=6, boolean_value=8, floating_point=10, number_constructible=12, double_constructible=14, integer_constructible=16, string_assignable=23, string_constructible=24, other=45, wrapper_value=50, complex_number=60, tuple_value=70, container_value=80,};template< typename T, typename Enable=void > struct classify_object{static constexpr object_category value{object_category::other};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_integral< T >::value &&!std::is_same< T, char >::value &&std::is_signed< T >::value &&!is_bool< T >::value &&!std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::integral_value};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_integral< T >::value &&std::is_unsigned< T >::value &&!std::is_same< T, char >::value &&!is_bool< T >::value >::type >{static constexpr object_category value{object_category::unsigned_integral};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_same< T, char >::value &&!std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::char_value};};template< typename T > struct classify_object< T, typename std::enable_if< is_bool< T >::value >::type >{static constexpr object_category value{object_category::boolean_value};};template< typename T > struct classify_object< T, typename std::enable_if< std::is_floating_point< T >::value >::type >{static constexpr object_category value{object_category::floating_point};};template< typename T >struct classify_object< T, typename std::enable_if<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&std::is_assignable< T &, std::string >::value >::type >{static constexpr object_category value{object_category::string_assignable};};template< typename T >struct classify_object< T, typename std::enable_if<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&!std::is_assignable< T &, std::string >::value &&(type_count< T >::value==1)&&std::is_constructible< T, std::string >::value >::type >{static constexpr object_category value{object_category::string_constructible};};template< typename T > struct classify_object< T, typename std::enable_if< std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::enumeration};};template< typename T > struct classify_object< T, typename std::enable_if< is_complex< T >::value >::type >{static constexpr object_category value{object_category::complex_number};};template< typename T > struct uncommon_type{using type=typename std::conditional<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&!std::is_assignable< T &, std::string >::value &&!std::is_constructible< T, std::string >::value &&!is_complex< T >::value &&!is_mutable_container< T >::value &&!std::is_enum< T >::value, std::true_type, std::false_type >::type;static constexpr bool value=type::value;};template< typename T >struct classify_object< T, typename std::enable_if<(!is_mutable_container< T >::value &&is_wrapper< T >::value &&!is_tuple_like< T >::value &&uncommon_type< T >::value)>::type >{static constexpr object_category value{object_category::wrapper_value};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&is_direct_constructible< T, double >::value &&is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::number_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&!is_direct_constructible< T, double >::value &&is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::integer_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&is_direct_constructible< T, double >::value &&!is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::double_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< is_tuple_like< T >::value &&((type_count< T >::value >=2 &&!is_wrapper< T >::value)||(uncommon_type< T >::value &&!is_direct_constructible< T, double >::value &&!is_direct_constructible< T, int >::value)||(uncommon_type< T >::value &&type_count< T >::value >=2))>::type >{static constexpr object_category value{object_category::tuple_value};};template< typename T > struct classify_object< T, typename std::enable_if< is_mutable_container< T >::value >::type >{static constexpr object_category value{object_category::container_value};};template< typename T, enable_if_t< classify_object< T >::value==object_category::char_value, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"CHAR";}template< typename T, enable_if_t< classify_object< T >::value==object_category::integral_value||classify_object< T >::value==object_category::integer_constructible, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"INT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::unsigned_integral, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"UINT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::floating_point||classify_object< T >::value==object_category::number_constructible||classify_object< T >::value==object_category::double_constructible, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"FLOAT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::enumeration, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"ENUM";}template< typename T, enable_if_t< classify_object< T >::value==object_category::boolean_value, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"BOOLEAN";}template< typename T, enable_if_t< classify_object< T >::value==object_category::complex_number, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"COMPLEX";}template< typename T, enable_if_t< classify_object< T >::value >=object_category::string_assignable &&classify_object< T >::value<=object_category::other, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"TEXT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value >=2, detail::enabler >=detail::dummy >std::string type_name();template< typename T, enable_if_t< classify_object< T >::value==object_category::container_value||classify_object< T >::value==object_category::wrapper_value, detail::enabler >=detail::dummy >std::string type_name();template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value==1, detail::enabler >=detail::dummy >inline std::string type_name(){return type_name< typename std::decay< typename std::tuple_element< 0, T >::type >::type >);}template< typename T, std::size_t I >inline typename std::enable_if< I==type_count_base< T >::value, std::string >::type tuple_name(){return std::string{};}template< typename T, std::size_t I >inline typename std::enable_if<(I< type_count_base< T >::value), std::string >::type tuple_name(){auto str=std::string{type_name< typename std::decay< typename std::tuple_element< I, T >::type >::type >)}+ ','+tuple_name< T, I+1 >);if(str.back()== ',') str.pop_back();return str;}template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value >=2, detail::enabler > > std::string type_name()
Recursively generate the tuple type name.
Definition: CLI11.h:1729
GLint GLint GLsizei GLint GLenum format
Definition: glcorearb.h:108
::OrtRunOptions RunOptions
KernelInfo(std::nullptr_t)
Create an empty instance to initialize later.
detail::TypeInfoImpl< detail::Unowned< const OrtTypeInfo >> ConstTypeInfo
Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value. Provides access to const OrtTypeInfo APIs.
Represents native memory allocation coming from one of the OrtAllocators registered with OnnxRuntime...
bool IsNegativeInfinity() const noexcept
Tests if the value represents negative infinity
float ToFloatImpl() const noexcept
Converts float16 to float
Session(std::nullptr_t)
Create an empty Session object, must be assigned a valid one to be used.
IEEE 754 half-precision floating point data type.
std::vector< std::string > GetSessionConfigKeys() const
constexpr Float8E4M3FN_t() noexcept
GLint location
Definition: glcorearb.h:805
SessionOptions(OrtSessionOptions *p)
Create and own custom defined operation.
ConstIoBinding GetConst() const
bool IsPositiveInfinity() const noexcept
Tests if the value represents positive infinity.
float8e5m2fnuz (Float8 Floating Point) data type
Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT_V...
constexpr Float8E5M2_t() noexcept
constexpr bool operator!=(const Float8E4M3FNUZ_t &rhs) const noexcept
GLuint const GLchar * name
Definition: glcorearb.h:786
int GetVariadicInputMinArity() const
Allocator(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
detail::ConstSessionOptionsImpl< detail::Unowned< const OrtSessionOptions >> ConstSessionOptions
RunOptions(std::nullptr_t)
Create an empty RunOptions object, must be assigned a valid one to be used.
Float16_t Negate() const noexcept
Creates a new instance with the sign flipped.
OCIOEXPORT const char * GetVersion()
Get the version number for the library, as a dot-delimited string (e.g., "1.0.0").
AllocatedFree(OrtAllocator *allocator)
bool IsNaNOrZero() const noexcept
Tests if the value is NaN or zero. Useful for comparisons.
ORT_DEFINE_RELEASE(Allocator)
contained_type * p_
bool GetVariadicOutputHomogeneity() const
bool IsInfinity() const noexcept
Tests if the value is either positive or negative infinity.
constexpr Base(contained_type *p) noexcept
BFloat16_t Abs() const noexcept
Creates an instance that represents absolute value.
Float16_t(float v) noexcept
__ctor from float. Float is converted into float16 16-bit representation.
GT_API const UT_StringHolder version
bool IsSubnormal() const noexcept
Tests if the value is subnormal (denormal).
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const
This struct owns the OrtKernInfo* pointer when a copy is made. For convenient wrapping of OrtKernelIn...
GLsizeiptr size
Definition: glcorearb.h:664
OrtErrorCode GetErrorCode() const
This class represents an ONNX Runtime logger that can be used to log information with an associated s...
IMATH_NAMESPACE::V2f IMATH_NAMESPACE::Box2i std::string this attribute is obsolete as of OpenEXR v3 float
ArenaCfg(std::nullptr_t)
Create an empty ArenaCfg object, must be assigned a valid one to be used.
Ort::Status(*)(Ort::ShapeInferContext &) ShapeInferFn
GLenum GLsizei GLsizei GLint * values
Definition: glcorearb.h:1602
void operator()(void *ptr) const
ConstValue GetConst() const
constexpr bool operator==(const Float8E5M2FNUZ_t &rhs) const noexcept
The ThreadingOptions.
float ToFloat() const noexcept
Converts bfloat16 to float
bool IsNaN() const noexcept
Tests if the value is NaN
constexpr bool operator!=(const Float8E5M2_t &rhs) const noexcept
ModelMetadata(OrtModelMetadata *p)
LoraAdapter holds a set of Lora Parameters loaded from a single file.
bool operator==(const BFloat16_t &rhs) const noexcept
MemoryInfo(OrtMemoryInfo *p)
bool IsOK() const noexcept
Returns true if instance represents an OK (non-error) status.
GLuint index
Definition: glcorearb.h:786
auto ptr(T p) -> const void *
Definition: format.h:4331
GLuint GLfloat * val
Definition: glcorearb.h:1608
static constexpr Float16_t FromBits(uint16_t v) noexcept
Explicit conversion to uint16_t representation of float16.
float8e5m2 (Float8 Floating Point) data type
**If you just want to fire and args
Definition: thread.h:618
bool IsNormal() const noexcept
Tests if the value is normal (not zero, subnormal, infinite, or NaN).
constexpr Float8E4M3FN_t(uint8_t v) noexcept
int GetVariadicOutputMinArity() const
constexpr bool operator==(const Float8E4M3FNUZ_t &rhs) const noexcept
bool IsNaNOrZero() const noexcept
Tests if the value is NaN or zero. Useful for comparisons.
std::string GetVersionString()
This function returns the onnxruntime version string
static bool AreZero(const BFloat16Impl &lhs, const BFloat16Impl &rhs) noexcept
IEEE defines that positive and negative zero are equal, this gives us a quick equality check for two ...
OrtErrorCode GetOrtErrorCode() const
#define MAX_CUSTOM_OP_END_VER
const char * GetExecutionProviderType() const
Base & operator=(Base &&v) noexcept
Wrapper around ::OrtTensorTypeAndShapeInfo.
MemoryInfo(std::nullptr_t)
No instance is created.
Exception(std::string &&string, OrtErrorCode code)
std::string MakeCustomOpConfigEntryKey(const char *custom_op_name, const char *config)
CustomOpConfigs.
MapTypeInfo(std::nullptr_t)
Create an empty MapTypeInfo object, must be assigned a valid one to be used.
Wrapper around ::OrtSessionOptions.
contained_type * release()
Relinquishes ownership of the contained C object pointer The underlying object is not destroyed...
SessionOptions(std::nullptr_t)
Create an empty SessionOptions object, must be assigned a valid one to be used.
Base & operator=(const Base &)=delete
UnownedValue GetUnowned() const
R GetAttribute(const char *name) const
Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime.
Wrapper around ::OrtSession.
Options for the CUDA provider that are passed to SessionOptionsAppendExecutionProvider_CUDA_V2. Please note that this struct is similar to OrtCUDAProviderOptions but only to be used internally. Going forward, new cuda provider options are to be supported via this struct and usage of the publicly defined OrtCUDAProviderOptions will be deprecated over time. User can only get the instance of OrtCUDAProviderOptionsV2 via CreateCUDAProviderOptions.
static constexpr BFloat16_t FromBits(uint16_t v) noexcept
Explicit conversion to uint16_t representation of bfloat16.
bfloat16 (Brain Floating Point) data type
ConstKernelInfo GetConst() const
LoraAdapter(std::nullptr_t)
bool IsInfinity() const noexcept
Tests if the value is either positive or negative infinity.
Shared implementation between public and internal classes. CRTP pattern.
SequenceTypeInfo(std::nullptr_t)
Create an empty SequenceTypeInfo object, must be assigned a valid one to be used. ...
Class that represents session configuration entries for one or more custom operators.
Status(std::nullptr_t) noexcept
Create an empty object, must be assigned a valid one to be used.
constexpr bool operator!=(const Float8E4M3FN_t &rhs) const noexcept
Definition: format.h:1821
Definition: format.h:4365
std::vector< std::string > GetOutputNamesHelper(const OrtIoBinding *binding, OrtAllocator *)
static uint16_t ToUint16Impl(float v) noexcept
Converts from float to uint16_t float16 representation
ConstTensorTypeAndShapeInfo GetConst() const
bool IsNegativeInfinity() const noexcept
Tests if the value represents negative infinity
UnownedIoBinding GetUnowned() const