HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
stream_handles.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 #pragma once
4 
5 #include <functional>
6 #include <unordered_map>
9 #include "core/common/status.h"
10 
11 namespace onnxruntime {
12 class IExecutionProvider;
13 // this opaque handle could be anything the target device generated.
14 // it could be a cuda event, or a npu notification implementation
15 using NotificationHandle = void*;
16 // it can be either a cuda stream, or even nullptr for device doesn't have stream support like cpu.
17 using StreamHandle = void*;
18 
19 namespace synchronize {
20 class Notification;
21 }
22 
23 // a stream abstraction which hold an opaque handle, and a reference to which OrtDevice instance this stream belong to.
24 // it need to be OrtDevice instance as we might have different stream on different OrtDevice with same type.
25 // i.e. different cuda stream on different GPU.
26 class Stream {
27  public:
28  Stream(StreamHandle h, const OrtDevice& d) : handle_(h), device_(d) {}
29 
30  virtual ~Stream() = default;
31  virtual std::unique_ptr<synchronize::Notification> CreateNotification(size_t /*num_consumers*/) {
32  return {};
33  };
34  // block the host thread until all the tasks in the stream finished.
35  virtual void Flush(){};
36  // The framework may reuse the stream instance for multiple iterations.
37  // This is the API that provide a chance to let the device stream cleanup
38  // resource at the end of a iteration.
39  virtual Status CleanUpOnRunEnd() { return Status::OK(); };
40 
41  StreamHandle GetHandle() const { return handle_; }
42 
43  const OrtDevice& GetDevice() const { return device_; }
44 
45  // We use the timestamp based vector clocks to optimize the resource sharing
46  // between different streams.
47  // Each stream maintain following data structure:
48  // 1. Current timestamp
49  // 2. A lookup table that for a given stream, what is its timestamp when the
50  // last synchronization happened with current stream.
51  // 3. When a notification is activated, it take a snapshot of current stream's
52  // lookup table.
53  // 4. When synchronization happened (current stream wait on a notification),
54  // update its lookup table with the table snapshot in notification.
55  // The memory reusing strategy is:
56  // A kernel in current stream is safe to reuse another stream's memory chunk
57  // as long as the reused chunk's timestamp is less than the last synchonized
58  // timestamp recorded in the lookup table.
59 
60  // Get the current timestamp
61  uint64_t GetCurrentTimestamp() const { return timestamp_; }
62 
63  // return the timestamp when the last synchronization happened between target stream and current stream.
64  // return 0 if no synchonization happened.
65  // if target_stream is nullptr, it means it is a sequence running on device doesn't support Stream (i.e. CPU)
66  // we can safely return 0 in that case to save a lookup.
67  uint64_t GetLastSyncTimestampWithTargetStream(Stream* target_stream) const {
68  if (!target_stream)
69  return 0;
70  auto it = other_stream_clock_.find(target_stream);
71  return it == other_stream_clock_.end() ? 0 : it->second;
72  }
73 
74  // make a copy of the current stream lookup table.
75  // this is used to create a snapshot of the stream lookup table in notification.
76  void CloneCurrentStreamSyncTable(std::unordered_map<Stream*, uint64_t>& output) const {
77  output.reserve(other_stream_clock_.size());
78  output.insert(other_stream_clock_.begin(), other_stream_clock_.end());
79  }
80 
81  // bump the current timestamp
82  // When a notification get activated, bump the snapshot in its owner.
83  // Stream is not shared across threads, BumpTimeStampAndReturn will only be invoked on the current thread
84  // where the stream is executed on, so there is no race condition.
86  return ++timestamp_;
87  }
88 
89  // update the stream lookup table with the snapshot saved in notification.
90  void UpdateStreamClock(const std::unordered_map<Stream*, uint64_t>& clock) {
91  for (const auto& kv : clock) {
92  auto ret = other_stream_clock_.insert(kv);
93  if (!ret.second) {
94  ret.first->second = std::max(ret.first->second, kv.second);
95  }
96  }
97  }
98 
99  private:
100  StreamHandle handle_;
101  const OrtDevice& device_;
102  uint64_t timestamp_{0};
103  // TODO: use inline container.
104  // currently this class is header only, but abseil doesn't compile with nvcc
105  // we need to add new symbol to provider_bridge and hide abseil from the header.
106  std::unordered_map<Stream*, uint64_t> other_stream_clock_{};
107 
108  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Stream);
109 };
110 
111 namespace synchronize {
112 // an object which record the status of the stream, and can be wait on from another stream.
114  public:
115  explicit Notification(Stream& s) : stream_(s) {}
116  virtual ~Notification() = default;
117 
118  // this api will perform three operations:
119  // 1. activate the notification on device, for example, record an event on GPU.
120  // 2. take a snapshot of the timestamp lookup table in current stream.
121  // 3. bump the timestamp for current stream.
123  Activate();
126  }
127 
128  // return the timestamp lookup table saved in the notification.
129  const std::unordered_map<Stream*, uint64_t>& GetStreamSyncTable() {
130  return stream_clock_;
131  }
132 
133  protected:
134  virtual void Activate() = 0;
135  // which stream create this notification.
137  // TODO: use inline container.
138  // currently this class is header only, but abseil doesn't compile with nvcc
139  // we need to add new symbol to provider_bridge and hide abseil from the header.
140  std::unordered_map<Stream*, uint64_t> stream_clock_{};
141 };
142 } // namespace synchronize
143 
144 // the definition for the handle for stream commands
145 // EP can register the handle to the executor.
146 // in the POC, just use primitive function pointer
147 // TODO: use a better way to dispatch handles.
148 using CreateStreamFn = std::function<std::unique_ptr<Stream>(const OrtDevice&)>;
149 
150 // an interface of a simple registry which hold the handles EP registered.
151 // make it interface so we can pass it through shared library based execution providers
153  public:
154  virtual ~IStreamCommandHandleRegistry() = default;
155  // Wait is a little special as we need to consider the source stream the notification generated, and the stream we are waiting.
156  // i.e., for an cuda event what notify the memory copy, it could be wait on a CPU stream, or on another cuda stream.
157  [[nodiscard]] virtual WaitNotificationFn GetWaitHandle(OrtDevice::DeviceType notification_ower_device_type,
158  OrtDevice::DeviceType executor_device_type) const = 0;
159  // Get the stream creation function registered on the given device type.
160  [[nodiscard]] virtual CreateStreamFn GetCreateStreamFn(OrtDevice::DeviceType execution_device_type) const = 0;
161  // register a wait methond which will be invoked when we wait a notification (created by 'notification_device_type' device) on a stream at 'device_type' device.
162  virtual void RegisterWaitFn(OrtDevice::DeviceType notification_device_type,
163  OrtDevice::DeviceType device_type,
164  WaitNotificationFn fn) = 0;
165  // register a handle about how to create stream on given device type.
166  virtual void RegisterCreateStreamFn(OrtDevice::DeviceType device_type, CreateStreamFn f) = 0;
167 };
168 
169 
170 } // namespace onnxruntime
virtual CreateStreamFn GetCreateStreamFn(OrtDevice::DeviceType execution_device_type) const =0
virtual std::unique_ptr< synchronize::Notification > CreateNotification(size_t)
void UpdateStreamClock(const std::unordered_map< Stream *, uint64_t > &clock)
virtual void RegisterWaitFn(OrtDevice::DeviceType notification_device_type, OrtDevice::DeviceType device_type, WaitNotificationFn fn)=0
virtual WaitNotificationFn GetWaitHandle(OrtDevice::DeviceType notification_ower_device_type, OrtDevice::DeviceType executor_device_type) const =0
uint64_t GetLastSyncTimestampWithTargetStream(Stream *target_stream) const
std::function< void(Stream &, synchronize::Notification &)> WaitNotificationFn
Definition: allocator.h:54
GLdouble s
Definition: glad.h:3009
const std::unordered_map< Stream *, uint64_t > & GetStreamSyncTable()
int8_t DeviceType
Definition: ortdevice.h:10
uint64_t BumpTimeStampAndReturn()
void CloneCurrentStreamSyncTable(std::unordered_map< Stream *, uint64_t > &output) const
GLfloat f
Definition: glcorearb.h:1926
const OrtDevice & GetDevice() const
StreamHandle GetHandle() const
std::unordered_map< Stream *, uint64_t > stream_clock_
void * StreamHandle
virtual Status CleanUpOnRunEnd()
GLfloat GLfloat GLfloat GLfloat h
Definition: glcorearb.h:2002
Stream(StreamHandle h, const OrtDevice &d)
ImageBuf OIIO_API max(Image_or_Const A, Image_or_Const B, ROI roi={}, int nthreads=0)
std::function< std::unique_ptr< Stream >(const OrtDevice &)> CreateStreamFn
virtual void RegisterCreateStreamFn(OrtDevice::DeviceType device_type, CreateStreamFn f)=0
uint64_t GetCurrentTimestamp() const
virtual void Flush()
void * NotificationHandle
virtual ~Stream()=default