slang_rs_object_ref_count.cpp revision feaca06fcb0772e9e972a0d61b17259fc5124d50
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_object_ref_count.h"
18
19#include <list>
20
21#include "clang/AST/DeclGroup.h"
22#include "clang/AST/Expr.h"
23#include "clang/AST/OperationKinds.h"
24#include "clang/AST/Stmt.h"
25#include "clang/AST/StmtVisitor.h"
26
27#include "slang_rs.h"
28#include "slang_rs_export_type.h"
29
30namespace slang {
31
32clang::FunctionDecl *RSObjectRefCount::Scope::
33    RSSetObjectFD[RSExportPrimitiveType::LastRSObjectType -
34                  RSExportPrimitiveType::FirstRSObjectType + 1];
35clang::FunctionDecl *RSObjectRefCount::Scope::
36    RSClearObjectFD[RSExportPrimitiveType::LastRSObjectType -
37                    RSExportPrimitiveType::FirstRSObjectType + 1];
38
39void RSObjectRefCount::Scope::GetRSRefCountingFunctions(
40    clang::ASTContext &C) {
41  for (unsigned i = 0;
42       i < (sizeof(RSClearObjectFD) / sizeof(clang::FunctionDecl*));
43       i++) {
44    RSSetObjectFD[i] = NULL;
45    RSClearObjectFD[i] = NULL;
46  }
47
48  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
49
50  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
51          E = TUDecl->decls_end(); I != E; I++) {
52    if ((I->getKind() >= clang::Decl::firstFunction) &&
53        (I->getKind() <= clang::Decl::lastFunction)) {
54      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
55
56      // points to RSSetObjectFD or RSClearObjectFD
57      clang::FunctionDecl **RSObjectFD;
58
59      if (FD->getName() == "rsSetObject") {
60        assert((FD->getNumParams() == 2) &&
61               "Invalid rsSetObject function prototype (# params)");
62        RSObjectFD = RSSetObjectFD;
63      } else if (FD->getName() == "rsClearObject") {
64        assert((FD->getNumParams() == 1) &&
65               "Invalid rsClearObject function prototype (# params)");
66        RSObjectFD = RSClearObjectFD;
67      } else {
68        continue;
69      }
70
71      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
72      clang::QualType PVT = PVD->getOriginalType();
73      // The first parameter must be a pointer like rs_allocation*
74      assert(PVT->isPointerType() &&
75             "Invalid rs{Set,Clear}Object function prototype (pointer param)");
76
77      // The rs object type passed to the FD
78      clang::QualType RST = PVT->getPointeeType();
79      RSExportPrimitiveType::DataType DT =
80          RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
81      assert(RSExportPrimitiveType::IsRSObjectType(DT)
82             && "must be RS object type");
83
84      RSObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)] = FD;
85    }
86  }
87}
88
89namespace {
90
91static void AppendToCompoundStatement(clang::ASTContext& C,
92                                      clang::CompoundStmt *CS,
93                                      std::list<clang::Stmt*> &StmtList,
94                                      bool InsertAtEndOfBlock) {
95  // Destructor code will be inserted before any return statement.
96  // Any subsequent statements in the compound statement are then placed
97  // after our new code.
98  // TODO(srhines): This should also handle the case of goto/break/continue.
99
100  clang::CompoundStmt::body_iterator bI = CS->body_begin();
101
102  unsigned OldStmtCount = 0;
103  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
104    OldStmtCount++;
105  }
106
107  unsigned NewStmtCount = StmtList.size();
108
109  clang::Stmt **UpdatedStmtList;
110  UpdatedStmtList = new clang::Stmt*[OldStmtCount+NewStmtCount];
111
112  unsigned UpdatedStmtCount = 0;
113  bool FoundReturn = false;
114  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
115    if ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass) {
116      FoundReturn = true;
117      break;
118    }
119    UpdatedStmtList[UpdatedStmtCount++] = *bI;
120  }
121
122  // Always insert before a return that we found, or if we are told
123  // to insert at the end of the block
124  if (FoundReturn || InsertAtEndOfBlock) {
125    std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
126    for (std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
127         I != StmtList.end();
128         I++) {
129      UpdatedStmtList[UpdatedStmtCount++] = *I;
130    }
131  }
132
133  // Pick up anything left over after a return statement
134  for ( ; bI != CS->body_end(); bI++) {
135    UpdatedStmtList[UpdatedStmtCount++] = *bI;
136  }
137
138  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
139
140  delete [] UpdatedStmtList;
141
142  return;
143}
144
145static void AppendAfterStmt(clang::ASTContext& C,
146                            clang::CompoundStmt *CS,
147                            clang::Stmt *OldStmt,
148                            clang::Stmt *NewStmt) {
149  assert(CS && OldStmt && NewStmt);
150  clang::CompoundStmt::body_iterator bI = CS->body_begin();
151  unsigned StmtCount = 1;  // Take into account new statement
152  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
153    StmtCount++;
154  }
155
156  clang::Stmt **UpdatedStmtList = new clang::Stmt*[StmtCount];
157
158  unsigned UpdatedStmtCount = 0;
159  unsigned Once = 0;
160  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
161    UpdatedStmtList[UpdatedStmtCount++] = *bI;
162    if (*bI == OldStmt) {
163      Once++;
164      assert(Once == 1);
165      UpdatedStmtList[UpdatedStmtCount++] = NewStmt;
166    }
167  }
168
169  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
170
171  delete [] UpdatedStmtList;
172
173  return;
174}
175
176static void ReplaceInCompoundStmt(clang::ASTContext& C,
177                                  clang::CompoundStmt *CS,
178                                  clang::Stmt* OldStmt,
179                                  clang::Stmt* NewStmt) {
180  clang::CompoundStmt::body_iterator bI = CS->body_begin();
181
182  unsigned StmtCount = 0;
183  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
184    StmtCount++;
185  }
186
187  clang::Stmt **UpdatedStmtList = new clang::Stmt*[StmtCount];
188
189  unsigned UpdatedStmtCount = 0;
190  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
191    if (*bI == OldStmt) {
192      UpdatedStmtList[UpdatedStmtCount++] = NewStmt;
193    } else {
194      UpdatedStmtList[UpdatedStmtCount++] = *bI;
195    }
196  }
197
198  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
199
200  delete [] UpdatedStmtList;
201
202  return;
203}
204
205
206// This class visits a compound statement and inserts the StmtList containing
207// destructors in proper locations. This includes inserting them before any
208// return statement in any sub-block, at the end of the logical enclosing
209// scope (compound statement), and/or before any break/continue statement that
210// would resume outside the declared scope. We will not handle the case for
211// goto statements that leave a local scope.
212// TODO(srhines): Make this work properly for break/continue.
213class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
214 private:
215  clang::ASTContext &mC;
216  std::list<clang::Stmt*> &mStmtList;
217  bool mTopLevel;
218 public:
219  DestructorVisitor(clang::ASTContext &C, std::list<clang::Stmt*> &StmtList);
220  void VisitStmt(clang::Stmt *S);
221  void VisitCompoundStmt(clang::CompoundStmt *CS);
222};
223
224DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
225                                     std::list<clang::Stmt*> &StmtList)
226  : mC(C),
227    mStmtList(StmtList),
228    mTopLevel(true) {
229  return;
230}
231
232void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
233  if (!CS->body_empty()) {
234    AppendToCompoundStatement(mC, CS, mStmtList, mTopLevel);
235    mTopLevel = false;
236    VisitStmt(CS);
237  }
238  return;
239}
240
241void DestructorVisitor::VisitStmt(clang::Stmt *S) {
242  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
243       I != E;
244       I++) {
245    if (clang::Stmt *Child = *I) {
246      Visit(Child);
247    }
248  }
249  return;
250}
251
252static int ArrayDim(clang::VarDecl *VD) {
253  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
254
255  if (!T || !T->isArrayType()) {
256    return 0;
257  }
258
259  const clang::ConstantArrayType *CAT =
260    static_cast<const clang::ConstantArrayType *>(T);
261  return static_cast<int>(CAT->getSize().getSExtValue());
262}
263
264static clang::Stmt *ClearArrayRSObject(clang::VarDecl *VD,
265    const clang::Type *T,
266    clang::FunctionDecl *ClearObjectFD) {
267  clang::ASTContext &C = VD->getASTContext();
268  clang::SourceRange Range = VD->getQualifierRange();
269  clang::SourceLocation Loc = Range.getEnd();
270
271  clang::Stmt *StmtArray[2] = {NULL};
272  int StmtCtr = 0;
273
274  int NumArrayElements = ArrayDim(VD);
275  if (NumArrayElements <= 0) {
276    return NULL;
277  }
278
279  // Example destructor loop for "rs_font fontArr[10];"
280  //
281  // (CompoundStmt
282  //   (DeclStmt "int rsIntIter")
283  //   (ForStmt
284  //     (BinaryOperator 'int' '='
285  //       (DeclRefExpr 'int' Var='rsIntIter')
286  //       (IntegerLiteral 'int' 0))
287  //     (BinaryOperator 'int' '<'
288  //       (DeclRefExpr 'int' Var='rsIntIter')
289  //       (IntegerLiteral 'int' 10)
290  //     NULL << CondVar >>
291  //     (UnaryOperator 'int' postfix '++'
292  //       (DeclRefExpr 'int' Var='rsIntIter'))
293  //     (CallExpr 'void'
294  //       (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
295  //         (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
296  //       (UnaryOperator 'rs_font *' prefix '&'
297  //         (ArraySubscriptExpr 'rs_font':'rs_font'
298  //           (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
299  //             (DeclRefExpr 'rs_font [10]' Var='fontArr'))
300  //           (DeclRefExpr 'int' Var='rsIntIter')))))))
301
302  // Create helper variable for iterating through elements
303  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
304  clang::VarDecl *IIVD =
305      clang::VarDecl::Create(C,
306                             VD->getDeclContext(),
307                             Loc,
308                             &II,
309                             C.IntTy,
310                             C.getTrivialTypeSourceInfo(C.IntTy),
311                             clang::SC_None,
312                             clang::SC_None);
313  clang::Decl *IID = (clang::Decl *)IIVD;
314
315  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
316  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
317
318  // Form the actual destructor loop
319  // for (Init; Cond; Inc)
320  //   RSClearObjectCall;
321
322  // Init -> "rsIntIter = 0"
323  clang::DeclRefExpr *RefrsIntIter =
324      clang::DeclRefExpr::Create(C,
325                                 NULL,
326                                 Range,
327                                 IIVD,
328                                 Loc,
329                                 C.IntTy);
330
331  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
332      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
333
334  clang::BinaryOperator *Init =
335      new(C) clang::BinaryOperator(RefrsIntIter,
336                                   Int0,
337                                   clang::BO_Assign,
338                                   C.IntTy,
339                                   Loc);
340
341  // Cond -> "rsIntIter < NumArrayElements"
342  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
343      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
344
345  clang::BinaryOperator *Cond =
346      new(C) clang::BinaryOperator(RefrsIntIter,
347                                   NumArrayElementsExpr,
348                                   clang::BO_LT,
349                                   C.IntTy,
350                                   Loc);
351
352  // Inc -> "rsIntIter++"
353  clang::UnaryOperator *Inc =
354      new(C) clang::UnaryOperator(RefrsIntIter,
355                                  clang::UO_PostInc,
356                                  C.IntTy,
357                                  Loc);
358
359  // Body -> "rsClearObject(&VD[rsIntIter]);"
360  // Destructor loop operates on individual array elements
361  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
362  clang::QualType ClearObjectFDArgType =
363      ClearObjectFD->getParamDecl(0)->getOriginalType();
364
365  const clang::Type *VT = RSExportType::GetTypeOfDecl(VD);
366  clang::DeclRefExpr *RefRSVar =
367      clang::DeclRefExpr::Create(C,
368                                 NULL,
369                                 Range,
370                                 VD,
371                                 Loc,
372                                 VT->getCanonicalTypeInternal());
373
374  clang::Expr *RefRSVarPtr =
375      clang::ImplicitCastExpr::Create(C,
376          C.getPointerType(T->getCanonicalTypeInternal()),
377          clang::CK_ArrayToPointerDecay,
378          RefRSVar,
379          NULL,
380          clang::VK_RValue);
381
382  clang::Expr *RefRSVarPtrSubscript =
383      new(C) clang::ArraySubscriptExpr(RefRSVarPtr,
384                                       RefrsIntIter,
385                                       T->getCanonicalTypeInternal(),
386                                       VD->getLocation());
387
388  clang::Expr *AddrRefRSVarPtrSubscript =
389      new(C) clang::UnaryOperator(RefRSVarPtrSubscript,
390                                  clang::UO_AddrOf,
391                                  ClearObjectFDArgType,
392                                  VD->getLocation());
393
394  clang::Expr *RefRSClearObjectFD =
395      clang::DeclRefExpr::Create(C,
396                                 NULL,
397                                 Range,
398                                 ClearObjectFD,
399                                 Loc,
400                                 ClearObjectFDType);
401
402  clang::Expr *RSClearObjectFP =
403      clang::ImplicitCastExpr::Create(C,
404                                      C.getPointerType(ClearObjectFDType),
405                                      clang::CK_FunctionToPointerDecay,
406                                      RefRSClearObjectFD,
407                                      NULL,
408                                      clang::VK_RValue);
409
410  clang::CallExpr *RSClearObjectCall =
411      new(C) clang::CallExpr(C,
412                             RSClearObjectFP,
413                             &AddrRefRSVarPtrSubscript,
414                             1,
415                             ClearObjectFD->getCallResultType(),
416                             Loc);
417
418  clang::ForStmt *DestructorLoop =
419      new(C) clang::ForStmt(C,
420                            Init,
421                            Cond,
422                            NULL,  // no condVar
423                            Inc,
424                            RSClearObjectCall,
425                            Loc,
426                            Loc,
427                            Loc);
428
429  StmtArray[StmtCtr++] = DestructorLoop;
430  assert(StmtCtr == 2);
431
432  clang::CompoundStmt *CS =
433      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
434
435  return CS;
436}
437
438}  // namespace
439
440void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
441    clang::BinaryOperator *AS) {
442
443  clang::QualType QT = AS->getType();
444  RSExportPrimitiveType::DataType DT =
445      RSExportPrimitiveType::GetRSSpecificType(QT.getTypePtr());
446
447  clang::FunctionDecl *SetObjectFD =
448      RSSetObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)];
449  assert((SetObjectFD != NULL) &&
450      "rsSetObject doesn't cover all RS object types");
451  clang::ASTContext &C = SetObjectFD->getASTContext();
452
453  clang::QualType SetObjectFDType = SetObjectFD->getType();
454  clang::QualType SetObjectFDArgType[2];
455  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
456  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
457
458  clang::SourceLocation Loc = SetObjectFD->getLocation();
459  clang::Expr *RefRSSetObjectFD =
460      clang::DeclRefExpr::Create(C,
461                                 NULL,
462                                 SetObjectFD->getQualifierRange(),
463                                 SetObjectFD,
464                                 Loc,
465                                 SetObjectFDType);
466
467  clang::Expr *RSSetObjectFP =
468      clang::ImplicitCastExpr::Create(C,
469                                      C.getPointerType(SetObjectFDType),
470                                      clang::CK_FunctionToPointerDecay,
471                                      RefRSSetObjectFD,
472                                      NULL,
473                                      clang::VK_RValue);
474
475  clang::Expr *ArgList[2];
476  ArgList[0] = new(C) clang::UnaryOperator(AS->getLHS(),
477                                           clang::UO_AddrOf,
478                                           SetObjectFDArgType[0],
479                                           Loc);
480  ArgList[1] = AS->getRHS();
481
482  clang::CallExpr *RSSetObjectCall =
483      new(C) clang::CallExpr(C,
484                             RSSetObjectFP,
485                             ArgList,
486                             2,
487                             SetObjectFD->getCallResultType(),
488                             Loc);
489
490  ReplaceInCompoundStmt(C, mCS, AS, RSSetObjectCall);
491
492  return;
493}
494
495void RSObjectRefCount::Scope::AppendRSObjectInit(
496    clang::VarDecl *VD,
497    clang::DeclStmt *DS,
498    RSExportPrimitiveType::DataType DT,
499    clang::Expr *InitExpr) {
500  assert(VD);
501
502  if (!InitExpr) {
503    return;
504  }
505
506  clang::FunctionDecl *SetObjectFD =
507      RSSetObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)];
508  assert((SetObjectFD != NULL) &&
509      "rsSetObject doesn't cover all RS object types");
510  clang::ASTContext &C = SetObjectFD->getASTContext();
511
512  clang::QualType SetObjectFDType = SetObjectFD->getType();
513  clang::QualType SetObjectFDArgType[2];
514  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
515  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
516
517  clang::SourceLocation Loc = SetObjectFD->getLocation();
518  clang::Expr *RefRSSetObjectFD =
519      clang::DeclRefExpr::Create(C,
520                                 NULL,
521                                 SetObjectFD->getQualifierRange(),
522                                 SetObjectFD,
523                                 Loc,
524                                 SetObjectFDType);
525
526  clang::Expr *RSSetObjectFP =
527      clang::ImplicitCastExpr::Create(C,
528                                      C.getPointerType(SetObjectFDType),
529                                      clang::CK_FunctionToPointerDecay,
530                                      RefRSSetObjectFD,
531                                      NULL,
532                                      clang::VK_RValue);
533
534  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
535  clang::DeclRefExpr *RefRSVar =
536      clang::DeclRefExpr::Create(C,
537                                 NULL,
538                                 VD->getQualifierRange(),
539                                 VD,
540                                 Loc,
541                                 T->getCanonicalTypeInternal());
542
543  clang::Expr *ArgList[2];
544  ArgList[0] = new(C) clang::UnaryOperator(RefRSVar,
545                                           clang::UO_AddrOf,
546                                           SetObjectFDArgType[0],
547                                           Loc);
548  ArgList[1] = InitExpr;
549
550  clang::CallExpr *RSSetObjectCall =
551      new(C) clang::CallExpr(C,
552                             RSSetObjectFP,
553                             ArgList,
554                             2,
555                             SetObjectFD->getCallResultType(),
556                             Loc);
557
558  AppendAfterStmt(C, mCS, DS, RSSetObjectCall);
559
560  return;
561}
562
563void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
564  std::list<clang::Stmt*> RSClearObjectCalls;
565  for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
566          E = mRSO.end();
567        I != E;
568        I++) {
569    clang::Stmt *S = ClearRSObject(*I);
570    if (S) {
571      RSClearObjectCalls.push_back(S);
572    }
573  }
574  if (RSClearObjectCalls.size() > 0) {
575    DestructorVisitor DV((*mRSO.begin())->getASTContext(), RSClearObjectCalls);
576    DV.Visit(mCS);
577  }
578  return;
579}
580
581clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(clang::VarDecl *VD) {
582  bool IsArrayType = false;
583  clang::ASTContext &C = VD->getASTContext();
584  clang::SourceLocation Loc = VD->getLocation();
585  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
586
587  // Loop through array types to get to base type
588  while (T && T->isArrayType()) {
589    T = T->getArrayElementTypeNoTypeQual();
590    IsArrayType = true;
591  }
592
593  RSExportPrimitiveType::DataType DT =
594      RSExportPrimitiveType::GetRSSpecificType(T);
595
596  assert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
597      "Should be RS object");
598
599  // Find the rsClearObject() for VD of RS object type DT
600  clang::FunctionDecl *ClearObjectFD =
601      RSClearObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)];
602  assert((ClearObjectFD != NULL) &&
603      "rsClearObject doesn't cover all RS object types");
604
605  if (IsArrayType) {
606    return ClearArrayRSObject(VD, T, ClearObjectFD);
607  }
608
609  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
610  clang::QualType ClearObjectFDArgType =
611      ClearObjectFD->getParamDecl(0)->getOriginalType();
612
613  // Example destructor for "rs_font localFont;"
614  //
615  // (CallExpr 'void'
616  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
617  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
618  //   (UnaryOperator 'rs_font *' prefix '&'
619  //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
620
621  // Reference expr to target RS object variable
622  clang::DeclRefExpr *RefRSVar =
623      clang::DeclRefExpr::Create(C,
624                                 NULL,
625                                 VD->getQualifierRange(),
626                                 VD,
627                                 Loc,
628                                 T->getCanonicalTypeInternal());
629
630  // Get address of RSObject in VD
631  clang::Expr *AddrRefRSVar =
632      new(C) clang::UnaryOperator(RefRSVar,
633                                  clang::UO_AddrOf,
634                                  ClearObjectFDArgType,
635                                  Loc);
636
637  clang::Expr *RefRSClearObjectFD =
638      clang::DeclRefExpr::Create(C,
639                                 NULL,
640                                 ClearObjectFD->getQualifierRange(),
641                                 ClearObjectFD,
642                                 ClearObjectFD->getLocation(),
643                                 ClearObjectFDType);
644
645  clang::Expr *RSClearObjectFP =
646      clang::ImplicitCastExpr::Create(C,
647                                      C.getPointerType(ClearObjectFDType),
648                                      clang::CK_FunctionToPointerDecay,
649                                      RefRSClearObjectFD,
650                                      NULL,
651                                      clang::VK_RValue);
652
653  clang::CallExpr *RSClearObjectCall =
654      new(C) clang::CallExpr(C,
655                             RSClearObjectFP,
656                             &AddrRefRSVar,
657                             1,
658                             ClearObjectFD->getCallResultType(),
659                             clang::SourceLocation());
660
661  return RSClearObjectCall;
662}
663
664bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
665                                          RSExportPrimitiveType::DataType *DT,
666                                          clang::Expr **InitExpr) {
667  assert(VD && DT && InitExpr);
668  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
669
670  // Loop through array types to get to base type
671  while (T && T->isArrayType()) {
672    T = T->getArrayElementTypeNoTypeQual();
673  }
674
675  *DT = RSExportPrimitiveType::GetRSSpecificType(T);
676
677  if (*DT == RSExportPrimitiveType::DataTypeUnknown) {
678    if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
679      *DT = RSExportPrimitiveType::DataTypeIsStruct;
680    } else {
681      return false;
682    }
683  }
684
685  bool DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
686  *InitExpr = VD->getInit();
687
688  if (!DataTypeIsRSObject && *InitExpr) {
689    // If we already have an initializer for a matrix type, we are done.
690    return DataTypeIsRSObject;
691  }
692
693  clang::Expr *ZeroInitializer =
694      CreateZeroInitializerForRSSpecificType(*DT,
695                                             VD->getASTContext(),
696                                             VD->getLocation());
697
698  if (ZeroInitializer) {
699    ZeroInitializer->setType(T->getCanonicalTypeInternal());
700    VD->setInit(ZeroInitializer);
701  }
702
703  return DataTypeIsRSObject;
704}
705
706clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
707    RSExportPrimitiveType::DataType DT,
708    clang::ASTContext &C,
709    const clang::SourceLocation &Loc) {
710  clang::Expr *Res = NULL;
711  switch (DT) {
712    case RSExportPrimitiveType::DataTypeIsStruct:
713    case RSExportPrimitiveType::DataTypeRSElement:
714    case RSExportPrimitiveType::DataTypeRSType:
715    case RSExportPrimitiveType::DataTypeRSAllocation:
716    case RSExportPrimitiveType::DataTypeRSSampler:
717    case RSExportPrimitiveType::DataTypeRSScript:
718    case RSExportPrimitiveType::DataTypeRSMesh:
719    case RSExportPrimitiveType::DataTypeRSProgramFragment:
720    case RSExportPrimitiveType::DataTypeRSProgramVertex:
721    case RSExportPrimitiveType::DataTypeRSProgramRaster:
722    case RSExportPrimitiveType::DataTypeRSProgramStore:
723    case RSExportPrimitiveType::DataTypeRSFont: {
724      //    (ImplicitCastExpr 'nullptr_t'
725      //      (IntegerLiteral 0)))
726      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
727      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
728      clang::Expr *CastToNull =
729          clang::ImplicitCastExpr::Create(C,
730                                          C.NullPtrTy,
731                                          clang::CK_IntegralToPointer,
732                                          Int0,
733                                          NULL,
734                                          clang::VK_RValue);
735
736      Res = new(C) clang::InitListExpr(C, Loc, &CastToNull, 1, Loc);
737      break;
738    }
739    case RSExportPrimitiveType::DataTypeRSMatrix2x2:
740    case RSExportPrimitiveType::DataTypeRSMatrix3x3:
741    case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
742      // RS matrix is not completely an RS object. They hold data by themselves.
743      // (InitListExpr rs_matrix2x2
744      //   (InitListExpr float[4]
745      //     (FloatingLiteral 0)
746      //     (FloatingLiteral 0)
747      //     (FloatingLiteral 0)
748      //     (FloatingLiteral 0)))
749      clang::QualType FloatTy = C.FloatTy;
750      // Constructor sets value to 0.0f by default
751      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
752      clang::FloatingLiteral *Float0Val =
753          clang::FloatingLiteral::Create(C,
754                                         Val,
755                                         /* isExact = */true,
756                                         FloatTy,
757                                         Loc);
758
759      unsigned N = 0;
760      if (DT == RSExportPrimitiveType::DataTypeRSMatrix2x2)
761        N = 2;
762      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix3x3)
763        N = 3;
764      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix4x4)
765        N = 4;
766
767      // Directly allocate 16 elements instead of dynamically allocate N*N
768      clang::Expr *InitVals[16];
769      for (unsigned i = 0; i < sizeof(InitVals) / sizeof(InitVals[0]); i++)
770        InitVals[i] = Float0Val;
771      clang::Expr *InitExpr =
772          new(C) clang::InitListExpr(C, Loc, InitVals, N * N, Loc);
773      InitExpr->setType(C.getConstantArrayType(FloatTy,
774                                               llvm::APInt(32, 4),
775                                               clang::ArrayType::Normal,
776                                               /* EltTypeQuals = */0));
777
778      Res = new(C) clang::InitListExpr(C, Loc, &InitExpr, 1, Loc);
779      break;
780    }
781    case RSExportPrimitiveType::DataTypeUnknown:
782    case RSExportPrimitiveType::DataTypeFloat16:
783    case RSExportPrimitiveType::DataTypeFloat32:
784    case RSExportPrimitiveType::DataTypeFloat64:
785    case RSExportPrimitiveType::DataTypeSigned8:
786    case RSExportPrimitiveType::DataTypeSigned16:
787    case RSExportPrimitiveType::DataTypeSigned32:
788    case RSExportPrimitiveType::DataTypeSigned64:
789    case RSExportPrimitiveType::DataTypeUnsigned8:
790    case RSExportPrimitiveType::DataTypeUnsigned16:
791    case RSExportPrimitiveType::DataTypeUnsigned32:
792    case RSExportPrimitiveType::DataTypeUnsigned64:
793    case RSExportPrimitiveType::DataTypeBoolean:
794    case RSExportPrimitiveType::DataTypeUnsigned565:
795    case RSExportPrimitiveType::DataTypeUnsigned5551:
796    case RSExportPrimitiveType::DataTypeUnsigned4444:
797    case RSExportPrimitiveType::DataTypeMax: {
798      assert(false && "Not RS object type!");
799    }
800    // No default case will enable compiler detecting the missing cases
801  }
802
803  return Res;
804}
805
806void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
807  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
808       I != E;
809       I++) {
810    clang::Decl *D = *I;
811    if (D->getKind() == clang::Decl::Var) {
812      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
813      RSExportPrimitiveType::DataType DT =
814          RSExportPrimitiveType::DataTypeUnknown;
815      clang::Expr *InitExpr = NULL;
816      if (InitializeRSObject(VD, &DT, &InitExpr)) {
817        getCurrentScope()->addRSObject(VD);
818        getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
819      }
820    }
821  }
822  return;
823}
824
825void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
826  if (!CS->body_empty()) {
827    // Push a new scope
828    Scope *S = new Scope(CS);
829    mScopeStack.push(S);
830
831    VisitStmt(CS);
832
833    // Destroy the scope
834    assert((getCurrentScope() == S) && "Corrupted scope stack!");
835    S->InsertLocalVarDestructors();
836    mScopeStack.pop();
837    delete S;
838  }
839  return;
840}
841
842void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
843  clang::QualType QT = AS->getType();
844  RSExportPrimitiveType::DataType DT =
845      RSExportPrimitiveType::GetRSSpecificType(QT.getTypePtr());
846
847  if (RSExportPrimitiveType::IsRSObjectType(DT)) {
848    getCurrentScope()->ReplaceRSObjectAssignment(AS);
849  }
850
851  return;
852}
853
854void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
855  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
856       I != E;
857       I++) {
858    if (clang::Stmt *Child = *I) {
859      Visit(Child);
860    }
861  }
862  return;
863}
864
865}  // namespace slang
866