HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ort_value.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <string>
7 #ifndef SHARED_PROVIDER
8 #include "core/common/common.h"
12 #include "core/framework/tensor.h"
13 
14 namespace onnxruntime {
15 #if !defined(DISABLE_SPARSE_TENSORS)
16 class SparseTensor;
17 #endif
18 class TensorSeq;
19 } // namespace onnxruntime
20 
21 #endif
22 
23 /**
24  Represents both tensors and non-tensors.
25 */
26 struct OrtValue {
27  public:
28  OrtValue() : data_(nullptr) {}
29  ~OrtValue() = default;
30 
32  Init(pData, type, deleter);
33  }
34 
36  data_.reset(pData, deleter);
37  type_ = type;
38  }
39 
40  void Init(void* pData, onnxruntime::MLDataType type, const std::function<void(void*)>& deleter) {
41  data_.reset(pData, deleter);
42  type_ = type;
43  }
44 
45  bool IsAllocated() const {
46  return data_ && type_;
47  }
48 
49  template <typename T>
50  const T& Get() const {
51  ORT_ENFORCE(onnxruntime::DataTypeImpl::GetType<T>() == type_, onnxruntime::DataTypeImpl::GetType<T>(), " != ", type_);
52  return *static_cast<T*>(data_.get());
53  }
54 
55  // May return nullptr, if this OrtValue is an optional type and it is "None".
56  template <typename T>
58  ORT_ENFORCE(onnxruntime::DataTypeImpl::GetType<T>() == type_, onnxruntime::DataTypeImpl::GetType<T>(), " != ", type_);
59  return static_cast<T*>(data_.get());
60  }
61 
62  bool IsTensor() const noexcept {
63  return (type_ != nullptr && type_->IsTensorType());
64  }
65 
66  bool IsTensorSequence() const noexcept {
67  return (type_ != nullptr && type_->IsTensorSequenceType());
68  }
69 
70  bool IsSparseTensor() const {
71 #if !defined(DISABLE_SPARSE_TENSORS)
72  return (type_ != nullptr && type_->IsSparseTensorType());
73 #else
74  ORT_THROW("Sparse tensor is not supported in this build.");
75 #endif
76  }
77 
79  return type_;
80  }
81 
82  private:
83  std::shared_ptr<void> data_;
84  onnxruntime::MLDataType type_{nullptr};
85 };
86 
87 template <>
88 inline const onnxruntime::Tensor& OrtValue::Get<onnxruntime::Tensor>() const {
89  ORT_ENFORCE(IsTensor(), "Trying to get a Tensor, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
90  return *static_cast<onnxruntime::Tensor*>(data_.get());
91 }
92 
93 template <>
94 inline onnxruntime::Tensor* OrtValue::GetMutable<onnxruntime::Tensor>() {
95  ORT_ENFORCE(IsTensor(), "Trying to get a Tensor, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
96  return static_cast<onnxruntime::Tensor*>(data_.get());
97 }
98 
99 template <>
100 inline const onnxruntime::TensorSeq& OrtValue::Get<onnxruntime::TensorSeq>() const {
101  ORT_ENFORCE(IsTensorSequence(), "Trying to get a TensorSeq, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
102  return *static_cast<onnxruntime::TensorSeq*>(data_.get());
103 }
104 
105 template <>
106 inline onnxruntime::TensorSeq* OrtValue::GetMutable<onnxruntime::TensorSeq>() {
107  ORT_ENFORCE(IsTensorSequence(), "Trying to get a TensorSeq, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
108  return static_cast<onnxruntime::TensorSeq*>(data_.get());
109 }
110 
111 #if !defined(DISABLE_SPARSE_TENSORS)
112 template <>
113 inline const onnxruntime::SparseTensor& OrtValue::Get<onnxruntime::SparseTensor>() const {
114  ORT_ENFORCE(IsSparseTensor(), "Trying to get a SparseTensor, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
115  return *static_cast<onnxruntime::SparseTensor*>(data_.get());
116 }
117 
118 template <>
119 inline onnxruntime::SparseTensor* OrtValue::GetMutable<onnxruntime::SparseTensor>() {
120  ORT_ENFORCE(IsSparseTensor(), "Trying to get a SparseTensor, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
121  return static_cast<onnxruntime::SparseTensor*>(data_.get());
122 }
123 #endif
onnxruntime::MLDataType Type() const
Definition: ort_value.h:78
static const char * ToString(MLDataType type)
bool IsTensorSequenceType() const
Definition: data_types.h:127
Base class for MLDataType.
Definition: data_types.h:81
bool IsTensor() const noexcept
Definition: ort_value.h:62
~OrtValue()=default
OrtValue()
Definition: ort_value.h:28
void Init(void *pData, onnxruntime::MLDataType type, const std::function< void(void *)> &deleter)
Definition: ort_value.h:40
void Init(void *pData, onnxruntime::MLDataType type, onnxruntime::DeleteFunc deleter)
Definition: ort_value.h:35
#define ORT_ENFORCE(condition,...)
Definition: common.h:173
bool IsTensorSequence() const noexcept
Definition: ort_value.h:66
This class implements SparseTensor. This class holds sparse non-zero data (values) and sparse format ...
Definition: sparse_tensor.h:55
void(*)(void *) DeleteFunc
Definition: data_types.h:74
bool IsSparseTensor() const
Definition: ort_value.h:70
bool IsTensorType() const
Definition: data_types.h:123
T * GetMutable()
Definition: ort_value.h:57
#define ORT_THROW(...)
Definition: common.h:163
OrtValue(void *pData, onnxruntime::MLDataType type, onnxruntime::DeleteFunc deleter)
Definition: ort_value.h:31
bool IsSparseTensorType() const
Definition: data_types.h:131
#define const
Definition: zconf.h:214
type
Definition: core.h:1059
const T & Get() const
Definition: ort_value.h:50
bool IsAllocated() const
Definition: ort_value.h:45