slang_rs_object_ref_count.cpp revision 283a6cf32b808c703b219862ac491df3c9fc8b4b
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_object_ref_count.h"
18
19#include <list>
20
21#include "clang/AST/DeclGroup.h"
22#include "clang/AST/Expr.h"
23#include "clang/AST/NestedNameSpecifier.h"
24#include "clang/AST/OperationKinds.h"
25#include "clang/AST/Stmt.h"
26#include "clang/AST/StmtVisitor.h"
27
28#include "slang_assert.h"
29#include "slang.h"
30#include "slang_rs_ast_replace.h"
31#include "slang_rs_export_type.h"
32
33namespace slang {
34
35/* Even though those two arrays are of size DataTypeMax, only entries that
36 * correspond to object types will be set.
37 */
38clang::FunctionDecl *
39RSObjectRefCount::RSSetObjectFD[DataTypeMax];
40clang::FunctionDecl *
41RSObjectRefCount::RSClearObjectFD[DataTypeMax];
42
43void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
44  for (unsigned i = 0; i < DataTypeMax; i++) {
45    RSSetObjectFD[i] = nullptr;
46    RSClearObjectFD[i] = nullptr;
47  }
48
49  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
50
51  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
52          E = TUDecl->decls_end(); I != E; I++) {
53    if ((I->getKind() >= clang::Decl::firstFunction) &&
54        (I->getKind() <= clang::Decl::lastFunction)) {
55      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
56
57      // points to RSSetObjectFD or RSClearObjectFD
58      clang::FunctionDecl **RSObjectFD;
59
60      if (FD->getName() == "rsSetObject") {
61        slangAssert((FD->getNumParams() == 2) &&
62                    "Invalid rsSetObject function prototype (# params)");
63        RSObjectFD = RSSetObjectFD;
64      } else if (FD->getName() == "rsClearObject") {
65        slangAssert((FD->getNumParams() == 1) &&
66                    "Invalid rsClearObject function prototype (# params)");
67        RSObjectFD = RSClearObjectFD;
68      } else {
69        continue;
70      }
71
72      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
73      clang::QualType PVT = PVD->getOriginalType();
74      // The first parameter must be a pointer like rs_allocation*
75      slangAssert(PVT->isPointerType() &&
76          "Invalid rs{Set,Clear}Object function prototype (pointer param)");
77
78      // The rs object type passed to the FD
79      clang::QualType RST = PVT->getPointeeType();
80      DataType DT = RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
81      slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
82             && "must be RS object type");
83
84      if (DT >= 0 && DT < DataTypeMax) {
85          RSObjectFD[DT] = FD;
86      } else {
87          slangAssert(false && "incorrect type");
88      }
89    }
90  }
91}
92
93namespace {
94
95unsigned CountRSObjectTypes(const clang::Type *T);
96
97clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
98                                     clang::Expr *DstExpr,
99                                     clang::Expr *SrcExpr,
100                                     clang::SourceLocation StartLoc,
101                                     clang::SourceLocation Loc);
102
103// This function constructs a new CompoundStmt from the input StmtList.
104clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
105      std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
106  unsigned NewStmtCount = StmtList.size();
107  unsigned CompoundStmtCount = 0;
108
109  clang::Stmt **CompoundStmtList;
110  CompoundStmtList = new clang::Stmt*[NewStmtCount];
111
112  std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
113  std::list<clang::Stmt*>::const_iterator E = StmtList.end();
114  for ( ; I != E; I++) {
115    CompoundStmtList[CompoundStmtCount++] = *I;
116  }
117  slangAssert(CompoundStmtCount == NewStmtCount);
118
119  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
120      C, llvm::makeArrayRef(CompoundStmtList, CompoundStmtCount), Loc, Loc);
121
122  delete [] CompoundStmtList;
123
124  return CS;
125}
126
127void AppendAfterStmt(clang::ASTContext &C,
128                     clang::CompoundStmt *CS,
129                     clang::Stmt *S,
130                     std::list<clang::Stmt*> &StmtList) {
131  slangAssert(CS);
132  clang::CompoundStmt::body_iterator bI = CS->body_begin();
133  clang::CompoundStmt::body_iterator bE = CS->body_end();
134  clang::Stmt **UpdatedStmtList =
135      new clang::Stmt*[CS->size() + StmtList.size()];
136
137  unsigned UpdatedStmtCount = 0;
138  unsigned Once = 0;
139  for ( ; bI != bE; bI++) {
140    if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
141      // If we come across a return here, we don't have anything we can
142      // reasonably replace. We should have already inserted our destructor
143      // code in the proper spot, so we just clean up and return.
144      delete [] UpdatedStmtList;
145
146      return;
147    }
148
149    UpdatedStmtList[UpdatedStmtCount++] = *bI;
150
151    if ((*bI == S) && !Once) {
152      Once++;
153      std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
154      std::list<clang::Stmt*>::const_iterator E = StmtList.end();
155      for ( ; I != E; I++) {
156        UpdatedStmtList[UpdatedStmtCount++] = *I;
157      }
158    }
159  }
160  slangAssert(Once <= 1);
161
162  // When S is nullptr, we are appending to the end of the CompoundStmt.
163  if (!S) {
164    slangAssert(Once == 0);
165    std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
166    std::list<clang::Stmt*>::const_iterator E = StmtList.end();
167    for ( ; I != E; I++) {
168      UpdatedStmtList[UpdatedStmtCount++] = *I;
169    }
170  }
171
172  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
173
174  delete [] UpdatedStmtList;
175}
176
177// This class visits a compound statement and collects a list of all the exiting
178// statements, such as any return statement in any sub-block, and any
179// break/continue statement that would resume outside the current scope.
180// We do not handle the case for goto statements that leave a local scope.
181class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
182 private:
183  // The loop depth of the currently visited node.
184  int mLoopDepth;
185
186  // The switch statement depth of the currently visited node.
187  // Note that this is tracked separately from the loop depth because
188  // SwitchStmt-contained ContinueStmt's should have destructors for the
189  // corresponding loop scope.
190  int mSwitchDepth;
191
192  // Output of the visitor: the statements that should be replaced by compound
193  // statements, each of which contains rsClearObject() calls followed by the
194  // original statement.
195  std::vector<clang::Stmt*> mExitingStmts;
196
197 public:
198  DestructorVisitor() : mLoopDepth(0), mSwitchDepth(0) {}
199
200  const std::vector<clang::Stmt*>& getExitingStmts() const {
201    return mExitingStmts;
202  }
203
204  void VisitStmt(clang::Stmt *S);
205  void VisitBreakStmt(clang::BreakStmt *BS);
206  void VisitContinueStmt(clang::ContinueStmt *CS);
207  void VisitDoStmt(clang::DoStmt *DS);
208  void VisitForStmt(clang::ForStmt *FS);
209  void VisitReturnStmt(clang::ReturnStmt *RS);
210  void VisitSwitchStmt(clang::SwitchStmt *SS);
211  void VisitWhileStmt(clang::WhileStmt *WS);
212};
213
214// Given a return statement RS that returns an rsObject, creates a temporary
215// variable, sets it to the original return expression using rsSetObject(),
216// adds these new statements into NewStmts.
217// Finally, creates and returns a new return statement that returns the
218// temporary variable.
219clang::CompoundStmt* CreateRetStmtWithTempVar(
220    clang::ASTContext& C,
221    clang::DeclContext* DC,
222    clang::ReturnStmt* RS,
223    const unsigned id) {
224  std::list<clang::Stmt*> NewStmts;
225  // Since we insert rsClearObj() calls before the return statement, we need
226  // to make sure none of the cleared RS objects are referenced in the
227  // return statement.
228  // For that, we create a new local variable named .rs.retval, assign the
229  // original return expression to it, make all necessary rsClearObj()
230  // calls, then return .rs.retval. Note rsClearObj() is not called on
231  // .rs.retval.
232
233  clang::SourceLocation Loc = RS->getLocStart();
234  std::stringstream ss;
235  ss << ".rs.retval" << id;
236  llvm::StringRef VarName(ss.str());
237
238  clang::Expr* RetVal = RS->getRetValue();
239  const clang::QualType RetTy = RetVal->getType();
240  clang::VarDecl* RSRetValDecl = clang::VarDecl::Create(
241      C,                                     // AST context
242      DC,                                    // Decl context
243      Loc,                                   // Start location
244      Loc,                                   // Id location
245      &C.Idents.get(VarName),                // Id
246      RetTy,                                 // Type
247      C.getTrivialTypeSourceInfo(RetTy),     // Type info
248      clang::SC_None                         // Storage class
249  );
250  clang::Decl* Decls[] = { RSRetValDecl };
251  const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
252      C, Decls, sizeof(Decls) / sizeof(*Decls));
253  clang::DeclStmt* DS = new (C) clang::DeclStmt(DGR, Loc, Loc);
254  NewStmts.push_back(DS);
255
256  clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
257      C,
258      clang::NestedNameSpecifierLoc(),       // QualifierLoc
259      Loc,                                   // TemplateKWLoc
260      RSRetValDecl,
261      false,                                 // RefersToEnclosingVariableOrCapture
262      Loc,                                   // NameLoc
263      RetTy,
264      clang::VK_LValue
265  );
266  clang::Stmt* SetRetTempVar = CreateSingleRSSetObject(C, DRE, RetVal, Loc, Loc);
267  NewStmts.push_back(SetRetTempVar);
268
269  // Creates a new return statement
270  clang::ReturnStmt* NewRet = new (C) clang::ReturnStmt(RS->getReturnLoc());
271  clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
272      C,
273      RetTy,
274      clang::CK_LValueToRValue,
275      DRE,
276      nullptr,
277      clang::VK_RValue
278  );
279  NewRet->setRetValue(CastExpr);
280  NewStmts.push_back(NewRet);
281
282  return BuildCompoundStmt(C, NewStmts, Loc);
283}
284
285void DestructorVisitor::VisitStmt(clang::Stmt *S) {
286  for (clang::Stmt* Child : S->children()) {
287    if (Child) {
288      Visit(Child);
289    }
290  }
291}
292
293void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
294  VisitStmt(BS);
295  if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
296    mExitingStmts.push_back(BS);
297  }
298}
299
300void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
301  VisitStmt(CS);
302  if (mLoopDepth == 0) {
303    // Switch statements can have nested continues.
304    mExitingStmts.push_back(CS);
305  }
306}
307
308void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
309  mLoopDepth++;
310  VisitStmt(DS);
311  mLoopDepth--;
312}
313
314void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
315  mLoopDepth++;
316  VisitStmt(FS);
317  mLoopDepth--;
318}
319
320void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
321  mExitingStmts.push_back(RS);
322}
323
324void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
325  mSwitchDepth++;
326  VisitStmt(SS);
327  mSwitchDepth--;
328}
329
330void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
331  mLoopDepth++;
332  VisitStmt(WS);
333  mLoopDepth--;
334}
335
336clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
337                                 clang::Expr *RefRSVar,
338                                 clang::SourceLocation Loc) {
339  slangAssert(RefRSVar);
340  const clang::Type *T = RefRSVar->getType().getTypePtr();
341  slangAssert(!T->isArrayType() &&
342              "Should not be destroying arrays with this function");
343
344  clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
345  slangAssert((ClearObjectFD != nullptr) &&
346              "rsClearObject doesn't cover all RS object types");
347
348  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
349  clang::QualType ClearObjectFDArgType =
350      ClearObjectFD->getParamDecl(0)->getOriginalType();
351
352  // Example destructor for "rs_font localFont;"
353  //
354  // (CallExpr 'void'
355  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
356  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
357  //   (UnaryOperator 'rs_font *' prefix '&'
358  //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
359
360  // Get address of targeted RS object
361  clang::Expr *AddrRefRSVar =
362      new(C) clang::UnaryOperator(RefRSVar,
363                                  clang::UO_AddrOf,
364                                  ClearObjectFDArgType,
365                                  clang::VK_RValue,
366                                  clang::OK_Ordinary,
367                                  Loc);
368
369  clang::Expr *RefRSClearObjectFD =
370      clang::DeclRefExpr::Create(C,
371                                 clang::NestedNameSpecifierLoc(),
372                                 clang::SourceLocation(),
373                                 ClearObjectFD,
374                                 false,
375                                 ClearObjectFD->getLocation(),
376                                 ClearObjectFDType,
377                                 clang::VK_RValue,
378                                 nullptr);
379
380  clang::Expr *RSClearObjectFP =
381      clang::ImplicitCastExpr::Create(C,
382                                      C.getPointerType(ClearObjectFDType),
383                                      clang::CK_FunctionToPointerDecay,
384                                      RefRSClearObjectFD,
385                                      nullptr,
386                                      clang::VK_RValue);
387
388  llvm::SmallVector<clang::Expr*, 1> ArgList;
389  ArgList.push_back(AddrRefRSVar);
390
391  clang::CallExpr *RSClearObjectCall =
392      new(C) clang::CallExpr(C,
393                             RSClearObjectFP,
394                             ArgList,
395                             ClearObjectFD->getCallResultType(),
396                             clang::VK_RValue,
397                             Loc);
398
399  return RSClearObjectCall;
400}
401
402static int ArrayDim(const clang::Type *T) {
403  if (!T || !T->isArrayType()) {
404    return 0;
405  }
406
407  const clang::ConstantArrayType *CAT =
408    static_cast<const clang::ConstantArrayType *>(T);
409  return static_cast<int>(CAT->getSize().getSExtValue());
410}
411
412clang::Stmt *ClearStructRSObject(
413    clang::ASTContext &C,
414    clang::DeclContext *DC,
415    clang::Expr *RefRSStruct,
416    clang::SourceLocation StartLoc,
417    clang::SourceLocation Loc);
418
419clang::Stmt *ClearArrayRSObject(
420    clang::ASTContext &C,
421    clang::DeclContext *DC,
422    clang::Expr *RefRSArr,
423    clang::SourceLocation StartLoc,
424    clang::SourceLocation Loc) {
425  const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
426  slangAssert(BaseType->isArrayType());
427
428  int NumArrayElements = ArrayDim(BaseType);
429  // Actually extract out the base RS object type for use later
430  BaseType = BaseType->getArrayElementTypeNoTypeQual();
431
432  if (NumArrayElements <= 0) {
433    return nullptr;
434  }
435
436  // Example destructor loop for "rs_font fontArr[10];"
437  //
438  // (ForStmt
439  //   (DeclStmt
440  //     (VarDecl used rsIntIter 'int' cinit
441  //       (IntegerLiteral 'int' 0)))
442  //   (BinaryOperator 'int' '<'
443  //     (ImplicitCastExpr int LValueToRValue
444  //       (DeclRefExpr 'int' Var='rsIntIter'))
445  //     (IntegerLiteral 'int' 10)
446  //   nullptr << CondVar >>
447  //   (UnaryOperator 'int' postfix '++'
448  //     (DeclRefExpr 'int' Var='rsIntIter'))
449  //   (CallExpr 'void'
450  //     (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
451  //       (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
452  //     (UnaryOperator 'rs_font *' prefix '&'
453  //       (ArraySubscriptExpr 'rs_font':'rs_font'
454  //         (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
455  //           (DeclRefExpr 'rs_font [10]' Var='fontArr'))
456  //         (DeclRefExpr 'int' Var='rsIntIter'))))))
457
458  // Create helper variable for iterating through elements
459  static unsigned sIterCounter = 0;
460  std::stringstream UniqueIterName;
461  UniqueIterName << "rsIntIter" << sIterCounter++;
462  clang::IdentifierInfo *II = &C.Idents.get(UniqueIterName.str());
463  clang::VarDecl *IIVD =
464      clang::VarDecl::Create(C,
465                             DC,
466                             StartLoc,
467                             Loc,
468                             II,
469                             C.IntTy,
470                             C.getTrivialTypeSourceInfo(C.IntTy),
471                             clang::SC_None);
472  // Mark "rsIntIter" as used
473  IIVD->markUsed(C);
474
475  // Form the actual destructor loop
476  // for (Init; Cond; Inc)
477  //   RSClearObjectCall;
478
479  // Init -> "int rsIntIter = 0"
480  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
481      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
482  IIVD->setInit(Int0);
483
484  clang::Decl *IID = (clang::Decl *)IIVD;
485  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
486  clang::Stmt *Init = new(C) clang::DeclStmt(DGR, Loc, Loc);
487
488  // Cond -> "rsIntIter < NumArrayElements"
489  clang::DeclRefExpr *RefrsIntIterLValue =
490      clang::DeclRefExpr::Create(C,
491                                 clang::NestedNameSpecifierLoc(),
492                                 clang::SourceLocation(),
493                                 IIVD,
494                                 false,
495                                 Loc,
496                                 C.IntTy,
497                                 clang::VK_LValue,
498                                 nullptr);
499
500  clang::Expr *RefrsIntIterRValue =
501      clang::ImplicitCastExpr::Create(C,
502                                      RefrsIntIterLValue->getType(),
503                                      clang::CK_LValueToRValue,
504                                      RefrsIntIterLValue,
505                                      nullptr,
506                                      clang::VK_RValue);
507
508  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
509      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
510
511  clang::BinaryOperator *Cond =
512      new(C) clang::BinaryOperator(RefrsIntIterRValue,
513                                   NumArrayElementsExpr,
514                                   clang::BO_LT,
515                                   C.IntTy,
516                                   clang::VK_RValue,
517                                   clang::OK_Ordinary,
518                                   Loc,
519                                   false);
520
521  // Inc -> "rsIntIter++"
522  clang::UnaryOperator *Inc =
523      new(C) clang::UnaryOperator(RefrsIntIterLValue,
524                                  clang::UO_PostInc,
525                                  C.IntTy,
526                                  clang::VK_RValue,
527                                  clang::OK_Ordinary,
528                                  Loc);
529
530  // Body -> "rsClearObject(&VD[rsIntIter]);"
531  // Destructor loop operates on individual array elements
532
533  clang::Expr *RefRSArrPtr =
534      clang::ImplicitCastExpr::Create(C,
535          C.getPointerType(BaseType->getCanonicalTypeInternal()),
536          clang::CK_ArrayToPointerDecay,
537          RefRSArr,
538          nullptr,
539          clang::VK_RValue);
540
541  clang::Expr *RefRSArrPtrSubscript =
542      new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
543                                       RefrsIntIterRValue,
544                                       BaseType->getCanonicalTypeInternal(),
545                                       clang::VK_RValue,
546                                       clang::OK_Ordinary,
547                                       Loc);
548
549  DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
550
551  clang::Stmt *RSClearObjectCall = nullptr;
552  if (BaseType->isArrayType()) {
553    RSClearObjectCall =
554        ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
555  } else if (DT == DataTypeUnknown) {
556    RSClearObjectCall =
557        ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
558  } else {
559    RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
560  }
561
562  clang::ForStmt *DestructorLoop =
563      new(C) clang::ForStmt(C,
564                            Init,
565                            Cond,
566                            nullptr,  // no condVar
567                            Inc,
568                            RSClearObjectCall,
569                            Loc,
570                            Loc,
571                            Loc);
572
573  return DestructorLoop;
574}
575
576unsigned CountRSObjectTypes(const clang::Type *T) {
577  slangAssert(T);
578  unsigned RSObjectCount = 0;
579
580  if (T->isArrayType()) {
581    return CountRSObjectTypes(T->getArrayElementTypeNoTypeQual());
582  }
583
584  DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
585  if (DT != DataTypeUnknown) {
586    return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
587  }
588
589  if (T->isUnionType()) {
590    clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
591    RD = RD->getDefinition();
592    for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
593           FE = RD->field_end();
594         FI != FE;
595         FI++) {
596      const clang::FieldDecl *FD = *FI;
597      const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
598      if (CountRSObjectTypes(FT)) {
599        slangAssert(false && "can't have unions with RS object types!");
600        return 0;
601      }
602    }
603  }
604
605  if (!T->isStructureType()) {
606    return 0;
607  }
608
609  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
610  RD = RD->getDefinition();
611  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
612         FE = RD->field_end();
613       FI != FE;
614       FI++) {
615    const clang::FieldDecl *FD = *FI;
616    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
617    if (CountRSObjectTypes(FT)) {
618      // Sub-structs should only count once (as should arrays, etc.)
619      RSObjectCount++;
620    }
621  }
622
623  return RSObjectCount;
624}
625
626clang::Stmt *ClearStructRSObject(
627    clang::ASTContext &C,
628    clang::DeclContext *DC,
629    clang::Expr *RefRSStruct,
630    clang::SourceLocation StartLoc,
631    clang::SourceLocation Loc) {
632  const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
633
634  slangAssert(!BaseType->isArrayType());
635
636  // Structs should show up as unknown primitive types
637  slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
638              DataTypeUnknown);
639
640  unsigned FieldsToDestroy = CountRSObjectTypes(BaseType);
641  slangAssert(FieldsToDestroy != 0);
642
643  unsigned StmtCount = 0;
644  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
645  for (unsigned i = 0; i < FieldsToDestroy; i++) {
646    StmtArray[i] = nullptr;
647  }
648
649  // Populate StmtArray by creating a destructor for each RS object field
650  clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
651  RD = RD->getDefinition();
652  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
653         FE = RD->field_end();
654       FI != FE;
655       FI++) {
656    // We just look through all field declarations to see if we find a
657    // declaration for an RS object type (or an array of one).
658    bool IsArrayType = false;
659    clang::FieldDecl *FD = *FI;
660    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
661    const clang::Type *OrigType = FT;
662    while (FT && FT->isArrayType()) {
663      FT = FT->getArrayElementTypeNoTypeQual();
664      IsArrayType = true;
665    }
666
667    // Pass a DeclarationNameInfo with a valid DeclName, since name equality
668    // gets asserted during CodeGen.
669    clang::DeclarationNameInfo FDDeclNameInfo(FD->getDeclName(),
670                                              FD->getLocation());
671
672    if (RSExportPrimitiveType::IsRSObjectType(FT)) {
673      clang::DeclAccessPair FoundDecl =
674          clang::DeclAccessPair::make(FD, clang::AS_none);
675      clang::MemberExpr *RSObjectMember =
676          clang::MemberExpr::Create(C,
677                                    RefRSStruct,
678                                    false,
679                                    clang::SourceLocation(),
680                                    clang::NestedNameSpecifierLoc(),
681                                    clang::SourceLocation(),
682                                    FD,
683                                    FoundDecl,
684                                    FDDeclNameInfo,
685                                    nullptr,
686                                    OrigType->getCanonicalTypeInternal(),
687                                    clang::VK_RValue,
688                                    clang::OK_Ordinary);
689
690      slangAssert(StmtCount < FieldsToDestroy);
691
692      if (IsArrayType) {
693        StmtArray[StmtCount++] = ClearArrayRSObject(C,
694                                                    DC,
695                                                    RSObjectMember,
696                                                    StartLoc,
697                                                    Loc);
698      } else {
699        StmtArray[StmtCount++] = ClearSingleRSObject(C,
700                                                     RSObjectMember,
701                                                     Loc);
702      }
703    } else if (FT->isStructureType() && CountRSObjectTypes(FT)) {
704      // In this case, we have a nested struct. We may not end up filling all
705      // of the spaces in StmtArray (sub-structs should handle themselves
706      // with separate compound statements).
707      clang::DeclAccessPair FoundDecl =
708          clang::DeclAccessPair::make(FD, clang::AS_none);
709      clang::MemberExpr *RSObjectMember =
710          clang::MemberExpr::Create(C,
711                                    RefRSStruct,
712                                    false,
713                                    clang::SourceLocation(),
714                                    clang::NestedNameSpecifierLoc(),
715                                    clang::SourceLocation(),
716                                    FD,
717                                    FoundDecl,
718                                    clang::DeclarationNameInfo(),
719                                    nullptr,
720                                    OrigType->getCanonicalTypeInternal(),
721                                    clang::VK_RValue,
722                                    clang::OK_Ordinary);
723
724      if (IsArrayType) {
725        StmtArray[StmtCount++] = ClearArrayRSObject(C,
726                                                    DC,
727                                                    RSObjectMember,
728                                                    StartLoc,
729                                                    Loc);
730      } else {
731        StmtArray[StmtCount++] = ClearStructRSObject(C,
732                                                     DC,
733                                                     RSObjectMember,
734                                                     StartLoc,
735                                                     Loc);
736      }
737    }
738  }
739
740  slangAssert(StmtCount > 0);
741  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
742      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
743
744  delete [] StmtArray;
745
746  return CS;
747}
748
749clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
750                                     clang::Expr *DstExpr,
751                                     clang::Expr *SrcExpr,
752                                     clang::SourceLocation StartLoc,
753                                     clang::SourceLocation Loc) {
754  const clang::Type *T = DstExpr->getType().getTypePtr();
755  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
756  slangAssert((SetObjectFD != nullptr) &&
757              "rsSetObject doesn't cover all RS object types");
758
759  clang::QualType SetObjectFDType = SetObjectFD->getType();
760  clang::QualType SetObjectFDArgType[2];
761  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
762  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
763
764  clang::Expr *RefRSSetObjectFD =
765      clang::DeclRefExpr::Create(C,
766                                 clang::NestedNameSpecifierLoc(),
767                                 clang::SourceLocation(),
768                                 SetObjectFD,
769                                 false,
770                                 Loc,
771                                 SetObjectFDType,
772                                 clang::VK_RValue,
773                                 nullptr);
774
775  clang::Expr *RSSetObjectFP =
776      clang::ImplicitCastExpr::Create(C,
777                                      C.getPointerType(SetObjectFDType),
778                                      clang::CK_FunctionToPointerDecay,
779                                      RefRSSetObjectFD,
780                                      nullptr,
781                                      clang::VK_RValue);
782
783  llvm::SmallVector<clang::Expr*, 2> ArgList;
784  ArgList.push_back(new(C) clang::UnaryOperator(DstExpr,
785                                                clang::UO_AddrOf,
786                                                SetObjectFDArgType[0],
787                                                clang::VK_RValue,
788                                                clang::OK_Ordinary,
789                                                Loc));
790  ArgList.push_back(SrcExpr);
791
792  clang::CallExpr *RSSetObjectCall =
793      new(C) clang::CallExpr(C,
794                             RSSetObjectFP,
795                             ArgList,
796                             SetObjectFD->getCallResultType(),
797                             clang::VK_RValue,
798                             Loc);
799
800  return RSSetObjectCall;
801}
802
803clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
804                                     clang::Expr *LHS,
805                                     clang::Expr *RHS,
806                                     clang::SourceLocation StartLoc,
807                                     clang::SourceLocation Loc);
808
809/*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
810                                           clang::Expr *DstArr,
811                                           clang::Expr *SrcArr,
812                                           clang::SourceLocation StartLoc,
813                                           clang::SourceLocation Loc) {
814  clang::DeclContext *DC = nullptr;
815  const clang::Type *BaseType = DstArr->getType().getTypePtr();
816  slangAssert(BaseType->isArrayType());
817
818  int NumArrayElements = ArrayDim(BaseType);
819  // Actually extract out the base RS object type for use later
820  BaseType = BaseType->getArrayElementTypeNoTypeQual();
821
822  clang::Stmt *StmtArray[2] = {nullptr};
823  int StmtCtr = 0;
824
825  if (NumArrayElements <= 0) {
826    return nullptr;
827  }
828
829  // Create helper variable for iterating through elements
830  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
831  clang::VarDecl *IIVD =
832      clang::VarDecl::Create(C,
833                             DC,
834                             StartLoc,
835                             Loc,
836                             &II,
837                             C.IntTy,
838                             C.getTrivialTypeSourceInfo(C.IntTy),
839                             clang::SC_None,
840                             clang::SC_None);
841  clang::Decl *IID = (clang::Decl *)IIVD;
842
843  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
844  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
845
846  // Form the actual loop
847  // for (Init; Cond; Inc)
848  //   RSSetObjectCall;
849
850  // Init -> "rsIntIter = 0"
851  clang::DeclRefExpr *RefrsIntIter =
852      clang::DeclRefExpr::Create(C,
853                                 clang::NestedNameSpecifierLoc(),
854                                 IIVD,
855                                 Loc,
856                                 C.IntTy,
857                                 clang::VK_RValue,
858                                 nullptr);
859
860  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
861      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
862
863  clang::BinaryOperator *Init =
864      new(C) clang::BinaryOperator(RefrsIntIter,
865                                   Int0,
866                                   clang::BO_Assign,
867                                   C.IntTy,
868                                   clang::VK_RValue,
869                                   clang::OK_Ordinary,
870                                   Loc);
871
872  // Cond -> "rsIntIter < NumArrayElements"
873  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
874      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
875
876  clang::BinaryOperator *Cond =
877      new(C) clang::BinaryOperator(RefrsIntIter,
878                                   NumArrayElementsExpr,
879                                   clang::BO_LT,
880                                   C.IntTy,
881                                   clang::VK_RValue,
882                                   clang::OK_Ordinary,
883                                   Loc);
884
885  // Inc -> "rsIntIter++"
886  clang::UnaryOperator *Inc =
887      new(C) clang::UnaryOperator(RefrsIntIter,
888                                  clang::UO_PostInc,
889                                  C.IntTy,
890                                  clang::VK_RValue,
891                                  clang::OK_Ordinary,
892                                  Loc);
893
894  // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
895  // Loop operates on individual array elements
896
897  clang::Expr *DstArrPtr =
898      clang::ImplicitCastExpr::Create(C,
899          C.getPointerType(BaseType->getCanonicalTypeInternal()),
900          clang::CK_ArrayToPointerDecay,
901          DstArr,
902          nullptr,
903          clang::VK_RValue);
904
905  clang::Expr *DstArrPtrSubscript =
906      new(C) clang::ArraySubscriptExpr(DstArrPtr,
907                                       RefrsIntIter,
908                                       BaseType->getCanonicalTypeInternal(),
909                                       clang::VK_RValue,
910                                       clang::OK_Ordinary,
911                                       Loc);
912
913  clang::Expr *SrcArrPtr =
914      clang::ImplicitCastExpr::Create(C,
915          C.getPointerType(BaseType->getCanonicalTypeInternal()),
916          clang::CK_ArrayToPointerDecay,
917          SrcArr,
918          nullptr,
919          clang::VK_RValue);
920
921  clang::Expr *SrcArrPtrSubscript =
922      new(C) clang::ArraySubscriptExpr(SrcArrPtr,
923                                       RefrsIntIter,
924                                       BaseType->getCanonicalTypeInternal(),
925                                       clang::VK_RValue,
926                                       clang::OK_Ordinary,
927                                       Loc);
928
929  DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
930
931  clang::Stmt *RSSetObjectCall = nullptr;
932  if (BaseType->isArrayType()) {
933    RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
934                                             SrcArrPtrSubscript,
935                                             StartLoc, Loc);
936  } else if (DT == DataTypeUnknown) {
937    RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
938                                              SrcArrPtrSubscript,
939                                              StartLoc, Loc);
940  } else {
941    RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
942                                              SrcArrPtrSubscript,
943                                              StartLoc, Loc);
944  }
945
946  clang::ForStmt *DestructorLoop =
947      new(C) clang::ForStmt(C,
948                            Init,
949                            Cond,
950                            nullptr,  // no condVar
951                            Inc,
952                            RSSetObjectCall,
953                            Loc,
954                            Loc,
955                            Loc);
956
957  StmtArray[StmtCtr++] = DestructorLoop;
958  slangAssert(StmtCtr == 2);
959
960  clang::CompoundStmt *CS =
961      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
962
963  return CS;
964} */
965
966clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
967                                     clang::Expr *LHS,
968                                     clang::Expr *RHS,
969                                     clang::SourceLocation StartLoc,
970                                     clang::SourceLocation Loc) {
971  clang::QualType QT = LHS->getType();
972  const clang::Type *T = QT.getTypePtr();
973  slangAssert(T->isStructureType());
974  slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
975
976  // Keep an extra slot for the original copy (memcpy)
977  unsigned FieldsToSet = CountRSObjectTypes(T) + 1;
978
979  unsigned StmtCount = 0;
980  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
981  for (unsigned i = 0; i < FieldsToSet; i++) {
982    StmtArray[i] = nullptr;
983  }
984
985  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
986  RD = RD->getDefinition();
987  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
988         FE = RD->field_end();
989       FI != FE;
990       FI++) {
991    bool IsArrayType = false;
992    clang::FieldDecl *FD = *FI;
993    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
994    const clang::Type *OrigType = FT;
995
996    if (!CountRSObjectTypes(FT)) {
997      // Skip to next if we don't have any viable RS object types
998      continue;
999    }
1000
1001    clang::DeclAccessPair FoundDecl =
1002        clang::DeclAccessPair::make(FD, clang::AS_none);
1003    clang::MemberExpr *DstMember =
1004        clang::MemberExpr::Create(C,
1005                                  LHS,
1006                                  false,
1007                                  clang::SourceLocation(),
1008                                  clang::NestedNameSpecifierLoc(),
1009                                  clang::SourceLocation(),
1010                                  FD,
1011                                  FoundDecl,
1012                                  clang::DeclarationNameInfo(),
1013                                  nullptr,
1014                                  OrigType->getCanonicalTypeInternal(),
1015                                  clang::VK_RValue,
1016                                  clang::OK_Ordinary);
1017
1018    clang::MemberExpr *SrcMember =
1019        clang::MemberExpr::Create(C,
1020                                  RHS,
1021                                  false,
1022                                  clang::SourceLocation(),
1023                                  clang::NestedNameSpecifierLoc(),
1024                                  clang::SourceLocation(),
1025                                  FD,
1026                                  FoundDecl,
1027                                  clang::DeclarationNameInfo(),
1028                                  nullptr,
1029                                  OrigType->getCanonicalTypeInternal(),
1030                                  clang::VK_RValue,
1031                                  clang::OK_Ordinary);
1032
1033    if (FT->isArrayType()) {
1034      FT = FT->getArrayElementTypeNoTypeQual();
1035      IsArrayType = true;
1036    }
1037
1038    DataType DT = RSExportPrimitiveType::GetRSSpecificType(FT);
1039
1040    if (IsArrayType) {
1041      clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
1042      DiagEngine.Report(
1043        clang::FullSourceLoc(Loc, C.getSourceManager()),
1044        DiagEngine.getCustomDiagID(
1045          clang::DiagnosticsEngine::Error,
1046          "Arrays of RS object types within structures cannot be copied"));
1047      // TODO(srhines): Support setting arrays of RS objects
1048      // StmtArray[StmtCount++] =
1049      //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1050    } else if (DT == DataTypeUnknown) {
1051      StmtArray[StmtCount++] =
1052          CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1053    } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
1054      StmtArray[StmtCount++] =
1055          CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1056    } else {
1057      slangAssert(false);
1058    }
1059  }
1060
1061  slangAssert(StmtCount < FieldsToSet);
1062
1063  // We still need to actually do the overall struct copy. For simplicity,
1064  // we just do a straight-up assignment (which will still preserve all
1065  // the proper RS object reference counts).
1066  clang::BinaryOperator *CopyStruct =
1067      new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
1068                                   clang::VK_RValue, clang::OK_Ordinary, Loc,
1069                                   false);
1070  StmtArray[StmtCount++] = CopyStruct;
1071
1072  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
1073      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
1074
1075  delete [] StmtArray;
1076
1077  return CS;
1078}
1079
1080}  // namespace
1081
1082void RSObjectRefCount::Scope::InsertStmt(const clang::ASTContext &C,
1083                                         clang::Stmt *NewStmt) {
1084  std::vector<clang::Stmt*> newBody;
1085  for (clang::Stmt* S1 : mCS->body()) {
1086    if (S1 == mCurrent) {
1087      newBody.push_back(NewStmt);
1088    }
1089    newBody.push_back(S1);
1090  }
1091  mCS->setStmts(C, newBody.data(), newBody.size());
1092}
1093
1094void RSObjectRefCount::Scope::ReplaceStmt(const clang::ASTContext &C,
1095                                          clang::Stmt *NewStmt) {
1096  std::vector<clang::Stmt*> newBody;
1097  for (clang::Stmt* S1 : mCS->body()) {
1098    if (S1 == mCurrent) {
1099      newBody.push_back(NewStmt);
1100    } else {
1101      newBody.push_back(S1);
1102    }
1103  }
1104  mCS->setStmts(C, newBody.data(), newBody.size());
1105}
1106
1107void RSObjectRefCount::Scope::ReplaceExpr(const clang::ASTContext& C,
1108                                          clang::Expr* OldExpr,
1109                                          clang::Expr* NewExpr) {
1110  RSASTReplace R(C);
1111  R.ReplaceStmt(mCurrent, OldExpr, NewExpr);
1112}
1113
1114void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1115    clang::BinaryOperator *AS) {
1116
1117  clang::QualType QT = AS->getType();
1118
1119  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1120      DataTypeRSAllocation)->getASTContext();
1121
1122  clang::SourceLocation Loc = AS->getExprLoc();
1123  clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
1124  clang::Stmt *UpdatedStmt = nullptr;
1125
1126  if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1127    // By definition, this is a struct assignment if we get here
1128    UpdatedStmt =
1129        CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1130  } else {
1131    UpdatedStmt =
1132        CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1133  }
1134
1135  RSASTReplace R(C);
1136  R.ReplaceStmt(mCS, AS, UpdatedStmt);
1137}
1138
1139void RSObjectRefCount::Scope::AppendRSObjectInit(
1140    clang::VarDecl *VD,
1141    clang::DeclStmt *DS,
1142    DataType DT,
1143    clang::Expr *InitExpr) {
1144  slangAssert(VD);
1145
1146  if (!InitExpr) {
1147    return;
1148  }
1149
1150  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1151      DataTypeRSAllocation)->getASTContext();
1152  clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1153      DataTypeRSAllocation)->getLocation();
1154  clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1155      DataTypeRSAllocation)->getInnerLocStart();
1156
1157  if (DT == DataTypeIsStruct) {
1158    const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1159    clang::DeclRefExpr *RefRSVar =
1160        clang::DeclRefExpr::Create(C,
1161                                   clang::NestedNameSpecifierLoc(),
1162                                   clang::SourceLocation(),
1163                                   VD,
1164                                   false,
1165                                   Loc,
1166                                   T->getCanonicalTypeInternal(),
1167                                   clang::VK_RValue,
1168                                   nullptr);
1169
1170    clang::Stmt *RSSetObjectOps =
1171        CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1172
1173    std::list<clang::Stmt*> StmtList;
1174    StmtList.push_back(RSSetObjectOps);
1175    AppendAfterStmt(C, mCS, DS, StmtList);
1176    return;
1177  }
1178
1179  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1180  slangAssert((SetObjectFD != nullptr) &&
1181              "rsSetObject doesn't cover all RS object types");
1182
1183  clang::QualType SetObjectFDType = SetObjectFD->getType();
1184  clang::QualType SetObjectFDArgType[2];
1185  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1186  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1187
1188  clang::Expr *RefRSSetObjectFD =
1189      clang::DeclRefExpr::Create(C,
1190                                 clang::NestedNameSpecifierLoc(),
1191                                 clang::SourceLocation(),
1192                                 SetObjectFD,
1193                                 false,
1194                                 Loc,
1195                                 SetObjectFDType,
1196                                 clang::VK_RValue,
1197                                 nullptr);
1198
1199  clang::Expr *RSSetObjectFP =
1200      clang::ImplicitCastExpr::Create(C,
1201                                      C.getPointerType(SetObjectFDType),
1202                                      clang::CK_FunctionToPointerDecay,
1203                                      RefRSSetObjectFD,
1204                                      nullptr,
1205                                      clang::VK_RValue);
1206
1207  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1208  clang::DeclRefExpr *RefRSVar =
1209      clang::DeclRefExpr::Create(C,
1210                                 clang::NestedNameSpecifierLoc(),
1211                                 clang::SourceLocation(),
1212                                 VD,
1213                                 false,
1214                                 Loc,
1215                                 T->getCanonicalTypeInternal(),
1216                                 clang::VK_RValue,
1217                                 nullptr);
1218
1219  llvm::SmallVector<clang::Expr*, 2> ArgList;
1220  ArgList.push_back(new(C) clang::UnaryOperator(RefRSVar,
1221                                                clang::UO_AddrOf,
1222                                                SetObjectFDArgType[0],
1223                                                clang::VK_RValue,
1224                                                clang::OK_Ordinary,
1225                                                Loc));
1226  ArgList.push_back(InitExpr);
1227
1228  clang::CallExpr *RSSetObjectCall =
1229      new(C) clang::CallExpr(C,
1230                             RSSetObjectFP,
1231                             ArgList,
1232                             SetObjectFD->getCallResultType(),
1233                             clang::VK_RValue,
1234                             Loc);
1235
1236  std::list<clang::Stmt*> StmtList;
1237  StmtList.push_back(RSSetObjectCall);
1238  AppendAfterStmt(C, mCS, DS, StmtList);
1239}
1240
1241void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1242  if (mRSO.empty()) {
1243    return;
1244  }
1245
1246  clang::DeclContext* DC = mRSO.front()->getDeclContext();
1247  clang::ASTContext& C = DC->getParentASTContext();
1248  clang::SourceManager& SM = C.getSourceManager();
1249
1250  const auto& OccursBefore = [&SM] (clang::SourceLocation L1, clang::SourceLocation L2)->bool {
1251    return SM.isBeforeInTranslationUnit(L1, L2);
1252  };
1253  typedef std::map<clang::SourceLocation, clang::Stmt*, decltype(OccursBefore)> DMap;
1254
1255  DMap dtors(OccursBefore);
1256
1257  // Create rsClearObject calls. Note the DMap entries are sorted by the SourceLocation.
1258  for (clang::VarDecl* VD : mRSO) {
1259    clang::SourceLocation Loc = VD->getSourceRange().getBegin();
1260    clang::Stmt* RSClearObjectCall = ClearRSObject(VD, DC);
1261    dtors.insert(std::make_pair(Loc, RSClearObjectCall));
1262  }
1263
1264  DestructorVisitor Visitor;
1265  Visitor.Visit(mCS);
1266
1267  // Replace each exiting statement with a block that contains the original statement
1268  // and added rsClearObject() calls before it.
1269  for (clang::Stmt* S : Visitor.getExitingStmts()) {
1270
1271    const clang::SourceLocation currentLoc = S->getLocStart();
1272
1273    DMap::iterator firstDtorIter = dtors.begin();
1274    DMap::iterator currentDtorIter = firstDtorIter;
1275    DMap::iterator lastDtorIter = dtors.end();
1276
1277    while (currentDtorIter != lastDtorIter &&
1278           OccursBefore(currentDtorIter->first, currentLoc)) {
1279      currentDtorIter++;
1280    }
1281
1282    if (currentDtorIter == firstDtorIter) {
1283      continue;
1284    }
1285
1286    std::list<clang::Stmt*> Stmts;
1287
1288    // Insert rsClearObject() calls for all rsObjects declared before the current statement
1289    for(DMap::iterator it = firstDtorIter; it != currentDtorIter; it++) {
1290      Stmts.push_back(it->second);
1291    }
1292    Stmts.push_back(S);
1293
1294    RSASTReplace R(C);
1295    clang::CompoundStmt* CS = BuildCompoundStmt(C, Stmts, S->getLocEnd());
1296    R.ReplaceStmt(mCS, S, CS);
1297  }
1298
1299  std::list<clang::Stmt*> Stmts;
1300  for(auto LocCallPair : dtors) {
1301    Stmts.push_back(LocCallPair.second);
1302  }
1303  AppendAfterStmt(C, mCS, nullptr, Stmts);
1304}
1305
1306clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
1307    clang::VarDecl *VD,
1308    clang::DeclContext *DC) {
1309  slangAssert(VD);
1310  clang::ASTContext &C = VD->getASTContext();
1311  clang::SourceLocation Loc = VD->getLocation();
1312  clang::SourceLocation StartLoc = VD->getInnerLocStart();
1313  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1314
1315  // Reference expr to target RS object variable
1316  clang::DeclRefExpr *RefRSVar =
1317      clang::DeclRefExpr::Create(C,
1318                                 clang::NestedNameSpecifierLoc(),
1319                                 clang::SourceLocation(),
1320                                 VD,
1321                                 false,
1322                                 Loc,
1323                                 T->getCanonicalTypeInternal(),
1324                                 clang::VK_RValue,
1325                                 nullptr);
1326
1327  if (T->isArrayType()) {
1328    return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1329  }
1330
1331  DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
1332
1333  if (DT == DataTypeUnknown ||
1334      DT == DataTypeIsStruct) {
1335    return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1336  }
1337
1338  slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1339              "Should be RS object");
1340
1341  return ClearSingleRSObject(C, RefRSVar, Loc);
1342}
1343
1344bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1345                                          DataType *DT,
1346                                          clang::Expr **InitExpr) {
1347  slangAssert(VD && DT && InitExpr);
1348  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1349
1350  // Loop through array types to get to base type
1351  while (T && T->isArrayType()) {
1352    T = T->getArrayElementTypeNoTypeQual();
1353  }
1354
1355  bool DataTypeIsStructWithRSObject = false;
1356  *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1357
1358  if (*DT == DataTypeUnknown) {
1359    if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1360      *DT = DataTypeIsStruct;
1361      DataTypeIsStructWithRSObject = true;
1362    } else {
1363      return false;
1364    }
1365  }
1366
1367  bool DataTypeIsRSObject = false;
1368  if (DataTypeIsStructWithRSObject) {
1369    DataTypeIsRSObject = true;
1370  } else {
1371    DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1372  }
1373  *InitExpr = VD->getInit();
1374
1375  if (!DataTypeIsRSObject && *InitExpr) {
1376    // If we already have an initializer for a matrix type, we are done.
1377    return DataTypeIsRSObject;
1378  }
1379
1380  clang::Expr *ZeroInitializer =
1381      CreateZeroInitializerForRSSpecificType(*DT,
1382                                             VD->getASTContext(),
1383                                             VD->getLocation());
1384
1385  if (ZeroInitializer) {
1386    ZeroInitializer->setType(T->getCanonicalTypeInternal());
1387    VD->setInit(ZeroInitializer);
1388  }
1389
1390  return DataTypeIsRSObject;
1391}
1392
1393clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
1394    DataType DT,
1395    clang::ASTContext &C,
1396    const clang::SourceLocation &Loc) {
1397  clang::Expr *Res = nullptr;
1398  switch (DT) {
1399    case DataTypeIsStruct:
1400    case DataTypeRSElement:
1401    case DataTypeRSType:
1402    case DataTypeRSAllocation:
1403    case DataTypeRSSampler:
1404    case DataTypeRSScript:
1405    case DataTypeRSMesh:
1406    case DataTypeRSPath:
1407    case DataTypeRSProgramFragment:
1408    case DataTypeRSProgramVertex:
1409    case DataTypeRSProgramRaster:
1410    case DataTypeRSProgramStore:
1411    case DataTypeRSFont: {
1412      //    (ImplicitCastExpr 'nullptr_t'
1413      //      (IntegerLiteral 0)))
1414      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
1415      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
1416      clang::Expr *CastToNull =
1417          clang::ImplicitCastExpr::Create(C,
1418                                          C.NullPtrTy,
1419                                          clang::CK_IntegralToPointer,
1420                                          Int0,
1421                                          nullptr,
1422                                          clang::VK_RValue);
1423
1424      llvm::SmallVector<clang::Expr*, 1>InitList;
1425      InitList.push_back(CastToNull);
1426
1427      Res = new(C) clang::InitListExpr(C, Loc, InitList, Loc);
1428      break;
1429    }
1430    case DataTypeRSMatrix2x2:
1431    case DataTypeRSMatrix3x3:
1432    case DataTypeRSMatrix4x4: {
1433      // RS matrix is not completely an RS object. They hold data by themselves.
1434      // (InitListExpr rs_matrix2x2
1435      //   (InitListExpr float[4]
1436      //     (FloatingLiteral 0)
1437      //     (FloatingLiteral 0)
1438      //     (FloatingLiteral 0)
1439      //     (FloatingLiteral 0)))
1440      clang::QualType FloatTy = C.FloatTy;
1441      // Constructor sets value to 0.0f by default
1442      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
1443      clang::FloatingLiteral *Float0Val =
1444          clang::FloatingLiteral::Create(C,
1445                                         Val,
1446                                         /* isExact = */true,
1447                                         FloatTy,
1448                                         Loc);
1449
1450      unsigned N = 0;
1451      if (DT == DataTypeRSMatrix2x2) {
1452        N = 2;
1453      } else if (DT == DataTypeRSMatrix3x3) {
1454        N = 3;
1455      } else if (DT == DataTypeRSMatrix4x4) {
1456        N = 4;
1457      }
1458      unsigned N_2 = N * N;
1459
1460      // Assume we are going to be allocating 16 elements, since 4x4 is max.
1461      llvm::SmallVector<clang::Expr*, 16> InitVals;
1462      for (unsigned i = 0; i < N_2; i++)
1463        InitVals.push_back(Float0Val);
1464      clang::Expr *InitExpr =
1465          new(C) clang::InitListExpr(C, Loc, InitVals, Loc);
1466      InitExpr->setType(C.getConstantArrayType(FloatTy,
1467                                               llvm::APInt(32, N_2),
1468                                               clang::ArrayType::Normal,
1469                                               /* EltTypeQuals = */0));
1470      llvm::SmallVector<clang::Expr*, 1> InitExprVec;
1471      InitExprVec.push_back(InitExpr);
1472
1473      Res = new(C) clang::InitListExpr(C, Loc, InitExprVec, Loc);
1474      break;
1475    }
1476    case DataTypeUnknown:
1477    case DataTypeFloat16:
1478    case DataTypeFloat32:
1479    case DataTypeFloat64:
1480    case DataTypeSigned8:
1481    case DataTypeSigned16:
1482    case DataTypeSigned32:
1483    case DataTypeSigned64:
1484    case DataTypeUnsigned8:
1485    case DataTypeUnsigned16:
1486    case DataTypeUnsigned32:
1487    case DataTypeUnsigned64:
1488    case DataTypeBoolean:
1489    case DataTypeUnsigned565:
1490    case DataTypeUnsigned5551:
1491    case DataTypeUnsigned4444:
1492    case DataTypeMax: {
1493      slangAssert(false && "Not RS object type!");
1494    }
1495    // No default case will enable compiler detecting the missing cases
1496  }
1497
1498  return Res;
1499}
1500
1501void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1502  VisitStmt(DS);
1503  getCurrentScope()->setCurrentStmt(DS);
1504  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1505       I != E;
1506       I++) {
1507    clang::Decl *D = *I;
1508    if (D->getKind() == clang::Decl::Var) {
1509      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1510      DataType DT = DataTypeUnknown;
1511      clang::Expr *InitExpr = nullptr;
1512      if (InitializeRSObject(VD, &DT, &InitExpr)) {
1513        // We need to zero-init all RS object types (including matrices), ...
1514        getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1515        // ... but, only add to the list of RS objects if we have some
1516        // non-matrix RS object fields.
1517        if (CountRSObjectTypes(VD->getType().getTypePtr())) {
1518          getCurrentScope()->addRSObject(VD);
1519        }
1520      }
1521    }
1522  }
1523}
1524
1525void RSObjectRefCount::VisitCallExpr(clang::CallExpr* CE) {
1526  clang::QualType RetTy;
1527  const clang::FunctionDecl* FD = CE->getDirectCallee();
1528
1529  if (FD) {
1530    // Direct calls
1531
1532    RetTy = FD->getReturnType();
1533  } else {
1534    // Indirect calls through function pointers
1535
1536    const clang::Expr* Callee = CE->getCallee();
1537    const clang::Type* CalleeType = Callee->getType().getTypePtr();
1538    const clang::PointerType* PtrType = CalleeType->getAs<clang::PointerType>();
1539
1540    if (!PtrType) {
1541      return;
1542    }
1543
1544    const clang::Type* PointeeType = PtrType->getPointeeType().getTypePtr();
1545    const clang::FunctionType* FuncType = PointeeType->getAs<clang::FunctionType>();
1546
1547    if (!FuncType) {
1548      return;
1549    }
1550
1551    RetTy = FuncType->getReturnType();
1552  }
1553
1554  if (CountRSObjectTypes(RetTy.getTypePtr())==0) {
1555    return;
1556  }
1557
1558  clang::SourceLocation Loc = CE->getSourceRange().getBegin();
1559  std::stringstream ss;
1560  ss << ".rs.tmp" << getNextID();
1561  llvm::StringRef VarName(ss.str());
1562
1563  clang::VarDecl* TempVarDecl = clang::VarDecl::Create(
1564      mCtx,                                  // AST context
1565      GetDeclContext(),                      // Decl context
1566      Loc,                                   // Start location
1567      Loc,                                   // Id location
1568      &mCtx.Idents.get(VarName),             // Id
1569      RetTy,                                 // Type
1570      mCtx.getTrivialTypeSourceInfo(RetTy),  // Type info
1571      clang::SC_None                         // Storage class
1572  );
1573  TempVarDecl->setInit(CE);
1574  clang::Decl* Decls[] = { TempVarDecl };
1575  const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
1576      mCtx, Decls, sizeof(Decls) / sizeof(*Decls));
1577  clang::DeclStmt* DS = new (mCtx) clang::DeclStmt(DGR, Loc, Loc);
1578
1579  getCurrentScope()->InsertStmt(mCtx, DS);
1580
1581  clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
1582      mCtx,                                  // AST context
1583      clang::NestedNameSpecifierLoc(),       // QualifierLoc
1584      Loc,                                   // TemplateKWLoc
1585      TempVarDecl,
1586      false,                                 // RefersToEnclosingVariableOrCapture
1587      Loc,                                   // NameLoc
1588      RetTy,
1589      clang::VK_LValue
1590  );
1591  clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
1592      mCtx,
1593      RetTy,
1594      clang::CK_LValueToRValue,
1595      DRE,
1596      nullptr,
1597      clang::VK_RValue
1598  );
1599
1600  getCurrentScope()->ReplaceExpr(mCtx, CE, CastExpr);
1601
1602  // Register TempVarDecl for destruction call (rsClearObj).
1603  getCurrentScope()->addRSObject(TempVarDecl);
1604}
1605
1606void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1607  if (!emptyScope()) {
1608    getCurrentScope()->setCurrentStmt(CS);
1609  }
1610
1611  if (!CS->body_empty()) {
1612    // Push a new scope
1613    Scope *S = new Scope(CS);
1614    mScopeStack.push_back(S);
1615
1616    VisitStmt(CS);
1617
1618    // Destroy the scope
1619    slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1620    S->InsertLocalVarDestructors();
1621    mScopeStack.pop_back();
1622    delete S;
1623  }
1624}
1625
1626void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1627  getCurrentScope()->setCurrentStmt(AS);
1628  clang::QualType QT = AS->getType();
1629
1630  if (CountRSObjectTypes(QT.getTypePtr())) {
1631    getCurrentScope()->ReplaceRSObjectAssignment(AS);
1632  }
1633}
1634
1635void RSObjectRefCount::VisitReturnStmt(clang::ReturnStmt *RS) {
1636  getCurrentScope()->setCurrentStmt(RS);
1637
1638  // If there is no local rsObject declared so far, no need to transform the
1639  // return statement.
1640
1641  bool RSObjDeclared = false;
1642
1643  for (const Scope* S : mScopeStack) {
1644    if (S->hasRSObject()) {
1645      RSObjDeclared = true;
1646      break;
1647    }
1648  }
1649
1650  if (!RSObjDeclared) {
1651    return;
1652  }
1653
1654  // If the return statement does not return anything, or if it does not return
1655  // a rsObject, no need to transform it.
1656
1657  clang::Expr* RetVal = RS->getRetValue();
1658  if (!RetVal || CountRSObjectTypes(RetVal->getType().getTypePtr()) == 0) {
1659    return;
1660  }
1661
1662  // Transform the return statement so that it does not potentially return or
1663  // reference a rsObject that has been cleared.
1664
1665  clang::CompoundStmt* NewRS;
1666  NewRS = CreateRetStmtWithTempVar(mCtx, GetDeclContext(), RS, getNextID());
1667
1668  getCurrentScope()->ReplaceStmt(mCtx, NewRS);
1669}
1670
1671void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1672  getCurrentScope()->setCurrentStmt(S);
1673  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1674       I != E;
1675       I++) {
1676    if (clang::Stmt *Child = *I) {
1677      Visit(Child);
1678    }
1679  }
1680}
1681
1682// This function walks the list of global variables and (potentially) creates
1683// a single global static destructor function that properly decrements
1684// reference counts on the contained RS object types.
1685clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
1686  Init();
1687
1688  clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
1689  clang::SourceLocation loc;
1690
1691  llvm::StringRef SR(".rs.dtor");
1692  clang::IdentifierInfo &II = mCtx.Idents.get(SR);
1693  clang::DeclarationName N(&II);
1694  clang::FunctionProtoType::ExtProtoInfo EPI;
1695  clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy,
1696      llvm::ArrayRef<clang::QualType>(), EPI);
1697  clang::FunctionDecl *FD = nullptr;
1698
1699  // Generate rsClearObject() call chains for every global variable
1700  // (whether static or extern).
1701  std::list<clang::Stmt *> StmtList;
1702  for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
1703          E = DC->decls_end(); I != E; I++) {
1704    clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
1705    if (VD) {
1706      if (CountRSObjectTypes(VD->getType().getTypePtr())) {
1707        if (!FD) {
1708          // Only create FD if we are going to use it.
1709          FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, nullptr,
1710                                           clang::SC_None);
1711        }
1712        // Mark VD as used.  It might be unused, except for the destructor.
1713        // 'markUsed' has side-effects that are caused only if VD is not already
1714        // used.  Hence no need for an extra check here.
1715        VD->markUsed(mCtx);
1716        // Make sure to create any helpers within the function's DeclContext,
1717        // not the one associated with the global translation unit.
1718        clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
1719        StmtList.push_back(RSClearObjectCall);
1720      }
1721    }
1722  }
1723
1724  // Nothing needs to be destroyed, so don't emit a dtor.
1725  if (StmtList.empty()) {
1726    return nullptr;
1727  }
1728
1729  clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
1730
1731  FD->setBody(CS);
1732
1733  return FD;
1734}
1735
1736bool HasRSObjectType(const clang::Type *T) {
1737  return CountRSObjectTypes(T) != 0;
1738}
1739
1740}  // namespace slang
1741