HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
onnxruntime_cxx_inline.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 // Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead.
5 // If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead.
6 //
7 // These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
8 // the main C++ file with implementation details.
9 
10 #include <cstring>
11 #include <functional>
12 
13 #define RETURN_ON_API_FAIL(expression) \
14  { \
15  auto err = (expression); \
16  if (err) { \
17  return Status(err); \
18  } \
19  }
20 
21 namespace Ort {
22 
23 namespace detail {
24 inline void ThrowStatus(const Status& st) {
25  std::string error_message = st.GetErrorMessage();
26  OrtErrorCode error_code = st.GetErrorCode();
27  ORT_CXX_API_THROW(std::move(error_message), error_code);
28 }
29 } // namespace detail
30 
31 inline void ThrowOnError(OrtStatus* ort_status) {
32  if (ort_status) {
33  Ort::Status st(ort_status);
35  }
36 }
37 
38 inline void ThrowOnError(const Status& st) {
39  if (st) {
41  }
42 }
43 
44 inline Status::Status(OrtStatus* status) noexcept : Base<OrtStatus>{status} {
45 }
46 
47 inline Status::Status(const std::exception& e) noexcept {
48  p_ = GetApi().CreateStatus(ORT_FAIL, e.what());
49 }
50 
51 inline Status::Status(const Exception& e) noexcept {
52  p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what());
53 }
54 
55 inline Status::Status(const char* message, OrtErrorCode code) noexcept {
56  p_ = GetApi().CreateStatus(code, message);
57 }
58 
59 inline std::string Status::GetErrorMessage() const {
61  return message;
62 }
63 
64 inline OrtErrorCode Status::GetErrorCode() const {
65  return GetApi().GetErrorCode(p_);
66 }
67 
68 inline bool Status::IsOK() const noexcept {
69  return (p_ == nullptr);
70 }
71 
72 // This template converts a C++ type into it's ONNXTensorElementDataType
73 template <typename T>
74 struct TypeToTensorType;
75 template <>
76 struct TypeToTensorType<float> {
77  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
78 };
79 template <>
80 struct TypeToTensorType<Float16_t> {
81  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
82 };
83 template <>
84 struct TypeToTensorType<BFloat16_t> {
85  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
86 };
87 template <>
88 struct TypeToTensorType<double> {
89  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
90 };
91 template <>
92 struct TypeToTensorType<int8_t> {
93  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
94 };
95 template <>
96 struct TypeToTensorType<int16_t> {
97  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
98 };
99 template <>
100 struct TypeToTensorType<int32_t> {
101  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
102 };
103 template <>
104 struct TypeToTensorType<int64_t> {
105  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
106 };
107 template <>
108 struct TypeToTensorType<uint8_t> {
109  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
110 };
111 template <>
112 struct TypeToTensorType<uint16_t> {
113  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
114 };
115 template <>
116 struct TypeToTensorType<uint32_t> {
117  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
118 };
119 template <>
120 struct TypeToTensorType<uint64_t> {
121  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
122 };
123 template <>
124 struct TypeToTensorType<bool> {
125  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
126 };
127 
128 template <>
129 struct TypeToTensorType<Float8E4M3FN_t> {
130  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN;
131 };
132 template <>
133 struct TypeToTensorType<Float8E4M3FNUZ_t> {
134  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ;
135 };
136 template <>
137 struct TypeToTensorType<Float8E5M2_t> {
138  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2;
139 };
140 template <>
141 struct TypeToTensorType<Float8E5M2FNUZ_t> {
142  static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ;
143 };
144 
145 inline bool BFloat16_t::operator==(const BFloat16_t& rhs) const noexcept {
146  if (IsNaN() || rhs.IsNaN()) {
147  // IEEE defines that NaN is not equal to anything, including itself.
148  return false;
149  }
150  return val == rhs.val;
151 }
152 
153 inline bool BFloat16_t::operator<(const BFloat16_t& rhs) const noexcept {
154  if (IsNaN() || rhs.IsNaN()) {
155  // IEEE defines that NaN is unordered with respect to everything, including itself.
156  return false;
157  }
158 
159  const bool left_is_negative = IsNegative();
160  if (left_is_negative != rhs.IsNegative()) {
161  // When the signs of left and right differ, we know that left is less than right if it is
162  // the negative value. The exception to this is if both values are zero, in which case IEEE
163  // says they should be equal, even if the signs differ.
164  return left_is_negative && !AreZero(*this, rhs);
165  }
166  return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
167 }
168 
169 inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
170  : allocator_(allocator), p_(p), size_(size) {
171 }
172 
174  if (p_ != nullptr) {
175  // We do not throw out of destructor
176  auto ret = GetApi().AllocatorFree(allocator_, p_);
177  static_cast<void>(ret);
178  }
179 }
180 
181 inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) {
182  *this = std::move(o);
183 }
184 
186  OrtAllocator* alloc = nullptr;
187  void* p = nullptr;
188  size_t sz = 0;
189 
190  // Swap out this
191  std::swap(alloc, allocator_);
192  std::swap(p, p_);
193  std::swap(sz, size_);
194 
195  // Swap with incoming
196  std::swap(allocator_, o.allocator_);
197  std::swap(p_, o.p_);
198  std::swap(size_, o.size_);
199 
200  // Destroy this instance if needed
201  MemoryAllocation this_alloc(alloc, p, sz);
202  return *this;
203 }
204 
205 namespace detail {
206 
207 template <typename T>
208 inline void* AllocatorImpl<T>::Alloc(size_t size) {
209  void* out;
210  ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
211  return out;
212 }
213 
214 template <typename T>
216  void* out;
217  ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
218  MemoryAllocation result(this->p_, out, size);
219  return result;
220 }
221 
222 template <typename T>
223 inline void AllocatorImpl<T>::Free(void* p) {
224  ThrowOnError(GetApi().AllocatorFree(this->p_, p));
225 }
226 
227 template <typename T>
229  const OrtMemoryInfo* out;
230  ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out));
231  return ConstMemoryInfo{out};
232 }
233 
234 } // namespace detail
235 
237  ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_));
238 }
239 
240 inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) {
241  ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_));
242 }
243 
244 namespace detail {
245 
246 template <typename T>
248  const char* name = nullptr;
249  ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name));
250  return std::string(name);
251 }
252 
253 template <typename T>
254 inline OrtAllocatorType MemoryInfoImpl<T>::GetAllocatorType() const {
255  OrtAllocatorType type;
256  ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type));
257  return type;
258 }
259 
260 template <typename T>
261 inline int MemoryInfoImpl<T>::GetDeviceId() const {
262  int id = 0;
263  ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id));
264  return id;
265 }
266 
267 template <typename T>
268 inline OrtMemoryInfoDeviceType MemoryInfoImpl<T>::GetDeviceType() const {
269  OrtMemoryInfoDeviceType type;
270  GetApi().MemoryInfoGetDeviceType(this->p_, &type);
271  return type;
272 }
273 
274 template <typename T>
275 inline OrtMemType MemoryInfoImpl<T>::GetMemoryType() const {
276  OrtMemType type;
277  ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type));
278  return type;
279 }
280 
281 template <typename T>
282 template <typename U>
284  int comp_result = 0;
285  ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result));
286  return comp_result == 0;
287 }
288 
289 } // namespace detail
290 
291 inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) {
292  OrtMemoryInfo* p;
293  ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
294  return MemoryInfo(p);
295 }
296 
297 inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
298  ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_));
299 }
300 
301 namespace detail {
302 template <typename T>
303 inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames() const {
304  AllocatorWithDefaultOptions allocator;
305  return binding_utils::GetOutputNamesHelper(this->p_, allocator);
306 }
307 
308 template <typename T>
309 inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames(OrtAllocator* allocator) const {
310  return binding_utils::GetOutputNamesHelper(this->p_, allocator);
311 }
312 
313 template <typename T>
314 inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues() const {
315  AllocatorWithDefaultOptions allocator;
316  return binding_utils::GetOutputValuesHelper(this->p_, allocator);
317 }
318 
319 template <typename T>
320 inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues(OrtAllocator* allocator) const {
321  return binding_utils::GetOutputValuesHelper(this->p_, allocator);
322 }
323 
324 template <typename T>
325 inline void IoBindingImpl<T>::BindInput(const char* name, const Value& value) {
326  ThrowOnError(GetApi().BindInput(this->p_, name, value));
327 }
328 
329 template <typename T>
330 inline void IoBindingImpl<T>::BindOutput(const char* name, const Value& value) {
331  ThrowOnError(GetApi().BindOutput(this->p_, name, value));
332 }
333 
334 template <typename T>
335 inline void IoBindingImpl<T>::BindOutput(const char* name, const OrtMemoryInfo* mem_info) {
336  ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info));
337 }
338 
339 template <typename T>
341  GetApi().ClearBoundInputs(this->p_);
342 }
343 
344 template <typename T>
346  GetApi().ClearBoundOutputs(this->p_);
347 }
348 
349 template <typename T>
351  ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_));
352 }
353 
354 template <typename T>
356  ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_));
357 }
358 
359 namespace binding_utils {
360 inline std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
361  std::vector<std::string> result;
362  auto free_fn = detail::AllocatedFree(allocator);
363  using Ptr = std::unique_ptr<void, decltype(free_fn)>;
364 
365  char* buffer = nullptr;
366  size_t* lengths = nullptr;
367  size_t count = 0;
368  ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count));
369 
370  if (count == 0) {
371  return result;
372  }
373 
374  Ptr buffer_g(buffer, free_fn);
375  Ptr lengths_g(lengths, free_fn);
376 
377  result.reserve(count);
378  for (size_t i = 0; i < count; ++i) {
379  auto sz = *lengths;
380  result.emplace_back(buffer, sz);
381  buffer += sz;
382  ++lengths;
383  }
384  return result;
385 }
386 
387 inline std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
388  std::vector<Value> result;
389  size_t owned = 0;
390  size_t output_count = 0;
391  // Lambda to release the buffer when no longer needed and
392  // make sure that we destroy all instances on exception
393  auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
394  if (buffer) {
395  while (owned < output_count) {
396  auto* p = buffer + owned++;
397  GetApi().ReleaseValue(*p);
398  }
399  allocator->Free(allocator, buffer);
400  }
401  };
402  using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
403 
404  OrtValue** output_buffer = nullptr;
405  ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count));
406  if (output_count == 0) {
407  return result;
408  }
409 
410  Ptr buffer_g(output_buffer, free_fn);
411 
412  result.reserve(output_count);
413  for (size_t i = 0; i < output_count; ++i) {
414  result.emplace_back(output_buffer[i]);
415  ++owned;
416  }
417  return result;
418 }
419 
420 } // namespace binding_utils
421 } // namespace detail
422 
423 inline IoBinding::IoBinding(Session& session) {
424  ThrowOnError(GetApi().CreateIoBinding(session, &this->p_));
425 }
426 
427 inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
428  ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
429 }
430 
432  ThrowOnError(GetApi().CreateThreadingOptions(&p_));
433 }
434 
436  ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads));
437  return *this;
438 }
439 
441  ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads));
442  return *this;
443 }
444 
446  ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning));
447  return *this;
448 }
449 
452  return *this;
453 }
454 
455 inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
456  ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
457  return *this;
458 }
459 
460 inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
461  ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
462  return *this;
463 }
464 
465 inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
466  ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
467  return *this;
468 }
469 
470 inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
471  ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
472  if (strcmp(logid, "onnxruntime-node") == 0) {
473  ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
474  } else {
475  ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
476  }
477 }
478 
479 inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
480  ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_));
481  if (strcmp(logid, "onnxruntime-node") == 0) {
482  ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
483  } else {
484  ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
485  }
486 }
487 
488 inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) {
489  ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_));
490  if (strcmp(logid, "onnxruntime-node") == 0) {
491  ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
492  } else {
493  ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
494  }
495 }
496 
497 inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
498  OrtLoggingLevel logging_level, _In_ const char* logid) {
499  ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_));
500  if (strcmp(logid, "onnxruntime-node") == 0) {
501  ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
502  } else {
503  ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
504  }
505 }
506 
507 inline Env& Env::EnableTelemetryEvents() {
508  ThrowOnError(GetApi().EnableTelemetryEvents(p_));
509  return *this;
510 }
511 
514  return *this;
515 }
516 
517 inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) {
518  ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level));
519  return *this;
520 }
521 
522 inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) {
523  ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg));
524  return *this;
525 }
526 
527 inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg) {
528  std::vector<const char*> keys, values;
529  auto num_entries = options.size();
530  if (num_entries > 0) {
531  keys.reserve(num_entries);
532  values.reserve(num_entries);
533  for (const auto& entry : options) {
534  keys.push_back(entry.first.c_str());
535  values.push_back(entry.second.c_str());
536  }
537  }
538  ThrowOnError(GetApi().CreateAndRegisterAllocatorV2(p_, provider_type.c_str(), mem_info, arena_cfg, keys.data(), values.data(), num_entries));
539  return *this;
540 }
541 
542 inline CustomOpDomain::CustomOpDomain(const char* domain) {
543  ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
544 }
545 
546 inline void CustomOpDomain::Add(const OrtCustomOp* op) {
547  ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
548 }
549 
551  ThrowOnError(GetApi().CreateRunOptions(&p_));
552 }
553 
555  ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
556  return *this;
557 }
558 
560  ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
561  return *this;
562 }
563 
565  int out;
566  ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
567  return out;
568 }
569 
571  int out;
572  ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out));
573  return out;
574 }
575 
576 inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
577  ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
578  return *this;
579 }
580 
581 inline const char* RunOptions::GetRunTag() const {
582  const char* out;
583  ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
584  return out;
585 }
586 
587 inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) {
588  ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value));
589  return *this;
590 }
591 
593  ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
594  return *this;
595 }
596 
598  ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
599  return *this;
600 }
601 
602 namespace detail {
603 
604 template <typename T>
606  OrtSessionOptions* out;
607  ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out));
608  return SessionOptions{out};
609 }
610 
611 template <typename T>
612 inline std::string ConstSessionOptionsImpl<T>::GetConfigEntry(const char* config_key) const {
613  size_t size = 0;
614  // Feed nullptr for the data buffer to query the true size of the string value
615  Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size));
616 
617  std::string out;
618  out.resize(size);
619  Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size));
620  out.resize(size - 1); // remove the terminating character '\0'
621 
622  return out;
623 }
624 
625 template <typename T>
626 inline bool ConstSessionOptionsImpl<T>::HasConfigEntry(const char* config_key) const {
627  int out = 0;
628  Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out));
629  return static_cast<bool>(out);
630 }
631 
632 template <typename T>
634  if (!this->HasConfigEntry(config_key)) {
635  return def;
636  }
637 
638  return this->GetConfigEntry(config_key);
639 }
640 
641 template <typename T>
643  ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads));
644  return *this;
645 }
646 
647 template <typename T>
649  ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads));
650  return *this;
651 }
652 
653 template <typename T>
654 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
655  ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level));
656  return *this;
657 }
658 
659 template <typename T>
661  ThrowOnError(GetApi().SetDeterministicCompute(this->p_, value));
662  return *this;
663 }
664 
665 template <typename T>
666 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
667  ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath));
668  return *this;
669 }
670 
671 template <typename T>
672 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
673  ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix));
674  return *this;
675 }
676 
677 template <typename T>
679  ThrowOnError(GetApi().DisableProfiling(this->p_));
680  return *this;
681 }
682 
683 template <typename T>
685  ThrowOnError(GetApi().EnableOrtCustomOps(this->p_));
686  return *this;
687 }
688 
689 template <typename T>
691  ThrowOnError(GetApi().EnableMemPattern(this->p_));
692  return *this;
693 }
694 
695 template <typename T>
697  ThrowOnError(GetApi().DisableMemPattern(this->p_));
698  return *this;
699 }
700 
701 template <typename T>
703  ThrowOnError(GetApi().EnableCpuMemArena(this->p_));
704  return *this;
705 }
706 
707 template <typename T>
709  ThrowOnError(GetApi().DisableCpuMemArena(this->p_));
710  return *this;
711 }
712 
713 template <typename T>
714 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionMode execution_mode) {
715  ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode));
716  return *this;
717 }
718 
719 template <typename T>
721  ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
722  return *this;
723 }
724 
725 template <typename T>
727  ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level));
728  return *this;
729 }
730 
731 template <typename T>
733  ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain));
734  return *this;
735 }
736 
737 template <typename T>
738 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddConfigEntry(const char* config_key, const char* config_value) {
739  ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value));
740  return *this;
741 }
742 
743 template <typename T>
744 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddInitializer(const char* name, const OrtValue* ort_val) {
745  ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val));
746  return *this;
747 }
748 
749 template <typename T>
751  ThrowOnError(GetApi().DisablePerSessionThreads(this->p_));
752  return *this;
753 }
754 
755 template <typename T>
756 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializers(const std::vector<std::string>& names,
757  const std::vector<Value>& ort_values) {
758  const size_t inputs_num = names.size();
759  if (inputs_num != ort_values.size()) {
760  ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT);
761  }
762  std::vector<const char*> names_ptr;
763  std::vector<const OrtValue*> ort_values_ptrs;
764  names_ptr.reserve(inputs_num);
765  ort_values_ptrs.reserve(inputs_num);
766  for (size_t i = 0; i < inputs_num; ++i) {
767  names_ptr.push_back(names[i].c_str());
768  ort_values_ptrs.push_back(ort_values[i]);
769  }
770  ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num));
771  return *this;
772 }
773 
774 template <typename T>
775 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) {
776  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options));
777  return *this;
778 }
779 
780 template <typename T>
782  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options));
783  return *this;
784 }
785 
786 template <typename T>
787 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
788  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options));
789  return *this;
790 }
791 
792 template <typename T>
793 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
794  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options));
795  return *this;
796 }
797 
798 template <typename T>
800  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options));
801  return *this;
802 }
803 
804 template <typename T>
805 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
806  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options));
807  return *this;
808 }
809 
810 template <typename T>
812  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
813  return *this;
814 }
815 
816 template <typename T>
818  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_Dnnl(this->p_, &provider_options));
819  return *this;
820 }
821 
822 template <typename T>
824  const std::string& provider_name,
825  const std::unordered_map<std::string, std::string>& provider_options) {
826  auto num_entries = provider_options.size();
827  std::vector<const char*> keys, values;
828  if (num_entries > 0) {
829  keys.reserve(num_entries);
830  values.reserve(num_entries);
831 
832  for (const auto& entry : provider_options) {
833  keys.push_back(entry.first.c_str());
834  values.push_back(entry.second.c_str());
835  }
836  }
837 
838  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(),
839  keys.data(), values.data(), num_entries));
840 
841  return *this;
842 }
843 
844 template <typename T>
845 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
846  ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
847  return *this;
848 }
849 
850 template <typename T>
851 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
852  ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options));
853  return *this;
854 }
855 
856 template <typename T>
857 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
858  ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn));
859  return *this;
860 }
861 
862 template <typename T>
863 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
864  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options));
865  return *this;
866 }
867 
868 template <typename T>
869 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options) {
870  auto num_entries = provider_options.size();
871  std::vector<const char*> keys, values;
872  if (num_entries > 0) {
873  keys.reserve(num_entries);
874  values.reserve(num_entries);
875 
876  for (const auto& entry : provider_options) {
877  keys.push_back(entry.first.c_str());
878  values.push_back(entry.second.c_str());
879  }
880  }
881 
882  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO_V2(this->p_,
883  keys.data(), values.data(), num_entries));
884 
885  return *this;
886 }
887 
888 template <typename T>
890  const CustomOpConfigs& custom_op_configs) {
891  // Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by
892  // the custom op library.
893  for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) {
894  AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str());
895  }
896 
897  ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name));
898  return *this;
899 }
900 
901 template <typename T>
902 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsUsingFunction(const char* registration_function_name) {
903  ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name));
904  return *this;
905 }
906 
907 /// Session
908 template <typename T>
909 inline size_t ConstSessionImpl<T>::GetInputCount() const {
910  size_t out;
911  ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out));
912  return out;
913 }
914 
915 template <typename T>
916 inline size_t ConstSessionImpl<T>::GetOutputCount() const {
917  size_t out;
918  ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out));
919  return out;
920 }
921 
922 template <typename T>
924  size_t out;
925  ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out));
926  return out;
927 }
928 
929 template <typename T>
930 inline AllocatedStringPtr ConstSessionImpl<T>::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
931  char* out;
932  ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out));
933  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
934 }
935 
936 template <typename T>
937 inline AllocatedStringPtr ConstSessionImpl<T>::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const {
938  char* out;
939  ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out));
940  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
941 }
942 
943 template <typename T>
945  char* out;
946  ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out));
947  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
948 }
949 
950 template <typename T>
952  uint64_t out;
953  ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out));
954  return out;
955 }
956 
957 template <typename T>
959  OrtModelMetadata* out;
960  ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out));
961  return ModelMetadata{out};
962 }
963 
964 template <typename T>
966  OrtTypeInfo* out;
967  ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out));
968  return TypeInfo{out};
969 }
970 
971 template <typename T>
973  OrtTypeInfo* out;
974  ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out));
975  return TypeInfo{out};
976 }
977 
978 template <typename T>
980  OrtTypeInfo* out;
981  ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out));
982  return TypeInfo{out};
983 }
984 
985 template <typename T>
986 inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
987  const char* const* output_names, size_t output_count) {
988  std::vector<Value> output_values;
989  output_values.reserve(output_count);
990  for (size_t i = 0; i < output_count; i++)
991  output_values.emplace_back(nullptr);
992  Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count);
993  return output_values;
994 }
995 
996 template <typename T>
997 inline void SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
998  const char* const* output_names, Value* output_values, size_t output_count) {
999  static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
1000  auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1001  auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1002  ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
1003 }
1004 
1005 template <typename T>
1006 inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding& io_binding) {
1007  ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
1008 }
1009 
1010 template <typename T>
1011 inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1012  const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) {
1013  auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1014  auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1015  ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names,
1016  ort_input_values, input_count, output_names, output_count,
1017  ort_output_values, callback, user_data));
1018 }
1019 
1020 template <typename T>
1022  char* out = nullptr;
1023  ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out));
1024  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1025 }
1026 
1027 } // namespace detail
1028 
1030  ThrowOnError(GetApi().CreateSessionOptions(&this->p_));
1031 }
1032 
1033 /// CustomOpConfigs
1034 inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) {
1035  std::string config_key = "custom_op.";
1036 
1037  config_key += custom_op_name;
1038  config_key += ".";
1039  config_key += config;
1040 
1041  return config_key;
1042 }
1043 
1044 inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) {
1045  const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key);
1046  flat_configs_[full_flat_key] = config_value;
1047  return *this;
1048 }
1049 
1050 inline const std::unordered_map<std::string, std::string>& CustomOpConfigs::GetFlattenedConfigs() const {
1051  return flat_configs_;
1052 }
1053 
1054 inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
1055  ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_));
1056 }
1057 
1058 inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
1059  OrtPrepackedWeightsContainer* prepacked_weights_container) {
1060  ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_));
1061 }
1062 
1063 inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
1064  ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_));
1065 }
1066 
1067 inline Session::Session(const Env& env, const void* model_data, size_t model_data_length,
1068  const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
1069  ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
1070  prepacked_weights_container, &this->p_));
1071 }
1072 
1073 inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
1074  char* out;
1075  ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
1076  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1077 }
1078 
1079 inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const {
1080  char* out;
1081  ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
1082  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1083 }
1084 
1085 inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const {
1086  char* out;
1087  ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
1088  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1089 }
1090 
1092  char* out;
1093  ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
1094  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1095 }
1096 
1098  char* out;
1099  ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
1100  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1101 }
1102 
1103 inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const {
1104  char* out;
1105  ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
1106  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1107 }
1108 
1109 inline std::vector<AllocatedStringPtr> ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const {
1110  auto deletor = detail::AllocatedFree(allocator);
1111  std::vector<AllocatedStringPtr> result;
1112 
1113  char** out = nullptr;
1114  int64_t num_keys = 0;
1115  ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
1116  if (num_keys <= 0) {
1117  return result;
1118  }
1119 
1120  // array of pointers will be freed
1121  std::unique_ptr<void, decltype(deletor)> array_guard(out, deletor);
1122  // reserve may throw
1123  auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); };
1124  std::unique_ptr<char*, decltype(strings_deletor)> strings_guard(out, strings_deletor);
1125  result.reserve(static_cast<size_t>(num_keys));
1126  strings_guard.release();
1127  for (int64_t i = 0; i < num_keys; ++i) {
1128  result.push_back(AllocatedStringPtr(out[i], deletor));
1129  }
1130 
1131  return result;
1132 }
1133 
1134 inline int64_t ModelMetadata::GetVersion() const {
1135  int64_t out;
1136  ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
1137  return out;
1138 }
1139 
1140 namespace detail {
1141 
1142 template <typename T>
1143 inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl<T>::GetElementType() const {
1144  ONNXTensorElementDataType out;
1145  ThrowOnError(GetApi().GetTensorElementType(this->p_, &out));
1146  return out;
1147 }
1148 
1149 template <typename T>
1151  size_t out;
1152  ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out));
1153  return static_cast<size_t>(out);
1154 }
1155 
1156 template <typename T>
1158  size_t out;
1159  ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out));
1160  return out;
1161 }
1162 
1163 template <typename T>
1164 inline void TensorTypeAndShapeInfoImpl<T>::GetDimensions(int64_t* values, size_t values_count) const {
1165  ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count));
1166 }
1167 
1168 template <typename T>
1169 inline void TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions(const char** values, size_t values_count) const {
1170  ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count));
1171 }
1172 
1173 template <typename T>
1174 inline std::vector<int64_t> TensorTypeAndShapeInfoImpl<T>::GetShape() const {
1175  std::vector<int64_t> out(GetDimensionsCount(), 0);
1176  ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size()));
1177  return out;
1178 }
1179 
1180 template <typename T>
1182  const OrtTensorTypeAndShapeInfo* out;
1183  ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out));
1184  return ConstTensorTypeAndShapeInfo{out};
1185 }
1186 
1187 template <typename T>
1189  const OrtSequenceTypeInfo* out;
1190  ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out));
1191  return ConstSequenceTypeInfo{out};
1192 }
1193 
1194 template <typename T>
1196  const OrtMapTypeInfo* out;
1197  ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out));
1198  return ConstMapTypeInfo{out};
1199 }
1200 
1201 template <typename T>
1202 inline ONNXType TypeInfoImpl<T>::GetONNXType() const {
1203  ONNXType out;
1204  ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out));
1205  return out;
1206 }
1207 
1208 template <typename T>
1210  OrtTypeInfo* output;
1211  ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output));
1212  return TypeInfo{output};
1213 }
1214 
1215 template <typename T>
1217  OrtTypeInfo* info;
1218  ThrowOnError(GetApi().GetOptionalContainedTypeInfo(this->p_, &info));
1219  return TypeInfo{info};
1220 }
1221 
1222 template <typename T>
1223 inline ONNXTensorElementDataType MapTypeInfoImpl<T>::GetMapKeyType() const {
1224  ONNXTensorElementDataType out;
1225  ThrowOnError(GetApi().GetMapKeyType(this->p_, &out));
1226  return out;
1227 }
1228 
1229 template <typename T>
1231  OrtTypeInfo* output;
1232  ThrowOnError(GetApi().GetMapValueType(this->p_, &output));
1233  return TypeInfo{output};
1234 }
1235 
1236 template <typename T>
1238  const OrtOptionalTypeInfo* info;
1239  ThrowOnError(GetApi().CastTypeInfoToOptionalTypeInfo(this->p_, &info));
1240  return ConstOptionalTypeInfo{info};
1241 }
1242 
1243 } // namespace detail
1244 
1245 namespace detail {
1246 
1247 template <typename T>
1248 template <typename R>
1249 inline void ConstValueImpl<T>::GetOpaqueData(const char* domain, const char* type_name, R& out) const {
1250  ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R)));
1251 }
1252 
1253 template <typename T>
1254 inline bool ConstValueImpl<T>::IsTensor() const {
1255  int out;
1256  ThrowOnError(GetApi().IsTensor(this->p_, &out));
1257  return out != 0;
1258 }
1259 
1260 template <typename T>
1261 inline bool ConstValueImpl<T>::HasValue() const {
1262  int out;
1263  ThrowOnError(GetApi().HasValue(this->p_, &out));
1264  return out != 0;
1265 }
1266 
1267 template <typename T>
1268 inline size_t ConstValueImpl<T>::GetCount() const {
1269  size_t out;
1270  ThrowOnError(GetApi().GetValueCount(this->p_, &out));
1271  return out;
1272 }
1273 
1274 template <typename T>
1275 inline Value ConstValueImpl<T>::GetValue(int index, OrtAllocator* allocator) const {
1276  OrtValue* out;
1277  ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out));
1278  return Value{out};
1279 }
1280 
1281 template <typename T>
1283  size_t out;
1284  ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out));
1285  return out;
1286 }
1287 
1288 template <typename T>
1289 inline size_t ConstValueImpl<T>::GetStringTensorElementLength(size_t element_index) const {
1290  size_t out;
1291  ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out));
1292  return out;
1293 }
1294 
1295 template <typename T>
1296 template <typename R>
1297 inline const R* ConstValueImpl<T>::GetTensorData() const {
1298  R* out;
1299  ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), (void**)&out));
1300  return out;
1301 }
1302 
1303 template <typename T>
1304 inline const void* ConstValueImpl<T>::GetTensorRawData() const {
1305  void* out;
1306  ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), &out));
1307  return out;
1308 }
1309 
1310 template <typename T>
1312  OrtTypeInfo* output;
1313  ThrowOnError(GetApi().GetTypeInfo(this->p_, &output));
1314  return TypeInfo{output};
1315 }
1316 
1317 template <typename T>
1319  OrtTensorTypeAndShapeInfo* output;
1320  ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output));
1321  return TensorTypeAndShapeInfo{output};
1322 }
1323 
1324 template <typename T>
1326  const OrtMemoryInfo* mem_info;
1327  ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info));
1328  return ConstMemoryInfo(mem_info);
1329 }
1330 
1331 template <typename T>
1332 inline void ConstValueImpl<T>::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const {
1333  ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer));
1334 }
1335 
1336 template <typename T>
1337 inline std::string ConstValueImpl<T>::GetStringTensorElement(size_t element_index) const {
1338  size_t buffer_length;
1339  ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &buffer_length));
1340 
1341  std::string s;
1342  s.resize(buffer_length);
1343  ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, &s[0]));
1344  return s;
1345 }
1346 
1347 template <typename T>
1348 inline void ConstValueImpl<T>::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
1349  ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count));
1350 }
1351 
1352 #if !defined(DISABLE_SPARSE_TENSORS)
1353 template <typename T>
1354 inline OrtSparseFormat ConstValueImpl<T>::GetSparseFormat() const {
1355  OrtSparseFormat format;
1356  ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format));
1357  return format;
1358 }
1359 
1360 template <typename T>
1362  OrtTensorTypeAndShapeInfo* output;
1363  ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output));
1364  return TensorTypeAndShapeInfo{output};
1365 }
1366 
1367 template <typename T>
1368 inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const {
1369  OrtTensorTypeAndShapeInfo* output;
1370  ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output));
1371  return TensorTypeAndShapeInfo{output};
1372 }
1373 
1374 template <typename T>
1375 template <typename R>
1376 inline const R* ConstValueImpl<T>::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const {
1377  const void* out;
1378  ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out));
1379  return reinterpret_cast<const R*>(out);
1380 }
1381 
1382 template <typename T>
1384  int out;
1385  ThrowOnError(GetApi().IsSparseTensor(this->p_, &out));
1386  return out != 0;
1387 }
1388 
1389 template <typename T>
1390 template <typename R>
1392  const void* out;
1393  ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out));
1394  return reinterpret_cast<const R*>(out);
1395 }
1396 
1397 #endif
1398 
1399 template <typename T>
1400 void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
1401  ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
1402 }
1403 
1404 template <typename T>
1405 void ValueImpl<T>::FillStringTensorElement(const char* s, size_t index) {
1406  ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index));
1407 }
1408 
1409 template <typename T>
1410 inline char* ValueImpl<T>::GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length) {
1411  char* result;
1412  ThrowOnError(GetApi().GetResizedStringTensorElementBuffer(this->p_, index, buffer_length, &result));
1413  return result;
1414 }
1415 
1416 template <typename T>
1418  void* out;
1419  ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out));
1420  return out;
1421 }
1422 
1423 template <typename T>
1424 template <typename R>
1426  R* out;
1427  ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out));
1428  return out;
1429 }
1430 
1431 template <typename T>
1432 template <typename R>
1433 R& ValueImpl<T>::At(const std::vector<int64_t>& location) {
1434  static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
1435  R* out;
1436  ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out));
1437  return *out;
1438 }
1439 
1440 #if !defined(DISABLE_SPARSE_TENSORS)
1441 template <typename T>
1442 void ValueImpl<T>::UseCooIndices(int64_t* indices_data, size_t indices_num) {
1443  ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num));
1444 }
1445 
1446 template <typename T>
1447 void ValueImpl<T>::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) {
1448  ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num));
1449 }
1450 
1451 template <typename T>
1452 void ValueImpl<T>::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) {
1453  ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data));
1454 }
1455 
1456 template <typename T>
1458  const int64_t* indices_data, size_t indices_num) {
1459  ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape,
1460  values_param.values_shape_len, values_param.data.p_data,
1461  indices_data, indices_num));
1462 }
1463 
1464 template <typename T>
1467  const int64_t* inner_indices_data, size_t inner_indices_num,
1468  const int64_t* outer_indices_data, size_t outer_indices_num) {
1469  ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1470  inner_indices_data, inner_indices_num,
1471  outer_indices_data, outer_indices_num));
1472 }
1473 
1474 template <typename T>
1477  const Shape& indices_shape,
1478  const int32_t* indices_data) {
1479  ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1480  indices_shape.shape, indices_shape.shape_len,
1481  indices_data));
1482 }
1483 
1484 #endif // !defined(DISABLE_SPARSE_TENSORS)
1485 
1486 } // namespace detail
1487 
1488 template <typename T>
1489 inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
1490  return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
1491 }
1492 
1493 inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1494  ONNXTensorElementDataType type) {
1495  OrtValue* out;
1496  ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
1497  return Value{out};
1498 }
1499 
1500 template <typename T>
1501 inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
1502  return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
1503 }
1504 
1505 inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
1506  OrtValue* out;
1507  ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
1508  return Value{out};
1509 }
1510 
1511 #if !defined(DISABLE_SPARSE_TENSORS)
1512 
1513 template <typename T>
1514 inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1515  const Shape& values_shape) {
1516  return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType<T>::type);
1517 }
1518 
1519 inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1520  const Shape& values_shape, ONNXTensorElementDataType type) {
1521  OrtValue* out;
1522  ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
1523  values_shape.shape, values_shape.shape_len, type, &out));
1524  return Value{out};
1525 }
1526 
1527 template <typename T>
1528 inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) {
1529  return CreateSparseTensor(allocator, dense_shape, TypeToTensorType<T>::type);
1530 }
1531 
1532 inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape,
1533  ONNXTensorElementDataType type) {
1534  OrtValue* out;
1535  ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out));
1536  return Value{out};
1537 }
1538 #endif // !defined(DISABLE_SPARSE_TENSORS)
1539 
1540 inline Value Value::CreateMap(const Value& keys, const Value& values) {
1541  OrtValue* out;
1542  const OrtValue* inputs[2] = {keys, values};
1543  ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
1544  return Value{out};
1545 }
1546 
1547 inline Value Value::CreateSequence(const std::vector<Value>& values) {
1548  OrtValue* out;
1549  std::vector<const OrtValue*> values_ort{values.data(), values.data() + values.size()};
1550  ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
1551  return Value{out};
1552 }
1553 
1554 template <typename T>
1555 inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
1556  OrtValue* out;
1557  ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
1558  return Value{out};
1559 }
1560 
1561 //
1562 // Custom OP Inlines
1563 //
1564 inline Logger::Logger(const OrtLogger* logger) : logger_(logger) {
1565  Ort::ThrowOnError(GetApi().Logger_GetLoggingSeverityLevel(this->logger_, &this->cached_severity_level_));
1566 }
1567 
1568 inline OrtLoggingLevel Logger::GetLoggingSeverityLevel() const noexcept {
1569  return cached_severity_level_;
1570 }
1571 
1572 inline Status Logger::LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
1573  const char* func_name, const char* message) const noexcept {
1574  OrtStatus* status = GetApi().Logger_LogMessage(logger_, log_severity_level, message, file_path, line_number,
1575  func_name);
1576  return Status{status};
1577 }
1578 
1579 // Disable warnings about the format string not being a literal (-Wformat-nonliteral and -Wformat-security)
1580 // for gcc and clang. The alternative is to use actual C-style variadic parameters and apply
1581 // __attribute__(format(printf...)), which does not work with variadic templates.
1582 #if defined(__GNUC__)
1583 #pragma GCC diagnostic push
1584 #pragma GCC diagnostic ignored "-Wformat-nonliteral"
1585 #pragma GCC diagnostic ignored "-Wformat-security"
1586 #elif defined(__clang__)
1587 #pragma clang diagnostic push
1588 #pragma clang diagnostic ignored "-Wformat-nonliteral"
1589 #pragma clang diagnostic ignored "-Wformat-security"
1590 #endif
1591 template <typename... Args>
1592 inline Status Logger::LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path,
1593  int line_number, const char* func_name, const char* format,
1594  Args&&... args) const noexcept {
1595  int msg_len = std::snprintf(nullptr, 0U, format, std::forward<Args>(args)...);
1596 
1597  if (msg_len < 0) { // Formatting error
1598  return Status("Failed to log message due to formatting error", OrtErrorCode::ORT_FAIL);
1599  }
1600 
1601  OrtStatus* status = nullptr;
1602  const size_t buffer_size = static_cast<size_t>(msg_len) + 1U;
1603 
1604  constexpr size_t kStackBufferSize = 1024;
1605 
1606  if (buffer_size < kStackBufferSize) {
1607  char buffer[kStackBufferSize];
1608  snprintf(buffer, kStackBufferSize, format, std::forward<Args>(args)...);
1609  status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer, file_path, line_number, func_name);
1610  } else {
1611  // std::make_unique is only supported starting at C++14.
1612 #if (__cplusplus >= 201402L) || (_MSC_VER >= 1900)
1613  auto buffer = std::make_unique<char[]>(buffer_size);
1614 #else
1615  std::unique_ptr<char[]> buffer(new char[buffer_size]);
1616 #endif
1617  std::snprintf(buffer.get(), buffer_size, format, std::forward<Args>(args)...);
1618  status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer.get(), file_path, line_number, func_name);
1619  }
1620 
1621  return Status{status};
1622 }
1623 // Re-enable -Wformat-nonliteral and -Wformat-security
1624 #if defined(__GNUC__)
1625 #pragma GCC diagnostic pop
1626 #elif defined(__clang__)
1627 #pragma clang diagnostic pop
1628 #endif
1629 
1630 inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) {
1631 }
1632 
1633 inline size_t KernelContext::GetInputCount() const {
1634  size_t out = 0;
1635  Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out));
1636  return out;
1637 }
1638 
1639 inline size_t KernelContext::GetOutputCount() const {
1640  size_t out = 0;
1641  Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out));
1642  return out;
1643 }
1644 
1646  const OrtValue* out = nullptr;
1647  Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out));
1648  return ConstValue{out};
1649 }
1650 
1651 inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const {
1652  OrtValue* out = nullptr;
1653  Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out));
1654  return UnownedValue(out);
1655 }
1656 
1657 inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int64_t>& dims) const {
1658  OrtValue* out = nullptr;
1659  Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out));
1660  return UnownedValue(out);
1661 }
1662 
1664  void* out = nullptr;
1665  Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
1666  return out;
1667 }
1668 
1669 inline OrtAllocator* KernelContext::GetAllocator(const OrtMemoryInfo& memory_info) const {
1670  OrtAllocator* out = nullptr;
1671  Ort::ThrowOnError(GetApi().KernelContext_GetAllocator(ctx_, &memory_info, &out));
1672  return out;
1673 }
1674 
1676  const OrtLogger* out = nullptr;
1677  ThrowOnError(GetApi().KernelContext_GetLogger(this->ctx_, &out));
1678  return Logger{out};
1679 }
1680 
1681 inline void KernelContext::ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const {
1682  ThrowOnError(GetApi().KernelContext_ParallelFor(ctx_, fn, total, num_batch, usr_data));
1683 }
1684 
1685 inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) {
1686  Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_));
1687 }
1688 
1689 namespace detail {
1690 template <typename T>
1692  OrtKernelInfo* info_copy = nullptr;
1693  Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy));
1694  return KernelInfo{info_copy};
1695 }
1696 
1697 template <typename T>
1698 inline size_t KernelInfoImpl<T>::GetInputCount() const {
1699  size_t out = 0;
1700  ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out));
1701  return out;
1702 }
1703 
1704 template <typename T>
1705 inline size_t KernelInfoImpl<T>::GetOutputCount() const {
1706  size_t out = 0;
1707  ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out));
1708  return out;
1709 }
1710 
1711 template <typename T>
1713  size_t size = 0;
1714 
1715  // Feed nullptr for the data buffer to query the true size of the string value
1716  Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size));
1717 
1718  std::string out;
1719  out.resize(size);
1720  Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size));
1721  out.resize(size - 1); // remove the terminating character '\0'
1722 
1723  return out;
1724 }
1725 
1726 template <typename T>
1728  size_t size = 0;
1729 
1730  // Feed nullptr for the data buffer to query the true size of the string value
1731  Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size));
1732 
1733  std::string out;
1734  out.resize(size);
1735  Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size));
1736  out.resize(size - 1); // remove the terminating character '\0'
1737 
1738  return out;
1739 }
1740 
1741 template <typename T>
1743  OrtTypeInfo* out = nullptr;
1744  ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out));
1745  return TypeInfo{out};
1746 }
1747 
1748 template <typename T>
1750  OrtTypeInfo* out = nullptr;
1751  ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out));
1752  return TypeInfo{out};
1753 }
1754 
1755 template <typename T>
1756 inline Value KernelInfoImpl<T>::GetTensorAttribute(const char* name, OrtAllocator* allocator) const {
1757  OrtValue* out = nullptr;
1758  ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out));
1759  return Value{out};
1760 }
1761 
1762 template <typename T>
1763 inline ConstValue KernelInfoImpl<T>::GetTensorConstantInput(size_t index, int* is_constant) const {
1764  const OrtValue* out = nullptr;
1765  ThrowOnError(GetApi().KernelInfoGetConstantInput_tensor(this->p_, index, is_constant, &out));
1766  return ConstValue{out};
1767 }
1768 
1769 template <typename T>
1771  size_t size = 0;
1772 
1773  // Feed nullptr for the data buffer to query the true size of the string value
1774  Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, nullptr, &size));
1775 
1776  std::string out;
1777  out.resize(size);
1778  Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, &out[0], &size));
1779  out.resize(size - 1); // remove the terminating character '\0'
1780 
1781  return out;
1782 }
1783 
1784 template <typename T>
1786  const OrtLogger* out = nullptr;
1787  ThrowOnError(GetApi().KernelInfo_GetLogger(this->p_, &out));
1788  return Logger{out};
1789 }
1790 
1791 inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
1792  Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
1793 }
1794 
1795 inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) {
1796  Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out));
1797 }
1798 
1799 inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) {
1800  size_t size = 0;
1801  // Feed nullptr for the data buffer to query the true size of the string attribute
1802  Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size));
1803 
1804  std::string out;
1805  out.resize(size);
1806  Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size));
1807  out.resize(size - 1); // remove the terminating character '\0'
1808  out.swap(result);
1809 }
1810 
1811 inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>& result) {
1812  size_t size = 0;
1813  // Feed nullptr for the data buffer to query the true size of the attribute
1814  Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size));
1815 
1816  std::vector<float> out;
1817  out.resize(size);
1818  Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size));
1819  out.swap(result);
1820 }
1821 
1822 inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>& result) {
1823  size_t size = 0;
1824 
1825  // Feed nullptr for the data buffer to query the true size of the attribute
1826  Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size));
1827 
1828  std::vector<int64_t> out;
1829  out.resize(size);
1830  Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size));
1831  out.swap(result);
1832 }
1833 } // namespace detail
1834 
1835 inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}
1836 
1837 inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
1838 
1839 inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version,
1840  const char** type_constraint_names,
1841  const ONNXTensorElementDataType* type_constraint_values,
1842  size_t type_constraint_count,
1843  const OpAttr* attr_values, size_t attr_count,
1844  size_t input_count, size_t output_count) {
1845  static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*),
1846  "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely");
1847  auto attr_input_values = reinterpret_cast<const OrtOpAttr* const*>(attr_values);
1848  OrtOp* op;
1849  Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
1850  static_cast<int>(type_constraint_count),
1851  attr_input_values,
1852  static_cast<int>(attr_count),
1853  static_cast<int>(input_count),
1854  static_cast<int>(output_count), &op));
1855  return Op{op};
1856 }
1857 
1858 inline void Op::Invoke(const OrtKernelContext* context,
1859  const Value* input_values,
1860  size_t input_count,
1861  Value* output_values,
1862  size_t output_count) {
1863  static_assert(sizeof(Value) == sizeof(OrtValue*),
1864  "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
1865  auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1866  auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1867  Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast<int>(input_count),
1868  ort_output_values, static_cast<int>(output_count)));
1869 }
1870 
1871 inline void Op::Invoke(const OrtKernelContext* context,
1872  const OrtValue* const* input_values,
1873  size_t input_count,
1874  OrtValue* const* output_values,
1875  size_t output_count) {
1876  Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast<int>(input_count),
1877  output_values, static_cast<int>(output_count)));
1878 }
1879 
1880 inline std::string GetVersionString() {
1881  return OrtGetApiBase()->GetVersionString();
1882 }
1883 
1885  return GetApi().GetBuildInfoString();
1886 }
1887 
1888 inline std::vector<std::string> GetAvailableProviders() {
1889  char** providers;
1890  int len;
1891 
1892  auto release_fn = [&len](char** providers) {
1893  // This should always return nullptr.
1894  ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len));
1895  };
1896 
1897  ThrowOnError(GetApi().GetAvailableProviders(&providers, &len));
1898  std::unique_ptr<char*, decltype(release_fn)> guard(providers, release_fn);
1899  std::vector<std::string> available_providers;
1900  available_providers.reserve(static_cast<size_t>(len));
1901  for (int i = 0; i < len; ++i) {
1902  available_providers.emplace_back(providers[i]);
1903  }
1904  return available_providers;
1905 }
1906 
1907 template <typename TOp, typename TKernel, bool WithStatus>
1908 void CustomOpBase<TOp, TKernel, WithStatus>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
1909  ConstSessionOptions options) const {
1910  const TOp* derived = static_cast<const TOp*>(this);
1911  std::vector<std::string> keys = derived->GetSessionConfigKeys();
1912 
1913  out.reserve(keys.size());
1914 
1915  std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), "");
1916  const size_t prefix_size = config_entry_key.length();
1917 
1918  for (const auto& key : keys) {
1919  config_entry_key.resize(prefix_size);
1920  config_entry_key.append(key);
1921  out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), "");
1922  }
1923 }
1924 
1925 inline ShapeInferContext::ShapeInferContext(const OrtApi* ort_api,
1926  OrtShapeInferContext* ctx) : ort_api_(ort_api), ctx_(ctx) {
1927  size_t input_count = 0;
1928  Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputCount(ctx_, &input_count));
1929  for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
1930  OrtTensorTypeAndShapeInfo* info{};
1931  Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputTypeShape(ctx, ith_input, &info));
1932  TensorTypeAndShapeInfo type_shape_info(info);
1933  auto integer_shape = type_shape_info.GetShape();
1934  std::vector<const char*> symbolic_shape(integer_shape.size(), {});
1935  type_shape_info.GetSymbolicDimensions(&symbolic_shape[0], integer_shape.size());
1936  Shape shape;
1937  for (size_t ith = 0; ith < integer_shape.size(); ++ith) {
1938  if (symbolic_shape[ith] && std::string{symbolic_shape[ith]}.size() > 0) {
1939  shape.emplace_back(symbolic_shape[ith]);
1940  } else {
1941  shape.emplace_back(integer_shape[ith]);
1942  }
1943  }
1944  input_shapes_.push_back(std::move(shape));
1945  type_shape_info.release();
1946  }
1947 }
1948 
1949 inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape) {
1950  OrtTensorTypeAndShapeInfo* info = {};
1951  RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info));
1952 
1953  using InfoPtr = std::unique_ptr<OrtTensorTypeAndShapeInfo, std::function<void(OrtTensorTypeAndShapeInfo*)>>;
1954 
1955  InfoPtr info_ptr(info, [this](OrtTensorTypeAndShapeInfo* obj) {
1956  ort_api_->ReleaseTensorTypeAndShapeInfo(obj);
1957  });
1958 
1959  std::vector<int64_t> integer_dims;
1960  std::vector<const char*> symbolic_dims;
1961 
1962  for (const auto dim : shape) {
1963  if (dim.IsInt()) {
1964  integer_dims.push_back(dim.IsInt());
1965  symbolic_dims.push_back("");
1966  } else {
1967  if (!dim.AsSym() || std::string{dim.AsSym()}.empty()) {
1968  ORT_CXX_API_THROW("Symbolic dim must not be an empty string", ORT_INVALID_ARGUMENT);
1969  }
1970  integer_dims.push_back(SymbolicInteger::INVALID_INT_DIM);
1971  symbolic_dims.push_back(dim.AsSym());
1972  }
1973  }
1974 
1975  RETURN_ON_API_FAIL(ort_api_->SetDimensions(info, integer_dims.data(), integer_dims.size()));
1976  RETURN_ON_API_FAIL(ort_api_->SetSymbolicDimensions(info, symbolic_dims.data(), symbolic_dims.size()));
1977  RETURN_ON_API_FAIL(ort_api_->ShapeInferContext_SetOutputTypeShape(ctx_, indice, info));
1978  return Status{nullptr};
1979 }
1980 
1981 inline int64_t ShapeInferContext::GetAttrInt(const char* attr_name) {
1982  const auto* attr = GetAttrHdl(attr_name);
1983  int64_t i = {};
1984  size_t out = {};
1985  Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INT, &i, sizeof(i), &out));
1986  return i;
1987 }
1988 
1989 inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_name) {
1990  const auto* attr = GetAttrHdl(attr_name);
1991  int64_t i = {};
1992  size_t out = {};
1993  // first call to get the bytes needed
1994  auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i, sizeof(i), &out);
1995  if (status) {
1996  size_t num_i = out / sizeof(int64_t);
1997  ShapeInferContext::Ints ints(num_i, 0);
1998  Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out));
1999  return ints;
2000  } else {
2001  return {i};
2002  }
2003 }
2004 
2005 inline float ShapeInferContext::GetAttrFloat(const char* attr_name) {
2006  const auto* attr = GetAttrHdl(attr_name);
2007  float f = {};
2008  size_t out = {};
2009  Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOAT, &f, sizeof(f), &out));
2010  return f;
2011 }
2012 
2013 inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* attr_name) {
2014  const auto* attr = GetAttrHdl(attr_name);
2015  float f = {};
2016  size_t out = {};
2017  // first call to get the bytes needed
2018  auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f, sizeof(f), &out);
2019  if (status) {
2020  size_t num_f = out / sizeof(float);
2021  ShapeInferContext::Floats floats(num_f, 0);
2022  Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out));
2023  return floats;
2024  } else {
2025  return {f};
2026  }
2027 }
2028 
2029 inline std::string ShapeInferContext::GetAttrString(const char* attr_name) {
2030  const auto* attr = GetAttrHdl(attr_name);
2031  char c = {};
2032  size_t out = {};
2033  // first call to get the bytes needed
2034  auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, &c, sizeof(char), &out);
2035  if (status) {
2036  std::vector<char> chars(out, '\0');
2037  Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, chars.data(), out, &out));
2038  return {chars.data()};
2039  } else {
2040  return {c};
2041  }
2042 }
2043 
2044 inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* attr_name) {
2045  const auto* attr = GetAttrHdl(attr_name);
2046  char c = {};
2047  size_t out = {};
2048  // first call to get the bytes needed
2049  auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c, sizeof(char), &out);
2050  if (status) {
2051  std::vector<char> chars(out, '\0');
2052  Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, chars.data(), out, &out));
2054  char* char_st = chars.data();
2055  char* char_ed = char_st + out;
2056  while (char_st < char_ed) {
2057  strings.emplace_back(char_st);
2058  while (*char_st != '\0') {
2059  char_st++;
2060  }
2061  char_st++;
2062  }
2063  return strings;
2064  } else {
2065  return {std::string{c}};
2066  }
2067 }
2068 
2069 inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) const {
2070  const OrtOpAttr* attr_hdl = {};
2071  Ort::ThrowOnError(ort_api_->ShapeInferContext_GetAttribute(ctx_, attr_name, &attr_hdl));
2072  return attr_hdl;
2073 }
2074 
2075 } // namespace Ort
OrtMemoryInfoDeviceType GetDeviceType() const
Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T *file_path, int line_number, const char *func_name, const char *format, Args &&...args) const noexcept
std::vector< int64_t > Ints
void Invoke(const OrtKernelContext *context, const Value *input_values, size_t input_count, Value *output_values, size_t output_count)
std::string GetBuildInfoString()
This function returns the onnxruntime build information: including git branch, git commit id...
SessionOptionsImpl & SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn)
Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn.
GLuint GLsizei const GLchar * message
Definition: glcorearb.h:2543
size_t GetElementCount() const
Wraps OrtApi::GetTensorShapeElementCount.
MemoryAllocation & operator=(const MemoryAllocation &)=delete
Env & DisableTelemetryEvents()
Wraps OrtApi::EnableTelemetryEvents.
AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of the overridable initializer name at then specified index.
ThreadingOptions & SetGlobalSpinControl(int allow_spinning)
Wraps OrtApi::SetGlobalSpinControl.
union Ort::detail::OrtSparseValuesParam::@164 data
OrtAllocator * GetAllocator(const OrtMemoryInfo &memory_info) const
std::string GetErrorMessage() const
png_const_structrp png_const_inforp info_ptr
Definition: png.h:1939
size_t GetInputCount() const
Returns the number of model inputs.
SessionOptionsImpl & DisablePerSessionThreads()
Wraps OrtApi::DisablePerSessionThreads.
std::vector< std::string > Strings
Env & UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level)
Wraps OrtApi::UpdateEnvWithCustomLogLevel.
void UseBlockSparseIndices(const Shape &indices_shape, int32_t *indices_data)
Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSp...
RunOptions & AddConfigEntry(const char *config_key, const char *config_value)
Wraps OrtApi::AddRunConfigEntry.
ConstMapTypeInfo GetMapTypeInfo() const
Wraps OrtApi::CastTypeInfoToMapTypeInfo.
void FillSparseTensorCsr(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const int64_t *inner_indices_data, size_t inner_indices_num, const int64_t *outer_indices_data, size_t outer_indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
TypeInfo GetInputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetInputTypeInfo.
size_t GetOutputCount() const
void GetSymbolicDimensions(const char **values, size_t values_count) const
Wraps OrtApi::GetSymbolicDimensions.
Type information that may contain either TensorTypeAndShapeInfo or the information about contained se...
SessionOptionsImpl & EnableMemPattern()
Wraps OrtApi::EnableMemPattern.
SessionOptionsImpl & EnableCpuMemArena()
Wraps OrtApi::EnableCpuMemArena.
static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1)
std::vector< float > Floats
bool IsTensor() const
Returns true if Value is a tensor, false for other types like map/sequence/etc.
AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of output name at then specified index.
void GetOpaqueData(const char *domain, const char *type_name, R &) const
Obtains a pointer to a user defined data for experimental purposes
const void * GetTensorRawData() const
Returns a non-typed pointer to a tensor contained data.
ConstMemoryInfo GetInfo() const
void GetStringTensorElement(size_t buffer_length, size_t element_index, void *buffer) const
The API copies UTF-8 encoded bytes for the requested string element contained within a tensor or a sp...
SessionOptionsImpl & SetLogId(const char *logid)
Wraps OrtApi::SetSessionLogId.
void swap(UT::ArraySet< Key, MULTI, MAX_LOAD_FACTOR_256, Clearer, Hash, KeyEqual > &a, UT::ArraySet< Key, MULTI, MAX_LOAD_FACTOR_256, Clearer, Hash, KeyEqual > &b)
Definition: UT_ArraySet.h:1631
Value GetValue(int index, OrtAllocator *allocator) const
uint64_t GetProfilingStartTimeNs() const
Wraps OrtApi::SessionGetProfilingStartTimeNs.
SessionOptionsImpl & AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN.
UnownedValue GetOutput(size_t index, const int64_t *dim_values, size_t dim_count) const
static Value CreateSequence(const std::vector< Value > &values)
Creates an OrtValue with a Sequence Onnx type representation. The API would ref-count the supplied Or...
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
void ThrowOnError(OrtStatus *ort_status)
GLsizei const GLfloat * value
Definition: glcorearb.h:824
void BindInput(const char *name, const Value &)
ONNXTensorElementDataType GetMapKeyType() const
Wraps OrtApi::GetMapKeyType.
ConstValue GetInput(size_t index) const
SessionOptionsImpl & AppendExecutionProvider_CANN(const OrtCANNProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl.
MemoryAllocation GetAllocation(size_t size)
std::unique_ptr< char, detail::AllocatedFree > AllocatedStringPtr
unique_ptr typedef used to own strings allocated by OrtAllocators and release them at the end of the ...
SessionOptionsImpl & DisableCpuMemArena()
Wraps OrtApi::DisableCpuMemArena.
AllocatedStringPtr GetDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the description.
bool HasConfigEntry(const char *config_key) const
Wraps OrtApi::HasSessionConfigEntry.
void UseCsrIndices(int64_t *inner_data, size_t inner_num, int64_t *outer_data, size_t outer_num)
Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tens...
ONNXTensorElementDataType GetElementType() const
Wraps OrtApi::GetTensorElementType.
static Value CreateOpaque(const char *domain, const char *type_name, const T &value)
Creates an OrtValue wrapping an Opaque type. This is used for experimental support of non-tensor type...
Env & CreateAndRegisterAllocator(const OrtMemoryInfo *mem_info, const OrtArenaCfg *arena_cfg)
Wraps OrtApi::CreateAndRegisterAllocator.
void FillStringTensorElement(const char *s, size_t index)
Set a single string in a string tensor
std::vector< SymbolicInteger > Shape
std::vector< Value > Run(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, size_t output_count)
Run the model returning results in an Ort allocated vector.
Wrapper around ::OrtModelMetadata.
GLint level
Definition: glcorearb.h:108
SessionOptionsImpl & AddInitializer(const char *name, const OrtValue *ort_val)
Wraps OrtApi::AddInitializer.
const R * GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t &num_indices) const
The API retrieves a pointer to the internal indices buffer. The API merely performs a convenience dat...
std::string GetConfigEntryOrDefault(const char *config_key, const std::string &def)
OpAttr(const char *name, const void *data, int len, OrtOpAttrType type)
SessionOptionsImpl & AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT.
GLdouble s
Definition: glad.h:3009
SessionOptionsImpl & EnableOrtCustomOps()
Wraps OrtApi::EnableOrtCustomOps.
size_t GetDimensionsCount() const
Wraps OrtApi::GetDimensionsCount.
AllocatedStringPtr LookupCustomMetadataMapAllocated(const char *key, OrtAllocator *allocator) const
Looks up a value by a key in the Custom Metadata map.
R & At(const std::vector< int64_t > &location)
std::string GetAllocatorName() const
void RunAsync(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, Value *output_values, size_t output_count, RunAsyncCallbackFn callback, void *user_data)
Run the model asynchronously in a thread owned by intra op thread pool.
SessionOptionsImpl & AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA.
Wrapper around OrtValue.
SessionOptionsImpl & SetIntraOpNumThreads(int intra_op_num_threads)
Wraps OrtApi::SetIntraOpNumThreads.
**But if you need a result
Definition: thread.h:613
SessionOptionsImpl & RegisterCustomOpsUsingFunction(const char *function_name)
Wraps OrtApi::RegisterCustomOpsUsingFunction.
IoBinding(std::nullptr_t)
Create an empty object for convenience. Sometimes, we want to initialize members later.
The Env (Environment)
OrtLoggingLevel GetLoggingSeverityLevel() const noexcept
ConstOptionalTypeInfo GetOptionalTypeInfo() const
wraps OrtApi::CastTypeInfoToOptionalTypeInfo
std::vector< AllocatedStringPtr > GetCustomMetadataMapKeysAllocated(OrtAllocator *allocator) const
Returns a vector of copies of the custom metadata keys.
SessionOptionsImpl & SetOptimizedModelFilePath(const ORTCHAR_T *optimized_model_file)
Wraps OrtApi::SetOptimizedModelFilePath.
Env(std::nullptr_t)
Create an empty Env object, must be assigned a valid one to be used.
ThreadingOptions & SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn)
Wraps OrtApi::SetGlobalCustomJoinThreadFn.
float8e4m3fnuz (Float8 Floating Point) data type
const R * GetSparseTensorValues() const
The API returns a pointer to an internal buffer of the sparse tensor containing non-zero values...
TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const
The API returns type and shape information for the specified indices. Each supported indices have the...
GLuint buffer
Definition: glcorearb.h:660
SessionOptionsImpl & SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level)
Wraps OrtApi::SetSessionGraphOptimizationLevel.
const R * GetTensorData() const
Returns a const typed pointer to the tensor contained data. No type checking is performed, the caller must ensure the type matches the tensor type.
SessionOptionsImpl & SetCustomThreadCreationOptions(void *ort_custom_thread_creation_options)
Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions.
void * GetTensorMutableRawData()
Returns a non-typed non-const pointer to a tensor contained data.
static Op Create(const OrtKernelInfo *info, const char *op_name, const char *domain, int version, const char **type_constraint_names, const ONNXTensorElementDataType *type_constraint_values, size_t type_constraint_count, const OpAttr *attr_values, size_t attr_count, size_t input_count, size_t output_count)
std::string GetConfigEntry(const char *config_key) const
Wraps OrtApi::GetSessionConfigEntry.
void ThrowStatus(const Status &st)
std::vector< std::string > GetOutputNames() const
float8e4m3fn (Float8 Floating Point) data type
SessionOptionsImpl & AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2 &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT.
ConstValue GetTensorConstantInput(size_t index, int *is_constant) const
GLuint GLsizei const GLuint const GLintptr * offsets
Definition: glcorearb.h:2621
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
bool IsSparseTensor() const
Returns true if the OrtValue contains a sparse tensor
std::vector< Value > GetOutputValuesHelper(const OrtIoBinding *binding, OrtAllocator *)
ModelMetadata GetModelMetadata() const
Wraps OrtApi::SessionGetModelMetadata.
AllocatedStringPtr GetGraphNameAllocated(OrtAllocator *allocator) const
Used for interop with the C API.
ThreadingOptions & SetGlobalInterOpNumThreads(int inter_op_num_threads)
Wraps OrtApi::SetGlobalInterOpNumThreads.
void FillSparseTensorBlockSparse(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const Shape &indices_shape, const int32_t *indices_data)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
TypeInfo GetOutputTypeInfo(size_t index) const
std::vector< int64_t > GetShape() const
Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape.
Wrapper around OrtMemoryInfo.
std::vector< std::string > GetAvailableProviders()
This is a C++ wrapper for OrtApi::GetAvailableProviders() and returns a vector of strings representin...
GLfloat f
Definition: glcorearb.h:1926
Op(std::nullptr_t)
Create an empty Operator object, must be assigned a valid one to be used.
bool operator<(const BFloat16_t &rhs) const noexcept
SessionOptions Clone() const
Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions.
Definition: core.h:760
Wrapper around ::OrtIoBinding.
void GetAttrs(const OrtKernelInfo *p, const char *name, std::vector< float > &)
IMATH_NAMESPACE::V2f float
char * GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length)
Allocate if necessary and obtain a pointer to a UTF-8 encoded string element buffer indexed by the fl...
SessionOptionsImpl & SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn)
Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn.
The Status that holds ownership of OrtStatus received from C API Use it to safely destroy OrtStatus* ...
void GetStringTensorContent(void *buffer, size_t buffer_length, size_t *offsets, size_t offsets_count) const
The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor into...
void GetDimensions(int64_t *values, size_t values_count) const
Wraps OrtApi::GetDimensions.
A generic, discriminated value, whose type may be queried dynamically.
Definition: Value.h:44
int64_t GetVersion() const
Wraps OrtApi::ModelMetadataGetVersion.
void GetAttr(const OrtKernelInfo *p, const char *name, float &)
CustomOpDomain(std::nullptr_t)
Create an empty CustomOpDomain object, must be assigned a valid one to be used.
SessionOptionsImpl & AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO_V2.
const std::unordered_map< std::string, std::string > & GetFlattenedConfigs() const
Returns a flattened map of custom operator configuration entries and their values.
SessionOptionsImpl & AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2 &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2.
constexpr std::enable_if< I< type_count_base< T >::value, int >::type tuple_type_size(){return subtype_count< typename std::tuple_element< I, T >::type >::value+tuple_type_size< T, I+1 >);}template< typename T > struct type_count< T, typename std::enable_if< is_tuple_like< T >::value >::type >{static constexpr int value{tuple_type_size< T, 0 >)};};template< typename T > struct subtype_count{static constexpr int value{is_mutable_container< T >::value?expected_max_vector_size:type_count< T >::value};};template< typename T, typename Enable=void > struct type_count_min{static const int value{0};};template< typename T >struct type_count_min< T, typename std::enable_if<!is_mutable_container< T >::value &&!is_tuple_like< T >::value &&!is_wrapper< T >::value &&!is_complex< T >::value &&!std::is_void< T >::value >::type >{static constexpr int value{type_count< T >::value};};template< typename T > struct type_count_min< T, typename std::enable_if< is_complex< T >::value >::type >{static constexpr int value{1};};template< typename T >struct type_count_min< T, typename std::enable_if< is_wrapper< T >::value &&!is_complex< T >::value &&!is_tuple_like< T >::value >::type >{static constexpr int value{subtype_count_min< typename T::value_type >::value};};template< typename T, std::size_t I >constexpr typename std::enable_if< I==type_count_base< T >::value, int >::type tuple_type_size_min(){return 0;}template< typename T, std::size_t I > constexpr typename std::enable_if< I< type_count_base< T >::value, int >::type tuple_type_size_min(){return subtype_count_min< typename std::tuple_element< I, T >::type >::value+tuple_type_size_min< T, I+1 >);}template< typename T > struct type_count_min< T, typename std::enable_if< is_tuple_like< T >::value >::type >{static constexpr int value{tuple_type_size_min< T, 0 >)};};template< typename T > struct subtype_count_min{static constexpr int value{is_mutable_container< T >::value?((type_count< T >::value< expected_max_vector_size)?type_count< T >::value:0):type_count_min< T >::value};};template< typename T, typename Enable=void > struct expected_count{static const int value{0};};template< typename T >struct expected_count< T, typename std::enable_if<!is_mutable_container< T >::value &&!is_wrapper< T >::value &&!std::is_void< T >::value >::type >{static constexpr int value{1};};template< typename T > struct expected_count< T, typename std::enable_if< is_mutable_container< T >::value >::type >{static constexpr int value{expected_max_vector_size};};template< typename T >struct expected_count< T, typename std::enable_if<!is_mutable_container< T >::value &&is_wrapper< T >::value >::type >{static constexpr int value{expected_count< typename T::value_type >::value};};enum class object_category:int{char_value=1, integral_value=2, unsigned_integral=4, enumeration=6, boolean_value=8, floating_point=10, number_constructible=12, double_constructible=14, integer_constructible=16, string_assignable=23, string_constructible=24, other=45, wrapper_value=50, complex_number=60, tuple_value=70, container_value=80,};template< typename T, typename Enable=void > struct classify_object{static constexpr object_category value{object_category::other};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_integral< T >::value &&!std::is_same< T, char >::value &&std::is_signed< T >::value &&!is_bool< T >::value &&!std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::integral_value};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_integral< T >::value &&std::is_unsigned< T >::value &&!std::is_same< T, char >::value &&!is_bool< T >::value >::type >{static constexpr object_category value{object_category::unsigned_integral};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_same< T, char >::value &&!std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::char_value};};template< typename T > struct classify_object< T, typename std::enable_if< is_bool< T >::value >::type >{static constexpr object_category value{object_category::boolean_value};};template< typename T > struct classify_object< T, typename std::enable_if< std::is_floating_point< T >::value >::type >{static constexpr object_category value{object_category::floating_point};};template< typename T >struct classify_object< T, typename std::enable_if<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&std::is_assignable< T &, std::string >::value >::type >{static constexpr object_category value{object_category::string_assignable};};template< typename T >struct classify_object< T, typename std::enable_if<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&!std::is_assignable< T &, std::string >::value &&(type_count< T >::value==1)&&std::is_constructible< T, std::string >::value >::type >{static constexpr object_category value{object_category::string_constructible};};template< typename T > struct classify_object< T, typename std::enable_if< std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::enumeration};};template< typename T > struct classify_object< T, typename std::enable_if< is_complex< T >::value >::type >{static constexpr object_category value{object_category::complex_number};};template< typename T > struct uncommon_type{using type=typename std::conditional<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&!std::is_assignable< T &, std::string >::value &&!std::is_constructible< T, std::string >::value &&!is_complex< T >::value &&!is_mutable_container< T >::value &&!std::is_enum< T >::value, std::true_type, std::false_type >::type;static constexpr bool value=type::value;};template< typename T >struct classify_object< T, typename std::enable_if<(!is_mutable_container< T >::value &&is_wrapper< T >::value &&!is_tuple_like< T >::value &&uncommon_type< T >::value)>::type >{static constexpr object_category value{object_category::wrapper_value};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&is_direct_constructible< T, double >::value &&is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::number_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&!is_direct_constructible< T, double >::value &&is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::integer_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&is_direct_constructible< T, double >::value &&!is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::double_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< is_tuple_like< T >::value &&((type_count< T >::value >=2 &&!is_wrapper< T >::value)||(uncommon_type< T >::value &&!is_direct_constructible< T, double >::value &&!is_direct_constructible< T, int >::value)||(uncommon_type< T >::value &&type_count< T >::value >=2))>::type >{static constexpr object_category value{object_category::tuple_value};};template< typename T > struct classify_object< T, typename std::enable_if< is_mutable_container< T >::value >::type >{static constexpr object_category value{object_category::container_value};};template< typename T, enable_if_t< classify_object< T >::value==object_category::char_value, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"CHAR";}template< typename T, enable_if_t< classify_object< T >::value==object_category::integral_value||classify_object< T >::value==object_category::integer_constructible, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"INT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::unsigned_integral, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"UINT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::floating_point||classify_object< T >::value==object_category::number_constructible||classify_object< T >::value==object_category::double_constructible, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"FLOAT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::enumeration, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"ENUM";}template< typename T, enable_if_t< classify_object< T >::value==object_category::boolean_value, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"BOOLEAN";}template< typename T, enable_if_t< classify_object< T >::value==object_category::complex_number, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"COMPLEX";}template< typename T, enable_if_t< classify_object< T >::value >=object_category::string_assignable &&classify_object< T >::value<=object_category::other, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"TEXT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value >=2, detail::enabler >=detail::dummy >std::string type_name();template< typename T, enable_if_t< classify_object< T >::value==object_category::container_value||classify_object< T >::value==object_category::wrapper_value, detail::enabler >=detail::dummy >std::string type_name();template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value==1, detail::enabler >=detail::dummy >inline std::string type_name(){return type_name< typename std::decay< typename std::tuple_element< 0, T >::type >::type >);}template< typename T, std::size_t I >inline typename std::enable_if< I==type_count_base< T >::value, std::string >::type tuple_name(){return std::string{};}template< typename T, std::size_t I >inline typename std::enable_if<(I< type_count_base< T >::value), std::string >::type tuple_name(){auto str=std::string{type_name< typename std::decay< typename std::tuple_element< I, T >::type >::type >)}+ ','+tuple_name< T, I+1 >);if(str.back()== ',') str.pop_back();return str;}template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value >=2, detail::enabler > > std::string type_name()
Recursively generate the tuple type name.
Definition: CLI11.h:1729
static Value CreateTensor(const OrtMemoryInfo *info, T *p_data, size_t p_data_element_count, const int64_t *shape, size_t shape_len)
Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
GLint GLint GLsizei GLint GLenum format
Definition: glcorearb.h:108
ThreadingOptions & SetGlobalIntraOpNumThreads(int intra_op_num_threads)
Wraps OrtApi::SetGlobalIntraOpNumThreads.
Value GetTensorAttribute(const char *name, OrtAllocator *allocator) const
void Add(const OrtCustomOp *op)
Wraps CustomOpDomain_Add.
KernelInfo(std::nullptr_t)
Create an empty instance to initialize later.
AllocatedStringPtr GetDomainAllocated(OrtAllocator *allocator) const
Returns a copy of the domain name.
Represents native memory allocation coming from one of the OrtAllocators registered with OnnxRuntime...
Session(std::nullptr_t)
Create an empty Session object, must be assigned a valid one to be used.
ThreadingOptions & SetGlobalDenormalAsZero()
Wraps OrtApi::SetGlobalDenormalAsZero.
IEEE 754 half-precision floating point data type.
GLint location
Definition: glcorearb.h:805
GLuint id
Definition: glcorearb.h:655
static Value CreateMap(const Value &keys, const Value &values)
Creates an OrtValue with a Map Onnx type representation. The API would ref-count the supplied OrtValu...
SessionOptionsImpl & SetInterOpNumThreads(int inter_op_num_threads)
Wraps OrtApi::SetInterOpNumThreads.
float8e5m2fnuz (Float8 Floating Point) data type
Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT_V...
R * GetTensorMutableData()
Returns a non-const typed pointer to an OrtValue/Tensor contained buffer No type checking is performe...
SessionOptionsImpl & AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions &provider_options)
SessionOptions()
Wraps OrtApi::CreateSessionOptions.
void * GetGPUComputeStream() const
void BindOutput(const char *name, const Value &)
GLuint const GLchar * name
Definition: glcorearb.h:786
RunOptions & SetRunTag(const char *run_tag)
wraps OrtApi::RunOptionsSetRunTag
Allocator(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
SessionOptionsImpl & EnableProfiling(const ORTCHAR_T *profile_file_prefix)
Wraps OrtApi::EnableProfiling.
size_t GetStringTensorDataLength() const
This API returns a full length of string data contained within either a tensor or a sparse Tensor...
detail::ConstSessionOptionsImpl< detail::Unowned< const OrtSessionOptions >> ConstSessionOptions
GLsizei const GLchar *const * strings
Definition: glcorearb.h:1933
RunOptions & SetTerminate()
Terminates all currently executing Session::Run calls that were made using this RunOptions instance...
OrtAllocatorType GetAllocatorType() const
SessionOptionsImpl & AppendExecutionProvider_OpenVINO_V2(const std::unordered_map< std::string, std::string > &provider_options={})
TypeInfo GetSequenceElementType() const
Wraps OrtApi::GetSequenceElementType.
static Value CreateSparseTensor(const OrtMemoryInfo *info, T *p_data, const Shape &dense_shape, const Shape &values_shape)
This is a simple forwarding method to the other overload that helps deducing data type enum value fro...
void GetSessionConfigs(std::unordered_map< std::string, std::string > &out, ConstSessionOptions options) const
TypeInfo GetOutputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOutputTypeInfo.
size_t GetStringTensorElementLength(size_t element_index) const
The API returns a byte length of UTF-8 encoded string element contained in either a tensor or a spare...
detail::ValueImpl< detail::Unowned< OrtValue >> UnownedValue
SessionOptionsImpl & Add(OrtCustomOpDomain *custom_op_domain)
Wraps OrtApi::AddCustomOpDomain.
TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const
The API returns type and shape information for stored non-zero values of the sparse tensor...
GT_API const UT_StringHolder version
TypeInfo GetTypeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
RunOptions & SetRunLogSeverityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogSeverityLevel.
ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
Wraps OrtApi::CastTypeInfoToTensorInfo.
bool operator==(const MemoryInfoImpl< U > &o) const
TypeInfo GetMapValueType() const
Wraps OrtApi::GetMapValueType.
RunOptions()
Wraps OrtApi::CreateRunOptions.
This struct owns the OrtKernInfo* pointer when a copy is made. For convenient wrapping of OrtKernelIn...
void FillSparseTensorCoo(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values_param, const int64_t *indices_data, size_t indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
std::vector< Value > GetOutputValues() const
GLsizeiptr size
Definition: glcorearb.h:664
OrtErrorCode GetErrorCode() const
void FillStringTensor(const char *const *s, size_t s_len)
Set all strings at once in a string tensor
void ParallelFor(void(*fn)(void *, size_t), size_t total, size_t num_batch, void *usr_data) const
KernelContext(OrtKernelContext *context)
This class represents an ONNX Runtime logger that can be used to log information with an associated s...
Env & CreateAndRegisterAllocatorV2(const std::string &provider_type, const OrtMemoryInfo *mem_info, const std::unordered_map< std::string, std::string > &options, const OrtArenaCfg *arena_cfg)
Wraps OrtApi::CreateAndRegisterAllocatorV2.
ArenaCfg(std::nullptr_t)
Create an empty ArenaCfg object, must be assigned a valid one to be used.
SessionOptionsImpl & SetDeterministicCompute(bool value)
Wraps OrtApi::SetDeterministicCompute.
GT_API const UT_StringHolder st
ThreadingOptions()
Wraps OrtApi::CreateThreadingOptions.
GLenum GLsizei GLsizei GLint * values
Definition: glcorearb.h:1602
std::string GetOutputName(size_t index) const
The ThreadingOptions.
bool operator==(const BFloat16_t &rhs) const noexcept
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOverridableInitializerTypeInfo.
SessionOptionsImpl & AddConfigEntry(const char *config_key, const char *config_value)
Wraps OrtApi::AddSessionConfigEntry.
bool IsOK() const noexcept
Returns true if instance represents an OK (non-error) status.
SessionOptionsImpl & SetExecutionMode(ExecutionMode execution_mode)
Wraps OrtApi::SetSessionExecutionMode.
GLuint index
Definition: glcorearb.h:786
AllocatedStringPtr EndProfilingAllocated(OrtAllocator *allocator)
End profiling and return a copy of the profiling file name.
ShapeInferContext(const OrtApi *ort_api, OrtShapeInferContext *ctx)
Logger()=default
SessionOptionsImpl & AppendExecutionProvider(const std::string &provider_name, const std::unordered_map< std::string, std::string > &provider_options={})
Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK...
GLuint GLfloat * val
Definition: glcorearb.h:1608
SessionOptionsImpl & DisableProfiling()
Wraps OrtApi::DisableProfiling.
AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the graph description.
float8e5m2 (Float8 Floating Point) data type
**If you just want to fire and args
Definition: thread.h:609
RunOptions & UnsetTerminate()
Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without ...
size_t GetOverridableInitializerCount() const
Returns the number of inputs that have defaults that can be overridden.
const char * GetRunTag() const
Wraps OrtApi::RunOptionsGetRunTag.
int GetRunLogVerbosityLevel() const
Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel.
std::string GetVersionString()
This function returns the onnxruntime version string
Definition: core.h:1131
size_t GetOutputCount() const
Returns the number of model outputs.
Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T *file_path, int line_number, const char *func_name, const char *message) const noexcept
SessionOptionsImpl & AddExternalInitializers(const std::vector< std::string > &names, const std::vector< Value > &ort_values)
Wraps OrtApi::AddExternalInitializers.
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
int GetRunLogSeverityLevel() const
Wraps OrtApi::RunOptionsGetRunLogSeverityLevel.
ConstMemoryInfo GetTensorMemoryInfo() const
This API returns information about the memory allocation used to hold data.
AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of input name at the specified index.
Wrapper around ::OrtTensorTypeAndShapeInfo.
MemoryInfo(std::nullptr_t)
No instance is created.
std::string MakeCustomOpConfigEntryKey(const char *custom_op_name, const char *config)
CustomOpConfigs.
size_t GetCount() const
< Return true if OrtValue contains data and returns false if the OrtValue is a None ...
std::string GetInputName(size_t index) const
ThreadingOptions & SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn)
Wraps OrtApi::SetGlobalCustomCreateThreadFn.
Wrapper around ::OrtSessionOptions.
OrtSparseFormat GetSparseFormat() const
The API returns the sparse data format this OrtValue holds in a sparse tensor. If the sparse tensor w...
detail::MemoryInfoImpl< detail::Unowned< const OrtMemoryInfo >> ConstMemoryInfo
void UseCooIndices(int64_t *indices_data, size_t indices_num)
Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tens...
MemoryAllocation(OrtAllocator *allocator, void *p, size_t size)
CustomOpConfigs & AddConfig(const char *custom_op_name, const char *config_key, const char *config_value)
Adds a session configuration entry/value for a specific custom operator.
type
Definition: core.h:1059
TypeInfo GetInputTypeInfo(size_t index) const
SessionOptionsImpl & SetLogSeverityLevel(int level)
Wraps OrtApi::SetSessionLogSeverityLevel.
Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime.
Wrapper around ::OrtSession.
SessionOptionsImpl & DisableMemPattern()
Wraps OrtApi::DisableMemPattern.
SessionOptionsImpl & AppendExecutionProvider_ROCM(const OrtROCMProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM.
Options for the CUDA provider that are passed to SessionOptionsAppendExecutionProvider_CUDA_V2. Please note that this struct is similar to OrtCUDAProviderOptions but only to be used internally. Going forward, new cuda provider options are to be supported via this struct and usage of the publicly defined OrtCUDAProviderOptions will be deprecated over time. User can only get the instance of OrtCUDAProviderOptionsV2 via CreateCUDAProviderOptions.
RunOptions & SetRunLogVerbosityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel.
ThreadingOptions & SetGlobalCustomThreadCreationOptions(void *ort_custom_thread_creation_options)
Wraps OrtApi::SetGlobalCustomThreadCreationOptions.
#define RETURN_ON_API_FAIL(expression)
TypeInfo GetOptionalElementType() const
Wraps OrtApi::CastOptionalTypeToContainedTypeInfo.
bfloat16 (Brain Floating Point) data type
Class that represents session configuration entries for one or more custom operators.
Status(std::nullptr_t) noexcept
Create an empty object, must be assigned a valid one to be used.
GLint GLsizei count
Definition: glcorearb.h:405
Definition: format.h:895
std::vector< std::string > GetOutputNamesHelper(const OrtIoBinding *binding, OrtAllocator *)
#define ORT_CXX_API_THROW(string, code)
SessionOptionsImpl & RegisterCustomOpsLibrary(const ORTCHAR_T *library_name, const CustomOpConfigs &custom_op_configs={})
GLsizei GLenum GLenum GLuint GLenum GLsizei * lengths
Definition: glcorearb.h:2542
ConstSequenceTypeInfo GetSequenceTypeInfo() const
Wraps OrtApi::CastTypeInfoToSequenceTypeInfo.