1//===-- CrossDSOCFI.cpp - Externalize this module's CFI checks ------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass exports all llvm.bitset's found in the module in the form of a
11// __cfi_check function, which can be used to verify cross-DSO call targets.
12//
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Transforms/IPO.h"
16#include "llvm/ADT/DenseSet.h"
17#include "llvm/ADT/EquivalenceClasses.h"
18#include "llvm/ADT/Statistic.h"
19#include "llvm/IR/Constant.h"
20#include "llvm/IR/Constants.h"
21#include "llvm/IR/Function.h"
22#include "llvm/IR/GlobalObject.h"
23#include "llvm/IR/GlobalVariable.h"
24#include "llvm/IR/IRBuilder.h"
25#include "llvm/IR/Instructions.h"
26#include "llvm/IR/Intrinsics.h"
27#include "llvm/IR/MDBuilder.h"
28#include "llvm/IR/Module.h"
29#include "llvm/IR/Operator.h"
30#include "llvm/Pass.h"
31#include "llvm/Support/Debug.h"
32#include "llvm/Support/raw_ostream.h"
33#include "llvm/Transforms/Utils/BasicBlockUtils.h"
34
35using namespace llvm;
36
37#define DEBUG_TYPE "cross-dso-cfi"
38
39STATISTIC(TypeIds, "Number of unique type identifiers");
40
41namespace {
42
43struct CrossDSOCFI : public ModulePass {
44  static char ID;
45  CrossDSOCFI() : ModulePass(ID) {
46    initializeCrossDSOCFIPass(*PassRegistry::getPassRegistry());
47  }
48
49  Module *M;
50  MDNode *VeryLikelyWeights;
51
52  ConstantInt *extractBitSetTypeId(MDNode *MD);
53  void buildCFICheck();
54
55  bool doInitialization(Module &M) override;
56  bool runOnModule(Module &M) override;
57};
58
59} // anonymous namespace
60
61INITIALIZE_PASS_BEGIN(CrossDSOCFI, "cross-dso-cfi", "Cross-DSO CFI", false,
62                      false)
63INITIALIZE_PASS_END(CrossDSOCFI, "cross-dso-cfi", "Cross-DSO CFI", false, false)
64char CrossDSOCFI::ID = 0;
65
66ModulePass *llvm::createCrossDSOCFIPass() { return new CrossDSOCFI; }
67
68bool CrossDSOCFI::doInitialization(Module &Mod) {
69  M = &Mod;
70  VeryLikelyWeights =
71      MDBuilder(M->getContext()).createBranchWeights((1U << 20) - 1, 1);
72
73  return false;
74}
75
76/// extractBitSetTypeId - Extracts TypeId from a hash-based bitset MDNode.
77ConstantInt *CrossDSOCFI::extractBitSetTypeId(MDNode *MD) {
78  // This check excludes vtables for classes inside anonymous namespaces.
79  auto TM = dyn_cast<ValueAsMetadata>(MD->getOperand(0));
80  if (!TM)
81    return nullptr;
82  auto C = dyn_cast_or_null<ConstantInt>(TM->getValue());
83  if (!C) return nullptr;
84  // We are looking for i64 constants.
85  if (C->getBitWidth() != 64) return nullptr;
86
87  // Sanity check.
88  auto FM = dyn_cast_or_null<ValueAsMetadata>(MD->getOperand(1));
89  // Can be null if a function was removed by an optimization.
90  if (FM) {
91    auto F = dyn_cast<Function>(FM->getValue());
92    // But can never be a function declaration.
93    assert(!F || !F->isDeclaration());
94    (void)F; // Suppress unused variable warning in the no-asserts build.
95  }
96  return C;
97}
98
99/// buildCFICheck - emits __cfi_check for the current module.
100void CrossDSOCFI::buildCFICheck() {
101  // FIXME: verify that __cfi_check ends up near the end of the code section,
102  // but before the jump slots created in LowerBitSets.
103  llvm::DenseSet<uint64_t> BitSetIds;
104  NamedMDNode *BitSetNM = M->getNamedMetadata("llvm.bitsets");
105
106  if (BitSetNM)
107    for (unsigned I = 0, E = BitSetNM->getNumOperands(); I != E; ++I)
108      if (ConstantInt *TypeId = extractBitSetTypeId(BitSetNM->getOperand(I)))
109        BitSetIds.insert(TypeId->getZExtValue());
110
111  LLVMContext &Ctx = M->getContext();
112  Constant *C = M->getOrInsertFunction(
113      "__cfi_check",
114      FunctionType::get(
115          Type::getVoidTy(Ctx),
116          {Type::getInt64Ty(Ctx), PointerType::getUnqual(Type::getInt8Ty(Ctx))},
117          false));
118  Function *F = dyn_cast<Function>(C);
119  F->setAlignment(4096);
120  auto args = F->arg_begin();
121  Argument &CallSiteTypeId = *(args++);
122  CallSiteTypeId.setName("CallSiteTypeId");
123  Argument &Addr = *(args++);
124  Addr.setName("Addr");
125  assert(args == F->arg_end());
126
127  BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
128
129  BasicBlock *TrapBB = BasicBlock::Create(Ctx, "trap", F);
130  IRBuilder<> IRBTrap(TrapBB);
131  Function *TrapFn = Intrinsic::getDeclaration(M, Intrinsic::trap);
132  llvm::CallInst *TrapCall = IRBTrap.CreateCall(TrapFn);
133  TrapCall->setDoesNotReturn();
134  TrapCall->setDoesNotThrow();
135  IRBTrap.CreateUnreachable();
136
137  BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", F);
138  IRBuilder<> IRBExit(ExitBB);
139  IRBExit.CreateRetVoid();
140
141  IRBuilder<> IRB(BB);
142  SwitchInst *SI = IRB.CreateSwitch(&CallSiteTypeId, TrapBB, BitSetIds.size());
143  for (uint64_t TypeId : BitSetIds) {
144    ConstantInt *CaseTypeId = ConstantInt::get(Type::getInt64Ty(Ctx), TypeId);
145    BasicBlock *TestBB = BasicBlock::Create(Ctx, "test", F);
146    IRBuilder<> IRBTest(TestBB);
147    Function *BitsetTestFn =
148        Intrinsic::getDeclaration(M, Intrinsic::bitset_test);
149
150    Value *Test = IRBTest.CreateCall(
151        BitsetTestFn, {&Addr, MetadataAsValue::get(
152                                  Ctx, ConstantAsMetadata::get(CaseTypeId))});
153    BranchInst *BI = IRBTest.CreateCondBr(Test, ExitBB, TrapBB);
154    BI->setMetadata(LLVMContext::MD_prof, VeryLikelyWeights);
155
156    SI->addCase(CaseTypeId, TestBB);
157    ++TypeIds;
158  }
159}
160
161bool CrossDSOCFI::runOnModule(Module &M) {
162  if (M.getModuleFlag("Cross-DSO CFI") == nullptr)
163    return false;
164  buildCFICheck();
165  return true;
166}
167