slang_backend.cpp revision 15e44e66adc350adb4fe0533a442092c64333ab5
1/*
2 * Copyright 2010-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "clang/AST/ASTContext.h"
23#include "clang/AST/Decl.h"
24#include "clang/AST/DeclGroup.h"
25
26#include "clang/Basic/Diagnostic.h"
27#include "clang/Basic/TargetInfo.h"
28#include "clang/Basic/TargetOptions.h"
29
30#include "clang/CodeGen/ModuleBuilder.h"
31
32#include "clang/Frontend/CodeGenOptions.h"
33#include "clang/Frontend/FrontendDiagnostic.h"
34
35#include "llvm/ADT/Twine.h"
36#include "llvm/ADT/StringExtras.h"
37
38#include "llvm/Bitcode/ReaderWriter.h"
39
40#include "llvm/CodeGen/RegAllocRegistry.h"
41#include "llvm/CodeGen/SchedulerRegistry.h"
42
43#include "llvm/IR/Constant.h"
44#include "llvm/IR/Constants.h"
45#include "llvm/IR/DataLayout.h"
46#include "llvm/IR/DebugLoc.h"
47#include "llvm/IR/DerivedTypes.h"
48#include "llvm/IR/Function.h"
49#include "llvm/IR/IRBuilder.h"
50#include "llvm/IR/IRPrintingPasses.h"
51#include "llvm/IR/LLVMContext.h"
52#include "llvm/IR/Metadata.h"
53#include "llvm/IR/Module.h"
54
55#include "llvm/Transforms/IPO/PassManagerBuilder.h"
56
57#include "llvm/Target/TargetMachine.h"
58#include "llvm/Target/TargetOptions.h"
59#include "llvm/Support/TargetRegistry.h"
60
61#include "llvm/MC/SubtargetFeature.h"
62
63#include "slang_assert.h"
64#include "slang.h"
65#include "slang_bitcode_gen.h"
66#include "slang_rs_context.h"
67#include "slang_rs_export_foreach.h"
68#include "slang_rs_export_func.h"
69#include "slang_rs_export_reduce.h"
70#include "slang_rs_export_type.h"
71#include "slang_rs_export_var.h"
72#include "slang_rs_metadata.h"
73
74#include "rs_cc_options.h"
75
76#include "strip_unknown_attributes.h"
77
78namespace slang {
79
80void Backend::CreateFunctionPasses() {
81  if (!mPerFunctionPasses) {
82    mPerFunctionPasses = new llvm::legacy::FunctionPassManager(mpModule);
83
84    llvm::PassManagerBuilder PMBuilder;
85    PMBuilder.OptLevel = mCodeGenOpts.OptimizationLevel;
86    PMBuilder.populateFunctionPassManager(*mPerFunctionPasses);
87  }
88}
89
90void Backend::CreateModulePasses() {
91  if (!mPerModulePasses) {
92    mPerModulePasses = new llvm::legacy::PassManager();
93
94    llvm::PassManagerBuilder PMBuilder;
95    PMBuilder.OptLevel = mCodeGenOpts.OptimizationLevel;
96    PMBuilder.SizeLevel = mCodeGenOpts.OptimizeSize;
97    if (mCodeGenOpts.UnitAtATime) {
98      PMBuilder.DisableUnitAtATime = 0;
99    } else {
100      PMBuilder.DisableUnitAtATime = 1;
101    }
102
103    if (mCodeGenOpts.UnrollLoops) {
104      PMBuilder.DisableUnrollLoops = 0;
105    } else {
106      PMBuilder.DisableUnrollLoops = 1;
107    }
108
109    PMBuilder.populateModulePassManager(*mPerModulePasses);
110    // Add a pass to strip off unknown/unsupported attributes.
111    mPerModulePasses->add(createStripUnknownAttributesPass());
112  }
113}
114
115bool Backend::CreateCodeGenPasses() {
116  if ((mOT != Slang::OT_Assembly) && (mOT != Slang::OT_Object))
117    return true;
118
119  // Now we add passes for code emitting
120  if (mCodeGenPasses) {
121    return true;
122  } else {
123    mCodeGenPasses = new llvm::legacy::FunctionPassManager(mpModule);
124  }
125
126  // Create the TargetMachine for generating code.
127  std::string Triple = mpModule->getTargetTriple();
128
129  std::string Error;
130  const llvm::Target* TargetInfo =
131      llvm::TargetRegistry::lookupTarget(Triple, Error);
132  if (TargetInfo == nullptr) {
133    mDiagEngine.Report(clang::diag::err_fe_unable_to_create_target) << Error;
134    return false;
135  }
136
137  // Target Machine Options
138  llvm::TargetOptions Options;
139
140  // Use soft-float ABI for ARM (which is the target used by Slang during code
141  // generation).  Codegen still uses hardware FPU by default.  To use software
142  // floating point, add 'soft-float' feature to FeaturesStr below.
143  Options.FloatABIType = llvm::FloatABI::Soft;
144
145  // BCC needs all unknown symbols resolved at compilation time. So we don't
146  // need any relocation model.
147  llvm::Reloc::Model RM = llvm::Reloc::Static;
148
149  // This is set for the linker (specify how large of the virtual addresses we
150  // can access for all unknown symbols.)
151  llvm::CodeModel::Model CM;
152  if (mpModule->getDataLayout().getPointerSize() == 4) {
153    CM = llvm::CodeModel::Small;
154  } else {
155    // The target may have pointer size greater than 32 (e.g. x86_64
156    // architecture) may need large data address model
157    CM = llvm::CodeModel::Medium;
158  }
159
160  // Setup feature string
161  std::string FeaturesStr;
162  if (mTargetOpts.CPU.size() || mTargetOpts.Features.size()) {
163    llvm::SubtargetFeatures Features;
164
165    for (std::vector<std::string>::const_iterator
166             I = mTargetOpts.Features.begin(), E = mTargetOpts.Features.end();
167         I != E;
168         I++)
169      Features.AddFeature(*I);
170
171    FeaturesStr = Features.getString();
172  }
173
174  llvm::TargetMachine *TM =
175    TargetInfo->createTargetMachine(Triple, mTargetOpts.CPU, FeaturesStr,
176                                    Options, RM, CM);
177
178  // Register scheduler
179  llvm::RegisterScheduler::setDefault(llvm::createDefaultScheduler);
180
181  // Register allocation policy:
182  //  createFastRegisterAllocator: fast but bad quality
183  //  createGreedyRegisterAllocator: not so fast but good quality
184  llvm::RegisterRegAlloc::setDefault((mCodeGenOpts.OptimizationLevel == 0) ?
185                                     llvm::createFastRegisterAllocator :
186                                     llvm::createGreedyRegisterAllocator);
187
188  llvm::CodeGenOpt::Level OptLevel = llvm::CodeGenOpt::Default;
189  if (mCodeGenOpts.OptimizationLevel == 0) {
190    OptLevel = llvm::CodeGenOpt::None;
191  } else if (mCodeGenOpts.OptimizationLevel == 3) {
192    OptLevel = llvm::CodeGenOpt::Aggressive;
193  }
194
195  llvm::TargetMachine::CodeGenFileType CGFT =
196      llvm::TargetMachine::CGFT_AssemblyFile;
197  if (mOT == Slang::OT_Object) {
198    CGFT = llvm::TargetMachine::CGFT_ObjectFile;
199  }
200  if (TM->addPassesToEmitFile(*mCodeGenPasses, mBufferOutStream,
201                              CGFT, OptLevel)) {
202    mDiagEngine.Report(clang::diag::err_fe_unable_to_interface_with_target);
203    return false;
204  }
205
206  return true;
207}
208
209Backend::Backend(RSContext *Context, clang::DiagnosticsEngine *DiagEngine,
210                 const RSCCOptions &Opts, const clang::CodeGenOptions &CodeGenOpts,
211                 const clang::TargetOptions &TargetOpts, PragmaList *Pragmas,
212                 llvm::raw_ostream *OS, Slang::OutputType OT,
213                 clang::SourceManager &SourceMgr, bool AllowRSPrefix,
214                 bool IsFilterscript)
215    : ASTConsumer(), mTargetOpts(TargetOpts), mpModule(nullptr), mpOS(OS),
216      mOT(OT), mGen(nullptr), mPerFunctionPasses(nullptr),
217      mPerModulePasses(nullptr), mCodeGenPasses(nullptr),
218      mBufferOutStream(*mpOS), mContext(Context),
219      mSourceMgr(SourceMgr), mASTPrint(Opts.mASTPrint), mAllowRSPrefix(AllowRSPrefix),
220      mIsFilterscript(IsFilterscript), mExportVarMetadata(nullptr),
221      mExportFuncMetadata(nullptr), mExportForEachNameMetadata(nullptr),
222      mExportForEachSignatureMetadata(nullptr), mExportReduceMetadata(nullptr),
223      mExportReduceNewMetadata(nullptr),
224      mExportTypeMetadata(nullptr), mRSObjectSlotsMetadata(nullptr),
225      mRefCount(mContext->getASTContext()),
226      mASTChecker(Context, Context->getTargetAPI(), IsFilterscript),
227      mForEachHandler(Context),
228      mLLVMContext(llvm::getGlobalContext()), mDiagEngine(*DiagEngine),
229      mCodeGenOpts(CodeGenOpts), mPragmas(Pragmas) {
230  mGen = CreateLLVMCodeGen(mDiagEngine, "", mCodeGenOpts, mLLVMContext);
231}
232
233void Backend::Initialize(clang::ASTContext &Ctx) {
234  mGen->Initialize(Ctx);
235
236  mpModule = mGen->GetModule();
237}
238
239void Backend::HandleTranslationUnit(clang::ASTContext &Ctx) {
240  HandleTranslationUnitPre(Ctx);
241
242  if (mASTPrint)
243    Ctx.getTranslationUnitDecl()->dump();
244
245  mGen->HandleTranslationUnit(Ctx);
246
247  // Here, we complete a translation unit (whole translation unit is now in LLVM
248  // IR). Now, interact with LLVM backend to generate actual machine code (asm
249  // or machine code, whatever.)
250
251  // Silently ignore if we weren't initialized for some reason.
252  if (!mpModule)
253    return;
254
255  llvm::Module *M = mGen->ReleaseModule();
256  if (!M) {
257    // The module has been released by IR gen on failures, do not double free.
258    mpModule = nullptr;
259    return;
260  }
261
262  slangAssert(mpModule == M &&
263              "Unexpected module change during LLVM IR generation");
264
265  // Insert #pragma information into metadata section of module
266  if (!mPragmas->empty()) {
267    llvm::NamedMDNode *PragmaMetadata =
268        mpModule->getOrInsertNamedMetadata(Slang::PragmaMetadataName);
269    for (PragmaList::const_iterator I = mPragmas->begin(), E = mPragmas->end();
270         I != E;
271         I++) {
272      llvm::SmallVector<llvm::Metadata*, 2> Pragma;
273      // Name goes first
274      Pragma.push_back(llvm::MDString::get(mLLVMContext, I->first));
275      // And then value
276      Pragma.push_back(llvm::MDString::get(mLLVMContext, I->second));
277
278      // Create MDNode and insert into PragmaMetadata
279      PragmaMetadata->addOperand(
280          llvm::MDNode::get(mLLVMContext, Pragma));
281    }
282  }
283
284  HandleTranslationUnitPost(mpModule);
285
286  // Create passes for optimization and code emission
287
288  // Create and run per-function passes
289  CreateFunctionPasses();
290  if (mPerFunctionPasses) {
291    mPerFunctionPasses->doInitialization();
292
293    for (llvm::Module::iterator I = mpModule->begin(), E = mpModule->end();
294         I != E;
295         I++)
296      if (!I->isDeclaration())
297        mPerFunctionPasses->run(*I);
298
299    mPerFunctionPasses->doFinalization();
300  }
301
302  // Create and run module passes
303  CreateModulePasses();
304  if (mPerModulePasses)
305    mPerModulePasses->run(*mpModule);
306
307  switch (mOT) {
308    case Slang::OT_Assembly:
309    case Slang::OT_Object: {
310      if (!CreateCodeGenPasses())
311        return;
312
313      mCodeGenPasses->doInitialization();
314
315      for (llvm::Module::iterator I = mpModule->begin(), E = mpModule->end();
316          I != E;
317          I++)
318        if (!I->isDeclaration())
319          mCodeGenPasses->run(*I);
320
321      mCodeGenPasses->doFinalization();
322      break;
323    }
324    case Slang::OT_LLVMAssembly: {
325      llvm::legacy::PassManager *LLEmitPM = new llvm::legacy::PassManager();
326      LLEmitPM->add(llvm::createPrintModulePass(mBufferOutStream));
327      LLEmitPM->run(*mpModule);
328      break;
329    }
330    case Slang::OT_Bitcode: {
331      writeBitcode(mBufferOutStream, *mpModule, getTargetAPI(),
332                   mCodeGenOpts.OptimizationLevel, mCodeGenOpts.getDebugInfo());
333      break;
334    }
335    case Slang::OT_Nothing: {
336      return;
337    }
338    default: {
339      slangAssert(false && "Unknown output type");
340    }
341  }
342
343  mBufferOutStream.flush();
344}
345
346void Backend::HandleTagDeclDefinition(clang::TagDecl *D) {
347  mGen->HandleTagDeclDefinition(D);
348}
349
350void Backend::CompleteTentativeDefinition(clang::VarDecl *D) {
351  mGen->CompleteTentativeDefinition(D);
352}
353
354Backend::~Backend() {
355  delete mpModule;
356  delete mGen;
357  delete mPerFunctionPasses;
358  delete mPerModulePasses;
359  delete mCodeGenPasses;
360}
361
362// 1) Add zero initialization of local RS object types
363void Backend::AnnotateFunction(clang::FunctionDecl *FD) {
364  if (FD &&
365      FD->hasBody() &&
366      !Slang::IsLocInRSHeaderFile(FD->getLocation(), mSourceMgr)) {
367    mRefCount.Init();
368    mRefCount.Visit(FD->getBody());
369  }
370}
371
372void Backend::LowerRSForEachCall(clang::FunctionDecl *FD) {
373  // Skip this AST walking for lower API levels.
374  if (getTargetAPI() < SLANG_DEVELOPMENT_TARGET_API) {
375    return;
376  }
377
378  if (!FD || !FD->hasBody() ||
379      Slang::IsLocInRSHeaderFile(FD->getLocation(), mSourceMgr)) {
380    return;
381  }
382
383  mForEachHandler.VisitStmt(FD->getBody());
384}
385
386bool Backend::HandleTopLevelDecl(clang::DeclGroupRef D) {
387  // Disallow user-defined functions with prefix "rs"
388  if (!mAllowRSPrefix) {
389    // Iterate all function declarations in the program.
390    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
391         I != E; I++) {
392      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
393      if (FD == nullptr)
394        continue;
395      if (!FD->getName().startswith("rs"))  // Check prefix
396        continue;
397      if (!Slang::IsLocInRSHeaderFile(FD->getLocation(), mSourceMgr))
398        mContext->ReportError(FD->getLocation(),
399                              "invalid function name prefix, "
400                              "\"rs\" is reserved: '%0'")
401            << FD->getName();
402    }
403  }
404
405  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
406    clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
407    // Process any non-static function declarations
408    if (FD && FD->isGlobal()) {
409      // Check that we don't have any array parameters being misintrepeted as
410      // kernel pointers due to the C type system's array to pointer decay.
411      size_t numParams = FD->getNumParams();
412      for (size_t i = 0; i < numParams; i++) {
413        const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
414        clang::QualType QT = PVD->getOriginalType();
415        if (QT->isArrayType()) {
416          mContext->ReportError(
417              PVD->getTypeSpecStartLoc(),
418              "exported function parameters may not have array type: %0")
419              << QT;
420        }
421      }
422      AnnotateFunction(FD);
423    }
424
425    if (getTargetAPI() == SLANG_DEVELOPMENT_TARGET_API) {
426      if (FD && FD->hasBody() &&
427          RSExportForEach::isRSForEachFunc(getTargetAPI(), FD)) {
428        // Log kernels by their names, and assign them slot numbers.
429        if (!Slang::IsLocInRSHeaderFile(FD->getLocation(), mSourceMgr)) {
430            mContext->addForEach(FD);
431        }
432      } else {
433        // Look for any kernel launch calls and translate them into using the
434        // internal API.
435        // TODO: Simply ignores kernel launch inside a kernel for now.
436        // Needs more rigorous and comprehensive checks.
437        LowerRSForEachCall(FD);
438      }
439    }
440  }
441
442  return mGen->HandleTopLevelDecl(D);
443}
444
445void Backend::HandleTranslationUnitPre(clang::ASTContext &C) {
446  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
447
448  // If we have an invalid RS/FS AST, don't check further.
449  if (!mASTChecker.Validate()) {
450    return;
451  }
452
453  if (mIsFilterscript) {
454    mContext->addPragma("rs_fp_relaxed", "");
455  }
456
457  int version = mContext->getVersion();
458  if (version == 0) {
459    // Not setting a version is an error
460    mDiagEngine.Report(
461        mSourceMgr.getLocForEndOfFile(mSourceMgr.getMainFileID()),
462        mDiagEngine.getCustomDiagID(
463            clang::DiagnosticsEngine::Error,
464            "missing pragma for version in source file"));
465  } else {
466    slangAssert(version == 1);
467  }
468
469  if (mContext->getReflectJavaPackageName().empty()) {
470    mDiagEngine.Report(
471        mSourceMgr.getLocForEndOfFile(mSourceMgr.getMainFileID()),
472        mDiagEngine.getCustomDiagID(clang::DiagnosticsEngine::Error,
473                                    "missing \"#pragma rs "
474                                    "java_package_name(com.foo.bar)\" "
475                                    "in source file"));
476    return;
477  }
478
479  // Create a static global destructor if necessary (to handle RS object
480  // runtime cleanup).
481  clang::FunctionDecl *FD = mRefCount.CreateStaticGlobalDtor();
482  if (FD) {
483    HandleTopLevelDecl(clang::DeclGroupRef(FD));
484  }
485
486  // Process any static function declarations
487  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
488          E = TUDecl->decls_end(); I != E; I++) {
489    if ((I->getKind() >= clang::Decl::firstFunction) &&
490        (I->getKind() <= clang::Decl::lastFunction)) {
491      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
492      if (FD && !FD->isGlobal()) {
493        AnnotateFunction(FD);
494      }
495    }
496  }
497}
498
499///////////////////////////////////////////////////////////////////////////////
500void Backend::dumpExportVarInfo(llvm::Module *M) {
501  int slotCount = 0;
502  if (mExportVarMetadata == nullptr)
503    mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
504
505  llvm::SmallVector<llvm::Metadata *, 2> ExportVarInfo;
506
507  // We emit slot information (#rs_object_slots) for any reference counted
508  // RS type or pointer (which can also be bound).
509
510  for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
511          E = mContext->export_vars_end();
512       I != E;
513       I++) {
514    const RSExportVar *EV = *I;
515    const RSExportType *ET = EV->getType();
516    bool countsAsRSObject = false;
517
518    // Variable name
519    ExportVarInfo.push_back(
520        llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
521
522    // Type name
523    switch (ET->getClass()) {
524      case RSExportType::ExportClassPrimitive: {
525        const RSExportPrimitiveType *PT =
526            static_cast<const RSExportPrimitiveType*>(ET);
527        ExportVarInfo.push_back(
528            llvm::MDString::get(
529              mLLVMContext, llvm::utostr_32(PT->getType())));
530        if (PT->isRSObjectType()) {
531          countsAsRSObject = true;
532        }
533        break;
534      }
535      case RSExportType::ExportClassPointer: {
536        ExportVarInfo.push_back(
537            llvm::MDString::get(
538              mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
539                ->getPointeeType()->getName()).c_str()));
540        break;
541      }
542      case RSExportType::ExportClassMatrix: {
543        ExportVarInfo.push_back(
544            llvm::MDString::get(
545              mLLVMContext, llvm::utostr_32(
546                  /* TODO Strange value.  This pushes just a number, quite
547                   * different than the other cases.  What is this used for?
548                   * These are the metadata values that some partner drivers
549                   * want to reference (for TBAA, etc.). We may want to look
550                   * at whether these provide any reasonable value (or have
551                   * distinct enough values to actually depend on).
552                   */
553                DataTypeRSMatrix2x2 +
554                static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
555        break;
556      }
557      case RSExportType::ExportClassVector:
558      case RSExportType::ExportClassConstantArray:
559      case RSExportType::ExportClassRecord: {
560        ExportVarInfo.push_back(
561            llvm::MDString::get(mLLVMContext,
562              EV->getType()->getName().c_str()));
563        break;
564      }
565    }
566
567    mExportVarMetadata->addOperand(
568        llvm::MDNode::get(mLLVMContext, ExportVarInfo));
569    ExportVarInfo.clear();
570
571    if (mRSObjectSlotsMetadata == nullptr) {
572      mRSObjectSlotsMetadata =
573          M->getOrInsertNamedMetadata(RS_OBJECT_SLOTS_MN);
574    }
575
576    if (countsAsRSObject) {
577      mRSObjectSlotsMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
578          llvm::MDString::get(mLLVMContext, llvm::utostr_32(slotCount))));
579    }
580
581    slotCount++;
582  }
583}
584
585void Backend::dumpExportFunctionInfo(llvm::Module *M) {
586  if (mExportFuncMetadata == nullptr)
587    mExportFuncMetadata =
588        M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
589
590  llvm::SmallVector<llvm::Metadata *, 1> ExportFuncInfo;
591
592  for (RSContext::const_export_func_iterator
593          I = mContext->export_funcs_begin(),
594          E = mContext->export_funcs_end();
595       I != E;
596       I++) {
597    const RSExportFunc *EF = *I;
598
599    // Function name
600    if (!EF->hasParam()) {
601      ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
602                                                   EF->getName().c_str()));
603    } else {
604      llvm::Function *F = M->getFunction(EF->getName());
605      llvm::Function *HelperFunction;
606      const std::string HelperFunctionName(".helper_" + EF->getName());
607
608      slangAssert(F && "Function marked as exported disappeared in Bitcode");
609
610      // Create helper function
611      {
612        llvm::StructType *HelperFunctionParameterTy = nullptr;
613        std::vector<bool> isStructInput;
614
615        if (!F->getArgumentList().empty()) {
616          std::vector<llvm::Type*> HelperFunctionParameterTys;
617          for (llvm::Function::arg_iterator AI = F->arg_begin(),
618                   AE = F->arg_end(); AI != AE; AI++) {
619              if (AI->getType()->isPointerTy() && AI->getType()->getPointerElementType()->isStructTy()) {
620                  HelperFunctionParameterTys.push_back(AI->getType()->getPointerElementType());
621                  isStructInput.push_back(true);
622              } else {
623                  HelperFunctionParameterTys.push_back(AI->getType());
624                  isStructInput.push_back(false);
625              }
626          }
627          HelperFunctionParameterTy =
628              llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
629        }
630
631        if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
632          fprintf(stderr, "Failed to export function %s: parameter type "
633                          "mismatch during creation of helper function.\n",
634                  EF->getName().c_str());
635
636          const RSExportRecordType *Expected = EF->getParamPacketType();
637          if (Expected) {
638            fprintf(stderr, "Expected:\n");
639            Expected->getLLVMType()->dump();
640          }
641          if (HelperFunctionParameterTy) {
642            fprintf(stderr, "Got:\n");
643            HelperFunctionParameterTy->dump();
644          }
645        }
646
647        std::vector<llvm::Type*> Params;
648        if (HelperFunctionParameterTy) {
649          llvm::PointerType *HelperFunctionParameterTyP =
650              llvm::PointerType::getUnqual(HelperFunctionParameterTy);
651          Params.push_back(HelperFunctionParameterTyP);
652        }
653
654        llvm::FunctionType * HelperFunctionType =
655            llvm::FunctionType::get(F->getReturnType(),
656                                    Params,
657                                    /* IsVarArgs = */false);
658
659        HelperFunction =
660            llvm::Function::Create(HelperFunctionType,
661                                   llvm::GlobalValue::ExternalLinkage,
662                                   HelperFunctionName,
663                                   M);
664
665        HelperFunction->addFnAttr(llvm::Attribute::NoInline);
666        HelperFunction->setCallingConv(F->getCallingConv());
667
668        // Create helper function body
669        {
670          llvm::Argument *HelperFunctionParameter =
671              &(*HelperFunction->arg_begin());
672          llvm::BasicBlock *BB =
673              llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
674          llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
675          llvm::SmallVector<llvm::Value*, 6> Params;
676          llvm::Value *Idx[2];
677
678          Idx[0] =
679              llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
680
681          // getelementptr and load instruction for all elements in
682          // parameter .p
683          for (size_t i = 0; i < EF->getNumParameters(); i++) {
684            // getelementptr
685            Idx[1] = llvm::ConstantInt::get(
686              llvm::Type::getInt32Ty(mLLVMContext), i);
687
688            llvm::Value *Ptr = NULL;
689
690            Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter, Idx);
691
692            // Load is only required for non-struct ptrs
693            if (isStructInput[i]) {
694                Params.push_back(Ptr);
695            } else {
696                llvm::Value *V = IB->CreateLoad(Ptr);
697                Params.push_back(V);
698            }
699          }
700
701          // Call and pass the all elements as parameter to F
702          llvm::CallInst *CI = IB->CreateCall(F, Params);
703
704          CI->setCallingConv(F->getCallingConv());
705
706          if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext)) {
707            IB->CreateRetVoid();
708          } else {
709            IB->CreateRet(CI);
710          }
711
712          delete IB;
713        }
714      }
715
716      ExportFuncInfo.push_back(
717          llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
718    }
719
720    mExportFuncMetadata->addOperand(
721        llvm::MDNode::get(mLLVMContext, ExportFuncInfo));
722    ExportFuncInfo.clear();
723  }
724}
725
726void Backend::dumpExportForEachInfo(llvm::Module *M) {
727  if (mExportForEachNameMetadata == nullptr) {
728    mExportForEachNameMetadata =
729        M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_NAME_MN);
730  }
731  if (mExportForEachSignatureMetadata == nullptr) {
732    mExportForEachSignatureMetadata =
733        M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_MN);
734  }
735
736  llvm::SmallVector<llvm::Metadata *, 1> ExportForEachName;
737  llvm::SmallVector<llvm::Metadata *, 1> ExportForEachInfo;
738
739  for (RSContext::const_export_foreach_iterator
740          I = mContext->export_foreach_begin(),
741          E = mContext->export_foreach_end();
742       I != E;
743       I++) {
744    const RSExportForEach *EFE = *I;
745
746    ExportForEachName.push_back(
747        llvm::MDString::get(mLLVMContext, EFE->getName().c_str()));
748
749    mExportForEachNameMetadata->addOperand(
750        llvm::MDNode::get(mLLVMContext, ExportForEachName));
751    ExportForEachName.clear();
752
753    ExportForEachInfo.push_back(
754        llvm::MDString::get(mLLVMContext,
755                            llvm::utostr_32(EFE->getSignatureMetadata())));
756
757    mExportForEachSignatureMetadata->addOperand(
758        llvm::MDNode::get(mLLVMContext, ExportForEachInfo));
759    ExportForEachInfo.clear();
760  }
761}
762
763void Backend::dumpExportReduceInfo(llvm::Module *M) {
764  if (!mExportReduceMetadata) {
765    mExportReduceMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_REDUCE_MN);
766  }
767
768  llvm::SmallVector<llvm::Metadata *, 1> ExportReduceInfo;
769
770  // Add the names of the reduce-style kernel functions to the metadata node.
771  for (auto I = mContext->export_reduce_begin(),
772            E = mContext->export_reduce_end(); I != E; ++I) {
773    ExportReduceInfo.clear();
774
775    ExportReduceInfo.push_back(
776      llvm::MDString::get(mLLVMContext, (*I)->getName().c_str()));
777
778    mExportReduceMetadata->addOperand(
779      llvm::MDNode::get(mLLVMContext, ExportReduceInfo));
780  }
781}
782
783void Backend::dumpExportReduceNewInfo(llvm::Module *M) {
784  if (!mExportReduceNewMetadata) {
785    mExportReduceNewMetadata =
786      M->getOrInsertNamedMetadata(RS_EXPORT_REDUCE_NEW_MN);
787  }
788
789  llvm::SmallVector<llvm::Metadata *, 6> ExportReduceNewInfo;
790  // Add operand to ExportReduceNewInfo, padding out missing operands with
791  // nullptr.
792  auto addOperand = [&ExportReduceNewInfo](uint32_t Idx, llvm::Metadata *N) {
793    while (Idx > ExportReduceNewInfo.size())
794      ExportReduceNewInfo.push_back(nullptr);
795    ExportReduceNewInfo.push_back(N);
796  };
797  // Add string operand to ExportReduceNewInfo, padding out missing operands
798  // with nullptr.
799  // If string is empty, then do not add it unless Always is true.
800  auto addString = [&addOperand, this](uint32_t Idx, const std::string &S,
801                                       bool Always = true) {
802    if (Always || !S.empty())
803      addOperand(Idx, llvm::MDString::get(mLLVMContext, S));
804  };
805
806  // Add the description of the reduction kernels to the metadata node.
807  for (auto I = mContext->export_reduce_new_begin(),
808            E = mContext->export_reduce_new_end();
809       I != E; ++I) {
810    ExportReduceNewInfo.clear();
811
812    addString(0, (*I)->getNameReduce());
813    addString(1, (*I)->getNameInitializer());
814
815    llvm::SmallVector<llvm::Metadata *, 2> Accumulator;
816    Accumulator.push_back(
817      llvm::MDString::get(mLLVMContext, (*I)->getNameAccumulator()));
818    Accumulator.push_back(llvm::MDString::get(
819      mLLVMContext,
820      llvm::utostr_32(0))); // TODO: emit actual accumulator signature bits
821    addOperand(2, llvm::MDTuple::get(mLLVMContext, Accumulator));
822
823    addString(3, (*I)->getNameCombiner(), false);
824    addString(4, (*I)->getNameOutConverter(), false);
825    addString(5, (*I)->getNameHalter(), false);
826
827    mExportReduceNewMetadata->addOperand(
828      llvm::MDTuple::get(mLLVMContext, ExportReduceNewInfo));
829  }
830}
831
832void Backend::dumpExportTypeInfo(llvm::Module *M) {
833  llvm::SmallVector<llvm::Metadata *, 1> ExportTypeInfo;
834
835  for (RSContext::const_export_type_iterator
836          I = mContext->export_types_begin(),
837          E = mContext->export_types_end();
838       I != E;
839       I++) {
840    // First, dump type name list to export
841    const RSExportType *ET = I->getValue();
842
843    ExportTypeInfo.clear();
844    // Type name
845    ExportTypeInfo.push_back(
846        llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
847
848    if (ET->getClass() == RSExportType::ExportClassRecord) {
849      const RSExportRecordType *ERT =
850          static_cast<const RSExportRecordType*>(ET);
851
852      if (mExportTypeMetadata == nullptr)
853        mExportTypeMetadata =
854            M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
855
856      mExportTypeMetadata->addOperand(
857          llvm::MDNode::get(mLLVMContext, ExportTypeInfo));
858
859      // Now, export struct field information to %[struct name]
860      std::string StructInfoMetadataName("%");
861      StructInfoMetadataName.append(ET->getName());
862      llvm::NamedMDNode *StructInfoMetadata =
863          M->getOrInsertNamedMetadata(StructInfoMetadataName);
864      llvm::SmallVector<llvm::Metadata *, 3> FieldInfo;
865
866      slangAssert(StructInfoMetadata->getNumOperands() == 0 &&
867                  "Metadata with same name was created before");
868      for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
869              FE = ERT->fields_end();
870           FI != FE;
871           FI++) {
872        const RSExportRecordType::Field *F = *FI;
873
874        // 1. field name
875        FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
876                                                F->getName().c_str()));
877
878        // 2. field type name
879        FieldInfo.push_back(
880            llvm::MDString::get(mLLVMContext,
881                                F->getType()->getName().c_str()));
882
883        StructInfoMetadata->addOperand(
884            llvm::MDNode::get(mLLVMContext, FieldInfo));
885        FieldInfo.clear();
886      }
887    }   // ET->getClass() == RSExportType::ExportClassRecord
888  }
889}
890
891void Backend::HandleTranslationUnitPost(llvm::Module *M) {
892
893  if (!mContext->is64Bit()) {
894    M->setDataLayout("e-p:32:32-i64:64-v128:64:128-n32-S64");
895  }
896
897  if (!mContext->processExports()) {
898    return;
899  }
900
901  if (mContext->hasExportVar())
902    dumpExportVarInfo(M);
903
904  if (mContext->hasExportFunc())
905    dumpExportFunctionInfo(M);
906
907  if (mContext->hasExportForEach())
908    dumpExportForEachInfo(M);
909
910  if (mContext->hasExportReduce())
911    dumpExportReduceInfo(M);
912
913  if (mContext->hasExportReduceNew())
914    dumpExportReduceNewInfo(M);
915
916  if (mContext->hasExportType())
917    dumpExportTypeInfo(M);
918}
919
920}  // namespace slang
921