slang_rs_object_ref_count.cpp revision d1123c29614eb9e0df7485f9e0775470db2f0384
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_object_ref_count.h"
18
19#include <list>
20
21#include "clang/AST/DeclGroup.h"
22#include "clang/AST/Expr.h"
23#include "clang/AST/NestedNameSpecifier.h"
24#include "clang/AST/OperationKinds.h"
25#include "clang/AST/Stmt.h"
26#include "clang/AST/StmtVisitor.h"
27
28#include "slang_assert.h"
29#include "slang_rs.h"
30#include "slang_rs_ast_replace.h"
31#include "slang_rs_export_type.h"
32
33namespace slang {
34
35clang::FunctionDecl *RSObjectRefCount::
36    RSSetObjectFD[RSExportPrimitiveType::LastRSObjectType -
37                  RSExportPrimitiveType::FirstRSObjectType + 1];
38clang::FunctionDecl *RSObjectRefCount::
39    RSClearObjectFD[RSExportPrimitiveType::LastRSObjectType -
40                    RSExportPrimitiveType::FirstRSObjectType + 1];
41
42void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
43  for (unsigned i = 0;
44       i < (sizeof(RSClearObjectFD) / sizeof(clang::FunctionDecl*));
45       i++) {
46    RSSetObjectFD[i] = NULL;
47    RSClearObjectFD[i] = NULL;
48  }
49
50  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
51
52  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
53          E = TUDecl->decls_end(); I != E; I++) {
54    if ((I->getKind() >= clang::Decl::firstFunction) &&
55        (I->getKind() <= clang::Decl::lastFunction)) {
56      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
57
58      // points to RSSetObjectFD or RSClearObjectFD
59      clang::FunctionDecl **RSObjectFD;
60
61      if (FD->getName() == "rsSetObject") {
62        slangAssert((FD->getNumParams() == 2) &&
63                    "Invalid rsSetObject function prototype (# params)");
64        RSObjectFD = RSSetObjectFD;
65      } else if (FD->getName() == "rsClearObject") {
66        slangAssert((FD->getNumParams() == 1) &&
67                    "Invalid rsClearObject function prototype (# params)");
68        RSObjectFD = RSClearObjectFD;
69      } else {
70        continue;
71      }
72
73      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
74      clang::QualType PVT = PVD->getOriginalType();
75      // The first parameter must be a pointer like rs_allocation*
76      slangAssert(PVT->isPointerType() &&
77          "Invalid rs{Set,Clear}Object function prototype (pointer param)");
78
79      // The rs object type passed to the FD
80      clang::QualType RST = PVT->getPointeeType();
81      RSExportPrimitiveType::DataType DT =
82          RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
83      slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
84             && "must be RS object type");
85
86      RSObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)] = FD;
87    }
88  }
89}
90
91namespace {
92
93// This function constructs a new CompoundStmt from the input StmtList.
94static clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
95      std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
96  unsigned NewStmtCount = StmtList.size();
97  unsigned CompoundStmtCount = 0;
98
99  clang::Stmt **CompoundStmtList;
100  CompoundStmtList = new clang::Stmt*[NewStmtCount];
101
102  std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
103  std::list<clang::Stmt*>::const_iterator E = StmtList.end();
104  for ( ; I != E; I++) {
105    CompoundStmtList[CompoundStmtCount++] = *I;
106  }
107  slangAssert(CompoundStmtCount == NewStmtCount);
108
109  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(C,
110                                                       CompoundStmtList,
111                                                       CompoundStmtCount,
112                                                       Loc,
113                                                       Loc);
114
115  delete [] CompoundStmtList;
116
117  return CS;
118}
119
120static void AppendAfterStmt(clang::ASTContext &C,
121                            clang::CompoundStmt *CS,
122                            clang::Stmt *S,
123                            std::list<clang::Stmt*> &StmtList) {
124  slangAssert(CS);
125  clang::CompoundStmt::body_iterator bI = CS->body_begin();
126  clang::CompoundStmt::body_iterator bE = CS->body_end();
127  clang::Stmt **UpdatedStmtList =
128      new clang::Stmt*[CS->size() + StmtList.size()];
129
130  unsigned UpdatedStmtCount = 0;
131  unsigned Once = 0;
132  for ( ; bI != bE; bI++) {
133    if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
134      // If we come across a return here, we don't have anything we can
135      // reasonably replace. We should have already inserted our destructor
136      // code in the proper spot, so we just clean up and return.
137      delete [] UpdatedStmtList;
138
139      return;
140    }
141
142    UpdatedStmtList[UpdatedStmtCount++] = *bI;
143
144    if ((*bI == S) && !Once) {
145      Once++;
146      std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
147      std::list<clang::Stmt*>::const_iterator E = StmtList.end();
148      for ( ; I != E; I++) {
149        UpdatedStmtList[UpdatedStmtCount++] = *I;
150      }
151    }
152  }
153  slangAssert(Once <= 1);
154
155  // When S is NULL, we are appending to the end of the CompoundStmt.
156  if (!S) {
157    slangAssert(Once == 0);
158    std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
159    std::list<clang::Stmt*>::const_iterator E = StmtList.end();
160    for ( ; I != E; I++) {
161      UpdatedStmtList[UpdatedStmtCount++] = *I;
162    }
163  }
164
165  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
166
167  delete [] UpdatedStmtList;
168
169  return;
170}
171
172// This class visits a compound statement and inserts DtorStmt
173// in proper locations. This includes inserting it before any
174// return statement in any sub-block, at the end of the logical enclosing
175// scope (compound statement), and/or before any break/continue statement that
176// would resume outside the declared scope. We will not handle the case for
177// goto statements that leave a local scope.
178//
179// To accomplish these goals, it collects a list of sub-Stmt's that
180// correspond to scope exit points. It then uses an RSASTReplace visitor to
181// transform the AST, inserting appropriate destructors before each of those
182// sub-Stmt's (and also before the exit of the outermost containing Stmt for
183// the scope).
184class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
185 private:
186  clang::ASTContext &mCtx;
187
188  // The loop depth of the currently visited node.
189  int mLoopDepth;
190
191  // The switch statement depth of the currently visited node.
192  // Note that this is tracked separately from the loop depth because
193  // SwitchStmt-contained ContinueStmt's should have destructors for the
194  // corresponding loop scope.
195  int mSwitchDepth;
196
197  // The outermost statement block that we are currently visiting.
198  // This should always be a CompoundStmt.
199  clang::Stmt *mOuterStmt;
200
201  // The destructor to execute for this scope/variable.
202  clang::Stmt* mDtorStmt;
203
204  // The stack of statements which should be replaced by a compound statement
205  // containing the new destructor call followed by the original Stmt.
206  std::stack<clang::Stmt*> mReplaceStmtStack;
207
208  // The source location for the variable declaration that we are trying to
209  // insert destructors for. Note that InsertDestructors() will not generate
210  // destructor calls for source locations that occur lexically before this
211  // location.
212  clang::SourceLocation mVarLoc;
213
214 public:
215  DestructorVisitor(clang::ASTContext &C,
216                    clang::Stmt* OuterStmt,
217                    clang::Stmt* DtorStmt,
218                    clang::SourceLocation VarLoc);
219
220  // This code walks the collected list of Stmts to replace and actually does
221  // the replacement. It also finishes up by appending the destructor to the
222  // current outermost CompoundStmt.
223  void InsertDestructors() {
224    clang::Stmt *S = NULL;
225    clang::SourceManager &SM = mCtx.getSourceManager();
226    std::list<clang::Stmt *> StmtList;
227    StmtList.push_back(mDtorStmt);
228
229    while (!mReplaceStmtStack.empty()) {
230      S = mReplaceStmtStack.top();
231      mReplaceStmtStack.pop();
232
233      // Skip all source locations that occur before the variable's
234      // declaration, since it won't have been initialized yet.
235      if (SM.isBeforeInTranslationUnit(S->getLocStart(), mVarLoc)) {
236        continue;
237      }
238
239      StmtList.push_back(S);
240      clang::CompoundStmt *CS =
241          BuildCompoundStmt(mCtx, StmtList, S->getLocEnd());
242      StmtList.pop_back();
243
244      RSASTReplace R(mCtx);
245      R.ReplaceStmt(mOuterStmt, S, CS);
246    }
247    clang::CompoundStmt *CS =
248      llvm::dyn_cast<clang::CompoundStmt>(mOuterStmt);
249    slangAssert(CS);
250    AppendAfterStmt(mCtx, CS, NULL, StmtList);
251  }
252
253  void VisitStmt(clang::Stmt *S);
254  void VisitCompoundStmt(clang::CompoundStmt *CS);
255
256  void VisitBreakStmt(clang::BreakStmt *BS);
257  void VisitCaseStmt(clang::CaseStmt *CS);
258  void VisitContinueStmt(clang::ContinueStmt *CS);
259  void VisitDefaultStmt(clang::DefaultStmt *DS);
260  void VisitDoStmt(clang::DoStmt *DS);
261  void VisitForStmt(clang::ForStmt *FS);
262  void VisitIfStmt(clang::IfStmt *IS);
263  void VisitReturnStmt(clang::ReturnStmt *RS);
264  void VisitSwitchCase(clang::SwitchCase *SC);
265  void VisitSwitchStmt(clang::SwitchStmt *SS);
266  void VisitWhileStmt(clang::WhileStmt *WS);
267};
268
269DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
270                         clang::Stmt *OuterStmt,
271                         clang::Stmt *DtorStmt,
272                         clang::SourceLocation VarLoc)
273  : mCtx(C),
274    mLoopDepth(0),
275    mSwitchDepth(0),
276    mOuterStmt(OuterStmt),
277    mDtorStmt(DtorStmt),
278    mVarLoc(VarLoc) {
279  return;
280}
281
282void DestructorVisitor::VisitStmt(clang::Stmt *S) {
283  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
284       I != E;
285       I++) {
286    if (clang::Stmt *Child = *I) {
287      Visit(Child);
288    }
289  }
290  return;
291}
292
293void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
294  VisitStmt(CS);
295  return;
296}
297
298void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
299  VisitStmt(BS);
300  if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
301    mReplaceStmtStack.push(BS);
302  }
303  return;
304}
305
306void DestructorVisitor::VisitCaseStmt(clang::CaseStmt *CS) {
307  VisitStmt(CS);
308  return;
309}
310
311void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
312  VisitStmt(CS);
313  if (mLoopDepth == 0) {
314    // Switch statements can have nested continues.
315    mReplaceStmtStack.push(CS);
316  }
317  return;
318}
319
320void DestructorVisitor::VisitDefaultStmt(clang::DefaultStmt *DS) {
321  VisitStmt(DS);
322  return;
323}
324
325void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
326  mLoopDepth++;
327  VisitStmt(DS);
328  mLoopDepth--;
329  return;
330}
331
332void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
333  mLoopDepth++;
334  VisitStmt(FS);
335  mLoopDepth--;
336  return;
337}
338
339void DestructorVisitor::VisitIfStmt(clang::IfStmt *IS) {
340  VisitStmt(IS);
341  return;
342}
343
344void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
345  mReplaceStmtStack.push(RS);
346  return;
347}
348
349void DestructorVisitor::VisitSwitchCase(clang::SwitchCase *SC) {
350  slangAssert(false && "Both case and default have specialized handlers");
351  VisitStmt(SC);
352  return;
353}
354
355void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
356  mSwitchDepth++;
357  VisitStmt(SS);
358  mSwitchDepth--;
359  return;
360}
361
362void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
363  mLoopDepth++;
364  VisitStmt(WS);
365  mLoopDepth--;
366  return;
367}
368
369clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
370                                 clang::Expr *RefRSVar,
371                                 clang::SourceLocation Loc) {
372  slangAssert(RefRSVar);
373  const clang::Type *T = RefRSVar->getType().getTypePtr();
374  slangAssert(!T->isArrayType() &&
375              "Should not be destroying arrays with this function");
376
377  clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
378  slangAssert((ClearObjectFD != NULL) &&
379              "rsClearObject doesn't cover all RS object types");
380
381  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
382  clang::QualType ClearObjectFDArgType =
383      ClearObjectFD->getParamDecl(0)->getOriginalType();
384
385  // Example destructor for "rs_font localFont;"
386  //
387  // (CallExpr 'void'
388  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
389  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
390  //   (UnaryOperator 'rs_font *' prefix '&'
391  //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
392
393  // Get address of targeted RS object
394  clang::Expr *AddrRefRSVar =
395      new(C) clang::UnaryOperator(RefRSVar,
396                                  clang::UO_AddrOf,
397                                  ClearObjectFDArgType,
398                                  clang::VK_RValue,
399                                  clang::OK_Ordinary,
400                                  Loc);
401
402  clang::Expr *RefRSClearObjectFD =
403      clang::DeclRefExpr::Create(C,
404                                 clang::NestedNameSpecifierLoc(),
405                                 ClearObjectFD,
406                                 ClearObjectFD->getLocation(),
407                                 ClearObjectFDType,
408                                 clang::VK_RValue,
409                                 NULL);
410
411  clang::Expr *RSClearObjectFP =
412      clang::ImplicitCastExpr::Create(C,
413                                      C.getPointerType(ClearObjectFDType),
414                                      clang::CK_FunctionToPointerDecay,
415                                      RefRSClearObjectFD,
416                                      NULL,
417                                      clang::VK_RValue);
418
419  clang::CallExpr *RSClearObjectCall =
420      new(C) clang::CallExpr(C,
421                             RSClearObjectFP,
422                             &AddrRefRSVar,
423                             1,
424                             ClearObjectFD->getCallResultType(),
425                             clang::VK_RValue,
426                             Loc);
427
428  return RSClearObjectCall;
429}
430
431static int ArrayDim(const clang::Type *T) {
432  if (!T || !T->isArrayType()) {
433    return 0;
434  }
435
436  const clang::ConstantArrayType *CAT =
437    static_cast<const clang::ConstantArrayType *>(T);
438  return static_cast<int>(CAT->getSize().getSExtValue());
439}
440
441static clang::Stmt *ClearStructRSObject(
442    clang::ASTContext &C,
443    clang::DeclContext *DC,
444    clang::Expr *RefRSStruct,
445    clang::SourceLocation StartLoc,
446    clang::SourceLocation Loc);
447
448static clang::Stmt *ClearArrayRSObject(
449    clang::ASTContext &C,
450    clang::DeclContext *DC,
451    clang::Expr *RefRSArr,
452    clang::SourceLocation StartLoc,
453    clang::SourceLocation Loc) {
454  const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
455  slangAssert(BaseType->isArrayType());
456
457  int NumArrayElements = ArrayDim(BaseType);
458  // Actually extract out the base RS object type for use later
459  BaseType = BaseType->getArrayElementTypeNoTypeQual();
460
461  clang::Stmt *StmtArray[2] = {NULL};
462  int StmtCtr = 0;
463
464  if (NumArrayElements <= 0) {
465    return NULL;
466  }
467
468  // Example destructor loop for "rs_font fontArr[10];"
469  //
470  // (CompoundStmt
471  //   (DeclStmt "int rsIntIter")
472  //   (ForStmt
473  //     (BinaryOperator 'int' '='
474  //       (DeclRefExpr 'int' Var='rsIntIter')
475  //       (IntegerLiteral 'int' 0))
476  //     (BinaryOperator 'int' '<'
477  //       (DeclRefExpr 'int' Var='rsIntIter')
478  //       (IntegerLiteral 'int' 10)
479  //     NULL << CondVar >>
480  //     (UnaryOperator 'int' postfix '++'
481  //       (DeclRefExpr 'int' Var='rsIntIter'))
482  //     (CallExpr 'void'
483  //       (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
484  //         (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
485  //       (UnaryOperator 'rs_font *' prefix '&'
486  //         (ArraySubscriptExpr 'rs_font':'rs_font'
487  //           (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
488  //             (DeclRefExpr 'rs_font [10]' Var='fontArr'))
489  //           (DeclRefExpr 'int' Var='rsIntIter')))))))
490
491  // Create helper variable for iterating through elements
492  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
493  clang::VarDecl *IIVD =
494      clang::VarDecl::Create(C,
495                             DC,
496                             StartLoc,
497                             Loc,
498                             &II,
499                             C.IntTy,
500                             C.getTrivialTypeSourceInfo(C.IntTy),
501                             clang::SC_None,
502                             clang::SC_None);
503  clang::Decl *IID = (clang::Decl *)IIVD;
504
505  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
506  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
507
508  // Form the actual destructor loop
509  // for (Init; Cond; Inc)
510  //   RSClearObjectCall;
511
512  // Init -> "rsIntIter = 0"
513  clang::DeclRefExpr *RefrsIntIter =
514      clang::DeclRefExpr::Create(C,
515                                 clang::NestedNameSpecifierLoc(),
516                                 IIVD,
517                                 Loc,
518                                 C.IntTy,
519                                 clang::VK_RValue,
520                                 NULL);
521
522  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
523      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
524
525  clang::BinaryOperator *Init =
526      new(C) clang::BinaryOperator(RefrsIntIter,
527                                   Int0,
528                                   clang::BO_Assign,
529                                   C.IntTy,
530                                   clang::VK_RValue,
531                                   clang::OK_Ordinary,
532                                   Loc);
533
534  // Cond -> "rsIntIter < NumArrayElements"
535  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
536      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
537
538  clang::BinaryOperator *Cond =
539      new(C) clang::BinaryOperator(RefrsIntIter,
540                                   NumArrayElementsExpr,
541                                   clang::BO_LT,
542                                   C.IntTy,
543                                   clang::VK_RValue,
544                                   clang::OK_Ordinary,
545                                   Loc);
546
547  // Inc -> "rsIntIter++"
548  clang::UnaryOperator *Inc =
549      new(C) clang::UnaryOperator(RefrsIntIter,
550                                  clang::UO_PostInc,
551                                  C.IntTy,
552                                  clang::VK_RValue,
553                                  clang::OK_Ordinary,
554                                  Loc);
555
556  // Body -> "rsClearObject(&VD[rsIntIter]);"
557  // Destructor loop operates on individual array elements
558
559  clang::Expr *RefRSArrPtr =
560      clang::ImplicitCastExpr::Create(C,
561          C.getPointerType(BaseType->getCanonicalTypeInternal()),
562          clang::CK_ArrayToPointerDecay,
563          RefRSArr,
564          NULL,
565          clang::VK_RValue);
566
567  clang::Expr *RefRSArrPtrSubscript =
568      new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
569                                       RefrsIntIter,
570                                       BaseType->getCanonicalTypeInternal(),
571                                       clang::VK_RValue,
572                                       clang::OK_Ordinary,
573                                       Loc);
574
575  RSExportPrimitiveType::DataType DT =
576      RSExportPrimitiveType::GetRSSpecificType(BaseType);
577
578  clang::Stmt *RSClearObjectCall = NULL;
579  if (BaseType->isArrayType()) {
580    RSClearObjectCall =
581        ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
582  } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
583    RSClearObjectCall =
584        ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
585  } else {
586    RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
587  }
588
589  clang::ForStmt *DestructorLoop =
590      new(C) clang::ForStmt(C,
591                            Init,
592                            Cond,
593                            NULL,  // no condVar
594                            Inc,
595                            RSClearObjectCall,
596                            Loc,
597                            Loc,
598                            Loc);
599
600  StmtArray[StmtCtr++] = DestructorLoop;
601  slangAssert(StmtCtr == 2);
602
603  clang::CompoundStmt *CS =
604      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
605
606  return CS;
607}
608
609static unsigned CountRSObjectTypes(clang::ASTContext &C,
610                                   const clang::Type *T,
611                                   clang::SourceLocation Loc) {
612  slangAssert(T);
613  unsigned RSObjectCount = 0;
614
615  if (T->isArrayType()) {
616    return CountRSObjectTypes(C, T->getArrayElementTypeNoTypeQual(), Loc);
617  }
618
619  RSExportPrimitiveType::DataType DT =
620      RSExportPrimitiveType::GetRSSpecificType(T);
621  if (DT != RSExportPrimitiveType::DataTypeUnknown) {
622    return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
623  }
624
625  if (T->isUnionType()) {
626    clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
627    RD = RD->getDefinition();
628    for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
629           FE = RD->field_end();
630         FI != FE;
631         FI++) {
632      const clang::FieldDecl *FD = *FI;
633      const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
634      if (CountRSObjectTypes(C, FT, Loc)) {
635        slangAssert(false && "can't have unions with RS object types!");
636        return 0;
637      }
638    }
639  }
640
641  if (!T->isStructureType()) {
642    return 0;
643  }
644
645  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
646  RD = RD->getDefinition();
647  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
648         FE = RD->field_end();
649       FI != FE;
650       FI++) {
651    const clang::FieldDecl *FD = *FI;
652    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
653    if (CountRSObjectTypes(C, FT, Loc)) {
654      // Sub-structs should only count once (as should arrays, etc.)
655      RSObjectCount++;
656    }
657  }
658
659  return RSObjectCount;
660}
661
662static clang::Stmt *ClearStructRSObject(
663    clang::ASTContext &C,
664    clang::DeclContext *DC,
665    clang::Expr *RefRSStruct,
666    clang::SourceLocation StartLoc,
667    clang::SourceLocation Loc) {
668  const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
669
670  slangAssert(!BaseType->isArrayType());
671
672  // Structs should show up as unknown primitive types
673  slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
674              RSExportPrimitiveType::DataTypeUnknown);
675
676  unsigned FieldsToDestroy = CountRSObjectTypes(C, BaseType, Loc);
677
678  unsigned StmtCount = 0;
679  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
680  for (unsigned i = 0; i < FieldsToDestroy; i++) {
681    StmtArray[i] = NULL;
682  }
683
684  // Populate StmtArray by creating a destructor for each RS object field
685  clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
686  RD = RD->getDefinition();
687  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
688         FE = RD->field_end();
689       FI != FE;
690       FI++) {
691    // We just look through all field declarations to see if we find a
692    // declaration for an RS object type (or an array of one).
693    bool IsArrayType = false;
694    clang::FieldDecl *FD = *FI;
695    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
696    const clang::Type *OrigType = FT;
697    while (FT && FT->isArrayType()) {
698      FT = FT->getArrayElementTypeNoTypeQual();
699      IsArrayType = true;
700    }
701
702    if (RSExportPrimitiveType::IsRSObjectType(FT)) {
703      clang::DeclAccessPair FoundDecl =
704          clang::DeclAccessPair::make(FD, clang::AS_none);
705      clang::MemberExpr *RSObjectMember =
706          clang::MemberExpr::Create(C,
707                                    RefRSStruct,
708                                    false,
709                                    clang::NestedNameSpecifierLoc(),
710                                    FD,
711                                    FoundDecl,
712                                    clang::DeclarationNameInfo(),
713                                    NULL,
714                                    OrigType->getCanonicalTypeInternal(),
715                                    clang::VK_RValue,
716                                    clang::OK_Ordinary);
717
718      slangAssert(StmtCount < FieldsToDestroy);
719
720      if (IsArrayType) {
721        StmtArray[StmtCount++] = ClearArrayRSObject(C,
722                                                    DC,
723                                                    RSObjectMember,
724                                                    StartLoc,
725                                                    Loc);
726      } else {
727        StmtArray[StmtCount++] = ClearSingleRSObject(C,
728                                                     RSObjectMember,
729                                                     Loc);
730      }
731    } else if (FT->isStructureType() && CountRSObjectTypes(C, FT, Loc)) {
732      // In this case, we have a nested struct. We may not end up filling all
733      // of the spaces in StmtArray (sub-structs should handle themselves
734      // with separate compound statements).
735      clang::DeclAccessPair FoundDecl =
736          clang::DeclAccessPair::make(FD, clang::AS_none);
737      clang::MemberExpr *RSObjectMember =
738          clang::MemberExpr::Create(C,
739                                    RefRSStruct,
740                                    false,
741                                    clang::NestedNameSpecifierLoc(),
742                                    FD,
743                                    FoundDecl,
744                                    clang::DeclarationNameInfo(),
745                                    NULL,
746                                    OrigType->getCanonicalTypeInternal(),
747                                    clang::VK_RValue,
748                                    clang::OK_Ordinary);
749
750      if (IsArrayType) {
751        StmtArray[StmtCount++] = ClearArrayRSObject(C,
752                                                    DC,
753                                                    RSObjectMember,
754                                                    StartLoc,
755                                                    Loc);
756      } else {
757        StmtArray[StmtCount++] = ClearStructRSObject(C,
758                                                     DC,
759                                                     RSObjectMember,
760                                                     StartLoc,
761                                                     Loc);
762      }
763    }
764  }
765
766  slangAssert(StmtCount > 0);
767  clang::CompoundStmt *CS =
768      new(C) clang::CompoundStmt(C, StmtArray, StmtCount, Loc, Loc);
769
770  delete [] StmtArray;
771
772  return CS;
773}
774
775static clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
776                                            clang::Expr *DstExpr,
777                                            clang::Expr *SrcExpr,
778                                            clang::SourceLocation StartLoc,
779                                            clang::SourceLocation Loc) {
780  const clang::Type *T = DstExpr->getType().getTypePtr();
781  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
782  slangAssert((SetObjectFD != NULL) &&
783              "rsSetObject doesn't cover all RS object types");
784
785  clang::QualType SetObjectFDType = SetObjectFD->getType();
786  clang::QualType SetObjectFDArgType[2];
787  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
788  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
789
790  clang::Expr *RefRSSetObjectFD =
791      clang::DeclRefExpr::Create(C,
792                                 clang::NestedNameSpecifierLoc(),
793                                 SetObjectFD,
794                                 Loc,
795                                 SetObjectFDType,
796                                 clang::VK_RValue,
797                                 NULL);
798
799  clang::Expr *RSSetObjectFP =
800      clang::ImplicitCastExpr::Create(C,
801                                      C.getPointerType(SetObjectFDType),
802                                      clang::CK_FunctionToPointerDecay,
803                                      RefRSSetObjectFD,
804                                      NULL,
805                                      clang::VK_RValue);
806
807  clang::Expr *ArgList[2];
808  ArgList[0] = new(C) clang::UnaryOperator(DstExpr,
809                                           clang::UO_AddrOf,
810                                           SetObjectFDArgType[0],
811                                           clang::VK_RValue,
812                                           clang::OK_Ordinary,
813                                           Loc);
814  ArgList[1] = SrcExpr;
815
816  clang::CallExpr *RSSetObjectCall =
817      new(C) clang::CallExpr(C,
818                             RSSetObjectFP,
819                             ArgList,
820                             2,
821                             SetObjectFD->getCallResultType(),
822                             clang::VK_RValue,
823                             Loc);
824
825  return RSSetObjectCall;
826}
827
828static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
829                                            clang::Expr *LHS,
830                                            clang::Expr *RHS,
831                                            clang::SourceLocation StartLoc,
832                                            clang::SourceLocation Loc);
833
834static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
835                                           clang::Expr *DstArr,
836                                           clang::Expr *SrcArr,
837                                           clang::SourceLocation StartLoc,
838                                           clang::SourceLocation Loc) {
839  clang::DeclContext *DC = NULL;
840  const clang::Type *BaseType = DstArr->getType().getTypePtr();
841  slangAssert(BaseType->isArrayType());
842
843  int NumArrayElements = ArrayDim(BaseType);
844  // Actually extract out the base RS object type for use later
845  BaseType = BaseType->getArrayElementTypeNoTypeQual();
846
847  clang::Stmt *StmtArray[2] = {NULL};
848  int StmtCtr = 0;
849
850  if (NumArrayElements <= 0) {
851    return NULL;
852  }
853
854  // Create helper variable for iterating through elements
855  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
856  clang::VarDecl *IIVD =
857      clang::VarDecl::Create(C,
858                             DC,
859                             StartLoc,
860                             Loc,
861                             &II,
862                             C.IntTy,
863                             C.getTrivialTypeSourceInfo(C.IntTy),
864                             clang::SC_None,
865                             clang::SC_None);
866  clang::Decl *IID = (clang::Decl *)IIVD;
867
868  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
869  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
870
871  // Form the actual loop
872  // for (Init; Cond; Inc)
873  //   RSSetObjectCall;
874
875  // Init -> "rsIntIter = 0"
876  clang::DeclRefExpr *RefrsIntIter =
877      clang::DeclRefExpr::Create(C,
878                                 clang::NestedNameSpecifierLoc(),
879                                 IIVD,
880                                 Loc,
881                                 C.IntTy,
882                                 clang::VK_RValue,
883                                 NULL);
884
885  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
886      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
887
888  clang::BinaryOperator *Init =
889      new(C) clang::BinaryOperator(RefrsIntIter,
890                                   Int0,
891                                   clang::BO_Assign,
892                                   C.IntTy,
893                                   clang::VK_RValue,
894                                   clang::OK_Ordinary,
895                                   Loc);
896
897  // Cond -> "rsIntIter < NumArrayElements"
898  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
899      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
900
901  clang::BinaryOperator *Cond =
902      new(C) clang::BinaryOperator(RefrsIntIter,
903                                   NumArrayElementsExpr,
904                                   clang::BO_LT,
905                                   C.IntTy,
906                                   clang::VK_RValue,
907                                   clang::OK_Ordinary,
908                                   Loc);
909
910  // Inc -> "rsIntIter++"
911  clang::UnaryOperator *Inc =
912      new(C) clang::UnaryOperator(RefrsIntIter,
913                                  clang::UO_PostInc,
914                                  C.IntTy,
915                                  clang::VK_RValue,
916                                  clang::OK_Ordinary,
917                                  Loc);
918
919  // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
920  // Loop operates on individual array elements
921
922  clang::Expr *DstArrPtr =
923      clang::ImplicitCastExpr::Create(C,
924          C.getPointerType(BaseType->getCanonicalTypeInternal()),
925          clang::CK_ArrayToPointerDecay,
926          DstArr,
927          NULL,
928          clang::VK_RValue);
929
930  clang::Expr *DstArrPtrSubscript =
931      new(C) clang::ArraySubscriptExpr(DstArrPtr,
932                                       RefrsIntIter,
933                                       BaseType->getCanonicalTypeInternal(),
934                                       clang::VK_RValue,
935                                       clang::OK_Ordinary,
936                                       Loc);
937
938  clang::Expr *SrcArrPtr =
939      clang::ImplicitCastExpr::Create(C,
940          C.getPointerType(BaseType->getCanonicalTypeInternal()),
941          clang::CK_ArrayToPointerDecay,
942          SrcArr,
943          NULL,
944          clang::VK_RValue);
945
946  clang::Expr *SrcArrPtrSubscript =
947      new(C) clang::ArraySubscriptExpr(SrcArrPtr,
948                                       RefrsIntIter,
949                                       BaseType->getCanonicalTypeInternal(),
950                                       clang::VK_RValue,
951                                       clang::OK_Ordinary,
952                                       Loc);
953
954  RSExportPrimitiveType::DataType DT =
955      RSExportPrimitiveType::GetRSSpecificType(BaseType);
956
957  clang::Stmt *RSSetObjectCall = NULL;
958  if (BaseType->isArrayType()) {
959    RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
960                                             SrcArrPtrSubscript,
961                                             StartLoc, Loc);
962  } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
963    RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
964                                              SrcArrPtrSubscript,
965                                              StartLoc, Loc);
966  } else {
967    RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
968                                              SrcArrPtrSubscript,
969                                              StartLoc, Loc);
970  }
971
972  clang::ForStmt *DestructorLoop =
973      new(C) clang::ForStmt(C,
974                            Init,
975                            Cond,
976                            NULL,  // no condVar
977                            Inc,
978                            RSSetObjectCall,
979                            Loc,
980                            Loc,
981                            Loc);
982
983  StmtArray[StmtCtr++] = DestructorLoop;
984  slangAssert(StmtCtr == 2);
985
986  clang::CompoundStmt *CS =
987      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
988
989  return CS;
990}
991
992static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
993                                            clang::Expr *LHS,
994                                            clang::Expr *RHS,
995                                            clang::SourceLocation StartLoc,
996                                            clang::SourceLocation Loc) {
997  clang::QualType QT = LHS->getType();
998  const clang::Type *T = QT.getTypePtr();
999  slangAssert(T->isStructureType());
1000  slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
1001
1002  // Keep an extra slot for the original copy (memcpy)
1003  unsigned FieldsToSet = CountRSObjectTypes(C, T, Loc) + 1;
1004
1005  unsigned StmtCount = 0;
1006  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
1007  for (unsigned i = 0; i < FieldsToSet; i++) {
1008    StmtArray[i] = NULL;
1009  }
1010
1011  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
1012  RD = RD->getDefinition();
1013  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1014         FE = RD->field_end();
1015       FI != FE;
1016       FI++) {
1017    bool IsArrayType = false;
1018    clang::FieldDecl *FD = *FI;
1019    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1020    const clang::Type *OrigType = FT;
1021
1022    if (!CountRSObjectTypes(C, FT, Loc)) {
1023      // Skip to next if we don't have any viable RS object types
1024      continue;
1025    }
1026
1027    clang::DeclAccessPair FoundDecl =
1028        clang::DeclAccessPair::make(FD, clang::AS_none);
1029    clang::MemberExpr *DstMember =
1030        clang::MemberExpr::Create(C,
1031                                  LHS,
1032                                  false,
1033                                  clang::NestedNameSpecifierLoc(),
1034                                  FD,
1035                                  FoundDecl,
1036                                  clang::DeclarationNameInfo(),
1037                                  NULL,
1038                                  OrigType->getCanonicalTypeInternal(),
1039                                  clang::VK_RValue,
1040                                  clang::OK_Ordinary);
1041
1042    clang::MemberExpr *SrcMember =
1043        clang::MemberExpr::Create(C,
1044                                  RHS,
1045                                  false,
1046                                  clang::NestedNameSpecifierLoc(),
1047                                  FD,
1048                                  FoundDecl,
1049                                  clang::DeclarationNameInfo(),
1050                                  NULL,
1051                                  OrigType->getCanonicalTypeInternal(),
1052                                  clang::VK_RValue,
1053                                  clang::OK_Ordinary);
1054
1055    if (FT->isArrayType()) {
1056      FT = FT->getArrayElementTypeNoTypeQual();
1057      IsArrayType = true;
1058    }
1059
1060    RSExportPrimitiveType::DataType DT =
1061        RSExportPrimitiveType::GetRSSpecificType(FT);
1062
1063    if (IsArrayType) {
1064      clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
1065      DiagEngine.Report(
1066        clang::FullSourceLoc(Loc, C.getSourceManager()),
1067        DiagEngine.getCustomDiagID(
1068          clang::DiagnosticsEngine::Error,
1069          "Arrays of RS object types within structures cannot be copied"));
1070      // TODO(srhines): Support setting arrays of RS objects
1071      // StmtArray[StmtCount++] =
1072      //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1073    } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
1074      StmtArray[StmtCount++] =
1075          CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1076    } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
1077      StmtArray[StmtCount++] =
1078          CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1079    } else {
1080      slangAssert(false);
1081    }
1082  }
1083
1084  slangAssert(StmtCount > 0 && StmtCount < FieldsToSet);
1085
1086  // We still need to actually do the overall struct copy. For simplicity,
1087  // we just do a straight-up assignment (which will still preserve all
1088  // the proper RS object reference counts).
1089  clang::BinaryOperator *CopyStruct =
1090      new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
1091                                   clang::VK_RValue, clang::OK_Ordinary, Loc);
1092  StmtArray[StmtCount++] = CopyStruct;
1093
1094  clang::CompoundStmt *CS =
1095      new(C) clang::CompoundStmt(C, StmtArray, StmtCount, Loc, Loc);
1096
1097  delete [] StmtArray;
1098
1099  return CS;
1100}
1101
1102}  // namespace
1103
1104void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1105    clang::BinaryOperator *AS) {
1106
1107  clang::QualType QT = AS->getType();
1108
1109  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1110      RSExportPrimitiveType::DataTypeRSFont)->getASTContext();
1111
1112  clang::SourceLocation Loc = AS->getExprLoc();
1113  clang::SourceLocation StartLoc = AS->getExprLoc();
1114  clang::Stmt *UpdatedStmt = NULL;
1115
1116  if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1117    // By definition, this is a struct assignment if we get here
1118    UpdatedStmt =
1119        CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1120  } else {
1121    UpdatedStmt =
1122        CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1123  }
1124
1125  RSASTReplace R(C);
1126  R.ReplaceStmt(mCS, AS, UpdatedStmt);
1127  return;
1128}
1129
1130void RSObjectRefCount::Scope::AppendRSObjectInit(
1131    clang::VarDecl *VD,
1132    clang::DeclStmt *DS,
1133    RSExportPrimitiveType::DataType DT,
1134    clang::Expr *InitExpr) {
1135  slangAssert(VD);
1136
1137  if (!InitExpr) {
1138    return;
1139  }
1140
1141  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1142      RSExportPrimitiveType::DataTypeRSFont)->getASTContext();
1143  clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1144      RSExportPrimitiveType::DataTypeRSFont)->getLocation();
1145  clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1146      RSExportPrimitiveType::DataTypeRSFont)->getInnerLocStart();
1147
1148  if (DT == RSExportPrimitiveType::DataTypeIsStruct) {
1149    const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1150    clang::DeclRefExpr *RefRSVar =
1151        clang::DeclRefExpr::Create(C,
1152                                   clang::NestedNameSpecifierLoc(),
1153                                   VD,
1154                                   Loc,
1155                                   T->getCanonicalTypeInternal(),
1156                                   clang::VK_RValue,
1157                                   NULL);
1158
1159    clang::Stmt *RSSetObjectOps =
1160        CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1161
1162    std::list<clang::Stmt*> StmtList;
1163    StmtList.push_back(RSSetObjectOps);
1164    AppendAfterStmt(C, mCS, DS, StmtList);
1165    return;
1166  }
1167
1168  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1169  slangAssert((SetObjectFD != NULL) &&
1170              "rsSetObject doesn't cover all RS object types");
1171
1172  clang::QualType SetObjectFDType = SetObjectFD->getType();
1173  clang::QualType SetObjectFDArgType[2];
1174  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1175  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1176
1177  clang::Expr *RefRSSetObjectFD =
1178      clang::DeclRefExpr::Create(C,
1179                                 clang::NestedNameSpecifierLoc(),
1180                                 SetObjectFD,
1181                                 Loc,
1182                                 SetObjectFDType,
1183                                 clang::VK_RValue,
1184                                 NULL);
1185
1186  clang::Expr *RSSetObjectFP =
1187      clang::ImplicitCastExpr::Create(C,
1188                                      C.getPointerType(SetObjectFDType),
1189                                      clang::CK_FunctionToPointerDecay,
1190                                      RefRSSetObjectFD,
1191                                      NULL,
1192                                      clang::VK_RValue);
1193
1194  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1195  clang::DeclRefExpr *RefRSVar =
1196      clang::DeclRefExpr::Create(C,
1197                                 clang::NestedNameSpecifierLoc(),
1198                                 VD,
1199                                 Loc,
1200                                 T->getCanonicalTypeInternal(),
1201                                 clang::VK_RValue,
1202                                 NULL);
1203
1204  clang::Expr *ArgList[2];
1205  ArgList[0] = new(C) clang::UnaryOperator(RefRSVar,
1206                                           clang::UO_AddrOf,
1207                                           SetObjectFDArgType[0],
1208                                           clang::VK_RValue,
1209                                           clang::OK_Ordinary,
1210                                           Loc);
1211  ArgList[1] = InitExpr;
1212
1213  clang::CallExpr *RSSetObjectCall =
1214      new(C) clang::CallExpr(C,
1215                             RSSetObjectFP,
1216                             ArgList,
1217                             2,
1218                             SetObjectFD->getCallResultType(),
1219                             clang::VK_RValue,
1220                             Loc);
1221
1222  std::list<clang::Stmt*> StmtList;
1223  StmtList.push_back(RSSetObjectCall);
1224  AppendAfterStmt(C, mCS, DS, StmtList);
1225
1226  return;
1227}
1228
1229void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1230  for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
1231          E = mRSO.end();
1232        I != E;
1233        I++) {
1234    clang::VarDecl *VD = *I;
1235    clang::Stmt *RSClearObjectCall = ClearRSObject(VD, VD->getDeclContext());
1236    if (RSClearObjectCall) {
1237      DestructorVisitor DV((*mRSO.begin())->getASTContext(),
1238                           mCS,
1239                           RSClearObjectCall,
1240                           VD->getSourceRange().getBegin());
1241      DV.Visit(mCS);
1242      DV.InsertDestructors();
1243    }
1244  }
1245  return;
1246}
1247
1248clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
1249    clang::VarDecl *VD,
1250    clang::DeclContext *DC) {
1251  slangAssert(VD);
1252  clang::ASTContext &C = VD->getASTContext();
1253  clang::SourceLocation Loc = VD->getLocation();
1254  clang::SourceLocation StartLoc = VD->getInnerLocStart();
1255  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1256
1257  // Reference expr to target RS object variable
1258  clang::DeclRefExpr *RefRSVar =
1259      clang::DeclRefExpr::Create(C,
1260                                 clang::NestedNameSpecifierLoc(),
1261                                 VD,
1262                                 Loc,
1263                                 T->getCanonicalTypeInternal(),
1264                                 clang::VK_RValue,
1265                                 NULL);
1266
1267  if (T->isArrayType()) {
1268    return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1269  }
1270
1271  RSExportPrimitiveType::DataType DT =
1272      RSExportPrimitiveType::GetRSSpecificType(T);
1273
1274  if (DT == RSExportPrimitiveType::DataTypeUnknown ||
1275      DT == RSExportPrimitiveType::DataTypeIsStruct) {
1276    return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1277  }
1278
1279  slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1280              "Should be RS object");
1281
1282  return ClearSingleRSObject(C, RefRSVar, Loc);
1283}
1284
1285bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1286                                          RSExportPrimitiveType::DataType *DT,
1287                                          clang::Expr **InitExpr) {
1288  slangAssert(VD && DT && InitExpr);
1289  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1290
1291  // Loop through array types to get to base type
1292  while (T && T->isArrayType()) {
1293    T = T->getArrayElementTypeNoTypeQual();
1294  }
1295
1296  bool DataTypeIsStructWithRSObject = false;
1297  *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1298
1299  if (*DT == RSExportPrimitiveType::DataTypeUnknown) {
1300    if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1301      *DT = RSExportPrimitiveType::DataTypeIsStruct;
1302      DataTypeIsStructWithRSObject = true;
1303    } else {
1304      return false;
1305    }
1306  }
1307
1308  bool DataTypeIsRSObject = false;
1309  if (DataTypeIsStructWithRSObject) {
1310    DataTypeIsRSObject = true;
1311  } else {
1312    DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1313  }
1314  *InitExpr = VD->getInit();
1315
1316  if (!DataTypeIsRSObject && *InitExpr) {
1317    // If we already have an initializer for a matrix type, we are done.
1318    return DataTypeIsRSObject;
1319  }
1320
1321  clang::Expr *ZeroInitializer =
1322      CreateZeroInitializerForRSSpecificType(*DT,
1323                                             VD->getASTContext(),
1324                                             VD->getLocation());
1325
1326  if (ZeroInitializer) {
1327    ZeroInitializer->setType(T->getCanonicalTypeInternal());
1328    VD->setInit(ZeroInitializer);
1329  }
1330
1331  return DataTypeIsRSObject;
1332}
1333
1334clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
1335    RSExportPrimitiveType::DataType DT,
1336    clang::ASTContext &C,
1337    const clang::SourceLocation &Loc) {
1338  clang::Expr *Res = NULL;
1339  switch (DT) {
1340    case RSExportPrimitiveType::DataTypeIsStruct:
1341    case RSExportPrimitiveType::DataTypeRSElement:
1342    case RSExportPrimitiveType::DataTypeRSType:
1343    case RSExportPrimitiveType::DataTypeRSAllocation:
1344    case RSExportPrimitiveType::DataTypeRSSampler:
1345    case RSExportPrimitiveType::DataTypeRSScript:
1346    case RSExportPrimitiveType::DataTypeRSMesh:
1347    case RSExportPrimitiveType::DataTypeRSPath:
1348    case RSExportPrimitiveType::DataTypeRSProgramFragment:
1349    case RSExportPrimitiveType::DataTypeRSProgramVertex:
1350    case RSExportPrimitiveType::DataTypeRSProgramRaster:
1351    case RSExportPrimitiveType::DataTypeRSProgramStore:
1352    case RSExportPrimitiveType::DataTypeRSFont: {
1353      //    (ImplicitCastExpr 'nullptr_t'
1354      //      (IntegerLiteral 0)))
1355      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
1356      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
1357      clang::Expr *CastToNull =
1358          clang::ImplicitCastExpr::Create(C,
1359                                          C.NullPtrTy,
1360                                          clang::CK_IntegralToPointer,
1361                                          Int0,
1362                                          NULL,
1363                                          clang::VK_RValue);
1364
1365      Res = new(C) clang::InitListExpr(C, Loc, &CastToNull, 1, Loc);
1366      break;
1367    }
1368    case RSExportPrimitiveType::DataTypeRSMatrix2x2:
1369    case RSExportPrimitiveType::DataTypeRSMatrix3x3:
1370    case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
1371      // RS matrix is not completely an RS object. They hold data by themselves.
1372      // (InitListExpr rs_matrix2x2
1373      //   (InitListExpr float[4]
1374      //     (FloatingLiteral 0)
1375      //     (FloatingLiteral 0)
1376      //     (FloatingLiteral 0)
1377      //     (FloatingLiteral 0)))
1378      clang::QualType FloatTy = C.FloatTy;
1379      // Constructor sets value to 0.0f by default
1380      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
1381      clang::FloatingLiteral *Float0Val =
1382          clang::FloatingLiteral::Create(C,
1383                                         Val,
1384                                         /* isExact = */true,
1385                                         FloatTy,
1386                                         Loc);
1387
1388      unsigned N = 0;
1389      if (DT == RSExportPrimitiveType::DataTypeRSMatrix2x2)
1390        N = 2;
1391      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix3x3)
1392        N = 3;
1393      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix4x4)
1394        N = 4;
1395
1396      // Directly allocate 16 elements instead of dynamically allocate N*N
1397      clang::Expr *InitVals[16];
1398      for (unsigned i = 0; i < sizeof(InitVals) / sizeof(InitVals[0]); i++)
1399        InitVals[i] = Float0Val;
1400      clang::Expr *InitExpr =
1401          new(C) clang::InitListExpr(C, Loc, InitVals, N * N, Loc);
1402      InitExpr->setType(C.getConstantArrayType(FloatTy,
1403                                               llvm::APInt(32, N * N),
1404                                               clang::ArrayType::Normal,
1405                                               /* EltTypeQuals = */0));
1406
1407      Res = new(C) clang::InitListExpr(C, Loc, &InitExpr, 1, Loc);
1408      break;
1409    }
1410    case RSExportPrimitiveType::DataTypeUnknown:
1411    case RSExportPrimitiveType::DataTypeFloat16:
1412    case RSExportPrimitiveType::DataTypeFloat32:
1413    case RSExportPrimitiveType::DataTypeFloat64:
1414    case RSExportPrimitiveType::DataTypeSigned8:
1415    case RSExportPrimitiveType::DataTypeSigned16:
1416    case RSExportPrimitiveType::DataTypeSigned32:
1417    case RSExportPrimitiveType::DataTypeSigned64:
1418    case RSExportPrimitiveType::DataTypeUnsigned8:
1419    case RSExportPrimitiveType::DataTypeUnsigned16:
1420    case RSExportPrimitiveType::DataTypeUnsigned32:
1421    case RSExportPrimitiveType::DataTypeUnsigned64:
1422    case RSExportPrimitiveType::DataTypeBoolean:
1423    case RSExportPrimitiveType::DataTypeUnsigned565:
1424    case RSExportPrimitiveType::DataTypeUnsigned5551:
1425    case RSExportPrimitiveType::DataTypeUnsigned4444:
1426    case RSExportPrimitiveType::DataTypeMax: {
1427      slangAssert(false && "Not RS object type!");
1428    }
1429    // No default case will enable compiler detecting the missing cases
1430  }
1431
1432  return Res;
1433}
1434
1435void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1436  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1437       I != E;
1438       I++) {
1439    clang::Decl *D = *I;
1440    if (D->getKind() == clang::Decl::Var) {
1441      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1442      RSExportPrimitiveType::DataType DT =
1443          RSExportPrimitiveType::DataTypeUnknown;
1444      clang::Expr *InitExpr = NULL;
1445      if (InitializeRSObject(VD, &DT, &InitExpr)) {
1446        getCurrentScope()->addRSObject(VD);
1447        getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1448      }
1449    }
1450  }
1451  return;
1452}
1453
1454void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1455  if (!CS->body_empty()) {
1456    // Push a new scope
1457    Scope *S = new Scope(CS);
1458    mScopeStack.push(S);
1459
1460    VisitStmt(CS);
1461
1462    // Destroy the scope
1463    slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1464    S->InsertLocalVarDestructors();
1465    mScopeStack.pop();
1466    delete S;
1467  }
1468  return;
1469}
1470
1471void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1472  clang::QualType QT = AS->getType();
1473
1474  if (CountRSObjectTypes(mCtx, QT.getTypePtr(), AS->getExprLoc())) {
1475    getCurrentScope()->ReplaceRSObjectAssignment(AS);
1476  }
1477
1478  return;
1479}
1480
1481void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1482  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1483       I != E;
1484       I++) {
1485    if (clang::Stmt *Child = *I) {
1486      Visit(Child);
1487    }
1488  }
1489  return;
1490}
1491
1492// This function walks the list of global variables and (potentially) creates
1493// a single global static destructor function that properly decrements
1494// reference counts on the contained RS object types.
1495clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
1496  Init();
1497
1498  clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
1499  clang::SourceLocation loc;
1500
1501  llvm::StringRef SR(".rs.dtor");
1502  clang::IdentifierInfo &II = mCtx.Idents.get(SR);
1503  clang::DeclarationName N(&II);
1504  clang::FunctionProtoType::ExtProtoInfo EPI;
1505  clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy, NULL, 0, EPI);
1506  clang::FunctionDecl *FD = NULL;
1507
1508  // Generate rsClearObject() call chains for every global variable
1509  // (whether static or extern).
1510  std::list<clang::Stmt *> StmtList;
1511  for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
1512          E = DC->decls_end(); I != E; I++) {
1513    clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
1514    if (VD) {
1515      if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(), loc)) {
1516        if (!FD) {
1517          // Only create FD if we are going to use it.
1518          FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, NULL);
1519        }
1520        // Make sure to create any helpers within the function's DeclContext,
1521        // not the one associated with the global translation unit.
1522        clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
1523        StmtList.push_back(RSClearObjectCall);
1524      }
1525    }
1526  }
1527
1528  // Nothing needs to be destroyed, so don't emit a dtor.
1529  if (StmtList.empty()) {
1530    return NULL;
1531  }
1532
1533  clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
1534
1535  FD->setBody(CS);
1536
1537  return FD;
1538}
1539
1540}  // namespace slang
1541