HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
int4.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <cassert>
7 #include <type_traits>
8 #include "core/common/common.h"
9 #include <gsl/gsl>
10 
11 namespace onnxruntime {
12 
13 template <bool Signed>
14 struct Int4Traits;
15 
16 template <>
17 struct Int4Traits<true> {
18  using UnpackedType = int8_t;
19  static constexpr int8_t min_val = -8;
20  static constexpr int8_t max_val = 7;
21 };
22 
23 template <>
24 struct Int4Traits<false> {
25  using UnpackedType = uint8_t;
26  static constexpr uint8_t min_val = 0;
27  static constexpr uint8_t max_val = 15;
28 };
29 
30 /// <summary>
31 /// Stores 2 packed 4-bit elements in 1 byte.
32 /// </summary>
33 /// <typeparam name="Signed">Set to true if signed int4, or false if unsigned uint4.</typeparam>
34 template <bool Signed>
35 struct Int4x2Base {
39 
41 
42  Int4x2Base() = default;
43 
44  explicit Int4x2Base(std::byte bits) {
45  bits_ = bits;
46  }
47 
49  bits_ = static_cast<std::byte>(((val1 & 0xF) << 4) | (val0 & 0xF));
50  }
51 
52  static inline int8_t SignExtendLower4Bits(std::byte bits) {
53  // Sign-extend lower 4-bits by left shifting and then doing an arithmetic right shift.
54  constexpr uint8_t shift = (sizeof(int32_t) * 8) - 4;
55  return static_cast<int8_t>((static_cast<int32_t>(bits) << shift) >> shift);
56  }
57 
58  inline UnpackedType GetElem(size_t index) const {
59  assert(index <= 1);
60  const uint8_t shift = 4 * static_cast<uint8_t>(index);
61  const std::byte val = (bits_ >> shift) & std::byte{0xF};
62 
63  if constexpr (Signed) {
64  return SignExtendLower4Bits(val);
65  } else {
66  return static_cast<UnpackedType>(val);
67  }
68  }
69 
70  inline void SetElem(size_t index, UnpackedType val) {
71  assert(index <= 1);
72  const uint8_t shift = 4 * static_cast<uint8_t>(index);
73  const std::byte mask = std::byte{0xF0} >> shift;
74 
75  bits_ &= mask; // Clear 4-bit element to 0
76  bits_ |= static_cast<std::byte>((val & 0xF) << shift); // Set 4-bit element to val
77  }
78 
79  inline std::byte ToBits() const {
80  return bits_;
81  }
82 
83  static size_t CalcNumInt4Pairs(size_t num_int4_elems) {
84  return (num_int4_elems + 1) / 2;
85  }
86 
87  /// <summary>
88  /// Copy a source buffer of 4-bit elements (packed) into a destination buffer of 8-bit elements (unpacked).
89  /// </summary>
90  /// <param name="dst">Destination buffer to store unpacked 8-bit elements</param>
91  /// <param name="src">Source buffer with 4-bit elements</param>
92  /// <returns>True on success</returns>
93  static bool Unpack(gsl::span<UnpackedType> dst, gsl::span<const Int4x2Base<Signed>> src) {
94  if (CalcNumInt4Pairs(dst.size()) != src.size()) {
95  return false;
96  }
97 
98  if (src.empty()) {
99  return true;
100  }
101 
102  for (size_t i = 0; i < dst.size(); i++) {
103  size_t r = i >> 1; // i / 2;
104  size_t c = i & 0x1; // i % 2;
105  dst[i] = src[r].GetElem(c);
106  }
107 
108  return true;
109  }
110 
111  /// <summary>
112  /// Copy a source buffer of 8-bit elements (unpacked) into a destination buffer of 4-bit elements (packed).
113  /// </summary>
114  /// <param name="dst">Destination buffer to store packed 4-bit elements</param>
115  /// <param name="src">Source buffer with 8-bit elements</param>
116  /// <returns>True on success</returns>
117  static bool Pack(gsl::span<Int4x2Base<Signed>> dst, gsl::span<const UnpackedType> src) {
118  if (CalcNumInt4Pairs(src.size()) != dst.size()) {
119  return false;
120  }
121 
122  if (src.empty()) {
123  return true;
124  }
125 
126  size_t src_i = 0;
127  size_t dst_i = 0;
128 
129  for (; src_i < src.size() - 1; src_i += 2) {
130  dst[dst_i++] = Int4x2Base<Signed>(src[src_i], src[src_i + 1]);
131  }
132 
133  if (src_i < src.size()) {
134  dst[dst_i] = Int4x2Base<Signed>(src[src_i], 0);
135  }
136 
137  return true;
138  }
139 
140  /// <summary>
141  /// Returns hierarchical indices for a packed int4 element from the given element index.
142  ///
143  /// Usage:
144  /// Int4x2* data = ...;
145  /// auto indices = GetTensorElemIndices(3); // 4th int4 element
146  /// int8_t elem = data[indices.first].GetElem(indices.second);
147  /// </summary>
148  /// <param name="index">Index of 4-bit element</param>
149  /// <returns>Unpacked element</returns>
150  static inline std::pair<size_t, size_t> GetTensorElemIndices(size_t index) {
151  return {index >> 1, index & 0x1};
152  }
153 };
154 
157 static_assert(sizeof(Int4x2) == sizeof(std::byte));
158 static_assert(sizeof(UInt4x2) == sizeof(std::byte));
159 } // namespace onnxruntime
std::byte ToBits() const
Definition: int4.h:79
UnpackedType GetElem(size_t index) const
Definition: int4.h:58
static std::pair< size_t, size_t > GetTensorElemIndices(size_t index)
Returns hierarchical indices for a packed int4 element from the given element index.
Definition: int4.h:150
static int8_t SignExtendLower4Bits(std::byte bits)
Definition: int4.h:52
std::byte bits_
Definition: int4.h:40
static bool Pack(gsl::span< Int4x2Base< Signed >> dst, gsl::span< const UnpackedType > src)
Copy a source buffer of 8-bit elements (unpacked) into a destination buffer of 4-bit elements (packed...
Definition: int4.h:117
static constexpr UnpackedType max_val
Definition: int4.h:38
Stores 2 packed 4-bit elements in 1 byte.
Definition: int4.h:35
GLint GLuint mask
Definition: glcorearb.h:124
static bool Unpack(gsl::span< UnpackedType > dst, gsl::span< const Int4x2Base< Signed >> src)
Copy a source buffer of 4-bit elements (packed) into a destination buffer of 8-bit elements (unpacked...
Definition: int4.h:93
Int4x2Base(UnpackedType val0, UnpackedType val1)
Definition: int4.h:48
static constexpr UnpackedType min_val
Definition: int4.h:37
GLenum GLenum dst
Definition: glcorearb.h:1793
void SetElem(size_t index, UnpackedType val)
Definition: int4.h:70
static size_t CalcNumInt4Pairs(size_t num_int4_elems)
Definition: int4.h:83
GLuint index
Definition: glcorearb.h:786
GLuint GLfloat * val
Definition: glcorearb.h:1608
Int4x2Base(std::byte bits)
Definition: int4.h:44
GLboolean r
Definition: glcorearb.h:1222
typename Int4Traits< Signed >::UnpackedType UnpackedType
Definition: int4.h:36
GLenum GLenum GLsizei void GLsizei void void * span
Definition: glad.h:5135
GLenum src
Definition: glcorearb.h:1793
unsigned char byte
Definition: UT_Span.h:163