1//===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// Instrumentation-based profile-guided optimization
11//
12//===----------------------------------------------------------------------===//
13
14#include "CodeGenPGO.h"
15#include "CodeGenFunction.h"
16#include "clang/AST/RecursiveASTVisitor.h"
17#include "clang/AST/StmtVisitor.h"
18#include "llvm/IR/MDBuilder.h"
19#include "llvm/ProfileData/InstrProfReader.h"
20#include "llvm/Support/Endian.h"
21#include "llvm/Support/FileSystem.h"
22#include "llvm/Support/MD5.h"
23
24using namespace clang;
25using namespace CodeGen;
26
27void CodeGenPGO::setFuncName(llvm::Function *Fn) {
28  RawFuncName = Fn->getName();
29
30  // Function names may be prefixed with a binary '1' to indicate
31  // that the backend should not modify the symbols due to any platform
32  // naming convention. Do not include that '1' in the PGO profile name.
33  if (RawFuncName[0] == '\1')
34    RawFuncName = RawFuncName.substr(1);
35
36  if (!Fn->hasLocalLinkage()) {
37    PrefixedFuncName.reset(new std::string(RawFuncName));
38    return;
39  }
40
41  // For local symbols, prepend the main file name to distinguish them.
42  // Do not include the full path in the file name since there's no guarantee
43  // that it will stay the same, e.g., if the files are checked out from
44  // version control in different locations.
45  PrefixedFuncName.reset(new std::string(CGM.getCodeGenOpts().MainFileName));
46  if (PrefixedFuncName->empty())
47    PrefixedFuncName->assign("<unknown>");
48  PrefixedFuncName->append(":");
49  PrefixedFuncName->append(RawFuncName);
50}
51
52static llvm::Function *getRegisterFunc(CodeGenModule &CGM) {
53  return CGM.getModule().getFunction("__llvm_profile_register_functions");
54}
55
56static llvm::BasicBlock *getOrInsertRegisterBB(CodeGenModule &CGM) {
57  // Don't do this for Darwin.  compiler-rt uses linker magic.
58  if (CGM.getTarget().getTriple().isOSDarwin())
59    return nullptr;
60
61  // Only need to insert this once per module.
62  if (llvm::Function *RegisterF = getRegisterFunc(CGM))
63    return &RegisterF->getEntryBlock();
64
65  // Construct the function.
66  auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
67  auto *RegisterFTy = llvm::FunctionType::get(VoidTy, false);
68  auto *RegisterF = llvm::Function::Create(RegisterFTy,
69                                           llvm::GlobalValue::InternalLinkage,
70                                           "__llvm_profile_register_functions",
71                                           &CGM.getModule());
72  RegisterF->setUnnamedAddr(true);
73  if (CGM.getCodeGenOpts().DisableRedZone)
74    RegisterF->addFnAttr(llvm::Attribute::NoRedZone);
75
76  // Construct and return the entry block.
77  auto *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", RegisterF);
78  CGBuilderTy Builder(BB);
79  Builder.CreateRetVoid();
80  return BB;
81}
82
83static llvm::Constant *getOrInsertRuntimeRegister(CodeGenModule &CGM) {
84  auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
85  auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
86  auto *RuntimeRegisterTy = llvm::FunctionType::get(VoidTy, VoidPtrTy, false);
87  return CGM.getModule().getOrInsertFunction("__llvm_profile_register_function",
88                                             RuntimeRegisterTy);
89}
90
91static bool isMachO(const CodeGenModule &CGM) {
92  return CGM.getTarget().getTriple().isOSBinFormatMachO();
93}
94
95static StringRef getCountersSection(const CodeGenModule &CGM) {
96  return isMachO(CGM) ? "__DATA,__llvm_prf_cnts" : "__llvm_prf_cnts";
97}
98
99static StringRef getNameSection(const CodeGenModule &CGM) {
100  return isMachO(CGM) ? "__DATA,__llvm_prf_names" : "__llvm_prf_names";
101}
102
103static StringRef getDataSection(const CodeGenModule &CGM) {
104  return isMachO(CGM) ? "__DATA,__llvm_prf_data" : "__llvm_prf_data";
105}
106
107llvm::GlobalVariable *CodeGenPGO::buildDataVar() {
108  // Create name variable.
109  llvm::LLVMContext &Ctx = CGM.getLLVMContext();
110  auto *VarName = llvm::ConstantDataArray::getString(Ctx, getFuncName(),
111                                                     false);
112  auto *Name = new llvm::GlobalVariable(CGM.getModule(), VarName->getType(),
113                                        true, VarLinkage, VarName,
114                                        getFuncVarName("name"));
115  Name->setSection(getNameSection(CGM));
116  Name->setAlignment(1);
117
118  // Create data variable.
119  auto *Int32Ty = llvm::Type::getInt32Ty(Ctx);
120  auto *Int64Ty = llvm::Type::getInt64Ty(Ctx);
121  auto *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx);
122  auto *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx);
123  llvm::Type *DataTypes[] = {
124    Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int64PtrTy
125  };
126  auto *DataTy = llvm::StructType::get(Ctx, makeArrayRef(DataTypes));
127  llvm::Constant *DataVals[] = {
128    llvm::ConstantInt::get(Int32Ty, getFuncName().size()),
129    llvm::ConstantInt::get(Int32Ty, NumRegionCounters),
130    llvm::ConstantInt::get(Int64Ty, FunctionHash),
131    llvm::ConstantExpr::getBitCast(Name, Int8PtrTy),
132    llvm::ConstantExpr::getBitCast(RegionCounters, Int64PtrTy)
133  };
134  auto *Data =
135    new llvm::GlobalVariable(CGM.getModule(), DataTy, true, VarLinkage,
136                             llvm::ConstantStruct::get(DataTy, DataVals),
137                             getFuncVarName("data"));
138
139  // All the data should be packed into an array in its own section.
140  Data->setSection(getDataSection(CGM));
141  Data->setAlignment(8);
142
143  // Hide all these symbols so that we correctly get a copy for each
144  // executable.  The profile format expects names and counters to be
145  // contiguous, so references into shared objects would be invalid.
146  if (!llvm::GlobalValue::isLocalLinkage(VarLinkage)) {
147    Name->setVisibility(llvm::GlobalValue::HiddenVisibility);
148    Data->setVisibility(llvm::GlobalValue::HiddenVisibility);
149    RegionCounters->setVisibility(llvm::GlobalValue::HiddenVisibility);
150  }
151
152  // Make sure the data doesn't get deleted.
153  CGM.addUsedGlobal(Data);
154  return Data;
155}
156
157void CodeGenPGO::emitInstrumentationData() {
158  if (!RegionCounters)
159    return;
160
161  // Build the data.
162  auto *Data = buildDataVar();
163
164  // Register the data.
165  auto *RegisterBB = getOrInsertRegisterBB(CGM);
166  if (!RegisterBB)
167    return;
168  CGBuilderTy Builder(RegisterBB->getTerminator());
169  auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
170  Builder.CreateCall(getOrInsertRuntimeRegister(CGM),
171                     Builder.CreateBitCast(Data, VoidPtrTy));
172}
173
174llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) {
175  if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
176    return nullptr;
177
178  assert(CGM.getModule().getFunction("__llvm_profile_init") == nullptr &&
179         "profile initialization already emitted");
180
181  // Get the function to call at initialization.
182  llvm::Constant *RegisterF = getRegisterFunc(CGM);
183  if (!RegisterF)
184    return nullptr;
185
186  // Create the initialization function.
187  auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
188  auto *F = llvm::Function::Create(llvm::FunctionType::get(VoidTy, false),
189                                   llvm::GlobalValue::InternalLinkage,
190                                   "__llvm_profile_init", &CGM.getModule());
191  F->setUnnamedAddr(true);
192  F->addFnAttr(llvm::Attribute::NoInline);
193  if (CGM.getCodeGenOpts().DisableRedZone)
194    F->addFnAttr(llvm::Attribute::NoRedZone);
195
196  // Add the basic block and the necessary calls.
197  CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F));
198  Builder.CreateCall(RegisterF);
199  Builder.CreateRetVoid();
200
201  return F;
202}
203
204namespace {
205/// \brief Stable hasher for PGO region counters.
206///
207/// PGOHash produces a stable hash of a given function's control flow.
208///
209/// Changing the output of this hash will invalidate all previously generated
210/// profiles -- i.e., don't do it.
211///
212/// \note  When this hash does eventually change (years?), we still need to
213/// support old hashes.  We'll need to pull in the version number from the
214/// profile data format and use the matching hash function.
215class PGOHash {
216  uint64_t Working;
217  unsigned Count;
218  llvm::MD5 MD5;
219
220  static const int NumBitsPerType = 6;
221  static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
222  static const unsigned TooBig = 1u << NumBitsPerType;
223
224public:
225  /// \brief Hash values for AST nodes.
226  ///
227  /// Distinct values for AST nodes that have region counters attached.
228  ///
229  /// These values must be stable.  All new members must be added at the end,
230  /// and no members should be removed.  Changing the enumeration value for an
231  /// AST node will affect the hash of every function that contains that node.
232  enum HashType : unsigned char {
233    None = 0,
234    LabelStmt = 1,
235    WhileStmt,
236    DoStmt,
237    ForStmt,
238    CXXForRangeStmt,
239    ObjCForCollectionStmt,
240    SwitchStmt,
241    CaseStmt,
242    DefaultStmt,
243    IfStmt,
244    CXXTryStmt,
245    CXXCatchStmt,
246    ConditionalOperator,
247    BinaryOperatorLAnd,
248    BinaryOperatorLOr,
249    BinaryConditionalOperator,
250
251    // Keep this last.  It's for the static assert that follows.
252    LastHashType
253  };
254  static_assert(LastHashType <= TooBig, "Too many types in HashType");
255
256  // TODO: When this format changes, take in a version number here, and use the
257  // old hash calculation for file formats that used the old hash.
258  PGOHash() : Working(0), Count(0) {}
259  void combine(HashType Type);
260  uint64_t finalize();
261};
262const int PGOHash::NumBitsPerType;
263const unsigned PGOHash::NumTypesPerWord;
264const unsigned PGOHash::TooBig;
265
266  /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
267  struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
268    /// The next counter value to assign.
269    unsigned NextCounter;
270    /// The function hash.
271    PGOHash Hash;
272    /// The map of statements to counters.
273    llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
274
275    MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
276        : NextCounter(0), CounterMap(CounterMap) {}
277
278    // Blocks and lambdas are handled as separate functions, so we need not
279    // traverse them in the parent context.
280    bool TraverseBlockExpr(BlockExpr *BE) { return true; }
281    bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
282    bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
283
284    bool VisitDecl(const Decl *D) {
285      switch (D->getKind()) {
286      default:
287        break;
288      case Decl::Function:
289      case Decl::CXXMethod:
290      case Decl::CXXConstructor:
291      case Decl::CXXDestructor:
292      case Decl::CXXConversion:
293      case Decl::ObjCMethod:
294      case Decl::Block:
295      case Decl::Captured:
296        CounterMap[D->getBody()] = NextCounter++;
297        break;
298      }
299      return true;
300    }
301
302    bool VisitStmt(const Stmt *S) {
303      auto Type = getHashType(S);
304      if (Type == PGOHash::None)
305        return true;
306
307      CounterMap[S] = NextCounter++;
308      Hash.combine(Type);
309      return true;
310    }
311    PGOHash::HashType getHashType(const Stmt *S) {
312      switch (S->getStmtClass()) {
313      default:
314        break;
315      case Stmt::LabelStmtClass:
316        return PGOHash::LabelStmt;
317      case Stmt::WhileStmtClass:
318        return PGOHash::WhileStmt;
319      case Stmt::DoStmtClass:
320        return PGOHash::DoStmt;
321      case Stmt::ForStmtClass:
322        return PGOHash::ForStmt;
323      case Stmt::CXXForRangeStmtClass:
324        return PGOHash::CXXForRangeStmt;
325      case Stmt::ObjCForCollectionStmtClass:
326        return PGOHash::ObjCForCollectionStmt;
327      case Stmt::SwitchStmtClass:
328        return PGOHash::SwitchStmt;
329      case Stmt::CaseStmtClass:
330        return PGOHash::CaseStmt;
331      case Stmt::DefaultStmtClass:
332        return PGOHash::DefaultStmt;
333      case Stmt::IfStmtClass:
334        return PGOHash::IfStmt;
335      case Stmt::CXXTryStmtClass:
336        return PGOHash::CXXTryStmt;
337      case Stmt::CXXCatchStmtClass:
338        return PGOHash::CXXCatchStmt;
339      case Stmt::ConditionalOperatorClass:
340        return PGOHash::ConditionalOperator;
341      case Stmt::BinaryConditionalOperatorClass:
342        return PGOHash::BinaryConditionalOperator;
343      case Stmt::BinaryOperatorClass: {
344        const BinaryOperator *BO = cast<BinaryOperator>(S);
345        if (BO->getOpcode() == BO_LAnd)
346          return PGOHash::BinaryOperatorLAnd;
347        if (BO->getOpcode() == BO_LOr)
348          return PGOHash::BinaryOperatorLOr;
349        break;
350      }
351      }
352      return PGOHash::None;
353    }
354  };
355
356  /// A StmtVisitor that propagates the raw counts through the AST and
357  /// records the count at statements where the value may change.
358  struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
359    /// PGO state.
360    CodeGenPGO &PGO;
361
362    /// A flag that is set when the current count should be recorded on the
363    /// next statement, such as at the exit of a loop.
364    bool RecordNextStmtCount;
365
366    /// The map of statements to count values.
367    llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
368
369    /// BreakContinueStack - Keep counts of breaks and continues inside loops.
370    struct BreakContinue {
371      uint64_t BreakCount;
372      uint64_t ContinueCount;
373      BreakContinue() : BreakCount(0), ContinueCount(0) {}
374    };
375    SmallVector<BreakContinue, 8> BreakContinueStack;
376
377    ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
378                        CodeGenPGO &PGO)
379        : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
380
381    void RecordStmtCount(const Stmt *S) {
382      if (RecordNextStmtCount) {
383        CountMap[S] = PGO.getCurrentRegionCount();
384        RecordNextStmtCount = false;
385      }
386    }
387
388    void VisitStmt(const Stmt *S) {
389      RecordStmtCount(S);
390      for (Stmt::const_child_range I = S->children(); I; ++I) {
391        if (*I)
392         this->Visit(*I);
393      }
394    }
395
396    void VisitFunctionDecl(const FunctionDecl *D) {
397      // Counter tracks entry to the function body.
398      RegionCounter Cnt(PGO, D->getBody());
399      Cnt.beginRegion();
400      CountMap[D->getBody()] = PGO.getCurrentRegionCount();
401      Visit(D->getBody());
402    }
403
404    // Skip lambda expressions. We visit these as FunctionDecls when we're
405    // generating them and aren't interested in the body when generating a
406    // parent context.
407    void VisitLambdaExpr(const LambdaExpr *LE) {}
408
409    void VisitCapturedDecl(const CapturedDecl *D) {
410      // Counter tracks entry to the capture body.
411      RegionCounter Cnt(PGO, D->getBody());
412      Cnt.beginRegion();
413      CountMap[D->getBody()] = PGO.getCurrentRegionCount();
414      Visit(D->getBody());
415    }
416
417    void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
418      // Counter tracks entry to the method body.
419      RegionCounter Cnt(PGO, D->getBody());
420      Cnt.beginRegion();
421      CountMap[D->getBody()] = PGO.getCurrentRegionCount();
422      Visit(D->getBody());
423    }
424
425    void VisitBlockDecl(const BlockDecl *D) {
426      // Counter tracks entry to the block body.
427      RegionCounter Cnt(PGO, D->getBody());
428      Cnt.beginRegion();
429      CountMap[D->getBody()] = PGO.getCurrentRegionCount();
430      Visit(D->getBody());
431    }
432
433    void VisitReturnStmt(const ReturnStmt *S) {
434      RecordStmtCount(S);
435      if (S->getRetValue())
436        Visit(S->getRetValue());
437      PGO.setCurrentRegionUnreachable();
438      RecordNextStmtCount = true;
439    }
440
441    void VisitGotoStmt(const GotoStmt *S) {
442      RecordStmtCount(S);
443      PGO.setCurrentRegionUnreachable();
444      RecordNextStmtCount = true;
445    }
446
447    void VisitLabelStmt(const LabelStmt *S) {
448      RecordNextStmtCount = false;
449      // Counter tracks the block following the label.
450      RegionCounter Cnt(PGO, S);
451      Cnt.beginRegion();
452      CountMap[S] = PGO.getCurrentRegionCount();
453      Visit(S->getSubStmt());
454    }
455
456    void VisitBreakStmt(const BreakStmt *S) {
457      RecordStmtCount(S);
458      assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
459      BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
460      PGO.setCurrentRegionUnreachable();
461      RecordNextStmtCount = true;
462    }
463
464    void VisitContinueStmt(const ContinueStmt *S) {
465      RecordStmtCount(S);
466      assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
467      BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
468      PGO.setCurrentRegionUnreachable();
469      RecordNextStmtCount = true;
470    }
471
472    void VisitWhileStmt(const WhileStmt *S) {
473      RecordStmtCount(S);
474      // Counter tracks the body of the loop.
475      RegionCounter Cnt(PGO, S);
476      BreakContinueStack.push_back(BreakContinue());
477      // Visit the body region first so the break/continue adjustments can be
478      // included when visiting the condition.
479      Cnt.beginRegion();
480      CountMap[S->getBody()] = PGO.getCurrentRegionCount();
481      Visit(S->getBody());
482      Cnt.adjustForControlFlow();
483
484      // ...then go back and propagate counts through the condition. The count
485      // at the start of the condition is the sum of the incoming edges,
486      // the backedge from the end of the loop body, and the edges from
487      // continue statements.
488      BreakContinue BC = BreakContinueStack.pop_back_val();
489      Cnt.setCurrentRegionCount(Cnt.getParentCount() +
490                                Cnt.getAdjustedCount() + BC.ContinueCount);
491      CountMap[S->getCond()] = PGO.getCurrentRegionCount();
492      Visit(S->getCond());
493      Cnt.adjustForControlFlow();
494      Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
495      RecordNextStmtCount = true;
496    }
497
498    void VisitDoStmt(const DoStmt *S) {
499      RecordStmtCount(S);
500      // Counter tracks the body of the loop.
501      RegionCounter Cnt(PGO, S);
502      BreakContinueStack.push_back(BreakContinue());
503      Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
504      CountMap[S->getBody()] = PGO.getCurrentRegionCount();
505      Visit(S->getBody());
506      Cnt.adjustForControlFlow();
507
508      BreakContinue BC = BreakContinueStack.pop_back_val();
509      // The count at the start of the condition is equal to the count at the
510      // end of the body. The adjusted count does not include either the
511      // fall-through count coming into the loop or the continue count, so add
512      // both of those separately. This is coincidentally the same equation as
513      // with while loops but for different reasons.
514      Cnt.setCurrentRegionCount(Cnt.getParentCount() +
515                                Cnt.getAdjustedCount() + BC.ContinueCount);
516      CountMap[S->getCond()] = PGO.getCurrentRegionCount();
517      Visit(S->getCond());
518      Cnt.adjustForControlFlow();
519      Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
520      RecordNextStmtCount = true;
521    }
522
523    void VisitForStmt(const ForStmt *S) {
524      RecordStmtCount(S);
525      if (S->getInit())
526        Visit(S->getInit());
527      // Counter tracks the body of the loop.
528      RegionCounter Cnt(PGO, S);
529      BreakContinueStack.push_back(BreakContinue());
530      // Visit the body region first. (This is basically the same as a while
531      // loop; see further comments in VisitWhileStmt.)
532      Cnt.beginRegion();
533      CountMap[S->getBody()] = PGO.getCurrentRegionCount();
534      Visit(S->getBody());
535      Cnt.adjustForControlFlow();
536
537      // The increment is essentially part of the body but it needs to include
538      // the count for all the continue statements.
539      if (S->getInc()) {
540        Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
541                                  BreakContinueStack.back().ContinueCount);
542        CountMap[S->getInc()] = PGO.getCurrentRegionCount();
543        Visit(S->getInc());
544        Cnt.adjustForControlFlow();
545      }
546
547      BreakContinue BC = BreakContinueStack.pop_back_val();
548
549      // ...then go back and propagate counts through the condition.
550      if (S->getCond()) {
551        Cnt.setCurrentRegionCount(Cnt.getParentCount() +
552                                  Cnt.getAdjustedCount() +
553                                  BC.ContinueCount);
554        CountMap[S->getCond()] = PGO.getCurrentRegionCount();
555        Visit(S->getCond());
556        Cnt.adjustForControlFlow();
557      }
558      Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
559      RecordNextStmtCount = true;
560    }
561
562    void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
563      RecordStmtCount(S);
564      Visit(S->getRangeStmt());
565      Visit(S->getBeginEndStmt());
566      // Counter tracks the body of the loop.
567      RegionCounter Cnt(PGO, S);
568      BreakContinueStack.push_back(BreakContinue());
569      // Visit the body region first. (This is basically the same as a while
570      // loop; see further comments in VisitWhileStmt.)
571      Cnt.beginRegion();
572      CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
573      Visit(S->getLoopVarStmt());
574      Visit(S->getBody());
575      Cnt.adjustForControlFlow();
576
577      // The increment is essentially part of the body but it needs to include
578      // the count for all the continue statements.
579      Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
580                                BreakContinueStack.back().ContinueCount);
581      CountMap[S->getInc()] = PGO.getCurrentRegionCount();
582      Visit(S->getInc());
583      Cnt.adjustForControlFlow();
584
585      BreakContinue BC = BreakContinueStack.pop_back_val();
586
587      // ...then go back and propagate counts through the condition.
588      Cnt.setCurrentRegionCount(Cnt.getParentCount() +
589                                Cnt.getAdjustedCount() +
590                                BC.ContinueCount);
591      CountMap[S->getCond()] = PGO.getCurrentRegionCount();
592      Visit(S->getCond());
593      Cnt.adjustForControlFlow();
594      Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
595      RecordNextStmtCount = true;
596    }
597
598    void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
599      RecordStmtCount(S);
600      Visit(S->getElement());
601      // Counter tracks the body of the loop.
602      RegionCounter Cnt(PGO, S);
603      BreakContinueStack.push_back(BreakContinue());
604      Cnt.beginRegion();
605      CountMap[S->getBody()] = PGO.getCurrentRegionCount();
606      Visit(S->getBody());
607      BreakContinue BC = BreakContinueStack.pop_back_val();
608      Cnt.adjustForControlFlow();
609      Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
610      RecordNextStmtCount = true;
611    }
612
613    void VisitSwitchStmt(const SwitchStmt *S) {
614      RecordStmtCount(S);
615      Visit(S->getCond());
616      PGO.setCurrentRegionUnreachable();
617      BreakContinueStack.push_back(BreakContinue());
618      Visit(S->getBody());
619      // If the switch is inside a loop, add the continue counts.
620      BreakContinue BC = BreakContinueStack.pop_back_val();
621      if (!BreakContinueStack.empty())
622        BreakContinueStack.back().ContinueCount += BC.ContinueCount;
623      // Counter tracks the exit block of the switch.
624      RegionCounter ExitCnt(PGO, S);
625      ExitCnt.beginRegion();
626      RecordNextStmtCount = true;
627    }
628
629    void VisitCaseStmt(const CaseStmt *S) {
630      RecordNextStmtCount = false;
631      // Counter for this particular case. This counts only jumps from the
632      // switch header and does not include fallthrough from the case before
633      // this one.
634      RegionCounter Cnt(PGO, S);
635      Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
636      CountMap[S] = Cnt.getCount();
637      RecordNextStmtCount = true;
638      Visit(S->getSubStmt());
639    }
640
641    void VisitDefaultStmt(const DefaultStmt *S) {
642      RecordNextStmtCount = false;
643      // Counter for this default case. This does not include fallthrough from
644      // the previous case.
645      RegionCounter Cnt(PGO, S);
646      Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
647      CountMap[S] = Cnt.getCount();
648      RecordNextStmtCount = true;
649      Visit(S->getSubStmt());
650    }
651
652    void VisitIfStmt(const IfStmt *S) {
653      RecordStmtCount(S);
654      // Counter tracks the "then" part of an if statement. The count for
655      // the "else" part, if it exists, will be calculated from this counter.
656      RegionCounter Cnt(PGO, S);
657      Visit(S->getCond());
658
659      Cnt.beginRegion();
660      CountMap[S->getThen()] = PGO.getCurrentRegionCount();
661      Visit(S->getThen());
662      Cnt.adjustForControlFlow();
663
664      if (S->getElse()) {
665        Cnt.beginElseRegion();
666        CountMap[S->getElse()] = PGO.getCurrentRegionCount();
667        Visit(S->getElse());
668        Cnt.adjustForControlFlow();
669      }
670      Cnt.applyAdjustmentsToRegion(0);
671      RecordNextStmtCount = true;
672    }
673
674    void VisitCXXTryStmt(const CXXTryStmt *S) {
675      RecordStmtCount(S);
676      Visit(S->getTryBlock());
677      for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
678        Visit(S->getHandler(I));
679      // Counter tracks the continuation block of the try statement.
680      RegionCounter Cnt(PGO, S);
681      Cnt.beginRegion();
682      RecordNextStmtCount = true;
683    }
684
685    void VisitCXXCatchStmt(const CXXCatchStmt *S) {
686      RecordNextStmtCount = false;
687      // Counter tracks the catch statement's handler block.
688      RegionCounter Cnt(PGO, S);
689      Cnt.beginRegion();
690      CountMap[S] = PGO.getCurrentRegionCount();
691      Visit(S->getHandlerBlock());
692    }
693
694    void VisitAbstractConditionalOperator(
695        const AbstractConditionalOperator *E) {
696      RecordStmtCount(E);
697      // Counter tracks the "true" part of a conditional operator. The
698      // count in the "false" part will be calculated from this counter.
699      RegionCounter Cnt(PGO, E);
700      Visit(E->getCond());
701
702      Cnt.beginRegion();
703      CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount();
704      Visit(E->getTrueExpr());
705      Cnt.adjustForControlFlow();
706
707      Cnt.beginElseRegion();
708      CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount();
709      Visit(E->getFalseExpr());
710      Cnt.adjustForControlFlow();
711
712      Cnt.applyAdjustmentsToRegion(0);
713      RecordNextStmtCount = true;
714    }
715
716    void VisitBinLAnd(const BinaryOperator *E) {
717      RecordStmtCount(E);
718      // Counter tracks the right hand side of a logical and operator.
719      RegionCounter Cnt(PGO, E);
720      Visit(E->getLHS());
721      Cnt.beginRegion();
722      CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
723      Visit(E->getRHS());
724      Cnt.adjustForControlFlow();
725      Cnt.applyAdjustmentsToRegion(0);
726      RecordNextStmtCount = true;
727    }
728
729    void VisitBinLOr(const BinaryOperator *E) {
730      RecordStmtCount(E);
731      // Counter tracks the right hand side of a logical or operator.
732      RegionCounter Cnt(PGO, E);
733      Visit(E->getLHS());
734      Cnt.beginRegion();
735      CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
736      Visit(E->getRHS());
737      Cnt.adjustForControlFlow();
738      Cnt.applyAdjustmentsToRegion(0);
739      RecordNextStmtCount = true;
740    }
741  };
742}
743
744void PGOHash::combine(HashType Type) {
745  // Check that we never combine 0 and only have six bits.
746  assert(Type && "Hash is invalid: unexpected type 0");
747  assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
748
749  // Pass through MD5 if enough work has built up.
750  if (Count && Count % NumTypesPerWord == 0) {
751    using namespace llvm::support;
752    uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
753    MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
754    Working = 0;
755  }
756
757  // Accumulate the current type.
758  ++Count;
759  Working = Working << NumBitsPerType | Type;
760}
761
762uint64_t PGOHash::finalize() {
763  // Use Working as the hash directly if we never used MD5.
764  if (Count <= NumTypesPerWord)
765    // No need to byte swap here, since none of the math was endian-dependent.
766    // This number will be byte-swapped as required on endianness transitions,
767    // so we will see the same value on the other side.
768    return Working;
769
770  // Check for remaining work in Working.
771  if (Working)
772    MD5.update(Working);
773
774  // Finalize the MD5 and return the hash.
775  llvm::MD5::MD5Result Result;
776  MD5.final(Result);
777  using namespace llvm::support;
778  return endian::read<uint64_t, little, unaligned>(Result);
779}
780
781static void emitRuntimeHook(CodeGenModule &CGM) {
782  const char *const RuntimeVarName = "__llvm_profile_runtime";
783  const char *const RuntimeUserName = "__llvm_profile_runtime_user";
784  if (CGM.getModule().getGlobalVariable(RuntimeVarName))
785    return;
786
787  // Declare the runtime hook.
788  llvm::LLVMContext &Ctx = CGM.getLLVMContext();
789  auto *Int32Ty = llvm::Type::getInt32Ty(Ctx);
790  auto *Var = new llvm::GlobalVariable(CGM.getModule(), Int32Ty, false,
791                                       llvm::GlobalValue::ExternalLinkage,
792                                       nullptr, RuntimeVarName);
793
794  // Make a function that uses it.
795  auto *User = llvm::Function::Create(llvm::FunctionType::get(Int32Ty, false),
796                                      llvm::GlobalValue::LinkOnceODRLinkage,
797                                      RuntimeUserName, &CGM.getModule());
798  User->addFnAttr(llvm::Attribute::NoInline);
799  if (CGM.getCodeGenOpts().DisableRedZone)
800    User->addFnAttr(llvm::Attribute::NoRedZone);
801  CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", User));
802  auto *Load = Builder.CreateLoad(Var);
803  Builder.CreateRet(Load);
804
805  // Create a use of the function.  Now the definition of the runtime variable
806  // should get pulled in, along with any static initializears.
807  CGM.addUsedGlobal(User);
808}
809
810void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
811  bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
812  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
813  if (!InstrumentRegions && !PGOReader)
814    return;
815  if (D->isImplicit())
816    return;
817  setFuncName(Fn);
818
819  // Set the linkage for variables based on the function linkage.  Usually, we
820  // want to match it, but available_externally and extern_weak both have the
821  // wrong semantics.
822  VarLinkage = Fn->getLinkage();
823  switch (VarLinkage) {
824  case llvm::GlobalValue::ExternalWeakLinkage:
825    VarLinkage = llvm::GlobalValue::LinkOnceAnyLinkage;
826    break;
827  case llvm::GlobalValue::AvailableExternallyLinkage:
828    VarLinkage = llvm::GlobalValue::LinkOnceODRLinkage;
829    break;
830  default:
831    break;
832  }
833
834  mapRegionCounters(D);
835  if (InstrumentRegions) {
836    emitRuntimeHook(CGM);
837    emitCounterVariables();
838  }
839  if (PGOReader) {
840    SourceManager &SM = CGM.getContext().getSourceManager();
841    loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
842    computeRegionCounts(D);
843    applyFunctionAttributes(PGOReader, Fn);
844  }
845}
846
847void CodeGenPGO::mapRegionCounters(const Decl *D) {
848  RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
849  MapRegionCounters Walker(*RegionCounterMap);
850  if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
851    Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
852  else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
853    Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
854  else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
855    Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
856  else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
857    Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
858  assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
859  NumRegionCounters = Walker.NextCounter;
860  FunctionHash = Walker.Hash.finalize();
861}
862
863void CodeGenPGO::computeRegionCounts(const Decl *D) {
864  StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
865  ComputeRegionCounts Walker(*StmtCountMap, *this);
866  if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
867    Walker.VisitFunctionDecl(FD);
868  else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
869    Walker.VisitObjCMethodDecl(MD);
870  else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
871    Walker.VisitBlockDecl(BD);
872  else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
873    Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
874}
875
876void
877CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
878                                    llvm::Function *Fn) {
879  if (!haveRegionCounts())
880    return;
881
882  uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
883  uint64_t FunctionCount = getRegionCount(0);
884  if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
885    // Turn on InlineHint attribute for hot functions.
886    // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
887    Fn->addFnAttr(llvm::Attribute::InlineHint);
888  else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
889    // Turn on Cold attribute for cold functions.
890    // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
891    Fn->addFnAttr(llvm::Attribute::Cold);
892}
893
894void CodeGenPGO::emitCounterVariables() {
895  llvm::LLVMContext &Ctx = CGM.getLLVMContext();
896  llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx),
897                                                    NumRegionCounters);
898  RegionCounters =
899    new llvm::GlobalVariable(CGM.getModule(), CounterTy, false, VarLinkage,
900                             llvm::Constant::getNullValue(CounterTy),
901                             getFuncVarName("counters"));
902  RegionCounters->setAlignment(8);
903  RegionCounters->setSection(getCountersSection(CGM));
904}
905
906void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
907  if (!RegionCounters)
908    return;
909  llvm::Value *Addr =
910    Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter);
911  llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount");
912  Count = Builder.CreateAdd(Count, Builder.getInt64(1));
913  Builder.CreateStore(Count, Addr);
914}
915
916void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
917                                  bool IsInMainFile) {
918  CGM.getPGOStats().addVisited(IsInMainFile);
919  RegionCounts.reset(new std::vector<uint64_t>);
920  uint64_t Hash;
921  if (PGOReader->getFunctionCounts(getFuncName(), Hash, *RegionCounts)) {
922    CGM.getPGOStats().addMissing(IsInMainFile);
923    RegionCounts.reset();
924  } else if (Hash != FunctionHash ||
925             RegionCounts->size() != NumRegionCounters) {
926    CGM.getPGOStats().addMismatched(IsInMainFile);
927    RegionCounts.reset();
928  }
929}
930
931void CodeGenPGO::destroyRegionCounters() {
932  RegionCounterMap.reset();
933  StmtCountMap.reset();
934  RegionCounts.reset();
935  RegionCounters = nullptr;
936}
937
938/// \brief Calculate what to divide by to scale weights.
939///
940/// Given the maximum weight, calculate a divisor that will scale all the
941/// weights to strictly less than UINT32_MAX.
942static uint64_t calculateWeightScale(uint64_t MaxWeight) {
943  return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
944}
945
946/// \brief Scale an individual branch weight (and add 1).
947///
948/// Scale a 64-bit weight down to 32-bits using \c Scale.
949///
950/// According to Laplace's Rule of Succession, it is better to compute the
951/// weight based on the count plus 1, so universally add 1 to the value.
952///
953/// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
954/// greater than \c Weight.
955static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
956  assert(Scale && "scale by 0?");
957  uint64_t Scaled = Weight / Scale + 1;
958  assert(Scaled <= UINT32_MAX && "overflow 32-bits");
959  return Scaled;
960}
961
962llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
963                                              uint64_t FalseCount) {
964  // Check for empty weights.
965  if (!TrueCount && !FalseCount)
966    return nullptr;
967
968  // Calculate how to scale down to 32-bits.
969  uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
970
971  llvm::MDBuilder MDHelper(CGM.getLLVMContext());
972  return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
973                                      scaleBranchWeight(FalseCount, Scale));
974}
975
976llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
977  // We need at least two elements to create meaningful weights.
978  if (Weights.size() < 2)
979    return nullptr;
980
981  // Check for empty weights.
982  uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
983  if (MaxWeight == 0)
984    return nullptr;
985
986  // Calculate how to scale down to 32-bits.
987  uint64_t Scale = calculateWeightScale(MaxWeight);
988
989  SmallVector<uint32_t, 16> ScaledWeights;
990  ScaledWeights.reserve(Weights.size());
991  for (uint64_t W : Weights)
992    ScaledWeights.push_back(scaleBranchWeight(W, Scale));
993
994  llvm::MDBuilder MDHelper(CGM.getLLVMContext());
995  return MDHelper.createBranchWeights(ScaledWeights);
996}
997
998llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
999                                            RegionCounter &Cnt) {
1000  if (!haveRegionCounts())
1001    return nullptr;
1002  uint64_t LoopCount = Cnt.getCount();
1003  uint64_t CondCount = 0;
1004  bool Found = getStmtCount(Cond, CondCount);
1005  assert(Found && "missing expected loop condition count");
1006  (void)Found;
1007  if (CondCount == 0)
1008    return nullptr;
1009  return createBranchWeights(LoopCount,
1010                             std::max(CondCount, LoopCount) - LoopCount);
1011}
1012