slang_rs_check_ast.cpp revision c0c5dd85f2d2df2bcf0cb284001f544d6c42eff9
1/*
2 * Copyright 2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "clang/AST/Attr.h"
18
19#include "slang_rs_check_ast.h"
20
21#include "slang_assert.h"
22#include "slang.h"
23#include "slang_rs_export_foreach.h"
24#include "slang_rs_export_reduce.h"
25#include "slang_rs_export_type.h"
26
27namespace slang {
28
29void RSCheckAST::VisitStmt(clang::Stmt *S) {
30  // This function does the actual iteration through all sub-Stmt's within
31  // a given Stmt. Note that this function is skipped by all of the other
32  // Visit* functions if we have already found a higher-level match.
33  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
34       I != E;
35       I++) {
36    if (clang::Stmt *Child = *I) {
37      Visit(Child);
38    }
39  }
40}
41
42void RSCheckAST::WarnOnSetElementAt(clang::CallExpr *E) {
43  clang::FunctionDecl *Decl;
44  Decl = clang::dyn_cast_or_null<clang::FunctionDecl>(E->getCalleeDecl());
45
46  if (!Decl || Decl->getNameAsString() != std::string("rsSetElementAt")) {
47    return;
48  }
49
50  clang::Expr *Expr;
51  clang::ImplicitCastExpr *ImplCast;
52  Expr = E->getArg(1);
53  ImplCast = clang::dyn_cast_or_null<clang::ImplicitCastExpr>(Expr);
54
55  if (!ImplCast) {
56    return;
57  }
58
59  const clang::Type *Ty;
60  const clang::VectorType *VectorTy;
61  const clang::BuiltinType *ElementTy;
62  Ty = ImplCast->getSubExpr()->getType()->getPointeeType()
63    ->getUnqualifiedDesugaredType();
64  VectorTy = clang::dyn_cast_or_null<clang::VectorType>(Ty);
65
66  if (VectorTy) {
67    ElementTy = clang::dyn_cast_or_null<clang::BuiltinType>(
68      VectorTy->getElementType()->getUnqualifiedDesugaredType());
69  } else {
70    ElementTy = clang::dyn_cast_or_null<clang::BuiltinType>(
71      Ty->getUnqualifiedDesugaredType());
72  }
73
74  if (!ElementTy) {
75    return;
76  }
77
78  // We only support vectors with 2, 3 or 4 elements.
79  if (VectorTy) {
80    switch (VectorTy->getNumElements()) {
81    default:
82      return;
83    case 2:
84    case 3:
85    case 4:
86      break;
87    }
88  }
89
90  const char *Name;
91
92  switch (ElementTy->getKind()) {
93    case clang::BuiltinType::Float:
94      Name = "float";
95      break;
96    case clang::BuiltinType::Double:
97      Name = "double";
98      break;
99    case clang::BuiltinType::Char_S:
100      Name = "char";
101      break;
102    case clang::BuiltinType::Short:
103      Name = "short";
104      break;
105    case clang::BuiltinType::Int:
106      Name = "int";
107      break;
108    case clang::BuiltinType::Long:
109      Name = "long";
110      break;
111    case clang::BuiltinType::UChar:
112      Name = "uchar";
113      break;
114    case clang::BuiltinType::UShort:
115      Name = "ushort";
116      break;
117    case clang::BuiltinType::UInt:
118      Name = "uint";
119      break;
120    case clang::BuiltinType::ULong:
121      Name = "ulong";
122      break;
123    default:
124      return;
125  }
126
127  clang::DiagnosticBuilder DiagBuilder =
128      Context->ReportWarning(E->getLocStart(),
129                             "untyped rsSetElementAt() can reduce performance. "
130                             "Use rsSetElementAt_%0%1() instead.");
131  DiagBuilder << Name;
132
133  if (VectorTy) {
134    DiagBuilder << VectorTy->getNumElements();
135  } else {
136    DiagBuilder << "";
137  }
138}
139
140void RSCheckAST::VisitCallExpr(clang::CallExpr *E) {
141  WarnOnSetElementAt(E);
142
143  for (clang::CallExpr::arg_iterator AI = E->arg_begin(), AE = E->arg_end();
144       AI != AE; ++AI) {
145    Visit(*AI);
146  }
147}
148
149void RSCheckAST::ValidateFunctionDecl(clang::FunctionDecl *FD) {
150  if (!FD) {
151    return;
152  }
153
154  if (FD->hasAttr<clang::KernelAttr>()) {
155    // Validate that the kernel attribute is not used with static.
156    if (FD->getStorageClass() == clang::SC_Static) {
157      Context->ReportError(FD->getLocation(),
158                           "Invalid use of attribute kernel with "
159                           "static function declaration: %0")
160        << FD->getName();
161      mValid = false;
162    }
163
164    // We allow no arguments to the attribute, or an expected single
165    // argument. If there is an expected single argument, we verify
166    // that it is one of the recognized kernel kinds.
167    llvm::StringRef KernelKind =
168      FD->getAttr<clang::KernelAttr>()->getKernelKind();
169
170    if (!KernelKind.empty() && !KernelKind.equals("reduce")) {
171      Context->ReportError(FD->getLocation(),
172                           "Unknown kernel attribute argument '%0' "
173                           "in declaration of function '%1'")
174        << KernelKind << FD->getName();
175      mValid = false;
176    }
177  }
178
179  clang::QualType resultType = FD->getReturnType().getCanonicalType();
180  bool isExtern = (FD->getFormalLinkage() == clang::ExternalLinkage);
181
182  // We use FD as our NamedDecl in the case of a bad return type.
183  if (!RSExportType::ValidateType(Context, C, resultType, FD,
184                                  FD->getLocStart(), mTargetAPI,
185                                  mIsFilterscript, isExtern)) {
186    mValid = false;
187  }
188
189  size_t numParams = FD->getNumParams();
190  for (size_t i = 0; i < numParams; i++) {
191    clang::ParmVarDecl *PVD = FD->getParamDecl(i);
192    clang::QualType QT = PVD->getType().getCanonicalType();
193    if (!RSExportType::ValidateType(Context, C, QT, PVD, PVD->getLocStart(),
194                                    mTargetAPI, mIsFilterscript, isExtern)) {
195      mValid = false;
196    }
197  }
198
199  bool saveKernel = mInKernel;
200  mInKernel = RSExportForEach::isRSForEachFunc(mTargetAPI, FD) ||
201              RSExportReduce::isRSReduceFunc(mTargetAPI, FD);
202
203  if (clang::Stmt *Body = FD->getBody()) {
204    Visit(Body);
205  }
206
207  mInKernel = saveKernel;
208}
209
210
211void RSCheckAST::ValidateVarDecl(clang::VarDecl *VD) {
212  if (!VD) {
213    return;
214  }
215
216  clang::QualType QT = VD->getType();
217
218  if (VD->getFormalLinkage() == clang::ExternalLinkage) {
219    llvm::StringRef TypeName;
220    const clang::Type *T = QT.getTypePtr();
221    if (!RSExportType::NormalizeType(T, TypeName, Context, VD)) {
222      mValid = false;
223    }
224  }
225
226  // We don't allow static (non-const) variables within kernels.
227  if (mInKernel && VD->isStaticLocal()) {
228    if (!QT.isConstQualified()) {
229      Context->ReportError(
230          VD->getLocation(),
231          "Non-const static variables are not allowed in kernels: '%0'")
232          << VD->getName();
233      mValid = false;
234    }
235  }
236
237  if (!RSExportType::ValidateVarDecl(Context, VD, mTargetAPI, mIsFilterscript)) {
238    mValid = false;
239  } else if (clang::Expr *Init = VD->getInit()) {
240    // Only check the initializer if the decl is already ok.
241    Visit(Init);
242  }
243}
244
245
246void RSCheckAST::VisitDeclStmt(clang::DeclStmt *DS) {
247  if (!Slang::IsLocInRSHeaderFile(DS->getLocStart(), mSM)) {
248    for (clang::DeclStmt::decl_iterator I = DS->decl_begin(),
249                                        E = DS->decl_end();
250         I != E;
251         ++I) {
252      if (clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I)) {
253        ValidateVarDecl(VD);
254      } else if (clang::FunctionDecl *FD =
255            llvm::dyn_cast<clang::FunctionDecl>(*I)) {
256        ValidateFunctionDecl(FD);
257      }
258    }
259  }
260}
261
262
263void RSCheckAST::VisitCastExpr(clang::CastExpr *CE) {
264  if (CE->getCastKind() == clang::CK_BitCast) {
265    clang::QualType QT = CE->getType();
266    const clang::Type *T = QT.getTypePtr();
267    if (T->isVectorType()) {
268      if (llvm::isa<clang::ImplicitCastExpr>(CE)) {
269        Context->ReportError(CE->getExprLoc(), "invalid implicit vector cast");
270      } else {
271        Context->ReportError(CE->getExprLoc(), "invalid vector cast");
272      }
273      mValid = false;
274    }
275  }
276  Visit(CE->getSubExpr());
277}
278
279
280void RSCheckAST::VisitExpr(clang::Expr *E) {
281  // This is where FS checks for code using pointer and/or 64-bit expressions
282  // (i.e. things like casts).
283
284  // First we skip implicit casts (things like function calls and explicit
285  // array accesses rely heavily on them and they are valid.
286  E = E->IgnoreImpCasts();
287
288  // Expressions at this point in the checker are not externally visible.
289  static const bool kIsExtern = false;
290
291  if (mIsFilterscript &&
292      !Slang::IsLocInRSHeaderFile(E->getExprLoc(), mSM) &&
293      !RSExportType::ValidateType(Context, C, E->getType(), nullptr, E->getExprLoc(),
294                                  mTargetAPI, mIsFilterscript, kIsExtern)) {
295    mValid = false;
296  } else {
297    // Only visit sub-expressions if we haven't already seen a violation.
298    VisitStmt(E);
299  }
300}
301
302
303bool RSCheckAST::Validate() {
304  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
305  for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
306          DE = TUDecl->decls_end();
307       DI != DE;
308       DI++) {
309    if (!Slang::IsLocInRSHeaderFile(DI->getLocStart(), mSM)) {
310      if (clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*DI)) {
311        ValidateVarDecl(VD);
312      } else if (clang::FunctionDecl *FD =
313            llvm::dyn_cast<clang::FunctionDecl>(*DI)) {
314        ValidateFunctionDecl(FD);
315      } else if (clang::Stmt *Body = (*DI)->getBody()) {
316        Visit(Body);
317      }
318    }
319  }
320
321  return mValid;
322}
323
324}  // namespace slang
325