1//
2// Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6
7#include "compiler/translator/ValidateLimitations.h"
8#include "compiler/translator/InfoSink.h"
9#include "compiler/translator/InitializeParseContext.h"
10#include "compiler/translator/ParseContext.h"
11#include "angle_gl.h"
12
13namespace
14{
15
16// Traverses a node to check if it represents a constant index expression.
17// Definition:
18// constant-index-expressions are a superset of constant-expressions.
19// Constant-index-expressions can include loop indices as defined in
20// GLSL ES 1.0 spec, Appendix A, section 4.
21// The following are constant-index-expressions:
22// - Constant expressions
23// - Loop indices as defined in section 4
24// - Expressions composed of both of the above
25class ValidateConstIndexExpr : public TIntermTraverser
26{
27  public:
28    ValidateConstIndexExpr(TLoopStack& stack)
29        : mValid(true), mLoopStack(stack) {}
30
31    // Returns true if the parsed node represents a constant index expression.
32    bool isValid() const { return mValid; }
33
34    virtual void visitSymbol(TIntermSymbol *symbol)
35    {
36        // Only constants and loop indices are allowed in a
37        // constant index expression.
38        if (mValid)
39        {
40            mValid = (symbol->getQualifier() == EvqConst) ||
41                     (mLoopStack.findLoop(symbol));
42        }
43    }
44
45  private:
46    bool mValid;
47    TLoopStack& mLoopStack;
48};
49
50const char *GetOperatorString(TOperator op)
51{
52    switch (op)
53    {
54      case EOpInitialize: return "=";
55      case EOpAssign: return "=";
56      case EOpAddAssign: return "+=";
57      case EOpSubAssign: return "-=";
58      case EOpDivAssign: return "/=";
59
60      // Fall-through.
61      case EOpMulAssign:
62      case EOpVectorTimesMatrixAssign:
63      case EOpVectorTimesScalarAssign:
64      case EOpMatrixTimesScalarAssign:
65      case EOpMatrixTimesMatrixAssign: return "*=";
66
67      // Fall-through.
68      case EOpIndexDirect:
69      case EOpIndexIndirect: return "[]";
70
71      case EOpIndexDirectStruct:
72      case EOpIndexDirectInterfaceBlock: return ".";
73      case EOpVectorSwizzle: return ".";
74      case EOpAdd: return "+";
75      case EOpSub: return "-";
76      case EOpMul: return "*";
77      case EOpDiv: return "/";
78      case EOpMod: UNIMPLEMENTED(); break;
79      case EOpEqual: return "==";
80      case EOpNotEqual: return "!=";
81      case EOpLessThan: return "<";
82      case EOpGreaterThan: return ">";
83      case EOpLessThanEqual: return "<=";
84      case EOpGreaterThanEqual: return ">=";
85
86      // Fall-through.
87      case EOpVectorTimesScalar:
88      case EOpVectorTimesMatrix:
89      case EOpMatrixTimesVector:
90      case EOpMatrixTimesScalar:
91      case EOpMatrixTimesMatrix: return "*";
92
93      case EOpLogicalOr: return "||";
94      case EOpLogicalXor: return "^^";
95      case EOpLogicalAnd: return "&&";
96      case EOpNegative: return "-";
97      case EOpVectorLogicalNot: return "not";
98      case EOpLogicalNot: return "!";
99      case EOpPostIncrement: return "++";
100      case EOpPostDecrement: return "--";
101      case EOpPreIncrement: return "++";
102      case EOpPreDecrement: return "--";
103
104      case EOpRadians: return "radians";
105      case EOpDegrees: return "degrees";
106      case EOpSin: return "sin";
107      case EOpCos: return "cos";
108      case EOpTan: return "tan";
109      case EOpAsin: return "asin";
110      case EOpAcos: return "acos";
111      case EOpAtan: return "atan";
112      case EOpExp: return "exp";
113      case EOpLog: return "log";
114      case EOpExp2: return "exp2";
115      case EOpLog2: return "log2";
116      case EOpSqrt: return "sqrt";
117      case EOpInverseSqrt: return "inversesqrt";
118      case EOpAbs: return "abs";
119      case EOpSign: return "sign";
120      case EOpFloor: return "floor";
121      case EOpCeil: return "ceil";
122      case EOpFract: return "fract";
123      case EOpLength: return "length";
124      case EOpNormalize: return "normalize";
125      case EOpDFdx: return "dFdx";
126      case EOpDFdy: return "dFdy";
127      case EOpFwidth: return "fwidth";
128      case EOpAny: return "any";
129      case EOpAll: return "all";
130
131      default: break;
132    }
133    return "";
134}
135
136}  // namespace anonymous
137
138ValidateLimitations::ValidateLimitations(sh::GLenum shaderType,
139                                         TInfoSinkBase &sink)
140    : mShaderType(shaderType),
141      mSink(sink),
142      mNumErrors(0)
143{
144}
145
146bool ValidateLimitations::visitBinary(Visit, TIntermBinary *node)
147{
148    // Check if loop index is modified in the loop body.
149    validateOperation(node, node->getLeft());
150
151    // Check indexing.
152    switch (node->getOp())
153    {
154      case EOpIndexDirect:
155      case EOpIndexIndirect:
156        validateIndexing(node);
157        break;
158      default:
159        break;
160    }
161    return true;
162}
163
164bool ValidateLimitations::visitUnary(Visit, TIntermUnary *node)
165{
166    // Check if loop index is modified in the loop body.
167    validateOperation(node, node->getOperand());
168
169    return true;
170}
171
172bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate *node)
173{
174    switch (node->getOp()) {
175      case EOpFunctionCall:
176        validateFunctionCall(node);
177        break;
178      default:
179        break;
180    }
181    return true;
182}
183
184bool ValidateLimitations::visitLoop(Visit, TIntermLoop *node)
185{
186    if (!validateLoopType(node))
187        return false;
188
189    if (!validateForLoopHeader(node))
190        return false;
191
192    TIntermNode *body = node->getBody();
193    if (body != NULL)
194    {
195        mLoopStack.push(node);
196        body->traverse(this);
197        mLoopStack.pop();
198    }
199
200    // The loop is fully processed - no need to visit children.
201    return false;
202}
203
204void ValidateLimitations::error(TSourceLoc loc,
205                                const char *reason, const char *token)
206{
207    mSink.prefix(EPrefixError);
208    mSink.location(loc);
209    mSink << "'" << token << "' : " << reason << "\n";
210    ++mNumErrors;
211}
212
213bool ValidateLimitations::withinLoopBody() const
214{
215    return !mLoopStack.empty();
216}
217
218bool ValidateLimitations::isLoopIndex(TIntermSymbol *symbol)
219{
220    return mLoopStack.findLoop(symbol) != NULL;
221}
222
223bool ValidateLimitations::validateLoopType(TIntermLoop *node)
224{
225    TLoopType type = node->getType();
226    if (type == ELoopFor)
227        return true;
228
229    // Reject while and do-while loops.
230    error(node->getLine(),
231          "This type of loop is not allowed",
232          type == ELoopWhile ? "while" : "do");
233    return false;
234}
235
236bool ValidateLimitations::validateForLoopHeader(TIntermLoop *node)
237{
238    ASSERT(node->getType() == ELoopFor);
239
240    //
241    // The for statement has the form:
242    //    for ( init-declaration ; condition ; expression ) statement
243    //
244    int indexSymbolId = validateForLoopInit(node);
245    if (indexSymbolId < 0)
246        return false;
247    if (!validateForLoopCond(node, indexSymbolId))
248        return false;
249    if (!validateForLoopExpr(node, indexSymbolId))
250        return false;
251
252    return true;
253}
254
255int ValidateLimitations::validateForLoopInit(TIntermLoop *node)
256{
257    TIntermNode *init = node->getInit();
258    if (init == NULL)
259    {
260        error(node->getLine(), "Missing init declaration", "for");
261        return -1;
262    }
263
264    //
265    // init-declaration has the form:
266    //     type-specifier identifier = constant-expression
267    //
268    TIntermAggregate *decl = init->getAsAggregate();
269    if ((decl == NULL) || (decl->getOp() != EOpDeclaration))
270    {
271        error(init->getLine(), "Invalid init declaration", "for");
272        return -1;
273    }
274    // To keep things simple do not allow declaration list.
275    TIntermSequence *declSeq = decl->getSequence();
276    if (declSeq->size() != 1)
277    {
278        error(decl->getLine(), "Invalid init declaration", "for");
279        return -1;
280    }
281    TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode();
282    if ((declInit == NULL) || (declInit->getOp() != EOpInitialize))
283    {
284        error(decl->getLine(), "Invalid init declaration", "for");
285        return -1;
286    }
287    TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
288    if (symbol == NULL)
289    {
290        error(declInit->getLine(), "Invalid init declaration", "for");
291        return -1;
292    }
293    // The loop index has type int or float.
294    TBasicType type = symbol->getBasicType();
295    if ((type != EbtInt) && (type != EbtUInt) && (type != EbtFloat)) {
296        error(symbol->getLine(),
297              "Invalid type for loop index", getBasicString(type));
298        return -1;
299    }
300    // The loop index is initialized with constant expression.
301    if (!isConstExpr(declInit->getRight()))
302    {
303        error(declInit->getLine(),
304              "Loop index cannot be initialized with non-constant expression",
305              symbol->getSymbol().c_str());
306        return -1;
307    }
308
309    return symbol->getId();
310}
311
312bool ValidateLimitations::validateForLoopCond(TIntermLoop *node,
313                                              int indexSymbolId)
314{
315    TIntermNode *cond = node->getCondition();
316    if (cond == NULL)
317    {
318        error(node->getLine(), "Missing condition", "for");
319        return false;
320    }
321    //
322    // condition has the form:
323    //     loop_index relational_operator constant_expression
324    //
325    TIntermBinary *binOp = cond->getAsBinaryNode();
326    if (binOp == NULL)
327    {
328        error(node->getLine(), "Invalid condition", "for");
329        return false;
330    }
331    // Loop index should be to the left of relational operator.
332    TIntermSymbol *symbol = binOp->getLeft()->getAsSymbolNode();
333    if (symbol == NULL)
334    {
335        error(binOp->getLine(), "Invalid condition", "for");
336        return false;
337    }
338    if (symbol->getId() != indexSymbolId)
339    {
340        error(symbol->getLine(),
341              "Expected loop index", symbol->getSymbol().c_str());
342        return false;
343    }
344    // Relational operator is one of: > >= < <= == or !=.
345    switch (binOp->getOp())
346    {
347      case EOpEqual:
348      case EOpNotEqual:
349      case EOpLessThan:
350      case EOpGreaterThan:
351      case EOpLessThanEqual:
352      case EOpGreaterThanEqual:
353        break;
354      default:
355        error(binOp->getLine(),
356              "Invalid relational operator",
357              GetOperatorString(binOp->getOp()));
358        break;
359    }
360    // Loop index must be compared with a constant.
361    if (!isConstExpr(binOp->getRight()))
362    {
363        error(binOp->getLine(),
364              "Loop index cannot be compared with non-constant expression",
365              symbol->getSymbol().c_str());
366        return false;
367    }
368
369    return true;
370}
371
372bool ValidateLimitations::validateForLoopExpr(TIntermLoop *node,
373                                              int indexSymbolId)
374{
375    TIntermNode *expr = node->getExpression();
376    if (expr == NULL)
377    {
378        error(node->getLine(), "Missing expression", "for");
379        return false;
380    }
381
382    // for expression has one of the following forms:
383    //     loop_index++
384    //     loop_index--
385    //     loop_index += constant_expression
386    //     loop_index -= constant_expression
387    //     ++loop_index
388    //     --loop_index
389    // The last two forms are not specified in the spec, but I am assuming
390    // its an oversight.
391    TIntermUnary *unOp = expr->getAsUnaryNode();
392    TIntermBinary *binOp = unOp ? NULL : expr->getAsBinaryNode();
393
394    TOperator op = EOpNull;
395    TIntermSymbol *symbol = NULL;
396    if (unOp != NULL)
397    {
398        op = unOp->getOp();
399        symbol = unOp->getOperand()->getAsSymbolNode();
400    }
401    else if (binOp != NULL)
402    {
403        op = binOp->getOp();
404        symbol = binOp->getLeft()->getAsSymbolNode();
405    }
406
407    // The operand must be loop index.
408    if (symbol == NULL)
409    {
410        error(expr->getLine(), "Invalid expression", "for");
411        return false;
412    }
413    if (symbol->getId() != indexSymbolId)
414    {
415        error(symbol->getLine(),
416              "Expected loop index", symbol->getSymbol().c_str());
417        return false;
418    }
419
420    // The operator is one of: ++ -- += -=.
421    switch (op)
422    {
423      case EOpPostIncrement:
424      case EOpPostDecrement:
425      case EOpPreIncrement:
426      case EOpPreDecrement:
427        ASSERT((unOp != NULL) && (binOp == NULL));
428        break;
429      case EOpAddAssign:
430      case EOpSubAssign:
431        ASSERT((unOp == NULL) && (binOp != NULL));
432        break;
433      default:
434        error(expr->getLine(), "Invalid operator", GetOperatorString(op));
435        return false;
436    }
437
438    // Loop index must be incremented/decremented with a constant.
439    if (binOp != NULL)
440    {
441        if (!isConstExpr(binOp->getRight()))
442        {
443            error(binOp->getLine(),
444                  "Loop index cannot be modified by non-constant expression",
445                  symbol->getSymbol().c_str());
446            return false;
447        }
448    }
449
450    return true;
451}
452
453bool ValidateLimitations::validateFunctionCall(TIntermAggregate *node)
454{
455    ASSERT(node->getOp() == EOpFunctionCall);
456
457    // If not within loop body, there is nothing to check.
458    if (!withinLoopBody())
459        return true;
460
461    // List of param indices for which loop indices are used as argument.
462    typedef std::vector<size_t> ParamIndex;
463    ParamIndex pIndex;
464    TIntermSequence *params = node->getSequence();
465    for (TIntermSequence::size_type i = 0; i < params->size(); ++i)
466    {
467        TIntermSymbol *symbol = (*params)[i]->getAsSymbolNode();
468        if (symbol && isLoopIndex(symbol))
469            pIndex.push_back(i);
470    }
471    // If none of the loop indices are used as arguments,
472    // there is nothing to check.
473    if (pIndex.empty())
474        return true;
475
476    bool valid = true;
477    TSymbolTable& symbolTable = GetGlobalParseContext()->symbolTable;
478    TSymbol* symbol = symbolTable.find(node->getName(), GetGlobalParseContext()->shaderVersion);
479    ASSERT(symbol && symbol->isFunction());
480    TFunction *function = static_cast<TFunction *>(symbol);
481    for (ParamIndex::const_iterator i = pIndex.begin();
482         i != pIndex.end(); ++i)
483    {
484        const TParameter &param = function->getParam(*i);
485        TQualifier qual = param.type->getQualifier();
486        if ((qual == EvqOut) || (qual == EvqInOut))
487        {
488            error((*params)[*i]->getLine(),
489                  "Loop index cannot be used as argument to a function out or inout parameter",
490                  (*params)[*i]->getAsSymbolNode()->getSymbol().c_str());
491            valid = false;
492        }
493    }
494
495    return valid;
496}
497
498bool ValidateLimitations::validateOperation(TIntermOperator *node,
499                                            TIntermNode* operand)
500{
501    // Check if loop index is modified in the loop body.
502    if (!withinLoopBody() || !node->isAssignment())
503        return true;
504
505    TIntermSymbol *symbol = operand->getAsSymbolNode();
506    if (symbol && isLoopIndex(symbol))
507    {
508        error(node->getLine(),
509              "Loop index cannot be statically assigned to within the body of the loop",
510              symbol->getSymbol().c_str());
511    }
512    return true;
513}
514
515bool ValidateLimitations::isConstExpr(TIntermNode *node)
516{
517    ASSERT(node != NULL);
518    return node->getAsConstantUnion() != NULL;
519}
520
521bool ValidateLimitations::isConstIndexExpr(TIntermNode *node)
522{
523    ASSERT(node != NULL);
524
525    ValidateConstIndexExpr validate(mLoopStack);
526    node->traverse(&validate);
527    return validate.isValid();
528}
529
530bool ValidateLimitations::validateIndexing(TIntermBinary *node)
531{
532    ASSERT((node->getOp() == EOpIndexDirect) ||
533           (node->getOp() == EOpIndexIndirect));
534
535    bool valid = true;
536    TIntermTyped *index = node->getRight();
537    // The index expression must have integral type.
538    if (!index->isScalarInt()) {
539        error(index->getLine(),
540              "Index expression must have integral type",
541              index->getCompleteString().c_str());
542        valid = false;
543    }
544    // The index expession must be a constant-index-expression unless
545    // the operand is a uniform in a vertex shader.
546    TIntermTyped *operand = node->getLeft();
547    bool skip = (mShaderType == GL_VERTEX_SHADER) &&
548                (operand->getQualifier() == EvqUniform);
549    if (!skip && !isConstIndexExpr(index))
550    {
551        error(index->getLine(), "Index expression must be constant", "[]");
552        valid = false;
553    }
554    return valid;
555}
556
557