slang_rs_object_ref_count.cpp revision d5f9d6c8b6944dfc30d4fea68479c2fcc250a62c
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_object_ref_count.h"
18
19#include <list>
20
21#include "clang/AST/DeclGroup.h"
22#include "clang/AST/Expr.h"
23#include "clang/AST/OperationKinds.h"
24#include "clang/AST/Stmt.h"
25#include "clang/AST/StmtVisitor.h"
26
27#include "slang_rs.h"
28#include "slang_rs_export_type.h"
29
30namespace slang {
31
32clang::FunctionDecl *RSObjectRefCount::Scope::
33    RSSetObjectFD[RSExportPrimitiveType::LastRSObjectType -
34                  RSExportPrimitiveType::FirstRSObjectType + 1];
35clang::FunctionDecl *RSObjectRefCount::Scope::
36    RSClearObjectFD[RSExportPrimitiveType::LastRSObjectType -
37                    RSExportPrimitiveType::FirstRSObjectType + 1];
38
39void RSObjectRefCount::Scope::GetRSRefCountingFunctions(
40    clang::ASTContext &C) {
41  for (unsigned i = 0;
42       i < (sizeof(RSClearObjectFD) / sizeof(clang::FunctionDecl*));
43       i++) {
44    RSSetObjectFD[i] = NULL;
45    RSClearObjectFD[i] = NULL;
46  }
47
48  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
49
50  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
51          E = TUDecl->decls_end(); I != E; I++) {
52    if ((I->getKind() >= clang::Decl::firstFunction) &&
53        (I->getKind() <= clang::Decl::lastFunction)) {
54      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
55
56      // points to RSSetObjectFD or RSClearObjectFD
57      clang::FunctionDecl **RSObjectFD;
58
59      if (FD->getName() == "rsSetObject") {
60        assert((FD->getNumParams() == 2) &&
61               "Invalid rsSetObject function prototype (# params)");
62        RSObjectFD = RSSetObjectFD;
63      } else if (FD->getName() == "rsClearObject") {
64        assert((FD->getNumParams() == 1) &&
65               "Invalid rsClearObject function prototype (# params)");
66        RSObjectFD = RSClearObjectFD;
67      } else {
68        continue;
69      }
70
71      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
72      clang::QualType PVT = PVD->getOriginalType();
73      // The first parameter must be a pointer like rs_allocation*
74      assert(PVT->isPointerType() &&
75             "Invalid rs{Set,Clear}Object function prototype (pointer param)");
76
77      // The rs object type passed to the FD
78      clang::QualType RST = PVT->getPointeeType();
79      RSExportPrimitiveType::DataType DT =
80          RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
81      assert(RSExportPrimitiveType::IsRSObjectType(DT)
82             && "must be RS object type");
83
84      RSObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)] = FD;
85    }
86  }
87}
88
89namespace {
90
91static void AppendToCompoundStatement(clang::ASTContext& C,
92                                      clang::CompoundStmt *CS,
93                                      std::list<clang::Stmt*> &StmtList,
94                                      bool InsertAtEndOfBlock) {
95  // Destructor code will be inserted before any return statement.
96  // Any subsequent statements in the compound statement are then placed
97  // after our new code.
98  // TODO(srhines): This should also handle the case of goto/break/continue.
99
100  clang::CompoundStmt::body_iterator bI = CS->body_begin();
101  clang::CompoundStmt::body_iterator bE = CS->body_end();
102
103  unsigned OldStmtCount = 0;
104  for (bI = CS->body_begin(); bI != bE; bI++) {
105    OldStmtCount++;
106  }
107
108  unsigned NewStmtCount = StmtList.size();
109
110  clang::Stmt **UpdatedStmtList;
111  UpdatedStmtList = new clang::Stmt*[OldStmtCount+NewStmtCount];
112
113  unsigned UpdatedStmtCount = 0;
114  bool FoundReturn = false;
115  for (bI = CS->body_begin(); bI != bE; bI++) {
116    if ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass) {
117      FoundReturn = true;
118      break;
119    }
120    UpdatedStmtList[UpdatedStmtCount++] = *bI;
121  }
122
123  // Always insert before a return that we found, or if we are told
124  // to insert at the end of the block
125  if (FoundReturn || InsertAtEndOfBlock) {
126    std::list<clang::Stmt*>::const_iterator E = StmtList.end();
127    for (std::list<clang::Stmt*>::const_iterator I = StmtList.begin(),
128            E = StmtList.end();
129         I != E;
130         I++) {
131      UpdatedStmtList[UpdatedStmtCount++] = *I;
132    }
133  }
134
135  // Pick up anything left over after a return statement
136  for ( ; bI != bE; bI++) {
137    UpdatedStmtList[UpdatedStmtCount++] = *bI;
138  }
139
140  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
141
142  delete [] UpdatedStmtList;
143
144  return;
145}
146
147// This class visits a compound statement and inserts the StmtList containing
148// destructors in proper locations. This includes inserting them before any
149// return statement in any sub-block, at the end of the logical enclosing
150// scope (compound statement), and/or before any break/continue statement that
151// would resume outside the declared scope. We will not handle the case for
152// goto statements that leave a local scope.
153// TODO(srhines): Make this work properly for break/continue.
154class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
155 private:
156  clang::ASTContext &mC;
157  std::list<clang::Stmt*> &mStmtList;
158  bool mTopLevel;
159 public:
160  DestructorVisitor(clang::ASTContext &C, std::list<clang::Stmt*> &StmtList);
161  void VisitStmt(clang::Stmt *S);
162  void VisitCompoundStmt(clang::CompoundStmt *CS);
163};
164
165DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
166                                     std::list<clang::Stmt*> &StmtList)
167  : mC(C),
168    mStmtList(StmtList),
169    mTopLevel(true) {
170  return;
171}
172
173void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
174  if (!CS->body_empty()) {
175    AppendToCompoundStatement(mC, CS, mStmtList, mTopLevel);
176    mTopLevel = false;
177    VisitStmt(CS);
178  }
179  return;
180}
181
182void DestructorVisitor::VisitStmt(clang::Stmt *S) {
183  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
184       I != E;
185       I++) {
186    if (clang::Stmt *Child = *I) {
187      Visit(Child);
188    }
189  }
190  return;
191}
192
193}  // namespace
194
195void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
196  std::list<clang::Stmt*> RSClearObjectCalls;
197  for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
198          E = mRSO.end();
199        I != E;
200        I++) {
201    clang::Stmt *E = ClearRSObject(*I);
202    if (E) {
203      RSClearObjectCalls.push_back(E);
204    }
205  }
206  if (RSClearObjectCalls.size() > 0) {
207    DestructorVisitor DV((*mRSO.begin())->getASTContext(), RSClearObjectCalls);
208    DV.Visit(mCS);
209  }
210  return;
211}
212
213clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(clang::VarDecl *VD) {
214  clang::ASTContext &C = VD->getASTContext();
215  clang::SourceLocation Loc = VD->getLocation();
216  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
217  RSExportPrimitiveType::DataType DT =
218      RSExportPrimitiveType::GetRSSpecificType(T);
219
220  assert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
221      "Should be RS object");
222
223  // Find the rsClearObject() for VD of RS object type DT
224  clang::FunctionDecl *ClearObjectFD =
225      RSClearObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)];
226  assert((ClearObjectFD != NULL) &&
227      "rsClearObject doesn't cover all RS object types");
228
229  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
230  clang::QualType ClearObjectFDArgType =
231      ClearObjectFD->getParamDecl(0)->getOriginalType();
232
233  // We generate a call to rsClearObject passing &VD as the parameter
234  // (CallExpr 'void'
235  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
236  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
237  //   (UnaryOperator 'rs_font *' prefix '&'
238  //     (DeclRefExpr 'rs_font':'rs_font' Var='[var name]')))
239
240  // Reference expr to target RS object variable
241  clang::DeclRefExpr *RefRSVar =
242      clang::DeclRefExpr::Create(C,
243                                 NULL,
244                                 VD->getQualifierRange(),
245                                 VD,
246                                 Loc,
247                                 T->getCanonicalTypeInternal(),
248                                 NULL);
249
250  // Get address of RSObject in VD
251  clang::Expr *AddrRefRSVar =
252      new(C) clang::UnaryOperator(RefRSVar,
253                                  clang::UO_AddrOf,
254                                  ClearObjectFDArgType,
255                                  Loc);
256
257  clang::Expr *RefRSClearObjectFD =
258      clang::DeclRefExpr::Create(C,
259                                 NULL,
260                                 ClearObjectFD->getQualifierRange(),
261                                 ClearObjectFD,
262                                 ClearObjectFD->getLocation(),
263                                 ClearObjectFDType,
264                                 NULL);
265
266  clang::Expr *RSClearObjectFP =
267      clang::ImplicitCastExpr::Create(C,
268                                      C.getPointerType(ClearObjectFDType),
269                                      clang::CK_FunctionToPointerDecay,
270                                      RefRSClearObjectFD,
271                                      NULL,
272                                      clang::VK_RValue);
273
274  clang::CallExpr *RSClearObjectCall =
275      new(C) clang::CallExpr(C,
276                             RSClearObjectFP,
277                             &AddrRefRSVar,
278                             1,
279                             ClearObjectFD->getCallResultType(),
280                             clang::SourceLocation());
281
282  return RSClearObjectCall;
283}
284
285bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD) {
286  bool IsArrayType = false;
287  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
288
289  // Loop through array types to get to base type
290  while (T && T->isArrayType()) {
291    T = T->getArrayElementTypeNoTypeQual();
292    IsArrayType = true;
293  }
294
295  RSExportPrimitiveType::DataType DT =
296      RSExportPrimitiveType::GetRSSpecificType(T);
297
298  if (DT == RSExportPrimitiveType::DataTypeUnknown) {
299    return false;
300  }
301
302  if (VD->hasInit()) {
303    // TODO(srhines): Update the reference count of RS object in initializer.
304    // This can potentially be done as part of the assignment pass.
305  } else {
306    clang::Expr *ZeroInitializer =
307        CreateZeroInitializerForRSSpecificType(DT,
308                                               VD->getASTContext(),
309                                               VD->getLocation());
310
311    if (ZeroInitializer) {
312      ZeroInitializer->setType(T->getCanonicalTypeInternal());
313      VD->setInit(ZeroInitializer);
314    }
315  }
316
317  // TODO(srhines): Skip returning true in the case of array objects because
318  // we don't have looping destructor support yet.
319  return !IsArrayType && RSExportPrimitiveType::IsRSObjectType(DT);
320}
321
322clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
323    RSExportPrimitiveType::DataType DT,
324    clang::ASTContext &C,
325    const clang::SourceLocation &Loc) {
326  clang::Expr *Res = NULL;
327  switch (DT) {
328    case RSExportPrimitiveType::DataTypeRSElement:
329    case RSExportPrimitiveType::DataTypeRSType:
330    case RSExportPrimitiveType::DataTypeRSAllocation:
331    case RSExportPrimitiveType::DataTypeRSSampler:
332    case RSExportPrimitiveType::DataTypeRSScript:
333    case RSExportPrimitiveType::DataTypeRSMesh:
334    case RSExportPrimitiveType::DataTypeRSProgramFragment:
335    case RSExportPrimitiveType::DataTypeRSProgramVertex:
336    case RSExportPrimitiveType::DataTypeRSProgramRaster:
337    case RSExportPrimitiveType::DataTypeRSProgramStore:
338    case RSExportPrimitiveType::DataTypeRSFont: {
339      //    (ImplicitCastExpr 'nullptr_t'
340      //      (IntegerLiteral 0)))
341      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
342      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
343      clang::Expr *CastToNull =
344          clang::ImplicitCastExpr::Create(C,
345                                          C.NullPtrTy,
346                                          clang::CK_IntegralToPointer,
347                                          Int0,
348                                          NULL,
349                                          clang::VK_RValue);
350
351      Res = new(C) clang::InitListExpr(C, Loc, &CastToNull, 1, Loc);
352      break;
353    }
354    case RSExportPrimitiveType::DataTypeRSMatrix2x2:
355    case RSExportPrimitiveType::DataTypeRSMatrix3x3:
356    case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
357      // RS matrix is not completely an RS object. They hold data by themselves.
358      // (InitListExpr rs_matrix2x2
359      //   (InitListExpr float[4]
360      //     (FloatingLiteral 0)
361      //     (FloatingLiteral 0)
362      //     (FloatingLiteral 0)
363      //     (FloatingLiteral 0)))
364      clang::QualType FloatTy = C.FloatTy;
365      // Constructor sets value to 0.0f by default
366      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
367      clang::FloatingLiteral *Float0Val =
368          clang::FloatingLiteral::Create(C,
369                                         Val,
370                                         /* isExact = */true,
371                                         FloatTy,
372                                         Loc);
373
374      unsigned N = 0;
375      if (DT == RSExportPrimitiveType::DataTypeRSMatrix2x2)
376        N = 2;
377      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix3x3)
378        N = 3;
379      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix4x4)
380        N = 4;
381
382      // Directly allocate 16 elements instead of dynamically allocate N*N
383      clang::Expr *InitVals[16];
384      for (unsigned i = 0; i < sizeof(InitVals) / sizeof(InitVals[0]); i++)
385        InitVals[i] = Float0Val;
386      clang::Expr *InitExpr =
387          new(C) clang::InitListExpr(C, Loc, InitVals, N * N, Loc);
388      InitExpr->setType(C.getConstantArrayType(FloatTy,
389                                               llvm::APInt(32, 4),
390                                               clang::ArrayType::Normal,
391                                               /* EltTypeQuals = */0));
392
393      Res = new(C) clang::InitListExpr(C, Loc, &InitExpr, 1, Loc);
394      break;
395    }
396    case RSExportPrimitiveType::DataTypeUnknown:
397    case RSExportPrimitiveType::DataTypeFloat16:
398    case RSExportPrimitiveType::DataTypeFloat32:
399    case RSExportPrimitiveType::DataTypeFloat64:
400    case RSExportPrimitiveType::DataTypeSigned8:
401    case RSExportPrimitiveType::DataTypeSigned16:
402    case RSExportPrimitiveType::DataTypeSigned32:
403    case RSExportPrimitiveType::DataTypeSigned64:
404    case RSExportPrimitiveType::DataTypeUnsigned8:
405    case RSExportPrimitiveType::DataTypeUnsigned16:
406    case RSExportPrimitiveType::DataTypeUnsigned32:
407    case RSExportPrimitiveType::DataTypeUnsigned64:
408    case RSExportPrimitiveType::DataTypeBoolean:
409    case RSExportPrimitiveType::DataTypeUnsigned565:
410    case RSExportPrimitiveType::DataTypeUnsigned5551:
411    case RSExportPrimitiveType::DataTypeUnsigned4444:
412    case RSExportPrimitiveType::DataTypeMax: {
413      assert(false && "Not RS object type!");
414    }
415    // No default case will enable compiler detecting the missing cases
416  }
417
418  return Res;
419}
420
421void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
422  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
423       I != E;
424       I++) {
425    clang::Decl *D = *I;
426    if (D->getKind() == clang::Decl::Var) {
427      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
428      if (InitializeRSObject(VD))
429        getCurrentScope()->addRSObject(VD);
430    }
431  }
432  return;
433}
434
435void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
436  if (!CS->body_empty()) {
437    // Push a new scope
438    Scope *S = new Scope(CS);
439    mScopeStack.push(S);
440
441    VisitStmt(CS);
442
443    // Destroy the scope
444    // TODO(srhines): Update reference count of the RS object refenced by
445    //                getCurrentScope().
446    assert((getCurrentScope() == S) && "Corrupted scope stack!");
447    S->InsertLocalVarDestructors();
448    mScopeStack.pop();
449    delete S;
450  }
451  return;
452}
453
454void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
455  // TODO(srhines): Update reference count
456  return;
457}
458
459void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
460  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
461       I != E;
462       I++) {
463    if (clang::Stmt *Child = *I) {
464      Visit(Child);
465    }
466  }
467  return;
468}
469
470}  // namespace slang
471