slang_rs_object_ref_count.cpp revision f2174cfd6a556b51aadf2b8765e50df080e8f18e
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_object_ref_count.h"
18
19#include <list>
20
21#include "clang/AST/DeclGroup.h"
22#include "clang/AST/Expr.h"
23#include "clang/AST/OperationKinds.h"
24#include "clang/AST/Stmt.h"
25#include "clang/AST/StmtVisitor.h"
26
27#include "slang_assert.h"
28#include "slang_rs.h"
29#include "slang_rs_export_type.h"
30
31namespace slang {
32
33clang::FunctionDecl *RSObjectRefCount::
34    RSSetObjectFD[RSExportPrimitiveType::LastRSObjectType -
35                  RSExportPrimitiveType::FirstRSObjectType + 1];
36clang::FunctionDecl *RSObjectRefCount::
37    RSClearObjectFD[RSExportPrimitiveType::LastRSObjectType -
38                    RSExportPrimitiveType::FirstRSObjectType + 1];
39
40void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
41  for (unsigned i = 0;
42       i < (sizeof(RSClearObjectFD) / sizeof(clang::FunctionDecl*));
43       i++) {
44    RSSetObjectFD[i] = NULL;
45    RSClearObjectFD[i] = NULL;
46  }
47
48  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
49
50  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
51          E = TUDecl->decls_end(); I != E; I++) {
52    if ((I->getKind() >= clang::Decl::firstFunction) &&
53        (I->getKind() <= clang::Decl::lastFunction)) {
54      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
55
56      // points to RSSetObjectFD or RSClearObjectFD
57      clang::FunctionDecl **RSObjectFD;
58
59      if (FD->getName() == "rsSetObject") {
60        slangAssert((FD->getNumParams() == 2) &&
61                    "Invalid rsSetObject function prototype (# params)");
62        RSObjectFD = RSSetObjectFD;
63      } else if (FD->getName() == "rsClearObject") {
64        slangAssert((FD->getNumParams() == 1) &&
65                    "Invalid rsClearObject function prototype (# params)");
66        RSObjectFD = RSClearObjectFD;
67      } else {
68        continue;
69      }
70
71      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
72      clang::QualType PVT = PVD->getOriginalType();
73      // The first parameter must be a pointer like rs_allocation*
74      slangAssert(PVT->isPointerType() &&
75          "Invalid rs{Set,Clear}Object function prototype (pointer param)");
76
77      // The rs object type passed to the FD
78      clang::QualType RST = PVT->getPointeeType();
79      RSExportPrimitiveType::DataType DT =
80          RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
81      slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
82             && "must be RS object type");
83
84      RSObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)] = FD;
85    }
86  }
87}
88
89namespace {
90
91static void AppendToCompoundStatement(clang::ASTContext& C,
92                                      clang::CompoundStmt *CS,
93                                      std::list<clang::Stmt*> &StmtList,
94                                      bool InsertAtEndOfBlock) {
95  // Destructor code will be inserted before any return statement.
96  // Any subsequent statements in the compound statement are then placed
97  // after our new code.
98  // TODO(srhines): This should also handle the case of goto/break/continue.
99
100  clang::CompoundStmt::body_iterator bI = CS->body_begin();
101
102  unsigned OldStmtCount = 0;
103  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
104    OldStmtCount++;
105  }
106
107  unsigned NewStmtCount = StmtList.size();
108
109  clang::Stmt **UpdatedStmtList;
110  UpdatedStmtList = new clang::Stmt*[OldStmtCount+NewStmtCount];
111
112  unsigned UpdatedStmtCount = 0;
113  bool FoundReturn = false;
114  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
115    if ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass) {
116      FoundReturn = true;
117      break;
118    }
119    UpdatedStmtList[UpdatedStmtCount++] = *bI;
120  }
121
122  // Always insert before a return that we found, or if we are told
123  // to insert at the end of the block
124  if (FoundReturn || InsertAtEndOfBlock) {
125    std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
126    for (std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
127         I != StmtList.end();
128         I++) {
129      UpdatedStmtList[UpdatedStmtCount++] = *I;
130    }
131  }
132
133  // Pick up anything left over after a return statement
134  for ( ; bI != CS->body_end(); bI++) {
135    UpdatedStmtList[UpdatedStmtCount++] = *bI;
136  }
137
138  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
139
140  delete [] UpdatedStmtList;
141
142  return;
143}
144
145static void AppendAfterStmt(clang::ASTContext& C,
146                            clang::CompoundStmt *CS,
147                            clang::Stmt *OldStmt,
148                            clang::Stmt *NewStmt) {
149  slangAssert(CS && OldStmt && NewStmt);
150  clang::CompoundStmt::body_iterator bI = CS->body_begin();
151  unsigned StmtCount = 1;  // Take into account new statement
152  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
153    StmtCount++;
154  }
155
156  clang::Stmt **UpdatedStmtList = new clang::Stmt*[StmtCount];
157
158  unsigned UpdatedStmtCount = 0;
159  unsigned Once = 0;
160  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
161    UpdatedStmtList[UpdatedStmtCount++] = *bI;
162    if (*bI == OldStmt) {
163      Once++;
164      slangAssert(Once == 1);
165      UpdatedStmtList[UpdatedStmtCount++] = NewStmt;
166    }
167  }
168
169  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
170
171  delete [] UpdatedStmtList;
172
173  return;
174}
175
176static void ReplaceInCompoundStmt(clang::ASTContext& C,
177                                  clang::CompoundStmt *CS,
178                                  clang::Stmt* OldStmt,
179                                  clang::Stmt* NewStmt) {
180  clang::CompoundStmt::body_iterator bI = CS->body_begin();
181
182  unsigned StmtCount = 0;
183  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
184    StmtCount++;
185  }
186
187  clang::Stmt **UpdatedStmtList = new clang::Stmt*[StmtCount];
188
189  unsigned UpdatedStmtCount = 0;
190  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
191    if (*bI == OldStmt) {
192      UpdatedStmtList[UpdatedStmtCount++] = NewStmt;
193    } else {
194      UpdatedStmtList[UpdatedStmtCount++] = *bI;
195    }
196  }
197
198  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
199
200  delete [] UpdatedStmtList;
201
202  return;
203}
204
205
206// This class visits a compound statement and inserts the StmtList containing
207// destructors in proper locations. This includes inserting them before any
208// return statement in any sub-block, at the end of the logical enclosing
209// scope (compound statement), and/or before any break/continue statement that
210// would resume outside the declared scope. We will not handle the case for
211// goto statements that leave a local scope.
212// TODO(srhines): Make this work properly for break/continue.
213class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
214 private:
215  clang::ASTContext &mC;
216  std::list<clang::Stmt*> &mStmtList;
217  bool mTopLevel;
218 public:
219  DestructorVisitor(clang::ASTContext &C, std::list<clang::Stmt*> &StmtList);
220  void VisitStmt(clang::Stmt *S);
221  void VisitCompoundStmt(clang::CompoundStmt *CS);
222};
223
224DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
225                                     std::list<clang::Stmt*> &StmtList)
226  : mC(C),
227    mStmtList(StmtList),
228    mTopLevel(true) {
229  return;
230}
231
232void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
233  if (!CS->body_empty()) {
234    AppendToCompoundStatement(mC, CS, mStmtList, mTopLevel);
235    mTopLevel = false;
236    VisitStmt(CS);
237  }
238  return;
239}
240
241void DestructorVisitor::VisitStmt(clang::Stmt *S) {
242  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
243       I != E;
244       I++) {
245    if (clang::Stmt *Child = *I) {
246      Visit(Child);
247    }
248  }
249  return;
250}
251
252clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
253                                 clang::Expr *RefRSVar,
254                                 clang::SourceLocation Loc) {
255  slangAssert(RefRSVar);
256  const clang::Type *T = RefRSVar->getType().getTypePtr();
257  slangAssert(!T->isArrayType() &&
258              "Should not be destroying arrays with this function");
259
260  clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
261  slangAssert((ClearObjectFD != NULL) &&
262              "rsClearObject doesn't cover all RS object types");
263
264  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
265  clang::QualType ClearObjectFDArgType =
266      ClearObjectFD->getParamDecl(0)->getOriginalType();
267
268  // Example destructor for "rs_font localFont;"
269  //
270  // (CallExpr 'void'
271  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
272  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
273  //   (UnaryOperator 'rs_font *' prefix '&'
274  //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
275
276  // Get address of targeted RS object
277  clang::Expr *AddrRefRSVar =
278      new(C) clang::UnaryOperator(RefRSVar,
279                                  clang::UO_AddrOf,
280                                  ClearObjectFDArgType,
281                                  Loc);
282
283  clang::Expr *RefRSClearObjectFD =
284      clang::DeclRefExpr::Create(C,
285                                 NULL,
286                                 ClearObjectFD->getQualifierRange(),
287                                 ClearObjectFD,
288                                 ClearObjectFD->getLocation(),
289                                 ClearObjectFDType);
290
291  clang::Expr *RSClearObjectFP =
292      clang::ImplicitCastExpr::Create(C,
293                                      C.getPointerType(ClearObjectFDType),
294                                      clang::CK_FunctionToPointerDecay,
295                                      RefRSClearObjectFD,
296                                      NULL,
297                                      clang::VK_RValue);
298
299  clang::CallExpr *RSClearObjectCall =
300      new(C) clang::CallExpr(C,
301                             RSClearObjectFP,
302                             &AddrRefRSVar,
303                             1,
304                             ClearObjectFD->getCallResultType(),
305                             Loc);
306
307  return RSClearObjectCall;
308}
309
310static int ArrayDim(const clang::Type *T) {
311  if (!T || !T->isArrayType()) {
312    return 0;
313  }
314
315  const clang::ConstantArrayType *CAT =
316    static_cast<const clang::ConstantArrayType *>(T);
317  return static_cast<int>(CAT->getSize().getSExtValue());
318}
319
320static clang::Stmt *ClearStructRSObject(
321    clang::ASTContext &C,
322    clang::DeclContext *DC,
323    clang::Expr *RefRSStruct,
324    clang::SourceRange Range,
325    clang::SourceLocation Loc);
326
327static clang::Stmt *ClearArrayRSObject(
328    clang::ASTContext &C,
329    clang::DeclContext *DC,
330    clang::Expr *RefRSArr,
331    clang::SourceRange Range,
332    clang::SourceLocation Loc) {
333  const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
334  slangAssert(BaseType->isArrayType());
335
336  int NumArrayElements = ArrayDim(BaseType);
337  // Actually extract out the base RS object type for use later
338  BaseType = BaseType->getArrayElementTypeNoTypeQual();
339
340  clang::Stmt *StmtArray[2] = {NULL};
341  int StmtCtr = 0;
342
343  if (NumArrayElements <= 0) {
344    return NULL;
345  }
346
347  // Example destructor loop for "rs_font fontArr[10];"
348  //
349  // (CompoundStmt
350  //   (DeclStmt "int rsIntIter")
351  //   (ForStmt
352  //     (BinaryOperator 'int' '='
353  //       (DeclRefExpr 'int' Var='rsIntIter')
354  //       (IntegerLiteral 'int' 0))
355  //     (BinaryOperator 'int' '<'
356  //       (DeclRefExpr 'int' Var='rsIntIter')
357  //       (IntegerLiteral 'int' 10)
358  //     NULL << CondVar >>
359  //     (UnaryOperator 'int' postfix '++'
360  //       (DeclRefExpr 'int' Var='rsIntIter'))
361  //     (CallExpr 'void'
362  //       (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
363  //         (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
364  //       (UnaryOperator 'rs_font *' prefix '&'
365  //         (ArraySubscriptExpr 'rs_font':'rs_font'
366  //           (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
367  //             (DeclRefExpr 'rs_font [10]' Var='fontArr'))
368  //           (DeclRefExpr 'int' Var='rsIntIter')))))))
369
370  // Create helper variable for iterating through elements
371  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
372  clang::VarDecl *IIVD =
373      clang::VarDecl::Create(C,
374                             DC,
375                             Loc,
376                             &II,
377                             C.IntTy,
378                             C.getTrivialTypeSourceInfo(C.IntTy),
379                             clang::SC_None,
380                             clang::SC_None);
381  clang::Decl *IID = (clang::Decl *)IIVD;
382
383  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
384  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
385
386  // Form the actual destructor loop
387  // for (Init; Cond; Inc)
388  //   RSClearObjectCall;
389
390  // Init -> "rsIntIter = 0"
391  clang::DeclRefExpr *RefrsIntIter =
392      clang::DeclRefExpr::Create(C,
393                                 NULL,
394                                 Range,
395                                 IIVD,
396                                 Loc,
397                                 C.IntTy);
398
399  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
400      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
401
402  clang::BinaryOperator *Init =
403      new(C) clang::BinaryOperator(RefrsIntIter,
404                                   Int0,
405                                   clang::BO_Assign,
406                                   C.IntTy,
407                                   Loc);
408
409  // Cond -> "rsIntIter < NumArrayElements"
410  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
411      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
412
413  clang::BinaryOperator *Cond =
414      new(C) clang::BinaryOperator(RefrsIntIter,
415                                   NumArrayElementsExpr,
416                                   clang::BO_LT,
417                                   C.IntTy,
418                                   Loc);
419
420  // Inc -> "rsIntIter++"
421  clang::UnaryOperator *Inc =
422      new(C) clang::UnaryOperator(RefrsIntIter,
423                                  clang::UO_PostInc,
424                                  C.IntTy,
425                                  Loc);
426
427  // Body -> "rsClearObject(&VD[rsIntIter]);"
428  // Destructor loop operates on individual array elements
429
430  clang::Expr *RefRSArrPtr =
431      clang::ImplicitCastExpr::Create(C,
432          C.getPointerType(BaseType->getCanonicalTypeInternal()),
433          clang::CK_ArrayToPointerDecay,
434          RefRSArr,
435          NULL,
436          clang::VK_RValue);
437
438  clang::Expr *RefRSArrPtrSubscript =
439      new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
440                                       RefrsIntIter,
441                                       BaseType->getCanonicalTypeInternal(),
442                                       Loc);
443
444  RSExportPrimitiveType::DataType DT =
445      RSExportPrimitiveType::GetRSSpecificType(BaseType);
446
447  clang::Stmt *RSClearObjectCall = NULL;
448  if (BaseType->isArrayType()) {
449    RSClearObjectCall =
450        ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, Range, Loc);
451  } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
452    RSClearObjectCall =
453        ClearStructRSObject(C, DC, RefRSArrPtrSubscript, Range, Loc);
454  } else {
455    RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
456  }
457
458  clang::ForStmt *DestructorLoop =
459      new(C) clang::ForStmt(C,
460                            Init,
461                            Cond,
462                            NULL,  // no condVar
463                            Inc,
464                            RSClearObjectCall,
465                            Loc,
466                            Loc,
467                            Loc);
468
469  StmtArray[StmtCtr++] = DestructorLoop;
470  slangAssert(StmtCtr == 2);
471
472  clang::CompoundStmt *CS =
473      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
474
475  return CS;
476}
477
478static unsigned CountRSObjectTypesInStruct(const clang::Type *T) {
479  slangAssert(T);
480  unsigned RSObjectCount = 0;
481
482  if (T->isArrayType()) {
483    return CountRSObjectTypesInStruct(T->getArrayElementTypeNoTypeQual());
484  }
485
486  RSExportPrimitiveType::DataType DT =
487      RSExportPrimitiveType::GetRSSpecificType(T);
488  if (DT != RSExportPrimitiveType::DataTypeUnknown) {
489    return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
490  }
491
492  if (!T->isStructureType()) {
493    return 0;
494  }
495
496  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
497  RD = RD->getDefinition();
498  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
499         FE = RD->field_end();
500       FI != FE;
501       FI++) {
502    const clang::FieldDecl *FD = *FI;
503    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
504    if (CountRSObjectTypesInStruct(FT)) {
505      // Sub-structs should only count once (as should arrays, etc.)
506      RSObjectCount++;
507    }
508  }
509
510  return RSObjectCount;
511}
512
513static clang::Stmt *ClearStructRSObject(
514    clang::ASTContext &C,
515    clang::DeclContext *DC,
516    clang::Expr *RefRSStruct,
517    clang::SourceRange Range,
518    clang::SourceLocation Loc) {
519  const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
520
521  slangAssert(!BaseType->isArrayType());
522
523  RSExportPrimitiveType::DataType DT =
524      RSExportPrimitiveType::GetRSSpecificType(BaseType);
525
526  // Structs should show up as unknown primitive types
527  slangAssert(DT == RSExportPrimitiveType::DataTypeUnknown);
528
529  unsigned FieldsToDestroy = CountRSObjectTypesInStruct(BaseType);
530
531  unsigned StmtCount = 0;
532  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
533
534  // Populate StmtArray by creating a destructor for each RS object field
535  clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
536  RD = RD->getDefinition();
537  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
538         FE = RD->field_end();
539       FI != FE;
540       FI++) {
541    // We just look through all field declarations to see if we find a
542    // declaration for an RS object type (or an array of one).
543    bool IsArrayType = false;
544    clang::FieldDecl *FD = *FI;
545    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
546    const clang::Type *OrigType = FT;
547    while (FT && FT->isArrayType()) {
548      FT = FT->getArrayElementTypeNoTypeQual();
549      IsArrayType = true;
550    }
551
552    if (RSExportPrimitiveType::IsRSObjectType(FT)) {
553      clang::DeclAccessPair FoundDecl =
554          clang::DeclAccessPair::make(FD, clang::AS_none);
555      clang::MemberExpr *RSObjectMember =
556          clang::MemberExpr::Create(C,
557                                    RefRSStruct,
558                                    false,
559                                    NULL,
560                                    Range,
561                                    FD,
562                                    FoundDecl,
563                                    clang::DeclarationNameInfo(),
564                                    NULL,
565                                    OrigType->getCanonicalTypeInternal());
566
567      slangAssert(StmtCount < FieldsToDestroy);
568
569      if (IsArrayType) {
570        StmtArray[StmtCount++] = ClearArrayRSObject(C,
571                                                    DC,
572                                                    RSObjectMember,
573                                                    Range,
574                                                    Loc);
575      } else {
576        StmtArray[StmtCount++] = ClearSingleRSObject(C,
577                                                     RSObjectMember,
578                                                     Loc);
579      }
580    } else if (FT->isStructureType() && CountRSObjectTypesInStruct(FT)) {
581      // In this case, we have a nested struct. We may not end up filling all
582      // of the spaces in StmtArray (sub-structs should handle themselves
583      // with separate compound statements).
584      clang::DeclAccessPair FoundDecl =
585          clang::DeclAccessPair::make(FD, clang::AS_none);
586      clang::MemberExpr *RSObjectMember =
587          clang::MemberExpr::Create(C,
588                                    RefRSStruct,
589                                    false,
590                                    NULL,
591                                    Range,
592                                    FD,
593                                    FoundDecl,
594                                    clang::DeclarationNameInfo(),
595                                    NULL,
596                                    OrigType->getCanonicalTypeInternal());
597
598      if (IsArrayType) {
599        StmtArray[StmtCount++] = ClearArrayRSObject(C,
600                                                    DC,
601                                                    RSObjectMember,
602                                                    Range,
603                                                    Loc);
604      } else {
605        StmtArray[StmtCount++] = ClearStructRSObject(C,
606                                                     DC,
607                                                     RSObjectMember,
608                                                     Range,
609                                                     Loc);
610      }
611    }
612  }
613
614  slangAssert(StmtCount > 0);
615  clang::CompoundStmt *CS =
616      new(C) clang::CompoundStmt(C, StmtArray, StmtCount, Loc, Loc);
617
618  delete [] StmtArray;
619
620  return CS;
621}
622
623}  // namespace
624
625void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
626    clang::BinaryOperator *AS) {
627
628  clang::QualType QT = AS->getType();
629
630  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(
631      QT.getTypePtr());
632  slangAssert((SetObjectFD != NULL) &&
633              "rsSetObject doesn't cover all RS object types");
634  clang::ASTContext &C = SetObjectFD->getASTContext();
635
636  clang::QualType SetObjectFDType = SetObjectFD->getType();
637  clang::QualType SetObjectFDArgType[2];
638  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
639  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
640
641  clang::SourceLocation Loc = SetObjectFD->getLocation();
642  clang::Expr *RefRSSetObjectFD =
643      clang::DeclRefExpr::Create(C,
644                                 NULL,
645                                 SetObjectFD->getQualifierRange(),
646                                 SetObjectFD,
647                                 Loc,
648                                 SetObjectFDType);
649
650  clang::Expr *RSSetObjectFP =
651      clang::ImplicitCastExpr::Create(C,
652                                      C.getPointerType(SetObjectFDType),
653                                      clang::CK_FunctionToPointerDecay,
654                                      RefRSSetObjectFD,
655                                      NULL,
656                                      clang::VK_RValue);
657
658  clang::Expr *ArgList[2];
659  ArgList[0] = new(C) clang::UnaryOperator(AS->getLHS(),
660                                           clang::UO_AddrOf,
661                                           SetObjectFDArgType[0],
662                                           Loc);
663  ArgList[1] = AS->getRHS();
664
665  clang::CallExpr *RSSetObjectCall =
666      new(C) clang::CallExpr(C,
667                             RSSetObjectFP,
668                             ArgList,
669                             2,
670                             SetObjectFD->getCallResultType(),
671                             Loc);
672
673  ReplaceInCompoundStmt(C, mCS, AS, RSSetObjectCall);
674
675  return;
676}
677
678void RSObjectRefCount::Scope::AppendRSObjectInit(
679    clang::VarDecl *VD,
680    clang::DeclStmt *DS,
681    RSExportPrimitiveType::DataType DT,
682    clang::Expr *InitExpr) {
683  slangAssert(VD);
684
685  if (!InitExpr) {
686    return;
687  }
688
689  if (DT == RSExportPrimitiveType::DataTypeIsStruct) {
690    // TODO(srhines): Skip struct initialization right now
691    return;
692  }
693
694  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
695  slangAssert((SetObjectFD != NULL) &&
696              "rsSetObject doesn't cover all RS object types");
697  clang::ASTContext &C = SetObjectFD->getASTContext();
698
699  clang::QualType SetObjectFDType = SetObjectFD->getType();
700  clang::QualType SetObjectFDArgType[2];
701  SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
702  SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
703
704  clang::SourceLocation Loc = SetObjectFD->getLocation();
705  clang::Expr *RefRSSetObjectFD =
706      clang::DeclRefExpr::Create(C,
707                                 NULL,
708                                 SetObjectFD->getQualifierRange(),
709                                 SetObjectFD,
710                                 Loc,
711                                 SetObjectFDType);
712
713  clang::Expr *RSSetObjectFP =
714      clang::ImplicitCastExpr::Create(C,
715                                      C.getPointerType(SetObjectFDType),
716                                      clang::CK_FunctionToPointerDecay,
717                                      RefRSSetObjectFD,
718                                      NULL,
719                                      clang::VK_RValue);
720
721  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
722  clang::DeclRefExpr *RefRSVar =
723      clang::DeclRefExpr::Create(C,
724                                 NULL,
725                                 VD->getQualifierRange(),
726                                 VD,
727                                 Loc,
728                                 T->getCanonicalTypeInternal());
729
730  clang::Expr *ArgList[2];
731  ArgList[0] = new(C) clang::UnaryOperator(RefRSVar,
732                                           clang::UO_AddrOf,
733                                           SetObjectFDArgType[0],
734                                           Loc);
735  ArgList[1] = InitExpr;
736
737  clang::CallExpr *RSSetObjectCall =
738      new(C) clang::CallExpr(C,
739                             RSSetObjectFP,
740                             ArgList,
741                             2,
742                             SetObjectFD->getCallResultType(),
743                             Loc);
744
745  AppendAfterStmt(C, mCS, DS, RSSetObjectCall);
746
747  return;
748}
749
750void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
751  std::list<clang::Stmt*> RSClearObjectCalls;
752  for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
753          E = mRSO.end();
754        I != E;
755        I++) {
756    clang::Stmt *S = ClearRSObject(*I);
757    if (S) {
758      RSClearObjectCalls.push_back(S);
759    }
760  }
761  if (RSClearObjectCalls.size() > 0) {
762    DestructorVisitor DV((*mRSO.begin())->getASTContext(), RSClearObjectCalls);
763    DV.Visit(mCS);
764  }
765  return;
766}
767
768clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(clang::VarDecl *VD) {
769  slangAssert(VD);
770  clang::ASTContext &C = VD->getASTContext();
771  clang::DeclContext *DC = VD->getDeclContext();
772  clang::SourceRange Range = VD->getQualifierRange();
773  clang::SourceLocation Loc = VD->getLocation();
774  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
775
776  // Reference expr to target RS object variable
777  clang::DeclRefExpr *RefRSVar =
778      clang::DeclRefExpr::Create(C,
779                                 NULL,
780                                 Range,
781                                 VD,
782                                 Loc,
783                                 T->getCanonicalTypeInternal());
784
785  if (T->isArrayType()) {
786    return ClearArrayRSObject(C, DC, RefRSVar, Range, Loc);
787  }
788
789  RSExportPrimitiveType::DataType DT =
790      RSExportPrimitiveType::GetRSSpecificType(T);
791
792  if (DT == RSExportPrimitiveType::DataTypeUnknown ||
793      DT == RSExportPrimitiveType::DataTypeIsStruct) {
794    return ClearStructRSObject(C, DC, RefRSVar, Range, Loc);
795  }
796
797  slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
798              "Should be RS object");
799
800  return ClearSingleRSObject(C, RefRSVar, Loc);
801}
802
803bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
804                                          RSExportPrimitiveType::DataType *DT,
805                                          clang::Expr **InitExpr) {
806  slangAssert(VD && DT && InitExpr);
807  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
808
809  // Loop through array types to get to base type
810  while (T && T->isArrayType()) {
811    T = T->getArrayElementTypeNoTypeQual();
812  }
813
814  bool DataTypeIsStructWithRSObject = false;
815  *DT = RSExportPrimitiveType::GetRSSpecificType(T);
816
817  if (*DT == RSExportPrimitiveType::DataTypeUnknown) {
818    if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
819      *DT = RSExportPrimitiveType::DataTypeIsStruct;
820      DataTypeIsStructWithRSObject = true;
821    } else {
822      return false;
823    }
824  }
825
826  bool DataTypeIsRSObject = false;
827  if (DataTypeIsStructWithRSObject) {
828    DataTypeIsRSObject = true;
829  } else {
830    DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
831  }
832  *InitExpr = VD->getInit();
833
834  if (!DataTypeIsRSObject && *InitExpr) {
835    // If we already have an initializer for a matrix type, we are done.
836    return DataTypeIsRSObject;
837  }
838
839  clang::Expr *ZeroInitializer =
840      CreateZeroInitializerForRSSpecificType(*DT,
841                                             VD->getASTContext(),
842                                             VD->getLocation());
843
844  if (ZeroInitializer) {
845    ZeroInitializer->setType(T->getCanonicalTypeInternal());
846    VD->setInit(ZeroInitializer);
847  }
848
849  return DataTypeIsRSObject;
850}
851
852clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
853    RSExportPrimitiveType::DataType DT,
854    clang::ASTContext &C,
855    const clang::SourceLocation &Loc) {
856  clang::Expr *Res = NULL;
857  switch (DT) {
858    case RSExportPrimitiveType::DataTypeIsStruct:
859    case RSExportPrimitiveType::DataTypeRSElement:
860    case RSExportPrimitiveType::DataTypeRSType:
861    case RSExportPrimitiveType::DataTypeRSAllocation:
862    case RSExportPrimitiveType::DataTypeRSSampler:
863    case RSExportPrimitiveType::DataTypeRSScript:
864    case RSExportPrimitiveType::DataTypeRSMesh:
865    case RSExportPrimitiveType::DataTypeRSProgramFragment:
866    case RSExportPrimitiveType::DataTypeRSProgramVertex:
867    case RSExportPrimitiveType::DataTypeRSProgramRaster:
868    case RSExportPrimitiveType::DataTypeRSProgramStore:
869    case RSExportPrimitiveType::DataTypeRSFont: {
870      //    (ImplicitCastExpr 'nullptr_t'
871      //      (IntegerLiteral 0)))
872      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
873      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
874      clang::Expr *CastToNull =
875          clang::ImplicitCastExpr::Create(C,
876                                          C.NullPtrTy,
877                                          clang::CK_IntegralToPointer,
878                                          Int0,
879                                          NULL,
880                                          clang::VK_RValue);
881
882      Res = new(C) clang::InitListExpr(C, Loc, &CastToNull, 1, Loc);
883      break;
884    }
885    case RSExportPrimitiveType::DataTypeRSMatrix2x2:
886    case RSExportPrimitiveType::DataTypeRSMatrix3x3:
887    case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
888      // RS matrix is not completely an RS object. They hold data by themselves.
889      // (InitListExpr rs_matrix2x2
890      //   (InitListExpr float[4]
891      //     (FloatingLiteral 0)
892      //     (FloatingLiteral 0)
893      //     (FloatingLiteral 0)
894      //     (FloatingLiteral 0)))
895      clang::QualType FloatTy = C.FloatTy;
896      // Constructor sets value to 0.0f by default
897      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
898      clang::FloatingLiteral *Float0Val =
899          clang::FloatingLiteral::Create(C,
900                                         Val,
901                                         /* isExact = */true,
902                                         FloatTy,
903                                         Loc);
904
905      unsigned N = 0;
906      if (DT == RSExportPrimitiveType::DataTypeRSMatrix2x2)
907        N = 2;
908      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix3x3)
909        N = 3;
910      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix4x4)
911        N = 4;
912
913      // Directly allocate 16 elements instead of dynamically allocate N*N
914      clang::Expr *InitVals[16];
915      for (unsigned i = 0; i < sizeof(InitVals) / sizeof(InitVals[0]); i++)
916        InitVals[i] = Float0Val;
917      clang::Expr *InitExpr =
918          new(C) clang::InitListExpr(C, Loc, InitVals, N * N, Loc);
919      InitExpr->setType(C.getConstantArrayType(FloatTy,
920                                               llvm::APInt(32, 4),
921                                               clang::ArrayType::Normal,
922                                               /* EltTypeQuals = */0));
923
924      Res = new(C) clang::InitListExpr(C, Loc, &InitExpr, 1, Loc);
925      break;
926    }
927    case RSExportPrimitiveType::DataTypeUnknown:
928    case RSExportPrimitiveType::DataTypeFloat16:
929    case RSExportPrimitiveType::DataTypeFloat32:
930    case RSExportPrimitiveType::DataTypeFloat64:
931    case RSExportPrimitiveType::DataTypeSigned8:
932    case RSExportPrimitiveType::DataTypeSigned16:
933    case RSExportPrimitiveType::DataTypeSigned32:
934    case RSExportPrimitiveType::DataTypeSigned64:
935    case RSExportPrimitiveType::DataTypeUnsigned8:
936    case RSExportPrimitiveType::DataTypeUnsigned16:
937    case RSExportPrimitiveType::DataTypeUnsigned32:
938    case RSExportPrimitiveType::DataTypeUnsigned64:
939    case RSExportPrimitiveType::DataTypeBoolean:
940    case RSExportPrimitiveType::DataTypeUnsigned565:
941    case RSExportPrimitiveType::DataTypeUnsigned5551:
942    case RSExportPrimitiveType::DataTypeUnsigned4444:
943    case RSExportPrimitiveType::DataTypeMax: {
944      slangAssert(false && "Not RS object type!");
945    }
946    // No default case will enable compiler detecting the missing cases
947  }
948
949  return Res;
950}
951
952void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
953  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
954       I != E;
955       I++) {
956    clang::Decl *D = *I;
957    if (D->getKind() == clang::Decl::Var) {
958      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
959      RSExportPrimitiveType::DataType DT =
960          RSExportPrimitiveType::DataTypeUnknown;
961      clang::Expr *InitExpr = NULL;
962      if (InitializeRSObject(VD, &DT, &InitExpr)) {
963        getCurrentScope()->addRSObject(VD);
964        getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
965      }
966    }
967  }
968  return;
969}
970
971void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
972  if (!CS->body_empty()) {
973    // Push a new scope
974    Scope *S = new Scope(CS);
975    mScopeStack.push(S);
976
977    VisitStmt(CS);
978
979    // Destroy the scope
980    slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
981    S->InsertLocalVarDestructors();
982    mScopeStack.pop();
983    delete S;
984  }
985  return;
986}
987
988void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
989  clang::QualType QT = AS->getType();
990
991  if (RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
992    getCurrentScope()->ReplaceRSObjectAssignment(AS);
993  }
994
995  return;
996}
997
998void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
999  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1000       I != E;
1001       I++) {
1002    if (clang::Stmt *Child = *I) {
1003      Visit(Child);
1004    }
1005  }
1006  return;
1007}
1008
1009}  // namespace slang
1010