HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ShaderGraph.h
Go to the documentation of this file.
1 //
2 // Copyright Contributors to the MaterialX Project
3 // SPDX-License-Identifier: Apache-2.0
4 //
5 
6 #ifndef MATERIALX_SHADERGRAPH_H
7 #define MATERIALX_SHADERGRAPH_H
8 
9 /// @file
10 /// Shader graph class
11 
13 
19 
20 #include <MaterialXCore/Document.h>
21 #include <MaterialXCore/Node.h>
22 
24 
25 class Syntax;
26 class ShaderGraphEdge;
28 class GenOptions;
29 
30 /// An internal input socket in a shader graph,
31 /// used for connecting internal nodes to the outside
33 
34 /// An internal output socket in a shader graph,
35 /// used for connecting internal nodes to the outside
37 
38 /// A shared pointer to a shader graph
39 using ShaderGraphPtr = shared_ptr<class ShaderGraph>;
40 
41 /// @class ShaderGraph
42 /// Class representing a graph (DAG) for shader generation
44 {
45  public:
46  /// Constructor.
47  ShaderGraph(const ShaderGraph* parent, const string& name, ConstDocumentPtr document, const StringSet& reservedWords);
48 
49  /// Destructor.
50  virtual ~ShaderGraph() { }
51 
52  /// Create a new shader graph from an element.
53  /// Supported elements are outputs and shader nodes.
54  static ShaderGraphPtr create(const ShaderGraph* parent, const string& name, ElementPtr element,
55  GenContext& context);
56 
57  /// Create a new shader graph from a nodegraph.
58  static ShaderGraphPtr create(const ShaderGraph* parent, const NodeGraph& nodeGraph,
59  GenContext& context);
60 
61  /// Return true if this node is a graph.
62  bool isAGraph() const override { return true; }
63 
64  /// Get an internal node by name
65  ShaderNode* getNode(const string& name);
66 
67  /// Get an internal node by name
68  const ShaderNode* getNode(const string& name) const;
69 
70  /// Get a vector of all nodes in order
71  const vector<ShaderNode*>& getNodes() const { return _nodeOrder; }
72 
73  /// Get number of input sockets
74  size_t numInputSockets() const { return numOutputs(); }
75 
76  /// Get number of output sockets
77  size_t numOutputSockets() const { return numInputs(); }
78 
79  /// Get socket by index
82  const ShaderGraphInputSocket* getInputSocket(size_t index) const { return getOutput(index); }
83  const ShaderGraphOutputSocket* getOutputSocket(size_t index = 0) const { return getInput(index); }
84 
85  /// Get socket by name
86  ShaderGraphInputSocket* getInputSocket(const string& name) { return getOutput(name); }
87  ShaderGraphOutputSocket* getOutputSocket(const string& name) { return getInput(name); }
88  const ShaderGraphInputSocket* getInputSocket(const string& name) const { return getOutput(name); }
89  const ShaderGraphOutputSocket* getOutputSocket(const string& name) const { return getInput(name); }
90 
91  /// Get vector of sockets
92  const vector<ShaderGraphInputSocket*>& getInputSockets() const { return _outputOrder; }
93  const vector<ShaderGraphOutputSocket*>& getOutputSockets() const { return _inputOrder; }
94 
95  /// Apply color and unit transforms to each input of a node.
96  void applyInputTransforms(ConstNodePtr node, ShaderNodePtr shaderNode, GenContext& context);
97 
98  /// Create a new node in the graph
99  ShaderNode* createNode(ConstNodePtr node, GenContext& context);
100 
101  /// Add input sockets
102  ShaderGraphInputSocket* addInputSocket(const string& name, TypeDesc type);
103  [[deprecated]] ShaderGraphInputSocket* addInputSocket(const string& name, const TypeDesc* type) { return addInputSocket(name, *type); }
104 
105  /// Add output sockets
106  ShaderGraphOutputSocket* addOutputSocket(const string& name, TypeDesc type);
107  [[deprecated]] ShaderGraphOutputSocket* addOutputSocket(const string& name, const TypeDesc* type) { return addOutputSocket(name, *type); }
108 
109  /// Add a default geometric node and connect to the given input.
110  void addDefaultGeomNode(ShaderInput* input, const GeomPropDef& geomprop, GenContext& context);
111 
112  /// Sort the nodes in topological order.
113  void topologicalSort();
114 
115  /// Return an iterator for traversal upstream from the given output
116  static ShaderGraphEdgeIterator traverseUpstream(ShaderOutput* output);
117 
118  /// Return the map of unique identifiers used in the scope of this graph.
119  IdentifierMap& getIdentifierMap() { return _identifiers; }
120 
121  protected:
122  /// Create node connections corresponding to the connection between a pair of elements.
123  /// @param downstreamElement Element representing the node to connect to.
124  /// @param upstreamElement Element representing the node to connect from
125  /// @param connectingElement If non-null, specifies the element on on the downstream node to connect to.
126  /// @param context Context for generation.
127  void createConnectedNodes(const ElementPtr& downstreamElement,
128  const ElementPtr& upstreamElement,
129  ElementPtr connectingElement,
130  GenContext& context);
131 
132  /// Add a node to the graph
133  void addNode(ShaderNodePtr node);
134 
135  /// Add input sockets from an interface element (nodedef, nodegraph or node)
136  void addInputSockets(const InterfaceElement& elem, GenContext& context);
137 
138  /// Add output sockets from an interface element (nodedef, nodegraph or node)
139  void addOutputSockets(const InterfaceElement& elem, GenContext& context);
140 
141  /// Traverse from the given root element and add all dependencies upstream.
142  /// The traversal is done in the context of a material, if given, to include
143  /// bind input elements in the traversal.
144  void addUpstreamDependencies(const Element& root, GenContext& context);
145 
146  /// Add a color transform node and connect to the given input.
147  void addColorTransformNode(ShaderInput* input, const ColorSpaceTransform& transform, GenContext& context);
148 
149  /// Add a color transform node and connect to the given output.
150  void addColorTransformNode(ShaderOutput* output, const ColorSpaceTransform& transform, GenContext& context);
151 
152  /// Add a unit transform node and connect to the given input.
153  void addUnitTransformNode(ShaderInput* input, const UnitTransform& transform, GenContext& context);
154 
155  /// Add a unit transform node and connect to the given output.
156  void addUnitTransformNode(ShaderOutput* output, const UnitTransform& transform, GenContext& context);
157 
158  /// Perform all post-build operations on the graph.
159  void finalize(GenContext& context);
160 
161  /// Optimize the graph, removing redundant paths.
162  void optimize();
163 
164  /// Bypass a node for a particular input and output,
165  /// effectively connecting the input's upstream connection
166  /// with the output's downstream connections.
167  void bypass(ShaderNode* node, size_t inputIndex, size_t outputIndex = 0);
168 
169  /// For inputs and outputs in the graph set the variable names to be used
170  /// in generated code. Making sure variable names are valid and unique
171  /// to avoid name conflicts during shader generation.
172  void setVariableNames(GenContext& context);
173 
174  /// Populate the color transform map for the given shader port, if the provided combination of
175  /// source and target color spaces are supported for its data type.
176  void populateColorTransformMap(ColorManagementSystemPtr colorManagementSystem, ShaderPort* shaderPort,
177  const string& sourceColorSpace, const string& targetColorSpace, bool asInput);
178 
179  /// Populates the appropriate unit transform map if the provided input/parameter or output
180  /// has a unit attribute and is of the supported type
181  void populateUnitTransformMap(UnitSystemPtr unitSystem, ShaderPort* shaderPort, ValueElementPtr element, const string& targetUnitSpace, bool asInput);
182 
183  /// Break all connections on a node
184  void disconnect(ShaderNode* node) const;
185 
187  std::unordered_map<string, ShaderNodePtr> _nodeMap;
188  std::vector<ShaderNode*> _nodeOrder;
190 
191  // Temporary storage for inputs that require color transformations
192  std::unordered_map<ShaderInput*, ColorSpaceTransform> _inputColorTransformMap;
193  // Temporary storage for inputs that require unit transformations
194  std::unordered_map<ShaderInput*, UnitTransform> _inputUnitTransformMap;
195 
196  // Temporary storage for outputs that require color transformations
197  std::unordered_map<ShaderOutput*, ColorSpaceTransform> _outputColorTransformMap;
198  // Temporary storage for outputs that require unit transformations
199  std::unordered_map<ShaderOutput*, UnitTransform> _outputUnitTransformMap;
200 };
201 
202 /// @class ShaderGraphEdge
203 /// An edge returned during shader graph traversal.
205 {
206  public:
208  upstream(up),
209  downstream(down)
210  {
211  }
212 
213  bool operator==(const ShaderGraphEdge& rhs) const
214  {
215  return upstream == rhs.upstream && downstream == rhs.downstream;
216  }
217 
218  bool operator!=(const ShaderGraphEdge& rhs) const
219  {
220  return !(*this == rhs);
221  }
222 
223  bool operator<(const ShaderGraphEdge& rhs) const
224  {
225  return std::tie(upstream, downstream) < std::tie(rhs.upstream, rhs.downstream);
226  }
227 
230 };
231 
232 /// @class ShaderGraphEdgeIterator
233 /// Iterator class for traversing edges between nodes in a shader graph.
235 {
236  public:
239 
240  bool operator==(const ShaderGraphEdgeIterator& rhs) const
241  {
242  return _upstream == rhs._upstream &&
243  _downstream == rhs._downstream &&
244  _stack == rhs._stack;
245  }
246  bool operator!=(const ShaderGraphEdgeIterator& rhs) const
247  {
248  return !(*this == rhs);
249  }
250 
251  /// Dereference this iterator, returning the current output in the traversal.
253  {
254  return ShaderGraphEdge(_upstream, _downstream);
255  }
256 
257  /// Iterate to the next edge in the traversal.
258  /// @throws ExceptionFoundCycle if a cycle is encountered.
259  ShaderGraphEdgeIterator& operator++();
260 
261  /// Return a reference to this iterator to begin traversal
263  {
264  return *this;
265  }
266 
267  /// Return the end iterator.
268  static const ShaderGraphEdgeIterator& end();
269 
270  private:
271  void extendPathUpstream(ShaderOutput* upstream, ShaderInput* downstream);
272  void returnPathDownstream(ShaderOutput* upstream);
273  bool skipOrMarkAsVisited(ShaderGraphEdge);
274 
275  ShaderOutput* _upstream;
276  ShaderInput* _downstream;
277  using StackFrame = std::pair<ShaderOutput*, size_t>;
278  std::vector<StackFrame> _stack;
279  std::set<ShaderOutput*> _path;
280  std::set<ShaderGraphEdge> _visitedEdges;
281 };
282 
284 
285 #endif
std::unordered_map< ShaderInput *, UnitTransform > _inputUnitTransformMap
Definition: ShaderGraph.h:194
ShaderGraphOutputSocket * addOutputSocket(const string &name, const TypeDesc *type)
Definition: ShaderGraph.h:107
ShaderGraphEdge operator*() const
Dereference this iterator, returning the current output in the traversal.
Definition: ShaderGraph.h:252
ShaderGraphInputSocket * getInputSocket(size_t index)
Get socket by index.
Definition: ShaderGraph.h:80
friend class ShaderGraph
Definition: ShaderNode.h:503
vector< ShaderInput * > _inputOrder
Definition: ShaderNode.h:495
bool operator==(const ShaderGraphEdgeIterator &rhs) const
Definition: ShaderGraph.h:240
std::unordered_map< string, ShaderNodePtr > _nodeMap
Definition: ShaderGraph.h:187
#define MATERIALX_NAMESPACE_BEGIN
Definition: Generated.h:25
ShaderOutput * getOutput(size_t index=0)
Definition: ShaderNode.h:447
ShaderInput * getInput(size_t index)
Get inputs/outputs by index.
Definition: ShaderNode.h:446
shared_ptr< class UnitSystem > UnitSystemPtr
A shared pointer to a UnitSystem.
Definition: UnitSystem.h:26
vector< ShaderOutput * > _outputOrder
Definition: ShaderNode.h:498
ShaderGraphOutputSocket * getOutputSocket(const string &name)
Definition: ShaderGraph.h:87
shared_ptr< const Node > ConstNodePtr
A shared pointer to a const Node.
Definition: Node.h:26
size_t numOutputs() const
Definition: ShaderNode.h:443
#define MX_GENSHADER_API
Definition: Export.h:18
const vector< ShaderGraphOutputSocket * > & getOutputSockets() const
Definition: ShaderGraph.h:93
ShaderGraphOutputSocket * getOutputSocket(size_t index=0)
Definition: ShaderGraph.h:81
const ShaderGraphInputSocket * getInputSocket(size_t index) const
Definition: ShaderGraph.h:82
std::vector< ShaderNode * > _nodeOrder
Definition: ShaderGraph.h:188
bool operator!=(const ShaderGraphEdgeIterator &rhs) const
Definition: ShaderGraph.h:246
ShaderGraphEdgeIterator & begin()
Return a reference to this iterator to begin traversal.
Definition: ShaderGraph.h:262
size_t numInputSockets() const
Get number of input sockets.
Definition: ShaderGraph.h:74
ShaderGraphInputSocket * addInputSocket(const string &name, const TypeDesc *type)
Definition: ShaderGraph.h:103
ShaderGraphEdge(ShaderOutput *up, ShaderInput *down)
Definition: ShaderGraph.h:207
GLint GLint GLsizei GLint GLenum GLenum type
Definition: glcorearb.h:108
bool operator<(const ShaderGraphEdge &rhs) const
Definition: ShaderGraph.h:223
shared_ptr< class ColorManagementSystem > ColorManagementSystemPtr
A shared pointer to a ColorManagementSystem.
ShaderGraphInputSocket * getInputSocket(const string &name)
Get socket by name.
Definition: ShaderGraph.h:86
shared_ptr< class ShaderNode > ShaderNodePtr
Shared pointer to a ShaderNode.
Definition: ShaderNode.h:35
ShaderInput * downstream
Definition: ShaderGraph.h:229
GLuint GLuint end
Definition: glcorearb.h:475
bool operator==(const ShaderGraphEdge &rhs) const
Definition: ShaderGraph.h:213
static ShaderNodePtr create(const ShaderGraph *parent, const string &name, const NodeDef &nodeDef, GenContext &context)
Create a new node from a nodedef.
GLuint const GLchar * name
Definition: glcorearb.h:786
shared_ptr< class ShaderGraph > ShaderGraphPtr
A shared pointer to a shader graph.
Definition: ShaderGraph.h:39
bool operator!=(const ShaderGraphEdge &rhs) const
Definition: ShaderGraph.h:218
GA_API const UT_StringHolder transform
const ShaderGraphOutputSocket * getOutputSocket(size_t index=0) const
Definition: ShaderGraph.h:83
const ShaderGraphOutputSocket * getOutputSocket(const string &name) const
Definition: ShaderGraph.h:89
size_t numOutputSockets() const
Get number of output sockets.
Definition: ShaderGraph.h:77
virtual ~ShaderGraph()
Destructor.
Definition: ShaderGraph.h:50
std::unordered_map< ShaderOutput *, ColorSpaceTransform > _outputColorTransformMap
Definition: ShaderGraph.h:197
ShaderOutput * upstream
Definition: ShaderGraph.h:228
bool isAGraph() const override
Return true if this node is a graph.
Definition: ShaderGraph.h:62
IdentifierMap _identifiers
Definition: ShaderGraph.h:189
const vector< ShaderNode * > & getNodes() const
Get a vector of all nodes in order.
Definition: ShaderGraph.h:71
std::unordered_map< string, size_t > IdentifierMap
Definition: Syntax.h:38
GLuint index
Definition: glcorearb.h:786
const ShaderGraphInputSocket * getInputSocket(const string &name) const
Definition: ShaderGraph.h:88
const vector< ShaderGraphInputSocket * > & getInputSockets() const
Get vector of sockets.
Definition: ShaderGraph.h:92
std::set< string > StringSet
A set of strings.
Definition: Library.h:64
std::unordered_map< ShaderOutput *, UnitTransform > _outputUnitTransformMap
Definition: ShaderGraph.h:199
shared_ptr< Element > ElementPtr
A shared pointer to an Element.
Definition: Element.h:31
shared_ptr< ValueElement > ValueElementPtr
A shared pointer to a ValueElement.
Definition: Element.h:41
#define MATERIALX_NAMESPACE_END
Definition: Generated.h:26
std::unordered_map< ShaderInput *, ColorSpaceTransform > _inputColorTransformMap
Definition: ShaderGraph.h:192
size_t numInputs() const
Get number of inputs/outputs.
Definition: ShaderNode.h:442
Definition: Syntax.h:43
ConstDocumentPtr _document
Definition: ShaderGraph.h:186
shared_ptr< const Document > ConstDocumentPtr
A shared pointer to a const Document.
Definition: Document.h:24
IdentifierMap & getIdentifierMap()
Return the map of unique identifiers used in the scope of this graph.
Definition: ShaderGraph.h:119