slang_rs_check_ast.cpp revision 44f10063c2c08dab103a44cded0c3a288d65d43b
1/*
2 * Copyright 2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_check_ast.h"
18
19#include "slang_assert.h"
20#include "slang_rs.h"
21#include "slang_rs_export_foreach.h"
22#include "slang_rs_export_type.h"
23
24namespace slang {
25
26void RSCheckAST::VisitStmt(clang::Stmt *S) {
27  // This function does the actual iteration through all sub-Stmt's within
28  // a given Stmt. Note that this function is skipped by all of the other
29  // Visit* functions if we have already found a higher-level match.
30  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
31       I != E;
32       I++) {
33    if (clang::Stmt *Child = *I) {
34      Visit(Child);
35    }
36  }
37}
38
39void RSCheckAST::ValidateFunctionDecl(clang::FunctionDecl *FD) {
40  if (!FD) {
41    return;
42  }
43
44  if (mIsFilterscript) {
45    // Validate parameters for Filterscript.
46    size_t numParams = FD->getNumParams();
47
48    clang::QualType resultType = FD->getResultType().getCanonicalType();
49
50    // We use FD as our NamedDecl in the case of a bad return type.
51    if (!RSExportType::ValidateType(C, resultType, FD,
52                                    FD->getLocStart(), mTargetAPI,
53                                    mIsFilterscript)) {
54      mValid = false;
55    }
56
57    for (size_t i = 0; i < numParams; i++) {
58      clang::ParmVarDecl *PVD = FD->getParamDecl(i);
59      clang::QualType QT = PVD->getType().getCanonicalType();
60      if (!RSExportType::ValidateType(C, QT, PVD, PVD->getLocStart(),
61                                      mTargetAPI, mIsFilterscript)) {
62        mValid = false;
63      }
64    }
65  }
66
67  bool saveKernel = mInKernel;
68  mInKernel = RSExportForEach::isRSForEachFunc(mTargetAPI, &mDiagEngine, FD);
69
70  if (clang::Stmt *Body = FD->getBody()) {
71    Visit(Body);
72  }
73
74  mInKernel = saveKernel;
75}
76
77
78void RSCheckAST::ValidateVarDecl(clang::VarDecl *VD) {
79  if (!VD) {
80    return;
81  }
82
83  clang::QualType QT = VD->getType();
84
85  if (VD->getFormalLinkage() == clang::ExternalLinkage) {
86    llvm::StringRef TypeName;
87    const clang::Type *T = QT.getTypePtr();
88    if (!RSExportType::NormalizeType(T, TypeName, &mDiagEngine, VD)) {
89      mValid = false;
90    }
91  }
92
93  // We don't allow static (non-const) variables within kernels.
94  if (mInKernel && VD->isStaticLocal()) {
95    if (!QT.isConstQualified()) {
96      mDiagEngine.Report(
97        clang::FullSourceLoc(VD->getLocation(), mSM),
98        mDiagEngine.getCustomDiagID(
99          clang::DiagnosticsEngine::Error,
100          "Non-const static variables are not allowed in kernels: '%0'"))
101          << VD->getName();
102      mValid = false;
103    }
104  }
105
106  if (!RSExportType::ValidateVarDecl(VD, mTargetAPI, mIsFilterscript)) {
107    mValid = false;
108  } else if (clang::Expr *Init = VD->getInit()) {
109    // Only check the initializer if the decl is already ok.
110    Visit(Init);
111  }
112}
113
114
115void RSCheckAST::VisitDeclStmt(clang::DeclStmt *DS) {
116  if (!SlangRS::IsLocInRSHeaderFile(DS->getLocStart(), mSM)) {
117    for (clang::DeclStmt::decl_iterator I = DS->decl_begin(),
118                                        E = DS->decl_end();
119         I != E;
120         ++I) {
121      if (clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I)) {
122        ValidateVarDecl(VD);
123      } else if (clang::FunctionDecl *FD =
124            llvm::dyn_cast<clang::FunctionDecl>(*I)) {
125        ValidateFunctionDecl(FD);
126      }
127    }
128  }
129}
130
131
132void RSCheckAST::VisitCastExpr(clang::CastExpr *CE) {
133  if (CE->getCastKind() == clang::CK_BitCast) {
134    clang::QualType QT = CE->getType();
135    const clang::Type *T = QT.getTypePtr();
136    if (T->isVectorType()) {
137      clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
138      if (llvm::isa<clang::ImplicitCastExpr>(CE)) {
139        DiagEngine.Report(
140          clang::FullSourceLoc(CE->getExprLoc(),
141                               DiagEngine.getSourceManager()),
142          DiagEngine.getCustomDiagID(clang::DiagnosticsEngine::Error,
143                                     "invalid implicit vector cast"));
144      } else {
145        DiagEngine.Report(
146          clang::FullSourceLoc(CE->getExprLoc(),
147                               DiagEngine.getSourceManager()),
148          DiagEngine.getCustomDiagID(clang::DiagnosticsEngine::Error,
149                                     "invalid vector cast"));
150      }
151      mValid = false;
152    }
153  }
154  Visit(CE->getSubExpr());
155}
156
157
158void RSCheckAST::VisitExpr(clang::Expr *E) {
159  // This is where FS checks for code using pointer and/or 64-bit expressions
160  // (i.e. things like casts).
161
162  // First we skip implicit casts (things like function calls and explicit
163  // array accesses rely heavily on them and they are valid.
164  E = E->IgnoreImpCasts();
165  if (mIsFilterscript &&
166      !SlangRS::IsLocInRSHeaderFile(E->getExprLoc(), mSM) &&
167      !RSExportType::ValidateType(C, E->getType(), NULL, E->getExprLoc(),
168                                  mTargetAPI, mIsFilterscript)) {
169    mValid = false;
170  } else {
171    // Only visit sub-expressions if we haven't already seen a violation.
172    VisitStmt(E);
173  }
174}
175
176
177bool RSCheckAST::Validate() {
178  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
179  for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
180          DE = TUDecl->decls_end();
181       DI != DE;
182       DI++) {
183    if (!SlangRS::IsLocInRSHeaderFile(DI->getLocStart(), mSM)) {
184      if (clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*DI)) {
185        ValidateVarDecl(VD);
186      } else if (clang::FunctionDecl *FD =
187            llvm::dyn_cast<clang::FunctionDecl>(*DI)) {
188        ValidateFunctionDecl(FD);
189      } else if (clang::Stmt *Body = (*DI)->getBody()) {
190        Visit(Body);
191      }
192    }
193  }
194
195  return mValid;
196}
197
198}  // namespace slang
199