HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
graph_nodes.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <memory>
7 #include <type_traits>
8 #include <vector>
9 
10 namespace onnxruntime {
11 
12 class Node;
13 
14 /**
15 Class to filter out null entries from either a vector of unique_ptr<Node> or a vector of [const] Node* and
16 provide an iterator interface that returns [const] Node& for the valid entries.
17 */
18 template <typename TNodesContainer>
19 class ValidNodes {
20  public:
21  template <typename TIterator>
22  class NodeIterator;
23 
24  // optional filtering function to return a subset of nodes
25  using NodeFilterFunc = std::function<bool(NodeIndex)>;
26 
27  /**
28  Construct a ValidNodes instance to provide iteration over all valid nodes in the TNodesCollection
29  @param[in] nodes Nodes to iterate, skipping invalid entries.
30  */
31  explicit ValidNodes(TNodesContainer& nodes) noexcept : nodes_(&nodes) {}
32 
33  explicit ValidNodes(TNodesContainer& nodes, NodeFilterFunc&& filter_node_fn) noexcept
34  : nodes_(&nodes), filter_node_fn_{std::move(filter_node_fn)} {}
35 
36  using ConstNodeIterator = NodeIterator<typename TNodesContainer::const_iterator>;
37  using MutableNodeIterator = NodeIterator<typename TNodesContainer::iterator>;
38  using ConstReverseNodeIterator = NodeIterator<typename TNodesContainer::const_reverse_iterator>;
39 
41  return {nodes_->cbegin(), nodes_->cend(), filter_node_fn_};
42  }
43 
45  return {nodes_->cend(), nodes_->cend(), filter_node_fn_};
46  }
47 
49  return cbegin();
50  }
51 
53  return cend();
54  }
55 
57  return {nodes_->crbegin(), nodes_->crend(), filter_node_fn_};
58  }
59 
61  return {nodes_->crend(), nodes_->crend(), filter_node_fn_};
62  }
63 
64  // we only allow mutable access if the container is non-const.
65  // we need to templatize the functions for enable_if to work at this level, but mandate T2 being TNodesContainer
66  template <typename T2 = TNodesContainer>
68  static_assert(std::is_same<T2, TNodesContainer>::value, "Explicit specialization is not allowed");
69  return MutableNodeIterator(nodes_->begin(), nodes_->end(), filter_node_fn_);
70  }
71 
72  template <typename T2 = TNodesContainer>
74  static_assert(std::is_same<T2, TNodesContainer>::value, "Explicit specialization is not allowed");
75  return MutableNodeIterator(nodes_->end(), nodes_->end(), filter_node_fn_);
76  }
77 
78  bool empty() const noexcept { return nodes_->empty(); }
79 
80  /**
81  @class NodeIterator
82  Iterator to provide const and non-const access to valid Node instances in a Graph.
83  @remarks Skips invalid nodes.
84  */
85  template <typename TIterator>
86  class NodeIterator {
87  // get the type being returned by the iterator. can't use TIterator::value_type as that is always non-const
88  using IterType = typename std::remove_reference<typename std::iterator_traits<TIterator>::reference>::type;
89  // and determine what we will return based on its constness
91  const Node, // return const Node if this is a const iterator
92  Node>::type; // else return Node
93 
94  public:
95  using iterator_category = std::input_iterator_tag;
96  using value_type = T;
97  using difference_type = typename TIterator::difference_type;
98  using pointer = T*;
99  using reference = T&;
100  using const_reference = const T&;
101 
102  /** Construct a NodeInterator and move to the first valid node. */
103  NodeIterator(const TIterator current, const TIterator end, const NodeFilterFunc& filter_fn) noexcept
104  : current_{current}, end_{end}, apply_filter_{filter_fn != nullptr}, filter_func_{&filter_fn} {
105  // skip to next valid node, stopping at end if none are found
106  while (current_ < end && (*current_ == nullptr ||
107  (apply_filter_ && (*filter_func_)((*current_)->Index()) == true))) {
108  ++current_;
109  }
110  }
111 
112  bool operator==(const NodeIterator<TIterator>& other) const noexcept {
113  return (current_ == other.current_);
114  }
115 
116  bool operator!=(const NodeIterator<TIterator>& other) const noexcept {
117  return (current_ != other.current_);
118  }
119 
120  void operator++() {
121  if (current_ < end_) {
122  while (++current_ != end_) {
123  if (*current_ != nullptr && (!apply_filter_ || (*filter_func_)((*current_)->Index()) == false))
124  break;
125  }
126  }
127  }
128 
130  NodeIterator<TIterator> tmp{*this};
131  ++(*this);
132 
133  return tmp;
134  }
135 
136  /** Return the current Node&. This will be const if the iterator was returned from a const GraphNodes instance. */
138  // if iterator is valid we always have a non-nullptr node
139  // if this is a nullptr we're at end_ and this shouldn't be being called
140  return **current_;
141  }
142 
144  return current_->get();
145  }
146 
147  private:
148  TIterator current_;
149  TIterator end_;
150  bool apply_filter_; // store whether filter_func_ is not nullptr and contains a callable
151  const NodeFilterFunc* filter_func_; // store as pointer so iterator is copyable
152  };
153 
154  private:
155  gsl::not_null<TNodesContainer*> nodes_; // always set by ctor
156 
157  // no filtering if not set. this instance owns the filter func if set.
158  NodeFilterFunc filter_node_fn_;
159 };
160 
161 /**
162 Class that provides iteration over all valid nodes in the Graph.
163 */
164 class GraphNodes : public ValidNodes<std::vector<std::unique_ptr<Node>>> {
165  public:
166  GraphNodes(std::vector<std::unique_ptr<Node>>& nodes) : ValidNodes(nodes) {
167  }
168 };
169 
170 // Variant that only ever allows const access to nodes and optionally allows filtering of the nodes.
171 class ConstGraphNodes : public ValidNodes<const std::vector<std::unique_ptr<Node>>> {
172  public:
173  ConstGraphNodes(const std::vector<std::unique_ptr<Node>>& nodes) : ValidNodes(nodes) {
174  }
175 
176  ConstGraphNodes(const std::vector<std::unique_ptr<Node>>& nodes,
177  GraphNodes::NodeFilterFunc&& filter_func)
178  : ValidNodes(nodes, std::move(filter_func)) {
179  }
180 };
181 
182 } // namespace onnxruntime
bool empty() const noexcept
Definition: graph_nodes.h:78
typename TIterator::difference_type difference_type
Definition: graph_nodes.h:97
Definition: Node.h:52
NodeIterator< TIterator > operator++(int)
Definition: graph_nodes.h:129
GLsizei const GLfloat * value
Definition: glcorearb.h:824
ConstNodeIterator cend() const noexcept
Definition: graph_nodes.h:44
ConstNodeIterator begin() const noexcept
Definition: graph_nodes.h:48
ConstNodeIterator cbegin() const noexcept
Definition: graph_nodes.h:40
GraphNodes(std::vector< std::unique_ptr< Node >> &nodes)
Definition: graph_nodes.h:166
NodeIterator< typename std::vector< std::unique_ptr< Node > >::const_reverse_iterator > ConstReverseNodeIterator
Definition: graph_nodes.h:38
ValidNodes(TNodesContainer &nodes) noexcept
Definition: graph_nodes.h:31
std::enable_if<!std::is_const< T2 >::value, MutableNodeIterator >::type begin() noexcept
Definition: graph_nodes.h:67
GLuint GLuint end
Definition: glcorearb.h:475
std::enable_if<!std::is_const< T2 >::value, MutableNodeIterator >::type end() noexcept
Definition: graph_nodes.h:73
bool operator!=(const NodeIterator< TIterator > &other) const noexcept
Definition: graph_nodes.h:116
ConstReverseNodeIterator rbegin() const noexcept
Definition: graph_nodes.h:56
ConstNodeIterator end() const noexcept
Definition: graph_nodes.h:52
std::input_iterator_tag iterator_category
Definition: graph_nodes.h:95
bool operator==(const NodeIterator< TIterator > &other) const noexcept
Definition: graph_nodes.h:112
ConstGraphNodes(const std::vector< std::unique_ptr< Node >> &nodes, GraphNodes::NodeFilterFunc &&filter_func)
Definition: graph_nodes.h:176
ConstGraphNodes(const std::vector< std::unique_ptr< Node >> &nodes)
Definition: graph_nodes.h:173
NodeIterator(const TIterator current, const TIterator end, const NodeFilterFunc &filter_fn) noexcept
Definition: graph_nodes.h:103
ValidNodes(TNodesContainer &nodes, NodeFilterFunc &&filter_node_fn) noexcept
Definition: graph_nodes.h:33
ConstReverseNodeIterator rend() const noexcept
Definition: graph_nodes.h:60
NodeIterator< typename std::vector< std::unique_ptr< Node > >::const_iterator > ConstNodeIterator
Definition: graph_nodes.h:36
#define const
Definition: zconf.h:214
type
Definition: core.h:1059
NodeIterator< typename std::vector< std::unique_ptr< Node > >::iterator > MutableNodeIterator
Definition: graph_nodes.h:37