1// Copyright 2016 The SwiftShader Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//
16// Definition of the in-memory high-level intermediate representation
17// of shaders.  This is a tree that parser creates.
18//
19// Nodes in the tree are defined as a hierarchy of classes derived from
20// TIntermNode. Each is a node in a tree.  There is no preset branching factor;
21// each node can have it's own type of list of children.
22//
23
24#ifndef __INTERMEDIATE_H
25#define __INTERMEDIATE_H
26
27#include "Common.h"
28#include "Types.h"
29#include "ConstantUnion.h"
30
31//
32// Operators used by the high-level (parse tree) representation.
33//
34enum TOperator {
35	EOpNull,            // if in a node, should only mean a node is still being built
36	EOpSequence,        // denotes a list of statements, or parameters, etc.
37	EOpFunctionCall,
38	EOpFunction,        // For function definition
39	EOpParameters,      // an aggregate listing the parameters to a function
40
41	EOpDeclaration,
42	EOpInvariantDeclaration, // Specialized declarations for attributing invariance
43	EOpPrototype,
44
45	//
46	// Unary operators
47	//
48
49	EOpNegative,
50	EOpLogicalNot,
51	EOpVectorLogicalNot,
52	EOpBitwiseNot,
53
54	EOpPostIncrement,
55	EOpPostDecrement,
56	EOpPreIncrement,
57	EOpPreDecrement,
58
59	//
60	// binary operations
61	//
62
63	EOpAdd,
64	EOpSub,
65	EOpMul,
66	EOpDiv,
67	EOpEqual,
68	EOpNotEqual,
69	EOpVectorEqual,
70	EOpVectorNotEqual,
71	EOpLessThan,
72	EOpGreaterThan,
73	EOpLessThanEqual,
74	EOpGreaterThanEqual,
75	EOpComma,
76
77	EOpOuterProduct,
78	EOpTranspose,
79	EOpDeterminant,
80	EOpInverse,
81
82	EOpVectorTimesScalar,
83	EOpVectorTimesMatrix,
84	EOpMatrixTimesVector,
85	EOpMatrixTimesScalar,
86
87	EOpLogicalOr,
88	EOpLogicalXor,
89	EOpLogicalAnd,
90
91	EOpIMod,
92	EOpBitShiftLeft,
93	EOpBitShiftRight,
94	EOpBitwiseAnd,
95	EOpBitwiseXor,
96	EOpBitwiseOr,
97
98	EOpIndexDirect,
99	EOpIndexIndirect,
100	EOpIndexDirectStruct,
101	EOpIndexDirectInterfaceBlock,
102
103	EOpVectorSwizzle,
104
105	//
106	// Built-in functions potentially mapped to operators
107	//
108
109	EOpRadians,
110	EOpDegrees,
111	EOpSin,
112	EOpCos,
113	EOpTan,
114	EOpAsin,
115	EOpAcos,
116	EOpAtan,
117	EOpSinh,
118	EOpCosh,
119	EOpTanh,
120	EOpAsinh,
121	EOpAcosh,
122	EOpAtanh,
123
124	EOpPow,
125	EOpExp,
126	EOpLog,
127	EOpExp2,
128	EOpLog2,
129	EOpSqrt,
130	EOpInverseSqrt,
131
132	EOpAbs,
133	EOpSign,
134	EOpFloor,
135	EOpTrunc,
136	EOpRound,
137	EOpRoundEven,
138	EOpCeil,
139	EOpFract,
140	EOpMod,
141	EOpModf,
142	EOpMin,
143	EOpMax,
144	EOpClamp,
145	EOpMix,
146	EOpStep,
147	EOpSmoothStep,
148	EOpIsNan,
149	EOpIsInf,
150	EOpFloatBitsToInt,
151	EOpFloatBitsToUint,
152	EOpIntBitsToFloat,
153	EOpUintBitsToFloat,
154	EOpPackSnorm2x16,
155	EOpPackUnorm2x16,
156	EOpPackHalf2x16,
157	EOpUnpackSnorm2x16,
158	EOpUnpackUnorm2x16,
159	EOpUnpackHalf2x16,
160
161	EOpLength,
162	EOpDistance,
163	EOpDot,
164	EOpCross,
165	EOpNormalize,
166	EOpFaceForward,
167	EOpReflect,
168	EOpRefract,
169
170	EOpDFdx,            // Fragment only, OES_standard_derivatives extension
171	EOpDFdy,            // Fragment only, OES_standard_derivatives extension
172	EOpFwidth,          // Fragment only, OES_standard_derivatives extension
173
174	EOpMatrixTimesMatrix,
175
176	EOpAny,
177	EOpAll,
178
179	//
180	// Branch
181	//
182
183	EOpKill,            // Fragment only
184	EOpReturn,
185	EOpBreak,
186	EOpContinue,
187
188	//
189	// Constructors
190	//
191
192	EOpConstructInt,
193	EOpConstructUInt,
194	EOpConstructBool,
195	EOpConstructFloat,
196	EOpConstructVec2,
197	EOpConstructVec3,
198	EOpConstructVec4,
199	EOpConstructBVec2,
200	EOpConstructBVec3,
201	EOpConstructBVec4,
202	EOpConstructIVec2,
203	EOpConstructIVec3,
204	EOpConstructIVec4,
205	EOpConstructUVec2,
206	EOpConstructUVec3,
207	EOpConstructUVec4,
208	EOpConstructMat2,
209	EOpConstructMat2x3,
210	EOpConstructMat2x4,
211	EOpConstructMat3x2,
212	EOpConstructMat3,
213	EOpConstructMat3x4,
214	EOpConstructMat4x2,
215	EOpConstructMat4x3,
216	EOpConstructMat4,
217	EOpConstructStruct,
218
219	//
220	// moves
221	//
222
223	EOpAssign,
224	EOpInitialize,
225	EOpAddAssign,
226	EOpSubAssign,
227	EOpMulAssign,
228	EOpVectorTimesMatrixAssign,
229	EOpVectorTimesScalarAssign,
230	EOpMatrixTimesScalarAssign,
231	EOpMatrixTimesMatrixAssign,
232	EOpDivAssign,
233	EOpIModAssign,
234	EOpBitShiftLeftAssign,
235	EOpBitShiftRightAssign,
236	EOpBitwiseAndAssign,
237	EOpBitwiseXorAssign,
238	EOpBitwiseOrAssign
239};
240
241extern TOperator TypeToConstructorOperator(const TType &type);
242extern const char* getOperatorString(TOperator op);
243
244class TIntermTraverser;
245class TIntermAggregate;
246class TIntermBinary;
247class TIntermUnary;
248class TIntermConstantUnion;
249class TIntermSelection;
250class TIntermTyped;
251class TIntermSymbol;
252class TIntermLoop;
253class TIntermBranch;
254class TInfoSink;
255class TIntermSwitch;
256class TIntermCase;
257
258//
259// Base class for the tree nodes
260//
261class TIntermNode {
262public:
263	POOL_ALLOCATOR_NEW_DELETE();
264
265	TIntermNode()
266	{
267		// TODO: Move this to TSourceLoc constructor
268		// after getting rid of TPublicType.
269		line.first_file = line.last_file = 0;
270		line.first_line = line.last_line = 0;
271	}
272
273	const TSourceLoc& getLine() const { return line; }
274	void setLine(const TSourceLoc& l) { line = l; }
275
276	virtual void traverse(TIntermTraverser*) = 0;
277	virtual TIntermTyped* getAsTyped() { return 0; }
278	virtual TIntermConstantUnion* getAsConstantUnion() { return 0; }
279	virtual TIntermAggregate* getAsAggregate() { return 0; }
280	virtual TIntermBinary* getAsBinaryNode() { return 0; }
281	virtual TIntermUnary* getAsUnaryNode() { return 0; }
282	virtual TIntermSelection* getAsSelectionNode() { return 0; }
283	virtual TIntermSymbol* getAsSymbolNode() { return 0; }
284	virtual TIntermLoop* getAsLoopNode() { return 0; }
285	virtual TIntermBranch* getAsBranchNode() { return 0; }
286	virtual TIntermSwitch *getAsSwitchNode() { return 0; }
287	virtual TIntermCase *getAsCaseNode() { return 0; }
288	virtual ~TIntermNode() { }
289
290protected:
291	TSourceLoc line;
292};
293
294//
295// This is just to help yacc.
296//
297struct TIntermNodePair {
298	TIntermNode* node1;
299	TIntermNode* node2;
300};
301
302//
303// Intermediate class for nodes that have a type.
304//
305class TIntermTyped : public TIntermNode {
306public:
307	TIntermTyped(const TType& t) : type(t)  { }
308	virtual TIntermTyped* getAsTyped() { return this; }
309
310	virtual void setType(const TType& t) { type = t; }
311	const TType& getType() const { return type; }
312	TType* getTypePointer() { return &type; }
313
314	TBasicType getBasicType() const { return type.getBasicType(); }
315	TQualifier getQualifier() const { return type.getQualifier(); }
316	TPrecision getPrecision() const { return type.getPrecision(); }
317	int getNominalSize() const { return type.getNominalSize(); }
318	int getSecondarySize() const { return type.getSecondarySize(); }
319
320	bool isInterfaceBlock() const { return type.isInterfaceBlock(); }
321	bool isMatrix() const { return type.isMatrix(); }
322	bool isArray()  const { return type.isArray(); }
323	bool isVector() const { return type.isVector(); }
324	bool isScalar() const { return type.isScalar(); }
325	bool isScalarInt() const { return type.isScalarInt(); }
326	bool isRegister() const { return type.isRegister(); }   // Fits in a 4-element register
327	bool isStruct() const { return type.isStruct(); }
328	const char* getBasicString() const { return type.getBasicString(); }
329	const char* getQualifierString() const { return type.getQualifierString(); }
330	TString getCompleteString() const { return type.getCompleteString(); }
331
332	int totalRegisterCount() const { return type.totalRegisterCount(); }
333	int blockRegisterCount() const { return type.blockRegisterCount(); }
334	int elementRegisterCount() const { return type.elementRegisterCount(); }
335	int registerSize() const { return type.registerSize(); }
336	int getArraySize() const { return type.getArraySize(); }
337
338	static TIntermTyped *CreateIndexNode(int index);
339protected:
340	TType type;
341};
342
343//
344// Handle for, do-while, and while loops.
345//
346enum TLoopType {
347	ELoopFor,
348	ELoopWhile,
349	ELoopDoWhile
350};
351
352class TIntermLoop : public TIntermNode {
353public:
354	TIntermLoop(TLoopType aType,
355	            TIntermNode *aInit, TIntermTyped* aCond, TIntermTyped* aExpr,
356	            TIntermNode* aBody) :
357			type(aType),
358			init(aInit),
359			cond(aCond),
360			expr(aExpr),
361			body(aBody),
362			unrollFlag(false) { }
363
364	virtual TIntermLoop* getAsLoopNode() { return this; }
365	virtual void traverse(TIntermTraverser*);
366
367	TLoopType getType() const { return type; }
368	TIntermNode* getInit() { return init; }
369	TIntermTyped* getCondition() { return cond; }
370	TIntermTyped* getExpression() { return expr; }
371	TIntermNode* getBody() { return body; }
372
373	void setUnrollFlag(bool flag) { unrollFlag = flag; }
374	bool getUnrollFlag() { return unrollFlag; }
375
376protected:
377	TLoopType type;
378	TIntermNode* init;  // for-loop initialization
379	TIntermTyped* cond; // loop exit condition
380	TIntermTyped* expr; // for-loop expression
381	TIntermNode* body;  // loop body
382
383	bool unrollFlag; // Whether the loop should be unrolled or not.
384};
385
386//
387// Handle break, continue, return, and kill.
388//
389class TIntermBranch : public TIntermNode {
390public:
391	TIntermBranch(TOperator op, TIntermTyped* e) :
392			flowOp(op),
393			expression(e) { }
394
395	virtual TIntermBranch* getAsBranchNode() { return this; }
396	virtual void traverse(TIntermTraverser*);
397
398	TOperator getFlowOp() { return flowOp; }
399	TIntermTyped* getExpression() { return expression; }
400
401protected:
402	TOperator flowOp;
403	TIntermTyped* expression;  // non-zero except for "return exp;" statements
404};
405
406//
407// Nodes that correspond to symbols or constants in the source code.
408//
409class TIntermSymbol : public TIntermTyped {
410public:
411	// if symbol is initialized as symbol(sym), the memory comes from the poolallocator of sym. If sym comes from
412	// per process globalpoolallocator, then it causes increased memory usage per compile
413	// it is essential to use "symbol = sym" to assign to symbol
414	TIntermSymbol(int i, const TString& sym, const TType& t) :
415			TIntermTyped(t), id(i)  { symbol = sym; }
416
417	int getId() const { return id; }
418	const TString& getSymbol() const { return symbol; }
419
420	void setId(int newId) { id = newId; }
421
422	virtual void traverse(TIntermTraverser*);
423	virtual TIntermSymbol* getAsSymbolNode() { return this; }
424
425protected:
426	int id;
427	TString symbol;
428};
429
430class TIntermConstantUnion : public TIntermTyped {
431public:
432	TIntermConstantUnion(ConstantUnion *unionPointer, const TType& t) : TIntermTyped(t), unionArrayPointer(unionPointer)
433	{
434		getTypePointer()->setQualifier(EvqConstExpr);
435	}
436
437	ConstantUnion* getUnionArrayPointer() const { return unionArrayPointer; }
438
439	int getIConst(int index) const { return unionArrayPointer ? unionArrayPointer[index].getIConst() : 0; }
440	int getUConst(int index) const { return unionArrayPointer ? unionArrayPointer[index].getUConst() : 0; }
441	float getFConst(int index) const { return unionArrayPointer ? unionArrayPointer[index].getFConst() : 0.0f; }
442	bool getBConst(int index) const { return unionArrayPointer ? unionArrayPointer[index].getBConst() : false; }
443
444	// Previous union pointer freed on pool deallocation.
445	void replaceConstantUnion(ConstantUnion *safeConstantUnion) { unionArrayPointer = safeConstantUnion; }
446
447	virtual TIntermConstantUnion* getAsConstantUnion()  { return this; }
448	virtual void traverse(TIntermTraverser*);
449
450	TIntermTyped* fold(TOperator, TIntermTyped*, TInfoSink&);
451
452protected:
453	ConstantUnion *unionArrayPointer;
454};
455
456//
457// Intermediate class for node types that hold operators.
458//
459class TIntermOperator : public TIntermTyped {
460public:
461	TOperator getOp() const { return op; }
462	void setOp(TOperator o) { op = o; }
463
464	bool modifiesState() const;
465	bool isConstructor() const;
466
467protected:
468	TIntermOperator(TOperator o) : TIntermTyped(TType(EbtFloat, EbpUndefined)), op(o) {}
469	TIntermOperator(TOperator o, TType& t) : TIntermTyped(t), op(o) {}
470	TOperator op;
471};
472
473//
474// Nodes for all the basic binary math operators.
475//
476class TIntermBinary : public TIntermOperator {
477public:
478	TIntermBinary(TOperator o) : TIntermOperator(o) {}
479
480	virtual TIntermBinary* getAsBinaryNode() { return this; }
481	virtual void traverse(TIntermTraverser*);
482
483	void setType(const TType &t) override
484	{
485		type = t;
486
487		if(left->getQualifier() == EvqConstExpr && right->getQualifier() == EvqConstExpr)
488		{
489			type.setQualifier(EvqConstExpr);
490		}
491	}
492
493	void setLeft(TIntermTyped* n) { left = n; }
494	void setRight(TIntermTyped* n) { right = n; }
495	TIntermTyped* getLeft() const { return left; }
496	TIntermTyped* getRight() const { return right; }
497	bool promote(TInfoSink&);
498
499protected:
500	TIntermTyped* left;
501	TIntermTyped* right;
502};
503
504//
505// Nodes for unary math operators.
506//
507class TIntermUnary : public TIntermOperator {
508public:
509	TIntermUnary(TOperator o, TType& t) : TIntermOperator(o, t), operand(0) {}
510	TIntermUnary(TOperator o) : TIntermOperator(o), operand(0) {}
511
512	void setType(const TType &t) override
513	{
514		type = t;
515
516		if(operand->getQualifier() == EvqConstExpr)
517		{
518			type.setQualifier(EvqConstExpr);
519		}
520	}
521
522	virtual void traverse(TIntermTraverser*);
523	virtual TIntermUnary* getAsUnaryNode() { return this; }
524
525	void setOperand(TIntermTyped* o) { operand = o; }
526	TIntermTyped* getOperand() { return operand; }
527	bool promote(TInfoSink&, const TType *funcReturnType);
528
529protected:
530	TIntermTyped* operand;
531};
532
533typedef TVector<TIntermNode*> TIntermSequence;
534typedef TVector<int> TQualifierList;
535
536//
537// Nodes that operate on an arbitrary sized set of children.
538//
539class TIntermAggregate : public TIntermOperator {
540public:
541	TIntermAggregate() : TIntermOperator(EOpNull), userDefined(false) { endLine = { 0, 0, 0, 0 }; }
542	TIntermAggregate(TOperator o) : TIntermOperator(o), userDefined(false) { endLine = { 0, 0, 0, 0 }; }
543	~TIntermAggregate() { }
544
545	virtual TIntermAggregate* getAsAggregate() { return this; }
546	virtual void traverse(TIntermTraverser*);
547
548	TIntermSequence& getSequence() { return sequence; }
549
550	void setType(const TType &t) override
551	{
552		type = t;
553
554		if(op != EOpFunctionCall)
555		{
556			for(TIntermNode *node : sequence)
557			{
558				if(!node->getAsTyped() || node->getAsTyped()->getQualifier() != EvqConstExpr)
559				{
560					return;
561				}
562			}
563
564			type.setQualifier(EvqConstExpr);
565		}
566	}
567
568	void setName(const TString& n) { name = n; }
569	const TString& getName() const { return name; }
570
571	void setUserDefined() { userDefined = true; }
572	bool isUserDefined() const { return userDefined; }
573
574	void setOptimize(bool o) { optimize = o; }
575	bool getOptimize() { return optimize; }
576	void setDebug(bool d) { debug = d; }
577	bool getDebug() { return debug; }
578
579	void setEndLine(const TSourceLoc& line) { endLine = line; }
580	const TSourceLoc& getEndLine() const { return endLine; }
581
582	bool isConstantFoldable()
583	{
584		for(TIntermNode *node : sequence)
585		{
586			if(!node->getAsConstantUnion() || !node->getAsConstantUnion()->getUnionArrayPointer())
587			{
588				return false;
589			}
590		}
591
592		return true;
593	}
594
595protected:
596	TIntermAggregate(const TIntermAggregate&); // disallow copy constructor
597	TIntermAggregate& operator=(const TIntermAggregate&); // disallow assignment operator
598	TIntermSequence sequence;
599	TString name;
600	bool userDefined; // used for user defined function names
601
602	bool optimize;
603	bool debug;
604	TSourceLoc endLine;
605};
606
607//
608// For if tests.  Simplified since there is no switch statement.
609//
610class TIntermSelection : public TIntermTyped {
611public:
612	TIntermSelection(TIntermTyped* cond, TIntermNode* trueB, TIntermNode* falseB) :
613			TIntermTyped(TType(EbtVoid, EbpUndefined)), condition(cond), trueBlock(trueB), falseBlock(falseB) {}
614	TIntermSelection(TIntermTyped* cond, TIntermNode* trueB, TIntermNode* falseB, const TType& type) :
615			TIntermTyped(type), condition(cond), trueBlock(trueB), falseBlock(falseB)
616	{
617		this->type.setQualifier(EvqTemporary);
618	}
619
620	virtual void traverse(TIntermTraverser*);
621
622	bool usesTernaryOperator() const { return getBasicType() != EbtVoid; }
623	TIntermTyped* getCondition() const { return condition; }
624	TIntermNode* getTrueBlock() const { return trueBlock; }
625	TIntermNode* getFalseBlock() const { return falseBlock; }
626	TIntermSelection* getAsSelectionNode() { return this; }
627
628protected:
629	TIntermTyped* condition;
630	TIntermNode* trueBlock;
631	TIntermNode* falseBlock;
632};
633
634//
635// Switch statement.
636//
637class TIntermSwitch : public TIntermNode
638{
639public:
640	TIntermSwitch(TIntermTyped *init, TIntermAggregate *statementList)
641		: TIntermNode(), mInit(init), mStatementList(statementList)
642	{}
643
644	void traverse(TIntermTraverser *it);
645
646	TIntermSwitch *getAsSwitchNode() { return this; }
647
648	TIntermTyped *getInit() { return mInit; }
649	TIntermAggregate *getStatementList() { return mStatementList; }
650	void setStatementList(TIntermAggregate *statementList) { mStatementList = statementList; }
651
652protected:
653	TIntermTyped *mInit;
654	TIntermAggregate *mStatementList;
655};
656
657//
658// Case label.
659//
660class TIntermCase : public TIntermNode
661{
662public:
663	TIntermCase(TIntermTyped *condition)
664		: TIntermNode(), mCondition(condition)
665	{}
666
667	void traverse(TIntermTraverser *it);
668
669	TIntermCase *getAsCaseNode() { return this; }
670
671	bool hasCondition() const { return mCondition != nullptr; }
672	TIntermTyped *getCondition() const { return mCondition; }
673
674protected:
675	TIntermTyped *mCondition;
676};
677
678enum Visit
679{
680	PreVisit,
681	InVisit,
682	PostVisit
683};
684
685//
686// For traversing the tree.  User should derive from this,
687// put their traversal specific data in it, and then pass
688// it to a Traverse method.
689//
690// When using this, just fill in the methods for nodes you want visited.
691// Return false from a pre-visit to skip visiting that node's subtree.
692//
693class TIntermTraverser
694{
695public:
696	POOL_ALLOCATOR_NEW_DELETE();
697	TIntermTraverser(bool preVisit = true, bool inVisit = false, bool postVisit = false, bool rightToLeft = false) :
698			preVisit(preVisit),
699			inVisit(inVisit),
700			postVisit(postVisit),
701			rightToLeft(rightToLeft),
702			mDepth(0) {}
703	virtual ~TIntermTraverser() {};
704
705	virtual void visitSymbol(TIntermSymbol*) {}
706	virtual void visitConstantUnion(TIntermConstantUnion*) {}
707	virtual bool visitBinary(Visit visit, TIntermBinary*) {return true;}
708	virtual bool visitUnary(Visit visit, TIntermUnary*) {return true;}
709	virtual bool visitSelection(Visit visit, TIntermSelection*) {return true;}
710	virtual bool visitAggregate(Visit visit, TIntermAggregate*) {return true;}
711	virtual bool visitLoop(Visit visit, TIntermLoop*) {return true;}
712	virtual bool visitBranch(Visit visit, TIntermBranch*) {return true;}
713	virtual bool visitSwitch(Visit, TIntermSwitch*) { return true; }
714	virtual bool visitCase(Visit, TIntermCase*) { return true; }
715
716	void incrementDepth(TIntermNode *current)
717	{
718		mDepth++;
719		mPath.push_back(current);
720	}
721
722	void decrementDepth()
723	{
724		mDepth--;
725		mPath.pop_back();
726	}
727
728	TIntermNode *getParentNode()
729	{
730		return mPath.size() == 0 ? nullptr : mPath.back();
731	}
732
733	const bool preVisit;
734	const bool inVisit;
735	const bool postVisit;
736	const bool rightToLeft;
737
738protected:
739	int mDepth;
740
741	// All the nodes from root to the current node's parent during traversing.
742	TVector<TIntermNode *> mPath;
743
744private:
745	struct ParentBlock
746	{
747		ParentBlock(TIntermAggregate *nodeIn, TIntermSequence::size_type posIn)
748		: node(nodeIn), pos(posIn)
749		{}
750
751		TIntermAggregate *node;
752		TIntermSequence::size_type pos;
753	};
754	// All the code blocks from the root to the current node's parent during traversal.
755	std::vector<ParentBlock> mParentBlockStack;
756};
757
758#endif // __INTERMEDIATE_H
759