CGCUDANV.cpp revision 3b844ba7d5be205a9b4f5f0b0d1b7978977f4b8c
1//===----- CGCUDANV.cpp - Interface to NVIDIA CUDA Runtime ----------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This provides a class for CUDA code generation targeting the NVIDIA CUDA
11// runtime library.
12//
13//===----------------------------------------------------------------------===//
14
15#include "CGCUDARuntime.h"
16#include "CodeGenFunction.h"
17#include "CodeGenModule.h"
18#include "clang/AST/Decl.h"
19#include "llvm/IR/BasicBlock.h"
20#include "llvm/IR/Constants.h"
21#include "llvm/IR/DerivedTypes.h"
22#include "llvm/Support/CallSite.h"
23#include <vector>
24
25using namespace clang;
26using namespace CodeGen;
27
28namespace {
29
30class CGNVCUDARuntime : public CGCUDARuntime {
31
32private:
33  llvm::Type *IntTy, *SizeTy;
34  llvm::PointerType *CharPtrTy, *VoidPtrTy;
35
36  llvm::Constant *getSetupArgumentFn() const;
37  llvm::Constant *getLaunchFn() const;
38
39public:
40  CGNVCUDARuntime(CodeGenModule &CGM);
41
42  void EmitDeviceStubBody(CodeGenFunction &CGF, FunctionArgList &Args);
43};
44
45}
46
47CGNVCUDARuntime::CGNVCUDARuntime(CodeGenModule &CGM) : CGCUDARuntime(CGM) {
48  CodeGen::CodeGenTypes &Types = CGM.getTypes();
49  ASTContext &Ctx = CGM.getContext();
50
51  IntTy = Types.ConvertType(Ctx.IntTy);
52  SizeTy = Types.ConvertType(Ctx.getSizeType());
53
54  CharPtrTy = llvm::PointerType::getUnqual(Types.ConvertType(Ctx.CharTy));
55  VoidPtrTy = cast<llvm::PointerType>(Types.ConvertType(Ctx.VoidPtrTy));
56}
57
58llvm::Constant *CGNVCUDARuntime::getSetupArgumentFn() const {
59  // cudaError_t cudaSetupArgument(void *, size_t, size_t)
60  std::vector<llvm::Type*> Params;
61  Params.push_back(VoidPtrTy);
62  Params.push_back(SizeTy);
63  Params.push_back(SizeTy);
64  return CGM.CreateRuntimeFunction(llvm::FunctionType::get(IntTy,
65                                                           Params, false),
66                                   "cudaSetupArgument");
67}
68
69llvm::Constant *CGNVCUDARuntime::getLaunchFn() const {
70  // cudaError_t cudaLaunch(char *)
71  std::vector<llvm::Type*> Params;
72  Params.push_back(CharPtrTy);
73  return CGM.CreateRuntimeFunction(llvm::FunctionType::get(IntTy,
74                                                           Params, false),
75                                   "cudaLaunch");
76}
77
78void CGNVCUDARuntime::EmitDeviceStubBody(CodeGenFunction &CGF,
79                                         FunctionArgList &Args) {
80  // Build the argument value list and the argument stack struct type.
81  llvm::SmallVector<llvm::Value *, 16> ArgValues;
82  std::vector<llvm::Type *> ArgTypes;
83  for (FunctionArgList::const_iterator I = Args.begin(), E = Args.end();
84       I != E; ++I) {
85    llvm::Value *V = CGF.GetAddrOfLocalVar(*I);
86    ArgValues.push_back(V);
87    assert(isa<llvm::PointerType>(V->getType()) && "Arg type not PointerType");
88    ArgTypes.push_back(cast<llvm::PointerType>(V->getType())->getElementType());
89  }
90  llvm::StructType *ArgStackTy = llvm::StructType::get(
91      CGF.getLLVMContext(), ArgTypes);
92
93  llvm::BasicBlock *EndBlock = CGF.createBasicBlock("setup.end");
94
95  // Emit the calls to cudaSetupArgument
96  llvm::Constant *cudaSetupArgFn = getSetupArgumentFn();
97  for (unsigned I = 0, E = Args.size(); I != E; ++I) {
98    llvm::Value *Args[3];
99    llvm::BasicBlock *NextBlock = CGF.createBasicBlock("setup.next");
100    Args[0] = CGF.Builder.CreatePointerCast(ArgValues[I], VoidPtrTy);
101    Args[1] = CGF.Builder.CreateIntCast(
102        llvm::ConstantExpr::getSizeOf(ArgTypes[I]),
103        SizeTy, false);
104    Args[2] = CGF.Builder.CreateIntCast(
105        llvm::ConstantExpr::getOffsetOf(ArgStackTy, I),
106        SizeTy, false);
107    llvm::CallSite CS = CGF.EmitCallOrInvoke(cudaSetupArgFn, Args);
108    llvm::Constant *Zero = llvm::ConstantInt::get(IntTy, 0);
109    llvm::Value *CSZero = CGF.Builder.CreateICmpEQ(CS.getInstruction(), Zero);
110    CGF.Builder.CreateCondBr(CSZero, NextBlock, EndBlock);
111    CGF.EmitBlock(NextBlock);
112  }
113
114  // Emit the call to cudaLaunch
115  llvm::Constant *cudaLaunchFn = getLaunchFn();
116  llvm::Value *Arg = CGF.Builder.CreatePointerCast(CGF.CurFn, CharPtrTy);
117  CGF.EmitCallOrInvoke(cudaLaunchFn, Arg);
118  CGF.EmitBranch(EndBlock);
119
120  CGF.EmitBlock(EndBlock);
121}
122
123CGCUDARuntime *CodeGen::CreateNVCUDARuntime(CodeGenModule &CGM) {
124  return new CGNVCUDARuntime(CGM);
125}
126