slang_rs_object_ref_count.cpp revision 78e69cb06b9b0683b2ac9dcafde87b867690ef2f
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_object_ref_count.h"
18
19#include <list>
20
21#include "clang/AST/DeclGroup.h"
22#include "clang/AST/Expr.h"
23#include "clang/AST/NestedNameSpecifier.h"
24#include "clang/AST/OperationKinds.h"
25#include "clang/AST/Stmt.h"
26#include "clang/AST/StmtVisitor.h"
27
28#include "slang_assert.h"
29#include "slang_rs.h"
30#include "slang_rs_ast_replace.h"
31#include "slang_rs_export_type.h"
32
33namespace slang {
34
35clang::FunctionDecl *RSObjectRefCount::
36    RSSetObjectFD[RSExportPrimitiveType::LastRSObjectType -
37                  RSExportPrimitiveType::FirstRSObjectType + 1];
38clang::FunctionDecl *RSObjectRefCount::
39    RSClearObjectFD[RSExportPrimitiveType::LastRSObjectType -
40                    RSExportPrimitiveType::FirstRSObjectType + 1];
41
42void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
43  for (unsigned i = 0;
44       i < (sizeof(RSClearObjectFD) / sizeof(clang::FunctionDecl*));
45       i++) {
46    RSSetObjectFD[i] = NULL;
47    RSClearObjectFD[i] = NULL;
48  }
49
50  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
51
52  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
53          E = TUDecl->decls_end(); I != E; I++) {
54    if ((I->getKind() >= clang::Decl::firstFunction) &&
55        (I->getKind() <= clang::Decl::lastFunction)) {
56      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
57
58      // points to RSSetObjectFD or RSClearObjectFD
59      clang::FunctionDecl **RSObjectFD;
60
61      if (FD->getName() == "rsSetObject") {
62        slangAssert((FD->getNumParams() == 2) &&
63                    "Invalid rsSetObject function prototype (# params)");
64        RSObjectFD = RSSetObjectFD;
65      } else if (FD->getName() == "rsClearObject") {
66        slangAssert((FD->getNumParams() == 1) &&
67                    "Invalid rsClearObject function prototype (# params)");
68        RSObjectFD = RSClearObjectFD;
69      } else {
70        continue;
71      }
72
73      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
74      clang::QualType PVT = PVD->getOriginalType();
75      // The first parameter must be a pointer like rs_allocation*
76      slangAssert(PVT->isPointerType() &&
77          "Invalid rs{Set,Clear}Object function prototype (pointer param)");
78
79      // The rs object type passed to the FD
80      clang::QualType RST = PVT->getPointeeType();
81      RSExportPrimitiveType::DataType DT =
82          RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
83      slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
84             && "must be RS object type");
85
86      RSObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)] = FD;
87    }
88  }
89}
90
91namespace {
92
93// This function constructs a new CompoundStmt from the input StmtList.
94static clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
95      std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
96  unsigned NewStmtCount = StmtList.size();
97  unsigned CompoundStmtCount = 0;
98
99  clang::Stmt **CompoundStmtList;
100  CompoundStmtList = new clang::Stmt*[NewStmtCount];
101
102  std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
103  std::list<clang::Stmt*>::const_iterator E = StmtList.end();
104  for ( ; I != E; I++) {
105    CompoundStmtList[CompoundStmtCount++] = *I;
106  }
107  slangAssert(CompoundStmtCount == NewStmtCount);
108
109  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(C,
110                                                       CompoundStmtList,
111                                                       CompoundStmtCount,
112                                                       Loc,
113                                                       Loc);
114
115  delete [] CompoundStmtList;
116
117  return CS;
118}
119
120static void AppendAfterStmt(clang::ASTContext &C,
121                            clang::CompoundStmt *CS,
122                            clang::Stmt *S,
123                            std::list<clang::Stmt*> &StmtList) {
124  slangAssert(CS);
125  clang::CompoundStmt::body_iterator bI = CS->body_begin();
126  clang::CompoundStmt::body_iterator bE = CS->body_end();
127  clang::Stmt **UpdatedStmtList =
128      new clang::Stmt*[CS->size() + StmtList.size()];
129
130  unsigned UpdatedStmtCount = 0;
131  unsigned Once = 0;
132  for ( ; bI != bE; bI++) {
133    if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
134      // If we come across a return here, we don't have anything we can
135      // reasonably replace. We should have already inserted our destructor
136      // code in the proper spot, so we just clean up and return.
137      delete [] UpdatedStmtList;
138
139      return;
140    }
141
142    UpdatedStmtList[UpdatedStmtCount++] = *bI;
143
144    if ((*bI == S) && !Once) {
145      Once++;
146      std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
147      std::list<clang::Stmt*>::const_iterator E = StmtList.end();
148      for ( ; I != E; I++) {
149        UpdatedStmtList[UpdatedStmtCount++] = *I;
150      }
151    }
152  }
153  slangAssert(Once <= 1);
154
155  // When S is NULL, we are appending to the end of the CompoundStmt.
156  if (!S) {
157    slangAssert(Once == 0);
158    std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
159    std::list<clang::Stmt*>::const_iterator E = StmtList.end();
160    for ( ; I != E; I++) {
161      UpdatedStmtList[UpdatedStmtCount++] = *I;
162    }
163  }
164
165  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
166
167  delete [] UpdatedStmtList;
168
169  return;
170}
171
172// This class visits a compound statement and inserts the StmtList containing
173// destructors in proper locations. This includes inserting them before any
174// return statement in any sub-block, at the end of the logical enclosing
175// scope (compound statement), and/or before any break/continue statement that
176// would resume outside the declared scope. We will not handle the case for
177// goto statements that leave a local scope.
178//
179// To accomplish these goals, it collects a list of sub-Stmt's that
180// correspond to scope exit points. It then uses an RSASTReplace visitor to
181// transform the AST, inserting appropriate destructors before each of those
182// sub-Stmt's (and also before the exit of the outermost containing Stmt for
183// the scope).
184class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
185 private:
186  clang::ASTContext &mCtx;
187
188  // The loop depth of the currently visited node.
189  int mLoopDepth;
190
191  // The switch statement depth of the currently visited node.
192  // Note that this is tracked separately from the loop depth because
193  // SwitchStmt-contained ContinueStmt's should have destructors for the
194  // corresponding loop scope.
195  int mSwitchDepth;
196
197  // The outermost statement block that we are currently visiting.
198  // This should always be a CompoundStmt.
199  clang::Stmt *mOuterStmt;
200
201  // The list of destructors to execute for this scope.
202  std::list<clang::Stmt*> &mStmtList;
203
204  // The stack of statements which should be replaced by a compound statement
205  // containing the new destructor calls followed by the original Stmt.
206  std::stack<clang::Stmt*> mReplaceStmtStack;
207
208 public:
209  DestructorVisitor(clang::ASTContext &C,
210                    clang::Stmt* OuterStmt,
211                    std::list<clang::Stmt*> &StmtList);
212
213  // This code walks the collected list of Stmts to replace and actually does
214  // the replacement. It also finishes up by appending appropriate destructors
215  // to the current outermost CompoundStmt.
216  void InsertDestructors() {
217    clang::Stmt *S = NULL;
218    while (!mReplaceStmtStack.empty()) {
219      S = mReplaceStmtStack.top();
220      mReplaceStmtStack.pop();
221
222      mStmtList.push_back(S);
223      clang::CompoundStmt *CS =
224          BuildCompoundStmt(mCtx, mStmtList, S->getLocEnd());
225      mStmtList.pop_back();
226
227      RSASTReplace R(mCtx);
228      R.ReplaceStmt(mOuterStmt, S, CS);
229    }
230    clang::CompoundStmt *CS = dyn_cast<clang::CompoundStmt>(mOuterStmt);
231    slangAssert(CS);
232    if (CS) {
233      AppendAfterStmt(mCtx, CS, NULL, mStmtList);
234    }
235  }
236
237  void VisitStmt(clang::Stmt *S);
238  void VisitCompoundStmt(clang::CompoundStmt *CS);
239
240  void VisitBreakStmt(clang::BreakStmt *BS);
241  void VisitCaseStmt(clang::CaseStmt *CS);
242  void VisitContinueStmt(clang::ContinueStmt *CS);
243  void VisitDefaultStmt(clang::DefaultStmt *DS);
244  void VisitDoStmt(clang::DoStmt *DS);
245  void VisitForStmt(clang::ForStmt *FS);
246  void VisitIfStmt(clang::IfStmt *IS);
247  void VisitReturnStmt(clang::ReturnStmt *RS);
248  void VisitSwitchCase(clang::SwitchCase *SC);
249  void VisitSwitchStmt(clang::SwitchStmt *SS);
250  void VisitWhileStmt(clang::WhileStmt *WS);
251};
252
253DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
254                         clang::Stmt *OuterStmt,
255                         std::list<clang::Stmt*> &StmtList)
256  : mCtx(C),
257    mLoopDepth(0),
258    mSwitchDepth(0),
259    mOuterStmt(OuterStmt),
260    mStmtList(StmtList) {
261  return;
262}
263
264void DestructorVisitor::VisitStmt(clang::Stmt *S) {
265  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
266       I != E;
267       I++) {
268    if (clang::Stmt *Child = *I) {
269      Visit(Child);
270    }
271  }
272  return;
273}
274
275void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
276  VisitStmt(CS);
277  return;
278}
279
280void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
281  VisitStmt(BS);
282  if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
283    mReplaceStmtStack.push(BS);
284  }
285  return;
286}
287
288void DestructorVisitor::VisitCaseStmt(clang::CaseStmt *CS) {
289  VisitStmt(CS);
290  return;
291}
292
293void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
294  VisitStmt(CS);
295  if (mLoopDepth == 0) {
296    // Switch statements can have nested continues.
297    mReplaceStmtStack.push(CS);
298  }
299  return;
300}
301
302void DestructorVisitor::VisitDefaultStmt(clang::DefaultStmt *DS) {
303  VisitStmt(DS);
304  return;
305}
306
307void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
308  mLoopDepth++;
309  VisitStmt(DS);
310  mLoopDepth--;
311  return;
312}
313
314void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
315  mLoopDepth++;
316  VisitStmt(FS);
317  mLoopDepth--;
318  return;
319}
320
321void DestructorVisitor::VisitIfStmt(clang::IfStmt *IS) {
322  VisitStmt(IS);
323  return;
324}
325
326void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
327  mReplaceStmtStack.push(RS);
328  return;
329}
330
331void DestructorVisitor::VisitSwitchCase(clang::SwitchCase *SC) {
332  slangAssert(false && "Both case and default have specialized handlers");
333  VisitStmt(SC);
334  return;
335}
336
337void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
338  mSwitchDepth++;
339  VisitStmt(SS);
340  mSwitchDepth--;
341  return;
342}
343
344void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
345  mLoopDepth++;
346  VisitStmt(WS);
347  mLoopDepth--;
348  return;
349}
350
351clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
352                                 clang::Expr *RefRSVar,
353                                 clang::SourceLocation Loc) {
354  slangAssert(RefRSVar);
355  const clang::Type *T = RefRSVar->getType().getTypePtr();
356  slangAssert(!T->isArrayType() &&
357              "Should not be destroying arrays with this function");
358
359  clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
360  slangAssert((ClearObjectFD != NULL) &&
361              "rsClearObject doesn't cover all RS object types");
362
363  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
364  clang::QualType ClearObjectFDArgType =
365      ClearObjectFD->getParamDecl(0)->getOriginalType();
366
367  // Example destructor for "rs_font localFont;"
368  //
369  // (CallExpr 'void'
370  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
371  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
372  //   (UnaryOperator 'rs_font *' prefix '&'
373  //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
374
375  // Get address of targeted RS object
376  clang::Expr *AddrRefRSVar =
377      new(C) clang::UnaryOperator(RefRSVar,
378                                  clang::UO_AddrOf,
379                                  ClearObjectFDArgType,
380                                  clang::VK_RValue,
381                                  clang::OK_Ordinary,
382                                  Loc);
383
384  clang::Expr *RefRSClearObjectFD =
385      clang::DeclRefExpr::Create(C,
386                                 clang::NestedNameSpecifierLoc(),
387                                 ClearObjectFD,
388                                 ClearObjectFD->getLocation(),
389                                 ClearObjectFDType,
390                                 clang::VK_RValue,
391                                 NULL);
392
393  clang::Expr *RSClearObjectFP =
394      clang::ImplicitCastExpr::Create(C,
395                                      C.getPointerType(ClearObjectFDType),
396                                      clang::CK_FunctionToPointerDecay,
397                                      RefRSClearObjectFD,
398                                      NULL,
399                                      clang::VK_RValue);
400
401  clang::CallExpr *RSClearObjectCall =
402      new(C) clang::CallExpr(C,
403                             RSClearObjectFP,
404                             &AddrRefRSVar,
405                             1,
406                             ClearObjectFD->getCallResultType(),
407                             clang::VK_RValue,
408                             Loc);
409
410  return RSClearObjectCall;
411}
412
413static int ArrayDim(const clang::Type *T) {
414  if (!T || !T->isArrayType()) {
415    return 0;
416  }
417
418  const clang::ConstantArrayType *CAT =
419    static_cast<const clang::ConstantArrayType *>(T);
420  return static_cast<int>(CAT->getSize().getSExtValue());
421}
422
423static clang::Stmt *ClearStructRSObject(
424    clang::ASTContext &C,
425    clang::DeclContext *DC,
426    clang::Expr *RefRSStruct,
427    clang::SourceLocation StartLoc,
428    clang::SourceLocation Loc);
429
430static clang::Stmt *ClearArrayRSObject(
431    clang::ASTContext &C,
432    clang::DeclContext *DC,
433    clang::Expr *RefRSArr,
434    clang::SourceLocation StartLoc,
435    clang::SourceLocation Loc) {
436  const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
437  slangAssert(BaseType->isArrayType());
438
439  int NumArrayElements = ArrayDim(BaseType);
440  // Actually extract out the base RS object type for use later
441  BaseType = BaseType->getArrayElementTypeNoTypeQual();
442
443  clang::Stmt *StmtArray[2] = {NULL};
444  int StmtCtr = 0;
445
446  if (NumArrayElements <= 0) {
447    return NULL;
448  }
449
450  // Example destructor loop for "rs_font fontArr[10];"
451  //
452  // (CompoundStmt
453  //   (DeclStmt "int rsIntIter")
454  //   (ForStmt
455  //     (BinaryOperator 'int' '='
456  //       (DeclRefExpr 'int' Var='rsIntIter')
457  //       (IntegerLiteral 'int' 0))
458  //     (BinaryOperator 'int' '<'
459  //       (DeclRefExpr 'int' Var='rsIntIter')
460  //       (IntegerLiteral 'int' 10)
461  //     NULL << CondVar >>
462  //     (UnaryOperator 'int' postfix '++'
463  //       (DeclRefExpr 'int' Var='rsIntIter'))
464  //     (CallExpr 'void'
465  //       (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
466  //         (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
467  //       (UnaryOperator 'rs_font *' prefix '&'
468  //         (ArraySubscriptExpr 'rs_font':'rs_font'
469  //           (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
470  //             (DeclRefExpr 'rs_font [10]' Var='fontArr'))
471  //           (DeclRefExpr 'int' Var='rsIntIter')))))))
472
473  // Create helper variable for iterating through elements
474  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
475  clang::VarDecl *IIVD =
476      clang::VarDecl::Create(C,
477                             DC,
478                             StartLoc,
479                             Loc,
480                             &II,
481                             C.IntTy,
482                             C.getTrivialTypeSourceInfo(C.IntTy),
483                             clang::SC_None,
484                             clang::SC_None);
485  clang::Decl *IID = (clang::Decl *)IIVD;
486
487  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
488  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
489
490  // Form the actual destructor loop
491  // for (Init; Cond; Inc)
492  //   RSClearObjectCall;
493
494  // Init -> "rsIntIter = 0"
495  clang::DeclRefExpr *RefrsIntIter =
496      clang::DeclRefExpr::Create(C,
497                                 clang::NestedNameSpecifierLoc(),
498                                 IIVD,
499                                 Loc,
500                                 C.IntTy,
501                                 clang::VK_RValue,
502                                 NULL);
503
504  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
505      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
506
507  clang::BinaryOperator *Init =
508      new(C) clang::BinaryOperator(RefrsIntIter,
509                                   Int0,
510                                   clang::BO_Assign,
511                                   C.IntTy,
512                                   clang::VK_RValue,
513                                   clang::OK_Ordinary,
514                                   Loc);
515
516  // Cond -> "rsIntIter < NumArrayElements"
517  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
518      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
519
520  clang::BinaryOperator *Cond =
521      new(C) clang::BinaryOperator(RefrsIntIter,
522                                   NumArrayElementsExpr,
523                                   clang::BO_LT,
524                                   C.IntTy,
525                                   clang::VK_RValue,
526                                   clang::OK_Ordinary,
527                                   Loc);
528
529  // Inc -> "rsIntIter++"
530  clang::UnaryOperator *Inc =
531      new(C) clang::UnaryOperator(RefrsIntIter,
532                                  clang::UO_PostInc,
533                                  C.IntTy,
534                                  clang::VK_RValue,
535                                  clang::OK_Ordinary,
536                                  Loc);
537
538  // Body -> "rsClearObject(&VD[rsIntIter]);"
539  // Destructor loop operates on individual array elements
540
541  clang::Expr *RefRSArrPtr =
542      clang::ImplicitCastExpr::Create(C,
543          C.getPointerType(BaseType->getCanonicalTypeInternal()),
544          clang::CK_ArrayToPointerDecay,
545          RefRSArr,
546          NULL,
547          clang::VK_RValue);
548
549  clang::Expr *RefRSArrPtrSubscript =
550      new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
551                                       RefrsIntIter,
552                                       BaseType->getCanonicalTypeInternal(),
553                                       clang::VK_RValue,
554                                       clang::OK_Ordinary,
555                                       Loc);
556
557  RSExportPrimitiveType::DataType DT =
558      RSExportPrimitiveType::GetRSSpecificType(BaseType);
559
560  clang::Stmt *RSClearObjectCall = NULL;
561  if (BaseType->isArrayType()) {
562    RSClearObjectCall =
563        ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
564  } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
565    RSClearObjectCall =
566        ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
567  } else {
568    RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
569  }
570
571  clang::ForStmt *DestructorLoop =
572      new(C) clang::ForStmt(C,
573                            Init,
574                            Cond,
575                            NULL,  // no condVar
576                            Inc,
577                            RSClearObjectCall,
578                            Loc,
579                            Loc,
580                            Loc);
581
582  StmtArray[StmtCtr++] = DestructorLoop;
583  slangAssert(StmtCtr == 2);
584
585  clang::CompoundStmt *CS =
586      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
587
588  return CS;
589}
590
591static unsigned CountRSObjectTypes(clang::ASTContext &C,
592                                   const clang::Type *T,
593                                   clang::SourceLocation Loc) {
594  slangAssert(T);
595  unsigned RSObjectCount = 0;
596
597  if (T->isArrayType()) {
598    return CountRSObjectTypes(C, T->getArrayElementTypeNoTypeQual(), Loc);
599  }
600
601  RSExportPrimitiveType::DataType DT =
602      RSExportPrimitiveType::GetRSSpecificType(T);
603  if (DT != RSExportPrimitiveType::DataTypeUnknown) {
604    return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
605  }
606
607  if (T->isUnionType()) {
608    clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
609    RD = RD->getDefinition();
610    for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
611           FE = RD->field_end();
612         FI != FE;
613         FI++) {
614      const clang::FieldDecl *FD = *FI;
615      const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
616      if (CountRSObjectTypes(C, FT, Loc)) {
617        slangAssert(false && "can't have unions with RS object types!");
618        return 0;
619      }
620    }
621  }
622
623  if (!T->isStructureType()) {
624    return 0;
625  }
626
627  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
628  RD = RD->getDefinition();
629  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
630         FE = RD->field_end();
631       FI != FE;
632       FI++) {
633    const clang::FieldDecl *FD = *FI;
634    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
635    if (CountRSObjectTypes(C, FT, Loc)) {
636      // Sub-structs should only count once (as should arrays, etc.)
637      RSObjectCount++;
638    }
639  }
640
641  return RSObjectCount;
642}
643
644static clang::Stmt *ClearStructRSObject(
645    clang::ASTContext &C,
646    clang::DeclContext *DC,
647    clang::Expr *RefRSStruct,
648    clang::SourceLocation StartLoc,
649    clang::SourceLocation Loc) {
650  const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
651
652  slangAssert(!BaseType->isArrayType());
653
654  // Structs should show up as unknown primitive types
655  slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
656              RSExportPrimitiveType::DataTypeUnknown);
657
658  unsigned FieldsToDestroy = CountRSObjectTypes(C, BaseType, Loc);
659
660  unsigned StmtCount = 0;
661  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
662  for (unsigned i = 0; i < FieldsToDestroy; i++) {
663    StmtArray[i] = NULL;
664  }
665
666  // Populate StmtArray by creating a destructor for each RS object field
667  clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
668  RD = RD->getDefinition();
669  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
670         FE = RD->field_end();
671       FI != FE;
672       FI++) {
673    // We just look through all field declarations to see if we find a
674    // declaration for an RS object type (or an array of one).
675    bool IsArrayType = false;
676    clang::FieldDecl *FD = *FI;
677    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
678    const clang::Type *OrigType = FT;
679    while (FT && FT->isArrayType()) {
680      FT = FT->getArrayElementTypeNoTypeQual();
681      IsArrayType = true;
682    }
683
684    if (RSExportPrimitiveType::IsRSObjectType(FT)) {
685      clang::DeclAccessPair FoundDecl =
686          clang::DeclAccessPair::make(FD, clang::AS_none);
687      clang::MemberExpr *RSObjectMember =
688          clang::MemberExpr::Create(C,
689                                    RefRSStruct,
690                                    false,
691                                    clang::NestedNameSpecifierLoc(),
692                                    FD,
693                                    FoundDecl,
694                                    clang::DeclarationNameInfo(),
695                                    NULL,
696                                    OrigType->getCanonicalTypeInternal(),
697                                    clang::VK_RValue,
698                                    clang::OK_Ordinary);
699
700      slangAssert(StmtCount < FieldsToDestroy);
701
702      if (IsArrayType) {
703        StmtArray[StmtCount++] = ClearArrayRSObject(C,
704                                                    DC,
705                                                    RSObjectMember,
706                                                    StartLoc,
707                                                    Loc);
708      } else {
709        StmtArray[StmtCount++] = ClearSingleRSObject(C,
710                                                     RSObjectMember,
711                                                     Loc);
712      }
713    } else if (FT->isStructureType() && CountRSObjectTypes(C, FT, Loc)) {
714      // In this case, we have a nested struct. We may not end up filling all
715      // of the spaces in StmtArray (sub-structs should handle themselves
716      // with separate compound statements).
717      clang::DeclAccessPair FoundDecl =
718          clang::DeclAccessPair::make(FD, clang::AS_none);
719      clang::MemberExpr *RSObjectMember =
720          clang::MemberExpr::Create(C,
721                                    RefRSStruct,
722                                    false,
723                                    clang::NestedNameSpecifierLoc(),
724                                    FD,
725                                    FoundDecl,
726                                    clang::DeclarationNameInfo(),
727                                    NULL,
728                                    OrigType->getCanonicalTypeInternal(),
729                                    clang::VK_RValue,
730                                    clang::OK_Ordinary);
731
732      if (IsArrayType) {
733        StmtArray[StmtCount++] = ClearArrayRSObject(C,
734                                                    DC,
735                                                    RSObjectMember,
736                                                    StartLoc,
737                                                    Loc);
738      } else {
739        StmtArray[StmtCount++] = ClearStructRSObject(C,
740                                                     DC,
741                                                     RSObjectMember,
742                                                     StartLoc,
743                                                     Loc);
744      }
745    }
746  }
747
748  slangAssert(StmtCount > 0);
749  clang::CompoundStmt *CS =
750      new(C) clang::CompoundStmt(C, StmtArray, StmtCount, Loc, Loc);
751
752  delete [] StmtArray;
753
754  return CS;
755}
756
757static clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
758                                            clang::Expr *DstExpr,
759                                            clang::Expr *SrcExpr,
760                                            clang::SourceLocation StartLoc,
761                                            clang::SourceLocation Loc) {
762  const clang::Type *T = DstExpr->getType().getTypePtr();
763  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
764  slangAssert((SetObjectFD != NULL) &&
765              "rsSetObject doesn't cover all RS object types");
766
767  clang::QualType SetObjectFDType = SetObjectFD->getType();
768  clang::QualType SetObjectFDArgType[2];
769  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
770  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
771
772  clang::Expr *RefRSSetObjectFD =
773      clang::DeclRefExpr::Create(C,
774                                 clang::NestedNameSpecifierLoc(),
775                                 SetObjectFD,
776                                 Loc,
777                                 SetObjectFDType,
778                                 clang::VK_RValue,
779                                 NULL);
780
781  clang::Expr *RSSetObjectFP =
782      clang::ImplicitCastExpr::Create(C,
783                                      C.getPointerType(SetObjectFDType),
784                                      clang::CK_FunctionToPointerDecay,
785                                      RefRSSetObjectFD,
786                                      NULL,
787                                      clang::VK_RValue);
788
789  clang::Expr *ArgList[2];
790  ArgList[0] = new(C) clang::UnaryOperator(DstExpr,
791                                           clang::UO_AddrOf,
792                                           SetObjectFDArgType[0],
793                                           clang::VK_RValue,
794                                           clang::OK_Ordinary,
795                                           Loc);
796  ArgList[1] = SrcExpr;
797
798  clang::CallExpr *RSSetObjectCall =
799      new(C) clang::CallExpr(C,
800                             RSSetObjectFP,
801                             ArgList,
802                             2,
803                             SetObjectFD->getCallResultType(),
804                             clang::VK_RValue,
805                             Loc);
806
807  return RSSetObjectCall;
808}
809
810static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
811                                            clang::Expr *LHS,
812                                            clang::Expr *RHS,
813                                            clang::SourceLocation StartLoc,
814                                            clang::SourceLocation Loc);
815
816static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
817                                           clang::Expr *DstArr,
818                                           clang::Expr *SrcArr,
819                                           clang::SourceLocation StartLoc,
820                                           clang::SourceLocation Loc) {
821  clang::DeclContext *DC = NULL;
822  const clang::Type *BaseType = DstArr->getType().getTypePtr();
823  slangAssert(BaseType->isArrayType());
824
825  int NumArrayElements = ArrayDim(BaseType);
826  // Actually extract out the base RS object type for use later
827  BaseType = BaseType->getArrayElementTypeNoTypeQual();
828
829  clang::Stmt *StmtArray[2] = {NULL};
830  int StmtCtr = 0;
831
832  if (NumArrayElements <= 0) {
833    return NULL;
834  }
835
836  // Create helper variable for iterating through elements
837  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
838  clang::VarDecl *IIVD =
839      clang::VarDecl::Create(C,
840                             DC,
841                             StartLoc,
842                             Loc,
843                             &II,
844                             C.IntTy,
845                             C.getTrivialTypeSourceInfo(C.IntTy),
846                             clang::SC_None,
847                             clang::SC_None);
848  clang::Decl *IID = (clang::Decl *)IIVD;
849
850  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
851  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
852
853  // Form the actual loop
854  // for (Init; Cond; Inc)
855  //   RSSetObjectCall;
856
857  // Init -> "rsIntIter = 0"
858  clang::DeclRefExpr *RefrsIntIter =
859      clang::DeclRefExpr::Create(C,
860                                 clang::NestedNameSpecifierLoc(),
861                                 IIVD,
862                                 Loc,
863                                 C.IntTy,
864                                 clang::VK_RValue,
865                                 NULL);
866
867  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
868      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
869
870  clang::BinaryOperator *Init =
871      new(C) clang::BinaryOperator(RefrsIntIter,
872                                   Int0,
873                                   clang::BO_Assign,
874                                   C.IntTy,
875                                   clang::VK_RValue,
876                                   clang::OK_Ordinary,
877                                   Loc);
878
879  // Cond -> "rsIntIter < NumArrayElements"
880  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
881      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
882
883  clang::BinaryOperator *Cond =
884      new(C) clang::BinaryOperator(RefrsIntIter,
885                                   NumArrayElementsExpr,
886                                   clang::BO_LT,
887                                   C.IntTy,
888                                   clang::VK_RValue,
889                                   clang::OK_Ordinary,
890                                   Loc);
891
892  // Inc -> "rsIntIter++"
893  clang::UnaryOperator *Inc =
894      new(C) clang::UnaryOperator(RefrsIntIter,
895                                  clang::UO_PostInc,
896                                  C.IntTy,
897                                  clang::VK_RValue,
898                                  clang::OK_Ordinary,
899                                  Loc);
900
901  // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
902  // Loop operates on individual array elements
903
904  clang::Expr *DstArrPtr =
905      clang::ImplicitCastExpr::Create(C,
906          C.getPointerType(BaseType->getCanonicalTypeInternal()),
907          clang::CK_ArrayToPointerDecay,
908          DstArr,
909          NULL,
910          clang::VK_RValue);
911
912  clang::Expr *DstArrPtrSubscript =
913      new(C) clang::ArraySubscriptExpr(DstArrPtr,
914                                       RefrsIntIter,
915                                       BaseType->getCanonicalTypeInternal(),
916                                       clang::VK_RValue,
917                                       clang::OK_Ordinary,
918                                       Loc);
919
920  clang::Expr *SrcArrPtr =
921      clang::ImplicitCastExpr::Create(C,
922          C.getPointerType(BaseType->getCanonicalTypeInternal()),
923          clang::CK_ArrayToPointerDecay,
924          SrcArr,
925          NULL,
926          clang::VK_RValue);
927
928  clang::Expr *SrcArrPtrSubscript =
929      new(C) clang::ArraySubscriptExpr(SrcArrPtr,
930                                       RefrsIntIter,
931                                       BaseType->getCanonicalTypeInternal(),
932                                       clang::VK_RValue,
933                                       clang::OK_Ordinary,
934                                       Loc);
935
936  RSExportPrimitiveType::DataType DT =
937      RSExportPrimitiveType::GetRSSpecificType(BaseType);
938
939  clang::Stmt *RSSetObjectCall = NULL;
940  if (BaseType->isArrayType()) {
941    RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
942                                             SrcArrPtrSubscript,
943                                             StartLoc, Loc);
944  } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
945    RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
946                                              SrcArrPtrSubscript,
947                                              StartLoc, Loc);
948  } else {
949    RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
950                                              SrcArrPtrSubscript,
951                                              StartLoc, Loc);
952  }
953
954  clang::ForStmt *DestructorLoop =
955      new(C) clang::ForStmt(C,
956                            Init,
957                            Cond,
958                            NULL,  // no condVar
959                            Inc,
960                            RSSetObjectCall,
961                            Loc,
962                            Loc,
963                            Loc);
964
965  StmtArray[StmtCtr++] = DestructorLoop;
966  slangAssert(StmtCtr == 2);
967
968  clang::CompoundStmt *CS =
969      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
970
971  return CS;
972}
973
974static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
975                                            clang::Expr *LHS,
976                                            clang::Expr *RHS,
977                                            clang::SourceLocation StartLoc,
978                                            clang::SourceLocation Loc) {
979  clang::QualType QT = LHS->getType();
980  const clang::Type *T = QT.getTypePtr();
981  slangAssert(T->isStructureType());
982  slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
983
984  // Keep an extra slot for the original copy (memcpy)
985  unsigned FieldsToSet = CountRSObjectTypes(C, T, Loc) + 1;
986
987  unsigned StmtCount = 0;
988  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
989  for (unsigned i = 0; i < FieldsToSet; i++) {
990    StmtArray[i] = NULL;
991  }
992
993  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
994  RD = RD->getDefinition();
995  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
996         FE = RD->field_end();
997       FI != FE;
998       FI++) {
999    bool IsArrayType = false;
1000    clang::FieldDecl *FD = *FI;
1001    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1002    const clang::Type *OrigType = FT;
1003
1004    if (!CountRSObjectTypes(C, FT, Loc)) {
1005      // Skip to next if we don't have any viable RS object types
1006      continue;
1007    }
1008
1009    clang::DeclAccessPair FoundDecl =
1010        clang::DeclAccessPair::make(FD, clang::AS_none);
1011    clang::MemberExpr *DstMember =
1012        clang::MemberExpr::Create(C,
1013                                  LHS,
1014                                  false,
1015                                  clang::NestedNameSpecifierLoc(),
1016                                  FD,
1017                                  FoundDecl,
1018                                  clang::DeclarationNameInfo(),
1019                                  NULL,
1020                                  OrigType->getCanonicalTypeInternal(),
1021                                  clang::VK_RValue,
1022                                  clang::OK_Ordinary);
1023
1024    clang::MemberExpr *SrcMember =
1025        clang::MemberExpr::Create(C,
1026                                  RHS,
1027                                  false,
1028                                  clang::NestedNameSpecifierLoc(),
1029                                  FD,
1030                                  FoundDecl,
1031                                  clang::DeclarationNameInfo(),
1032                                  NULL,
1033                                  OrigType->getCanonicalTypeInternal(),
1034                                  clang::VK_RValue,
1035                                  clang::OK_Ordinary);
1036
1037    if (FT->isArrayType()) {
1038      FT = FT->getArrayElementTypeNoTypeQual();
1039      IsArrayType = true;
1040    }
1041
1042    RSExportPrimitiveType::DataType DT =
1043        RSExportPrimitiveType::GetRSSpecificType(FT);
1044
1045    if (IsArrayType) {
1046      clang::Diagnostic &Diags = C.getDiagnostics();
1047      Diags.Report(clang::FullSourceLoc(Loc, C.getSourceManager()),
1048          Diags.getCustomDiagID(clang::Diagnostic::Error,
1049            "Arrays of RS object types within structures cannot be copied"));
1050      // TODO(srhines): Support setting arrays of RS objects
1051      // StmtArray[StmtCount++] =
1052      //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1053    } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
1054      StmtArray[StmtCount++] =
1055          CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1056    } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
1057      StmtArray[StmtCount++] =
1058          CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1059    } else {
1060      slangAssert(false);
1061    }
1062  }
1063
1064  slangAssert(StmtCount > 0 && StmtCount < FieldsToSet);
1065
1066  // We still need to actually do the overall struct copy. For simplicity,
1067  // we just do a straight-up assignment (which will still preserve all
1068  // the proper RS object reference counts).
1069  clang::BinaryOperator *CopyStruct =
1070      new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
1071                                   clang::VK_RValue, clang::OK_Ordinary, Loc);
1072  StmtArray[StmtCount++] = CopyStruct;
1073
1074  clang::CompoundStmt *CS =
1075      new(C) clang::CompoundStmt(C, StmtArray, StmtCount, Loc, Loc);
1076
1077  delete [] StmtArray;
1078
1079  return CS;
1080}
1081
1082}  // namespace
1083
1084void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1085    clang::BinaryOperator *AS) {
1086
1087  clang::QualType QT = AS->getType();
1088
1089  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1090      RSExportPrimitiveType::DataTypeRSFont)->getASTContext();
1091
1092  clang::SourceLocation Loc = AS->getExprLoc();
1093  clang::SourceLocation StartLoc = AS->getExprLoc();
1094  clang::Stmt *UpdatedStmt = NULL;
1095
1096  if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1097    // By definition, this is a struct assignment if we get here
1098    UpdatedStmt =
1099        CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1100  } else {
1101    UpdatedStmt =
1102        CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1103  }
1104
1105  RSASTReplace R(C);
1106  R.ReplaceStmt(mCS, AS, UpdatedStmt);
1107  return;
1108}
1109
1110void RSObjectRefCount::Scope::AppendRSObjectInit(
1111    clang::VarDecl *VD,
1112    clang::DeclStmt *DS,
1113    RSExportPrimitiveType::DataType DT,
1114    clang::Expr *InitExpr) {
1115  slangAssert(VD);
1116
1117  if (!InitExpr) {
1118    return;
1119  }
1120
1121  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1122      RSExportPrimitiveType::DataTypeRSFont)->getASTContext();
1123  clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1124      RSExportPrimitiveType::DataTypeRSFont)->getLocation();
1125  clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1126      RSExportPrimitiveType::DataTypeRSFont)->getInnerLocStart();
1127
1128  if (DT == RSExportPrimitiveType::DataTypeIsStruct) {
1129    const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1130    clang::DeclRefExpr *RefRSVar =
1131        clang::DeclRefExpr::Create(C,
1132                                   clang::NestedNameSpecifierLoc(),
1133                                   VD,
1134                                   Loc,
1135                                   T->getCanonicalTypeInternal(),
1136                                   clang::VK_RValue,
1137                                   NULL);
1138
1139    clang::Stmt *RSSetObjectOps =
1140        CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1141
1142    std::list<clang::Stmt*> StmtList;
1143    StmtList.push_back(RSSetObjectOps);
1144    AppendAfterStmt(C, mCS, DS, StmtList);
1145    return;
1146  }
1147
1148  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1149  slangAssert((SetObjectFD != NULL) &&
1150              "rsSetObject doesn't cover all RS object types");
1151
1152  clang::QualType SetObjectFDType = SetObjectFD->getType();
1153  clang::QualType SetObjectFDArgType[2];
1154  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1155  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1156
1157  clang::Expr *RefRSSetObjectFD =
1158      clang::DeclRefExpr::Create(C,
1159                                 clang::NestedNameSpecifierLoc(),
1160                                 SetObjectFD,
1161                                 Loc,
1162                                 SetObjectFDType,
1163                                 clang::VK_RValue,
1164                                 NULL);
1165
1166  clang::Expr *RSSetObjectFP =
1167      clang::ImplicitCastExpr::Create(C,
1168                                      C.getPointerType(SetObjectFDType),
1169                                      clang::CK_FunctionToPointerDecay,
1170                                      RefRSSetObjectFD,
1171                                      NULL,
1172                                      clang::VK_RValue);
1173
1174  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1175  clang::DeclRefExpr *RefRSVar =
1176      clang::DeclRefExpr::Create(C,
1177                                 clang::NestedNameSpecifierLoc(),
1178                                 VD,
1179                                 Loc,
1180                                 T->getCanonicalTypeInternal(),
1181                                 clang::VK_RValue,
1182                                 NULL);
1183
1184  clang::Expr *ArgList[2];
1185  ArgList[0] = new(C) clang::UnaryOperator(RefRSVar,
1186                                           clang::UO_AddrOf,
1187                                           SetObjectFDArgType[0],
1188                                           clang::VK_RValue,
1189                                           clang::OK_Ordinary,
1190                                           Loc);
1191  ArgList[1] = InitExpr;
1192
1193  clang::CallExpr *RSSetObjectCall =
1194      new(C) clang::CallExpr(C,
1195                             RSSetObjectFP,
1196                             ArgList,
1197                             2,
1198                             SetObjectFD->getCallResultType(),
1199                             clang::VK_RValue,
1200                             Loc);
1201
1202  std::list<clang::Stmt*> StmtList;
1203  StmtList.push_back(RSSetObjectCall);
1204  AppendAfterStmt(C, mCS, DS, StmtList);
1205
1206  return;
1207}
1208
1209void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1210  std::list<clang::Stmt*> RSClearObjectCalls;
1211  for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
1212          E = mRSO.end();
1213        I != E;
1214        I++) {
1215    clang::Stmt *S = ClearRSObject(*I);
1216    if (S) {
1217      RSClearObjectCalls.push_back(S);
1218    }
1219  }
1220  if (RSClearObjectCalls.size() > 0) {
1221    DestructorVisitor DV((*mRSO.begin())->getASTContext(),
1222                         mCS,
1223                         RSClearObjectCalls);
1224    DV.Visit(mCS);
1225    DV.InsertDestructors();
1226  }
1227  return;
1228}
1229
1230clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(clang::VarDecl *VD) {
1231  slangAssert(VD);
1232  clang::ASTContext &C = VD->getASTContext();
1233  clang::DeclContext *DC = VD->getDeclContext();
1234  clang::SourceLocation Loc = VD->getLocation();
1235  clang::SourceLocation StartLoc = VD->getInnerLocStart();
1236  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1237
1238  // Reference expr to target RS object variable
1239  clang::DeclRefExpr *RefRSVar =
1240      clang::DeclRefExpr::Create(C,
1241                                 clang::NestedNameSpecifierLoc(),
1242                                 VD,
1243                                 Loc,
1244                                 T->getCanonicalTypeInternal(),
1245                                 clang::VK_RValue,
1246                                 NULL);
1247
1248  if (T->isArrayType()) {
1249    return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1250  }
1251
1252  RSExportPrimitiveType::DataType DT =
1253      RSExportPrimitiveType::GetRSSpecificType(T);
1254
1255  if (DT == RSExportPrimitiveType::DataTypeUnknown ||
1256      DT == RSExportPrimitiveType::DataTypeIsStruct) {
1257    return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1258  }
1259
1260  slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1261              "Should be RS object");
1262
1263  return ClearSingleRSObject(C, RefRSVar, Loc);
1264}
1265
1266bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1267                                          RSExportPrimitiveType::DataType *DT,
1268                                          clang::Expr **InitExpr) {
1269  slangAssert(VD && DT && InitExpr);
1270  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1271
1272  // Loop through array types to get to base type
1273  while (T && T->isArrayType()) {
1274    T = T->getArrayElementTypeNoTypeQual();
1275  }
1276
1277  bool DataTypeIsStructWithRSObject = false;
1278  *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1279
1280  if (*DT == RSExportPrimitiveType::DataTypeUnknown) {
1281    if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1282      *DT = RSExportPrimitiveType::DataTypeIsStruct;
1283      DataTypeIsStructWithRSObject = true;
1284    } else {
1285      return false;
1286    }
1287  }
1288
1289  bool DataTypeIsRSObject = false;
1290  if (DataTypeIsStructWithRSObject) {
1291    DataTypeIsRSObject = true;
1292  } else {
1293    DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1294  }
1295  *InitExpr = VD->getInit();
1296
1297  if (!DataTypeIsRSObject && *InitExpr) {
1298    // If we already have an initializer for a matrix type, we are done.
1299    return DataTypeIsRSObject;
1300  }
1301
1302  clang::Expr *ZeroInitializer =
1303      CreateZeroInitializerForRSSpecificType(*DT,
1304                                             VD->getASTContext(),
1305                                             VD->getLocation());
1306
1307  if (ZeroInitializer) {
1308    ZeroInitializer->setType(T->getCanonicalTypeInternal());
1309    VD->setInit(ZeroInitializer);
1310  }
1311
1312  return DataTypeIsRSObject;
1313}
1314
1315clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
1316    RSExportPrimitiveType::DataType DT,
1317    clang::ASTContext &C,
1318    const clang::SourceLocation &Loc) {
1319  clang::Expr *Res = NULL;
1320  switch (DT) {
1321    case RSExportPrimitiveType::DataTypeIsStruct:
1322    case RSExportPrimitiveType::DataTypeRSElement:
1323    case RSExportPrimitiveType::DataTypeRSType:
1324    case RSExportPrimitiveType::DataTypeRSAllocation:
1325    case RSExportPrimitiveType::DataTypeRSSampler:
1326    case RSExportPrimitiveType::DataTypeRSScript:
1327    case RSExportPrimitiveType::DataTypeRSMesh:
1328    case RSExportPrimitiveType::DataTypeRSProgramFragment:
1329    case RSExportPrimitiveType::DataTypeRSProgramVertex:
1330    case RSExportPrimitiveType::DataTypeRSProgramRaster:
1331    case RSExportPrimitiveType::DataTypeRSProgramStore:
1332    case RSExportPrimitiveType::DataTypeRSFont: {
1333      //    (ImplicitCastExpr 'nullptr_t'
1334      //      (IntegerLiteral 0)))
1335      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
1336      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
1337      clang::Expr *CastToNull =
1338          clang::ImplicitCastExpr::Create(C,
1339                                          C.NullPtrTy,
1340                                          clang::CK_IntegralToPointer,
1341                                          Int0,
1342                                          NULL,
1343                                          clang::VK_RValue);
1344
1345      Res = new(C) clang::InitListExpr(C, Loc, &CastToNull, 1, Loc);
1346      break;
1347    }
1348    case RSExportPrimitiveType::DataTypeRSMatrix2x2:
1349    case RSExportPrimitiveType::DataTypeRSMatrix3x3:
1350    case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
1351      // RS matrix is not completely an RS object. They hold data by themselves.
1352      // (InitListExpr rs_matrix2x2
1353      //   (InitListExpr float[4]
1354      //     (FloatingLiteral 0)
1355      //     (FloatingLiteral 0)
1356      //     (FloatingLiteral 0)
1357      //     (FloatingLiteral 0)))
1358      clang::QualType FloatTy = C.FloatTy;
1359      // Constructor sets value to 0.0f by default
1360      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
1361      clang::FloatingLiteral *Float0Val =
1362          clang::FloatingLiteral::Create(C,
1363                                         Val,
1364                                         /* isExact = */true,
1365                                         FloatTy,
1366                                         Loc);
1367
1368      unsigned N = 0;
1369      if (DT == RSExportPrimitiveType::DataTypeRSMatrix2x2)
1370        N = 2;
1371      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix3x3)
1372        N = 3;
1373      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix4x4)
1374        N = 4;
1375
1376      // Directly allocate 16 elements instead of dynamically allocate N*N
1377      clang::Expr *InitVals[16];
1378      for (unsigned i = 0; i < sizeof(InitVals) / sizeof(InitVals[0]); i++)
1379        InitVals[i] = Float0Val;
1380      clang::Expr *InitExpr =
1381          new(C) clang::InitListExpr(C, Loc, InitVals, N * N, Loc);
1382      InitExpr->setType(C.getConstantArrayType(FloatTy,
1383                                               llvm::APInt(32, 4),
1384                                               clang::ArrayType::Normal,
1385                                               /* EltTypeQuals = */0));
1386
1387      Res = new(C) clang::InitListExpr(C, Loc, &InitExpr, 1, Loc);
1388      break;
1389    }
1390    case RSExportPrimitiveType::DataTypeUnknown:
1391    case RSExportPrimitiveType::DataTypeFloat16:
1392    case RSExportPrimitiveType::DataTypeFloat32:
1393    case RSExportPrimitiveType::DataTypeFloat64:
1394    case RSExportPrimitiveType::DataTypeSigned8:
1395    case RSExportPrimitiveType::DataTypeSigned16:
1396    case RSExportPrimitiveType::DataTypeSigned32:
1397    case RSExportPrimitiveType::DataTypeSigned64:
1398    case RSExportPrimitiveType::DataTypeUnsigned8:
1399    case RSExportPrimitiveType::DataTypeUnsigned16:
1400    case RSExportPrimitiveType::DataTypeUnsigned32:
1401    case RSExportPrimitiveType::DataTypeUnsigned64:
1402    case RSExportPrimitiveType::DataTypeBoolean:
1403    case RSExportPrimitiveType::DataTypeUnsigned565:
1404    case RSExportPrimitiveType::DataTypeUnsigned5551:
1405    case RSExportPrimitiveType::DataTypeUnsigned4444:
1406    case RSExportPrimitiveType::DataTypeMax: {
1407      slangAssert(false && "Not RS object type!");
1408    }
1409    // No default case will enable compiler detecting the missing cases
1410  }
1411
1412  return Res;
1413}
1414
1415void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1416  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1417       I != E;
1418       I++) {
1419    clang::Decl *D = *I;
1420    if (D->getKind() == clang::Decl::Var) {
1421      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1422      RSExportPrimitiveType::DataType DT =
1423          RSExportPrimitiveType::DataTypeUnknown;
1424      clang::Expr *InitExpr = NULL;
1425      if (InitializeRSObject(VD, &DT, &InitExpr)) {
1426        getCurrentScope()->addRSObject(VD);
1427        getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1428      }
1429    }
1430  }
1431  return;
1432}
1433
1434void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1435  if (!CS->body_empty()) {
1436    // Push a new scope
1437    Scope *S = new Scope(CS);
1438    mScopeStack.push(S);
1439
1440    VisitStmt(CS);
1441
1442    // Destroy the scope
1443    slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1444    S->InsertLocalVarDestructors();
1445    mScopeStack.pop();
1446    delete S;
1447  }
1448  return;
1449}
1450
1451void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1452  clang::QualType QT = AS->getType();
1453
1454  if (CountRSObjectTypes(mCtx, QT.getTypePtr(), AS->getExprLoc())) {
1455    getCurrentScope()->ReplaceRSObjectAssignment(AS);
1456  }
1457
1458  return;
1459}
1460
1461void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1462  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1463       I != E;
1464       I++) {
1465    if (clang::Stmt *Child = *I) {
1466      Visit(Child);
1467    }
1468  }
1469  return;
1470}
1471
1472}  // namespace slang
1473