1//
2// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6
7#include "compiler/translator/DetectCallDepth.h"
8#include "compiler/translator/InfoSink.h"
9
10DetectCallDepth::FunctionNode::FunctionNode(const TString& fname)
11    : name(fname),
12      visit(PreVisit)
13{
14}
15
16const TString& DetectCallDepth::FunctionNode::getName() const
17{
18    return name;
19}
20
21void DetectCallDepth::FunctionNode::addCallee(
22    DetectCallDepth::FunctionNode* callee)
23{
24    for (size_t i = 0; i < callees.size(); ++i) {
25        if (callees[i] == callee)
26            return;
27    }
28    callees.push_back(callee);
29}
30
31int DetectCallDepth::FunctionNode::detectCallDepth(DetectCallDepth* detectCallDepth, int depth)
32{
33    ASSERT(visit == PreVisit);
34    ASSERT(detectCallDepth);
35
36    int maxDepth = depth;
37    visit = InVisit;
38    for (size_t i = 0; i < callees.size(); ++i) {
39        switch (callees[i]->visit) {
40            case InVisit:
41                // cycle detected, i.e., recursion detected.
42                return kInfiniteCallDepth;
43            case PostVisit:
44                break;
45            case PreVisit: {
46                // Check before we recurse so we don't go too depth
47                if (detectCallDepth->checkExceedsMaxDepth(depth))
48                    return depth;
49                int callDepth = callees[i]->detectCallDepth(detectCallDepth, depth + 1);
50                // Check after we recurse so we can exit immediately and provide info.
51                if (detectCallDepth->checkExceedsMaxDepth(callDepth)) {
52                    detectCallDepth->getInfoSink().info << "<-" << callees[i]->getName();
53                    return callDepth;
54                }
55                maxDepth = std::max(callDepth, maxDepth);
56                break;
57            }
58            default:
59                UNREACHABLE();
60                break;
61        }
62    }
63    visit = PostVisit;
64    return maxDepth;
65}
66
67void DetectCallDepth::FunctionNode::reset()
68{
69    visit = PreVisit;
70}
71
72DetectCallDepth::DetectCallDepth(TInfoSink& infoSink, bool limitCallStackDepth, int maxCallStackDepth)
73    : TIntermTraverser(true, false, true, false),
74      currentFunction(NULL),
75      infoSink(infoSink),
76      maxDepth(limitCallStackDepth ? maxCallStackDepth : FunctionNode::kInfiniteCallDepth)
77{
78}
79
80DetectCallDepth::~DetectCallDepth()
81{
82    for (size_t i = 0; i < functions.size(); ++i)
83        delete functions[i];
84}
85
86bool DetectCallDepth::visitAggregate(Visit visit, TIntermAggregate* node)
87{
88    switch (node->getOp())
89    {
90        case EOpPrototype:
91            // Function declaration.
92            // Don't add FunctionNode here because node->getName() is the
93            // unmangled function name.
94            break;
95        case EOpFunction: {
96            // Function definition.
97            if (visit == PreVisit) {
98                currentFunction = findFunctionByName(node->getName());
99                if (currentFunction == NULL) {
100                    currentFunction = new FunctionNode(node->getName());
101                    functions.push_back(currentFunction);
102                }
103            } else if (visit == PostVisit) {
104                currentFunction = NULL;
105            }
106            break;
107        }
108        case EOpFunctionCall: {
109            // Function call.
110            if (visit == PreVisit) {
111                FunctionNode* func = findFunctionByName(node->getName());
112                if (func == NULL) {
113                    func = new FunctionNode(node->getName());
114                    functions.push_back(func);
115                }
116                if (currentFunction)
117                    currentFunction->addCallee(func);
118            }
119            break;
120        }
121        default:
122            break;
123    }
124    return true;
125}
126
127bool DetectCallDepth::checkExceedsMaxDepth(int depth)
128{
129    return depth >= maxDepth;
130}
131
132void DetectCallDepth::resetFunctionNodes()
133{
134    for (size_t i = 0; i < functions.size(); ++i) {
135        functions[i]->reset();
136    }
137}
138
139DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepthForFunction(FunctionNode* func)
140{
141    currentFunction = NULL;
142    resetFunctionNodes();
143
144    int maxCallDepth = func->detectCallDepth(this, 1);
145
146    if (maxCallDepth == FunctionNode::kInfiniteCallDepth)
147        return kErrorRecursion;
148
149    if (maxCallDepth >= maxDepth)
150        return kErrorMaxDepthExceeded;
151
152    return kErrorNone;
153}
154
155DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepth()
156{
157    if (maxDepth != FunctionNode::kInfiniteCallDepth) {
158        // Check all functions because the driver may fail on them
159        // TODO: Before detectingRecursion, strip unused functions.
160        for (size_t i = 0; i < functions.size(); ++i) {
161            ErrorCode error = detectCallDepthForFunction(functions[i]);
162            if (error != kErrorNone)
163                return error;
164        }
165    } else {
166        FunctionNode* main = findFunctionByName("main(");
167        if (main == NULL)
168            return kErrorMissingMain;
169
170        return detectCallDepthForFunction(main);
171    }
172
173    return kErrorNone;
174}
175
176DetectCallDepth::FunctionNode* DetectCallDepth::findFunctionByName(
177    const TString& name)
178{
179    for (size_t i = 0; i < functions.size(); ++i) {
180        if (functions[i]->getName() == name)
181            return functions[i];
182    }
183    return NULL;
184}
185
186