1//
2// Copyright (c) 2002-2014 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6
7#include "compiler/translator/ScalarizeVecAndMatConstructorArgs.h"
8#include "compiler/translator/compilerdebug.h"
9
10#include <algorithm>
11
12#include "angle_gl.h"
13#include "common/angleutils.h"
14
15namespace
16{
17
18bool ContainsMatrixNode(const TIntermSequence &sequence)
19{
20    for (size_t ii = 0; ii < sequence.size(); ++ii)
21    {
22        TIntermTyped *node = sequence[ii]->getAsTyped();
23        if (node && node->isMatrix())
24            return true;
25    }
26    return false;
27}
28
29bool ContainsVectorNode(const TIntermSequence &sequence)
30{
31    for (size_t ii = 0; ii < sequence.size(); ++ii)
32    {
33        TIntermTyped *node = sequence[ii]->getAsTyped();
34        if (node && node->isVector())
35            return true;
36    }
37    return false;
38}
39
40TIntermConstantUnion *ConstructIndexNode(int index)
41{
42    ConstantUnion *u = new ConstantUnion[1];
43    u[0].setIConst(index);
44
45    TType type(EbtInt, EbpUndefined, EvqConst, 1);
46    TIntermConstantUnion *node = new TIntermConstantUnion(u, type);
47    return node;
48}
49
50TIntermBinary *ConstructVectorIndexBinaryNode(TIntermSymbol *symbolNode, int index)
51{
52    TIntermBinary *binary = new TIntermBinary(EOpIndexDirect);
53    binary->setLeft(symbolNode);
54    TIntermConstantUnion *indexNode = ConstructIndexNode(index);
55    binary->setRight(indexNode);
56    return binary;
57}
58
59TIntermBinary *ConstructMatrixIndexBinaryNode(
60    TIntermSymbol *symbolNode, int colIndex, int rowIndex)
61{
62    TIntermBinary *colVectorNode =
63        ConstructVectorIndexBinaryNode(symbolNode, colIndex);
64
65    TIntermBinary *binary = new TIntermBinary(EOpIndexDirect);
66    binary->setLeft(colVectorNode);
67    TIntermConstantUnion *rowIndexNode = ConstructIndexNode(rowIndex);
68    binary->setRight(rowIndexNode);
69    return binary;
70}
71
72}  // namespace anonymous
73
74bool ScalarizeVecAndMatConstructorArgs::visitAggregate(Visit visit, TIntermAggregate *node)
75{
76    if (visit == PreVisit)
77    {
78        switch (node->getOp())
79        {
80          case EOpSequence:
81            mSequenceStack.push_back(TIntermSequence());
82            {
83                for (TIntermSequence::const_iterator iter = node->getSequence()->begin();
84                     iter != node->getSequence()->end(); ++iter)
85                {
86                    TIntermNode *child = *iter;
87                    ASSERT(child != NULL);
88                    child->traverse(this);
89                    mSequenceStack.back().push_back(child);
90                }
91            }
92            if (mSequenceStack.back().size() > node->getSequence()->size())
93            {
94                node->getSequence()->clear();
95                *(node->getSequence()) = mSequenceStack.back();
96            }
97            mSequenceStack.pop_back();
98            return false;
99          case EOpConstructVec2:
100          case EOpConstructVec3:
101          case EOpConstructVec4:
102          case EOpConstructBVec2:
103          case EOpConstructBVec3:
104          case EOpConstructBVec4:
105          case EOpConstructIVec2:
106          case EOpConstructIVec3:
107          case EOpConstructIVec4:
108            if (ContainsMatrixNode(*(node->getSequence())))
109                scalarizeArgs(node, false, true);
110            break;
111          case EOpConstructMat2:
112          case EOpConstructMat3:
113          case EOpConstructMat4:
114            if (ContainsVectorNode(*(node->getSequence())))
115                scalarizeArgs(node, true, false);
116            break;
117          default:
118            break;
119        }
120    }
121    return true;
122}
123
124void ScalarizeVecAndMatConstructorArgs::scalarizeArgs(
125    TIntermAggregate *aggregate, bool scalarizeVector, bool scalarizeMatrix)
126{
127    ASSERT(aggregate);
128    int size = 0;
129    switch (aggregate->getOp())
130    {
131      case EOpConstructVec2:
132      case EOpConstructBVec2:
133      case EOpConstructIVec2:
134        size = 2;
135        break;
136      case EOpConstructVec3:
137      case EOpConstructBVec3:
138      case EOpConstructIVec3:
139        size = 3;
140        break;
141      case EOpConstructVec4:
142      case EOpConstructBVec4:
143      case EOpConstructIVec4:
144      case EOpConstructMat2:
145        size = 4;
146        break;
147      case EOpConstructMat3:
148        size = 9;
149        break;
150      case EOpConstructMat4:
151        size = 16;
152        break;
153      default:
154        break;
155    }
156    TIntermSequence *sequence = aggregate->getSequence();
157    TIntermSequence original(*sequence);
158    sequence->clear();
159    for (size_t ii = 0; ii < original.size(); ++ii)
160    {
161        ASSERT(size > 0);
162        TIntermTyped *node = original[ii]->getAsTyped();
163        ASSERT(node);
164        TString varName = createTempVariable(node);
165        if (node->isScalar())
166        {
167            TIntermSymbol *symbolNode =
168                new TIntermSymbol(-1, varName, node->getType());
169            sequence->push_back(symbolNode);
170            size--;
171        }
172        else if (node->isVector())
173        {
174            if (scalarizeVector)
175            {
176                int repeat = std::min(size, node->getNominalSize());
177                size -= repeat;
178                for (int index = 0; index < repeat; ++index)
179                {
180                    TIntermSymbol *symbolNode =
181                        new TIntermSymbol(-1, varName, node->getType());
182                    TIntermBinary *newNode = ConstructVectorIndexBinaryNode(
183                        symbolNode, index);
184                    sequence->push_back(newNode);
185                }
186            }
187            else
188            {
189                TIntermSymbol *symbolNode =
190                    new TIntermSymbol(-1, varName, node->getType());
191                sequence->push_back(symbolNode);
192                size -= node->getNominalSize();
193            }
194        }
195        else
196        {
197            ASSERT(node->isMatrix());
198            if (scalarizeMatrix)
199            {
200                int colIndex = 0, rowIndex = 0;
201                int repeat = std::min(size, node->getCols() * node->getRows());
202                size -= repeat;
203                while (repeat > 0)
204                {
205                    TIntermSymbol *symbolNode =
206                        new TIntermSymbol(-1, varName, node->getType());
207                    TIntermBinary *newNode = ConstructMatrixIndexBinaryNode(
208                        symbolNode, colIndex, rowIndex);
209                    sequence->push_back(newNode);
210                    rowIndex++;
211                    if (rowIndex >= node->getRows())
212                    {
213                        rowIndex = 0;
214                        colIndex++;
215                    }
216                    repeat--;
217                }
218            }
219            else
220            {
221                TIntermSymbol *symbolNode =
222                    new TIntermSymbol(-1, varName, node->getType());
223                sequence->push_back(symbolNode);
224                size -= node->getCols() * node->getRows();
225            }
226        }
227    }
228}
229
230TString ScalarizeVecAndMatConstructorArgs::createTempVariable(TIntermTyped *original)
231{
232    TString tempVarName = "_webgl_tmp_";
233    if (original->isScalar())
234    {
235        tempVarName += "scalar_";
236    }
237    else if (original->isVector())
238    {
239        tempVarName += "vec_";
240    }
241    else
242    {
243        ASSERT(original->isMatrix());
244        tempVarName += "mat_";
245    }
246    tempVarName += Str(mTempVarCount).c_str();
247    mTempVarCount++;
248
249    ASSERT(original);
250    TType type = original->getType();
251    type.setQualifier(EvqTemporary);
252
253    if (mShaderType == GL_FRAGMENT_SHADER &&
254        type.getBasicType() == EbtFloat &&
255        type.getPrecision() == EbpUndefined)
256    {
257        // We use the highest available precision for the temporary variable
258        // to avoid computing the actual precision using the rules defined
259        // in GLSL ES 1.0 Section 4.5.2.
260        type.setPrecision(mFragmentPrecisionHigh ? EbpHigh : EbpMedium);
261    }
262
263    TIntermBinary *init = new TIntermBinary(EOpInitialize);
264    TIntermSymbol *symbolNode = new TIntermSymbol(-1, tempVarName, type);
265    init->setLeft(symbolNode);
266    init->setRight(original);
267    init->setType(type);
268
269    TIntermAggregate *decl = new TIntermAggregate(EOpDeclaration);
270    decl->getSequence()->push_back(init);
271
272    ASSERT(mSequenceStack.size() > 0);
273    TIntermSequence &sequence = mSequenceStack.back();
274    sequence.push_back(decl);
275
276    return tempVarName;
277}
278