slang_rs_object_ref_count.cpp revision 56916bec2734a6059e1d0f58959514c825c746f7
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_object_ref_count.h"
18
19#include <list>
20
21#include "clang/AST/DeclGroup.h"
22#include "clang/AST/Expr.h"
23#include "clang/AST/NestedNameSpecifier.h"
24#include "clang/AST/OperationKinds.h"
25#include "clang/AST/Stmt.h"
26#include "clang/AST/StmtVisitor.h"
27
28#include "slang_assert.h"
29#include "slang.h"
30#include "slang_rs_ast_replace.h"
31#include "slang_rs_export_type.h"
32
33namespace slang {
34
35/* Even though those two arrays are of size DataTypeMax, only entries that
36 * correspond to object types will be set.
37 */
38clang::FunctionDecl *
39RSObjectRefCount::RSSetObjectFD[DataTypeMax];
40clang::FunctionDecl *
41RSObjectRefCount::RSClearObjectFD[DataTypeMax];
42
43void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
44  for (unsigned i = 0; i < DataTypeMax; i++) {
45    RSSetObjectFD[i] = nullptr;
46    RSClearObjectFD[i] = nullptr;
47  }
48
49  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
50
51  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
52          E = TUDecl->decls_end(); I != E; I++) {
53    if ((I->getKind() >= clang::Decl::firstFunction) &&
54        (I->getKind() <= clang::Decl::lastFunction)) {
55      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
56
57      // points to RSSetObjectFD or RSClearObjectFD
58      clang::FunctionDecl **RSObjectFD;
59
60      if (FD->getName() == "rsSetObject") {
61        slangAssert((FD->getNumParams() == 2) &&
62                    "Invalid rsSetObject function prototype (# params)");
63        RSObjectFD = RSSetObjectFD;
64      } else if (FD->getName() == "rsClearObject") {
65        slangAssert((FD->getNumParams() == 1) &&
66                    "Invalid rsClearObject function prototype (# params)");
67        RSObjectFD = RSClearObjectFD;
68      } else {
69        continue;
70      }
71
72      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
73      clang::QualType PVT = PVD->getOriginalType();
74      // The first parameter must be a pointer like rs_allocation*
75      slangAssert(PVT->isPointerType() &&
76          "Invalid rs{Set,Clear}Object function prototype (pointer param)");
77
78      // The rs object type passed to the FD
79      clang::QualType RST = PVT->getPointeeType();
80      DataType DT = RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
81      slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
82             && "must be RS object type");
83
84      if (DT >= 0 && DT < DataTypeMax) {
85          RSObjectFD[DT] = FD;
86      } else {
87          slangAssert(false && "incorrect type");
88      }
89    }
90  }
91}
92
93namespace {
94
95// This function constructs a new CompoundStmt from the input StmtList.
96static clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
97      std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
98  unsigned NewStmtCount = StmtList.size();
99  unsigned CompoundStmtCount = 0;
100
101  clang::Stmt **CompoundStmtList;
102  CompoundStmtList = new clang::Stmt*[NewStmtCount];
103
104  std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
105  std::list<clang::Stmt*>::const_iterator E = StmtList.end();
106  for ( ; I != E; I++) {
107    CompoundStmtList[CompoundStmtCount++] = *I;
108  }
109  slangAssert(CompoundStmtCount == NewStmtCount);
110
111  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
112      C, llvm::makeArrayRef(CompoundStmtList, CompoundStmtCount), Loc, Loc);
113
114  delete [] CompoundStmtList;
115
116  return CS;
117}
118
119static void AppendAfterStmt(clang::ASTContext &C,
120                            clang::CompoundStmt *CS,
121                            clang::Stmt *S,
122                            std::list<clang::Stmt*> &StmtList) {
123  slangAssert(CS);
124  clang::CompoundStmt::body_iterator bI = CS->body_begin();
125  clang::CompoundStmt::body_iterator bE = CS->body_end();
126  clang::Stmt **UpdatedStmtList =
127      new clang::Stmt*[CS->size() + StmtList.size()];
128
129  unsigned UpdatedStmtCount = 0;
130  unsigned Once = 0;
131  for ( ; bI != bE; bI++) {
132    if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
133      // If we come across a return here, we don't have anything we can
134      // reasonably replace. We should have already inserted our destructor
135      // code in the proper spot, so we just clean up and return.
136      delete [] UpdatedStmtList;
137
138      return;
139    }
140
141    UpdatedStmtList[UpdatedStmtCount++] = *bI;
142
143    if ((*bI == S) && !Once) {
144      Once++;
145      std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
146      std::list<clang::Stmt*>::const_iterator E = StmtList.end();
147      for ( ; I != E; I++) {
148        UpdatedStmtList[UpdatedStmtCount++] = *I;
149      }
150    }
151  }
152  slangAssert(Once <= 1);
153
154  // When S is nullptr, we are appending to the end of the CompoundStmt.
155  if (!S) {
156    slangAssert(Once == 0);
157    std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
158    std::list<clang::Stmt*>::const_iterator E = StmtList.end();
159    for ( ; I != E; I++) {
160      UpdatedStmtList[UpdatedStmtCount++] = *I;
161    }
162  }
163
164  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
165
166  delete [] UpdatedStmtList;
167}
168
169// This class visits a compound statement and inserts DtorStmt
170// in proper locations. This includes inserting it before any
171// return statement in any sub-block, at the end of the logical enclosing
172// scope (compound statement), and/or before any break/continue statement that
173// would resume outside the declared scope. We will not handle the case for
174// goto statements that leave a local scope.
175//
176// To accomplish these goals, it collects a list of sub-Stmt's that
177// correspond to scope exit points. It then uses an RSASTReplace visitor to
178// transform the AST, inserting appropriate destructors before each of those
179// sub-Stmt's (and also before the exit of the outermost containing Stmt for
180// the scope).
181class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
182 private:
183  clang::ASTContext &mCtx;
184
185  // The loop depth of the currently visited node.
186  int mLoopDepth;
187
188  // The switch statement depth of the currently visited node.
189  // Note that this is tracked separately from the loop depth because
190  // SwitchStmt-contained ContinueStmt's should have destructors for the
191  // corresponding loop scope.
192  int mSwitchDepth;
193
194  // The outermost statement block that we are currently visiting.
195  // This should always be a CompoundStmt.
196  clang::Stmt *mOuterStmt;
197
198  // The destructor to execute for this scope/variable.
199  clang::Stmt* mDtorStmt;
200
201  // The stack of statements which should be replaced by a compound statement
202  // containing the new destructor call followed by the original Stmt.
203  std::stack<clang::Stmt*> mReplaceStmtStack;
204
205  // The source location for the variable declaration that we are trying to
206  // insert destructors for. Note that InsertDestructors() will not generate
207  // destructor calls for source locations that occur lexically before this
208  // location.
209  clang::SourceLocation mVarLoc;
210
211 public:
212  DestructorVisitor(clang::ASTContext &C,
213                    clang::Stmt* OuterStmt,
214                    clang::Stmt* DtorStmt,
215                    clang::SourceLocation VarLoc);
216
217  // This code walks the collected list of Stmts to replace and actually does
218  // the replacement. It also finishes up by appending the destructor to the
219  // current outermost CompoundStmt.
220  void InsertDestructors() {
221    clang::Stmt *S = nullptr;
222    clang::SourceManager &SM = mCtx.getSourceManager();
223    std::list<clang::Stmt *> StmtList;
224    StmtList.push_back(mDtorStmt);
225
226    while (!mReplaceStmtStack.empty()) {
227      S = mReplaceStmtStack.top();
228      mReplaceStmtStack.pop();
229
230      // Skip all source locations that occur before the variable's
231      // declaration, since it won't have been initialized yet.
232      if (SM.isBeforeInTranslationUnit(S->getLocStart(), mVarLoc)) {
233        continue;
234      }
235
236      StmtList.push_back(S);
237      clang::CompoundStmt *CS =
238          BuildCompoundStmt(mCtx, StmtList, S->getLocEnd());
239      StmtList.pop_back();
240
241      RSASTReplace R(mCtx);
242      R.ReplaceStmt(mOuterStmt, S, CS);
243    }
244    clang::CompoundStmt *CS =
245      llvm::dyn_cast<clang::CompoundStmt>(mOuterStmt);
246    slangAssert(CS);
247    AppendAfterStmt(mCtx, CS, nullptr, StmtList);
248  }
249
250  void VisitStmt(clang::Stmt *S);
251
252  void VisitBreakStmt(clang::BreakStmt *BS);
253  void VisitContinueStmt(clang::ContinueStmt *CS);
254  void VisitDoStmt(clang::DoStmt *DS);
255  void VisitForStmt(clang::ForStmt *FS);
256  void VisitReturnStmt(clang::ReturnStmt *RS);
257  void VisitSwitchStmt(clang::SwitchStmt *SS);
258  void VisitWhileStmt(clang::WhileStmt *WS);
259};
260
261DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
262                         clang::Stmt *OuterStmt,
263                         clang::Stmt *DtorStmt,
264                         clang::SourceLocation VarLoc)
265  : mCtx(C),
266    mLoopDepth(0),
267    mSwitchDepth(0),
268    mOuterStmt(OuterStmt),
269    mDtorStmt(DtorStmt),
270    mVarLoc(VarLoc) {
271}
272
273void DestructorVisitor::VisitStmt(clang::Stmt *S) {
274  for (clang::Stmt* Child : S->children()) {
275    if (Child) {
276      Visit(Child);
277    }
278  }
279}
280
281void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
282  VisitStmt(BS);
283  if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
284    mReplaceStmtStack.push(BS);
285  }
286}
287
288void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
289  VisitStmt(CS);
290  if (mLoopDepth == 0) {
291    // Switch statements can have nested continues.
292    mReplaceStmtStack.push(CS);
293  }
294}
295
296void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
297  mLoopDepth++;
298  VisitStmt(DS);
299  mLoopDepth--;
300}
301
302void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
303  mLoopDepth++;
304  VisitStmt(FS);
305  mLoopDepth--;
306}
307
308void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
309  mReplaceStmtStack.push(RS);
310}
311
312void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
313  mSwitchDepth++;
314  VisitStmt(SS);
315  mSwitchDepth--;
316}
317
318void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
319  mLoopDepth++;
320  VisitStmt(WS);
321  mLoopDepth--;
322}
323
324clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
325                                 clang::Expr *RefRSVar,
326                                 clang::SourceLocation Loc) {
327  slangAssert(RefRSVar);
328  const clang::Type *T = RefRSVar->getType().getTypePtr();
329  slangAssert(!T->isArrayType() &&
330              "Should not be destroying arrays with this function");
331
332  clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
333  slangAssert((ClearObjectFD != nullptr) &&
334              "rsClearObject doesn't cover all RS object types");
335
336  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
337  clang::QualType ClearObjectFDArgType =
338      ClearObjectFD->getParamDecl(0)->getOriginalType();
339
340  // Example destructor for "rs_font localFont;"
341  //
342  // (CallExpr 'void'
343  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
344  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
345  //   (UnaryOperator 'rs_font *' prefix '&'
346  //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
347
348  // Get address of targeted RS object
349  clang::Expr *AddrRefRSVar =
350      new(C) clang::UnaryOperator(RefRSVar,
351                                  clang::UO_AddrOf,
352                                  ClearObjectFDArgType,
353                                  clang::VK_RValue,
354                                  clang::OK_Ordinary,
355                                  Loc);
356
357  clang::Expr *RefRSClearObjectFD =
358      clang::DeclRefExpr::Create(C,
359                                 clang::NestedNameSpecifierLoc(),
360                                 clang::SourceLocation(),
361                                 ClearObjectFD,
362                                 false,
363                                 ClearObjectFD->getLocation(),
364                                 ClearObjectFDType,
365                                 clang::VK_RValue,
366                                 nullptr);
367
368  clang::Expr *RSClearObjectFP =
369      clang::ImplicitCastExpr::Create(C,
370                                      C.getPointerType(ClearObjectFDType),
371                                      clang::CK_FunctionToPointerDecay,
372                                      RefRSClearObjectFD,
373                                      nullptr,
374                                      clang::VK_RValue);
375
376  llvm::SmallVector<clang::Expr*, 1> ArgList;
377  ArgList.push_back(AddrRefRSVar);
378
379  clang::CallExpr *RSClearObjectCall =
380      new(C) clang::CallExpr(C,
381                             RSClearObjectFP,
382                             ArgList,
383                             ClearObjectFD->getCallResultType(),
384                             clang::VK_RValue,
385                             Loc);
386
387  return RSClearObjectCall;
388}
389
390static int ArrayDim(const clang::Type *T) {
391  if (!T || !T->isArrayType()) {
392    return 0;
393  }
394
395  const clang::ConstantArrayType *CAT =
396    static_cast<const clang::ConstantArrayType *>(T);
397  return static_cast<int>(CAT->getSize().getSExtValue());
398}
399
400static clang::Stmt *ClearStructRSObject(
401    clang::ASTContext &C,
402    clang::DeclContext *DC,
403    clang::Expr *RefRSStruct,
404    clang::SourceLocation StartLoc,
405    clang::SourceLocation Loc);
406
407static clang::Stmt *ClearArrayRSObject(
408    clang::ASTContext &C,
409    clang::DeclContext *DC,
410    clang::Expr *RefRSArr,
411    clang::SourceLocation StartLoc,
412    clang::SourceLocation Loc) {
413  const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
414  slangAssert(BaseType->isArrayType());
415
416  int NumArrayElements = ArrayDim(BaseType);
417  // Actually extract out the base RS object type for use later
418  BaseType = BaseType->getArrayElementTypeNoTypeQual();
419
420  if (NumArrayElements <= 0) {
421    return nullptr;
422  }
423
424  // Example destructor loop for "rs_font fontArr[10];"
425  //
426  // (ForStmt
427  //   (DeclStmt
428  //     (VarDecl used rsIntIter 'int' cinit
429  //       (IntegerLiteral 'int' 0)))
430  //   (BinaryOperator 'int' '<'
431  //     (ImplicitCastExpr int LValueToRValue
432  //       (DeclRefExpr 'int' Var='rsIntIter'))
433  //     (IntegerLiteral 'int' 10)
434  //   nullptr << CondVar >>
435  //   (UnaryOperator 'int' postfix '++'
436  //     (DeclRefExpr 'int' Var='rsIntIter'))
437  //   (CallExpr 'void'
438  //     (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
439  //       (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
440  //     (UnaryOperator 'rs_font *' prefix '&'
441  //       (ArraySubscriptExpr 'rs_font':'rs_font'
442  //         (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
443  //           (DeclRefExpr 'rs_font [10]' Var='fontArr'))
444  //         (DeclRefExpr 'int' Var='rsIntIter'))))))
445
446  // Create helper variable for iterating through elements
447  static unsigned sIterCounter = 0;
448  std::stringstream UniqueIterName;
449  UniqueIterName << "rsIntIter" << sIterCounter++;
450  clang::IdentifierInfo *II = &C.Idents.get(UniqueIterName.str());
451  clang::VarDecl *IIVD =
452      clang::VarDecl::Create(C,
453                             DC,
454                             StartLoc,
455                             Loc,
456                             II,
457                             C.IntTy,
458                             C.getTrivialTypeSourceInfo(C.IntTy),
459                             clang::SC_None);
460  // Mark "rsIntIter" as used
461  IIVD->markUsed(C);
462
463  // Form the actual destructor loop
464  // for (Init; Cond; Inc)
465  //   RSClearObjectCall;
466
467  // Init -> "int rsIntIter = 0"
468  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
469      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
470  IIVD->setInit(Int0);
471
472  clang::Decl *IID = (clang::Decl *)IIVD;
473  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
474  clang::Stmt *Init = new(C) clang::DeclStmt(DGR, Loc, Loc);
475
476  // Cond -> "rsIntIter < NumArrayElements"
477  clang::DeclRefExpr *RefrsIntIterLValue =
478      clang::DeclRefExpr::Create(C,
479                                 clang::NestedNameSpecifierLoc(),
480                                 clang::SourceLocation(),
481                                 IIVD,
482                                 false,
483                                 Loc,
484                                 C.IntTy,
485                                 clang::VK_LValue,
486                                 nullptr);
487
488  clang::Expr *RefrsIntIterRValue =
489      clang::ImplicitCastExpr::Create(C,
490                                      RefrsIntIterLValue->getType(),
491                                      clang::CK_LValueToRValue,
492                                      RefrsIntIterLValue,
493                                      nullptr,
494                                      clang::VK_RValue);
495
496  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
497      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
498
499  clang::BinaryOperator *Cond =
500      new(C) clang::BinaryOperator(RefrsIntIterRValue,
501                                   NumArrayElementsExpr,
502                                   clang::BO_LT,
503                                   C.IntTy,
504                                   clang::VK_RValue,
505                                   clang::OK_Ordinary,
506                                   Loc,
507                                   false);
508
509  // Inc -> "rsIntIter++"
510  clang::UnaryOperator *Inc =
511      new(C) clang::UnaryOperator(RefrsIntIterLValue,
512                                  clang::UO_PostInc,
513                                  C.IntTy,
514                                  clang::VK_RValue,
515                                  clang::OK_Ordinary,
516                                  Loc);
517
518  // Body -> "rsClearObject(&VD[rsIntIter]);"
519  // Destructor loop operates on individual array elements
520
521  clang::Expr *RefRSArrPtr =
522      clang::ImplicitCastExpr::Create(C,
523          C.getPointerType(BaseType->getCanonicalTypeInternal()),
524          clang::CK_ArrayToPointerDecay,
525          RefRSArr,
526          nullptr,
527          clang::VK_RValue);
528
529  clang::Expr *RefRSArrPtrSubscript =
530      new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
531                                       RefrsIntIterRValue,
532                                       BaseType->getCanonicalTypeInternal(),
533                                       clang::VK_RValue,
534                                       clang::OK_Ordinary,
535                                       Loc);
536
537  DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
538
539  clang::Stmt *RSClearObjectCall = nullptr;
540  if (BaseType->isArrayType()) {
541    RSClearObjectCall =
542        ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
543  } else if (DT == DataTypeUnknown) {
544    RSClearObjectCall =
545        ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
546  } else {
547    RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
548  }
549
550  clang::ForStmt *DestructorLoop =
551      new(C) clang::ForStmt(C,
552                            Init,
553                            Cond,
554                            nullptr,  // no condVar
555                            Inc,
556                            RSClearObjectCall,
557                            Loc,
558                            Loc,
559                            Loc);
560
561  return DestructorLoop;
562}
563
564static unsigned CountRSObjectTypes(clang::ASTContext &C,
565                                   const clang::Type *T,
566                                   clang::SourceLocation Loc) {
567  slangAssert(T);
568  unsigned RSObjectCount = 0;
569
570  if (T->isArrayType()) {
571    return CountRSObjectTypes(C, T->getArrayElementTypeNoTypeQual(), Loc);
572  }
573
574  DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
575  if (DT != DataTypeUnknown) {
576    return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
577  }
578
579  if (T->isUnionType()) {
580    clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
581    RD = RD->getDefinition();
582    for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
583           FE = RD->field_end();
584         FI != FE;
585         FI++) {
586      const clang::FieldDecl *FD = *FI;
587      const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
588      if (CountRSObjectTypes(C, FT, Loc)) {
589        slangAssert(false && "can't have unions with RS object types!");
590        return 0;
591      }
592    }
593  }
594
595  if (!T->isStructureType()) {
596    return 0;
597  }
598
599  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
600  RD = RD->getDefinition();
601  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
602         FE = RD->field_end();
603       FI != FE;
604       FI++) {
605    const clang::FieldDecl *FD = *FI;
606    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
607    if (CountRSObjectTypes(C, FT, Loc)) {
608      // Sub-structs should only count once (as should arrays, etc.)
609      RSObjectCount++;
610    }
611  }
612
613  return RSObjectCount;
614}
615
616static clang::Stmt *ClearStructRSObject(
617    clang::ASTContext &C,
618    clang::DeclContext *DC,
619    clang::Expr *RefRSStruct,
620    clang::SourceLocation StartLoc,
621    clang::SourceLocation Loc) {
622  const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
623
624  slangAssert(!BaseType->isArrayType());
625
626  // Structs should show up as unknown primitive types
627  slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
628              DataTypeUnknown);
629
630  unsigned FieldsToDestroy = CountRSObjectTypes(C, BaseType, Loc);
631  slangAssert(FieldsToDestroy != 0);
632
633  unsigned StmtCount = 0;
634  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
635  for (unsigned i = 0; i < FieldsToDestroy; i++) {
636    StmtArray[i] = nullptr;
637  }
638
639  // Populate StmtArray by creating a destructor for each RS object field
640  clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
641  RD = RD->getDefinition();
642  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
643         FE = RD->field_end();
644       FI != FE;
645       FI++) {
646    // We just look through all field declarations to see if we find a
647    // declaration for an RS object type (or an array of one).
648    bool IsArrayType = false;
649    clang::FieldDecl *FD = *FI;
650    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
651    const clang::Type *OrigType = FT;
652    while (FT && FT->isArrayType()) {
653      FT = FT->getArrayElementTypeNoTypeQual();
654      IsArrayType = true;
655    }
656
657    // Pass a DeclarationNameInfo with a valid DeclName, since name equality
658    // gets asserted during CodeGen.
659    clang::DeclarationNameInfo FDDeclNameInfo(FD->getDeclName(),
660                                              FD->getLocation());
661
662    if (RSExportPrimitiveType::IsRSObjectType(FT)) {
663      clang::DeclAccessPair FoundDecl =
664          clang::DeclAccessPair::make(FD, clang::AS_none);
665      clang::MemberExpr *RSObjectMember =
666          clang::MemberExpr::Create(C,
667                                    RefRSStruct,
668                                    false,
669                                    clang::SourceLocation(),
670                                    clang::NestedNameSpecifierLoc(),
671                                    clang::SourceLocation(),
672                                    FD,
673                                    FoundDecl,
674                                    FDDeclNameInfo,
675                                    nullptr,
676                                    OrigType->getCanonicalTypeInternal(),
677                                    clang::VK_RValue,
678                                    clang::OK_Ordinary);
679
680      slangAssert(StmtCount < FieldsToDestroy);
681
682      if (IsArrayType) {
683        StmtArray[StmtCount++] = ClearArrayRSObject(C,
684                                                    DC,
685                                                    RSObjectMember,
686                                                    StartLoc,
687                                                    Loc);
688      } else {
689        StmtArray[StmtCount++] = ClearSingleRSObject(C,
690                                                     RSObjectMember,
691                                                     Loc);
692      }
693    } else if (FT->isStructureType() && CountRSObjectTypes(C, FT, Loc)) {
694      // In this case, we have a nested struct. We may not end up filling all
695      // of the spaces in StmtArray (sub-structs should handle themselves
696      // with separate compound statements).
697      clang::DeclAccessPair FoundDecl =
698          clang::DeclAccessPair::make(FD, clang::AS_none);
699      clang::MemberExpr *RSObjectMember =
700          clang::MemberExpr::Create(C,
701                                    RefRSStruct,
702                                    false,
703                                    clang::SourceLocation(),
704                                    clang::NestedNameSpecifierLoc(),
705                                    clang::SourceLocation(),
706                                    FD,
707                                    FoundDecl,
708                                    clang::DeclarationNameInfo(),
709                                    nullptr,
710                                    OrigType->getCanonicalTypeInternal(),
711                                    clang::VK_RValue,
712                                    clang::OK_Ordinary);
713
714      if (IsArrayType) {
715        StmtArray[StmtCount++] = ClearArrayRSObject(C,
716                                                    DC,
717                                                    RSObjectMember,
718                                                    StartLoc,
719                                                    Loc);
720      } else {
721        StmtArray[StmtCount++] = ClearStructRSObject(C,
722                                                     DC,
723                                                     RSObjectMember,
724                                                     StartLoc,
725                                                     Loc);
726      }
727    }
728  }
729
730  slangAssert(StmtCount > 0);
731  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
732      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
733
734  delete [] StmtArray;
735
736  return CS;
737}
738
739static clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
740                                            clang::Expr *DstExpr,
741                                            clang::Expr *SrcExpr,
742                                            clang::SourceLocation StartLoc,
743                                            clang::SourceLocation Loc) {
744  const clang::Type *T = DstExpr->getType().getTypePtr();
745  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
746  slangAssert((SetObjectFD != nullptr) &&
747              "rsSetObject doesn't cover all RS object types");
748
749  clang::QualType SetObjectFDType = SetObjectFD->getType();
750  clang::QualType SetObjectFDArgType[2];
751  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
752  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
753
754  clang::Expr *RefRSSetObjectFD =
755      clang::DeclRefExpr::Create(C,
756                                 clang::NestedNameSpecifierLoc(),
757                                 clang::SourceLocation(),
758                                 SetObjectFD,
759                                 false,
760                                 Loc,
761                                 SetObjectFDType,
762                                 clang::VK_RValue,
763                                 nullptr);
764
765  clang::Expr *RSSetObjectFP =
766      clang::ImplicitCastExpr::Create(C,
767                                      C.getPointerType(SetObjectFDType),
768                                      clang::CK_FunctionToPointerDecay,
769                                      RefRSSetObjectFD,
770                                      nullptr,
771                                      clang::VK_RValue);
772
773  llvm::SmallVector<clang::Expr*, 2> ArgList;
774  ArgList.push_back(new(C) clang::UnaryOperator(DstExpr,
775                                                clang::UO_AddrOf,
776                                                SetObjectFDArgType[0],
777                                                clang::VK_RValue,
778                                                clang::OK_Ordinary,
779                                                Loc));
780  ArgList.push_back(SrcExpr);
781
782  clang::CallExpr *RSSetObjectCall =
783      new(C) clang::CallExpr(C,
784                             RSSetObjectFP,
785                             ArgList,
786                             SetObjectFD->getCallResultType(),
787                             clang::VK_RValue,
788                             Loc);
789
790  return RSSetObjectCall;
791}
792
793static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
794                                            clang::Expr *LHS,
795                                            clang::Expr *RHS,
796                                            clang::SourceLocation StartLoc,
797                                            clang::SourceLocation Loc);
798
799/*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
800                                           clang::Expr *DstArr,
801                                           clang::Expr *SrcArr,
802                                           clang::SourceLocation StartLoc,
803                                           clang::SourceLocation Loc) {
804  clang::DeclContext *DC = nullptr;
805  const clang::Type *BaseType = DstArr->getType().getTypePtr();
806  slangAssert(BaseType->isArrayType());
807
808  int NumArrayElements = ArrayDim(BaseType);
809  // Actually extract out the base RS object type for use later
810  BaseType = BaseType->getArrayElementTypeNoTypeQual();
811
812  clang::Stmt *StmtArray[2] = {nullptr};
813  int StmtCtr = 0;
814
815  if (NumArrayElements <= 0) {
816    return nullptr;
817  }
818
819  // Create helper variable for iterating through elements
820  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
821  clang::VarDecl *IIVD =
822      clang::VarDecl::Create(C,
823                             DC,
824                             StartLoc,
825                             Loc,
826                             &II,
827                             C.IntTy,
828                             C.getTrivialTypeSourceInfo(C.IntTy),
829                             clang::SC_None,
830                             clang::SC_None);
831  clang::Decl *IID = (clang::Decl *)IIVD;
832
833  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
834  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
835
836  // Form the actual loop
837  // for (Init; Cond; Inc)
838  //   RSSetObjectCall;
839
840  // Init -> "rsIntIter = 0"
841  clang::DeclRefExpr *RefrsIntIter =
842      clang::DeclRefExpr::Create(C,
843                                 clang::NestedNameSpecifierLoc(),
844                                 IIVD,
845                                 Loc,
846                                 C.IntTy,
847                                 clang::VK_RValue,
848                                 nullptr);
849
850  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
851      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
852
853  clang::BinaryOperator *Init =
854      new(C) clang::BinaryOperator(RefrsIntIter,
855                                   Int0,
856                                   clang::BO_Assign,
857                                   C.IntTy,
858                                   clang::VK_RValue,
859                                   clang::OK_Ordinary,
860                                   Loc);
861
862  // Cond -> "rsIntIter < NumArrayElements"
863  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
864      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
865
866  clang::BinaryOperator *Cond =
867      new(C) clang::BinaryOperator(RefrsIntIter,
868                                   NumArrayElementsExpr,
869                                   clang::BO_LT,
870                                   C.IntTy,
871                                   clang::VK_RValue,
872                                   clang::OK_Ordinary,
873                                   Loc);
874
875  // Inc -> "rsIntIter++"
876  clang::UnaryOperator *Inc =
877      new(C) clang::UnaryOperator(RefrsIntIter,
878                                  clang::UO_PostInc,
879                                  C.IntTy,
880                                  clang::VK_RValue,
881                                  clang::OK_Ordinary,
882                                  Loc);
883
884  // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
885  // Loop operates on individual array elements
886
887  clang::Expr *DstArrPtr =
888      clang::ImplicitCastExpr::Create(C,
889          C.getPointerType(BaseType->getCanonicalTypeInternal()),
890          clang::CK_ArrayToPointerDecay,
891          DstArr,
892          nullptr,
893          clang::VK_RValue);
894
895  clang::Expr *DstArrPtrSubscript =
896      new(C) clang::ArraySubscriptExpr(DstArrPtr,
897                                       RefrsIntIter,
898                                       BaseType->getCanonicalTypeInternal(),
899                                       clang::VK_RValue,
900                                       clang::OK_Ordinary,
901                                       Loc);
902
903  clang::Expr *SrcArrPtr =
904      clang::ImplicitCastExpr::Create(C,
905          C.getPointerType(BaseType->getCanonicalTypeInternal()),
906          clang::CK_ArrayToPointerDecay,
907          SrcArr,
908          nullptr,
909          clang::VK_RValue);
910
911  clang::Expr *SrcArrPtrSubscript =
912      new(C) clang::ArraySubscriptExpr(SrcArrPtr,
913                                       RefrsIntIter,
914                                       BaseType->getCanonicalTypeInternal(),
915                                       clang::VK_RValue,
916                                       clang::OK_Ordinary,
917                                       Loc);
918
919  DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
920
921  clang::Stmt *RSSetObjectCall = nullptr;
922  if (BaseType->isArrayType()) {
923    RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
924                                             SrcArrPtrSubscript,
925                                             StartLoc, Loc);
926  } else if (DT == DataTypeUnknown) {
927    RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
928                                              SrcArrPtrSubscript,
929                                              StartLoc, Loc);
930  } else {
931    RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
932                                              SrcArrPtrSubscript,
933                                              StartLoc, Loc);
934  }
935
936  clang::ForStmt *DestructorLoop =
937      new(C) clang::ForStmt(C,
938                            Init,
939                            Cond,
940                            nullptr,  // no condVar
941                            Inc,
942                            RSSetObjectCall,
943                            Loc,
944                            Loc,
945                            Loc);
946
947  StmtArray[StmtCtr++] = DestructorLoop;
948  slangAssert(StmtCtr == 2);
949
950  clang::CompoundStmt *CS =
951      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
952
953  return CS;
954} */
955
956static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
957                                            clang::Expr *LHS,
958                                            clang::Expr *RHS,
959                                            clang::SourceLocation StartLoc,
960                                            clang::SourceLocation Loc) {
961  clang::QualType QT = LHS->getType();
962  const clang::Type *T = QT.getTypePtr();
963  slangAssert(T->isStructureType());
964  slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
965
966  // Keep an extra slot for the original copy (memcpy)
967  unsigned FieldsToSet = CountRSObjectTypes(C, T, Loc) + 1;
968
969  unsigned StmtCount = 0;
970  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
971  for (unsigned i = 0; i < FieldsToSet; i++) {
972    StmtArray[i] = nullptr;
973  }
974
975  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
976  RD = RD->getDefinition();
977  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
978         FE = RD->field_end();
979       FI != FE;
980       FI++) {
981    bool IsArrayType = false;
982    clang::FieldDecl *FD = *FI;
983    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
984    const clang::Type *OrigType = FT;
985
986    if (!CountRSObjectTypes(C, FT, Loc)) {
987      // Skip to next if we don't have any viable RS object types
988      continue;
989    }
990
991    clang::DeclAccessPair FoundDecl =
992        clang::DeclAccessPair::make(FD, clang::AS_none);
993    clang::MemberExpr *DstMember =
994        clang::MemberExpr::Create(C,
995                                  LHS,
996                                  false,
997                                  clang::SourceLocation(),
998                                  clang::NestedNameSpecifierLoc(),
999                                  clang::SourceLocation(),
1000                                  FD,
1001                                  FoundDecl,
1002                                  clang::DeclarationNameInfo(),
1003                                  nullptr,
1004                                  OrigType->getCanonicalTypeInternal(),
1005                                  clang::VK_RValue,
1006                                  clang::OK_Ordinary);
1007
1008    clang::MemberExpr *SrcMember =
1009        clang::MemberExpr::Create(C,
1010                                  RHS,
1011                                  false,
1012                                  clang::SourceLocation(),
1013                                  clang::NestedNameSpecifierLoc(),
1014                                  clang::SourceLocation(),
1015                                  FD,
1016                                  FoundDecl,
1017                                  clang::DeclarationNameInfo(),
1018                                  nullptr,
1019                                  OrigType->getCanonicalTypeInternal(),
1020                                  clang::VK_RValue,
1021                                  clang::OK_Ordinary);
1022
1023    if (FT->isArrayType()) {
1024      FT = FT->getArrayElementTypeNoTypeQual();
1025      IsArrayType = true;
1026    }
1027
1028    DataType DT = RSExportPrimitiveType::GetRSSpecificType(FT);
1029
1030    if (IsArrayType) {
1031      clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
1032      DiagEngine.Report(
1033        clang::FullSourceLoc(Loc, C.getSourceManager()),
1034        DiagEngine.getCustomDiagID(
1035          clang::DiagnosticsEngine::Error,
1036          "Arrays of RS object types within structures cannot be copied"));
1037      // TODO(srhines): Support setting arrays of RS objects
1038      // StmtArray[StmtCount++] =
1039      //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1040    } else if (DT == DataTypeUnknown) {
1041      StmtArray[StmtCount++] =
1042          CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1043    } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
1044      StmtArray[StmtCount++] =
1045          CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1046    } else {
1047      slangAssert(false);
1048    }
1049  }
1050
1051  slangAssert(StmtCount < FieldsToSet);
1052
1053  // We still need to actually do the overall struct copy. For simplicity,
1054  // we just do a straight-up assignment (which will still preserve all
1055  // the proper RS object reference counts).
1056  clang::BinaryOperator *CopyStruct =
1057      new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
1058                                   clang::VK_RValue, clang::OK_Ordinary, Loc,
1059                                   false);
1060  StmtArray[StmtCount++] = CopyStruct;
1061
1062  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
1063      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
1064
1065  delete [] StmtArray;
1066
1067  return CS;
1068}
1069
1070}  // namespace
1071
1072void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1073    clang::BinaryOperator *AS) {
1074
1075  clang::QualType QT = AS->getType();
1076
1077  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1078      DataTypeRSAllocation)->getASTContext();
1079
1080  clang::SourceLocation Loc = AS->getExprLoc();
1081  clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
1082  clang::Stmt *UpdatedStmt = nullptr;
1083
1084  if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1085    // By definition, this is a struct assignment if we get here
1086    UpdatedStmt =
1087        CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1088  } else {
1089    UpdatedStmt =
1090        CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1091  }
1092
1093  RSASTReplace R(C);
1094  R.ReplaceStmt(mCS, AS, UpdatedStmt);
1095}
1096
1097void RSObjectRefCount::Scope::AppendRSObjectInit(
1098    clang::VarDecl *VD,
1099    clang::DeclStmt *DS,
1100    DataType DT,
1101    clang::Expr *InitExpr) {
1102  slangAssert(VD);
1103
1104  if (!InitExpr) {
1105    return;
1106  }
1107
1108  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1109      DataTypeRSAllocation)->getASTContext();
1110  clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1111      DataTypeRSAllocation)->getLocation();
1112  clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1113      DataTypeRSAllocation)->getInnerLocStart();
1114
1115  if (DT == DataTypeIsStruct) {
1116    const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1117    clang::DeclRefExpr *RefRSVar =
1118        clang::DeclRefExpr::Create(C,
1119                                   clang::NestedNameSpecifierLoc(),
1120                                   clang::SourceLocation(),
1121                                   VD,
1122                                   false,
1123                                   Loc,
1124                                   T->getCanonicalTypeInternal(),
1125                                   clang::VK_RValue,
1126                                   nullptr);
1127
1128    clang::Stmt *RSSetObjectOps =
1129        CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1130
1131    std::list<clang::Stmt*> StmtList;
1132    StmtList.push_back(RSSetObjectOps);
1133    AppendAfterStmt(C, mCS, DS, StmtList);
1134    return;
1135  }
1136
1137  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1138  slangAssert((SetObjectFD != nullptr) &&
1139              "rsSetObject doesn't cover all RS object types");
1140
1141  clang::QualType SetObjectFDType = SetObjectFD->getType();
1142  clang::QualType SetObjectFDArgType[2];
1143  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1144  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1145
1146  clang::Expr *RefRSSetObjectFD =
1147      clang::DeclRefExpr::Create(C,
1148                                 clang::NestedNameSpecifierLoc(),
1149                                 clang::SourceLocation(),
1150                                 SetObjectFD,
1151                                 false,
1152                                 Loc,
1153                                 SetObjectFDType,
1154                                 clang::VK_RValue,
1155                                 nullptr);
1156
1157  clang::Expr *RSSetObjectFP =
1158      clang::ImplicitCastExpr::Create(C,
1159                                      C.getPointerType(SetObjectFDType),
1160                                      clang::CK_FunctionToPointerDecay,
1161                                      RefRSSetObjectFD,
1162                                      nullptr,
1163                                      clang::VK_RValue);
1164
1165  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1166  clang::DeclRefExpr *RefRSVar =
1167      clang::DeclRefExpr::Create(C,
1168                                 clang::NestedNameSpecifierLoc(),
1169                                 clang::SourceLocation(),
1170                                 VD,
1171                                 false,
1172                                 Loc,
1173                                 T->getCanonicalTypeInternal(),
1174                                 clang::VK_RValue,
1175                                 nullptr);
1176
1177  llvm::SmallVector<clang::Expr*, 2> ArgList;
1178  ArgList.push_back(new(C) clang::UnaryOperator(RefRSVar,
1179                                                clang::UO_AddrOf,
1180                                                SetObjectFDArgType[0],
1181                                                clang::VK_RValue,
1182                                                clang::OK_Ordinary,
1183                                                Loc));
1184  ArgList.push_back(InitExpr);
1185
1186  clang::CallExpr *RSSetObjectCall =
1187      new(C) clang::CallExpr(C,
1188                             RSSetObjectFP,
1189                             ArgList,
1190                             SetObjectFD->getCallResultType(),
1191                             clang::VK_RValue,
1192                             Loc);
1193
1194  std::list<clang::Stmt*> StmtList;
1195  StmtList.push_back(RSSetObjectCall);
1196  AppendAfterStmt(C, mCS, DS, StmtList);
1197}
1198
1199void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1200  for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
1201          E = mRSO.end();
1202        I != E;
1203        I++) {
1204    clang::VarDecl *VD = *I;
1205    clang::Stmt *RSClearObjectCall = ClearRSObject(VD, VD->getDeclContext());
1206    if (RSClearObjectCall) {
1207      clang::ASTContext &C = (*mRSO.begin())->getASTContext();
1208      // Mark VD as used.  It might be unused, except for the destructor.
1209      // 'markUsed' has side-effects that are caused only if VD is not already
1210      // used.  Hence no need for an extra check here.
1211      VD->markUsed(C);
1212      DestructorVisitor DV(C,
1213                           mCS,
1214                           RSClearObjectCall,
1215                           VD->getSourceRange().getBegin());
1216      DV.Visit(mCS);
1217      DV.InsertDestructors();
1218    }
1219  }
1220}
1221
1222clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
1223    clang::VarDecl *VD,
1224    clang::DeclContext *DC) {
1225  slangAssert(VD);
1226  clang::ASTContext &C = VD->getASTContext();
1227  clang::SourceLocation Loc = VD->getLocation();
1228  clang::SourceLocation StartLoc = VD->getInnerLocStart();
1229  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1230
1231  // Reference expr to target RS object variable
1232  clang::DeclRefExpr *RefRSVar =
1233      clang::DeclRefExpr::Create(C,
1234                                 clang::NestedNameSpecifierLoc(),
1235                                 clang::SourceLocation(),
1236                                 VD,
1237                                 false,
1238                                 Loc,
1239                                 T->getCanonicalTypeInternal(),
1240                                 clang::VK_RValue,
1241                                 nullptr);
1242
1243  if (T->isArrayType()) {
1244    return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1245  }
1246
1247  DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
1248
1249  if (DT == DataTypeUnknown ||
1250      DT == DataTypeIsStruct) {
1251    return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1252  }
1253
1254  slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1255              "Should be RS object");
1256
1257  return ClearSingleRSObject(C, RefRSVar, Loc);
1258}
1259
1260bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1261                                          DataType *DT,
1262                                          clang::Expr **InitExpr) {
1263  slangAssert(VD && DT && InitExpr);
1264  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1265
1266  // Loop through array types to get to base type
1267  while (T && T->isArrayType()) {
1268    T = T->getArrayElementTypeNoTypeQual();
1269  }
1270
1271  bool DataTypeIsStructWithRSObject = false;
1272  *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1273
1274  if (*DT == DataTypeUnknown) {
1275    if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1276      *DT = DataTypeIsStruct;
1277      DataTypeIsStructWithRSObject = true;
1278    } else {
1279      return false;
1280    }
1281  }
1282
1283  bool DataTypeIsRSObject = false;
1284  if (DataTypeIsStructWithRSObject) {
1285    DataTypeIsRSObject = true;
1286  } else {
1287    DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1288  }
1289  *InitExpr = VD->getInit();
1290
1291  if (!DataTypeIsRSObject && *InitExpr) {
1292    // If we already have an initializer for a matrix type, we are done.
1293    return DataTypeIsRSObject;
1294  }
1295
1296  clang::Expr *ZeroInitializer =
1297      CreateZeroInitializerForRSSpecificType(*DT,
1298                                             VD->getASTContext(),
1299                                             VD->getLocation());
1300
1301  if (ZeroInitializer) {
1302    ZeroInitializer->setType(T->getCanonicalTypeInternal());
1303    VD->setInit(ZeroInitializer);
1304  }
1305
1306  return DataTypeIsRSObject;
1307}
1308
1309clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
1310    DataType DT,
1311    clang::ASTContext &C,
1312    const clang::SourceLocation &Loc) {
1313  clang::Expr *Res = nullptr;
1314  switch (DT) {
1315    case DataTypeIsStruct:
1316    case DataTypeRSElement:
1317    case DataTypeRSType:
1318    case DataTypeRSAllocation:
1319    case DataTypeRSSampler:
1320    case DataTypeRSScript:
1321    case DataTypeRSMesh:
1322    case DataTypeRSPath:
1323    case DataTypeRSProgramFragment:
1324    case DataTypeRSProgramVertex:
1325    case DataTypeRSProgramRaster:
1326    case DataTypeRSProgramStore:
1327    case DataTypeRSFont: {
1328      //    (ImplicitCastExpr 'nullptr_t'
1329      //      (IntegerLiteral 0)))
1330      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
1331      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
1332      clang::Expr *CastToNull =
1333          clang::ImplicitCastExpr::Create(C,
1334                                          C.NullPtrTy,
1335                                          clang::CK_IntegralToPointer,
1336                                          Int0,
1337                                          nullptr,
1338                                          clang::VK_RValue);
1339
1340      llvm::SmallVector<clang::Expr*, 1>InitList;
1341      InitList.push_back(CastToNull);
1342
1343      Res = new(C) clang::InitListExpr(C, Loc, InitList, Loc);
1344      break;
1345    }
1346    case DataTypeRSMatrix2x2:
1347    case DataTypeRSMatrix3x3:
1348    case DataTypeRSMatrix4x4: {
1349      // RS matrix is not completely an RS object. They hold data by themselves.
1350      // (InitListExpr rs_matrix2x2
1351      //   (InitListExpr float[4]
1352      //     (FloatingLiteral 0)
1353      //     (FloatingLiteral 0)
1354      //     (FloatingLiteral 0)
1355      //     (FloatingLiteral 0)))
1356      clang::QualType FloatTy = C.FloatTy;
1357      // Constructor sets value to 0.0f by default
1358      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
1359      clang::FloatingLiteral *Float0Val =
1360          clang::FloatingLiteral::Create(C,
1361                                         Val,
1362                                         /* isExact = */true,
1363                                         FloatTy,
1364                                         Loc);
1365
1366      unsigned N = 0;
1367      if (DT == DataTypeRSMatrix2x2)
1368        N = 2;
1369      else if (DT == DataTypeRSMatrix3x3)
1370        N = 3;
1371      else if (DT == DataTypeRSMatrix4x4)
1372        N = 4;
1373      unsigned N_2 = N * N;
1374
1375      // Assume we are going to be allocating 16 elements, since 4x4 is max.
1376      llvm::SmallVector<clang::Expr*, 16> InitVals;
1377      for (unsigned i = 0; i < N_2; i++)
1378        InitVals.push_back(Float0Val);
1379      clang::Expr *InitExpr =
1380          new(C) clang::InitListExpr(C, Loc, InitVals, Loc);
1381      InitExpr->setType(C.getConstantArrayType(FloatTy,
1382                                               llvm::APInt(32, N_2),
1383                                               clang::ArrayType::Normal,
1384                                               /* EltTypeQuals = */0));
1385      llvm::SmallVector<clang::Expr*, 1> InitExprVec;
1386      InitExprVec.push_back(InitExpr);
1387
1388      Res = new(C) clang::InitListExpr(C, Loc, InitExprVec, Loc);
1389      break;
1390    }
1391    case DataTypeUnknown:
1392    case DataTypeFloat16:
1393    case DataTypeFloat32:
1394    case DataTypeFloat64:
1395    case DataTypeSigned8:
1396    case DataTypeSigned16:
1397    case DataTypeSigned32:
1398    case DataTypeSigned64:
1399    case DataTypeUnsigned8:
1400    case DataTypeUnsigned16:
1401    case DataTypeUnsigned32:
1402    case DataTypeUnsigned64:
1403    case DataTypeBoolean:
1404    case DataTypeUnsigned565:
1405    case DataTypeUnsigned5551:
1406    case DataTypeUnsigned4444:
1407    case DataTypeMax: {
1408      slangAssert(false && "Not RS object type!");
1409    }
1410    // No default case will enable compiler detecting the missing cases
1411  }
1412
1413  return Res;
1414}
1415
1416void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1417  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1418       I != E;
1419       I++) {
1420    clang::Decl *D = *I;
1421    if (D->getKind() == clang::Decl::Var) {
1422      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1423      DataType DT = DataTypeUnknown;
1424      clang::Expr *InitExpr = nullptr;
1425      if (InitializeRSObject(VD, &DT, &InitExpr)) {
1426        // We need to zero-init all RS object types (including matrices), ...
1427        getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1428        // ... but, only add to the list of RS objects if we have some
1429        // non-matrix RS object fields.
1430        if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(),
1431                               VD->getLocation())) {
1432          getCurrentScope()->addRSObject(VD);
1433        }
1434      }
1435    }
1436  }
1437}
1438
1439void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1440  if (!CS->body_empty()) {
1441    // Push a new scope
1442    Scope *S = new Scope(CS);
1443    mScopeStack.push(S);
1444
1445    VisitStmt(CS);
1446
1447    // Destroy the scope
1448    slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1449    S->InsertLocalVarDestructors();
1450    mScopeStack.pop();
1451    delete S;
1452  }
1453}
1454
1455void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1456  clang::QualType QT = AS->getType();
1457
1458  if (CountRSObjectTypes(mCtx, QT.getTypePtr(), AS->getExprLoc())) {
1459    getCurrentScope()->ReplaceRSObjectAssignment(AS);
1460  }
1461}
1462
1463void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1464  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1465       I != E;
1466       I++) {
1467    if (clang::Stmt *Child = *I) {
1468      Visit(Child);
1469    }
1470  }
1471}
1472
1473// This function walks the list of global variables and (potentially) creates
1474// a single global static destructor function that properly decrements
1475// reference counts on the contained RS object types.
1476clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
1477  Init();
1478
1479  clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
1480  clang::SourceLocation loc;
1481
1482  llvm::StringRef SR(".rs.dtor");
1483  clang::IdentifierInfo &II = mCtx.Idents.get(SR);
1484  clang::DeclarationName N(&II);
1485  clang::FunctionProtoType::ExtProtoInfo EPI;
1486  clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy,
1487      llvm::ArrayRef<clang::QualType>(), EPI);
1488  clang::FunctionDecl *FD = nullptr;
1489
1490  // Generate rsClearObject() call chains for every global variable
1491  // (whether static or extern).
1492  std::list<clang::Stmt *> StmtList;
1493  for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
1494          E = DC->decls_end(); I != E; I++) {
1495    clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
1496    if (VD) {
1497      if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(), loc)) {
1498        if (!FD) {
1499          // Only create FD if we are going to use it.
1500          FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, nullptr,
1501                                           clang::SC_None);
1502        }
1503        // Mark VD as used.  It might be unused, except for the destructor.
1504        // 'markUsed' has side-effects that are caused only if VD is not already
1505        // used.  Hence no need for an extra check here.
1506        VD->markUsed(mCtx);
1507        // Make sure to create any helpers within the function's DeclContext,
1508        // not the one associated with the global translation unit.
1509        clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
1510        StmtList.push_back(RSClearObjectCall);
1511      }
1512    }
1513  }
1514
1515  // Nothing needs to be destroyed, so don't emit a dtor.
1516  if (StmtList.empty()) {
1517    return nullptr;
1518  }
1519
1520  clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
1521
1522  FD->setBody(CS);
1523
1524  return FD;
1525}
1526
1527}  // namespace slang
1528