HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
rocm_context.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #define ORT_ROCM_CTX
5 
6 #include "rocm_resource.h"
8 #include <hip/hip_runtime.h>
9 #include <miopen/miopen.h>
10 #include <hipblas/hipblas.h>
11 
12 namespace Ort {
13 
14 namespace Custom {
15 
16 struct RocmContext : public CustomOpContext {
18  miopenHandle_t miopen_handle = {};
19  hipblasHandle_t blas_handle = {};
20 
21  void Init(const OrtKernelContext& kernel_ctx) {
22  const auto& ort_api = Ort::GetApi();
23  void* resource = {};
24  OrtStatus* status = nullptr;
25 
26  status = ort_api.KernelContext_GetResource(
27  &kernel_ctx, ORT_ROCM_RESOURCE_VERSION, RocmResource::hip_stream_t, &resource);
28  if (status) {
29  ORT_CXX_API_THROW("failed to fetch hip stream", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
30  }
31  hip_stream = reinterpret_cast<hipStream_t>(resource);
32 
33  resource = {};
34  status = ort_api.KernelContext_GetResource(
36  if (status) {
37  ORT_CXX_API_THROW("failed to fetch miopen handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
38  }
39  miopen_handle = reinterpret_cast<miopenHandle_t>(resource);
40 
41  resource = {};
42  status = ort_api.KernelContext_GetResource(
44  if (status) {
45  ORT_CXX_API_THROW("failed to fetch hipblas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
46  }
47  blas_handle = reinterpret_cast<hipblasHandle_t>(resource);
48  }
49 };
50 
51 } // namespace Custom
52 } // namespace Ort
void Init(const OrtKernelContext &kernel_ctx)
Definition: rocm_context.h:21
miopenHandle_t miopen_handle
Definition: rocm_context.h:18
struct ihipStream_t * hipStream_t
Definition: oidn.h:25
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
Use a manually-specified time code range.
#define ORT_ROCM_RESOURCE_VERSION
Definition: rocm_resource.h:6
hipblasHandle_t blas_handle
Definition: rocm_context.h:19
#define ORT_CXX_API_THROW(string, code)