slang_rs_object_ref_count.cpp revision 6e6578a360497f78a181e63d7783422a9c9bfb15
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_object_ref_count.h"
18
19#include <list>
20
21#include "clang/AST/DeclGroup.h"
22#include "clang/AST/Expr.h"
23#include "clang/AST/OperationKinds.h"
24#include "clang/AST/Stmt.h"
25#include "clang/AST/StmtVisitor.h"
26
27#include "slang_assert.h"
28#include "slang_rs.h"
29#include "slang_rs_export_type.h"
30
31namespace slang {
32
33clang::FunctionDecl *RSObjectRefCount::Scope::
34    RSSetObjectFD[RSExportPrimitiveType::LastRSObjectType -
35                  RSExportPrimitiveType::FirstRSObjectType + 1];
36clang::FunctionDecl *RSObjectRefCount::Scope::
37    RSClearObjectFD[RSExportPrimitiveType::LastRSObjectType -
38                    RSExportPrimitiveType::FirstRSObjectType + 1];
39
40void RSObjectRefCount::Scope::GetRSRefCountingFunctions(
41    clang::ASTContext &C) {
42  for (unsigned i = 0;
43       i < (sizeof(RSClearObjectFD) / sizeof(clang::FunctionDecl*));
44       i++) {
45    RSSetObjectFD[i] = NULL;
46    RSClearObjectFD[i] = NULL;
47  }
48
49  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
50
51  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
52          E = TUDecl->decls_end(); I != E; I++) {
53    if ((I->getKind() >= clang::Decl::firstFunction) &&
54        (I->getKind() <= clang::Decl::lastFunction)) {
55      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
56
57      // points to RSSetObjectFD or RSClearObjectFD
58      clang::FunctionDecl **RSObjectFD;
59
60      if (FD->getName() == "rsSetObject") {
61        slangAssert((FD->getNumParams() == 2) &&
62                    "Invalid rsSetObject function prototype (# params)");
63        RSObjectFD = RSSetObjectFD;
64      } else if (FD->getName() == "rsClearObject") {
65        slangAssert((FD->getNumParams() == 1) &&
66                    "Invalid rsClearObject function prototype (# params)");
67        RSObjectFD = RSClearObjectFD;
68      } else {
69        continue;
70      }
71
72      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
73      clang::QualType PVT = PVD->getOriginalType();
74      // The first parameter must be a pointer like rs_allocation*
75      slangAssert(PVT->isPointerType() &&
76          "Invalid rs{Set,Clear}Object function prototype (pointer param)");
77
78      // The rs object type passed to the FD
79      clang::QualType RST = PVT->getPointeeType();
80      RSExportPrimitiveType::DataType DT =
81          RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
82      slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
83             && "must be RS object type");
84
85      RSObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)] = FD;
86    }
87  }
88}
89
90namespace {
91
92static void AppendToCompoundStatement(clang::ASTContext& C,
93                                      clang::CompoundStmt *CS,
94                                      std::list<clang::Stmt*> &StmtList,
95                                      bool InsertAtEndOfBlock) {
96  // Destructor code will be inserted before any return statement.
97  // Any subsequent statements in the compound statement are then placed
98  // after our new code.
99  // TODO(srhines): This should also handle the case of goto/break/continue.
100
101  clang::CompoundStmt::body_iterator bI = CS->body_begin();
102
103  unsigned OldStmtCount = 0;
104  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
105    OldStmtCount++;
106  }
107
108  unsigned NewStmtCount = StmtList.size();
109
110  clang::Stmt **UpdatedStmtList;
111  UpdatedStmtList = new clang::Stmt*[OldStmtCount+NewStmtCount];
112
113  unsigned UpdatedStmtCount = 0;
114  bool FoundReturn = false;
115  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
116    if ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass) {
117      FoundReturn = true;
118      break;
119    }
120    UpdatedStmtList[UpdatedStmtCount++] = *bI;
121  }
122
123  // Always insert before a return that we found, or if we are told
124  // to insert at the end of the block
125  if (FoundReturn || InsertAtEndOfBlock) {
126    std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
127    for (std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
128         I != StmtList.end();
129         I++) {
130      UpdatedStmtList[UpdatedStmtCount++] = *I;
131    }
132  }
133
134  // Pick up anything left over after a return statement
135  for ( ; bI != CS->body_end(); bI++) {
136    UpdatedStmtList[UpdatedStmtCount++] = *bI;
137  }
138
139  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
140
141  delete [] UpdatedStmtList;
142
143  return;
144}
145
146static void AppendAfterStmt(clang::ASTContext& C,
147                            clang::CompoundStmt *CS,
148                            clang::Stmt *OldStmt,
149                            clang::Stmt *NewStmt) {
150  slangAssert(CS && OldStmt && NewStmt);
151  clang::CompoundStmt::body_iterator bI = CS->body_begin();
152  unsigned StmtCount = 1;  // Take into account new statement
153  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
154    StmtCount++;
155  }
156
157  clang::Stmt **UpdatedStmtList = new clang::Stmt*[StmtCount];
158
159  unsigned UpdatedStmtCount = 0;
160  unsigned Once = 0;
161  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
162    UpdatedStmtList[UpdatedStmtCount++] = *bI;
163    if (*bI == OldStmt) {
164      Once++;
165      slangAssert(Once == 1);
166      UpdatedStmtList[UpdatedStmtCount++] = NewStmt;
167    }
168  }
169
170  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
171
172  delete [] UpdatedStmtList;
173
174  return;
175}
176
177static void ReplaceInCompoundStmt(clang::ASTContext& C,
178                                  clang::CompoundStmt *CS,
179                                  clang::Stmt* OldStmt,
180                                  clang::Stmt* NewStmt) {
181  clang::CompoundStmt::body_iterator bI = CS->body_begin();
182
183  unsigned StmtCount = 0;
184  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
185    StmtCount++;
186  }
187
188  clang::Stmt **UpdatedStmtList = new clang::Stmt*[StmtCount];
189
190  unsigned UpdatedStmtCount = 0;
191  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
192    if (*bI == OldStmt) {
193      UpdatedStmtList[UpdatedStmtCount++] = NewStmt;
194    } else {
195      UpdatedStmtList[UpdatedStmtCount++] = *bI;
196    }
197  }
198
199  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
200
201  delete [] UpdatedStmtList;
202
203  return;
204}
205
206
207// This class visits a compound statement and inserts the StmtList containing
208// destructors in proper locations. This includes inserting them before any
209// return statement in any sub-block, at the end of the logical enclosing
210// scope (compound statement), and/or before any break/continue statement that
211// would resume outside the declared scope. We will not handle the case for
212// goto statements that leave a local scope.
213// TODO(srhines): Make this work properly for break/continue.
214class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
215 private:
216  clang::ASTContext &mC;
217  std::list<clang::Stmt*> &mStmtList;
218  bool mTopLevel;
219 public:
220  DestructorVisitor(clang::ASTContext &C, std::list<clang::Stmt*> &StmtList);
221  void VisitStmt(clang::Stmt *S);
222  void VisitCompoundStmt(clang::CompoundStmt *CS);
223};
224
225DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
226                                     std::list<clang::Stmt*> &StmtList)
227  : mC(C),
228    mStmtList(StmtList),
229    mTopLevel(true) {
230  return;
231}
232
233void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
234  if (!CS->body_empty()) {
235    AppendToCompoundStatement(mC, CS, mStmtList, mTopLevel);
236    mTopLevel = false;
237    VisitStmt(CS);
238  }
239  return;
240}
241
242void DestructorVisitor::VisitStmt(clang::Stmt *S) {
243  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
244       I != E;
245       I++) {
246    if (clang::Stmt *Child = *I) {
247      Visit(Child);
248    }
249  }
250  return;
251}
252
253static int ArrayDim(clang::VarDecl *VD) {
254  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
255
256  if (!T || !T->isArrayType()) {
257    return 0;
258  }
259
260  const clang::ConstantArrayType *CAT =
261    static_cast<const clang::ConstantArrayType *>(T);
262  return static_cast<int>(CAT->getSize().getSExtValue());
263}
264
265static clang::Stmt *ClearArrayRSObject(clang::VarDecl *VD,
266    const clang::Type *T,
267    clang::FunctionDecl *ClearObjectFD) {
268  clang::ASTContext &C = VD->getASTContext();
269  clang::SourceRange Range = VD->getQualifierRange();
270  clang::SourceLocation Loc = Range.getEnd();
271
272  clang::Stmt *StmtArray[2] = {NULL};
273  int StmtCtr = 0;
274
275  int NumArrayElements = ArrayDim(VD);
276  if (NumArrayElements <= 0) {
277    return NULL;
278  }
279
280  // Example destructor loop for "rs_font fontArr[10];"
281  //
282  // (CompoundStmt
283  //   (DeclStmt "int rsIntIter")
284  //   (ForStmt
285  //     (BinaryOperator 'int' '='
286  //       (DeclRefExpr 'int' Var='rsIntIter')
287  //       (IntegerLiteral 'int' 0))
288  //     (BinaryOperator 'int' '<'
289  //       (DeclRefExpr 'int' Var='rsIntIter')
290  //       (IntegerLiteral 'int' 10)
291  //     NULL << CondVar >>
292  //     (UnaryOperator 'int' postfix '++'
293  //       (DeclRefExpr 'int' Var='rsIntIter'))
294  //     (CallExpr 'void'
295  //       (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
296  //         (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
297  //       (UnaryOperator 'rs_font *' prefix '&'
298  //         (ArraySubscriptExpr 'rs_font':'rs_font'
299  //           (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
300  //             (DeclRefExpr 'rs_font [10]' Var='fontArr'))
301  //           (DeclRefExpr 'int' Var='rsIntIter')))))))
302
303  // Create helper variable for iterating through elements
304  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
305  clang::VarDecl *IIVD =
306      clang::VarDecl::Create(C,
307                             VD->getDeclContext(),
308                             Loc,
309                             &II,
310                             C.IntTy,
311                             C.getTrivialTypeSourceInfo(C.IntTy),
312                             clang::SC_None,
313                             clang::SC_None);
314  clang::Decl *IID = (clang::Decl *)IIVD;
315
316  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
317  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
318
319  // Form the actual destructor loop
320  // for (Init; Cond; Inc)
321  //   RSClearObjectCall;
322
323  // Init -> "rsIntIter = 0"
324  clang::DeclRefExpr *RefrsIntIter =
325      clang::DeclRefExpr::Create(C,
326                                 NULL,
327                                 Range,
328                                 IIVD,
329                                 Loc,
330                                 C.IntTy);
331
332  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
333      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
334
335  clang::BinaryOperator *Init =
336      new(C) clang::BinaryOperator(RefrsIntIter,
337                                   Int0,
338                                   clang::BO_Assign,
339                                   C.IntTy,
340                                   Loc);
341
342  // Cond -> "rsIntIter < NumArrayElements"
343  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
344      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
345
346  clang::BinaryOperator *Cond =
347      new(C) clang::BinaryOperator(RefrsIntIter,
348                                   NumArrayElementsExpr,
349                                   clang::BO_LT,
350                                   C.IntTy,
351                                   Loc);
352
353  // Inc -> "rsIntIter++"
354  clang::UnaryOperator *Inc =
355      new(C) clang::UnaryOperator(RefrsIntIter,
356                                  clang::UO_PostInc,
357                                  C.IntTy,
358                                  Loc);
359
360  // Body -> "rsClearObject(&VD[rsIntIter]);"
361  // Destructor loop operates on individual array elements
362  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
363  clang::QualType ClearObjectFDArgType =
364      ClearObjectFD->getParamDecl(0)->getOriginalType();
365
366  const clang::Type *VT = RSExportType::GetTypeOfDecl(VD);
367  clang::DeclRefExpr *RefRSVar =
368      clang::DeclRefExpr::Create(C,
369                                 NULL,
370                                 Range,
371                                 VD,
372                                 Loc,
373                                 VT->getCanonicalTypeInternal());
374
375  clang::Expr *RefRSVarPtr =
376      clang::ImplicitCastExpr::Create(C,
377          C.getPointerType(T->getCanonicalTypeInternal()),
378          clang::CK_ArrayToPointerDecay,
379          RefRSVar,
380          NULL,
381          clang::VK_RValue);
382
383  clang::Expr *RefRSVarPtrSubscript =
384      new(C) clang::ArraySubscriptExpr(RefRSVarPtr,
385                                       RefrsIntIter,
386                                       T->getCanonicalTypeInternal(),
387                                       VD->getLocation());
388
389  clang::Expr *AddrRefRSVarPtrSubscript =
390      new(C) clang::UnaryOperator(RefRSVarPtrSubscript,
391                                  clang::UO_AddrOf,
392                                  ClearObjectFDArgType,
393                                  VD->getLocation());
394
395  clang::Expr *RefRSClearObjectFD =
396      clang::DeclRefExpr::Create(C,
397                                 NULL,
398                                 Range,
399                                 ClearObjectFD,
400                                 Loc,
401                                 ClearObjectFDType);
402
403  clang::Expr *RSClearObjectFP =
404      clang::ImplicitCastExpr::Create(C,
405                                      C.getPointerType(ClearObjectFDType),
406                                      clang::CK_FunctionToPointerDecay,
407                                      RefRSClearObjectFD,
408                                      NULL,
409                                      clang::VK_RValue);
410
411  clang::CallExpr *RSClearObjectCall =
412      new(C) clang::CallExpr(C,
413                             RSClearObjectFP,
414                             &AddrRefRSVarPtrSubscript,
415                             1,
416                             ClearObjectFD->getCallResultType(),
417                             Loc);
418
419  clang::ForStmt *DestructorLoop =
420      new(C) clang::ForStmt(C,
421                            Init,
422                            Cond,
423                            NULL,  // no condVar
424                            Inc,
425                            RSClearObjectCall,
426                            Loc,
427                            Loc,
428                            Loc);
429
430  StmtArray[StmtCtr++] = DestructorLoop;
431  slangAssert(StmtCtr == 2);
432
433  clang::CompoundStmt *CS =
434      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
435
436  return CS;
437}
438
439}  // namespace
440
441void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
442    clang::BinaryOperator *AS) {
443
444  clang::QualType QT = AS->getType();
445  RSExportPrimitiveType::DataType DT =
446      RSExportPrimitiveType::GetRSSpecificType(QT.getTypePtr());
447
448  clang::FunctionDecl *SetObjectFD =
449      RSSetObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)];
450  slangAssert((SetObjectFD != NULL) &&
451              "rsSetObject doesn't cover all RS object types");
452  clang::ASTContext &C = SetObjectFD->getASTContext();
453
454  clang::QualType SetObjectFDType = SetObjectFD->getType();
455  clang::QualType SetObjectFDArgType[2];
456  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
457  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
458
459  clang::SourceLocation Loc = SetObjectFD->getLocation();
460  clang::Expr *RefRSSetObjectFD =
461      clang::DeclRefExpr::Create(C,
462                                 NULL,
463                                 SetObjectFD->getQualifierRange(),
464                                 SetObjectFD,
465                                 Loc,
466                                 SetObjectFDType);
467
468  clang::Expr *RSSetObjectFP =
469      clang::ImplicitCastExpr::Create(C,
470                                      C.getPointerType(SetObjectFDType),
471                                      clang::CK_FunctionToPointerDecay,
472                                      RefRSSetObjectFD,
473                                      NULL,
474                                      clang::VK_RValue);
475
476  clang::Expr *ArgList[2];
477  ArgList[0] = new(C) clang::UnaryOperator(AS->getLHS(),
478                                           clang::UO_AddrOf,
479                                           SetObjectFDArgType[0],
480                                           Loc);
481  ArgList[1] = AS->getRHS();
482
483  clang::CallExpr *RSSetObjectCall =
484      new(C) clang::CallExpr(C,
485                             RSSetObjectFP,
486                             ArgList,
487                             2,
488                             SetObjectFD->getCallResultType(),
489                             Loc);
490
491  ReplaceInCompoundStmt(C, mCS, AS, RSSetObjectCall);
492
493  return;
494}
495
496void RSObjectRefCount::Scope::AppendRSObjectInit(
497    clang::VarDecl *VD,
498    clang::DeclStmt *DS,
499    RSExportPrimitiveType::DataType DT,
500    clang::Expr *InitExpr) {
501  slangAssert(VD);
502
503  if (!InitExpr) {
504    return;
505  }
506
507  clang::FunctionDecl *SetObjectFD =
508      RSSetObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)];
509  slangAssert((SetObjectFD != NULL) &&
510              "rsSetObject doesn't cover all RS object types");
511  clang::ASTContext &C = SetObjectFD->getASTContext();
512
513  clang::QualType SetObjectFDType = SetObjectFD->getType();
514  clang::QualType SetObjectFDArgType[2];
515  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
516  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
517
518  clang::SourceLocation Loc = SetObjectFD->getLocation();
519  clang::Expr *RefRSSetObjectFD =
520      clang::DeclRefExpr::Create(C,
521                                 NULL,
522                                 SetObjectFD->getQualifierRange(),
523                                 SetObjectFD,
524                                 Loc,
525                                 SetObjectFDType);
526
527  clang::Expr *RSSetObjectFP =
528      clang::ImplicitCastExpr::Create(C,
529                                      C.getPointerType(SetObjectFDType),
530                                      clang::CK_FunctionToPointerDecay,
531                                      RefRSSetObjectFD,
532                                      NULL,
533                                      clang::VK_RValue);
534
535  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
536  clang::DeclRefExpr *RefRSVar =
537      clang::DeclRefExpr::Create(C,
538                                 NULL,
539                                 VD->getQualifierRange(),
540                                 VD,
541                                 Loc,
542                                 T->getCanonicalTypeInternal());
543
544  clang::Expr *ArgList[2];
545  ArgList[0] = new(C) clang::UnaryOperator(RefRSVar,
546                                           clang::UO_AddrOf,
547                                           SetObjectFDArgType[0],
548                                           Loc);
549  ArgList[1] = InitExpr;
550
551  clang::CallExpr *RSSetObjectCall =
552      new(C) clang::CallExpr(C,
553                             RSSetObjectFP,
554                             ArgList,
555                             2,
556                             SetObjectFD->getCallResultType(),
557                             Loc);
558
559  AppendAfterStmt(C, mCS, DS, RSSetObjectCall);
560
561  return;
562}
563
564void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
565  std::list<clang::Stmt*> RSClearObjectCalls;
566  for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
567          E = mRSO.end();
568        I != E;
569        I++) {
570    clang::Stmt *S = ClearRSObject(*I);
571    if (S) {
572      RSClearObjectCalls.push_back(S);
573    }
574  }
575  if (RSClearObjectCalls.size() > 0) {
576    DestructorVisitor DV((*mRSO.begin())->getASTContext(), RSClearObjectCalls);
577    DV.Visit(mCS);
578  }
579  return;
580}
581
582clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(clang::VarDecl *VD) {
583  bool IsArrayType = false;
584  clang::ASTContext &C = VD->getASTContext();
585  clang::SourceLocation Loc = VD->getLocation();
586  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
587
588  // Loop through array types to get to base type
589  while (T && T->isArrayType()) {
590    T = T->getArrayElementTypeNoTypeQual();
591    IsArrayType = true;
592  }
593
594  RSExportPrimitiveType::DataType DT =
595      RSExportPrimitiveType::GetRSSpecificType(T);
596
597  slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
598              "Should be RS object");
599
600  // Find the rsClearObject() for VD of RS object type DT
601  clang::FunctionDecl *ClearObjectFD =
602      RSClearObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)];
603  slangAssert((ClearObjectFD != NULL) &&
604              "rsClearObject doesn't cover all RS object types");
605
606  if (IsArrayType) {
607    return ClearArrayRSObject(VD, T, ClearObjectFD);
608  }
609
610  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
611  clang::QualType ClearObjectFDArgType =
612      ClearObjectFD->getParamDecl(0)->getOriginalType();
613
614  // Example destructor for "rs_font localFont;"
615  //
616  // (CallExpr 'void'
617  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
618  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
619  //   (UnaryOperator 'rs_font *' prefix '&'
620  //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
621
622  // Reference expr to target RS object variable
623  clang::DeclRefExpr *RefRSVar =
624      clang::DeclRefExpr::Create(C,
625                                 NULL,
626                                 VD->getQualifierRange(),
627                                 VD,
628                                 Loc,
629                                 T->getCanonicalTypeInternal());
630
631  // Get address of RSObject in VD
632  clang::Expr *AddrRefRSVar =
633      new(C) clang::UnaryOperator(RefRSVar,
634                                  clang::UO_AddrOf,
635                                  ClearObjectFDArgType,
636                                  Loc);
637
638  clang::Expr *RefRSClearObjectFD =
639      clang::DeclRefExpr::Create(C,
640                                 NULL,
641                                 ClearObjectFD->getQualifierRange(),
642                                 ClearObjectFD,
643                                 ClearObjectFD->getLocation(),
644                                 ClearObjectFDType);
645
646  clang::Expr *RSClearObjectFP =
647      clang::ImplicitCastExpr::Create(C,
648                                      C.getPointerType(ClearObjectFDType),
649                                      clang::CK_FunctionToPointerDecay,
650                                      RefRSClearObjectFD,
651                                      NULL,
652                                      clang::VK_RValue);
653
654  clang::CallExpr *RSClearObjectCall =
655      new(C) clang::CallExpr(C,
656                             RSClearObjectFP,
657                             &AddrRefRSVar,
658                             1,
659                             ClearObjectFD->getCallResultType(),
660                             clang::SourceLocation());
661
662  return RSClearObjectCall;
663}
664
665bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
666                                          RSExportPrimitiveType::DataType *DT,
667                                          clang::Expr **InitExpr) {
668  slangAssert(VD && DT && InitExpr);
669  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
670
671  // Loop through array types to get to base type
672  while (T && T->isArrayType()) {
673    T = T->getArrayElementTypeNoTypeQual();
674  }
675
676  *DT = RSExportPrimitiveType::GetRSSpecificType(T);
677
678  if (*DT == RSExportPrimitiveType::DataTypeUnknown) {
679    if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
680      *DT = RSExportPrimitiveType::DataTypeIsStruct;
681    } else {
682      return false;
683    }
684  }
685
686  bool DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
687  *InitExpr = VD->getInit();
688
689  if (!DataTypeIsRSObject && *InitExpr) {
690    // If we already have an initializer for a matrix type, we are done.
691    return DataTypeIsRSObject;
692  }
693
694  clang::Expr *ZeroInitializer =
695      CreateZeroInitializerForRSSpecificType(*DT,
696                                             VD->getASTContext(),
697                                             VD->getLocation());
698
699  if (ZeroInitializer) {
700    ZeroInitializer->setType(T->getCanonicalTypeInternal());
701    VD->setInit(ZeroInitializer);
702  }
703
704  return DataTypeIsRSObject;
705}
706
707clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
708    RSExportPrimitiveType::DataType DT,
709    clang::ASTContext &C,
710    const clang::SourceLocation &Loc) {
711  clang::Expr *Res = NULL;
712  switch (DT) {
713    case RSExportPrimitiveType::DataTypeIsStruct:
714    case RSExportPrimitiveType::DataTypeRSElement:
715    case RSExportPrimitiveType::DataTypeRSType:
716    case RSExportPrimitiveType::DataTypeRSAllocation:
717    case RSExportPrimitiveType::DataTypeRSSampler:
718    case RSExportPrimitiveType::DataTypeRSScript:
719    case RSExportPrimitiveType::DataTypeRSMesh:
720    case RSExportPrimitiveType::DataTypeRSProgramFragment:
721    case RSExportPrimitiveType::DataTypeRSProgramVertex:
722    case RSExportPrimitiveType::DataTypeRSProgramRaster:
723    case RSExportPrimitiveType::DataTypeRSProgramStore:
724    case RSExportPrimitiveType::DataTypeRSFont: {
725      //    (ImplicitCastExpr 'nullptr_t'
726      //      (IntegerLiteral 0)))
727      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
728      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
729      clang::Expr *CastToNull =
730          clang::ImplicitCastExpr::Create(C,
731                                          C.NullPtrTy,
732                                          clang::CK_IntegralToPointer,
733                                          Int0,
734                                          NULL,
735                                          clang::VK_RValue);
736
737      Res = new(C) clang::InitListExpr(C, Loc, &CastToNull, 1, Loc);
738      break;
739    }
740    case RSExportPrimitiveType::DataTypeRSMatrix2x2:
741    case RSExportPrimitiveType::DataTypeRSMatrix3x3:
742    case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
743      // RS matrix is not completely an RS object. They hold data by themselves.
744      // (InitListExpr rs_matrix2x2
745      //   (InitListExpr float[4]
746      //     (FloatingLiteral 0)
747      //     (FloatingLiteral 0)
748      //     (FloatingLiteral 0)
749      //     (FloatingLiteral 0)))
750      clang::QualType FloatTy = C.FloatTy;
751      // Constructor sets value to 0.0f by default
752      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
753      clang::FloatingLiteral *Float0Val =
754          clang::FloatingLiteral::Create(C,
755                                         Val,
756                                         /* isExact = */true,
757                                         FloatTy,
758                                         Loc);
759
760      unsigned N = 0;
761      if (DT == RSExportPrimitiveType::DataTypeRSMatrix2x2)
762        N = 2;
763      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix3x3)
764        N = 3;
765      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix4x4)
766        N = 4;
767
768      // Directly allocate 16 elements instead of dynamically allocate N*N
769      clang::Expr *InitVals[16];
770      for (unsigned i = 0; i < sizeof(InitVals) / sizeof(InitVals[0]); i++)
771        InitVals[i] = Float0Val;
772      clang::Expr *InitExpr =
773          new(C) clang::InitListExpr(C, Loc, InitVals, N * N, Loc);
774      InitExpr->setType(C.getConstantArrayType(FloatTy,
775                                               llvm::APInt(32, 4),
776                                               clang::ArrayType::Normal,
777                                               /* EltTypeQuals = */0));
778
779      Res = new(C) clang::InitListExpr(C, Loc, &InitExpr, 1, Loc);
780      break;
781    }
782    case RSExportPrimitiveType::DataTypeUnknown:
783    case RSExportPrimitiveType::DataTypeFloat16:
784    case RSExportPrimitiveType::DataTypeFloat32:
785    case RSExportPrimitiveType::DataTypeFloat64:
786    case RSExportPrimitiveType::DataTypeSigned8:
787    case RSExportPrimitiveType::DataTypeSigned16:
788    case RSExportPrimitiveType::DataTypeSigned32:
789    case RSExportPrimitiveType::DataTypeSigned64:
790    case RSExportPrimitiveType::DataTypeUnsigned8:
791    case RSExportPrimitiveType::DataTypeUnsigned16:
792    case RSExportPrimitiveType::DataTypeUnsigned32:
793    case RSExportPrimitiveType::DataTypeUnsigned64:
794    case RSExportPrimitiveType::DataTypeBoolean:
795    case RSExportPrimitiveType::DataTypeUnsigned565:
796    case RSExportPrimitiveType::DataTypeUnsigned5551:
797    case RSExportPrimitiveType::DataTypeUnsigned4444:
798    case RSExportPrimitiveType::DataTypeMax: {
799      slangAssert(false && "Not RS object type!");
800    }
801    // No default case will enable compiler detecting the missing cases
802  }
803
804  return Res;
805}
806
807void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
808  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
809       I != E;
810       I++) {
811    clang::Decl *D = *I;
812    if (D->getKind() == clang::Decl::Var) {
813      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
814      RSExportPrimitiveType::DataType DT =
815          RSExportPrimitiveType::DataTypeUnknown;
816      clang::Expr *InitExpr = NULL;
817      if (InitializeRSObject(VD, &DT, &InitExpr)) {
818        getCurrentScope()->addRSObject(VD);
819        getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
820      }
821    }
822  }
823  return;
824}
825
826void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
827  if (!CS->body_empty()) {
828    // Push a new scope
829    Scope *S = new Scope(CS);
830    mScopeStack.push(S);
831
832    VisitStmt(CS);
833
834    // Destroy the scope
835    slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
836    S->InsertLocalVarDestructors();
837    mScopeStack.pop();
838    delete S;
839  }
840  return;
841}
842
843void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
844  clang::QualType QT = AS->getType();
845  RSExportPrimitiveType::DataType DT =
846      RSExportPrimitiveType::GetRSSpecificType(QT.getTypePtr());
847
848  if (RSExportPrimitiveType::IsRSObjectType(DT)) {
849    getCurrentScope()->ReplaceRSObjectAssignment(AS);
850  }
851
852  return;
853}
854
855void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
856  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
857       I != E;
858       I++) {
859    if (clang::Stmt *Child = *I) {
860      Visit(Child);
861    }
862  }
863  return;
864}
865
866}  // namespace slang
867