slang_rs_backend.cpp revision 43730fe3c839af391efe6bdf56b0479860121924
1/*
2 * Copyright 2010-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "clang/AST/ASTContext.h"
23#include "clang/Frontend/CodeGenOptions.h"
24
25#include "llvm/ADT/Twine.h"
26#include "llvm/ADT/StringExtras.h"
27
28#include "llvm/Constant.h"
29#include "llvm/Constants.h"
30#include "llvm/DerivedTypes.h"
31#include "llvm/Function.h"
32#include "llvm/IRBuilder.h"
33#include "llvm/Metadata.h"
34#include "llvm/Module.h"
35
36#include "llvm/Support/DebugLoc.h"
37
38#include "slang_assert.h"
39#include "slang_rs.h"
40#include "slang_rs_context.h"
41#include "slang_rs_export_foreach.h"
42#include "slang_rs_export_func.h"
43#include "slang_rs_export_type.h"
44#include "slang_rs_export_var.h"
45#include "slang_rs_metadata.h"
46
47namespace slang {
48
49RSBackend::RSBackend(RSContext *Context,
50                     clang::DiagnosticsEngine *DiagEngine,
51                     const clang::CodeGenOptions &CodeGenOpts,
52                     const clang::TargetOptions &TargetOpts,
53                     PragmaList *Pragmas,
54                     llvm::raw_ostream *OS,
55                     Slang::OutputType OT,
56                     clang::SourceManager &SourceMgr,
57                     bool AllowRSPrefix)
58  : Backend(DiagEngine, CodeGenOpts, TargetOpts, Pragmas, OS, OT),
59    mContext(Context),
60    mSourceMgr(SourceMgr),
61    mAllowRSPrefix(AllowRSPrefix),
62    mExportVarMetadata(NULL),
63    mExportFuncMetadata(NULL),
64    mExportForEachNameMetadata(NULL),
65    mExportForEachSignatureMetadata(NULL),
66    mExportTypeMetadata(NULL),
67    mRSObjectSlotsMetadata(NULL),
68    mRefCount(mContext->getASTContext()) {
69}
70
71// 1) Add zero initialization of local RS object types
72void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
73  if (FD &&
74      FD->hasBody() &&
75      !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) {
76    mRefCount.Init();
77    mRefCount.Visit(FD->getBody());
78  }
79  return;
80}
81
82bool RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
83  // Disallow user-defined functions with prefix "rs"
84  if (!mAllowRSPrefix) {
85    // Iterate all function declarations in the program.
86    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
87         I != E; I++) {
88      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
89      if (FD == NULL)
90        continue;
91      if (!FD->getName().startswith("rs"))  // Check prefix
92        continue;
93      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
94        mDiagEngine.Report(
95          clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
96          mDiagEngine.getCustomDiagID(clang::DiagnosticsEngine::Error,
97                                      "invalid function name prefix, "
98                                      "\"rs\" is reserved: '%0'"))
99          << FD->getName();
100    }
101  }
102
103  // Process any non-static function declarations
104  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
105    clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
106    if (FD && FD->isGlobal()) {
107      AnnotateFunction(FD);
108    }
109  }
110
111  return Backend::HandleTopLevelDecl(D);
112}
113
114namespace {
115
116static bool ValidateVarDecl(clang::VarDecl *VD, unsigned int TargetAPI) {
117  if (!VD) {
118    return true;
119  }
120
121  clang::ASTContext &C = VD->getASTContext();
122  const clang::Type *T = VD->getType().getTypePtr();
123  bool valid = true;
124
125  if (VD->getLinkage() == clang::ExternalLinkage) {
126    llvm::StringRef TypeName;
127    if (!RSExportType::NormalizeType(T, TypeName, &C.getDiagnostics(), VD)) {
128      valid = false;
129    }
130  }
131  valid &= RSExportType::ValidateVarDecl(VD, TargetAPI);
132
133  return valid;
134}
135
136static bool ValidateASTContext(clang::ASTContext &C, unsigned int TargetAPI) {
137  bool valid = true;
138  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
139  for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
140          DE = TUDecl->decls_end();
141       DI != DE;
142       DI++) {
143    clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*DI);
144    if (VD && !ValidateVarDecl(VD, TargetAPI)) {
145      valid = false;
146    }
147  }
148
149  return valid;
150}
151
152}  // namespace
153
154void RSBackend::HandleTranslationUnitPre(clang::ASTContext &C) {
155  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
156
157  if (!ValidateASTContext(C, getTargetAPI())) {
158    return;
159  }
160
161  int version = mContext->getVersion();
162  if (version == 0) {
163    // Not setting a version is an error
164    mDiagEngine.Report(
165        mSourceMgr.getLocForEndOfFile(mSourceMgr.getMainFileID()),
166        mDiagEngine.getCustomDiagID(
167            clang::DiagnosticsEngine::Error,
168            "missing pragma for version in source file"));
169  } else {
170    slangAssert(version == 1);
171  }
172
173  // Create a static global destructor if necessary (to handle RS object
174  // runtime cleanup).
175  clang::FunctionDecl *FD = mRefCount.CreateStaticGlobalDtor();
176  if (FD) {
177    HandleTopLevelDecl(clang::DeclGroupRef(FD));
178  }
179
180  // Process any static function declarations
181  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
182          E = TUDecl->decls_end(); I != E; I++) {
183    if ((I->getKind() >= clang::Decl::firstFunction) &&
184        (I->getKind() <= clang::Decl::lastFunction)) {
185      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
186      if (FD && !FD->isGlobal()) {
187        AnnotateFunction(FD);
188      }
189    }
190  }
191
192  return;
193}
194
195///////////////////////////////////////////////////////////////////////////////
196void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
197  if (!mContext->processExport()) {
198    return;
199  }
200
201  // Write optimization level
202  llvm::SmallVector<llvm::Value*, 1> OptimizationOption;
203  OptimizationOption.push_back(llvm::ConstantInt::get(
204    mLLVMContext, llvm::APInt(32, mCodeGenOpts.OptimizationLevel)));
205
206  // Dump export variable info
207  if (mContext->hasExportVar()) {
208    int slotCount = 0;
209    if (mExportVarMetadata == NULL)
210      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
211
212    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
213
214    // We emit slot information (#rs_object_slots) for any reference counted
215    // RS type or pointer (which can also be bound).
216
217    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
218            E = mContext->export_vars_end();
219         I != E;
220         I++) {
221      const RSExportVar *EV = *I;
222      const RSExportType *ET = EV->getType();
223      bool countsAsRSObject = false;
224
225      // Variable name
226      ExportVarInfo.push_back(
227          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
228
229      // Type name
230      switch (ET->getClass()) {
231        case RSExportType::ExportClassPrimitive: {
232          const RSExportPrimitiveType *PT =
233              static_cast<const RSExportPrimitiveType*>(ET);
234          ExportVarInfo.push_back(
235              llvm::MDString::get(
236                mLLVMContext, llvm::utostr_32(PT->getType())));
237          if (PT->isRSObjectType()) {
238            countsAsRSObject = true;
239          }
240          break;
241        }
242        case RSExportType::ExportClassPointer: {
243          ExportVarInfo.push_back(
244              llvm::MDString::get(
245                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
246                  ->getPointeeType()->getName()).c_str()));
247          break;
248        }
249        case RSExportType::ExportClassMatrix: {
250          ExportVarInfo.push_back(
251              llvm::MDString::get(
252                mLLVMContext, llvm::utostr_32(
253                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
254                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
255          break;
256        }
257        case RSExportType::ExportClassVector:
258        case RSExportType::ExportClassConstantArray:
259        case RSExportType::ExportClassRecord: {
260          ExportVarInfo.push_back(
261              llvm::MDString::get(mLLVMContext,
262                EV->getType()->getName().c_str()));
263          break;
264        }
265      }
266
267      mExportVarMetadata->addOperand(
268          llvm::MDNode::get(mLLVMContext, ExportVarInfo));
269      ExportVarInfo.clear();
270
271      if (mRSObjectSlotsMetadata == NULL) {
272        mRSObjectSlotsMetadata =
273            M->getOrInsertNamedMetadata(RS_OBJECT_SLOTS_MN);
274      }
275
276      if (countsAsRSObject) {
277        mRSObjectSlotsMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
278            llvm::MDString::get(mLLVMContext, llvm::utostr_32(slotCount))));
279      }
280
281      slotCount++;
282    }
283  }
284
285  // Dump export function info
286  if (mContext->hasExportFunc()) {
287    if (mExportFuncMetadata == NULL)
288      mExportFuncMetadata =
289          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
290
291    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
292
293    for (RSContext::const_export_func_iterator
294            I = mContext->export_funcs_begin(),
295            E = mContext->export_funcs_end();
296         I != E;
297         I++) {
298      const RSExportFunc *EF = *I;
299
300      // Function name
301      if (!EF->hasParam()) {
302        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
303                                                     EF->getName().c_str()));
304      } else {
305        llvm::Function *F = M->getFunction(EF->getName());
306        llvm::Function *HelperFunction;
307        const std::string HelperFunctionName(".helper_" + EF->getName());
308
309        slangAssert(F && "Function marked as exported disappeared in Bitcode");
310
311        // Create helper function
312        {
313          llvm::StructType *HelperFunctionParameterTy = NULL;
314
315          if (!F->getArgumentList().empty()) {
316            std::vector<llvm::Type*> HelperFunctionParameterTys;
317            for (llvm::Function::arg_iterator AI = F->arg_begin(),
318                 AE = F->arg_end(); AI != AE; AI++)
319              HelperFunctionParameterTys.push_back(AI->getType());
320
321            HelperFunctionParameterTy =
322                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
323          }
324
325          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
326            fprintf(stderr, "Failed to export function %s: parameter type "
327                            "mismatch during creation of helper function.\n",
328                    EF->getName().c_str());
329
330            const RSExportRecordType *Expected = EF->getParamPacketType();
331            if (Expected) {
332              fprintf(stderr, "Expected:\n");
333              Expected->getLLVMType()->dump();
334            }
335            if (HelperFunctionParameterTy) {
336              fprintf(stderr, "Got:\n");
337              HelperFunctionParameterTy->dump();
338            }
339          }
340
341          std::vector<llvm::Type*> Params;
342          if (HelperFunctionParameterTy) {
343            llvm::PointerType *HelperFunctionParameterTyP =
344                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
345            Params.push_back(HelperFunctionParameterTyP);
346          }
347
348          llvm::FunctionType * HelperFunctionType =
349              llvm::FunctionType::get(F->getReturnType(),
350                                      Params,
351                                      /* IsVarArgs = */false);
352
353          HelperFunction =
354              llvm::Function::Create(HelperFunctionType,
355                                     llvm::GlobalValue::ExternalLinkage,
356                                     HelperFunctionName,
357                                     M);
358
359          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
360          HelperFunction->setCallingConv(F->getCallingConv());
361
362          // Create helper function body
363          {
364            llvm::Argument *HelperFunctionParameter =
365                &(*HelperFunction->arg_begin());
366            llvm::BasicBlock *BB =
367                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
368            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
369            llvm::SmallVector<llvm::Value*, 6> Params;
370            llvm::Value *Idx[2];
371
372            Idx[0] =
373                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
374
375            // getelementptr and load instruction for all elements in
376            // parameter .p
377            for (size_t i = 0; i < EF->getNumParameters(); i++) {
378              // getelementptr
379              Idx[1] = llvm::ConstantInt::get(
380                llvm::Type::getInt32Ty(mLLVMContext), i);
381
382              llvm::Value *Ptr =
383                IB->CreateInBoundsGEP(HelperFunctionParameter, Idx);
384
385              // load
386              llvm::Value *V = IB->CreateLoad(Ptr);
387              Params.push_back(V);
388            }
389
390            // Call and pass the all elements as parameter to F
391            llvm::CallInst *CI = IB->CreateCall(F, Params);
392
393            CI->setCallingConv(F->getCallingConv());
394
395            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
396              IB->CreateRetVoid();
397            else
398              IB->CreateRet(CI);
399
400            delete IB;
401          }
402        }
403
404        ExportFuncInfo.push_back(
405            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
406      }
407
408      mExportFuncMetadata->addOperand(
409          llvm::MDNode::get(mLLVMContext, ExportFuncInfo));
410      ExportFuncInfo.clear();
411    }
412  }
413
414  // Dump export function info
415  if (mContext->hasExportForEach()) {
416    if (mExportForEachNameMetadata == NULL) {
417      mExportForEachNameMetadata =
418          M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_NAME_MN);
419    }
420    if (mExportForEachSignatureMetadata == NULL) {
421      mExportForEachSignatureMetadata =
422          M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_MN);
423    }
424
425    llvm::SmallVector<llvm::Value*, 1> ExportForEachName;
426    llvm::SmallVector<llvm::Value*, 1> ExportForEachInfo;
427
428    for (RSContext::const_export_foreach_iterator
429            I = mContext->export_foreach_begin(),
430            E = mContext->export_foreach_end();
431         I != E;
432         I++) {
433      const RSExportForEach *EFE = *I;
434
435      ExportForEachName.push_back(
436          llvm::MDString::get(mLLVMContext, EFE->getName().c_str()));
437
438      mExportForEachNameMetadata->addOperand(
439          llvm::MDNode::get(mLLVMContext, ExportForEachName));
440      ExportForEachName.clear();
441
442      ExportForEachInfo.push_back(
443          llvm::MDString::get(mLLVMContext,
444                              llvm::utostr_32(EFE->getSignatureMetadata())));
445
446      mExportForEachSignatureMetadata->addOperand(
447          llvm::MDNode::get(mLLVMContext, ExportForEachInfo));
448      ExportForEachInfo.clear();
449    }
450  }
451
452  // Dump export type info
453  if (mContext->hasExportType()) {
454    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
455
456    for (RSContext::const_export_type_iterator
457            I = mContext->export_types_begin(),
458            E = mContext->export_types_end();
459         I != E;
460         I++) {
461      // First, dump type name list to export
462      const RSExportType *ET = I->getValue();
463
464      ExportTypeInfo.clear();
465      // Type name
466      ExportTypeInfo.push_back(
467          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
468
469      if (ET->getClass() == RSExportType::ExportClassRecord) {
470        const RSExportRecordType *ERT =
471            static_cast<const RSExportRecordType*>(ET);
472
473        if (mExportTypeMetadata == NULL)
474          mExportTypeMetadata =
475              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
476
477        mExportTypeMetadata->addOperand(
478            llvm::MDNode::get(mLLVMContext, ExportTypeInfo));
479
480        // Now, export struct field information to %[struct name]
481        std::string StructInfoMetadataName("%");
482        StructInfoMetadataName.append(ET->getName());
483        llvm::NamedMDNode *StructInfoMetadata =
484            M->getOrInsertNamedMetadata(StructInfoMetadataName);
485        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
486
487        slangAssert(StructInfoMetadata->getNumOperands() == 0 &&
488                    "Metadata with same name was created before");
489        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
490                FE = ERT->fields_end();
491             FI != FE;
492             FI++) {
493          const RSExportRecordType::Field *F = *FI;
494
495          // 1. field name
496          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
497                                                  F->getName().c_str()));
498
499          // 2. field type name
500          FieldInfo.push_back(
501              llvm::MDString::get(mLLVMContext,
502                                  F->getType()->getName().c_str()));
503
504          StructInfoMetadata->addOperand(
505              llvm::MDNode::get(mLLVMContext, FieldInfo));
506          FieldInfo.clear();
507        }
508      }   // ET->getClass() == RSExportType::ExportClassRecord
509    }
510  }
511
512  return;
513}
514
515RSBackend::~RSBackend() {
516  return;
517}
518
519}  // namespace slang
520