1//===-- JumpInstrTables.cpp: Jump-Instruction Tables ----------------------===//
2//
3// This file is distributed under the University of Illinois Open Source
4// License. See LICENSE.TXT for details.
5//
6//===----------------------------------------------------------------------===//
7///
8/// \file
9/// \brief An implementation of jump-instruction tables.
10///
11//===----------------------------------------------------------------------===//
12
13#define DEBUG_TYPE "jt"
14
15#include "llvm/CodeGen/JumpInstrTables.h"
16
17#include "llvm/ADT/Statistic.h"
18#include "llvm/Analysis/JumpInstrTableInfo.h"
19#include "llvm/CodeGen/Passes.h"
20#include "llvm/IR/Attributes.h"
21#include "llvm/IR/CallSite.h"
22#include "llvm/IR/Constants.h"
23#include "llvm/IR/DerivedTypes.h"
24#include "llvm/IR/Function.h"
25#include "llvm/IR/LLVMContext.h"
26#include "llvm/IR/Module.h"
27#include "llvm/IR/Operator.h"
28#include "llvm/IR/Type.h"
29#include "llvm/IR/Verifier.h"
30#include "llvm/Support/CommandLine.h"
31#include "llvm/Support/Debug.h"
32#include "llvm/Support/raw_ostream.h"
33
34#include <vector>
35
36using namespace llvm;
37
38char JumpInstrTables::ID = 0;
39
40INITIALIZE_PASS_BEGIN(JumpInstrTables, "jump-instr-tables",
41                      "Jump-Instruction Tables", true, true)
42INITIALIZE_PASS_DEPENDENCY(JumpInstrTableInfo);
43INITIALIZE_PASS_END(JumpInstrTables, "jump-instr-tables",
44                    "Jump-Instruction Tables", true, true)
45
46STATISTIC(NumJumpTables, "Number of indirect call tables generated");
47STATISTIC(NumFuncsInJumpTables, "Number of functions in the jump tables");
48
49ModulePass *llvm::createJumpInstrTablesPass() {
50  // The default implementation uses a single table for all functions.
51  return new JumpInstrTables(JumpTable::Single);
52}
53
54ModulePass *llvm::createJumpInstrTablesPass(JumpTable::JumpTableType JTT) {
55  return new JumpInstrTables(JTT);
56}
57
58namespace {
59static const char jump_func_prefix[] = "__llvm_jump_instr_table_";
60static const char jump_section_prefix[] = ".jump.instr.table.text.";
61
62// Checks to see if a given CallSite is making an indirect call, including
63// cases where the indirect call is made through a bitcast.
64bool isIndirectCall(CallSite &CS) {
65  if (CS.getCalledFunction())
66    return false;
67
68  // Check the value to see if it is merely a bitcast of a function. In
69  // this case, it will translate to a direct function call in the resulting
70  // assembly, so we won't treat it as an indirect call here.
71  const Value *V = CS.getCalledValue();
72  if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {
73    return !(CE->isCast() && isa<Function>(CE->getOperand(0)));
74  }
75
76  // Otherwise, since we know it's a call, it must be an indirect call
77  return true;
78}
79
80// Replaces Functions and GlobalAliases with a different Value.
81bool replaceGlobalValueIndirectUse(GlobalValue *GV, Value *V, Use *U) {
82  User *Us = U->getUser();
83  if (!Us)
84    return false;
85  if (Instruction *I = dyn_cast<Instruction>(Us)) {
86    CallSite CS(I);
87
88    // Don't do the replacement if this use is a direct call to this function.
89    // If the use is not the called value, then replace it.
90    if (CS && (isIndirectCall(CS) || CS.isCallee(U))) {
91      return false;
92    }
93
94    U->set(V);
95  } else if (Constant *C = dyn_cast<Constant>(Us)) {
96    // Don't replace calls to bitcasts of function symbols, since they get
97    // translated to direct calls.
98    if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Us)) {
99      if (CE->getOpcode() == Instruction::BitCast) {
100        // This bitcast must have exactly one user.
101        if (CE->user_begin() != CE->user_end()) {
102          User *ParentUs = *CE->user_begin();
103          if (CallInst *CI = dyn_cast<CallInst>(ParentUs)) {
104            CallSite CS(CI);
105            Use &CEU = *CE->use_begin();
106            if (CS.isCallee(&CEU)) {
107              return false;
108            }
109          }
110        }
111      }
112    }
113
114    // GlobalAlias doesn't support replaceUsesOfWithOnConstant. And the verifier
115    // requires alias to point to a defined function. So, GlobalAlias is handled
116    // as a separate case in runOnModule.
117    if (!isa<GlobalAlias>(C))
118      C->replaceUsesOfWithOnConstant(GV, V, U);
119  } else {
120    assert(false && "The Use of a Function symbol is neither an instruction nor"
121                    " a constant");
122  }
123
124  return true;
125}
126
127// Replaces all replaceable address-taken uses of GV with a pointer to a
128// jump-instruction table entry.
129void replaceValueWithFunction(GlobalValue *GV, Function *F) {
130  // Go through all uses of this function and replace the uses of GV with the
131  // jump-table version of the function. Get the uses as a vector before
132  // replacing them, since replacing them changes the use list and invalidates
133  // the iterator otherwise.
134  for (Value::use_iterator I = GV->use_begin(), E = GV->use_end(); I != E;) {
135    Use &U = *I++;
136
137    // Replacement of constants replaces all instances in the constant. So, some
138    // uses might have already been handled by the time we reach them here.
139    if (U.get() == GV)
140      replaceGlobalValueIndirectUse(GV, F, &U);
141  }
142
143  return;
144}
145} // end anonymous namespace
146
147JumpInstrTables::JumpInstrTables()
148    : ModulePass(ID), Metadata(), JITI(nullptr), TableCount(0),
149      JTType(JumpTable::Single) {
150  initializeJumpInstrTablesPass(*PassRegistry::getPassRegistry());
151}
152
153JumpInstrTables::JumpInstrTables(JumpTable::JumpTableType JTT)
154    : ModulePass(ID), Metadata(), JITI(nullptr), TableCount(0), JTType(JTT) {
155  initializeJumpInstrTablesPass(*PassRegistry::getPassRegistry());
156}
157
158JumpInstrTables::~JumpInstrTables() {}
159
160void JumpInstrTables::getAnalysisUsage(AnalysisUsage &AU) const {
161  AU.addRequired<JumpInstrTableInfo>();
162}
163
164Function *JumpInstrTables::insertEntry(Module &M, Function *Target) {
165  FunctionType *OrigFunTy = Target->getFunctionType();
166  FunctionType *FunTy = transformType(OrigFunTy);
167
168  JumpMap::iterator it = Metadata.find(FunTy);
169  if (Metadata.end() == it) {
170    struct TableMeta Meta;
171    Meta.TableNum = TableCount;
172    Meta.Count = 0;
173    Metadata[FunTy] = Meta;
174    it = Metadata.find(FunTy);
175    ++NumJumpTables;
176    ++TableCount;
177  }
178
179  it->second.Count++;
180
181  std::string NewName(jump_func_prefix);
182  NewName += (Twine(it->second.TableNum) + "_" + Twine(it->second.Count)).str();
183  Function *JumpFun =
184      Function::Create(OrigFunTy, GlobalValue::ExternalLinkage, NewName, &M);
185  // The section for this table
186  JumpFun->setSection((jump_section_prefix + Twine(it->second.TableNum)).str());
187  JITI->insertEntry(FunTy, Target, JumpFun);
188
189  ++NumFuncsInJumpTables;
190  return JumpFun;
191}
192
193bool JumpInstrTables::hasTable(FunctionType *FunTy) {
194  FunctionType *TransTy = transformType(FunTy);
195  return Metadata.end() != Metadata.find(TransTy);
196}
197
198FunctionType *JumpInstrTables::transformType(FunctionType *FunTy) {
199  // Returning nullptr forces all types into the same table, since all types map
200  // to the same type
201  Type *VoidPtrTy = Type::getInt8PtrTy(FunTy->getContext());
202
203  // Ignore the return type.
204  Type *RetTy = VoidPtrTy;
205  bool IsVarArg = FunTy->isVarArg();
206  std::vector<Type *> ParamTys(FunTy->getNumParams());
207  FunctionType::param_iterator PI, PE;
208  int i = 0;
209
210  std::vector<Type *> EmptyParams;
211  Type *Int32Ty = Type::getInt32Ty(FunTy->getContext());
212  FunctionType *VoidFnTy = FunctionType::get(
213      Type::getVoidTy(FunTy->getContext()), EmptyParams, false);
214  switch (JTType) {
215  case JumpTable::Single:
216
217    return FunctionType::get(RetTy, EmptyParams, false);
218  case JumpTable::Arity:
219    // Transform all types to void* so that all functions with the same arity
220    // end up in the same table.
221    for (PI = FunTy->param_begin(), PE = FunTy->param_end(); PI != PE;
222         PI++, i++) {
223      ParamTys[i] = VoidPtrTy;
224    }
225
226    return FunctionType::get(RetTy, ParamTys, IsVarArg);
227  case JumpTable::Simplified:
228    // Project all parameters types to one of 3 types: composite, integer, and
229    // function, matching the three subclasses of Type.
230    for (PI = FunTy->param_begin(), PE = FunTy->param_end(); PI != PE;
231         ++PI, ++i) {
232      assert((isa<IntegerType>(*PI) || isa<FunctionType>(*PI) ||
233              isa<CompositeType>(*PI)) &&
234             "This type is not an Integer or a Composite or a Function");
235      if (isa<CompositeType>(*PI)) {
236        ParamTys[i] = VoidPtrTy;
237      } else if (isa<FunctionType>(*PI)) {
238        ParamTys[i] = VoidFnTy;
239      } else if (isa<IntegerType>(*PI)) {
240        ParamTys[i] = Int32Ty;
241      }
242    }
243
244    return FunctionType::get(RetTy, ParamTys, IsVarArg);
245  case JumpTable::Full:
246    // Don't transform this type at all.
247    return FunTy;
248  }
249
250  return nullptr;
251}
252
253bool JumpInstrTables::runOnModule(Module &M) {
254  // Make sure the module is well-formed, especially with respect to jumptable.
255  if (verifyModule(M))
256    return false;
257
258  JITI = &getAnalysis<JumpInstrTableInfo>();
259
260  // Get the set of jumptable-annotated functions.
261  DenseMap<Function *, Function *> Functions;
262  for (Function &F : M) {
263    if (F.hasFnAttribute(Attribute::JumpTable)) {
264      assert(F.hasUnnamedAddr() &&
265             "Attribute 'jumptable' requires 'unnamed_addr'");
266      Functions[&F] = nullptr;
267    }
268  }
269
270  // Create the jump-table functions.
271  for (auto &KV : Functions) {
272    Function *F = KV.first;
273    KV.second = insertEntry(M, F);
274  }
275
276  // GlobalAlias is a special case, because the target of an alias statement
277  // must be a defined function. So, instead of replacing a given function in
278  // the alias, we replace all uses of aliases that target jumptable functions.
279  // Note that there's no need to create these functions, since only aliases
280  // that target known jumptable functions are replaced, and there's no way to
281  // put the jumptable annotation on a global alias.
282  DenseMap<GlobalAlias *, Function *> Aliases;
283  for (GlobalAlias &GA : M.aliases()) {
284    Constant *Aliasee = GA.getAliasee();
285    if (Function *F = dyn_cast<Function>(Aliasee)) {
286      auto it = Functions.find(F);
287      if (it != Functions.end()) {
288        Aliases[&GA] = it->second;
289      }
290    }
291  }
292
293  // Replace each address taken function with its jump-instruction table entry.
294  for (auto &KV : Functions)
295    replaceValueWithFunction(KV.first, KV.second);
296
297  for (auto &KV : Aliases)
298    replaceValueWithFunction(KV.first, KV.second);
299
300  return !Functions.empty();
301}
302