slang_rs_object_ref_count.cpp revision 579e4f481774275980d6e3fbb7978c674da6d302
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_object_ref_count.h"
18
19#include "clang/AST/DeclGroup.h"
20#include "clang/AST/Expr.h"
21#include "clang/AST/NestedNameSpecifier.h"
22#include "clang/AST/OperationKinds.h"
23#include "clang/AST/Stmt.h"
24#include "clang/AST/StmtVisitor.h"
25
26#include "slang_assert.h"
27#include "slang.h"
28#include "slang_rs_ast_replace.h"
29#include "slang_rs_export_type.h"
30
31namespace slang {
32
33/* Even though those two arrays are of size DataTypeMax, only entries that
34 * correspond to object types will be set.
35 */
36clang::FunctionDecl *
37RSObjectRefCount::RSSetObjectFD[DataTypeMax];
38clang::FunctionDecl *
39RSObjectRefCount::RSClearObjectFD[DataTypeMax];
40
41void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
42  for (unsigned i = 0; i < DataTypeMax; i++) {
43    RSSetObjectFD[i] = nullptr;
44    RSClearObjectFD[i] = nullptr;
45  }
46
47  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
48
49  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
50          E = TUDecl->decls_end(); I != E; I++) {
51    if ((I->getKind() >= clang::Decl::firstFunction) &&
52        (I->getKind() <= clang::Decl::lastFunction)) {
53      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
54
55      // points to RSSetObjectFD or RSClearObjectFD
56      clang::FunctionDecl **RSObjectFD;
57
58      if (FD->getName() == "rsSetObject") {
59        slangAssert((FD->getNumParams() == 2) &&
60                    "Invalid rsSetObject function prototype (# params)");
61        RSObjectFD = RSSetObjectFD;
62      } else if (FD->getName() == "rsClearObject") {
63        slangAssert((FD->getNumParams() == 1) &&
64                    "Invalid rsClearObject function prototype (# params)");
65        RSObjectFD = RSClearObjectFD;
66      } else {
67        continue;
68      }
69
70      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
71      clang::QualType PVT = PVD->getOriginalType();
72      // The first parameter must be a pointer like rs_allocation*
73      slangAssert(PVT->isPointerType() &&
74          "Invalid rs{Set,Clear}Object function prototype (pointer param)");
75
76      // The rs object type passed to the FD
77      clang::QualType RST = PVT->getPointeeType();
78      DataType DT = RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
79      slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
80             && "must be RS object type");
81
82      if (DT >= 0 && DT < DataTypeMax) {
83          RSObjectFD[DT] = FD;
84      } else {
85          slangAssert(false && "incorrect type");
86      }
87    }
88  }
89}
90
91namespace {
92
93unsigned CountRSObjectTypes(const clang::Type *T);
94
95clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
96                                     clang::Expr *DstExpr,
97                                     clang::Expr *SrcExpr,
98                                     clang::SourceLocation StartLoc,
99                                     clang::SourceLocation Loc);
100
101// This function constructs a new CompoundStmt from the input StmtList.
102clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
103      std::vector<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
104  unsigned NewStmtCount = StmtList.size();
105  unsigned CompoundStmtCount = 0;
106
107  clang::Stmt **CompoundStmtList;
108  CompoundStmtList = new clang::Stmt*[NewStmtCount];
109
110  std::vector<clang::Stmt*>::const_iterator I = StmtList.begin();
111  std::vector<clang::Stmt*>::const_iterator E = StmtList.end();
112  for ( ; I != E; I++) {
113    CompoundStmtList[CompoundStmtCount++] = *I;
114  }
115  slangAssert(CompoundStmtCount == NewStmtCount);
116
117  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
118      C, llvm::makeArrayRef(CompoundStmtList, CompoundStmtCount), Loc, Loc);
119
120  delete [] CompoundStmtList;
121
122  return CS;
123}
124
125void AppendAfterStmt(clang::ASTContext &C,
126                     clang::CompoundStmt *CS,
127                     clang::Stmt *S,
128                     std::list<clang::Stmt*> &StmtList) {
129  slangAssert(CS);
130  clang::CompoundStmt::body_iterator bI = CS->body_begin();
131  clang::CompoundStmt::body_iterator bE = CS->body_end();
132  clang::Stmt **UpdatedStmtList =
133      new clang::Stmt*[CS->size() + StmtList.size()];
134
135  unsigned UpdatedStmtCount = 0;
136  unsigned Once = 0;
137  for ( ; bI != bE; bI++) {
138    if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
139      // If we come across a return here, we don't have anything we can
140      // reasonably replace. We should have already inserted our destructor
141      // code in the proper spot, so we just clean up and return.
142      delete [] UpdatedStmtList;
143
144      return;
145    }
146
147    UpdatedStmtList[UpdatedStmtCount++] = *bI;
148
149    if ((*bI == S) && !Once) {
150      Once++;
151      std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
152      std::list<clang::Stmt*>::const_iterator E = StmtList.end();
153      for ( ; I != E; I++) {
154        UpdatedStmtList[UpdatedStmtCount++] = *I;
155      }
156    }
157  }
158  slangAssert(Once <= 1);
159
160  // When S is nullptr, we are appending to the end of the CompoundStmt.
161  if (!S) {
162    slangAssert(Once == 0);
163    std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
164    std::list<clang::Stmt*>::const_iterator E = StmtList.end();
165    for ( ; I != E; I++) {
166      UpdatedStmtList[UpdatedStmtCount++] = *I;
167    }
168  }
169
170  CS->setStmts(C, llvm::makeArrayRef(UpdatedStmtList, UpdatedStmtCount));
171
172  delete [] UpdatedStmtList;
173}
174
175// This class visits a compound statement and collects a list of all the exiting
176// statements, such as any return statement in any sub-block, and any
177// break/continue statement that would resume outside the current scope.
178// We do not handle the case for goto statements that leave a local scope.
179class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
180 private:
181  // The loop depth of the currently visited node.
182  int mLoopDepth;
183
184  // The switch statement depth of the currently visited node.
185  // Note that this is tracked separately from the loop depth because
186  // SwitchStmt-contained ContinueStmt's should have destructors for the
187  // corresponding loop scope.
188  int mSwitchDepth;
189
190  // Output of the visitor: the statements that should be replaced by compound
191  // statements, each of which contains rsClearObject() calls followed by the
192  // original statement.
193  std::vector<clang::Stmt*> mExitingStmts;
194
195 public:
196  DestructorVisitor() : mLoopDepth(0), mSwitchDepth(0) {}
197
198  const std::vector<clang::Stmt*>& getExitingStmts() const {
199    return mExitingStmts;
200  }
201
202  void VisitStmt(clang::Stmt *S);
203  void VisitBreakStmt(clang::BreakStmt *BS);
204  void VisitContinueStmt(clang::ContinueStmt *CS);
205  void VisitDoStmt(clang::DoStmt *DS);
206  void VisitForStmt(clang::ForStmt *FS);
207  void VisitReturnStmt(clang::ReturnStmt *RS);
208  void VisitSwitchStmt(clang::SwitchStmt *SS);
209  void VisitWhileStmt(clang::WhileStmt *WS);
210};
211
212void DestructorVisitor::VisitStmt(clang::Stmt *S) {
213  for (clang::Stmt* Child : S->children()) {
214    if (Child) {
215      Visit(Child);
216    }
217  }
218}
219
220void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
221  VisitStmt(BS);
222  if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
223    mExitingStmts.push_back(BS);
224  }
225}
226
227void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
228  VisitStmt(CS);
229  if (mLoopDepth == 0) {
230    // Switch statements can have nested continues.
231    mExitingStmts.push_back(CS);
232  }
233}
234
235void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
236  mLoopDepth++;
237  VisitStmt(DS);
238  mLoopDepth--;
239}
240
241void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
242  mLoopDepth++;
243  VisitStmt(FS);
244  mLoopDepth--;
245}
246
247void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
248  mExitingStmts.push_back(RS);
249}
250
251void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
252  mSwitchDepth++;
253  VisitStmt(SS);
254  mSwitchDepth--;
255}
256
257void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
258  mLoopDepth++;
259  VisitStmt(WS);
260  mLoopDepth--;
261}
262
263clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
264                                 clang::Expr *RefRSVar,
265                                 clang::SourceLocation Loc) {
266  slangAssert(RefRSVar);
267  const clang::Type *T = RefRSVar->getType().getTypePtr();
268  slangAssert(!T->isArrayType() &&
269              "Should not be destroying arrays with this function");
270
271  clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
272  slangAssert((ClearObjectFD != nullptr) &&
273              "rsClearObject doesn't cover all RS object types");
274
275  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
276  clang::QualType ClearObjectFDArgType =
277      ClearObjectFD->getParamDecl(0)->getOriginalType();
278
279  // Example destructor for "rs_font localFont;"
280  //
281  // (CallExpr 'void'
282  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
283  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
284  //   (UnaryOperator 'rs_font *' prefix '&'
285  //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
286
287  // Get address of targeted RS object
288  clang::Expr *AddrRefRSVar =
289      new(C) clang::UnaryOperator(RefRSVar,
290                                  clang::UO_AddrOf,
291                                  ClearObjectFDArgType,
292                                  clang::VK_RValue,
293                                  clang::OK_Ordinary,
294                                  Loc);
295
296  clang::Expr *RefRSClearObjectFD =
297      clang::DeclRefExpr::Create(C,
298                                 clang::NestedNameSpecifierLoc(),
299                                 clang::SourceLocation(),
300                                 ClearObjectFD,
301                                 false,
302                                 ClearObjectFD->getLocation(),
303                                 ClearObjectFDType,
304                                 clang::VK_RValue,
305                                 nullptr);
306
307  clang::Expr *RSClearObjectFP =
308      clang::ImplicitCastExpr::Create(C,
309                                      C.getPointerType(ClearObjectFDType),
310                                      clang::CK_FunctionToPointerDecay,
311                                      RefRSClearObjectFD,
312                                      nullptr,
313                                      clang::VK_RValue);
314
315  llvm::SmallVector<clang::Expr*, 1> ArgList;
316  ArgList.push_back(AddrRefRSVar);
317
318  clang::CallExpr *RSClearObjectCall =
319      new(C) clang::CallExpr(C,
320                             RSClearObjectFP,
321                             ArgList,
322                             ClearObjectFD->getCallResultType(),
323                             clang::VK_RValue,
324                             Loc);
325
326  return RSClearObjectCall;
327}
328
329static int ArrayDim(const clang::Type *T) {
330  if (!T || !T->isArrayType()) {
331    return 0;
332  }
333
334  const clang::ConstantArrayType *CAT =
335    static_cast<const clang::ConstantArrayType *>(T);
336  return static_cast<int>(CAT->getSize().getSExtValue());
337}
338
339clang::Stmt *ClearStructRSObject(
340    clang::ASTContext &C,
341    clang::DeclContext *DC,
342    clang::Expr *RefRSStruct,
343    clang::SourceLocation StartLoc,
344    clang::SourceLocation Loc);
345
346clang::Stmt *ClearArrayRSObject(
347    clang::ASTContext &C,
348    clang::DeclContext *DC,
349    clang::Expr *RefRSArr,
350    clang::SourceLocation StartLoc,
351    clang::SourceLocation Loc) {
352  const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
353  slangAssert(BaseType->isArrayType());
354
355  int NumArrayElements = ArrayDim(BaseType);
356  // Actually extract out the base RS object type for use later
357  BaseType = BaseType->getArrayElementTypeNoTypeQual();
358
359  if (NumArrayElements <= 0) {
360    return nullptr;
361  }
362
363  // Example destructor loop for "rs_font fontArr[10];"
364  //
365  // (ForStmt
366  //   (DeclStmt
367  //     (VarDecl used rsIntIter 'int' cinit
368  //       (IntegerLiteral 'int' 0)))
369  //   (BinaryOperator 'int' '<'
370  //     (ImplicitCastExpr int LValueToRValue
371  //       (DeclRefExpr 'int' Var='rsIntIter'))
372  //     (IntegerLiteral 'int' 10)
373  //   nullptr << CondVar >>
374  //   (UnaryOperator 'int' postfix '++'
375  //     (DeclRefExpr 'int' Var='rsIntIter'))
376  //   (CallExpr 'void'
377  //     (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
378  //       (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
379  //     (UnaryOperator 'rs_font *' prefix '&'
380  //       (ArraySubscriptExpr 'rs_font':'rs_font'
381  //         (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
382  //           (DeclRefExpr 'rs_font [10]' Var='fontArr'))
383  //         (DeclRefExpr 'int' Var='rsIntIter'))))))
384
385  // Create helper variable for iterating through elements
386  static unsigned sIterCounter = 0;
387  std::stringstream UniqueIterName;
388  UniqueIterName << "rsIntIter" << sIterCounter++;
389  clang::IdentifierInfo *II = &C.Idents.get(UniqueIterName.str());
390  clang::VarDecl *IIVD =
391      clang::VarDecl::Create(C,
392                             DC,
393                             StartLoc,
394                             Loc,
395                             II,
396                             C.IntTy,
397                             C.getTrivialTypeSourceInfo(C.IntTy),
398                             clang::SC_None);
399  // Mark "rsIntIter" as used
400  IIVD->markUsed(C);
401
402  // Form the actual destructor loop
403  // for (Init; Cond; Inc)
404  //   RSClearObjectCall;
405
406  // Init -> "int rsIntIter = 0"
407  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
408      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
409  IIVD->setInit(Int0);
410
411  clang::Decl *IID = (clang::Decl *)IIVD;
412  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
413  clang::Stmt *Init = new(C) clang::DeclStmt(DGR, Loc, Loc);
414
415  // Cond -> "rsIntIter < NumArrayElements"
416  clang::DeclRefExpr *RefrsIntIterLValue =
417      clang::DeclRefExpr::Create(C,
418                                 clang::NestedNameSpecifierLoc(),
419                                 clang::SourceLocation(),
420                                 IIVD,
421                                 false,
422                                 Loc,
423                                 C.IntTy,
424                                 clang::VK_LValue,
425                                 nullptr);
426
427  clang::Expr *RefrsIntIterRValue =
428      clang::ImplicitCastExpr::Create(C,
429                                      RefrsIntIterLValue->getType(),
430                                      clang::CK_LValueToRValue,
431                                      RefrsIntIterLValue,
432                                      nullptr,
433                                      clang::VK_RValue);
434
435  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
436      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
437
438  clang::BinaryOperator *Cond =
439      new(C) clang::BinaryOperator(RefrsIntIterRValue,
440                                   NumArrayElementsExpr,
441                                   clang::BO_LT,
442                                   C.IntTy,
443                                   clang::VK_RValue,
444                                   clang::OK_Ordinary,
445                                   Loc,
446                                   false);
447
448  // Inc -> "rsIntIter++"
449  clang::UnaryOperator *Inc =
450      new(C) clang::UnaryOperator(RefrsIntIterLValue,
451                                  clang::UO_PostInc,
452                                  C.IntTy,
453                                  clang::VK_RValue,
454                                  clang::OK_Ordinary,
455                                  Loc);
456
457  // Body -> "rsClearObject(&VD[rsIntIter]);"
458  // Destructor loop operates on individual array elements
459
460  clang::Expr *RefRSArrPtr =
461      clang::ImplicitCastExpr::Create(C,
462          C.getPointerType(BaseType->getCanonicalTypeInternal()),
463          clang::CK_ArrayToPointerDecay,
464          RefRSArr,
465          nullptr,
466          clang::VK_RValue);
467
468  clang::Expr *RefRSArrPtrSubscript =
469      new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
470                                       RefrsIntIterRValue,
471                                       BaseType->getCanonicalTypeInternal(),
472                                       clang::VK_RValue,
473                                       clang::OK_Ordinary,
474                                       Loc);
475
476  DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
477
478  clang::Stmt *RSClearObjectCall = nullptr;
479  if (BaseType->isArrayType()) {
480    RSClearObjectCall =
481        ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
482  } else if (DT == DataTypeUnknown) {
483    RSClearObjectCall =
484        ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
485  } else {
486    RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
487  }
488
489  clang::ForStmt *DestructorLoop =
490      new(C) clang::ForStmt(C,
491                            Init,
492                            Cond,
493                            nullptr,  // no condVar
494                            Inc,
495                            RSClearObjectCall,
496                            Loc,
497                            Loc,
498                            Loc);
499
500  return DestructorLoop;
501}
502
503unsigned CountRSObjectTypes(const clang::Type *T) {
504  slangAssert(T);
505  unsigned RSObjectCount = 0;
506
507  if (T->isArrayType()) {
508    return CountRSObjectTypes(T->getArrayElementTypeNoTypeQual());
509  }
510
511  DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
512  if (DT != DataTypeUnknown) {
513    return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
514  }
515
516  if (T->isUnionType()) {
517    clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
518    RD = RD->getDefinition();
519    for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
520           FE = RD->field_end();
521         FI != FE;
522         FI++) {
523      const clang::FieldDecl *FD = *FI;
524      const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
525      if (CountRSObjectTypes(FT)) {
526        slangAssert(false && "can't have unions with RS object types!");
527        return 0;
528      }
529    }
530  }
531
532  if (!T->isStructureType()) {
533    return 0;
534  }
535
536  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
537  RD = RD->getDefinition();
538  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
539         FE = RD->field_end();
540       FI != FE;
541       FI++) {
542    const clang::FieldDecl *FD = *FI;
543    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
544    if (CountRSObjectTypes(FT)) {
545      // Sub-structs should only count once (as should arrays, etc.)
546      RSObjectCount++;
547    }
548  }
549
550  return RSObjectCount;
551}
552
553clang::Stmt *ClearStructRSObject(
554    clang::ASTContext &C,
555    clang::DeclContext *DC,
556    clang::Expr *RefRSStruct,
557    clang::SourceLocation StartLoc,
558    clang::SourceLocation Loc) {
559  const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
560
561  slangAssert(!BaseType->isArrayType());
562
563  // Structs should show up as unknown primitive types
564  slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
565              DataTypeUnknown);
566
567  unsigned FieldsToDestroy = CountRSObjectTypes(BaseType);
568  slangAssert(FieldsToDestroy != 0);
569
570  unsigned StmtCount = 0;
571  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
572  for (unsigned i = 0; i < FieldsToDestroy; i++) {
573    StmtArray[i] = nullptr;
574  }
575
576  // Populate StmtArray by creating a destructor for each RS object field
577  clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
578  RD = RD->getDefinition();
579  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
580         FE = RD->field_end();
581       FI != FE;
582       FI++) {
583    // We just look through all field declarations to see if we find a
584    // declaration for an RS object type (or an array of one).
585    bool IsArrayType = false;
586    clang::FieldDecl *FD = *FI;
587    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
588    const clang::Type *OrigType = FT;
589    while (FT && FT->isArrayType()) {
590      FT = FT->getArrayElementTypeNoTypeQual();
591      IsArrayType = true;
592    }
593
594    // Pass a DeclarationNameInfo with a valid DeclName, since name equality
595    // gets asserted during CodeGen.
596    clang::DeclarationNameInfo FDDeclNameInfo(FD->getDeclName(),
597                                              FD->getLocation());
598
599    if (RSExportPrimitiveType::IsRSObjectType(FT)) {
600      clang::DeclAccessPair FoundDecl =
601          clang::DeclAccessPair::make(FD, clang::AS_none);
602      clang::MemberExpr *RSObjectMember =
603          clang::MemberExpr::Create(C,
604                                    RefRSStruct,
605                                    false,
606                                    clang::SourceLocation(),
607                                    clang::NestedNameSpecifierLoc(),
608                                    clang::SourceLocation(),
609                                    FD,
610                                    FoundDecl,
611                                    FDDeclNameInfo,
612                                    nullptr,
613                                    OrigType->getCanonicalTypeInternal(),
614                                    clang::VK_RValue,
615                                    clang::OK_Ordinary);
616
617      slangAssert(StmtCount < FieldsToDestroy);
618
619      if (IsArrayType) {
620        StmtArray[StmtCount++] = ClearArrayRSObject(C,
621                                                    DC,
622                                                    RSObjectMember,
623                                                    StartLoc,
624                                                    Loc);
625      } else {
626        StmtArray[StmtCount++] = ClearSingleRSObject(C,
627                                                     RSObjectMember,
628                                                     Loc);
629      }
630    } else if (FT->isStructureType() && CountRSObjectTypes(FT)) {
631      // In this case, we have a nested struct. We may not end up filling all
632      // of the spaces in StmtArray (sub-structs should handle themselves
633      // with separate compound statements).
634      clang::DeclAccessPair FoundDecl =
635          clang::DeclAccessPair::make(FD, clang::AS_none);
636      clang::MemberExpr *RSObjectMember =
637          clang::MemberExpr::Create(C,
638                                    RefRSStruct,
639                                    false,
640                                    clang::SourceLocation(),
641                                    clang::NestedNameSpecifierLoc(),
642                                    clang::SourceLocation(),
643                                    FD,
644                                    FoundDecl,
645                                    clang::DeclarationNameInfo(),
646                                    nullptr,
647                                    OrigType->getCanonicalTypeInternal(),
648                                    clang::VK_RValue,
649                                    clang::OK_Ordinary);
650
651      if (IsArrayType) {
652        StmtArray[StmtCount++] = ClearArrayRSObject(C,
653                                                    DC,
654                                                    RSObjectMember,
655                                                    StartLoc,
656                                                    Loc);
657      } else {
658        StmtArray[StmtCount++] = ClearStructRSObject(C,
659                                                     DC,
660                                                     RSObjectMember,
661                                                     StartLoc,
662                                                     Loc);
663      }
664    }
665  }
666
667  slangAssert(StmtCount > 0);
668  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
669      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
670
671  delete [] StmtArray;
672
673  return CS;
674}
675
676clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
677                                     clang::Expr *DstExpr,
678                                     clang::Expr *SrcExpr,
679                                     clang::SourceLocation StartLoc,
680                                     clang::SourceLocation Loc) {
681  const clang::Type *T = DstExpr->getType().getTypePtr();
682  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
683  slangAssert((SetObjectFD != nullptr) &&
684              "rsSetObject doesn't cover all RS object types");
685
686  clang::QualType SetObjectFDType = SetObjectFD->getType();
687  clang::QualType SetObjectFDArgType[2];
688  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
689  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
690
691  clang::Expr *RefRSSetObjectFD =
692      clang::DeclRefExpr::Create(C,
693                                 clang::NestedNameSpecifierLoc(),
694                                 clang::SourceLocation(),
695                                 SetObjectFD,
696                                 false,
697                                 Loc,
698                                 SetObjectFDType,
699                                 clang::VK_RValue,
700                                 nullptr);
701
702  clang::Expr *RSSetObjectFP =
703      clang::ImplicitCastExpr::Create(C,
704                                      C.getPointerType(SetObjectFDType),
705                                      clang::CK_FunctionToPointerDecay,
706                                      RefRSSetObjectFD,
707                                      nullptr,
708                                      clang::VK_RValue);
709
710  llvm::SmallVector<clang::Expr*, 2> ArgList;
711  ArgList.push_back(new(C) clang::UnaryOperator(DstExpr,
712                                                clang::UO_AddrOf,
713                                                SetObjectFDArgType[0],
714                                                clang::VK_RValue,
715                                                clang::OK_Ordinary,
716                                                Loc));
717  ArgList.push_back(SrcExpr);
718
719  clang::CallExpr *RSSetObjectCall =
720      new(C) clang::CallExpr(C,
721                             RSSetObjectFP,
722                             ArgList,
723                             SetObjectFD->getCallResultType(),
724                             clang::VK_RValue,
725                             Loc);
726
727  return RSSetObjectCall;
728}
729
730clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
731                                     clang::Expr *LHS,
732                                     clang::Expr *RHS,
733                                     clang::SourceLocation StartLoc,
734                                     clang::SourceLocation Loc);
735
736/*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
737                                           clang::Expr *DstArr,
738                                           clang::Expr *SrcArr,
739                                           clang::SourceLocation StartLoc,
740                                           clang::SourceLocation Loc) {
741  clang::DeclContext *DC = nullptr;
742  const clang::Type *BaseType = DstArr->getType().getTypePtr();
743  slangAssert(BaseType->isArrayType());
744
745  int NumArrayElements = ArrayDim(BaseType);
746  // Actually extract out the base RS object type for use later
747  BaseType = BaseType->getArrayElementTypeNoTypeQual();
748
749  clang::Stmt *StmtArray[2] = {nullptr};
750  int StmtCtr = 0;
751
752  if (NumArrayElements <= 0) {
753    return nullptr;
754  }
755
756  // Create helper variable for iterating through elements
757  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
758  clang::VarDecl *IIVD =
759      clang::VarDecl::Create(C,
760                             DC,
761                             StartLoc,
762                             Loc,
763                             &II,
764                             C.IntTy,
765                             C.getTrivialTypeSourceInfo(C.IntTy),
766                             clang::SC_None,
767                             clang::SC_None);
768  clang::Decl *IID = (clang::Decl *)IIVD;
769
770  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
771  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
772
773  // Form the actual loop
774  // for (Init; Cond; Inc)
775  //   RSSetObjectCall;
776
777  // Init -> "rsIntIter = 0"
778  clang::DeclRefExpr *RefrsIntIter =
779      clang::DeclRefExpr::Create(C,
780                                 clang::NestedNameSpecifierLoc(),
781                                 IIVD,
782                                 Loc,
783                                 C.IntTy,
784                                 clang::VK_RValue,
785                                 nullptr);
786
787  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
788      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
789
790  clang::BinaryOperator *Init =
791      new(C) clang::BinaryOperator(RefrsIntIter,
792                                   Int0,
793                                   clang::BO_Assign,
794                                   C.IntTy,
795                                   clang::VK_RValue,
796                                   clang::OK_Ordinary,
797                                   Loc);
798
799  // Cond -> "rsIntIter < NumArrayElements"
800  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
801      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
802
803  clang::BinaryOperator *Cond =
804      new(C) clang::BinaryOperator(RefrsIntIter,
805                                   NumArrayElementsExpr,
806                                   clang::BO_LT,
807                                   C.IntTy,
808                                   clang::VK_RValue,
809                                   clang::OK_Ordinary,
810                                   Loc);
811
812  // Inc -> "rsIntIter++"
813  clang::UnaryOperator *Inc =
814      new(C) clang::UnaryOperator(RefrsIntIter,
815                                  clang::UO_PostInc,
816                                  C.IntTy,
817                                  clang::VK_RValue,
818                                  clang::OK_Ordinary,
819                                  Loc);
820
821  // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
822  // Loop operates on individual array elements
823
824  clang::Expr *DstArrPtr =
825      clang::ImplicitCastExpr::Create(C,
826          C.getPointerType(BaseType->getCanonicalTypeInternal()),
827          clang::CK_ArrayToPointerDecay,
828          DstArr,
829          nullptr,
830          clang::VK_RValue);
831
832  clang::Expr *DstArrPtrSubscript =
833      new(C) clang::ArraySubscriptExpr(DstArrPtr,
834                                       RefrsIntIter,
835                                       BaseType->getCanonicalTypeInternal(),
836                                       clang::VK_RValue,
837                                       clang::OK_Ordinary,
838                                       Loc);
839
840  clang::Expr *SrcArrPtr =
841      clang::ImplicitCastExpr::Create(C,
842          C.getPointerType(BaseType->getCanonicalTypeInternal()),
843          clang::CK_ArrayToPointerDecay,
844          SrcArr,
845          nullptr,
846          clang::VK_RValue);
847
848  clang::Expr *SrcArrPtrSubscript =
849      new(C) clang::ArraySubscriptExpr(SrcArrPtr,
850                                       RefrsIntIter,
851                                       BaseType->getCanonicalTypeInternal(),
852                                       clang::VK_RValue,
853                                       clang::OK_Ordinary,
854                                       Loc);
855
856  DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
857
858  clang::Stmt *RSSetObjectCall = nullptr;
859  if (BaseType->isArrayType()) {
860    RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
861                                             SrcArrPtrSubscript,
862                                             StartLoc, Loc);
863  } else if (DT == DataTypeUnknown) {
864    RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
865                                              SrcArrPtrSubscript,
866                                              StartLoc, Loc);
867  } else {
868    RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
869                                              SrcArrPtrSubscript,
870                                              StartLoc, Loc);
871  }
872
873  clang::ForStmt *DestructorLoop =
874      new(C) clang::ForStmt(C,
875                            Init,
876                            Cond,
877                            nullptr,  // no condVar
878                            Inc,
879                            RSSetObjectCall,
880                            Loc,
881                            Loc,
882                            Loc);
883
884  StmtArray[StmtCtr++] = DestructorLoop;
885  slangAssert(StmtCtr == 2);
886
887  clang::CompoundStmt *CS =
888      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
889
890  return CS;
891} */
892
893clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
894                                     clang::Expr *LHS,
895                                     clang::Expr *RHS,
896                                     clang::SourceLocation StartLoc,
897                                     clang::SourceLocation Loc) {
898  clang::QualType QT = LHS->getType();
899  const clang::Type *T = QT.getTypePtr();
900  slangAssert(T->isStructureType());
901  slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
902
903  // Keep an extra slot for the original copy (memcpy)
904  unsigned FieldsToSet = CountRSObjectTypes(T) + 1;
905
906  unsigned StmtCount = 0;
907  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
908  for (unsigned i = 0; i < FieldsToSet; i++) {
909    StmtArray[i] = nullptr;
910  }
911
912  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
913  RD = RD->getDefinition();
914  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
915         FE = RD->field_end();
916       FI != FE;
917       FI++) {
918    bool IsArrayType = false;
919    clang::FieldDecl *FD = *FI;
920    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
921    const clang::Type *OrigType = FT;
922
923    if (!CountRSObjectTypes(FT)) {
924      // Skip to next if we don't have any viable RS object types
925      continue;
926    }
927
928    clang::DeclAccessPair FoundDecl =
929        clang::DeclAccessPair::make(FD, clang::AS_none);
930    clang::MemberExpr *DstMember =
931        clang::MemberExpr::Create(C,
932                                  LHS,
933                                  false,
934                                  clang::SourceLocation(),
935                                  clang::NestedNameSpecifierLoc(),
936                                  clang::SourceLocation(),
937                                  FD,
938                                  FoundDecl,
939                                  clang::DeclarationNameInfo(
940                                      FD->getDeclName(),
941                                      clang::SourceLocation()),
942                                  nullptr,
943                                  OrigType->getCanonicalTypeInternal(),
944                                  clang::VK_RValue,
945                                  clang::OK_Ordinary);
946
947    clang::MemberExpr *SrcMember =
948        clang::MemberExpr::Create(C,
949                                  RHS,
950                                  false,
951                                  clang::SourceLocation(),
952                                  clang::NestedNameSpecifierLoc(),
953                                  clang::SourceLocation(),
954                                  FD,
955                                  FoundDecl,
956                                  clang::DeclarationNameInfo(
957                                      FD->getDeclName(),
958                                      clang::SourceLocation()),
959                                  nullptr,
960                                  OrigType->getCanonicalTypeInternal(),
961                                  clang::VK_RValue,
962                                  clang::OK_Ordinary);
963
964    if (FT->isArrayType()) {
965      FT = FT->getArrayElementTypeNoTypeQual();
966      IsArrayType = true;
967    }
968
969    DataType DT = RSExportPrimitiveType::GetRSSpecificType(FT);
970
971    if (IsArrayType) {
972      clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
973      DiagEngine.Report(
974        clang::FullSourceLoc(Loc, C.getSourceManager()),
975        DiagEngine.getCustomDiagID(
976          clang::DiagnosticsEngine::Error,
977          "Arrays of RS object types within structures cannot be copied"));
978      // TODO(srhines): Support setting arrays of RS objects
979      // StmtArray[StmtCount++] =
980      //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
981    } else if (DT == DataTypeUnknown) {
982      StmtArray[StmtCount++] =
983          CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
984    } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
985      StmtArray[StmtCount++] =
986          CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
987    } else {
988      slangAssert(false);
989    }
990  }
991
992  slangAssert(StmtCount < FieldsToSet);
993
994  // We still need to actually do the overall struct copy. For simplicity,
995  // we just do a straight-up assignment (which will still preserve all
996  // the proper RS object reference counts).
997  clang::BinaryOperator *CopyStruct =
998      new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
999                                   clang::VK_RValue, clang::OK_Ordinary, Loc,
1000                                   false);
1001  StmtArray[StmtCount++] = CopyStruct;
1002
1003  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
1004      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
1005
1006  delete [] StmtArray;
1007
1008  return CS;
1009}
1010
1011}  // namespace
1012
1013void RSObjectRefCount::Scope::InsertStmt(const clang::ASTContext &C,
1014                                         clang::Stmt *NewStmt) {
1015  std::vector<clang::Stmt*> newBody;
1016  for (clang::Stmt* S1 : mCS->body()) {
1017    if (S1 == mCurrent) {
1018      newBody.push_back(NewStmt);
1019    }
1020    newBody.push_back(S1);
1021  }
1022  mCS->setStmts(C, newBody);
1023}
1024
1025void RSObjectRefCount::Scope::ReplaceStmt(const clang::ASTContext &C,
1026                                          clang::Stmt *NewStmt) {
1027  std::vector<clang::Stmt*> newBody;
1028  for (clang::Stmt* S1 : mCS->body()) {
1029    if (S1 == mCurrent) {
1030      newBody.push_back(NewStmt);
1031    } else {
1032      newBody.push_back(S1);
1033    }
1034  }
1035  mCS->setStmts(C, newBody);
1036}
1037
1038void RSObjectRefCount::Scope::ReplaceExpr(const clang::ASTContext& C,
1039                                          clang::Expr* OldExpr,
1040                                          clang::Expr* NewExpr) {
1041  RSASTReplace R(C);
1042  R.ReplaceStmt(mCurrent, OldExpr, NewExpr);
1043}
1044
1045void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1046    clang::BinaryOperator *AS) {
1047
1048  clang::QualType QT = AS->getType();
1049
1050  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1051      DataTypeRSAllocation)->getASTContext();
1052
1053  clang::SourceLocation Loc = AS->getExprLoc();
1054  clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
1055  clang::Stmt *UpdatedStmt = nullptr;
1056
1057  if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1058    // By definition, this is a struct assignment if we get here
1059    UpdatedStmt =
1060        CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1061  } else {
1062    UpdatedStmt =
1063        CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1064  }
1065
1066  RSASTReplace R(C);
1067  R.ReplaceStmt(mCS, AS, UpdatedStmt);
1068}
1069
1070void RSObjectRefCount::Scope::AppendRSObjectInit(
1071    clang::VarDecl *VD,
1072    clang::DeclStmt *DS,
1073    DataType DT,
1074    clang::Expr *InitExpr) {
1075  slangAssert(VD);
1076
1077  if (!InitExpr) {
1078    return;
1079  }
1080
1081  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1082      DataTypeRSAllocation)->getASTContext();
1083  clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1084      DataTypeRSAllocation)->getLocation();
1085  clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1086      DataTypeRSAllocation)->getInnerLocStart();
1087
1088  if (DT == DataTypeIsStruct) {
1089    const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1090    clang::DeclRefExpr *RefRSVar =
1091        clang::DeclRefExpr::Create(C,
1092                                   clang::NestedNameSpecifierLoc(),
1093                                   clang::SourceLocation(),
1094                                   VD,
1095                                   false,
1096                                   Loc,
1097                                   T->getCanonicalTypeInternal(),
1098                                   clang::VK_RValue,
1099                                   nullptr);
1100
1101    clang::Stmt *RSSetObjectOps =
1102        CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1103    // Fix for b/37363420; consider:
1104    //
1105    // struct foo { rs_matrix m; };
1106    // void bar() {
1107    //   struct foo M = {...};
1108    // }
1109    //
1110    // slang modifies that declaration with initialization to a
1111    // declaration plus an assignment of the initialization values.
1112    //
1113    // void bar() {
1114    //   struct foo M = {};
1115    //   M = {...}; // by CreateStructRSSetObject() above
1116    // }
1117    //
1118    // the slang-generated statement (M = {...}) is a use of M, and we
1119    // need to mark M (clang::VarDecl *VD) as used.
1120    VD->markUsed(C);
1121
1122    std::list<clang::Stmt*> StmtList;
1123    StmtList.push_back(RSSetObjectOps);
1124    AppendAfterStmt(C, mCS, DS, StmtList);
1125    return;
1126  }
1127
1128  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1129  slangAssert((SetObjectFD != nullptr) &&
1130              "rsSetObject doesn't cover all RS object types");
1131
1132  clang::QualType SetObjectFDType = SetObjectFD->getType();
1133  clang::QualType SetObjectFDArgType[2];
1134  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1135  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1136
1137  clang::Expr *RefRSSetObjectFD =
1138      clang::DeclRefExpr::Create(C,
1139                                 clang::NestedNameSpecifierLoc(),
1140                                 clang::SourceLocation(),
1141                                 SetObjectFD,
1142                                 false,
1143                                 Loc,
1144                                 SetObjectFDType,
1145                                 clang::VK_RValue,
1146                                 nullptr);
1147
1148  clang::Expr *RSSetObjectFP =
1149      clang::ImplicitCastExpr::Create(C,
1150                                      C.getPointerType(SetObjectFDType),
1151                                      clang::CK_FunctionToPointerDecay,
1152                                      RefRSSetObjectFD,
1153                                      nullptr,
1154                                      clang::VK_RValue);
1155
1156  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1157  clang::DeclRefExpr *RefRSVar =
1158      clang::DeclRefExpr::Create(C,
1159                                 clang::NestedNameSpecifierLoc(),
1160                                 clang::SourceLocation(),
1161                                 VD,
1162                                 false,
1163                                 Loc,
1164                                 T->getCanonicalTypeInternal(),
1165                                 clang::VK_RValue,
1166                                 nullptr);
1167
1168  llvm::SmallVector<clang::Expr*, 2> ArgList;
1169  ArgList.push_back(new(C) clang::UnaryOperator(RefRSVar,
1170                                                clang::UO_AddrOf,
1171                                                SetObjectFDArgType[0],
1172                                                clang::VK_RValue,
1173                                                clang::OK_Ordinary,
1174                                                Loc));
1175  ArgList.push_back(InitExpr);
1176
1177  clang::CallExpr *RSSetObjectCall =
1178      new(C) clang::CallExpr(C,
1179                             RSSetObjectFP,
1180                             ArgList,
1181                             SetObjectFD->getCallResultType(),
1182                             clang::VK_RValue,
1183                             Loc);
1184
1185  std::list<clang::Stmt*> StmtList;
1186  StmtList.push_back(RSSetObjectCall);
1187  AppendAfterStmt(C, mCS, DS, StmtList);
1188}
1189
1190void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1191  if (mRSO.empty()) {
1192    return;
1193  }
1194
1195  clang::DeclContext* DC = mRSO.front()->getDeclContext();
1196  clang::ASTContext& C = DC->getParentASTContext();
1197  clang::SourceManager& SM = C.getSourceManager();
1198
1199  const auto& OccursBefore = [&SM] (clang::SourceLocation L1, clang::SourceLocation L2)->bool {
1200    return SM.isBeforeInTranslationUnit(L1, L2);
1201  };
1202  typedef std::map<clang::SourceLocation, clang::Stmt*, decltype(OccursBefore)> DMap;
1203
1204  DMap dtors(OccursBefore);
1205
1206  // Create rsClearObject calls. Note the DMap entries are sorted by the SourceLocation.
1207  for (clang::VarDecl* VD : mRSO) {
1208    clang::SourceLocation Loc = VD->getSourceRange().getBegin();
1209    clang::Stmt* RSClearObjectCall = ClearRSObject(VD, DC);
1210    dtors.insert(std::make_pair(Loc, RSClearObjectCall));
1211  }
1212
1213  DestructorVisitor Visitor;
1214  Visitor.Visit(mCS);
1215
1216  // Replace each exiting statement with a block that contains the original statement
1217  // and added rsClearObject() calls before it.
1218  for (clang::Stmt* S : Visitor.getExitingStmts()) {
1219
1220    const clang::SourceLocation currentLoc = S->getLocStart();
1221
1222    DMap::iterator firstDtorIter = dtors.begin();
1223    DMap::iterator currentDtorIter = firstDtorIter;
1224    DMap::iterator lastDtorIter = dtors.end();
1225
1226    while (currentDtorIter != lastDtorIter &&
1227           OccursBefore(currentDtorIter->first, currentLoc)) {
1228      currentDtorIter++;
1229    }
1230
1231    if (currentDtorIter == firstDtorIter) {
1232      continue;
1233    }
1234
1235    std::vector<clang::Stmt*> Stmts;
1236
1237    // Insert rsClearObject() calls for all rsObjects declared before the current statement
1238    for(DMap::iterator it = firstDtorIter; it != currentDtorIter; it++) {
1239      Stmts.push_back(it->second);
1240    }
1241    Stmts.push_back(S);
1242
1243    RSASTReplace R(C);
1244    clang::CompoundStmt* CS = BuildCompoundStmt(C, Stmts, S->getLocEnd());
1245    R.ReplaceStmt(mCS, S, CS);
1246  }
1247
1248  std::list<clang::Stmt*> Stmts;
1249  for(auto LocCallPair : dtors) {
1250    Stmts.push_back(LocCallPair.second);
1251  }
1252  AppendAfterStmt(C, mCS, nullptr, Stmts);
1253}
1254
1255clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
1256    clang::VarDecl *VD,
1257    clang::DeclContext *DC) {
1258  slangAssert(VD);
1259  clang::ASTContext &C = VD->getASTContext();
1260  clang::SourceLocation Loc = VD->getLocation();
1261  clang::SourceLocation StartLoc = VD->getInnerLocStart();
1262  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1263
1264  // Reference expr to target RS object variable
1265  clang::DeclRefExpr *RefRSVar =
1266      clang::DeclRefExpr::Create(C,
1267                                 clang::NestedNameSpecifierLoc(),
1268                                 clang::SourceLocation(),
1269                                 VD,
1270                                 false,
1271                                 Loc,
1272                                 T->getCanonicalTypeInternal(),
1273                                 clang::VK_RValue,
1274                                 nullptr);
1275
1276  if (T->isArrayType()) {
1277    return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1278  }
1279
1280  DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
1281
1282  if (DT == DataTypeUnknown ||
1283      DT == DataTypeIsStruct) {
1284    return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1285  }
1286
1287  slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1288              "Should be RS object");
1289
1290  return ClearSingleRSObject(C, RefRSVar, Loc);
1291}
1292
1293bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1294                                          DataType *DT,
1295                                          clang::Expr **InitExpr) {
1296  slangAssert(VD && DT && InitExpr);
1297  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1298
1299  // Loop through array types to get to base type
1300  while (T && T->isArrayType()) {
1301    T = T->getArrayElementTypeNoTypeQual();
1302  }
1303
1304  bool DataTypeIsStructWithRSObject = false;
1305  *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1306
1307  if (*DT == DataTypeUnknown) {
1308    if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1309      *DT = DataTypeIsStruct;
1310      DataTypeIsStructWithRSObject = true;
1311    } else {
1312      return false;
1313    }
1314  }
1315
1316  bool DataTypeIsRSObject = false;
1317  if (DataTypeIsStructWithRSObject) {
1318    DataTypeIsRSObject = true;
1319  } else {
1320    DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1321  }
1322  *InitExpr = VD->getInit();
1323
1324  if (!DataTypeIsRSObject && *InitExpr) {
1325    // If we already have an initializer for a matrix type, we are done.
1326    return DataTypeIsRSObject;
1327  }
1328
1329  clang::Expr *ZeroInitializer =
1330      CreateEmptyInitListExpr(VD->getASTContext(), VD->getLocation());
1331
1332  if (ZeroInitializer) {
1333    ZeroInitializer->setType(T->getCanonicalTypeInternal());
1334    VD->setInit(ZeroInitializer);
1335  }
1336
1337  return DataTypeIsRSObject;
1338}
1339
1340clang::Expr *RSObjectRefCount::CreateEmptyInitListExpr(
1341    clang::ASTContext &C,
1342    const clang::SourceLocation &Loc) {
1343
1344  // We can cheaply construct a zero initializer by just creating an empty
1345  // initializer list. Clang supports this extension to C(99), and will create
1346  // any necessary constructs to zero out the entire variable.
1347  llvm::SmallVector<clang::Expr*, 1> EmptyInitList;
1348  return new(C) clang::InitListExpr(C, Loc, EmptyInitList, Loc);
1349}
1350
1351clang::DeclRefExpr *RSObjectRefCount::CreateGuard(clang::ASTContext &C,
1352                                                  clang::DeclContext *DC,
1353                                                  clang::Expr *E,
1354                                                  const llvm::Twine &VarName,
1355                                                  std::vector<clang::Stmt*> &NewStmts) {
1356  clang::SourceLocation Loc = E->getLocStart();
1357  const clang::QualType Ty = E->getType();
1358  clang::VarDecl* TmpDecl = clang::VarDecl::Create(
1359      C,                                     // AST context
1360      DC,                                    // Decl context
1361      Loc,                                   // Start location
1362      Loc,                                   // Id location
1363      &C.Idents.get(VarName.str()),          // Id
1364      Ty,                                    // Type
1365      C.getTrivialTypeSourceInfo(Ty),        // Type info
1366      clang::SC_None                         // Storage class
1367  );
1368  const clang::Type *T = Ty.getTypePtr();
1369  clang::Expr *ZeroInitializer =
1370      RSObjectRefCount::CreateEmptyInitListExpr(C, Loc);
1371  ZeroInitializer->setType(T->getCanonicalTypeInternal());
1372  TmpDecl->setInit(ZeroInitializer);
1373  TmpDecl->markUsed(C);
1374  clang::Decl* Decls[] = { TmpDecl };
1375  const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
1376      C, Decls, sizeof(Decls) / sizeof(*Decls));
1377  clang::DeclStmt* DS = new (C) clang::DeclStmt(DGR, Loc, Loc);
1378  NewStmts.push_back(DS);
1379
1380  clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
1381      C,
1382      clang::NestedNameSpecifierLoc(),       // QualifierLoc
1383      Loc,                                   // TemplateKWLoc
1384      TmpDecl,
1385      false,                                 // RefersToEnclosingVariableOrCapture
1386      Loc,                                   // NameLoc
1387      Ty,
1388      clang::VK_LValue
1389  );
1390
1391  clang::Stmt *UpdatedStmt = nullptr;
1392  if (!RSExportPrimitiveType::IsRSObjectType(Ty.getTypePtr())) {
1393    // By definition, this is a struct assignment if we get here
1394    UpdatedStmt =
1395        CreateStructRSSetObject(C, DRE, E, Loc, Loc);
1396  } else {
1397    UpdatedStmt =
1398        CreateSingleRSSetObject(C, DRE, E, Loc, Loc);
1399  }
1400  NewStmts.push_back(UpdatedStmt);
1401
1402  return DRE;
1403}
1404
1405void RSObjectRefCount::CreateParameterGuard(clang::ASTContext &C,
1406                                            clang::DeclContext *DC,
1407                                            clang::ParmVarDecl *PD,
1408                                            std::vector<clang::Stmt*> &NewStmts) {
1409  clang::SourceLocation Loc = PD->getLocStart();
1410  clang::DeclRefExpr* ParamDRE = clang::DeclRefExpr::Create(
1411      C,
1412      clang::NestedNameSpecifierLoc(),       // QualifierLoc
1413      Loc,                                   // TemplateKWLoc
1414      PD,
1415      false,                                 // RefersToEnclosingVariableOrCapture
1416      Loc,                                   // NameLoc
1417      PD->getType(),
1418      clang::VK_RValue
1419  );
1420
1421  CreateGuard(C, DC, ParamDRE,
1422              llvm::Twine(".rs.param.") + llvm::Twine(PD->getName()), NewStmts);
1423}
1424
1425void RSObjectRefCount::HandleParamsAndLocals(clang::FunctionDecl *FD) {
1426  std::vector<clang::Stmt*> NewStmts;
1427  std::list<clang::ParmVarDecl*> ObjParams;
1428  for (clang::ParmVarDecl *Param : FD->parameters()) {
1429    clang::QualType QT = Param->getType();
1430    if (CountRSObjectTypes(QT.getTypePtr())) {
1431      // Ignore non-object types
1432      RSObjectRefCount::CreateParameterGuard(mCtx, FD, Param, NewStmts);
1433      ObjParams.push_back(Param);
1434    }
1435  }
1436
1437  clang::Stmt *OldBody = FD->getBody();
1438  if (ObjParams.empty()) {
1439    Visit(OldBody);
1440    return;
1441  }
1442
1443  NewStmts.push_back(OldBody);
1444
1445  clang::SourceLocation Loc = FD->getLocStart();
1446  clang::CompoundStmt *NewBody = BuildCompoundStmt(mCtx, NewStmts, Loc);
1447  Scope S(NewBody);
1448  for (clang::ParmVarDecl *Param : ObjParams) {
1449    S.addRSObject(Param);
1450  }
1451  mScopeStack.push_back(&S);
1452
1453  // To avoid adding unnecessary ref counting artifacts to newly added temporary
1454  // local variables for parameters, visits only the old function body here.
1455  Visit(OldBody);
1456
1457  FD->setBody(NewBody);
1458
1459  S.InsertLocalVarDestructors();
1460  mScopeStack.pop_back();
1461}
1462
1463clang::CompoundStmt* RSObjectRefCount::CreateRetStmtWithTempVar(
1464    clang::ASTContext& C,
1465    clang::DeclContext* DC,
1466    clang::ReturnStmt* RS,
1467    const unsigned id) {
1468  std::vector<clang::Stmt*> NewStmts;
1469  // Since we insert rsClearObj() calls before the return statement, we need
1470  // to make sure none of the cleared RS objects are referenced in the
1471  // return statement.
1472  // For that, we create a new local variable named .rs.retval, assign the
1473  // original return expression to it, make all necessary rsClearObj()
1474  // calls, then return .rs.retval. Note rsClearObj() is not called on
1475  // .rs.retval.
1476  clang::SourceLocation Loc = RS->getLocStart();
1477  clang::Expr* RetVal = RS->getRetValue();
1478  const clang::QualType RetTy = RetVal->getType();
1479  clang::DeclRefExpr *DRE = CreateGuard(C, DC, RetVal,
1480                                        llvm::Twine(".rs.retval") + llvm::Twine(id),
1481                                        NewStmts);
1482
1483  // Creates a new return statement
1484  clang::ReturnStmt* NewRet = new (C) clang::ReturnStmt(Loc);
1485  clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
1486      C,
1487      RetTy,
1488      clang::CK_LValueToRValue,
1489      DRE,
1490      nullptr,
1491      clang::VK_RValue
1492  );
1493  NewRet->setRetValue(CastExpr);
1494  NewStmts.push_back(NewRet);
1495
1496  return BuildCompoundStmt(C, NewStmts, Loc);
1497}
1498
1499void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1500  VisitStmt(DS);
1501  getCurrentScope()->setCurrentStmt(DS);
1502  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1503       I != E;
1504       I++) {
1505    clang::Decl *D = *I;
1506    if (D->getKind() == clang::Decl::Var) {
1507      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1508      DataType DT = DataTypeUnknown;
1509      clang::Expr *InitExpr = nullptr;
1510      if (InitializeRSObject(VD, &DT, &InitExpr)) {
1511        // We need to zero-init all RS object types (including matrices), ...
1512        getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1513        // ... but, only add to the list of RS objects if we have some
1514        // non-matrix RS object fields.
1515        if (CountRSObjectTypes(VD->getType().getTypePtr())) {
1516          getCurrentScope()->addRSObject(VD);
1517        }
1518      }
1519    }
1520  }
1521}
1522
1523void RSObjectRefCount::VisitCallExpr(clang::CallExpr* CE) {
1524  clang::QualType RetTy;
1525  const clang::FunctionDecl* FD = CE->getDirectCallee();
1526
1527  if (FD) {
1528    // Direct calls
1529
1530    RetTy = FD->getReturnType();
1531  } else {
1532    // Indirect calls through function pointers
1533
1534    const clang::Expr* Callee = CE->getCallee();
1535    const clang::Type* CalleeType = Callee->getType().getTypePtr();
1536    const clang::PointerType* PtrType = CalleeType->getAs<clang::PointerType>();
1537
1538    if (!PtrType) {
1539      return;
1540    }
1541
1542    const clang::Type* PointeeType = PtrType->getPointeeType().getTypePtr();
1543    const clang::FunctionType* FuncType = PointeeType->getAs<clang::FunctionType>();
1544
1545    if (!FuncType) {
1546      return;
1547    }
1548
1549    RetTy = FuncType->getReturnType();
1550  }
1551
1552  // The RenderScript runtime API maintains the invariant that the sysRef of a new RS object would
1553  // be 1, with the exception of rsGetAllocation() (deprecated in API 22), which leaves the sysRef
1554  // 0 for a new allocation. It is the responsibility of the callee of the API to decrement the
1555  // sysRef when a reference of the RS object goes out of scope. The compiler generates code to do
1556  // just that, by creating a temporary variable named ".rs.tmpN" with the result of
1557  // an RS-object-returning API directly assigned to it, and calling rsClearObject() on .rs.tmpN
1558  // right before it exits the current scope. Such code generation is skipped for rsGetAllocation()
1559  // to avoid decrementing its sysRef below zero.
1560
1561  if (CountRSObjectTypes(RetTy.getTypePtr())==0 ||
1562      (FD && FD->getName() == "rsGetAllocation")) {
1563    return;
1564  }
1565
1566  clang::SourceLocation Loc = CE->getSourceRange().getBegin();
1567  std::stringstream ss;
1568  ss << ".rs.tmp" << getNextID();
1569  llvm::StringRef VarName(ss.str());
1570
1571  clang::VarDecl* TempVarDecl = clang::VarDecl::Create(
1572      mCtx,                                  // AST context
1573      GetDeclContext(),                      // Decl context
1574      Loc,                                   // Start location
1575      Loc,                                   // Id location
1576      &mCtx.Idents.get(VarName),             // Id
1577      RetTy,                                 // Type
1578      mCtx.getTrivialTypeSourceInfo(RetTy),  // Type info
1579      clang::SC_None                         // Storage class
1580  );
1581  TempVarDecl->setInit(CE);
1582  TempVarDecl->markUsed(mCtx);
1583  clang::Decl* Decls[] = { TempVarDecl };
1584  const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
1585      mCtx, Decls, sizeof(Decls) / sizeof(*Decls));
1586  clang::DeclStmt* DS = new (mCtx) clang::DeclStmt(DGR, Loc, Loc);
1587
1588  getCurrentScope()->InsertStmt(mCtx, DS);
1589
1590  clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
1591      mCtx,                                  // AST context
1592      clang::NestedNameSpecifierLoc(),       // QualifierLoc
1593      Loc,                                   // TemplateKWLoc
1594      TempVarDecl,
1595      false,                                 // RefersToEnclosingVariableOrCapture
1596      Loc,                                   // NameLoc
1597      RetTy,
1598      clang::VK_LValue
1599  );
1600  clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
1601      mCtx,
1602      RetTy,
1603      clang::CK_LValueToRValue,
1604      DRE,
1605      nullptr,
1606      clang::VK_RValue
1607  );
1608
1609  getCurrentScope()->ReplaceExpr(mCtx, CE, CastExpr);
1610
1611  // Register TempVarDecl for destruction call (rsClearObj).
1612  getCurrentScope()->addRSObject(TempVarDecl);
1613}
1614
1615void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1616  if (!emptyScope()) {
1617    getCurrentScope()->setCurrentStmt(CS);
1618  }
1619
1620  if (!CS->body_empty()) {
1621    // Push a new scope
1622    Scope *S = new Scope(CS);
1623    mScopeStack.push_back(S);
1624
1625    VisitStmt(CS);
1626
1627    // Destroy the scope
1628    slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1629    S->InsertLocalVarDestructors();
1630    mScopeStack.pop_back();
1631    delete S;
1632  }
1633}
1634
1635void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1636  getCurrentScope()->setCurrentStmt(AS);
1637  clang::QualType QT = AS->getType();
1638
1639  if (CountRSObjectTypes(QT.getTypePtr())) {
1640    getCurrentScope()->ReplaceRSObjectAssignment(AS);
1641  }
1642}
1643
1644void RSObjectRefCount::VisitReturnStmt(clang::ReturnStmt *RS) {
1645  getCurrentScope()->setCurrentStmt(RS);
1646
1647  // If there is no local rsObject declared so far, no need to transform the
1648  // return statement.
1649
1650  bool RSObjDeclared = false;
1651
1652  for (const Scope* S : mScopeStack) {
1653    if (S->hasRSObject()) {
1654      RSObjDeclared = true;
1655      break;
1656    }
1657  }
1658
1659  if (!RSObjDeclared) {
1660    return;
1661  }
1662
1663  // If the return statement does not return anything, or if it does not return
1664  // a rsObject, no need to transform it.
1665
1666  clang::Expr* RetVal = RS->getRetValue();
1667  if (!RetVal || CountRSObjectTypes(RetVal->getType().getTypePtr()) == 0) {
1668    return;
1669  }
1670
1671  // Transform the return statement so that it does not potentially return or
1672  // reference a rsObject that has been cleared.
1673
1674  clang::CompoundStmt* NewRS;
1675  NewRS = CreateRetStmtWithTempVar(mCtx, GetDeclContext(), RS, getNextID());
1676
1677  getCurrentScope()->ReplaceStmt(mCtx, NewRS);
1678}
1679
1680void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1681  getCurrentScope()->setCurrentStmt(S);
1682  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1683       I != E;
1684       I++) {
1685    if (clang::Stmt *Child = *I) {
1686      Visit(Child);
1687    }
1688  }
1689}
1690
1691// This function walks the list of global variables and (potentially) creates
1692// a single global static destructor function that properly decrements
1693// reference counts on the contained RS object types.
1694clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
1695  Init();
1696
1697  clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
1698  clang::SourceLocation loc;
1699
1700  llvm::StringRef SR(".rs.dtor");
1701  clang::IdentifierInfo &II = mCtx.Idents.get(SR);
1702  clang::DeclarationName N(&II);
1703  clang::FunctionProtoType::ExtProtoInfo EPI;
1704  clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy,
1705      llvm::ArrayRef<clang::QualType>(), EPI);
1706  clang::FunctionDecl *FD = nullptr;
1707
1708  // Generate rsClearObject() call chains for every global variable
1709  // (whether static or extern).
1710  std::vector<clang::Stmt *> StmtList;
1711  for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
1712          E = DC->decls_end(); I != E; I++) {
1713    clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
1714    if (VD) {
1715      if (CountRSObjectTypes(VD->getType().getTypePtr())) {
1716        if (!FD) {
1717          // Only create FD if we are going to use it.
1718          FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, nullptr,
1719                                           clang::SC_None);
1720        }
1721        // Mark VD as used.  It might be unused, except for the destructor.
1722        // 'markUsed' has side-effects that are caused only if VD is not already
1723        // used.  Hence no need for an extra check here.
1724        VD->markUsed(mCtx);
1725        // Make sure to create any helpers within the function's DeclContext,
1726        // not the one associated with the global translation unit.
1727        clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
1728        StmtList.push_back(RSClearObjectCall);
1729      }
1730    }
1731  }
1732
1733  // Nothing needs to be destroyed, so don't emit a dtor.
1734  if (StmtList.empty()) {
1735    return nullptr;
1736  }
1737
1738  clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
1739
1740  FD->setBody(CS);
1741  // We need some way to tell if this FD is generated by slang
1742  FD->setImplicit();
1743
1744  return FD;
1745}
1746
1747bool HasRSObjectType(const clang::Type *T) {
1748  return CountRSObjectTypes(T) != 0;
1749}
1750
1751}  // namespace slang
1752