1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_object_ref_count.h"
18
19#include "clang/AST/DeclGroup.h"
20#include "clang/AST/Expr.h"
21#include "clang/AST/NestedNameSpecifier.h"
22#include "clang/AST/OperationKinds.h"
23#include "clang/AST/RecursiveASTVisitor.h"
24#include "clang/AST/Stmt.h"
25#include "clang/AST/StmtVisitor.h"
26
27#include "slang_assert.h"
28#include "slang.h"
29#include "slang_rs_ast_replace.h"
30#include "slang_rs_export_type.h"
31
32namespace slang {
33
34/* Even though those two arrays are of size DataTypeMax, only entries that
35 * correspond to object types will be set.
36 */
37clang::FunctionDecl *
38RSObjectRefCount::RSSetObjectFD[DataTypeMax];
39clang::FunctionDecl *
40RSObjectRefCount::RSClearObjectFD[DataTypeMax];
41
42void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
43  for (unsigned i = 0; i < DataTypeMax; i++) {
44    RSSetObjectFD[i] = nullptr;
45    RSClearObjectFD[i] = nullptr;
46  }
47
48  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
49
50  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
51          E = TUDecl->decls_end(); I != E; I++) {
52    if ((I->getKind() >= clang::Decl::firstFunction) &&
53        (I->getKind() <= clang::Decl::lastFunction)) {
54      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
55
56      // points to RSSetObjectFD or RSClearObjectFD
57      clang::FunctionDecl **RSObjectFD;
58
59      if (FD->getName() == "rsSetObject") {
60        slangAssert((FD->getNumParams() == 2) &&
61                    "Invalid rsSetObject function prototype (# params)");
62        RSObjectFD = RSSetObjectFD;
63      } else if (FD->getName() == "rsClearObject") {
64        slangAssert((FD->getNumParams() == 1) &&
65                    "Invalid rsClearObject function prototype (# params)");
66        RSObjectFD = RSClearObjectFD;
67      } else {
68        continue;
69      }
70
71      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
72      clang::QualType PVT = PVD->getOriginalType();
73      // The first parameter must be a pointer like rs_allocation*
74      slangAssert(PVT->isPointerType() &&
75          "Invalid rs{Set,Clear}Object function prototype (pointer param)");
76
77      // The rs object type passed to the FD
78      clang::QualType RST = PVT->getPointeeType();
79      DataType DT = RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
80      slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
81             && "must be RS object type");
82
83      if (DT >= 0 && DT < DataTypeMax) {
84          RSObjectFD[DT] = FD;
85      } else {
86          slangAssert(false && "incorrect type");
87      }
88    }
89  }
90}
91
92namespace {
93
94unsigned CountRSObjectTypes(const clang::Type *T);
95
96clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
97                                     clang::Expr *DstExpr,
98                                     clang::Expr *SrcExpr,
99                                     clang::SourceLocation StartLoc,
100                                     clang::SourceLocation Loc);
101
102// This function constructs a new CompoundStmt from the input StmtList.
103clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
104      std::vector<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
105  unsigned NewStmtCount = StmtList.size();
106  unsigned CompoundStmtCount = 0;
107
108  clang::Stmt **CompoundStmtList;
109  CompoundStmtList = new clang::Stmt*[NewStmtCount];
110
111  std::vector<clang::Stmt*>::const_iterator I = StmtList.begin();
112  std::vector<clang::Stmt*>::const_iterator E = StmtList.end();
113  for ( ; I != E; I++) {
114    CompoundStmtList[CompoundStmtCount++] = *I;
115  }
116  slangAssert(CompoundStmtCount == NewStmtCount);
117
118  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
119      C, llvm::makeArrayRef(CompoundStmtList, CompoundStmtCount), Loc, Loc);
120
121  delete [] CompoundStmtList;
122
123  return CS;
124}
125
126void AppendAfterStmt(clang::ASTContext &C,
127                     clang::CompoundStmt *CS,
128                     clang::Stmt *S,
129                     std::list<clang::Stmt*> &StmtList) {
130  slangAssert(CS);
131  clang::CompoundStmt::body_iterator bI = CS->body_begin();
132  clang::CompoundStmt::body_iterator bE = CS->body_end();
133  clang::Stmt **UpdatedStmtList =
134      new clang::Stmt*[CS->size() + StmtList.size()];
135
136  unsigned UpdatedStmtCount = 0;
137  unsigned Once = 0;
138  for ( ; bI != bE; bI++) {
139    if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
140      // If we come across a return here, we don't have anything we can
141      // reasonably replace. We should have already inserted our destructor
142      // code in the proper spot, so we just clean up and return.
143      delete [] UpdatedStmtList;
144
145      return;
146    }
147
148    UpdatedStmtList[UpdatedStmtCount++] = *bI;
149
150    if ((*bI == S) && !Once) {
151      Once++;
152      std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
153      std::list<clang::Stmt*>::const_iterator E = StmtList.end();
154      for ( ; I != E; I++) {
155        UpdatedStmtList[UpdatedStmtCount++] = *I;
156      }
157    }
158  }
159  slangAssert(Once <= 1);
160
161  // When S is nullptr, we are appending to the end of the CompoundStmt.
162  if (!S) {
163    slangAssert(Once == 0);
164    std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
165    std::list<clang::Stmt*>::const_iterator E = StmtList.end();
166    for ( ; I != E; I++) {
167      UpdatedStmtList[UpdatedStmtCount++] = *I;
168    }
169  }
170
171  CS->setStmts(C, llvm::makeArrayRef(UpdatedStmtList, UpdatedStmtCount));
172
173  delete [] UpdatedStmtList;
174}
175
176// This class visits a compound statement and collects a list of all the exiting
177// statements, such as any return statement in any sub-block, and any
178// break/continue statement that would resume outside the current scope.
179// We do not handle the case for goto statements that leave a local scope.
180class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
181 private:
182  // The loop depth of the currently visited node.
183  int mLoopDepth;
184
185  // The switch statement depth of the currently visited node.
186  // Note that this is tracked separately from the loop depth because
187  // SwitchStmt-contained ContinueStmt's should have destructors for the
188  // corresponding loop scope.
189  int mSwitchDepth;
190
191  // Output of the visitor: the statements that should be replaced by compound
192  // statements, each of which contains rsClearObject() calls followed by the
193  // original statement.
194  std::vector<clang::Stmt*> mExitingStmts;
195
196 public:
197  DestructorVisitor() : mLoopDepth(0), mSwitchDepth(0) {}
198
199  const std::vector<clang::Stmt*>& getExitingStmts() const {
200    return mExitingStmts;
201  }
202
203  void VisitStmt(clang::Stmt *S);
204  void VisitBreakStmt(clang::BreakStmt *BS);
205  void VisitContinueStmt(clang::ContinueStmt *CS);
206  void VisitDoStmt(clang::DoStmt *DS);
207  void VisitForStmt(clang::ForStmt *FS);
208  void VisitReturnStmt(clang::ReturnStmt *RS);
209  void VisitSwitchStmt(clang::SwitchStmt *SS);
210  void VisitWhileStmt(clang::WhileStmt *WS);
211};
212
213void DestructorVisitor::VisitStmt(clang::Stmt *S) {
214  for (clang::Stmt* Child : S->children()) {
215    if (Child) {
216      Visit(Child);
217    }
218  }
219}
220
221void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
222  VisitStmt(BS);
223  if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
224    mExitingStmts.push_back(BS);
225  }
226}
227
228void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
229  VisitStmt(CS);
230  if (mLoopDepth == 0) {
231    // Switch statements can have nested continues.
232    mExitingStmts.push_back(CS);
233  }
234}
235
236void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
237  mLoopDepth++;
238  VisitStmt(DS);
239  mLoopDepth--;
240}
241
242void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
243  mLoopDepth++;
244  VisitStmt(FS);
245  mLoopDepth--;
246}
247
248void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
249  mExitingStmts.push_back(RS);
250}
251
252void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
253  mSwitchDepth++;
254  VisitStmt(SS);
255  mSwitchDepth--;
256}
257
258void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
259  mLoopDepth++;
260  VisitStmt(WS);
261  mLoopDepth--;
262}
263
264clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
265                                 clang::Expr *RefRSVar,
266                                 clang::SourceLocation Loc) {
267  slangAssert(RefRSVar);
268  const clang::Type *T = RefRSVar->getType().getTypePtr();
269  slangAssert(!T->isArrayType() &&
270              "Should not be destroying arrays with this function");
271
272  clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
273  slangAssert((ClearObjectFD != nullptr) &&
274              "rsClearObject doesn't cover all RS object types");
275
276  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
277  clang::QualType ClearObjectFDArgType =
278      ClearObjectFD->getParamDecl(0)->getOriginalType();
279
280  // Example destructor for "rs_font localFont;"
281  //
282  // (CallExpr 'void'
283  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
284  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
285  //   (UnaryOperator 'rs_font *' prefix '&'
286  //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
287
288  // Get address of targeted RS object
289  clang::Expr *AddrRefRSVar =
290      new(C) clang::UnaryOperator(RefRSVar,
291                                  clang::UO_AddrOf,
292                                  ClearObjectFDArgType,
293                                  clang::VK_RValue,
294                                  clang::OK_Ordinary,
295                                  Loc);
296
297  clang::Expr *RefRSClearObjectFD =
298      clang::DeclRefExpr::Create(C,
299                                 clang::NestedNameSpecifierLoc(),
300                                 clang::SourceLocation(),
301                                 ClearObjectFD,
302                                 false,
303                                 ClearObjectFD->getLocation(),
304                                 ClearObjectFDType,
305                                 clang::VK_RValue,
306                                 nullptr);
307
308  clang::Expr *RSClearObjectFP =
309      clang::ImplicitCastExpr::Create(C,
310                                      C.getPointerType(ClearObjectFDType),
311                                      clang::CK_FunctionToPointerDecay,
312                                      RefRSClearObjectFD,
313                                      nullptr,
314                                      clang::VK_RValue);
315
316  llvm::SmallVector<clang::Expr*, 1> ArgList;
317  ArgList.push_back(AddrRefRSVar);
318
319  clang::CallExpr *RSClearObjectCall =
320      new(C) clang::CallExpr(C,
321                             RSClearObjectFP,
322                             ArgList,
323                             ClearObjectFD->getCallResultType(),
324                             clang::VK_RValue,
325                             Loc);
326
327  return RSClearObjectCall;
328}
329
330static int ArrayDim(const clang::Type *T) {
331  if (!T || !T->isArrayType()) {
332    return 0;
333  }
334
335  const clang::ConstantArrayType *CAT =
336    static_cast<const clang::ConstantArrayType *>(T);
337  return static_cast<int>(CAT->getSize().getSExtValue());
338}
339
340clang::Stmt *ClearStructRSObject(
341    clang::ASTContext &C,
342    clang::DeclContext *DC,
343    clang::Expr *RefRSStruct,
344    clang::SourceLocation StartLoc,
345    clang::SourceLocation Loc);
346
347clang::Stmt *ClearArrayRSObject(
348    clang::ASTContext &C,
349    clang::DeclContext *DC,
350    clang::Expr *RefRSArr,
351    clang::SourceLocation StartLoc,
352    clang::SourceLocation Loc) {
353  const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
354  slangAssert(BaseType->isArrayType());
355
356  int NumArrayElements = ArrayDim(BaseType);
357  // Actually extract out the base RS object type for use later
358  BaseType = BaseType->getArrayElementTypeNoTypeQual();
359
360  if (NumArrayElements <= 0) {
361    return nullptr;
362  }
363
364  // Example destructor loop for "rs_font fontArr[10];"
365  //
366  // (ForStmt
367  //   (DeclStmt
368  //     (VarDecl used rsIntIter 'int' cinit
369  //       (IntegerLiteral 'int' 0)))
370  //   (BinaryOperator 'int' '<'
371  //     (ImplicitCastExpr int LValueToRValue
372  //       (DeclRefExpr 'int' Var='rsIntIter'))
373  //     (IntegerLiteral 'int' 10)
374  //   nullptr << CondVar >>
375  //   (UnaryOperator 'int' postfix '++'
376  //     (DeclRefExpr 'int' Var='rsIntIter'))
377  //   (CallExpr 'void'
378  //     (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
379  //       (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
380  //     (UnaryOperator 'rs_font *' prefix '&'
381  //       (ArraySubscriptExpr 'rs_font':'rs_font'
382  //         (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
383  //           (DeclRefExpr 'rs_font [10]' Var='fontArr'))
384  //         (DeclRefExpr 'int' Var='rsIntIter'))))))
385
386  // Create helper variable for iterating through elements
387  static unsigned sIterCounter = 0;
388  std::stringstream UniqueIterName;
389  UniqueIterName << "rsIntIter" << sIterCounter++;
390  clang::IdentifierInfo *II = &C.Idents.get(UniqueIterName.str());
391  clang::VarDecl *IIVD =
392      clang::VarDecl::Create(C,
393                             DC,
394                             StartLoc,
395                             Loc,
396                             II,
397                             C.IntTy,
398                             C.getTrivialTypeSourceInfo(C.IntTy),
399                             clang::SC_None);
400  // Mark "rsIntIter" as used
401  IIVD->markUsed(C);
402
403  // Form the actual destructor loop
404  // for (Init; Cond; Inc)
405  //   RSClearObjectCall;
406
407  // Init -> "int rsIntIter = 0"
408  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
409      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
410  IIVD->setInit(Int0);
411
412  clang::Decl *IID = (clang::Decl *)IIVD;
413  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
414  clang::Stmt *Init = new(C) clang::DeclStmt(DGR, Loc, Loc);
415
416  // Cond -> "rsIntIter < NumArrayElements"
417  clang::DeclRefExpr *RefrsIntIterLValue =
418      clang::DeclRefExpr::Create(C,
419                                 clang::NestedNameSpecifierLoc(),
420                                 clang::SourceLocation(),
421                                 IIVD,
422                                 false,
423                                 Loc,
424                                 C.IntTy,
425                                 clang::VK_LValue,
426                                 nullptr);
427
428  clang::Expr *RefrsIntIterRValue =
429      clang::ImplicitCastExpr::Create(C,
430                                      RefrsIntIterLValue->getType(),
431                                      clang::CK_LValueToRValue,
432                                      RefrsIntIterLValue,
433                                      nullptr,
434                                      clang::VK_RValue);
435
436  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
437      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
438
439  clang::BinaryOperator *Cond =
440      new(C) clang::BinaryOperator(RefrsIntIterRValue,
441                                   NumArrayElementsExpr,
442                                   clang::BO_LT,
443                                   C.IntTy,
444                                   clang::VK_RValue,
445                                   clang::OK_Ordinary,
446                                   Loc,
447                                   false);
448
449  // Inc -> "rsIntIter++"
450  clang::UnaryOperator *Inc =
451      new(C) clang::UnaryOperator(RefrsIntIterLValue,
452                                  clang::UO_PostInc,
453                                  C.IntTy,
454                                  clang::VK_RValue,
455                                  clang::OK_Ordinary,
456                                  Loc);
457
458  // Body -> "rsClearObject(&VD[rsIntIter]);"
459  // Destructor loop operates on individual array elements
460
461  clang::Expr *RefRSArrPtr =
462      clang::ImplicitCastExpr::Create(C,
463          C.getPointerType(BaseType->getCanonicalTypeInternal()),
464          clang::CK_ArrayToPointerDecay,
465          RefRSArr,
466          nullptr,
467          clang::VK_RValue);
468
469  clang::Expr *RefRSArrPtrSubscript =
470      new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
471                                       RefrsIntIterRValue,
472                                       BaseType->getCanonicalTypeInternal(),
473                                       clang::VK_RValue,
474                                       clang::OK_Ordinary,
475                                       Loc);
476
477  DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
478
479  clang::Stmt *RSClearObjectCall = nullptr;
480  if (BaseType->isArrayType()) {
481    RSClearObjectCall =
482        ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
483  } else if (DT == DataTypeUnknown) {
484    RSClearObjectCall =
485        ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
486  } else {
487    RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
488  }
489
490  clang::ForStmt *DestructorLoop =
491      new(C) clang::ForStmt(C,
492                            Init,
493                            Cond,
494                            nullptr,  // no condVar
495                            Inc,
496                            RSClearObjectCall,
497                            Loc,
498                            Loc,
499                            Loc);
500
501  return DestructorLoop;
502}
503
504unsigned CountRSObjectTypes(const clang::Type *T) {
505  slangAssert(T);
506  unsigned RSObjectCount = 0;
507
508  if (T->isArrayType()) {
509    return CountRSObjectTypes(T->getArrayElementTypeNoTypeQual());
510  }
511
512  DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
513  if (DT != DataTypeUnknown) {
514    return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
515  }
516
517  if (T->isUnionType()) {
518    clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
519    RD = RD->getDefinition();
520    for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
521           FE = RD->field_end();
522         FI != FE;
523         FI++) {
524      const clang::FieldDecl *FD = *FI;
525      const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
526      if (CountRSObjectTypes(FT)) {
527        slangAssert(false && "can't have unions with RS object types!");
528        return 0;
529      }
530    }
531  }
532
533  if (!T->isStructureType()) {
534    return 0;
535  }
536
537  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
538  RD = RD->getDefinition();
539  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
540         FE = RD->field_end();
541       FI != FE;
542       FI++) {
543    const clang::FieldDecl *FD = *FI;
544    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
545    if (CountRSObjectTypes(FT)) {
546      // Sub-structs should only count once (as should arrays, etc.)
547      RSObjectCount++;
548    }
549  }
550
551  return RSObjectCount;
552}
553
554clang::Stmt *ClearStructRSObject(
555    clang::ASTContext &C,
556    clang::DeclContext *DC,
557    clang::Expr *RefRSStruct,
558    clang::SourceLocation StartLoc,
559    clang::SourceLocation Loc) {
560  const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
561
562  slangAssert(!BaseType->isArrayType());
563
564  // Structs should show up as unknown primitive types
565  slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
566              DataTypeUnknown);
567
568  unsigned FieldsToDestroy = CountRSObjectTypes(BaseType);
569  slangAssert(FieldsToDestroy != 0);
570
571  unsigned StmtCount = 0;
572  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
573  for (unsigned i = 0; i < FieldsToDestroy; i++) {
574    StmtArray[i] = nullptr;
575  }
576
577  // Populate StmtArray by creating a destructor for each RS object field
578  clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
579  RD = RD->getDefinition();
580  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
581         FE = RD->field_end();
582       FI != FE;
583       FI++) {
584    // We just look through all field declarations to see if we find a
585    // declaration for an RS object type (or an array of one).
586    bool IsArrayType = false;
587    clang::FieldDecl *FD = *FI;
588    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
589    slangAssert(FT);
590    const clang::Type *OrigType = FT;
591    while (FT->isArrayType()) {
592      FT = FT->getArrayElementTypeNoTypeQual();
593      slangAssert(FT);
594      IsArrayType = true;
595    }
596
597    // Pass a DeclarationNameInfo with a valid DeclName, since name equality
598    // gets asserted during CodeGen.
599    clang::DeclarationNameInfo FDDeclNameInfo(FD->getDeclName(),
600                                              FD->getLocation());
601
602    if (RSExportPrimitiveType::IsRSObjectType(FT)) {
603      clang::DeclAccessPair FoundDecl =
604          clang::DeclAccessPair::make(FD, clang::AS_none);
605      clang::MemberExpr *RSObjectMember =
606          clang::MemberExpr::Create(C,
607                                    RefRSStruct,
608                                    false,
609                                    clang::SourceLocation(),
610                                    clang::NestedNameSpecifierLoc(),
611                                    clang::SourceLocation(),
612                                    FD,
613                                    FoundDecl,
614                                    FDDeclNameInfo,
615                                    nullptr,
616                                    OrigType->getCanonicalTypeInternal(),
617                                    clang::VK_RValue,
618                                    clang::OK_Ordinary);
619
620      slangAssert(StmtCount < FieldsToDestroy);
621
622      if (IsArrayType) {
623        StmtArray[StmtCount++] = ClearArrayRSObject(C,
624                                                    DC,
625                                                    RSObjectMember,
626                                                    StartLoc,
627                                                    Loc);
628      } else {
629        StmtArray[StmtCount++] = ClearSingleRSObject(C,
630                                                     RSObjectMember,
631                                                     Loc);
632      }
633    } else if (FT->isStructureType() && CountRSObjectTypes(FT)) {
634      // In this case, we have a nested struct. We may not end up filling all
635      // of the spaces in StmtArray (sub-structs should handle themselves
636      // with separate compound statements).
637      clang::DeclAccessPair FoundDecl =
638          clang::DeclAccessPair::make(FD, clang::AS_none);
639      clang::MemberExpr *RSObjectMember =
640          clang::MemberExpr::Create(C,
641                                    RefRSStruct,
642                                    false,
643                                    clang::SourceLocation(),
644                                    clang::NestedNameSpecifierLoc(),
645                                    clang::SourceLocation(),
646                                    FD,
647                                    FoundDecl,
648                                    clang::DeclarationNameInfo(),
649                                    nullptr,
650                                    OrigType->getCanonicalTypeInternal(),
651                                    clang::VK_RValue,
652                                    clang::OK_Ordinary);
653
654      if (IsArrayType) {
655        StmtArray[StmtCount++] = ClearArrayRSObject(C,
656                                                    DC,
657                                                    RSObjectMember,
658                                                    StartLoc,
659                                                    Loc);
660      } else {
661        StmtArray[StmtCount++] = ClearStructRSObject(C,
662                                                     DC,
663                                                     RSObjectMember,
664                                                     StartLoc,
665                                                     Loc);
666      }
667    }
668  }
669
670  slangAssert(StmtCount > 0);
671  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
672      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
673
674  delete [] StmtArray;
675
676  return CS;
677}
678
679clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
680                                     clang::Expr *DstExpr,
681                                     clang::Expr *SrcExpr,
682                                     clang::SourceLocation StartLoc,
683                                     clang::SourceLocation Loc) {
684  const clang::Type *T = DstExpr->getType().getTypePtr();
685  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
686  slangAssert((SetObjectFD != nullptr) &&
687              "rsSetObject doesn't cover all RS object types");
688
689  clang::QualType SetObjectFDType = SetObjectFD->getType();
690  clang::QualType SetObjectFDArgType[2];
691  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
692  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
693
694  clang::Expr *RefRSSetObjectFD =
695      clang::DeclRefExpr::Create(C,
696                                 clang::NestedNameSpecifierLoc(),
697                                 clang::SourceLocation(),
698                                 SetObjectFD,
699                                 false,
700                                 Loc,
701                                 SetObjectFDType,
702                                 clang::VK_RValue,
703                                 nullptr);
704
705  clang::Expr *RSSetObjectFP =
706      clang::ImplicitCastExpr::Create(C,
707                                      C.getPointerType(SetObjectFDType),
708                                      clang::CK_FunctionToPointerDecay,
709                                      RefRSSetObjectFD,
710                                      nullptr,
711                                      clang::VK_RValue);
712
713  llvm::SmallVector<clang::Expr*, 2> ArgList;
714  ArgList.push_back(new(C) clang::UnaryOperator(DstExpr,
715                                                clang::UO_AddrOf,
716                                                SetObjectFDArgType[0],
717                                                clang::VK_RValue,
718                                                clang::OK_Ordinary,
719                                                Loc));
720  ArgList.push_back(SrcExpr);
721
722  clang::CallExpr *RSSetObjectCall =
723      new(C) clang::CallExpr(C,
724                             RSSetObjectFP,
725                             ArgList,
726                             SetObjectFD->getCallResultType(),
727                             clang::VK_RValue,
728                             Loc);
729
730  return RSSetObjectCall;
731}
732
733clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
734                                     clang::Expr *LHS,
735                                     clang::Expr *RHS,
736                                     clang::SourceLocation StartLoc,
737                                     clang::SourceLocation Loc);
738
739/*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
740                                           clang::Expr *DstArr,
741                                           clang::Expr *SrcArr,
742                                           clang::SourceLocation StartLoc,
743                                           clang::SourceLocation Loc) {
744  clang::DeclContext *DC = nullptr;
745  const clang::Type *BaseType = DstArr->getType().getTypePtr();
746  slangAssert(BaseType->isArrayType());
747
748  int NumArrayElements = ArrayDim(BaseType);
749  // Actually extract out the base RS object type for use later
750  BaseType = BaseType->getArrayElementTypeNoTypeQual();
751
752  clang::Stmt *StmtArray[2] = {nullptr};
753  int StmtCtr = 0;
754
755  if (NumArrayElements <= 0) {
756    return nullptr;
757  }
758
759  // Create helper variable for iterating through elements
760  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
761  clang::VarDecl *IIVD =
762      clang::VarDecl::Create(C,
763                             DC,
764                             StartLoc,
765                             Loc,
766                             &II,
767                             C.IntTy,
768                             C.getTrivialTypeSourceInfo(C.IntTy),
769                             clang::SC_None,
770                             clang::SC_None);
771  clang::Decl *IID = (clang::Decl *)IIVD;
772
773  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
774  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
775
776  // Form the actual loop
777  // for (Init; Cond; Inc)
778  //   RSSetObjectCall;
779
780  // Init -> "rsIntIter = 0"
781  clang::DeclRefExpr *RefrsIntIter =
782      clang::DeclRefExpr::Create(C,
783                                 clang::NestedNameSpecifierLoc(),
784                                 IIVD,
785                                 Loc,
786                                 C.IntTy,
787                                 clang::VK_RValue,
788                                 nullptr);
789
790  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
791      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
792
793  clang::BinaryOperator *Init =
794      new(C) clang::BinaryOperator(RefrsIntIter,
795                                   Int0,
796                                   clang::BO_Assign,
797                                   C.IntTy,
798                                   clang::VK_RValue,
799                                   clang::OK_Ordinary,
800                                   Loc);
801
802  // Cond -> "rsIntIter < NumArrayElements"
803  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
804      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
805
806  clang::BinaryOperator *Cond =
807      new(C) clang::BinaryOperator(RefrsIntIter,
808                                   NumArrayElementsExpr,
809                                   clang::BO_LT,
810                                   C.IntTy,
811                                   clang::VK_RValue,
812                                   clang::OK_Ordinary,
813                                   Loc);
814
815  // Inc -> "rsIntIter++"
816  clang::UnaryOperator *Inc =
817      new(C) clang::UnaryOperator(RefrsIntIter,
818                                  clang::UO_PostInc,
819                                  C.IntTy,
820                                  clang::VK_RValue,
821                                  clang::OK_Ordinary,
822                                  Loc);
823
824  // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
825  // Loop operates on individual array elements
826
827  clang::Expr *DstArrPtr =
828      clang::ImplicitCastExpr::Create(C,
829          C.getPointerType(BaseType->getCanonicalTypeInternal()),
830          clang::CK_ArrayToPointerDecay,
831          DstArr,
832          nullptr,
833          clang::VK_RValue);
834
835  clang::Expr *DstArrPtrSubscript =
836      new(C) clang::ArraySubscriptExpr(DstArrPtr,
837                                       RefrsIntIter,
838                                       BaseType->getCanonicalTypeInternal(),
839                                       clang::VK_RValue,
840                                       clang::OK_Ordinary,
841                                       Loc);
842
843  clang::Expr *SrcArrPtr =
844      clang::ImplicitCastExpr::Create(C,
845          C.getPointerType(BaseType->getCanonicalTypeInternal()),
846          clang::CK_ArrayToPointerDecay,
847          SrcArr,
848          nullptr,
849          clang::VK_RValue);
850
851  clang::Expr *SrcArrPtrSubscript =
852      new(C) clang::ArraySubscriptExpr(SrcArrPtr,
853                                       RefrsIntIter,
854                                       BaseType->getCanonicalTypeInternal(),
855                                       clang::VK_RValue,
856                                       clang::OK_Ordinary,
857                                       Loc);
858
859  DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
860
861  clang::Stmt *RSSetObjectCall = nullptr;
862  if (BaseType->isArrayType()) {
863    RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
864                                             SrcArrPtrSubscript,
865                                             StartLoc, Loc);
866  } else if (DT == DataTypeUnknown) {
867    RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
868                                              SrcArrPtrSubscript,
869                                              StartLoc, Loc);
870  } else {
871    RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
872                                              SrcArrPtrSubscript,
873                                              StartLoc, Loc);
874  }
875
876  clang::ForStmt *DestructorLoop =
877      new(C) clang::ForStmt(C,
878                            Init,
879                            Cond,
880                            nullptr,  // no condVar
881                            Inc,
882                            RSSetObjectCall,
883                            Loc,
884                            Loc,
885                            Loc);
886
887  StmtArray[StmtCtr++] = DestructorLoop;
888  slangAssert(StmtCtr == 2);
889
890  clang::CompoundStmt *CS =
891      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
892
893  return CS;
894} */
895
896clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
897                                     clang::Expr *LHS,
898                                     clang::Expr *RHS,
899                                     clang::SourceLocation StartLoc,
900                                     clang::SourceLocation Loc) {
901  clang::QualType QT = LHS->getType();
902  const clang::Type *T = QT.getTypePtr();
903  slangAssert(T->isStructureType());
904  slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
905
906  // Keep an extra slot for the original copy (memcpy)
907  unsigned FieldsToSet = CountRSObjectTypes(T) + 1;
908
909  unsigned StmtCount = 0;
910  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
911  for (unsigned i = 0; i < FieldsToSet; i++) {
912    StmtArray[i] = nullptr;
913  }
914
915  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
916  RD = RD->getDefinition();
917  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
918         FE = RD->field_end();
919       FI != FE;
920       FI++) {
921    bool IsArrayType = false;
922    clang::FieldDecl *FD = *FI;
923    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
924    const clang::Type *OrigType = FT;
925
926    if (!CountRSObjectTypes(FT)) {
927      // Skip to next if we don't have any viable RS object types
928      continue;
929    }
930
931    clang::DeclAccessPair FoundDecl =
932        clang::DeclAccessPair::make(FD, clang::AS_none);
933    clang::MemberExpr *DstMember =
934        clang::MemberExpr::Create(C,
935                                  LHS,
936                                  false,
937                                  clang::SourceLocation(),
938                                  clang::NestedNameSpecifierLoc(),
939                                  clang::SourceLocation(),
940                                  FD,
941                                  FoundDecl,
942                                  clang::DeclarationNameInfo(
943                                      FD->getDeclName(),
944                                      clang::SourceLocation()),
945                                  nullptr,
946                                  OrigType->getCanonicalTypeInternal(),
947                                  clang::VK_RValue,
948                                  clang::OK_Ordinary);
949
950    clang::MemberExpr *SrcMember =
951        clang::MemberExpr::Create(C,
952                                  RHS,
953                                  false,
954                                  clang::SourceLocation(),
955                                  clang::NestedNameSpecifierLoc(),
956                                  clang::SourceLocation(),
957                                  FD,
958                                  FoundDecl,
959                                  clang::DeclarationNameInfo(
960                                      FD->getDeclName(),
961                                      clang::SourceLocation()),
962                                  nullptr,
963                                  OrigType->getCanonicalTypeInternal(),
964                                  clang::VK_RValue,
965                                  clang::OK_Ordinary);
966
967    if (FT->isArrayType()) {
968      FT = FT->getArrayElementTypeNoTypeQual();
969      IsArrayType = true;
970    }
971
972    DataType DT = RSExportPrimitiveType::GetRSSpecificType(FT);
973
974    if (IsArrayType) {
975      clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
976      DiagEngine.Report(
977        clang::FullSourceLoc(Loc, C.getSourceManager()),
978        DiagEngine.getCustomDiagID(
979          clang::DiagnosticsEngine::Error,
980          "Arrays of RS object types within structures cannot be copied"));
981      // TODO(srhines): Support setting arrays of RS objects
982      // StmtArray[StmtCount++] =
983      //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
984    } else if (DT == DataTypeUnknown) {
985      StmtArray[StmtCount++] =
986          CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
987    } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
988      StmtArray[StmtCount++] =
989          CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
990    } else {
991      slangAssert(false);
992    }
993  }
994
995  slangAssert(StmtCount < FieldsToSet);
996
997  // We still need to actually do the overall struct copy. For simplicity,
998  // we just do a straight-up assignment (which will still preserve all
999  // the proper RS object reference counts).
1000  clang::BinaryOperator *CopyStruct =
1001      new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
1002                                   clang::VK_RValue, clang::OK_Ordinary, Loc,
1003                                   false);
1004  StmtArray[StmtCount++] = CopyStruct;
1005
1006  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
1007      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
1008
1009  delete [] StmtArray;
1010
1011  return CS;
1012}
1013
1014}  // namespace
1015
1016void RSObjectRefCount::Scope::InsertStmt(const clang::ASTContext &C,
1017                                         clang::Stmt *NewStmt) {
1018  std::vector<clang::Stmt*> newBody;
1019  for (clang::Stmt* S1 : mCS->body()) {
1020    if (S1 == mCurrent) {
1021      newBody.push_back(NewStmt);
1022    }
1023    newBody.push_back(S1);
1024  }
1025  mCS->setStmts(C, newBody);
1026}
1027
1028void RSObjectRefCount::Scope::ReplaceStmt(const clang::ASTContext &C,
1029                                          clang::Stmt *NewStmt) {
1030  std::vector<clang::Stmt*> newBody;
1031  for (clang::Stmt* S1 : mCS->body()) {
1032    if (S1 == mCurrent) {
1033      newBody.push_back(NewStmt);
1034    } else {
1035      newBody.push_back(S1);
1036    }
1037  }
1038  mCS->setStmts(C, newBody);
1039}
1040
1041void RSObjectRefCount::Scope::ReplaceExpr(const clang::ASTContext& C,
1042                                          clang::Expr* OldExpr,
1043                                          clang::Expr* NewExpr) {
1044  RSASTReplace R(C);
1045  R.ReplaceStmt(mCurrent, OldExpr, NewExpr);
1046}
1047
1048void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1049    clang::BinaryOperator *AS) {
1050
1051  clang::QualType QT = AS->getType();
1052
1053  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1054      DataTypeRSAllocation)->getASTContext();
1055
1056  clang::SourceLocation Loc = AS->getExprLoc();
1057  clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
1058  clang::Stmt *UpdatedStmt = nullptr;
1059
1060  if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1061    // By definition, this is a struct assignment if we get here
1062    UpdatedStmt =
1063        CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1064  } else {
1065    UpdatedStmt =
1066        CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1067  }
1068
1069  RSASTReplace R(C);
1070  R.ReplaceStmt(mCS, AS, UpdatedStmt);
1071}
1072
1073void RSObjectRefCount::Scope::AppendRSObjectInit(
1074    clang::VarDecl *VD,
1075    clang::DeclStmt *DS,
1076    DataType DT,
1077    clang::Expr *InitExpr) {
1078  slangAssert(VD);
1079
1080  if (!InitExpr) {
1081    return;
1082  }
1083
1084  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1085      DataTypeRSAllocation)->getASTContext();
1086  clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1087      DataTypeRSAllocation)->getLocation();
1088  clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1089      DataTypeRSAllocation)->getInnerLocStart();
1090
1091  if (DT == DataTypeIsStruct) {
1092    const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1093    clang::DeclRefExpr *RefRSVar =
1094        clang::DeclRefExpr::Create(C,
1095                                   clang::NestedNameSpecifierLoc(),
1096                                   clang::SourceLocation(),
1097                                   VD,
1098                                   false,
1099                                   Loc,
1100                                   T->getCanonicalTypeInternal(),
1101                                   clang::VK_RValue,
1102                                   nullptr);
1103
1104    clang::Stmt *RSSetObjectOps =
1105        CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1106    // Fix for b/37363420; consider:
1107    //
1108    // struct foo { rs_matrix m; };
1109    // void bar() {
1110    //   struct foo M = {...};
1111    // }
1112    //
1113    // slang modifies that declaration with initialization to a
1114    // declaration plus an assignment of the initialization values.
1115    //
1116    // void bar() {
1117    //   struct foo M = {};
1118    //   M = {...}; // by CreateStructRSSetObject() above
1119    // }
1120    //
1121    // the slang-generated statement (M = {...}) is a use of M, and we
1122    // need to mark M (clang::VarDecl *VD) as used.
1123    VD->markUsed(C);
1124
1125    std::list<clang::Stmt*> StmtList;
1126    StmtList.push_back(RSSetObjectOps);
1127    AppendAfterStmt(C, mCS, DS, StmtList);
1128    return;
1129  }
1130
1131  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1132  slangAssert((SetObjectFD != nullptr) &&
1133              "rsSetObject doesn't cover all RS object types");
1134
1135  clang::QualType SetObjectFDType = SetObjectFD->getType();
1136  clang::QualType SetObjectFDArgType[2];
1137  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1138  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1139
1140  clang::Expr *RefRSSetObjectFD =
1141      clang::DeclRefExpr::Create(C,
1142                                 clang::NestedNameSpecifierLoc(),
1143                                 clang::SourceLocation(),
1144                                 SetObjectFD,
1145                                 false,
1146                                 Loc,
1147                                 SetObjectFDType,
1148                                 clang::VK_RValue,
1149                                 nullptr);
1150
1151  clang::Expr *RSSetObjectFP =
1152      clang::ImplicitCastExpr::Create(C,
1153                                      C.getPointerType(SetObjectFDType),
1154                                      clang::CK_FunctionToPointerDecay,
1155                                      RefRSSetObjectFD,
1156                                      nullptr,
1157                                      clang::VK_RValue);
1158
1159  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1160  clang::DeclRefExpr *RefRSVar =
1161      clang::DeclRefExpr::Create(C,
1162                                 clang::NestedNameSpecifierLoc(),
1163                                 clang::SourceLocation(),
1164                                 VD,
1165                                 false,
1166                                 Loc,
1167                                 T->getCanonicalTypeInternal(),
1168                                 clang::VK_RValue,
1169                                 nullptr);
1170
1171  llvm::SmallVector<clang::Expr*, 2> ArgList;
1172  ArgList.push_back(new(C) clang::UnaryOperator(RefRSVar,
1173                                                clang::UO_AddrOf,
1174                                                SetObjectFDArgType[0],
1175                                                clang::VK_RValue,
1176                                                clang::OK_Ordinary,
1177                                                Loc));
1178  ArgList.push_back(InitExpr);
1179
1180  clang::CallExpr *RSSetObjectCall =
1181      new(C) clang::CallExpr(C,
1182                             RSSetObjectFP,
1183                             ArgList,
1184                             SetObjectFD->getCallResultType(),
1185                             clang::VK_RValue,
1186                             Loc);
1187
1188  std::list<clang::Stmt*> StmtList;
1189  StmtList.push_back(RSSetObjectCall);
1190  AppendAfterStmt(C, mCS, DS, StmtList);
1191}
1192
1193void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1194  if (mRSO.empty()) {
1195    return;
1196  }
1197
1198  clang::DeclContext* DC = mRSO.front()->getDeclContext();
1199  clang::ASTContext& C = DC->getParentASTContext();
1200  clang::SourceManager& SM = C.getSourceManager();
1201
1202  const auto& OccursBefore = [&SM] (clang::SourceLocation L1, clang::SourceLocation L2)->bool {
1203    return SM.isBeforeInTranslationUnit(L1, L2);
1204  };
1205  typedef std::map<clang::SourceLocation, clang::Stmt*, decltype(OccursBefore)> DMap;
1206
1207  DMap dtors(OccursBefore);
1208
1209  // Create rsClearObject calls. Note the DMap entries are sorted by the SourceLocation.
1210  for (clang::VarDecl* VD : mRSO) {
1211    clang::SourceLocation Loc = VD->getSourceRange().getBegin();
1212    clang::Stmt* RSClearObjectCall = ClearRSObject(VD, DC);
1213    dtors.insert(std::make_pair(Loc, RSClearObjectCall));
1214  }
1215
1216  DestructorVisitor Visitor;
1217  Visitor.Visit(mCS);
1218
1219  // Replace each exiting statement with a block that contains the original statement
1220  // and added rsClearObject() calls before it.
1221  for (clang::Stmt* S : Visitor.getExitingStmts()) {
1222
1223    const clang::SourceLocation currentLoc = S->getLocStart();
1224
1225    DMap::iterator firstDtorIter = dtors.begin();
1226    DMap::iterator currentDtorIter = firstDtorIter;
1227    DMap::iterator lastDtorIter = dtors.end();
1228
1229    while (currentDtorIter != lastDtorIter &&
1230           OccursBefore(currentDtorIter->first, currentLoc)) {
1231      currentDtorIter++;
1232    }
1233
1234    if (currentDtorIter == firstDtorIter) {
1235      continue;
1236    }
1237
1238    std::vector<clang::Stmt*> Stmts;
1239
1240    // Insert rsClearObject() calls for all rsObjects declared before the current statement
1241    for(DMap::iterator it = firstDtorIter; it != currentDtorIter; it++) {
1242      Stmts.push_back(it->second);
1243    }
1244    Stmts.push_back(S);
1245
1246    RSASTReplace R(C);
1247    clang::CompoundStmt* CS = BuildCompoundStmt(C, Stmts, S->getLocEnd());
1248    R.ReplaceStmt(mCS, S, CS);
1249  }
1250
1251  std::list<clang::Stmt*> Stmts;
1252  for(auto LocCallPair : dtors) {
1253    Stmts.push_back(LocCallPair.second);
1254  }
1255  AppendAfterStmt(C, mCS, nullptr, Stmts);
1256}
1257
1258clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
1259    clang::VarDecl *VD,
1260    clang::DeclContext *DC) {
1261  slangAssert(VD);
1262  clang::ASTContext &C = VD->getASTContext();
1263  clang::SourceLocation Loc = VD->getLocation();
1264  clang::SourceLocation StartLoc = VD->getInnerLocStart();
1265  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1266
1267  // Reference expr to target RS object variable
1268  clang::DeclRefExpr *RefRSVar =
1269      clang::DeclRefExpr::Create(C,
1270                                 clang::NestedNameSpecifierLoc(),
1271                                 clang::SourceLocation(),
1272                                 VD,
1273                                 false,
1274                                 Loc,
1275                                 T->getCanonicalTypeInternal(),
1276                                 clang::VK_RValue,
1277                                 nullptr);
1278
1279  if (T->isArrayType()) {
1280    return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1281  }
1282
1283  DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
1284
1285  if (DT == DataTypeUnknown ||
1286      DT == DataTypeIsStruct) {
1287    return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1288  }
1289
1290  slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1291              "Should be RS object");
1292
1293  return ClearSingleRSObject(C, RefRSVar, Loc);
1294}
1295
1296bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1297                                          DataType *DT,
1298                                          clang::Expr **InitExpr) {
1299  slangAssert(VD && DT && InitExpr);
1300  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1301
1302  // Loop through array types to get to base type
1303  slangAssert(T);
1304  while (T->isArrayType()) {
1305    T = T->getArrayElementTypeNoTypeQual();
1306    slangAssert(T);
1307  }
1308
1309  bool DataTypeIsStructWithRSObject = false;
1310  *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1311
1312  if (*DT == DataTypeUnknown) {
1313    if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1314      *DT = DataTypeIsStruct;
1315      DataTypeIsStructWithRSObject = true;
1316    } else {
1317      return false;
1318    }
1319  }
1320
1321  bool DataTypeIsRSObject = false;
1322  if (DataTypeIsStructWithRSObject) {
1323    DataTypeIsRSObject = true;
1324  } else {
1325    DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1326  }
1327  *InitExpr = VD->getInit();
1328
1329  if (!DataTypeIsRSObject && *InitExpr) {
1330    // If we already have an initializer for a matrix type, we are done.
1331    return DataTypeIsRSObject;
1332  }
1333
1334  clang::Expr *ZeroInitializer =
1335      CreateEmptyInitListExpr(VD->getASTContext(), VD->getLocation());
1336
1337  if (ZeroInitializer) {
1338    ZeroInitializer->setType(T->getCanonicalTypeInternal());
1339    VD->setInit(ZeroInitializer);
1340  }
1341
1342  return DataTypeIsRSObject;
1343}
1344
1345clang::Expr *RSObjectRefCount::CreateEmptyInitListExpr(
1346    clang::ASTContext &C,
1347    const clang::SourceLocation &Loc) {
1348
1349  // We can cheaply construct a zero initializer by just creating an empty
1350  // initializer list. Clang supports this extension to C(99), and will create
1351  // any necessary constructs to zero out the entire variable.
1352  llvm::SmallVector<clang::Expr*, 1> EmptyInitList;
1353  return new(C) clang::InitListExpr(C, Loc, EmptyInitList, Loc);
1354}
1355
1356clang::DeclRefExpr *RSObjectRefCount::CreateGuard(clang::ASTContext &C,
1357                                                  clang::DeclContext *DC,
1358                                                  clang::Expr *E,
1359                                                  const llvm::Twine &VarName,
1360                                                  std::vector<clang::Stmt*> &NewStmts) {
1361  clang::SourceLocation Loc = E->getLocStart();
1362  const clang::QualType Ty = E->getType();
1363  clang::VarDecl* TmpDecl = clang::VarDecl::Create(
1364      C,                                     // AST context
1365      DC,                                    // Decl context
1366      Loc,                                   // Start location
1367      Loc,                                   // Id location
1368      &C.Idents.get(VarName.str()),          // Id
1369      Ty,                                    // Type
1370      C.getTrivialTypeSourceInfo(Ty),        // Type info
1371      clang::SC_None                         // Storage class
1372  );
1373  const clang::Type *T = Ty.getTypePtr();
1374  clang::Expr *ZeroInitializer =
1375      RSObjectRefCount::CreateEmptyInitListExpr(C, Loc);
1376  ZeroInitializer->setType(T->getCanonicalTypeInternal());
1377  TmpDecl->setInit(ZeroInitializer);
1378  TmpDecl->markUsed(C);
1379  clang::Decl* Decls[] = { TmpDecl };
1380  const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
1381      C, Decls, sizeof(Decls) / sizeof(*Decls));
1382  clang::DeclStmt* DS = new (C) clang::DeclStmt(DGR, Loc, Loc);
1383  NewStmts.push_back(DS);
1384
1385  clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
1386      C,
1387      clang::NestedNameSpecifierLoc(),       // QualifierLoc
1388      Loc,                                   // TemplateKWLoc
1389      TmpDecl,
1390      false,                                 // RefersToEnclosingVariableOrCapture
1391      Loc,                                   // NameLoc
1392      Ty,
1393      clang::VK_LValue
1394  );
1395
1396  clang::Stmt *UpdatedStmt = nullptr;
1397  if (CountRSObjectTypes(Ty.getTypePtr()) == 0) {
1398    // The expression E is not an RS object itself. Instead of calling
1399    // rsSetObject(), create an assignment statement to set the value of the
1400    // temporary "guard" variable to the expression.
1401    // This can happen if called from RSObjectRefCount::VisitReturnStmt(),
1402    // when the return expression is not an RS object but references one.
1403    UpdatedStmt =
1404      new(C) clang::BinaryOperator(DRE, E, clang::BO_Assign, Ty,
1405                                   clang::VK_RValue, clang::OK_Ordinary, Loc,
1406                                   false);
1407
1408  } else if (!RSExportPrimitiveType::IsRSObjectType(Ty.getTypePtr())) {
1409    // By definition, this is a struct assignment if we get here
1410    UpdatedStmt =
1411        CreateStructRSSetObject(C, DRE, E, Loc, Loc);
1412  } else {
1413    UpdatedStmt =
1414        CreateSingleRSSetObject(C, DRE, E, Loc, Loc);
1415  }
1416  NewStmts.push_back(UpdatedStmt);
1417
1418  return DRE;
1419}
1420
1421void RSObjectRefCount::CreateParameterGuard(clang::ASTContext &C,
1422                                            clang::DeclContext *DC,
1423                                            clang::ParmVarDecl *PD,
1424                                            std::vector<clang::Stmt*> &NewStmts) {
1425  clang::SourceLocation Loc = PD->getLocStart();
1426  clang::DeclRefExpr* ParamDRE = clang::DeclRefExpr::Create(
1427      C,
1428      clang::NestedNameSpecifierLoc(),       // QualifierLoc
1429      Loc,                                   // TemplateKWLoc
1430      PD,
1431      false,                                 // RefersToEnclosingVariableOrCapture
1432      Loc,                                   // NameLoc
1433      PD->getType(),
1434      clang::VK_RValue
1435  );
1436
1437  CreateGuard(C, DC, ParamDRE,
1438              llvm::Twine(".rs.param.") + llvm::Twine(PD->getName()), NewStmts);
1439}
1440
1441void RSObjectRefCount::HandleParamsAndLocals(clang::FunctionDecl *FD) {
1442  std::vector<clang::Stmt*> NewStmts;
1443  std::list<clang::ParmVarDecl*> ObjParams;
1444  for (clang::ParmVarDecl *Param : FD->parameters()) {
1445    clang::QualType QT = Param->getType();
1446    if (CountRSObjectTypes(QT.getTypePtr())) {
1447      // Ignore non-object types
1448      RSObjectRefCount::CreateParameterGuard(mCtx, FD, Param, NewStmts);
1449      ObjParams.push_back(Param);
1450    }
1451  }
1452
1453  clang::Stmt *OldBody = FD->getBody();
1454  if (ObjParams.empty()) {
1455    Visit(OldBody);
1456    return;
1457  }
1458
1459  NewStmts.push_back(OldBody);
1460
1461  clang::SourceLocation Loc = FD->getLocStart();
1462  clang::CompoundStmt *NewBody = BuildCompoundStmt(mCtx, NewStmts, Loc);
1463  Scope S(NewBody);
1464  for (clang::ParmVarDecl *Param : ObjParams) {
1465    S.addRSObject(Param);
1466  }
1467  mScopeStack.push_back(&S);
1468
1469  // To avoid adding unnecessary ref counting artifacts to newly added temporary
1470  // local variables for parameters, visits only the old function body here.
1471  Visit(OldBody);
1472
1473  FD->setBody(NewBody);
1474
1475  S.InsertLocalVarDestructors();
1476  mScopeStack.pop_back();
1477}
1478
1479clang::CompoundStmt* RSObjectRefCount::CreateRetStmtWithTempVar(
1480    clang::ASTContext& C,
1481    clang::DeclContext* DC,
1482    clang::ReturnStmt* RS,
1483    const unsigned id) {
1484  std::vector<clang::Stmt*> NewStmts;
1485  // Since we insert rsClearObj() calls before the return statement, we need
1486  // to make sure none of the cleared RS objects are referenced in the
1487  // return statement.
1488  // For that, we create a new local variable named .rs.retval, assign the
1489  // original return expression to it, make all necessary rsClearObj()
1490  // calls, then return .rs.retval. Note rsClearObj() is not called on
1491  // .rs.retval.
1492  clang::SourceLocation Loc = RS->getLocStart();
1493  clang::Expr* RetVal = RS->getRetValue();
1494  const clang::QualType RetTy = RetVal->getType();
1495  clang::DeclRefExpr *DRE = CreateGuard(C, DC, RetVal,
1496                                        llvm::Twine(".rs.retval") + llvm::Twine(id),
1497                                        NewStmts);
1498
1499  // Creates a new return statement
1500  clang::ReturnStmt* NewRet = new (C) clang::ReturnStmt(Loc);
1501  clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
1502      C,
1503      RetTy,
1504      clang::CK_LValueToRValue,
1505      DRE,
1506      nullptr,
1507      clang::VK_RValue
1508  );
1509  NewRet->setRetValue(CastExpr);
1510  NewStmts.push_back(NewRet);
1511
1512  return BuildCompoundStmt(C, NewStmts, Loc);
1513}
1514
1515void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1516  VisitStmt(DS);
1517  getCurrentScope()->setCurrentStmt(DS);
1518  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1519       I != E;
1520       I++) {
1521    clang::Decl *D = *I;
1522    if (D->getKind() == clang::Decl::Var) {
1523      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1524      DataType DT = DataTypeUnknown;
1525      clang::Expr *InitExpr = nullptr;
1526      if (InitializeRSObject(VD, &DT, &InitExpr)) {
1527        // We need to zero-init all RS object types (including matrices), ...
1528        getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1529        // ... but, only add to the list of RS objects if we have some
1530        // non-matrix RS object fields.
1531        if (CountRSObjectTypes(VD->getType().getTypePtr())) {
1532          getCurrentScope()->addRSObject(VD);
1533        }
1534      }
1535    }
1536  }
1537}
1538
1539void RSObjectRefCount::VisitCallExpr(clang::CallExpr* CE) {
1540  clang::QualType RetTy;
1541  const clang::FunctionDecl* FD = CE->getDirectCallee();
1542
1543  if (FD) {
1544    // Direct calls
1545
1546    RetTy = FD->getReturnType();
1547  } else {
1548    // Indirect calls through function pointers
1549
1550    const clang::Expr* Callee = CE->getCallee();
1551    const clang::Type* CalleeType = Callee->getType().getTypePtr();
1552    const clang::PointerType* PtrType = CalleeType->getAs<clang::PointerType>();
1553
1554    if (!PtrType) {
1555      return;
1556    }
1557
1558    const clang::Type* PointeeType = PtrType->getPointeeType().getTypePtr();
1559    const clang::FunctionType* FuncType = PointeeType->getAs<clang::FunctionType>();
1560
1561    if (!FuncType) {
1562      return;
1563    }
1564
1565    RetTy = FuncType->getReturnType();
1566  }
1567
1568  // The RenderScript runtime API maintains the invariant that the sysRef of a new RS object would
1569  // be 1, with the exception of rsGetAllocation() (deprecated in API 22), which leaves the sysRef
1570  // 0 for a new allocation. It is the responsibility of the callee of the API to decrement the
1571  // sysRef when a reference of the RS object goes out of scope. The compiler generates code to do
1572  // just that, by creating a temporary variable named ".rs.tmpN" with the result of
1573  // an RS-object-returning API directly assigned to it, and calling rsClearObject() on .rs.tmpN
1574  // right before it exits the current scope. Such code generation is skipped for rsGetAllocation()
1575  // to avoid decrementing its sysRef below zero.
1576
1577  if (CountRSObjectTypes(RetTy.getTypePtr())==0 ||
1578      (FD && FD->getName() == "rsGetAllocation")) {
1579    return;
1580  }
1581
1582  clang::SourceLocation Loc = CE->getSourceRange().getBegin();
1583  std::stringstream ss;
1584  ss << ".rs.tmp" << getNextID();
1585  clang::IdentifierInfo *II = &mCtx.Idents.get(ss.str());
1586
1587  clang::VarDecl* TempVarDecl = clang::VarDecl::Create(
1588      mCtx,                                  // AST context
1589      GetDeclContext(),                      // Decl context
1590      Loc,                                   // Start location
1591      Loc,                                   // Id location
1592      II,                                    // Id
1593      RetTy,                                 // Type
1594      mCtx.getTrivialTypeSourceInfo(RetTy),  // Type info
1595      clang::SC_None                         // Storage class
1596  );
1597  TempVarDecl->setInit(CE);
1598  TempVarDecl->markUsed(mCtx);
1599  clang::Decl* Decls[] = { TempVarDecl };
1600  const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
1601      mCtx, Decls, sizeof(Decls) / sizeof(*Decls));
1602  clang::DeclStmt* DS = new (mCtx) clang::DeclStmt(DGR, Loc, Loc);
1603
1604  getCurrentScope()->InsertStmt(mCtx, DS);
1605
1606  clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
1607      mCtx,                                  // AST context
1608      clang::NestedNameSpecifierLoc(),       // QualifierLoc
1609      Loc,                                   // TemplateKWLoc
1610      TempVarDecl,
1611      false,                                 // RefersToEnclosingVariableOrCapture
1612      Loc,                                   // NameLoc
1613      RetTy,
1614      clang::VK_LValue
1615  );
1616  clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
1617      mCtx,
1618      RetTy,
1619      clang::CK_LValueToRValue,
1620      DRE,
1621      nullptr,
1622      clang::VK_RValue
1623  );
1624
1625  getCurrentScope()->ReplaceExpr(mCtx, CE, CastExpr);
1626
1627  // Register TempVarDecl for destruction call (rsClearObj).
1628  getCurrentScope()->addRSObject(TempVarDecl);
1629}
1630
1631void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1632  if (!emptyScope()) {
1633    getCurrentScope()->setCurrentStmt(CS);
1634  }
1635
1636  if (!CS->body_empty()) {
1637    // Push a new scope
1638    Scope *S = new Scope(CS);
1639    mScopeStack.push_back(S);
1640
1641    VisitStmt(CS);
1642
1643    // Destroy the scope
1644    slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1645    S->InsertLocalVarDestructors();
1646    mScopeStack.pop_back();
1647    delete S;
1648  }
1649}
1650
1651void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1652  getCurrentScope()->setCurrentStmt(AS);
1653  clang::QualType QT = AS->getType();
1654
1655  if (CountRSObjectTypes(QT.getTypePtr())) {
1656    getCurrentScope()->ReplaceRSObjectAssignment(AS);
1657  }
1658}
1659
1660namespace {
1661
1662class FindRSObjRefVisitor : public clang::RecursiveASTVisitor<FindRSObjRefVisitor> {
1663public:
1664  explicit FindRSObjRefVisitor() : mRefRSObj(false) {}
1665  bool VisitExpr(clang::Expr* Expression) {
1666    if (CountRSObjectTypes(Expression->getType().getTypePtr()) > 0) {
1667      mRefRSObj = true;
1668      // Found a reference to an RS object. Stop the AST traversal.
1669      return false;
1670    }
1671    return true;
1672  }
1673
1674  bool foundRSObjRef() const { return mRefRSObj; }
1675
1676private:
1677  bool mRefRSObj;
1678};
1679
1680}  // anonymous namespace
1681
1682void RSObjectRefCount::VisitReturnStmt(clang::ReturnStmt *RS) {
1683  getCurrentScope()->setCurrentStmt(RS);
1684
1685  // If there is no local rsObject declared so far, no need to transform the
1686  // return statement.
1687
1688  bool RSObjDeclared = false;
1689
1690  for (const Scope* S : mScopeStack) {
1691    if (S->hasRSObject()) {
1692      RSObjDeclared = true;
1693      break;
1694    }
1695  }
1696
1697  if (!RSObjDeclared) {
1698    return;
1699  }
1700
1701  FindRSObjRefVisitor visitor;
1702
1703  visitor.TraverseStmt(RS);
1704
1705  // If the return statement does not return anything, or if it does not reference
1706  // a rsObject, no need to transform it.
1707
1708  if (!visitor.foundRSObjRef()) {
1709    return;
1710  }
1711
1712  // Transform the return statement so that it does not potentially return or
1713  // reference a rsObject that has been cleared.
1714
1715  clang::CompoundStmt* NewRS;
1716  NewRS = CreateRetStmtWithTempVar(mCtx, GetDeclContext(), RS, getNextID());
1717
1718  getCurrentScope()->ReplaceStmt(mCtx, NewRS);
1719}
1720
1721void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1722  getCurrentScope()->setCurrentStmt(S);
1723  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1724       I != E;
1725       I++) {
1726    if (clang::Stmt *Child = *I) {
1727      Visit(Child);
1728    }
1729  }
1730}
1731
1732// This function walks the list of global variables and (potentially) creates
1733// a single global static destructor function that properly decrements
1734// reference counts on the contained RS object types.
1735clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
1736  Init();
1737
1738  clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
1739  clang::SourceLocation loc;
1740
1741  llvm::StringRef SR(".rs.dtor");
1742  clang::IdentifierInfo &II = mCtx.Idents.get(SR);
1743  clang::DeclarationName N(&II);
1744  clang::FunctionProtoType::ExtProtoInfo EPI;
1745  clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy,
1746      llvm::ArrayRef<clang::QualType>(), EPI);
1747  clang::FunctionDecl *FD = nullptr;
1748
1749  // Generate rsClearObject() call chains for every global variable
1750  // (whether static or extern).
1751  std::vector<clang::Stmt *> StmtList;
1752  for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
1753          E = DC->decls_end(); I != E; I++) {
1754    clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
1755    if (VD) {
1756      if (CountRSObjectTypes(VD->getType().getTypePtr())) {
1757        if (!FD) {
1758          // Only create FD if we are going to use it.
1759          FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, nullptr,
1760                                           clang::SC_None);
1761        }
1762        // Mark VD as used.  It might be unused, except for the destructor.
1763        // 'markUsed' has side-effects that are caused only if VD is not already
1764        // used.  Hence no need for an extra check here.
1765        VD->markUsed(mCtx);
1766        // Make sure to create any helpers within the function's DeclContext,
1767        // not the one associated with the global translation unit.
1768        clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
1769        StmtList.push_back(RSClearObjectCall);
1770      }
1771    }
1772  }
1773
1774  // Nothing needs to be destroyed, so don't emit a dtor.
1775  if (StmtList.empty()) {
1776    return nullptr;
1777  }
1778
1779  clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
1780
1781  slangAssert(FD);
1782  FD->setBody(CS);
1783  // We need some way to tell if this FD is generated by slang
1784  FD->setImplicit();
1785
1786  return FD;
1787}
1788
1789bool HasRSObjectType(const clang::Type *T) {
1790  return CountRSObjectTypes(T) != 0;
1791}
1792
1793}  // namespace slang
1794