1/*
2 * Copyright 2011, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_ast_replace.h"
18
19#include "slang_assert.h"
20
21#include "llvm/Support/Casting.h"
22
23namespace slang {
24
25void RSASTReplace::ReplaceStmt(
26    clang::Stmt *OuterStmt,
27    clang::Stmt *OldStmt,
28    clang::Stmt *NewStmt) {
29  mOldStmt = OldStmt;
30  mNewStmt = NewStmt;
31  mOuterStmt = OuterStmt;
32
33  // This simplifies use in various Stmt visitor passes where the only
34  // valid type is an Expr.
35  mOldExpr = llvm::dyn_cast<clang::Expr>(OldStmt);
36  if (mOldExpr) {
37    mNewExpr = llvm::dyn_cast<clang::Expr>(NewStmt);
38  }
39  Visit(mOuterStmt);
40}
41
42void RSASTReplace::ReplaceInCompoundStmt(clang::CompoundStmt *CS) {
43  clang::Stmt **UpdatedStmtList = new clang::Stmt*[CS->size()];
44
45  unsigned UpdatedStmtCount = 0;
46  clang::CompoundStmt::body_iterator bI = CS->body_begin();
47  clang::CompoundStmt::body_iterator bE = CS->body_end();
48
49  for ( ; bI != bE; bI++) {
50    if (matchesStmt(*bI)) {
51      UpdatedStmtList[UpdatedStmtCount++] = mNewStmt;
52    } else {
53      UpdatedStmtList[UpdatedStmtCount++] = *bI;
54    }
55  }
56
57  CS->setStmts(C, llvm::makeArrayRef(UpdatedStmtList, UpdatedStmtCount));
58
59  delete [] UpdatedStmtList;
60}
61
62void RSASTReplace::VisitStmt(clang::Stmt *S) {
63  // This function does the actual iteration through all sub-Stmt's within
64  // a given Stmt. Note that this function is skipped by all of the other
65  // Visit* functions if we have already found a higher-level match.
66  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
67       I != E;
68       I++) {
69    if (clang::Stmt *Child = *I) {
70      if (!matchesStmt(Child)) {
71        Visit(Child);
72      }
73    }
74  }
75}
76
77void RSASTReplace::VisitCompoundStmt(clang::CompoundStmt *CS) {
78  VisitStmt(CS);
79  ReplaceInCompoundStmt(CS);
80}
81
82void RSASTReplace::VisitCaseStmt(clang::CaseStmt *CS) {
83  if (matchesStmt(CS->getSubStmt())) {
84    CS->setSubStmt(mNewStmt);
85  } else {
86    VisitStmt(CS);
87  }
88}
89
90void RSASTReplace::VisitDeclStmt(clang::DeclStmt* DS) {
91  VisitStmt(DS);
92  for (clang::Decl* D : DS->decls()) {
93    clang::VarDecl* VD;
94    if ((VD = llvm::dyn_cast<clang::VarDecl>(D))) {
95      if (matchesExpr(VD->getInit())) {
96        VD->setInit(mNewExpr);
97      }
98    }
99  }
100}
101
102void RSASTReplace::VisitDefaultStmt(clang::DefaultStmt *DS) {
103  if (matchesStmt(DS->getSubStmt())) {
104    DS->setSubStmt(mNewStmt);
105  } else {
106    VisitStmt(DS);
107  }
108}
109
110void RSASTReplace::VisitDoStmt(clang::DoStmt *DS) {
111  if (matchesExpr(DS->getCond())) {
112    DS->setCond(mNewExpr);
113  } else if (matchesStmt(DS->getBody())) {
114    DS->setBody(mNewStmt);
115  } else {
116    VisitStmt(DS);
117  }
118}
119
120void RSASTReplace::VisitForStmt(clang::ForStmt *FS) {
121  if (matchesStmt(FS->getInit())) {
122    FS->setInit(mNewStmt);
123  } else if (matchesExpr(FS->getCond())) {
124    FS->setCond(mNewExpr);
125  } else if (matchesExpr(FS->getInc())) {
126    FS->setInc(mNewExpr);
127  } else if (matchesStmt(FS->getBody())) {
128    FS->setBody(mNewStmt);
129  } else {
130    VisitStmt(FS);
131  }
132}
133
134void RSASTReplace::VisitIfStmt(clang::IfStmt *IS) {
135  if (matchesExpr(IS->getCond())) {
136    IS->setCond(mNewExpr);
137  } else if (matchesStmt(IS->getThen())) {
138    IS->setThen(mNewStmt);
139  } else if (matchesStmt(IS->getElse())) {
140    IS->setElse(mNewStmt);
141  } else {
142    VisitStmt(IS);
143  }
144}
145
146void RSASTReplace::VisitSwitchCase(clang::SwitchCase *SC) {
147  slangAssert(false && "Both case and default have specialized handlers");
148  VisitStmt(SC);
149}
150
151void RSASTReplace::VisitSwitchStmt(clang::SwitchStmt *SS) {
152  if (matchesExpr(SS->getCond())) {
153    SS->setCond(mNewExpr);
154  } else {
155    VisitStmt(SS);
156  }
157}
158
159void RSASTReplace::VisitWhileStmt(clang::WhileStmt *WS) {
160  if (matchesExpr(WS->getCond())) {
161    WS->setCond(mNewExpr);
162  } else if (matchesStmt(WS->getBody())) {
163    WS->setBody(mNewStmt);
164  } else {
165    VisitStmt(WS);
166  }
167}
168
169}  // namespace slang
170