1//===-- InductiveRangeCheckElimination.cpp - ------------------------------===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9// The InductiveRangeCheckElimination pass splits a loop's iteration space into 10// three disjoint ranges. It does that in a way such that the loop running in 11// the middle loop provably does not need range checks. As an example, it will 12// convert 13// 14// len = < known positive > 15// for (i = 0; i < n; i++) { 16// if (0 <= i && i < len) { 17// do_something(); 18// } else { 19// throw_out_of_bounds(); 20// } 21// } 22// 23// to 24// 25// len = < known positive > 26// limit = smin(n, len) 27// // no first segment 28// for (i = 0; i < limit; i++) { 29// if (0 <= i && i < len) { // this check is fully redundant 30// do_something(); 31// } else { 32// throw_out_of_bounds(); 33// } 34// } 35// for (i = limit; i < n; i++) { 36// if (0 <= i && i < len) { 37// do_something(); 38// } else { 39// throw_out_of_bounds(); 40// } 41// } 42//===----------------------------------------------------------------------===// 43 44#include "llvm/ADT/Optional.h" 45#include "llvm/Analysis/BranchProbabilityInfo.h" 46#include "llvm/Analysis/InstructionSimplify.h" 47#include "llvm/Analysis/LoopInfo.h" 48#include "llvm/Analysis/LoopPass.h" 49#include "llvm/Analysis/ScalarEvolution.h" 50#include "llvm/Analysis/ScalarEvolutionExpander.h" 51#include "llvm/Analysis/ScalarEvolutionExpressions.h" 52#include "llvm/Analysis/ValueTracking.h" 53#include "llvm/IR/Dominators.h" 54#include "llvm/IR/Function.h" 55#include "llvm/IR/IRBuilder.h" 56#include "llvm/IR/Instructions.h" 57#include "llvm/IR/Module.h" 58#include "llvm/IR/PatternMatch.h" 59#include "llvm/IR/ValueHandle.h" 60#include "llvm/IR/Verifier.h" 61#include "llvm/Pass.h" 62#include "llvm/Support/Debug.h" 63#include "llvm/Support/raw_ostream.h" 64#include "llvm/Transforms/Scalar.h" 65#include "llvm/Transforms/Utils/BasicBlockUtils.h" 66#include "llvm/Transforms/Utils/Cloning.h" 67#include "llvm/Transforms/Utils/LoopUtils.h" 68#include "llvm/Transforms/Utils/SimplifyIndVar.h" 69#include "llvm/Transforms/Utils/UnrollLoop.h" 70#include <array> 71 72using namespace llvm; 73 74static cl::opt<unsigned> LoopSizeCutoff("irce-loop-size-cutoff", cl::Hidden, 75 cl::init(64)); 76 77static cl::opt<bool> PrintChangedLoops("irce-print-changed-loops", cl::Hidden, 78 cl::init(false)); 79 80static cl::opt<bool> PrintRangeChecks("irce-print-range-checks", cl::Hidden, 81 cl::init(false)); 82 83static cl::opt<int> MaxExitProbReciprocal("irce-max-exit-prob-reciprocal", 84 cl::Hidden, cl::init(10)); 85 86#define DEBUG_TYPE "irce" 87 88namespace { 89 90/// An inductive range check is conditional branch in a loop with 91/// 92/// 1. a very cold successor (i.e. the branch jumps to that successor very 93/// rarely) 94/// 95/// and 96/// 97/// 2. a condition that is provably true for some contiguous range of values 98/// taken by the containing loop's induction variable. 99/// 100class InductiveRangeCheck { 101 // Classifies a range check 102 enum RangeCheckKind : unsigned { 103 // Range check of the form "0 <= I". 104 RANGE_CHECK_LOWER = 1, 105 106 // Range check of the form "I < L" where L is known positive. 107 RANGE_CHECK_UPPER = 2, 108 109 // The logical and of the RANGE_CHECK_LOWER and RANGE_CHECK_UPPER 110 // conditions. 111 RANGE_CHECK_BOTH = RANGE_CHECK_LOWER | RANGE_CHECK_UPPER, 112 113 // Unrecognized range check condition. 114 RANGE_CHECK_UNKNOWN = (unsigned)-1 115 }; 116 117 static const char *rangeCheckKindToStr(RangeCheckKind); 118 119 const SCEV *Offset; 120 const SCEV *Scale; 121 Value *Length; 122 BranchInst *Branch; 123 RangeCheckKind Kind; 124 125 static RangeCheckKind parseRangeCheckICmp(Loop *L, ICmpInst *ICI, 126 ScalarEvolution &SE, Value *&Index, 127 Value *&Length); 128 129 static InductiveRangeCheck::RangeCheckKind 130 parseRangeCheck(Loop *L, ScalarEvolution &SE, Value *Condition, 131 const SCEV *&Index, Value *&UpperLimit); 132 133 InductiveRangeCheck() : 134 Offset(nullptr), Scale(nullptr), Length(nullptr), Branch(nullptr) { } 135 136public: 137 const SCEV *getOffset() const { return Offset; } 138 const SCEV *getScale() const { return Scale; } 139 Value *getLength() const { return Length; } 140 141 void print(raw_ostream &OS) const { 142 OS << "InductiveRangeCheck:\n"; 143 OS << " Kind: " << rangeCheckKindToStr(Kind) << "\n"; 144 OS << " Offset: "; 145 Offset->print(OS); 146 OS << " Scale: "; 147 Scale->print(OS); 148 OS << " Length: "; 149 if (Length) 150 Length->print(OS); 151 else 152 OS << "(null)"; 153 OS << "\n Branch: "; 154 getBranch()->print(OS); 155 OS << "\n"; 156 } 157 158#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 159 void dump() { 160 print(dbgs()); 161 } 162#endif 163 164 BranchInst *getBranch() const { return Branch; } 165 166 /// Represents an signed integer range [Range.getBegin(), Range.getEnd()). If 167 /// R.getEnd() sle R.getBegin(), then R denotes the empty range. 168 169 class Range { 170 const SCEV *Begin; 171 const SCEV *End; 172 173 public: 174 Range(const SCEV *Begin, const SCEV *End) : Begin(Begin), End(End) { 175 assert(Begin->getType() == End->getType() && "ill-typed range!"); 176 } 177 178 Type *getType() const { return Begin->getType(); } 179 const SCEV *getBegin() const { return Begin; } 180 const SCEV *getEnd() const { return End; } 181 }; 182 183 typedef SpecificBumpPtrAllocator<InductiveRangeCheck> AllocatorTy; 184 185 /// This is the value the condition of the branch needs to evaluate to for the 186 /// branch to take the hot successor (see (1) above). 187 bool getPassingDirection() { return true; } 188 189 /// Computes a range for the induction variable (IndVar) in which the range 190 /// check is redundant and can be constant-folded away. The induction 191 /// variable is not required to be the canonical {0,+,1} induction variable. 192 Optional<Range> computeSafeIterationSpace(ScalarEvolution &SE, 193 const SCEVAddRecExpr *IndVar, 194 IRBuilder<> &B) const; 195 196 /// Create an inductive range check out of BI if possible, else return 197 /// nullptr. 198 static InductiveRangeCheck *create(AllocatorTy &Alloc, BranchInst *BI, 199 Loop *L, ScalarEvolution &SE, 200 BranchProbabilityInfo &BPI); 201}; 202 203class InductiveRangeCheckElimination : public LoopPass { 204 InductiveRangeCheck::AllocatorTy Allocator; 205 206public: 207 static char ID; 208 InductiveRangeCheckElimination() : LoopPass(ID) { 209 initializeInductiveRangeCheckEliminationPass( 210 *PassRegistry::getPassRegistry()); 211 } 212 213 void getAnalysisUsage(AnalysisUsage &AU) const override { 214 AU.addRequired<LoopInfoWrapperPass>(); 215 AU.addRequiredID(LoopSimplifyID); 216 AU.addRequiredID(LCSSAID); 217 AU.addRequired<ScalarEvolution>(); 218 AU.addRequired<BranchProbabilityInfo>(); 219 } 220 221 bool runOnLoop(Loop *L, LPPassManager &LPM) override; 222}; 223 224char InductiveRangeCheckElimination::ID = 0; 225} 226 227INITIALIZE_PASS(InductiveRangeCheckElimination, "irce", 228 "Inductive range check elimination", false, false) 229 230const char *InductiveRangeCheck::rangeCheckKindToStr( 231 InductiveRangeCheck::RangeCheckKind RCK) { 232 switch (RCK) { 233 case InductiveRangeCheck::RANGE_CHECK_UNKNOWN: 234 return "RANGE_CHECK_UNKNOWN"; 235 236 case InductiveRangeCheck::RANGE_CHECK_UPPER: 237 return "RANGE_CHECK_UPPER"; 238 239 case InductiveRangeCheck::RANGE_CHECK_LOWER: 240 return "RANGE_CHECK_LOWER"; 241 242 case InductiveRangeCheck::RANGE_CHECK_BOTH: 243 return "RANGE_CHECK_BOTH"; 244 } 245 246 llvm_unreachable("unknown range check type!"); 247} 248 249/// Parse a single ICmp instruction, `ICI`, into a range check. If `ICI` 250/// cannot 251/// be interpreted as a range check, return `RANGE_CHECK_UNKNOWN` and set 252/// `Index` and `Length` to `nullptr`. Otherwise set `Index` to the value 253/// being 254/// range checked, and set `Length` to the upper limit `Index` is being range 255/// checked with if (and only if) the range check type is stronger or equal to 256/// RANGE_CHECK_UPPER. 257/// 258InductiveRangeCheck::RangeCheckKind 259InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, 260 ScalarEvolution &SE, Value *&Index, 261 Value *&Length) { 262 263 auto IsNonNegativeAndNotLoopVarying = [&SE, L](Value *V) { 264 const SCEV *S = SE.getSCEV(V); 265 if (isa<SCEVCouldNotCompute>(S)) 266 return false; 267 268 return SE.getLoopDisposition(S, L) == ScalarEvolution::LoopInvariant && 269 SE.isKnownNonNegative(S); 270 }; 271 272 using namespace llvm::PatternMatch; 273 274 ICmpInst::Predicate Pred = ICI->getPredicate(); 275 Value *LHS = ICI->getOperand(0); 276 Value *RHS = ICI->getOperand(1); 277 278 switch (Pred) { 279 default: 280 return RANGE_CHECK_UNKNOWN; 281 282 case ICmpInst::ICMP_SLE: 283 std::swap(LHS, RHS); 284 // fallthrough 285 case ICmpInst::ICMP_SGE: 286 if (match(RHS, m_ConstantInt<0>())) { 287 Index = LHS; 288 return RANGE_CHECK_LOWER; 289 } 290 return RANGE_CHECK_UNKNOWN; 291 292 case ICmpInst::ICMP_SLT: 293 std::swap(LHS, RHS); 294 // fallthrough 295 case ICmpInst::ICMP_SGT: 296 if (match(RHS, m_ConstantInt<-1>())) { 297 Index = LHS; 298 return RANGE_CHECK_LOWER; 299 } 300 301 if (IsNonNegativeAndNotLoopVarying(LHS)) { 302 Index = RHS; 303 Length = LHS; 304 return RANGE_CHECK_UPPER; 305 } 306 return RANGE_CHECK_UNKNOWN; 307 308 case ICmpInst::ICMP_ULT: 309 std::swap(LHS, RHS); 310 // fallthrough 311 case ICmpInst::ICMP_UGT: 312 if (IsNonNegativeAndNotLoopVarying(LHS)) { 313 Index = RHS; 314 Length = LHS; 315 return RANGE_CHECK_BOTH; 316 } 317 return RANGE_CHECK_UNKNOWN; 318 } 319 320 llvm_unreachable("default clause returns!"); 321} 322 323/// Parses an arbitrary condition into a range check. `Length` is set only if 324/// the range check is recognized to be `RANGE_CHECK_UPPER` or stronger. 325InductiveRangeCheck::RangeCheckKind 326InductiveRangeCheck::parseRangeCheck(Loop *L, ScalarEvolution &SE, 327 Value *Condition, const SCEV *&Index, 328 Value *&Length) { 329 using namespace llvm::PatternMatch; 330 331 Value *A = nullptr; 332 Value *B = nullptr; 333 334 if (match(Condition, m_And(m_Value(A), m_Value(B)))) { 335 Value *IndexA = nullptr, *IndexB = nullptr; 336 Value *LengthA = nullptr, *LengthB = nullptr; 337 ICmpInst *ICmpA = dyn_cast<ICmpInst>(A), *ICmpB = dyn_cast<ICmpInst>(B); 338 339 if (!ICmpA || !ICmpB) 340 return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; 341 342 auto RCKindA = parseRangeCheckICmp(L, ICmpA, SE, IndexA, LengthA); 343 auto RCKindB = parseRangeCheckICmp(L, ICmpB, SE, IndexB, LengthB); 344 345 if (RCKindA == InductiveRangeCheck::RANGE_CHECK_UNKNOWN || 346 RCKindB == InductiveRangeCheck::RANGE_CHECK_UNKNOWN) 347 return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; 348 349 if (IndexA != IndexB) 350 return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; 351 352 if (LengthA != nullptr && LengthB != nullptr && LengthA != LengthB) 353 return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; 354 355 Index = SE.getSCEV(IndexA); 356 if (isa<SCEVCouldNotCompute>(Index)) 357 return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; 358 359 Length = LengthA == nullptr ? LengthB : LengthA; 360 361 return (InductiveRangeCheck::RangeCheckKind)(RCKindA | RCKindB); 362 } 363 364 if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) { 365 Value *IndexVal = nullptr; 366 367 auto RCKind = parseRangeCheckICmp(L, ICI, SE, IndexVal, Length); 368 369 if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN) 370 return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; 371 372 Index = SE.getSCEV(IndexVal); 373 if (isa<SCEVCouldNotCompute>(Index)) 374 return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; 375 376 return RCKind; 377 } 378 379 return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; 380} 381 382 383InductiveRangeCheck * 384InductiveRangeCheck::create(InductiveRangeCheck::AllocatorTy &A, BranchInst *BI, 385 Loop *L, ScalarEvolution &SE, 386 BranchProbabilityInfo &BPI) { 387 388 if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch()) 389 return nullptr; 390 391 BranchProbability LikelyTaken(15, 16); 392 393 if (BPI.getEdgeProbability(BI->getParent(), (unsigned) 0) < LikelyTaken) 394 return nullptr; 395 396 Value *Length = nullptr; 397 const SCEV *IndexSCEV = nullptr; 398 399 auto RCKind = InductiveRangeCheck::parseRangeCheck(L, SE, BI->getCondition(), 400 IndexSCEV, Length); 401 402 if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN) 403 return nullptr; 404 405 assert(IndexSCEV && "contract with SplitRangeCheckCondition!"); 406 assert((!(RCKind & InductiveRangeCheck::RANGE_CHECK_UPPER) || Length) && 407 "contract with SplitRangeCheckCondition!"); 408 409 const SCEVAddRecExpr *IndexAddRec = dyn_cast<SCEVAddRecExpr>(IndexSCEV); 410 bool IsAffineIndex = 411 IndexAddRec && (IndexAddRec->getLoop() == L) && IndexAddRec->isAffine(); 412 413 if (!IsAffineIndex) 414 return nullptr; 415 416 InductiveRangeCheck *IRC = new (A.Allocate()) InductiveRangeCheck; 417 IRC->Length = Length; 418 IRC->Offset = IndexAddRec->getStart(); 419 IRC->Scale = IndexAddRec->getStepRecurrence(SE); 420 IRC->Branch = BI; 421 IRC->Kind = RCKind; 422 return IRC; 423} 424 425namespace { 426 427// Keeps track of the structure of a loop. This is similar to llvm::Loop, 428// except that it is more lightweight and can track the state of a loop through 429// changing and potentially invalid IR. This structure also formalizes the 430// kinds of loops we can deal with -- ones that have a single latch that is also 431// an exiting block *and* have a canonical induction variable. 432struct LoopStructure { 433 const char *Tag; 434 435 BasicBlock *Header; 436 BasicBlock *Latch; 437 438 // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th 439 // successor is `LatchExit', the exit block of the loop. 440 BranchInst *LatchBr; 441 BasicBlock *LatchExit; 442 unsigned LatchBrExitIdx; 443 444 Value *IndVarNext; 445 Value *IndVarStart; 446 Value *LoopExitAt; 447 bool IndVarIncreasing; 448 449 LoopStructure() 450 : Tag(""), Header(nullptr), Latch(nullptr), LatchBr(nullptr), 451 LatchExit(nullptr), LatchBrExitIdx(-1), IndVarNext(nullptr), 452 IndVarStart(nullptr), LoopExitAt(nullptr), IndVarIncreasing(false) {} 453 454 template <typename M> LoopStructure map(M Map) const { 455 LoopStructure Result; 456 Result.Tag = Tag; 457 Result.Header = cast<BasicBlock>(Map(Header)); 458 Result.Latch = cast<BasicBlock>(Map(Latch)); 459 Result.LatchBr = cast<BranchInst>(Map(LatchBr)); 460 Result.LatchExit = cast<BasicBlock>(Map(LatchExit)); 461 Result.LatchBrExitIdx = LatchBrExitIdx; 462 Result.IndVarNext = Map(IndVarNext); 463 Result.IndVarStart = Map(IndVarStart); 464 Result.LoopExitAt = Map(LoopExitAt); 465 Result.IndVarIncreasing = IndVarIncreasing; 466 return Result; 467 } 468 469 static Optional<LoopStructure> parseLoopStructure(ScalarEvolution &, 470 BranchProbabilityInfo &BPI, 471 Loop &, 472 const char *&); 473}; 474 475/// This class is used to constrain loops to run within a given iteration space. 476/// The algorithm this class implements is given a Loop and a range [Begin, 477/// End). The algorithm then tries to break out a "main loop" out of the loop 478/// it is given in a way that the "main loop" runs with the induction variable 479/// in a subset of [Begin, End). The algorithm emits appropriate pre and post 480/// loops to run any remaining iterations. The pre loop runs any iterations in 481/// which the induction variable is < Begin, and the post loop runs any 482/// iterations in which the induction variable is >= End. 483/// 484class LoopConstrainer { 485 // The representation of a clone of the original loop we started out with. 486 struct ClonedLoop { 487 // The cloned blocks 488 std::vector<BasicBlock *> Blocks; 489 490 // `Map` maps values in the clonee into values in the cloned version 491 ValueToValueMapTy Map; 492 493 // An instance of `LoopStructure` for the cloned loop 494 LoopStructure Structure; 495 }; 496 497 // Result of rewriting the range of a loop. See changeIterationSpaceEnd for 498 // more details on what these fields mean. 499 struct RewrittenRangeInfo { 500 BasicBlock *PseudoExit; 501 BasicBlock *ExitSelector; 502 std::vector<PHINode *> PHIValuesAtPseudoExit; 503 PHINode *IndVarEnd; 504 505 RewrittenRangeInfo() 506 : PseudoExit(nullptr), ExitSelector(nullptr), IndVarEnd(nullptr) {} 507 }; 508 509 // Calculated subranges we restrict the iteration space of the main loop to. 510 // See the implementation of `calculateSubRanges' for more details on how 511 // these fields are computed. `LowLimit` is None if there is no restriction 512 // on low end of the restricted iteration space of the main loop. `HighLimit` 513 // is None if there is no restriction on high end of the restricted iteration 514 // space of the main loop. 515 516 struct SubRanges { 517 Optional<const SCEV *> LowLimit; 518 Optional<const SCEV *> HighLimit; 519 }; 520 521 // A utility function that does a `replaceUsesOfWith' on the incoming block 522 // set of a `PHINode' -- replaces instances of `Block' in the `PHINode's 523 // incoming block list with `ReplaceBy'. 524 static void replacePHIBlock(PHINode *PN, BasicBlock *Block, 525 BasicBlock *ReplaceBy); 526 527 // Compute a safe set of limits for the main loop to run in -- effectively the 528 // intersection of `Range' and the iteration space of the original loop. 529 // Return None if unable to compute the set of subranges. 530 // 531 Optional<SubRanges> calculateSubRanges() const; 532 533 // Clone `OriginalLoop' and return the result in CLResult. The IR after 534 // running `cloneLoop' is well formed except for the PHI nodes in CLResult -- 535 // the PHI nodes say that there is an incoming edge from `OriginalPreheader` 536 // but there is no such edge. 537 // 538 void cloneLoop(ClonedLoop &CLResult, const char *Tag) const; 539 540 // Rewrite the iteration space of the loop denoted by (LS, Preheader). The 541 // iteration space of the rewritten loop ends at ExitLoopAt. The start of the 542 // iteration space is not changed. `ExitLoopAt' is assumed to be slt 543 // `OriginalHeaderCount'. 544 // 545 // If there are iterations left to execute, control is made to jump to 546 // `ContinuationBlock', otherwise they take the normal loop exit. The 547 // returned `RewrittenRangeInfo' object is populated as follows: 548 // 549 // .PseudoExit is a basic block that unconditionally branches to 550 // `ContinuationBlock'. 551 // 552 // .ExitSelector is a basic block that decides, on exit from the loop, 553 // whether to branch to the "true" exit or to `PseudoExit'. 554 // 555 // .PHIValuesAtPseudoExit are PHINodes in `PseudoExit' that compute the value 556 // for each PHINode in the loop header on taking the pseudo exit. 557 // 558 // After changeIterationSpaceEnd, `Preheader' is no longer a legitimate 559 // preheader because it is made to branch to the loop header only 560 // conditionally. 561 // 562 RewrittenRangeInfo 563 changeIterationSpaceEnd(const LoopStructure &LS, BasicBlock *Preheader, 564 Value *ExitLoopAt, 565 BasicBlock *ContinuationBlock) const; 566 567 // The loop denoted by `LS' has `OldPreheader' as its preheader. This 568 // function creates a new preheader for `LS' and returns it. 569 // 570 BasicBlock *createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader, 571 const char *Tag) const; 572 573 // `ContinuationBlockAndPreheader' was the continuation block for some call to 574 // `changeIterationSpaceEnd' and is the preheader to the loop denoted by `LS'. 575 // This function rewrites the PHI nodes in `LS.Header' to start with the 576 // correct value. 577 void rewriteIncomingValuesForPHIs( 578 LoopStructure &LS, BasicBlock *ContinuationBlockAndPreheader, 579 const LoopConstrainer::RewrittenRangeInfo &RRI) const; 580 581 // Even though we do not preserve any passes at this time, we at least need to 582 // keep the parent loop structure consistent. The `LPPassManager' seems to 583 // verify this after running a loop pass. This function adds the list of 584 // blocks denoted by BBs to this loops parent loop if required. 585 void addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs); 586 587 // Some global state. 588 Function &F; 589 LLVMContext &Ctx; 590 ScalarEvolution &SE; 591 592 // Information about the original loop we started out with. 593 Loop &OriginalLoop; 594 LoopInfo &OriginalLoopInfo; 595 const SCEV *LatchTakenCount; 596 BasicBlock *OriginalPreheader; 597 598 // The preheader of the main loop. This may or may not be different from 599 // `OriginalPreheader'. 600 BasicBlock *MainLoopPreheader; 601 602 // The range we need to run the main loop in. 603 InductiveRangeCheck::Range Range; 604 605 // The structure of the main loop (see comment at the beginning of this class 606 // for a definition) 607 LoopStructure MainLoopStructure; 608 609public: 610 LoopConstrainer(Loop &L, LoopInfo &LI, const LoopStructure &LS, 611 ScalarEvolution &SE, InductiveRangeCheck::Range R) 612 : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), 613 SE(SE), OriginalLoop(L), OriginalLoopInfo(LI), LatchTakenCount(nullptr), 614 OriginalPreheader(nullptr), MainLoopPreheader(nullptr), Range(R), 615 MainLoopStructure(LS) {} 616 617 // Entry point for the algorithm. Returns true on success. 618 bool run(); 619}; 620 621} 622 623void LoopConstrainer::replacePHIBlock(PHINode *PN, BasicBlock *Block, 624 BasicBlock *ReplaceBy) { 625 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 626 if (PN->getIncomingBlock(i) == Block) 627 PN->setIncomingBlock(i, ReplaceBy); 628} 629 630static bool CanBeSMax(ScalarEvolution &SE, const SCEV *S) { 631 APInt SMax = 632 APInt::getSignedMaxValue(cast<IntegerType>(S->getType())->getBitWidth()); 633 return SE.getSignedRange(S).contains(SMax) && 634 SE.getUnsignedRange(S).contains(SMax); 635} 636 637static bool CanBeSMin(ScalarEvolution &SE, const SCEV *S) { 638 APInt SMin = 639 APInt::getSignedMinValue(cast<IntegerType>(S->getType())->getBitWidth()); 640 return SE.getSignedRange(S).contains(SMin) && 641 SE.getUnsignedRange(S).contains(SMin); 642} 643 644Optional<LoopStructure> 645LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BPI, 646 Loop &L, const char *&FailureReason) { 647 assert(L.isLoopSimplifyForm() && "should follow from addRequired<>"); 648 649 BasicBlock *Latch = L.getLoopLatch(); 650 if (!L.isLoopExiting(Latch)) { 651 FailureReason = "no loop latch"; 652 return None; 653 } 654 655 BasicBlock *Header = L.getHeader(); 656 BasicBlock *Preheader = L.getLoopPreheader(); 657 if (!Preheader) { 658 FailureReason = "no preheader"; 659 return None; 660 } 661 662 BranchInst *LatchBr = dyn_cast<BranchInst>(&*Latch->rbegin()); 663 if (!LatchBr || LatchBr->isUnconditional()) { 664 FailureReason = "latch terminator not conditional branch"; 665 return None; 666 } 667 668 unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0; 669 670 BranchProbability ExitProbability = 671 BPI.getEdgeProbability(LatchBr->getParent(), LatchBrExitIdx); 672 673 if (ExitProbability > BranchProbability(1, MaxExitProbReciprocal)) { 674 FailureReason = "short running loop, not profitable"; 675 return None; 676 } 677 678 ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition()); 679 if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) { 680 FailureReason = "latch terminator branch not conditional on integral icmp"; 681 return None; 682 } 683 684 const SCEV *LatchCount = SE.getExitCount(&L, Latch); 685 if (isa<SCEVCouldNotCompute>(LatchCount)) { 686 FailureReason = "could not compute latch count"; 687 return None; 688 } 689 690 ICmpInst::Predicate Pred = ICI->getPredicate(); 691 Value *LeftValue = ICI->getOperand(0); 692 const SCEV *LeftSCEV = SE.getSCEV(LeftValue); 693 IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType()); 694 695 Value *RightValue = ICI->getOperand(1); 696 const SCEV *RightSCEV = SE.getSCEV(RightValue); 697 698 // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence. 699 if (!isa<SCEVAddRecExpr>(LeftSCEV)) { 700 if (isa<SCEVAddRecExpr>(RightSCEV)) { 701 std::swap(LeftSCEV, RightSCEV); 702 std::swap(LeftValue, RightValue); 703 Pred = ICmpInst::getSwappedPredicate(Pred); 704 } else { 705 FailureReason = "no add recurrences in the icmp"; 706 return None; 707 } 708 } 709 710 auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) { 711 if (AR->getNoWrapFlags(SCEV::FlagNSW)) 712 return true; 713 714 IntegerType *Ty = cast<IntegerType>(AR->getType()); 715 IntegerType *WideTy = 716 IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2); 717 718 const SCEVAddRecExpr *ExtendAfterOp = 719 dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy)); 720 if (ExtendAfterOp) { 721 const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy); 722 const SCEV *ExtendedStep = 723 SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy); 724 725 bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart && 726 ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep; 727 728 if (NoSignedWrap) 729 return true; 730 } 731 732 // We may have proved this when computing the sign extension above. 733 return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap; 734 }; 735 736 auto IsInductionVar = [&](const SCEVAddRecExpr *AR, bool &IsIncreasing) { 737 if (!AR->isAffine()) 738 return false; 739 740 // Currently we only work with induction variables that have been proved to 741 // not wrap. This restriction can potentially be lifted in the future. 742 743 if (!HasNoSignedWrap(AR)) 744 return false; 745 746 if (const SCEVConstant *StepExpr = 747 dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) { 748 ConstantInt *StepCI = StepExpr->getValue(); 749 if (StepCI->isOne() || StepCI->isMinusOne()) { 750 IsIncreasing = StepCI->isOne(); 751 return true; 752 } 753 } 754 755 return false; 756 }; 757 758 // `ICI` is interpreted as taking the backedge if the *next* value of the 759 // induction variable satisfies some constraint. 760 761 const SCEVAddRecExpr *IndVarNext = cast<SCEVAddRecExpr>(LeftSCEV); 762 bool IsIncreasing = false; 763 if (!IsInductionVar(IndVarNext, IsIncreasing)) { 764 FailureReason = "LHS in icmp not induction variable"; 765 return None; 766 } 767 768 ConstantInt *One = ConstantInt::get(IndVarTy, 1); 769 // TODO: generalize the predicates here to also match their unsigned variants. 770 if (IsIncreasing) { 771 bool FoundExpectedPred = 772 (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 1) || 773 (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 0); 774 775 if (!FoundExpectedPred) { 776 FailureReason = "expected icmp slt semantically, found something else"; 777 return None; 778 } 779 780 if (LatchBrExitIdx == 0) { 781 if (CanBeSMax(SE, RightSCEV)) { 782 // TODO: this restriction is easily removable -- we just have to 783 // remember that the icmp was an slt and not an sle. 784 FailureReason = "limit may overflow when coercing sle to slt"; 785 return None; 786 } 787 788 IRBuilder<> B(&*Preheader->rbegin()); 789 RightValue = B.CreateAdd(RightValue, One); 790 } 791 792 } else { 793 bool FoundExpectedPred = 794 (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 1) || 795 (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 0); 796 797 if (!FoundExpectedPred) { 798 FailureReason = "expected icmp sgt semantically, found something else"; 799 return None; 800 } 801 802 if (LatchBrExitIdx == 0) { 803 if (CanBeSMin(SE, RightSCEV)) { 804 // TODO: this restriction is easily removable -- we just have to 805 // remember that the icmp was an sgt and not an sge. 806 FailureReason = "limit may overflow when coercing sge to sgt"; 807 return None; 808 } 809 810 IRBuilder<> B(&*Preheader->rbegin()); 811 RightValue = B.CreateSub(RightValue, One); 812 } 813 } 814 815 const SCEV *StartNext = IndVarNext->getStart(); 816 const SCEV *Addend = SE.getNegativeSCEV(IndVarNext->getStepRecurrence(SE)); 817 const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); 818 819 BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); 820 821 assert(SE.getLoopDisposition(LatchCount, &L) == 822 ScalarEvolution::LoopInvariant && 823 "loop variant exit count doesn't make sense!"); 824 825 assert(!L.contains(LatchExit) && "expected an exit block!"); 826 const DataLayout &DL = Preheader->getModule()->getDataLayout(); 827 Value *IndVarStartV = 828 SCEVExpander(SE, DL, "irce") 829 .expandCodeFor(IndVarStart, IndVarTy, &*Preheader->rbegin()); 830 IndVarStartV->setName("indvar.start"); 831 832 LoopStructure Result; 833 834 Result.Tag = "main"; 835 Result.Header = Header; 836 Result.Latch = Latch; 837 Result.LatchBr = LatchBr; 838 Result.LatchExit = LatchExit; 839 Result.LatchBrExitIdx = LatchBrExitIdx; 840 Result.IndVarStart = IndVarStartV; 841 Result.IndVarNext = LeftValue; 842 Result.IndVarIncreasing = IsIncreasing; 843 Result.LoopExitAt = RightValue; 844 845 FailureReason = nullptr; 846 847 return Result; 848} 849 850Optional<LoopConstrainer::SubRanges> 851LoopConstrainer::calculateSubRanges() const { 852 IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType()); 853 854 if (Range.getType() != Ty) 855 return None; 856 857 LoopConstrainer::SubRanges Result; 858 859 // I think we can be more aggressive here and make this nuw / nsw if the 860 // addition that feeds into the icmp for the latch's terminating branch is nuw 861 // / nsw. In any case, a wrapping 2's complement addition is safe. 862 ConstantInt *One = ConstantInt::get(Ty, 1); 863 const SCEV *Start = SE.getSCEV(MainLoopStructure.IndVarStart); 864 const SCEV *End = SE.getSCEV(MainLoopStructure.LoopExitAt); 865 866 bool Increasing = MainLoopStructure.IndVarIncreasing; 867 868 // We compute `Smallest` and `Greatest` such that [Smallest, Greatest) is the 869 // range of values the induction variable takes. 870 871 const SCEV *Smallest = nullptr, *Greatest = nullptr; 872 873 if (Increasing) { 874 Smallest = Start; 875 Greatest = End; 876 } else { 877 // These two computations may sign-overflow. Here is why that is okay: 878 // 879 // We know that the induction variable does not sign-overflow on any 880 // iteration except the last one, and it starts at `Start` and ends at 881 // `End`, decrementing by one every time. 882 // 883 // * if `Smallest` sign-overflows we know `End` is `INT_SMAX`. Since the 884 // induction variable is decreasing we know that that the smallest value 885 // the loop body is actually executed with is `INT_SMIN` == `Smallest`. 886 // 887 // * if `Greatest` sign-overflows, we know it can only be `INT_SMIN`. In 888 // that case, `Clamp` will always return `Smallest` and 889 // [`Result.LowLimit`, `Result.HighLimit`) = [`Smallest`, `Smallest`) 890 // will be an empty range. Returning an empty range is always safe. 891 // 892 893 Smallest = SE.getAddExpr(End, SE.getSCEV(One)); 894 Greatest = SE.getAddExpr(Start, SE.getSCEV(One)); 895 } 896 897 auto Clamp = [this, Smallest, Greatest](const SCEV *S) { 898 return SE.getSMaxExpr(Smallest, SE.getSMinExpr(Greatest, S)); 899 }; 900 901 // In some cases we can prove that we don't need a pre or post loop 902 903 bool ProvablyNoPreloop = 904 SE.isKnownPredicate(ICmpInst::ICMP_SLE, Range.getBegin(), Smallest); 905 if (!ProvablyNoPreloop) 906 Result.LowLimit = Clamp(Range.getBegin()); 907 908 bool ProvablyNoPostLoop = 909 SE.isKnownPredicate(ICmpInst::ICMP_SLE, Greatest, Range.getEnd()); 910 if (!ProvablyNoPostLoop) 911 Result.HighLimit = Clamp(Range.getEnd()); 912 913 return Result; 914} 915 916void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, 917 const char *Tag) const { 918 for (BasicBlock *BB : OriginalLoop.getBlocks()) { 919 BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F); 920 Result.Blocks.push_back(Clone); 921 Result.Map[BB] = Clone; 922 } 923 924 auto GetClonedValue = [&Result](Value *V) { 925 assert(V && "null values not in domain!"); 926 auto It = Result.Map.find(V); 927 if (It == Result.Map.end()) 928 return V; 929 return static_cast<Value *>(It->second); 930 }; 931 932 Result.Structure = MainLoopStructure.map(GetClonedValue); 933 Result.Structure.Tag = Tag; 934 935 for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) { 936 BasicBlock *ClonedBB = Result.Blocks[i]; 937 BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i]; 938 939 assert(Result.Map[OriginalBB] == ClonedBB && "invariant!"); 940 941 for (Instruction &I : *ClonedBB) 942 RemapInstruction(&I, Result.Map, 943 RF_NoModuleLevelChanges | RF_IgnoreMissingEntries); 944 945 // Exit blocks will now have one more predecessor and their PHI nodes need 946 // to be edited to reflect that. No phi nodes need to be introduced because 947 // the loop is in LCSSA. 948 949 for (auto SBBI = succ_begin(OriginalBB), SBBE = succ_end(OriginalBB); 950 SBBI != SBBE; ++SBBI) { 951 952 if (OriginalLoop.contains(*SBBI)) 953 continue; // not an exit block 954 955 for (Instruction &I : **SBBI) { 956 if (!isa<PHINode>(&I)) 957 break; 958 959 PHINode *PN = cast<PHINode>(&I); 960 Value *OldIncoming = PN->getIncomingValueForBlock(OriginalBB); 961 PN->addIncoming(GetClonedValue(OldIncoming), ClonedBB); 962 } 963 } 964 } 965} 966 967LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( 968 const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt, 969 BasicBlock *ContinuationBlock) const { 970 971 // We start with a loop with a single latch: 972 // 973 // +--------------------+ 974 // | | 975 // | preheader | 976 // | | 977 // +--------+-----------+ 978 // | ----------------\ 979 // | / | 980 // +--------v----v------+ | 981 // | | | 982 // | header | | 983 // | | | 984 // +--------------------+ | 985 // | 986 // ..... | 987 // | 988 // +--------------------+ | 989 // | | | 990 // | latch >----------/ 991 // | | 992 // +-------v------------+ 993 // | 994 // | 995 // | +--------------------+ 996 // | | | 997 // +---> original exit | 998 // | | 999 // +--------------------+ 1000 // 1001 // We change the control flow to look like 1002 // 1003 // 1004 // +--------------------+ 1005 // | | 1006 // | preheader >-------------------------+ 1007 // | | | 1008 // +--------v-----------+ | 1009 // | /-------------+ | 1010 // | / | | 1011 // +--------v--v--------+ | | 1012 // | | | | 1013 // | header | | +--------+ | 1014 // | | | | | | 1015 // +--------------------+ | | +-----v-----v-----------+ 1016 // | | | | 1017 // | | | .pseudo.exit | 1018 // | | | | 1019 // | | +-----------v-----------+ 1020 // | | | 1021 // ..... | | | 1022 // | | +--------v-------------+ 1023 // +--------------------+ | | | | 1024 // | | | | | ContinuationBlock | 1025 // | latch >------+ | | | 1026 // | | | +----------------------+ 1027 // +---------v----------+ | 1028 // | | 1029 // | | 1030 // | +---------------^-----+ 1031 // | | | 1032 // +-----> .exit.selector | 1033 // | | 1034 // +----------v----------+ 1035 // | 1036 // +--------------------+ | 1037 // | | | 1038 // | original exit <----+ 1039 // | | 1040 // +--------------------+ 1041 // 1042 1043 RewrittenRangeInfo RRI; 1044 1045 auto BBInsertLocation = std::next(Function::iterator(LS.Latch)); 1046 RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector", 1047 &F, BBInsertLocation); 1048 RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F, 1049 BBInsertLocation); 1050 1051 BranchInst *PreheaderJump = cast<BranchInst>(&*Preheader->rbegin()); 1052 bool Increasing = LS.IndVarIncreasing; 1053 1054 IRBuilder<> B(PreheaderJump); 1055 1056 // EnterLoopCond - is it okay to start executing this `LS'? 1057 Value *EnterLoopCond = Increasing 1058 ? B.CreateICmpSLT(LS.IndVarStart, ExitSubloopAt) 1059 : B.CreateICmpSGT(LS.IndVarStart, ExitSubloopAt); 1060 1061 B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); 1062 PreheaderJump->eraseFromParent(); 1063 1064 LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); 1065 B.SetInsertPoint(LS.LatchBr); 1066 Value *TakeBackedgeLoopCond = 1067 Increasing ? B.CreateICmpSLT(LS.IndVarNext, ExitSubloopAt) 1068 : B.CreateICmpSGT(LS.IndVarNext, ExitSubloopAt); 1069 Value *CondForBranch = LS.LatchBrExitIdx == 1 1070 ? TakeBackedgeLoopCond 1071 : B.CreateNot(TakeBackedgeLoopCond); 1072 1073 LS.LatchBr->setCondition(CondForBranch); 1074 1075 B.SetInsertPoint(RRI.ExitSelector); 1076 1077 // IterationsLeft - are there any more iterations left, given the original 1078 // upper bound on the induction variable? If not, we branch to the "real" 1079 // exit. 1080 Value *IterationsLeft = Increasing 1081 ? B.CreateICmpSLT(LS.IndVarNext, LS.LoopExitAt) 1082 : B.CreateICmpSGT(LS.IndVarNext, LS.LoopExitAt); 1083 B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); 1084 1085 BranchInst *BranchToContinuation = 1086 BranchInst::Create(ContinuationBlock, RRI.PseudoExit); 1087 1088 // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of 1089 // each of the PHI nodes in the loop header. This feeds into the initial 1090 // value of the same PHI nodes if/when we continue execution. 1091 for (Instruction &I : *LS.Header) { 1092 if (!isa<PHINode>(&I)) 1093 break; 1094 1095 PHINode *PN = cast<PHINode>(&I); 1096 1097 PHINode *NewPHI = PHINode::Create(PN->getType(), 2, PN->getName() + ".copy", 1098 BranchToContinuation); 1099 1100 NewPHI->addIncoming(PN->getIncomingValueForBlock(Preheader), Preheader); 1101 NewPHI->addIncoming(PN->getIncomingValueForBlock(LS.Latch), 1102 RRI.ExitSelector); 1103 RRI.PHIValuesAtPseudoExit.push_back(NewPHI); 1104 } 1105 1106 RRI.IndVarEnd = PHINode::Create(LS.IndVarNext->getType(), 2, "indvar.end", 1107 BranchToContinuation); 1108 RRI.IndVarEnd->addIncoming(LS.IndVarStart, Preheader); 1109 RRI.IndVarEnd->addIncoming(LS.IndVarNext, RRI.ExitSelector); 1110 1111 // The latch exit now has a branch from `RRI.ExitSelector' instead of 1112 // `LS.Latch'. The PHI nodes need to be updated to reflect that. 1113 for (Instruction &I : *LS.LatchExit) { 1114 if (PHINode *PN = dyn_cast<PHINode>(&I)) 1115 replacePHIBlock(PN, LS.Latch, RRI.ExitSelector); 1116 else 1117 break; 1118 } 1119 1120 return RRI; 1121} 1122 1123void LoopConstrainer::rewriteIncomingValuesForPHIs( 1124 LoopStructure &LS, BasicBlock *ContinuationBlock, 1125 const LoopConstrainer::RewrittenRangeInfo &RRI) const { 1126 1127 unsigned PHIIndex = 0; 1128 for (Instruction &I : *LS.Header) { 1129 if (!isa<PHINode>(&I)) 1130 break; 1131 1132 PHINode *PN = cast<PHINode>(&I); 1133 1134 for (unsigned i = 0, e = PN->getNumIncomingValues(); i < e; ++i) 1135 if (PN->getIncomingBlock(i) == ContinuationBlock) 1136 PN->setIncomingValue(i, RRI.PHIValuesAtPseudoExit[PHIIndex++]); 1137 } 1138 1139 LS.IndVarStart = RRI.IndVarEnd; 1140} 1141 1142BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS, 1143 BasicBlock *OldPreheader, 1144 const char *Tag) const { 1145 1146 BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header); 1147 BranchInst::Create(LS.Header, Preheader); 1148 1149 for (Instruction &I : *LS.Header) { 1150 if (!isa<PHINode>(&I)) 1151 break; 1152 1153 PHINode *PN = cast<PHINode>(&I); 1154 for (unsigned i = 0, e = PN->getNumIncomingValues(); i < e; ++i) 1155 replacePHIBlock(PN, OldPreheader, Preheader); 1156 } 1157 1158 return Preheader; 1159} 1160 1161void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) { 1162 Loop *ParentLoop = OriginalLoop.getParentLoop(); 1163 if (!ParentLoop) 1164 return; 1165 1166 for (BasicBlock *BB : BBs) 1167 ParentLoop->addBasicBlockToLoop(BB, OriginalLoopInfo); 1168} 1169 1170bool LoopConstrainer::run() { 1171 BasicBlock *Preheader = nullptr; 1172 LatchTakenCount = SE.getExitCount(&OriginalLoop, MainLoopStructure.Latch); 1173 Preheader = OriginalLoop.getLoopPreheader(); 1174 assert(!isa<SCEVCouldNotCompute>(LatchTakenCount) && Preheader != nullptr && 1175 "preconditions!"); 1176 1177 OriginalPreheader = Preheader; 1178 MainLoopPreheader = Preheader; 1179 1180 Optional<SubRanges> MaybeSR = calculateSubRanges(); 1181 if (!MaybeSR.hasValue()) { 1182 DEBUG(dbgs() << "irce: could not compute subranges\n"); 1183 return false; 1184 } 1185 1186 SubRanges SR = MaybeSR.getValue(); 1187 bool Increasing = MainLoopStructure.IndVarIncreasing; 1188 IntegerType *IVTy = 1189 cast<IntegerType>(MainLoopStructure.IndVarNext->getType()); 1190 1191 SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce"); 1192 Instruction *InsertPt = OriginalPreheader->getTerminator(); 1193 1194 // It would have been better to make `PreLoop' and `PostLoop' 1195 // `Optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy 1196 // constructor. 1197 ClonedLoop PreLoop, PostLoop; 1198 bool NeedsPreLoop = 1199 Increasing ? SR.LowLimit.hasValue() : SR.HighLimit.hasValue(); 1200 bool NeedsPostLoop = 1201 Increasing ? SR.HighLimit.hasValue() : SR.LowLimit.hasValue(); 1202 1203 Value *ExitPreLoopAt = nullptr; 1204 Value *ExitMainLoopAt = nullptr; 1205 const SCEVConstant *MinusOneS = 1206 cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */)); 1207 1208 if (NeedsPreLoop) { 1209 const SCEV *ExitPreLoopAtSCEV = nullptr; 1210 1211 if (Increasing) 1212 ExitPreLoopAtSCEV = *SR.LowLimit; 1213 else { 1214 if (CanBeSMin(SE, *SR.HighLimit)) { 1215 DEBUG(dbgs() << "irce: could not prove no-overflow when computing " 1216 << "preloop exit limit. HighLimit = " << *(*SR.HighLimit) 1217 << "\n"); 1218 return false; 1219 } 1220 ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); 1221 } 1222 1223 ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt); 1224 ExitPreLoopAt->setName("exit.preloop.at"); 1225 } 1226 1227 if (NeedsPostLoop) { 1228 const SCEV *ExitMainLoopAtSCEV = nullptr; 1229 1230 if (Increasing) 1231 ExitMainLoopAtSCEV = *SR.HighLimit; 1232 else { 1233 if (CanBeSMin(SE, *SR.LowLimit)) { 1234 DEBUG(dbgs() << "irce: could not prove no-overflow when computing " 1235 << "mainloop exit limit. LowLimit = " << *(*SR.LowLimit) 1236 << "\n"); 1237 return false; 1238 } 1239 ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); 1240 } 1241 1242 ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt); 1243 ExitMainLoopAt->setName("exit.mainloop.at"); 1244 } 1245 1246 // We clone these ahead of time so that we don't have to deal with changing 1247 // and temporarily invalid IR as we transform the loops. 1248 if (NeedsPreLoop) 1249 cloneLoop(PreLoop, "preloop"); 1250 if (NeedsPostLoop) 1251 cloneLoop(PostLoop, "postloop"); 1252 1253 RewrittenRangeInfo PreLoopRRI; 1254 1255 if (NeedsPreLoop) { 1256 Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header, 1257 PreLoop.Structure.Header); 1258 1259 MainLoopPreheader = 1260 createPreheader(MainLoopStructure, Preheader, "mainloop"); 1261 PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader, 1262 ExitPreLoopAt, MainLoopPreheader); 1263 rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader, 1264 PreLoopRRI); 1265 } 1266 1267 BasicBlock *PostLoopPreheader = nullptr; 1268 RewrittenRangeInfo PostLoopRRI; 1269 1270 if (NeedsPostLoop) { 1271 PostLoopPreheader = 1272 createPreheader(PostLoop.Structure, Preheader, "postloop"); 1273 PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader, 1274 ExitMainLoopAt, PostLoopPreheader); 1275 rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader, 1276 PostLoopRRI); 1277 } 1278 1279 BasicBlock *NewMainLoopPreheader = 1280 MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr; 1281 BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit, 1282 PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit, 1283 PostLoopRRI.ExitSelector, NewMainLoopPreheader}; 1284 1285 // Some of the above may be nullptr, filter them out before passing to 1286 // addToParentLoopIfNeeded. 1287 auto NewBlocksEnd = 1288 std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr); 1289 1290 addToParentLoopIfNeeded(makeArrayRef(std::begin(NewBlocks), NewBlocksEnd)); 1291 addToParentLoopIfNeeded(PreLoop.Blocks); 1292 addToParentLoopIfNeeded(PostLoop.Blocks); 1293 1294 return true; 1295} 1296 1297/// Computes and returns a range of values for the induction variable (IndVar) 1298/// in which the range check can be safely elided. If it cannot compute such a 1299/// range, returns None. 1300Optional<InductiveRangeCheck::Range> 1301InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE, 1302 const SCEVAddRecExpr *IndVar, 1303 IRBuilder<> &) const { 1304 // IndVar is of the form "A + B * I" (where "I" is the canonical induction 1305 // variable, that may or may not exist as a real llvm::Value in the loop) and 1306 // this inductive range check is a range check on the "C + D * I" ("C" is 1307 // getOffset() and "D" is getScale()). We rewrite the value being range 1308 // checked to "M + N * IndVar" where "N" = "D * B^(-1)" and "M" = "C - NA". 1309 // Currently we support this only for "B" = "D" = { 1 or -1 }, but the code 1310 // can be generalized as needed. 1311 // 1312 // The actual inequalities we solve are of the form 1313 // 1314 // 0 <= M + 1 * IndVar < L given L >= 0 (i.e. N == 1) 1315 // 1316 // The inequality is satisfied by -M <= IndVar < (L - M) [^1]. All additions 1317 // and subtractions are twos-complement wrapping and comparisons are signed. 1318 // 1319 // Proof: 1320 // 1321 // If there exists IndVar such that -M <= IndVar < (L - M) then it follows 1322 // that -M <= (-M + L) [== Eq. 1]. Since L >= 0, if (-M + L) sign-overflows 1323 // then (-M + L) < (-M). Hence by [Eq. 1], (-M + L) could not have 1324 // overflown. 1325 // 1326 // This means IndVar = t + (-M) for t in [0, L). Hence (IndVar + M) = t. 1327 // Hence 0 <= (IndVar + M) < L 1328 1329 // [^1]: Note that the solution does _not_ apply if L < 0; consider values M = 1330 // 127, IndVar = 126 and L = -2 in an i8 world. 1331 1332 if (!IndVar->isAffine()) 1333 return None; 1334 1335 const SCEV *A = IndVar->getStart(); 1336 const SCEVConstant *B = dyn_cast<SCEVConstant>(IndVar->getStepRecurrence(SE)); 1337 if (!B) 1338 return None; 1339 1340 const SCEV *C = getOffset(); 1341 const SCEVConstant *D = dyn_cast<SCEVConstant>(getScale()); 1342 if (D != B) 1343 return None; 1344 1345 ConstantInt *ConstD = D->getValue(); 1346 if (!(ConstD->isMinusOne() || ConstD->isOne())) 1347 return None; 1348 1349 const SCEV *M = SE.getMinusSCEV(C, A); 1350 1351 const SCEV *Begin = SE.getNegativeSCEV(M); 1352 const SCEV *UpperLimit = nullptr; 1353 1354 // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L". 1355 // We can potentially do much better here. 1356 if (Value *V = getLength()) { 1357 UpperLimit = SE.getSCEV(V); 1358 } else { 1359 assert(Kind == InductiveRangeCheck::RANGE_CHECK_LOWER && "invariant!"); 1360 unsigned BitWidth = cast<IntegerType>(IndVar->getType())->getBitWidth(); 1361 UpperLimit = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); 1362 } 1363 1364 const SCEV *End = SE.getMinusSCEV(UpperLimit, M); 1365 return InductiveRangeCheck::Range(Begin, End); 1366} 1367 1368static Optional<InductiveRangeCheck::Range> 1369IntersectRange(ScalarEvolution &SE, 1370 const Optional<InductiveRangeCheck::Range> &R1, 1371 const InductiveRangeCheck::Range &R2, IRBuilder<> &B) { 1372 if (!R1.hasValue()) 1373 return R2; 1374 auto &R1Value = R1.getValue(); 1375 1376 // TODO: we could widen the smaller range and have this work; but for now we 1377 // bail out to keep things simple. 1378 if (R1Value.getType() != R2.getType()) 1379 return None; 1380 1381 const SCEV *NewBegin = SE.getSMaxExpr(R1Value.getBegin(), R2.getBegin()); 1382 const SCEV *NewEnd = SE.getSMinExpr(R1Value.getEnd(), R2.getEnd()); 1383 1384 return InductiveRangeCheck::Range(NewBegin, NewEnd); 1385} 1386 1387bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { 1388 if (L->getBlocks().size() >= LoopSizeCutoff) { 1389 DEBUG(dbgs() << "irce: giving up constraining loop, too large\n";); 1390 return false; 1391 } 1392 1393 BasicBlock *Preheader = L->getLoopPreheader(); 1394 if (!Preheader) { 1395 DEBUG(dbgs() << "irce: loop has no preheader, leaving\n"); 1396 return false; 1397 } 1398 1399 LLVMContext &Context = Preheader->getContext(); 1400 InductiveRangeCheck::AllocatorTy IRCAlloc; 1401 SmallVector<InductiveRangeCheck *, 16> RangeChecks; 1402 ScalarEvolution &SE = getAnalysis<ScalarEvolution>(); 1403 BranchProbabilityInfo &BPI = getAnalysis<BranchProbabilityInfo>(); 1404 1405 for (auto BBI : L->getBlocks()) 1406 if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator())) 1407 if (InductiveRangeCheck *IRC = 1408 InductiveRangeCheck::create(IRCAlloc, TBI, L, SE, BPI)) 1409 RangeChecks.push_back(IRC); 1410 1411 if (RangeChecks.empty()) 1412 return false; 1413 1414 auto PrintRecognizedRangeChecks = [&](raw_ostream &OS) { 1415 OS << "irce: looking at loop "; L->print(OS); 1416 OS << "irce: loop has " << RangeChecks.size() 1417 << " inductive range checks: \n"; 1418 for (InductiveRangeCheck *IRC : RangeChecks) 1419 IRC->print(OS); 1420 }; 1421 1422 DEBUG(PrintRecognizedRangeChecks(dbgs())); 1423 1424 if (PrintRangeChecks) 1425 PrintRecognizedRangeChecks(errs()); 1426 1427 const char *FailureReason = nullptr; 1428 Optional<LoopStructure> MaybeLoopStructure = 1429 LoopStructure::parseLoopStructure(SE, BPI, *L, FailureReason); 1430 if (!MaybeLoopStructure.hasValue()) { 1431 DEBUG(dbgs() << "irce: could not parse loop structure: " << FailureReason 1432 << "\n";); 1433 return false; 1434 } 1435 LoopStructure LS = MaybeLoopStructure.getValue(); 1436 bool Increasing = LS.IndVarIncreasing; 1437 const SCEV *MinusOne = 1438 SE.getConstant(LS.IndVarNext->getType(), Increasing ? -1 : 1, true); 1439 const SCEVAddRecExpr *IndVar = 1440 cast<SCEVAddRecExpr>(SE.getAddExpr(SE.getSCEV(LS.IndVarNext), MinusOne)); 1441 1442 Optional<InductiveRangeCheck::Range> SafeIterRange; 1443 Instruction *ExprInsertPt = Preheader->getTerminator(); 1444 1445 SmallVector<InductiveRangeCheck *, 4> RangeChecksToEliminate; 1446 1447 IRBuilder<> B(ExprInsertPt); 1448 for (InductiveRangeCheck *IRC : RangeChecks) { 1449 auto Result = IRC->computeSafeIterationSpace(SE, IndVar, B); 1450 if (Result.hasValue()) { 1451 auto MaybeSafeIterRange = 1452 IntersectRange(SE, SafeIterRange, Result.getValue(), B); 1453 if (MaybeSafeIterRange.hasValue()) { 1454 RangeChecksToEliminate.push_back(IRC); 1455 SafeIterRange = MaybeSafeIterRange.getValue(); 1456 } 1457 } 1458 } 1459 1460 if (!SafeIterRange.hasValue()) 1461 return false; 1462 1463 LoopConstrainer LC(*L, getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), LS, 1464 SE, SafeIterRange.getValue()); 1465 bool Changed = LC.run(); 1466 1467 if (Changed) { 1468 auto PrintConstrainedLoopInfo = [L]() { 1469 dbgs() << "irce: in function "; 1470 dbgs() << L->getHeader()->getParent()->getName() << ": "; 1471 dbgs() << "constrained "; 1472 L->print(dbgs()); 1473 }; 1474 1475 DEBUG(PrintConstrainedLoopInfo()); 1476 1477 if (PrintChangedLoops) 1478 PrintConstrainedLoopInfo(); 1479 1480 // Optimize away the now-redundant range checks. 1481 1482 for (InductiveRangeCheck *IRC : RangeChecksToEliminate) { 1483 ConstantInt *FoldedRangeCheck = IRC->getPassingDirection() 1484 ? ConstantInt::getTrue(Context) 1485 : ConstantInt::getFalse(Context); 1486 IRC->getBranch()->setCondition(FoldedRangeCheck); 1487 } 1488 } 1489 1490 return Changed; 1491} 1492 1493Pass *llvm::createInductiveRangeCheckEliminationPass() { 1494 return new InductiveRangeCheckElimination; 1495} 1496