slang_rs_object_ref_count.cpp revision c9344ace5050b7d5ee9a677ce996fc0b731c0830
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_object_ref_count.h"
18
19#include <list>
20
21#include "clang/AST/DeclGroup.h"
22#include "clang/AST/Expr.h"
23#include "clang/AST/NestedNameSpecifier.h"
24#include "clang/AST/OperationKinds.h"
25#include "clang/AST/Stmt.h"
26#include "clang/AST/StmtVisitor.h"
27
28#include "slang_assert.h"
29#include "slang.h"
30#include "slang_rs_ast_replace.h"
31#include "slang_rs_export_type.h"
32
33namespace slang {
34
35/* Even though those two arrays are of size DataTypeMax, only entries that
36 * correspond to object types will be set.
37 */
38clang::FunctionDecl *
39RSObjectRefCount::RSSetObjectFD[DataTypeMax];
40clang::FunctionDecl *
41RSObjectRefCount::RSClearObjectFD[DataTypeMax];
42
43void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
44  for (unsigned i = 0; i < DataTypeMax; i++) {
45    RSSetObjectFD[i] = nullptr;
46    RSClearObjectFD[i] = nullptr;
47  }
48
49  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
50
51  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
52          E = TUDecl->decls_end(); I != E; I++) {
53    if ((I->getKind() >= clang::Decl::firstFunction) &&
54        (I->getKind() <= clang::Decl::lastFunction)) {
55      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
56
57      // points to RSSetObjectFD or RSClearObjectFD
58      clang::FunctionDecl **RSObjectFD;
59
60      if (FD->getName() == "rsSetObject") {
61        slangAssert((FD->getNumParams() == 2) &&
62                    "Invalid rsSetObject function prototype (# params)");
63        RSObjectFD = RSSetObjectFD;
64      } else if (FD->getName() == "rsClearObject") {
65        slangAssert((FD->getNumParams() == 1) &&
66                    "Invalid rsClearObject function prototype (# params)");
67        RSObjectFD = RSClearObjectFD;
68      } else {
69        continue;
70      }
71
72      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
73      clang::QualType PVT = PVD->getOriginalType();
74      // The first parameter must be a pointer like rs_allocation*
75      slangAssert(PVT->isPointerType() &&
76          "Invalid rs{Set,Clear}Object function prototype (pointer param)");
77
78      // The rs object type passed to the FD
79      clang::QualType RST = PVT->getPointeeType();
80      DataType DT = RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
81      slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
82             && "must be RS object type");
83
84      if (DT >= 0 && DT < DataTypeMax) {
85          RSObjectFD[DT] = FD;
86      } else {
87          slangAssert(false && "incorrect type");
88      }
89    }
90  }
91}
92
93namespace {
94
95unsigned CountRSObjectTypes(const clang::Type *T);
96
97clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
98                                     clang::Expr *DstExpr,
99                                     clang::Expr *SrcExpr,
100                                     clang::SourceLocation StartLoc,
101                                     clang::SourceLocation Loc);
102
103// This function constructs a new CompoundStmt from the input StmtList.
104clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
105      std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
106  unsigned NewStmtCount = StmtList.size();
107  unsigned CompoundStmtCount = 0;
108
109  clang::Stmt **CompoundStmtList;
110  CompoundStmtList = new clang::Stmt*[NewStmtCount];
111
112  std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
113  std::list<clang::Stmt*>::const_iterator E = StmtList.end();
114  for ( ; I != E; I++) {
115    CompoundStmtList[CompoundStmtCount++] = *I;
116  }
117  slangAssert(CompoundStmtCount == NewStmtCount);
118
119  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
120      C, llvm::makeArrayRef(CompoundStmtList, CompoundStmtCount), Loc, Loc);
121
122  delete [] CompoundStmtList;
123
124  return CS;
125}
126
127void AppendAfterStmt(clang::ASTContext &C,
128                     clang::CompoundStmt *CS,
129                     clang::Stmt *S,
130                     std::list<clang::Stmt*> &StmtList) {
131  slangAssert(CS);
132  clang::CompoundStmt::body_iterator bI = CS->body_begin();
133  clang::CompoundStmt::body_iterator bE = CS->body_end();
134  clang::Stmt **UpdatedStmtList =
135      new clang::Stmt*[CS->size() + StmtList.size()];
136
137  unsigned UpdatedStmtCount = 0;
138  unsigned Once = 0;
139  for ( ; bI != bE; bI++) {
140    if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
141      // If we come across a return here, we don't have anything we can
142      // reasonably replace. We should have already inserted our destructor
143      // code in the proper spot, so we just clean up and return.
144      delete [] UpdatedStmtList;
145
146      return;
147    }
148
149    UpdatedStmtList[UpdatedStmtCount++] = *bI;
150
151    if ((*bI == S) && !Once) {
152      Once++;
153      std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
154      std::list<clang::Stmt*>::const_iterator E = StmtList.end();
155      for ( ; I != E; I++) {
156        UpdatedStmtList[UpdatedStmtCount++] = *I;
157      }
158    }
159  }
160  slangAssert(Once <= 1);
161
162  // When S is nullptr, we are appending to the end of the CompoundStmt.
163  if (!S) {
164    slangAssert(Once == 0);
165    std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
166    std::list<clang::Stmt*>::const_iterator E = StmtList.end();
167    for ( ; I != E; I++) {
168      UpdatedStmtList[UpdatedStmtCount++] = *I;
169    }
170  }
171
172  CS->setStmts(C, llvm::makeArrayRef(UpdatedStmtList, UpdatedStmtCount));
173
174  delete [] UpdatedStmtList;
175}
176
177// This class visits a compound statement and collects a list of all the exiting
178// statements, such as any return statement in any sub-block, and any
179// break/continue statement that would resume outside the current scope.
180// We do not handle the case for goto statements that leave a local scope.
181class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
182 private:
183  // The loop depth of the currently visited node.
184  int mLoopDepth;
185
186  // The switch statement depth of the currently visited node.
187  // Note that this is tracked separately from the loop depth because
188  // SwitchStmt-contained ContinueStmt's should have destructors for the
189  // corresponding loop scope.
190  int mSwitchDepth;
191
192  // Output of the visitor: the statements that should be replaced by compound
193  // statements, each of which contains rsClearObject() calls followed by the
194  // original statement.
195  std::vector<clang::Stmt*> mExitingStmts;
196
197 public:
198  DestructorVisitor() : mLoopDepth(0), mSwitchDepth(0) {}
199
200  const std::vector<clang::Stmt*>& getExitingStmts() const {
201    return mExitingStmts;
202  }
203
204  void VisitStmt(clang::Stmt *S);
205  void VisitBreakStmt(clang::BreakStmt *BS);
206  void VisitContinueStmt(clang::ContinueStmt *CS);
207  void VisitDoStmt(clang::DoStmt *DS);
208  void VisitForStmt(clang::ForStmt *FS);
209  void VisitReturnStmt(clang::ReturnStmt *RS);
210  void VisitSwitchStmt(clang::SwitchStmt *SS);
211  void VisitWhileStmt(clang::WhileStmt *WS);
212};
213
214void DestructorVisitor::VisitStmt(clang::Stmt *S) {
215  for (clang::Stmt* Child : S->children()) {
216    if (Child) {
217      Visit(Child);
218    }
219  }
220}
221
222void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
223  VisitStmt(BS);
224  if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
225    mExitingStmts.push_back(BS);
226  }
227}
228
229void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
230  VisitStmt(CS);
231  if (mLoopDepth == 0) {
232    // Switch statements can have nested continues.
233    mExitingStmts.push_back(CS);
234  }
235}
236
237void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
238  mLoopDepth++;
239  VisitStmt(DS);
240  mLoopDepth--;
241}
242
243void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
244  mLoopDepth++;
245  VisitStmt(FS);
246  mLoopDepth--;
247}
248
249void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
250  mExitingStmts.push_back(RS);
251}
252
253void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
254  mSwitchDepth++;
255  VisitStmt(SS);
256  mSwitchDepth--;
257}
258
259void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
260  mLoopDepth++;
261  VisitStmt(WS);
262  mLoopDepth--;
263}
264
265clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
266                                 clang::Expr *RefRSVar,
267                                 clang::SourceLocation Loc) {
268  slangAssert(RefRSVar);
269  const clang::Type *T = RefRSVar->getType().getTypePtr();
270  slangAssert(!T->isArrayType() &&
271              "Should not be destroying arrays with this function");
272
273  clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
274  slangAssert((ClearObjectFD != nullptr) &&
275              "rsClearObject doesn't cover all RS object types");
276
277  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
278  clang::QualType ClearObjectFDArgType =
279      ClearObjectFD->getParamDecl(0)->getOriginalType();
280
281  // Example destructor for "rs_font localFont;"
282  //
283  // (CallExpr 'void'
284  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
285  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
286  //   (UnaryOperator 'rs_font *' prefix '&'
287  //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
288
289  // Get address of targeted RS object
290  clang::Expr *AddrRefRSVar =
291      new(C) clang::UnaryOperator(RefRSVar,
292                                  clang::UO_AddrOf,
293                                  ClearObjectFDArgType,
294                                  clang::VK_RValue,
295                                  clang::OK_Ordinary,
296                                  Loc);
297
298  clang::Expr *RefRSClearObjectFD =
299      clang::DeclRefExpr::Create(C,
300                                 clang::NestedNameSpecifierLoc(),
301                                 clang::SourceLocation(),
302                                 ClearObjectFD,
303                                 false,
304                                 ClearObjectFD->getLocation(),
305                                 ClearObjectFDType,
306                                 clang::VK_RValue,
307                                 nullptr);
308
309  clang::Expr *RSClearObjectFP =
310      clang::ImplicitCastExpr::Create(C,
311                                      C.getPointerType(ClearObjectFDType),
312                                      clang::CK_FunctionToPointerDecay,
313                                      RefRSClearObjectFD,
314                                      nullptr,
315                                      clang::VK_RValue);
316
317  llvm::SmallVector<clang::Expr*, 1> ArgList;
318  ArgList.push_back(AddrRefRSVar);
319
320  clang::CallExpr *RSClearObjectCall =
321      new(C) clang::CallExpr(C,
322                             RSClearObjectFP,
323                             ArgList,
324                             ClearObjectFD->getCallResultType(),
325                             clang::VK_RValue,
326                             Loc);
327
328  return RSClearObjectCall;
329}
330
331static int ArrayDim(const clang::Type *T) {
332  if (!T || !T->isArrayType()) {
333    return 0;
334  }
335
336  const clang::ConstantArrayType *CAT =
337    static_cast<const clang::ConstantArrayType *>(T);
338  return static_cast<int>(CAT->getSize().getSExtValue());
339}
340
341clang::Stmt *ClearStructRSObject(
342    clang::ASTContext &C,
343    clang::DeclContext *DC,
344    clang::Expr *RefRSStruct,
345    clang::SourceLocation StartLoc,
346    clang::SourceLocation Loc);
347
348clang::Stmt *ClearArrayRSObject(
349    clang::ASTContext &C,
350    clang::DeclContext *DC,
351    clang::Expr *RefRSArr,
352    clang::SourceLocation StartLoc,
353    clang::SourceLocation Loc) {
354  const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
355  slangAssert(BaseType->isArrayType());
356
357  int NumArrayElements = ArrayDim(BaseType);
358  // Actually extract out the base RS object type for use later
359  BaseType = BaseType->getArrayElementTypeNoTypeQual();
360
361  if (NumArrayElements <= 0) {
362    return nullptr;
363  }
364
365  // Example destructor loop for "rs_font fontArr[10];"
366  //
367  // (ForStmt
368  //   (DeclStmt
369  //     (VarDecl used rsIntIter 'int' cinit
370  //       (IntegerLiteral 'int' 0)))
371  //   (BinaryOperator 'int' '<'
372  //     (ImplicitCastExpr int LValueToRValue
373  //       (DeclRefExpr 'int' Var='rsIntIter'))
374  //     (IntegerLiteral 'int' 10)
375  //   nullptr << CondVar >>
376  //   (UnaryOperator 'int' postfix '++'
377  //     (DeclRefExpr 'int' Var='rsIntIter'))
378  //   (CallExpr 'void'
379  //     (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
380  //       (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
381  //     (UnaryOperator 'rs_font *' prefix '&'
382  //       (ArraySubscriptExpr 'rs_font':'rs_font'
383  //         (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
384  //           (DeclRefExpr 'rs_font [10]' Var='fontArr'))
385  //         (DeclRefExpr 'int' Var='rsIntIter'))))))
386
387  // Create helper variable for iterating through elements
388  static unsigned sIterCounter = 0;
389  std::stringstream UniqueIterName;
390  UniqueIterName << "rsIntIter" << sIterCounter++;
391  clang::IdentifierInfo *II = &C.Idents.get(UniqueIterName.str());
392  clang::VarDecl *IIVD =
393      clang::VarDecl::Create(C,
394                             DC,
395                             StartLoc,
396                             Loc,
397                             II,
398                             C.IntTy,
399                             C.getTrivialTypeSourceInfo(C.IntTy),
400                             clang::SC_None);
401  // Mark "rsIntIter" as used
402  IIVD->markUsed(C);
403
404  // Form the actual destructor loop
405  // for (Init; Cond; Inc)
406  //   RSClearObjectCall;
407
408  // Init -> "int rsIntIter = 0"
409  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
410      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
411  IIVD->setInit(Int0);
412
413  clang::Decl *IID = (clang::Decl *)IIVD;
414  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
415  clang::Stmt *Init = new(C) clang::DeclStmt(DGR, Loc, Loc);
416
417  // Cond -> "rsIntIter < NumArrayElements"
418  clang::DeclRefExpr *RefrsIntIterLValue =
419      clang::DeclRefExpr::Create(C,
420                                 clang::NestedNameSpecifierLoc(),
421                                 clang::SourceLocation(),
422                                 IIVD,
423                                 false,
424                                 Loc,
425                                 C.IntTy,
426                                 clang::VK_LValue,
427                                 nullptr);
428
429  clang::Expr *RefrsIntIterRValue =
430      clang::ImplicitCastExpr::Create(C,
431                                      RefrsIntIterLValue->getType(),
432                                      clang::CK_LValueToRValue,
433                                      RefrsIntIterLValue,
434                                      nullptr,
435                                      clang::VK_RValue);
436
437  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
438      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
439
440  clang::BinaryOperator *Cond =
441      new(C) clang::BinaryOperator(RefrsIntIterRValue,
442                                   NumArrayElementsExpr,
443                                   clang::BO_LT,
444                                   C.IntTy,
445                                   clang::VK_RValue,
446                                   clang::OK_Ordinary,
447                                   Loc,
448                                   false);
449
450  // Inc -> "rsIntIter++"
451  clang::UnaryOperator *Inc =
452      new(C) clang::UnaryOperator(RefrsIntIterLValue,
453                                  clang::UO_PostInc,
454                                  C.IntTy,
455                                  clang::VK_RValue,
456                                  clang::OK_Ordinary,
457                                  Loc);
458
459  // Body -> "rsClearObject(&VD[rsIntIter]);"
460  // Destructor loop operates on individual array elements
461
462  clang::Expr *RefRSArrPtr =
463      clang::ImplicitCastExpr::Create(C,
464          C.getPointerType(BaseType->getCanonicalTypeInternal()),
465          clang::CK_ArrayToPointerDecay,
466          RefRSArr,
467          nullptr,
468          clang::VK_RValue);
469
470  clang::Expr *RefRSArrPtrSubscript =
471      new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
472                                       RefrsIntIterRValue,
473                                       BaseType->getCanonicalTypeInternal(),
474                                       clang::VK_RValue,
475                                       clang::OK_Ordinary,
476                                       Loc);
477
478  DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
479
480  clang::Stmt *RSClearObjectCall = nullptr;
481  if (BaseType->isArrayType()) {
482    RSClearObjectCall =
483        ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
484  } else if (DT == DataTypeUnknown) {
485    RSClearObjectCall =
486        ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
487  } else {
488    RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
489  }
490
491  clang::ForStmt *DestructorLoop =
492      new(C) clang::ForStmt(C,
493                            Init,
494                            Cond,
495                            nullptr,  // no condVar
496                            Inc,
497                            RSClearObjectCall,
498                            Loc,
499                            Loc,
500                            Loc);
501
502  return DestructorLoop;
503}
504
505unsigned CountRSObjectTypes(const clang::Type *T) {
506  slangAssert(T);
507  unsigned RSObjectCount = 0;
508
509  if (T->isArrayType()) {
510    return CountRSObjectTypes(T->getArrayElementTypeNoTypeQual());
511  }
512
513  DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
514  if (DT != DataTypeUnknown) {
515    return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
516  }
517
518  if (T->isUnionType()) {
519    clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
520    RD = RD->getDefinition();
521    for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
522           FE = RD->field_end();
523         FI != FE;
524         FI++) {
525      const clang::FieldDecl *FD = *FI;
526      const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
527      if (CountRSObjectTypes(FT)) {
528        slangAssert(false && "can't have unions with RS object types!");
529        return 0;
530      }
531    }
532  }
533
534  if (!T->isStructureType()) {
535    return 0;
536  }
537
538  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
539  RD = RD->getDefinition();
540  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
541         FE = RD->field_end();
542       FI != FE;
543       FI++) {
544    const clang::FieldDecl *FD = *FI;
545    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
546    if (CountRSObjectTypes(FT)) {
547      // Sub-structs should only count once (as should arrays, etc.)
548      RSObjectCount++;
549    }
550  }
551
552  return RSObjectCount;
553}
554
555clang::Stmt *ClearStructRSObject(
556    clang::ASTContext &C,
557    clang::DeclContext *DC,
558    clang::Expr *RefRSStruct,
559    clang::SourceLocation StartLoc,
560    clang::SourceLocation Loc) {
561  const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
562
563  slangAssert(!BaseType->isArrayType());
564
565  // Structs should show up as unknown primitive types
566  slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
567              DataTypeUnknown);
568
569  unsigned FieldsToDestroy = CountRSObjectTypes(BaseType);
570  slangAssert(FieldsToDestroy != 0);
571
572  unsigned StmtCount = 0;
573  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
574  for (unsigned i = 0; i < FieldsToDestroy; i++) {
575    StmtArray[i] = nullptr;
576  }
577
578  // Populate StmtArray by creating a destructor for each RS object field
579  clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
580  RD = RD->getDefinition();
581  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
582         FE = RD->field_end();
583       FI != FE;
584       FI++) {
585    // We just look through all field declarations to see if we find a
586    // declaration for an RS object type (or an array of one).
587    bool IsArrayType = false;
588    clang::FieldDecl *FD = *FI;
589    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
590    const clang::Type *OrigType = FT;
591    while (FT && FT->isArrayType()) {
592      FT = FT->getArrayElementTypeNoTypeQual();
593      IsArrayType = true;
594    }
595
596    // Pass a DeclarationNameInfo with a valid DeclName, since name equality
597    // gets asserted during CodeGen.
598    clang::DeclarationNameInfo FDDeclNameInfo(FD->getDeclName(),
599                                              FD->getLocation());
600
601    if (RSExportPrimitiveType::IsRSObjectType(FT)) {
602      clang::DeclAccessPair FoundDecl =
603          clang::DeclAccessPair::make(FD, clang::AS_none);
604      clang::MemberExpr *RSObjectMember =
605          clang::MemberExpr::Create(C,
606                                    RefRSStruct,
607                                    false,
608                                    clang::SourceLocation(),
609                                    clang::NestedNameSpecifierLoc(),
610                                    clang::SourceLocation(),
611                                    FD,
612                                    FoundDecl,
613                                    FDDeclNameInfo,
614                                    nullptr,
615                                    OrigType->getCanonicalTypeInternal(),
616                                    clang::VK_RValue,
617                                    clang::OK_Ordinary);
618
619      slangAssert(StmtCount < FieldsToDestroy);
620
621      if (IsArrayType) {
622        StmtArray[StmtCount++] = ClearArrayRSObject(C,
623                                                    DC,
624                                                    RSObjectMember,
625                                                    StartLoc,
626                                                    Loc);
627      } else {
628        StmtArray[StmtCount++] = ClearSingleRSObject(C,
629                                                     RSObjectMember,
630                                                     Loc);
631      }
632    } else if (FT->isStructureType() && CountRSObjectTypes(FT)) {
633      // In this case, we have a nested struct. We may not end up filling all
634      // of the spaces in StmtArray (sub-structs should handle themselves
635      // with separate compound statements).
636      clang::DeclAccessPair FoundDecl =
637          clang::DeclAccessPair::make(FD, clang::AS_none);
638      clang::MemberExpr *RSObjectMember =
639          clang::MemberExpr::Create(C,
640                                    RefRSStruct,
641                                    false,
642                                    clang::SourceLocation(),
643                                    clang::NestedNameSpecifierLoc(),
644                                    clang::SourceLocation(),
645                                    FD,
646                                    FoundDecl,
647                                    clang::DeclarationNameInfo(),
648                                    nullptr,
649                                    OrigType->getCanonicalTypeInternal(),
650                                    clang::VK_RValue,
651                                    clang::OK_Ordinary);
652
653      if (IsArrayType) {
654        StmtArray[StmtCount++] = ClearArrayRSObject(C,
655                                                    DC,
656                                                    RSObjectMember,
657                                                    StartLoc,
658                                                    Loc);
659      } else {
660        StmtArray[StmtCount++] = ClearStructRSObject(C,
661                                                     DC,
662                                                     RSObjectMember,
663                                                     StartLoc,
664                                                     Loc);
665      }
666    }
667  }
668
669  slangAssert(StmtCount > 0);
670  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
671      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
672
673  delete [] StmtArray;
674
675  return CS;
676}
677
678clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
679                                     clang::Expr *DstExpr,
680                                     clang::Expr *SrcExpr,
681                                     clang::SourceLocation StartLoc,
682                                     clang::SourceLocation Loc) {
683  const clang::Type *T = DstExpr->getType().getTypePtr();
684  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
685  slangAssert((SetObjectFD != nullptr) &&
686              "rsSetObject doesn't cover all RS object types");
687
688  clang::QualType SetObjectFDType = SetObjectFD->getType();
689  clang::QualType SetObjectFDArgType[2];
690  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
691  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
692
693  clang::Expr *RefRSSetObjectFD =
694      clang::DeclRefExpr::Create(C,
695                                 clang::NestedNameSpecifierLoc(),
696                                 clang::SourceLocation(),
697                                 SetObjectFD,
698                                 false,
699                                 Loc,
700                                 SetObjectFDType,
701                                 clang::VK_RValue,
702                                 nullptr);
703
704  clang::Expr *RSSetObjectFP =
705      clang::ImplicitCastExpr::Create(C,
706                                      C.getPointerType(SetObjectFDType),
707                                      clang::CK_FunctionToPointerDecay,
708                                      RefRSSetObjectFD,
709                                      nullptr,
710                                      clang::VK_RValue);
711
712  llvm::SmallVector<clang::Expr*, 2> ArgList;
713  ArgList.push_back(new(C) clang::UnaryOperator(DstExpr,
714                                                clang::UO_AddrOf,
715                                                SetObjectFDArgType[0],
716                                                clang::VK_RValue,
717                                                clang::OK_Ordinary,
718                                                Loc));
719  ArgList.push_back(SrcExpr);
720
721  clang::CallExpr *RSSetObjectCall =
722      new(C) clang::CallExpr(C,
723                             RSSetObjectFP,
724                             ArgList,
725                             SetObjectFD->getCallResultType(),
726                             clang::VK_RValue,
727                             Loc);
728
729  return RSSetObjectCall;
730}
731
732clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
733                                     clang::Expr *LHS,
734                                     clang::Expr *RHS,
735                                     clang::SourceLocation StartLoc,
736                                     clang::SourceLocation Loc);
737
738/*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
739                                           clang::Expr *DstArr,
740                                           clang::Expr *SrcArr,
741                                           clang::SourceLocation StartLoc,
742                                           clang::SourceLocation Loc) {
743  clang::DeclContext *DC = nullptr;
744  const clang::Type *BaseType = DstArr->getType().getTypePtr();
745  slangAssert(BaseType->isArrayType());
746
747  int NumArrayElements = ArrayDim(BaseType);
748  // Actually extract out the base RS object type for use later
749  BaseType = BaseType->getArrayElementTypeNoTypeQual();
750
751  clang::Stmt *StmtArray[2] = {nullptr};
752  int StmtCtr = 0;
753
754  if (NumArrayElements <= 0) {
755    return nullptr;
756  }
757
758  // Create helper variable for iterating through elements
759  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
760  clang::VarDecl *IIVD =
761      clang::VarDecl::Create(C,
762                             DC,
763                             StartLoc,
764                             Loc,
765                             &II,
766                             C.IntTy,
767                             C.getTrivialTypeSourceInfo(C.IntTy),
768                             clang::SC_None,
769                             clang::SC_None);
770  clang::Decl *IID = (clang::Decl *)IIVD;
771
772  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
773  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
774
775  // Form the actual loop
776  // for (Init; Cond; Inc)
777  //   RSSetObjectCall;
778
779  // Init -> "rsIntIter = 0"
780  clang::DeclRefExpr *RefrsIntIter =
781      clang::DeclRefExpr::Create(C,
782                                 clang::NestedNameSpecifierLoc(),
783                                 IIVD,
784                                 Loc,
785                                 C.IntTy,
786                                 clang::VK_RValue,
787                                 nullptr);
788
789  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
790      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
791
792  clang::BinaryOperator *Init =
793      new(C) clang::BinaryOperator(RefrsIntIter,
794                                   Int0,
795                                   clang::BO_Assign,
796                                   C.IntTy,
797                                   clang::VK_RValue,
798                                   clang::OK_Ordinary,
799                                   Loc);
800
801  // Cond -> "rsIntIter < NumArrayElements"
802  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
803      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
804
805  clang::BinaryOperator *Cond =
806      new(C) clang::BinaryOperator(RefrsIntIter,
807                                   NumArrayElementsExpr,
808                                   clang::BO_LT,
809                                   C.IntTy,
810                                   clang::VK_RValue,
811                                   clang::OK_Ordinary,
812                                   Loc);
813
814  // Inc -> "rsIntIter++"
815  clang::UnaryOperator *Inc =
816      new(C) clang::UnaryOperator(RefrsIntIter,
817                                  clang::UO_PostInc,
818                                  C.IntTy,
819                                  clang::VK_RValue,
820                                  clang::OK_Ordinary,
821                                  Loc);
822
823  // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
824  // Loop operates on individual array elements
825
826  clang::Expr *DstArrPtr =
827      clang::ImplicitCastExpr::Create(C,
828          C.getPointerType(BaseType->getCanonicalTypeInternal()),
829          clang::CK_ArrayToPointerDecay,
830          DstArr,
831          nullptr,
832          clang::VK_RValue);
833
834  clang::Expr *DstArrPtrSubscript =
835      new(C) clang::ArraySubscriptExpr(DstArrPtr,
836                                       RefrsIntIter,
837                                       BaseType->getCanonicalTypeInternal(),
838                                       clang::VK_RValue,
839                                       clang::OK_Ordinary,
840                                       Loc);
841
842  clang::Expr *SrcArrPtr =
843      clang::ImplicitCastExpr::Create(C,
844          C.getPointerType(BaseType->getCanonicalTypeInternal()),
845          clang::CK_ArrayToPointerDecay,
846          SrcArr,
847          nullptr,
848          clang::VK_RValue);
849
850  clang::Expr *SrcArrPtrSubscript =
851      new(C) clang::ArraySubscriptExpr(SrcArrPtr,
852                                       RefrsIntIter,
853                                       BaseType->getCanonicalTypeInternal(),
854                                       clang::VK_RValue,
855                                       clang::OK_Ordinary,
856                                       Loc);
857
858  DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
859
860  clang::Stmt *RSSetObjectCall = nullptr;
861  if (BaseType->isArrayType()) {
862    RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
863                                             SrcArrPtrSubscript,
864                                             StartLoc, Loc);
865  } else if (DT == DataTypeUnknown) {
866    RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
867                                              SrcArrPtrSubscript,
868                                              StartLoc, Loc);
869  } else {
870    RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
871                                              SrcArrPtrSubscript,
872                                              StartLoc, Loc);
873  }
874
875  clang::ForStmt *DestructorLoop =
876      new(C) clang::ForStmt(C,
877                            Init,
878                            Cond,
879                            nullptr,  // no condVar
880                            Inc,
881                            RSSetObjectCall,
882                            Loc,
883                            Loc,
884                            Loc);
885
886  StmtArray[StmtCtr++] = DestructorLoop;
887  slangAssert(StmtCtr == 2);
888
889  clang::CompoundStmt *CS =
890      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
891
892  return CS;
893} */
894
895clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
896                                     clang::Expr *LHS,
897                                     clang::Expr *RHS,
898                                     clang::SourceLocation StartLoc,
899                                     clang::SourceLocation Loc) {
900  clang::QualType QT = LHS->getType();
901  const clang::Type *T = QT.getTypePtr();
902  slangAssert(T->isStructureType());
903  slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
904
905  // Keep an extra slot for the original copy (memcpy)
906  unsigned FieldsToSet = CountRSObjectTypes(T) + 1;
907
908  unsigned StmtCount = 0;
909  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
910  for (unsigned i = 0; i < FieldsToSet; i++) {
911    StmtArray[i] = nullptr;
912  }
913
914  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
915  RD = RD->getDefinition();
916  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
917         FE = RD->field_end();
918       FI != FE;
919       FI++) {
920    bool IsArrayType = false;
921    clang::FieldDecl *FD = *FI;
922    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
923    const clang::Type *OrigType = FT;
924
925    if (!CountRSObjectTypes(FT)) {
926      // Skip to next if we don't have any viable RS object types
927      continue;
928    }
929
930    clang::DeclAccessPair FoundDecl =
931        clang::DeclAccessPair::make(FD, clang::AS_none);
932    clang::MemberExpr *DstMember =
933        clang::MemberExpr::Create(C,
934                                  LHS,
935                                  false,
936                                  clang::SourceLocation(),
937                                  clang::NestedNameSpecifierLoc(),
938                                  clang::SourceLocation(),
939                                  FD,
940                                  FoundDecl,
941                                  clang::DeclarationNameInfo(
942                                      FD->getDeclName(),
943                                      clang::SourceLocation()),
944                                  nullptr,
945                                  OrigType->getCanonicalTypeInternal(),
946                                  clang::VK_RValue,
947                                  clang::OK_Ordinary);
948
949    clang::MemberExpr *SrcMember =
950        clang::MemberExpr::Create(C,
951                                  RHS,
952                                  false,
953                                  clang::SourceLocation(),
954                                  clang::NestedNameSpecifierLoc(),
955                                  clang::SourceLocation(),
956                                  FD,
957                                  FoundDecl,
958                                  clang::DeclarationNameInfo(),
959                                  nullptr,
960                                  OrigType->getCanonicalTypeInternal(),
961                                  clang::VK_RValue,
962                                  clang::OK_Ordinary);
963
964    if (FT->isArrayType()) {
965      FT = FT->getArrayElementTypeNoTypeQual();
966      IsArrayType = true;
967    }
968
969    DataType DT = RSExportPrimitiveType::GetRSSpecificType(FT);
970
971    if (IsArrayType) {
972      clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
973      DiagEngine.Report(
974        clang::FullSourceLoc(Loc, C.getSourceManager()),
975        DiagEngine.getCustomDiagID(
976          clang::DiagnosticsEngine::Error,
977          "Arrays of RS object types within structures cannot be copied"));
978      // TODO(srhines): Support setting arrays of RS objects
979      // StmtArray[StmtCount++] =
980      //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
981    } else if (DT == DataTypeUnknown) {
982      StmtArray[StmtCount++] =
983          CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
984    } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
985      StmtArray[StmtCount++] =
986          CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
987    } else {
988      slangAssert(false);
989    }
990  }
991
992  slangAssert(StmtCount < FieldsToSet);
993
994  // We still need to actually do the overall struct copy. For simplicity,
995  // we just do a straight-up assignment (which will still preserve all
996  // the proper RS object reference counts).
997  clang::BinaryOperator *CopyStruct =
998      new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
999                                   clang::VK_RValue, clang::OK_Ordinary, Loc,
1000                                   false);
1001  StmtArray[StmtCount++] = CopyStruct;
1002
1003  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
1004      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
1005
1006  delete [] StmtArray;
1007
1008  return CS;
1009}
1010
1011}  // namespace
1012
1013void RSObjectRefCount::Scope::InsertStmt(const clang::ASTContext &C,
1014                                         clang::Stmt *NewStmt) {
1015  std::vector<clang::Stmt*> newBody;
1016  for (clang::Stmt* S1 : mCS->body()) {
1017    if (S1 == mCurrent) {
1018      newBody.push_back(NewStmt);
1019    }
1020    newBody.push_back(S1);
1021  }
1022  mCS->setStmts(C, newBody);
1023}
1024
1025void RSObjectRefCount::Scope::ReplaceStmt(const clang::ASTContext &C,
1026                                          clang::Stmt *NewStmt) {
1027  std::vector<clang::Stmt*> newBody;
1028  for (clang::Stmt* S1 : mCS->body()) {
1029    if (S1 == mCurrent) {
1030      newBody.push_back(NewStmt);
1031    } else {
1032      newBody.push_back(S1);
1033    }
1034  }
1035  mCS->setStmts(C, newBody);
1036}
1037
1038void RSObjectRefCount::Scope::ReplaceExpr(const clang::ASTContext& C,
1039                                          clang::Expr* OldExpr,
1040                                          clang::Expr* NewExpr) {
1041  RSASTReplace R(C);
1042  R.ReplaceStmt(mCurrent, OldExpr, NewExpr);
1043}
1044
1045void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1046    clang::BinaryOperator *AS) {
1047
1048  clang::QualType QT = AS->getType();
1049
1050  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1051      DataTypeRSAllocation)->getASTContext();
1052
1053  clang::SourceLocation Loc = AS->getExprLoc();
1054  clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
1055  clang::Stmt *UpdatedStmt = nullptr;
1056
1057  if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1058    // By definition, this is a struct assignment if we get here
1059    UpdatedStmt =
1060        CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1061  } else {
1062    UpdatedStmt =
1063        CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1064  }
1065
1066  RSASTReplace R(C);
1067  R.ReplaceStmt(mCS, AS, UpdatedStmt);
1068}
1069
1070void RSObjectRefCount::Scope::AppendRSObjectInit(
1071    clang::VarDecl *VD,
1072    clang::DeclStmt *DS,
1073    DataType DT,
1074    clang::Expr *InitExpr) {
1075  slangAssert(VD);
1076
1077  if (!InitExpr) {
1078    return;
1079  }
1080
1081  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1082      DataTypeRSAllocation)->getASTContext();
1083  clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1084      DataTypeRSAllocation)->getLocation();
1085  clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1086      DataTypeRSAllocation)->getInnerLocStart();
1087
1088  if (DT == DataTypeIsStruct) {
1089    const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1090    clang::DeclRefExpr *RefRSVar =
1091        clang::DeclRefExpr::Create(C,
1092                                   clang::NestedNameSpecifierLoc(),
1093                                   clang::SourceLocation(),
1094                                   VD,
1095                                   false,
1096                                   Loc,
1097                                   T->getCanonicalTypeInternal(),
1098                                   clang::VK_RValue,
1099                                   nullptr);
1100
1101    clang::Stmt *RSSetObjectOps =
1102        CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1103
1104    std::list<clang::Stmt*> StmtList;
1105    StmtList.push_back(RSSetObjectOps);
1106    AppendAfterStmt(C, mCS, DS, StmtList);
1107    return;
1108  }
1109
1110  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1111  slangAssert((SetObjectFD != nullptr) &&
1112              "rsSetObject doesn't cover all RS object types");
1113
1114  clang::QualType SetObjectFDType = SetObjectFD->getType();
1115  clang::QualType SetObjectFDArgType[2];
1116  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1117  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1118
1119  clang::Expr *RefRSSetObjectFD =
1120      clang::DeclRefExpr::Create(C,
1121                                 clang::NestedNameSpecifierLoc(),
1122                                 clang::SourceLocation(),
1123                                 SetObjectFD,
1124                                 false,
1125                                 Loc,
1126                                 SetObjectFDType,
1127                                 clang::VK_RValue,
1128                                 nullptr);
1129
1130  clang::Expr *RSSetObjectFP =
1131      clang::ImplicitCastExpr::Create(C,
1132                                      C.getPointerType(SetObjectFDType),
1133                                      clang::CK_FunctionToPointerDecay,
1134                                      RefRSSetObjectFD,
1135                                      nullptr,
1136                                      clang::VK_RValue);
1137
1138  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1139  clang::DeclRefExpr *RefRSVar =
1140      clang::DeclRefExpr::Create(C,
1141                                 clang::NestedNameSpecifierLoc(),
1142                                 clang::SourceLocation(),
1143                                 VD,
1144                                 false,
1145                                 Loc,
1146                                 T->getCanonicalTypeInternal(),
1147                                 clang::VK_RValue,
1148                                 nullptr);
1149
1150  llvm::SmallVector<clang::Expr*, 2> ArgList;
1151  ArgList.push_back(new(C) clang::UnaryOperator(RefRSVar,
1152                                                clang::UO_AddrOf,
1153                                                SetObjectFDArgType[0],
1154                                                clang::VK_RValue,
1155                                                clang::OK_Ordinary,
1156                                                Loc));
1157  ArgList.push_back(InitExpr);
1158
1159  clang::CallExpr *RSSetObjectCall =
1160      new(C) clang::CallExpr(C,
1161                             RSSetObjectFP,
1162                             ArgList,
1163                             SetObjectFD->getCallResultType(),
1164                             clang::VK_RValue,
1165                             Loc);
1166
1167  std::list<clang::Stmt*> StmtList;
1168  StmtList.push_back(RSSetObjectCall);
1169  AppendAfterStmt(C, mCS, DS, StmtList);
1170}
1171
1172void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1173  if (mRSO.empty()) {
1174    return;
1175  }
1176
1177  clang::DeclContext* DC = mRSO.front()->getDeclContext();
1178  clang::ASTContext& C = DC->getParentASTContext();
1179  clang::SourceManager& SM = C.getSourceManager();
1180
1181  const auto& OccursBefore = [&SM] (clang::SourceLocation L1, clang::SourceLocation L2)->bool {
1182    return SM.isBeforeInTranslationUnit(L1, L2);
1183  };
1184  typedef std::map<clang::SourceLocation, clang::Stmt*, decltype(OccursBefore)> DMap;
1185
1186  DMap dtors(OccursBefore);
1187
1188  // Create rsClearObject calls. Note the DMap entries are sorted by the SourceLocation.
1189  for (clang::VarDecl* VD : mRSO) {
1190    clang::SourceLocation Loc = VD->getSourceRange().getBegin();
1191    clang::Stmt* RSClearObjectCall = ClearRSObject(VD, DC);
1192    dtors.insert(std::make_pair(Loc, RSClearObjectCall));
1193  }
1194
1195  DestructorVisitor Visitor;
1196  Visitor.Visit(mCS);
1197
1198  // Replace each exiting statement with a block that contains the original statement
1199  // and added rsClearObject() calls before it.
1200  for (clang::Stmt* S : Visitor.getExitingStmts()) {
1201
1202    const clang::SourceLocation currentLoc = S->getLocStart();
1203
1204    DMap::iterator firstDtorIter = dtors.begin();
1205    DMap::iterator currentDtorIter = firstDtorIter;
1206    DMap::iterator lastDtorIter = dtors.end();
1207
1208    while (currentDtorIter != lastDtorIter &&
1209           OccursBefore(currentDtorIter->first, currentLoc)) {
1210      currentDtorIter++;
1211    }
1212
1213    if (currentDtorIter == firstDtorIter) {
1214      continue;
1215    }
1216
1217    std::list<clang::Stmt*> Stmts;
1218
1219    // Insert rsClearObject() calls for all rsObjects declared before the current statement
1220    for(DMap::iterator it = firstDtorIter; it != currentDtorIter; it++) {
1221      Stmts.push_back(it->second);
1222    }
1223    Stmts.push_back(S);
1224
1225    RSASTReplace R(C);
1226    clang::CompoundStmt* CS = BuildCompoundStmt(C, Stmts, S->getLocEnd());
1227    R.ReplaceStmt(mCS, S, CS);
1228  }
1229
1230  std::list<clang::Stmt*> Stmts;
1231  for(auto LocCallPair : dtors) {
1232    Stmts.push_back(LocCallPair.second);
1233  }
1234  AppendAfterStmt(C, mCS, nullptr, Stmts);
1235}
1236
1237clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
1238    clang::VarDecl *VD,
1239    clang::DeclContext *DC) {
1240  slangAssert(VD);
1241  clang::ASTContext &C = VD->getASTContext();
1242  clang::SourceLocation Loc = VD->getLocation();
1243  clang::SourceLocation StartLoc = VD->getInnerLocStart();
1244  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1245
1246  // Reference expr to target RS object variable
1247  clang::DeclRefExpr *RefRSVar =
1248      clang::DeclRefExpr::Create(C,
1249                                 clang::NestedNameSpecifierLoc(),
1250                                 clang::SourceLocation(),
1251                                 VD,
1252                                 false,
1253                                 Loc,
1254                                 T->getCanonicalTypeInternal(),
1255                                 clang::VK_RValue,
1256                                 nullptr);
1257
1258  if (T->isArrayType()) {
1259    return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1260  }
1261
1262  DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
1263
1264  if (DT == DataTypeUnknown ||
1265      DT == DataTypeIsStruct) {
1266    return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1267  }
1268
1269  slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1270              "Should be RS object");
1271
1272  return ClearSingleRSObject(C, RefRSVar, Loc);
1273}
1274
1275bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1276                                          DataType *DT,
1277                                          clang::Expr **InitExpr) {
1278  slangAssert(VD && DT && InitExpr);
1279  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1280
1281  // Loop through array types to get to base type
1282  while (T && T->isArrayType()) {
1283    T = T->getArrayElementTypeNoTypeQual();
1284  }
1285
1286  bool DataTypeIsStructWithRSObject = false;
1287  *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1288
1289  if (*DT == DataTypeUnknown) {
1290    if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1291      *DT = DataTypeIsStruct;
1292      DataTypeIsStructWithRSObject = true;
1293    } else {
1294      return false;
1295    }
1296  }
1297
1298  bool DataTypeIsRSObject = false;
1299  if (DataTypeIsStructWithRSObject) {
1300    DataTypeIsRSObject = true;
1301  } else {
1302    DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1303  }
1304  *InitExpr = VD->getInit();
1305
1306  if (!DataTypeIsRSObject && *InitExpr) {
1307    // If we already have an initializer for a matrix type, we are done.
1308    return DataTypeIsRSObject;
1309  }
1310
1311  clang::Expr *ZeroInitializer =
1312      CreateEmptyInitListExpr(VD->getASTContext(), VD->getLocation());
1313
1314  if (ZeroInitializer) {
1315    ZeroInitializer->setType(T->getCanonicalTypeInternal());
1316    VD->setInit(ZeroInitializer);
1317  }
1318
1319  return DataTypeIsRSObject;
1320}
1321
1322clang::Expr *RSObjectRefCount::CreateEmptyInitListExpr(
1323    clang::ASTContext &C,
1324    const clang::SourceLocation &Loc) {
1325
1326  // We can cheaply construct a zero initializer by just creating an empty
1327  // initializer list. Clang supports this extension to C(99), and will create
1328  // any necessary constructs to zero out the entire variable.
1329  llvm::SmallVector<clang::Expr*, 1> EmptyInitList;
1330  return new(C) clang::InitListExpr(C, Loc, EmptyInitList, Loc);
1331}
1332
1333clang::CompoundStmt* RSObjectRefCount::CreateRetStmtWithTempVar(
1334    clang::ASTContext& C,
1335    clang::DeclContext* DC,
1336    clang::ReturnStmt* RS,
1337    const unsigned id) {
1338  std::list<clang::Stmt*> NewStmts;
1339  // Since we insert rsClearObj() calls before the return statement, we need
1340  // to make sure none of the cleared RS objects are referenced in the
1341  // return statement.
1342  // For that, we create a new local variable named .rs.retval, assign the
1343  // original return expression to it, make all necessary rsClearObj()
1344  // calls, then return .rs.retval. Note rsClearObj() is not called on
1345  // .rs.retval.
1346
1347  clang::SourceLocation Loc = RS->getLocStart();
1348  std::stringstream ss;
1349  ss << ".rs.retval" << id;
1350  llvm::StringRef VarName(ss.str());
1351
1352  clang::Expr* RetVal = RS->getRetValue();
1353  const clang::QualType RetTy = RetVal->getType();
1354  clang::VarDecl* RSRetValDecl = clang::VarDecl::Create(
1355      C,                                     // AST context
1356      DC,                                    // Decl context
1357      Loc,                                   // Start location
1358      Loc,                                   // Id location
1359      &C.Idents.get(VarName),                // Id
1360      RetTy,                                 // Type
1361      C.getTrivialTypeSourceInfo(RetTy),     // Type info
1362      clang::SC_None                         // Storage class
1363  );
1364  const clang::Type *T = RetTy.getTypePtr();
1365  clang::Expr *ZeroInitializer =
1366      RSObjectRefCount::CreateEmptyInitListExpr(C, Loc);
1367  ZeroInitializer->setType(T->getCanonicalTypeInternal());
1368  RSRetValDecl->setInit(ZeroInitializer);
1369  RSRetValDecl->markUsed(C);
1370  clang::Decl* Decls[] = { RSRetValDecl };
1371  const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
1372      C, Decls, sizeof(Decls) / sizeof(*Decls));
1373  clang::DeclStmt* DS = new (C) clang::DeclStmt(DGR, Loc, Loc);
1374  NewStmts.push_back(DS);
1375
1376  clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
1377      C,
1378      clang::NestedNameSpecifierLoc(),       // QualifierLoc
1379      Loc,                                   // TemplateKWLoc
1380      RSRetValDecl,
1381      false,                                 // RefersToEnclosingVariableOrCapture
1382      Loc,                                   // NameLoc
1383      RetTy,
1384      clang::VK_LValue
1385  );
1386  clang::Stmt* SetRetTempVar = CreateSingleRSSetObject(C, DRE, RetVal, Loc, Loc);
1387  NewStmts.push_back(SetRetTempVar);
1388
1389  // Creates a new return statement
1390  clang::ReturnStmt* NewRet = new (C) clang::ReturnStmt(RS->getReturnLoc());
1391  clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
1392      C,
1393      RetTy,
1394      clang::CK_LValueToRValue,
1395      DRE,
1396      nullptr,
1397      clang::VK_RValue
1398  );
1399  NewRet->setRetValue(CastExpr);
1400  NewStmts.push_back(NewRet);
1401
1402  return BuildCompoundStmt(C, NewStmts, Loc);
1403}
1404
1405void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1406  VisitStmt(DS);
1407  getCurrentScope()->setCurrentStmt(DS);
1408  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1409       I != E;
1410       I++) {
1411    clang::Decl *D = *I;
1412    if (D->getKind() == clang::Decl::Var) {
1413      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1414      DataType DT = DataTypeUnknown;
1415      clang::Expr *InitExpr = nullptr;
1416      if (InitializeRSObject(VD, &DT, &InitExpr)) {
1417        // We need to zero-init all RS object types (including matrices), ...
1418        getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1419        // ... but, only add to the list of RS objects if we have some
1420        // non-matrix RS object fields.
1421        if (CountRSObjectTypes(VD->getType().getTypePtr())) {
1422          getCurrentScope()->addRSObject(VD);
1423        }
1424      }
1425    }
1426  }
1427}
1428
1429void RSObjectRefCount::VisitCallExpr(clang::CallExpr* CE) {
1430  clang::QualType RetTy;
1431  const clang::FunctionDecl* FD = CE->getDirectCallee();
1432
1433  if (FD) {
1434    // Direct calls
1435
1436    RetTy = FD->getReturnType();
1437  } else {
1438    // Indirect calls through function pointers
1439
1440    const clang::Expr* Callee = CE->getCallee();
1441    const clang::Type* CalleeType = Callee->getType().getTypePtr();
1442    const clang::PointerType* PtrType = CalleeType->getAs<clang::PointerType>();
1443
1444    if (!PtrType) {
1445      return;
1446    }
1447
1448    const clang::Type* PointeeType = PtrType->getPointeeType().getTypePtr();
1449    const clang::FunctionType* FuncType = PointeeType->getAs<clang::FunctionType>();
1450
1451    if (!FuncType) {
1452      return;
1453    }
1454
1455    RetTy = FuncType->getReturnType();
1456  }
1457
1458  // The RenderScript runtime API maintains the invariant that the sysRef of a new RS object would
1459  // be 1, with the exception of rsGetAllocation() (deprecated in API 22), which leaves the sysRef
1460  // 0 for a new allocation. It is the responsibility of the callee of the API to decrement the
1461  // sysRef when a reference of the RS object goes out of scope. The compiler generates code to do
1462  // just that, by creating a temporary variable named ".rs.tmpN" with the result of
1463  // an RS-object-returning API directly assigned to it, and calling rsClearObject() on .rs.tmpN
1464  // right before it exits the current scope. Such code generation is skipped for rsGetAllocation()
1465  // to avoid decrementing its sysRef below zero.
1466
1467  if (CountRSObjectTypes(RetTy.getTypePtr())==0 ||
1468      (FD && FD->getName() == "rsGetAllocation")) {
1469    return;
1470  }
1471
1472  clang::SourceLocation Loc = CE->getSourceRange().getBegin();
1473  std::stringstream ss;
1474  ss << ".rs.tmp" << getNextID();
1475  llvm::StringRef VarName(ss.str());
1476
1477  clang::VarDecl* TempVarDecl = clang::VarDecl::Create(
1478      mCtx,                                  // AST context
1479      GetDeclContext(),                      // Decl context
1480      Loc,                                   // Start location
1481      Loc,                                   // Id location
1482      &mCtx.Idents.get(VarName),             // Id
1483      RetTy,                                 // Type
1484      mCtx.getTrivialTypeSourceInfo(RetTy),  // Type info
1485      clang::SC_None                         // Storage class
1486  );
1487  TempVarDecl->setInit(CE);
1488  TempVarDecl->markUsed(mCtx);
1489  clang::Decl* Decls[] = { TempVarDecl };
1490  const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
1491      mCtx, Decls, sizeof(Decls) / sizeof(*Decls));
1492  clang::DeclStmt* DS = new (mCtx) clang::DeclStmt(DGR, Loc, Loc);
1493
1494  getCurrentScope()->InsertStmt(mCtx, DS);
1495
1496  clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
1497      mCtx,                                  // AST context
1498      clang::NestedNameSpecifierLoc(),       // QualifierLoc
1499      Loc,                                   // TemplateKWLoc
1500      TempVarDecl,
1501      false,                                 // RefersToEnclosingVariableOrCapture
1502      Loc,                                   // NameLoc
1503      RetTy,
1504      clang::VK_LValue
1505  );
1506  clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
1507      mCtx,
1508      RetTy,
1509      clang::CK_LValueToRValue,
1510      DRE,
1511      nullptr,
1512      clang::VK_RValue
1513  );
1514
1515  getCurrentScope()->ReplaceExpr(mCtx, CE, CastExpr);
1516
1517  // Register TempVarDecl for destruction call (rsClearObj).
1518  getCurrentScope()->addRSObject(TempVarDecl);
1519}
1520
1521void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1522  if (!emptyScope()) {
1523    getCurrentScope()->setCurrentStmt(CS);
1524  }
1525
1526  if (!CS->body_empty()) {
1527    // Push a new scope
1528    Scope *S = new Scope(CS);
1529    mScopeStack.push_back(S);
1530
1531    VisitStmt(CS);
1532
1533    // Destroy the scope
1534    slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1535    S->InsertLocalVarDestructors();
1536    mScopeStack.pop_back();
1537    delete S;
1538  }
1539}
1540
1541void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1542  getCurrentScope()->setCurrentStmt(AS);
1543  clang::QualType QT = AS->getType();
1544
1545  if (CountRSObjectTypes(QT.getTypePtr())) {
1546    getCurrentScope()->ReplaceRSObjectAssignment(AS);
1547  }
1548}
1549
1550void RSObjectRefCount::VisitReturnStmt(clang::ReturnStmt *RS) {
1551  getCurrentScope()->setCurrentStmt(RS);
1552
1553  // If there is no local rsObject declared so far, no need to transform the
1554  // return statement.
1555
1556  bool RSObjDeclared = false;
1557
1558  for (const Scope* S : mScopeStack) {
1559    if (S->hasRSObject()) {
1560      RSObjDeclared = true;
1561      break;
1562    }
1563  }
1564
1565  if (!RSObjDeclared) {
1566    return;
1567  }
1568
1569  // If the return statement does not return anything, or if it does not return
1570  // a rsObject, no need to transform it.
1571
1572  clang::Expr* RetVal = RS->getRetValue();
1573  if (!RetVal || CountRSObjectTypes(RetVal->getType().getTypePtr()) == 0) {
1574    return;
1575  }
1576
1577  // Transform the return statement so that it does not potentially return or
1578  // reference a rsObject that has been cleared.
1579
1580  clang::CompoundStmt* NewRS;
1581  NewRS = CreateRetStmtWithTempVar(mCtx, GetDeclContext(), RS, getNextID());
1582
1583  getCurrentScope()->ReplaceStmt(mCtx, NewRS);
1584}
1585
1586void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1587  getCurrentScope()->setCurrentStmt(S);
1588  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1589       I != E;
1590       I++) {
1591    if (clang::Stmt *Child = *I) {
1592      Visit(Child);
1593    }
1594  }
1595}
1596
1597// This function walks the list of global variables and (potentially) creates
1598// a single global static destructor function that properly decrements
1599// reference counts on the contained RS object types.
1600clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
1601  Init();
1602
1603  clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
1604  clang::SourceLocation loc;
1605
1606  llvm::StringRef SR(".rs.dtor");
1607  clang::IdentifierInfo &II = mCtx.Idents.get(SR);
1608  clang::DeclarationName N(&II);
1609  clang::FunctionProtoType::ExtProtoInfo EPI;
1610  clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy,
1611      llvm::ArrayRef<clang::QualType>(), EPI);
1612  clang::FunctionDecl *FD = nullptr;
1613
1614  // Generate rsClearObject() call chains for every global variable
1615  // (whether static or extern).
1616  std::list<clang::Stmt *> StmtList;
1617  for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
1618          E = DC->decls_end(); I != E; I++) {
1619    clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
1620    if (VD) {
1621      if (CountRSObjectTypes(VD->getType().getTypePtr())) {
1622        if (!FD) {
1623          // Only create FD if we are going to use it.
1624          FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, nullptr,
1625                                           clang::SC_None);
1626        }
1627        // Mark VD as used.  It might be unused, except for the destructor.
1628        // 'markUsed' has side-effects that are caused only if VD is not already
1629        // used.  Hence no need for an extra check here.
1630        VD->markUsed(mCtx);
1631        // Make sure to create any helpers within the function's DeclContext,
1632        // not the one associated with the global translation unit.
1633        clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
1634        StmtList.push_back(RSClearObjectCall);
1635      }
1636    }
1637  }
1638
1639  // Nothing needs to be destroyed, so don't emit a dtor.
1640  if (StmtList.empty()) {
1641    return nullptr;
1642  }
1643
1644  clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
1645
1646  FD->setBody(CS);
1647  // We need some way to tell if this FD is generated by slang
1648  FD->setImplicit();
1649
1650  return FD;
1651}
1652
1653bool HasRSObjectType(const clang::Type *T) {
1654  return CountRSObjectTypes(T) != 0;
1655}
1656
1657}  // namespace slang
1658