slang_rs_object_ref_count.cpp revision cec9b65aa890dea58e39951900ae13efb8d11703
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_object_ref_count.h"
18
19#include <list>
20
21#include "clang/AST/DeclGroup.h"
22#include "clang/AST/Expr.h"
23#include "clang/AST/NestedNameSpecifier.h"
24#include "clang/AST/OperationKinds.h"
25#include "clang/AST/Stmt.h"
26#include "clang/AST/StmtVisitor.h"
27
28#include "slang_assert.h"
29#include "slang_rs.h"
30#include "slang_rs_ast_replace.h"
31#include "slang_rs_export_type.h"
32
33namespace slang {
34
35/* Even though those two arrays are of size DataTypeMax, only entries that
36 * correspond to object types will be set.
37 */
38clang::FunctionDecl *
39RSObjectRefCount::RSSetObjectFD[DataTypeMax];
40clang::FunctionDecl *
41RSObjectRefCount::RSClearObjectFD[DataTypeMax];
42
43void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
44  for (unsigned i = 0; i < DataTypeMax; i++) {
45    RSSetObjectFD[i] = NULL;
46    RSClearObjectFD[i] = NULL;
47  }
48
49  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
50
51  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
52          E = TUDecl->decls_end(); I != E; I++) {
53    if ((I->getKind() >= clang::Decl::firstFunction) &&
54        (I->getKind() <= clang::Decl::lastFunction)) {
55      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
56
57      // points to RSSetObjectFD or RSClearObjectFD
58      clang::FunctionDecl **RSObjectFD;
59
60      if (FD->getName() == "rsSetObject") {
61        slangAssert((FD->getNumParams() == 2) &&
62                    "Invalid rsSetObject function prototype (# params)");
63        RSObjectFD = RSSetObjectFD;
64      } else if (FD->getName() == "rsClearObject") {
65        slangAssert((FD->getNumParams() == 1) &&
66                    "Invalid rsClearObject function prototype (# params)");
67        RSObjectFD = RSClearObjectFD;
68      } else {
69        continue;
70      }
71
72      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
73      clang::QualType PVT = PVD->getOriginalType();
74      // The first parameter must be a pointer like rs_allocation*
75      slangAssert(PVT->isPointerType() &&
76          "Invalid rs{Set,Clear}Object function prototype (pointer param)");
77
78      // The rs object type passed to the FD
79      clang::QualType RST = PVT->getPointeeType();
80      DataType DT = RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
81      slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
82             && "must be RS object type");
83
84      if (DT >= 0 && DT < DataTypeMax) {
85          RSObjectFD[DT] = FD;
86      } else {
87          slangAssert(false && "incorrect type");
88      }
89    }
90  }
91}
92
93namespace {
94
95// This function constructs a new CompoundStmt from the input StmtList.
96static clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
97      std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
98  unsigned NewStmtCount = StmtList.size();
99  unsigned CompoundStmtCount = 0;
100
101  clang::Stmt **CompoundStmtList;
102  CompoundStmtList = new clang::Stmt*[NewStmtCount];
103
104  std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
105  std::list<clang::Stmt*>::const_iterator E = StmtList.end();
106  for ( ; I != E; I++) {
107    CompoundStmtList[CompoundStmtCount++] = *I;
108  }
109  slangAssert(CompoundStmtCount == NewStmtCount);
110
111  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
112      C, llvm::makeArrayRef(CompoundStmtList, CompoundStmtCount), Loc, Loc);
113
114  delete [] CompoundStmtList;
115
116  return CS;
117}
118
119static void AppendAfterStmt(clang::ASTContext &C,
120                            clang::CompoundStmt *CS,
121                            clang::Stmt *S,
122                            std::list<clang::Stmt*> &StmtList) {
123  slangAssert(CS);
124  clang::CompoundStmt::body_iterator bI = CS->body_begin();
125  clang::CompoundStmt::body_iterator bE = CS->body_end();
126  clang::Stmt **UpdatedStmtList =
127      new clang::Stmt*[CS->size() + StmtList.size()];
128
129  unsigned UpdatedStmtCount = 0;
130  unsigned Once = 0;
131  for ( ; bI != bE; bI++) {
132    if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
133      // If we come across a return here, we don't have anything we can
134      // reasonably replace. We should have already inserted our destructor
135      // code in the proper spot, so we just clean up and return.
136      delete [] UpdatedStmtList;
137
138      return;
139    }
140
141    UpdatedStmtList[UpdatedStmtCount++] = *bI;
142
143    if ((*bI == S) && !Once) {
144      Once++;
145      std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
146      std::list<clang::Stmt*>::const_iterator E = StmtList.end();
147      for ( ; I != E; I++) {
148        UpdatedStmtList[UpdatedStmtCount++] = *I;
149      }
150    }
151  }
152  slangAssert(Once <= 1);
153
154  // When S is NULL, we are appending to the end of the CompoundStmt.
155  if (!S) {
156    slangAssert(Once == 0);
157    std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
158    std::list<clang::Stmt*>::const_iterator E = StmtList.end();
159    for ( ; I != E; I++) {
160      UpdatedStmtList[UpdatedStmtCount++] = *I;
161    }
162  }
163
164  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
165
166  delete [] UpdatedStmtList;
167
168  return;
169}
170
171// This class visits a compound statement and inserts DtorStmt
172// in proper locations. This includes inserting it before any
173// return statement in any sub-block, at the end of the logical enclosing
174// scope (compound statement), and/or before any break/continue statement that
175// would resume outside the declared scope. We will not handle the case for
176// goto statements that leave a local scope.
177//
178// To accomplish these goals, it collects a list of sub-Stmt's that
179// correspond to scope exit points. It then uses an RSASTReplace visitor to
180// transform the AST, inserting appropriate destructors before each of those
181// sub-Stmt's (and also before the exit of the outermost containing Stmt for
182// the scope).
183class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
184 private:
185  clang::ASTContext &mCtx;
186
187  // The loop depth of the currently visited node.
188  int mLoopDepth;
189
190  // The switch statement depth of the currently visited node.
191  // Note that this is tracked separately from the loop depth because
192  // SwitchStmt-contained ContinueStmt's should have destructors for the
193  // corresponding loop scope.
194  int mSwitchDepth;
195
196  // The outermost statement block that we are currently visiting.
197  // This should always be a CompoundStmt.
198  clang::Stmt *mOuterStmt;
199
200  // The destructor to execute for this scope/variable.
201  clang::Stmt* mDtorStmt;
202
203  // The stack of statements which should be replaced by a compound statement
204  // containing the new destructor call followed by the original Stmt.
205  std::stack<clang::Stmt*> mReplaceStmtStack;
206
207  // The source location for the variable declaration that we are trying to
208  // insert destructors for. Note that InsertDestructors() will not generate
209  // destructor calls for source locations that occur lexically before this
210  // location.
211  clang::SourceLocation mVarLoc;
212
213 public:
214  DestructorVisitor(clang::ASTContext &C,
215                    clang::Stmt* OuterStmt,
216                    clang::Stmt* DtorStmt,
217                    clang::SourceLocation VarLoc);
218
219  // This code walks the collected list of Stmts to replace and actually does
220  // the replacement. It also finishes up by appending the destructor to the
221  // current outermost CompoundStmt.
222  void InsertDestructors() {
223    clang::Stmt *S = NULL;
224    clang::SourceManager &SM = mCtx.getSourceManager();
225    std::list<clang::Stmt *> StmtList;
226    StmtList.push_back(mDtorStmt);
227
228    while (!mReplaceStmtStack.empty()) {
229      S = mReplaceStmtStack.top();
230      mReplaceStmtStack.pop();
231
232      // Skip all source locations that occur before the variable's
233      // declaration, since it won't have been initialized yet.
234      if (SM.isBeforeInTranslationUnit(S->getLocStart(), mVarLoc)) {
235        continue;
236      }
237
238      StmtList.push_back(S);
239      clang::CompoundStmt *CS =
240          BuildCompoundStmt(mCtx, StmtList, S->getLocEnd());
241      StmtList.pop_back();
242
243      RSASTReplace R(mCtx);
244      R.ReplaceStmt(mOuterStmt, S, CS);
245    }
246    clang::CompoundStmt *CS =
247      llvm::dyn_cast<clang::CompoundStmt>(mOuterStmt);
248    slangAssert(CS);
249    AppendAfterStmt(mCtx, CS, NULL, StmtList);
250  }
251
252  void VisitStmt(clang::Stmt *S);
253  void VisitCompoundStmt(clang::CompoundStmt *CS);
254
255  void VisitBreakStmt(clang::BreakStmt *BS);
256  void VisitCaseStmt(clang::CaseStmt *CS);
257  void VisitContinueStmt(clang::ContinueStmt *CS);
258  void VisitDefaultStmt(clang::DefaultStmt *DS);
259  void VisitDoStmt(clang::DoStmt *DS);
260  void VisitForStmt(clang::ForStmt *FS);
261  void VisitIfStmt(clang::IfStmt *IS);
262  void VisitReturnStmt(clang::ReturnStmt *RS);
263  void VisitSwitchCase(clang::SwitchCase *SC);
264  void VisitSwitchStmt(clang::SwitchStmt *SS);
265  void VisitWhileStmt(clang::WhileStmt *WS);
266};
267
268DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
269                         clang::Stmt *OuterStmt,
270                         clang::Stmt *DtorStmt,
271                         clang::SourceLocation VarLoc)
272  : mCtx(C),
273    mLoopDepth(0),
274    mSwitchDepth(0),
275    mOuterStmt(OuterStmt),
276    mDtorStmt(DtorStmt),
277    mVarLoc(VarLoc) {
278  return;
279}
280
281void DestructorVisitor::VisitStmt(clang::Stmt *S) {
282  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
283       I != E;
284       I++) {
285    if (clang::Stmt *Child = *I) {
286      Visit(Child);
287    }
288  }
289  return;
290}
291
292void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
293  VisitStmt(CS);
294  return;
295}
296
297void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
298  VisitStmt(BS);
299  if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
300    mReplaceStmtStack.push(BS);
301  }
302  return;
303}
304
305void DestructorVisitor::VisitCaseStmt(clang::CaseStmt *CS) {
306  VisitStmt(CS);
307  return;
308}
309
310void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
311  VisitStmt(CS);
312  if (mLoopDepth == 0) {
313    // Switch statements can have nested continues.
314    mReplaceStmtStack.push(CS);
315  }
316  return;
317}
318
319void DestructorVisitor::VisitDefaultStmt(clang::DefaultStmt *DS) {
320  VisitStmt(DS);
321  return;
322}
323
324void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
325  mLoopDepth++;
326  VisitStmt(DS);
327  mLoopDepth--;
328  return;
329}
330
331void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
332  mLoopDepth++;
333  VisitStmt(FS);
334  mLoopDepth--;
335  return;
336}
337
338void DestructorVisitor::VisitIfStmt(clang::IfStmt *IS) {
339  VisitStmt(IS);
340  return;
341}
342
343void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
344  mReplaceStmtStack.push(RS);
345  return;
346}
347
348void DestructorVisitor::VisitSwitchCase(clang::SwitchCase *SC) {
349  slangAssert(false && "Both case and default have specialized handlers");
350  VisitStmt(SC);
351  return;
352}
353
354void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
355  mSwitchDepth++;
356  VisitStmt(SS);
357  mSwitchDepth--;
358  return;
359}
360
361void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
362  mLoopDepth++;
363  VisitStmt(WS);
364  mLoopDepth--;
365  return;
366}
367
368clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
369                                 clang::Expr *RefRSVar,
370                                 clang::SourceLocation Loc) {
371  slangAssert(RefRSVar);
372  const clang::Type *T = RefRSVar->getType().getTypePtr();
373  slangAssert(!T->isArrayType() &&
374              "Should not be destroying arrays with this function");
375
376  clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
377  slangAssert((ClearObjectFD != NULL) &&
378              "rsClearObject doesn't cover all RS object types");
379
380  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
381  clang::QualType ClearObjectFDArgType =
382      ClearObjectFD->getParamDecl(0)->getOriginalType();
383
384  // Example destructor for "rs_font localFont;"
385  //
386  // (CallExpr 'void'
387  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
388  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
389  //   (UnaryOperator 'rs_font *' prefix '&'
390  //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
391
392  // Get address of targeted RS object
393  clang::Expr *AddrRefRSVar =
394      new(C) clang::UnaryOperator(RefRSVar,
395                                  clang::UO_AddrOf,
396                                  ClearObjectFDArgType,
397                                  clang::VK_RValue,
398                                  clang::OK_Ordinary,
399                                  Loc);
400
401  clang::Expr *RefRSClearObjectFD =
402      clang::DeclRefExpr::Create(C,
403                                 clang::NestedNameSpecifierLoc(),
404                                 clang::SourceLocation(),
405                                 ClearObjectFD,
406                                 false,
407                                 ClearObjectFD->getLocation(),
408                                 ClearObjectFDType,
409                                 clang::VK_RValue,
410                                 NULL);
411
412  clang::Expr *RSClearObjectFP =
413      clang::ImplicitCastExpr::Create(C,
414                                      C.getPointerType(ClearObjectFDType),
415                                      clang::CK_FunctionToPointerDecay,
416                                      RefRSClearObjectFD,
417                                      NULL,
418                                      clang::VK_RValue);
419
420  llvm::SmallVector<clang::Expr*, 1> ArgList;
421  ArgList.push_back(AddrRefRSVar);
422
423  clang::CallExpr *RSClearObjectCall =
424      new(C) clang::CallExpr(C,
425                             RSClearObjectFP,
426                             ArgList,
427                             ClearObjectFD->getCallResultType(),
428                             clang::VK_RValue,
429                             Loc);
430
431  return RSClearObjectCall;
432}
433
434static int ArrayDim(const clang::Type *T) {
435  if (!T || !T->isArrayType()) {
436    return 0;
437  }
438
439  const clang::ConstantArrayType *CAT =
440    static_cast<const clang::ConstantArrayType *>(T);
441  return static_cast<int>(CAT->getSize().getSExtValue());
442}
443
444static clang::Stmt *ClearStructRSObject(
445    clang::ASTContext &C,
446    clang::DeclContext *DC,
447    clang::Expr *RefRSStruct,
448    clang::SourceLocation StartLoc,
449    clang::SourceLocation Loc);
450
451static clang::Stmt *ClearArrayRSObject(
452    clang::ASTContext &C,
453    clang::DeclContext *DC,
454    clang::Expr *RefRSArr,
455    clang::SourceLocation StartLoc,
456    clang::SourceLocation Loc) {
457  const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
458  slangAssert(BaseType->isArrayType());
459
460  int NumArrayElements = ArrayDim(BaseType);
461  // Actually extract out the base RS object type for use later
462  BaseType = BaseType->getArrayElementTypeNoTypeQual();
463
464  clang::Stmt *StmtArray[2] = {NULL};
465  int StmtCtr = 0;
466
467  if (NumArrayElements <= 0) {
468    return NULL;
469  }
470
471  // Example destructor loop for "rs_font fontArr[10];"
472  //
473  // (CompoundStmt
474  //   (DeclStmt "int rsIntIter")
475  //   (ForStmt
476  //     (BinaryOperator 'int' '='
477  //       (DeclRefExpr 'int' Var='rsIntIter')
478  //       (IntegerLiteral 'int' 0))
479  //     (BinaryOperator 'int' '<'
480  //       (DeclRefExpr 'int' Var='rsIntIter')
481  //       (IntegerLiteral 'int' 10)
482  //     NULL << CondVar >>
483  //     (UnaryOperator 'int' postfix '++'
484  //       (DeclRefExpr 'int' Var='rsIntIter'))
485  //     (CallExpr 'void'
486  //       (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
487  //         (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
488  //       (UnaryOperator 'rs_font *' prefix '&'
489  //         (ArraySubscriptExpr 'rs_font':'rs_font'
490  //           (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
491  //             (DeclRefExpr 'rs_font [10]' Var='fontArr'))
492  //           (DeclRefExpr 'int' Var='rsIntIter')))))))
493
494  // Create helper variable for iterating through elements
495  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
496  clang::VarDecl *IIVD =
497      clang::VarDecl::Create(C,
498                             DC,
499                             StartLoc,
500                             Loc,
501                             &II,
502                             C.IntTy,
503                             C.getTrivialTypeSourceInfo(C.IntTy),
504                             clang::SC_None);
505  clang::Decl *IID = (clang::Decl *)IIVD;
506
507  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
508  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
509
510  // Form the actual destructor loop
511  // for (Init; Cond; Inc)
512  //   RSClearObjectCall;
513
514  // Init -> "rsIntIter = 0"
515  clang::DeclRefExpr *RefrsIntIter =
516      clang::DeclRefExpr::Create(C,
517                                 clang::NestedNameSpecifierLoc(),
518                                 clang::SourceLocation(),
519                                 IIVD,
520                                 false,
521                                 Loc,
522                                 C.IntTy,
523                                 clang::VK_RValue,
524                                 NULL);
525
526  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
527      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
528
529  clang::BinaryOperator *Init =
530      new(C) clang::BinaryOperator(RefrsIntIter,
531                                   Int0,
532                                   clang::BO_Assign,
533                                   C.IntTy,
534                                   clang::VK_RValue,
535                                   clang::OK_Ordinary,
536                                   Loc,
537                                   false);
538
539  // Cond -> "rsIntIter < NumArrayElements"
540  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
541      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
542
543  clang::BinaryOperator *Cond =
544      new(C) clang::BinaryOperator(RefrsIntIter,
545                                   NumArrayElementsExpr,
546                                   clang::BO_LT,
547                                   C.IntTy,
548                                   clang::VK_RValue,
549                                   clang::OK_Ordinary,
550                                   Loc,
551                                   false);
552
553  // Inc -> "rsIntIter++"
554  clang::UnaryOperator *Inc =
555      new(C) clang::UnaryOperator(RefrsIntIter,
556                                  clang::UO_PostInc,
557                                  C.IntTy,
558                                  clang::VK_RValue,
559                                  clang::OK_Ordinary,
560                                  Loc);
561
562  // Body -> "rsClearObject(&VD[rsIntIter]);"
563  // Destructor loop operates on individual array elements
564
565  clang::Expr *RefRSArrPtr =
566      clang::ImplicitCastExpr::Create(C,
567          C.getPointerType(BaseType->getCanonicalTypeInternal()),
568          clang::CK_ArrayToPointerDecay,
569          RefRSArr,
570          NULL,
571          clang::VK_RValue);
572
573  clang::Expr *RefRSArrPtrSubscript =
574      new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
575                                       RefrsIntIter,
576                                       BaseType->getCanonicalTypeInternal(),
577                                       clang::VK_RValue,
578                                       clang::OK_Ordinary,
579                                       Loc);
580
581  DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
582
583  clang::Stmt *RSClearObjectCall = NULL;
584  if (BaseType->isArrayType()) {
585    RSClearObjectCall =
586        ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
587  } else if (DT == DataTypeUnknown) {
588    RSClearObjectCall =
589        ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
590  } else {
591    RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
592  }
593
594  clang::ForStmt *DestructorLoop =
595      new(C) clang::ForStmt(C,
596                            Init,
597                            Cond,
598                            NULL,  // no condVar
599                            Inc,
600                            RSClearObjectCall,
601                            Loc,
602                            Loc,
603                            Loc);
604
605  StmtArray[StmtCtr++] = DestructorLoop;
606  slangAssert(StmtCtr == 2);
607
608  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
609      C, llvm::makeArrayRef(StmtArray, StmtCtr), Loc, Loc);
610
611  return CS;
612}
613
614static unsigned CountRSObjectTypes(clang::ASTContext &C,
615                                   const clang::Type *T,
616                                   clang::SourceLocation Loc) {
617  slangAssert(T);
618  unsigned RSObjectCount = 0;
619
620  if (T->isArrayType()) {
621    return CountRSObjectTypes(C, T->getArrayElementTypeNoTypeQual(), Loc);
622  }
623
624  DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
625  if (DT != DataTypeUnknown) {
626    return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
627  }
628
629  if (T->isUnionType()) {
630    clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
631    RD = RD->getDefinition();
632    for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
633           FE = RD->field_end();
634         FI != FE;
635         FI++) {
636      const clang::FieldDecl *FD = *FI;
637      const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
638      if (CountRSObjectTypes(C, FT, Loc)) {
639        slangAssert(false && "can't have unions with RS object types!");
640        return 0;
641      }
642    }
643  }
644
645  if (!T->isStructureType()) {
646    return 0;
647  }
648
649  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
650  RD = RD->getDefinition();
651  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
652         FE = RD->field_end();
653       FI != FE;
654       FI++) {
655    const clang::FieldDecl *FD = *FI;
656    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
657    if (CountRSObjectTypes(C, FT, Loc)) {
658      // Sub-structs should only count once (as should arrays, etc.)
659      RSObjectCount++;
660    }
661  }
662
663  return RSObjectCount;
664}
665
666static clang::Stmt *ClearStructRSObject(
667    clang::ASTContext &C,
668    clang::DeclContext *DC,
669    clang::Expr *RefRSStruct,
670    clang::SourceLocation StartLoc,
671    clang::SourceLocation Loc) {
672  const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
673
674  slangAssert(!BaseType->isArrayType());
675
676  // Structs should show up as unknown primitive types
677  slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
678              DataTypeUnknown);
679
680  unsigned FieldsToDestroy = CountRSObjectTypes(C, BaseType, Loc);
681  slangAssert(FieldsToDestroy != 0);
682
683  unsigned StmtCount = 0;
684  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
685  for (unsigned i = 0; i < FieldsToDestroy; i++) {
686    StmtArray[i] = NULL;
687  }
688
689  // Populate StmtArray by creating a destructor for each RS object field
690  clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
691  RD = RD->getDefinition();
692  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
693         FE = RD->field_end();
694       FI != FE;
695       FI++) {
696    // We just look through all field declarations to see if we find a
697    // declaration for an RS object type (or an array of one).
698    bool IsArrayType = false;
699    clang::FieldDecl *FD = *FI;
700    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
701    const clang::Type *OrigType = FT;
702    while (FT && FT->isArrayType()) {
703      FT = FT->getArrayElementTypeNoTypeQual();
704      IsArrayType = true;
705    }
706
707    if (RSExportPrimitiveType::IsRSObjectType(FT)) {
708      clang::DeclAccessPair FoundDecl =
709          clang::DeclAccessPair::make(FD, clang::AS_none);
710      clang::MemberExpr *RSObjectMember =
711          clang::MemberExpr::Create(C,
712                                    RefRSStruct,
713                                    false,
714                                    clang::NestedNameSpecifierLoc(),
715                                    clang::SourceLocation(),
716                                    FD,
717                                    FoundDecl,
718                                    clang::DeclarationNameInfo(),
719                                    NULL,
720                                    OrigType->getCanonicalTypeInternal(),
721                                    clang::VK_RValue,
722                                    clang::OK_Ordinary);
723
724      slangAssert(StmtCount < FieldsToDestroy);
725
726      if (IsArrayType) {
727        StmtArray[StmtCount++] = ClearArrayRSObject(C,
728                                                    DC,
729                                                    RSObjectMember,
730                                                    StartLoc,
731                                                    Loc);
732      } else {
733        StmtArray[StmtCount++] = ClearSingleRSObject(C,
734                                                     RSObjectMember,
735                                                     Loc);
736      }
737    } else if (FT->isStructureType() && CountRSObjectTypes(C, FT, Loc)) {
738      // In this case, we have a nested struct. We may not end up filling all
739      // of the spaces in StmtArray (sub-structs should handle themselves
740      // with separate compound statements).
741      clang::DeclAccessPair FoundDecl =
742          clang::DeclAccessPair::make(FD, clang::AS_none);
743      clang::MemberExpr *RSObjectMember =
744          clang::MemberExpr::Create(C,
745                                    RefRSStruct,
746                                    false,
747                                    clang::NestedNameSpecifierLoc(),
748                                    clang::SourceLocation(),
749                                    FD,
750                                    FoundDecl,
751                                    clang::DeclarationNameInfo(),
752                                    NULL,
753                                    OrigType->getCanonicalTypeInternal(),
754                                    clang::VK_RValue,
755                                    clang::OK_Ordinary);
756
757      if (IsArrayType) {
758        StmtArray[StmtCount++] = ClearArrayRSObject(C,
759                                                    DC,
760                                                    RSObjectMember,
761                                                    StartLoc,
762                                                    Loc);
763      } else {
764        StmtArray[StmtCount++] = ClearStructRSObject(C,
765                                                     DC,
766                                                     RSObjectMember,
767                                                     StartLoc,
768                                                     Loc);
769      }
770    }
771  }
772
773  slangAssert(StmtCount > 0);
774  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
775      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
776
777  delete [] StmtArray;
778
779  return CS;
780}
781
782static clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
783                                            clang::Expr *DstExpr,
784                                            clang::Expr *SrcExpr,
785                                            clang::SourceLocation StartLoc,
786                                            clang::SourceLocation Loc) {
787  const clang::Type *T = DstExpr->getType().getTypePtr();
788  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
789  slangAssert((SetObjectFD != NULL) &&
790              "rsSetObject doesn't cover all RS object types");
791
792  clang::QualType SetObjectFDType = SetObjectFD->getType();
793  clang::QualType SetObjectFDArgType[2];
794  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
795  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
796
797  clang::Expr *RefRSSetObjectFD =
798      clang::DeclRefExpr::Create(C,
799                                 clang::NestedNameSpecifierLoc(),
800                                 clang::SourceLocation(),
801                                 SetObjectFD,
802                                 false,
803                                 Loc,
804                                 SetObjectFDType,
805                                 clang::VK_RValue,
806                                 NULL);
807
808  clang::Expr *RSSetObjectFP =
809      clang::ImplicitCastExpr::Create(C,
810                                      C.getPointerType(SetObjectFDType),
811                                      clang::CK_FunctionToPointerDecay,
812                                      RefRSSetObjectFD,
813                                      NULL,
814                                      clang::VK_RValue);
815
816  llvm::SmallVector<clang::Expr*, 2> ArgList;
817  ArgList.push_back(new(C) clang::UnaryOperator(DstExpr,
818                                                clang::UO_AddrOf,
819                                                SetObjectFDArgType[0],
820                                                clang::VK_RValue,
821                                                clang::OK_Ordinary,
822                                                Loc));
823  ArgList.push_back(SrcExpr);
824
825  clang::CallExpr *RSSetObjectCall =
826      new(C) clang::CallExpr(C,
827                             RSSetObjectFP,
828                             ArgList,
829                             SetObjectFD->getCallResultType(),
830                             clang::VK_RValue,
831                             Loc);
832
833  return RSSetObjectCall;
834}
835
836static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
837                                            clang::Expr *LHS,
838                                            clang::Expr *RHS,
839                                            clang::SourceLocation StartLoc,
840                                            clang::SourceLocation Loc);
841
842/*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
843                                           clang::Expr *DstArr,
844                                           clang::Expr *SrcArr,
845                                           clang::SourceLocation StartLoc,
846                                           clang::SourceLocation Loc) {
847  clang::DeclContext *DC = NULL;
848  const clang::Type *BaseType = DstArr->getType().getTypePtr();
849  slangAssert(BaseType->isArrayType());
850
851  int NumArrayElements = ArrayDim(BaseType);
852  // Actually extract out the base RS object type for use later
853  BaseType = BaseType->getArrayElementTypeNoTypeQual();
854
855  clang::Stmt *StmtArray[2] = {NULL};
856  int StmtCtr = 0;
857
858  if (NumArrayElements <= 0) {
859    return NULL;
860  }
861
862  // Create helper variable for iterating through elements
863  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
864  clang::VarDecl *IIVD =
865      clang::VarDecl::Create(C,
866                             DC,
867                             StartLoc,
868                             Loc,
869                             &II,
870                             C.IntTy,
871                             C.getTrivialTypeSourceInfo(C.IntTy),
872                             clang::SC_None,
873                             clang::SC_None);
874  clang::Decl *IID = (clang::Decl *)IIVD;
875
876  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
877  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
878
879  // Form the actual loop
880  // for (Init; Cond; Inc)
881  //   RSSetObjectCall;
882
883  // Init -> "rsIntIter = 0"
884  clang::DeclRefExpr *RefrsIntIter =
885      clang::DeclRefExpr::Create(C,
886                                 clang::NestedNameSpecifierLoc(),
887                                 IIVD,
888                                 Loc,
889                                 C.IntTy,
890                                 clang::VK_RValue,
891                                 NULL);
892
893  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
894      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
895
896  clang::BinaryOperator *Init =
897      new(C) clang::BinaryOperator(RefrsIntIter,
898                                   Int0,
899                                   clang::BO_Assign,
900                                   C.IntTy,
901                                   clang::VK_RValue,
902                                   clang::OK_Ordinary,
903                                   Loc);
904
905  // Cond -> "rsIntIter < NumArrayElements"
906  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
907      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
908
909  clang::BinaryOperator *Cond =
910      new(C) clang::BinaryOperator(RefrsIntIter,
911                                   NumArrayElementsExpr,
912                                   clang::BO_LT,
913                                   C.IntTy,
914                                   clang::VK_RValue,
915                                   clang::OK_Ordinary,
916                                   Loc);
917
918  // Inc -> "rsIntIter++"
919  clang::UnaryOperator *Inc =
920      new(C) clang::UnaryOperator(RefrsIntIter,
921                                  clang::UO_PostInc,
922                                  C.IntTy,
923                                  clang::VK_RValue,
924                                  clang::OK_Ordinary,
925                                  Loc);
926
927  // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
928  // Loop operates on individual array elements
929
930  clang::Expr *DstArrPtr =
931      clang::ImplicitCastExpr::Create(C,
932          C.getPointerType(BaseType->getCanonicalTypeInternal()),
933          clang::CK_ArrayToPointerDecay,
934          DstArr,
935          NULL,
936          clang::VK_RValue);
937
938  clang::Expr *DstArrPtrSubscript =
939      new(C) clang::ArraySubscriptExpr(DstArrPtr,
940                                       RefrsIntIter,
941                                       BaseType->getCanonicalTypeInternal(),
942                                       clang::VK_RValue,
943                                       clang::OK_Ordinary,
944                                       Loc);
945
946  clang::Expr *SrcArrPtr =
947      clang::ImplicitCastExpr::Create(C,
948          C.getPointerType(BaseType->getCanonicalTypeInternal()),
949          clang::CK_ArrayToPointerDecay,
950          SrcArr,
951          NULL,
952          clang::VK_RValue);
953
954  clang::Expr *SrcArrPtrSubscript =
955      new(C) clang::ArraySubscriptExpr(SrcArrPtr,
956                                       RefrsIntIter,
957                                       BaseType->getCanonicalTypeInternal(),
958                                       clang::VK_RValue,
959                                       clang::OK_Ordinary,
960                                       Loc);
961
962  DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
963
964  clang::Stmt *RSSetObjectCall = NULL;
965  if (BaseType->isArrayType()) {
966    RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
967                                             SrcArrPtrSubscript,
968                                             StartLoc, Loc);
969  } else if (DT == DataTypeUnknown) {
970    RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
971                                              SrcArrPtrSubscript,
972                                              StartLoc, Loc);
973  } else {
974    RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
975                                              SrcArrPtrSubscript,
976                                              StartLoc, Loc);
977  }
978
979  clang::ForStmt *DestructorLoop =
980      new(C) clang::ForStmt(C,
981                            Init,
982                            Cond,
983                            NULL,  // no condVar
984                            Inc,
985                            RSSetObjectCall,
986                            Loc,
987                            Loc,
988                            Loc);
989
990  StmtArray[StmtCtr++] = DestructorLoop;
991  slangAssert(StmtCtr == 2);
992
993  clang::CompoundStmt *CS =
994      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
995
996  return CS;
997} */
998
999static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
1000                                            clang::Expr *LHS,
1001                                            clang::Expr *RHS,
1002                                            clang::SourceLocation StartLoc,
1003                                            clang::SourceLocation Loc) {
1004  clang::QualType QT = LHS->getType();
1005  const clang::Type *T = QT.getTypePtr();
1006  slangAssert(T->isStructureType());
1007  slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
1008
1009  // Keep an extra slot for the original copy (memcpy)
1010  unsigned FieldsToSet = CountRSObjectTypes(C, T, Loc) + 1;
1011
1012  unsigned StmtCount = 0;
1013  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
1014  for (unsigned i = 0; i < FieldsToSet; i++) {
1015    StmtArray[i] = NULL;
1016  }
1017
1018  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
1019  RD = RD->getDefinition();
1020  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1021         FE = RD->field_end();
1022       FI != FE;
1023       FI++) {
1024    bool IsArrayType = false;
1025    clang::FieldDecl *FD = *FI;
1026    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1027    const clang::Type *OrigType = FT;
1028
1029    if (!CountRSObjectTypes(C, FT, Loc)) {
1030      // Skip to next if we don't have any viable RS object types
1031      continue;
1032    }
1033
1034    clang::DeclAccessPair FoundDecl =
1035        clang::DeclAccessPair::make(FD, clang::AS_none);
1036    clang::MemberExpr *DstMember =
1037        clang::MemberExpr::Create(C,
1038                                  LHS,
1039                                  false,
1040                                  clang::NestedNameSpecifierLoc(),
1041                                  clang::SourceLocation(),
1042                                  FD,
1043                                  FoundDecl,
1044                                  clang::DeclarationNameInfo(),
1045                                  NULL,
1046                                  OrigType->getCanonicalTypeInternal(),
1047                                  clang::VK_RValue,
1048                                  clang::OK_Ordinary);
1049
1050    clang::MemberExpr *SrcMember =
1051        clang::MemberExpr::Create(C,
1052                                  RHS,
1053                                  false,
1054                                  clang::NestedNameSpecifierLoc(),
1055                                  clang::SourceLocation(),
1056                                  FD,
1057                                  FoundDecl,
1058                                  clang::DeclarationNameInfo(),
1059                                  NULL,
1060                                  OrigType->getCanonicalTypeInternal(),
1061                                  clang::VK_RValue,
1062                                  clang::OK_Ordinary);
1063
1064    if (FT->isArrayType()) {
1065      FT = FT->getArrayElementTypeNoTypeQual();
1066      IsArrayType = true;
1067    }
1068
1069    DataType DT = RSExportPrimitiveType::GetRSSpecificType(FT);
1070
1071    if (IsArrayType) {
1072      clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
1073      DiagEngine.Report(
1074        clang::FullSourceLoc(Loc, C.getSourceManager()),
1075        DiagEngine.getCustomDiagID(
1076          clang::DiagnosticsEngine::Error,
1077          "Arrays of RS object types within structures cannot be copied"));
1078      // TODO(srhines): Support setting arrays of RS objects
1079      // StmtArray[StmtCount++] =
1080      //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1081    } else if (DT == DataTypeUnknown) {
1082      StmtArray[StmtCount++] =
1083          CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1084    } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
1085      StmtArray[StmtCount++] =
1086          CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1087    } else {
1088      slangAssert(false);
1089    }
1090  }
1091
1092  slangAssert(StmtCount < FieldsToSet);
1093
1094  // We still need to actually do the overall struct copy. For simplicity,
1095  // we just do a straight-up assignment (which will still preserve all
1096  // the proper RS object reference counts).
1097  clang::BinaryOperator *CopyStruct =
1098      new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
1099                                   clang::VK_RValue, clang::OK_Ordinary, Loc,
1100                                   false);
1101  StmtArray[StmtCount++] = CopyStruct;
1102
1103  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
1104      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
1105
1106  delete [] StmtArray;
1107
1108  return CS;
1109}
1110
1111}  // namespace
1112
1113void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1114    clang::BinaryOperator *AS) {
1115
1116  clang::QualType QT = AS->getType();
1117
1118  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1119      DataTypeRSFont)->getASTContext();
1120
1121  clang::SourceLocation Loc = AS->getExprLoc();
1122  clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
1123  clang::Stmt *UpdatedStmt = NULL;
1124
1125  if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1126    // By definition, this is a struct assignment if we get here
1127    UpdatedStmt =
1128        CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1129  } else {
1130    UpdatedStmt =
1131        CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1132  }
1133
1134  RSASTReplace R(C);
1135  R.ReplaceStmt(mCS, AS, UpdatedStmt);
1136  return;
1137}
1138
1139void RSObjectRefCount::Scope::AppendRSObjectInit(
1140    clang::VarDecl *VD,
1141    clang::DeclStmt *DS,
1142    DataType DT,
1143    clang::Expr *InitExpr) {
1144  slangAssert(VD);
1145
1146  if (!InitExpr) {
1147    return;
1148  }
1149
1150  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1151      DataTypeRSFont)->getASTContext();
1152  clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1153      DataTypeRSFont)->getLocation();
1154  clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1155      DataTypeRSFont)->getInnerLocStart();
1156
1157  if (DT == DataTypeIsStruct) {
1158    const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1159    clang::DeclRefExpr *RefRSVar =
1160        clang::DeclRefExpr::Create(C,
1161                                   clang::NestedNameSpecifierLoc(),
1162                                   clang::SourceLocation(),
1163                                   VD,
1164                                   false,
1165                                   Loc,
1166                                   T->getCanonicalTypeInternal(),
1167                                   clang::VK_RValue,
1168                                   NULL);
1169
1170    clang::Stmt *RSSetObjectOps =
1171        CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1172
1173    std::list<clang::Stmt*> StmtList;
1174    StmtList.push_back(RSSetObjectOps);
1175    AppendAfterStmt(C, mCS, DS, StmtList);
1176    return;
1177  }
1178
1179  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1180  slangAssert((SetObjectFD != NULL) &&
1181              "rsSetObject doesn't cover all RS object types");
1182
1183  clang::QualType SetObjectFDType = SetObjectFD->getType();
1184  clang::QualType SetObjectFDArgType[2];
1185  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1186  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1187
1188  clang::Expr *RefRSSetObjectFD =
1189      clang::DeclRefExpr::Create(C,
1190                                 clang::NestedNameSpecifierLoc(),
1191                                 clang::SourceLocation(),
1192                                 SetObjectFD,
1193                                 false,
1194                                 Loc,
1195                                 SetObjectFDType,
1196                                 clang::VK_RValue,
1197                                 NULL);
1198
1199  clang::Expr *RSSetObjectFP =
1200      clang::ImplicitCastExpr::Create(C,
1201                                      C.getPointerType(SetObjectFDType),
1202                                      clang::CK_FunctionToPointerDecay,
1203                                      RefRSSetObjectFD,
1204                                      NULL,
1205                                      clang::VK_RValue);
1206
1207  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1208  clang::DeclRefExpr *RefRSVar =
1209      clang::DeclRefExpr::Create(C,
1210                                 clang::NestedNameSpecifierLoc(),
1211                                 clang::SourceLocation(),
1212                                 VD,
1213                                 false,
1214                                 Loc,
1215                                 T->getCanonicalTypeInternal(),
1216                                 clang::VK_RValue,
1217                                 NULL);
1218
1219  llvm::SmallVector<clang::Expr*, 2> ArgList;
1220  ArgList.push_back(new(C) clang::UnaryOperator(RefRSVar,
1221                                                clang::UO_AddrOf,
1222                                                SetObjectFDArgType[0],
1223                                                clang::VK_RValue,
1224                                                clang::OK_Ordinary,
1225                                                Loc));
1226  ArgList.push_back(InitExpr);
1227
1228  clang::CallExpr *RSSetObjectCall =
1229      new(C) clang::CallExpr(C,
1230                             RSSetObjectFP,
1231                             ArgList,
1232                             SetObjectFD->getCallResultType(),
1233                             clang::VK_RValue,
1234                             Loc);
1235
1236  std::list<clang::Stmt*> StmtList;
1237  StmtList.push_back(RSSetObjectCall);
1238  AppendAfterStmt(C, mCS, DS, StmtList);
1239
1240  return;
1241}
1242
1243void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1244  for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
1245          E = mRSO.end();
1246        I != E;
1247        I++) {
1248    clang::VarDecl *VD = *I;
1249    clang::Stmt *RSClearObjectCall = ClearRSObject(VD, VD->getDeclContext());
1250    if (RSClearObjectCall) {
1251      DestructorVisitor DV((*mRSO.begin())->getASTContext(),
1252                           mCS,
1253                           RSClearObjectCall,
1254                           VD->getSourceRange().getBegin());
1255      DV.Visit(mCS);
1256      DV.InsertDestructors();
1257    }
1258  }
1259  return;
1260}
1261
1262clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
1263    clang::VarDecl *VD,
1264    clang::DeclContext *DC) {
1265  slangAssert(VD);
1266  clang::ASTContext &C = VD->getASTContext();
1267  clang::SourceLocation Loc = VD->getLocation();
1268  clang::SourceLocation StartLoc = VD->getInnerLocStart();
1269  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1270
1271  // Reference expr to target RS object variable
1272  clang::DeclRefExpr *RefRSVar =
1273      clang::DeclRefExpr::Create(C,
1274                                 clang::NestedNameSpecifierLoc(),
1275                                 clang::SourceLocation(),
1276                                 VD,
1277                                 false,
1278                                 Loc,
1279                                 T->getCanonicalTypeInternal(),
1280                                 clang::VK_RValue,
1281                                 NULL);
1282
1283  if (T->isArrayType()) {
1284    return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1285  }
1286
1287  DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
1288
1289  if (DT == DataTypeUnknown ||
1290      DT == DataTypeIsStruct) {
1291    return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1292  }
1293
1294  slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1295              "Should be RS object");
1296
1297  return ClearSingleRSObject(C, RefRSVar, Loc);
1298}
1299
1300bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1301                                          DataType *DT,
1302                                          clang::Expr **InitExpr) {
1303  slangAssert(VD && DT && InitExpr);
1304  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1305
1306  // Loop through array types to get to base type
1307  while (T && T->isArrayType()) {
1308    T = T->getArrayElementTypeNoTypeQual();
1309  }
1310
1311  bool DataTypeIsStructWithRSObject = false;
1312  *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1313
1314  if (*DT == DataTypeUnknown) {
1315    if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1316      *DT = DataTypeIsStruct;
1317      DataTypeIsStructWithRSObject = true;
1318    } else {
1319      return false;
1320    }
1321  }
1322
1323  bool DataTypeIsRSObject = false;
1324  if (DataTypeIsStructWithRSObject) {
1325    DataTypeIsRSObject = true;
1326  } else {
1327    DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1328  }
1329  *InitExpr = VD->getInit();
1330
1331  if (!DataTypeIsRSObject && *InitExpr) {
1332    // If we already have an initializer for a matrix type, we are done.
1333    return DataTypeIsRSObject;
1334  }
1335
1336  clang::Expr *ZeroInitializer =
1337      CreateZeroInitializerForRSSpecificType(*DT,
1338                                             VD->getASTContext(),
1339                                             VD->getLocation());
1340
1341  if (ZeroInitializer) {
1342    ZeroInitializer->setType(T->getCanonicalTypeInternal());
1343    VD->setInit(ZeroInitializer);
1344  }
1345
1346  return DataTypeIsRSObject;
1347}
1348
1349clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
1350    DataType DT,
1351    clang::ASTContext &C,
1352    const clang::SourceLocation &Loc) {
1353  clang::Expr *Res = NULL;
1354  switch (DT) {
1355    case DataTypeIsStruct:
1356    case DataTypeRSElement:
1357    case DataTypeRSType:
1358    case DataTypeRSAllocation:
1359    case DataTypeRSSampler:
1360    case DataTypeRSScript:
1361    case DataTypeRSMesh:
1362    case DataTypeRSPath:
1363    case DataTypeRSProgramFragment:
1364    case DataTypeRSProgramVertex:
1365    case DataTypeRSProgramRaster:
1366    case DataTypeRSProgramStore:
1367    case DataTypeRSFont: {
1368      //    (ImplicitCastExpr 'nullptr_t'
1369      //      (IntegerLiteral 0)))
1370      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
1371      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
1372      clang::Expr *CastToNull =
1373          clang::ImplicitCastExpr::Create(C,
1374                                          C.NullPtrTy,
1375                                          clang::CK_IntegralToPointer,
1376                                          Int0,
1377                                          NULL,
1378                                          clang::VK_RValue);
1379
1380      llvm::SmallVector<clang::Expr*, 1>InitList;
1381      InitList.push_back(CastToNull);
1382
1383      Res = new(C) clang::InitListExpr(C, Loc, InitList, Loc);
1384      break;
1385    }
1386    case DataTypeRSMatrix2x2:
1387    case DataTypeRSMatrix3x3:
1388    case DataTypeRSMatrix4x4: {
1389      // RS matrix is not completely an RS object. They hold data by themselves.
1390      // (InitListExpr rs_matrix2x2
1391      //   (InitListExpr float[4]
1392      //     (FloatingLiteral 0)
1393      //     (FloatingLiteral 0)
1394      //     (FloatingLiteral 0)
1395      //     (FloatingLiteral 0)))
1396      clang::QualType FloatTy = C.FloatTy;
1397      // Constructor sets value to 0.0f by default
1398      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
1399      clang::FloatingLiteral *Float0Val =
1400          clang::FloatingLiteral::Create(C,
1401                                         Val,
1402                                         /* isExact = */true,
1403                                         FloatTy,
1404                                         Loc);
1405
1406      unsigned N = 0;
1407      if (DT == DataTypeRSMatrix2x2)
1408        N = 2;
1409      else if (DT == DataTypeRSMatrix3x3)
1410        N = 3;
1411      else if (DT == DataTypeRSMatrix4x4)
1412        N = 4;
1413      unsigned N_2 = N * N;
1414
1415      // Assume we are going to be allocating 16 elements, since 4x4 is max.
1416      llvm::SmallVector<clang::Expr*, 16> InitVals;
1417      for (unsigned i = 0; i < N_2; i++)
1418        InitVals.push_back(Float0Val);
1419      clang::Expr *InitExpr =
1420          new(C) clang::InitListExpr(C, Loc, InitVals, Loc);
1421      InitExpr->setType(C.getConstantArrayType(FloatTy,
1422                                               llvm::APInt(32, N_2),
1423                                               clang::ArrayType::Normal,
1424                                               /* EltTypeQuals = */0));
1425      llvm::SmallVector<clang::Expr*, 1> InitExprVec;
1426      InitExprVec.push_back(InitExpr);
1427
1428      Res = new(C) clang::InitListExpr(C, Loc, InitExprVec, Loc);
1429      break;
1430    }
1431    case DataTypeUnknown:
1432    case DataTypeFloat16:
1433    case DataTypeFloat32:
1434    case DataTypeFloat64:
1435    case DataTypeSigned8:
1436    case DataTypeSigned16:
1437    case DataTypeSigned32:
1438    case DataTypeSigned64:
1439    case DataTypeUnsigned8:
1440    case DataTypeUnsigned16:
1441    case DataTypeUnsigned32:
1442    case DataTypeUnsigned64:
1443    case DataTypeBoolean:
1444    case DataTypeUnsigned565:
1445    case DataTypeUnsigned5551:
1446    case DataTypeUnsigned4444:
1447    case DataTypeMax: {
1448      slangAssert(false && "Not RS object type!");
1449    }
1450    // No default case will enable compiler detecting the missing cases
1451  }
1452
1453  return Res;
1454}
1455
1456void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1457  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1458       I != E;
1459       I++) {
1460    clang::Decl *D = *I;
1461    if (D->getKind() == clang::Decl::Var) {
1462      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1463      DataType DT = DataTypeUnknown;
1464      clang::Expr *InitExpr = NULL;
1465      if (InitializeRSObject(VD, &DT, &InitExpr)) {
1466        // We need to zero-init all RS object types (including matrices), ...
1467        getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1468        // ... but, only add to the list of RS objects if we have some
1469        // non-matrix RS object fields.
1470        if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(),
1471                               VD->getLocation())) {
1472          getCurrentScope()->addRSObject(VD);
1473        }
1474      }
1475    }
1476  }
1477  return;
1478}
1479
1480void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1481  if (!CS->body_empty()) {
1482    // Push a new scope
1483    Scope *S = new Scope(CS);
1484    mScopeStack.push(S);
1485
1486    VisitStmt(CS);
1487
1488    // Destroy the scope
1489    slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1490    S->InsertLocalVarDestructors();
1491    mScopeStack.pop();
1492    delete S;
1493  }
1494  return;
1495}
1496
1497void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1498  clang::QualType QT = AS->getType();
1499
1500  if (CountRSObjectTypes(mCtx, QT.getTypePtr(), AS->getExprLoc())) {
1501    getCurrentScope()->ReplaceRSObjectAssignment(AS);
1502  }
1503
1504  return;
1505}
1506
1507void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1508  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1509       I != E;
1510       I++) {
1511    if (clang::Stmt *Child = *I) {
1512      Visit(Child);
1513    }
1514  }
1515  return;
1516}
1517
1518// This function walks the list of global variables and (potentially) creates
1519// a single global static destructor function that properly decrements
1520// reference counts on the contained RS object types.
1521clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
1522  Init();
1523
1524  clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
1525  clang::SourceLocation loc;
1526
1527  llvm::StringRef SR(".rs.dtor");
1528  clang::IdentifierInfo &II = mCtx.Idents.get(SR);
1529  clang::DeclarationName N(&II);
1530  clang::FunctionProtoType::ExtProtoInfo EPI;
1531  clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy,
1532      llvm::ArrayRef<clang::QualType>(), EPI);
1533  clang::FunctionDecl *FD = NULL;
1534
1535  // Generate rsClearObject() call chains for every global variable
1536  // (whether static or extern).
1537  std::list<clang::Stmt *> StmtList;
1538  for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
1539          E = DC->decls_end(); I != E; I++) {
1540    clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
1541    if (VD) {
1542      if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(), loc)) {
1543        if (!FD) {
1544          // Only create FD if we are going to use it.
1545          FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, NULL,
1546                                           clang::SC_None);
1547        }
1548        // Make sure to create any helpers within the function's DeclContext,
1549        // not the one associated with the global translation unit.
1550        clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
1551        StmtList.push_back(RSClearObjectCall);
1552      }
1553    }
1554  }
1555
1556  // Nothing needs to be destroyed, so don't emit a dtor.
1557  if (StmtList.empty()) {
1558    return NULL;
1559  }
1560
1561  clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
1562
1563  FD->setBody(CS);
1564
1565  return FD;
1566}
1567
1568}  // namespace slang
1569