HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
cuda_context.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 // This header is to expose a context for cuda custom ops.
5 // By the context, a custom cuda operator could fetch existing resources,
6 // such as cuda stream and cudnn handle, for reusing.
7 
8 // For concrete usage, pls find page here:
9 // https://onnxruntime.ai/docs/reference/operators/add-custom-op.html#custom-ops-for-cuda-and-rocm
10 
11 #pragma once
12 
13 #define ORT_CUDA_CTX
14 
15 #include <cuda.h>
16 #include <cuda_runtime.h>
17 #ifndef USE_CUDA_MINIMAL
18 #include <cublas_v2.h>
19 #include <cudnn.h>
20 #endif
21 
24 
25 namespace Ort {
26 
27 namespace Custom {
28 
29 struct CudaContext : public CustomOpContext {
31  cudnnHandle_t cudnn_handle = {};
32  cublasHandle_t cublas_handle = {};
33  OrtAllocator* deferred_cpu_allocator = {};
34  // below are cuda ep options
35  int16_t device_id = 0;
36  int32_t arena_extend_strategy = 0;
41  bool prefer_nhwc = false;
42  bool use_tf32 = true;
43  bool fuse_conv_bias = true;
44 
45  void Init(const OrtKernelContext& kernel_ctx) {
46  cuda_stream = FetchResource<cudaStream_t>(kernel_ctx, CudaResource::cuda_stream_t);
47  cudnn_handle = FetchResource<cudnnHandle_t>(kernel_ctx, CudaResource::cudnn_handle_t);
48  cublas_handle = FetchResource<cublasHandle_t>(kernel_ctx, CudaResource::cublas_handle_t);
49  deferred_cpu_allocator = FetchResource<OrtAllocator*>(kernel_ctx, CudaResource::deferred_cpu_allocator_t);
50 
51  device_id = FetchResource<int16_t>(kernel_ctx, CudaResource::device_id_t);
52  arena_extend_strategy = FetchResource<int32_t>(kernel_ctx, CudaResource::arena_extend_strategy_t);
53  cudnn_conv_algo_search = FetchResource<int32_t>(kernel_ctx, CudaResource::cudnn_conv_algo_search_t);
55 
56  cudnn_conv1d_pad_to_nc1d = FetchResource<bool>(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t);
57  enable_skip_layer_norm_strict_mode = FetchResource<bool>(
59  prefer_nhwc = FetchResource<bool>(kernel_ctx, CudaResource::prefer_nhwc_t);
60  use_tf32 = FetchResource<bool>(kernel_ctx, CudaResource::use_tf32_t);
61  fuse_conv_bias = FetchResource<bool>(kernel_ctx, CudaResource::fuse_conv_bias_t);
62  }
63 
64  template <typename T>
65  T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) {
66  if constexpr (sizeof(T) > sizeof(void*)) {
67  ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type),
68  OrtErrorCode::ORT_INVALID_ARGUMENT);
69  }
70  const auto& ort_api = Ort::GetApi();
71  void* resource = {};
72  OrtStatus* status = ort_api.KernelContext_GetResource(
73  &kernel_ctx, ORT_CUDA_RESOURCE_VERSION, resource_type, &resource);
74  if (status) {
75  ORT_CXX_API_THROW("Failed to fetch cuda ep resource, resource type: " + std::to_string(resource_type),
76  OrtErrorCode::ORT_RUNTIME_EXCEPTION);
77  }
78  T t = {};
79  memcpy(&t, &resource, sizeof(T));
80  return t;
81  }
82 
83  void* AllocDeferredCpuMem(size_t size) const {
84  if (0 == size) {
85  return {};
86  }
87  const auto& ort_api = Ort::GetApi();
88  void* mem = {};
89  auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem);
90  if (status) {
91  ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
92  }
93  return mem;
94  }
95 
96  void FreeDeferredCpuMem(void* mem) const {
97  if (mem) {
98  const auto& ort_api = Ort::GetApi();
99  auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem);
100  if (status) {
101  ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
102  }
103  }
104  }
105 };
106 
107 } // namespace Custom
108 } // namespace Ort
auto to_string(const T &value) -> std::string
Definition: format.h:4527
cudaStream_t cuda_stream
Definition: cuda_context.h:30
OrtAllocator * deferred_cpu_allocator
Definition: cuda_context.h:33
void FreeDeferredCpuMem(void *mem) const
Definition: cuda_context.h:96
CudaResource
Definition: cuda_resource.h:8
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
void Init(const OrtKernelContext &kernel_ctx)
Definition: cuda_context.h:45
void * AllocDeferredCpuMem(size_t size) const
Definition: cuda_context.h:83
T FetchResource(const OrtKernelContext &kernel_ctx, CudaResource resource_type)
Definition: cuda_context.h:65
Use a manually-specified time code range.
cudnnHandle_t cudnn_handle
Definition: cuda_context.h:31
struct CUstream_st * cudaStream_t
Definition: oidn.h:24
GLdouble t
Definition: glad.h:2397
#define ORT_CUDA_RESOURCE_VERSION
Definition: cuda_resource.h:6
GLsizeiptr size
Definition: glcorearb.h:664
cublasHandle_t cublas_handle
Definition: cuda_context.h:32
#define ORT_CXX_API_THROW(string, code)