slang_rs_object_ref_count.cpp revision ab992e59a36a18df49bf4878968ef0598299afd3
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_object_ref_count.h"
18
19#include <list>
20
21#include "clang/AST/DeclGroup.h"
22#include "clang/AST/Expr.h"
23#include "clang/AST/NestedNameSpecifier.h"
24#include "clang/AST/OperationKinds.h"
25#include "clang/AST/Stmt.h"
26#include "clang/AST/StmtVisitor.h"
27
28#include "slang_assert.h"
29#include "slang_rs.h"
30#include "slang_rs_ast_replace.h"
31#include "slang_rs_export_type.h"
32
33namespace slang {
34
35clang::FunctionDecl *RSObjectRefCount::
36    RSSetObjectFD[RSExportPrimitiveType::LastRSObjectType -
37                  RSExportPrimitiveType::FirstRSObjectType + 1];
38clang::FunctionDecl *RSObjectRefCount::
39    RSClearObjectFD[RSExportPrimitiveType::LastRSObjectType -
40                    RSExportPrimitiveType::FirstRSObjectType + 1];
41
42void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
43  for (unsigned i = 0;
44       i < (sizeof(RSClearObjectFD) / sizeof(clang::FunctionDecl*));
45       i++) {
46    RSSetObjectFD[i] = NULL;
47    RSClearObjectFD[i] = NULL;
48  }
49
50  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
51
52  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
53          E = TUDecl->decls_end(); I != E; I++) {
54    if ((I->getKind() >= clang::Decl::firstFunction) &&
55        (I->getKind() <= clang::Decl::lastFunction)) {
56      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
57
58      // points to RSSetObjectFD or RSClearObjectFD
59      clang::FunctionDecl **RSObjectFD;
60
61      if (FD->getName() == "rsSetObject") {
62        slangAssert((FD->getNumParams() == 2) &&
63                    "Invalid rsSetObject function prototype (# params)");
64        RSObjectFD = RSSetObjectFD;
65      } else if (FD->getName() == "rsClearObject") {
66        slangAssert((FD->getNumParams() == 1) &&
67                    "Invalid rsClearObject function prototype (# params)");
68        RSObjectFD = RSClearObjectFD;
69      } else {
70        continue;
71      }
72
73      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
74      clang::QualType PVT = PVD->getOriginalType();
75      // The first parameter must be a pointer like rs_allocation*
76      slangAssert(PVT->isPointerType() &&
77          "Invalid rs{Set,Clear}Object function prototype (pointer param)");
78
79      // The rs object type passed to the FD
80      clang::QualType RST = PVT->getPointeeType();
81      RSExportPrimitiveType::DataType DT =
82          RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
83      slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
84             && "must be RS object type");
85
86      RSObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)] = FD;
87    }
88  }
89}
90
91namespace {
92
93// This function constructs a new CompoundStmt from the input StmtList.
94static clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
95      std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
96  unsigned NewStmtCount = StmtList.size();
97  unsigned CompoundStmtCount = 0;
98
99  clang::Stmt **CompoundStmtList;
100  CompoundStmtList = new clang::Stmt*[NewStmtCount];
101
102  std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
103  std::list<clang::Stmt*>::const_iterator E = StmtList.end();
104  for ( ; I != E; I++) {
105    CompoundStmtList[CompoundStmtCount++] = *I;
106  }
107  slangAssert(CompoundStmtCount == NewStmtCount);
108
109  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(C,
110                                                       CompoundStmtList,
111                                                       CompoundStmtCount,
112                                                       Loc,
113                                                       Loc);
114
115  delete [] CompoundStmtList;
116
117  return CS;
118}
119
120static void AppendAfterStmt(clang::ASTContext &C,
121                            clang::CompoundStmt *CS,
122                            clang::Stmt *S,
123                            std::list<clang::Stmt*> &StmtList) {
124  slangAssert(CS);
125  clang::CompoundStmt::body_iterator bI = CS->body_begin();
126  clang::CompoundStmt::body_iterator bE = CS->body_end();
127  clang::Stmt **UpdatedStmtList =
128      new clang::Stmt*[CS->size() + StmtList.size()];
129
130  unsigned UpdatedStmtCount = 0;
131  unsigned Once = 0;
132  for ( ; bI != bE; bI++) {
133    if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
134      // If we come across a return here, we don't have anything we can
135      // reasonably replace. We should have already inserted our destructor
136      // code in the proper spot, so we just clean up and return.
137      delete [] UpdatedStmtList;
138
139      return;
140    }
141
142    UpdatedStmtList[UpdatedStmtCount++] = *bI;
143
144    if ((*bI == S) && !Once) {
145      Once++;
146      std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
147      std::list<clang::Stmt*>::const_iterator E = StmtList.end();
148      for ( ; I != E; I++) {
149        UpdatedStmtList[UpdatedStmtCount++] = *I;
150      }
151    }
152  }
153  slangAssert(Once <= 1);
154
155  // When S is NULL, we are appending to the end of the CompoundStmt.
156  if (!S) {
157    slangAssert(Once == 0);
158    std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
159    std::list<clang::Stmt*>::const_iterator E = StmtList.end();
160    for ( ; I != E; I++) {
161      UpdatedStmtList[UpdatedStmtCount++] = *I;
162    }
163  }
164
165  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
166
167  delete [] UpdatedStmtList;
168
169  return;
170}
171
172// This class visits a compound statement and inserts the StmtList containing
173// destructors in proper locations. This includes inserting them before any
174// return statement in any sub-block, at the end of the logical enclosing
175// scope (compound statement), and/or before any break/continue statement that
176// would resume outside the declared scope. We will not handle the case for
177// goto statements that leave a local scope.
178//
179// To accomplish these goals, it collects a list of sub-Stmt's that
180// correspond to scope exit points. It then uses an RSASTReplace visitor to
181// transform the AST, inserting appropriate destructors before each of those
182// sub-Stmt's (and also before the exit of the outermost containing Stmt for
183// the scope).
184class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
185 private:
186  clang::ASTContext &mCtx;
187
188  // The loop depth of the currently visited node.
189  int mLoopDepth;
190
191  // The switch statement depth of the currently visited node.
192  // Note that this is tracked separately from the loop depth because
193  // SwitchStmt-contained ContinueStmt's should have destructors for the
194  // corresponding loop scope.
195  int mSwitchDepth;
196
197  // The outermost statement block that we are currently visiting.
198  // This should always be a CompoundStmt.
199  clang::Stmt *mOuterStmt;
200
201  // The list of destructors to execute for this scope.
202  std::list<clang::Stmt*> &mStmtList;
203
204  // The stack of statements which should be replaced by a compound statement
205  // containing the new destructor calls followed by the original Stmt.
206  std::stack<clang::Stmt*> mReplaceStmtStack;
207
208 public:
209  DestructorVisitor(clang::ASTContext &C,
210                    clang::Stmt* OuterStmt,
211                    std::list<clang::Stmt*> &StmtList);
212
213  // This code walks the collected list of Stmts to replace and actually does
214  // the replacement. It also finishes up by appending appropriate destructors
215  // to the current outermost CompoundStmt.
216  void InsertDestructors() {
217    clang::Stmt *S = NULL;
218    while (!mReplaceStmtStack.empty()) {
219      S = mReplaceStmtStack.top();
220      mReplaceStmtStack.pop();
221
222      mStmtList.push_back(S);
223      clang::CompoundStmt *CS =
224          BuildCompoundStmt(mCtx, mStmtList, S->getLocEnd());
225      mStmtList.pop_back();
226
227      RSASTReplace R(mCtx);
228      R.ReplaceStmt(mOuterStmt, S, CS);
229    }
230    clang::CompoundStmt *CS =
231      llvm::dyn_cast<clang::CompoundStmt>(mOuterStmt);
232    slangAssert(CS);
233    if (CS) {
234      AppendAfterStmt(mCtx, CS, NULL, mStmtList);
235    }
236  }
237
238  void VisitStmt(clang::Stmt *S);
239  void VisitCompoundStmt(clang::CompoundStmt *CS);
240
241  void VisitBreakStmt(clang::BreakStmt *BS);
242  void VisitCaseStmt(clang::CaseStmt *CS);
243  void VisitContinueStmt(clang::ContinueStmt *CS);
244  void VisitDefaultStmt(clang::DefaultStmt *DS);
245  void VisitDoStmt(clang::DoStmt *DS);
246  void VisitForStmt(clang::ForStmt *FS);
247  void VisitIfStmt(clang::IfStmt *IS);
248  void VisitReturnStmt(clang::ReturnStmt *RS);
249  void VisitSwitchCase(clang::SwitchCase *SC);
250  void VisitSwitchStmt(clang::SwitchStmt *SS);
251  void VisitWhileStmt(clang::WhileStmt *WS);
252};
253
254DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
255                         clang::Stmt *OuterStmt,
256                         std::list<clang::Stmt*> &StmtList)
257  : mCtx(C),
258    mLoopDepth(0),
259    mSwitchDepth(0),
260    mOuterStmt(OuterStmt),
261    mStmtList(StmtList) {
262  return;
263}
264
265void DestructorVisitor::VisitStmt(clang::Stmt *S) {
266  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
267       I != E;
268       I++) {
269    if (clang::Stmt *Child = *I) {
270      Visit(Child);
271    }
272  }
273  return;
274}
275
276void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
277  VisitStmt(CS);
278  return;
279}
280
281void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
282  VisitStmt(BS);
283  if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
284    mReplaceStmtStack.push(BS);
285  }
286  return;
287}
288
289void DestructorVisitor::VisitCaseStmt(clang::CaseStmt *CS) {
290  VisitStmt(CS);
291  return;
292}
293
294void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
295  VisitStmt(CS);
296  if (mLoopDepth == 0) {
297    // Switch statements can have nested continues.
298    mReplaceStmtStack.push(CS);
299  }
300  return;
301}
302
303void DestructorVisitor::VisitDefaultStmt(clang::DefaultStmt *DS) {
304  VisitStmt(DS);
305  return;
306}
307
308void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
309  mLoopDepth++;
310  VisitStmt(DS);
311  mLoopDepth--;
312  return;
313}
314
315void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
316  mLoopDepth++;
317  VisitStmt(FS);
318  mLoopDepth--;
319  return;
320}
321
322void DestructorVisitor::VisitIfStmt(clang::IfStmt *IS) {
323  VisitStmt(IS);
324  return;
325}
326
327void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
328  mReplaceStmtStack.push(RS);
329  return;
330}
331
332void DestructorVisitor::VisitSwitchCase(clang::SwitchCase *SC) {
333  slangAssert(false && "Both case and default have specialized handlers");
334  VisitStmt(SC);
335  return;
336}
337
338void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
339  mSwitchDepth++;
340  VisitStmt(SS);
341  mSwitchDepth--;
342  return;
343}
344
345void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
346  mLoopDepth++;
347  VisitStmt(WS);
348  mLoopDepth--;
349  return;
350}
351
352clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
353                                 clang::Expr *RefRSVar,
354                                 clang::SourceLocation Loc) {
355  slangAssert(RefRSVar);
356  const clang::Type *T = RefRSVar->getType().getTypePtr();
357  slangAssert(!T->isArrayType() &&
358              "Should not be destroying arrays with this function");
359
360  clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
361  slangAssert((ClearObjectFD != NULL) &&
362              "rsClearObject doesn't cover all RS object types");
363
364  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
365  clang::QualType ClearObjectFDArgType =
366      ClearObjectFD->getParamDecl(0)->getOriginalType();
367
368  // Example destructor for "rs_font localFont;"
369  //
370  // (CallExpr 'void'
371  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
372  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
373  //   (UnaryOperator 'rs_font *' prefix '&'
374  //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
375
376  // Get address of targeted RS object
377  clang::Expr *AddrRefRSVar =
378      new(C) clang::UnaryOperator(RefRSVar,
379                                  clang::UO_AddrOf,
380                                  ClearObjectFDArgType,
381                                  clang::VK_RValue,
382                                  clang::OK_Ordinary,
383                                  Loc);
384
385  clang::Expr *RefRSClearObjectFD =
386      clang::DeclRefExpr::Create(C,
387                                 clang::NestedNameSpecifierLoc(),
388                                 ClearObjectFD,
389                                 ClearObjectFD->getLocation(),
390                                 ClearObjectFDType,
391                                 clang::VK_RValue,
392                                 NULL);
393
394  clang::Expr *RSClearObjectFP =
395      clang::ImplicitCastExpr::Create(C,
396                                      C.getPointerType(ClearObjectFDType),
397                                      clang::CK_FunctionToPointerDecay,
398                                      RefRSClearObjectFD,
399                                      NULL,
400                                      clang::VK_RValue);
401
402  clang::CallExpr *RSClearObjectCall =
403      new(C) clang::CallExpr(C,
404                             RSClearObjectFP,
405                             &AddrRefRSVar,
406                             1,
407                             ClearObjectFD->getCallResultType(),
408                             clang::VK_RValue,
409                             Loc);
410
411  return RSClearObjectCall;
412}
413
414static int ArrayDim(const clang::Type *T) {
415  if (!T || !T->isArrayType()) {
416    return 0;
417  }
418
419  const clang::ConstantArrayType *CAT =
420    static_cast<const clang::ConstantArrayType *>(T);
421  return static_cast<int>(CAT->getSize().getSExtValue());
422}
423
424static clang::Stmt *ClearStructRSObject(
425    clang::ASTContext &C,
426    clang::DeclContext *DC,
427    clang::Expr *RefRSStruct,
428    clang::SourceLocation StartLoc,
429    clang::SourceLocation Loc);
430
431static clang::Stmt *ClearArrayRSObject(
432    clang::ASTContext &C,
433    clang::DeclContext *DC,
434    clang::Expr *RefRSArr,
435    clang::SourceLocation StartLoc,
436    clang::SourceLocation Loc) {
437  const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
438  slangAssert(BaseType->isArrayType());
439
440  int NumArrayElements = ArrayDim(BaseType);
441  // Actually extract out the base RS object type for use later
442  BaseType = BaseType->getArrayElementTypeNoTypeQual();
443
444  clang::Stmt *StmtArray[2] = {NULL};
445  int StmtCtr = 0;
446
447  if (NumArrayElements <= 0) {
448    return NULL;
449  }
450
451  // Example destructor loop for "rs_font fontArr[10];"
452  //
453  // (CompoundStmt
454  //   (DeclStmt "int rsIntIter")
455  //   (ForStmt
456  //     (BinaryOperator 'int' '='
457  //       (DeclRefExpr 'int' Var='rsIntIter')
458  //       (IntegerLiteral 'int' 0))
459  //     (BinaryOperator 'int' '<'
460  //       (DeclRefExpr 'int' Var='rsIntIter')
461  //       (IntegerLiteral 'int' 10)
462  //     NULL << CondVar >>
463  //     (UnaryOperator 'int' postfix '++'
464  //       (DeclRefExpr 'int' Var='rsIntIter'))
465  //     (CallExpr 'void'
466  //       (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
467  //         (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
468  //       (UnaryOperator 'rs_font *' prefix '&'
469  //         (ArraySubscriptExpr 'rs_font':'rs_font'
470  //           (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
471  //             (DeclRefExpr 'rs_font [10]' Var='fontArr'))
472  //           (DeclRefExpr 'int' Var='rsIntIter')))))))
473
474  // Create helper variable for iterating through elements
475  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
476  clang::VarDecl *IIVD =
477      clang::VarDecl::Create(C,
478                             DC,
479                             StartLoc,
480                             Loc,
481                             &II,
482                             C.IntTy,
483                             C.getTrivialTypeSourceInfo(C.IntTy),
484                             clang::SC_None,
485                             clang::SC_None);
486  clang::Decl *IID = (clang::Decl *)IIVD;
487
488  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
489  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
490
491  // Form the actual destructor loop
492  // for (Init; Cond; Inc)
493  //   RSClearObjectCall;
494
495  // Init -> "rsIntIter = 0"
496  clang::DeclRefExpr *RefrsIntIter =
497      clang::DeclRefExpr::Create(C,
498                                 clang::NestedNameSpecifierLoc(),
499                                 IIVD,
500                                 Loc,
501                                 C.IntTy,
502                                 clang::VK_RValue,
503                                 NULL);
504
505  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
506      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
507
508  clang::BinaryOperator *Init =
509      new(C) clang::BinaryOperator(RefrsIntIter,
510                                   Int0,
511                                   clang::BO_Assign,
512                                   C.IntTy,
513                                   clang::VK_RValue,
514                                   clang::OK_Ordinary,
515                                   Loc);
516
517  // Cond -> "rsIntIter < NumArrayElements"
518  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
519      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
520
521  clang::BinaryOperator *Cond =
522      new(C) clang::BinaryOperator(RefrsIntIter,
523                                   NumArrayElementsExpr,
524                                   clang::BO_LT,
525                                   C.IntTy,
526                                   clang::VK_RValue,
527                                   clang::OK_Ordinary,
528                                   Loc);
529
530  // Inc -> "rsIntIter++"
531  clang::UnaryOperator *Inc =
532      new(C) clang::UnaryOperator(RefrsIntIter,
533                                  clang::UO_PostInc,
534                                  C.IntTy,
535                                  clang::VK_RValue,
536                                  clang::OK_Ordinary,
537                                  Loc);
538
539  // Body -> "rsClearObject(&VD[rsIntIter]);"
540  // Destructor loop operates on individual array elements
541
542  clang::Expr *RefRSArrPtr =
543      clang::ImplicitCastExpr::Create(C,
544          C.getPointerType(BaseType->getCanonicalTypeInternal()),
545          clang::CK_ArrayToPointerDecay,
546          RefRSArr,
547          NULL,
548          clang::VK_RValue);
549
550  clang::Expr *RefRSArrPtrSubscript =
551      new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
552                                       RefrsIntIter,
553                                       BaseType->getCanonicalTypeInternal(),
554                                       clang::VK_RValue,
555                                       clang::OK_Ordinary,
556                                       Loc);
557
558  RSExportPrimitiveType::DataType DT =
559      RSExportPrimitiveType::GetRSSpecificType(BaseType);
560
561  clang::Stmt *RSClearObjectCall = NULL;
562  if (BaseType->isArrayType()) {
563    RSClearObjectCall =
564        ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
565  } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
566    RSClearObjectCall =
567        ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
568  } else {
569    RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
570  }
571
572  clang::ForStmt *DestructorLoop =
573      new(C) clang::ForStmt(C,
574                            Init,
575                            Cond,
576                            NULL,  // no condVar
577                            Inc,
578                            RSClearObjectCall,
579                            Loc,
580                            Loc,
581                            Loc);
582
583  StmtArray[StmtCtr++] = DestructorLoop;
584  slangAssert(StmtCtr == 2);
585
586  clang::CompoundStmt *CS =
587      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
588
589  return CS;
590}
591
592static unsigned CountRSObjectTypes(clang::ASTContext &C,
593                                   const clang::Type *T,
594                                   clang::SourceLocation Loc) {
595  slangAssert(T);
596  unsigned RSObjectCount = 0;
597
598  if (T->isArrayType()) {
599    return CountRSObjectTypes(C, T->getArrayElementTypeNoTypeQual(), Loc);
600  }
601
602  RSExportPrimitiveType::DataType DT =
603      RSExportPrimitiveType::GetRSSpecificType(T);
604  if (DT != RSExportPrimitiveType::DataTypeUnknown) {
605    return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
606  }
607
608  if (T->isUnionType()) {
609    clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
610    RD = RD->getDefinition();
611    for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
612           FE = RD->field_end();
613         FI != FE;
614         FI++) {
615      const clang::FieldDecl *FD = *FI;
616      const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
617      if (CountRSObjectTypes(C, FT, Loc)) {
618        slangAssert(false && "can't have unions with RS object types!");
619        return 0;
620      }
621    }
622  }
623
624  if (!T->isStructureType()) {
625    return 0;
626  }
627
628  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
629  RD = RD->getDefinition();
630  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
631         FE = RD->field_end();
632       FI != FE;
633       FI++) {
634    const clang::FieldDecl *FD = *FI;
635    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
636    if (CountRSObjectTypes(C, FT, Loc)) {
637      // Sub-structs should only count once (as should arrays, etc.)
638      RSObjectCount++;
639    }
640  }
641
642  return RSObjectCount;
643}
644
645static clang::Stmt *ClearStructRSObject(
646    clang::ASTContext &C,
647    clang::DeclContext *DC,
648    clang::Expr *RefRSStruct,
649    clang::SourceLocation StartLoc,
650    clang::SourceLocation Loc) {
651  const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
652
653  slangAssert(!BaseType->isArrayType());
654
655  // Structs should show up as unknown primitive types
656  slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
657              RSExportPrimitiveType::DataTypeUnknown);
658
659  unsigned FieldsToDestroy = CountRSObjectTypes(C, BaseType, Loc);
660
661  unsigned StmtCount = 0;
662  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
663  for (unsigned i = 0; i < FieldsToDestroy; i++) {
664    StmtArray[i] = NULL;
665  }
666
667  // Populate StmtArray by creating a destructor for each RS object field
668  clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
669  RD = RD->getDefinition();
670  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
671         FE = RD->field_end();
672       FI != FE;
673       FI++) {
674    // We just look through all field declarations to see if we find a
675    // declaration for an RS object type (or an array of one).
676    bool IsArrayType = false;
677    clang::FieldDecl *FD = *FI;
678    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
679    const clang::Type *OrigType = FT;
680    while (FT && FT->isArrayType()) {
681      FT = FT->getArrayElementTypeNoTypeQual();
682      IsArrayType = true;
683    }
684
685    if (RSExportPrimitiveType::IsRSObjectType(FT)) {
686      clang::DeclAccessPair FoundDecl =
687          clang::DeclAccessPair::make(FD, clang::AS_none);
688      clang::MemberExpr *RSObjectMember =
689          clang::MemberExpr::Create(C,
690                                    RefRSStruct,
691                                    false,
692                                    clang::NestedNameSpecifierLoc(),
693                                    FD,
694                                    FoundDecl,
695                                    clang::DeclarationNameInfo(),
696                                    NULL,
697                                    OrigType->getCanonicalTypeInternal(),
698                                    clang::VK_RValue,
699                                    clang::OK_Ordinary);
700
701      slangAssert(StmtCount < FieldsToDestroy);
702
703      if (IsArrayType) {
704        StmtArray[StmtCount++] = ClearArrayRSObject(C,
705                                                    DC,
706                                                    RSObjectMember,
707                                                    StartLoc,
708                                                    Loc);
709      } else {
710        StmtArray[StmtCount++] = ClearSingleRSObject(C,
711                                                     RSObjectMember,
712                                                     Loc);
713      }
714    } else if (FT->isStructureType() && CountRSObjectTypes(C, FT, Loc)) {
715      // In this case, we have a nested struct. We may not end up filling all
716      // of the spaces in StmtArray (sub-structs should handle themselves
717      // with separate compound statements).
718      clang::DeclAccessPair FoundDecl =
719          clang::DeclAccessPair::make(FD, clang::AS_none);
720      clang::MemberExpr *RSObjectMember =
721          clang::MemberExpr::Create(C,
722                                    RefRSStruct,
723                                    false,
724                                    clang::NestedNameSpecifierLoc(),
725                                    FD,
726                                    FoundDecl,
727                                    clang::DeclarationNameInfo(),
728                                    NULL,
729                                    OrigType->getCanonicalTypeInternal(),
730                                    clang::VK_RValue,
731                                    clang::OK_Ordinary);
732
733      if (IsArrayType) {
734        StmtArray[StmtCount++] = ClearArrayRSObject(C,
735                                                    DC,
736                                                    RSObjectMember,
737                                                    StartLoc,
738                                                    Loc);
739      } else {
740        StmtArray[StmtCount++] = ClearStructRSObject(C,
741                                                     DC,
742                                                     RSObjectMember,
743                                                     StartLoc,
744                                                     Loc);
745      }
746    }
747  }
748
749  slangAssert(StmtCount > 0);
750  clang::CompoundStmt *CS =
751      new(C) clang::CompoundStmt(C, StmtArray, StmtCount, Loc, Loc);
752
753  delete [] StmtArray;
754
755  return CS;
756}
757
758static clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
759                                            clang::Expr *DstExpr,
760                                            clang::Expr *SrcExpr,
761                                            clang::SourceLocation StartLoc,
762                                            clang::SourceLocation Loc) {
763  const clang::Type *T = DstExpr->getType().getTypePtr();
764  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
765  slangAssert((SetObjectFD != NULL) &&
766              "rsSetObject doesn't cover all RS object types");
767
768  clang::QualType SetObjectFDType = SetObjectFD->getType();
769  clang::QualType SetObjectFDArgType[2];
770  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
771  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
772
773  clang::Expr *RefRSSetObjectFD =
774      clang::DeclRefExpr::Create(C,
775                                 clang::NestedNameSpecifierLoc(),
776                                 SetObjectFD,
777                                 Loc,
778                                 SetObjectFDType,
779                                 clang::VK_RValue,
780                                 NULL);
781
782  clang::Expr *RSSetObjectFP =
783      clang::ImplicitCastExpr::Create(C,
784                                      C.getPointerType(SetObjectFDType),
785                                      clang::CK_FunctionToPointerDecay,
786                                      RefRSSetObjectFD,
787                                      NULL,
788                                      clang::VK_RValue);
789
790  clang::Expr *ArgList[2];
791  ArgList[0] = new(C) clang::UnaryOperator(DstExpr,
792                                           clang::UO_AddrOf,
793                                           SetObjectFDArgType[0],
794                                           clang::VK_RValue,
795                                           clang::OK_Ordinary,
796                                           Loc);
797  ArgList[1] = SrcExpr;
798
799  clang::CallExpr *RSSetObjectCall =
800      new(C) clang::CallExpr(C,
801                             RSSetObjectFP,
802                             ArgList,
803                             2,
804                             SetObjectFD->getCallResultType(),
805                             clang::VK_RValue,
806                             Loc);
807
808  return RSSetObjectCall;
809}
810
811static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
812                                            clang::Expr *LHS,
813                                            clang::Expr *RHS,
814                                            clang::SourceLocation StartLoc,
815                                            clang::SourceLocation Loc);
816
817static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
818                                           clang::Expr *DstArr,
819                                           clang::Expr *SrcArr,
820                                           clang::SourceLocation StartLoc,
821                                           clang::SourceLocation Loc) {
822  clang::DeclContext *DC = NULL;
823  const clang::Type *BaseType = DstArr->getType().getTypePtr();
824  slangAssert(BaseType->isArrayType());
825
826  int NumArrayElements = ArrayDim(BaseType);
827  // Actually extract out the base RS object type for use later
828  BaseType = BaseType->getArrayElementTypeNoTypeQual();
829
830  clang::Stmt *StmtArray[2] = {NULL};
831  int StmtCtr = 0;
832
833  if (NumArrayElements <= 0) {
834    return NULL;
835  }
836
837  // Create helper variable for iterating through elements
838  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
839  clang::VarDecl *IIVD =
840      clang::VarDecl::Create(C,
841                             DC,
842                             StartLoc,
843                             Loc,
844                             &II,
845                             C.IntTy,
846                             C.getTrivialTypeSourceInfo(C.IntTy),
847                             clang::SC_None,
848                             clang::SC_None);
849  clang::Decl *IID = (clang::Decl *)IIVD;
850
851  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
852  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
853
854  // Form the actual loop
855  // for (Init; Cond; Inc)
856  //   RSSetObjectCall;
857
858  // Init -> "rsIntIter = 0"
859  clang::DeclRefExpr *RefrsIntIter =
860      clang::DeclRefExpr::Create(C,
861                                 clang::NestedNameSpecifierLoc(),
862                                 IIVD,
863                                 Loc,
864                                 C.IntTy,
865                                 clang::VK_RValue,
866                                 NULL);
867
868  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
869      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
870
871  clang::BinaryOperator *Init =
872      new(C) clang::BinaryOperator(RefrsIntIter,
873                                   Int0,
874                                   clang::BO_Assign,
875                                   C.IntTy,
876                                   clang::VK_RValue,
877                                   clang::OK_Ordinary,
878                                   Loc);
879
880  // Cond -> "rsIntIter < NumArrayElements"
881  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
882      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
883
884  clang::BinaryOperator *Cond =
885      new(C) clang::BinaryOperator(RefrsIntIter,
886                                   NumArrayElementsExpr,
887                                   clang::BO_LT,
888                                   C.IntTy,
889                                   clang::VK_RValue,
890                                   clang::OK_Ordinary,
891                                   Loc);
892
893  // Inc -> "rsIntIter++"
894  clang::UnaryOperator *Inc =
895      new(C) clang::UnaryOperator(RefrsIntIter,
896                                  clang::UO_PostInc,
897                                  C.IntTy,
898                                  clang::VK_RValue,
899                                  clang::OK_Ordinary,
900                                  Loc);
901
902  // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
903  // Loop operates on individual array elements
904
905  clang::Expr *DstArrPtr =
906      clang::ImplicitCastExpr::Create(C,
907          C.getPointerType(BaseType->getCanonicalTypeInternal()),
908          clang::CK_ArrayToPointerDecay,
909          DstArr,
910          NULL,
911          clang::VK_RValue);
912
913  clang::Expr *DstArrPtrSubscript =
914      new(C) clang::ArraySubscriptExpr(DstArrPtr,
915                                       RefrsIntIter,
916                                       BaseType->getCanonicalTypeInternal(),
917                                       clang::VK_RValue,
918                                       clang::OK_Ordinary,
919                                       Loc);
920
921  clang::Expr *SrcArrPtr =
922      clang::ImplicitCastExpr::Create(C,
923          C.getPointerType(BaseType->getCanonicalTypeInternal()),
924          clang::CK_ArrayToPointerDecay,
925          SrcArr,
926          NULL,
927          clang::VK_RValue);
928
929  clang::Expr *SrcArrPtrSubscript =
930      new(C) clang::ArraySubscriptExpr(SrcArrPtr,
931                                       RefrsIntIter,
932                                       BaseType->getCanonicalTypeInternal(),
933                                       clang::VK_RValue,
934                                       clang::OK_Ordinary,
935                                       Loc);
936
937  RSExportPrimitiveType::DataType DT =
938      RSExportPrimitiveType::GetRSSpecificType(BaseType);
939
940  clang::Stmt *RSSetObjectCall = NULL;
941  if (BaseType->isArrayType()) {
942    RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
943                                             SrcArrPtrSubscript,
944                                             StartLoc, Loc);
945  } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
946    RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
947                                              SrcArrPtrSubscript,
948                                              StartLoc, Loc);
949  } else {
950    RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
951                                              SrcArrPtrSubscript,
952                                              StartLoc, Loc);
953  }
954
955  clang::ForStmt *DestructorLoop =
956      new(C) clang::ForStmt(C,
957                            Init,
958                            Cond,
959                            NULL,  // no condVar
960                            Inc,
961                            RSSetObjectCall,
962                            Loc,
963                            Loc,
964                            Loc);
965
966  StmtArray[StmtCtr++] = DestructorLoop;
967  slangAssert(StmtCtr == 2);
968
969  clang::CompoundStmt *CS =
970      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
971
972  return CS;
973}
974
975static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
976                                            clang::Expr *LHS,
977                                            clang::Expr *RHS,
978                                            clang::SourceLocation StartLoc,
979                                            clang::SourceLocation Loc) {
980  clang::QualType QT = LHS->getType();
981  const clang::Type *T = QT.getTypePtr();
982  slangAssert(T->isStructureType());
983  slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
984
985  // Keep an extra slot for the original copy (memcpy)
986  unsigned FieldsToSet = CountRSObjectTypes(C, T, Loc) + 1;
987
988  unsigned StmtCount = 0;
989  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
990  for (unsigned i = 0; i < FieldsToSet; i++) {
991    StmtArray[i] = NULL;
992  }
993
994  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
995  RD = RD->getDefinition();
996  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
997         FE = RD->field_end();
998       FI != FE;
999       FI++) {
1000    bool IsArrayType = false;
1001    clang::FieldDecl *FD = *FI;
1002    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1003    const clang::Type *OrigType = FT;
1004
1005    if (!CountRSObjectTypes(C, FT, Loc)) {
1006      // Skip to next if we don't have any viable RS object types
1007      continue;
1008    }
1009
1010    clang::DeclAccessPair FoundDecl =
1011        clang::DeclAccessPair::make(FD, clang::AS_none);
1012    clang::MemberExpr *DstMember =
1013        clang::MemberExpr::Create(C,
1014                                  LHS,
1015                                  false,
1016                                  clang::NestedNameSpecifierLoc(),
1017                                  FD,
1018                                  FoundDecl,
1019                                  clang::DeclarationNameInfo(),
1020                                  NULL,
1021                                  OrigType->getCanonicalTypeInternal(),
1022                                  clang::VK_RValue,
1023                                  clang::OK_Ordinary);
1024
1025    clang::MemberExpr *SrcMember =
1026        clang::MemberExpr::Create(C,
1027                                  RHS,
1028                                  false,
1029                                  clang::NestedNameSpecifierLoc(),
1030                                  FD,
1031                                  FoundDecl,
1032                                  clang::DeclarationNameInfo(),
1033                                  NULL,
1034                                  OrigType->getCanonicalTypeInternal(),
1035                                  clang::VK_RValue,
1036                                  clang::OK_Ordinary);
1037
1038    if (FT->isArrayType()) {
1039      FT = FT->getArrayElementTypeNoTypeQual();
1040      IsArrayType = true;
1041    }
1042
1043    RSExportPrimitiveType::DataType DT =
1044        RSExportPrimitiveType::GetRSSpecificType(FT);
1045
1046    if (IsArrayType) {
1047      clang::Diagnostic &Diags = C.getDiagnostics();
1048      Diags.Report(clang::FullSourceLoc(Loc, C.getSourceManager()),
1049          Diags.getCustomDiagID(clang::Diagnostic::Error,
1050            "Arrays of RS object types within structures cannot be copied"));
1051      // TODO(srhines): Support setting arrays of RS objects
1052      // StmtArray[StmtCount++] =
1053      //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1054    } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
1055      StmtArray[StmtCount++] =
1056          CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1057    } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
1058      StmtArray[StmtCount++] =
1059          CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1060    } else {
1061      slangAssert(false);
1062    }
1063  }
1064
1065  slangAssert(StmtCount > 0 && StmtCount < FieldsToSet);
1066
1067  // We still need to actually do the overall struct copy. For simplicity,
1068  // we just do a straight-up assignment (which will still preserve all
1069  // the proper RS object reference counts).
1070  clang::BinaryOperator *CopyStruct =
1071      new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
1072                                   clang::VK_RValue, clang::OK_Ordinary, Loc);
1073  StmtArray[StmtCount++] = CopyStruct;
1074
1075  clang::CompoundStmt *CS =
1076      new(C) clang::CompoundStmt(C, StmtArray, StmtCount, Loc, Loc);
1077
1078  delete [] StmtArray;
1079
1080  return CS;
1081}
1082
1083}  // namespace
1084
1085void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1086    clang::BinaryOperator *AS) {
1087
1088  clang::QualType QT = AS->getType();
1089
1090  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1091      RSExportPrimitiveType::DataTypeRSFont)->getASTContext();
1092
1093  clang::SourceLocation Loc = AS->getExprLoc();
1094  clang::SourceLocation StartLoc = AS->getExprLoc();
1095  clang::Stmt *UpdatedStmt = NULL;
1096
1097  if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1098    // By definition, this is a struct assignment if we get here
1099    UpdatedStmt =
1100        CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1101  } else {
1102    UpdatedStmt =
1103        CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1104  }
1105
1106  RSASTReplace R(C);
1107  R.ReplaceStmt(mCS, AS, UpdatedStmt);
1108  return;
1109}
1110
1111void RSObjectRefCount::Scope::AppendRSObjectInit(
1112    clang::VarDecl *VD,
1113    clang::DeclStmt *DS,
1114    RSExportPrimitiveType::DataType DT,
1115    clang::Expr *InitExpr) {
1116  slangAssert(VD);
1117
1118  if (!InitExpr) {
1119    return;
1120  }
1121
1122  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1123      RSExportPrimitiveType::DataTypeRSFont)->getASTContext();
1124  clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1125      RSExportPrimitiveType::DataTypeRSFont)->getLocation();
1126  clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1127      RSExportPrimitiveType::DataTypeRSFont)->getInnerLocStart();
1128
1129  if (DT == RSExportPrimitiveType::DataTypeIsStruct) {
1130    const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1131    clang::DeclRefExpr *RefRSVar =
1132        clang::DeclRefExpr::Create(C,
1133                                   clang::NestedNameSpecifierLoc(),
1134                                   VD,
1135                                   Loc,
1136                                   T->getCanonicalTypeInternal(),
1137                                   clang::VK_RValue,
1138                                   NULL);
1139
1140    clang::Stmt *RSSetObjectOps =
1141        CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1142
1143    std::list<clang::Stmt*> StmtList;
1144    StmtList.push_back(RSSetObjectOps);
1145    AppendAfterStmt(C, mCS, DS, StmtList);
1146    return;
1147  }
1148
1149  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1150  slangAssert((SetObjectFD != NULL) &&
1151              "rsSetObject doesn't cover all RS object types");
1152
1153  clang::QualType SetObjectFDType = SetObjectFD->getType();
1154  clang::QualType SetObjectFDArgType[2];
1155  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1156  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1157
1158  clang::Expr *RefRSSetObjectFD =
1159      clang::DeclRefExpr::Create(C,
1160                                 clang::NestedNameSpecifierLoc(),
1161                                 SetObjectFD,
1162                                 Loc,
1163                                 SetObjectFDType,
1164                                 clang::VK_RValue,
1165                                 NULL);
1166
1167  clang::Expr *RSSetObjectFP =
1168      clang::ImplicitCastExpr::Create(C,
1169                                      C.getPointerType(SetObjectFDType),
1170                                      clang::CK_FunctionToPointerDecay,
1171                                      RefRSSetObjectFD,
1172                                      NULL,
1173                                      clang::VK_RValue);
1174
1175  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1176  clang::DeclRefExpr *RefRSVar =
1177      clang::DeclRefExpr::Create(C,
1178                                 clang::NestedNameSpecifierLoc(),
1179                                 VD,
1180                                 Loc,
1181                                 T->getCanonicalTypeInternal(),
1182                                 clang::VK_RValue,
1183                                 NULL);
1184
1185  clang::Expr *ArgList[2];
1186  ArgList[0] = new(C) clang::UnaryOperator(RefRSVar,
1187                                           clang::UO_AddrOf,
1188                                           SetObjectFDArgType[0],
1189                                           clang::VK_RValue,
1190                                           clang::OK_Ordinary,
1191                                           Loc);
1192  ArgList[1] = InitExpr;
1193
1194  clang::CallExpr *RSSetObjectCall =
1195      new(C) clang::CallExpr(C,
1196                             RSSetObjectFP,
1197                             ArgList,
1198                             2,
1199                             SetObjectFD->getCallResultType(),
1200                             clang::VK_RValue,
1201                             Loc);
1202
1203  std::list<clang::Stmt*> StmtList;
1204  StmtList.push_back(RSSetObjectCall);
1205  AppendAfterStmt(C, mCS, DS, StmtList);
1206
1207  return;
1208}
1209
1210void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1211  std::list<clang::Stmt*> RSClearObjectCalls;
1212  for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
1213          E = mRSO.end();
1214        I != E;
1215        I++) {
1216    clang::Stmt *S = ClearRSObject(*I);
1217    if (S) {
1218      RSClearObjectCalls.push_back(S);
1219    }
1220  }
1221  if (RSClearObjectCalls.size() > 0) {
1222    DestructorVisitor DV((*mRSO.begin())->getASTContext(),
1223                         mCS,
1224                         RSClearObjectCalls);
1225    DV.Visit(mCS);
1226    DV.InsertDestructors();
1227  }
1228  return;
1229}
1230
1231clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(clang::VarDecl *VD) {
1232  slangAssert(VD);
1233  clang::ASTContext &C = VD->getASTContext();
1234  clang::DeclContext *DC = VD->getDeclContext();
1235  clang::SourceLocation Loc = VD->getLocation();
1236  clang::SourceLocation StartLoc = VD->getInnerLocStart();
1237  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1238
1239  // Reference expr to target RS object variable
1240  clang::DeclRefExpr *RefRSVar =
1241      clang::DeclRefExpr::Create(C,
1242                                 clang::NestedNameSpecifierLoc(),
1243                                 VD,
1244                                 Loc,
1245                                 T->getCanonicalTypeInternal(),
1246                                 clang::VK_RValue,
1247                                 NULL);
1248
1249  if (T->isArrayType()) {
1250    return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1251  }
1252
1253  RSExportPrimitiveType::DataType DT =
1254      RSExportPrimitiveType::GetRSSpecificType(T);
1255
1256  if (DT == RSExportPrimitiveType::DataTypeUnknown ||
1257      DT == RSExportPrimitiveType::DataTypeIsStruct) {
1258    return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1259  }
1260
1261  slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1262              "Should be RS object");
1263
1264  return ClearSingleRSObject(C, RefRSVar, Loc);
1265}
1266
1267bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1268                                          RSExportPrimitiveType::DataType *DT,
1269                                          clang::Expr **InitExpr) {
1270  slangAssert(VD && DT && InitExpr);
1271  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1272
1273  // Loop through array types to get to base type
1274  while (T && T->isArrayType()) {
1275    T = T->getArrayElementTypeNoTypeQual();
1276  }
1277
1278  bool DataTypeIsStructWithRSObject = false;
1279  *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1280
1281  if (*DT == RSExportPrimitiveType::DataTypeUnknown) {
1282    if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1283      *DT = RSExportPrimitiveType::DataTypeIsStruct;
1284      DataTypeIsStructWithRSObject = true;
1285    } else {
1286      return false;
1287    }
1288  }
1289
1290  bool DataTypeIsRSObject = false;
1291  if (DataTypeIsStructWithRSObject) {
1292    DataTypeIsRSObject = true;
1293  } else {
1294    DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1295  }
1296  *InitExpr = VD->getInit();
1297
1298  if (!DataTypeIsRSObject && *InitExpr) {
1299    // If we already have an initializer for a matrix type, we are done.
1300    return DataTypeIsRSObject;
1301  }
1302
1303  clang::Expr *ZeroInitializer =
1304      CreateZeroInitializerForRSSpecificType(*DT,
1305                                             VD->getASTContext(),
1306                                             VD->getLocation());
1307
1308  if (ZeroInitializer) {
1309    ZeroInitializer->setType(T->getCanonicalTypeInternal());
1310    VD->setInit(ZeroInitializer);
1311  }
1312
1313  return DataTypeIsRSObject;
1314}
1315
1316clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
1317    RSExportPrimitiveType::DataType DT,
1318    clang::ASTContext &C,
1319    const clang::SourceLocation &Loc) {
1320  clang::Expr *Res = NULL;
1321  switch (DT) {
1322    case RSExportPrimitiveType::DataTypeIsStruct:
1323    case RSExportPrimitiveType::DataTypeRSElement:
1324    case RSExportPrimitiveType::DataTypeRSType:
1325    case RSExportPrimitiveType::DataTypeRSAllocation:
1326    case RSExportPrimitiveType::DataTypeRSSampler:
1327    case RSExportPrimitiveType::DataTypeRSScript:
1328    case RSExportPrimitiveType::DataTypeRSMesh:
1329    case RSExportPrimitiveType::DataTypeRSProgramFragment:
1330    case RSExportPrimitiveType::DataTypeRSProgramVertex:
1331    case RSExportPrimitiveType::DataTypeRSProgramRaster:
1332    case RSExportPrimitiveType::DataTypeRSProgramStore:
1333    case RSExportPrimitiveType::DataTypeRSFont: {
1334      //    (ImplicitCastExpr 'nullptr_t'
1335      //      (IntegerLiteral 0)))
1336      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
1337      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
1338      clang::Expr *CastToNull =
1339          clang::ImplicitCastExpr::Create(C,
1340                                          C.NullPtrTy,
1341                                          clang::CK_IntegralToPointer,
1342                                          Int0,
1343                                          NULL,
1344                                          clang::VK_RValue);
1345
1346      Res = new(C) clang::InitListExpr(C, Loc, &CastToNull, 1, Loc);
1347      break;
1348    }
1349    case RSExportPrimitiveType::DataTypeRSMatrix2x2:
1350    case RSExportPrimitiveType::DataTypeRSMatrix3x3:
1351    case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
1352      // RS matrix is not completely an RS object. They hold data by themselves.
1353      // (InitListExpr rs_matrix2x2
1354      //   (InitListExpr float[4]
1355      //     (FloatingLiteral 0)
1356      //     (FloatingLiteral 0)
1357      //     (FloatingLiteral 0)
1358      //     (FloatingLiteral 0)))
1359      clang::QualType FloatTy = C.FloatTy;
1360      // Constructor sets value to 0.0f by default
1361      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
1362      clang::FloatingLiteral *Float0Val =
1363          clang::FloatingLiteral::Create(C,
1364                                         Val,
1365                                         /* isExact = */true,
1366                                         FloatTy,
1367                                         Loc);
1368
1369      unsigned N = 0;
1370      if (DT == RSExportPrimitiveType::DataTypeRSMatrix2x2)
1371        N = 2;
1372      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix3x3)
1373        N = 3;
1374      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix4x4)
1375        N = 4;
1376
1377      // Directly allocate 16 elements instead of dynamically allocate N*N
1378      clang::Expr *InitVals[16];
1379      for (unsigned i = 0; i < sizeof(InitVals) / sizeof(InitVals[0]); i++)
1380        InitVals[i] = Float0Val;
1381      clang::Expr *InitExpr =
1382          new(C) clang::InitListExpr(C, Loc, InitVals, N * N, Loc);
1383      InitExpr->setType(C.getConstantArrayType(FloatTy,
1384                                               llvm::APInt(32, 4),
1385                                               clang::ArrayType::Normal,
1386                                               /* EltTypeQuals = */0));
1387
1388      Res = new(C) clang::InitListExpr(C, Loc, &InitExpr, 1, Loc);
1389      break;
1390    }
1391    case RSExportPrimitiveType::DataTypeUnknown:
1392    case RSExportPrimitiveType::DataTypeFloat16:
1393    case RSExportPrimitiveType::DataTypeFloat32:
1394    case RSExportPrimitiveType::DataTypeFloat64:
1395    case RSExportPrimitiveType::DataTypeSigned8:
1396    case RSExportPrimitiveType::DataTypeSigned16:
1397    case RSExportPrimitiveType::DataTypeSigned32:
1398    case RSExportPrimitiveType::DataTypeSigned64:
1399    case RSExportPrimitiveType::DataTypeUnsigned8:
1400    case RSExportPrimitiveType::DataTypeUnsigned16:
1401    case RSExportPrimitiveType::DataTypeUnsigned32:
1402    case RSExportPrimitiveType::DataTypeUnsigned64:
1403    case RSExportPrimitiveType::DataTypeBoolean:
1404    case RSExportPrimitiveType::DataTypeUnsigned565:
1405    case RSExportPrimitiveType::DataTypeUnsigned5551:
1406    case RSExportPrimitiveType::DataTypeUnsigned4444:
1407    case RSExportPrimitiveType::DataTypeMax: {
1408      slangAssert(false && "Not RS object type!");
1409    }
1410    // No default case will enable compiler detecting the missing cases
1411  }
1412
1413  return Res;
1414}
1415
1416void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1417  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1418       I != E;
1419       I++) {
1420    clang::Decl *D = *I;
1421    if (D->getKind() == clang::Decl::Var) {
1422      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1423      RSExportPrimitiveType::DataType DT =
1424          RSExportPrimitiveType::DataTypeUnknown;
1425      clang::Expr *InitExpr = NULL;
1426      if (InitializeRSObject(VD, &DT, &InitExpr)) {
1427        getCurrentScope()->addRSObject(VD);
1428        getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1429      }
1430    }
1431  }
1432  return;
1433}
1434
1435void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1436  if (!CS->body_empty()) {
1437    // Push a new scope
1438    Scope *S = new Scope(CS);
1439    mScopeStack.push(S);
1440
1441    VisitStmt(CS);
1442
1443    // Destroy the scope
1444    slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1445    S->InsertLocalVarDestructors();
1446    mScopeStack.pop();
1447    delete S;
1448  }
1449  return;
1450}
1451
1452void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1453  clang::QualType QT = AS->getType();
1454
1455  if (CountRSObjectTypes(mCtx, QT.getTypePtr(), AS->getExprLoc())) {
1456    getCurrentScope()->ReplaceRSObjectAssignment(AS);
1457  }
1458
1459  return;
1460}
1461
1462void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1463  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1464       I != E;
1465       I++) {
1466    if (clang::Stmt *Child = *I) {
1467      Visit(Child);
1468    }
1469  }
1470  return;
1471}
1472
1473}  // namespace slang
1474