NVVMReflect.cpp revision cd81d94322a39503e4a3e87b6ee03d4fcb3465fb
1//===- NVVMReflect.cpp - NVVM Emulate conditional compilation -------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass replaces occurrences of __nvvm_reflect("string") with an
11// integer based on -nvvm-reflect-list string=<int> option given to this pass.
12// If an undefined string value is seen in a call to __nvvm_reflect("string"),
13// a default value of 0 will be used.
14//
15//===----------------------------------------------------------------------===//
16
17#include "NVPTX.h"
18#include "llvm/ADT/DenseMap.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/ADT/StringMap.h"
21#include "llvm/IR/Constants.h"
22#include "llvm/IR/DerivedTypes.h"
23#include "llvm/IR/Function.h"
24#include "llvm/IR/Instructions.h"
25#include "llvm/IR/Intrinsics.h"
26#include "llvm/IR/Module.h"
27#include "llvm/IR/Type.h"
28#include "llvm/Pass.h"
29#include "llvm/Support/CommandLine.h"
30#include "llvm/Support/Debug.h"
31#include "llvm/Support/raw_os_ostream.h"
32#include "llvm/Transforms/Scalar.h"
33#include <map>
34#include <sstream>
35#include <string>
36#include <vector>
37
38#define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
39
40using namespace llvm;
41
42#define DEBUG_TYPE "nvptx-reflect"
43
44namespace llvm { void initializeNVVMReflectPass(PassRegistry &); }
45
46namespace {
47class NVVMReflect : public ModulePass {
48private:
49  StringMap<int> VarMap;
50  typedef DenseMap<std::string, int>::iterator VarMapIter;
51
52public:
53  static char ID;
54  NVVMReflect() : ModulePass(ID) {
55    initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
56    VarMap.clear();
57  }
58
59  NVVMReflect(const StringMap<int> &Mapping)
60  : ModulePass(ID) {
61    initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
62    for (StringMap<int>::const_iterator I = Mapping.begin(), E = Mapping.end();
63         I != E; ++I) {
64      VarMap[(*I).getKey()] = (*I).getValue();
65    }
66  }
67
68  void getAnalysisUsage(AnalysisUsage &AU) const override {
69    AU.setPreservesAll();
70  }
71  bool runOnModule(Module &) override;
72
73private:
74  bool handleFunction(Function *ReflectFunction);
75  void setVarMap();
76};
77}
78
79ModulePass *llvm::createNVVMReflectPass() {
80  return new NVVMReflect();
81}
82
83ModulePass *llvm::createNVVMReflectPass(const StringMap<int>& Mapping) {
84  return new NVVMReflect(Mapping);
85}
86
87static cl::opt<bool>
88NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden,
89                   cl::desc("NVVM reflection, enabled by default"));
90
91char NVVMReflect::ID = 0;
92INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
93                "Replace occurrences of __nvvm_reflect() calls with 0/1", false,
94                false)
95
96static cl::list<std::string>
97ReflectList("nvvm-reflect-list", cl::value_desc("name=<int>"), cl::Hidden,
98            cl::desc("A list of string=num assignments"),
99            cl::ValueRequired);
100
101/// The command line can look as follows :
102/// -nvvm-reflect-list a=1,b=2 -nvvm-reflect-list c=3,d=0 -R e=2
103/// The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the
104/// ReflectList vector. First, each of ReflectList[i] is 'split'
105/// using "," as the delimiter. Then each of this part is split
106/// using "=" as the delimiter.
107void NVVMReflect::setVarMap() {
108  for (unsigned i = 0, e = ReflectList.size(); i != e; ++i) {
109    DEBUG(dbgs() << "Option : "  << ReflectList[i] << "\n");
110    SmallVector<StringRef, 4> NameValList;
111    StringRef(ReflectList[i]).split(NameValList, ",");
112    for (unsigned j = 0, ej = NameValList.size(); j != ej; ++j) {
113      SmallVector<StringRef, 2> NameValPair;
114      NameValList[j].split(NameValPair, "=");
115      assert(NameValPair.size() == 2 && "name=val expected");
116      std::stringstream ValStream(NameValPair[1]);
117      int Val;
118      ValStream >> Val;
119      assert((!(ValStream.fail())) && "integer value expected");
120      VarMap[NameValPair[0]] = Val;
121    }
122  }
123}
124
125bool NVVMReflect::handleFunction(Function *ReflectFunction) {
126  // Validate _reflect function
127  assert(ReflectFunction->isDeclaration() &&
128         "_reflect function should not have a body");
129  assert(ReflectFunction->getReturnType()->isIntegerTy() &&
130         "_reflect's return type should be integer");
131
132  std::vector<Instruction *> ToRemove;
133
134  // Go through the uses of ReflectFunction in this Function.
135  // Each of them should a CallInst with a ConstantArray argument.
136  // First validate that. If the c-string corresponding to the
137  // ConstantArray can be found successfully, see if it can be
138  // found in VarMap. If so, replace the uses of CallInst with the
139  // value found in VarMap. If not, replace the use  with value 0.
140  for (User *U : ReflectFunction->users()) {
141    assert(isa<CallInst>(U) && "Only a call instruction can use _reflect");
142    CallInst *Reflect = cast<CallInst>(U);
143
144    assert((Reflect->getNumOperands() == 2) &&
145           "Only one operand expect for _reflect function");
146    // In cuda, we will have an extra constant-to-generic conversion of
147    // the string.
148    const Value *Str = Reflect->getArgOperand(0);
149    if (isa<CallInst>(Str)) {
150      // CUDA path
151      const CallInst *ConvCall = cast<CallInst>(Str);
152      Str = ConvCall->getArgOperand(0);
153    }
154    assert(isa<ConstantExpr>(Str) &&
155           "Format of _reflect function not recognized");
156    const ConstantExpr *GEP = cast<ConstantExpr>(Str);
157
158    const Value *Sym = GEP->getOperand(0);
159    assert(isa<Constant>(Sym) && "Format of _reflect function not recognized");
160
161    const Constant *SymStr = cast<Constant>(Sym);
162
163    assert(isa<ConstantDataSequential>(SymStr->getOperand(0)) &&
164           "Format of _reflect function not recognized");
165
166    assert(cast<ConstantDataSequential>(SymStr->getOperand(0))->isCString() &&
167           "Format of _reflect function not recognized");
168
169    std::string ReflectArg =
170        cast<ConstantDataSequential>(SymStr->getOperand(0))->getAsString();
171
172    ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
173    DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n");
174
175    int ReflectVal = 0; // The default value is 0
176    if (VarMap.find(ReflectArg) != VarMap.end()) {
177      ReflectVal = VarMap[ReflectArg];
178    }
179    Reflect->replaceAllUsesWith(
180        ConstantInt::get(Reflect->getType(), ReflectVal));
181    ToRemove.push_back(Reflect);
182  }
183  if (ToRemove.size() == 0)
184    return false;
185
186  for (unsigned i = 0, e = ToRemove.size(); i != e; ++i)
187    ToRemove[i]->eraseFromParent();
188  return true;
189}
190
191bool NVVMReflect::runOnModule(Module &M) {
192  if (!NVVMReflectEnabled)
193    return false;
194
195  setVarMap();
196
197
198  bool Res = false;
199  std::string Name;
200  Type *Tys[1];
201  Type *I8Ty = Type::getInt8Ty(M.getContext());
202  Function *ReflectFunction;
203
204  // Check for standard overloaded versions of llvm.nvvm.reflect
205
206  for (unsigned i = 0; i != 5; ++i) {
207    Tys[0] = PointerType::get(I8Ty, i);
208    Name = Intrinsic::getName(Intrinsic::nvvm_reflect, Tys);
209    ReflectFunction = M.getFunction(Name);
210    if(ReflectFunction != 0) {
211      Res |= handleFunction(ReflectFunction);
212    }
213  }
214
215  ReflectFunction = M.getFunction(NVVM_REFLECT_FUNCTION);
216  // If reflect function is not used, then there will be
217  // no entry in the module.
218  if (ReflectFunction != 0)
219    Res |= handleFunction(ReflectFunction);
220
221  return Res;
222}
223