1//
2// Copyright (c) 2014 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6// RewriteElseBlocks.cpp: Implementation for tree transform to change
7//   all if-else blocks to if-if blocks.
8//
9
10#include "compiler/translator/RewriteElseBlocks.h"
11#include "compiler/translator/NodeSearch.h"
12#include "compiler/translator/SymbolTable.h"
13
14namespace sh
15{
16
17namespace
18{
19
20class ElseBlockRewriter : public TIntermTraverser
21{
22  public:
23    ElseBlockRewriter();
24
25  protected:
26    bool visitAggregate(Visit visit, TIntermAggregate *aggregate);
27
28  private:
29    int mTemporaryIndex;
30    const TType *mFunctionType;
31
32    TIntermNode *rewriteSelection(TIntermSelection *selection);
33};
34
35TIntermSymbol *MakeNewTemporary(const TString &name, TBasicType type)
36{
37    TType variableType(type, EbpHigh, EvqInternal);
38    return new TIntermSymbol(-1, name, variableType);
39}
40
41TIntermBinary *MakeNewBinary(TOperator op, TIntermTyped *left, TIntermTyped *right, const TType &resultType)
42{
43    TIntermBinary *binary = new TIntermBinary(op);
44    binary->setLeft(left);
45    binary->setRight(right);
46    binary->setType(resultType);
47    return binary;
48}
49
50TIntermUnary *MakeNewUnary(TOperator op, TIntermTyped *operand)
51{
52    TIntermUnary *unary = new TIntermUnary(op, operand->getType());
53    unary->setOperand(operand);
54    return unary;
55}
56
57ElseBlockRewriter::ElseBlockRewriter()
58    : TIntermTraverser(true, false, true, false),
59      mTemporaryIndex(0),
60      mFunctionType(NULL)
61{}
62
63bool ElseBlockRewriter::visitAggregate(Visit visit, TIntermAggregate *node)
64{
65    switch (node->getOp())
66    {
67      case EOpSequence:
68        if (visit == PostVisit)
69        {
70            for (size_t statementIndex = 0; statementIndex != node->getSequence()->size(); statementIndex++)
71            {
72                TIntermNode *statement = (*node->getSequence())[statementIndex];
73                TIntermSelection *selection = statement->getAsSelectionNode();
74                if (selection && selection->getFalseBlock() != NULL)
75                {
76                    // Check for if / else if
77                    TIntermSelection *elseIfBranch = selection->getFalseBlock()->getAsSelectionNode();
78                    if (elseIfBranch)
79                    {
80                        selection->replaceChildNode(elseIfBranch, rewriteSelection(elseIfBranch));
81                        delete elseIfBranch;
82                    }
83
84                    (*node->getSequence())[statementIndex] = rewriteSelection(selection);
85                    delete selection;
86                }
87            }
88        }
89        break;
90
91      case EOpFunction:
92        // Store the current function context (see comment below)
93        mFunctionType = ((visit == PreVisit) ? &node->getType() : NULL);
94        break;
95
96      default: break;
97    }
98
99    return true;
100}
101
102TIntermNode *ElseBlockRewriter::rewriteSelection(TIntermSelection *selection)
103{
104    ASSERT(selection != NULL);
105
106    TString temporaryName = "cond_" + str(mTemporaryIndex++);
107    TIntermTyped *typedCondition = selection->getCondition()->getAsTyped();
108    TType resultType(EbtBool, EbpUndefined);
109    TIntermSymbol *conditionSymbolInit = MakeNewTemporary(temporaryName, EbtBool);
110    TIntermBinary *storeCondition = MakeNewBinary(EOpInitialize, conditionSymbolInit,
111                                                  typedCondition, resultType);
112    TIntermNode *negatedElse = NULL;
113
114    TIntermSelection *falseBlock = NULL;
115
116    if (selection->getFalseBlock())
117    {
118        // crbug.com/346463
119        // D3D generates error messages claiming a function has no return value, when rewriting
120        // an if-else clause that returns something non-void in a function. By appending dummy
121        // returns (that are unreachable) we can silence this compile error.
122        if (mFunctionType && mFunctionType->getBasicType() != EbtVoid)
123        {
124            TString typeString = mFunctionType->getStruct() ? mFunctionType->getStruct()->name() :
125                mFunctionType->getBasicString();
126            TString rawText = "return (" + typeString + ")0";
127            negatedElse = new TIntermRaw(*mFunctionType, rawText);
128        }
129
130        TIntermSymbol *conditionSymbolElse = MakeNewTemporary(temporaryName, EbtBool);
131        TIntermUnary *negatedCondition = MakeNewUnary(EOpLogicalNot, conditionSymbolElse);
132        falseBlock = new TIntermSelection(negatedCondition,
133                                          selection->getFalseBlock(), negatedElse);
134    }
135
136    TIntermSymbol *conditionSymbolSel = MakeNewTemporary(temporaryName, EbtBool);
137    TIntermSelection *newSelection = new TIntermSelection(conditionSymbolSel,
138                                                          selection->getTrueBlock(), falseBlock);
139
140    TIntermAggregate *declaration = new TIntermAggregate(EOpDeclaration);
141    declaration->getSequence()->push_back(storeCondition);
142
143    TIntermAggregate *block = new TIntermAggregate(EOpSequence);
144    block->getSequence()->push_back(declaration);
145    block->getSequence()->push_back(newSelection);
146
147    return block;
148}
149
150}
151
152void RewriteElseBlocks(TIntermNode *node)
153{
154    ElseBlockRewriter rewriter;
155    node->traverse(&rewriter);
156}
157
158}
159