slang_rs_object_ref_count.cpp revision cc4d93488b344cbdb0d65c3af076f02dbf2ceb00
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_object_ref_count.h"
18
19#include <list>
20
21#include "clang/AST/DeclGroup.h"
22#include "clang/AST/Expr.h"
23#include "clang/AST/NestedNameSpecifier.h"
24#include "clang/AST/OperationKinds.h"
25#include "clang/AST/Stmt.h"
26#include "clang/AST/StmtVisitor.h"
27
28#include "slang_assert.h"
29#include "slang.h"
30#include "slang_rs_ast_replace.h"
31#include "slang_rs_export_type.h"
32
33namespace slang {
34
35/* Even though those two arrays are of size DataTypeMax, only entries that
36 * correspond to object types will be set.
37 */
38clang::FunctionDecl *
39RSObjectRefCount::RSSetObjectFD[DataTypeMax];
40clang::FunctionDecl *
41RSObjectRefCount::RSClearObjectFD[DataTypeMax];
42
43void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
44  for (unsigned i = 0; i < DataTypeMax; i++) {
45    RSSetObjectFD[i] = nullptr;
46    RSClearObjectFD[i] = nullptr;
47  }
48
49  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
50
51  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
52          E = TUDecl->decls_end(); I != E; I++) {
53    if ((I->getKind() >= clang::Decl::firstFunction) &&
54        (I->getKind() <= clang::Decl::lastFunction)) {
55      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
56
57      // points to RSSetObjectFD or RSClearObjectFD
58      clang::FunctionDecl **RSObjectFD;
59
60      if (FD->getName() == "rsSetObject") {
61        slangAssert((FD->getNumParams() == 2) &&
62                    "Invalid rsSetObject function prototype (# params)");
63        RSObjectFD = RSSetObjectFD;
64      } else if (FD->getName() == "rsClearObject") {
65        slangAssert((FD->getNumParams() == 1) &&
66                    "Invalid rsClearObject function prototype (# params)");
67        RSObjectFD = RSClearObjectFD;
68      } else {
69        continue;
70      }
71
72      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
73      clang::QualType PVT = PVD->getOriginalType();
74      // The first parameter must be a pointer like rs_allocation*
75      slangAssert(PVT->isPointerType() &&
76          "Invalid rs{Set,Clear}Object function prototype (pointer param)");
77
78      // The rs object type passed to the FD
79      clang::QualType RST = PVT->getPointeeType();
80      DataType DT = RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
81      slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
82             && "must be RS object type");
83
84      if (DT >= 0 && DT < DataTypeMax) {
85          RSObjectFD[DT] = FD;
86      } else {
87          slangAssert(false && "incorrect type");
88      }
89    }
90  }
91}
92
93namespace {
94
95// This function constructs a new CompoundStmt from the input StmtList.
96static clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
97      std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
98  unsigned NewStmtCount = StmtList.size();
99  unsigned CompoundStmtCount = 0;
100
101  clang::Stmt **CompoundStmtList;
102  CompoundStmtList = new clang::Stmt*[NewStmtCount];
103
104  std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
105  std::list<clang::Stmt*>::const_iterator E = StmtList.end();
106  for ( ; I != E; I++) {
107    CompoundStmtList[CompoundStmtCount++] = *I;
108  }
109  slangAssert(CompoundStmtCount == NewStmtCount);
110
111  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
112      C, llvm::makeArrayRef(CompoundStmtList, CompoundStmtCount), Loc, Loc);
113
114  delete [] CompoundStmtList;
115
116  return CS;
117}
118
119static void AppendAfterStmt(clang::ASTContext &C,
120                            clang::CompoundStmt *CS,
121                            clang::Stmt *S,
122                            std::list<clang::Stmt*> &StmtList) {
123  slangAssert(CS);
124  clang::CompoundStmt::body_iterator bI = CS->body_begin();
125  clang::CompoundStmt::body_iterator bE = CS->body_end();
126  clang::Stmt **UpdatedStmtList =
127      new clang::Stmt*[CS->size() + StmtList.size()];
128
129  unsigned UpdatedStmtCount = 0;
130  unsigned Once = 0;
131  for ( ; bI != bE; bI++) {
132    if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
133      // If we come across a return here, we don't have anything we can
134      // reasonably replace. We should have already inserted our destructor
135      // code in the proper spot, so we just clean up and return.
136      delete [] UpdatedStmtList;
137
138      return;
139    }
140
141    UpdatedStmtList[UpdatedStmtCount++] = *bI;
142
143    if ((*bI == S) && !Once) {
144      Once++;
145      std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
146      std::list<clang::Stmt*>::const_iterator E = StmtList.end();
147      for ( ; I != E; I++) {
148        UpdatedStmtList[UpdatedStmtCount++] = *I;
149      }
150    }
151  }
152  slangAssert(Once <= 1);
153
154  // When S is nullptr, we are appending to the end of the CompoundStmt.
155  if (!S) {
156    slangAssert(Once == 0);
157    std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
158    std::list<clang::Stmt*>::const_iterator E = StmtList.end();
159    for ( ; I != E; I++) {
160      UpdatedStmtList[UpdatedStmtCount++] = *I;
161    }
162  }
163
164  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
165
166  delete [] UpdatedStmtList;
167}
168
169// This class visits a compound statement and inserts DtorStmt
170// in proper locations. This includes inserting it before any
171// return statement in any sub-block, at the end of the logical enclosing
172// scope (compound statement), and/or before any break/continue statement that
173// would resume outside the declared scope. We will not handle the case for
174// goto statements that leave a local scope.
175//
176// To accomplish these goals, it collects a list of sub-Stmt's that
177// correspond to scope exit points. It then uses an RSASTReplace visitor to
178// transform the AST, inserting appropriate destructors before each of those
179// sub-Stmt's (and also before the exit of the outermost containing Stmt for
180// the scope).
181class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
182 private:
183  clang::ASTContext &mCtx;
184
185  // The loop depth of the currently visited node.
186  int mLoopDepth;
187
188  // The switch statement depth of the currently visited node.
189  // Note that this is tracked separately from the loop depth because
190  // SwitchStmt-contained ContinueStmt's should have destructors for the
191  // corresponding loop scope.
192  int mSwitchDepth;
193
194  // The outermost statement block that we are currently visiting.
195  // This should always be a CompoundStmt.
196  clang::Stmt *mOuterStmt;
197
198  // The destructor to execute for this scope/variable.
199  clang::Stmt* mDtorStmt;
200
201  // The stack of statements which should be replaced by a compound statement
202  // containing the new destructor call followed by the original Stmt.
203  std::stack<clang::Stmt*> mReplaceStmtStack;
204
205  // The source location for the variable declaration that we are trying to
206  // insert destructors for. Note that InsertDestructors() will not generate
207  // destructor calls for source locations that occur lexically before this
208  // location.
209  clang::SourceLocation mVarLoc;
210
211 public:
212  DestructorVisitor(clang::ASTContext &C,
213                    clang::Stmt* OuterStmt,
214                    clang::Stmt* DtorStmt,
215                    clang::SourceLocation VarLoc);
216
217  // This code walks the collected list of Stmts to replace and actually does
218  // the replacement. It also finishes up by appending the destructor to the
219  // current outermost CompoundStmt.
220  void InsertDestructors() {
221    clang::Stmt *S = nullptr;
222    clang::SourceManager &SM = mCtx.getSourceManager();
223    std::list<clang::Stmt *> StmtList;
224    StmtList.push_back(mDtorStmt);
225
226    while (!mReplaceStmtStack.empty()) {
227      S = mReplaceStmtStack.top();
228      mReplaceStmtStack.pop();
229
230      // Skip all source locations that occur before the variable's
231      // declaration, since it won't have been initialized yet.
232      if (SM.isBeforeInTranslationUnit(S->getLocStart(), mVarLoc)) {
233        continue;
234      }
235
236      StmtList.push_back(S);
237      clang::CompoundStmt *CS =
238          BuildCompoundStmt(mCtx, StmtList, S->getLocEnd());
239      StmtList.pop_back();
240
241      RSASTReplace R(mCtx);
242      R.ReplaceStmt(mOuterStmt, S, CS);
243    }
244    clang::CompoundStmt *CS =
245      llvm::dyn_cast<clang::CompoundStmt>(mOuterStmt);
246    slangAssert(CS);
247    AppendAfterStmt(mCtx, CS, nullptr, StmtList);
248  }
249
250  void VisitStmt(clang::Stmt *S);
251  void VisitCompoundStmt(clang::CompoundStmt *CS);
252
253  void VisitBreakStmt(clang::BreakStmt *BS);
254  void VisitCaseStmt(clang::CaseStmt *CS);
255  void VisitContinueStmt(clang::ContinueStmt *CS);
256  void VisitDefaultStmt(clang::DefaultStmt *DS);
257  void VisitDoStmt(clang::DoStmt *DS);
258  void VisitForStmt(clang::ForStmt *FS);
259  void VisitIfStmt(clang::IfStmt *IS);
260  void VisitReturnStmt(clang::ReturnStmt *RS);
261  void VisitSwitchCase(clang::SwitchCase *SC);
262  void VisitSwitchStmt(clang::SwitchStmt *SS);
263  void VisitWhileStmt(clang::WhileStmt *WS);
264};
265
266DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
267                         clang::Stmt *OuterStmt,
268                         clang::Stmt *DtorStmt,
269                         clang::SourceLocation VarLoc)
270  : mCtx(C),
271    mLoopDepth(0),
272    mSwitchDepth(0),
273    mOuterStmt(OuterStmt),
274    mDtorStmt(DtorStmt),
275    mVarLoc(VarLoc) {
276}
277
278void DestructorVisitor::VisitStmt(clang::Stmt *S) {
279  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
280       I != E;
281       I++) {
282    if (clang::Stmt *Child = *I) {
283      Visit(Child);
284    }
285  }
286}
287
288void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
289  VisitStmt(CS);
290}
291
292void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
293  VisitStmt(BS);
294  if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
295    mReplaceStmtStack.push(BS);
296  }
297}
298
299void DestructorVisitor::VisitCaseStmt(clang::CaseStmt *CS) {
300  VisitStmt(CS);
301}
302
303void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
304  VisitStmt(CS);
305  if (mLoopDepth == 0) {
306    // Switch statements can have nested continues.
307    mReplaceStmtStack.push(CS);
308  }
309}
310
311void DestructorVisitor::VisitDefaultStmt(clang::DefaultStmt *DS) {
312  VisitStmt(DS);
313}
314
315void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
316  mLoopDepth++;
317  VisitStmt(DS);
318  mLoopDepth--;
319}
320
321void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
322  mLoopDepth++;
323  VisitStmt(FS);
324  mLoopDepth--;
325}
326
327void DestructorVisitor::VisitIfStmt(clang::IfStmt *IS) {
328  VisitStmt(IS);
329}
330
331void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
332  mReplaceStmtStack.push(RS);
333}
334
335void DestructorVisitor::VisitSwitchCase(clang::SwitchCase *SC) {
336  slangAssert(false && "Both case and default have specialized handlers");
337  VisitStmt(SC);
338}
339
340void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
341  mSwitchDepth++;
342  VisitStmt(SS);
343  mSwitchDepth--;
344}
345
346void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
347  mLoopDepth++;
348  VisitStmt(WS);
349  mLoopDepth--;
350}
351
352clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
353                                 clang::Expr *RefRSVar,
354                                 clang::SourceLocation Loc) {
355  slangAssert(RefRSVar);
356  const clang::Type *T = RefRSVar->getType().getTypePtr();
357  slangAssert(!T->isArrayType() &&
358              "Should not be destroying arrays with this function");
359
360  clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
361  slangAssert((ClearObjectFD != nullptr) &&
362              "rsClearObject doesn't cover all RS object types");
363
364  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
365  clang::QualType ClearObjectFDArgType =
366      ClearObjectFD->getParamDecl(0)->getOriginalType();
367
368  // Example destructor for "rs_font localFont;"
369  //
370  // (CallExpr 'void'
371  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
372  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
373  //   (UnaryOperator 'rs_font *' prefix '&'
374  //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
375
376  // Get address of targeted RS object
377  clang::Expr *AddrRefRSVar =
378      new(C) clang::UnaryOperator(RefRSVar,
379                                  clang::UO_AddrOf,
380                                  ClearObjectFDArgType,
381                                  clang::VK_RValue,
382                                  clang::OK_Ordinary,
383                                  Loc);
384
385  clang::Expr *RefRSClearObjectFD =
386      clang::DeclRefExpr::Create(C,
387                                 clang::NestedNameSpecifierLoc(),
388                                 clang::SourceLocation(),
389                                 ClearObjectFD,
390                                 false,
391                                 ClearObjectFD->getLocation(),
392                                 ClearObjectFDType,
393                                 clang::VK_RValue,
394                                 nullptr);
395
396  clang::Expr *RSClearObjectFP =
397      clang::ImplicitCastExpr::Create(C,
398                                      C.getPointerType(ClearObjectFDType),
399                                      clang::CK_FunctionToPointerDecay,
400                                      RefRSClearObjectFD,
401                                      nullptr,
402                                      clang::VK_RValue);
403
404  llvm::SmallVector<clang::Expr*, 1> ArgList;
405  ArgList.push_back(AddrRefRSVar);
406
407  clang::CallExpr *RSClearObjectCall =
408      new(C) clang::CallExpr(C,
409                             RSClearObjectFP,
410                             ArgList,
411                             ClearObjectFD->getCallResultType(),
412                             clang::VK_RValue,
413                             Loc);
414
415  return RSClearObjectCall;
416}
417
418static int ArrayDim(const clang::Type *T) {
419  if (!T || !T->isArrayType()) {
420    return 0;
421  }
422
423  const clang::ConstantArrayType *CAT =
424    static_cast<const clang::ConstantArrayType *>(T);
425  return static_cast<int>(CAT->getSize().getSExtValue());
426}
427
428static clang::Stmt *ClearStructRSObject(
429    clang::ASTContext &C,
430    clang::DeclContext *DC,
431    clang::Expr *RefRSStruct,
432    clang::SourceLocation StartLoc,
433    clang::SourceLocation Loc);
434
435static clang::Stmt *ClearArrayRSObject(
436    clang::ASTContext &C,
437    clang::DeclContext *DC,
438    clang::Expr *RefRSArr,
439    clang::SourceLocation StartLoc,
440    clang::SourceLocation Loc) {
441  const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
442  slangAssert(BaseType->isArrayType());
443
444  int NumArrayElements = ArrayDim(BaseType);
445  // Actually extract out the base RS object type for use later
446  BaseType = BaseType->getArrayElementTypeNoTypeQual();
447
448  clang::Stmt *StmtArray[2] = {nullptr};
449  int StmtCtr = 0;
450
451  if (NumArrayElements <= 0) {
452    return nullptr;
453  }
454
455  // Example destructor loop for "rs_font fontArr[10];"
456  //
457  // (CompoundStmt
458  //   (DeclStmt "int rsIntIter")
459  //   (ForStmt
460  //     (BinaryOperator 'int' '='
461  //       (DeclRefExpr 'int' Var='rsIntIter')
462  //       (IntegerLiteral 'int' 0))
463  //     (BinaryOperator 'int' '<'
464  //       (DeclRefExpr 'int' Var='rsIntIter')
465  //       (IntegerLiteral 'int' 10)
466  //     nullptr << CondVar >>
467  //     (UnaryOperator 'int' postfix '++'
468  //       (DeclRefExpr 'int' Var='rsIntIter'))
469  //     (CallExpr 'void'
470  //       (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
471  //         (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
472  //       (UnaryOperator 'rs_font *' prefix '&'
473  //         (ArraySubscriptExpr 'rs_font':'rs_font'
474  //           (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
475  //             (DeclRefExpr 'rs_font [10]' Var='fontArr'))
476  //           (DeclRefExpr 'int' Var='rsIntIter')))))))
477
478  // Create helper variable for iterating through elements
479  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
480  clang::VarDecl *IIVD =
481      clang::VarDecl::Create(C,
482                             DC,
483                             StartLoc,
484                             Loc,
485                             &II,
486                             C.IntTy,
487                             C.getTrivialTypeSourceInfo(C.IntTy),
488                             clang::SC_None);
489  // Mark "rsIntIter" as used
490  IIVD->markUsed(C);
491  clang::Decl *IID = (clang::Decl *)IIVD;
492
493  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
494  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
495
496  // Form the actual destructor loop
497  // for (Init; Cond; Inc)
498  //   RSClearObjectCall;
499
500  // Init -> "rsIntIter = 0"
501  clang::DeclRefExpr *RefrsIntIter =
502      clang::DeclRefExpr::Create(C,
503                                 clang::NestedNameSpecifierLoc(),
504                                 clang::SourceLocation(),
505                                 IIVD,
506                                 false,
507                                 Loc,
508                                 C.IntTy,
509                                 clang::VK_RValue,
510                                 nullptr);
511
512  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
513      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
514
515  clang::BinaryOperator *Init =
516      new(C) clang::BinaryOperator(RefrsIntIter,
517                                   Int0,
518                                   clang::BO_Assign,
519                                   C.IntTy,
520                                   clang::VK_RValue,
521                                   clang::OK_Ordinary,
522                                   Loc,
523                                   false);
524
525  // Cond -> "rsIntIter < NumArrayElements"
526  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
527      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
528
529  clang::BinaryOperator *Cond =
530      new(C) clang::BinaryOperator(RefrsIntIter,
531                                   NumArrayElementsExpr,
532                                   clang::BO_LT,
533                                   C.IntTy,
534                                   clang::VK_RValue,
535                                   clang::OK_Ordinary,
536                                   Loc,
537                                   false);
538
539  // Inc -> "rsIntIter++"
540  clang::UnaryOperator *Inc =
541      new(C) clang::UnaryOperator(RefrsIntIter,
542                                  clang::UO_PostInc,
543                                  C.IntTy,
544                                  clang::VK_RValue,
545                                  clang::OK_Ordinary,
546                                  Loc);
547
548  // Body -> "rsClearObject(&VD[rsIntIter]);"
549  // Destructor loop operates on individual array elements
550
551  clang::Expr *RefRSArrPtr =
552      clang::ImplicitCastExpr::Create(C,
553          C.getPointerType(BaseType->getCanonicalTypeInternal()),
554          clang::CK_ArrayToPointerDecay,
555          RefRSArr,
556          nullptr,
557          clang::VK_RValue);
558
559  clang::Expr *RefRSArrPtrSubscript =
560      new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
561                                       RefrsIntIter,
562                                       BaseType->getCanonicalTypeInternal(),
563                                       clang::VK_RValue,
564                                       clang::OK_Ordinary,
565                                       Loc);
566
567  DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
568
569  clang::Stmt *RSClearObjectCall = nullptr;
570  if (BaseType->isArrayType()) {
571    RSClearObjectCall =
572        ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
573  } else if (DT == DataTypeUnknown) {
574    RSClearObjectCall =
575        ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
576  } else {
577    RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
578  }
579
580  clang::ForStmt *DestructorLoop =
581      new(C) clang::ForStmt(C,
582                            Init,
583                            Cond,
584                            nullptr,  // no condVar
585                            Inc,
586                            RSClearObjectCall,
587                            Loc,
588                            Loc,
589                            Loc);
590
591  StmtArray[StmtCtr++] = DestructorLoop;
592  slangAssert(StmtCtr == 2);
593
594  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
595      C, llvm::makeArrayRef(StmtArray, StmtCtr), Loc, Loc);
596
597  return CS;
598}
599
600static unsigned CountRSObjectTypes(clang::ASTContext &C,
601                                   const clang::Type *T,
602                                   clang::SourceLocation Loc) {
603  slangAssert(T);
604  unsigned RSObjectCount = 0;
605
606  if (T->isArrayType()) {
607    return CountRSObjectTypes(C, T->getArrayElementTypeNoTypeQual(), Loc);
608  }
609
610  DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
611  if (DT != DataTypeUnknown) {
612    return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
613  }
614
615  if (T->isUnionType()) {
616    clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
617    RD = RD->getDefinition();
618    for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
619           FE = RD->field_end();
620         FI != FE;
621         FI++) {
622      const clang::FieldDecl *FD = *FI;
623      const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
624      if (CountRSObjectTypes(C, FT, Loc)) {
625        slangAssert(false && "can't have unions with RS object types!");
626        return 0;
627      }
628    }
629  }
630
631  if (!T->isStructureType()) {
632    return 0;
633  }
634
635  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
636  RD = RD->getDefinition();
637  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
638         FE = RD->field_end();
639       FI != FE;
640       FI++) {
641    const clang::FieldDecl *FD = *FI;
642    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
643    if (CountRSObjectTypes(C, FT, Loc)) {
644      // Sub-structs should only count once (as should arrays, etc.)
645      RSObjectCount++;
646    }
647  }
648
649  return RSObjectCount;
650}
651
652static clang::Stmt *ClearStructRSObject(
653    clang::ASTContext &C,
654    clang::DeclContext *DC,
655    clang::Expr *RefRSStruct,
656    clang::SourceLocation StartLoc,
657    clang::SourceLocation Loc) {
658  const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
659
660  slangAssert(!BaseType->isArrayType());
661
662  // Structs should show up as unknown primitive types
663  slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
664              DataTypeUnknown);
665
666  unsigned FieldsToDestroy = CountRSObjectTypes(C, BaseType, Loc);
667  slangAssert(FieldsToDestroy != 0);
668
669  unsigned StmtCount = 0;
670  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
671  for (unsigned i = 0; i < FieldsToDestroy; i++) {
672    StmtArray[i] = nullptr;
673  }
674
675  // Populate StmtArray by creating a destructor for each RS object field
676  clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
677  RD = RD->getDefinition();
678  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
679         FE = RD->field_end();
680       FI != FE;
681       FI++) {
682    // We just look through all field declarations to see if we find a
683    // declaration for an RS object type (or an array of one).
684    bool IsArrayType = false;
685    clang::FieldDecl *FD = *FI;
686    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
687    const clang::Type *OrigType = FT;
688    while (FT && FT->isArrayType()) {
689      FT = FT->getArrayElementTypeNoTypeQual();
690      IsArrayType = true;
691    }
692
693    // Pass a DeclarationNameInfo with a valid DeclName, since name equality
694    // gets asserted during CodeGen.
695    clang::DeclarationNameInfo FDDeclNameInfo(FD->getDeclName(),
696                                              FD->getLocation());
697
698    if (RSExportPrimitiveType::IsRSObjectType(FT)) {
699      clang::DeclAccessPair FoundDecl =
700          clang::DeclAccessPair::make(FD, clang::AS_none);
701      clang::MemberExpr *RSObjectMember =
702          clang::MemberExpr::Create(C,
703                                    RefRSStruct,
704                                    false,
705                                    clang::SourceLocation(),
706                                    clang::NestedNameSpecifierLoc(),
707                                    clang::SourceLocation(),
708                                    FD,
709                                    FoundDecl,
710                                    FDDeclNameInfo,
711                                    nullptr,
712                                    OrigType->getCanonicalTypeInternal(),
713                                    clang::VK_RValue,
714                                    clang::OK_Ordinary);
715
716      slangAssert(StmtCount < FieldsToDestroy);
717
718      if (IsArrayType) {
719        StmtArray[StmtCount++] = ClearArrayRSObject(C,
720                                                    DC,
721                                                    RSObjectMember,
722                                                    StartLoc,
723                                                    Loc);
724      } else {
725        StmtArray[StmtCount++] = ClearSingleRSObject(C,
726                                                     RSObjectMember,
727                                                     Loc);
728      }
729    } else if (FT->isStructureType() && CountRSObjectTypes(C, FT, Loc)) {
730      // In this case, we have a nested struct. We may not end up filling all
731      // of the spaces in StmtArray (sub-structs should handle themselves
732      // with separate compound statements).
733      clang::DeclAccessPair FoundDecl =
734          clang::DeclAccessPair::make(FD, clang::AS_none);
735      clang::MemberExpr *RSObjectMember =
736          clang::MemberExpr::Create(C,
737                                    RefRSStruct,
738                                    false,
739                                    clang::SourceLocation(),
740                                    clang::NestedNameSpecifierLoc(),
741                                    clang::SourceLocation(),
742                                    FD,
743                                    FoundDecl,
744                                    clang::DeclarationNameInfo(),
745                                    nullptr,
746                                    OrigType->getCanonicalTypeInternal(),
747                                    clang::VK_RValue,
748                                    clang::OK_Ordinary);
749
750      if (IsArrayType) {
751        StmtArray[StmtCount++] = ClearArrayRSObject(C,
752                                                    DC,
753                                                    RSObjectMember,
754                                                    StartLoc,
755                                                    Loc);
756      } else {
757        StmtArray[StmtCount++] = ClearStructRSObject(C,
758                                                     DC,
759                                                     RSObjectMember,
760                                                     StartLoc,
761                                                     Loc);
762      }
763    }
764  }
765
766  slangAssert(StmtCount > 0);
767  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
768      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
769
770  delete [] StmtArray;
771
772  return CS;
773}
774
775static clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
776                                            clang::Expr *DstExpr,
777                                            clang::Expr *SrcExpr,
778                                            clang::SourceLocation StartLoc,
779                                            clang::SourceLocation Loc) {
780  const clang::Type *T = DstExpr->getType().getTypePtr();
781  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
782  slangAssert((SetObjectFD != nullptr) &&
783              "rsSetObject doesn't cover all RS object types");
784
785  clang::QualType SetObjectFDType = SetObjectFD->getType();
786  clang::QualType SetObjectFDArgType[2];
787  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
788  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
789
790  clang::Expr *RefRSSetObjectFD =
791      clang::DeclRefExpr::Create(C,
792                                 clang::NestedNameSpecifierLoc(),
793                                 clang::SourceLocation(),
794                                 SetObjectFD,
795                                 false,
796                                 Loc,
797                                 SetObjectFDType,
798                                 clang::VK_RValue,
799                                 nullptr);
800
801  clang::Expr *RSSetObjectFP =
802      clang::ImplicitCastExpr::Create(C,
803                                      C.getPointerType(SetObjectFDType),
804                                      clang::CK_FunctionToPointerDecay,
805                                      RefRSSetObjectFD,
806                                      nullptr,
807                                      clang::VK_RValue);
808
809  llvm::SmallVector<clang::Expr*, 2> ArgList;
810  ArgList.push_back(new(C) clang::UnaryOperator(DstExpr,
811                                                clang::UO_AddrOf,
812                                                SetObjectFDArgType[0],
813                                                clang::VK_RValue,
814                                                clang::OK_Ordinary,
815                                                Loc));
816  ArgList.push_back(SrcExpr);
817
818  clang::CallExpr *RSSetObjectCall =
819      new(C) clang::CallExpr(C,
820                             RSSetObjectFP,
821                             ArgList,
822                             SetObjectFD->getCallResultType(),
823                             clang::VK_RValue,
824                             Loc);
825
826  return RSSetObjectCall;
827}
828
829static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
830                                            clang::Expr *LHS,
831                                            clang::Expr *RHS,
832                                            clang::SourceLocation StartLoc,
833                                            clang::SourceLocation Loc);
834
835/*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
836                                           clang::Expr *DstArr,
837                                           clang::Expr *SrcArr,
838                                           clang::SourceLocation StartLoc,
839                                           clang::SourceLocation Loc) {
840  clang::DeclContext *DC = nullptr;
841  const clang::Type *BaseType = DstArr->getType().getTypePtr();
842  slangAssert(BaseType->isArrayType());
843
844  int NumArrayElements = ArrayDim(BaseType);
845  // Actually extract out the base RS object type for use later
846  BaseType = BaseType->getArrayElementTypeNoTypeQual();
847
848  clang::Stmt *StmtArray[2] = {nullptr};
849  int StmtCtr = 0;
850
851  if (NumArrayElements <= 0) {
852    return nullptr;
853  }
854
855  // Create helper variable for iterating through elements
856  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
857  clang::VarDecl *IIVD =
858      clang::VarDecl::Create(C,
859                             DC,
860                             StartLoc,
861                             Loc,
862                             &II,
863                             C.IntTy,
864                             C.getTrivialTypeSourceInfo(C.IntTy),
865                             clang::SC_None,
866                             clang::SC_None);
867  clang::Decl *IID = (clang::Decl *)IIVD;
868
869  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
870  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
871
872  // Form the actual loop
873  // for (Init; Cond; Inc)
874  //   RSSetObjectCall;
875
876  // Init -> "rsIntIter = 0"
877  clang::DeclRefExpr *RefrsIntIter =
878      clang::DeclRefExpr::Create(C,
879                                 clang::NestedNameSpecifierLoc(),
880                                 IIVD,
881                                 Loc,
882                                 C.IntTy,
883                                 clang::VK_RValue,
884                                 nullptr);
885
886  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
887      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
888
889  clang::BinaryOperator *Init =
890      new(C) clang::BinaryOperator(RefrsIntIter,
891                                   Int0,
892                                   clang::BO_Assign,
893                                   C.IntTy,
894                                   clang::VK_RValue,
895                                   clang::OK_Ordinary,
896                                   Loc);
897
898  // Cond -> "rsIntIter < NumArrayElements"
899  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
900      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
901
902  clang::BinaryOperator *Cond =
903      new(C) clang::BinaryOperator(RefrsIntIter,
904                                   NumArrayElementsExpr,
905                                   clang::BO_LT,
906                                   C.IntTy,
907                                   clang::VK_RValue,
908                                   clang::OK_Ordinary,
909                                   Loc);
910
911  // Inc -> "rsIntIter++"
912  clang::UnaryOperator *Inc =
913      new(C) clang::UnaryOperator(RefrsIntIter,
914                                  clang::UO_PostInc,
915                                  C.IntTy,
916                                  clang::VK_RValue,
917                                  clang::OK_Ordinary,
918                                  Loc);
919
920  // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
921  // Loop operates on individual array elements
922
923  clang::Expr *DstArrPtr =
924      clang::ImplicitCastExpr::Create(C,
925          C.getPointerType(BaseType->getCanonicalTypeInternal()),
926          clang::CK_ArrayToPointerDecay,
927          DstArr,
928          nullptr,
929          clang::VK_RValue);
930
931  clang::Expr *DstArrPtrSubscript =
932      new(C) clang::ArraySubscriptExpr(DstArrPtr,
933                                       RefrsIntIter,
934                                       BaseType->getCanonicalTypeInternal(),
935                                       clang::VK_RValue,
936                                       clang::OK_Ordinary,
937                                       Loc);
938
939  clang::Expr *SrcArrPtr =
940      clang::ImplicitCastExpr::Create(C,
941          C.getPointerType(BaseType->getCanonicalTypeInternal()),
942          clang::CK_ArrayToPointerDecay,
943          SrcArr,
944          nullptr,
945          clang::VK_RValue);
946
947  clang::Expr *SrcArrPtrSubscript =
948      new(C) clang::ArraySubscriptExpr(SrcArrPtr,
949                                       RefrsIntIter,
950                                       BaseType->getCanonicalTypeInternal(),
951                                       clang::VK_RValue,
952                                       clang::OK_Ordinary,
953                                       Loc);
954
955  DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
956
957  clang::Stmt *RSSetObjectCall = nullptr;
958  if (BaseType->isArrayType()) {
959    RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
960                                             SrcArrPtrSubscript,
961                                             StartLoc, Loc);
962  } else if (DT == DataTypeUnknown) {
963    RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
964                                              SrcArrPtrSubscript,
965                                              StartLoc, Loc);
966  } else {
967    RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
968                                              SrcArrPtrSubscript,
969                                              StartLoc, Loc);
970  }
971
972  clang::ForStmt *DestructorLoop =
973      new(C) clang::ForStmt(C,
974                            Init,
975                            Cond,
976                            nullptr,  // no condVar
977                            Inc,
978                            RSSetObjectCall,
979                            Loc,
980                            Loc,
981                            Loc);
982
983  StmtArray[StmtCtr++] = DestructorLoop;
984  slangAssert(StmtCtr == 2);
985
986  clang::CompoundStmt *CS =
987      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
988
989  return CS;
990} */
991
992static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
993                                            clang::Expr *LHS,
994                                            clang::Expr *RHS,
995                                            clang::SourceLocation StartLoc,
996                                            clang::SourceLocation Loc) {
997  clang::QualType QT = LHS->getType();
998  const clang::Type *T = QT.getTypePtr();
999  slangAssert(T->isStructureType());
1000  slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
1001
1002  // Keep an extra slot for the original copy (memcpy)
1003  unsigned FieldsToSet = CountRSObjectTypes(C, T, Loc) + 1;
1004
1005  unsigned StmtCount = 0;
1006  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
1007  for (unsigned i = 0; i < FieldsToSet; i++) {
1008    StmtArray[i] = nullptr;
1009  }
1010
1011  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
1012  RD = RD->getDefinition();
1013  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1014         FE = RD->field_end();
1015       FI != FE;
1016       FI++) {
1017    bool IsArrayType = false;
1018    clang::FieldDecl *FD = *FI;
1019    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1020    const clang::Type *OrigType = FT;
1021
1022    if (!CountRSObjectTypes(C, FT, Loc)) {
1023      // Skip to next if we don't have any viable RS object types
1024      continue;
1025    }
1026
1027    clang::DeclAccessPair FoundDecl =
1028        clang::DeclAccessPair::make(FD, clang::AS_none);
1029    clang::MemberExpr *DstMember =
1030        clang::MemberExpr::Create(C,
1031                                  LHS,
1032                                  false,
1033                                  clang::SourceLocation(),
1034                                  clang::NestedNameSpecifierLoc(),
1035                                  clang::SourceLocation(),
1036                                  FD,
1037                                  FoundDecl,
1038                                  clang::DeclarationNameInfo(),
1039                                  nullptr,
1040                                  OrigType->getCanonicalTypeInternal(),
1041                                  clang::VK_RValue,
1042                                  clang::OK_Ordinary);
1043
1044    clang::MemberExpr *SrcMember =
1045        clang::MemberExpr::Create(C,
1046                                  RHS,
1047                                  false,
1048                                  clang::SourceLocation(),
1049                                  clang::NestedNameSpecifierLoc(),
1050                                  clang::SourceLocation(),
1051                                  FD,
1052                                  FoundDecl,
1053                                  clang::DeclarationNameInfo(),
1054                                  nullptr,
1055                                  OrigType->getCanonicalTypeInternal(),
1056                                  clang::VK_RValue,
1057                                  clang::OK_Ordinary);
1058
1059    if (FT->isArrayType()) {
1060      FT = FT->getArrayElementTypeNoTypeQual();
1061      IsArrayType = true;
1062    }
1063
1064    DataType DT = RSExportPrimitiveType::GetRSSpecificType(FT);
1065
1066    if (IsArrayType) {
1067      clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
1068      DiagEngine.Report(
1069        clang::FullSourceLoc(Loc, C.getSourceManager()),
1070        DiagEngine.getCustomDiagID(
1071          clang::DiagnosticsEngine::Error,
1072          "Arrays of RS object types within structures cannot be copied"));
1073      // TODO(srhines): Support setting arrays of RS objects
1074      // StmtArray[StmtCount++] =
1075      //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1076    } else if (DT == DataTypeUnknown) {
1077      StmtArray[StmtCount++] =
1078          CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1079    } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
1080      StmtArray[StmtCount++] =
1081          CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1082    } else {
1083      slangAssert(false);
1084    }
1085  }
1086
1087  slangAssert(StmtCount < FieldsToSet);
1088
1089  // We still need to actually do the overall struct copy. For simplicity,
1090  // we just do a straight-up assignment (which will still preserve all
1091  // the proper RS object reference counts).
1092  clang::BinaryOperator *CopyStruct =
1093      new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
1094                                   clang::VK_RValue, clang::OK_Ordinary, Loc,
1095                                   false);
1096  StmtArray[StmtCount++] = CopyStruct;
1097
1098  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
1099      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
1100
1101  delete [] StmtArray;
1102
1103  return CS;
1104}
1105
1106}  // namespace
1107
1108void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1109    clang::BinaryOperator *AS) {
1110
1111  clang::QualType QT = AS->getType();
1112
1113  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1114      DataTypeRSAllocation)->getASTContext();
1115
1116  clang::SourceLocation Loc = AS->getExprLoc();
1117  clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
1118  clang::Stmt *UpdatedStmt = nullptr;
1119
1120  if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1121    // By definition, this is a struct assignment if we get here
1122    UpdatedStmt =
1123        CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1124  } else {
1125    UpdatedStmt =
1126        CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1127  }
1128
1129  RSASTReplace R(C);
1130  R.ReplaceStmt(mCS, AS, UpdatedStmt);
1131}
1132
1133void RSObjectRefCount::Scope::AppendRSObjectInit(
1134    clang::VarDecl *VD,
1135    clang::DeclStmt *DS,
1136    DataType DT,
1137    clang::Expr *InitExpr) {
1138  slangAssert(VD);
1139
1140  if (!InitExpr) {
1141    return;
1142  }
1143
1144  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1145      DataTypeRSAllocation)->getASTContext();
1146  clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1147      DataTypeRSAllocation)->getLocation();
1148  clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1149      DataTypeRSAllocation)->getInnerLocStart();
1150
1151  if (DT == DataTypeIsStruct) {
1152    const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1153    clang::DeclRefExpr *RefRSVar =
1154        clang::DeclRefExpr::Create(C,
1155                                   clang::NestedNameSpecifierLoc(),
1156                                   clang::SourceLocation(),
1157                                   VD,
1158                                   false,
1159                                   Loc,
1160                                   T->getCanonicalTypeInternal(),
1161                                   clang::VK_RValue,
1162                                   nullptr);
1163
1164    clang::Stmt *RSSetObjectOps =
1165        CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1166
1167    std::list<clang::Stmt*> StmtList;
1168    StmtList.push_back(RSSetObjectOps);
1169    AppendAfterStmt(C, mCS, DS, StmtList);
1170    return;
1171  }
1172
1173  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1174  slangAssert((SetObjectFD != nullptr) &&
1175              "rsSetObject doesn't cover all RS object types");
1176
1177  clang::QualType SetObjectFDType = SetObjectFD->getType();
1178  clang::QualType SetObjectFDArgType[2];
1179  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1180  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1181
1182  clang::Expr *RefRSSetObjectFD =
1183      clang::DeclRefExpr::Create(C,
1184                                 clang::NestedNameSpecifierLoc(),
1185                                 clang::SourceLocation(),
1186                                 SetObjectFD,
1187                                 false,
1188                                 Loc,
1189                                 SetObjectFDType,
1190                                 clang::VK_RValue,
1191                                 nullptr);
1192
1193  clang::Expr *RSSetObjectFP =
1194      clang::ImplicitCastExpr::Create(C,
1195                                      C.getPointerType(SetObjectFDType),
1196                                      clang::CK_FunctionToPointerDecay,
1197                                      RefRSSetObjectFD,
1198                                      nullptr,
1199                                      clang::VK_RValue);
1200
1201  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1202  clang::DeclRefExpr *RefRSVar =
1203      clang::DeclRefExpr::Create(C,
1204                                 clang::NestedNameSpecifierLoc(),
1205                                 clang::SourceLocation(),
1206                                 VD,
1207                                 false,
1208                                 Loc,
1209                                 T->getCanonicalTypeInternal(),
1210                                 clang::VK_RValue,
1211                                 nullptr);
1212
1213  llvm::SmallVector<clang::Expr*, 2> ArgList;
1214  ArgList.push_back(new(C) clang::UnaryOperator(RefRSVar,
1215                                                clang::UO_AddrOf,
1216                                                SetObjectFDArgType[0],
1217                                                clang::VK_RValue,
1218                                                clang::OK_Ordinary,
1219                                                Loc));
1220  ArgList.push_back(InitExpr);
1221
1222  clang::CallExpr *RSSetObjectCall =
1223      new(C) clang::CallExpr(C,
1224                             RSSetObjectFP,
1225                             ArgList,
1226                             SetObjectFD->getCallResultType(),
1227                             clang::VK_RValue,
1228                             Loc);
1229
1230  std::list<clang::Stmt*> StmtList;
1231  StmtList.push_back(RSSetObjectCall);
1232  AppendAfterStmt(C, mCS, DS, StmtList);
1233}
1234
1235void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1236  for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
1237          E = mRSO.end();
1238        I != E;
1239        I++) {
1240    clang::VarDecl *VD = *I;
1241    clang::Stmt *RSClearObjectCall = ClearRSObject(VD, VD->getDeclContext());
1242    if (RSClearObjectCall) {
1243      clang::ASTContext &C = (*mRSO.begin())->getASTContext();
1244      // Mark VD as used.  It might be unused, except for the destructor.
1245      // 'markUsed' has side-effects that are caused only if VD is not already
1246      // used.  Hence no need for an extra check here.
1247      VD->markUsed(C);
1248      DestructorVisitor DV(C,
1249                           mCS,
1250                           RSClearObjectCall,
1251                           VD->getSourceRange().getBegin());
1252      DV.Visit(mCS);
1253      DV.InsertDestructors();
1254    }
1255  }
1256}
1257
1258clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
1259    clang::VarDecl *VD,
1260    clang::DeclContext *DC) {
1261  slangAssert(VD);
1262  clang::ASTContext &C = VD->getASTContext();
1263  clang::SourceLocation Loc = VD->getLocation();
1264  clang::SourceLocation StartLoc = VD->getInnerLocStart();
1265  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1266
1267  // Reference expr to target RS object variable
1268  clang::DeclRefExpr *RefRSVar =
1269      clang::DeclRefExpr::Create(C,
1270                                 clang::NestedNameSpecifierLoc(),
1271                                 clang::SourceLocation(),
1272                                 VD,
1273                                 false,
1274                                 Loc,
1275                                 T->getCanonicalTypeInternal(),
1276                                 clang::VK_RValue,
1277                                 nullptr);
1278
1279  if (T->isArrayType()) {
1280    return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1281  }
1282
1283  DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
1284
1285  if (DT == DataTypeUnknown ||
1286      DT == DataTypeIsStruct) {
1287    return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1288  }
1289
1290  slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1291              "Should be RS object");
1292
1293  return ClearSingleRSObject(C, RefRSVar, Loc);
1294}
1295
1296bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1297                                          DataType *DT,
1298                                          clang::Expr **InitExpr) {
1299  slangAssert(VD && DT && InitExpr);
1300  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1301
1302  // Loop through array types to get to base type
1303  while (T && T->isArrayType()) {
1304    T = T->getArrayElementTypeNoTypeQual();
1305  }
1306
1307  bool DataTypeIsStructWithRSObject = false;
1308  *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1309
1310  if (*DT == DataTypeUnknown) {
1311    if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1312      *DT = DataTypeIsStruct;
1313      DataTypeIsStructWithRSObject = true;
1314    } else {
1315      return false;
1316    }
1317  }
1318
1319  bool DataTypeIsRSObject = false;
1320  if (DataTypeIsStructWithRSObject) {
1321    DataTypeIsRSObject = true;
1322  } else {
1323    DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1324  }
1325  *InitExpr = VD->getInit();
1326
1327  if (!DataTypeIsRSObject && *InitExpr) {
1328    // If we already have an initializer for a matrix type, we are done.
1329    return DataTypeIsRSObject;
1330  }
1331
1332  clang::Expr *ZeroInitializer =
1333      CreateZeroInitializerForRSSpecificType(*DT,
1334                                             VD->getASTContext(),
1335                                             VD->getLocation());
1336
1337  if (ZeroInitializer) {
1338    ZeroInitializer->setType(T->getCanonicalTypeInternal());
1339    VD->setInit(ZeroInitializer);
1340  }
1341
1342  return DataTypeIsRSObject;
1343}
1344
1345clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
1346    DataType DT,
1347    clang::ASTContext &C,
1348    const clang::SourceLocation &Loc) {
1349  clang::Expr *Res = nullptr;
1350  switch (DT) {
1351    case DataTypeIsStruct:
1352    case DataTypeRSElement:
1353    case DataTypeRSType:
1354    case DataTypeRSAllocation:
1355    case DataTypeRSSampler:
1356    case DataTypeRSScript:
1357    case DataTypeRSMesh:
1358    case DataTypeRSPath:
1359    case DataTypeRSProgramFragment:
1360    case DataTypeRSProgramVertex:
1361    case DataTypeRSProgramRaster:
1362    case DataTypeRSProgramStore:
1363    case DataTypeRSFont: {
1364      //    (ImplicitCastExpr 'nullptr_t'
1365      //      (IntegerLiteral 0)))
1366      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
1367      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
1368      clang::Expr *CastToNull =
1369          clang::ImplicitCastExpr::Create(C,
1370                                          C.NullPtrTy,
1371                                          clang::CK_IntegralToPointer,
1372                                          Int0,
1373                                          nullptr,
1374                                          clang::VK_RValue);
1375
1376      llvm::SmallVector<clang::Expr*, 1>InitList;
1377      InitList.push_back(CastToNull);
1378
1379      Res = new(C) clang::InitListExpr(C, Loc, InitList, Loc);
1380      break;
1381    }
1382    case DataTypeRSMatrix2x2:
1383    case DataTypeRSMatrix3x3:
1384    case DataTypeRSMatrix4x4: {
1385      // RS matrix is not completely an RS object. They hold data by themselves.
1386      // (InitListExpr rs_matrix2x2
1387      //   (InitListExpr float[4]
1388      //     (FloatingLiteral 0)
1389      //     (FloatingLiteral 0)
1390      //     (FloatingLiteral 0)
1391      //     (FloatingLiteral 0)))
1392      clang::QualType FloatTy = C.FloatTy;
1393      // Constructor sets value to 0.0f by default
1394      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
1395      clang::FloatingLiteral *Float0Val =
1396          clang::FloatingLiteral::Create(C,
1397                                         Val,
1398                                         /* isExact = */true,
1399                                         FloatTy,
1400                                         Loc);
1401
1402      unsigned N = 0;
1403      if (DT == DataTypeRSMatrix2x2)
1404        N = 2;
1405      else if (DT == DataTypeRSMatrix3x3)
1406        N = 3;
1407      else if (DT == DataTypeRSMatrix4x4)
1408        N = 4;
1409      unsigned N_2 = N * N;
1410
1411      // Assume we are going to be allocating 16 elements, since 4x4 is max.
1412      llvm::SmallVector<clang::Expr*, 16> InitVals;
1413      for (unsigned i = 0; i < N_2; i++)
1414        InitVals.push_back(Float0Val);
1415      clang::Expr *InitExpr =
1416          new(C) clang::InitListExpr(C, Loc, InitVals, Loc);
1417      InitExpr->setType(C.getConstantArrayType(FloatTy,
1418                                               llvm::APInt(32, N_2),
1419                                               clang::ArrayType::Normal,
1420                                               /* EltTypeQuals = */0));
1421      llvm::SmallVector<clang::Expr*, 1> InitExprVec;
1422      InitExprVec.push_back(InitExpr);
1423
1424      Res = new(C) clang::InitListExpr(C, Loc, InitExprVec, Loc);
1425      break;
1426    }
1427    case DataTypeUnknown:
1428    case DataTypeFloat16:
1429    case DataTypeFloat32:
1430    case DataTypeFloat64:
1431    case DataTypeSigned8:
1432    case DataTypeSigned16:
1433    case DataTypeSigned32:
1434    case DataTypeSigned64:
1435    case DataTypeUnsigned8:
1436    case DataTypeUnsigned16:
1437    case DataTypeUnsigned32:
1438    case DataTypeUnsigned64:
1439    case DataTypeBoolean:
1440    case DataTypeUnsigned565:
1441    case DataTypeUnsigned5551:
1442    case DataTypeUnsigned4444:
1443    case DataTypeMax: {
1444      slangAssert(false && "Not RS object type!");
1445    }
1446    // No default case will enable compiler detecting the missing cases
1447  }
1448
1449  return Res;
1450}
1451
1452void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1453  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1454       I != E;
1455       I++) {
1456    clang::Decl *D = *I;
1457    if (D->getKind() == clang::Decl::Var) {
1458      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1459      DataType DT = DataTypeUnknown;
1460      clang::Expr *InitExpr = nullptr;
1461      if (InitializeRSObject(VD, &DT, &InitExpr)) {
1462        // We need to zero-init all RS object types (including matrices), ...
1463        getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1464        // ... but, only add to the list of RS objects if we have some
1465        // non-matrix RS object fields.
1466        if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(),
1467                               VD->getLocation())) {
1468          getCurrentScope()->addRSObject(VD);
1469        }
1470      }
1471    }
1472  }
1473}
1474
1475void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1476  if (!CS->body_empty()) {
1477    // Push a new scope
1478    Scope *S = new Scope(CS);
1479    mScopeStack.push(S);
1480
1481    VisitStmt(CS);
1482
1483    // Destroy the scope
1484    slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1485    S->InsertLocalVarDestructors();
1486    mScopeStack.pop();
1487    delete S;
1488  }
1489}
1490
1491void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1492  clang::QualType QT = AS->getType();
1493
1494  if (CountRSObjectTypes(mCtx, QT.getTypePtr(), AS->getExprLoc())) {
1495    getCurrentScope()->ReplaceRSObjectAssignment(AS);
1496  }
1497}
1498
1499void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1500  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1501       I != E;
1502       I++) {
1503    if (clang::Stmt *Child = *I) {
1504      Visit(Child);
1505    }
1506  }
1507}
1508
1509// This function walks the list of global variables and (potentially) creates
1510// a single global static destructor function that properly decrements
1511// reference counts on the contained RS object types.
1512clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
1513  Init();
1514
1515  clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
1516  clang::SourceLocation loc;
1517
1518  llvm::StringRef SR(".rs.dtor");
1519  clang::IdentifierInfo &II = mCtx.Idents.get(SR);
1520  clang::DeclarationName N(&II);
1521  clang::FunctionProtoType::ExtProtoInfo EPI;
1522  clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy,
1523      llvm::ArrayRef<clang::QualType>(), EPI);
1524  clang::FunctionDecl *FD = nullptr;
1525
1526  // Generate rsClearObject() call chains for every global variable
1527  // (whether static or extern).
1528  std::list<clang::Stmt *> StmtList;
1529  for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
1530          E = DC->decls_end(); I != E; I++) {
1531    clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
1532    if (VD) {
1533      if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(), loc)) {
1534        if (!FD) {
1535          // Only create FD if we are going to use it.
1536          FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, nullptr,
1537                                           clang::SC_None);
1538        }
1539        // Mark VD as used.  It might be unused, except for the destructor.
1540        // 'markUsed' has side-effects that are caused only if VD is not already
1541        // used.  Hence no need for an extra check here.
1542        VD->markUsed(mCtx);
1543        // Make sure to create any helpers within the function's DeclContext,
1544        // not the one associated with the global translation unit.
1545        clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
1546        StmtList.push_back(RSClearObjectCall);
1547      }
1548    }
1549  }
1550
1551  // Nothing needs to be destroyed, so don't emit a dtor.
1552  if (StmtList.empty()) {
1553    return nullptr;
1554  }
1555
1556  clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
1557
1558  FD->setBody(CS);
1559
1560  return FD;
1561}
1562
1563}  // namespace slang
1564