HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
onnxruntime_cxx_inline.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 // Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead.
5 // If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead.
6 //
7 // These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
8 // the main C++ file with implementation details.
9 
10 #include <algorithm>
11 #include <functional>
12 #include <iterator>
13 #include <type_traits>
14 
15 // Convert OrtStatus to Ort::Status and return
16 // instead of throwing
17 #define ORT_CXX_RETURN_ON_API_FAIL(expression) \
18  { \
19  auto ort_status = (expression); \
20  if (ort_status) { \
21  return Ort::Status(ort_status); \
22  } \
23  }
24 
25 #ifdef __cpp_if_constexpr
26 #define ORT_CXX_IF_CONSTEXPR if constexpr
27 #else
28 #define ORT_CXX_IF_CONSTEXPR if
29 #endif
30 
31 namespace Ort {
32 
33 namespace detail {
34 inline void ThrowStatus(const Status& st) {
35  std::string error_message = st.GetErrorMessage();
36  OrtErrorCode error_code = st.GetErrorCode();
37  ORT_CXX_API_THROW(std::move(error_message), error_code);
38 }
39 } // namespace detail
40 
41 inline void ThrowOnError(OrtStatus* ort_status) {
42  if (ort_status) {
43  Ort::Status st(ort_status);
45  }
46 }
47 
48 inline void ThrowOnError(const Status& st) {
49  if (st) {
51  }
52 }
53 
54 inline Status::Status(OrtStatus* status) noexcept : Base<OrtStatus>{status} {
55 }
56 
57 inline Status::Status(const std::exception& e) noexcept {
58  p_ = GetApi().CreateStatus(ORT_FAIL, e.what());
59 }
60 
61 inline Status::Status(const Exception& e) noexcept {
62  p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what());
63 }
64 
65 inline Status::Status(const char* message, OrtErrorCode code) noexcept {
66  p_ = GetApi().CreateStatus(code, message);
67 }
68 
69 inline std::string Status::GetErrorMessage() const {
70  std::string message(GetApi().GetErrorMessage(p_));
71  return message;
72 }
73 
74 inline OrtErrorCode Status::GetErrorCode() const {
75  return GetApi().GetErrorCode(p_);
76 }
77 
78 inline bool Status::IsOK() const noexcept {
79  return (p_ == nullptr);
80 }
81 
82 // This template converts a C++ type into it's ONNXTensorElementDataType
83 template <typename T>
84 struct TypeToTensorType;
85 template <>
86 struct TypeToTensorType<float> {
87  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
88 };
89 template <>
90 struct TypeToTensorType<Float16_t> {
91  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
92 };
93 template <>
94 struct TypeToTensorType<BFloat16_t> {
95  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
96 };
97 template <>
98 struct TypeToTensorType<double> {
99  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
100 };
101 template <>
102 struct TypeToTensorType<int8_t> {
103  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
104 };
105 template <>
106 struct TypeToTensorType<int16_t> {
107  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
108 };
109 template <>
110 struct TypeToTensorType<int32_t> {
111  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
112 };
113 template <>
114 struct TypeToTensorType<int64_t> {
115  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
116 };
117 template <>
118 struct TypeToTensorType<uint8_t> {
119  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
120 };
121 template <>
122 struct TypeToTensorType<uint16_t> {
123  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
124 };
125 template <>
126 struct TypeToTensorType<uint32_t> {
127  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
128 };
129 template <>
130 struct TypeToTensorType<uint64_t> {
131  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
132 };
133 template <>
134 struct TypeToTensorType<bool> {
135  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
136 };
137 
138 template <>
139 struct TypeToTensorType<Float8E4M3FN_t> {
140  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN;
141 };
142 template <>
143 struct TypeToTensorType<Float8E4M3FNUZ_t> {
144  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ;
145 };
146 template <>
147 struct TypeToTensorType<Float8E5M2_t> {
148  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2;
149 };
150 template <>
151 struct TypeToTensorType<Float8E5M2FNUZ_t> {
152  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ;
153 };
154 
155 inline bool BFloat16_t::operator==(const BFloat16_t& rhs) const noexcept {
156  if (IsNaN() || rhs.IsNaN()) {
157  // IEEE defines that NaN is not equal to anything, including itself.
158  return false;
159  }
160  return val == rhs.val;
161 }
162 
163 inline bool BFloat16_t::operator<(const BFloat16_t& rhs) const noexcept {
164  if (IsNaN() || rhs.IsNaN()) {
165  // IEEE defines that NaN is unordered with respect to everything, including itself.
166  return false;
167  }
168 
169  const bool left_is_negative = IsNegative();
170  if (left_is_negative != rhs.IsNegative()) {
171  // When the signs of left and right differ, we know that left is less than right if it is
172  // the negative value. The exception to this is if both values are zero, in which case IEEE
173  // says they should be equal, even if the signs differ.
174  return left_is_negative && !AreZero(*this, rhs);
175  }
176  return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
177 }
178 
179 inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
180  : allocator_(allocator), p_(p), size_(size) {
181 }
182 
184  if (p_ != nullptr) {
185  // We do not throw out of destructor
186  auto ret = GetApi().AllocatorFree(allocator_, p_);
187  static_cast<void>(ret);
188  }
189 }
190 
191 inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) {
192  *this = std::move(o);
193 }
194 
196  OrtAllocator* alloc = nullptr;
197  void* p = nullptr;
198  size_t sz = 0;
199 
200  // Swap out this
201  std::swap(alloc, allocator_);
202  std::swap(p, p_);
203  std::swap(sz, size_);
204 
205  // Swap with incoming
206  std::swap(allocator_, o.allocator_);
207  std::swap(p_, o.p_);
208  std::swap(size_, o.size_);
209 
210  // Destroy this instance if needed
211  MemoryAllocation this_alloc(alloc, p, sz);
212  return *this;
213 }
214 
215 namespace detail {
216 
217 template <typename T>
218 inline void* AllocatorImpl<T>::Alloc(size_t size) {
219  void* out;
220  ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
221  return out;
222 }
223 
224 template <typename T>
226  void* out;
227  ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
228  MemoryAllocation result(this->p_, out, size);
229  return result;
230 }
231 
232 template <typename T>
233 inline void AllocatorImpl<T>::Free(void* p) {
234  ThrowOnError(GetApi().AllocatorFree(this->p_, p));
235 }
236 
237 template <typename T>
239  const OrtMemoryInfo* out;
240  ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out));
241  return ConstMemoryInfo{out};
242 }
243 
244 } // namespace detail
245 
247  ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_));
248 }
249 
250 inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) {
251  ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_));
252 }
253 
254 namespace detail {
255 
256 template <typename T>
257 inline std::string MemoryInfoImpl<T>::GetAllocatorName() const {
258  const char* name = nullptr;
259  ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name));
260  return std::string(name);
261 }
262 
263 template <typename T>
264 inline OrtAllocatorType MemoryInfoImpl<T>::GetAllocatorType() const {
265  OrtAllocatorType type;
266  ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type));
267  return type;
268 }
269 
270 template <typename T>
271 inline int MemoryInfoImpl<T>::GetDeviceId() const {
272  int id = 0;
273  ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id));
274  return id;
275 }
276 
277 template <typename T>
278 inline OrtMemoryInfoDeviceType MemoryInfoImpl<T>::GetDeviceType() const {
279  OrtMemoryInfoDeviceType type;
280  GetApi().MemoryInfoGetDeviceType(this->p_, &type);
281  return type;
282 }
283 
284 template <typename T>
285 inline OrtMemType MemoryInfoImpl<T>::GetMemoryType() const {
286  OrtMemType type;
287  ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type));
288  return type;
289 }
290 
291 template <typename T>
292 template <typename U>
294  int comp_result = 0;
295  ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result));
296  return comp_result == 0;
297 }
298 
299 } // namespace detail
300 
301 inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) {
302  OrtMemoryInfo* p;
303  ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
304  return MemoryInfo(p);
305 }
306 
307 inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
308  ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_));
309 }
310 
311 namespace detail {
312 template <typename T>
313 inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames() const {
314  AllocatorWithDefaultOptions allocator;
315  return binding_utils::GetOutputNamesHelper(this->p_, allocator);
316 }
317 
318 template <typename T>
319 inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames(OrtAllocator* allocator) const {
320  return binding_utils::GetOutputNamesHelper(this->p_, allocator);
321 }
322 
323 template <typename T>
324 inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues() const {
325  AllocatorWithDefaultOptions allocator;
326  return binding_utils::GetOutputValuesHelper(this->p_, allocator);
327 }
328 
329 template <typename T>
330 inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues(OrtAllocator* allocator) const {
331  return binding_utils::GetOutputValuesHelper(this->p_, allocator);
332 }
333 
334 template <typename T>
335 inline void IoBindingImpl<T>::BindInput(const char* name, const Value& value) {
336  ThrowOnError(GetApi().BindInput(this->p_, name, value));
337 }
338 
339 template <typename T>
340 inline void IoBindingImpl<T>::BindOutput(const char* name, const Value& value) {
341  ThrowOnError(GetApi().BindOutput(this->p_, name, value));
342 }
343 
344 template <typename T>
345 inline void IoBindingImpl<T>::BindOutput(const char* name, const OrtMemoryInfo* mem_info) {
346  ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info));
347 }
348 
349 template <typename T>
351  GetApi().ClearBoundInputs(this->p_);
352 }
353 
354 template <typename T>
356  GetApi().ClearBoundOutputs(this->p_);
357 }
358 
359 template <typename T>
361  ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_));
362 }
363 
364 template <typename T>
366  ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_));
367 }
368 
369 namespace binding_utils {
370 inline std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
371  std::vector<std::string> result;
372  auto free_fn = detail::AllocatedFree(allocator);
373  using Ptr = std::unique_ptr<void, decltype(free_fn)>;
374 
375  char* buffer = nullptr;
376  size_t* lengths = nullptr;
377  size_t count = 0;
378  ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count));
379 
380  if (count == 0) {
381  return result;
382  }
383 
384  Ptr buffer_g(buffer, free_fn);
385  Ptr lengths_g(lengths, free_fn);
386 
387  result.reserve(count);
388  for (size_t i = 0; i < count; ++i) {
389  auto sz = *lengths;
390  result.emplace_back(buffer, sz);
391  buffer += sz;
392  ++lengths;
393  }
394  return result;
395 }
396 
397 inline std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
398  std::vector<Value> result;
399  size_t owned = 0;
400  size_t output_count = 0;
401  // Lambda to release the buffer when no longer needed and
402  // make sure that we destroy all instances on exception
403  auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
404  if (buffer) {
405  while (owned < output_count) {
406  auto* p = buffer + owned++;
407  GetApi().ReleaseValue(*p);
408  }
409  allocator->Free(allocator, buffer);
410  }
411  };
412  using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
413 
414  OrtValue** output_buffer = nullptr;
415  ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count));
416  if (output_count == 0) {
417  return result;
418  }
419 
420  Ptr buffer_g(output_buffer, free_fn);
421 
422  result.reserve(output_count);
423  for (size_t i = 0; i < output_count; ++i) {
424  result.emplace_back(output_buffer[i]);
425  ++owned;
426  }
427  return result;
428 }
429 
430 } // namespace binding_utils
431 } // namespace detail
432 
433 inline IoBinding::IoBinding(Session& session) {
434  ThrowOnError(GetApi().CreateIoBinding(session, &this->p_));
435 }
436 
437 inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
438  ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
439 }
440 
442  ThrowOnError(GetApi().CreateThreadingOptions(&p_));
443 }
444 
446  ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads));
447  return *this;
448 }
449 
451  ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads));
452  return *this;
453 }
454 
456  ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning));
457  return *this;
458 }
459 
462  return *this;
463 }
464 
465 inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
466  ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
467  return *this;
468 }
469 
470 inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
471  ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
472  return *this;
473 }
474 
475 inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
476  ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
477  return *this;
478 }
479 
480 inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
481  ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
482  if (strcmp(logid, "onnxruntime-node") == 0) {
483  ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
484  } else {
485  ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
486  }
487 }
488 
489 inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
490  ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_));
491  if (strcmp(logid, "onnxruntime-node") == 0) {
492  ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
493  } else {
494  ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
495  }
496 }
497 
498 inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) {
499  ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_));
500  if (strcmp(logid, "onnxruntime-node") == 0) {
501  ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
502  } else {
503  ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
504  }
505 }
506 
507 inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
508  OrtLoggingLevel logging_level, _In_ const char* logid) {
509  ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_));
510  if (strcmp(logid, "onnxruntime-node") == 0) {
511  ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
512  } else {
513  ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
514  }
515 }
516 
517 inline Env& Env::EnableTelemetryEvents() {
518  ThrowOnError(GetApi().EnableTelemetryEvents(p_));
519  return *this;
520 }
521 
524  return *this;
525 }
526 
527 inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) {
528  ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level));
529  return *this;
530 }
531 
532 inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) {
533  ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg));
534  return *this;
535 }
536 
537 inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg) {
538  std::vector<const char*> keys, values;
539  auto num_entries = options.size();
540  if (num_entries > 0) {
541  keys.reserve(num_entries);
542  values.reserve(num_entries);
543  for (const auto& entry : options) {
544  keys.push_back(entry.first.c_str());
545  values.push_back(entry.second.c_str());
546  }
547  }
548  ThrowOnError(GetApi().CreateAndRegisterAllocatorV2(p_, provider_type.c_str(), mem_info, arena_cfg, keys.data(), values.data(), num_entries));
549  return *this;
550 }
551 
552 inline CustomOpDomain::CustomOpDomain(const char* domain) {
553  ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
554 }
555 
556 inline void CustomOpDomain::Add(const OrtCustomOp* op) {
557  ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
558 }
559 
561  OrtAllocator* allocator) {
562  OrtLoraAdapter* p;
563  ThrowOnError(GetApi().CreateLoraAdapter(adapter_path.c_str(), allocator, &p));
564  return LoraAdapter{p};
565 }
566 
567 inline LoraAdapter LoraAdapter::CreateLoraAdapterFromArray(const void* bytes, size_t num_bytes,
568  OrtAllocator* allocator) {
569  OrtLoraAdapter* p;
570  ThrowOnError(GetApi().CreateLoraAdapterFromArray(bytes, num_bytes, allocator, &p));
571  return LoraAdapter{p};
572 }
573 
575  ThrowOnError(GetApi().CreateRunOptions(&p_));
576 }
577 
579  ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
580  return *this;
581 }
582 
584  ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
585  return *this;
586 }
587 
589  int out;
590  ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
591  return out;
592 }
593 
595  int out;
596  ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out));
597  return out;
598 }
599 
600 inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
601  ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
602  return *this;
603 }
604 
605 inline const char* RunOptions::GetRunTag() const {
606  const char* out;
607  ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
608  return out;
609 }
610 
611 inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) {
612  ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value));
613  return *this;
614 }
615 
617  ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
618  return *this;
619 }
620 
622  ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
623  return *this;
624 }
625 
627  ThrowOnError(GetApi().RunOptionsAddActiveLoraAdapter(p_, adapter));
628  return *this;
629 }
630 
631 namespace detail {
632 
633 template <typename T>
635  OrtSessionOptions* out;
636  ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out));
637  return SessionOptions{out};
638 }
639 
640 template <typename T>
641 inline std::string ConstSessionOptionsImpl<T>::GetConfigEntry(const char* config_key) const {
642  size_t size = 0;
643  // Feed nullptr for the data buffer to query the true size of the string value
644  Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size));
645 
646  std::string out;
647  out.resize(size);
648  Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size));
649  out.resize(size - 1); // remove the terminating character '\0'
650 
651  return out;
652 }
653 
654 template <typename T>
655 inline bool ConstSessionOptionsImpl<T>::HasConfigEntry(const char* config_key) const {
656  int out = 0;
657  Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out));
658  return static_cast<bool>(out);
659 }
660 
661 template <typename T>
662 inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key, const std::string& def) {
663  if (!this->HasConfigEntry(config_key)) {
664  return def;
665  }
666 
667  return this->GetConfigEntry(config_key);
668 }
669 
670 template <typename T>
672  ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads));
673  return *this;
674 }
675 
676 template <typename T>
678  ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads));
679  return *this;
680 }
681 
682 template <typename T>
683 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
684  ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level));
685  return *this;
686 }
687 
688 template <typename T>
690  ThrowOnError(GetApi().SetDeterministicCompute(this->p_, value));
691  return *this;
692 }
693 
694 template <typename T>
695 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
696  ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath));
697  return *this;
698 }
699 
700 template <typename T>
701 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
702  ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix));
703  return *this;
704 }
705 
706 template <typename T>
708  ThrowOnError(GetApi().DisableProfiling(this->p_));
709  return *this;
710 }
711 
712 template <typename T>
714  ThrowOnError(GetApi().EnableOrtCustomOps(this->p_));
715  return *this;
716 }
717 
718 template <typename T>
720  ThrowOnError(GetApi().EnableMemPattern(this->p_));
721  return *this;
722 }
723 
724 template <typename T>
726  ThrowOnError(GetApi().DisableMemPattern(this->p_));
727  return *this;
728 }
729 
730 template <typename T>
732  ThrowOnError(GetApi().EnableCpuMemArena(this->p_));
733  return *this;
734 }
735 
736 template <typename T>
738  ThrowOnError(GetApi().DisableCpuMemArena(this->p_));
739  return *this;
740 }
741 
742 template <typename T>
743 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionMode execution_mode) {
744  ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode));
745  return *this;
746 }
747 
748 template <typename T>
750  ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
751  return *this;
752 }
753 
754 template <typename T>
756  ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level));
757  return *this;
758 }
759 
760 template <typename T>
762  ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain));
763  return *this;
764 }
765 
766 template <typename T>
767 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddConfigEntry(const char* config_key, const char* config_value) {
768  ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value));
769  return *this;
770 }
771 
772 template <typename T>
773 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddInitializer(const char* name, const OrtValue* ort_val) {
774  ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val));
775  return *this;
776 }
777 
778 template <typename T>
780  ThrowOnError(GetApi().DisablePerSessionThreads(this->p_));
781  return *this;
782 }
783 
784 template <typename T>
785 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializers(const std::vector<std::string>& names,
786  const std::vector<Value>& ort_values) {
787  const size_t inputs_num = names.size();
788  if (inputs_num != ort_values.size()) {
789  ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT);
790  }
791  std::vector<const char*> names_ptr;
792  std::vector<const OrtValue*> ort_values_ptrs;
793  names_ptr.reserve(inputs_num);
794  ort_values_ptrs.reserve(inputs_num);
795  for (size_t i = 0; i < inputs_num; ++i) {
796  names_ptr.push_back(names[i].c_str());
797  ort_values_ptrs.push_back(ort_values[i]);
798  }
799  ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num));
800  return *this;
801 }
802 
803 template <typename T>
805  const std::vector<char*>& buffer_array,
806  const std::vector<size_t>& file_lengths) {
807  const size_t inputs_num = file_names.size();
808  if (inputs_num != buffer_array.size()) {
809  ORT_CXX_API_THROW("Expecting names and buffer_array to have the same length", ORT_INVALID_ARGUMENT);
810  }
811  if (inputs_num != file_lengths.size()) {
812  ORT_CXX_API_THROW("Expecting names and file_lengths to have the same length", ORT_INVALID_ARGUMENT);
813  }
814  std::vector<const ORTCHAR_T*> names_ptr;
815  names_ptr.reserve(inputs_num);
816  for (size_t i = 0; i < inputs_num; ++i) {
817  names_ptr.push_back(file_names[i].c_str());
818  }
819  ThrowOnError(GetApi().AddExternalInitializersFromFilesInMemory(this->p_, names_ptr.data(), buffer_array.data(),
820  file_lengths.data(), inputs_num));
821  return *this;
822 }
823 
824 template <typename T>
825 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) {
826  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options));
827  return *this;
828 }
829 
830 template <typename T>
832  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options));
833  return *this;
834 }
835 
836 template <typename T>
837 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
838  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options));
839  return *this;
840 }
841 
842 template <typename T>
843 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
844  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options));
845  return *this;
846 }
847 
848 template <typename T>
850  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options));
851  return *this;
852 }
853 
854 template <typename T>
855 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
856  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options));
857  return *this;
858 }
859 
860 template <typename T>
862  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
863  return *this;
864 }
865 
866 template <typename T>
868  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_Dnnl(this->p_, &provider_options));
869  return *this;
870 }
871 
872 template <typename T>
874  const std::string& provider_name,
875  const std::unordered_map<std::string, std::string>& provider_options) {
876  auto num_entries = provider_options.size();
877  std::vector<const char*> keys, values;
878  if (num_entries > 0) {
879  keys.reserve(num_entries);
880  values.reserve(num_entries);
881 
882  for (const auto& entry : provider_options) {
883  keys.push_back(entry.first.c_str());
884  values.push_back(entry.second.c_str());
885  }
886  }
887 
888  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(),
889  keys.data(), values.data(), num_entries));
890 
891  return *this;
892 }
893 
894 template <typename T>
895 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
896  ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
897  return *this;
898 }
899 
900 template <typename T>
901 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
902  ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options));
903  return *this;
904 }
905 
906 template <typename T>
907 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
908  ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn));
909  return *this;
910 }
911 
912 template <typename T>
913 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
914  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options));
915  return *this;
916 }
917 
918 template <typename T>
919 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options) {
920  auto num_entries = provider_options.size();
921  std::vector<const char*> keys, values;
922  if (num_entries > 0) {
923  keys.reserve(num_entries);
924  values.reserve(num_entries);
925 
926  for (const auto& entry : provider_options) {
927  keys.push_back(entry.first.c_str());
928  values.push_back(entry.second.c_str());
929  }
930  }
931 
932  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO_V2(this->p_,
933  keys.data(), values.data(), num_entries));
934 
935  return *this;
936 }
937 
938 template <typename T>
939 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_VitisAI(const std::unordered_map<std::string, std::string>& provider_options) {
940  auto num_entries = provider_options.size();
941  std::vector<const char*> keys, values;
942  if (num_entries > 0) {
943  keys.reserve(num_entries);
944  values.reserve(num_entries);
945 
946  for (const auto& entry : provider_options) {
947  keys.push_back(entry.first.c_str());
948  values.push_back(entry.second.c_str());
949  }
950  }
951 
952  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_VitisAI(this->p_, keys.data(), values.data(), num_entries));
953 
954  return *this;
955 }
956 
957 template <typename T>
959  const CustomOpConfigs& custom_op_configs) {
960  // Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by
961  // the custom op library.
962  for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) {
963  AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str());
964  }
965 
966  ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name));
967  return *this;
968 }
969 
970 template <typename T>
971 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsUsingFunction(const char* registration_function_name) {
972  ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name));
973  return *this;
974 }
975 
976 /// Session
977 template <typename T>
978 inline size_t ConstSessionImpl<T>::GetInputCount() const {
979  size_t out;
980  ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out));
981  return out;
982 }
983 
984 template <typename T>
985 inline size_t ConstSessionImpl<T>::GetOutputCount() const {
986  size_t out;
987  ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out));
988  return out;
989 }
990 
991 template <typename T>
993  size_t out;
994  ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out));
995  return out;
996 }
997 
998 template <typename T>
999 inline AllocatedStringPtr ConstSessionImpl<T>::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
1000  char* out;
1001  ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out));
1002  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1003 }
1004 
1005 template <typename T>
1006 inline AllocatedStringPtr ConstSessionImpl<T>::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const {
1007  char* out;
1008  ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out));
1009  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1010 }
1011 
1012 template <typename T>
1014  char* out;
1015  ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out));
1016  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1017 }
1018 
1019 template <typename T>
1021  uint64_t out;
1022  ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out));
1023  return out;
1024 }
1025 
1026 template <typename T>
1028  OrtModelMetadata* out;
1029  ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out));
1030  return ModelMetadata{out};
1031 }
1032 
1033 template <typename T>
1035  OrtTypeInfo* out;
1036  ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out));
1037  return TypeInfo{out};
1038 }
1039 
1040 template <typename T>
1042  OrtTypeInfo* out;
1043  ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out));
1044  return TypeInfo{out};
1045 }
1046 
1047 template <typename T>
1049  OrtTypeInfo* out;
1050  ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out));
1051  return TypeInfo{out};
1052 }
1053 
1054 template <typename T>
1055 inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1056  const char* const* output_names, size_t output_count) {
1057  std::vector<Value> output_values;
1058  output_values.reserve(output_count);
1059  for (size_t i = 0; i < output_count; i++)
1060  output_values.emplace_back(nullptr);
1061  Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count);
1062  return output_values;
1063 }
1064 
1065 template <typename T>
1066 inline void SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1067  const char* const* output_names, Value* output_values, size_t output_count) {
1068  static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
1069  auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1070  auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1071  ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
1072 }
1073 
1074 template <typename T>
1075 inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding& io_binding) {
1076  ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
1077 }
1078 
1079 template <typename T>
1080 inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1081  const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) {
1082  auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1083  auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1084  ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names,
1085  ort_input_values, input_count, output_names, output_count,
1086  ort_output_values, callback, user_data));
1087 }
1088 
1089 template <typename T>
1091  char* out = nullptr;
1092  ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out));
1093  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1094 }
1095 
1096 template <typename T>
1097 inline void SessionImpl<T>::SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len) {
1098  ThrowOnError(GetApi().SetEpDynamicOptions(this->p_, keys, values, kv_len));
1099 }
1100 
1101 } // namespace detail
1102 
1104  ThrowOnError(GetApi().CreateSessionOptions(&this->p_));
1105 }
1106 
1107 /// CustomOpConfigs
1108 inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) {
1109  std::string config_key = "custom_op.";
1110 
1111  config_key += custom_op_name;
1112  config_key += ".";
1113  config_key += config;
1114 
1115  return config_key;
1116 }
1117 
1118 inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) {
1119  const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key);
1120  flat_configs_[full_flat_key] = config_value;
1121  return *this;
1122 }
1123 
1124 inline const std::unordered_map<std::string, std::string>& CustomOpConfigs::GetFlattenedConfigs() const {
1125  return flat_configs_;
1126 }
1127 
1128 inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
1129  ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_));
1130 }
1131 
1132 inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
1133  OrtPrepackedWeightsContainer* prepacked_weights_container) {
1134  ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_));
1135 }
1136 
1137 inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
1138  ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_));
1139 }
1140 
1141 inline Session::Session(const Env& env, const void* model_data, size_t model_data_length,
1142  const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
1143  ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
1144  prepacked_weights_container, &this->p_));
1145 }
1146 
1147 inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
1148  char* out;
1149  ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
1150  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1151 }
1152 
1153 inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const {
1154  char* out;
1155  ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
1156  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1157 }
1158 
1159 inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const {
1160  char* out;
1161  ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
1162  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1163 }
1164 
1166  char* out;
1167  ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
1168  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1169 }
1170 
1172  char* out;
1173  ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
1174  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1175 }
1176 
1177 inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const {
1178  char* out;
1179  ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
1180  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1181 }
1182 
1183 inline std::vector<AllocatedStringPtr> ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const {
1184  auto deletor = detail::AllocatedFree(allocator);
1185  std::vector<AllocatedStringPtr> result;
1186 
1187  char** out = nullptr;
1188  int64_t num_keys = 0;
1189  ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
1190  if (num_keys <= 0) {
1191  return result;
1192  }
1193 
1194  // array of pointers will be freed
1195  std::unique_ptr<void, decltype(deletor)> array_guard(out, deletor);
1196  // reserve may throw
1197  auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); };
1198  std::unique_ptr<char*, decltype(strings_deletor)> strings_guard(out, strings_deletor);
1199  result.reserve(static_cast<size_t>(num_keys));
1200  strings_guard.release();
1201  for (int64_t i = 0; i < num_keys; ++i) {
1202  result.push_back(AllocatedStringPtr(out[i], deletor));
1203  }
1204 
1205  return result;
1206 }
1207 
1208 inline int64_t ModelMetadata::GetVersion() const {
1209  int64_t out;
1210  ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
1211  return out;
1212 }
1213 
1214 namespace detail {
1215 
1216 template <typename T>
1217 inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl<T>::GetElementType() const {
1218  ONNXTensorElementDataType out;
1219  ThrowOnError(GetApi().GetTensorElementType(this->p_, &out));
1220  return out;
1221 }
1222 
1223 template <typename T>
1225  size_t out;
1226  ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out));
1227  return static_cast<size_t>(out);
1228 }
1229 
1230 template <typename T>
1232  size_t out;
1233  ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out));
1234  return out;
1235 }
1236 
1237 template <typename T>
1238 inline void TensorTypeAndShapeInfoImpl<T>::GetDimensions(int64_t* values, size_t values_count) const {
1239  ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count));
1240 }
1241 
1242 template <typename T>
1243 inline void TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions(const char** values, size_t values_count) const {
1244  ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count));
1245 }
1246 
1247 template <typename T>
1248 inline std::vector<int64_t> TensorTypeAndShapeInfoImpl<T>::GetShape() const {
1249  std::vector<int64_t> out(GetDimensionsCount(), 0);
1250  ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size()));
1251  return out;
1252 }
1253 
1254 template <typename T>
1256  const OrtTensorTypeAndShapeInfo* out;
1257  ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out));
1258  return ConstTensorTypeAndShapeInfo{out};
1259 }
1260 
1261 template <typename T>
1263  const OrtSequenceTypeInfo* out;
1264  ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out));
1265  return ConstSequenceTypeInfo{out};
1266 }
1267 
1268 template <typename T>
1270  const OrtMapTypeInfo* out;
1271  ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out));
1272  return ConstMapTypeInfo{out};
1273 }
1274 
1275 template <typename T>
1276 inline ONNXType TypeInfoImpl<T>::GetONNXType() const {
1277  ONNXType out;
1278  ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out));
1279  return out;
1280 }
1281 
1282 template <typename T>
1284  OrtTypeInfo* output;
1285  ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output));
1286  return TypeInfo{output};
1287 }
1288 
1289 template <typename T>
1291  OrtTypeInfo* info;
1292  ThrowOnError(GetApi().GetOptionalContainedTypeInfo(this->p_, &info));
1293  return TypeInfo{info};
1294 }
1295 
1296 template <typename T>
1297 inline ONNXTensorElementDataType MapTypeInfoImpl<T>::GetMapKeyType() const {
1298  ONNXTensorElementDataType out;
1299  ThrowOnError(GetApi().GetMapKeyType(this->p_, &out));
1300  return out;
1301 }
1302 
1303 template <typename T>
1305  OrtTypeInfo* output;
1306  ThrowOnError(GetApi().GetMapValueType(this->p_, &output));
1307  return TypeInfo{output};
1308 }
1309 
1310 template <typename T>
1312  const OrtOptionalTypeInfo* info;
1313  ThrowOnError(GetApi().CastTypeInfoToOptionalTypeInfo(this->p_, &info));
1314  return ConstOptionalTypeInfo{info};
1315 }
1316 
1317 } // namespace detail
1318 
1319 namespace detail {
1320 
1321 template <typename T>
1322 template <typename R>
1323 inline void ConstValueImpl<T>::GetOpaqueData(const char* domain, const char* type_name, R& out) const {
1324  ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R)));
1325 }
1326 
1327 template <typename T>
1328 inline bool ConstValueImpl<T>::IsTensor() const {
1329  int out;
1330  ThrowOnError(GetApi().IsTensor(this->p_, &out));
1331  return out != 0;
1332 }
1333 
1334 template <typename T>
1335 inline bool ConstValueImpl<T>::HasValue() const {
1336  int out;
1337  ThrowOnError(GetApi().HasValue(this->p_, &out));
1338  return out != 0;
1339 }
1340 
1341 template <typename T>
1342 inline size_t ConstValueImpl<T>::GetCount() const {
1343  size_t out;
1344  ThrowOnError(GetApi().GetValueCount(this->p_, &out));
1345  return out;
1346 }
1347 
1348 template <typename T>
1349 inline Value ConstValueImpl<T>::GetValue(int index, OrtAllocator* allocator) const {
1350  OrtValue* out;
1351  ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out));
1352  return Value{out};
1353 }
1354 
1355 template <typename T>
1357  size_t out;
1358  ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out));
1359  return out;
1360 }
1361 
1362 template <typename T>
1363 inline size_t ConstValueImpl<T>::GetStringTensorElementLength(size_t element_index) const {
1364  size_t out;
1365  ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out));
1366  return out;
1367 }
1368 
1369 template <typename T>
1370 template <typename R>
1371 inline const R* ConstValueImpl<T>::GetTensorData() const {
1372  R* out;
1373  ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), (void**)&out));
1374  return out;
1375 }
1376 
1377 template <typename T>
1378 inline const void* ConstValueImpl<T>::GetTensorRawData() const {
1379  void* out;
1380  ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), &out));
1381  return out;
1382 }
1383 
1384 template <typename T>
1386  OrtTypeInfo* output;
1387  ThrowOnError(GetApi().GetTypeInfo(this->p_, &output));
1388  return TypeInfo{output};
1389 }
1390 
1391 template <typename T>
1393  OrtTensorTypeAndShapeInfo* output;
1394  ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output));
1395  return TensorTypeAndShapeInfo{output};
1396 }
1397 
1398 template <typename T>
1400  const OrtMemoryInfo* mem_info;
1401  ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info));
1402  return ConstMemoryInfo(mem_info);
1403 }
1404 
1405 template <typename T>
1406 inline void ConstValueImpl<T>::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const {
1407  ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer));
1408 }
1409 
1410 template <typename T>
1411 inline std::string ConstValueImpl<T>::GetStringTensorElement(size_t element_index) const {
1412  size_t buffer_length;
1413  ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &buffer_length));
1414 
1415  std::string s;
1416  s.resize(buffer_length);
1417  ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, &s[0]));
1418  return s;
1419 }
1420 
1421 template <typename T>
1422 inline void ConstValueImpl<T>::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
1423  ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count));
1424 }
1425 
1426 #if !defined(DISABLE_SPARSE_TENSORS)
1427 template <typename T>
1428 inline OrtSparseFormat ConstValueImpl<T>::GetSparseFormat() const {
1429  OrtSparseFormat format;
1430  ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format));
1431  return format;
1432 }
1433 
1434 template <typename T>
1436  OrtTensorTypeAndShapeInfo* output;
1437  ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output));
1438  return TensorTypeAndShapeInfo{output};
1439 }
1440 
1441 template <typename T>
1442 inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const {
1443  OrtTensorTypeAndShapeInfo* output;
1444  ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output));
1445  return TensorTypeAndShapeInfo{output};
1446 }
1447 
1448 template <typename T>
1449 template <typename R>
1450 inline const R* ConstValueImpl<T>::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const {
1451  const void* out;
1452  ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out));
1453  return reinterpret_cast<const R*>(out);
1454 }
1455 
1456 template <typename T>
1458  int out;
1459  ThrowOnError(GetApi().IsSparseTensor(this->p_, &out));
1460  return out != 0;
1461 }
1462 
1463 template <typename T>
1464 template <typename R>
1466  const void* out;
1467  ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out));
1468  return reinterpret_cast<const R*>(out);
1469 }
1470 
1471 #endif
1472 
1473 template <typename T>
1474 void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
1475  ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
1476 }
1477 
1478 template <typename T>
1479 void ValueImpl<T>::FillStringTensorElement(const char* s, size_t index) {
1480  ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index));
1481 }
1482 
1483 template <typename T>
1484 inline char* ValueImpl<T>::GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length) {
1485  char* result;
1486  ThrowOnError(GetApi().GetResizedStringTensorElementBuffer(this->p_, index, buffer_length, &result));
1487  return result;
1488 }
1489 
1490 template <typename T>
1492  void* out;
1493  ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out));
1494  return out;
1495 }
1496 
1497 template <typename T>
1498 template <typename R>
1500  R* out;
1501  ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out));
1502  return out;
1503 }
1504 
1505 template <typename T>
1506 template <typename R>
1507 R& ValueImpl<T>::At(const std::vector<int64_t>& location) {
1508  static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
1509  R* out;
1510  ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out));
1511  return *out;
1512 }
1513 
1514 #if !defined(DISABLE_SPARSE_TENSORS)
1515 template <typename T>
1516 void ValueImpl<T>::UseCooIndices(int64_t* indices_data, size_t indices_num) {
1517  ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num));
1518 }
1519 
1520 template <typename T>
1521 void ValueImpl<T>::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) {
1522  ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num));
1523 }
1524 
1525 template <typename T>
1526 void ValueImpl<T>::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) {
1527  ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data));
1528 }
1529 
1530 template <typename T>
1532  const int64_t* indices_data, size_t indices_num) {
1533  ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape,
1534  values_param.values_shape_len, values_param.data.p_data,
1535  indices_data, indices_num));
1536 }
1537 
1538 template <typename T>
1541  const int64_t* inner_indices_data, size_t inner_indices_num,
1542  const int64_t* outer_indices_data, size_t outer_indices_num) {
1543  ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1544  inner_indices_data, inner_indices_num,
1545  outer_indices_data, outer_indices_num));
1546 }
1547 
1548 template <typename T>
1551  const Shape& indices_shape,
1552  const int32_t* indices_data) {
1553  ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1554  indices_shape.shape, indices_shape.shape_len,
1555  indices_data));
1556 }
1557 
1558 #endif // !defined(DISABLE_SPARSE_TENSORS)
1559 
1560 } // namespace detail
1561 
1562 template <typename T>
1563 inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
1564  return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
1565 }
1566 
1567 inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1568  ONNXTensorElementDataType type) {
1569  OrtValue* out;
1570  ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
1571  return Value{out};
1572 }
1573 
1574 template <typename T>
1575 inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
1576  return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
1577 }
1578 
1579 inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
1580  OrtValue* out;
1581  ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
1582  return Value{out};
1583 }
1584 
1585 #if !defined(DISABLE_SPARSE_TENSORS)
1586 
1587 template <typename T>
1588 inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1589  const Shape& values_shape) {
1590  return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType<T>::type);
1591 }
1592 
1593 inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1594  const Shape& values_shape, ONNXTensorElementDataType type) {
1595  OrtValue* out;
1596  ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
1597  values_shape.shape, values_shape.shape_len, type, &out));
1598  return Value{out};
1599 }
1600 
1601 template <typename T>
1602 inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) {
1603  return CreateSparseTensor(allocator, dense_shape, TypeToTensorType<T>::type);
1604 }
1605 
1606 inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape,
1607  ONNXTensorElementDataType type) {
1608  OrtValue* out;
1609  ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out));
1610  return Value{out};
1611 }
1612 #endif // !defined(DISABLE_SPARSE_TENSORS)
1613 
1614 inline Value Value::CreateMap(const Value& keys, const Value& values) {
1615  OrtValue* out;
1616  const OrtValue* inputs[2] = {keys, values};
1617  ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
1618  return Value{out};
1619 }
1620 
1621 inline Value Value::CreateSequence(const std::vector<Value>& values) {
1622  OrtValue* out;
1623  std::vector<const OrtValue*> values_ort{values.data(), values.data() + values.size()};
1624  ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
1625  return Value{out};
1626 }
1627 
1628 template <typename T>
1629 inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
1630  OrtValue* out;
1631  ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
1632  return Value{out};
1633 }
1634 
1635 //
1636 // Custom OP Inlines
1637 //
1638 inline Logger::Logger(const OrtLogger* logger) : logger_(logger) {
1639  Ort::ThrowOnError(GetApi().Logger_GetLoggingSeverityLevel(this->logger_, &this->cached_severity_level_));
1640 }
1641 
1642 inline OrtLoggingLevel Logger::GetLoggingSeverityLevel() const noexcept {
1643  return cached_severity_level_;
1644 }
1645 
1646 inline Status Logger::LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
1647  const char* func_name, const char* message) const noexcept {
1648  OrtStatus* status = GetApi().Logger_LogMessage(logger_, log_severity_level, message, file_path, line_number,
1649  func_name);
1650  return Status{status};
1651 }
1652 
1653 // Disable warnings about the format string not being a literal (-Wformat-nonliteral and -Wformat-security)
1654 // for gcc and clang. The alternative is to use actual C-style variadic parameters and apply
1655 // __attribute__(format(printf...)), which does not work with variadic templates.
1656 #if defined(__GNUC__)
1657 #pragma GCC diagnostic push
1658 #pragma GCC diagnostic ignored "-Wformat-nonliteral"
1659 #pragma GCC diagnostic ignored "-Wformat-security"
1660 #elif defined(__clang__)
1661 #pragma clang diagnostic push
1662 #pragma clang diagnostic ignored "-Wformat-nonliteral"
1663 #pragma clang diagnostic ignored "-Wformat-security"
1664 #endif
1665 template <typename... Args>
1666 inline Status Logger::LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path,
1667  int line_number, const char* func_name, const char* format,
1668  Args&&... args) const noexcept {
1669  int msg_len = std::snprintf(nullptr, 0U, format, std::forward<Args>(args)...);
1670 
1671  if (msg_len < 0) { // Formatting error
1672  return Status("Failed to log message due to formatting error", OrtErrorCode::ORT_FAIL);
1673  }
1674 
1675  OrtStatus* status = nullptr;
1676  const size_t buffer_size = static_cast<size_t>(msg_len) + 1U;
1677 
1678  constexpr size_t kStackBufferSize = 1024;
1679 
1680  if (buffer_size < kStackBufferSize) {
1681  char buffer[kStackBufferSize];
1682  snprintf(buffer, kStackBufferSize, format, std::forward<Args>(args)...);
1683  status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer, file_path, line_number, func_name);
1684  } else {
1685  // std::make_unique is only supported starting at C++14.
1686 #if (__cplusplus >= 201402L) || (_MSC_VER >= 1900)
1687  auto buffer = std::make_unique<char[]>(buffer_size);
1688 #else
1689  std::unique_ptr<char[]> buffer(new char[buffer_size]);
1690 #endif
1691  std::snprintf(buffer.get(), buffer_size, format, std::forward<Args>(args)...);
1692  status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer.get(), file_path, line_number, func_name);
1693  }
1694 
1695  return Status{status};
1696 }
1697 // Re-enable -Wformat-nonliteral and -Wformat-security
1698 #if defined(__GNUC__)
1699 #pragma GCC diagnostic pop
1700 #elif defined(__clang__)
1701 #pragma clang diagnostic pop
1702 #endif
1703 
1704 inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) {
1705 }
1706 
1707 inline size_t KernelContext::GetInputCount() const {
1708  size_t out = 0;
1709  Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out));
1710  return out;
1711 }
1712 
1713 inline size_t KernelContext::GetOutputCount() const {
1714  size_t out = 0;
1715  Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out));
1716  return out;
1717 }
1718 
1720  const OrtValue* out = nullptr;
1721  Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out));
1722  return ConstValue{out};
1723 }
1724 
1725 inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const {
1726  OrtValue* out = nullptr;
1727  Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out));
1728  return UnownedValue(out);
1729 }
1730 
1731 inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int64_t>& dims) const {
1732  OrtValue* out = nullptr;
1733  Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out));
1734  return UnownedValue(out);
1735 }
1736 
1738  void* out = nullptr;
1739  Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
1740  return out;
1741 }
1742 
1743 inline OrtAllocator* KernelContext::GetAllocator(const OrtMemoryInfo& memory_info) const {
1744  OrtAllocator* out = nullptr;
1745  Ort::ThrowOnError(GetApi().KernelContext_GetAllocator(ctx_, &memory_info, &out));
1746  return out;
1747 }
1748 
1750  const OrtLogger* out = nullptr;
1751  ThrowOnError(GetApi().KernelContext_GetLogger(this->ctx_, &out));
1752  return Logger{out};
1753 }
1754 
1755 inline void KernelContext::ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const {
1756  ThrowOnError(GetApi().KernelContext_ParallelFor(ctx_, fn, total, num_batch, usr_data));
1757 }
1758 
1759 inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) {
1760  Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_));
1761 }
1762 
1763 namespace detail {
1764 template <typename T>
1766  OrtKernelInfo* info_copy = nullptr;
1767  Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy));
1768  return KernelInfo{info_copy};
1769 }
1770 
1771 template <typename T>
1772 inline size_t KernelInfoImpl<T>::GetInputCount() const {
1773  size_t out = 0;
1774  ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out));
1775  return out;
1776 }
1777 
1778 template <typename T>
1779 inline size_t KernelInfoImpl<T>::GetOutputCount() const {
1780  size_t out = 0;
1781  ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out));
1782  return out;
1783 }
1784 
1785 template <typename T>
1786 inline std::string KernelInfoImpl<T>::GetInputName(size_t index) const {
1787  size_t size = 0;
1788 
1789  // Feed nullptr for the data buffer to query the true size of the string value
1790  Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size));
1791 
1792  std::string out;
1793  out.resize(size);
1794  Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size));
1795  out.resize(size - 1); // remove the terminating character '\0'
1796 
1797  return out;
1798 }
1799 
1800 template <typename T>
1801 inline std::string KernelInfoImpl<T>::GetOutputName(size_t index) const {
1802  size_t size = 0;
1803 
1804  // Feed nullptr for the data buffer to query the true size of the string value
1805  Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size));
1806 
1807  std::string out;
1808  out.resize(size);
1809  Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size));
1810  out.resize(size - 1); // remove the terminating character '\0'
1811 
1812  return out;
1813 }
1814 
1815 template <typename T>
1817  OrtTypeInfo* out = nullptr;
1818  ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out));
1819  return TypeInfo{out};
1820 }
1821 
1822 template <typename T>
1824  OrtTypeInfo* out = nullptr;
1825  ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out));
1826  return TypeInfo{out};
1827 }
1828 
1829 template <typename T>
1830 inline Value KernelInfoImpl<T>::GetTensorAttribute(const char* name, OrtAllocator* allocator) const {
1831  OrtValue* out = nullptr;
1832  ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out));
1833  return Value{out};
1834 }
1835 
1836 template <typename T>
1837 inline ConstValue KernelInfoImpl<T>::GetTensorConstantInput(size_t index, int* is_constant) const {
1838  const OrtValue* out = nullptr;
1839  ThrowOnError(GetApi().KernelInfoGetConstantInput_tensor(this->p_, index, is_constant, &out));
1840  return ConstValue{out};
1841 }
1842 
1843 template <typename T>
1844 inline std::string KernelInfoImpl<T>::GetNodeName() const {
1845  size_t size = 0;
1846 
1847  // Feed nullptr for the data buffer to query the true size of the string value
1848  Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, nullptr, &size));
1849 
1850  std::string out;
1851  out.resize(size);
1852  Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, &out[0], &size));
1853  out.resize(size - 1); // remove the terminating character '\0'
1854 
1855  return out;
1856 }
1857 
1858 template <typename T>
1860  const OrtLogger* out = nullptr;
1861  ThrowOnError(GetApi().KernelInfo_GetLogger(this->p_, &out));
1862  return Logger{out};
1863 }
1864 
1865 inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
1866  Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
1867 }
1868 
1869 inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) {
1870  Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out));
1871 }
1872 
1873 inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) {
1874  size_t size = 0;
1875  // Feed nullptr for the data buffer to query the true size of the string attribute
1876  Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size));
1877 
1878  std::string out;
1879  out.resize(size);
1880  Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size));
1881  out.resize(size - 1); // remove the terminating character '\0'
1882  out.swap(result);
1883 }
1884 
1885 inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>& result) {
1886  size_t size = 0;
1887  // Feed nullptr for the data buffer to query the true size of the attribute
1888  Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size));
1889 
1890  std::vector<float> out;
1891  out.resize(size);
1892  Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size));
1893  out.swap(result);
1894 }
1895 
1896 inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>& result) {
1897  size_t size = 0;
1898 
1899  // Feed nullptr for the data buffer to query the true size of the attribute
1900  Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size));
1901 
1902  std::vector<int64_t> out;
1903  out.resize(size);
1904  Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size));
1905  out.swap(result);
1906 }
1907 } // namespace detail
1908 
1909 inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}
1910 
1911 inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
1912 
1913 inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version,
1914  const char** type_constraint_names,
1915  const ONNXTensorElementDataType* type_constraint_values,
1916  size_t type_constraint_count,
1917  const OpAttr* attr_values, size_t attr_count,
1918  size_t input_count, size_t output_count) {
1919  static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*),
1920  "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely");
1921  auto attr_input_values = reinterpret_cast<const OrtOpAttr* const*>(attr_values);
1922  OrtOp* op;
1923  Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
1924  static_cast<int>(type_constraint_count),
1925  attr_input_values,
1926  static_cast<int>(attr_count),
1927  static_cast<int>(input_count),
1928  static_cast<int>(output_count), &op));
1929  return Op{op};
1930 }
1931 
1932 inline void Op::Invoke(const OrtKernelContext* context,
1933  const Value* input_values,
1934  size_t input_count,
1935  Value* output_values,
1936  size_t output_count) {
1937  static_assert(sizeof(Value) == sizeof(OrtValue*),
1938  "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
1939  auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1940  auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1941  Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast<int>(input_count),
1942  ort_output_values, static_cast<int>(output_count)));
1943 }
1944 
1945 inline void Op::Invoke(const OrtKernelContext* context,
1946  const OrtValue* const* input_values,
1947  size_t input_count,
1948  OrtValue* const* output_values,
1949  size_t output_count) {
1950  Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast<int>(input_count),
1951  output_values, static_cast<int>(output_count)));
1952 }
1953 
1954 inline std::string GetVersionString() {
1955  return OrtGetApiBase()->GetVersionString();
1956 }
1957 
1958 inline std::string GetBuildInfoString() {
1959  return GetApi().GetBuildInfoString();
1960 }
1961 
1962 inline std::vector<std::string> GetAvailableProviders() {
1963  char** providers;
1964  int len;
1965 
1966  auto release_fn = [&len](char** providers) {
1967  // This should always return nullptr.
1968  ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len));
1969  };
1970 
1971  ThrowOnError(GetApi().GetAvailableProviders(&providers, &len));
1972  std::unique_ptr<char*, decltype(release_fn)> guard(providers, release_fn);
1973  std::vector<std::string> available_providers;
1974  available_providers.reserve(static_cast<size_t>(len));
1975  for (int i = 0; i < len; ++i) {
1976  available_providers.emplace_back(providers[i]);
1977  }
1978  return available_providers;
1979 }
1980 
1981 template <typename TOp, typename TKernel, bool WithStatus>
1982 void CustomOpBase<TOp, TKernel, WithStatus>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
1983  ConstSessionOptions options) const {
1984  const TOp* derived = static_cast<const TOp*>(this);
1985  std::vector<std::string> keys = derived->GetSessionConfigKeys();
1986 
1987  out.reserve(keys.size());
1988 
1989  std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), "");
1990  const size_t prefix_size = config_entry_key.length();
1991 
1992  for (const auto& key : keys) {
1993  config_entry_key.resize(prefix_size);
1994  config_entry_key.append(key);
1995  out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), "");
1996  }
1997 }
1998 
1999 inline ShapeInferContext::ShapeInferContext(const OrtApi* ort_api,
2000  OrtShapeInferContext* ctx) : ort_api_(ort_api), ctx_(ctx) {
2001  size_t input_count = 0;
2002  Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputCount(ctx_, &input_count));
2003  for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
2004  OrtTensorTypeAndShapeInfo* info{};
2005  Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputTypeShape(ctx, ith_input, &info));
2006  TensorTypeAndShapeInfo type_shape_info(info);
2007  auto integer_shape = type_shape_info.GetShape();
2008  std::vector<const char*> symbolic_shape(integer_shape.size(), {});
2009  if (!integer_shape.empty()) {
2010  type_shape_info.GetSymbolicDimensions(&symbolic_shape[0], integer_shape.size());
2011  }
2012  Shape shape;
2013  for (size_t ith = 0; ith < integer_shape.size(); ++ith) {
2014  if (symbolic_shape[ith] && std::string{symbolic_shape[ith]}.size() > 0) {
2015  shape.emplace_back(symbolic_shape[ith]);
2016  } else {
2017  shape.emplace_back(integer_shape[ith]);
2018  }
2019  }
2020  input_shapes_.push_back(std::move(shape));
2021  type_shape_info.release();
2022  }
2023 }
2024 
2025 inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type) {
2026  OrtTensorTypeAndShapeInfo* info = {};
2027  ORT_CXX_RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info));
2028  ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetTensorElementType(info, type));
2029 
2030  using InfoPtr = std::unique_ptr<OrtTensorTypeAndShapeInfo, std::function<void(OrtTensorTypeAndShapeInfo*)>>;
2031 
2032  InfoPtr info_ptr(info, [this](OrtTensorTypeAndShapeInfo* obj) {
2033  ort_api_->ReleaseTensorTypeAndShapeInfo(obj);
2034  });
2035 
2036  std::vector<int64_t> integer_dims;
2037  std::vector<const char*> symbolic_dims;
2038 
2039  for (const auto dim : shape) {
2040  if (dim.IsInt()) {
2041  integer_dims.push_back(dim.AsInt());
2042  symbolic_dims.push_back("");
2043  } else {
2044  if (!dim.AsSym() || std::string{dim.AsSym()}.empty()) {
2045  ORT_CXX_API_THROW("Symbolic dim must not be an empty string", ORT_INVALID_ARGUMENT);
2046  }
2047  integer_dims.push_back(SymbolicInteger::INVALID_INT_DIM);
2048  symbolic_dims.push_back(dim.AsSym());
2049  }
2050  }
2051 
2052  ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetDimensions(info, integer_dims.data(), integer_dims.size()));
2053  ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetSymbolicDimensions(info, symbolic_dims.data(), symbolic_dims.size()));
2054  ORT_CXX_RETURN_ON_API_FAIL(ort_api_->ShapeInferContext_SetOutputTypeShape(ctx_, indice, info));
2055  return Status{nullptr};
2056 }
2057 
2058 inline int64_t ShapeInferContext::GetAttrInt(const char* attr_name) {
2059  const auto* attr = GetAttrHdl(attr_name);
2060  int64_t i = {};
2061  size_t out = {};
2062  Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INT, &i, sizeof(i), &out));
2063  return i;
2064 }
2065 
2066 inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_name) {
2067  const auto* attr = GetAttrHdl(attr_name);
2068  int64_t i = {};
2069  size_t out = {};
2070  // first call to get the bytes needed
2071  // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure.
2072  // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success).
2073  // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}.
2074  auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i, sizeof(i), &out);
2075  if (status) {
2076  size_t num_i = out / sizeof(int64_t);
2077  ShapeInferContext::Ints ints(num_i, 0);
2078  Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out));
2079  return ints;
2080  } else {
2081  if (out == 0u) {
2082  return {};
2083  }
2084  return {i};
2085  }
2086 }
2087 
2088 inline float ShapeInferContext::GetAttrFloat(const char* attr_name) {
2089  const auto* attr = GetAttrHdl(attr_name);
2090  float f = {};
2091  size_t out = {};
2092  Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOAT, &f, sizeof(f), &out));
2093  return f;
2094 }
2095 
2096 inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* attr_name) {
2097  const auto* attr = GetAttrHdl(attr_name);
2098  float f = {};
2099  size_t out = {};
2100  // first call to get the bytes needed
2101  // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure.
2102  // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success).
2103  // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}.
2104  auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f, sizeof(f), &out);
2105  if (status) {
2106  size_t num_f = out / sizeof(float);
2107  ShapeInferContext::Floats floats(num_f, 0);
2108  Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out));
2109  return floats;
2110  } else {
2111  if (out == 0u) {
2112  return {};
2113  }
2114  return {f};
2115  }
2116 }
2117 
2118 inline std::string ShapeInferContext::GetAttrString(const char* attr_name) {
2119  const auto* attr = GetAttrHdl(attr_name);
2120  char c = {};
2121  size_t out = {};
2122  // first call to get the bytes needed
2123  auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, &c, sizeof(char), &out);
2124  if (status) {
2125  std::vector<char> chars(out, '\0');
2126  Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, chars.data(), out, &out));
2127  return {chars.data()};
2128  } else {
2129  return {c};
2130  }
2131 }
2132 
2133 inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* attr_name) {
2134  const auto* attr = GetAttrHdl(attr_name);
2135  char c = {};
2136  size_t out = {};
2137  // first call to get the bytes needed
2138  // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure.
2139  // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success).
2140  // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}.
2141  auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c, sizeof(char), &out);
2142  if (status) {
2143  std::vector<char> chars(out, '\0');
2144  Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, chars.data(), out, &out));
2146  char* char_st = chars.data();
2147  char* char_ed = char_st + out;
2148  while (char_st < char_ed) {
2149  strings.emplace_back(char_st);
2150  while (*char_st != '\0') {
2151  char_st++;
2152  }
2153  char_st++;
2154  }
2155  return strings;
2156  } else {
2157  if (out == 0u) {
2158  return {};
2159  }
2160  return {std::string{c}};
2161  }
2162 }
2163 
2164 inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) const {
2165  const OrtOpAttr* attr_hdl = {};
2166  Ort::ThrowOnError(ort_api_->ShapeInferContext_GetAttribute(ctx_, attr_name, &attr_hdl));
2167  return attr_hdl;
2168 }
2169 
2170 } // namespace Ort
OrtMemoryInfoDeviceType GetDeviceType() const
type
Definition: core.h:556
Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T *file_path, int line_number, const char *func_name, const char *format, Args &&...args) const noexcept
std::vector< int64_t > Ints
void Invoke(const OrtKernelContext *context, const Value *input_values, size_t input_count, Value *output_values, size_t output_count)
std::string GetBuildInfoString()
This function returns the onnxruntime build information: including git branch, git commit id...
SessionOptionsImpl & SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn)
Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn.
GLuint GLsizei const GLchar * message
Definition: glcorearb.h:2543
size_t GetElementCount() const
Wraps OrtApi::GetTensorShapeElementCount.
MemoryAllocation & operator=(const MemoryAllocation &)=delete
Env & DisableTelemetryEvents()
Wraps OrtApi::EnableTelemetryEvents.
AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of the overridable initializer name at then specified index.
ThreadingOptions & SetGlobalSpinControl(int allow_spinning)
Wraps OrtApi::SetGlobalSpinControl.
OrtAllocator * GetAllocator(const OrtMemoryInfo &memory_info) const
std::string GetErrorMessage() const
png_const_structrp png_const_inforp info_ptr
Definition: png.h:1939
size_t GetInputCount() const
Returns the number of model inputs.
SessionOptionsImpl & DisablePerSessionThreads()
Wraps OrtApi::DisablePerSessionThreads.
std::vector< std::string > Strings
Env & UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level)
Wraps OrtApi::UpdateEnvWithCustomLogLevel.
void UseBlockSparseIndices(const Shape &indices_shape, int32_t *indices_data)
Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSp...
RunOptions & AddConfigEntry(const char *config_key, const char *config_value)
Wraps OrtApi::AddRunConfigEntry.
ConstMapTypeInfo GetMapTypeInfo() const
Wraps OrtApi::CastTypeInfoToMapTypeInfo.
void FillSparseTensorCsr(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const int64_t *inner_indices_data, size_t inner_indices_num, const int64_t *outer_indices_data, size_t outer_indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
TypeInfo GetInputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetInputTypeInfo.
size_t GetOutputCount() const
void GetSymbolicDimensions(const char **values, size_t values_count) const
Wraps OrtApi::GetSymbolicDimensions.
Type information that may contain either TensorTypeAndShapeInfo or the information about contained se...
SessionOptionsImpl & EnableMemPattern()
Wraps OrtApi::EnableMemPattern.
SessionOptionsImpl & EnableCpuMemArena()
Wraps OrtApi::EnableCpuMemArena.
static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1)
std::vector< float > Floats
bool IsTensor() const
Returns true if Value is a tensor, false for other types like map/sequence/etc.
AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of output name at then specified index.
void GetOpaqueData(const char *domain, const char *type_name, R &) const
Obtains a pointer to a user defined data for experimental purposes
const void * GetTensorRawData() const
Returns a non-typed pointer to a tensor contained data.
ConstMemoryInfo GetInfo() const
void GetStringTensorElement(size_t buffer_length, size_t element_index, void *buffer) const
The API copies UTF-8 encoded bytes for the requested string element contained within a tensor or a sp...
SessionOptionsImpl & SetLogId(const char *logid)
Wraps OrtApi::SetSessionLogId.
void swap(UT::ArraySet< Key, MULTI, MAX_LOAD_FACTOR_256, Clearer, Hash, KeyEqual > &a, UT::ArraySet< Key, MULTI, MAX_LOAD_FACTOR_256, Clearer, Hash, KeyEqual > &b)
Definition: UT_ArraySet.h:1699
Value GetValue(int index, OrtAllocator *allocator) const
uint64_t GetProfilingStartTimeNs() const
Wraps OrtApi::SessionGetProfilingStartTimeNs.
SessionOptionsImpl & AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN.
UnownedValue GetOutput(size_t index, const int64_t *dim_values, size_t dim_count) const
static Value CreateSequence(const std::vector< Value > &values)
Creates an OrtValue with a Sequence Onnx type representation. The API would ref-count the supplied Or...
void ThrowOnError(OrtStatus *ort_status)
GLsizei const GLfloat * value
Definition: glcorearb.h:824
void BindInput(const char *name, const Value &)
ONNXTensorElementDataType GetMapKeyType() const
Wraps OrtApi::GetMapKeyType.
ConstValue GetInput(size_t index) const
SessionOptionsImpl & AppendExecutionProvider_CANN(const OrtCANNProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl.
MemoryAllocation GetAllocation(size_t size)
std::unique_ptr< char, detail::AllocatedFree > AllocatedStringPtr
unique_ptr typedef used to own strings allocated by OrtAllocators and release them at the end of the ...
SessionOptionsImpl & DisableCpuMemArena()
Wraps OrtApi::DisableCpuMemArena.
AllocatedStringPtr GetDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the description.
bool HasConfigEntry(const char *config_key) const
Wraps OrtApi::HasSessionConfigEntry.
void UseCsrIndices(int64_t *inner_data, size_t inner_num, int64_t *outer_data, size_t outer_num)
Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tens...
ONNXTensorElementDataType GetElementType() const
Wraps OrtApi::GetTensorElementType.
static Value CreateOpaque(const char *domain, const char *type_name, const T &value)
Creates an OrtValue wrapping an Opaque type. This is used for experimental support of non-tensor type...
Env & CreateAndRegisterAllocator(const OrtMemoryInfo *mem_info, const OrtArenaCfg *arena_cfg)
Wraps OrtApi::CreateAndRegisterAllocator.
void FillStringTensorElement(const char *s, size_t index)
Set a single string in a string tensor
union Ort::detail::OrtSparseValuesParam::@166 data
std::vector< SymbolicInteger > Shape
std::vector< Value > Run(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, size_t output_count)
Run the model returning results in an Ort allocated vector.
Wrapper around ::OrtModelMetadata.
GLint level
Definition: glcorearb.h:108
SessionOptionsImpl & AddInitializer(const char *name, const OrtValue *ort_val)
Wraps OrtApi::AddInitializer.
const R * GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t &num_indices) const
The API retrieves a pointer to the internal indices buffer. The API merely performs a convenience dat...
std::string GetConfigEntryOrDefault(const char *config_key, const std::string &def)
OpAttr(const char *name, const void *data, int len, OrtOpAttrType type)
SessionOptionsImpl & AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT.
GLdouble s
Definition: glad.h:3009
SessionOptionsImpl & EnableOrtCustomOps()
Wraps OrtApi::EnableOrtCustomOps.
size_t GetDimensionsCount() const
Wraps OrtApi::GetDimensionsCount.
AllocatedStringPtr LookupCustomMetadataMapAllocated(const char *key, OrtAllocator *allocator) const
Looks up a value by a key in the Custom Metadata map.
R & At(const std::vector< int64_t > &location)
std::string GetAllocatorName() const
void RunAsync(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, Value *output_values, size_t output_count, RunAsyncCallbackFn callback, void *user_data)
Run the model asynchronously in a thread owned by intra op thread pool.
SessionOptionsImpl & AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA.
Wrapper around OrtValue.
SessionOptionsImpl & SetIntraOpNumThreads(int intra_op_num_threads)
Wraps OrtApi::SetIntraOpNumThreads.
**But if you need a result
Definition: thread.h:622
SessionOptionsImpl & RegisterCustomOpsUsingFunction(const char *function_name)
Wraps OrtApi::RegisterCustomOpsUsingFunction.
IoBinding(std::nullptr_t)
Create an empty object for convenience. Sometimes, we want to initialize members later.
The Env (Environment)
OrtLoggingLevel GetLoggingSeverityLevel() const noexcept
ConstOptionalTypeInfo GetOptionalTypeInfo() const
wraps OrtApi::CastTypeInfoToOptionalTypeInfo
std::vector< AllocatedStringPtr > GetCustomMetadataMapKeysAllocated(OrtAllocator *allocator) const
Returns a vector of copies of the custom metadata keys.
SessionOptionsImpl & SetOptimizedModelFilePath(const ORTCHAR_T *optimized_model_file)
Wraps OrtApi::SetOptimizedModelFilePath.
static LoraAdapter CreateLoraAdapterFromArray(const void *bytes, size_t num_bytes, OrtAllocator *allocator)
Wraps OrtApi::CreateLoraAdapterFromArray.
Env(std::nullptr_t)
Create an empty Env object, must be assigned a valid one to be used.
ThreadingOptions & SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn)
Wraps OrtApi::SetGlobalCustomJoinThreadFn.
float8e4m3fnuz (Float8 Floating Point) data type
const R * GetSparseTensorValues() const
The API returns a pointer to an internal buffer of the sparse tensor containing non-zero values...
TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const
The API returns type and shape information for the specified indices. Each supported indices have the...
GLuint buffer
Definition: glcorearb.h:660
SessionOptionsImpl & SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level)
Wraps OrtApi::SetSessionGraphOptimizationLevel.
const R * GetTensorData() const
Returns a const typed pointer to the tensor contained data. No type checking is performed, the caller must ensure the type matches the tensor type.
SessionOptionsImpl & SetCustomThreadCreationOptions(void *ort_custom_thread_creation_options)
Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions.
void * GetTensorMutableRawData()
Returns a non-typed non-const pointer to a tensor contained data.
static Op Create(const OrtKernelInfo *info, const char *op_name, const char *domain, int version, const char **type_constraint_names, const ONNXTensorElementDataType *type_constraint_values, size_t type_constraint_count, const OpAttr *attr_values, size_t attr_count, size_t input_count, size_t output_count)
std::string GetConfigEntry(const char *config_key) const
Wraps OrtApi::GetSessionConfigEntry.
OutGridT const XformOp bool bool
void ThrowStatus(const Status &st)
std::vector< std::string > GetOutputNames() const
float8e4m3fn (Float8 Floating Point) data type
SessionOptionsImpl & AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2 &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT.
SessionOptionsImpl & AddExternalInitializersFromFilesInMemory(const std::vector< std::basic_string< ORTCHAR_T >> &external_initializer_file_names, const std::vector< char * > &external_initializer_file_buffer_array, const std::vector< size_t > &external_initializer_file_lengths)
Wraps OrtApi::AddExternalInitializersFromFilesInMemory.
ConstValue GetTensorConstantInput(size_t index, int *is_constant) const
GLuint GLsizei const GLuint const GLintptr * offsets
Definition: glcorearb.h:2621
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
bool IsSparseTensor() const
Returns true if the OrtValue contains a sparse tensor
std::vector< Value > GetOutputValuesHelper(const OrtIoBinding *binding, OrtAllocator *)
ModelMetadata GetModelMetadata() const
Wraps OrtApi::SessionGetModelMetadata.
AllocatedStringPtr GetGraphNameAllocated(OrtAllocator *allocator) const
Used for interop with the C API.
ThreadingOptions & SetGlobalInterOpNumThreads(int inter_op_num_threads)
Wraps OrtApi::SetGlobalInterOpNumThreads.
void FillSparseTensorBlockSparse(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const Shape &indices_shape, const int32_t *indices_data)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
TypeInfo GetOutputTypeInfo(size_t index) const
std::vector< int64_t > GetShape() const
Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape.
Wrapper around OrtMemoryInfo.
void SetEpDynamicOptions(const char *const *keys, const char *const *values, size_t kv_len)
Set DynamicOptions for EPs (Execution Providers)
std::vector< std::string > GetAvailableProviders()
This is a C++ wrapper for OrtApi::GetAvailableProviders() and returns a vector of strings representin...
GLfloat f
Definition: glcorearb.h:1926
GLint GLint GLsizei GLint GLenum GLenum type
Definition: glcorearb.h:108
Op(std::nullptr_t)
Create an empty Operator object, must be assigned a valid one to be used.
bool operator<(const BFloat16_t &rhs) const noexcept
SessionOptions Clone() const
Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions.
static LoraAdapter CreateLoraAdapter(const std::basic_string< ORTCHAR_T > &adapter_path, OrtAllocator *allocator)
Wraps OrtApi::CreateLoraAdapter.
Wrapper around ::OrtIoBinding.
void GetAttrs(const OrtKernelInfo *p, const char *name, std::vector< float > &)
char * GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length)
Allocate if necessary and obtain a pointer to a UTF-8 encoded string element buffer indexed by the fl...
SessionOptionsImpl & SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn)
Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn.
The Status that holds ownership of OrtStatus received from C API Use it to safely destroy OrtStatus* ...
void GetStringTensorContent(void *buffer, size_t buffer_length, size_t *offsets, size_t offsets_count) const
The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor into...
void GetDimensions(int64_t *values, size_t values_count) const
Wraps OrtApi::GetDimensions.
A generic, discriminated value, whose type may be queried dynamically.
Definition: Value.h:45
int64_t GetVersion() const
Wraps OrtApi::ModelMetadataGetVersion.
void GetAttr(const OrtKernelInfo *p, const char *name, float &)
CustomOpDomain(std::nullptr_t)
Create an empty CustomOpDomain object, must be assigned a valid one to be used.
SessionOptionsImpl & AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO_V2.
const std::unordered_map< std::string, std::string > & GetFlattenedConfigs() const
Returns a flattened map of custom operator configuration entries and their values.
SessionOptionsImpl & AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2 &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2.
constexpr std::enable_if< I< type_count_base< T >::value, int >::type tuple_type_size(){return subtype_count< typename std::tuple_element< I, T >::type >::value+tuple_type_size< T, I+1 >);}template< typename T > struct type_count< T, typename std::enable_if< is_tuple_like< T >::value >::type >{static constexpr int value{tuple_type_size< T, 0 >)};};template< typename T > struct subtype_count{static constexpr int value{is_mutable_container< T >::value?expected_max_vector_size:type_count< T >::value};};template< typename T, typename Enable=void > struct type_count_min{static const int value{0};};template< typename T >struct type_count_min< T, typename std::enable_if<!is_mutable_container< T >::value &&!is_tuple_like< T >::value &&!is_wrapper< T >::value &&!is_complex< T >::value &&!std::is_void< T >::value >::type >{static constexpr int value{type_count< T >::value};};template< typename T > struct type_count_min< T, typename std::enable_if< is_complex< T >::value >::type >{static constexpr int value{1};};template< typename T >struct type_count_min< T, typename std::enable_if< is_wrapper< T >::value &&!is_complex< T >::value &&!is_tuple_like< T >::value >::type >{static constexpr int value{subtype_count_min< typename T::value_type >::value};};template< typename T, std::size_t I >constexpr typename std::enable_if< I==type_count_base< T >::value, int >::type tuple_type_size_min(){return 0;}template< typename T, std::size_t I > constexpr typename std::enable_if< I< type_count_base< T >::value, int >::type tuple_type_size_min(){return subtype_count_min< typename std::tuple_element< I, T >::type >::value+tuple_type_size_min< T, I+1 >);}template< typename T > struct type_count_min< T, typename std::enable_if< is_tuple_like< T >::value >::type >{static constexpr int value{tuple_type_size_min< T, 0 >)};};template< typename T > struct subtype_count_min{static constexpr int value{is_mutable_container< T >::value?((type_count< T >::value< expected_max_vector_size)?type_count< T >::value:0):type_count_min< T >::value};};template< typename T, typename Enable=void > struct expected_count{static const int value{0};};template< typename T >struct expected_count< T, typename std::enable_if<!is_mutable_container< T >::value &&!is_wrapper< T >::value &&!std::is_void< T >::value >::type >{static constexpr int value{1};};template< typename T > struct expected_count< T, typename std::enable_if< is_mutable_container< T >::value >::type >{static constexpr int value{expected_max_vector_size};};template< typename T >struct expected_count< T, typename std::enable_if<!is_mutable_container< T >::value &&is_wrapper< T >::value >::type >{static constexpr int value{expected_count< typename T::value_type >::value};};enum class object_category:int{char_value=1, integral_value=2, unsigned_integral=4, enumeration=6, boolean_value=8, floating_point=10, number_constructible=12, double_constructible=14, integer_constructible=16, string_assignable=23, string_constructible=24, other=45, wrapper_value=50, complex_number=60, tuple_value=70, container_value=80,};template< typename T, typename Enable=void > struct classify_object{static constexpr object_category value{object_category::other};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_integral< T >::value &&!std::is_same< T, char >::value &&std::is_signed< T >::value &&!is_bool< T >::value &&!std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::integral_value};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_integral< T >::value &&std::is_unsigned< T >::value &&!std::is_same< T, char >::value &&!is_bool< T >::value >::type >{static constexpr object_category value{object_category::unsigned_integral};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_same< T, char >::value &&!std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::char_value};};template< typename T > struct classify_object< T, typename std::enable_if< is_bool< T >::value >::type >{static constexpr object_category value{object_category::boolean_value};};template< typename T > struct classify_object< T, typename std::enable_if< std::is_floating_point< T >::value >::type >{static constexpr object_category value{object_category::floating_point};};template< typename T >struct classify_object< T, typename std::enable_if<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&std::is_assignable< T &, std::string >::value >::type >{static constexpr object_category value{object_category::string_assignable};};template< typename T >struct classify_object< T, typename std::enable_if<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&!std::is_assignable< T &, std::string >::value &&(type_count< T >::value==1)&&std::is_constructible< T, std::string >::value >::type >{static constexpr object_category value{object_category::string_constructible};};template< typename T > struct classify_object< T, typename std::enable_if< std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::enumeration};};template< typename T > struct classify_object< T, typename std::enable_if< is_complex< T >::value >::type >{static constexpr object_category value{object_category::complex_number};};template< typename T > struct uncommon_type{using type=typename std::conditional<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&!std::is_assignable< T &, std::string >::value &&!std::is_constructible< T, std::string >::value &&!is_complex< T >::value &&!is_mutable_container< T >::value &&!std::is_enum< T >::value, std::true_type, std::false_type >::type;static constexpr bool value=type::value;};template< typename T >struct classify_object< T, typename std::enable_if<(!is_mutable_container< T >::value &&is_wrapper< T >::value &&!is_tuple_like< T >::value &&uncommon_type< T >::value)>::type >{static constexpr object_category value{object_category::wrapper_value};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&is_direct_constructible< T, double >::value &&is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::number_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&!is_direct_constructible< T, double >::value &&is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::integer_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&is_direct_constructible< T, double >::value &&!is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::double_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< is_tuple_like< T >::value &&((type_count< T >::value >=2 &&!is_wrapper< T >::value)||(uncommon_type< T >::value &&!is_direct_constructible< T, double >::value &&!is_direct_constructible< T, int >::value)||(uncommon_type< T >::value &&type_count< T >::value >=2))>::type >{static constexpr object_category value{object_category::tuple_value};};template< typename T > struct classify_object< T, typename std::enable_if< is_mutable_container< T >::value >::type >{static constexpr object_category value{object_category::container_value};};template< typename T, enable_if_t< classify_object< T >::value==object_category::char_value, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"CHAR";}template< typename T, enable_if_t< classify_object< T >::value==object_category::integral_value||classify_object< T >::value==object_category::integer_constructible, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"INT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::unsigned_integral, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"UINT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::floating_point||classify_object< T >::value==object_category::number_constructible||classify_object< T >::value==object_category::double_constructible, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"FLOAT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::enumeration, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"ENUM";}template< typename T, enable_if_t< classify_object< T >::value==object_category::boolean_value, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"BOOLEAN";}template< typename T, enable_if_t< classify_object< T >::value==object_category::complex_number, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"COMPLEX";}template< typename T, enable_if_t< classify_object< T >::value >=object_category::string_assignable &&classify_object< T >::value<=object_category::other, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"TEXT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value >=2, detail::enabler >=detail::dummy >std::string type_name();template< typename T, enable_if_t< classify_object< T >::value==object_category::container_value||classify_object< T >::value==object_category::wrapper_value, detail::enabler >=detail::dummy >std::string type_name();template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value==1, detail::enabler >=detail::dummy >inline std::string type_name(){return type_name< typename std::decay< typename std::tuple_element< 0, T >::type >::type >);}template< typename T, std::size_t I >inline typename std::enable_if< I==type_count_base< T >::value, std::string >::type tuple_name(){return std::string{};}template< typename T, std::size_t I >inline typename std::enable_if<(I< type_count_base< T >::value), std::string >::type tuple_name(){auto str=std::string{type_name< typename std::decay< typename std::tuple_element< I, T >::type >::type >)}+ ','+tuple_name< T, I+1 >);if(str.back()== ',') str.pop_back();return str;}template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value >=2, detail::enabler > > std::string type_name()
Recursively generate the tuple type name.
Definition: CLI11.h:1729
static Value CreateTensor(const OrtMemoryInfo *info, T *p_data, size_t p_data_element_count, const int64_t *shape, size_t shape_len)
Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
GLint GLint GLsizei GLint GLenum format
Definition: glcorearb.h:108
ThreadingOptions & SetGlobalIntraOpNumThreads(int intra_op_num_threads)
Wraps OrtApi::SetGlobalIntraOpNumThreads.
Value GetTensorAttribute(const char *name, OrtAllocator *allocator) const
void Add(const OrtCustomOp *op)
Wraps CustomOpDomain_Add.
KernelInfo(std::nullptr_t)
Create an empty instance to initialize later.
AllocatedStringPtr GetDomainAllocated(OrtAllocator *allocator) const
Returns a copy of the domain name.
Represents native memory allocation coming from one of the OrtAllocators registered with OnnxRuntime...
Session(std::nullptr_t)
Create an empty Session object, must be assigned a valid one to be used.
ThreadingOptions & SetGlobalDenormalAsZero()
Wraps OrtApi::SetGlobalDenormalAsZero.
IEEE 754 half-precision floating point data type.
GLint location
Definition: glcorearb.h:805
GLuint id
Definition: glcorearb.h:655
static Value CreateMap(const Value &keys, const Value &values)
Creates an OrtValue with a Map Onnx type representation. The API would ref-count the supplied OrtValu...
SessionOptionsImpl & SetInterOpNumThreads(int inter_op_num_threads)
Wraps OrtApi::SetInterOpNumThreads.
float8e5m2fnuz (Float8 Floating Point) data type
Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT_V...
R * GetTensorMutableData()
Returns a non-const typed pointer to an OrtValue/Tensor contained buffer No type checking is performe...
SessionOptionsImpl & AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions &provider_options)
SessionOptions()
Wraps OrtApi::CreateSessionOptions.
void * GetGPUComputeStream() const
void BindOutput(const char *name, const Value &)
GLuint const GLchar * name
Definition: glcorearb.h:786
RunOptions & SetRunTag(const char *run_tag)
wraps OrtApi::RunOptionsSetRunTag
Allocator(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
SessionOptionsImpl & EnableProfiling(const ORTCHAR_T *profile_file_prefix)
Wraps OrtApi::EnableProfiling.
size_t GetStringTensorDataLength() const
This API returns a full length of string data contained within either a tensor or a sparse Tensor...
detail::ConstSessionOptionsImpl< detail::Unowned< const OrtSessionOptions >> ConstSessionOptions
GLsizei const GLchar *const * strings
Definition: glcorearb.h:1933
RunOptions & SetTerminate()
Terminates all currently executing Session::Run calls that were made using this RunOptions instance...
OrtAllocatorType GetAllocatorType() const
SessionOptionsImpl & AppendExecutionProvider_OpenVINO_V2(const std::unordered_map< std::string, std::string > &provider_options={})
TypeInfo GetSequenceElementType() const
Wraps OrtApi::GetSequenceElementType.
static Value CreateSparseTensor(const OrtMemoryInfo *info, T *p_data, const Shape &dense_shape, const Shape &values_shape)
This is a simple forwarding method to the other overload that helps deducing data type enum value fro...
void GetSessionConfigs(std::unordered_map< std::string, std::string > &out, ConstSessionOptions options) const
TypeInfo GetOutputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOutputTypeInfo.
size_t GetStringTensorElementLength(size_t element_index) const
The API returns a byte length of UTF-8 encoded string element contained in either a tensor or a spare...
detail::ValueImpl< detail::Unowned< OrtValue >> UnownedValue
SessionOptionsImpl & Add(OrtCustomOpDomain *custom_op_domain)
Wraps OrtApi::AddCustomOpDomain.
TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const
The API returns type and shape information for stored non-zero values of the sparse tensor...
GT_API const UT_StringHolder version
TypeInfo GetTypeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
RunOptions & SetRunLogSeverityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogSeverityLevel.
ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
Wraps OrtApi::CastTypeInfoToTensorInfo.
bool operator==(const MemoryInfoImpl< U > &o) const
TypeInfo GetMapValueType() const
Wraps OrtApi::GetMapValueType.
RunOptions()
Wraps OrtApi::CreateRunOptions.
SessionOptionsImpl & AppendExecutionProvider_VitisAI(const std::unordered_map< std::string, std::string > &provider_options={})
This struct owns the OrtKernInfo* pointer when a copy is made. For convenient wrapping of OrtKernelIn...
void FillSparseTensorCoo(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values_param, const int64_t *indices_data, size_t indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
std::vector< Value > GetOutputValues() const
GLsizeiptr size
Definition: glcorearb.h:664
OrtErrorCode GetErrorCode() const
void FillStringTensor(const char *const *s, size_t s_len)
Set all strings at once in a string tensor
void ParallelFor(void(*fn)(void *, size_t), size_t total, size_t num_batch, void *usr_data) const
KernelContext(OrtKernelContext *context)
This class represents an ONNX Runtime logger that can be used to log information with an associated s...
Env & CreateAndRegisterAllocatorV2(const std::string &provider_type, const OrtMemoryInfo *mem_info, const std::unordered_map< std::string, std::string > &options, const OrtArenaCfg *arena_cfg)
Wraps OrtApi::CreateAndRegisterAllocatorV2.
IMATH_NAMESPACE::V2f IMATH_NAMESPACE::Box2i std::string this attribute is obsolete as of OpenEXR v3 float
ArenaCfg(std::nullptr_t)
Create an empty ArenaCfg object, must be assigned a valid one to be used.
SessionOptionsImpl & SetDeterministicCompute(bool value)
Wraps OrtApi::SetDeterministicCompute.
GT_API const UT_StringHolder st
ThreadingOptions()
Wraps OrtApi::CreateThreadingOptions.
GLenum GLsizei GLsizei GLint * values
Definition: glcorearb.h:1602
std::string GetOutputName(size_t index) const
The ThreadingOptions.
LoraAdapter holds a set of Lora Parameters loaded from a single file.
bool operator==(const BFloat16_t &rhs) const noexcept
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOverridableInitializerTypeInfo.
SessionOptionsImpl & AddConfigEntry(const char *config_key, const char *config_value)
Wraps OrtApi::AddSessionConfigEntry.
bool IsOK() const noexcept
Returns true if instance represents an OK (non-error) status.
SessionOptionsImpl & SetExecutionMode(ExecutionMode execution_mode)
Wraps OrtApi::SetSessionExecutionMode.
GLuint index
Definition: glcorearb.h:786
AllocatedStringPtr EndProfilingAllocated(OrtAllocator *allocator)
End profiling and return a copy of the profiling file name.
ShapeInferContext(const OrtApi *ort_api, OrtShapeInferContext *ctx)
Logger()=default
SessionOptionsImpl & AppendExecutionProvider(const std::string &provider_name, const std::unordered_map< std::string, std::string > &provider_options={})
Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK...
GLuint GLfloat * val
Definition: glcorearb.h:1608
SessionOptionsImpl & DisableProfiling()
Wraps OrtApi::DisableProfiling.
AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the graph description.
float8e5m2 (Float8 Floating Point) data type
**If you just want to fire and args
Definition: thread.h:618
RunOptions & UnsetTerminate()
Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without ...
size_t GetOverridableInitializerCount() const
Returns the number of inputs that have defaults that can be overridden.
const char * GetRunTag() const
Wraps OrtApi::RunOptionsGetRunTag.
#define ORT_CXX_RETURN_ON_API_FAIL(expression)
int GetRunLogVerbosityLevel() const
Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel.
OIIO_UTIL_API const char * c_str(string_view str)
std::string GetVersionString()
This function returns the onnxruntime version string
size_t GetOutputCount() const
Returns the number of model outputs.
Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T *file_path, int line_number, const char *func_name, const char *message) const noexcept
SessionOptionsImpl & AddExternalInitializers(const std::vector< std::string > &names, const std::vector< Value > &ort_values)
Wraps OrtApi::AddExternalInitializers.
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
int GetRunLogSeverityLevel() const
Wraps OrtApi::RunOptionsGetRunLogSeverityLevel.
ConstMemoryInfo GetTensorMemoryInfo() const
This API returns information about the memory allocation used to hold data.
AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of input name at the specified index.
Wrapper around ::OrtTensorTypeAndShapeInfo.
MemoryInfo(std::nullptr_t)
No instance is created.
std::string MakeCustomOpConfigEntryKey(const char *custom_op_name, const char *config)
CustomOpConfigs.
size_t GetCount() const
< Return true if OrtValue contains data and returns false if the OrtValue is a None ...
std::string GetInputName(size_t index) const
ThreadingOptions & SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn)
Wraps OrtApi::SetGlobalCustomCreateThreadFn.
Wrapper around ::OrtSessionOptions.
OrtSparseFormat GetSparseFormat() const
The API returns the sparse data format this OrtValue holds in a sparse tensor. If the sparse tensor w...
detail::MemoryInfoImpl< detail::Unowned< const OrtMemoryInfo >> ConstMemoryInfo
void UseCooIndices(int64_t *indices_data, size_t indices_num)
Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tens...
MemoryAllocation(OrtAllocator *allocator, void *p, size_t size)
CustomOpConfigs & AddConfig(const char *custom_op_name, const char *config_key, const char *config_value)
Adds a session configuration entry/value for a specific custom operator.
TypeInfo GetInputTypeInfo(size_t index) const
SessionOptionsImpl & SetLogSeverityLevel(int level)
Wraps OrtApi::SetSessionLogSeverityLevel.
Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime.
Wrapper around ::OrtSession.
SessionOptionsImpl & DisableMemPattern()
Wraps OrtApi::DisableMemPattern.
SessionOptionsImpl & AppendExecutionProvider_ROCM(const OrtROCMProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM.
Options for the CUDA provider that are passed to SessionOptionsAppendExecutionProvider_CUDA_V2. Please note that this struct is similar to OrtCUDAProviderOptions but only to be used internally. Going forward, new cuda provider options are to be supported via this struct and usage of the publicly defined OrtCUDAProviderOptions will be deprecated over time. User can only get the instance of OrtCUDAProviderOptionsV2 via CreateCUDAProviderOptions.
RunOptions & SetRunLogVerbosityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel.
ThreadingOptions & SetGlobalCustomThreadCreationOptions(void *ort_custom_thread_creation_options)
Wraps OrtApi::SetGlobalCustomThreadCreationOptions.
TypeInfo GetOptionalElementType() const
Wraps OrtApi::CastOptionalTypeToContainedTypeInfo.
bfloat16 (Brain Floating Point) data type
RunOptions & AddActiveLoraAdapter(const LoraAdapter &adapter)
Add the LoraAdapter to the list of active adapters. The setting does not affect RunWithBinding() call...
Class that represents session configuration entries for one or more custom operators.
Status(std::nullptr_t) noexcept
Create an empty object, must be assigned a valid one to be used.
GLint GLsizei count
Definition: glcorearb.h:405
Definition: format.h:1821
Definition: format.h:4365
std::vector< std::string > GetOutputNamesHelper(const OrtIoBinding *binding, OrtAllocator *)
#define ORT_CXX_API_THROW(string, code)
SessionOptionsImpl & RegisterCustomOpsLibrary(const ORTCHAR_T *library_name, const CustomOpConfigs &custom_op_configs={})
GLsizei GLenum GLenum GLuint GLenum GLsizei * lengths
Definition: glcorearb.h:2542
ConstSequenceTypeInfo GetSequenceTypeInfo() const
Wraps OrtApi::CastTypeInfoToSequenceTypeInfo.