1//===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// Instrumentation-based profile-guided optimization
11//
12//===----------------------------------------------------------------------===//
13
14#include "CodeGenPGO.h"
15#include "CodeGenFunction.h"
16#include "CoverageMappingGen.h"
17#include "clang/AST/RecursiveASTVisitor.h"
18#include "clang/AST/StmtVisitor.h"
19#include "llvm/IR/Intrinsics.h"
20#include "llvm/IR/MDBuilder.h"
21#include "llvm/ProfileData/InstrProfReader.h"
22#include "llvm/Support/Endian.h"
23#include "llvm/Support/FileSystem.h"
24#include "llvm/Support/MD5.h"
25
26using namespace clang;
27using namespace CodeGen;
28
29void CodeGenPGO::setFuncName(StringRef Name,
30                             llvm::GlobalValue::LinkageTypes Linkage) {
31  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
32  FuncName = llvm::getPGOFuncName(
33      Name, Linkage, CGM.getCodeGenOpts().MainFileName,
34      PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
35
36  // If we're generating a profile, create a variable for the name.
37  if (CGM.getCodeGenOpts().ProfileInstrGenerate)
38    FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
39}
40
41void CodeGenPGO::setFuncName(llvm::Function *Fn) {
42  setFuncName(Fn->getName(), Fn->getLinkage());
43}
44
45namespace {
46/// \brief Stable hasher for PGO region counters.
47///
48/// PGOHash produces a stable hash of a given function's control flow.
49///
50/// Changing the output of this hash will invalidate all previously generated
51/// profiles -- i.e., don't do it.
52///
53/// \note  When this hash does eventually change (years?), we still need to
54/// support old hashes.  We'll need to pull in the version number from the
55/// profile data format and use the matching hash function.
56class PGOHash {
57  uint64_t Working;
58  unsigned Count;
59  llvm::MD5 MD5;
60
61  static const int NumBitsPerType = 6;
62  static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
63  static const unsigned TooBig = 1u << NumBitsPerType;
64
65public:
66  /// \brief Hash values for AST nodes.
67  ///
68  /// Distinct values for AST nodes that have region counters attached.
69  ///
70  /// These values must be stable.  All new members must be added at the end,
71  /// and no members should be removed.  Changing the enumeration value for an
72  /// AST node will affect the hash of every function that contains that node.
73  enum HashType : unsigned char {
74    None = 0,
75    LabelStmt = 1,
76    WhileStmt,
77    DoStmt,
78    ForStmt,
79    CXXForRangeStmt,
80    ObjCForCollectionStmt,
81    SwitchStmt,
82    CaseStmt,
83    DefaultStmt,
84    IfStmt,
85    CXXTryStmt,
86    CXXCatchStmt,
87    ConditionalOperator,
88    BinaryOperatorLAnd,
89    BinaryOperatorLOr,
90    BinaryConditionalOperator,
91
92    // Keep this last.  It's for the static assert that follows.
93    LastHashType
94  };
95  static_assert(LastHashType <= TooBig, "Too many types in HashType");
96
97  // TODO: When this format changes, take in a version number here, and use the
98  // old hash calculation for file formats that used the old hash.
99  PGOHash() : Working(0), Count(0) {}
100  void combine(HashType Type);
101  uint64_t finalize();
102};
103const int PGOHash::NumBitsPerType;
104const unsigned PGOHash::NumTypesPerWord;
105const unsigned PGOHash::TooBig;
106
107/// A RecursiveASTVisitor that fills a map of statements to PGO counters.
108struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
109  /// The next counter value to assign.
110  unsigned NextCounter;
111  /// The function hash.
112  PGOHash Hash;
113  /// The map of statements to counters.
114  llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
115
116  MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
117      : NextCounter(0), CounterMap(CounterMap) {}
118
119  // Blocks and lambdas are handled as separate functions, so we need not
120  // traverse them in the parent context.
121  bool TraverseBlockExpr(BlockExpr *BE) { return true; }
122  bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
123  bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
124
125  bool VisitDecl(const Decl *D) {
126    switch (D->getKind()) {
127    default:
128      break;
129    case Decl::Function:
130    case Decl::CXXMethod:
131    case Decl::CXXConstructor:
132    case Decl::CXXDestructor:
133    case Decl::CXXConversion:
134    case Decl::ObjCMethod:
135    case Decl::Block:
136    case Decl::Captured:
137      CounterMap[D->getBody()] = NextCounter++;
138      break;
139    }
140    return true;
141  }
142
143  bool VisitStmt(const Stmt *S) {
144    auto Type = getHashType(S);
145    if (Type == PGOHash::None)
146      return true;
147
148    CounterMap[S] = NextCounter++;
149    Hash.combine(Type);
150    return true;
151  }
152  PGOHash::HashType getHashType(const Stmt *S) {
153    switch (S->getStmtClass()) {
154    default:
155      break;
156    case Stmt::LabelStmtClass:
157      return PGOHash::LabelStmt;
158    case Stmt::WhileStmtClass:
159      return PGOHash::WhileStmt;
160    case Stmt::DoStmtClass:
161      return PGOHash::DoStmt;
162    case Stmt::ForStmtClass:
163      return PGOHash::ForStmt;
164    case Stmt::CXXForRangeStmtClass:
165      return PGOHash::CXXForRangeStmt;
166    case Stmt::ObjCForCollectionStmtClass:
167      return PGOHash::ObjCForCollectionStmt;
168    case Stmt::SwitchStmtClass:
169      return PGOHash::SwitchStmt;
170    case Stmt::CaseStmtClass:
171      return PGOHash::CaseStmt;
172    case Stmt::DefaultStmtClass:
173      return PGOHash::DefaultStmt;
174    case Stmt::IfStmtClass:
175      return PGOHash::IfStmt;
176    case Stmt::CXXTryStmtClass:
177      return PGOHash::CXXTryStmt;
178    case Stmt::CXXCatchStmtClass:
179      return PGOHash::CXXCatchStmt;
180    case Stmt::ConditionalOperatorClass:
181      return PGOHash::ConditionalOperator;
182    case Stmt::BinaryConditionalOperatorClass:
183      return PGOHash::BinaryConditionalOperator;
184    case Stmt::BinaryOperatorClass: {
185      const BinaryOperator *BO = cast<BinaryOperator>(S);
186      if (BO->getOpcode() == BO_LAnd)
187        return PGOHash::BinaryOperatorLAnd;
188      if (BO->getOpcode() == BO_LOr)
189        return PGOHash::BinaryOperatorLOr;
190      break;
191    }
192    }
193    return PGOHash::None;
194  }
195};
196
197/// A StmtVisitor that propagates the raw counts through the AST and
198/// records the count at statements where the value may change.
199struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
200  /// PGO state.
201  CodeGenPGO &PGO;
202
203  /// A flag that is set when the current count should be recorded on the
204  /// next statement, such as at the exit of a loop.
205  bool RecordNextStmtCount;
206
207  /// The count at the current location in the traversal.
208  uint64_t CurrentCount;
209
210  /// The map of statements to count values.
211  llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
212
213  /// BreakContinueStack - Keep counts of breaks and continues inside loops.
214  struct BreakContinue {
215    uint64_t BreakCount;
216    uint64_t ContinueCount;
217    BreakContinue() : BreakCount(0), ContinueCount(0) {}
218  };
219  SmallVector<BreakContinue, 8> BreakContinueStack;
220
221  ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
222                      CodeGenPGO &PGO)
223      : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
224
225  void RecordStmtCount(const Stmt *S) {
226    if (RecordNextStmtCount) {
227      CountMap[S] = CurrentCount;
228      RecordNextStmtCount = false;
229    }
230  }
231
232  /// Set and return the current count.
233  uint64_t setCount(uint64_t Count) {
234    CurrentCount = Count;
235    return Count;
236  }
237
238  void VisitStmt(const Stmt *S) {
239    RecordStmtCount(S);
240    for (const Stmt *Child : S->children())
241      if (Child)
242        this->Visit(Child);
243  }
244
245  void VisitFunctionDecl(const FunctionDecl *D) {
246    // Counter tracks entry to the function body.
247    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
248    CountMap[D->getBody()] = BodyCount;
249    Visit(D->getBody());
250  }
251
252  // Skip lambda expressions. We visit these as FunctionDecls when we're
253  // generating them and aren't interested in the body when generating a
254  // parent context.
255  void VisitLambdaExpr(const LambdaExpr *LE) {}
256
257  void VisitCapturedDecl(const CapturedDecl *D) {
258    // Counter tracks entry to the capture body.
259    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
260    CountMap[D->getBody()] = BodyCount;
261    Visit(D->getBody());
262  }
263
264  void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
265    // Counter tracks entry to the method body.
266    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
267    CountMap[D->getBody()] = BodyCount;
268    Visit(D->getBody());
269  }
270
271  void VisitBlockDecl(const BlockDecl *D) {
272    // Counter tracks entry to the block body.
273    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
274    CountMap[D->getBody()] = BodyCount;
275    Visit(D->getBody());
276  }
277
278  void VisitReturnStmt(const ReturnStmt *S) {
279    RecordStmtCount(S);
280    if (S->getRetValue())
281      Visit(S->getRetValue());
282    CurrentCount = 0;
283    RecordNextStmtCount = true;
284  }
285
286  void VisitCXXThrowExpr(const CXXThrowExpr *E) {
287    RecordStmtCount(E);
288    if (E->getSubExpr())
289      Visit(E->getSubExpr());
290    CurrentCount = 0;
291    RecordNextStmtCount = true;
292  }
293
294  void VisitGotoStmt(const GotoStmt *S) {
295    RecordStmtCount(S);
296    CurrentCount = 0;
297    RecordNextStmtCount = true;
298  }
299
300  void VisitLabelStmt(const LabelStmt *S) {
301    RecordNextStmtCount = false;
302    // Counter tracks the block following the label.
303    uint64_t BlockCount = setCount(PGO.getRegionCount(S));
304    CountMap[S] = BlockCount;
305    Visit(S->getSubStmt());
306  }
307
308  void VisitBreakStmt(const BreakStmt *S) {
309    RecordStmtCount(S);
310    assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
311    BreakContinueStack.back().BreakCount += CurrentCount;
312    CurrentCount = 0;
313    RecordNextStmtCount = true;
314  }
315
316  void VisitContinueStmt(const ContinueStmt *S) {
317    RecordStmtCount(S);
318    assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
319    BreakContinueStack.back().ContinueCount += CurrentCount;
320    CurrentCount = 0;
321    RecordNextStmtCount = true;
322  }
323
324  void VisitWhileStmt(const WhileStmt *S) {
325    RecordStmtCount(S);
326    uint64_t ParentCount = CurrentCount;
327
328    BreakContinueStack.push_back(BreakContinue());
329    // Visit the body region first so the break/continue adjustments can be
330    // included when visiting the condition.
331    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
332    CountMap[S->getBody()] = CurrentCount;
333    Visit(S->getBody());
334    uint64_t BackedgeCount = CurrentCount;
335
336    // ...then go back and propagate counts through the condition. The count
337    // at the start of the condition is the sum of the incoming edges,
338    // the backedge from the end of the loop body, and the edges from
339    // continue statements.
340    BreakContinue BC = BreakContinueStack.pop_back_val();
341    uint64_t CondCount =
342        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
343    CountMap[S->getCond()] = CondCount;
344    Visit(S->getCond());
345    setCount(BC.BreakCount + CondCount - BodyCount);
346    RecordNextStmtCount = true;
347  }
348
349  void VisitDoStmt(const DoStmt *S) {
350    RecordStmtCount(S);
351    uint64_t LoopCount = PGO.getRegionCount(S);
352
353    BreakContinueStack.push_back(BreakContinue());
354    // The count doesn't include the fallthrough from the parent scope. Add it.
355    uint64_t BodyCount = setCount(LoopCount + CurrentCount);
356    CountMap[S->getBody()] = BodyCount;
357    Visit(S->getBody());
358    uint64_t BackedgeCount = CurrentCount;
359
360    BreakContinue BC = BreakContinueStack.pop_back_val();
361    // The count at the start of the condition is equal to the count at the
362    // end of the body, plus any continues.
363    uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
364    CountMap[S->getCond()] = CondCount;
365    Visit(S->getCond());
366    setCount(BC.BreakCount + CondCount - LoopCount);
367    RecordNextStmtCount = true;
368  }
369
370  void VisitForStmt(const ForStmt *S) {
371    RecordStmtCount(S);
372    if (S->getInit())
373      Visit(S->getInit());
374
375    uint64_t ParentCount = CurrentCount;
376
377    BreakContinueStack.push_back(BreakContinue());
378    // Visit the body region first. (This is basically the same as a while
379    // loop; see further comments in VisitWhileStmt.)
380    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
381    CountMap[S->getBody()] = BodyCount;
382    Visit(S->getBody());
383    uint64_t BackedgeCount = CurrentCount;
384    BreakContinue BC = BreakContinueStack.pop_back_val();
385
386    // The increment is essentially part of the body but it needs to include
387    // the count for all the continue statements.
388    if (S->getInc()) {
389      uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
390      CountMap[S->getInc()] = IncCount;
391      Visit(S->getInc());
392    }
393
394    // ...then go back and propagate counts through the condition.
395    uint64_t CondCount =
396        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
397    if (S->getCond()) {
398      CountMap[S->getCond()] = CondCount;
399      Visit(S->getCond());
400    }
401    setCount(BC.BreakCount + CondCount - BodyCount);
402    RecordNextStmtCount = true;
403  }
404
405  void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
406    RecordStmtCount(S);
407    Visit(S->getLoopVarStmt());
408    Visit(S->getRangeStmt());
409    Visit(S->getBeginEndStmt());
410
411    uint64_t ParentCount = CurrentCount;
412    BreakContinueStack.push_back(BreakContinue());
413    // Visit the body region first. (This is basically the same as a while
414    // loop; see further comments in VisitWhileStmt.)
415    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
416    CountMap[S->getBody()] = BodyCount;
417    Visit(S->getBody());
418    uint64_t BackedgeCount = CurrentCount;
419    BreakContinue BC = BreakContinueStack.pop_back_val();
420
421    // The increment is essentially part of the body but it needs to include
422    // the count for all the continue statements.
423    uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
424    CountMap[S->getInc()] = IncCount;
425    Visit(S->getInc());
426
427    // ...then go back and propagate counts through the condition.
428    uint64_t CondCount =
429        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
430    CountMap[S->getCond()] = CondCount;
431    Visit(S->getCond());
432    setCount(BC.BreakCount + CondCount - BodyCount);
433    RecordNextStmtCount = true;
434  }
435
436  void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
437    RecordStmtCount(S);
438    Visit(S->getElement());
439    uint64_t ParentCount = CurrentCount;
440    BreakContinueStack.push_back(BreakContinue());
441    // Counter tracks the body of the loop.
442    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
443    CountMap[S->getBody()] = BodyCount;
444    Visit(S->getBody());
445    uint64_t BackedgeCount = CurrentCount;
446    BreakContinue BC = BreakContinueStack.pop_back_val();
447
448    setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
449             BodyCount);
450    RecordNextStmtCount = true;
451  }
452
453  void VisitSwitchStmt(const SwitchStmt *S) {
454    RecordStmtCount(S);
455    Visit(S->getCond());
456    CurrentCount = 0;
457    BreakContinueStack.push_back(BreakContinue());
458    Visit(S->getBody());
459    // If the switch is inside a loop, add the continue counts.
460    BreakContinue BC = BreakContinueStack.pop_back_val();
461    if (!BreakContinueStack.empty())
462      BreakContinueStack.back().ContinueCount += BC.ContinueCount;
463    // Counter tracks the exit block of the switch.
464    setCount(PGO.getRegionCount(S));
465    RecordNextStmtCount = true;
466  }
467
468  void VisitSwitchCase(const SwitchCase *S) {
469    RecordNextStmtCount = false;
470    // Counter for this particular case. This counts only jumps from the
471    // switch header and does not include fallthrough from the case before
472    // this one.
473    uint64_t CaseCount = PGO.getRegionCount(S);
474    setCount(CurrentCount + CaseCount);
475    // We need the count without fallthrough in the mapping, so it's more useful
476    // for branch probabilities.
477    CountMap[S] = CaseCount;
478    RecordNextStmtCount = true;
479    Visit(S->getSubStmt());
480  }
481
482  void VisitIfStmt(const IfStmt *S) {
483    RecordStmtCount(S);
484    uint64_t ParentCount = CurrentCount;
485    Visit(S->getCond());
486
487    // Counter tracks the "then" part of an if statement. The count for
488    // the "else" part, if it exists, will be calculated from this counter.
489    uint64_t ThenCount = setCount(PGO.getRegionCount(S));
490    CountMap[S->getThen()] = ThenCount;
491    Visit(S->getThen());
492    uint64_t OutCount = CurrentCount;
493
494    uint64_t ElseCount = ParentCount - ThenCount;
495    if (S->getElse()) {
496      setCount(ElseCount);
497      CountMap[S->getElse()] = ElseCount;
498      Visit(S->getElse());
499      OutCount += CurrentCount;
500    } else
501      OutCount += ElseCount;
502    setCount(OutCount);
503    RecordNextStmtCount = true;
504  }
505
506  void VisitCXXTryStmt(const CXXTryStmt *S) {
507    RecordStmtCount(S);
508    Visit(S->getTryBlock());
509    for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
510      Visit(S->getHandler(I));
511    // Counter tracks the continuation block of the try statement.
512    setCount(PGO.getRegionCount(S));
513    RecordNextStmtCount = true;
514  }
515
516  void VisitCXXCatchStmt(const CXXCatchStmt *S) {
517    RecordNextStmtCount = false;
518    // Counter tracks the catch statement's handler block.
519    uint64_t CatchCount = setCount(PGO.getRegionCount(S));
520    CountMap[S] = CatchCount;
521    Visit(S->getHandlerBlock());
522  }
523
524  void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
525    RecordStmtCount(E);
526    uint64_t ParentCount = CurrentCount;
527    Visit(E->getCond());
528
529    // Counter tracks the "true" part of a conditional operator. The
530    // count in the "false" part will be calculated from this counter.
531    uint64_t TrueCount = setCount(PGO.getRegionCount(E));
532    CountMap[E->getTrueExpr()] = TrueCount;
533    Visit(E->getTrueExpr());
534    uint64_t OutCount = CurrentCount;
535
536    uint64_t FalseCount = setCount(ParentCount - TrueCount);
537    CountMap[E->getFalseExpr()] = FalseCount;
538    Visit(E->getFalseExpr());
539    OutCount += CurrentCount;
540
541    setCount(OutCount);
542    RecordNextStmtCount = true;
543  }
544
545  void VisitBinLAnd(const BinaryOperator *E) {
546    RecordStmtCount(E);
547    uint64_t ParentCount = CurrentCount;
548    Visit(E->getLHS());
549    // Counter tracks the right hand side of a logical and operator.
550    uint64_t RHSCount = setCount(PGO.getRegionCount(E));
551    CountMap[E->getRHS()] = RHSCount;
552    Visit(E->getRHS());
553    setCount(ParentCount + RHSCount - CurrentCount);
554    RecordNextStmtCount = true;
555  }
556
557  void VisitBinLOr(const BinaryOperator *E) {
558    RecordStmtCount(E);
559    uint64_t ParentCount = CurrentCount;
560    Visit(E->getLHS());
561    // Counter tracks the right hand side of a logical or operator.
562    uint64_t RHSCount = setCount(PGO.getRegionCount(E));
563    CountMap[E->getRHS()] = RHSCount;
564    Visit(E->getRHS());
565    setCount(ParentCount + RHSCount - CurrentCount);
566    RecordNextStmtCount = true;
567  }
568};
569} // end anonymous namespace
570
571void PGOHash::combine(HashType Type) {
572  // Check that we never combine 0 and only have six bits.
573  assert(Type && "Hash is invalid: unexpected type 0");
574  assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
575
576  // Pass through MD5 if enough work has built up.
577  if (Count && Count % NumTypesPerWord == 0) {
578    using namespace llvm::support;
579    uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
580    MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
581    Working = 0;
582  }
583
584  // Accumulate the current type.
585  ++Count;
586  Working = Working << NumBitsPerType | Type;
587}
588
589uint64_t PGOHash::finalize() {
590  // Use Working as the hash directly if we never used MD5.
591  if (Count <= NumTypesPerWord)
592    // No need to byte swap here, since none of the math was endian-dependent.
593    // This number will be byte-swapped as required on endianness transitions,
594    // so we will see the same value on the other side.
595    return Working;
596
597  // Check for remaining work in Working.
598  if (Working)
599    MD5.update(Working);
600
601  // Finalize the MD5 and return the hash.
602  llvm::MD5::MD5Result Result;
603  MD5.final(Result);
604  using namespace llvm::support;
605  return endian::read<uint64_t, little, unaligned>(Result);
606}
607
608void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
609  const Decl *D = GD.getDecl();
610  bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
611  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
612  if (!InstrumentRegions && !PGOReader)
613    return;
614  if (D->isImplicit())
615    return;
616  // Constructors and destructors may be represented by several functions in IR.
617  // If so, instrument only base variant, others are implemented by delegation
618  // to the base one, it would be counted twice otherwise.
619  if (CGM.getTarget().getCXXABI().hasConstructorVariants() &&
620      ((isa<CXXConstructorDecl>(GD.getDecl()) &&
621        GD.getCtorType() != Ctor_Base) ||
622       (isa<CXXDestructorDecl>(GD.getDecl()) &&
623        GD.getDtorType() != Dtor_Base))) {
624      return;
625  }
626  CGM.ClearUnusedCoverageMapping(D);
627  setFuncName(Fn);
628
629  mapRegionCounters(D);
630  if (CGM.getCodeGenOpts().CoverageMapping)
631    emitCounterRegionMapping(D);
632  if (PGOReader) {
633    SourceManager &SM = CGM.getContext().getSourceManager();
634    loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
635    computeRegionCounts(D);
636    applyFunctionAttributes(PGOReader, Fn);
637  }
638}
639
640void CodeGenPGO::mapRegionCounters(const Decl *D) {
641  RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
642  MapRegionCounters Walker(*RegionCounterMap);
643  if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
644    Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
645  else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
646    Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
647  else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
648    Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
649  else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
650    Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
651  assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
652  NumRegionCounters = Walker.NextCounter;
653  FunctionHash = Walker.Hash.finalize();
654}
655
656void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
657  if (SkipCoverageMapping)
658    return;
659  // Don't map the functions inside the system headers
660  auto Loc = D->getBody()->getLocStart();
661  if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
662    return;
663
664  std::string CoverageMapping;
665  llvm::raw_string_ostream OS(CoverageMapping);
666  CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
667                                CGM.getContext().getSourceManager(),
668                                CGM.getLangOpts(), RegionCounterMap.get());
669  MappingGen.emitCounterMapping(D, OS);
670  OS.flush();
671
672  if (CoverageMapping.empty())
673    return;
674
675  CGM.getCoverageMapping()->addFunctionMappingRecord(
676      FuncNameVar, FuncName, FunctionHash, CoverageMapping);
677}
678
679void
680CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
681                                    llvm::GlobalValue::LinkageTypes Linkage) {
682  if (SkipCoverageMapping)
683    return;
684  // Don't map the functions inside the system headers
685  auto Loc = D->getBody()->getLocStart();
686  if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
687    return;
688
689  std::string CoverageMapping;
690  llvm::raw_string_ostream OS(CoverageMapping);
691  CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
692                                CGM.getContext().getSourceManager(),
693                                CGM.getLangOpts());
694  MappingGen.emitEmptyMapping(D, OS);
695  OS.flush();
696
697  if (CoverageMapping.empty())
698    return;
699
700  setFuncName(Name, Linkage);
701  CGM.getCoverageMapping()->addFunctionMappingRecord(
702      FuncNameVar, FuncName, FunctionHash, CoverageMapping);
703}
704
705void CodeGenPGO::computeRegionCounts(const Decl *D) {
706  StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
707  ComputeRegionCounts Walker(*StmtCountMap, *this);
708  if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
709    Walker.VisitFunctionDecl(FD);
710  else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
711    Walker.VisitObjCMethodDecl(MD);
712  else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
713    Walker.VisitBlockDecl(BD);
714  else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
715    Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
716}
717
718void
719CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
720                                    llvm::Function *Fn) {
721  if (!haveRegionCounts())
722    return;
723
724  uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
725  uint64_t FunctionCount = getRegionCount(nullptr);
726  if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
727    // Turn on InlineHint attribute for hot functions.
728    // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
729    Fn->addFnAttr(llvm::Attribute::InlineHint);
730  else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
731    // Turn on Cold attribute for cold functions.
732    // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
733    Fn->addFnAttr(llvm::Attribute::Cold);
734
735  Fn->setEntryCount(FunctionCount);
736}
737
738void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S) {
739  if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap)
740    return;
741  if (!Builder.GetInsertBlock())
742    return;
743
744  unsigned Counter = (*RegionCounterMap)[S];
745  auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
746  Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
747                     {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
748                      Builder.getInt64(FunctionHash),
749                      Builder.getInt32(NumRegionCounters),
750                      Builder.getInt32(Counter)});
751}
752
753void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
754                                  bool IsInMainFile) {
755  CGM.getPGOStats().addVisited(IsInMainFile);
756  RegionCounts.clear();
757  if (std::error_code EC =
758          PGOReader->getFunctionCounts(FuncName, FunctionHash, RegionCounts)) {
759    if (EC == llvm::instrprof_error::unknown_function)
760      CGM.getPGOStats().addMissing(IsInMainFile);
761    else if (EC == llvm::instrprof_error::hash_mismatch)
762      CGM.getPGOStats().addMismatched(IsInMainFile);
763    else if (EC == llvm::instrprof_error::malformed)
764      // TODO: Consider a more specific warning for this case.
765      CGM.getPGOStats().addMismatched(IsInMainFile);
766    RegionCounts.clear();
767  }
768}
769
770/// \brief Calculate what to divide by to scale weights.
771///
772/// Given the maximum weight, calculate a divisor that will scale all the
773/// weights to strictly less than UINT32_MAX.
774static uint64_t calculateWeightScale(uint64_t MaxWeight) {
775  return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
776}
777
778/// \brief Scale an individual branch weight (and add 1).
779///
780/// Scale a 64-bit weight down to 32-bits using \c Scale.
781///
782/// According to Laplace's Rule of Succession, it is better to compute the
783/// weight based on the count plus 1, so universally add 1 to the value.
784///
785/// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
786/// greater than \c Weight.
787static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
788  assert(Scale && "scale by 0?");
789  uint64_t Scaled = Weight / Scale + 1;
790  assert(Scaled <= UINT32_MAX && "overflow 32-bits");
791  return Scaled;
792}
793
794llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
795                                                    uint64_t FalseCount) {
796  // Check for empty weights.
797  if (!TrueCount && !FalseCount)
798    return nullptr;
799
800  // Calculate how to scale down to 32-bits.
801  uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
802
803  llvm::MDBuilder MDHelper(CGM.getLLVMContext());
804  return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
805                                      scaleBranchWeight(FalseCount, Scale));
806}
807
808llvm::MDNode *
809CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
810  // We need at least two elements to create meaningful weights.
811  if (Weights.size() < 2)
812    return nullptr;
813
814  // Check for empty weights.
815  uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
816  if (MaxWeight == 0)
817    return nullptr;
818
819  // Calculate how to scale down to 32-bits.
820  uint64_t Scale = calculateWeightScale(MaxWeight);
821
822  SmallVector<uint32_t, 16> ScaledWeights;
823  ScaledWeights.reserve(Weights.size());
824  for (uint64_t W : Weights)
825    ScaledWeights.push_back(scaleBranchWeight(W, Scale));
826
827  llvm::MDBuilder MDHelper(CGM.getLLVMContext());
828  return MDHelper.createBranchWeights(ScaledWeights);
829}
830
831llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
832                                                           uint64_t LoopCount) {
833  if (!PGO.haveRegionCounts())
834    return nullptr;
835  Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
836  assert(CondCount.hasValue() && "missing expected loop condition count");
837  if (*CondCount == 0)
838    return nullptr;
839  return createProfileWeights(LoopCount,
840                              std::max(*CondCount, LoopCount) - LoopCount);
841}
842