slang_backend.cpp revision 13dba25ad988da71201c1331b34f56e2d01bcf96
1/*
2 * Copyright 2010-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "clang/AST/ASTContext.h"
23#include "clang/AST/Attr.h"
24#include "clang/AST/Decl.h"
25#include "clang/AST/DeclGroup.h"
26
27#include "clang/Basic/Diagnostic.h"
28#include "clang/Basic/TargetInfo.h"
29#include "clang/Basic/TargetOptions.h"
30
31#include "clang/CodeGen/ModuleBuilder.h"
32
33#include "clang/Frontend/CodeGenOptions.h"
34#include "clang/Frontend/FrontendDiagnostic.h"
35
36#include "llvm/ADT/Twine.h"
37#include "llvm/ADT/StringExtras.h"
38
39#include "llvm/Bitcode/ReaderWriter.h"
40
41#include "llvm/CodeGen/RegAllocRegistry.h"
42#include "llvm/CodeGen/SchedulerRegistry.h"
43
44#include "llvm/IR/Constant.h"
45#include "llvm/IR/Constants.h"
46#include "llvm/IR/DataLayout.h"
47#include "llvm/IR/DebugLoc.h"
48#include "llvm/IR/DerivedTypes.h"
49#include "llvm/IR/Function.h"
50#include "llvm/IR/IRBuilder.h"
51#include "llvm/IR/IRPrintingPasses.h"
52#include "llvm/IR/LLVMContext.h"
53#include "llvm/IR/Metadata.h"
54#include "llvm/IR/Module.h"
55
56#include "llvm/Transforms/IPO/PassManagerBuilder.h"
57
58#include "llvm/Target/TargetMachine.h"
59#include "llvm/Target/TargetOptions.h"
60#include "llvm/Support/TargetRegistry.h"
61
62#include "llvm/MC/SubtargetFeature.h"
63
64#include "slang_assert.h"
65#include "slang.h"
66#include "slang_bitcode_gen.h"
67#include "slang_rs_context.h"
68#include "slang_rs_export_foreach.h"
69#include "slang_rs_export_func.h"
70#include "slang_rs_export_reduce.h"
71#include "slang_rs_export_type.h"
72#include "slang_rs_export_var.h"
73#include "slang_rs_metadata.h"
74
75#include "rs_cc_options.h"
76
77#include "strip_unknown_attributes.h"
78
79namespace slang {
80
81void Backend::CreateFunctionPasses() {
82  if (!mPerFunctionPasses) {
83    mPerFunctionPasses = new llvm::legacy::FunctionPassManager(mpModule);
84
85    llvm::PassManagerBuilder PMBuilder;
86    PMBuilder.OptLevel = mCodeGenOpts.OptimizationLevel;
87    PMBuilder.populateFunctionPassManager(*mPerFunctionPasses);
88  }
89}
90
91void Backend::CreateModulePasses() {
92  if (!mPerModulePasses) {
93    mPerModulePasses = new llvm::legacy::PassManager();
94
95    llvm::PassManagerBuilder PMBuilder;
96    PMBuilder.OptLevel = mCodeGenOpts.OptimizationLevel;
97    PMBuilder.SizeLevel = mCodeGenOpts.OptimizeSize;
98    if (mCodeGenOpts.UnitAtATime) {
99      PMBuilder.DisableUnitAtATime = 0;
100    } else {
101      PMBuilder.DisableUnitAtATime = 1;
102    }
103
104    if (mCodeGenOpts.UnrollLoops) {
105      PMBuilder.DisableUnrollLoops = 0;
106    } else {
107      PMBuilder.DisableUnrollLoops = 1;
108    }
109
110    PMBuilder.populateModulePassManager(*mPerModulePasses);
111    // Add a pass to strip off unknown/unsupported attributes.
112    mPerModulePasses->add(createStripUnknownAttributesPass());
113  }
114}
115
116bool Backend::CreateCodeGenPasses() {
117  if ((mOT != Slang::OT_Assembly) && (mOT != Slang::OT_Object))
118    return true;
119
120  // Now we add passes for code emitting
121  if (mCodeGenPasses) {
122    return true;
123  } else {
124    mCodeGenPasses = new llvm::legacy::FunctionPassManager(mpModule);
125  }
126
127  // Create the TargetMachine for generating code.
128  std::string Triple = mpModule->getTargetTriple();
129
130  std::string Error;
131  const llvm::Target* TargetInfo =
132      llvm::TargetRegistry::lookupTarget(Triple, Error);
133  if (TargetInfo == nullptr) {
134    mDiagEngine.Report(clang::diag::err_fe_unable_to_create_target) << Error;
135    return false;
136  }
137
138  // Target Machine Options
139  llvm::TargetOptions Options;
140
141  // Use soft-float ABI for ARM (which is the target used by Slang during code
142  // generation).  Codegen still uses hardware FPU by default.  To use software
143  // floating point, add 'soft-float' feature to FeaturesStr below.
144  Options.FloatABIType = llvm::FloatABI::Soft;
145
146  // BCC needs all unknown symbols resolved at compilation time. So we don't
147  // need any relocation model.
148  llvm::Reloc::Model RM = llvm::Reloc::Static;
149
150  // This is set for the linker (specify how large of the virtual addresses we
151  // can access for all unknown symbols.)
152  llvm::CodeModel::Model CM;
153  if (mpModule->getDataLayout().getPointerSize() == 4) {
154    CM = llvm::CodeModel::Small;
155  } else {
156    // The target may have pointer size greater than 32 (e.g. x86_64
157    // architecture) may need large data address model
158    CM = llvm::CodeModel::Medium;
159  }
160
161  // Setup feature string
162  std::string FeaturesStr;
163  if (mTargetOpts.CPU.size() || mTargetOpts.Features.size()) {
164    llvm::SubtargetFeatures Features;
165
166    for (std::vector<std::string>::const_iterator
167             I = mTargetOpts.Features.begin(), E = mTargetOpts.Features.end();
168         I != E;
169         I++)
170      Features.AddFeature(*I);
171
172    FeaturesStr = Features.getString();
173  }
174
175  llvm::TargetMachine *TM =
176    TargetInfo->createTargetMachine(Triple, mTargetOpts.CPU, FeaturesStr,
177                                    Options, RM, CM);
178
179  // Register scheduler
180  llvm::RegisterScheduler::setDefault(llvm::createDefaultScheduler);
181
182  // Register allocation policy:
183  //  createFastRegisterAllocator: fast but bad quality
184  //  createGreedyRegisterAllocator: not so fast but good quality
185  llvm::RegisterRegAlloc::setDefault((mCodeGenOpts.OptimizationLevel == 0) ?
186                                     llvm::createFastRegisterAllocator :
187                                     llvm::createGreedyRegisterAllocator);
188
189  llvm::CodeGenOpt::Level OptLevel = llvm::CodeGenOpt::Default;
190  if (mCodeGenOpts.OptimizationLevel == 0) {
191    OptLevel = llvm::CodeGenOpt::None;
192  } else if (mCodeGenOpts.OptimizationLevel == 3) {
193    OptLevel = llvm::CodeGenOpt::Aggressive;
194  }
195
196  llvm::TargetMachine::CodeGenFileType CGFT =
197      llvm::TargetMachine::CGFT_AssemblyFile;
198  if (mOT == Slang::OT_Object) {
199    CGFT = llvm::TargetMachine::CGFT_ObjectFile;
200  }
201  if (TM->addPassesToEmitFile(*mCodeGenPasses, mBufferOutStream,
202                              CGFT, OptLevel)) {
203    mDiagEngine.Report(clang::diag::err_fe_unable_to_interface_with_target);
204    return false;
205  }
206
207  return true;
208}
209
210Backend::Backend(RSContext *Context, clang::DiagnosticsEngine *DiagEngine,
211                 const RSCCOptions &Opts, const clang::CodeGenOptions &CodeGenOpts,
212                 const clang::TargetOptions &TargetOpts, PragmaList *Pragmas,
213                 llvm::raw_ostream *OS, Slang::OutputType OT,
214                 clang::SourceManager &SourceMgr, bool AllowRSPrefix,
215                 bool IsFilterscript)
216    : ASTConsumer(), mTargetOpts(TargetOpts), mpModule(nullptr), mpOS(OS),
217      mOT(OT), mGen(nullptr), mPerFunctionPasses(nullptr),
218      mPerModulePasses(nullptr), mCodeGenPasses(nullptr),
219      mBufferOutStream(*mpOS), mContext(Context),
220      mSourceMgr(SourceMgr), mASTPrint(Opts.mASTPrint), mAllowRSPrefix(AllowRSPrefix),
221      mIsFilterscript(IsFilterscript), mExportVarMetadata(nullptr),
222      mExportFuncMetadata(nullptr), mExportForEachNameMetadata(nullptr),
223      mExportForEachSignatureMetadata(nullptr), mExportReduceMetadata(nullptr),
224      mExportReduceNewMetadata(nullptr),
225      mExportTypeMetadata(nullptr), mRSObjectSlotsMetadata(nullptr),
226      mRefCount(mContext->getASTContext()),
227      mASTChecker(Context, Context->getTargetAPI(), IsFilterscript),
228      mForEachHandler(Context),
229      mLLVMContext(llvm::getGlobalContext()), mDiagEngine(*DiagEngine),
230      mCodeGenOpts(CodeGenOpts), mPragmas(Pragmas) {
231  mGen = CreateLLVMCodeGen(mDiagEngine, "", mCodeGenOpts, mLLVMContext);
232}
233
234void Backend::Initialize(clang::ASTContext &Ctx) {
235  mGen->Initialize(Ctx);
236
237  mpModule = mGen->GetModule();
238}
239
240void Backend::HandleTranslationUnit(clang::ASTContext &Ctx) {
241  HandleTranslationUnitPre(Ctx);
242
243  if (mASTPrint)
244    Ctx.getTranslationUnitDecl()->dump();
245
246  mGen->HandleTranslationUnit(Ctx);
247
248  // Here, we complete a translation unit (whole translation unit is now in LLVM
249  // IR). Now, interact with LLVM backend to generate actual machine code (asm
250  // or machine code, whatever.)
251
252  // Silently ignore if we weren't initialized for some reason.
253  if (!mpModule)
254    return;
255
256  llvm::Module *M = mGen->ReleaseModule();
257  if (!M) {
258    // The module has been released by IR gen on failures, do not double free.
259    mpModule = nullptr;
260    return;
261  }
262
263  slangAssert(mpModule == M &&
264              "Unexpected module change during LLVM IR generation");
265
266  // Insert #pragma information into metadata section of module
267  if (!mPragmas->empty()) {
268    llvm::NamedMDNode *PragmaMetadata =
269        mpModule->getOrInsertNamedMetadata(Slang::PragmaMetadataName);
270    for (PragmaList::const_iterator I = mPragmas->begin(), E = mPragmas->end();
271         I != E;
272         I++) {
273      llvm::SmallVector<llvm::Metadata*, 2> Pragma;
274      // Name goes first
275      Pragma.push_back(llvm::MDString::get(mLLVMContext, I->first));
276      // And then value
277      Pragma.push_back(llvm::MDString::get(mLLVMContext, I->second));
278
279      // Create MDNode and insert into PragmaMetadata
280      PragmaMetadata->addOperand(
281          llvm::MDNode::get(mLLVMContext, Pragma));
282    }
283  }
284
285  HandleTranslationUnitPost(mpModule);
286
287  // Create passes for optimization and code emission
288
289  // Create and run per-function passes
290  CreateFunctionPasses();
291  if (mPerFunctionPasses) {
292    mPerFunctionPasses->doInitialization();
293
294    for (llvm::Module::iterator I = mpModule->begin(), E = mpModule->end();
295         I != E;
296         I++)
297      if (!I->isDeclaration())
298        mPerFunctionPasses->run(*I);
299
300    mPerFunctionPasses->doFinalization();
301  }
302
303  // Create and run module passes
304  CreateModulePasses();
305  if (mPerModulePasses)
306    mPerModulePasses->run(*mpModule);
307
308  switch (mOT) {
309    case Slang::OT_Assembly:
310    case Slang::OT_Object: {
311      if (!CreateCodeGenPasses())
312        return;
313
314      mCodeGenPasses->doInitialization();
315
316      for (llvm::Module::iterator I = mpModule->begin(), E = mpModule->end();
317          I != E;
318          I++)
319        if (!I->isDeclaration())
320          mCodeGenPasses->run(*I);
321
322      mCodeGenPasses->doFinalization();
323      break;
324    }
325    case Slang::OT_LLVMAssembly: {
326      llvm::legacy::PassManager *LLEmitPM = new llvm::legacy::PassManager();
327      LLEmitPM->add(llvm::createPrintModulePass(mBufferOutStream));
328      LLEmitPM->run(*mpModule);
329      break;
330    }
331    case Slang::OT_Bitcode: {
332      writeBitcode(mBufferOutStream, *mpModule, getTargetAPI(),
333                   mCodeGenOpts.OptimizationLevel, mCodeGenOpts.getDebugInfo());
334      break;
335    }
336    case Slang::OT_Nothing: {
337      return;
338    }
339    default: {
340      slangAssert(false && "Unknown output type");
341    }
342  }
343
344  mBufferOutStream.flush();
345}
346
347void Backend::HandleTagDeclDefinition(clang::TagDecl *D) {
348  mGen->HandleTagDeclDefinition(D);
349}
350
351void Backend::CompleteTentativeDefinition(clang::VarDecl *D) {
352  mGen->CompleteTentativeDefinition(D);
353}
354
355Backend::~Backend() {
356  delete mpModule;
357  delete mGen;
358  delete mPerFunctionPasses;
359  delete mPerModulePasses;
360  delete mCodeGenPasses;
361}
362
363// 1) Add zero initialization of local RS object types
364void Backend::AnnotateFunction(clang::FunctionDecl *FD) {
365  if (FD &&
366      FD->hasBody() &&
367      !Slang::IsLocInRSHeaderFile(FD->getLocation(), mSourceMgr)) {
368    mRefCount.Init();
369    mRefCount.Visit(FD->getBody());
370  }
371}
372
373void Backend::LowerRSForEachCall(clang::FunctionDecl *FD) {
374  // Skip this AST walking for lower API levels.
375  if (getTargetAPI() < SLANG_N_TARGET_API) {
376    return;
377  }
378
379  if (!FD || !FD->hasBody() ||
380      Slang::IsLocInRSHeaderFile(FD->getLocation(), mSourceMgr)) {
381    return;
382  }
383
384  mForEachHandler.VisitStmt(FD->getBody());
385}
386
387bool Backend::HandleTopLevelDecl(clang::DeclGroupRef D) {
388  // Find and remember the types for rs_allocation and rs_script_call_t so
389  // they can be used later for translating rsForEach() calls.
390  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
391       (mContext->getAllocationType().isNull() ||
392        mContext->getScriptCallType().isNull()) &&
393       I != E; I++) {
394    if (clang::TypeDecl* TD = llvm::dyn_cast<clang::TypeDecl>(*I)) {
395      clang::StringRef TypeName = TD->getName();
396      if (TypeName.equals("rs_allocation")) {
397        mContext->setAllocationType(TD);
398      } else if (TypeName.equals("rs_script_call_t")) {
399        mContext->setScriptCallType(TD);
400      }
401    }
402  }
403
404  // Disallow user-defined functions with prefix "rs"
405  if (!mAllowRSPrefix) {
406    // Iterate all function declarations in the program.
407    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
408         I != E; I++) {
409      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
410      if (FD == nullptr)
411        continue;
412      if (!FD->getName().startswith("rs"))  // Check prefix
413        continue;
414      if (!Slang::IsLocInRSHeaderFile(FD->getLocation(), mSourceMgr))
415        mContext->ReportError(FD->getLocation(),
416                              "invalid function name prefix, "
417                              "\"rs\" is reserved: '%0'")
418            << FD->getName();
419    }
420  }
421
422  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
423    clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
424    if (FD) {
425      if (!FD->hasAttr<clang::UsedAttr>() && mContext->isReferencedByReducePragma(FD)) {
426        // Handle forward reference from pragma (see RSReducePragmaHandler::HandlePragma
427        // for backward reference).
428        FD->addAttr(clang::UsedAttr::CreateImplicit(mContext->getASTContext()));
429      }
430      if (FD->isGlobal()) {
431        // Check that we don't have any array parameters being misinterpreted as
432        // kernel pointers due to the C type system's array to pointer decay.
433        size_t numParams = FD->getNumParams();
434        for (size_t i = 0; i < numParams; i++) {
435          const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
436          clang::QualType QT = PVD->getOriginalType();
437          if (QT->isArrayType()) {
438            mContext->ReportError(
439                PVD->getTypeSpecStartLoc(),
440                "exported function parameters may not have array type: %0")
441                << QT;
442          }
443        }
444        AnnotateFunction(FD);
445      }
446    }
447
448    if (getTargetAPI() >= SLANG_N_TARGET_API) {
449      if (FD && FD->hasBody() &&
450          RSExportForEach::isRSForEachFunc(getTargetAPI(), FD)) {
451        // Log kernels by their names, and assign them slot numbers.
452        if (!Slang::IsLocInRSHeaderFile(FD->getLocation(), mSourceMgr)) {
453            mContext->addForEach(FD);
454        }
455      } else {
456        // Look for any kernel launch calls and translate them into using the
457        // internal API.
458        // TODO: Simply ignores kernel launch inside a kernel for now.
459        // Needs more rigorous and comprehensive checks.
460        LowerRSForEachCall(FD);
461      }
462    }
463  }
464
465  return mGen->HandleTopLevelDecl(D);
466}
467
468void Backend::HandleTranslationUnitPre(clang::ASTContext &C) {
469  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
470
471  if (!mContext->processReducePragmas())
472    return;
473
474  // If we have an invalid RS/FS AST, don't check further.
475  if (!mASTChecker.Validate()) {
476    return;
477  }
478
479  if (mIsFilterscript) {
480    mContext->addPragma("rs_fp_relaxed", "");
481  }
482
483  int version = mContext->getVersion();
484  if (version == 0) {
485    // Not setting a version is an error
486    mDiagEngine.Report(
487        mSourceMgr.getLocForEndOfFile(mSourceMgr.getMainFileID()),
488        mDiagEngine.getCustomDiagID(
489            clang::DiagnosticsEngine::Error,
490            "missing pragma for version in source file"));
491  } else {
492    slangAssert(version == 1);
493  }
494
495  if (mContext->getReflectJavaPackageName().empty()) {
496    mDiagEngine.Report(
497        mSourceMgr.getLocForEndOfFile(mSourceMgr.getMainFileID()),
498        mDiagEngine.getCustomDiagID(clang::DiagnosticsEngine::Error,
499                                    "missing \"#pragma rs "
500                                    "java_package_name(com.foo.bar)\" "
501                                    "in source file"));
502    return;
503  }
504
505  // Create a static global destructor if necessary (to handle RS object
506  // runtime cleanup).
507  clang::FunctionDecl *FD = mRefCount.CreateStaticGlobalDtor();
508  if (FD) {
509    HandleTopLevelDecl(clang::DeclGroupRef(FD));
510  }
511
512  // Process any static function declarations
513  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
514          E = TUDecl->decls_end(); I != E; I++) {
515    if ((I->getKind() >= clang::Decl::firstFunction) &&
516        (I->getKind() <= clang::Decl::lastFunction)) {
517      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
518      if (FD && !FD->isGlobal()) {
519        AnnotateFunction(FD);
520      }
521    }
522  }
523}
524
525///////////////////////////////////////////////////////////////////////////////
526void Backend::dumpExportVarInfo(llvm::Module *M) {
527  int slotCount = 0;
528  if (mExportVarMetadata == nullptr)
529    mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
530
531  llvm::SmallVector<llvm::Metadata *, 2> ExportVarInfo;
532
533  // We emit slot information (#rs_object_slots) for any reference counted
534  // RS type or pointer (which can also be bound).
535
536  for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
537          E = mContext->export_vars_end();
538       I != E;
539       I++) {
540    const RSExportVar *EV = *I;
541    const RSExportType *ET = EV->getType();
542    bool countsAsRSObject = false;
543
544    // Variable name
545    ExportVarInfo.push_back(
546        llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
547
548    // Type name
549    switch (ET->getClass()) {
550      case RSExportType::ExportClassPrimitive: {
551        const RSExportPrimitiveType *PT =
552            static_cast<const RSExportPrimitiveType*>(ET);
553        ExportVarInfo.push_back(
554            llvm::MDString::get(
555              mLLVMContext, llvm::utostr_32(PT->getType())));
556        if (PT->isRSObjectType()) {
557          countsAsRSObject = true;
558        }
559        break;
560      }
561      case RSExportType::ExportClassPointer: {
562        ExportVarInfo.push_back(
563            llvm::MDString::get(
564              mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
565                ->getPointeeType()->getName()).c_str()));
566        break;
567      }
568      case RSExportType::ExportClassMatrix: {
569        ExportVarInfo.push_back(
570            llvm::MDString::get(
571              mLLVMContext, llvm::utostr_32(
572                  /* TODO Strange value.  This pushes just a number, quite
573                   * different than the other cases.  What is this used for?
574                   * These are the metadata values that some partner drivers
575                   * want to reference (for TBAA, etc.). We may want to look
576                   * at whether these provide any reasonable value (or have
577                   * distinct enough values to actually depend on).
578                   */
579                DataTypeRSMatrix2x2 +
580                static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
581        break;
582      }
583      case RSExportType::ExportClassVector:
584      case RSExportType::ExportClassConstantArray:
585      case RSExportType::ExportClassRecord: {
586        ExportVarInfo.push_back(
587            llvm::MDString::get(mLLVMContext,
588              EV->getType()->getName().c_str()));
589        break;
590      }
591    }
592
593    mExportVarMetadata->addOperand(
594        llvm::MDNode::get(mLLVMContext, ExportVarInfo));
595    ExportVarInfo.clear();
596
597    if (mRSObjectSlotsMetadata == nullptr) {
598      mRSObjectSlotsMetadata =
599          M->getOrInsertNamedMetadata(RS_OBJECT_SLOTS_MN);
600    }
601
602    if (countsAsRSObject) {
603      mRSObjectSlotsMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
604          llvm::MDString::get(mLLVMContext, llvm::utostr_32(slotCount))));
605    }
606
607    slotCount++;
608  }
609}
610
611void Backend::dumpExportFunctionInfo(llvm::Module *M) {
612  if (mExportFuncMetadata == nullptr)
613    mExportFuncMetadata =
614        M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
615
616  llvm::SmallVector<llvm::Metadata *, 1> ExportFuncInfo;
617
618  for (RSContext::const_export_func_iterator
619          I = mContext->export_funcs_begin(),
620          E = mContext->export_funcs_end();
621       I != E;
622       I++) {
623    const RSExportFunc *EF = *I;
624
625    // Function name
626    if (!EF->hasParam()) {
627      ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
628                                                   EF->getName().c_str()));
629    } else {
630      llvm::Function *F = M->getFunction(EF->getName());
631      llvm::Function *HelperFunction;
632      const std::string HelperFunctionName(".helper_" + EF->getName());
633
634      slangAssert(F && "Function marked as exported disappeared in Bitcode");
635
636      // Create helper function
637      {
638        llvm::StructType *HelperFunctionParameterTy = nullptr;
639        std::vector<bool> isStructInput;
640
641        if (!F->getArgumentList().empty()) {
642          std::vector<llvm::Type*> HelperFunctionParameterTys;
643          for (llvm::Function::arg_iterator AI = F->arg_begin(),
644                   AE = F->arg_end(); AI != AE; AI++) {
645              if (AI->getType()->isPointerTy() && AI->getType()->getPointerElementType()->isStructTy()) {
646                  HelperFunctionParameterTys.push_back(AI->getType()->getPointerElementType());
647                  isStructInput.push_back(true);
648              } else {
649                  HelperFunctionParameterTys.push_back(AI->getType());
650                  isStructInput.push_back(false);
651              }
652          }
653          HelperFunctionParameterTy =
654              llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
655        }
656
657        if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
658          fprintf(stderr, "Failed to export function %s: parameter type "
659                          "mismatch during creation of helper function.\n",
660                  EF->getName().c_str());
661
662          const RSExportRecordType *Expected = EF->getParamPacketType();
663          if (Expected) {
664            fprintf(stderr, "Expected:\n");
665            Expected->getLLVMType()->dump();
666          }
667          if (HelperFunctionParameterTy) {
668            fprintf(stderr, "Got:\n");
669            HelperFunctionParameterTy->dump();
670          }
671        }
672
673        std::vector<llvm::Type*> Params;
674        if (HelperFunctionParameterTy) {
675          llvm::PointerType *HelperFunctionParameterTyP =
676              llvm::PointerType::getUnqual(HelperFunctionParameterTy);
677          Params.push_back(HelperFunctionParameterTyP);
678        }
679
680        llvm::FunctionType * HelperFunctionType =
681            llvm::FunctionType::get(F->getReturnType(),
682                                    Params,
683                                    /* IsVarArgs = */false);
684
685        HelperFunction =
686            llvm::Function::Create(HelperFunctionType,
687                                   llvm::GlobalValue::ExternalLinkage,
688                                   HelperFunctionName,
689                                   M);
690
691        HelperFunction->addFnAttr(llvm::Attribute::NoInline);
692        HelperFunction->setCallingConv(F->getCallingConv());
693
694        // Create helper function body
695        {
696          llvm::Argument *HelperFunctionParameter =
697              &(*HelperFunction->arg_begin());
698          llvm::BasicBlock *BB =
699              llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
700          llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
701          llvm::SmallVector<llvm::Value*, 6> Params;
702          llvm::Value *Idx[2];
703
704          Idx[0] =
705              llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
706
707          // getelementptr and load instruction for all elements in
708          // parameter .p
709          for (size_t i = 0; i < EF->getNumParameters(); i++) {
710            // getelementptr
711            Idx[1] = llvm::ConstantInt::get(
712              llvm::Type::getInt32Ty(mLLVMContext), i);
713
714            llvm::Value *Ptr = NULL;
715
716            Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter, Idx);
717
718            // Load is only required for non-struct ptrs
719            if (isStructInput[i]) {
720                Params.push_back(Ptr);
721            } else {
722                llvm::Value *V = IB->CreateLoad(Ptr);
723                Params.push_back(V);
724            }
725          }
726
727          // Call and pass the all elements as parameter to F
728          llvm::CallInst *CI = IB->CreateCall(F, Params);
729
730          CI->setCallingConv(F->getCallingConv());
731
732          if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext)) {
733            IB->CreateRetVoid();
734          } else {
735            IB->CreateRet(CI);
736          }
737
738          delete IB;
739        }
740      }
741
742      ExportFuncInfo.push_back(
743          llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
744    }
745
746    mExportFuncMetadata->addOperand(
747        llvm::MDNode::get(mLLVMContext, ExportFuncInfo));
748    ExportFuncInfo.clear();
749  }
750}
751
752void Backend::dumpExportForEachInfo(llvm::Module *M) {
753  if (mExportForEachNameMetadata == nullptr) {
754    mExportForEachNameMetadata =
755        M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_NAME_MN);
756  }
757  if (mExportForEachSignatureMetadata == nullptr) {
758    mExportForEachSignatureMetadata =
759        M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_MN);
760  }
761
762  llvm::SmallVector<llvm::Metadata *, 1> ExportForEachName;
763  llvm::SmallVector<llvm::Metadata *, 1> ExportForEachInfo;
764
765  for (RSContext::const_export_foreach_iterator
766          I = mContext->export_foreach_begin(),
767          E = mContext->export_foreach_end();
768       I != E;
769       I++) {
770    const RSExportForEach *EFE = *I;
771
772    ExportForEachName.push_back(
773        llvm::MDString::get(mLLVMContext, EFE->getName().c_str()));
774
775    mExportForEachNameMetadata->addOperand(
776        llvm::MDNode::get(mLLVMContext, ExportForEachName));
777    ExportForEachName.clear();
778
779    ExportForEachInfo.push_back(
780        llvm::MDString::get(mLLVMContext,
781                            llvm::utostr_32(EFE->getSignatureMetadata())));
782
783    mExportForEachSignatureMetadata->addOperand(
784        llvm::MDNode::get(mLLVMContext, ExportForEachInfo));
785    ExportForEachInfo.clear();
786  }
787}
788
789void Backend::dumpExportReduceInfo(llvm::Module *M) {
790  if (!mExportReduceMetadata) {
791    mExportReduceMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_REDUCE_MN);
792  }
793
794  llvm::SmallVector<llvm::Metadata *, 1> ExportReduceInfo;
795
796  // Add the names of the reduce-style kernel functions to the metadata node.
797  for (auto I = mContext->export_reduce_begin(),
798            E = mContext->export_reduce_end(); I != E; ++I) {
799    ExportReduceInfo.clear();
800
801    ExportReduceInfo.push_back(
802      llvm::MDString::get(mLLVMContext, (*I)->getName().c_str()));
803
804    mExportReduceMetadata->addOperand(
805      llvm::MDNode::get(mLLVMContext, ExportReduceInfo));
806  }
807}
808
809void Backend::dumpExportReduceNewInfo(llvm::Module *M) {
810  if (!mExportReduceNewMetadata) {
811    mExportReduceNewMetadata =
812      M->getOrInsertNamedMetadata(RS_EXPORT_REDUCE_NEW_MN);
813  }
814
815  llvm::SmallVector<llvm::Metadata *, 6> ExportReduceNewInfo;
816  // Add operand to ExportReduceNewInfo, padding out missing operands with
817  // nullptr.
818  auto addOperand = [&ExportReduceNewInfo](uint32_t Idx, llvm::Metadata *N) {
819    while (Idx > ExportReduceNewInfo.size())
820      ExportReduceNewInfo.push_back(nullptr);
821    ExportReduceNewInfo.push_back(N);
822  };
823  // Add string operand to ExportReduceNewInfo, padding out missing operands
824  // with nullptr.
825  // If string is empty, then do not add it unless Always is true.
826  auto addString = [&addOperand, this](uint32_t Idx, const std::string &S,
827                                       bool Always = true) {
828    if (Always || !S.empty())
829      addOperand(Idx, llvm::MDString::get(mLLVMContext, S));
830  };
831
832  // Add the description of the reduction kernels to the metadata node.
833  for (auto I = mContext->export_reduce_new_begin(),
834            E = mContext->export_reduce_new_end();
835       I != E; ++I) {
836    ExportReduceNewInfo.clear();
837
838    int Idx = 0;
839
840    addString(Idx++, (*I)->getNameReduce());
841
842    addOperand(Idx++, llvm::MDString::get(mLLVMContext, llvm::utostr_32((*I)->getAccumulatorTypeSize())));
843
844    llvm::SmallVector<llvm::Metadata *, 2> Accumulator;
845    Accumulator.push_back(
846      llvm::MDString::get(mLLVMContext, (*I)->getNameAccumulator()));
847    Accumulator.push_back(llvm::MDString::get(
848      mLLVMContext,
849      llvm::utostr_32((*I)->getAccumulatorSignatureMetadata())));
850    addOperand(Idx++, llvm::MDTuple::get(mLLVMContext, Accumulator));
851
852    addString(Idx++, (*I)->getNameInitializer(), false);
853    addString(Idx++, (*I)->getNameCombiner(), false);
854    addString(Idx++, (*I)->getNameOutConverter(), false);
855    addString(Idx++, (*I)->getNameHalter(), false);
856
857    mExportReduceNewMetadata->addOperand(
858      llvm::MDTuple::get(mLLVMContext, ExportReduceNewInfo));
859  }
860}
861
862void Backend::dumpExportTypeInfo(llvm::Module *M) {
863  llvm::SmallVector<llvm::Metadata *, 1> ExportTypeInfo;
864
865  for (RSContext::const_export_type_iterator
866          I = mContext->export_types_begin(),
867          E = mContext->export_types_end();
868       I != E;
869       I++) {
870    // First, dump type name list to export
871    const RSExportType *ET = I->getValue();
872
873    ExportTypeInfo.clear();
874    // Type name
875    ExportTypeInfo.push_back(
876        llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
877
878    if (ET->getClass() == RSExportType::ExportClassRecord) {
879      const RSExportRecordType *ERT =
880          static_cast<const RSExportRecordType*>(ET);
881
882      if (mExportTypeMetadata == nullptr)
883        mExportTypeMetadata =
884            M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
885
886      mExportTypeMetadata->addOperand(
887          llvm::MDNode::get(mLLVMContext, ExportTypeInfo));
888
889      // Now, export struct field information to %[struct name]
890      std::string StructInfoMetadataName("%");
891      StructInfoMetadataName.append(ET->getName());
892      llvm::NamedMDNode *StructInfoMetadata =
893          M->getOrInsertNamedMetadata(StructInfoMetadataName);
894      llvm::SmallVector<llvm::Metadata *, 3> FieldInfo;
895
896      slangAssert(StructInfoMetadata->getNumOperands() == 0 &&
897                  "Metadata with same name was created before");
898      for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
899              FE = ERT->fields_end();
900           FI != FE;
901           FI++) {
902        const RSExportRecordType::Field *F = *FI;
903
904        // 1. field name
905        FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
906                                                F->getName().c_str()));
907
908        // 2. field type name
909        FieldInfo.push_back(
910            llvm::MDString::get(mLLVMContext,
911                                F->getType()->getName().c_str()));
912
913        StructInfoMetadata->addOperand(
914            llvm::MDNode::get(mLLVMContext, FieldInfo));
915        FieldInfo.clear();
916      }
917    }   // ET->getClass() == RSExportType::ExportClassRecord
918  }
919}
920
921void Backend::HandleTranslationUnitPost(llvm::Module *M) {
922
923  if (!mContext->is64Bit()) {
924    M->setDataLayout("e-p:32:32-i64:64-v128:64:128-n32-S64");
925  }
926
927  if (!mContext->processExports())
928    return;
929
930  if (mContext->hasExportVar())
931    dumpExportVarInfo(M);
932
933  if (mContext->hasExportFunc())
934    dumpExportFunctionInfo(M);
935
936  if (mContext->hasExportForEach())
937    dumpExportForEachInfo(M);
938
939  if (mContext->hasExportReduce())
940    dumpExportReduceInfo(M);
941
942  if (mContext->hasExportReduceNew())
943    dumpExportReduceNewInfo(M);
944
945  if (mContext->hasExportType())
946    dumpExportTypeInfo(M);
947}
948
949}  // namespace slang
950