1//
2// Copyright (c) 2012 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6
7#ifndef COMPILER_DEPGRAPH_DEPENDENCY_GRAPH_H
8#define COMPILER_DEPGRAPH_DEPENDENCY_GRAPH_H
9
10#include "compiler/intermediate.h"
11
12#include <set>
13#include <stack>
14
15class TGraphNode;
16class TGraphParentNode;
17class TGraphArgument;
18class TGraphFunctionCall;
19class TGraphSymbol;
20class TGraphSelection;
21class TGraphLoop;
22class TGraphLogicalOp;
23class TDependencyGraphTraverser;
24class TDependencyGraphOutput;
25
26typedef std::set<TGraphNode*> TGraphNodeSet;
27typedef std::vector<TGraphNode*> TGraphNodeVector;
28typedef std::vector<TGraphSymbol*> TGraphSymbolVector;
29typedef std::vector<TGraphFunctionCall*> TFunctionCallVector;
30
31//
32// Base class for all dependency graph nodes.
33//
34class TGraphNode {
35public:
36    TGraphNode(TIntermNode* node) : intermNode(node) {}
37    virtual ~TGraphNode() {}
38    virtual void traverse(TDependencyGraphTraverser* graphTraverser);
39protected:
40    TIntermNode* intermNode;
41};
42
43//
44// Base class for dependency graph nodes that may have children.
45//
46class TGraphParentNode : public TGraphNode {
47public:
48    TGraphParentNode(TIntermNode* node) : TGraphNode(node) {}
49    virtual ~TGraphParentNode() {}
50    void addDependentNode(TGraphNode* node) { if (node != this) mDependentNodes.insert(node); }
51    virtual void traverse(TDependencyGraphTraverser* graphTraverser);
52private:
53    TGraphNodeSet mDependentNodes;
54};
55
56//
57// Handle function call arguments.
58//
59class TGraphArgument : public TGraphParentNode {
60public:
61    TGraphArgument(TIntermAggregate* intermFunctionCall, int argumentNumber)
62        : TGraphParentNode(intermFunctionCall)
63        , mArgumentNumber(argumentNumber) {}
64    virtual ~TGraphArgument() {}
65    const TIntermAggregate* getIntermFunctionCall() const { return intermNode->getAsAggregate(); }
66    int getArgumentNumber() const { return mArgumentNumber; }
67    virtual void traverse(TDependencyGraphTraverser* graphTraverser);
68private:
69    int mArgumentNumber;
70};
71
72//
73// Handle function calls.
74//
75class TGraphFunctionCall : public TGraphParentNode {
76public:
77    TGraphFunctionCall(TIntermAggregate* intermFunctionCall)
78        : TGraphParentNode(intermFunctionCall) {}
79    virtual ~TGraphFunctionCall() {}
80    const TIntermAggregate* getIntermFunctionCall() const { return intermNode->getAsAggregate(); }
81    virtual void traverse(TDependencyGraphTraverser* graphTraverser);
82};
83
84//
85// Handle symbols.
86//
87class TGraphSymbol : public TGraphParentNode {
88public:
89    TGraphSymbol(TIntermSymbol* intermSymbol) : TGraphParentNode(intermSymbol) {}
90    virtual ~TGraphSymbol() {}
91    const TIntermSymbol* getIntermSymbol() const { return intermNode->getAsSymbolNode(); }
92    virtual void traverse(TDependencyGraphTraverser* graphTraverser);
93};
94
95//
96// Handle if statements and ternary operators.
97//
98class TGraphSelection : public TGraphNode {
99public:
100    TGraphSelection(TIntermSelection* intermSelection) : TGraphNode(intermSelection) {}
101    virtual ~TGraphSelection() {}
102    const TIntermSelection* getIntermSelection() const { return intermNode->getAsSelectionNode(); }
103    virtual void traverse(TDependencyGraphTraverser* graphTraverser);
104};
105
106//
107// Handle for, do-while, and while loops.
108//
109class TGraphLoop : public TGraphNode {
110public:
111    TGraphLoop(TIntermLoop* intermLoop) : TGraphNode(intermLoop) {}
112    virtual ~TGraphLoop() {}
113    const TIntermLoop* getIntermLoop() const { return intermNode->getAsLoopNode(); }
114    virtual void traverse(TDependencyGraphTraverser* graphTraverser);
115};
116
117//
118// Handle logical and, or.
119//
120class TGraphLogicalOp : public TGraphNode {
121public:
122    TGraphLogicalOp(TIntermBinary* intermLogicalOp) : TGraphNode(intermLogicalOp) {}
123    virtual ~TGraphLogicalOp() {}
124    const TIntermBinary* getIntermLogicalOp() const { return intermNode->getAsBinaryNode(); }
125    const char* getOpString() const;
126    virtual void traverse(TDependencyGraphTraverser* graphTraverser);
127};
128
129//
130// A dependency graph of symbols, function calls, conditions etc.
131//
132// This class provides an interface to the entry points of the dependency graph.
133//
134// Dependency graph nodes should be created by using one of the provided "create..." methods.
135// This class (and nobody else) manages the memory of the created nodes.
136// Nodes may not be removed after being added, so all created nodes will exist while the
137// TDependencyGraph instance exists.
138//
139class TDependencyGraph {
140public:
141    TDependencyGraph(TIntermNode* intermNode);
142    ~TDependencyGraph();
143    TGraphNodeVector::const_iterator begin() const { return mAllNodes.begin(); }
144    TGraphNodeVector::const_iterator end() const { return mAllNodes.end(); }
145
146    TGraphSymbolVector::const_iterator beginSamplerSymbols() const
147    {
148        return mSamplerSymbols.begin();
149    }
150
151    TGraphSymbolVector::const_iterator endSamplerSymbols() const
152    {
153        return mSamplerSymbols.end();
154    }
155
156    TFunctionCallVector::const_iterator beginUserDefinedFunctionCalls() const
157    {
158        return mUserDefinedFunctionCalls.begin();
159    }
160
161    TFunctionCallVector::const_iterator endUserDefinedFunctionCalls() const
162    {
163        return mUserDefinedFunctionCalls.end();
164    }
165
166    TGraphArgument* createArgument(TIntermAggregate* intermFunctionCall, int argumentNumber);
167    TGraphFunctionCall* createFunctionCall(TIntermAggregate* intermFunctionCall);
168    TGraphSymbol* getOrCreateSymbol(TIntermSymbol* intermSymbol);
169    TGraphSelection* createSelection(TIntermSelection* intermSelection);
170    TGraphLoop* createLoop(TIntermLoop* intermLoop);
171    TGraphLogicalOp* createLogicalOp(TIntermBinary* intermLogicalOp);
172private:
173    typedef TMap<int, TGraphSymbol*> TSymbolIdMap;
174    typedef std::pair<int, TGraphSymbol*> TSymbolIdPair;
175
176    TGraphNodeVector mAllNodes;
177    TGraphSymbolVector mSamplerSymbols;
178    TFunctionCallVector mUserDefinedFunctionCalls;
179    TSymbolIdMap mSymbolIdMap;
180};
181
182//
183// For traversing the dependency graph. Users should derive from this,
184// put their traversal specific data in it, and then pass it to a
185// traverse method.
186//
187// When using this, just fill in the methods for nodes you want visited.
188//
189class TDependencyGraphTraverser {
190public:
191    TDependencyGraphTraverser() : mDepth(0) {}
192
193    virtual void visitSymbol(TGraphSymbol* symbol) {};
194    virtual void visitArgument(TGraphArgument* selection) {};
195    virtual void visitFunctionCall(TGraphFunctionCall* functionCall) {};
196    virtual void visitSelection(TGraphSelection* selection) {};
197    virtual void visitLoop(TGraphLoop* loop) {};
198    virtual void visitLogicalOp(TGraphLogicalOp* logicalOp) {};
199
200    int getDepth() const { return mDepth; }
201    void incrementDepth() { ++mDepth; }
202    void decrementDepth() { --mDepth; }
203
204    void clearVisited() { mVisited.clear(); }
205    void markVisited(TGraphNode* node) { mVisited.insert(node); }
206    bool isVisited(TGraphNode* node) const { return mVisited.find(node) != mVisited.end(); }
207private:
208    int mDepth;
209    TGraphNodeSet mVisited;
210};
211
212#endif
213