slang_rs_backend.cpp revision 4a4bf92a8add68629a7e6e59ef81c3c3fe603a75
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "llvm/ADT/Twine.h"
23#include "llvm/ADT/StringExtras.h"
24
25#include "llvm/Constant.h"
26#include "llvm/Constants.h"
27#include "llvm/DerivedTypes.h"
28#include "llvm/Function.h"
29#include "llvm/Metadata.h"
30#include "llvm/Module.h"
31
32#include "llvm/Support/IRBuilder.h"
33
34#include "slang_assert.h"
35#include "slang_rs.h"
36#include "slang_rs_context.h"
37#include "slang_rs_export_foreach.h"
38#include "slang_rs_export_func.h"
39#include "slang_rs_export_type.h"
40#include "slang_rs_export_var.h"
41#include "slang_rs_metadata.h"
42
43namespace slang {
44
45RSBackend::RSBackend(RSContext *Context,
46                     clang::Diagnostic *Diags,
47                     const clang::CodeGenOptions &CodeGenOpts,
48                     const clang::TargetOptions &TargetOpts,
49                     PragmaList *Pragmas,
50                     llvm::raw_ostream *OS,
51                     Slang::OutputType OT,
52                     clang::SourceManager &SourceMgr,
53                     bool AllowRSPrefix)
54    : Backend(Diags,
55              CodeGenOpts,
56              TargetOpts,
57              Pragmas,
58              OS,
59              OT),
60      mContext(Context),
61      mSourceMgr(SourceMgr),
62      mAllowRSPrefix(AllowRSPrefix),
63      mExportVarMetadata(NULL),
64      mExportFuncMetadata(NULL),
65      mExportForEachMetadata(NULL),
66      mExportTypeMetadata(NULL),
67      mRSObjectSlotsMetadata(NULL),
68      mRefCount(mContext->getASTContext()) {
69  return;
70}
71
72// 1) Add zero initialization of local RS object types
73void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
74  if (FD &&
75      FD->hasBody() &&
76      !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) {
77    mRefCount.Init();
78    mRefCount.Visit(FD->getBody());
79  }
80  return;
81}
82
83void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
84  // Disallow user-defined functions with prefix "rs"
85  if (!mAllowRSPrefix) {
86    // Iterate all function declarations in the program.
87    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
88         I != E; I++) {
89      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
90      if (FD == NULL)
91        continue;
92      if (!FD->getName().startswith("rs"))  // Check prefix
93        continue;
94      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
95        mDiags.Report(clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
96                      mDiags.getCustomDiagID(clang::Diagnostic::Error,
97                                             "invalid function name prefix, "
98                                             "\"rs\" is reserved: '%0'"))
99            << FD->getName();
100    }
101  }
102
103  // Process any non-static function declarations
104  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
105    clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
106    if (FD && FD->isGlobal()) {
107      AnnotateFunction(FD);
108    }
109  }
110
111  Backend::HandleTopLevelDecl(D);
112  return;
113}
114
115namespace {
116
117static bool ValidateVarDecl(clang::VarDecl *VD) {
118  if (!VD) {
119    return true;
120  }
121
122  clang::ASTContext &C = VD->getASTContext();
123  const clang::Type *T = VD->getType().getTypePtr();
124  bool valid = true;
125
126  if (VD->getLinkage() == clang::ExternalLinkage) {
127    llvm::StringRef TypeName;
128    if (!RSExportType::NormalizeType(T, TypeName, &C.getDiagnostics(), VD)) {
129      valid = false;
130    }
131  }
132  valid &= RSExportType::ValidateVarDecl(VD);
133
134  return valid;
135}
136
137static bool ValidateASTContext(clang::ASTContext &C) {
138  bool valid = true;
139  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
140  for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
141          DE = TUDecl->decls_end();
142       DI != DE;
143       DI++) {
144    clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*DI);
145    if (VD && !ValidateVarDecl(VD)) {
146      valid = false;
147    }
148  }
149
150  return valid;
151}
152
153}  // namespace
154
155void RSBackend::HandleTranslationUnitPre(clang::ASTContext &C) {
156  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
157
158  if (!ValidateASTContext(C)) {
159    return;
160  }
161
162  int version = mContext->getVersion();
163  if (version == 0) {
164    // Not setting a version is an error
165    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
166                      "Missing pragma for version in source file"));
167  } else if (version > 1) {
168    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
169                      "Pragma for version in source file must be set to 1"));
170  }
171
172  // Process any static function declarations
173  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
174          E = TUDecl->decls_end(); I != E; I++) {
175    if ((I->getKind() >= clang::Decl::firstFunction) &&
176        (I->getKind() <= clang::Decl::lastFunction)) {
177      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
178      if (FD && !FD->isGlobal()) {
179        AnnotateFunction(FD);
180      }
181    }
182  }
183
184  return;
185}
186
187///////////////////////////////////////////////////////////////////////////////
188void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
189  if (!mContext->processExport()) {
190    return;
191  }
192
193  // Dump export variable info
194  if (mContext->hasExportVar()) {
195    int slotCount = 0;
196    if (mExportVarMetadata == NULL)
197      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
198
199    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
200
201    // We emit slot information (#rs_object_slots) for any reference counted
202    // RS type or pointer (which can also be bound).
203
204    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
205            E = mContext->export_vars_end();
206         I != E;
207         I++) {
208      const RSExportVar *EV = *I;
209      const RSExportType *ET = EV->getType();
210      bool countsAsRSObject = false;
211
212      // Variable name
213      ExportVarInfo.push_back(
214          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
215
216      // Type name
217      switch (ET->getClass()) {
218        case RSExportType::ExportClassPrimitive: {
219          const RSExportPrimitiveType *PT =
220              static_cast<const RSExportPrimitiveType*>(ET);
221          ExportVarInfo.push_back(
222              llvm::MDString::get(
223                mLLVMContext, llvm::utostr_32(PT->getType())));
224          if (PT->isRSObjectType()) {
225            countsAsRSObject = true;
226          }
227          break;
228        }
229        case RSExportType::ExportClassPointer: {
230          ExportVarInfo.push_back(
231              llvm::MDString::get(
232                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
233                  ->getPointeeType()->getName()).c_str()));
234          break;
235        }
236        case RSExportType::ExportClassMatrix: {
237          ExportVarInfo.push_back(
238              llvm::MDString::get(
239                mLLVMContext, llvm::utostr_32(
240                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
241                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
242          break;
243        }
244        case RSExportType::ExportClassVector:
245        case RSExportType::ExportClassConstantArray:
246        case RSExportType::ExportClassRecord: {
247          ExportVarInfo.push_back(
248              llvm::MDString::get(mLLVMContext,
249                EV->getType()->getName().c_str()));
250          break;
251        }
252      }
253
254      mExportVarMetadata->addOperand(
255          llvm::MDNode::get(mLLVMContext, ExportVarInfo));
256      ExportVarInfo.clear();
257
258      if (mRSObjectSlotsMetadata == NULL) {
259        mRSObjectSlotsMetadata =
260            M->getOrInsertNamedMetadata(RS_OBJECT_SLOTS_MN);
261      }
262
263      if (countsAsRSObject) {
264        mRSObjectSlotsMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
265            llvm::MDString::get(mLLVMContext, llvm::utostr_32(slotCount))));
266      }
267
268      slotCount++;
269    }
270  }
271
272  // Dump export function info
273  if (mContext->hasExportFunc()) {
274    if (mExportFuncMetadata == NULL)
275      mExportFuncMetadata =
276          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
277
278    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
279
280    for (RSContext::const_export_func_iterator
281            I = mContext->export_funcs_begin(),
282            E = mContext->export_funcs_end();
283         I != E;
284         I++) {
285      const RSExportFunc *EF = *I;
286
287      // Function name
288      if (!EF->hasParam()) {
289        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
290                                                     EF->getName().c_str()));
291      } else {
292        llvm::Function *F = M->getFunction(EF->getName());
293        llvm::Function *HelperFunction;
294        const std::string HelperFunctionName(".helper_" + EF->getName());
295
296        slangAssert(F && "Function marked as exported disappeared in Bitcode");
297
298        // Create helper function
299        {
300          llvm::StructType *HelperFunctionParameterTy = NULL;
301
302          if (!F->getArgumentList().empty()) {
303            std::vector<llvm::Type*> HelperFunctionParameterTys;
304            for (llvm::Function::arg_iterator AI = F->arg_begin(),
305                 AE = F->arg_end(); AI != AE; AI++)
306              HelperFunctionParameterTys.push_back(AI->getType());
307
308            HelperFunctionParameterTy =
309                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
310          }
311
312          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
313            fprintf(stderr, "Failed to export function %s: parameter type "
314                            "mismatch during creation of helper function.\n",
315                    EF->getName().c_str());
316
317            const RSExportRecordType *Expected = EF->getParamPacketType();
318            if (Expected) {
319              fprintf(stderr, "Expected:\n");
320              Expected->getLLVMType()->dump();
321            }
322            if (HelperFunctionParameterTy) {
323              fprintf(stderr, "Got:\n");
324              HelperFunctionParameterTy->dump();
325            }
326          }
327
328          std::vector<llvm::Type*> Params;
329          if (HelperFunctionParameterTy) {
330            llvm::PointerType *HelperFunctionParameterTyP =
331                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
332            Params.push_back(HelperFunctionParameterTyP);
333          }
334
335          llvm::FunctionType * HelperFunctionType =
336              llvm::FunctionType::get(F->getReturnType(),
337                                      Params,
338                                      /* IsVarArgs = */false);
339
340          HelperFunction =
341              llvm::Function::Create(HelperFunctionType,
342                                     llvm::GlobalValue::ExternalLinkage,
343                                     HelperFunctionName,
344                                     M);
345
346          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
347          HelperFunction->setCallingConv(F->getCallingConv());
348
349          // Create helper function body
350          {
351            llvm::Argument *HelperFunctionParameter =
352                &(*HelperFunction->arg_begin());
353            llvm::BasicBlock *BB =
354                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
355            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
356            llvm::SmallVector<llvm::Value*, 6> Params;
357            llvm::Value *Idx[2];
358
359            Idx[0] =
360                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
361
362            // getelementptr and load instruction for all elements in
363            // parameter .p
364            for (size_t i = 0; i < EF->getNumParameters(); i++) {
365              // getelementptr
366              Idx[1] =
367                  llvm::ConstantInt::get(
368                      llvm::Type::getInt32Ty(mLLVMContext), i);
369              llvm::Value *Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter,
370                                                       Idx,
371                                                       Idx + 2);
372
373              // load
374              llvm::Value *V = IB->CreateLoad(Ptr);
375              Params.push_back(V);
376            }
377
378            // Call and pass the all elements as parameter to F
379            llvm::CallInst *CI = IB->CreateCall(F, Params);
380
381            CI->setCallingConv(F->getCallingConv());
382
383            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
384              IB->CreateRetVoid();
385            else
386              IB->CreateRet(CI);
387
388            delete IB;
389          }
390        }
391
392        ExportFuncInfo.push_back(
393            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
394      }
395
396      mExportFuncMetadata->addOperand(
397          llvm::MDNode::get(mLLVMContext, ExportFuncInfo));
398      ExportFuncInfo.clear();
399    }
400  }
401
402  // Dump export function info
403  if (mContext->hasExportForEach()) {
404    if (mExportForEachMetadata == NULL)
405      mExportForEachMetadata =
406          M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_MN);
407
408    llvm::SmallVector<llvm::Value*, 1> ExportForEachInfo;
409
410    for (RSContext::const_export_foreach_iterator
411            I = mContext->export_foreach_begin(),
412            E = mContext->export_foreach_end();
413         I != E;
414         I++) {
415      const RSExportForEach *EFE = *I;
416
417      ExportForEachInfo.push_back(
418          llvm::MDString::get(mLLVMContext,
419                              llvm::utostr_32(EFE->getMetadataEncoding())));
420
421      mExportForEachMetadata->addOperand(
422          llvm::MDNode::get(mLLVMContext, ExportForEachInfo));
423      ExportForEachInfo.clear();
424    }
425  }
426
427  // Dump export type info
428  if (mContext->hasExportType()) {
429    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
430
431    for (RSContext::const_export_type_iterator
432            I = mContext->export_types_begin(),
433            E = mContext->export_types_end();
434         I != E;
435         I++) {
436      // First, dump type name list to export
437      const RSExportType *ET = I->getValue();
438
439      ExportTypeInfo.clear();
440      // Type name
441      ExportTypeInfo.push_back(
442          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
443
444      if (ET->getClass() == RSExportType::ExportClassRecord) {
445        const RSExportRecordType *ERT =
446            static_cast<const RSExportRecordType*>(ET);
447
448        if (mExportTypeMetadata == NULL)
449          mExportTypeMetadata =
450              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
451
452        mExportTypeMetadata->addOperand(
453            llvm::MDNode::get(mLLVMContext, ExportTypeInfo));
454
455        // Now, export struct field information to %[struct name]
456        std::string StructInfoMetadataName("%");
457        StructInfoMetadataName.append(ET->getName());
458        llvm::NamedMDNode *StructInfoMetadata =
459            M->getOrInsertNamedMetadata(StructInfoMetadataName);
460        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
461
462        slangAssert(StructInfoMetadata->getNumOperands() == 0 &&
463                    "Metadata with same name was created before");
464        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
465                FE = ERT->fields_end();
466             FI != FE;
467             FI++) {
468          const RSExportRecordType::Field *F = *FI;
469
470          // 1. field name
471          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
472                                                  F->getName().c_str()));
473
474          // 2. field type name
475          FieldInfo.push_back(
476              llvm::MDString::get(mLLVMContext,
477                                  F->getType()->getName().c_str()));
478
479          // 3. field kind
480          switch (F->getType()->getClass()) {
481            case RSExportType::ExportClassPrimitive:
482            case RSExportType::ExportClassVector: {
483              const RSExportPrimitiveType *EPT =
484                  static_cast<const RSExportPrimitiveType*>(F->getType());
485              FieldInfo.push_back(
486                  llvm::MDString::get(mLLVMContext,
487                                      llvm::itostr(EPT->getKind())));
488              break;
489            }
490
491            default: {
492              FieldInfo.push_back(
493                  llvm::MDString::get(mLLVMContext,
494                                      llvm::itostr(
495                                        RSExportPrimitiveType::DataKindUser)));
496              break;
497            }
498          }
499
500          StructInfoMetadata->addOperand(
501              llvm::MDNode::get(mLLVMContext, FieldInfo));
502          FieldInfo.clear();
503        }
504      }   // ET->getClass() == RSExportType::ExportClassRecord
505    }
506  }
507
508  return;
509}
510
511RSBackend::~RSBackend() {
512  return;
513}
514
515}  // namespace slang
516