slang_rs_object_ref_count.cpp revision 65f23ed862e1a1e16477ba740f295ff4a83ac822
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_object_ref_count.h"
18
19#include <list>
20
21#include "clang/AST/DeclGroup.h"
22#include "clang/AST/Expr.h"
23#include "clang/AST/NestedNameSpecifier.h"
24#include "clang/AST/OperationKinds.h"
25#include "clang/AST/Stmt.h"
26#include "clang/AST/StmtVisitor.h"
27
28#include "slang_assert.h"
29#include "slang.h"
30#include "slang_rs_ast_replace.h"
31#include "slang_rs_export_type.h"
32
33namespace slang {
34
35/* Even though those two arrays are of size DataTypeMax, only entries that
36 * correspond to object types will be set.
37 */
38clang::FunctionDecl *
39RSObjectRefCount::RSSetObjectFD[DataTypeMax];
40clang::FunctionDecl *
41RSObjectRefCount::RSClearObjectFD[DataTypeMax];
42
43void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
44  for (unsigned i = 0; i < DataTypeMax; i++) {
45    RSSetObjectFD[i] = nullptr;
46    RSClearObjectFD[i] = nullptr;
47  }
48
49  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
50
51  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
52          E = TUDecl->decls_end(); I != E; I++) {
53    if ((I->getKind() >= clang::Decl::firstFunction) &&
54        (I->getKind() <= clang::Decl::lastFunction)) {
55      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
56
57      // points to RSSetObjectFD or RSClearObjectFD
58      clang::FunctionDecl **RSObjectFD;
59
60      if (FD->getName() == "rsSetObject") {
61        slangAssert((FD->getNumParams() == 2) &&
62                    "Invalid rsSetObject function prototype (# params)");
63        RSObjectFD = RSSetObjectFD;
64      } else if (FD->getName() == "rsClearObject") {
65        slangAssert((FD->getNumParams() == 1) &&
66                    "Invalid rsClearObject function prototype (# params)");
67        RSObjectFD = RSClearObjectFD;
68      } else {
69        continue;
70      }
71
72      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
73      clang::QualType PVT = PVD->getOriginalType();
74      // The first parameter must be a pointer like rs_allocation*
75      slangAssert(PVT->isPointerType() &&
76          "Invalid rs{Set,Clear}Object function prototype (pointer param)");
77
78      // The rs object type passed to the FD
79      clang::QualType RST = PVT->getPointeeType();
80      DataType DT = RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
81      slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
82             && "must be RS object type");
83
84      if (DT >= 0 && DT < DataTypeMax) {
85          RSObjectFD[DT] = FD;
86      } else {
87          slangAssert(false && "incorrect type");
88      }
89    }
90  }
91}
92
93namespace {
94
95// This function constructs a new CompoundStmt from the input StmtList.
96static clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
97      std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
98  unsigned NewStmtCount = StmtList.size();
99  unsigned CompoundStmtCount = 0;
100
101  clang::Stmt **CompoundStmtList;
102  CompoundStmtList = new clang::Stmt*[NewStmtCount];
103
104  std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
105  std::list<clang::Stmt*>::const_iterator E = StmtList.end();
106  for ( ; I != E; I++) {
107    CompoundStmtList[CompoundStmtCount++] = *I;
108  }
109  slangAssert(CompoundStmtCount == NewStmtCount);
110
111  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
112      C, llvm::makeArrayRef(CompoundStmtList, CompoundStmtCount), Loc, Loc);
113
114  delete [] CompoundStmtList;
115
116  return CS;
117}
118
119static void AppendAfterStmt(clang::ASTContext &C,
120                            clang::CompoundStmt *CS,
121                            clang::Stmt *S,
122                            std::list<clang::Stmt*> &StmtList) {
123  slangAssert(CS);
124  clang::CompoundStmt::body_iterator bI = CS->body_begin();
125  clang::CompoundStmt::body_iterator bE = CS->body_end();
126  clang::Stmt **UpdatedStmtList =
127      new clang::Stmt*[CS->size() + StmtList.size()];
128
129  unsigned UpdatedStmtCount = 0;
130  unsigned Once = 0;
131  for ( ; bI != bE; bI++) {
132    if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
133      // If we come across a return here, we don't have anything we can
134      // reasonably replace. We should have already inserted our destructor
135      // code in the proper spot, so we just clean up and return.
136      delete [] UpdatedStmtList;
137
138      return;
139    }
140
141    UpdatedStmtList[UpdatedStmtCount++] = *bI;
142
143    if ((*bI == S) && !Once) {
144      Once++;
145      std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
146      std::list<clang::Stmt*>::const_iterator E = StmtList.end();
147      for ( ; I != E; I++) {
148        UpdatedStmtList[UpdatedStmtCount++] = *I;
149      }
150    }
151  }
152  slangAssert(Once <= 1);
153
154  // When S is nullptr, we are appending to the end of the CompoundStmt.
155  if (!S) {
156    slangAssert(Once == 0);
157    std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
158    std::list<clang::Stmt*>::const_iterator E = StmtList.end();
159    for ( ; I != E; I++) {
160      UpdatedStmtList[UpdatedStmtCount++] = *I;
161    }
162  }
163
164  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
165
166  delete [] UpdatedStmtList;
167}
168
169// This class visits a compound statement and inserts DtorStmt
170// in proper locations. This includes inserting it before any
171// return statement in any sub-block, at the end of the logical enclosing
172// scope (compound statement), and/or before any break/continue statement that
173// would resume outside the declared scope. We will not handle the case for
174// goto statements that leave a local scope.
175//
176// To accomplish these goals, it collects a list of sub-Stmt's that
177// correspond to scope exit points. It then uses an RSASTReplace visitor to
178// transform the AST, inserting appropriate destructors before each of those
179// sub-Stmt's (and also before the exit of the outermost containing Stmt for
180// the scope).
181class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
182 private:
183  clang::ASTContext &mCtx;
184  clang::DeclContext *mDC;
185
186  // The loop depth of the currently visited node.
187  int mLoopDepth;
188
189  // The switch statement depth of the currently visited node.
190  // Note that this is tracked separately from the loop depth because
191  // SwitchStmt-contained ContinueStmt's should have destructors for the
192  // corresponding loop scope.
193  int mSwitchDepth;
194
195  // The outermost statement block that we are currently visiting.
196  // This should always be a CompoundStmt.
197  clang::Stmt *mOuterStmt;
198
199  // The destructor to execute for this scope/variable.
200  clang::Stmt* mDtorStmt;
201
202  // The stack of statements which should be replaced by a compound statement
203  // containing the new destructor call followed by the original Stmt.
204  std::stack<clang::Stmt*> mReplaceStmtStack;
205
206  // The source location for the variable declaration that we are trying to
207  // insert destructors for. Note that InsertDestructors() will not generate
208  // destructor calls for source locations that occur lexically before this
209  // location.
210  clang::SourceLocation mVarLoc;
211
212 public:
213  DestructorVisitor(clang::DeclContext* DC,
214                    clang::Stmt* OuterStmt,
215                    clang::Stmt* DtorStmt,
216                    clang::SourceLocation VarLoc);
217
218  // This code walks the collected list of Stmts to replace and actually does
219  // the replacement. It also finishes up by appending the destructor to the
220  // current outermost CompoundStmt.
221  void InsertDestructors() {
222    clang::Stmt *S = nullptr;
223    clang::SourceManager &SM = mCtx.getSourceManager();
224    std::list<clang::Stmt *> StmtList;
225    StmtList.push_back(mDtorStmt);
226
227    while (!mReplaceStmtStack.empty()) {
228      S = mReplaceStmtStack.top();
229      mReplaceStmtStack.pop();
230
231      // Skip all source locations that occur before the variable's
232      // declaration, since it won't have been initialized yet.
233      if (SM.isBeforeInTranslationUnit(S->getLocStart(), mVarLoc)) {
234        continue;
235      }
236
237      clang::CompoundStmt *CS;
238      clang::ReturnStmt* RS = llvm::dyn_cast<clang::ReturnStmt>(S);
239      clang::Expr* RetVal;
240      if (!RS || !(RetVal = RS->getRetValue())) {
241        StmtList.push_back(S);
242        CS = BuildCompoundStmt(mCtx, StmtList, S->getLocEnd());
243        StmtList.pop_back();
244      } else {
245        // Since we insert rsClearObj() calls before the return statement, we need
246        // to make sure none of the cleared RS objects are referenced in the
247        // return statement.
248        // For that, we create a new local variable named .rs.retval, assign the
249        // original return expression to it, make all necessary rsClearObj()
250        // calls, then return .rs.retval. Note rsSetObj() or rsClearObj() are not
251        // called on .rs.retval.
252
253        clang::SourceLocation Loc;
254        clang::QualType RetTy = RetVal->getType();
255        clang::VarDecl* RSRetValDecl = clang::VarDecl::Create(
256            mCtx,                                  // AST context
257            mDC,                                   // Decl context
258            Loc,                                   // Start location
259            Loc,                                   // Id location
260            &mCtx.Idents.get(".rs.retval"),        // Id
261            RetTy,                                 // Type
262            mCtx.getTrivialTypeSourceInfo(RetTy),  // Type info
263            clang::SC_None                         // Storage class
264        );
265        RSRetValDecl->setInit(RetVal);
266        clang::Decl* Decls[] = { RSRetValDecl };
267        const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
268            mCtx, Decls, sizeof(Decls) / sizeof(*Decls));
269        clang::DeclStmt* DS = new (mCtx) clang::DeclStmt(DGR, Loc, Loc);
270
271        // Creates a new return statement
272        clang::ReturnStmt* NewRet = new (mCtx) clang::ReturnStmt(RS->getReturnLoc());
273        clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
274            mCtx,
275            clang::NestedNameSpecifierLoc(),  // QualifierLoc
276            Loc,                              // TemplateKWLoc
277            RSRetValDecl,
278            false,                            // RefersToEnclosingVariableOrCapture
279            Loc,                              // NameLoc
280            RetTy,
281            clang::VK_LValue
282        );
283        clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
284            mCtx,
285            RetTy,
286            clang::CK_LValueToRValue,
287            DRE,
288            nullptr,
289            clang::VK_RValue
290        );
291        NewRet->setRetValue(CastExpr);
292
293        // Insert the two new statements into StmtList
294        StmtList.push_front(DS);
295        StmtList.push_back(NewRet);
296        CS = BuildCompoundStmt(mCtx, StmtList, S->getLocEnd());
297        StmtList.pop_back();
298        StmtList.pop_front();
299      }
300
301      RSASTReplace R(mCtx);
302      R.ReplaceStmt(mOuterStmt, S, CS);
303    }
304    clang::CompoundStmt *CS =
305      llvm::dyn_cast<clang::CompoundStmt>(mOuterStmt);
306    slangAssert(CS);
307    AppendAfterStmt(mCtx, CS, nullptr, StmtList);
308  }
309
310  void VisitStmt(clang::Stmt *S);
311
312  void VisitBreakStmt(clang::BreakStmt *BS);
313  void VisitContinueStmt(clang::ContinueStmt *CS);
314  void VisitDoStmt(clang::DoStmt *DS);
315  void VisitForStmt(clang::ForStmt *FS);
316  void VisitReturnStmt(clang::ReturnStmt *RS);
317  void VisitSwitchStmt(clang::SwitchStmt *SS);
318  void VisitWhileStmt(clang::WhileStmt *WS);
319};
320
321DestructorVisitor::DestructorVisitor(clang::DeclContext* DC,
322                         clang::Stmt *OuterStmt,
323                         clang::Stmt *DtorStmt,
324                         clang::SourceLocation VarLoc)
325  : mCtx(DC->getParentASTContext()),
326    mDC(DC),
327    mLoopDepth(0),
328    mSwitchDepth(0),
329    mOuterStmt(OuterStmt),
330    mDtorStmt(DtorStmt),
331    mVarLoc(VarLoc) {
332}
333
334void DestructorVisitor::VisitStmt(clang::Stmt *S) {
335  for (clang::Stmt* Child : S->children()) {
336    if (Child) {
337      Visit(Child);
338    }
339  }
340}
341
342void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
343  VisitStmt(BS);
344  if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
345    mReplaceStmtStack.push(BS);
346  }
347}
348
349void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
350  VisitStmt(CS);
351  if (mLoopDepth == 0) {
352    // Switch statements can have nested continues.
353    mReplaceStmtStack.push(CS);
354  }
355}
356
357void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
358  mLoopDepth++;
359  VisitStmt(DS);
360  mLoopDepth--;
361}
362
363void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
364  mLoopDepth++;
365  VisitStmt(FS);
366  mLoopDepth--;
367}
368
369void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
370  mReplaceStmtStack.push(RS);
371}
372
373void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
374  mSwitchDepth++;
375  VisitStmt(SS);
376  mSwitchDepth--;
377}
378
379void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
380  mLoopDepth++;
381  VisitStmt(WS);
382  mLoopDepth--;
383}
384
385clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
386                                 clang::Expr *RefRSVar,
387                                 clang::SourceLocation Loc) {
388  slangAssert(RefRSVar);
389  const clang::Type *T = RefRSVar->getType().getTypePtr();
390  slangAssert(!T->isArrayType() &&
391              "Should not be destroying arrays with this function");
392
393  clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
394  slangAssert((ClearObjectFD != nullptr) &&
395              "rsClearObject doesn't cover all RS object types");
396
397  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
398  clang::QualType ClearObjectFDArgType =
399      ClearObjectFD->getParamDecl(0)->getOriginalType();
400
401  // Example destructor for "rs_font localFont;"
402  //
403  // (CallExpr 'void'
404  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
405  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
406  //   (UnaryOperator 'rs_font *' prefix '&'
407  //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
408
409  // Get address of targeted RS object
410  clang::Expr *AddrRefRSVar =
411      new(C) clang::UnaryOperator(RefRSVar,
412                                  clang::UO_AddrOf,
413                                  ClearObjectFDArgType,
414                                  clang::VK_RValue,
415                                  clang::OK_Ordinary,
416                                  Loc);
417
418  clang::Expr *RefRSClearObjectFD =
419      clang::DeclRefExpr::Create(C,
420                                 clang::NestedNameSpecifierLoc(),
421                                 clang::SourceLocation(),
422                                 ClearObjectFD,
423                                 false,
424                                 ClearObjectFD->getLocation(),
425                                 ClearObjectFDType,
426                                 clang::VK_RValue,
427                                 nullptr);
428
429  clang::Expr *RSClearObjectFP =
430      clang::ImplicitCastExpr::Create(C,
431                                      C.getPointerType(ClearObjectFDType),
432                                      clang::CK_FunctionToPointerDecay,
433                                      RefRSClearObjectFD,
434                                      nullptr,
435                                      clang::VK_RValue);
436
437  llvm::SmallVector<clang::Expr*, 1> ArgList;
438  ArgList.push_back(AddrRefRSVar);
439
440  clang::CallExpr *RSClearObjectCall =
441      new(C) clang::CallExpr(C,
442                             RSClearObjectFP,
443                             ArgList,
444                             ClearObjectFD->getCallResultType(),
445                             clang::VK_RValue,
446                             Loc);
447
448  return RSClearObjectCall;
449}
450
451static int ArrayDim(const clang::Type *T) {
452  if (!T || !T->isArrayType()) {
453    return 0;
454  }
455
456  const clang::ConstantArrayType *CAT =
457    static_cast<const clang::ConstantArrayType *>(T);
458  return static_cast<int>(CAT->getSize().getSExtValue());
459}
460
461static clang::Stmt *ClearStructRSObject(
462    clang::ASTContext &C,
463    clang::DeclContext *DC,
464    clang::Expr *RefRSStruct,
465    clang::SourceLocation StartLoc,
466    clang::SourceLocation Loc);
467
468static clang::Stmt *ClearArrayRSObject(
469    clang::ASTContext &C,
470    clang::DeclContext *DC,
471    clang::Expr *RefRSArr,
472    clang::SourceLocation StartLoc,
473    clang::SourceLocation Loc) {
474  const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
475  slangAssert(BaseType->isArrayType());
476
477  int NumArrayElements = ArrayDim(BaseType);
478  // Actually extract out the base RS object type for use later
479  BaseType = BaseType->getArrayElementTypeNoTypeQual();
480
481  if (NumArrayElements <= 0) {
482    return nullptr;
483  }
484
485  // Example destructor loop for "rs_font fontArr[10];"
486  //
487  // (ForStmt
488  //   (DeclStmt
489  //     (VarDecl used rsIntIter 'int' cinit
490  //       (IntegerLiteral 'int' 0)))
491  //   (BinaryOperator 'int' '<'
492  //     (ImplicitCastExpr int LValueToRValue
493  //       (DeclRefExpr 'int' Var='rsIntIter'))
494  //     (IntegerLiteral 'int' 10)
495  //   nullptr << CondVar >>
496  //   (UnaryOperator 'int' postfix '++'
497  //     (DeclRefExpr 'int' Var='rsIntIter'))
498  //   (CallExpr 'void'
499  //     (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
500  //       (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
501  //     (UnaryOperator 'rs_font *' prefix '&'
502  //       (ArraySubscriptExpr 'rs_font':'rs_font'
503  //         (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
504  //           (DeclRefExpr 'rs_font [10]' Var='fontArr'))
505  //         (DeclRefExpr 'int' Var='rsIntIter'))))))
506
507  // Create helper variable for iterating through elements
508  static unsigned sIterCounter = 0;
509  std::stringstream UniqueIterName;
510  UniqueIterName << "rsIntIter" << sIterCounter++;
511  clang::IdentifierInfo *II = &C.Idents.get(UniqueIterName.str());
512  clang::VarDecl *IIVD =
513      clang::VarDecl::Create(C,
514                             DC,
515                             StartLoc,
516                             Loc,
517                             II,
518                             C.IntTy,
519                             C.getTrivialTypeSourceInfo(C.IntTy),
520                             clang::SC_None);
521  // Mark "rsIntIter" as used
522  IIVD->markUsed(C);
523
524  // Form the actual destructor loop
525  // for (Init; Cond; Inc)
526  //   RSClearObjectCall;
527
528  // Init -> "int rsIntIter = 0"
529  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
530      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
531  IIVD->setInit(Int0);
532
533  clang::Decl *IID = (clang::Decl *)IIVD;
534  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
535  clang::Stmt *Init = new(C) clang::DeclStmt(DGR, Loc, Loc);
536
537  // Cond -> "rsIntIter < NumArrayElements"
538  clang::DeclRefExpr *RefrsIntIterLValue =
539      clang::DeclRefExpr::Create(C,
540                                 clang::NestedNameSpecifierLoc(),
541                                 clang::SourceLocation(),
542                                 IIVD,
543                                 false,
544                                 Loc,
545                                 C.IntTy,
546                                 clang::VK_LValue,
547                                 nullptr);
548
549  clang::Expr *RefrsIntIterRValue =
550      clang::ImplicitCastExpr::Create(C,
551                                      RefrsIntIterLValue->getType(),
552                                      clang::CK_LValueToRValue,
553                                      RefrsIntIterLValue,
554                                      nullptr,
555                                      clang::VK_RValue);
556
557  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
558      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
559
560  clang::BinaryOperator *Cond =
561      new(C) clang::BinaryOperator(RefrsIntIterRValue,
562                                   NumArrayElementsExpr,
563                                   clang::BO_LT,
564                                   C.IntTy,
565                                   clang::VK_RValue,
566                                   clang::OK_Ordinary,
567                                   Loc,
568                                   false);
569
570  // Inc -> "rsIntIter++"
571  clang::UnaryOperator *Inc =
572      new(C) clang::UnaryOperator(RefrsIntIterLValue,
573                                  clang::UO_PostInc,
574                                  C.IntTy,
575                                  clang::VK_RValue,
576                                  clang::OK_Ordinary,
577                                  Loc);
578
579  // Body -> "rsClearObject(&VD[rsIntIter]);"
580  // Destructor loop operates on individual array elements
581
582  clang::Expr *RefRSArrPtr =
583      clang::ImplicitCastExpr::Create(C,
584          C.getPointerType(BaseType->getCanonicalTypeInternal()),
585          clang::CK_ArrayToPointerDecay,
586          RefRSArr,
587          nullptr,
588          clang::VK_RValue);
589
590  clang::Expr *RefRSArrPtrSubscript =
591      new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
592                                       RefrsIntIterRValue,
593                                       BaseType->getCanonicalTypeInternal(),
594                                       clang::VK_RValue,
595                                       clang::OK_Ordinary,
596                                       Loc);
597
598  DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
599
600  clang::Stmt *RSClearObjectCall = nullptr;
601  if (BaseType->isArrayType()) {
602    RSClearObjectCall =
603        ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
604  } else if (DT == DataTypeUnknown) {
605    RSClearObjectCall =
606        ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
607  } else {
608    RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
609  }
610
611  clang::ForStmt *DestructorLoop =
612      new(C) clang::ForStmt(C,
613                            Init,
614                            Cond,
615                            nullptr,  // no condVar
616                            Inc,
617                            RSClearObjectCall,
618                            Loc,
619                            Loc,
620                            Loc);
621
622  return DestructorLoop;
623}
624
625static unsigned CountRSObjectTypes(const clang::Type *T) {
626  slangAssert(T);
627  unsigned RSObjectCount = 0;
628
629  if (T->isArrayType()) {
630    return CountRSObjectTypes(T->getArrayElementTypeNoTypeQual());
631  }
632
633  DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
634  if (DT != DataTypeUnknown) {
635    return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
636  }
637
638  if (T->isUnionType()) {
639    clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
640    RD = RD->getDefinition();
641    for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
642           FE = RD->field_end();
643         FI != FE;
644         FI++) {
645      const clang::FieldDecl *FD = *FI;
646      const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
647      if (CountRSObjectTypes(FT)) {
648        slangAssert(false && "can't have unions with RS object types!");
649        return 0;
650      }
651    }
652  }
653
654  if (!T->isStructureType()) {
655    return 0;
656  }
657
658  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
659  RD = RD->getDefinition();
660  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
661         FE = RD->field_end();
662       FI != FE;
663       FI++) {
664    const clang::FieldDecl *FD = *FI;
665    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
666    if (CountRSObjectTypes(FT)) {
667      // Sub-structs should only count once (as should arrays, etc.)
668      RSObjectCount++;
669    }
670  }
671
672  return RSObjectCount;
673}
674
675static clang::Stmt *ClearStructRSObject(
676    clang::ASTContext &C,
677    clang::DeclContext *DC,
678    clang::Expr *RefRSStruct,
679    clang::SourceLocation StartLoc,
680    clang::SourceLocation Loc) {
681  const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
682
683  slangAssert(!BaseType->isArrayType());
684
685  // Structs should show up as unknown primitive types
686  slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
687              DataTypeUnknown);
688
689  unsigned FieldsToDestroy = CountRSObjectTypes(BaseType);
690  slangAssert(FieldsToDestroy != 0);
691
692  unsigned StmtCount = 0;
693  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
694  for (unsigned i = 0; i < FieldsToDestroy; i++) {
695    StmtArray[i] = nullptr;
696  }
697
698  // Populate StmtArray by creating a destructor for each RS object field
699  clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
700  RD = RD->getDefinition();
701  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
702         FE = RD->field_end();
703       FI != FE;
704       FI++) {
705    // We just look through all field declarations to see if we find a
706    // declaration for an RS object type (or an array of one).
707    bool IsArrayType = false;
708    clang::FieldDecl *FD = *FI;
709    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
710    const clang::Type *OrigType = FT;
711    while (FT && FT->isArrayType()) {
712      FT = FT->getArrayElementTypeNoTypeQual();
713      IsArrayType = true;
714    }
715
716    // Pass a DeclarationNameInfo with a valid DeclName, since name equality
717    // gets asserted during CodeGen.
718    clang::DeclarationNameInfo FDDeclNameInfo(FD->getDeclName(),
719                                              FD->getLocation());
720
721    if (RSExportPrimitiveType::IsRSObjectType(FT)) {
722      clang::DeclAccessPair FoundDecl =
723          clang::DeclAccessPair::make(FD, clang::AS_none);
724      clang::MemberExpr *RSObjectMember =
725          clang::MemberExpr::Create(C,
726                                    RefRSStruct,
727                                    false,
728                                    clang::SourceLocation(),
729                                    clang::NestedNameSpecifierLoc(),
730                                    clang::SourceLocation(),
731                                    FD,
732                                    FoundDecl,
733                                    FDDeclNameInfo,
734                                    nullptr,
735                                    OrigType->getCanonicalTypeInternal(),
736                                    clang::VK_RValue,
737                                    clang::OK_Ordinary);
738
739      slangAssert(StmtCount < FieldsToDestroy);
740
741      if (IsArrayType) {
742        StmtArray[StmtCount++] = ClearArrayRSObject(C,
743                                                    DC,
744                                                    RSObjectMember,
745                                                    StartLoc,
746                                                    Loc);
747      } else {
748        StmtArray[StmtCount++] = ClearSingleRSObject(C,
749                                                     RSObjectMember,
750                                                     Loc);
751      }
752    } else if (FT->isStructureType() && CountRSObjectTypes(FT)) {
753      // In this case, we have a nested struct. We may not end up filling all
754      // of the spaces in StmtArray (sub-structs should handle themselves
755      // with separate compound statements).
756      clang::DeclAccessPair FoundDecl =
757          clang::DeclAccessPair::make(FD, clang::AS_none);
758      clang::MemberExpr *RSObjectMember =
759          clang::MemberExpr::Create(C,
760                                    RefRSStruct,
761                                    false,
762                                    clang::SourceLocation(),
763                                    clang::NestedNameSpecifierLoc(),
764                                    clang::SourceLocation(),
765                                    FD,
766                                    FoundDecl,
767                                    clang::DeclarationNameInfo(),
768                                    nullptr,
769                                    OrigType->getCanonicalTypeInternal(),
770                                    clang::VK_RValue,
771                                    clang::OK_Ordinary);
772
773      if (IsArrayType) {
774        StmtArray[StmtCount++] = ClearArrayRSObject(C,
775                                                    DC,
776                                                    RSObjectMember,
777                                                    StartLoc,
778                                                    Loc);
779      } else {
780        StmtArray[StmtCount++] = ClearStructRSObject(C,
781                                                     DC,
782                                                     RSObjectMember,
783                                                     StartLoc,
784                                                     Loc);
785      }
786    }
787  }
788
789  slangAssert(StmtCount > 0);
790  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
791      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
792
793  delete [] StmtArray;
794
795  return CS;
796}
797
798static clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
799                                            clang::Expr *DstExpr,
800                                            clang::Expr *SrcExpr,
801                                            clang::SourceLocation StartLoc,
802                                            clang::SourceLocation Loc) {
803  const clang::Type *T = DstExpr->getType().getTypePtr();
804  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
805  slangAssert((SetObjectFD != nullptr) &&
806              "rsSetObject doesn't cover all RS object types");
807
808  clang::QualType SetObjectFDType = SetObjectFD->getType();
809  clang::QualType SetObjectFDArgType[2];
810  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
811  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
812
813  clang::Expr *RefRSSetObjectFD =
814      clang::DeclRefExpr::Create(C,
815                                 clang::NestedNameSpecifierLoc(),
816                                 clang::SourceLocation(),
817                                 SetObjectFD,
818                                 false,
819                                 Loc,
820                                 SetObjectFDType,
821                                 clang::VK_RValue,
822                                 nullptr);
823
824  clang::Expr *RSSetObjectFP =
825      clang::ImplicitCastExpr::Create(C,
826                                      C.getPointerType(SetObjectFDType),
827                                      clang::CK_FunctionToPointerDecay,
828                                      RefRSSetObjectFD,
829                                      nullptr,
830                                      clang::VK_RValue);
831
832  llvm::SmallVector<clang::Expr*, 2> ArgList;
833  ArgList.push_back(new(C) clang::UnaryOperator(DstExpr,
834                                                clang::UO_AddrOf,
835                                                SetObjectFDArgType[0],
836                                                clang::VK_RValue,
837                                                clang::OK_Ordinary,
838                                                Loc));
839  ArgList.push_back(SrcExpr);
840
841  clang::CallExpr *RSSetObjectCall =
842      new(C) clang::CallExpr(C,
843                             RSSetObjectFP,
844                             ArgList,
845                             SetObjectFD->getCallResultType(),
846                             clang::VK_RValue,
847                             Loc);
848
849  return RSSetObjectCall;
850}
851
852static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
853                                            clang::Expr *LHS,
854                                            clang::Expr *RHS,
855                                            clang::SourceLocation StartLoc,
856                                            clang::SourceLocation Loc);
857
858/*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
859                                           clang::Expr *DstArr,
860                                           clang::Expr *SrcArr,
861                                           clang::SourceLocation StartLoc,
862                                           clang::SourceLocation Loc) {
863  clang::DeclContext *DC = nullptr;
864  const clang::Type *BaseType = DstArr->getType().getTypePtr();
865  slangAssert(BaseType->isArrayType());
866
867  int NumArrayElements = ArrayDim(BaseType);
868  // Actually extract out the base RS object type for use later
869  BaseType = BaseType->getArrayElementTypeNoTypeQual();
870
871  clang::Stmt *StmtArray[2] = {nullptr};
872  int StmtCtr = 0;
873
874  if (NumArrayElements <= 0) {
875    return nullptr;
876  }
877
878  // Create helper variable for iterating through elements
879  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
880  clang::VarDecl *IIVD =
881      clang::VarDecl::Create(C,
882                             DC,
883                             StartLoc,
884                             Loc,
885                             &II,
886                             C.IntTy,
887                             C.getTrivialTypeSourceInfo(C.IntTy),
888                             clang::SC_None,
889                             clang::SC_None);
890  clang::Decl *IID = (clang::Decl *)IIVD;
891
892  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
893  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
894
895  // Form the actual loop
896  // for (Init; Cond; Inc)
897  //   RSSetObjectCall;
898
899  // Init -> "rsIntIter = 0"
900  clang::DeclRefExpr *RefrsIntIter =
901      clang::DeclRefExpr::Create(C,
902                                 clang::NestedNameSpecifierLoc(),
903                                 IIVD,
904                                 Loc,
905                                 C.IntTy,
906                                 clang::VK_RValue,
907                                 nullptr);
908
909  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
910      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
911
912  clang::BinaryOperator *Init =
913      new(C) clang::BinaryOperator(RefrsIntIter,
914                                   Int0,
915                                   clang::BO_Assign,
916                                   C.IntTy,
917                                   clang::VK_RValue,
918                                   clang::OK_Ordinary,
919                                   Loc);
920
921  // Cond -> "rsIntIter < NumArrayElements"
922  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
923      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
924
925  clang::BinaryOperator *Cond =
926      new(C) clang::BinaryOperator(RefrsIntIter,
927                                   NumArrayElementsExpr,
928                                   clang::BO_LT,
929                                   C.IntTy,
930                                   clang::VK_RValue,
931                                   clang::OK_Ordinary,
932                                   Loc);
933
934  // Inc -> "rsIntIter++"
935  clang::UnaryOperator *Inc =
936      new(C) clang::UnaryOperator(RefrsIntIter,
937                                  clang::UO_PostInc,
938                                  C.IntTy,
939                                  clang::VK_RValue,
940                                  clang::OK_Ordinary,
941                                  Loc);
942
943  // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
944  // Loop operates on individual array elements
945
946  clang::Expr *DstArrPtr =
947      clang::ImplicitCastExpr::Create(C,
948          C.getPointerType(BaseType->getCanonicalTypeInternal()),
949          clang::CK_ArrayToPointerDecay,
950          DstArr,
951          nullptr,
952          clang::VK_RValue);
953
954  clang::Expr *DstArrPtrSubscript =
955      new(C) clang::ArraySubscriptExpr(DstArrPtr,
956                                       RefrsIntIter,
957                                       BaseType->getCanonicalTypeInternal(),
958                                       clang::VK_RValue,
959                                       clang::OK_Ordinary,
960                                       Loc);
961
962  clang::Expr *SrcArrPtr =
963      clang::ImplicitCastExpr::Create(C,
964          C.getPointerType(BaseType->getCanonicalTypeInternal()),
965          clang::CK_ArrayToPointerDecay,
966          SrcArr,
967          nullptr,
968          clang::VK_RValue);
969
970  clang::Expr *SrcArrPtrSubscript =
971      new(C) clang::ArraySubscriptExpr(SrcArrPtr,
972                                       RefrsIntIter,
973                                       BaseType->getCanonicalTypeInternal(),
974                                       clang::VK_RValue,
975                                       clang::OK_Ordinary,
976                                       Loc);
977
978  DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
979
980  clang::Stmt *RSSetObjectCall = nullptr;
981  if (BaseType->isArrayType()) {
982    RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
983                                             SrcArrPtrSubscript,
984                                             StartLoc, Loc);
985  } else if (DT == DataTypeUnknown) {
986    RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
987                                              SrcArrPtrSubscript,
988                                              StartLoc, Loc);
989  } else {
990    RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
991                                              SrcArrPtrSubscript,
992                                              StartLoc, Loc);
993  }
994
995  clang::ForStmt *DestructorLoop =
996      new(C) clang::ForStmt(C,
997                            Init,
998                            Cond,
999                            nullptr,  // no condVar
1000                            Inc,
1001                            RSSetObjectCall,
1002                            Loc,
1003                            Loc,
1004                            Loc);
1005
1006  StmtArray[StmtCtr++] = DestructorLoop;
1007  slangAssert(StmtCtr == 2);
1008
1009  clang::CompoundStmt *CS =
1010      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
1011
1012  return CS;
1013} */
1014
1015static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
1016                                            clang::Expr *LHS,
1017                                            clang::Expr *RHS,
1018                                            clang::SourceLocation StartLoc,
1019                                            clang::SourceLocation Loc) {
1020  clang::QualType QT = LHS->getType();
1021  const clang::Type *T = QT.getTypePtr();
1022  slangAssert(T->isStructureType());
1023  slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
1024
1025  // Keep an extra slot for the original copy (memcpy)
1026  unsigned FieldsToSet = CountRSObjectTypes(T) + 1;
1027
1028  unsigned StmtCount = 0;
1029  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
1030  for (unsigned i = 0; i < FieldsToSet; i++) {
1031    StmtArray[i] = nullptr;
1032  }
1033
1034  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
1035  RD = RD->getDefinition();
1036  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1037         FE = RD->field_end();
1038       FI != FE;
1039       FI++) {
1040    bool IsArrayType = false;
1041    clang::FieldDecl *FD = *FI;
1042    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1043    const clang::Type *OrigType = FT;
1044
1045    if (!CountRSObjectTypes(FT)) {
1046      // Skip to next if we don't have any viable RS object types
1047      continue;
1048    }
1049
1050    clang::DeclAccessPair FoundDecl =
1051        clang::DeclAccessPair::make(FD, clang::AS_none);
1052    clang::MemberExpr *DstMember =
1053        clang::MemberExpr::Create(C,
1054                                  LHS,
1055                                  false,
1056                                  clang::SourceLocation(),
1057                                  clang::NestedNameSpecifierLoc(),
1058                                  clang::SourceLocation(),
1059                                  FD,
1060                                  FoundDecl,
1061                                  clang::DeclarationNameInfo(),
1062                                  nullptr,
1063                                  OrigType->getCanonicalTypeInternal(),
1064                                  clang::VK_RValue,
1065                                  clang::OK_Ordinary);
1066
1067    clang::MemberExpr *SrcMember =
1068        clang::MemberExpr::Create(C,
1069                                  RHS,
1070                                  false,
1071                                  clang::SourceLocation(),
1072                                  clang::NestedNameSpecifierLoc(),
1073                                  clang::SourceLocation(),
1074                                  FD,
1075                                  FoundDecl,
1076                                  clang::DeclarationNameInfo(),
1077                                  nullptr,
1078                                  OrigType->getCanonicalTypeInternal(),
1079                                  clang::VK_RValue,
1080                                  clang::OK_Ordinary);
1081
1082    if (FT->isArrayType()) {
1083      FT = FT->getArrayElementTypeNoTypeQual();
1084      IsArrayType = true;
1085    }
1086
1087    DataType DT = RSExportPrimitiveType::GetRSSpecificType(FT);
1088
1089    if (IsArrayType) {
1090      clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
1091      DiagEngine.Report(
1092        clang::FullSourceLoc(Loc, C.getSourceManager()),
1093        DiagEngine.getCustomDiagID(
1094          clang::DiagnosticsEngine::Error,
1095          "Arrays of RS object types within structures cannot be copied"));
1096      // TODO(srhines): Support setting arrays of RS objects
1097      // StmtArray[StmtCount++] =
1098      //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1099    } else if (DT == DataTypeUnknown) {
1100      StmtArray[StmtCount++] =
1101          CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1102    } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
1103      StmtArray[StmtCount++] =
1104          CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1105    } else {
1106      slangAssert(false);
1107    }
1108  }
1109
1110  slangAssert(StmtCount < FieldsToSet);
1111
1112  // We still need to actually do the overall struct copy. For simplicity,
1113  // we just do a straight-up assignment (which will still preserve all
1114  // the proper RS object reference counts).
1115  clang::BinaryOperator *CopyStruct =
1116      new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
1117                                   clang::VK_RValue, clang::OK_Ordinary, Loc,
1118                                   false);
1119  StmtArray[StmtCount++] = CopyStruct;
1120
1121  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
1122      C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
1123
1124  delete [] StmtArray;
1125
1126  return CS;
1127}
1128
1129}  // namespace
1130
1131void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1132    clang::BinaryOperator *AS) {
1133
1134  clang::QualType QT = AS->getType();
1135
1136  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1137      DataTypeRSAllocation)->getASTContext();
1138
1139  clang::SourceLocation Loc = AS->getExprLoc();
1140  clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
1141  clang::Stmt *UpdatedStmt = nullptr;
1142
1143  if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1144    // By definition, this is a struct assignment if we get here
1145    UpdatedStmt =
1146        CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1147  } else {
1148    UpdatedStmt =
1149        CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1150  }
1151
1152  RSASTReplace R(C);
1153  R.ReplaceStmt(mCS, AS, UpdatedStmt);
1154}
1155
1156void RSObjectRefCount::Scope::AppendRSObjectInit(
1157    clang::VarDecl *VD,
1158    clang::DeclStmt *DS,
1159    DataType DT,
1160    clang::Expr *InitExpr) {
1161  slangAssert(VD);
1162
1163  if (!InitExpr) {
1164    return;
1165  }
1166
1167  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1168      DataTypeRSAllocation)->getASTContext();
1169  clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1170      DataTypeRSAllocation)->getLocation();
1171  clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1172      DataTypeRSAllocation)->getInnerLocStart();
1173
1174  if (DT == DataTypeIsStruct) {
1175    const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1176    clang::DeclRefExpr *RefRSVar =
1177        clang::DeclRefExpr::Create(C,
1178                                   clang::NestedNameSpecifierLoc(),
1179                                   clang::SourceLocation(),
1180                                   VD,
1181                                   false,
1182                                   Loc,
1183                                   T->getCanonicalTypeInternal(),
1184                                   clang::VK_RValue,
1185                                   nullptr);
1186
1187    clang::Stmt *RSSetObjectOps =
1188        CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1189
1190    std::list<clang::Stmt*> StmtList;
1191    StmtList.push_back(RSSetObjectOps);
1192    AppendAfterStmt(C, mCS, DS, StmtList);
1193    return;
1194  }
1195
1196  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1197  slangAssert((SetObjectFD != nullptr) &&
1198              "rsSetObject doesn't cover all RS object types");
1199
1200  clang::QualType SetObjectFDType = SetObjectFD->getType();
1201  clang::QualType SetObjectFDArgType[2];
1202  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1203  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1204
1205  clang::Expr *RefRSSetObjectFD =
1206      clang::DeclRefExpr::Create(C,
1207                                 clang::NestedNameSpecifierLoc(),
1208                                 clang::SourceLocation(),
1209                                 SetObjectFD,
1210                                 false,
1211                                 Loc,
1212                                 SetObjectFDType,
1213                                 clang::VK_RValue,
1214                                 nullptr);
1215
1216  clang::Expr *RSSetObjectFP =
1217      clang::ImplicitCastExpr::Create(C,
1218                                      C.getPointerType(SetObjectFDType),
1219                                      clang::CK_FunctionToPointerDecay,
1220                                      RefRSSetObjectFD,
1221                                      nullptr,
1222                                      clang::VK_RValue);
1223
1224  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1225  clang::DeclRefExpr *RefRSVar =
1226      clang::DeclRefExpr::Create(C,
1227                                 clang::NestedNameSpecifierLoc(),
1228                                 clang::SourceLocation(),
1229                                 VD,
1230                                 false,
1231                                 Loc,
1232                                 T->getCanonicalTypeInternal(),
1233                                 clang::VK_RValue,
1234                                 nullptr);
1235
1236  llvm::SmallVector<clang::Expr*, 2> ArgList;
1237  ArgList.push_back(new(C) clang::UnaryOperator(RefRSVar,
1238                                                clang::UO_AddrOf,
1239                                                SetObjectFDArgType[0],
1240                                                clang::VK_RValue,
1241                                                clang::OK_Ordinary,
1242                                                Loc));
1243  ArgList.push_back(InitExpr);
1244
1245  clang::CallExpr *RSSetObjectCall =
1246      new(C) clang::CallExpr(C,
1247                             RSSetObjectFP,
1248                             ArgList,
1249                             SetObjectFD->getCallResultType(),
1250                             clang::VK_RValue,
1251                             Loc);
1252
1253  std::list<clang::Stmt*> StmtList;
1254  StmtList.push_back(RSSetObjectCall);
1255  AppendAfterStmt(C, mCS, DS, StmtList);
1256}
1257
1258void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1259  for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
1260          E = mRSO.end();
1261        I != E;
1262        I++) {
1263    clang::VarDecl *VD = *I;
1264    clang::Stmt *RSClearObjectCall = ClearRSObject(VD, VD->getDeclContext());
1265    if (RSClearObjectCall) {
1266      clang::ASTContext &C = (*mRSO.begin())->getASTContext();
1267      // Mark VD as used.  It might be unused, except for the destructor.
1268      // 'markUsed' has side-effects that are caused only if VD is not already
1269      // used.  Hence no need for an extra check here.
1270      VD->markUsed(C);
1271      DestructorVisitor DV(VD->getDeclContext(),
1272                           mCS,
1273                           RSClearObjectCall,
1274                           VD->getSourceRange().getBegin());
1275      DV.Visit(mCS);
1276      DV.InsertDestructors();
1277    }
1278  }
1279}
1280
1281clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
1282    clang::VarDecl *VD,
1283    clang::DeclContext *DC) {
1284  slangAssert(VD);
1285  clang::ASTContext &C = VD->getASTContext();
1286  clang::SourceLocation Loc = VD->getLocation();
1287  clang::SourceLocation StartLoc = VD->getInnerLocStart();
1288  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1289
1290  // Reference expr to target RS object variable
1291  clang::DeclRefExpr *RefRSVar =
1292      clang::DeclRefExpr::Create(C,
1293                                 clang::NestedNameSpecifierLoc(),
1294                                 clang::SourceLocation(),
1295                                 VD,
1296                                 false,
1297                                 Loc,
1298                                 T->getCanonicalTypeInternal(),
1299                                 clang::VK_RValue,
1300                                 nullptr);
1301
1302  if (T->isArrayType()) {
1303    return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1304  }
1305
1306  DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
1307
1308  if (DT == DataTypeUnknown ||
1309      DT == DataTypeIsStruct) {
1310    return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1311  }
1312
1313  slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1314              "Should be RS object");
1315
1316  return ClearSingleRSObject(C, RefRSVar, Loc);
1317}
1318
1319bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1320                                          DataType *DT,
1321                                          clang::Expr **InitExpr) {
1322  slangAssert(VD && DT && InitExpr);
1323  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1324
1325  // Loop through array types to get to base type
1326  while (T && T->isArrayType()) {
1327    T = T->getArrayElementTypeNoTypeQual();
1328  }
1329
1330  bool DataTypeIsStructWithRSObject = false;
1331  *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1332
1333  if (*DT == DataTypeUnknown) {
1334    if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1335      *DT = DataTypeIsStruct;
1336      DataTypeIsStructWithRSObject = true;
1337    } else {
1338      return false;
1339    }
1340  }
1341
1342  bool DataTypeIsRSObject = false;
1343  if (DataTypeIsStructWithRSObject) {
1344    DataTypeIsRSObject = true;
1345  } else {
1346    DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1347  }
1348  *InitExpr = VD->getInit();
1349
1350  if (!DataTypeIsRSObject && *InitExpr) {
1351    // If we already have an initializer for a matrix type, we are done.
1352    return DataTypeIsRSObject;
1353  }
1354
1355  clang::Expr *ZeroInitializer =
1356      CreateZeroInitializerForRSSpecificType(*DT,
1357                                             VD->getASTContext(),
1358                                             VD->getLocation());
1359
1360  if (ZeroInitializer) {
1361    ZeroInitializer->setType(T->getCanonicalTypeInternal());
1362    VD->setInit(ZeroInitializer);
1363  }
1364
1365  return DataTypeIsRSObject;
1366}
1367
1368clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
1369    DataType DT,
1370    clang::ASTContext &C,
1371    const clang::SourceLocation &Loc) {
1372  clang::Expr *Res = nullptr;
1373  switch (DT) {
1374    case DataTypeIsStruct:
1375    case DataTypeRSElement:
1376    case DataTypeRSType:
1377    case DataTypeRSAllocation:
1378    case DataTypeRSSampler:
1379    case DataTypeRSScript:
1380    case DataTypeRSMesh:
1381    case DataTypeRSPath:
1382    case DataTypeRSProgramFragment:
1383    case DataTypeRSProgramVertex:
1384    case DataTypeRSProgramRaster:
1385    case DataTypeRSProgramStore:
1386    case DataTypeRSFont: {
1387      //    (ImplicitCastExpr 'nullptr_t'
1388      //      (IntegerLiteral 0)))
1389      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
1390      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
1391      clang::Expr *CastToNull =
1392          clang::ImplicitCastExpr::Create(C,
1393                                          C.NullPtrTy,
1394                                          clang::CK_IntegralToPointer,
1395                                          Int0,
1396                                          nullptr,
1397                                          clang::VK_RValue);
1398
1399      llvm::SmallVector<clang::Expr*, 1>InitList;
1400      InitList.push_back(CastToNull);
1401
1402      Res = new(C) clang::InitListExpr(C, Loc, InitList, Loc);
1403      break;
1404    }
1405    case DataTypeRSMatrix2x2:
1406    case DataTypeRSMatrix3x3:
1407    case DataTypeRSMatrix4x4: {
1408      // RS matrix is not completely an RS object. They hold data by themselves.
1409      // (InitListExpr rs_matrix2x2
1410      //   (InitListExpr float[4]
1411      //     (FloatingLiteral 0)
1412      //     (FloatingLiteral 0)
1413      //     (FloatingLiteral 0)
1414      //     (FloatingLiteral 0)))
1415      clang::QualType FloatTy = C.FloatTy;
1416      // Constructor sets value to 0.0f by default
1417      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
1418      clang::FloatingLiteral *Float0Val =
1419          clang::FloatingLiteral::Create(C,
1420                                         Val,
1421                                         /* isExact = */true,
1422                                         FloatTy,
1423                                         Loc);
1424
1425      unsigned N = 0;
1426      if (DT == DataTypeRSMatrix2x2)
1427        N = 2;
1428      else if (DT == DataTypeRSMatrix3x3)
1429        N = 3;
1430      else if (DT == DataTypeRSMatrix4x4)
1431        N = 4;
1432      unsigned N_2 = N * N;
1433
1434      // Assume we are going to be allocating 16 elements, since 4x4 is max.
1435      llvm::SmallVector<clang::Expr*, 16> InitVals;
1436      for (unsigned i = 0; i < N_2; i++)
1437        InitVals.push_back(Float0Val);
1438      clang::Expr *InitExpr =
1439          new(C) clang::InitListExpr(C, Loc, InitVals, Loc);
1440      InitExpr->setType(C.getConstantArrayType(FloatTy,
1441                                               llvm::APInt(32, N_2),
1442                                               clang::ArrayType::Normal,
1443                                               /* EltTypeQuals = */0));
1444      llvm::SmallVector<clang::Expr*, 1> InitExprVec;
1445      InitExprVec.push_back(InitExpr);
1446
1447      Res = new(C) clang::InitListExpr(C, Loc, InitExprVec, Loc);
1448      break;
1449    }
1450    case DataTypeUnknown:
1451    case DataTypeFloat16:
1452    case DataTypeFloat32:
1453    case DataTypeFloat64:
1454    case DataTypeSigned8:
1455    case DataTypeSigned16:
1456    case DataTypeSigned32:
1457    case DataTypeSigned64:
1458    case DataTypeUnsigned8:
1459    case DataTypeUnsigned16:
1460    case DataTypeUnsigned32:
1461    case DataTypeUnsigned64:
1462    case DataTypeBoolean:
1463    case DataTypeUnsigned565:
1464    case DataTypeUnsigned5551:
1465    case DataTypeUnsigned4444:
1466    case DataTypeMax: {
1467      slangAssert(false && "Not RS object type!");
1468    }
1469    // No default case will enable compiler detecting the missing cases
1470  }
1471
1472  return Res;
1473}
1474
1475void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1476  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1477       I != E;
1478       I++) {
1479    clang::Decl *D = *I;
1480    if (D->getKind() == clang::Decl::Var) {
1481      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1482      DataType DT = DataTypeUnknown;
1483      clang::Expr *InitExpr = nullptr;
1484      if (InitializeRSObject(VD, &DT, &InitExpr)) {
1485        // We need to zero-init all RS object types (including matrices), ...
1486        getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1487        // ... but, only add to the list of RS objects if we have some
1488        // non-matrix RS object fields.
1489        if (CountRSObjectTypes(VD->getType().getTypePtr())) {
1490          getCurrentScope()->addRSObject(VD);
1491        }
1492      }
1493    }
1494  }
1495}
1496
1497void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1498  if (!CS->body_empty()) {
1499    // Push a new scope
1500    Scope *S = new Scope(CS);
1501    mScopeStack.push(S);
1502
1503    VisitStmt(CS);
1504
1505    // Destroy the scope
1506    slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1507    S->InsertLocalVarDestructors();
1508    mScopeStack.pop();
1509    delete S;
1510  }
1511}
1512
1513void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1514  clang::QualType QT = AS->getType();
1515
1516  if (CountRSObjectTypes(QT.getTypePtr())) {
1517    getCurrentScope()->ReplaceRSObjectAssignment(AS);
1518  }
1519}
1520
1521void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1522  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1523       I != E;
1524       I++) {
1525    if (clang::Stmt *Child = *I) {
1526      Visit(Child);
1527    }
1528  }
1529}
1530
1531// This function walks the list of global variables and (potentially) creates
1532// a single global static destructor function that properly decrements
1533// reference counts on the contained RS object types.
1534clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
1535  Init();
1536
1537  clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
1538  clang::SourceLocation loc;
1539
1540  llvm::StringRef SR(".rs.dtor");
1541  clang::IdentifierInfo &II = mCtx.Idents.get(SR);
1542  clang::DeclarationName N(&II);
1543  clang::FunctionProtoType::ExtProtoInfo EPI;
1544  clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy,
1545      llvm::ArrayRef<clang::QualType>(), EPI);
1546  clang::FunctionDecl *FD = nullptr;
1547
1548  // Generate rsClearObject() call chains for every global variable
1549  // (whether static or extern).
1550  std::list<clang::Stmt *> StmtList;
1551  for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
1552          E = DC->decls_end(); I != E; I++) {
1553    clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
1554    if (VD) {
1555      if (CountRSObjectTypes(VD->getType().getTypePtr())) {
1556        if (!FD) {
1557          // Only create FD if we are going to use it.
1558          FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, nullptr,
1559                                           clang::SC_None);
1560        }
1561        // Mark VD as used.  It might be unused, except for the destructor.
1562        // 'markUsed' has side-effects that are caused only if VD is not already
1563        // used.  Hence no need for an extra check here.
1564        VD->markUsed(mCtx);
1565        // Make sure to create any helpers within the function's DeclContext,
1566        // not the one associated with the global translation unit.
1567        clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
1568        StmtList.push_back(RSClearObjectCall);
1569      }
1570    }
1571  }
1572
1573  // Nothing needs to be destroyed, so don't emit a dtor.
1574  if (StmtList.empty()) {
1575    return nullptr;
1576  }
1577
1578  clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
1579
1580  FD->setBody(CS);
1581
1582  return FD;
1583}
1584
1585bool HasRSObjectType(const clang::Type *T) {
1586  return CountRSObjectTypes(T) != 0;
1587}
1588
1589}  // namespace slang
1590