slang_rs_check_ast.cpp revision 616854341745b958e0c409cdb6e21abb6225aa21
1/*
2 * Copyright 2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_check_ast.h"
18
19#include "slang_assert.h"
20#include "slang_rs.h"
21#include "slang_rs_export_foreach.h"
22#include "slang_rs_export_type.h"
23
24namespace slang {
25
26void RSCheckAST::VisitStmt(clang::Stmt *S) {
27  // This function does the actual iteration through all sub-Stmt's within
28  // a given Stmt. Note that this function is skipped by all of the other
29  // Visit* functions if we have already found a higher-level match.
30  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
31       I != E;
32       I++) {
33    if (clang::Stmt *Child = *I) {
34      Visit(Child);
35    }
36  }
37}
38
39void RSCheckAST::WarnOnSetElementAt(clang::CallExpr *E) {
40  clang::FunctionDecl *Decl;
41  clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
42  Decl = clang::dyn_cast_or_null<clang::FunctionDecl>(E->getCalleeDecl());
43
44  if (!Decl || Decl->getNameAsString() != std::string("rsSetElementAt")) {
45    return;
46  }
47
48  clang::Expr *Expr;
49  clang::ImplicitCastExpr *ImplCast;
50  Expr = E->getArg(1);
51  ImplCast = clang::dyn_cast_or_null<clang::ImplicitCastExpr>(Expr);
52
53  if (!ImplCast) {
54    return;
55  }
56
57  const clang::Type *Ty;
58  const clang::VectorType *VectorTy;
59  const clang::BuiltinType *ElementTy;
60  Ty = ImplCast->getSubExpr()->getType()->getPointeeType()
61    ->getUnqualifiedDesugaredType();
62  VectorTy = clang::dyn_cast_or_null<clang::VectorType>(Ty);
63
64  if (VectorTy) {
65    ElementTy = clang::dyn_cast_or_null<clang::BuiltinType>(
66      VectorTy->getElementType()->getUnqualifiedDesugaredType());
67  } else {
68    ElementTy = clang::dyn_cast_or_null<clang::BuiltinType>(
69      Ty->getUnqualifiedDesugaredType());
70  }
71
72  if (!ElementTy) {
73    return;
74  }
75
76  // We only support vectors with 2, 3 or 4 elements.
77  if (VectorTy) {
78    switch (VectorTy->getNumElements()) {
79    default:
80      return;
81    case 2:
82    case 3:
83    case 4:
84      break;
85    }
86  }
87
88  const char *Name;
89
90  switch (ElementTy->getKind()) {
91    case clang::BuiltinType::Float:
92      Name = "float";
93      break;
94    case clang::BuiltinType::Double:
95      Name = "double";
96      break;
97    case clang::BuiltinType::Char_S:
98      Name = "char";
99      break;
100    case clang::BuiltinType::Short:
101      Name = "short";
102      break;
103    case clang::BuiltinType::Int:
104      Name = "int";
105      break;
106    case clang::BuiltinType::Long:
107      Name = "long";
108      break;
109    case clang::BuiltinType::UChar:
110      Name = "uchar";
111      break;
112    case clang::BuiltinType::UShort:
113      Name = "ushort";
114      break;
115    case clang::BuiltinType::UInt:
116      Name = "uint";
117      break;
118    case clang::BuiltinType::ULong:
119      Name = "ulong";
120      break;
121    default:
122      return;
123  }
124
125  clang::DiagnosticBuilder DiagBuilder =  DiagEngine.Report(
126    clang::FullSourceLoc(E->getLocStart(), mSM),
127    mDiagEngine.getCustomDiagID( clang::DiagnosticsEngine::Warning,
128    "untyped rsSetElementAt() can reduce performance. "
129    "Use rsSetElementAt_%0%1() instead."));
130  DiagBuilder << Name;
131
132  if (VectorTy) {
133    DiagBuilder << VectorTy->getNumElements();
134  } else {
135    DiagBuilder << "";
136  }
137
138  return;
139}
140
141void RSCheckAST::VisitCallExpr(clang::CallExpr *E) {
142  WarnOnSetElementAt(E);
143
144  for (clang::CallExpr::arg_iterator AI = E->arg_begin(), AE = E->arg_end();
145       AI != AE; ++AI) {
146    Visit(*AI);
147  }
148}
149
150void RSCheckAST::ValidateFunctionDecl(clang::FunctionDecl *FD) {
151  if (!FD) {
152    return;
153  }
154
155  if (mIsFilterscript) {
156    // Validate parameters for Filterscript.
157    size_t numParams = FD->getNumParams();
158
159    clang::QualType resultType = FD->getResultType().getCanonicalType();
160
161    // We use FD as our NamedDecl in the case of a bad return type.
162    if (!RSExportType::ValidateType(C, resultType, FD,
163                                    FD->getLocStart(), mTargetAPI,
164                                    mIsFilterscript)) {
165      mValid = false;
166    }
167
168    for (size_t i = 0; i < numParams; i++) {
169      clang::ParmVarDecl *PVD = FD->getParamDecl(i);
170      clang::QualType QT = PVD->getType().getCanonicalType();
171      if (!RSExportType::ValidateType(C, QT, PVD, PVD->getLocStart(),
172                                      mTargetAPI, mIsFilterscript)) {
173        mValid = false;
174      }
175    }
176  }
177
178  bool saveKernel = mInKernel;
179  mInKernel = RSExportForEach::isRSForEachFunc(mTargetAPI, &mDiagEngine, FD);
180
181  if (clang::Stmt *Body = FD->getBody()) {
182    Visit(Body);
183  }
184
185  mInKernel = saveKernel;
186}
187
188
189void RSCheckAST::ValidateVarDecl(clang::VarDecl *VD) {
190  if (!VD) {
191    return;
192  }
193
194  clang::QualType QT = VD->getType();
195
196  if (VD->getFormalLinkage() == clang::ExternalLinkage) {
197    llvm::StringRef TypeName;
198    const clang::Type *T = QT.getTypePtr();
199    if (!RSExportType::NormalizeType(T, TypeName, &mDiagEngine, VD)) {
200      mValid = false;
201    }
202  }
203
204  // We don't allow static (non-const) variables within kernels.
205  if (mInKernel && VD->isStaticLocal()) {
206    if (!QT.isConstQualified()) {
207      mDiagEngine.Report(
208        clang::FullSourceLoc(VD->getLocation(), mSM),
209        mDiagEngine.getCustomDiagID(
210          clang::DiagnosticsEngine::Error,
211          "Non-const static variables are not allowed in kernels: '%0'"))
212          << VD->getName();
213      mValid = false;
214    }
215  }
216
217  if (!RSExportType::ValidateVarDecl(VD, mTargetAPI, mIsFilterscript)) {
218    mValid = false;
219  } else if (clang::Expr *Init = VD->getInit()) {
220    // Only check the initializer if the decl is already ok.
221    Visit(Init);
222  }
223}
224
225
226void RSCheckAST::VisitDeclStmt(clang::DeclStmt *DS) {
227  if (!SlangRS::IsLocInRSHeaderFile(DS->getLocStart(), mSM)) {
228    for (clang::DeclStmt::decl_iterator I = DS->decl_begin(),
229                                        E = DS->decl_end();
230         I != E;
231         ++I) {
232      if (clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I)) {
233        ValidateVarDecl(VD);
234      } else if (clang::FunctionDecl *FD =
235            llvm::dyn_cast<clang::FunctionDecl>(*I)) {
236        ValidateFunctionDecl(FD);
237      }
238    }
239  }
240}
241
242
243void RSCheckAST::VisitCastExpr(clang::CastExpr *CE) {
244  if (CE->getCastKind() == clang::CK_BitCast) {
245    clang::QualType QT = CE->getType();
246    const clang::Type *T = QT.getTypePtr();
247    if (T->isVectorType()) {
248      clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
249      if (llvm::isa<clang::ImplicitCastExpr>(CE)) {
250        DiagEngine.Report(
251          clang::FullSourceLoc(CE->getExprLoc(),
252                               DiagEngine.getSourceManager()),
253          DiagEngine.getCustomDiagID(clang::DiagnosticsEngine::Error,
254                                     "invalid implicit vector cast"));
255      } else {
256        DiagEngine.Report(
257          clang::FullSourceLoc(CE->getExprLoc(),
258                               DiagEngine.getSourceManager()),
259          DiagEngine.getCustomDiagID(clang::DiagnosticsEngine::Error,
260                                     "invalid vector cast"));
261      }
262      mValid = false;
263    }
264  }
265  Visit(CE->getSubExpr());
266}
267
268
269void RSCheckAST::VisitExpr(clang::Expr *E) {
270  // This is where FS checks for code using pointer and/or 64-bit expressions
271  // (i.e. things like casts).
272
273  // First we skip implicit casts (things like function calls and explicit
274  // array accesses rely heavily on them and they are valid.
275  E = E->IgnoreImpCasts();
276  if (mIsFilterscript &&
277      !SlangRS::IsLocInRSHeaderFile(E->getExprLoc(), mSM) &&
278      !RSExportType::ValidateType(C, E->getType(), NULL, E->getExprLoc(),
279                                  mTargetAPI, mIsFilterscript)) {
280    mValid = false;
281  } else {
282    // Only visit sub-expressions if we haven't already seen a violation.
283    VisitStmt(E);
284  }
285}
286
287
288bool RSCheckAST::Validate() {
289  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
290  for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
291          DE = TUDecl->decls_end();
292       DI != DE;
293       DI++) {
294    if (!SlangRS::IsLocInRSHeaderFile(DI->getLocStart(), mSM)) {
295      if (clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*DI)) {
296        ValidateVarDecl(VD);
297      } else if (clang::FunctionDecl *FD =
298            llvm::dyn_cast<clang::FunctionDecl>(*DI)) {
299        ValidateFunctionDecl(FD);
300      } else if (clang::Stmt *Body = (*DI)->getBody()) {
301        Visit(Body);
302      }
303    }
304  }
305
306  return mValid;
307}
308
309}  // namespace slang
310