slang_rs_object_ref_count.cpp revision 474655a402e70cb329e1bcd4ebbe00bdc5be4206
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_object_ref_count.h"
18
19#include <list>
20
21#include "clang/AST/DeclGroup.h"
22#include "clang/AST/Expr.h"
23#include "clang/AST/NestedNameSpecifier.h"
24#include "clang/AST/OperationKinds.h"
25#include "clang/AST/Stmt.h"
26#include "clang/AST/StmtVisitor.h"
27
28#include "slang_assert.h"
29#include "slang_rs.h"
30#include "slang_rs_ast_replace.h"
31#include "slang_rs_export_type.h"
32
33namespace slang {
34
35/* Even though those two arrays are of size DataTypeMax, only entries that
36 * correspond to object types will be set.
37 */
38clang::FunctionDecl *
39RSObjectRefCount::RSSetObjectFD[RSExportPrimitiveType::DataTypeMax];
40clang::FunctionDecl *
41RSObjectRefCount::RSClearObjectFD[RSExportPrimitiveType::DataTypeMax];
42
43void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
44  for (unsigned i = 0; i < RSExportPrimitiveType::DataTypeMax; i++) {
45    RSSetObjectFD[i] = NULL;
46    RSClearObjectFD[i] = NULL;
47  }
48
49  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
50
51  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
52          E = TUDecl->decls_end(); I != E; I++) {
53    if ((I->getKind() >= clang::Decl::firstFunction) &&
54        (I->getKind() <= clang::Decl::lastFunction)) {
55      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
56
57      // points to RSSetObjectFD or RSClearObjectFD
58      clang::FunctionDecl **RSObjectFD;
59
60      if (FD->getName() == "rsSetObject") {
61        slangAssert((FD->getNumParams() == 2) &&
62                    "Invalid rsSetObject function prototype (# params)");
63        RSObjectFD = RSSetObjectFD;
64      } else if (FD->getName() == "rsClearObject") {
65        slangAssert((FD->getNumParams() == 1) &&
66                    "Invalid rsClearObject function prototype (# params)");
67        RSObjectFD = RSClearObjectFD;
68      } else {
69        continue;
70      }
71
72      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
73      clang::QualType PVT = PVD->getOriginalType();
74      // The first parameter must be a pointer like rs_allocation*
75      slangAssert(PVT->isPointerType() &&
76          "Invalid rs{Set,Clear}Object function prototype (pointer param)");
77
78      // The rs object type passed to the FD
79      clang::QualType RST = PVT->getPointeeType();
80      RSExportPrimitiveType::DataType DT =
81          RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
82      slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
83             && "must be RS object type");
84
85      if (DT >= 0 && DT < RSExportPrimitiveType::DataTypeMax) {
86          RSObjectFD[DT] = FD;
87      } else {
88          slangAssert(false && "incorrect type");
89      }
90    }
91  }
92}
93
94namespace {
95
96// This function constructs a new CompoundStmt from the input StmtList.
97static clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
98      std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
99  unsigned NewStmtCount = StmtList.size();
100  unsigned CompoundStmtCount = 0;
101
102  clang::Stmt **CompoundStmtList;
103  CompoundStmtList = new clang::Stmt*[NewStmtCount];
104
105  std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
106  std::list<clang::Stmt*>::const_iterator E = StmtList.end();
107  for ( ; I != E; I++) {
108    CompoundStmtList[CompoundStmtCount++] = *I;
109  }
110  slangAssert(CompoundStmtCount == NewStmtCount);
111
112  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
113      C, llvm::makeArrayRef(CompoundStmtList, CompoundStmtCount), Loc, Loc);
114
115  delete [] CompoundStmtList;
116
117  return CS;
118}
119
120static void AppendAfterStmt(clang::ASTContext &C,
121                            clang::CompoundStmt *CS,
122                            clang::Stmt *S,
123                            std::list<clang::Stmt*> &StmtList) {
124  slangAssert(CS);
125  clang::CompoundStmt::body_iterator bI = CS->body_begin();
126  clang::CompoundStmt::body_iterator bE = CS->body_end();
127  clang::Stmt **UpdatedStmtList =
128      new clang::Stmt*[CS->size() + StmtList.size()];
129
130  unsigned UpdatedStmtCount = 0;
131  unsigned Once = 0;
132  for ( ; bI != bE; bI++) {
133    if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
134      // If we come across a return here, we don't have anything we can
135      // reasonably replace. We should have already inserted our destructor
136      // code in the proper spot, so we just clean up and return.
137      delete [] UpdatedStmtList;
138
139      return;
140    }
141
142    UpdatedStmtList[UpdatedStmtCount++] = *bI;
143
144    if ((*bI == S) && !Once) {
145      Once++;
146      std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
147      std::list<clang::Stmt*>::const_iterator E = StmtList.end();
148      for ( ; I != E; I++) {
149        UpdatedStmtList[UpdatedStmtCount++] = *I;
150      }
151    }
152  }
153  slangAssert(Once <= 1);
154
155  // When S is NULL, we are appending to the end of the CompoundStmt.
156  if (!S) {
157    slangAssert(Once == 0);
158    std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
159    std::list<clang::Stmt*>::const_iterator E = StmtList.end();
160    for ( ; I != E; I++) {
161      UpdatedStmtList[UpdatedStmtCount++] = *I;
162    }
163  }
164
165  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
166
167  delete [] UpdatedStmtList;
168
169  return;
170}
171
172// This class visits a compound statement and inserts DtorStmt
173// in proper locations. This includes inserting it before any
174// return statement in any sub-block, at the end of the logical enclosing
175// scope (compound statement), and/or before any break/continue statement that
176// would resume outside the declared scope. We will not handle the case for
177// goto statements that leave a local scope.
178//
179// To accomplish these goals, it collects a list of sub-Stmt's that
180// correspond to scope exit points. It then uses an RSASTReplace visitor to
181// transform the AST, inserting appropriate destructors before each of those
182// sub-Stmt's (and also before the exit of the outermost containing Stmt for
183// the scope).
184class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
185 private:
186  clang::ASTContext &mCtx;
187
188  // The loop depth of the currently visited node.
189  int mLoopDepth;
190
191  // The switch statement depth of the currently visited node.
192  // Note that this is tracked separately from the loop depth because
193  // SwitchStmt-contained ContinueStmt's should have destructors for the
194  // corresponding loop scope.
195  int mSwitchDepth;
196
197  // The outermost statement block that we are currently visiting.
198  // This should always be a CompoundStmt.
199  clang::Stmt *mOuterStmt;
200
201  // The destructor to execute for this scope/variable.
202  clang::Stmt* mDtorStmt;
203
204  // The stack of statements which should be replaced by a compound statement
205  // containing the new destructor call followed by the original Stmt.
206  std::stack<clang::Stmt*> mReplaceStmtStack;
207
208  // The source location for the variable declaration that we are trying to
209  // insert destructors for. Note that InsertDestructors() will not generate
210  // destructor calls for source locations that occur lexically before this
211  // location.
212  clang::SourceLocation mVarLoc;
213
214 public:
215  DestructorVisitor(clang::ASTContext &C,
216                    clang::Stmt* OuterStmt,
217                    clang::Stmt* DtorStmt,
218                    clang::SourceLocation VarLoc);
219
220  // This code walks the collected list of Stmts to replace and actually does
221  // the replacement. It also finishes up by appending the destructor to the
222  // current outermost CompoundStmt.
223  void InsertDestructors() {
224    clang::Stmt *S = NULL;
225    clang::SourceManager &SM = mCtx.getSourceManager();
226    std::list<clang::Stmt *> StmtList;
227    StmtList.push_back(mDtorStmt);
228
229    while (!mReplaceStmtStack.empty()) {
230      S = mReplaceStmtStack.top();
231      mReplaceStmtStack.pop();
232
233      // Skip all source locations that occur before the variable's
234      // declaration, since it won't have been initialized yet.
235      if (SM.isBeforeInTranslationUnit(S->getLocStart(), mVarLoc)) {
236        continue;
237      }
238
239      StmtList.push_back(S);
240      clang::CompoundStmt *CS =
241          BuildCompoundStmt(mCtx, StmtList, S->getLocEnd());
242      StmtList.pop_back();
243
244      RSASTReplace R(mCtx);
245      R.ReplaceStmt(mOuterStmt, S, CS);
246    }
247    clang::CompoundStmt *CS =
248      llvm::dyn_cast<clang::CompoundStmt>(mOuterStmt);
249    slangAssert(CS);
250    AppendAfterStmt(mCtx, CS, NULL, StmtList);
251  }
252
253  void VisitStmt(clang::Stmt *S);
254  void VisitCompoundStmt(clang::CompoundStmt *CS);
255
256  void VisitBreakStmt(clang::BreakStmt *BS);
257  void VisitCaseStmt(clang::CaseStmt *CS);
258  void VisitContinueStmt(clang::ContinueStmt *CS);
259  void VisitDefaultStmt(clang::DefaultStmt *DS);
260  void VisitDoStmt(clang::DoStmt *DS);
261  void VisitForStmt(clang::ForStmt *FS);
262  void VisitIfStmt(clang::IfStmt *IS);
263  void VisitReturnStmt(clang::ReturnStmt *RS);
264  void VisitSwitchCase(clang::SwitchCase *SC);
265  void VisitSwitchStmt(clang::SwitchStmt *SS);
266  void VisitWhileStmt(clang::WhileStmt *WS);
267};
268
269DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
270                         clang::Stmt *OuterStmt,
271                         clang::Stmt *DtorStmt,
272                         clang::SourceLocation VarLoc)
273  : mCtx(C),
274    mLoopDepth(0),
275    mSwitchDepth(0),
276    mOuterStmt(OuterStmt),
277    mDtorStmt(DtorStmt),
278    mVarLoc(VarLoc) {
279  return;
280}
281
282void DestructorVisitor::VisitStmt(clang::Stmt *S) {
283  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
284       I != E;
285       I++) {
286    if (clang::Stmt *Child = *I) {
287      Visit(Child);
288    }
289  }
290  return;
291}
292
293void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
294  VisitStmt(CS);
295  return;
296}
297
298void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
299  VisitStmt(BS);
300  if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
301    mReplaceStmtStack.push(BS);
302  }
303  return;
304}
305
306void DestructorVisitor::VisitCaseStmt(clang::CaseStmt *CS) {
307  VisitStmt(CS);
308  return;
309}
310
311void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
312  VisitStmt(CS);
313  if (mLoopDepth == 0) {
314    // Switch statements can have nested continues.
315    mReplaceStmtStack.push(CS);
316  }
317  return;
318}
319
320void DestructorVisitor::VisitDefaultStmt(clang::DefaultStmt *DS) {
321  VisitStmt(DS);
322  return;
323}
324
325void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
326  mLoopDepth++;
327  VisitStmt(DS);
328  mLoopDepth--;
329  return;
330}
331
332void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
333  mLoopDepth++;
334  VisitStmt(FS);
335  mLoopDepth--;
336  return;
337}
338
339void DestructorVisitor::VisitIfStmt(clang::IfStmt *IS) {
340  VisitStmt(IS);
341  return;
342}
343
344void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
345  mReplaceStmtStack.push(RS);
346  return;
347}
348
349void DestructorVisitor::VisitSwitchCase(clang::SwitchCase *SC) {
350  slangAssert(false && "Both case and default have specialized handlers");
351  VisitStmt(SC);
352  return;
353}
354
355void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
356  mSwitchDepth++;
357  VisitStmt(SS);
358  mSwitchDepth--;
359  return;
360}
361
362void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
363  mLoopDepth++;
364  VisitStmt(WS);
365  mLoopDepth--;
366  return;
367}
368
369clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
370                                 clang::Expr *RefRSVar,
371                                 clang::SourceLocation Loc) {
372  slangAssert(RefRSVar);
373  const clang::Type *T = RefRSVar->getType().getTypePtr();
374  slangAssert(!T->isArrayType() &&
375              "Should not be destroying arrays with this function");
376
377  clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
378  slangAssert((ClearObjectFD != NULL) &&
379              "rsClearObject doesn't cover all RS object types");
380
381  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
382  clang::QualType ClearObjectFDArgType =
383      ClearObjectFD->getParamDecl(0)->getOriginalType();
384
385  // Example destructor for "rs_font localFont;"
386  //
387  // (CallExpr 'void'
388  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
389  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
390  //   (UnaryOperator 'rs_font *' prefix '&'
391  //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
392
393  // Get address of targeted RS object
394  clang::Expr *AddrRefRSVar =
395      new(C) clang::UnaryOperator(RefRSVar,
396                                  clang::UO_AddrOf,
397                                  ClearObjectFDArgType,
398                                  clang::VK_RValue,
399                                  clang::OK_Ordinary,
400                                  Loc);
401
402  clang::Expr *RefRSClearObjectFD =
403      clang::DeclRefExpr::Create(C,
404                                 clang::NestedNameSpecifierLoc(),
405                                 clang::SourceLocation(),
406                                 ClearObjectFD,
407                                 false,
408                                 ClearObjectFD->getLocation(),
409                                 ClearObjectFDType,
410                                 clang::VK_RValue,
411                                 NULL);
412
413  clang::Expr *RSClearObjectFP =
414      clang::ImplicitCastExpr::Create(C,
415                                      C.getPointerType(ClearObjectFDType),
416                                      clang::CK_FunctionToPointerDecay,
417                                      RefRSClearObjectFD,
418                                      NULL,
419                                      clang::VK_RValue);
420
421  llvm::SmallVector<clang::Expr*, 1> ArgList;
422  ArgList.push_back(AddrRefRSVar);
423
424  clang::CallExpr *RSClearObjectCall =
425      new(C) clang::CallExpr(C,
426                             RSClearObjectFP,
427                             ArgList,
428                             ClearObjectFD->getCallResultType(),
429                             clang::VK_RValue,
430                             Loc);
431
432  return RSClearObjectCall;
433}
434
435static int ArrayDim(const clang::Type *T) {
436  if (!T || !T->isArrayType()) {
437    return 0;
438  }
439
440  const clang::ConstantArrayType *CAT =
441    static_cast<const clang::ConstantArrayType *>(T);
442  return static_cast<int>(CAT->getSize().getSExtValue());
443}
444
445static clang::Stmt *ClearStructRSObject(
446    clang::ASTContext &C,
447    clang::DeclContext *DC,
448    clang::Expr *RefRSStruct,
449    clang::SourceLocation StartLoc,
450    clang::SourceLocation Loc);
451
452static clang::Stmt *ClearArrayRSObject(
453    clang::ASTContext &C,
454    clang::DeclContext *DC,
455    clang::Expr *RefRSArr,
456    clang::SourceLocation StartLoc,
457    clang::SourceLocation Loc) {
458  const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
459  slangAssert(BaseType->isArrayType());
460
461  int NumArrayElements = ArrayDim(BaseType);
462  // Actually extract out the base RS object type for use later
463  BaseType = BaseType->getArrayElementTypeNoTypeQual();
464
465  clang::Stmt *StmtArray[2] = {NULL};
466  int StmtCtr = 0;
467
468  if (NumArrayElements <= 0) {
469    return NULL;
470  }
471
472  // Example destructor loop for "rs_font fontArr[10];"
473  //
474  // (CompoundStmt
475  //   (DeclStmt "int rsIntIter")
476  //   (ForStmt
477  //     (BinaryOperator 'int' '='
478  //       (DeclRefExpr 'int' Var='rsIntIter')
479  //       (IntegerLiteral 'int' 0))
480  //     (BinaryOperator 'int' '<'
481  //       (DeclRefExpr 'int' Var='rsIntIter')
482  //       (IntegerLiteral 'int' 10)
483  //     NULL << CondVar >>
484  //     (UnaryOperator 'int' postfix '++'
485  //       (DeclRefExpr 'int' Var='rsIntIter'))
486  //     (CallExpr 'void'
487  //       (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
488  //         (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
489  //       (UnaryOperator 'rs_font *' prefix '&'
490  //         (ArraySubscriptExpr 'rs_font':'rs_font'
491  //           (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
492  //             (DeclRefExpr 'rs_font [10]' Var='fontArr'))
493  //           (DeclRefExpr 'int' Var='rsIntIter')))))))
494
495  // Create helper variable for iterating through elements
496  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
497  clang::VarDecl *IIVD =
498      clang::VarDecl::Create(C,
499                             DC,
500                             StartLoc,
501                             Loc,
502                             &II,
503                             C.IntTy,
504                             C.getTrivialTypeSourceInfo(C.IntTy),
505                             clang::SC_None);
506  clang::Decl *IID = (clang::Decl *)IIVD;
507
508  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
509  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
510
511  // Form the actual destructor loop
512  // for (Init; Cond; Inc)
513  //   RSClearObjectCall;
514
515  // Init -> "rsIntIter = 0"
516  clang::DeclRefExpr *RefrsIntIter =
517      clang::DeclRefExpr::Create(C,
518                                 clang::NestedNameSpecifierLoc(),
519                                 clang::SourceLocation(),
520                                 IIVD,
521                                 false,
522                                 Loc,
523                                 C.IntTy,
524                                 clang::VK_RValue,
525                                 NULL);
526
527  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
528      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
529
530  clang::BinaryOperator *Init =
531      new(C) clang::BinaryOperator(RefrsIntIter,
532                                   Int0,
533                                   clang::BO_Assign,
534                                   C.IntTy,
535                                   clang::VK_RValue,
536                                   clang::OK_Ordinary,
537                                   Loc,
538                                   false);
539
540  // Cond -> "rsIntIter < NumArrayElements"
541  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
542      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
543
544  clang::BinaryOperator *Cond =
545      new(C) clang::BinaryOperator(RefrsIntIter,
546                                   NumArrayElementsExpr,
547                                   clang::BO_LT,
548                                   C.IntTy,
549                                   clang::VK_RValue,
550                                   clang::OK_Ordinary,
551                                   Loc,
552                                   false);
553
554  // Inc -> "rsIntIter++"
555  clang::UnaryOperator *Inc =
556      new(C) clang::UnaryOperator(RefrsIntIter,
557                                  clang::UO_PostInc,
558                                  C.IntTy,
559                                  clang::VK_RValue,
560                                  clang::OK_Ordinary,
561                                  Loc);
562
563  // Body -> "rsClearObject(&VD[rsIntIter]);"
564  // Destructor loop operates on individual array elements
565
566  clang::Expr *RefRSArrPtr =
567      clang::ImplicitCastExpr::Create(C,
568          C.getPointerType(BaseType->getCanonicalTypeInternal()),
569          clang::CK_ArrayToPointerDecay,
570          RefRSArr,
571          NULL,
572          clang::VK_RValue);
573
574  clang::Expr *RefRSArrPtrSubscript =
575      new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
576                                       RefrsIntIter,
577                                       BaseType->getCanonicalTypeInternal(),
578                                       clang::VK_RValue,
579                                       clang::OK_Ordinary,
580                                       Loc);
581
582  RSExportPrimitiveType::DataType DT =
583      RSExportPrimitiveType::GetRSSpecificType(BaseType);
584
585  clang::Stmt *RSClearObjectCall = NULL;
586  if (BaseType->isArrayType()) {
587    RSClearObjectCall =
588        ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
589  } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
590    RSClearObjectCall =
591        ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
592  } else {
593    RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
594  }
595
596  clang::ForStmt *DestructorLoop =
597      new(C) clang::ForStmt(C,
598                            Init,
599                            Cond,
600                            NULL,  // no condVar
601                            Inc,
602                            RSClearObjectCall,
603                            Loc,
604                            Loc,
605                            Loc);
606
607  StmtArray[StmtCtr++] = DestructorLoop;
608  slangAssert(StmtCtr == 2);
609
610  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
611      C, llvm::makeArrayRef(StmtArray, StmtCtr), Loc, Loc);
612
613  return CS;
614}
615
616static unsigned CountRSObjectTypes(clang::ASTContext &C,
617                                   const clang::Type *T,
618                                   clang::SourceLocation Loc) {
619  slangAssert(T);
620  unsigned RSObjectCount = 0;
621
622  if (T->isArrayType()) {
623    return CountRSObjectTypes(C, T->getArrayElementTypeNoTypeQual(), Loc);
624  }
625
626  RSExportPrimitiveType::DataType DT =
627      RSExportPrimitiveType::GetRSSpecificType(T);
628  if (DT != RSExportPrimitiveType::DataTypeUnknown) {
629    return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
630  }
631
632  if (T->isUnionType()) {
633    clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
634    RD = RD->getDefinition();
635    for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
636           FE = RD->field_end();
637         FI != FE;
638         FI++) {
639      const clang::FieldDecl *FD = *FI;
640      const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
641      if (CountRSObjectTypes(C, FT, Loc)) {
642        slangAssert(false && "can't have unions with RS object types!");
643        return 0;
644      }
645    }
646  }
647
648  if (!T->isStructureType()) {
649    return 0;
650  }
651
652  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
653  RD = RD->getDefinition();
654  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
655         FE = RD->field_end();
656       FI != FE;
657       FI++) {
658    const clang::FieldDecl *FD = *FI;
659    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
660    if (CountRSObjectTypes(C, FT, Loc)) {
661      // Sub-structs should only count once (as should arrays, etc.)
662      RSObjectCount++;
663    }
664  }
665
666  return RSObjectCount;
667}
668
669static clang::Stmt *ClearStructRSObject(
670    clang::ASTContext &C,
671    clang::DeclContext *DC,
672    clang::Expr *RefRSStruct,
673    clang::SourceLocation StartLoc,
674    clang::SourceLocation Loc) {
675  const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
676
677  slangAssert(!BaseType->isArrayType());
678
679  // Structs should show up as unknown primitive types
680  slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
681              RSExportPrimitiveType::DataTypeUnknown);
682
683  unsigned FieldsToDestroy = CountRSObjectTypes(C, BaseType, Loc);
684  slangAssert(FieldsToDestroy != 0);
685
686  unsigned StmtCount = 0;
687  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
688  for (unsigned i = 0; i < FieldsToDestroy; i++) {
689    StmtArray[i] = NULL;
690  }
691
692  // Populate StmtArray by creating a destructor for each RS object field
693  clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
694  RD = RD->getDefinition();
695  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
696         FE = RD->field_end();
697       FI != FE;
698       FI++) {
699    // We just look through all field declarations to see if we find a
700    // declaration for an RS object type (or an array of one).
701    bool IsArrayType = false;
702    clang::FieldDecl *FD = *FI;
703    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
704    const clang::Type *OrigType = FT;
705    while (FT && FT->isArrayType()) {
706      FT = FT->getArrayElementTypeNoTypeQual();
707      IsArrayType = true;
708    }
709
710    if (RSExportPrimitiveType::IsRSObjectType(FT)) {
711      clang::DeclAccessPair FoundDecl =
712          clang::DeclAccessPair::make(FD, clang::AS_none);
713      clang::MemberExpr *RSObjectMember =
714          clang::MemberExpr::Create(C,
715                                    RefRSStruct,
716                                    false,
717                                    clang::NestedNameSpecifierLoc(),
718                                    clang::SourceLocation(),
719                                    FD,
720                                    FoundDecl,
721                                    clang::DeclarationNameInfo(),
722                                    NULL,
723                                    OrigType->getCanonicalTypeInternal(),
724                                    clang::VK_RValue,
725                                    clang::OK_Ordinary);
726
727      slangAssert(StmtCount < FieldsToDestroy);
728
729      if (IsArrayType) {
730        StmtArray[StmtCount++] = ClearArrayRSObject(C,
731                                                    DC,
732                                                    RSObjectMember,
733                                                    StartLoc,
734                                                    Loc);
735      } else {
736        StmtArray[StmtCount++] = ClearSingleRSObject(C,
737                                                     RSObjectMember,
738                                                     Loc);
739      }
740    } else if (FT->isStructureType() && CountRSObjectTypes(C, FT, Loc)) {
741      // In this case, we have a nested struct. We may not end up filling all
742      // of the spaces in StmtArray (sub-structs should handle themselves
743      // with separate compound statements).
744      clang::DeclAccessPair FoundDecl =
745          clang::DeclAccessPair::make(FD, clang::AS_none);
746      clang::MemberExpr *RSObjectMember =
747          clang::MemberExpr::Create(C,
748                                    RefRSStruct,
749                                    false,
750                                    clang::NestedNameSpecifierLoc(),
751                                    clang::SourceLocation(),
752                                    FD,
753                                    FoundDecl,
754                                    clang::DeclarationNameInfo(),
755                                    NULL,
756                                    OrigType->getCanonicalTypeInternal(),
757                                    clang::VK_RValue,
758                                    clang::OK_Ordinary);
759
760      if (IsArrayType) {
761        StmtArray[StmtCount++] = ClearArrayRSObject(C,
762                                                    DC,
763                                                    RSObjectMember,
764                                                    StartLoc,
765                                                    Loc);
766      } else {
767        StmtArray[StmtCount++] = ClearStructRSObject(C,
768                                                     DC,
769                                                     RSObjectMember,
770                                                     StartLoc,
771                                                     Loc);
772      }
773    }
774  }
775
776  slangAssert(StmtCount > 0);
777  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
778      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
779
780  delete [] StmtArray;
781
782  return CS;
783}
784
785static clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
786                                            clang::Expr *DstExpr,
787                                            clang::Expr *SrcExpr,
788                                            clang::SourceLocation StartLoc,
789                                            clang::SourceLocation Loc) {
790  const clang::Type *T = DstExpr->getType().getTypePtr();
791  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
792  slangAssert((SetObjectFD != NULL) &&
793              "rsSetObject doesn't cover all RS object types");
794
795  clang::QualType SetObjectFDType = SetObjectFD->getType();
796  clang::QualType SetObjectFDArgType[2];
797  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
798  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
799
800  clang::Expr *RefRSSetObjectFD =
801      clang::DeclRefExpr::Create(C,
802                                 clang::NestedNameSpecifierLoc(),
803                                 clang::SourceLocation(),
804                                 SetObjectFD,
805                                 false,
806                                 Loc,
807                                 SetObjectFDType,
808                                 clang::VK_RValue,
809                                 NULL);
810
811  clang::Expr *RSSetObjectFP =
812      clang::ImplicitCastExpr::Create(C,
813                                      C.getPointerType(SetObjectFDType),
814                                      clang::CK_FunctionToPointerDecay,
815                                      RefRSSetObjectFD,
816                                      NULL,
817                                      clang::VK_RValue);
818
819  llvm::SmallVector<clang::Expr*, 2> ArgList;
820  ArgList.push_back(new(C) clang::UnaryOperator(DstExpr,
821                                                clang::UO_AddrOf,
822                                                SetObjectFDArgType[0],
823                                                clang::VK_RValue,
824                                                clang::OK_Ordinary,
825                                                Loc));
826  ArgList.push_back(SrcExpr);
827
828  clang::CallExpr *RSSetObjectCall =
829      new(C) clang::CallExpr(C,
830                             RSSetObjectFP,
831                             ArgList,
832                             SetObjectFD->getCallResultType(),
833                             clang::VK_RValue,
834                             Loc);
835
836  return RSSetObjectCall;
837}
838
839static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
840                                            clang::Expr *LHS,
841                                            clang::Expr *RHS,
842                                            clang::SourceLocation StartLoc,
843                                            clang::SourceLocation Loc);
844
845/*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
846                                           clang::Expr *DstArr,
847                                           clang::Expr *SrcArr,
848                                           clang::SourceLocation StartLoc,
849                                           clang::SourceLocation Loc) {
850  clang::DeclContext *DC = NULL;
851  const clang::Type *BaseType = DstArr->getType().getTypePtr();
852  slangAssert(BaseType->isArrayType());
853
854  int NumArrayElements = ArrayDim(BaseType);
855  // Actually extract out the base RS object type for use later
856  BaseType = BaseType->getArrayElementTypeNoTypeQual();
857
858  clang::Stmt *StmtArray[2] = {NULL};
859  int StmtCtr = 0;
860
861  if (NumArrayElements <= 0) {
862    return NULL;
863  }
864
865  // Create helper variable for iterating through elements
866  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
867  clang::VarDecl *IIVD =
868      clang::VarDecl::Create(C,
869                             DC,
870                             StartLoc,
871                             Loc,
872                             &II,
873                             C.IntTy,
874                             C.getTrivialTypeSourceInfo(C.IntTy),
875                             clang::SC_None,
876                             clang::SC_None);
877  clang::Decl *IID = (clang::Decl *)IIVD;
878
879  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
880  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
881
882  // Form the actual loop
883  // for (Init; Cond; Inc)
884  //   RSSetObjectCall;
885
886  // Init -> "rsIntIter = 0"
887  clang::DeclRefExpr *RefrsIntIter =
888      clang::DeclRefExpr::Create(C,
889                                 clang::NestedNameSpecifierLoc(),
890                                 IIVD,
891                                 Loc,
892                                 C.IntTy,
893                                 clang::VK_RValue,
894                                 NULL);
895
896  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
897      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
898
899  clang::BinaryOperator *Init =
900      new(C) clang::BinaryOperator(RefrsIntIter,
901                                   Int0,
902                                   clang::BO_Assign,
903                                   C.IntTy,
904                                   clang::VK_RValue,
905                                   clang::OK_Ordinary,
906                                   Loc);
907
908  // Cond -> "rsIntIter < NumArrayElements"
909  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
910      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
911
912  clang::BinaryOperator *Cond =
913      new(C) clang::BinaryOperator(RefrsIntIter,
914                                   NumArrayElementsExpr,
915                                   clang::BO_LT,
916                                   C.IntTy,
917                                   clang::VK_RValue,
918                                   clang::OK_Ordinary,
919                                   Loc);
920
921  // Inc -> "rsIntIter++"
922  clang::UnaryOperator *Inc =
923      new(C) clang::UnaryOperator(RefrsIntIter,
924                                  clang::UO_PostInc,
925                                  C.IntTy,
926                                  clang::VK_RValue,
927                                  clang::OK_Ordinary,
928                                  Loc);
929
930  // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
931  // Loop operates on individual array elements
932
933  clang::Expr *DstArrPtr =
934      clang::ImplicitCastExpr::Create(C,
935          C.getPointerType(BaseType->getCanonicalTypeInternal()),
936          clang::CK_ArrayToPointerDecay,
937          DstArr,
938          NULL,
939          clang::VK_RValue);
940
941  clang::Expr *DstArrPtrSubscript =
942      new(C) clang::ArraySubscriptExpr(DstArrPtr,
943                                       RefrsIntIter,
944                                       BaseType->getCanonicalTypeInternal(),
945                                       clang::VK_RValue,
946                                       clang::OK_Ordinary,
947                                       Loc);
948
949  clang::Expr *SrcArrPtr =
950      clang::ImplicitCastExpr::Create(C,
951          C.getPointerType(BaseType->getCanonicalTypeInternal()),
952          clang::CK_ArrayToPointerDecay,
953          SrcArr,
954          NULL,
955          clang::VK_RValue);
956
957  clang::Expr *SrcArrPtrSubscript =
958      new(C) clang::ArraySubscriptExpr(SrcArrPtr,
959                                       RefrsIntIter,
960                                       BaseType->getCanonicalTypeInternal(),
961                                       clang::VK_RValue,
962                                       clang::OK_Ordinary,
963                                       Loc);
964
965  RSExportPrimitiveType::DataType DT =
966      RSExportPrimitiveType::GetRSSpecificType(BaseType);
967
968  clang::Stmt *RSSetObjectCall = NULL;
969  if (BaseType->isArrayType()) {
970    RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
971                                             SrcArrPtrSubscript,
972                                             StartLoc, Loc);
973  } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
974    RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
975                                              SrcArrPtrSubscript,
976                                              StartLoc, Loc);
977  } else {
978    RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
979                                              SrcArrPtrSubscript,
980                                              StartLoc, Loc);
981  }
982
983  clang::ForStmt *DestructorLoop =
984      new(C) clang::ForStmt(C,
985                            Init,
986                            Cond,
987                            NULL,  // no condVar
988                            Inc,
989                            RSSetObjectCall,
990                            Loc,
991                            Loc,
992                            Loc);
993
994  StmtArray[StmtCtr++] = DestructorLoop;
995  slangAssert(StmtCtr == 2);
996
997  clang::CompoundStmt *CS =
998      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
999
1000  return CS;
1001} */
1002
1003static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
1004                                            clang::Expr *LHS,
1005                                            clang::Expr *RHS,
1006                                            clang::SourceLocation StartLoc,
1007                                            clang::SourceLocation Loc) {
1008  clang::QualType QT = LHS->getType();
1009  const clang::Type *T = QT.getTypePtr();
1010  slangAssert(T->isStructureType());
1011  slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
1012
1013  // Keep an extra slot for the original copy (memcpy)
1014  unsigned FieldsToSet = CountRSObjectTypes(C, T, Loc) + 1;
1015
1016  unsigned StmtCount = 0;
1017  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
1018  for (unsigned i = 0; i < FieldsToSet; i++) {
1019    StmtArray[i] = NULL;
1020  }
1021
1022  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
1023  RD = RD->getDefinition();
1024  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1025         FE = RD->field_end();
1026       FI != FE;
1027       FI++) {
1028    bool IsArrayType = false;
1029    clang::FieldDecl *FD = *FI;
1030    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1031    const clang::Type *OrigType = FT;
1032
1033    if (!CountRSObjectTypes(C, FT, Loc)) {
1034      // Skip to next if we don't have any viable RS object types
1035      continue;
1036    }
1037
1038    clang::DeclAccessPair FoundDecl =
1039        clang::DeclAccessPair::make(FD, clang::AS_none);
1040    clang::MemberExpr *DstMember =
1041        clang::MemberExpr::Create(C,
1042                                  LHS,
1043                                  false,
1044                                  clang::NestedNameSpecifierLoc(),
1045                                  clang::SourceLocation(),
1046                                  FD,
1047                                  FoundDecl,
1048                                  clang::DeclarationNameInfo(),
1049                                  NULL,
1050                                  OrigType->getCanonicalTypeInternal(),
1051                                  clang::VK_RValue,
1052                                  clang::OK_Ordinary);
1053
1054    clang::MemberExpr *SrcMember =
1055        clang::MemberExpr::Create(C,
1056                                  RHS,
1057                                  false,
1058                                  clang::NestedNameSpecifierLoc(),
1059                                  clang::SourceLocation(),
1060                                  FD,
1061                                  FoundDecl,
1062                                  clang::DeclarationNameInfo(),
1063                                  NULL,
1064                                  OrigType->getCanonicalTypeInternal(),
1065                                  clang::VK_RValue,
1066                                  clang::OK_Ordinary);
1067
1068    if (FT->isArrayType()) {
1069      FT = FT->getArrayElementTypeNoTypeQual();
1070      IsArrayType = true;
1071    }
1072
1073    RSExportPrimitiveType::DataType DT =
1074        RSExportPrimitiveType::GetRSSpecificType(FT);
1075
1076    if (IsArrayType) {
1077      clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
1078      DiagEngine.Report(
1079        clang::FullSourceLoc(Loc, C.getSourceManager()),
1080        DiagEngine.getCustomDiagID(
1081          clang::DiagnosticsEngine::Error,
1082          "Arrays of RS object types within structures cannot be copied"));
1083      // TODO(srhines): Support setting arrays of RS objects
1084      // StmtArray[StmtCount++] =
1085      //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1086    } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
1087      StmtArray[StmtCount++] =
1088          CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1089    } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
1090      StmtArray[StmtCount++] =
1091          CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1092    } else {
1093      slangAssert(false);
1094    }
1095  }
1096
1097  slangAssert(StmtCount < FieldsToSet);
1098
1099  // We still need to actually do the overall struct copy. For simplicity,
1100  // we just do a straight-up assignment (which will still preserve all
1101  // the proper RS object reference counts).
1102  clang::BinaryOperator *CopyStruct =
1103      new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
1104                                   clang::VK_RValue, clang::OK_Ordinary, Loc,
1105                                   false);
1106  StmtArray[StmtCount++] = CopyStruct;
1107
1108  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
1109      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
1110
1111  delete [] StmtArray;
1112
1113  return CS;
1114}
1115
1116}  // namespace
1117
1118void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1119    clang::BinaryOperator *AS) {
1120
1121  clang::QualType QT = AS->getType();
1122
1123  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1124      RSExportPrimitiveType::DataTypeRSFont)->getASTContext();
1125
1126  clang::SourceLocation Loc = AS->getExprLoc();
1127  clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
1128  clang::Stmt *UpdatedStmt = NULL;
1129
1130  if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1131    // By definition, this is a struct assignment if we get here
1132    UpdatedStmt =
1133        CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1134  } else {
1135    UpdatedStmt =
1136        CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1137  }
1138
1139  RSASTReplace R(C);
1140  R.ReplaceStmt(mCS, AS, UpdatedStmt);
1141  return;
1142}
1143
1144void RSObjectRefCount::Scope::AppendRSObjectInit(
1145    clang::VarDecl *VD,
1146    clang::DeclStmt *DS,
1147    RSExportPrimitiveType::DataType DT,
1148    clang::Expr *InitExpr) {
1149  slangAssert(VD);
1150
1151  if (!InitExpr) {
1152    return;
1153  }
1154
1155  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1156      RSExportPrimitiveType::DataTypeRSFont)->getASTContext();
1157  clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1158      RSExportPrimitiveType::DataTypeRSFont)->getLocation();
1159  clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1160      RSExportPrimitiveType::DataTypeRSFont)->getInnerLocStart();
1161
1162  if (DT == RSExportPrimitiveType::DataTypeIsStruct) {
1163    const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1164    clang::DeclRefExpr *RefRSVar =
1165        clang::DeclRefExpr::Create(C,
1166                                   clang::NestedNameSpecifierLoc(),
1167                                   clang::SourceLocation(),
1168                                   VD,
1169                                   false,
1170                                   Loc,
1171                                   T->getCanonicalTypeInternal(),
1172                                   clang::VK_RValue,
1173                                   NULL);
1174
1175    clang::Stmt *RSSetObjectOps =
1176        CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1177
1178    std::list<clang::Stmt*> StmtList;
1179    StmtList.push_back(RSSetObjectOps);
1180    AppendAfterStmt(C, mCS, DS, StmtList);
1181    return;
1182  }
1183
1184  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1185  slangAssert((SetObjectFD != NULL) &&
1186              "rsSetObject doesn't cover all RS object types");
1187
1188  clang::QualType SetObjectFDType = SetObjectFD->getType();
1189  clang::QualType SetObjectFDArgType[2];
1190  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1191  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1192
1193  clang::Expr *RefRSSetObjectFD =
1194      clang::DeclRefExpr::Create(C,
1195                                 clang::NestedNameSpecifierLoc(),
1196                                 clang::SourceLocation(),
1197                                 SetObjectFD,
1198                                 false,
1199                                 Loc,
1200                                 SetObjectFDType,
1201                                 clang::VK_RValue,
1202                                 NULL);
1203
1204  clang::Expr *RSSetObjectFP =
1205      clang::ImplicitCastExpr::Create(C,
1206                                      C.getPointerType(SetObjectFDType),
1207                                      clang::CK_FunctionToPointerDecay,
1208                                      RefRSSetObjectFD,
1209                                      NULL,
1210                                      clang::VK_RValue);
1211
1212  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1213  clang::DeclRefExpr *RefRSVar =
1214      clang::DeclRefExpr::Create(C,
1215                                 clang::NestedNameSpecifierLoc(),
1216                                 clang::SourceLocation(),
1217                                 VD,
1218                                 false,
1219                                 Loc,
1220                                 T->getCanonicalTypeInternal(),
1221                                 clang::VK_RValue,
1222                                 NULL);
1223
1224  llvm::SmallVector<clang::Expr*, 2> ArgList;
1225  ArgList.push_back(new(C) clang::UnaryOperator(RefRSVar,
1226                                                clang::UO_AddrOf,
1227                                                SetObjectFDArgType[0],
1228                                                clang::VK_RValue,
1229                                                clang::OK_Ordinary,
1230                                                Loc));
1231  ArgList.push_back(InitExpr);
1232
1233  clang::CallExpr *RSSetObjectCall =
1234      new(C) clang::CallExpr(C,
1235                             RSSetObjectFP,
1236                             ArgList,
1237                             SetObjectFD->getCallResultType(),
1238                             clang::VK_RValue,
1239                             Loc);
1240
1241  std::list<clang::Stmt*> StmtList;
1242  StmtList.push_back(RSSetObjectCall);
1243  AppendAfterStmt(C, mCS, DS, StmtList);
1244
1245  return;
1246}
1247
1248void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1249  for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
1250          E = mRSO.end();
1251        I != E;
1252        I++) {
1253    clang::VarDecl *VD = *I;
1254    clang::Stmt *RSClearObjectCall = ClearRSObject(VD, VD->getDeclContext());
1255    if (RSClearObjectCall) {
1256      DestructorVisitor DV((*mRSO.begin())->getASTContext(),
1257                           mCS,
1258                           RSClearObjectCall,
1259                           VD->getSourceRange().getBegin());
1260      DV.Visit(mCS);
1261      DV.InsertDestructors();
1262    }
1263  }
1264  return;
1265}
1266
1267clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
1268    clang::VarDecl *VD,
1269    clang::DeclContext *DC) {
1270  slangAssert(VD);
1271  clang::ASTContext &C = VD->getASTContext();
1272  clang::SourceLocation Loc = VD->getLocation();
1273  clang::SourceLocation StartLoc = VD->getInnerLocStart();
1274  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1275
1276  // Reference expr to target RS object variable
1277  clang::DeclRefExpr *RefRSVar =
1278      clang::DeclRefExpr::Create(C,
1279                                 clang::NestedNameSpecifierLoc(),
1280                                 clang::SourceLocation(),
1281                                 VD,
1282                                 false,
1283                                 Loc,
1284                                 T->getCanonicalTypeInternal(),
1285                                 clang::VK_RValue,
1286                                 NULL);
1287
1288  if (T->isArrayType()) {
1289    return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1290  }
1291
1292  RSExportPrimitiveType::DataType DT =
1293      RSExportPrimitiveType::GetRSSpecificType(T);
1294
1295  if (DT == RSExportPrimitiveType::DataTypeUnknown ||
1296      DT == RSExportPrimitiveType::DataTypeIsStruct) {
1297    return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1298  }
1299
1300  slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1301              "Should be RS object");
1302
1303  return ClearSingleRSObject(C, RefRSVar, Loc);
1304}
1305
1306bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1307                                          RSExportPrimitiveType::DataType *DT,
1308                                          clang::Expr **InitExpr) {
1309  slangAssert(VD && DT && InitExpr);
1310  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1311
1312  // Loop through array types to get to base type
1313  while (T && T->isArrayType()) {
1314    T = T->getArrayElementTypeNoTypeQual();
1315  }
1316
1317  bool DataTypeIsStructWithRSObject = false;
1318  *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1319
1320  if (*DT == RSExportPrimitiveType::DataTypeUnknown) {
1321    if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1322      *DT = RSExportPrimitiveType::DataTypeIsStruct;
1323      DataTypeIsStructWithRSObject = true;
1324    } else {
1325      return false;
1326    }
1327  }
1328
1329  bool DataTypeIsRSObject = false;
1330  if (DataTypeIsStructWithRSObject) {
1331    DataTypeIsRSObject = true;
1332  } else {
1333    DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1334  }
1335  *InitExpr = VD->getInit();
1336
1337  if (!DataTypeIsRSObject && *InitExpr) {
1338    // If we already have an initializer for a matrix type, we are done.
1339    return DataTypeIsRSObject;
1340  }
1341
1342  clang::Expr *ZeroInitializer =
1343      CreateZeroInitializerForRSSpecificType(*DT,
1344                                             VD->getASTContext(),
1345                                             VD->getLocation());
1346
1347  if (ZeroInitializer) {
1348    ZeroInitializer->setType(T->getCanonicalTypeInternal());
1349    VD->setInit(ZeroInitializer);
1350  }
1351
1352  return DataTypeIsRSObject;
1353}
1354
1355clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
1356    RSExportPrimitiveType::DataType DT,
1357    clang::ASTContext &C,
1358    const clang::SourceLocation &Loc) {
1359  clang::Expr *Res = NULL;
1360  switch (DT) {
1361    case RSExportPrimitiveType::DataTypeIsStruct:
1362    case RSExportPrimitiveType::DataTypeRSElement:
1363    case RSExportPrimitiveType::DataTypeRSType:
1364    case RSExportPrimitiveType::DataTypeRSAllocation:
1365    case RSExportPrimitiveType::DataTypeRSSampler:
1366    case RSExportPrimitiveType::DataTypeRSScript:
1367    case RSExportPrimitiveType::DataTypeRSMesh:
1368    case RSExportPrimitiveType::DataTypeRSPath:
1369    case RSExportPrimitiveType::DataTypeRSProgramFragment:
1370    case RSExportPrimitiveType::DataTypeRSProgramVertex:
1371    case RSExportPrimitiveType::DataTypeRSProgramRaster:
1372    case RSExportPrimitiveType::DataTypeRSProgramStore:
1373    case RSExportPrimitiveType::DataTypeRSFont: {
1374      //    (ImplicitCastExpr 'nullptr_t'
1375      //      (IntegerLiteral 0)))
1376      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
1377      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
1378      clang::Expr *CastToNull =
1379          clang::ImplicitCastExpr::Create(C,
1380                                          C.NullPtrTy,
1381                                          clang::CK_IntegralToPointer,
1382                                          Int0,
1383                                          NULL,
1384                                          clang::VK_RValue);
1385
1386      llvm::SmallVector<clang::Expr*, 1>InitList;
1387      InitList.push_back(CastToNull);
1388
1389      Res = new(C) clang::InitListExpr(C, Loc, InitList, Loc);
1390      break;
1391    }
1392    case RSExportPrimitiveType::DataTypeRSMatrix2x2:
1393    case RSExportPrimitiveType::DataTypeRSMatrix3x3:
1394    case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
1395      // RS matrix is not completely an RS object. They hold data by themselves.
1396      // (InitListExpr rs_matrix2x2
1397      //   (InitListExpr float[4]
1398      //     (FloatingLiteral 0)
1399      //     (FloatingLiteral 0)
1400      //     (FloatingLiteral 0)
1401      //     (FloatingLiteral 0)))
1402      clang::QualType FloatTy = C.FloatTy;
1403      // Constructor sets value to 0.0f by default
1404      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
1405      clang::FloatingLiteral *Float0Val =
1406          clang::FloatingLiteral::Create(C,
1407                                         Val,
1408                                         /* isExact = */true,
1409                                         FloatTy,
1410                                         Loc);
1411
1412      unsigned N = 0;
1413      if (DT == RSExportPrimitiveType::DataTypeRSMatrix2x2)
1414        N = 2;
1415      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix3x3)
1416        N = 3;
1417      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix4x4)
1418        N = 4;
1419      unsigned N_2 = N * N;
1420
1421      // Assume we are going to be allocating 16 elements, since 4x4 is max.
1422      llvm::SmallVector<clang::Expr*, 16> InitVals;
1423      for (unsigned i = 0; i < N_2; i++)
1424        InitVals.push_back(Float0Val);
1425      clang::Expr *InitExpr =
1426          new(C) clang::InitListExpr(C, Loc, InitVals, Loc);
1427      InitExpr->setType(C.getConstantArrayType(FloatTy,
1428                                               llvm::APInt(32, N_2),
1429                                               clang::ArrayType::Normal,
1430                                               /* EltTypeQuals = */0));
1431      llvm::SmallVector<clang::Expr*, 1> InitExprVec;
1432      InitExprVec.push_back(InitExpr);
1433
1434      Res = new(C) clang::InitListExpr(C, Loc, InitExprVec, Loc);
1435      break;
1436    }
1437    case RSExportPrimitiveType::DataTypeUnknown:
1438    case RSExportPrimitiveType::DataTypeFloat16:
1439    case RSExportPrimitiveType::DataTypeFloat32:
1440    case RSExportPrimitiveType::DataTypeFloat64:
1441    case RSExportPrimitiveType::DataTypeSigned8:
1442    case RSExportPrimitiveType::DataTypeSigned16:
1443    case RSExportPrimitiveType::DataTypeSigned32:
1444    case RSExportPrimitiveType::DataTypeSigned64:
1445    case RSExportPrimitiveType::DataTypeUnsigned8:
1446    case RSExportPrimitiveType::DataTypeUnsigned16:
1447    case RSExportPrimitiveType::DataTypeUnsigned32:
1448    case RSExportPrimitiveType::DataTypeUnsigned64:
1449    case RSExportPrimitiveType::DataTypeBoolean:
1450    case RSExportPrimitiveType::DataTypeUnsigned565:
1451    case RSExportPrimitiveType::DataTypeUnsigned5551:
1452    case RSExportPrimitiveType::DataTypeUnsigned4444:
1453    case RSExportPrimitiveType::DataTypeMax: {
1454      slangAssert(false && "Not RS object type!");
1455    }
1456    // No default case will enable compiler detecting the missing cases
1457  }
1458
1459  return Res;
1460}
1461
1462void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1463  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1464       I != E;
1465       I++) {
1466    clang::Decl *D = *I;
1467    if (D->getKind() == clang::Decl::Var) {
1468      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1469      RSExportPrimitiveType::DataType DT =
1470          RSExportPrimitiveType::DataTypeUnknown;
1471      clang::Expr *InitExpr = NULL;
1472      if (InitializeRSObject(VD, &DT, &InitExpr)) {
1473        // We need to zero-init all RS object types (including matrices), ...
1474        getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1475        // ... but, only add to the list of RS objects if we have some
1476        // non-matrix RS object fields.
1477        if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(),
1478                               VD->getLocation())) {
1479          getCurrentScope()->addRSObject(VD);
1480        }
1481      }
1482    }
1483  }
1484  return;
1485}
1486
1487void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1488  if (!CS->body_empty()) {
1489    // Push a new scope
1490    Scope *S = new Scope(CS);
1491    mScopeStack.push(S);
1492
1493    VisitStmt(CS);
1494
1495    // Destroy the scope
1496    slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1497    S->InsertLocalVarDestructors();
1498    mScopeStack.pop();
1499    delete S;
1500  }
1501  return;
1502}
1503
1504void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1505  clang::QualType QT = AS->getType();
1506
1507  if (CountRSObjectTypes(mCtx, QT.getTypePtr(), AS->getExprLoc())) {
1508    getCurrentScope()->ReplaceRSObjectAssignment(AS);
1509  }
1510
1511  return;
1512}
1513
1514void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1515  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1516       I != E;
1517       I++) {
1518    if (clang::Stmt *Child = *I) {
1519      Visit(Child);
1520    }
1521  }
1522  return;
1523}
1524
1525// This function walks the list of global variables and (potentially) creates
1526// a single global static destructor function that properly decrements
1527// reference counts on the contained RS object types.
1528clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
1529  Init();
1530
1531  clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
1532  clang::SourceLocation loc;
1533
1534  llvm::StringRef SR(".rs.dtor");
1535  clang::IdentifierInfo &II = mCtx.Idents.get(SR);
1536  clang::DeclarationName N(&II);
1537  clang::FunctionProtoType::ExtProtoInfo EPI;
1538  clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy,
1539      llvm::ArrayRef<clang::QualType>(), EPI);
1540  clang::FunctionDecl *FD = NULL;
1541
1542  // Generate rsClearObject() call chains for every global variable
1543  // (whether static or extern).
1544  std::list<clang::Stmt *> StmtList;
1545  for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
1546          E = DC->decls_end(); I != E; I++) {
1547    clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
1548    if (VD) {
1549      if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(), loc)) {
1550        if (!FD) {
1551          // Only create FD if we are going to use it.
1552          FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, NULL,
1553                                           clang::SC_None);
1554        }
1555        // Make sure to create any helpers within the function's DeclContext,
1556        // not the one associated with the global translation unit.
1557        clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
1558        StmtList.push_back(RSClearObjectCall);
1559      }
1560    }
1561  }
1562
1563  // Nothing needs to be destroyed, so don't emit a dtor.
1564  if (StmtList.empty()) {
1565    return NULL;
1566  }
1567
1568  clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
1569
1570  FD->setBody(CS);
1571
1572  return FD;
1573}
1574
1575}  // namespace slang
1576