HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
onnxruntime_cxx_inline.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 // Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead.
5 // If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead.
6 //
7 // These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
8 // the main C++ file with implementation details.
9 
10 namespace Ort {
11 
12 namespace detail {
13 inline void ThrowStatus(const Status& st) {
14  std::string error_message = st.GetErrorMessage();
15  OrtErrorCode error_code = st.GetErrorCode();
16  ORT_CXX_API_THROW(std::move(error_message), error_code);
17 }
18 } // namespace detail
19 
20 inline void ThrowOnError(OrtStatus* ort_status) {
21  if (ort_status) {
22  Ort::Status st(ort_status);
24  }
25 }
26 
27 inline void ThrowOnError(const Status& st) {
28  if (st) {
30  }
31 }
32 
33 inline Status::Status(OrtStatus* status) : Base<OrtStatus>{status} {
34 }
35 
36 inline Status::Status(const std::exception& e) {
37  p_ = GetApi().CreateStatus(ORT_FAIL, e.what());
38 }
39 
40 inline Status::Status(const Exception& e) {
41  p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what());
42 }
43 
44 inline std::string Status::GetErrorMessage() const {
46  return message;
47 }
48 
49 inline OrtErrorCode Status::GetErrorCode() const {
50  return GetApi().GetErrorCode(p_);
51 }
52 
53 // This template converts a C++ type into it's ONNXTensorElementDataType
54 template <typename T>
55 struct TypeToTensorType;
56 template <>
57 struct TypeToTensorType<float> {
59 };
60 template <>
61 struct TypeToTensorType<Float16_t> {
63 };
64 template <>
65 struct TypeToTensorType<BFloat16_t> {
67 };
68 template <>
69 struct TypeToTensorType<double> {
71 };
72 template <>
73 struct TypeToTensorType<int8_t> {
75 };
76 template <>
77 struct TypeToTensorType<int16_t> {
79 };
80 template <>
81 struct TypeToTensorType<int32_t> {
83 };
84 template <>
85 struct TypeToTensorType<int64_t> {
87 };
88 template <>
89 struct TypeToTensorType<uint8_t> {
91 };
92 template <>
93 struct TypeToTensorType<uint16_t> {
95 };
96 template <>
97 struct TypeToTensorType<uint32_t> {
99 };
100 template <>
101 struct TypeToTensorType<uint64_t> {
103 };
104 template <>
105 struct TypeToTensorType<bool> {
107 };
108 
109 inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
110  : allocator_(allocator), p_(p), size_(size) {
111 }
112 
114  if (p_ != nullptr) {
115  // We do not throw out of destructor
116  auto ret = GetApi().AllocatorFree(allocator_, p_);
117  static_cast<void>(ret);
118  }
119 }
120 
121 inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) {
122  *this = std::move(o);
123 }
124 
126  OrtAllocator* alloc = nullptr;
127  void* p = nullptr;
128  size_t sz = 0;
129 
130  // Swap out this
131  std::swap(alloc, allocator_);
132  std::swap(p, p_);
133  std::swap(sz, size_);
134 
135  // Swap with incoming
136  std::swap(allocator_, o.allocator_);
137  std::swap(p_, o.p_);
138  std::swap(size_, o.size_);
139 
140  // Destroy this instance if needed
141  MemoryAllocation this_alloc(alloc, p, sz);
142  return *this;
143 }
144 
145 namespace detail {
146 
147 template <typename T>
148 inline void* AllocatorImpl<T>::Alloc(size_t size) {
149  void* out;
150  ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
151  return out;
152 }
153 
154 template <typename T>
156  void* out;
157  ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
158  MemoryAllocation result(this->p_, out, size);
159  return result;
160 }
161 
162 template <typename T>
163 inline void AllocatorImpl<T>::Free(void* p) {
164  ThrowOnError(GetApi().AllocatorFree(this->p_, p));
165 }
166 
167 template <typename T>
169  const OrtMemoryInfo* out;
170  ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out));
171  return ConstMemoryInfo{out};
172 }
173 
174 } // namespace detail
175 
177  ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_));
178 }
179 
180 inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) {
181  ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_));
182 }
183 
184 namespace detail {
185 
186 template <typename T>
188  const char* name = nullptr;
189  ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name));
190  return std::string(name);
191 }
192 
193 template <typename T>
196  ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type));
197  return type;
198 }
199 
200 template <typename T>
201 inline int MemoryInfoImpl<T>::GetDeviceId() const {
202  int id = 0;
203  ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id));
204  return id;
205 }
206 
207 template <typename T>
210  GetApi().MemoryInfoGetDeviceType(this->p_, &type);
211  return type;
212 }
213 
214 template <typename T>
217  ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type));
218  return type;
219 }
220 
221 template <typename T>
222 template <typename U>
224  int comp_result = 0;
225  ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result));
226  return comp_result == 0;
227 }
228 
229 } // namespace detail
230 
232  OrtMemoryInfo* p;
233  ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
234  return MemoryInfo(p);
235 }
236 
237 inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
238  ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_));
239 }
240 
241 namespace detail {
242 template <typename T>
243 inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames() const {
244  AllocatorWithDefaultOptions allocator;
245  return binding_utils::GetOutputNamesHelper(this->p_, allocator);
246 }
247 
248 template <typename T>
249 inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames(OrtAllocator* allocator) const {
250  return binding_utils::GetOutputNamesHelper(this->p_, allocator);
251 }
252 
253 template <typename T>
254 inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues() const {
255  AllocatorWithDefaultOptions allocator;
256  return binding_utils::GetOutputValuesHelper(this->p_, allocator);
257 }
258 
259 template <typename T>
260 inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues(OrtAllocator* allocator) const {
261  return binding_utils::GetOutputValuesHelper(this->p_, allocator);
262 }
263 
264 template <typename T>
265 inline void IoBindingImpl<T>::BindInput(const char* name, const Value& value) {
266  ThrowOnError(GetApi().BindInput(this->p_, name, value));
267 }
268 
269 template <typename T>
270 inline void IoBindingImpl<T>::BindOutput(const char* name, const Value& value) {
271  ThrowOnError(GetApi().BindOutput(this->p_, name, value));
272 }
273 
274 template <typename T>
275 inline void IoBindingImpl<T>::BindOutput(const char* name, const OrtMemoryInfo* mem_info) {
276  ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info));
277 }
278 
279 template <typename T>
281  GetApi().ClearBoundInputs(this->p_);
282 }
283 
284 template <typename T>
286  GetApi().ClearBoundOutputs(this->p_);
287 }
288 
289 template <typename T>
291  ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_));
292 }
293 
294 template <typename T>
296  ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_));
297 }
298 
299 namespace binding_utils {
300 inline std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
301  std::vector<std::string> result;
302  auto free_fn = detail::AllocatedFree(allocator);
303  using Ptr = std::unique_ptr<void, decltype(free_fn)>;
304 
305  char* buffer = nullptr;
306  size_t* lengths = nullptr;
307  size_t count = 0;
308  ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count));
309 
310  if (count == 0) {
311  return result;
312  }
313 
314  Ptr buffer_g(buffer, free_fn);
315  Ptr lengths_g(lengths, free_fn);
316 
317  result.reserve(count);
318  for (size_t i = 0; i < count; ++i) {
319  auto sz = *lengths;
320  result.emplace_back(buffer, sz);
321  buffer += sz;
322  ++lengths;
323  }
324  return result;
325 }
326 
327 inline std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
328  std::vector<Value> result;
329  size_t owned = 0;
330  size_t output_count = 0;
331  // Lambda to release the buffer when no longer needed and
332  // make sure that we destroy all instances on exception
333  auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
334  if (buffer) {
335  while (owned < output_count) {
336  auto* p = buffer + owned++;
337  GetApi().ReleaseValue(*p);
338  }
339  allocator->Free(allocator, buffer);
340  }
341  };
342  using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
343 
344  OrtValue** output_buffer = nullptr;
345  ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count));
346  if (output_count == 0) {
347  return result;
348  }
349 
350  Ptr buffer_g(output_buffer, free_fn);
351 
352  result.reserve(output_count);
353  for (size_t i = 0; i < output_count; ++i) {
354  result.emplace_back(output_buffer[i]);
355  ++owned;
356  }
357  return result;
358 }
359 
360 } // namespace binding_utils
361 } // namespace detail
362 
363 inline IoBinding::IoBinding(Session& session) {
364  ThrowOnError(GetApi().CreateIoBinding(session, &this->p_));
365 }
366 
367 inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
368  ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
369 }
370 
372  ThrowOnError(GetApi().CreateThreadingOptions(&p_));
373 }
374 
376  ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads));
377  return *this;
378 }
379 
381  ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads));
382  return *this;
383 }
384 
386  ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning));
387  return *this;
388 }
389 
392  return *this;
393 }
394 
396  ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
397  return *this;
398 }
399 
400 inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
401  ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
402  return *this;
403 }
404 
406  ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
407  return *this;
408 }
409 
410 inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
411  ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
412  if (strcmp(logid, "onnxruntime-node") == 0) {
414  } else {
416  }
417 }
418 
419 inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
420  ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_));
421  if (strcmp(logid, "onnxruntime-node") == 0) {
423  } else {
425  }
426 }
427 
428 inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) {
429  ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_));
430  if (strcmp(logid, "onnxruntime-node") == 0) {
432  } else {
434  }
435 }
436 
437 inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
438  OrtLoggingLevel logging_level, _In_ const char* logid) {
439  ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_));
440  if (strcmp(logid, "onnxruntime-node") == 0) {
442  } else {
444  }
445 }
446 
447 inline Env& Env::EnableTelemetryEvents() {
448  ThrowOnError(GetApi().EnableTelemetryEvents(p_));
449  return *this;
450 }
451 
454  return *this;
455 }
456 
458  ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level));
459  return *this;
460 }
461 
462 inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) {
463  ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg));
464  return *this;
465 }
466 
467 inline CustomOpDomain::CustomOpDomain(const char* domain) {
468  ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
469 }
470 
471 inline void CustomOpDomain::Add(const OrtCustomOp* op) {
472  ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
473 }
474 
476  ThrowOnError(GetApi().CreateRunOptions(&p_));
477 }
478 
480  ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
481  return *this;
482 }
483 
485  ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
486  return *this;
487 }
488 
490  int out;
491  ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
492  return out;
493 }
494 
496  int out;
497  ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out));
498  return out;
499 }
500 
501 inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
502  ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
503  return *this;
504 }
505 
506 inline const char* RunOptions::GetRunTag() const {
507  const char* out;
508  ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
509  return out;
510 }
511 
512 inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) {
513  ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value));
514  return *this;
515 }
516 
518  ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
519  return *this;
520 }
521 
523  ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
524  return *this;
525 }
526 
527 namespace detail {
528 
529 template <typename T>
531  OrtSessionOptions* out;
532  ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out));
533  return SessionOptions{out};
534 }
535 
536 template <typename T>
537 inline std::string ConstSessionOptionsImpl<T>::GetConfigEntry(const char* config_key) const {
538  size_t size = 0;
539  // Feed nullptr for the data buffer to query the true size of the string value
540  Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size));
541 
542  std::string out;
543  out.resize(size);
544  Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size));
545  out.resize(size - 1); // remove the terminating character '\0'
546 
547  return out;
548 }
549 
550 template <typename T>
551 inline bool ConstSessionOptionsImpl<T>::HasConfigEntry(const char* config_key) const {
552  int out = 0;
553  Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out));
554  return static_cast<bool>(out);
555 }
556 
557 template <typename T>
559  if (!this->HasConfigEntry(config_key)) {
560  return def;
561  }
562 
563  return this->GetConfigEntry(config_key);
564 }
565 
566 template <typename T>
568  ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads));
569  return *this;
570 }
571 
572 template <typename T>
574  ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads));
575  return *this;
576 }
577 
578 template <typename T>
580  ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level));
581  return *this;
582 }
583 
584 template <typename T>
586  ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath));
587  return *this;
588 }
589 
590 template <typename T>
592  ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix));
593  return *this;
594 }
595 
596 template <typename T>
598  ThrowOnError(GetApi().DisableProfiling(this->p_));
599  return *this;
600 }
601 
602 template <typename T>
604  ThrowOnError(GetApi().EnableOrtCustomOps(this->p_));
605  return *this;
606 }
607 
608 template <typename T>
610  ThrowOnError(GetApi().EnableMemPattern(this->p_));
611  return *this;
612 }
613 
614 template <typename T>
616  ThrowOnError(GetApi().DisableMemPattern(this->p_));
617  return *this;
618 }
619 
620 template <typename T>
622  ThrowOnError(GetApi().EnableCpuMemArena(this->p_));
623  return *this;
624 }
625 
626 template <typename T>
628  ThrowOnError(GetApi().DisableCpuMemArena(this->p_));
629  return *this;
630 }
631 
632 template <typename T>
634  ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode));
635  return *this;
636 }
637 
638 template <typename T>
640  ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
641  return *this;
642 }
643 
644 template <typename T>
646  ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level));
647  return *this;
648 }
649 
650 template <typename T>
651 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::Add(OrtCustomOpDomain* custom_op_domain) {
652  ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain));
653  return *this;
654 }
655 
656 template <typename T>
657 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddConfigEntry(const char* config_key, const char* config_value) {
658  ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value));
659  return *this;
660 }
661 
662 template <typename T>
663 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddInitializer(const char* name, const OrtValue* ort_val) {
664  ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val));
665  return *this;
666 }
667 
668 template <typename T>
670  ThrowOnError(GetApi().DisablePerSessionThreads(this->p_));
671  return *this;
672 }
673 
674 template <typename T>
675 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializers(const std::vector<std::string>& names,
676  const std::vector<Value>& ort_values) {
677  const size_t inputs_num = names.size();
678  if (inputs_num != ort_values.size()) {
679  ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT);
680  }
681  std::vector<const char*> names_ptr;
682  std::vector<const OrtValue*> ort_values_ptrs;
683  names_ptr.reserve(inputs_num);
684  ort_values_ptrs.reserve(inputs_num);
685  for (size_t i = 0; i < inputs_num; ++i) {
686  names_ptr.push_back(names[i].c_str());
687  ort_values_ptrs.push_back(ort_values[i]);
688  }
689  ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num));
690  return *this;
691 }
692 
693 template <typename T>
695  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options));
696  return *this;
697 }
698 
699 template <typename T>
701  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options));
702  return *this;
703 }
704 
705 template <typename T>
707  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options));
708  return *this;
709 }
710 
711 template <typename T>
713  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options));
714  return *this;
715 }
716 
717 template <typename T>
719  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options));
720  return *this;
721 }
722 
723 template <typename T>
725  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options));
726  return *this;
727 }
728 
729 template <typename T>
731  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
732  return *this;
733 }
734 
735 template <typename T>
737  const std::string& provider_name,
738  const std::unordered_map<std::string, std::string>& provider_options) {
739  auto num_entries = provider_options.size();
740  std::vector<const char*> keys, values;
741  if (num_entries > 0) {
742  keys.reserve(num_entries);
743  values.reserve(num_entries);
744 
745  for (const auto& entry : provider_options) {
746  keys.push_back(entry.first.c_str());
747  values.push_back(entry.second.c_str());
748  }
749  }
750 
751  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(),
752  keys.data(), values.data(), num_entries));
753 
754  return *this;
755 }
756 
757 template <typename T>
759  ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
760  return *this;
761 }
762 
763 template <typename T>
764 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
765  ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options));
766  return *this;
767 }
768 
769 template <typename T>
771  ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn));
772  return *this;
773 }
774 
775 template <typename T>
777  ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options));
778  return *this;
779 }
780 
781 template <typename T>
783  const CustomOpConfigs& custom_op_configs) {
784  // Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by
785  // the custom op library.
786  for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) {
787  AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str());
788  }
789 
790  ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name));
791  return *this;
792 }
793 
794 template <typename T>
795 inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsUsingFunction(const char* registration_function_name) {
796  ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name));
797  return *this;
798 }
799 
800 /// Session
801 template <typename T>
802 inline size_t ConstSessionImpl<T>::GetInputCount() const {
803  size_t out;
804  ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out));
805  return out;
806 }
807 
808 template <typename T>
809 inline size_t ConstSessionImpl<T>::GetOutputCount() const {
810  size_t out;
811  ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out));
812  return out;
813 }
814 
815 template <typename T>
817  size_t out;
818  ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out));
819  return out;
820 }
821 
822 template <typename T>
824  char* out;
825  ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out));
826  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
827 }
828 
829 template <typename T>
831  char* out;
832  ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out));
833  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
834 }
835 
836 template <typename T>
838  char* out;
839  ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out));
840  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
841 }
842 
843 template <typename T>
845  uint64_t out;
846  ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out));
847  return out;
848 }
849 
850 template <typename T>
852  OrtModelMetadata* out;
853  ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out));
854  return ModelMetadata{out};
855 }
856 
857 template <typename T>
859  OrtTypeInfo* out;
860  ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out));
861  return TypeInfo{out};
862 }
863 
864 template <typename T>
866  OrtTypeInfo* out;
867  ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out));
868  return TypeInfo{out};
869 }
870 
871 template <typename T>
873  OrtTypeInfo* out;
874  ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out));
875  return TypeInfo{out};
876 }
877 
878 template <typename T>
879 inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
880  const char* const* output_names, size_t output_count) {
881  std::vector<Value> output_values;
882  output_values.reserve(output_count);
883  for (size_t i = 0; i < output_count; i++)
884  output_values.emplace_back(nullptr);
885  Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count);
886  return output_values;
887 }
888 
889 template <typename T>
890 inline void SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
891  const char* const* output_names, Value* output_values, size_t output_count) {
892  static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
893  auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
894  auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
895  ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
896 }
897 
898 template <typename T>
899 inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding& io_binding) {
900  ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
901 }
902 
903 template <typename T>
905  char* out = nullptr;
906  ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out));
907  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
908 }
909 
910 } // namespace detail
911 
913  ThrowOnError(GetApi().CreateSessionOptions(&this->p_));
914 }
915 
916 /// CustomOpConfigs
917 inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) {
918  std::string config_key = "custom_op.";
919 
920  config_key += custom_op_name;
921  config_key += ".";
922  config_key += config;
923 
924  return config_key;
925 }
926 
927 inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) {
928  const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key);
929  flat_configs_[full_flat_key] = config_value;
930  return *this;
931 }
932 
933 inline const std::unordered_map<std::string, std::string>& CustomOpConfigs::GetFlattenedConfigs() const {
934  return flat_configs_;
935 }
936 
937 inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
938  ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_));
939 }
940 
941 inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
942  OrtPrepackedWeightsContainer* prepacked_weights_container) {
943  ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_));
944 }
945 
946 inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
947  ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_));
948 }
949 
950 inline Session::Session(const Env& env, const void* model_data, size_t model_data_length,
951  const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
952  ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
953  prepacked_weights_container, &this->p_));
954 }
955 
956 inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
957  char* out;
958  ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
959  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
960 }
961 
963  char* out;
964  ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
965  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
966 }
967 
969  char* out;
970  ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
971  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
972 }
973 
975  char* out;
976  ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
977  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
978 }
979 
981  char* out;
982  ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
983  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
984 }
985 
987  char* out;
988  ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
989  return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
990 }
991 
992 inline std::vector<AllocatedStringPtr> ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const {
993  auto deletor = detail::AllocatedFree(allocator);
994  std::vector<AllocatedStringPtr> result;
995 
996  char** out = nullptr;
997  int64_t num_keys = 0;
998  ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
999  if (num_keys <= 0) {
1000  return result;
1001  }
1002 
1003  // array of pointers will be freed
1004  std::unique_ptr<void, decltype(deletor)> array_guard(out, deletor);
1005  // reserve may throw
1006  auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); };
1007  std::unique_ptr<char*, decltype(strings_deletor)> strings_guard(out, strings_deletor);
1008  result.reserve(static_cast<size_t>(num_keys));
1009  strings_guard.release();
1010  for (int64_t i = 0; i < num_keys; ++i) {
1011  result.push_back(AllocatedStringPtr(out[i], deletor));
1012  }
1013 
1014  return result;
1015 }
1016 
1017 inline int64_t ModelMetadata::GetVersion() const {
1018  int64_t out;
1019  ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
1020  return out;
1021 }
1022 
1023 namespace detail {
1024 
1025 template <typename T>
1028  ThrowOnError(GetApi().GetTensorElementType(this->p_, &out));
1029  return out;
1030 }
1031 
1032 template <typename T>
1034  size_t out;
1035  ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out));
1036  return static_cast<size_t>(out);
1037 }
1038 
1039 template <typename T>
1041  size_t out;
1042  ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out));
1043  return out;
1044 }
1045 
1046 template <typename T>
1047 inline void TensorTypeAndShapeInfoImpl<T>::GetDimensions(int64_t* values, size_t values_count) const {
1048  ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count));
1049 }
1050 
1051 template <typename T>
1052 inline void TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions(const char** values, size_t values_count) const {
1053  ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count));
1054 }
1055 
1056 template <typename T>
1057 inline std::vector<int64_t> TensorTypeAndShapeInfoImpl<T>::GetShape() const {
1058  std::vector<int64_t> out(GetDimensionsCount(), 0);
1059  ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size()));
1060  return out;
1061 }
1062 
1063 } // namespace detail
1064 
1065 namespace detail {
1066 template <typename T>
1068  const OrtTensorTypeAndShapeInfo* out;
1069  ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out));
1070  return ConstTensorTypeAndShapeInfo{out};
1071 }
1072 
1073 template <typename T>
1075  const OrtSequenceTypeInfo* out;
1076  ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out));
1077  return ConstSequenceTypeInfo{out};
1078 }
1079 
1080 template <typename T>
1082  const OrtMapTypeInfo* out;
1083  ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out));
1084  return ConstMapTypeInfo{out};
1085 }
1086 
1087 template <typename T>
1089  ONNXType out;
1090  ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out));
1091  return out;
1092 }
1093 
1094 } // namespace detail
1095 
1096 namespace detail {
1097 template <typename T>
1099  OrtTypeInfo* output;
1100  ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output));
1101  return TypeInfo{output};
1102 }
1103 
1104 } // namespace detail
1105 
1106 namespace detail {
1107 template <typename T>
1110  ThrowOnError(GetApi().GetMapKeyType(this->p_, &out));
1111  return out;
1112 }
1113 
1114 template <typename T>
1116  OrtTypeInfo* output;
1117  ThrowOnError(GetApi().GetMapValueType(this->p_, &output));
1118  return TypeInfo{output};
1119 }
1120 } // namespace detail
1121 
1122 namespace detail {
1123 
1124 template <typename T>
1125 template <typename R>
1126 inline void ConstValueImpl<T>::GetOpaqueData(const char* domain, const char* type_name, R& out) const {
1127  ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R)));
1128 }
1129 
1130 template <typename T>
1131 inline bool ConstValueImpl<T>::IsTensor() const {
1132  int out;
1133  ThrowOnError(GetApi().IsTensor(this->p_, &out));
1134  return out != 0;
1135 }
1136 
1137 template <typename T>
1138 inline bool ConstValueImpl<T>::HasValue() const {
1139  int out;
1140  ThrowOnError(GetApi().HasValue(this->p_, &out));
1141  return out != 0;
1142 }
1143 
1144 template <typename T>
1145 inline size_t ConstValueImpl<T>::GetCount() const {
1146  size_t out;
1147  ThrowOnError(GetApi().GetValueCount(this->p_, &out));
1148  return out;
1149 }
1150 
1151 template <typename T>
1152 inline Value ConstValueImpl<T>::GetValue(int index, OrtAllocator* allocator) const {
1153  OrtValue* out;
1154  ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out));
1155  return Value{out};
1156 }
1157 
1158 template <typename T>
1160  size_t out;
1161  ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out));
1162  return out;
1163 }
1164 
1165 template <typename T>
1166 inline size_t ConstValueImpl<T>::GetStringTensorElementLength(size_t element_index) const {
1167  size_t out;
1168  ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out));
1169  return out;
1170 }
1171 
1172 template <typename T>
1173 template <typename R>
1174 inline const R* ConstValueImpl<T>::GetTensorData() const {
1175  R* out;
1176  ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), (void**)&out));
1177  return out;
1178 }
1179 
1180 template <typename T>
1181 inline const void* ConstValueImpl<T>::GetTensorRawData() const {
1182  void* out;
1183  ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), &out));
1184  return out;
1185 }
1186 
1187 template <typename T>
1189  OrtTypeInfo* output;
1190  ThrowOnError(GetApi().GetTypeInfo(this->p_, &output));
1191  return TypeInfo{output};
1192 }
1193 
1194 template <typename T>
1196  OrtTensorTypeAndShapeInfo* output;
1197  ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output));
1198  return TensorTypeAndShapeInfo{output};
1199 }
1200 
1201 template <typename T>
1203  const OrtMemoryInfo* mem_info;
1204  ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info));
1205  return ConstMemoryInfo(mem_info);
1206 }
1207 
1208 template <typename T>
1209 inline void ConstValueImpl<T>::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const {
1210  ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer));
1211 }
1212 
1213 template <typename T>
1214 inline void ConstValueImpl<T>::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
1215  ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count));
1216 }
1217 
1218 #if !defined(DISABLE_SPARSE_TENSORS)
1219 template <typename T>
1222  ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format));
1223  return format;
1224 }
1225 
1226 template <typename T>
1228  OrtTensorTypeAndShapeInfo* output;
1229  ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output));
1230  return TensorTypeAndShapeInfo{output};
1231 }
1232 
1233 template <typename T>
1235  OrtTensorTypeAndShapeInfo* output;
1236  ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output));
1237  return TensorTypeAndShapeInfo{output};
1238 }
1239 
1240 template <typename T>
1241 template <typename R>
1242 inline const R* ConstValueImpl<T>::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const {
1243  const void* out;
1244  ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out));
1245  return reinterpret_cast<const R*>(out);
1246 }
1247 
1248 template <typename T>
1250  int out;
1251  ThrowOnError(GetApi().IsSparseTensor(this->p_, &out));
1252  return out != 0;
1253 }
1254 
1255 template <typename T>
1256 template <typename R>
1258  const void* out;
1259  ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out));
1260  return reinterpret_cast<const R*>(out);
1261 }
1262 
1263 #endif
1264 
1265 template <typename T>
1266 void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
1267  ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
1268 }
1269 
1270 template <typename T>
1271 void ValueImpl<T>::FillStringTensorElement(const char* s, size_t index) {
1272  ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index));
1273 }
1274 
1275 template <typename T>
1277  void* out;
1278  ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out));
1279  return out;
1280 }
1281 
1282 template <typename T>
1283 template <typename R>
1285  R* out;
1286  ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out));
1287  return out;
1288 }
1289 
1290 template <typename T>
1291 template <typename R>
1292 R& ValueImpl<T>::At(const std::vector<int64_t>& location) {
1293  static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
1294  R* out;
1295  ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out));
1296  return *out;
1297 }
1298 
1299 #if !defined(DISABLE_SPARSE_TENSORS)
1300 template <typename T>
1301 void ValueImpl<T>::UseCooIndices(int64_t* indices_data, size_t indices_num) {
1302  ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num));
1303 }
1304 
1305 template <typename T>
1306 void ValueImpl<T>::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) {
1307  ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num));
1308 }
1309 
1310 template <typename T>
1311 void ValueImpl<T>::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) {
1312  ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data));
1313 }
1314 
1315 template <typename T>
1317  const int64_t* indices_data, size_t indices_num) {
1318  ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape,
1319  values_param.values_shape_len, values_param.data.p_data,
1320  indices_data, indices_num));
1321 }
1322 
1323 template <typename T>
1326  const int64_t* inner_indices_data, size_t inner_indices_num,
1327  const int64_t* outer_indices_data, size_t outer_indices_num) {
1328  ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1329  inner_indices_data, inner_indices_num,
1330  outer_indices_data, outer_indices_num));
1331 }
1332 
1333 template <typename T>
1336  const Shape& indices_shape,
1337  const int32_t* indices_data) {
1338  ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1339  indices_shape.shape, indices_shape.shape_len,
1340  indices_data));
1341 }
1342 
1343 #endif // !defined(DISABLE_SPARSE_TENSORS)
1344 
1345 } // namespace detail
1346 
1347 template <typename T>
1348 inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
1349  return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
1350 }
1351 
1352 inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1354  OrtValue* out;
1355  ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
1356  return Value{out};
1357 }
1358 
1359 template <typename T>
1360 inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
1361  return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
1362 }
1363 
1364 inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
1365  OrtValue* out;
1366  ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
1367  return Value{out};
1368 }
1369 
1370 #if !defined(DISABLE_SPARSE_TENSORS)
1371 
1372 template <typename T>
1373 inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1374  const Shape& values_shape) {
1375  return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType<T>::type);
1376 }
1377 
1378 inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1379  const Shape& values_shape, ONNXTensorElementDataType type) {
1380  OrtValue* out;
1381  ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
1382  values_shape.shape, values_shape.shape_len, type, &out));
1383  return Value{out};
1384 }
1385 
1386 template <typename T>
1387 inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) {
1388  return CreateSparseTensor(allocator, dense_shape, TypeToTensorType<T>::type);
1389 }
1390 
1391 inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape,
1393  OrtValue* out;
1394  ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out));
1395  return Value{out};
1396 }
1397 #endif // !defined(DISABLE_SPARSE_TENSORS)
1398 
1400  OrtValue* out;
1401  OrtValue* inputs[2] = {keys, values};
1402  ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
1403  return Value{out};
1404 }
1405 
1406 inline Value Value::CreateSequence(std::vector<Value>& values) {
1407  OrtValue* out;
1408  std::vector<OrtValue*> values_ort{values.data(), values.data() + values.size()};
1409  ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
1410  return Value{out};
1411 }
1412 
1413 template <typename T>
1414 inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
1415  OrtValue* out;
1416  ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
1417  return Value{out};
1418 }
1419 
1420 //
1421 // Custom OP Inlines
1422 //
1423 inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) {
1424 }
1425 
1426 inline size_t KernelContext::GetInputCount() const {
1427  size_t out = 0;
1428  Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out));
1429  return out;
1430 }
1431 
1432 inline size_t KernelContext::GetOutputCount() const {
1433  size_t out = 0;
1434  Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out));
1435  return out;
1436 }
1437 
1439  const OrtValue* out = nullptr;
1440  Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out));
1441  return ConstValue{out};
1442 }
1443 
1444 inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const {
1445  OrtValue* out = nullptr;
1446  Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out));
1447  return UnownedValue(out);
1448 }
1449 
1450 inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int64_t>& dims) const {
1451  OrtValue* out = nullptr;
1452  Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out));
1453  return UnownedValue(out);
1454 }
1455 
1457  void* out = nullptr;
1458  Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
1459  return out;
1460 }
1461 
1462 inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) {
1463  Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_));
1464 }
1465 
1466 namespace detail {
1467 template <typename T>
1469  OrtKernelInfo* info_copy = nullptr;
1470  Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy));
1471  return KernelInfo{info_copy};
1472 }
1473 
1474 template <typename T>
1475 inline size_t KernelInfoImpl<T>::GetInputCount() const {
1476  size_t out = 0;
1477  ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out));
1478  return out;
1479 }
1480 
1481 template <typename T>
1482 inline size_t KernelInfoImpl<T>::GetOutputCount() const {
1483  size_t out = 0;
1484  ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out));
1485  return out;
1486 }
1487 
1488 template <typename T>
1490  size_t size = 0;
1491 
1492  // Feed nullptr for the data buffer to query the true size of the string value
1493  Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size));
1494 
1495  std::string out;
1496  out.resize(size);
1497  Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size));
1498  out.resize(size - 1); // remove the terminating character '\0'
1499 
1500  return out;
1501 }
1502 
1503 template <typename T>
1505  size_t size = 0;
1506 
1507  // Feed nullptr for the data buffer to query the true size of the string value
1508  Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size));
1509 
1510  std::string out;
1511  out.resize(size);
1512  Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size));
1513  out.resize(size - 1); // remove the terminating character '\0'
1514 
1515  return out;
1516 }
1517 
1518 template <typename T>
1520  OrtTypeInfo* out = nullptr;
1521  ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out));
1522  return TypeInfo{out};
1523 }
1524 
1525 template <typename T>
1527  OrtTypeInfo* out = nullptr;
1528  ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out));
1529  return TypeInfo{out};
1530 }
1531 
1532 template <typename T>
1533 inline Value KernelInfoImpl<T>::GetTensorAttribute(const char* name, OrtAllocator* allocator) const {
1534  OrtValue* out = nullptr;
1535  ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out));
1536  return Value{out};
1537 }
1538 
1539 inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
1540  Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
1541 }
1542 
1543 inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) {
1544  Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out));
1545 }
1546 
1547 inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) {
1548  size_t size = 0;
1549  // Feed nullptr for the data buffer to query the true size of the string attribute
1550  Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size));
1551 
1552  std::string out;
1553  out.resize(size);
1554  Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size));
1555  out.resize(size - 1); // remove the terminating character '\0'
1556  out.swap(result);
1557 }
1558 
1559 inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>& result) {
1560  size_t size = 0;
1561  // Feed nullptr for the data buffer to query the true size of the attribute
1562  Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size));
1563 
1564  std::vector<float> out;
1565  out.resize(size);
1566  Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size));
1567  out.swap(result);
1568 }
1569 
1570 inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>& result) {
1571  size_t size = 0;
1572 
1573  // Feed nullptr for the data buffer to query the true size of the attribute
1574  Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size));
1575 
1576  std::vector<int64_t> out;
1577  out.resize(size);
1578  Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size));
1579  out.swap(result);
1580 }
1581 } // namespace detail
1582 
1583 inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}
1584 
1585 inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
1586 
1587 inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version,
1588  const char** type_constraint_names,
1589  const ONNXTensorElementDataType* type_constraint_values,
1590  size_t type_constraint_count,
1591  const OpAttr* attr_values, size_t attr_count,
1592  size_t input_count, size_t output_count) {
1593  static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*),
1594  "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely");
1595  auto attr_input_values = reinterpret_cast<const OrtOpAttr* const*>(attr_values);
1596  OrtOp* op;
1597  Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
1598  static_cast<int>(type_constraint_count),
1599  attr_input_values,
1600  static_cast<int>(attr_count),
1601  static_cast<int>(input_count),
1602  static_cast<int>(output_count), &op));
1603  return Op{op};
1604 }
1605 
1606 inline void Op::Invoke(const OrtKernelContext* context,
1607  const Value* input_values,
1608  size_t input_count,
1609  Value* output_values,
1610  size_t output_count) {
1611  static_assert(sizeof(Value) == sizeof(OrtValue*),
1612  "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
1613  auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1614  auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1615  Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast<int>(input_count),
1616  ort_output_values, static_cast<int>(output_count)));
1617 }
1618 
1619 inline void Op::Invoke(const OrtKernelContext* context,
1620  const OrtValue* const* input_values,
1621  size_t input_count,
1622  OrtValue* const* output_values,
1623  size_t output_count) {
1624  Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast<int>(input_count),
1625  output_values, static_cast<int>(output_count)));
1626 }
1627 
1628 inline void CustomOpApi::ThrowOnError(OrtStatus* status) {
1629  Ort::ThrowOnError(status);
1630 }
1631 
1632 template <>
1633 inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1634  float out;
1635  Ort::ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
1636  return out;
1637 }
1638 
1639 template <>
1640 inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1641  int64_t out;
1642  Ort::ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
1643  return out;
1644 }
1645 
1646 template <>
1647 inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1648  size_t size = 0;
1649  std::string out;
1650 
1651  // Feed nullptr for the data buffer to query the true size of the string attribute
1652  OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
1653 
1654  if (status == nullptr) {
1655  out.resize(size);
1656  Ort::ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
1657  out.resize(size - 1); // remove the terminating character '\0'
1658  } else {
1659  Ort::ThrowOnError(status);
1660  }
1661  return out;
1662 }
1663 
1664 template <>
1665 inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1666  size_t size = 0;
1667  std::vector<float> out;
1668 
1669  // Feed nullptr for the data buffer to query the true size of the attribute
1670  OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
1671 
1672  if (status == nullptr) {
1673  out.resize(size);
1674  Ort::ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
1675  } else {
1676  Ort::ThrowOnError(status);
1677  }
1678  return out;
1679 }
1680 
1681 template <>
1682 inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1683  size_t size = 0;
1684  std::vector<int64_t> out;
1685 
1686  // Feed nullptr for the data buffer to query the true size of the attribute
1687  OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
1688 
1689  if (status == nullptr) {
1690  out.resize(size);
1691  Ort::ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
1692  } else {
1693  Ort::ThrowOnError(status);
1694  }
1695  return out;
1696 }
1697 inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) {
1698  OrtTensorTypeAndShapeInfo* out;
1699  Ort::ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
1700  return out;
1701 }
1702 
1703 inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
1704  size_t out;
1705  Ort::ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
1706  return out;
1707 }
1708 
1709 inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) {
1711  Ort::ThrowOnError(api_.GetTensorElementType(info, &out));
1712  return out;
1713 }
1714 
1715 inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
1716  size_t out;
1717  Ort::ThrowOnError(api_.GetDimensionsCount(info, &out));
1718  return out;
1719 }
1720 
1721 inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) {
1722  Ort::ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
1723 }
1724 
1725 inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) {
1726  Ort::ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
1727 }
1728 
1729 template <typename T>
1731  T* data;
1732  Ort::ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
1733  return data;
1734 }
1735 
1736 inline const OrtMemoryInfo* CustomOpApi::GetTensorMemoryInfo(_In_ const OrtValue* value) {
1737  const OrtMemoryInfo* mem_info;
1738  Ort::ThrowOnError(api_.GetTensorMemoryInfo(value, &mem_info));
1739  return mem_info;
1740 }
1741 
1742 template <typename T>
1743 inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) {
1744  T* data = nullptr;
1745  Ort::ThrowOnError(api_.GetTensorMutableData(const_cast<OrtValue*>(value), reinterpret_cast<void**>(&data)));
1746  return data;
1747 }
1748 
1749 inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) {
1750  size_t out;
1751  Ort::ThrowOnError(api_.GetDimensionsCount(info, &out));
1752  std::vector<int64_t> output(out);
1753  Ort::ThrowOnError(api_.GetDimensions(info, output.data(), out));
1754  return output;
1755 }
1756 
1757 inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) {
1758  api_.ReleaseTensorTypeAndShapeInfo(input);
1759 }
1760 
1761 inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) {
1762  size_t out;
1763  Ort::ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
1764  return out;
1765 }
1766 
1767 inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) {
1768  const OrtValue* out;
1769  Ort::ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
1770  return out;
1771 }
1772 
1773 inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) {
1774  size_t out;
1775  Ort::ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
1776  return out;
1777 }
1778 
1780  _In_ const int64_t* dim_values, size_t dim_count) {
1781  OrtValue* out;
1782  Ort::ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
1783  return out;
1784 }
1785 
1787  void* out;
1788  Ort::ThrowOnError(api_.KernelContext_GetGPUComputeStream(context, &out));
1789  return out;
1790 }
1791 
1792 inline OrtOpAttr* CustomOpApi::CreateOpAttr(_In_ const char* name,
1793  _In_ const void* data,
1794  _In_ int len,
1795  _In_ OrtOpAttrType type) {
1796  OrtOpAttr* op_attr{};
1797  Ort::ThrowOnError(api_.CreateOpAttr(name, data, len, type, &op_attr));
1798  return op_attr;
1799 }
1800 
1801 inline void CustomOpApi::ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr* op_attr) {
1802  api_.ReleaseOpAttr(op_attr);
1803 }
1804 
1805 inline OrtOp* CustomOpApi::CreateOp(_In_ const OrtKernelInfo* info,
1806  _In_ const char* op_name,
1807  _In_ const char* domain,
1808  _In_ int version,
1809  _In_opt_ const char** type_constraint_names,
1810  _In_opt_ const ONNXTensorElementDataType* type_constraint_values,
1811  _In_opt_ int type_constraint_count,
1812  _In_opt_ const OrtOpAttr* const* attr_values,
1813  _In_opt_ int attr_count,
1814  _In_ int input_count,
1815  _In_ int output_count) {
1816  OrtOp* ort_op{};
1817  Ort::ThrowOnError(api_.CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
1818  type_constraint_count, attr_values, attr_count, input_count, output_count, &ort_op));
1819  return ort_op;
1820 }
1821 
1822 inline void CustomOpApi::InvokeOp(_In_ const OrtKernelContext* context,
1823  _In_ const OrtOp* ort_op,
1824  _In_ const OrtValue* const* input_values,
1825  _In_ int input_count,
1826  _Inout_ OrtValue* const* output_values,
1827  _In_ int output_count) {
1828  Ort::ThrowOnError(api_.InvokeOp(context, ort_op, input_values, input_count, output_values, output_count));
1829 }
1830 
1831 inline void CustomOpApi::ReleaseOp(_Frees_ptr_opt_ OrtOp* ort_op) {
1832  api_.ReleaseOp(ort_op);
1833 }
1834 
1836  OrtKernelInfo* info_copy{};
1837  Ort::ThrowOnError(api_.CopyKernelInfo(info, &info_copy));
1838  return info_copy;
1839 }
1840 
1842  api_.ReleaseKernelInfo(info_copy);
1843 }
1844 
1845 inline std::vector<std::string> GetAvailableProviders() {
1846  int len;
1847  char** providers;
1848  ThrowOnError(GetApi().GetAvailableProviders(&providers, &len));
1849  std::vector<std::string> available_providers(providers, providers + len);
1850  ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len));
1851  return available_providers;
1852 }
1853 
1854 SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val);
1855 
1856 template <typename TOp, typename TKernel>
1857 void CustomOpBase<TOp, TKernel>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
1858  ConstSessionOptions options) const {
1859  const TOp* derived = static_cast<const TOp*>(this);
1860  std::vector<std::string> keys = derived->GetSessionConfigKeys();
1861 
1862  out.reserve(keys.size());
1863 
1864  std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), "");
1865  const size_t prefix_size = config_entry_key.length();
1866 
1867  for (const auto& key : keys) {
1868  config_entry_key.resize(prefix_size);
1869  config_entry_key.append(key);
1870  out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), "");
1871  }
1872 }
1873 
1874 } // namespace Ort
OrtMemoryInfoDeviceType GetDeviceType() const
ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo *info)
void * KernelContext_GetGPUComputeStream(const OrtKernelContext *context)
void Invoke(const OrtKernelContext *context, const Value *input_values, size_t input_count, Value *output_values, size_t output_count)
SessionOptionsImpl & SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn)
Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn.
GLuint GLsizei const GLchar * message
Definition: glcorearb.h:2543
size_t GetElementCount() const
Wraps OrtApi::GetTensorShapeElementCount.
static Value CreateOpaque(const char *domain, const char *type_name, const T &)
Wraps OrtApi::CreateOpaqueValue.
MemoryAllocation & operator=(const MemoryAllocation &)=delete
Env & DisableTelemetryEvents()
Wraps OrtApi::EnableTelemetryEvents.
AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of the overridable initializer name at then specified index.
ThreadingOptions & SetGlobalSpinControl(int allow_spinning)
Wraps OrtApi::SetGlobalSpinControl.
OrtCustomThreadHandle(* OrtCustomCreateThreadFn)(void *ort_custom_thread_creation_options, OrtThreadWorkerFn ort_thread_worker_fn, void *ort_worker_fn_param)
Ort custom thread creation function.
std::string GetErrorMessage() const
size_t GetInputCount() const
Returns the number of model inputs.
SessionOptionsImpl & DisablePerSessionThreads()
Wraps OrtApi::DisablePerSessionThreads.
Env & UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level)
Wraps OrtApi::UpdateEnvWithCustomLogLevel.
#define _In_
void InvokeOp(_In_ const OrtKernelContext *context, _In_ const OrtOp *ort_op, _In_ const OrtValue *const *input_values, _In_ int input_count, _Inout_ OrtValue *const *output_values, _In_ int output_count)
void UseBlockSparseIndices(const Shape &indices_shape, int32_t *indices_data)
Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSp...
RunOptions & AddConfigEntry(const char *config_key, const char *config_value)
Wraps OrtApi::AddRunConfigEntry.
ConstMapTypeInfo GetMapTypeInfo() const
Wraps OrtApi::CastTypeInfoToMapTypeInfo.
void FillSparseTensorCsr(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const int64_t *inner_indices_data, size_t inner_indices_num, const int64_t *outer_indices_data, size_t outer_indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
TypeInfo GetInputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetInputTypeInfo.
size_t GetOutputCount() const
void GetSymbolicDimensions(const char **values, size_t values_count) const
Wraps OrtApi::GetSymbolicDimensions.
Type information that may contain either TensorTypeAndShapeInfo or the information about contained se...
SessionOptionsImpl & EnableMemPattern()
Wraps OrtApi::EnableMemPattern.
#define _Frees_ptr_opt_
SessionOptionsImpl & EnableCpuMemArena()
Wraps OrtApi::EnableCpuMemArena.
static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1)
bool IsTensor() const
Returns true if Value is a tensor, false for other types like map/sequence/etc.
AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of output name at then specified index.
void GetOpaqueData(const char *domain, const char *type_name, R &) const
Obtains a pointer to a user defined data for experimental purposes
const void * GetTensorRawData() const
Returns a non-typed pointer to a tensor contained data.
ConstMemoryInfo GetInfo() const
T * GetTensorMutableData(_Inout_ OrtValue *value)
void GetStringTensorElement(size_t buffer_length, size_t element_index, void *buffer) const
The API copies UTF-8 encoded bytes for the requested string element contained within a tensor or a sp...
SessionOptionsImpl & SetLogId(const char *logid)
Wraps OrtApi::SetSessionLogId.
GLboolean * data
Definition: glcorearb.h:131
void swap(UT::ArraySet< Key, MULTI, MAX_LOAD_FACTOR_256, Clearer, Hash, KeyEqual > &a, UT::ArraySet< Key, MULTI, MAX_LOAD_FACTOR_256, Clearer, Hash, KeyEqual > &b)
Definition: UT_ArraySet.h:1631
Value GetValue(int index, OrtAllocator *allocator) const
uint64_t GetProfilingStartTimeNs() const
Wraps OrtApi::SessionGetProfilingStartTimeNs.
#define ORTCHAR_T
Status(std::nullptr_t)
Create an empty object, must be assigned a valid one to be used.
SessionOptionsImpl & AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN.
UnownedValue GetOutput(size_t index, const int64_t *dim_values, size_t dim_count) const
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
void ThrowOnError(OrtStatus *ort_status)
GLsizei const GLfloat * value
Definition: glcorearb.h:824
void BindInput(const char *name, const Value &)
ONNXTensorElementDataType GetMapKeyType() const
Wraps OrtApi::GetMapKeyType.
OrtMemoryInfoDeviceType
This mimics OrtDevice type constants so they can be returned in the API.
ConstValue GetInput(size_t index) const
SessionOptionsImpl & AppendExecutionProvider_CANN(const OrtCANNProviderOptions &provider_options)
MemoryAllocation GetAllocation(size_t size)
std::unique_ptr< char, detail::AllocatedFree > AllocatedStringPtr
unique_ptr typedef used to own strings allocated by OrtAllocators and release them at the end of the ...
SessionOptionsImpl & DisableCpuMemArena()
Wraps OrtApi::DisableCpuMemArena.
AllocatedStringPtr GetDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the description.
bool HasConfigEntry(const char *config_key) const
Wraps OrtApi::HasSessionConfigEntry.
void UseCsrIndices(int64_t *inner_data, size_t inner_num, int64_t *outer_data, size_t outer_num)
Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tens...
ONNXTensorElementDataType GetElementType() const
Wraps OrtApi::GetTensorElementType.
Env & CreateAndRegisterAllocator(const OrtMemoryInfo *mem_info, const OrtArenaCfg *arena_cfg)
Wraps OrtApi::CreateAndRegisterAllocator.
void FillStringTensorElement(const char *s, size_t index)
Set a single string in a string tensor
void ReleaseOp(_Frees_ptr_opt_ OrtOp *ort_op)
std::vector< Value > Run(const RunOptions &run_options, const char *const *input_names, const Value *input_values, size_t input_count, const char *const *output_names, size_t output_count)
Run the model returning results in an Ort allocated vector.
Wrapper around ::OrtModelMetadata.
GLint level
Definition: glcorearb.h:108
void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo *info, _Out_ int64_t *dim_values, size_t dim_values_length)
#define _In_opt_
SessionOptionsImpl & AddInitializer(const char *name, const OrtValue *ort_val)
Wraps OrtApi::AddInitializer.
const R * GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t &num_indices) const
The API retrieves a pointer to the internal indices buffer. The API merely performs a convenience dat...
std::string GetConfigEntryOrDefault(const char *config_key, const std::string &def)
OpAttr(const char *name, const void *data, int len, OrtOpAttrType type)
SessionOptionsImpl & AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT.
GLdouble s
Definition: glad.h:3009
SessionOptionsImpl & EnableOrtCustomOps()
Wraps OrtApi::EnableOrtCustomOps.
size_t GetDimensionsCount() const
Wraps OrtApi::GetDimensionsCount.
AllocatedStringPtr LookupCustomMetadataMapAllocated(const char *key, OrtAllocator *allocator) const
Looks up a value by a key in the Custom Metadata map.
R & At(const std::vector< int64_t > &location)
std::string GetAllocatorName() const
SessionOptionsImpl & AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA.
Wrapper around OrtValue.
SessionOptionsImpl & SetIntraOpNumThreads(int intra_op_num_threads)
Wraps OrtApi::SetIntraOpNumThreads.
static Value CreateSequence(std::vector< Value > &values)
Wraps OrtApi::CreateValue.
**But if you need a result
Definition: thread.h:613
SessionOptionsImpl & RegisterCustomOpsUsingFunction(const char *function_name)
Wraps OrtApi::RegisterCustomOpsUsingFunction.
IoBinding(std::nullptr_t)
Create an empty object for convenience. Sometimes, we want to initialize members later.
OrtAllocatorType
The Env (Environment)
std::vector< AllocatedStringPtr > GetCustomMetadataMapKeysAllocated(OrtAllocator *allocator) const
Returns a vector of copies of the custom metadata keys.
SessionOptionsImpl & SetOptimizedModelFilePath(const ORTCHAR_T *optimized_model_file)
Wraps OrtApi::SetOptimizedModelFilePath.
Env(std::nullptr_t)
Create an empty Env object, must be assigned a valid one to be used.
ThreadingOptions & SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn)
Wraps OrtApi::SetGlobalCustomJoinThreadFn.
const R * GetSparseTensorValues() const
The API returns a pointer to an internal buffer of the sparse tensor containing non-zero values...
TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const
The API returns type and shape information for the specified indices. Each supported indices have the...
GLuint buffer
Definition: glcorearb.h:660
SessionOptionsImpl & SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level)
Wraps OrtApi::SetSessionGraphOptimizationLevel.
const R * GetTensorData() const
Returns a const typed pointer to the tensor contained data. No type checking is performed, the caller must ensure the type matches the tensor type.
SessionOptionsImpl & SetCustomThreadCreationOptions(void *ort_custom_thread_creation_options)
Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions.
void * GetTensorMutableRawData()
Returns a non-typed non-const pointer to a tensor contained data.
union Ort::detail::OrtSparseValuesParam::@93 data
static Op Create(const OrtKernelInfo *info, const char *op_name, const char *domain, int version, const char **type_constraint_names, const ONNXTensorElementDataType *type_constraint_values, size_t type_constraint_count, const OpAttr *attr_values, size_t attr_count, size_t input_count, size_t output_count)
std::string GetConfigEntry(const char *config_key) const
Wraps OrtApi::GetSessionConfigEntry.
void ThrowStatus(const Status &st)
std::vector< std::string > GetOutputNames() const
void(ORT_API_CALL * OrtLoggingFunction)(void *param, OrtLoggingLevel severity, const char *category, const char *logid, const char *code_location, const char *message)
SessionOptionsImpl & AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2 &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT.
#define _Inout_
OrtSparseIndicesFormat
GLuint GLsizei const GLuint const GLintptr * offsets
Definition: glcorearb.h:2621
bool IsSparseTensor() const
Returns true if the OrtValue contains a sparse tensor
std::vector< Value > GetOutputValuesHelper(const OrtIoBinding *binding, OrtAllocator *)
void ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo *info_copy)
ModelMetadata GetModelMetadata() const
Wraps OrtApi::SessionGetModelMetadata.
AllocatedStringPtr GetGraphNameAllocated(OrtAllocator *allocator) const
Used for interop with the C API.
size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo *info)
ThreadingOptions & SetGlobalInterOpNumThreads(int inter_op_num_threads)
Wraps OrtApi::SetGlobalInterOpNumThreads.
void FillSparseTensorBlockSparse(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values, const Shape &indices_shape, const int32_t *indices_data)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
TypeInfo GetOutputTypeInfo(size_t index) const
std::vector< int64_t > GetShape() const
Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape.
const T * GetTensorData(_Inout_ const OrtValue *value)
const OrtValue * KernelContext_GetInput(const OrtKernelContext *context, _In_ size_t index)
std::vector< int64_t > GetTensorShape(const OrtTensorTypeAndShapeInfo *info)
Wrapper around OrtMemoryInfo.
const OrtMemoryInfo * GetTensorMemoryInfo(_In_ const OrtValue *value)
std::vector< std::string > GetAvailableProviders()
This is a C++ wrapper for OrtApi::GetAvailableProviders() and returns a vector of strings representin...
Op(std::nullptr_t)
Create an empty Operator object, must be assigned a valid one to be used.
size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo *info)
SessionOptions Clone() const
Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions.
Definition: core.h:760
Wrapper around ::OrtIoBinding.
void GetAttrs(const OrtKernelInfo *p, const char *name, std::vector< float > &)
SessionOptionsImpl & SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn)
Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn.
OrtKernelInfo * CopyKernelInfo(_In_ const OrtKernelInfo *info)
The Status that holds ownership of OrtStatus received from C API Use it to safely destroy OrtStatus* ...
OrtOpAttr * CreateOpAttr(_In_ const char *name, _In_ const void *data, _In_ int len, _In_ OrtOpAttrType type)
void GetStringTensorContent(void *buffer, size_t buffer_length, size_t *offsets, size_t offsets_count) const
The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor into...
void GetDimensions(int64_t *values, size_t values_count) const
Wraps OrtApi::GetDimensions.
A generic, discriminated value, whose type may be queried dynamically.
Definition: Value.h:44
int64_t GetVersion() const
Wraps OrtApi::ModelMetadataGetVersion.
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo *input)
void GetAttr(const OrtKernelInfo *p, const char *name, float &)
CustomOpDomain(std::nullptr_t)
Create an empty CustomOpDomain object, must be assigned a valid one to be used.
SessionOptionsImpl & AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO.
const std::unordered_map< std::string, std::string > & GetFlattenedConfigs() const
Returns a flattened map of custom operator configuration entries and their values.
SessionOptionsImpl & AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2 &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2.
constexpr std::enable_if< I< type_count_base< T >::value, int >::type tuple_type_size(){return subtype_count< typename std::tuple_element< I, T >::type >::value+tuple_type_size< T, I+1 >);}template< typename T > struct type_count< T, typename std::enable_if< is_tuple_like< T >::value >::type >{static constexpr int value{tuple_type_size< T, 0 >)};};template< typename T > struct subtype_count{static constexpr int value{is_mutable_container< T >::value?expected_max_vector_size:type_count< T >::value};};template< typename T, typename Enable=void > struct type_count_min{static const int value{0};};template< typename T >struct type_count_min< T, typename std::enable_if<!is_mutable_container< T >::value &&!is_tuple_like< T >::value &&!is_wrapper< T >::value &&!is_complex< T >::value &&!std::is_void< T >::value >::type >{static constexpr int value{type_count< T >::value};};template< typename T > struct type_count_min< T, typename std::enable_if< is_complex< T >::value >::type >{static constexpr int value{1};};template< typename T >struct type_count_min< T, typename std::enable_if< is_wrapper< T >::value &&!is_complex< T >::value &&!is_tuple_like< T >::value >::type >{static constexpr int value{subtype_count_min< typename T::value_type >::value};};template< typename T, std::size_t I >constexpr typename std::enable_if< I==type_count_base< T >::value, int >::type tuple_type_size_min(){return 0;}template< typename T, std::size_t I > constexpr typename std::enable_if< I< type_count_base< T >::value, int >::type tuple_type_size_min(){return subtype_count_min< typename std::tuple_element< I, T >::type >::value+tuple_type_size_min< T, I+1 >);}template< typename T > struct type_count_min< T, typename std::enable_if< is_tuple_like< T >::value >::type >{static constexpr int value{tuple_type_size_min< T, 0 >)};};template< typename T > struct subtype_count_min{static constexpr int value{is_mutable_container< T >::value?((type_count< T >::value< expected_max_vector_size)?type_count< T >::value:0):type_count_min< T >::value};};template< typename T, typename Enable=void > struct expected_count{static const int value{0};};template< typename T >struct expected_count< T, typename std::enable_if<!is_mutable_container< T >::value &&!is_wrapper< T >::value &&!std::is_void< T >::value >::type >{static constexpr int value{1};};template< typename T > struct expected_count< T, typename std::enable_if< is_mutable_container< T >::value >::type >{static constexpr int value{expected_max_vector_size};};template< typename T >struct expected_count< T, typename std::enable_if<!is_mutable_container< T >::value &&is_wrapper< T >::value >::type >{static constexpr int value{expected_count< typename T::value_type >::value};};enum class object_category:int{char_value=1, integral_value=2, unsigned_integral=4, enumeration=6, boolean_value=8, floating_point=10, number_constructible=12, double_constructible=14, integer_constructible=16, string_assignable=23, string_constructible=24, other=45, wrapper_value=50, complex_number=60, tuple_value=70, container_value=80,};template< typename T, typename Enable=void > struct classify_object{static constexpr object_category value{object_category::other};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_integral< T >::value &&!std::is_same< T, char >::value &&std::is_signed< T >::value &&!is_bool< T >::value &&!std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::integral_value};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_integral< T >::value &&std::is_unsigned< T >::value &&!std::is_same< T, char >::value &&!is_bool< T >::value >::type >{static constexpr object_category value{object_category::unsigned_integral};};template< typename T >struct classify_object< T, typename std::enable_if< std::is_same< T, char >::value &&!std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::char_value};};template< typename T > struct classify_object< T, typename std::enable_if< is_bool< T >::value >::type >{static constexpr object_category value{object_category::boolean_value};};template< typename T > struct classify_object< T, typename std::enable_if< std::is_floating_point< T >::value >::type >{static constexpr object_category value{object_category::floating_point};};template< typename T >struct classify_object< T, typename std::enable_if<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&std::is_assignable< T &, std::string >::value >::type >{static constexpr object_category value{object_category::string_assignable};};template< typename T >struct classify_object< T, typename std::enable_if<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&!std::is_assignable< T &, std::string >::value &&(type_count< T >::value==1)&&std::is_constructible< T, std::string >::value >::type >{static constexpr object_category value{object_category::string_constructible};};template< typename T > struct classify_object< T, typename std::enable_if< std::is_enum< T >::value >::type >{static constexpr object_category value{object_category::enumeration};};template< typename T > struct classify_object< T, typename std::enable_if< is_complex< T >::value >::type >{static constexpr object_category value{object_category::complex_number};};template< typename T > struct uncommon_type{using type=typename std::conditional<!std::is_floating_point< T >::value &&!std::is_integral< T >::value &&!std::is_assignable< T &, std::string >::value &&!std::is_constructible< T, std::string >::value &&!is_complex< T >::value &&!is_mutable_container< T >::value &&!std::is_enum< T >::value, std::true_type, std::false_type >::type;static constexpr bool value=type::value;};template< typename T >struct classify_object< T, typename std::enable_if<(!is_mutable_container< T >::value &&is_wrapper< T >::value &&!is_tuple_like< T >::value &&uncommon_type< T >::value)>::type >{static constexpr object_category value{object_category::wrapper_value};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&is_direct_constructible< T, double >::value &&is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::number_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&!is_direct_constructible< T, double >::value &&is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::integer_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< uncommon_type< T >::value &&type_count< T >::value==1 &&!is_wrapper< T >::value &&is_direct_constructible< T, double >::value &&!is_direct_constructible< T, int >::value >::type >{static constexpr object_category value{object_category::double_constructible};};template< typename T >struct classify_object< T, typename std::enable_if< is_tuple_like< T >::value &&((type_count< T >::value >=2 &&!is_wrapper< T >::value)||(uncommon_type< T >::value &&!is_direct_constructible< T, double >::value &&!is_direct_constructible< T, int >::value)||(uncommon_type< T >::value &&type_count< T >::value >=2))>::type >{static constexpr object_category value{object_category::tuple_value};};template< typename T > struct classify_object< T, typename std::enable_if< is_mutable_container< T >::value >::type >{static constexpr object_category value{object_category::container_value};};template< typename T, enable_if_t< classify_object< T >::value==object_category::char_value, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"CHAR";}template< typename T, enable_if_t< classify_object< T >::value==object_category::integral_value||classify_object< T >::value==object_category::integer_constructible, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"INT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::unsigned_integral, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"UINT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::floating_point||classify_object< T >::value==object_category::number_constructible||classify_object< T >::value==object_category::double_constructible, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"FLOAT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::enumeration, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"ENUM";}template< typename T, enable_if_t< classify_object< T >::value==object_category::boolean_value, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"BOOLEAN";}template< typename T, enable_if_t< classify_object< T >::value==object_category::complex_number, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"COMPLEX";}template< typename T, enable_if_t< classify_object< T >::value >=object_category::string_assignable &&classify_object< T >::value<=object_category::other, detail::enabler >=detail::dummy >constexpr const char *type_name(){return"TEXT";}template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value >=2, detail::enabler >=detail::dummy >std::string type_name();template< typename T, enable_if_t< classify_object< T >::value==object_category::container_value||classify_object< T >::value==object_category::wrapper_value, detail::enabler >=detail::dummy >std::string type_name();template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value==1, detail::enabler >=detail::dummy >inline std::string type_name(){return type_name< typename std::decay< typename std::tuple_element< 0, T >::type >::type >);}template< typename T, std::size_t I >inline typename std::enable_if< I==type_count_base< T >::value, std::string >::type tuple_name(){return std::string{};}template< typename T, std::size_t I >inline typename std::enable_if<(I< type_count_base< T >::value), std::string >::type tuple_name(){auto str=std::string{type_name< typename std::decay< typename std::tuple_element< I, T >::type >::type >)}+ ','+tuple_name< T, I+1 >);if(str.back()== ',') str.pop_back();return str;}template< typename T, enable_if_t< classify_object< T >::value==object_category::tuple_value &&type_count_base< T >::value >=2, detail::enabler > > std::string type_name()
Recursively generate the tuple type name.
Definition: CLI11.h:1729
static Value CreateTensor(const OrtMemoryInfo *info, T *p_data, size_t p_data_element_count, const int64_t *shape, size_t shape_len)
Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
GLint GLint GLsizei GLint GLenum format
Definition: glcorearb.h:108
const OrtApi & GetApi()
This returns a reference to the OrtApi interface in use.
ThreadingOptions & SetGlobalIntraOpNumThreads(int intra_op_num_threads)
Wraps OrtApi::SetGlobalIntraOpNumThreads.
GraphOptimizationLevel
Graph optimization level.
size_t KernelContext_GetOutputCount(const OrtKernelContext *context)
Value GetTensorAttribute(const char *name, OrtAllocator *allocator) const
void Add(const OrtCustomOp *op)
Wraps CustomOpDomain_Add.
KernelInfo(std::nullptr_t)
Create an empty instance to initialize later.
AllocatedStringPtr GetDomainAllocated(OrtAllocator *allocator) const
Returns a copy of the domain name.
Represents native memory allocation coming from one of the OrtAllocators registered with OnnxRuntime...
Session(std::nullptr_t)
Create an empty Session object, must be assigned a valid one to be used.
ThreadingOptions & SetGlobalDenormalAsZero()
Wraps OrtApi::SetGlobalDenormalAsZero.
IEEE 754 half-precision floating point data type.
GLint location
Definition: glcorearb.h:805
OpenVINO Provider Options.
GLuint id
Definition: glcorearb.h:655
SessionOptionsImpl & SetInterOpNumThreads(int inter_op_num_threads)
Wraps OrtApi::SetInterOpNumThreads.
Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT_V...
OrtStatus *ORT_API_CALL * CreateStatus(OrtErrorCode code, _In_ const char *msg) NO_EXCEPTION ORT_ALL_ARGS_NONNULL
Create an OrtStatus from a null terminated string.
R * GetTensorMutableData()
Returns a non-const typed pointer to an OrtValue/Tensor contained buffer No type checking is performe...
SessionOptions()
Wraps OrtApi::CreateSessionOptions.
void * GetGPUComputeStream() const
void BindOutput(const char *name, const Value &)
GLuint const GLchar * name
Definition: glcorearb.h:786
RunOptions & SetRunTag(const char *run_tag)
wraps OrtApi::RunOptionsSetRunTag
Allocator(std::nullptr_t)
Convenience to create a class member and then replace with an instance.
SessionOptionsImpl & EnableProfiling(const ORTCHAR_T *profile_file_prefix)
Wraps OrtApi::EnableProfiling.
size_t GetStringTensorDataLength() const
This API returns a full length of string data contained within either a tensor or a sparse Tensor...
RunOptions & SetTerminate()
Terminates all currently executing Session::Run calls that were made using this RunOptions instance...
OrtAllocatorType GetAllocatorType() const
TypeInfo GetSequenceElementType() const
Wraps OrtApi::GetSequenceElementType.
static Value CreateSparseTensor(const OrtMemoryInfo *info, T *p_data, const Shape &dense_shape, const Shape &values_shape)
This is a simple forwarding method to the other overload that helps deducing data type enum value fro...
TypeInfo GetOutputTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOutputTypeInfo.
size_t GetStringTensorElementLength(size_t element_index) const
The API returns a byte length of UTF-8 encoded string element contained in either a tensor or a spare...
detail::ValueImpl< detail::Unowned< OrtValue >> UnownedValue
SessionOptionsImpl & Add(OrtCustomOpDomain *custom_op_domain)
Wraps OrtApi::AddCustomOpDomain.
TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const
The API returns type and shape information for stored non-zero values of the sparse tensor...
GT_API const UT_StringHolder version
TypeInfo GetTypeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
#define _Out_
RunOptions & SetRunLogSeverityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogSeverityLevel.
ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
Wraps OrtApi::CastTypeInfoToTensorInfo.
TensorRT Provider Options.
bool operator==(const MemoryInfoImpl< U > &o) const
TypeInfo GetMapValueType() const
Wraps OrtApi::GetMapValueType.
RunOptions()
Wraps OrtApi::CreateRunOptions.
OrtTensorTypeAndShapeInfo * GetTensorTypeAndShape(_In_ const OrtValue *value)
This struct owns the OrtKernInfo* pointer when a copy is made. For convenient wrapping of OrtKernelIn...
void FillSparseTensorCoo(const OrtMemoryInfo *data_mem_info, const OrtSparseValuesParam &values_param, const int64_t *indices_data, size_t indices_num)
The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API an...
std::vector< Value > GetOutputValues() const
GLsizeiptr size
Definition: glcorearb.h:664
OrtErrorCode GetErrorCode() const
const char * what() const noexceptoverride
Definition: Exception.h:44
struct OrtKernelInfo OrtKernelInfo
void FillStringTensor(const char *const *s, size_t s_len)
Set all strings at once in a string tensor
KernelContext(OrtKernelContext *context)
ArenaCfg(std::nullptr_t)
Create an empty ArenaCfg object, must be assigned a valid one to be used.
GT_API const UT_StringHolder st
ThreadingOptions()
Wraps OrtApi::CreateThreadingOptions.
void ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr *op_attr)
GLenum GLsizei GLsizei GLint * values
Definition: glcorearb.h:1602
void SetDimensions(OrtTensorTypeAndShapeInfo *info, _In_ const int64_t *dim_values, size_t dim_count)
std::string GetOutputName(size_t index) const
The ThreadingOptions.
OrtOpAttrType
static Value CreateMap(Value &keys, Value &values)
Wraps OrtApi::CreateValue.
void GetSessionConfigs(std::unordered_map< std::string, std::string > &out, ConstSessionOptions options) const
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const
Wraps OrtApi::SessionGetOverridableInitializerTypeInfo.
OrtSparseFormat
SessionOptionsImpl & AddConfigEntry(const char *config_key, const char *config_value)
Wraps OrtApi::AddSessionConfigEntry.
Memory allocation interface.
CUDA Provider Options.
OrtOp * CreateOp(_In_ const OrtKernelInfo *info, _In_ const char *op_name, _In_ const char *domain, _In_ int version, _In_opt_ const char **type_constraint_names, _In_opt_ const ONNXTensorElementDataType *type_constraint_values, _In_opt_ int type_constraint_count, _In_opt_ const OrtOpAttr *const *attr_values, _In_opt_ int attr_count, _In_ int input_count, _In_ int output_count)
SessionOptionsImpl & SetExecutionMode(ExecutionMode execution_mode)
Wraps OrtApi::SetSessionExecutionMode.
GLuint index
Definition: glcorearb.h:786
ROCM Provider Options.
AllocatedStringPtr EndProfilingAllocated(OrtAllocator *allocator)
End profiling and return a copy of the profiling file name.
SessionOptionsImpl & AppendExecutionProvider(const std::string &provider_name, const std::unordered_map< std::string, std::string > &provider_options={})
Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports SNPE and XNNPACK.
SessionOptionsImpl & DisableProfiling()
Wraps OrtApi::DisableProfiling.
AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator *allocator) const
Returns a copy of the graph description.
ExecutionMode
OrtValue * KernelContext_GetOutput(OrtKernelContext *context, _In_ size_t index, _In_ const int64_t *dim_values, size_t dim_count)
RunOptions & UnsetTerminate()
Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without ...
size_t GetOverridableInitializerCount() const
Returns the number of inputs that have defaults that can be overridden.
const char * GetRunTag() const
Wraps OrtApi::RunOptionsGetRunTag.
int GetRunLogVerbosityLevel() const
Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel.
Definition: core.h:1131
size_t GetOutputCount() const
Returns the number of model outputs.
SessionOptionsImpl & AddExternalInitializers(const std::vector< std::string > &names, const std::vector< Value > &ort_values)
Wraps OrtApi::AddExternalInitializers.
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const
The API returns type information for data contained in a tensor. For sparse tensors it returns type i...
ONNXTensorElementDataType
#define const
Definition: zconf.h:214
int GetRunLogSeverityLevel() const
Wraps OrtApi::RunOptionsGetRunLogSeverityLevel.
ConstMemoryInfo GetTensorMemoryInfo() const
This API returns information about the memory allocation used to hold data.
AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator *allocator) const
Returns a copy of input name at the specified index.
Wrapper around ::OrtTensorTypeAndShapeInfo.
MemoryInfo(std::nullptr_t)
No instance is created.
std::string MakeCustomOpConfigEntryKey(const char *custom_op_name, const char *config)
CustomOpConfigs.
size_t GetCount() const
< Return true if OrtValue contains data and returns false if the OrtValue is a None ...
std::string GetInputName(size_t index) const
ThreadingOptions & SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn)
Wraps OrtApi::SetGlobalCustomCreateThreadFn.
Wrapper around ::OrtSessionOptions.
OrtSparseFormat GetSparseFormat() const
The API returns the sparse data format this OrtValue holds in a sparse tensor. If the sparse tensor w...
detail::MemoryInfoImpl< detail::Unowned< const OrtMemoryInfo >> ConstMemoryInfo
OrtErrorCode
void UseCooIndices(int64_t *indices_data, size_t indices_num)
Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tens...
MemoryAllocation(OrtAllocator *allocator, void *p, size_t size)
CustomOpConfigs & AddConfig(const char *custom_op_name, const char *config_key, const char *config_value)
Adds a session configuration entry/value for a specific custom operator.
void(* OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_handle)
Custom thread join function.
type
Definition: core.h:1059
void ThrowOnError(OrtStatus *result)
TypeInfo GetInputTypeInfo(size_t index) const
SessionOptionsImpl & SetLogSeverityLevel(int level)
Wraps OrtApi::SetSessionLogSeverityLevel.
Wrapper around OrtAllocator default instance that is owned by Onnxruntime.
Wrapper around ::OrtSession.
SessionOptionsImpl & DisableMemPattern()
Wraps OrtApi::DisableMemPattern.
SessionOptionsImpl & AppendExecutionProvider_ROCM(const OrtROCMProviderOptions &provider_options)
Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM.
Options for the CUDA provider that are passed to SessionOptionsAppendExecutionProvider_CUDA_V2. Please note that this struct is similar to OrtCUDAProviderOptions but only to be used internally. Going forward, new cuda provider options are to be supported via this struct and usage of the publicly defined OrtCUDAProviderOptions will be deprecated over time. User can only get the instance of OrtCUDAProviderOptionsV2 via CreateCUDAProviderOptions.
RunOptions & SetRunLogVerbosityLevel(int)
Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel.
ThreadingOptions & SetGlobalCustomThreadCreationOptions(void *ort_custom_thread_creation_options)
Wraps OrtApi::SetGlobalCustomThreadCreationOptions.
struct OrtKernelContext OrtKernelContext
bfloat16 (Brain Floating Point) data type
OrtLoggingLevel
Logging severity levels.
Class that represents session configuration entries for one or more custom operators.
size_t KernelContext_GetInputCount(const OrtKernelContext *context)
GLint GLsizei count
Definition: glcorearb.h:405
Definition: format.h:895
std::vector< std::string > GetOutputNamesHelper(const OrtIoBinding *binding, OrtAllocator *)
#define ORT_CXX_API_THROW(string, code)
OrtMemType
Memory types for allocated memory, execution provider specific types should be extended in each provi...
SessionOptionsImpl & RegisterCustomOpsLibrary(const ORTCHAR_T *library_name, const CustomOpConfigs &custom_op_configs={})
T KernelInfoGetAttribute(_In_ const OrtKernelInfo *info, _In_ const char *name)
GLsizei GLenum GLenum GLuint GLenum GLsizei * lengths
Definition: glcorearb.h:2542
ConstSequenceTypeInfo GetSequenceTypeInfo() const
Wraps OrtApi::CastTypeInfoToSequenceTypeInfo.
MIGraphX Provider Options.