1//
2// Copyright (c) 2002-2010 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6
7#include "compiler/ValidateLimitations.h"
8#include "compiler/InfoSink.h"
9#include "compiler/ParseHelper.h"
10
11namespace {
12bool IsLoopIndex(const TIntermSymbol* symbol, const TLoopStack& stack) {
13    for (TLoopStack::const_iterator i = stack.begin(); i != stack.end(); ++i) {
14        if (i->index.id == symbol->getId())
15            return true;
16    }
17    return false;
18}
19
20// Traverses a node to check if it represents a constant index expression.
21// Definition:
22// constant-index-expressions are a superset of constant-expressions.
23// Constant-index-expressions can include loop indices as defined in
24// GLSL ES 1.0 spec, Appendix A, section 4.
25// The following are constant-index-expressions:
26// - Constant expressions
27// - Loop indices as defined in section 4
28// - Expressions composed of both of the above
29class ValidateConstIndexExpr : public TIntermTraverser {
30public:
31    ValidateConstIndexExpr(const TLoopStack& stack)
32        : mValid(true), mLoopStack(stack) {}
33
34    // Returns true if the parsed node represents a constant index expression.
35    bool isValid() const { return mValid; }
36
37    virtual void visitSymbol(TIntermSymbol* symbol) {
38        // Only constants and loop indices are allowed in a
39        // constant index expression.
40        if (mValid) {
41            mValid = (symbol->getQualifier() == EvqConst) ||
42                     IsLoopIndex(symbol, mLoopStack);
43        }
44    }
45    virtual void visitConstantUnion(TIntermConstantUnion*) {}
46    virtual bool visitBinary(Visit, TIntermBinary*) { return true; }
47    virtual bool visitUnary(Visit, TIntermUnary*) { return true; }
48    virtual bool visitSelection(Visit, TIntermSelection*) { return true; }
49    virtual bool visitAggregate(Visit, TIntermAggregate*) { return true; }
50    virtual bool visitLoop(Visit, TIntermLoop*) { return true; }
51    virtual bool visitBranch(Visit, TIntermBranch*) { return true; }
52
53private:
54    bool mValid;
55    const TLoopStack& mLoopStack;
56};
57}  // namespace
58
59ValidateLimitations::ValidateLimitations(ShShaderType shaderType,
60                                         TInfoSinkBase& sink)
61    : mShaderType(shaderType),
62      mSink(sink),
63      mNumErrors(0)
64{
65}
66
67void ValidateLimitations::visitSymbol(TIntermSymbol*)
68{
69}
70
71void ValidateLimitations::visitConstantUnion(TIntermConstantUnion*)
72{
73}
74
75bool ValidateLimitations::visitBinary(Visit, TIntermBinary* node)
76{
77    // Check if loop index is modified in the loop body.
78    validateOperation(node, node->getLeft());
79
80    // Check indexing.
81    switch (node->getOp()) {
82      case EOpIndexDirect:
83      case EOpIndexIndirect:
84        validateIndexing(node);
85        break;
86      default: break;
87    }
88    return true;
89}
90
91bool ValidateLimitations::visitUnary(Visit, TIntermUnary* node)
92{
93    // Check if loop index is modified in the loop body.
94    validateOperation(node, node->getOperand());
95
96    return true;
97}
98
99bool ValidateLimitations::visitSelection(Visit, TIntermSelection*)
100{
101    return true;
102}
103
104bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate* node)
105{
106    switch (node->getOp()) {
107      case EOpFunctionCall:
108        validateFunctionCall(node);
109        break;
110      default:
111        break;
112    }
113    return true;
114}
115
116bool ValidateLimitations::visitLoop(Visit, TIntermLoop* node)
117{
118    if (!validateLoopType(node))
119        return false;
120
121    TLoopInfo info;
122    memset(&info, 0, sizeof(TLoopInfo));
123    if (!validateForLoopHeader(node, &info))
124        return false;
125
126    TIntermNode* body = node->getBody();
127    if (body != NULL) {
128        mLoopStack.push_back(info);
129        body->traverse(this);
130        mLoopStack.pop_back();
131    }
132
133    // The loop is fully processed - no need to visit children.
134    return false;
135}
136
137bool ValidateLimitations::visitBranch(Visit, TIntermBranch*)
138{
139    return true;
140}
141
142void ValidateLimitations::error(TSourceLoc loc,
143                                const char *reason, const char* token)
144{
145    mSink.prefix(EPrefixError);
146    mSink.location(loc);
147    mSink << "'" << token << "' : " << reason << "\n";
148    ++mNumErrors;
149}
150
151bool ValidateLimitations::withinLoopBody() const
152{
153    return !mLoopStack.empty();
154}
155
156bool ValidateLimitations::isLoopIndex(const TIntermSymbol* symbol) const
157{
158    return IsLoopIndex(symbol, mLoopStack);
159}
160
161bool ValidateLimitations::validateLoopType(TIntermLoop* node) {
162    TLoopType type = node->getType();
163    if (type == ELoopFor)
164        return true;
165
166    // Reject while and do-while loops.
167    error(node->getLine(),
168          "This type of loop is not allowed",
169          type == ELoopWhile ? "while" : "do");
170    return false;
171}
172
173bool ValidateLimitations::validateForLoopHeader(TIntermLoop* node,
174                                                TLoopInfo* info)
175{
176    ASSERT(node->getType() == ELoopFor);
177
178    //
179    // The for statement has the form:
180    //    for ( init-declaration ; condition ; expression ) statement
181    //
182    if (!validateForLoopInit(node, info))
183        return false;
184    if (!validateForLoopCond(node, info))
185        return false;
186    if (!validateForLoopExpr(node, info))
187        return false;
188
189    return true;
190}
191
192bool ValidateLimitations::validateForLoopInit(TIntermLoop* node,
193                                              TLoopInfo* info)
194{
195    TIntermNode* init = node->getInit();
196    if (init == NULL) {
197        error(node->getLine(), "Missing init declaration", "for");
198        return false;
199    }
200
201    //
202    // init-declaration has the form:
203    //     type-specifier identifier = constant-expression
204    //
205    TIntermAggregate* decl = init->getAsAggregate();
206    if ((decl == NULL) || (decl->getOp() != EOpDeclaration)) {
207        error(init->getLine(), "Invalid init declaration", "for");
208        return false;
209    }
210    // To keep things simple do not allow declaration list.
211    TIntermSequence& declSeq = decl->getSequence();
212    if (declSeq.size() != 1) {
213        error(decl->getLine(), "Invalid init declaration", "for");
214        return false;
215    }
216    TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
217    if ((declInit == NULL) || (declInit->getOp() != EOpInitialize)) {
218        error(decl->getLine(), "Invalid init declaration", "for");
219        return false;
220    }
221    TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
222    if (symbol == NULL) {
223        error(declInit->getLine(), "Invalid init declaration", "for");
224        return false;
225    }
226    // The loop index has type int or float.
227    TBasicType type = symbol->getBasicType();
228    if ((type != EbtInt) && (type != EbtFloat)) {
229        error(symbol->getLine(),
230              "Invalid type for loop index", getBasicString(type));
231        return false;
232    }
233    // The loop index is initialized with constant expression.
234    if (!isConstExpr(declInit->getRight())) {
235        error(declInit->getLine(),
236              "Loop index cannot be initialized with non-constant expression",
237              symbol->getSymbol().c_str());
238        return false;
239    }
240
241    info->index.id = symbol->getId();
242    return true;
243}
244
245bool ValidateLimitations::validateForLoopCond(TIntermLoop* node,
246                                              TLoopInfo* info)
247{
248    TIntermNode* cond = node->getCondition();
249    if (cond == NULL) {
250        error(node->getLine(), "Missing condition", "for");
251        return false;
252    }
253    //
254    // condition has the form:
255    //     loop_index relational_operator constant_expression
256    //
257    TIntermBinary* binOp = cond->getAsBinaryNode();
258    if (binOp == NULL) {
259        error(node->getLine(), "Invalid condition", "for");
260        return false;
261    }
262    // Loop index should be to the left of relational operator.
263    TIntermSymbol* symbol = binOp->getLeft()->getAsSymbolNode();
264    if (symbol == NULL) {
265        error(binOp->getLine(), "Invalid condition", "for");
266        return false;
267    }
268    if (symbol->getId() != info->index.id) {
269        error(symbol->getLine(),
270              "Expected loop index", symbol->getSymbol().c_str());
271        return false;
272    }
273    // Relational operator is one of: > >= < <= == or !=.
274    switch (binOp->getOp()) {
275      case EOpEqual:
276      case EOpNotEqual:
277      case EOpLessThan:
278      case EOpGreaterThan:
279      case EOpLessThanEqual:
280      case EOpGreaterThanEqual:
281        break;
282      default:
283        error(binOp->getLine(),
284              "Invalid relational operator",
285              getOperatorString(binOp->getOp()));
286        break;
287    }
288    // Loop index must be compared with a constant.
289    if (!isConstExpr(binOp->getRight())) {
290        error(binOp->getLine(),
291              "Loop index cannot be compared with non-constant expression",
292              symbol->getSymbol().c_str());
293        return false;
294    }
295
296    return true;
297}
298
299bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node,
300                                              TLoopInfo* info)
301{
302    TIntermNode* expr = node->getExpression();
303    if (expr == NULL) {
304        error(node->getLine(), "Missing expression", "for");
305        return false;
306    }
307
308    // for expression has one of the following forms:
309    //     loop_index++
310    //     loop_index--
311    //     loop_index += constant_expression
312    //     loop_index -= constant_expression
313    //     ++loop_index
314    //     --loop_index
315    // The last two forms are not specified in the spec, but I am assuming
316    // its an oversight.
317    TIntermUnary* unOp = expr->getAsUnaryNode();
318    TIntermBinary* binOp = unOp ? NULL : expr->getAsBinaryNode();
319
320    TOperator op = EOpNull;
321    TIntermSymbol* symbol = NULL;
322    if (unOp != NULL) {
323        op = unOp->getOp();
324        symbol = unOp->getOperand()->getAsSymbolNode();
325    } else if (binOp != NULL) {
326        op = binOp->getOp();
327        symbol = binOp->getLeft()->getAsSymbolNode();
328    }
329
330    // The operand must be loop index.
331    if (symbol == NULL) {
332        error(expr->getLine(), "Invalid expression", "for");
333        return false;
334    }
335    if (symbol->getId() != info->index.id) {
336        error(symbol->getLine(),
337              "Expected loop index", symbol->getSymbol().c_str());
338        return false;
339    }
340
341    // The operator is one of: ++ -- += -=.
342    switch (op) {
343        case EOpPostIncrement:
344        case EOpPostDecrement:
345        case EOpPreIncrement:
346        case EOpPreDecrement:
347            ASSERT((unOp != NULL) && (binOp == NULL));
348            break;
349        case EOpAddAssign:
350        case EOpSubAssign:
351            ASSERT((unOp == NULL) && (binOp != NULL));
352            break;
353        default:
354            error(expr->getLine(), "Invalid operator", getOperatorString(op));
355            return false;
356    }
357
358    // Loop index must be incremented/decremented with a constant.
359    if (binOp != NULL) {
360        if (!isConstExpr(binOp->getRight())) {
361            error(binOp->getLine(),
362                  "Loop index cannot be modified by non-constant expression",
363                  symbol->getSymbol().c_str());
364            return false;
365        }
366    }
367
368    return true;
369}
370
371bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node)
372{
373    ASSERT(node->getOp() == EOpFunctionCall);
374
375    // If not within loop body, there is nothing to check.
376    if (!withinLoopBody())
377        return true;
378
379    // List of param indices for which loop indices are used as argument.
380    typedef std::vector<int> ParamIndex;
381    ParamIndex pIndex;
382    TIntermSequence& params = node->getSequence();
383    for (TIntermSequence::size_type i = 0; i < params.size(); ++i) {
384        TIntermSymbol* symbol = params[i]->getAsSymbolNode();
385        if (symbol && isLoopIndex(symbol))
386            pIndex.push_back(i);
387    }
388    // If none of the loop indices are used as arguments,
389    // there is nothing to check.
390    if (pIndex.empty())
391        return true;
392
393    bool valid = true;
394    TSymbolTable& symbolTable = GlobalParseContext->symbolTable;
395    TSymbol* symbol = symbolTable.find(node->getName());
396    ASSERT(symbol && symbol->isFunction());
397    TFunction* function = static_cast<TFunction*>(symbol);
398    for (ParamIndex::const_iterator i = pIndex.begin();
399         i != pIndex.end(); ++i) {
400        const TParameter& param = function->getParam(*i);
401        TQualifier qual = param.type->getQualifier();
402        if ((qual == EvqOut) || (qual == EvqInOut)) {
403            error(params[*i]->getLine(),
404                  "Loop index cannot be used as argument to a function out or inout parameter",
405                  params[*i]->getAsSymbolNode()->getSymbol().c_str());
406            valid = false;
407        }
408    }
409
410    return valid;
411}
412
413bool ValidateLimitations::validateOperation(TIntermOperator* node,
414                                            TIntermNode* operand) {
415    // Check if loop index is modified in the loop body.
416    if (!withinLoopBody() || !node->modifiesState())
417        return true;
418
419    const TIntermSymbol* symbol = operand->getAsSymbolNode();
420    if (symbol && isLoopIndex(symbol)) {
421        error(node->getLine(),
422              "Loop index cannot be statically assigned to within the body of the loop",
423              symbol->getSymbol().c_str());
424    }
425    return true;
426}
427
428bool ValidateLimitations::isConstExpr(TIntermNode* node)
429{
430    ASSERT(node != NULL);
431    return node->getAsConstantUnion() != NULL;
432}
433
434bool ValidateLimitations::isConstIndexExpr(TIntermNode* node)
435{
436    ASSERT(node != NULL);
437
438    ValidateConstIndexExpr validate(mLoopStack);
439    node->traverse(&validate);
440    return validate.isValid();
441}
442
443bool ValidateLimitations::validateIndexing(TIntermBinary* node)
444{
445    ASSERT((node->getOp() == EOpIndexDirect) ||
446           (node->getOp() == EOpIndexIndirect));
447
448    bool valid = true;
449    TIntermTyped* index = node->getRight();
450    // The index expression must have integral type.
451    if (!index->isScalar() || (index->getBasicType() != EbtInt)) {
452        error(index->getLine(),
453              "Index expression must have integral type",
454              index->getCompleteString().c_str());
455        valid = false;
456    }
457    // The index expession must be a constant-index-expression unless
458    // the operand is a uniform in a vertex shader.
459    TIntermTyped* operand = node->getLeft();
460    bool skip = (mShaderType == SH_VERTEX_SHADER) &&
461                (operand->getQualifier() == EvqUniform);
462    if (!skip && !isConstIndexExpr(index)) {
463        error(index->getLine(), "Index expression must be constant", "[]");
464        valid = false;
465    }
466    return valid;
467}
468
469