slang_rs_object_ref_count.cpp revision 292e00a0259ac28cac1055cb6077cf6fc7c6743c
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_object_ref_count.h"
18
19#include <list>
20
21#include "clang/AST/DeclGroup.h"
22#include "clang/AST/Expr.h"
23#include "clang/AST/NestedNameSpecifier.h"
24#include "clang/AST/OperationKinds.h"
25#include "clang/AST/Stmt.h"
26#include "clang/AST/StmtVisitor.h"
27
28#include "slang_assert.h"
29#include "slang_rs.h"
30#include "slang_rs_ast_replace.h"
31#include "slang_rs_export_type.h"
32
33namespace slang {
34
35clang::FunctionDecl *RSObjectRefCount::
36    RSSetObjectFD[RSExportPrimitiveType::LastRSObjectType -
37                  RSExportPrimitiveType::FirstRSObjectType + 1];
38clang::FunctionDecl *RSObjectRefCount::
39    RSClearObjectFD[RSExportPrimitiveType::LastRSObjectType -
40                    RSExportPrimitiveType::FirstRSObjectType + 1];
41
42void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
43  for (unsigned i = 0;
44       i < (sizeof(RSClearObjectFD) / sizeof(clang::FunctionDecl*));
45       i++) {
46    RSSetObjectFD[i] = NULL;
47    RSClearObjectFD[i] = NULL;
48  }
49
50  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
51
52  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
53          E = TUDecl->decls_end(); I != E; I++) {
54    if ((I->getKind() >= clang::Decl::firstFunction) &&
55        (I->getKind() <= clang::Decl::lastFunction)) {
56      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
57
58      // points to RSSetObjectFD or RSClearObjectFD
59      clang::FunctionDecl **RSObjectFD;
60
61      if (FD->getName() == "rsSetObject") {
62        slangAssert((FD->getNumParams() == 2) &&
63                    "Invalid rsSetObject function prototype (# params)");
64        RSObjectFD = RSSetObjectFD;
65      } else if (FD->getName() == "rsClearObject") {
66        slangAssert((FD->getNumParams() == 1) &&
67                    "Invalid rsClearObject function prototype (# params)");
68        RSObjectFD = RSClearObjectFD;
69      } else {
70        continue;
71      }
72
73      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
74      clang::QualType PVT = PVD->getOriginalType();
75      // The first parameter must be a pointer like rs_allocation*
76      slangAssert(PVT->isPointerType() &&
77          "Invalid rs{Set,Clear}Object function prototype (pointer param)");
78
79      // The rs object type passed to the FD
80      clang::QualType RST = PVT->getPointeeType();
81      RSExportPrimitiveType::DataType DT =
82          RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
83      slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
84             && "must be RS object type");
85
86      RSObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)] = FD;
87    }
88  }
89}
90
91namespace {
92
93// This function constructs a new CompoundStmt from the input StmtList.
94static clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
95      std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
96  unsigned NewStmtCount = StmtList.size();
97  unsigned CompoundStmtCount = 0;
98
99  clang::Stmt **CompoundStmtList;
100  CompoundStmtList = new clang::Stmt*[NewStmtCount];
101
102  std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
103  std::list<clang::Stmt*>::const_iterator E = StmtList.end();
104  for ( ; I != E; I++) {
105    CompoundStmtList[CompoundStmtCount++] = *I;
106  }
107  slangAssert(CompoundStmtCount == NewStmtCount);
108
109  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(C,
110                                                       CompoundStmtList,
111                                                       CompoundStmtCount,
112                                                       Loc,
113                                                       Loc);
114
115  delete [] CompoundStmtList;
116
117  return CS;
118}
119
120static void AppendAfterStmt(clang::ASTContext &C,
121                            clang::CompoundStmt *CS,
122                            clang::Stmt *S,
123                            std::list<clang::Stmt*> &StmtList) {
124  slangAssert(CS);
125  clang::CompoundStmt::body_iterator bI = CS->body_begin();
126  clang::CompoundStmt::body_iterator bE = CS->body_end();
127  clang::Stmt **UpdatedStmtList =
128      new clang::Stmt*[CS->size() + StmtList.size()];
129
130  unsigned UpdatedStmtCount = 0;
131  unsigned Once = 0;
132  for ( ; bI != bE; bI++) {
133    if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
134      // If we come across a return here, we don't have anything we can
135      // reasonably replace. We should have already inserted our destructor
136      // code in the proper spot, so we just clean up and return.
137      delete [] UpdatedStmtList;
138
139      return;
140    }
141
142    UpdatedStmtList[UpdatedStmtCount++] = *bI;
143
144    if ((*bI == S) && !Once) {
145      Once++;
146      std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
147      std::list<clang::Stmt*>::const_iterator E = StmtList.end();
148      for ( ; I != E; I++) {
149        UpdatedStmtList[UpdatedStmtCount++] = *I;
150      }
151    }
152  }
153  slangAssert(Once <= 1);
154
155  // When S is NULL, we are appending to the end of the CompoundStmt.
156  if (!S) {
157    slangAssert(Once == 0);
158    std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
159    std::list<clang::Stmt*>::const_iterator E = StmtList.end();
160    for ( ; I != E; I++) {
161      UpdatedStmtList[UpdatedStmtCount++] = *I;
162    }
163  }
164
165  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
166
167  delete [] UpdatedStmtList;
168
169  return;
170}
171
172// This class visits a compound statement and inserts the StmtList containing
173// destructors in proper locations. This includes inserting them before any
174// return statement in any sub-block, at the end of the logical enclosing
175// scope (compound statement), and/or before any break/continue statement that
176// would resume outside the declared scope. We will not handle the case for
177// goto statements that leave a local scope.
178//
179// To accomplish these goals, it collects a list of sub-Stmt's that
180// correspond to scope exit points. It then uses an RSASTReplace visitor to
181// transform the AST, inserting appropriate destructors before each of those
182// sub-Stmt's (and also before the exit of the outermost containing Stmt for
183// the scope).
184class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
185 private:
186  clang::ASTContext &mC;
187
188  // The loop depth of the currently visited node.
189  int mLoopDepth;
190
191  // The switch statement depth of the currently visited node.
192  // Note that this is tracked separately from the loop depth because
193  // SwitchStmt-contained ContinueStmt's should have destructors for the
194  // corresponding loop scope.
195  int mSwitchDepth;
196
197  // The outermost statement block that we are currently visiting.
198  // This should always be a CompoundStmt.
199  clang::Stmt *mOuterStmt;
200
201  // The list of destructors to execute for this scope.
202  std::list<clang::Stmt*> &mStmtList;
203
204  // The stack of statements which should be replaced by a compound statement
205  // containing the new destructor calls followed by the original Stmt.
206  std::stack<clang::Stmt*> mReplaceStmtStack;
207
208 public:
209  DestructorVisitor(clang::ASTContext &C,
210                    clang::Stmt* OuterStmt,
211                    std::list<clang::Stmt*> &StmtList);
212
213  // This code walks the collected list of Stmts to replace and actually does
214  // the replacement. It also finishes up by appending appropriate destructors
215  // to the current outermost CompoundStmt.
216  void InsertDestructors() {
217    clang::Stmt *S = NULL;
218    while (!mReplaceStmtStack.empty()) {
219      S = mReplaceStmtStack.top();
220      mReplaceStmtStack.pop();
221
222      mStmtList.push_back(S);
223      clang::CompoundStmt *CS =
224          BuildCompoundStmt(mC, mStmtList, S->getLocEnd());
225      mStmtList.pop_back();
226
227      RSASTReplace R(mC);
228      R.ReplaceStmt(mOuterStmt, S, CS);
229    }
230    clang::CompoundStmt *CS = dyn_cast<clang::CompoundStmt>(mOuterStmt);
231    slangAssert(CS);
232    if (CS) {
233      AppendAfterStmt(mC, CS, NULL, mStmtList);
234    }
235  }
236
237  void VisitStmt(clang::Stmt *S);
238  void VisitCompoundStmt(clang::CompoundStmt *CS);
239
240  void VisitBreakStmt(clang::BreakStmt *BS);
241  void VisitCaseStmt(clang::CaseStmt *CS);
242  void VisitContinueStmt(clang::ContinueStmt *CS);
243  void VisitDefaultStmt(clang::DefaultStmt *DS);
244  void VisitDoStmt(clang::DoStmt *DS);
245  void VisitForStmt(clang::ForStmt *FS);
246  void VisitIfStmt(clang::IfStmt *IS);
247  void VisitReturnStmt(clang::ReturnStmt *RS);
248  void VisitSwitchCase(clang::SwitchCase *SC);
249  void VisitSwitchStmt(clang::SwitchStmt *SS);
250  void VisitWhileStmt(clang::WhileStmt *WS);
251};
252
253DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
254                         clang::Stmt *OuterStmt,
255                         std::list<clang::Stmt*> &StmtList)
256  : mC(C),
257    mLoopDepth(0),
258    mSwitchDepth(0),
259    mOuterStmt(OuterStmt),
260    mStmtList(StmtList) {
261  return;
262}
263
264void DestructorVisitor::VisitStmt(clang::Stmt *S) {
265  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
266       I != E;
267       I++) {
268    if (clang::Stmt *Child = *I) {
269      Visit(Child);
270    }
271  }
272  return;
273}
274
275void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
276  VisitStmt(CS);
277  return;
278}
279
280void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
281  VisitStmt(BS);
282  if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
283    mReplaceStmtStack.push(BS);
284  }
285  return;
286}
287
288void DestructorVisitor::VisitCaseStmt(clang::CaseStmt *CS) {
289  VisitStmt(CS);
290  return;
291}
292
293void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
294  VisitStmt(CS);
295  if (mLoopDepth == 0) {
296    // Switch statements can have nested continues.
297    mReplaceStmtStack.push(CS);
298  }
299  return;
300}
301
302void DestructorVisitor::VisitDefaultStmt(clang::DefaultStmt *DS) {
303  VisitStmt(DS);
304  return;
305}
306
307void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
308  mLoopDepth++;
309  VisitStmt(DS);
310  mLoopDepth--;
311  return;
312}
313
314void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
315  mLoopDepth++;
316  VisitStmt(FS);
317  mLoopDepth--;
318  return;
319}
320
321void DestructorVisitor::VisitIfStmt(clang::IfStmt *IS) {
322  VisitStmt(IS);
323  return;
324}
325
326void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
327  mReplaceStmtStack.push(RS);
328  return;
329}
330
331void DestructorVisitor::VisitSwitchCase(clang::SwitchCase *SC) {
332  slangAssert(false && "Both case and default have specialized handlers");
333  VisitStmt(SC);
334  return;
335}
336
337void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
338  mSwitchDepth++;
339  VisitStmt(SS);
340  mSwitchDepth--;
341  return;
342}
343
344void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
345  mLoopDepth++;
346  VisitStmt(WS);
347  mLoopDepth--;
348  return;
349}
350
351clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
352                                 clang::Expr *RefRSVar,
353                                 clang::SourceLocation Loc) {
354  slangAssert(RefRSVar);
355  const clang::Type *T = RefRSVar->getType().getTypePtr();
356  slangAssert(!T->isArrayType() &&
357              "Should not be destroying arrays with this function");
358
359  clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
360  slangAssert((ClearObjectFD != NULL) &&
361              "rsClearObject doesn't cover all RS object types");
362
363  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
364  clang::QualType ClearObjectFDArgType =
365      ClearObjectFD->getParamDecl(0)->getOriginalType();
366
367  // Example destructor for "rs_font localFont;"
368  //
369  // (CallExpr 'void'
370  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
371  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
372  //   (UnaryOperator 'rs_font *' prefix '&'
373  //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
374
375  // Get address of targeted RS object
376  clang::Expr *AddrRefRSVar =
377      new(C) clang::UnaryOperator(RefRSVar,
378                                  clang::UO_AddrOf,
379                                  ClearObjectFDArgType,
380                                  clang::VK_RValue,
381                                  clang::OK_Ordinary,
382                                  Loc);
383
384  clang::Expr *RefRSClearObjectFD =
385      clang::DeclRefExpr::Create(C,
386                                 clang::NestedNameSpecifierLoc(),
387                                 ClearObjectFD,
388                                 ClearObjectFD->getLocation(),
389                                 ClearObjectFDType,
390                                 clang::VK_RValue,
391                                 NULL);
392
393  clang::Expr *RSClearObjectFP =
394      clang::ImplicitCastExpr::Create(C,
395                                      C.getPointerType(ClearObjectFDType),
396                                      clang::CK_FunctionToPointerDecay,
397                                      RefRSClearObjectFD,
398                                      NULL,
399                                      clang::VK_RValue);
400
401  clang::CallExpr *RSClearObjectCall =
402      new(C) clang::CallExpr(C,
403                             RSClearObjectFP,
404                             &AddrRefRSVar,
405                             1,
406                             ClearObjectFD->getCallResultType(),
407                             clang::VK_RValue,
408                             Loc);
409
410  return RSClearObjectCall;
411}
412
413static int ArrayDim(const clang::Type *T) {
414  if (!T || !T->isArrayType()) {
415    return 0;
416  }
417
418  const clang::ConstantArrayType *CAT =
419    static_cast<const clang::ConstantArrayType *>(T);
420  return static_cast<int>(CAT->getSize().getSExtValue());
421}
422
423static clang::Stmt *ClearStructRSObject(
424    clang::ASTContext &C,
425    clang::DeclContext *DC,
426    clang::Expr *RefRSStruct,
427    clang::SourceLocation Loc);
428
429static clang::Stmt *ClearArrayRSObject(
430    clang::ASTContext &C,
431    clang::DeclContext *DC,
432    clang::Expr *RefRSArr,
433    clang::SourceLocation Loc) {
434  const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
435  slangAssert(BaseType->isArrayType());
436
437  int NumArrayElements = ArrayDim(BaseType);
438  // Actually extract out the base RS object type for use later
439  BaseType = BaseType->getArrayElementTypeNoTypeQual();
440
441  clang::Stmt *StmtArray[2] = {NULL};
442  int StmtCtr = 0;
443
444  if (NumArrayElements <= 0) {
445    return NULL;
446  }
447
448  // Example destructor loop for "rs_font fontArr[10];"
449  //
450  // (CompoundStmt
451  //   (DeclStmt "int rsIntIter")
452  //   (ForStmt
453  //     (BinaryOperator 'int' '='
454  //       (DeclRefExpr 'int' Var='rsIntIter')
455  //       (IntegerLiteral 'int' 0))
456  //     (BinaryOperator 'int' '<'
457  //       (DeclRefExpr 'int' Var='rsIntIter')
458  //       (IntegerLiteral 'int' 10)
459  //     NULL << CondVar >>
460  //     (UnaryOperator 'int' postfix '++'
461  //       (DeclRefExpr 'int' Var='rsIntIter'))
462  //     (CallExpr 'void'
463  //       (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
464  //         (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
465  //       (UnaryOperator 'rs_font *' prefix '&'
466  //         (ArraySubscriptExpr 'rs_font':'rs_font'
467  //           (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
468  //             (DeclRefExpr 'rs_font [10]' Var='fontArr'))
469  //           (DeclRefExpr 'int' Var='rsIntIter')))))))
470
471  // Create helper variable for iterating through elements
472  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
473  clang::VarDecl *IIVD =
474      clang::VarDecl::Create(C,
475                             DC,
476                             Loc,
477                             &II,
478                             C.IntTy,
479                             C.getTrivialTypeSourceInfo(C.IntTy),
480                             clang::SC_None,
481                             clang::SC_None);
482  clang::Decl *IID = (clang::Decl *)IIVD;
483
484  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
485  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
486
487  // Form the actual destructor loop
488  // for (Init; Cond; Inc)
489  //   RSClearObjectCall;
490
491  // Init -> "rsIntIter = 0"
492  clang::DeclRefExpr *RefrsIntIter =
493      clang::DeclRefExpr::Create(C,
494                                 clang::NestedNameSpecifierLoc(),
495                                 IIVD,
496                                 Loc,
497                                 C.IntTy,
498                                 clang::VK_RValue,
499                                 NULL);
500
501  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
502      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
503
504  clang::BinaryOperator *Init =
505      new(C) clang::BinaryOperator(RefrsIntIter,
506                                   Int0,
507                                   clang::BO_Assign,
508                                   C.IntTy,
509                                   clang::VK_RValue,
510                                   clang::OK_Ordinary,
511                                   Loc);
512
513  // Cond -> "rsIntIter < NumArrayElements"
514  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
515      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
516
517  clang::BinaryOperator *Cond =
518      new(C) clang::BinaryOperator(RefrsIntIter,
519                                   NumArrayElementsExpr,
520                                   clang::BO_LT,
521                                   C.IntTy,
522                                   clang::VK_RValue,
523                                   clang::OK_Ordinary,
524                                   Loc);
525
526  // Inc -> "rsIntIter++"
527  clang::UnaryOperator *Inc =
528      new(C) clang::UnaryOperator(RefrsIntIter,
529                                  clang::UO_PostInc,
530                                  C.IntTy,
531                                  clang::VK_RValue,
532                                  clang::OK_Ordinary,
533                                  Loc);
534
535  // Body -> "rsClearObject(&VD[rsIntIter]);"
536  // Destructor loop operates on individual array elements
537
538  clang::Expr *RefRSArrPtr =
539      clang::ImplicitCastExpr::Create(C,
540          C.getPointerType(BaseType->getCanonicalTypeInternal()),
541          clang::CK_ArrayToPointerDecay,
542          RefRSArr,
543          NULL,
544          clang::VK_RValue);
545
546  clang::Expr *RefRSArrPtrSubscript =
547      new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
548                                       RefrsIntIter,
549                                       BaseType->getCanonicalTypeInternal(),
550                                       clang::VK_RValue,
551                                       clang::OK_Ordinary,
552                                       Loc);
553
554  RSExportPrimitiveType::DataType DT =
555      RSExportPrimitiveType::GetRSSpecificType(BaseType);
556
557  clang::Stmt *RSClearObjectCall = NULL;
558  if (BaseType->isArrayType()) {
559    RSClearObjectCall =
560        ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, Loc);
561  } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
562    RSClearObjectCall =
563        ClearStructRSObject(C, DC, RefRSArrPtrSubscript, Loc);
564  } else {
565    RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
566  }
567
568  clang::ForStmt *DestructorLoop =
569      new(C) clang::ForStmt(C,
570                            Init,
571                            Cond,
572                            NULL,  // no condVar
573                            Inc,
574                            RSClearObjectCall,
575                            Loc,
576                            Loc,
577                            Loc);
578
579  StmtArray[StmtCtr++] = DestructorLoop;
580  slangAssert(StmtCtr == 2);
581
582  clang::CompoundStmt *CS =
583      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
584
585  return CS;
586}
587
588static unsigned CountRSObjectTypes(const clang::Type *T) {
589  slangAssert(T);
590  unsigned RSObjectCount = 0;
591
592  if (T->isArrayType()) {
593    return CountRSObjectTypes(T->getArrayElementTypeNoTypeQual());
594  }
595
596  RSExportPrimitiveType::DataType DT =
597      RSExportPrimitiveType::GetRSSpecificType(T);
598  if (DT != RSExportPrimitiveType::DataTypeUnknown) {
599    return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
600  }
601
602  if (!T->isStructureType()) {
603    return 0;
604  }
605
606  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
607  RD = RD->getDefinition();
608  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
609         FE = RD->field_end();
610       FI != FE;
611       FI++) {
612    const clang::FieldDecl *FD = *FI;
613    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
614    if (CountRSObjectTypes(FT)) {
615      // Sub-structs should only count once (as should arrays, etc.)
616      RSObjectCount++;
617    }
618  }
619
620  return RSObjectCount;
621}
622
623static clang::Stmt *ClearStructRSObject(
624    clang::ASTContext &C,
625    clang::DeclContext *DC,
626    clang::Expr *RefRSStruct,
627    clang::SourceLocation Loc) {
628  const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
629
630  slangAssert(!BaseType->isArrayType());
631
632  // Structs should show up as unknown primitive types
633  slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
634              RSExportPrimitiveType::DataTypeUnknown);
635
636  unsigned FieldsToDestroy = CountRSObjectTypes(BaseType);
637
638  unsigned StmtCount = 0;
639  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
640  for (unsigned i = 0; i < FieldsToDestroy; i++) {
641    StmtArray[i] = NULL;
642  }
643
644  // Populate StmtArray by creating a destructor for each RS object field
645  clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
646  RD = RD->getDefinition();
647  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
648         FE = RD->field_end();
649       FI != FE;
650       FI++) {
651    // We just look through all field declarations to see if we find a
652    // declaration for an RS object type (or an array of one).
653    bool IsArrayType = false;
654    clang::FieldDecl *FD = *FI;
655    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
656    const clang::Type *OrigType = FT;
657    while (FT && FT->isArrayType()) {
658      FT = FT->getArrayElementTypeNoTypeQual();
659      IsArrayType = true;
660    }
661
662    if (RSExportPrimitiveType::IsRSObjectType(FT)) {
663      clang::DeclAccessPair FoundDecl =
664          clang::DeclAccessPair::make(FD, clang::AS_none);
665      clang::MemberExpr *RSObjectMember =
666          clang::MemberExpr::Create(C,
667                                    RefRSStruct,
668                                    false,
669                                    clang::NestedNameSpecifierLoc(),
670                                    FD,
671                                    FoundDecl,
672                                    clang::DeclarationNameInfo(),
673                                    NULL,
674                                    OrigType->getCanonicalTypeInternal(),
675                                    clang::VK_RValue,
676                                    clang::OK_Ordinary);
677
678      slangAssert(StmtCount < FieldsToDestroy);
679
680      if (IsArrayType) {
681        StmtArray[StmtCount++] = ClearArrayRSObject(C,
682                                                    DC,
683                                                    RSObjectMember,
684                                                    Loc);
685      } else {
686        StmtArray[StmtCount++] = ClearSingleRSObject(C,
687                                                     RSObjectMember,
688                                                     Loc);
689      }
690    } else if (FT->isStructureType() && CountRSObjectTypes(FT)) {
691      // In this case, we have a nested struct. We may not end up filling all
692      // of the spaces in StmtArray (sub-structs should handle themselves
693      // with separate compound statements).
694      clang::DeclAccessPair FoundDecl =
695          clang::DeclAccessPair::make(FD, clang::AS_none);
696      clang::MemberExpr *RSObjectMember =
697          clang::MemberExpr::Create(C,
698                                    RefRSStruct,
699                                    false,
700                                    clang::NestedNameSpecifierLoc(),
701                                    FD,
702                                    FoundDecl,
703                                    clang::DeclarationNameInfo(),
704                                    NULL,
705                                    OrigType->getCanonicalTypeInternal(),
706                                    clang::VK_RValue,
707                                    clang::OK_Ordinary);
708
709      if (IsArrayType) {
710        StmtArray[StmtCount++] = ClearArrayRSObject(C,
711                                                    DC,
712                                                    RSObjectMember,
713                                                    Loc);
714      } else {
715        StmtArray[StmtCount++] = ClearStructRSObject(C,
716                                                     DC,
717                                                     RSObjectMember,
718                                                     Loc);
719      }
720    }
721  }
722
723  slangAssert(StmtCount > 0);
724  clang::CompoundStmt *CS =
725      new(C) clang::CompoundStmt(C, StmtArray, StmtCount, Loc, Loc);
726
727  delete [] StmtArray;
728
729  return CS;
730}
731
732static clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
733                                            clang::Diagnostic *Diags,
734                                            clang::Expr *DstExpr,
735                                            clang::Expr *SrcExpr,
736                                            clang::SourceLocation Loc) {
737  const clang::Type *T = DstExpr->getType().getTypePtr();
738  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
739  slangAssert((SetObjectFD != NULL) &&
740              "rsSetObject doesn't cover all RS object types");
741
742  clang::QualType SetObjectFDType = SetObjectFD->getType();
743  clang::QualType SetObjectFDArgType[2];
744  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
745  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
746
747  clang::Expr *RefRSSetObjectFD =
748      clang::DeclRefExpr::Create(C,
749                                 clang::NestedNameSpecifierLoc(),
750                                 SetObjectFD,
751                                 Loc,
752                                 SetObjectFDType,
753                                 clang::VK_RValue,
754                                 NULL);
755
756  clang::Expr *RSSetObjectFP =
757      clang::ImplicitCastExpr::Create(C,
758                                      C.getPointerType(SetObjectFDType),
759                                      clang::CK_FunctionToPointerDecay,
760                                      RefRSSetObjectFD,
761                                      NULL,
762                                      clang::VK_RValue);
763
764  clang::Expr *ArgList[2];
765  ArgList[0] = new(C) clang::UnaryOperator(DstExpr,
766                                           clang::UO_AddrOf,
767                                           SetObjectFDArgType[0],
768                                           clang::VK_RValue,
769                                           clang::OK_Ordinary,
770                                           Loc);
771  ArgList[1] = SrcExpr;
772
773  clang::CallExpr *RSSetObjectCall =
774      new(C) clang::CallExpr(C,
775                             RSSetObjectFP,
776                             ArgList,
777                             2,
778                             SetObjectFD->getCallResultType(),
779                             clang::VK_RValue,
780                             Loc);
781
782  return RSSetObjectCall;
783}
784
785static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
786                                            clang::Diagnostic *Diags,
787                                            clang::Expr *LHS,
788                                            clang::Expr *RHS,
789                                            clang::SourceLocation Loc);
790
791static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
792                                           clang::Diagnostic *Diags,
793                                           clang::Expr *DstArr,
794                                           clang::Expr *SrcArr,
795                                           clang::SourceLocation Loc) {
796  clang::DeclContext *DC = NULL;
797  const clang::Type *BaseType = DstArr->getType().getTypePtr();
798  slangAssert(BaseType->isArrayType());
799
800  int NumArrayElements = ArrayDim(BaseType);
801  // Actually extract out the base RS object type for use later
802  BaseType = BaseType->getArrayElementTypeNoTypeQual();
803
804  clang::Stmt *StmtArray[2] = {NULL};
805  int StmtCtr = 0;
806
807  if (NumArrayElements <= 0) {
808    return NULL;
809  }
810
811  // Create helper variable for iterating through elements
812  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
813  clang::VarDecl *IIVD =
814      clang::VarDecl::Create(C,
815                             DC,
816                             Loc,
817                             &II,
818                             C.IntTy,
819                             C.getTrivialTypeSourceInfo(C.IntTy),
820                             clang::SC_None,
821                             clang::SC_None);
822  clang::Decl *IID = (clang::Decl *)IIVD;
823
824  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
825  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
826
827  // Form the actual loop
828  // for (Init; Cond; Inc)
829  //   RSSetObjectCall;
830
831  // Init -> "rsIntIter = 0"
832  clang::DeclRefExpr *RefrsIntIter =
833      clang::DeclRefExpr::Create(C,
834                                 clang::NestedNameSpecifierLoc(),
835                                 IIVD,
836                                 Loc,
837                                 C.IntTy,
838                                 clang::VK_RValue,
839                                 NULL);
840
841  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
842      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
843
844  clang::BinaryOperator *Init =
845      new(C) clang::BinaryOperator(RefrsIntIter,
846                                   Int0,
847                                   clang::BO_Assign,
848                                   C.IntTy,
849                                   clang::VK_RValue,
850                                   clang::OK_Ordinary,
851                                   Loc);
852
853  // Cond -> "rsIntIter < NumArrayElements"
854  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
855      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
856
857  clang::BinaryOperator *Cond =
858      new(C) clang::BinaryOperator(RefrsIntIter,
859                                   NumArrayElementsExpr,
860                                   clang::BO_LT,
861                                   C.IntTy,
862                                   clang::VK_RValue,
863                                   clang::OK_Ordinary,
864                                   Loc);
865
866  // Inc -> "rsIntIter++"
867  clang::UnaryOperator *Inc =
868      new(C) clang::UnaryOperator(RefrsIntIter,
869                                  clang::UO_PostInc,
870                                  C.IntTy,
871                                  clang::VK_RValue,
872                                  clang::OK_Ordinary,
873                                  Loc);
874
875  // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
876  // Loop operates on individual array elements
877
878  clang::Expr *DstArrPtr =
879      clang::ImplicitCastExpr::Create(C,
880          C.getPointerType(BaseType->getCanonicalTypeInternal()),
881          clang::CK_ArrayToPointerDecay,
882          DstArr,
883          NULL,
884          clang::VK_RValue);
885
886  clang::Expr *DstArrPtrSubscript =
887      new(C) clang::ArraySubscriptExpr(DstArrPtr,
888                                       RefrsIntIter,
889                                       BaseType->getCanonicalTypeInternal(),
890                                       clang::VK_RValue,
891                                       clang::OK_Ordinary,
892                                       Loc);
893
894  clang::Expr *SrcArrPtr =
895      clang::ImplicitCastExpr::Create(C,
896          C.getPointerType(BaseType->getCanonicalTypeInternal()),
897          clang::CK_ArrayToPointerDecay,
898          SrcArr,
899          NULL,
900          clang::VK_RValue);
901
902  clang::Expr *SrcArrPtrSubscript =
903      new(C) clang::ArraySubscriptExpr(SrcArrPtr,
904                                       RefrsIntIter,
905                                       BaseType->getCanonicalTypeInternal(),
906                                       clang::VK_RValue,
907                                       clang::OK_Ordinary,
908                                       Loc);
909
910  RSExportPrimitiveType::DataType DT =
911      RSExportPrimitiveType::GetRSSpecificType(BaseType);
912
913  clang::Stmt *RSSetObjectCall = NULL;
914  if (BaseType->isArrayType()) {
915    RSSetObjectCall = CreateArrayRSSetObject(C, Diags, DstArrPtrSubscript,
916                                             SrcArrPtrSubscript, Loc);
917  } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
918    RSSetObjectCall = CreateStructRSSetObject(C, Diags, DstArrPtrSubscript,
919                                              SrcArrPtrSubscript, Loc);
920  } else {
921    RSSetObjectCall = CreateSingleRSSetObject(C, Diags, DstArrPtrSubscript,
922                                              SrcArrPtrSubscript, Loc);
923  }
924
925  clang::ForStmt *DestructorLoop =
926      new(C) clang::ForStmt(C,
927                            Init,
928                            Cond,
929                            NULL,  // no condVar
930                            Inc,
931                            RSSetObjectCall,
932                            Loc,
933                            Loc,
934                            Loc);
935
936  StmtArray[StmtCtr++] = DestructorLoop;
937  slangAssert(StmtCtr == 2);
938
939  clang::CompoundStmt *CS =
940      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
941
942  return CS;
943}
944
945static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
946                                            clang::Diagnostic *Diags,
947                                            clang::Expr *LHS,
948                                            clang::Expr *RHS,
949                                            clang::SourceLocation Loc) {
950  clang::QualType QT = LHS->getType();
951  const clang::Type *T = QT.getTypePtr();
952  slangAssert(T->isStructureType());
953  slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
954
955  // Keep an extra slot for the original copy (memcpy)
956  unsigned FieldsToSet = CountRSObjectTypes(T) + 1;
957
958  unsigned StmtCount = 0;
959  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
960  for (unsigned i = 0; i < FieldsToSet; i++) {
961    StmtArray[i] = NULL;
962  }
963
964  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
965  RD = RD->getDefinition();
966  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
967         FE = RD->field_end();
968       FI != FE;
969       FI++) {
970    bool IsArrayType = false;
971    clang::FieldDecl *FD = *FI;
972    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
973    const clang::Type *OrigType = FT;
974
975    if (!CountRSObjectTypes(FT)) {
976      // Skip to next if we don't have any viable RS object types
977      continue;
978    }
979
980    clang::DeclAccessPair FoundDecl =
981        clang::DeclAccessPair::make(FD, clang::AS_none);
982    clang::MemberExpr *DstMember =
983        clang::MemberExpr::Create(C,
984                                  LHS,
985                                  false,
986                                  clang::NestedNameSpecifierLoc(),
987                                  FD,
988                                  FoundDecl,
989                                  clang::DeclarationNameInfo(),
990                                  NULL,
991                                  OrigType->getCanonicalTypeInternal(),
992                                  clang::VK_RValue,
993                                  clang::OK_Ordinary);
994
995    clang::MemberExpr *SrcMember =
996        clang::MemberExpr::Create(C,
997                                  RHS,
998                                  false,
999                                  clang::NestedNameSpecifierLoc(),
1000                                  FD,
1001                                  FoundDecl,
1002                                  clang::DeclarationNameInfo(),
1003                                  NULL,
1004                                  OrigType->getCanonicalTypeInternal(),
1005                                  clang::VK_RValue,
1006                                  clang::OK_Ordinary);
1007
1008    if (FT->isArrayType()) {
1009      FT = FT->getArrayElementTypeNoTypeQual();
1010      IsArrayType = true;
1011    }
1012
1013    RSExportPrimitiveType::DataType DT =
1014        RSExportPrimitiveType::GetRSSpecificType(FT);
1015
1016    if (IsArrayType) {
1017      Diags->Report(clang::FullSourceLoc(Loc, C.getSourceManager()),
1018          Diags->getCustomDiagID(clang::Diagnostic::Error,
1019            "Arrays of RS object types within structures cannot be copied"));
1020      // TODO(srhines): Support setting arrays of RS objects
1021      // StmtArray[StmtCount++] =
1022      //    CreateArrayRSSetObject(C, Diags, DstMember, SrcMember, Loc);
1023    } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
1024      StmtArray[StmtCount++] =
1025          CreateStructRSSetObject(C, Diags, DstMember, SrcMember, Loc);
1026    } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
1027      StmtArray[StmtCount++] =
1028          CreateSingleRSSetObject(C, Diags, DstMember, SrcMember, Loc);
1029    } else {
1030      slangAssert(false);
1031    }
1032  }
1033
1034  slangAssert(StmtCount > 0 && StmtCount < FieldsToSet);
1035
1036  // We still need to actually do the overall struct copy. For simplicity,
1037  // we just do a straight-up assignment (which will still preserve all
1038  // the proper RS object reference counts).
1039  clang::BinaryOperator *CopyStruct =
1040      new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
1041                                   clang::VK_RValue, clang::OK_Ordinary, Loc);
1042  StmtArray[StmtCount++] = CopyStruct;
1043
1044  clang::CompoundStmt *CS =
1045      new(C) clang::CompoundStmt(C, StmtArray, StmtCount, Loc, Loc);
1046
1047  delete [] StmtArray;
1048
1049  return CS;
1050}
1051
1052}  // namespace
1053
1054void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1055    clang::BinaryOperator *AS,
1056    clang::Diagnostic *Diags) {
1057
1058  clang::QualType QT = AS->getType();
1059
1060  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1061      RSExportPrimitiveType::DataTypeRSFont)->getASTContext();
1062
1063  clang::SourceLocation Loc = AS->getExprLoc();
1064  clang::Stmt *UpdatedStmt = NULL;
1065
1066  if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1067    // By definition, this is a struct assignment if we get here
1068    UpdatedStmt =
1069        CreateStructRSSetObject(C, Diags, AS->getLHS(), AS->getRHS(), Loc);
1070  } else {
1071    UpdatedStmt =
1072        CreateSingleRSSetObject(C, Diags, AS->getLHS(), AS->getRHS(), Loc);
1073  }
1074
1075  RSASTReplace R(C);
1076  R.ReplaceStmt(mCS, AS, UpdatedStmt);
1077  return;
1078}
1079
1080void RSObjectRefCount::Scope::AppendRSObjectInit(
1081    clang::Diagnostic *Diags,
1082    clang::VarDecl *VD,
1083    clang::DeclStmt *DS,
1084    RSExportPrimitiveType::DataType DT,
1085    clang::Expr *InitExpr) {
1086  slangAssert(VD);
1087
1088  if (!InitExpr) {
1089    return;
1090  }
1091
1092  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1093      RSExportPrimitiveType::DataTypeRSFont)->getASTContext();
1094  clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1095      RSExportPrimitiveType::DataTypeRSFont)->getLocation();
1096
1097  if (DT == RSExportPrimitiveType::DataTypeIsStruct) {
1098    const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1099    clang::DeclRefExpr *RefRSVar =
1100        clang::DeclRefExpr::Create(C,
1101                                   clang::NestedNameSpecifierLoc(),
1102                                   VD,
1103                                   Loc,
1104                                   T->getCanonicalTypeInternal(),
1105                                   clang::VK_RValue,
1106                                   NULL);
1107
1108    clang::Stmt *RSSetObjectOps =
1109        CreateStructRSSetObject(C, Diags, RefRSVar, InitExpr, Loc);
1110
1111    std::list<clang::Stmt*> StmtList;
1112    StmtList.push_back(RSSetObjectOps);
1113    AppendAfterStmt(C, mCS, DS, StmtList);
1114    return;
1115  }
1116
1117  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1118  slangAssert((SetObjectFD != NULL) &&
1119              "rsSetObject doesn't cover all RS object types");
1120
1121  clang::QualType SetObjectFDType = SetObjectFD->getType();
1122  clang::QualType SetObjectFDArgType[2];
1123  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1124  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1125
1126  clang::Expr *RefRSSetObjectFD =
1127      clang::DeclRefExpr::Create(C,
1128                                 clang::NestedNameSpecifierLoc(),
1129                                 SetObjectFD,
1130                                 Loc,
1131                                 SetObjectFDType,
1132                                 clang::VK_RValue,
1133                                 NULL);
1134
1135  clang::Expr *RSSetObjectFP =
1136      clang::ImplicitCastExpr::Create(C,
1137                                      C.getPointerType(SetObjectFDType),
1138                                      clang::CK_FunctionToPointerDecay,
1139                                      RefRSSetObjectFD,
1140                                      NULL,
1141                                      clang::VK_RValue);
1142
1143  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1144  clang::DeclRefExpr *RefRSVar =
1145      clang::DeclRefExpr::Create(C,
1146                                 clang::NestedNameSpecifierLoc(),
1147                                 VD,
1148                                 Loc,
1149                                 T->getCanonicalTypeInternal(),
1150                                 clang::VK_RValue,
1151                                 NULL);
1152
1153  clang::Expr *ArgList[2];
1154  ArgList[0] = new(C) clang::UnaryOperator(RefRSVar,
1155                                           clang::UO_AddrOf,
1156                                           SetObjectFDArgType[0],
1157                                           clang::VK_RValue,
1158                                           clang::OK_Ordinary,
1159                                           Loc);
1160  ArgList[1] = InitExpr;
1161
1162  clang::CallExpr *RSSetObjectCall =
1163      new(C) clang::CallExpr(C,
1164                             RSSetObjectFP,
1165                             ArgList,
1166                             2,
1167                             SetObjectFD->getCallResultType(),
1168                             clang::VK_RValue,
1169                             Loc);
1170
1171  std::list<clang::Stmt*> StmtList;
1172  StmtList.push_back(RSSetObjectCall);
1173  AppendAfterStmt(C, mCS, DS, StmtList);
1174
1175  return;
1176}
1177
1178void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1179  std::list<clang::Stmt*> RSClearObjectCalls;
1180  for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
1181          E = mRSO.end();
1182        I != E;
1183        I++) {
1184    clang::Stmt *S = ClearRSObject(*I);
1185    if (S) {
1186      RSClearObjectCalls.push_back(S);
1187    }
1188  }
1189  if (RSClearObjectCalls.size() > 0) {
1190    DestructorVisitor DV((*mRSO.begin())->getASTContext(),
1191                         mCS,
1192                         RSClearObjectCalls);
1193    DV.Visit(mCS);
1194    DV.InsertDestructors();
1195  }
1196  return;
1197}
1198
1199clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(clang::VarDecl *VD) {
1200  slangAssert(VD);
1201  clang::ASTContext &C = VD->getASTContext();
1202  clang::DeclContext *DC = VD->getDeclContext();
1203  clang::SourceLocation Loc = VD->getLocation();
1204  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1205
1206  // Reference expr to target RS object variable
1207  clang::DeclRefExpr *RefRSVar =
1208      clang::DeclRefExpr::Create(C,
1209                                 clang::NestedNameSpecifierLoc(),
1210                                 VD,
1211                                 Loc,
1212                                 T->getCanonicalTypeInternal(),
1213                                 clang::VK_RValue,
1214                                 NULL);
1215
1216  if (T->isArrayType()) {
1217    return ClearArrayRSObject(C, DC, RefRSVar, Loc);
1218  }
1219
1220  RSExportPrimitiveType::DataType DT =
1221      RSExportPrimitiveType::GetRSSpecificType(T);
1222
1223  if (DT == RSExportPrimitiveType::DataTypeUnknown ||
1224      DT == RSExportPrimitiveType::DataTypeIsStruct) {
1225    return ClearStructRSObject(C, DC, RefRSVar, Loc);
1226  }
1227
1228  slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1229              "Should be RS object");
1230
1231  return ClearSingleRSObject(C, RefRSVar, Loc);
1232}
1233
1234bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1235                                          RSExportPrimitiveType::DataType *DT,
1236                                          clang::Expr **InitExpr) {
1237  slangAssert(VD && DT && InitExpr);
1238  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1239
1240  // Loop through array types to get to base type
1241  while (T && T->isArrayType()) {
1242    T = T->getArrayElementTypeNoTypeQual();
1243  }
1244
1245  bool DataTypeIsStructWithRSObject = false;
1246  *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1247
1248  if (*DT == RSExportPrimitiveType::DataTypeUnknown) {
1249    if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1250      *DT = RSExportPrimitiveType::DataTypeIsStruct;
1251      DataTypeIsStructWithRSObject = true;
1252    } else {
1253      return false;
1254    }
1255  }
1256
1257  bool DataTypeIsRSObject = false;
1258  if (DataTypeIsStructWithRSObject) {
1259    DataTypeIsRSObject = true;
1260  } else {
1261    DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1262  }
1263  *InitExpr = VD->getInit();
1264
1265  if (!DataTypeIsRSObject && *InitExpr) {
1266    // If we already have an initializer for a matrix type, we are done.
1267    return DataTypeIsRSObject;
1268  }
1269
1270  clang::Expr *ZeroInitializer =
1271      CreateZeroInitializerForRSSpecificType(*DT,
1272                                             VD->getASTContext(),
1273                                             VD->getLocation());
1274
1275  if (ZeroInitializer) {
1276    ZeroInitializer->setType(T->getCanonicalTypeInternal());
1277    VD->setInit(ZeroInitializer);
1278  }
1279
1280  return DataTypeIsRSObject;
1281}
1282
1283clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
1284    RSExportPrimitiveType::DataType DT,
1285    clang::ASTContext &C,
1286    const clang::SourceLocation &Loc) {
1287  clang::Expr *Res = NULL;
1288  switch (DT) {
1289    case RSExportPrimitiveType::DataTypeIsStruct:
1290    case RSExportPrimitiveType::DataTypeRSElement:
1291    case RSExportPrimitiveType::DataTypeRSType:
1292    case RSExportPrimitiveType::DataTypeRSAllocation:
1293    case RSExportPrimitiveType::DataTypeRSSampler:
1294    case RSExportPrimitiveType::DataTypeRSScript:
1295    case RSExportPrimitiveType::DataTypeRSMesh:
1296    case RSExportPrimitiveType::DataTypeRSProgramFragment:
1297    case RSExportPrimitiveType::DataTypeRSProgramVertex:
1298    case RSExportPrimitiveType::DataTypeRSProgramRaster:
1299    case RSExportPrimitiveType::DataTypeRSProgramStore:
1300    case RSExportPrimitiveType::DataTypeRSFont: {
1301      //    (ImplicitCastExpr 'nullptr_t'
1302      //      (IntegerLiteral 0)))
1303      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
1304      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
1305      clang::Expr *CastToNull =
1306          clang::ImplicitCastExpr::Create(C,
1307                                          C.NullPtrTy,
1308                                          clang::CK_IntegralToPointer,
1309                                          Int0,
1310                                          NULL,
1311                                          clang::VK_RValue);
1312
1313      Res = new(C) clang::InitListExpr(C, Loc, &CastToNull, 1, Loc);
1314      break;
1315    }
1316    case RSExportPrimitiveType::DataTypeRSMatrix2x2:
1317    case RSExportPrimitiveType::DataTypeRSMatrix3x3:
1318    case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
1319      // RS matrix is not completely an RS object. They hold data by themselves.
1320      // (InitListExpr rs_matrix2x2
1321      //   (InitListExpr float[4]
1322      //     (FloatingLiteral 0)
1323      //     (FloatingLiteral 0)
1324      //     (FloatingLiteral 0)
1325      //     (FloatingLiteral 0)))
1326      clang::QualType FloatTy = C.FloatTy;
1327      // Constructor sets value to 0.0f by default
1328      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
1329      clang::FloatingLiteral *Float0Val =
1330          clang::FloatingLiteral::Create(C,
1331                                         Val,
1332                                         /* isExact = */true,
1333                                         FloatTy,
1334                                         Loc);
1335
1336      unsigned N = 0;
1337      if (DT == RSExportPrimitiveType::DataTypeRSMatrix2x2)
1338        N = 2;
1339      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix3x3)
1340        N = 3;
1341      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix4x4)
1342        N = 4;
1343
1344      // Directly allocate 16 elements instead of dynamically allocate N*N
1345      clang::Expr *InitVals[16];
1346      for (unsigned i = 0; i < sizeof(InitVals) / sizeof(InitVals[0]); i++)
1347        InitVals[i] = Float0Val;
1348      clang::Expr *InitExpr =
1349          new(C) clang::InitListExpr(C, Loc, InitVals, N * N, Loc);
1350      InitExpr->setType(C.getConstantArrayType(FloatTy,
1351                                               llvm::APInt(32, 4),
1352                                               clang::ArrayType::Normal,
1353                                               /* EltTypeQuals = */0));
1354
1355      Res = new(C) clang::InitListExpr(C, Loc, &InitExpr, 1, Loc);
1356      break;
1357    }
1358    case RSExportPrimitiveType::DataTypeUnknown:
1359    case RSExportPrimitiveType::DataTypeFloat16:
1360    case RSExportPrimitiveType::DataTypeFloat32:
1361    case RSExportPrimitiveType::DataTypeFloat64:
1362    case RSExportPrimitiveType::DataTypeSigned8:
1363    case RSExportPrimitiveType::DataTypeSigned16:
1364    case RSExportPrimitiveType::DataTypeSigned32:
1365    case RSExportPrimitiveType::DataTypeSigned64:
1366    case RSExportPrimitiveType::DataTypeUnsigned8:
1367    case RSExportPrimitiveType::DataTypeUnsigned16:
1368    case RSExportPrimitiveType::DataTypeUnsigned32:
1369    case RSExportPrimitiveType::DataTypeUnsigned64:
1370    case RSExportPrimitiveType::DataTypeBoolean:
1371    case RSExportPrimitiveType::DataTypeUnsigned565:
1372    case RSExportPrimitiveType::DataTypeUnsigned5551:
1373    case RSExportPrimitiveType::DataTypeUnsigned4444:
1374    case RSExportPrimitiveType::DataTypeMax: {
1375      slangAssert(false && "Not RS object type!");
1376    }
1377    // No default case will enable compiler detecting the missing cases
1378  }
1379
1380  return Res;
1381}
1382
1383void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1384  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1385       I != E;
1386       I++) {
1387    clang::Decl *D = *I;
1388    if (D->getKind() == clang::Decl::Var) {
1389      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1390      RSExportPrimitiveType::DataType DT =
1391          RSExportPrimitiveType::DataTypeUnknown;
1392      clang::Expr *InitExpr = NULL;
1393      if (InitializeRSObject(VD, &DT, &InitExpr)) {
1394        getCurrentScope()->addRSObject(VD);
1395        getCurrentScope()->AppendRSObjectInit(mDiags, VD, DS, DT, InitExpr);
1396      }
1397    }
1398  }
1399  return;
1400}
1401
1402void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1403  if (!CS->body_empty()) {
1404    // Push a new scope
1405    Scope *S = new Scope(CS);
1406    mScopeStack.push(S);
1407
1408    VisitStmt(CS);
1409
1410    // Destroy the scope
1411    slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1412    S->InsertLocalVarDestructors();
1413    mScopeStack.pop();
1414    delete S;
1415  }
1416  return;
1417}
1418
1419void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1420  clang::QualType QT = AS->getType();
1421
1422  if (CountRSObjectTypes(QT.getTypePtr())) {
1423    getCurrentScope()->ReplaceRSObjectAssignment(AS, mDiags);
1424  }
1425
1426  return;
1427}
1428
1429void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1430  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1431       I != E;
1432       I++) {
1433    if (clang::Stmt *Child = *I) {
1434      Visit(Child);
1435    }
1436  }
1437  return;
1438}
1439
1440}  // namespace slang
1441