1//===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file contains the implementation of the scalar evolution analysis
11// engine, which is used primarily to analyze expressions involving induction
12// variables in loops.
13//
14// There are several aspects to this library.  First is the representation of
15// scalar expressions, which are represented as subclasses of the SCEV class.
16// These classes are used to represent certain types of subexpressions that we
17// can handle. We only create one SCEV of a particular shape, so
18// pointer-comparisons for equality are legal.
19//
20// One important aspect of the SCEV objects is that they are never cyclic, even
21// if there is a cycle in the dataflow for an expression (ie, a PHI node).  If
22// the PHI node is one of the idioms that we can represent (e.g., a polynomial
23// recurrence) then we represent it directly as a recurrence node, otherwise we
24// represent it as a SCEVUnknown node.
25//
26// In addition to being able to represent expressions of various types, we also
27// have folders that are used to build the *canonical* representation for a
28// particular expression.  These folders are capable of using a variety of
29// rewrite rules to simplify the expressions.
30//
31// Once the folders are defined, we can implement the more interesting
32// higher-level code, such as the code that recognizes PHI nodes of various
33// types, computes the execution count of a loop, etc.
34//
35// TODO: We should use these routines and value representations to implement
36// dependence analysis!
37//
38//===----------------------------------------------------------------------===//
39//
40// There are several good references for the techniques used in this analysis.
41//
42//  Chains of recurrences -- a method to expedite the evaluation
43//  of closed-form functions
44//  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
45//
46//  On computational properties of chains of recurrences
47//  Eugene V. Zima
48//
49//  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
50//  Robert A. van Engelen
51//
52//  Efficient Symbolic Analysis for Optimizing Compilers
53//  Robert A. van Engelen
54//
55//  Using the chains of recurrences algebra for data dependence testing and
56//  induction variable substitution
57//  MS Thesis, Johnie Birch
58//
59//===----------------------------------------------------------------------===//
60
61#define DEBUG_TYPE "scalar-evolution"
62#include "llvm/Analysis/ScalarEvolutionExpressions.h"
63#include "llvm/Constants.h"
64#include "llvm/DerivedTypes.h"
65#include "llvm/GlobalVariable.h"
66#include "llvm/GlobalAlias.h"
67#include "llvm/Instructions.h"
68#include "llvm/LLVMContext.h"
69#include "llvm/Operator.h"
70#include "llvm/Analysis/ConstantFolding.h"
71#include "llvm/Analysis/Dominators.h"
72#include "llvm/Analysis/InstructionSimplify.h"
73#include "llvm/Analysis/LoopInfo.h"
74#include "llvm/Analysis/ValueTracking.h"
75#include "llvm/Assembly/Writer.h"
76#include "llvm/Target/TargetData.h"
77#include "llvm/Target/TargetLibraryInfo.h"
78#include "llvm/Support/CommandLine.h"
79#include "llvm/Support/ConstantRange.h"
80#include "llvm/Support/Debug.h"
81#include "llvm/Support/ErrorHandling.h"
82#include "llvm/Support/GetElementPtrTypeIterator.h"
83#include "llvm/Support/InstIterator.h"
84#include "llvm/Support/MathExtras.h"
85#include "llvm/Support/raw_ostream.h"
86#include "llvm/ADT/Statistic.h"
87#include "llvm/ADT/STLExtras.h"
88#include "llvm/ADT/SmallPtrSet.h"
89#include <algorithm>
90using namespace llvm;
91
92STATISTIC(NumArrayLenItCounts,
93          "Number of trip counts computed with array length");
94STATISTIC(NumTripCountsComputed,
95          "Number of loops with predictable loop counts");
96STATISTIC(NumTripCountsNotComputed,
97          "Number of loops without predictable loop counts");
98STATISTIC(NumBruteForceTripCountsComputed,
99          "Number of loops with trip counts computed by force");
100
101static cl::opt<unsigned>
102MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
103                        cl::desc("Maximum number of iterations SCEV will "
104                                 "symbolically execute a constant "
105                                 "derived loop"),
106                        cl::init(100));
107
108INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution",
109                "Scalar Evolution Analysis", false, true)
110INITIALIZE_PASS_DEPENDENCY(LoopInfo)
111INITIALIZE_PASS_DEPENDENCY(DominatorTree)
112INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo)
113INITIALIZE_PASS_END(ScalarEvolution, "scalar-evolution",
114                "Scalar Evolution Analysis", false, true)
115char ScalarEvolution::ID = 0;
116
117//===----------------------------------------------------------------------===//
118//                           SCEV class definitions
119//===----------------------------------------------------------------------===//
120
121//===----------------------------------------------------------------------===//
122// Implementation of the SCEV class.
123//
124
125void SCEV::dump() const {
126  print(dbgs());
127  dbgs() << '\n';
128}
129
130void SCEV::print(raw_ostream &OS) const {
131  switch (getSCEVType()) {
132  case scConstant:
133    WriteAsOperand(OS, cast<SCEVConstant>(this)->getValue(), false);
134    return;
135  case scTruncate: {
136    const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
137    const SCEV *Op = Trunc->getOperand();
138    OS << "(trunc " << *Op->getType() << " " << *Op << " to "
139       << *Trunc->getType() << ")";
140    return;
141  }
142  case scZeroExtend: {
143    const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
144    const SCEV *Op = ZExt->getOperand();
145    OS << "(zext " << *Op->getType() << " " << *Op << " to "
146       << *ZExt->getType() << ")";
147    return;
148  }
149  case scSignExtend: {
150    const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
151    const SCEV *Op = SExt->getOperand();
152    OS << "(sext " << *Op->getType() << " " << *Op << " to "
153       << *SExt->getType() << ")";
154    return;
155  }
156  case scAddRecExpr: {
157    const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
158    OS << "{" << *AR->getOperand(0);
159    for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
160      OS << ",+," << *AR->getOperand(i);
161    OS << "}<";
162    if (AR->getNoWrapFlags(FlagNUW))
163      OS << "nuw><";
164    if (AR->getNoWrapFlags(FlagNSW))
165      OS << "nsw><";
166    if (AR->getNoWrapFlags(FlagNW) &&
167        !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
168      OS << "nw><";
169    WriteAsOperand(OS, AR->getLoop()->getHeader(), /*PrintType=*/false);
170    OS << ">";
171    return;
172  }
173  case scAddExpr:
174  case scMulExpr:
175  case scUMaxExpr:
176  case scSMaxExpr: {
177    const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
178    const char *OpStr = 0;
179    switch (NAry->getSCEVType()) {
180    case scAddExpr: OpStr = " + "; break;
181    case scMulExpr: OpStr = " * "; break;
182    case scUMaxExpr: OpStr = " umax "; break;
183    case scSMaxExpr: OpStr = " smax "; break;
184    }
185    OS << "(";
186    for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
187         I != E; ++I) {
188      OS << **I;
189      if (llvm::next(I) != E)
190        OS << OpStr;
191    }
192    OS << ")";
193    switch (NAry->getSCEVType()) {
194    case scAddExpr:
195    case scMulExpr:
196      if (NAry->getNoWrapFlags(FlagNUW))
197        OS << "<nuw>";
198      if (NAry->getNoWrapFlags(FlagNSW))
199        OS << "<nsw>";
200    }
201    return;
202  }
203  case scUDivExpr: {
204    const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
205    OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
206    return;
207  }
208  case scUnknown: {
209    const SCEVUnknown *U = cast<SCEVUnknown>(this);
210    Type *AllocTy;
211    if (U->isSizeOf(AllocTy)) {
212      OS << "sizeof(" << *AllocTy << ")";
213      return;
214    }
215    if (U->isAlignOf(AllocTy)) {
216      OS << "alignof(" << *AllocTy << ")";
217      return;
218    }
219
220    Type *CTy;
221    Constant *FieldNo;
222    if (U->isOffsetOf(CTy, FieldNo)) {
223      OS << "offsetof(" << *CTy << ", ";
224      WriteAsOperand(OS, FieldNo, false);
225      OS << ")";
226      return;
227    }
228
229    // Otherwise just print it normally.
230    WriteAsOperand(OS, U->getValue(), false);
231    return;
232  }
233  case scCouldNotCompute:
234    OS << "***COULDNOTCOMPUTE***";
235    return;
236  default: break;
237  }
238  llvm_unreachable("Unknown SCEV kind!");
239}
240
241Type *SCEV::getType() const {
242  switch (getSCEVType()) {
243  case scConstant:
244    return cast<SCEVConstant>(this)->getType();
245  case scTruncate:
246  case scZeroExtend:
247  case scSignExtend:
248    return cast<SCEVCastExpr>(this)->getType();
249  case scAddRecExpr:
250  case scMulExpr:
251  case scUMaxExpr:
252  case scSMaxExpr:
253    return cast<SCEVNAryExpr>(this)->getType();
254  case scAddExpr:
255    return cast<SCEVAddExpr>(this)->getType();
256  case scUDivExpr:
257    return cast<SCEVUDivExpr>(this)->getType();
258  case scUnknown:
259    return cast<SCEVUnknown>(this)->getType();
260  case scCouldNotCompute:
261    llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
262  default:
263    llvm_unreachable("Unknown SCEV kind!");
264  }
265}
266
267bool SCEV::isZero() const {
268  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
269    return SC->getValue()->isZero();
270  return false;
271}
272
273bool SCEV::isOne() const {
274  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
275    return SC->getValue()->isOne();
276  return false;
277}
278
279bool SCEV::isAllOnesValue() const {
280  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
281    return SC->getValue()->isAllOnesValue();
282  return false;
283}
284
285/// isNonConstantNegative - Return true if the specified scev is negated, but
286/// not a constant.
287bool SCEV::isNonConstantNegative() const {
288  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
289  if (!Mul) return false;
290
291  // If there is a constant factor, it will be first.
292  const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
293  if (!SC) return false;
294
295  // Return true if the value is negative, this matches things like (-42 * V).
296  return SC->getValue()->getValue().isNegative();
297}
298
299SCEVCouldNotCompute::SCEVCouldNotCompute() :
300  SCEV(FoldingSetNodeIDRef(), scCouldNotCompute) {}
301
302bool SCEVCouldNotCompute::classof(const SCEV *S) {
303  return S->getSCEVType() == scCouldNotCompute;
304}
305
306const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
307  FoldingSetNodeID ID;
308  ID.AddInteger(scConstant);
309  ID.AddPointer(V);
310  void *IP = 0;
311  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
312  SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
313  UniqueSCEVs.InsertNode(S, IP);
314  return S;
315}
316
317const SCEV *ScalarEvolution::getConstant(const APInt& Val) {
318  return getConstant(ConstantInt::get(getContext(), Val));
319}
320
321const SCEV *
322ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
323  IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
324  return getConstant(ConstantInt::get(ITy, V, isSigned));
325}
326
327SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID,
328                           unsigned SCEVTy, const SCEV *op, Type *ty)
329  : SCEV(ID, SCEVTy), Op(op), Ty(ty) {}
330
331SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID,
332                                   const SCEV *op, Type *ty)
333  : SCEVCastExpr(ID, scTruncate, op, ty) {
334  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
335         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
336         "Cannot truncate non-integer value!");
337}
338
339SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
340                                       const SCEV *op, Type *ty)
341  : SCEVCastExpr(ID, scZeroExtend, op, ty) {
342  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
343         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
344         "Cannot zero extend non-integer value!");
345}
346
347SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
348                                       const SCEV *op, Type *ty)
349  : SCEVCastExpr(ID, scSignExtend, op, ty) {
350  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
351         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
352         "Cannot sign extend non-integer value!");
353}
354
355void SCEVUnknown::deleted() {
356  // Clear this SCEVUnknown from various maps.
357  SE->forgetMemoizedResults(this);
358
359  // Remove this SCEVUnknown from the uniquing map.
360  SE->UniqueSCEVs.RemoveNode(this);
361
362  // Release the value.
363  setValPtr(0);
364}
365
366void SCEVUnknown::allUsesReplacedWith(Value *New) {
367  // Clear this SCEVUnknown from various maps.
368  SE->forgetMemoizedResults(this);
369
370  // Remove this SCEVUnknown from the uniquing map.
371  SE->UniqueSCEVs.RemoveNode(this);
372
373  // Update this SCEVUnknown to point to the new value. This is needed
374  // because there may still be outstanding SCEVs which still point to
375  // this SCEVUnknown.
376  setValPtr(New);
377}
378
379bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
380  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
381    if (VCE->getOpcode() == Instruction::PtrToInt)
382      if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
383        if (CE->getOpcode() == Instruction::GetElementPtr &&
384            CE->getOperand(0)->isNullValue() &&
385            CE->getNumOperands() == 2)
386          if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
387            if (CI->isOne()) {
388              AllocTy = cast<PointerType>(CE->getOperand(0)->getType())
389                                 ->getElementType();
390              return true;
391            }
392
393  return false;
394}
395
396bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
397  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
398    if (VCE->getOpcode() == Instruction::PtrToInt)
399      if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
400        if (CE->getOpcode() == Instruction::GetElementPtr &&
401            CE->getOperand(0)->isNullValue()) {
402          Type *Ty =
403            cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
404          if (StructType *STy = dyn_cast<StructType>(Ty))
405            if (!STy->isPacked() &&
406                CE->getNumOperands() == 3 &&
407                CE->getOperand(1)->isNullValue()) {
408              if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
409                if (CI->isOne() &&
410                    STy->getNumElements() == 2 &&
411                    STy->getElementType(0)->isIntegerTy(1)) {
412                  AllocTy = STy->getElementType(1);
413                  return true;
414                }
415            }
416        }
417
418  return false;
419}
420
421bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
422  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
423    if (VCE->getOpcode() == Instruction::PtrToInt)
424      if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
425        if (CE->getOpcode() == Instruction::GetElementPtr &&
426            CE->getNumOperands() == 3 &&
427            CE->getOperand(0)->isNullValue() &&
428            CE->getOperand(1)->isNullValue()) {
429          Type *Ty =
430            cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
431          // Ignore vector types here so that ScalarEvolutionExpander doesn't
432          // emit getelementptrs that index into vectors.
433          if (Ty->isStructTy() || Ty->isArrayTy()) {
434            CTy = Ty;
435            FieldNo = CE->getOperand(2);
436            return true;
437          }
438        }
439
440  return false;
441}
442
443//===----------------------------------------------------------------------===//
444//                               SCEV Utilities
445//===----------------------------------------------------------------------===//
446
447namespace {
448  /// SCEVComplexityCompare - Return true if the complexity of the LHS is less
449  /// than the complexity of the RHS.  This comparator is used to canonicalize
450  /// expressions.
451  class SCEVComplexityCompare {
452    const LoopInfo *const LI;
453  public:
454    explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {}
455
456    // Return true or false if LHS is less than, or at least RHS, respectively.
457    bool operator()(const SCEV *LHS, const SCEV *RHS) const {
458      return compare(LHS, RHS) < 0;
459    }
460
461    // Return negative, zero, or positive, if LHS is less than, equal to, or
462    // greater than RHS, respectively. A three-way result allows recursive
463    // comparisons to be more efficient.
464    int compare(const SCEV *LHS, const SCEV *RHS) const {
465      // Fast-path: SCEVs are uniqued so we can do a quick equality check.
466      if (LHS == RHS)
467        return 0;
468
469      // Primarily, sort the SCEVs by their getSCEVType().
470      unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
471      if (LType != RType)
472        return (int)LType - (int)RType;
473
474      // Aside from the getSCEVType() ordering, the particular ordering
475      // isn't very important except that it's beneficial to be consistent,
476      // so that (a + b) and (b + a) don't end up as different expressions.
477      switch (LType) {
478      case scUnknown: {
479        const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
480        const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
481
482        // Sort SCEVUnknown values with some loose heuristics. TODO: This is
483        // not as complete as it could be.
484        const Value *LV = LU->getValue(), *RV = RU->getValue();
485
486        // Order pointer values after integer values. This helps SCEVExpander
487        // form GEPs.
488        bool LIsPointer = LV->getType()->isPointerTy(),
489             RIsPointer = RV->getType()->isPointerTy();
490        if (LIsPointer != RIsPointer)
491          return (int)LIsPointer - (int)RIsPointer;
492
493        // Compare getValueID values.
494        unsigned LID = LV->getValueID(),
495                 RID = RV->getValueID();
496        if (LID != RID)
497          return (int)LID - (int)RID;
498
499        // Sort arguments by their position.
500        if (const Argument *LA = dyn_cast<Argument>(LV)) {
501          const Argument *RA = cast<Argument>(RV);
502          unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
503          return (int)LArgNo - (int)RArgNo;
504        }
505
506        // For instructions, compare their loop depth, and their operand
507        // count.  This is pretty loose.
508        if (const Instruction *LInst = dyn_cast<Instruction>(LV)) {
509          const Instruction *RInst = cast<Instruction>(RV);
510
511          // Compare loop depths.
512          const BasicBlock *LParent = LInst->getParent(),
513                           *RParent = RInst->getParent();
514          if (LParent != RParent) {
515            unsigned LDepth = LI->getLoopDepth(LParent),
516                     RDepth = LI->getLoopDepth(RParent);
517            if (LDepth != RDepth)
518              return (int)LDepth - (int)RDepth;
519          }
520
521          // Compare the number of operands.
522          unsigned LNumOps = LInst->getNumOperands(),
523                   RNumOps = RInst->getNumOperands();
524          return (int)LNumOps - (int)RNumOps;
525        }
526
527        return 0;
528      }
529
530      case scConstant: {
531        const SCEVConstant *LC = cast<SCEVConstant>(LHS);
532        const SCEVConstant *RC = cast<SCEVConstant>(RHS);
533
534        // Compare constant values.
535        const APInt &LA = LC->getValue()->getValue();
536        const APInt &RA = RC->getValue()->getValue();
537        unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
538        if (LBitWidth != RBitWidth)
539          return (int)LBitWidth - (int)RBitWidth;
540        return LA.ult(RA) ? -1 : 1;
541      }
542
543      case scAddRecExpr: {
544        const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
545        const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
546
547        // Compare addrec loop depths.
548        const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
549        if (LLoop != RLoop) {
550          unsigned LDepth = LLoop->getLoopDepth(),
551                   RDepth = RLoop->getLoopDepth();
552          if (LDepth != RDepth)
553            return (int)LDepth - (int)RDepth;
554        }
555
556        // Addrec complexity grows with operand count.
557        unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
558        if (LNumOps != RNumOps)
559          return (int)LNumOps - (int)RNumOps;
560
561        // Lexicographically compare.
562        for (unsigned i = 0; i != LNumOps; ++i) {
563          long X = compare(LA->getOperand(i), RA->getOperand(i));
564          if (X != 0)
565            return X;
566        }
567
568        return 0;
569      }
570
571      case scAddExpr:
572      case scMulExpr:
573      case scSMaxExpr:
574      case scUMaxExpr: {
575        const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
576        const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
577
578        // Lexicographically compare n-ary expressions.
579        unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
580        for (unsigned i = 0; i != LNumOps; ++i) {
581          if (i >= RNumOps)
582            return 1;
583          long X = compare(LC->getOperand(i), RC->getOperand(i));
584          if (X != 0)
585            return X;
586        }
587        return (int)LNumOps - (int)RNumOps;
588      }
589
590      case scUDivExpr: {
591        const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
592        const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
593
594        // Lexicographically compare udiv expressions.
595        long X = compare(LC->getLHS(), RC->getLHS());
596        if (X != 0)
597          return X;
598        return compare(LC->getRHS(), RC->getRHS());
599      }
600
601      case scTruncate:
602      case scZeroExtend:
603      case scSignExtend: {
604        const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
605        const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
606
607        // Compare cast expressions by operand.
608        return compare(LC->getOperand(), RC->getOperand());
609      }
610
611      default:
612        llvm_unreachable("Unknown SCEV kind!");
613      }
614    }
615  };
616}
617
618/// GroupByComplexity - Given a list of SCEV objects, order them by their
619/// complexity, and group objects of the same complexity together by value.
620/// When this routine is finished, we know that any duplicates in the vector are
621/// consecutive and that complexity is monotonically increasing.
622///
623/// Note that we go take special precautions to ensure that we get deterministic
624/// results from this routine.  In other words, we don't want the results of
625/// this to depend on where the addresses of various SCEV objects happened to
626/// land in memory.
627///
628static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
629                              LoopInfo *LI) {
630  if (Ops.size() < 2) return;  // Noop
631  if (Ops.size() == 2) {
632    // This is the common case, which also happens to be trivially simple.
633    // Special case it.
634    const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
635    if (SCEVComplexityCompare(LI)(RHS, LHS))
636      std::swap(LHS, RHS);
637    return;
638  }
639
640  // Do the rough sort by complexity.
641  std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI));
642
643  // Now that we are sorted by complexity, group elements of the same
644  // complexity.  Note that this is, at worst, N^2, but the vector is likely to
645  // be extremely short in practice.  Note that we take this approach because we
646  // do not want to depend on the addresses of the objects we are grouping.
647  for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
648    const SCEV *S = Ops[i];
649    unsigned Complexity = S->getSCEVType();
650
651    // If there are any objects of the same complexity and same value as this
652    // one, group them.
653    for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
654      if (Ops[j] == S) { // Found a duplicate.
655        // Move it to immediately after i'th element.
656        std::swap(Ops[i+1], Ops[j]);
657        ++i;   // no need to rescan it.
658        if (i == e-2) return;  // Done!
659      }
660    }
661  }
662}
663
664
665
666//===----------------------------------------------------------------------===//
667//                      Simple SCEV method implementations
668//===----------------------------------------------------------------------===//
669
670/// BinomialCoefficient - Compute BC(It, K).  The result has width W.
671/// Assume, K > 0.
672static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
673                                       ScalarEvolution &SE,
674                                       Type *ResultTy) {
675  // Handle the simplest case efficiently.
676  if (K == 1)
677    return SE.getTruncateOrZeroExtend(It, ResultTy);
678
679  // We are using the following formula for BC(It, K):
680  //
681  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
682  //
683  // Suppose, W is the bitwidth of the return value.  We must be prepared for
684  // overflow.  Hence, we must assure that the result of our computation is
685  // equal to the accurate one modulo 2^W.  Unfortunately, division isn't
686  // safe in modular arithmetic.
687  //
688  // However, this code doesn't use exactly that formula; the formula it uses
689  // is something like the following, where T is the number of factors of 2 in
690  // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
691  // exponentiation:
692  //
693  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
694  //
695  // This formula is trivially equivalent to the previous formula.  However,
696  // this formula can be implemented much more efficiently.  The trick is that
697  // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
698  // arithmetic.  To do exact division in modular arithmetic, all we have
699  // to do is multiply by the inverse.  Therefore, this step can be done at
700  // width W.
701  //
702  // The next issue is how to safely do the division by 2^T.  The way this
703  // is done is by doing the multiplication step at a width of at least W + T
704  // bits.  This way, the bottom W+T bits of the product are accurate. Then,
705  // when we perform the division by 2^T (which is equivalent to a right shift
706  // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get
707  // truncated out after the division by 2^T.
708  //
709  // In comparison to just directly using the first formula, this technique
710  // is much more efficient; using the first formula requires W * K bits,
711  // but this formula less than W + K bits. Also, the first formula requires
712  // a division step, whereas this formula only requires multiplies and shifts.
713  //
714  // It doesn't matter whether the subtraction step is done in the calculation
715  // width or the input iteration count's width; if the subtraction overflows,
716  // the result must be zero anyway.  We prefer here to do it in the width of
717  // the induction variable because it helps a lot for certain cases; CodeGen
718  // isn't smart enough to ignore the overflow, which leads to much less
719  // efficient code if the width of the subtraction is wider than the native
720  // register width.
721  //
722  // (It's possible to not widen at all by pulling out factors of 2 before
723  // the multiplication; for example, K=2 can be calculated as
724  // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
725  // extra arithmetic, so it's not an obvious win, and it gets
726  // much more complicated for K > 3.)
727
728  // Protection from insane SCEVs; this bound is conservative,
729  // but it probably doesn't matter.
730  if (K > 1000)
731    return SE.getCouldNotCompute();
732
733  unsigned W = SE.getTypeSizeInBits(ResultTy);
734
735  // Calculate K! / 2^T and T; we divide out the factors of two before
736  // multiplying for calculating K! / 2^T to avoid overflow.
737  // Other overflow doesn't matter because we only care about the bottom
738  // W bits of the result.
739  APInt OddFactorial(W, 1);
740  unsigned T = 1;
741  for (unsigned i = 3; i <= K; ++i) {
742    APInt Mult(W, i);
743    unsigned TwoFactors = Mult.countTrailingZeros();
744    T += TwoFactors;
745    Mult = Mult.lshr(TwoFactors);
746    OddFactorial *= Mult;
747  }
748
749  // We need at least W + T bits for the multiplication step
750  unsigned CalculationBits = W + T;
751
752  // Calculate 2^T, at width T+W.
753  APInt DivFactor = APInt(CalculationBits, 1).shl(T);
754
755  // Calculate the multiplicative inverse of K! / 2^T;
756  // this multiplication factor will perform the exact division by
757  // K! / 2^T.
758  APInt Mod = APInt::getSignedMinValue(W+1);
759  APInt MultiplyFactor = OddFactorial.zext(W+1);
760  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
761  MultiplyFactor = MultiplyFactor.trunc(W);
762
763  // Calculate the product, at width T+W
764  IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
765                                                      CalculationBits);
766  const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
767  for (unsigned i = 1; i != K; ++i) {
768    const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
769    Dividend = SE.getMulExpr(Dividend,
770                             SE.getTruncateOrZeroExtend(S, CalculationTy));
771  }
772
773  // Divide by 2^T
774  const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
775
776  // Truncate the result, and divide by K! / 2^T.
777
778  return SE.getMulExpr(SE.getConstant(MultiplyFactor),
779                       SE.getTruncateOrZeroExtend(DivResult, ResultTy));
780}
781
782/// evaluateAtIteration - Return the value of this chain of recurrences at
783/// the specified iteration number.  We can evaluate this recurrence by
784/// multiplying each element in the chain by the binomial coefficient
785/// corresponding to it.  In other words, we can evaluate {A,+,B,+,C,+,D} as:
786///
787///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
788///
789/// where BC(It, k) stands for binomial coefficient.
790///
791const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
792                                                ScalarEvolution &SE) const {
793  const SCEV *Result = getStart();
794  for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
795    // The computation is correct in the face of overflow provided that the
796    // multiplication is performed _after_ the evaluation of the binomial
797    // coefficient.
798    const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType());
799    if (isa<SCEVCouldNotCompute>(Coeff))
800      return Coeff;
801
802    Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
803  }
804  return Result;
805}
806
807//===----------------------------------------------------------------------===//
808//                    SCEV Expression folder implementations
809//===----------------------------------------------------------------------===//
810
811const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,
812                                             Type *Ty) {
813  assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
814         "This is not a truncating conversion!");
815  assert(isSCEVable(Ty) &&
816         "This is not a conversion to a SCEVable type!");
817  Ty = getEffectiveSCEVType(Ty);
818
819  FoldingSetNodeID ID;
820  ID.AddInteger(scTruncate);
821  ID.AddPointer(Op);
822  ID.AddPointer(Ty);
823  void *IP = 0;
824  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
825
826  // Fold if the operand is constant.
827  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
828    return getConstant(
829      cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(),
830                                               getEffectiveSCEVType(Ty))));
831
832  // trunc(trunc(x)) --> trunc(x)
833  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
834    return getTruncateExpr(ST->getOperand(), Ty);
835
836  // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
837  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
838    return getTruncateOrSignExtend(SS->getOperand(), Ty);
839
840  // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
841  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
842    return getTruncateOrZeroExtend(SZ->getOperand(), Ty);
843
844  // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can
845  // eliminate all the truncates.
846  if (const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Op)) {
847    SmallVector<const SCEV *, 4> Operands;
848    bool hasTrunc = false;
849    for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) {
850      const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty);
851      hasTrunc = isa<SCEVTruncateExpr>(S);
852      Operands.push_back(S);
853    }
854    if (!hasTrunc)
855      return getAddExpr(Operands);
856    UniqueSCEVs.FindNodeOrInsertPos(ID, IP);  // Mutates IP, returns NULL.
857  }
858
859  // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can
860  // eliminate all the truncates.
861  if (const SCEVMulExpr *SM = dyn_cast<SCEVMulExpr>(Op)) {
862    SmallVector<const SCEV *, 4> Operands;
863    bool hasTrunc = false;
864    for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) {
865      const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty);
866      hasTrunc = isa<SCEVTruncateExpr>(S);
867      Operands.push_back(S);
868    }
869    if (!hasTrunc)
870      return getMulExpr(Operands);
871    UniqueSCEVs.FindNodeOrInsertPos(ID, IP);  // Mutates IP, returns NULL.
872  }
873
874  // If the input value is a chrec scev, truncate the chrec's operands.
875  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
876    SmallVector<const SCEV *, 4> Operands;
877    for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
878      Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty));
879    return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
880  }
881
882  // As a special case, fold trunc(undef) to undef. We don't want to
883  // know too much about SCEVUnknowns, but this special case is handy
884  // and harmless.
885  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Op))
886    if (isa<UndefValue>(U->getValue()))
887      return getSCEV(UndefValue::get(Ty));
888
889  // The cast wasn't folded; create an explicit cast node. We can reuse
890  // the existing insert position since if we get here, we won't have
891  // made any changes which would invalidate it.
892  SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
893                                                 Op, Ty);
894  UniqueSCEVs.InsertNode(S, IP);
895  return S;
896}
897
898const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
899                                               Type *Ty) {
900  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
901         "This is not an extending conversion!");
902  assert(isSCEVable(Ty) &&
903         "This is not a conversion to a SCEVable type!");
904  Ty = getEffectiveSCEVType(Ty);
905
906  // Fold if the operand is constant.
907  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
908    return getConstant(
909      cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(),
910                                              getEffectiveSCEVType(Ty))));
911
912  // zext(zext(x)) --> zext(x)
913  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
914    return getZeroExtendExpr(SZ->getOperand(), Ty);
915
916  // Before doing any expensive analysis, check to see if we've already
917  // computed a SCEV for this Op and Ty.
918  FoldingSetNodeID ID;
919  ID.AddInteger(scZeroExtend);
920  ID.AddPointer(Op);
921  ID.AddPointer(Ty);
922  void *IP = 0;
923  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
924
925  // zext(trunc(x)) --> zext(x) or x or trunc(x)
926  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
927    // It's possible the bits taken off by the truncate were all zero bits. If
928    // so, we should be able to simplify this further.
929    const SCEV *X = ST->getOperand();
930    ConstantRange CR = getUnsignedRange(X);
931    unsigned TruncBits = getTypeSizeInBits(ST->getType());
932    unsigned NewBits = getTypeSizeInBits(Ty);
933    if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
934            CR.zextOrTrunc(NewBits)))
935      return getTruncateOrZeroExtend(X, Ty);
936  }
937
938  // If the input value is a chrec scev, and we can prove that the value
939  // did not overflow the old, smaller, value, we can zero extend all of the
940  // operands (often constants).  This allows analysis of something like
941  // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
942  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
943    if (AR->isAffine()) {
944      const SCEV *Start = AR->getStart();
945      const SCEV *Step = AR->getStepRecurrence(*this);
946      unsigned BitWidth = getTypeSizeInBits(AR->getType());
947      const Loop *L = AR->getLoop();
948
949      // If we have special knowledge that this addrec won't overflow,
950      // we don't need to do any further analysis.
951      if (AR->getNoWrapFlags(SCEV::FlagNUW))
952        return getAddRecExpr(getZeroExtendExpr(Start, Ty),
953                             getZeroExtendExpr(Step, Ty),
954                             L, AR->getNoWrapFlags());
955
956      // Check whether the backedge-taken count is SCEVCouldNotCompute.
957      // Note that this serves two purposes: It filters out loops that are
958      // simply not analyzable, and it covers the case where this code is
959      // being called from within backedge-taken count analysis, such that
960      // attempting to ask for the backedge-taken count would likely result
961      // in infinite recursion. In the later case, the analysis code will
962      // cope with a conservative value, and it will take care to purge
963      // that value once it has finished.
964      const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
965      if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
966        // Manually compute the final value for AR, checking for
967        // overflow.
968
969        // Check whether the backedge-taken count can be losslessly casted to
970        // the addrec's type. The count is always unsigned.
971        const SCEV *CastedMaxBECount =
972          getTruncateOrZeroExtend(MaxBECount, Start->getType());
973        const SCEV *RecastedMaxBECount =
974          getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
975        if (MaxBECount == RecastedMaxBECount) {
976          Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
977          // Check whether Start+Step*MaxBECount has no unsigned overflow.
978          const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step);
979          const SCEV *Add = getAddExpr(Start, ZMul);
980          const SCEV *OperandExtendedAdd =
981            getAddExpr(getZeroExtendExpr(Start, WideTy),
982                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
983                                  getZeroExtendExpr(Step, WideTy)));
984          if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd) {
985            // Cache knowledge of AR NUW, which is propagated to this AddRec.
986            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
987            // Return the expression with the addrec on the outside.
988            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
989                                 getZeroExtendExpr(Step, Ty),
990                                 L, AR->getNoWrapFlags());
991          }
992          // Similar to above, only this time treat the step value as signed.
993          // This covers loops that count down.
994          const SCEV *SMul = getMulExpr(CastedMaxBECount, Step);
995          Add = getAddExpr(Start, SMul);
996          OperandExtendedAdd =
997            getAddExpr(getZeroExtendExpr(Start, WideTy),
998                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
999                                  getSignExtendExpr(Step, WideTy)));
1000          if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd) {
1001            // Cache knowledge of AR NW, which is propagated to this AddRec.
1002            // Negative step causes unsigned wrap, but it still can't self-wrap.
1003            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1004            // Return the expression with the addrec on the outside.
1005            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
1006                                 getSignExtendExpr(Step, Ty),
1007                                 L, AR->getNoWrapFlags());
1008          }
1009        }
1010
1011        // If the backedge is guarded by a comparison with the pre-inc value
1012        // the addrec is safe. Also, if the entry is guarded by a comparison
1013        // with the start value and the backedge is guarded by a comparison
1014        // with the post-inc value, the addrec is safe.
1015        if (isKnownPositive(Step)) {
1016          const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
1017                                      getUnsignedRange(Step).getUnsignedMax());
1018          if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
1019              (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) &&
1020               isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT,
1021                                           AR->getPostIncExpr(*this), N))) {
1022            // Cache knowledge of AR NUW, which is propagated to this AddRec.
1023            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
1024            // Return the expression with the addrec on the outside.
1025            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
1026                                 getZeroExtendExpr(Step, Ty),
1027                                 L, AR->getNoWrapFlags());
1028          }
1029        } else if (isKnownNegative(Step)) {
1030          const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
1031                                      getSignedRange(Step).getSignedMin());
1032          if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
1033              (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) &&
1034               isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT,
1035                                           AR->getPostIncExpr(*this), N))) {
1036            // Cache knowledge of AR NW, which is propagated to this AddRec.
1037            // Negative step causes unsigned wrap, but it still can't self-wrap.
1038            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1039            // Return the expression with the addrec on the outside.
1040            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
1041                                 getSignExtendExpr(Step, Ty),
1042                                 L, AR->getNoWrapFlags());
1043          }
1044        }
1045      }
1046    }
1047
1048  // The cast wasn't folded; create an explicit cast node.
1049  // Recompute the insert position, as it may have been invalidated.
1050  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1051  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1052                                                   Op, Ty);
1053  UniqueSCEVs.InsertNode(S, IP);
1054  return S;
1055}
1056
1057// Get the limit of a recurrence such that incrementing by Step cannot cause
1058// signed overflow as long as the value of the recurrence within the loop does
1059// not exceed this limit before incrementing.
1060static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1061                                           ICmpInst::Predicate *Pred,
1062                                           ScalarEvolution *SE) {
1063  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1064  if (SE->isKnownPositive(Step)) {
1065    *Pred = ICmpInst::ICMP_SLT;
1066    return SE->getConstant(APInt::getSignedMinValue(BitWidth) -
1067                           SE->getSignedRange(Step).getSignedMax());
1068  }
1069  if (SE->isKnownNegative(Step)) {
1070    *Pred = ICmpInst::ICMP_SGT;
1071    return SE->getConstant(APInt::getSignedMaxValue(BitWidth) -
1072                       SE->getSignedRange(Step).getSignedMin());
1073  }
1074  return 0;
1075}
1076
1077// The recurrence AR has been shown to have no signed wrap. Typically, if we can
1078// prove NSW for AR, then we can just as easily prove NSW for its preincrement
1079// or postincrement sibling. This allows normalizing a sign extended AddRec as
1080// such: {sext(Step + Start),+,Step} => {(Step + sext(Start),+,Step} As a
1081// result, the expression "Step + sext(PreIncAR)" is congruent with
1082// "sext(PostIncAR)"
1083static const SCEV *getPreStartForSignExtend(const SCEVAddRecExpr *AR,
1084                                            Type *Ty,
1085                                            ScalarEvolution *SE) {
1086  const Loop *L = AR->getLoop();
1087  const SCEV *Start = AR->getStart();
1088  const SCEV *Step = AR->getStepRecurrence(*SE);
1089
1090  // Check for a simple looking step prior to loop entry.
1091  const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1092  if (!SA)
1093    return 0;
1094
1095  // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1096  // subtraction is expensive. For this purpose, perform a quick and dirty
1097  // difference, by checking for Step in the operand list.
1098  SmallVector<const SCEV *, 4> DiffOps;
1099  for (SCEVAddExpr::op_iterator I = SA->op_begin(), E = SA->op_end();
1100       I != E; ++I) {
1101    if (*I != Step)
1102      DiffOps.push_back(*I);
1103  }
1104  if (DiffOps.size() == SA->getNumOperands())
1105    return 0;
1106
1107  // This is a postinc AR. Check for overflow on the preinc recurrence using the
1108  // same three conditions that getSignExtendedExpr checks.
1109
1110  // 1. NSW flags on the step increment.
1111  const SCEV *PreStart = SE->getAddExpr(DiffOps, SA->getNoWrapFlags());
1112  const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1113    SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1114
1115  if (PreAR && PreAR->getNoWrapFlags(SCEV::FlagNSW))
1116    return PreStart;
1117
1118  // 2. Direct overflow check on the step operation's expression.
1119  unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1120  Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1121  const SCEV *OperandExtendedStart =
1122    SE->getAddExpr(SE->getSignExtendExpr(PreStart, WideTy),
1123                   SE->getSignExtendExpr(Step, WideTy));
1124  if (SE->getSignExtendExpr(Start, WideTy) == OperandExtendedStart) {
1125    // Cache knowledge of PreAR NSW.
1126    if (PreAR)
1127      const_cast<SCEVAddRecExpr *>(PreAR)->setNoWrapFlags(SCEV::FlagNSW);
1128    // FIXME: this optimization needs a unit test
1129    DEBUG(dbgs() << "SCEV: untested prestart overflow check\n");
1130    return PreStart;
1131  }
1132
1133  // 3. Loop precondition.
1134  ICmpInst::Predicate Pred;
1135  const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, SE);
1136
1137  if (OverflowLimit &&
1138      SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) {
1139    return PreStart;
1140  }
1141  return 0;
1142}
1143
1144// Get the normalized sign-extended expression for this AddRec's Start.
1145static const SCEV *getSignExtendAddRecStart(const SCEVAddRecExpr *AR,
1146                                            Type *Ty,
1147                                            ScalarEvolution *SE) {
1148  const SCEV *PreStart = getPreStartForSignExtend(AR, Ty, SE);
1149  if (!PreStart)
1150    return SE->getSignExtendExpr(AR->getStart(), Ty);
1151
1152  return SE->getAddExpr(SE->getSignExtendExpr(AR->getStepRecurrence(*SE), Ty),
1153                        SE->getSignExtendExpr(PreStart, Ty));
1154}
1155
1156const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
1157                                               Type *Ty) {
1158  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1159         "This is not an extending conversion!");
1160  assert(isSCEVable(Ty) &&
1161         "This is not a conversion to a SCEVable type!");
1162  Ty = getEffectiveSCEVType(Ty);
1163
1164  // Fold if the operand is constant.
1165  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1166    return getConstant(
1167      cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(),
1168                                              getEffectiveSCEVType(Ty))));
1169
1170  // sext(sext(x)) --> sext(x)
1171  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1172    return getSignExtendExpr(SS->getOperand(), Ty);
1173
1174  // sext(zext(x)) --> zext(x)
1175  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1176    return getZeroExtendExpr(SZ->getOperand(), Ty);
1177
1178  // Before doing any expensive analysis, check to see if we've already
1179  // computed a SCEV for this Op and Ty.
1180  FoldingSetNodeID ID;
1181  ID.AddInteger(scSignExtend);
1182  ID.AddPointer(Op);
1183  ID.AddPointer(Ty);
1184  void *IP = 0;
1185  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1186
1187  // If the input value is provably positive, build a zext instead.
1188  if (isKnownNonNegative(Op))
1189    return getZeroExtendExpr(Op, Ty);
1190
1191  // sext(trunc(x)) --> sext(x) or x or trunc(x)
1192  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1193    // It's possible the bits taken off by the truncate were all sign bits. If
1194    // so, we should be able to simplify this further.
1195    const SCEV *X = ST->getOperand();
1196    ConstantRange CR = getSignedRange(X);
1197    unsigned TruncBits = getTypeSizeInBits(ST->getType());
1198    unsigned NewBits = getTypeSizeInBits(Ty);
1199    if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1200            CR.sextOrTrunc(NewBits)))
1201      return getTruncateOrSignExtend(X, Ty);
1202  }
1203
1204  // If the input value is a chrec scev, and we can prove that the value
1205  // did not overflow the old, smaller, value, we can sign extend all of the
1206  // operands (often constants).  This allows analysis of something like
1207  // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
1208  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1209    if (AR->isAffine()) {
1210      const SCEV *Start = AR->getStart();
1211      const SCEV *Step = AR->getStepRecurrence(*this);
1212      unsigned BitWidth = getTypeSizeInBits(AR->getType());
1213      const Loop *L = AR->getLoop();
1214
1215      // If we have special knowledge that this addrec won't overflow,
1216      // we don't need to do any further analysis.
1217      if (AR->getNoWrapFlags(SCEV::FlagNSW))
1218        return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
1219                             getSignExtendExpr(Step, Ty),
1220                             L, SCEV::FlagNSW);
1221
1222      // Check whether the backedge-taken count is SCEVCouldNotCompute.
1223      // Note that this serves two purposes: It filters out loops that are
1224      // simply not analyzable, and it covers the case where this code is
1225      // being called from within backedge-taken count analysis, such that
1226      // attempting to ask for the backedge-taken count would likely result
1227      // in infinite recursion. In the later case, the analysis code will
1228      // cope with a conservative value, and it will take care to purge
1229      // that value once it has finished.
1230      const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
1231      if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1232        // Manually compute the final value for AR, checking for
1233        // overflow.
1234
1235        // Check whether the backedge-taken count can be losslessly casted to
1236        // the addrec's type. The count is always unsigned.
1237        const SCEV *CastedMaxBECount =
1238          getTruncateOrZeroExtend(MaxBECount, Start->getType());
1239        const SCEV *RecastedMaxBECount =
1240          getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
1241        if (MaxBECount == RecastedMaxBECount) {
1242          Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1243          // Check whether Start+Step*MaxBECount has no signed overflow.
1244          const SCEV *SMul = getMulExpr(CastedMaxBECount, Step);
1245          const SCEV *Add = getAddExpr(Start, SMul);
1246          const SCEV *OperandExtendedAdd =
1247            getAddExpr(getSignExtendExpr(Start, WideTy),
1248                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
1249                                  getSignExtendExpr(Step, WideTy)));
1250          if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd) {
1251            // Cache knowledge of AR NSW, which is propagated to this AddRec.
1252            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1253            // Return the expression with the addrec on the outside.
1254            return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
1255                                 getSignExtendExpr(Step, Ty),
1256                                 L, AR->getNoWrapFlags());
1257          }
1258          // Similar to above, only this time treat the step value as unsigned.
1259          // This covers loops that count up with an unsigned step.
1260          const SCEV *UMul = getMulExpr(CastedMaxBECount, Step);
1261          Add = getAddExpr(Start, UMul);
1262          OperandExtendedAdd =
1263            getAddExpr(getSignExtendExpr(Start, WideTy),
1264                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
1265                                  getZeroExtendExpr(Step, WideTy)));
1266          if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd) {
1267            // Cache knowledge of AR NSW, which is propagated to this AddRec.
1268            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1269            // Return the expression with the addrec on the outside.
1270            return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
1271                                 getZeroExtendExpr(Step, Ty),
1272                                 L, AR->getNoWrapFlags());
1273          }
1274        }
1275
1276        // If the backedge is guarded by a comparison with the pre-inc value
1277        // the addrec is safe. Also, if the entry is guarded by a comparison
1278        // with the start value and the backedge is guarded by a comparison
1279        // with the post-inc value, the addrec is safe.
1280        ICmpInst::Predicate Pred;
1281        const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, this);
1282        if (OverflowLimit &&
1283            (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
1284             (isLoopEntryGuardedByCond(L, Pred, Start, OverflowLimit) &&
1285              isLoopBackedgeGuardedByCond(L, Pred, AR->getPostIncExpr(*this),
1286                                          OverflowLimit)))) {
1287          // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec.
1288          const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1289          return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
1290                               getSignExtendExpr(Step, Ty),
1291                               L, AR->getNoWrapFlags());
1292        }
1293      }
1294    }
1295
1296  // The cast wasn't folded; create an explicit cast node.
1297  // Recompute the insert position, as it may have been invalidated.
1298  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1299  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1300                                                   Op, Ty);
1301  UniqueSCEVs.InsertNode(S, IP);
1302  return S;
1303}
1304
1305/// getAnyExtendExpr - Return a SCEV for the given operand extended with
1306/// unspecified bits out to the given type.
1307///
1308const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
1309                                              Type *Ty) {
1310  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1311         "This is not an extending conversion!");
1312  assert(isSCEVable(Ty) &&
1313         "This is not a conversion to a SCEVable type!");
1314  Ty = getEffectiveSCEVType(Ty);
1315
1316  // Sign-extend negative constants.
1317  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1318    if (SC->getValue()->getValue().isNegative())
1319      return getSignExtendExpr(Op, Ty);
1320
1321  // Peel off a truncate cast.
1322  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
1323    const SCEV *NewOp = T->getOperand();
1324    if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
1325      return getAnyExtendExpr(NewOp, Ty);
1326    return getTruncateOrNoop(NewOp, Ty);
1327  }
1328
1329  // Next try a zext cast. If the cast is folded, use it.
1330  const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
1331  if (!isa<SCEVZeroExtendExpr>(ZExt))
1332    return ZExt;
1333
1334  // Next try a sext cast. If the cast is folded, use it.
1335  const SCEV *SExt = getSignExtendExpr(Op, Ty);
1336  if (!isa<SCEVSignExtendExpr>(SExt))
1337    return SExt;
1338
1339  // Force the cast to be folded into the operands of an addrec.
1340  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
1341    SmallVector<const SCEV *, 4> Ops;
1342    for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
1343         I != E; ++I)
1344      Ops.push_back(getAnyExtendExpr(*I, Ty));
1345    return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
1346  }
1347
1348  // As a special case, fold anyext(undef) to undef. We don't want to
1349  // know too much about SCEVUnknowns, but this special case is handy
1350  // and harmless.
1351  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Op))
1352    if (isa<UndefValue>(U->getValue()))
1353      return getSCEV(UndefValue::get(Ty));
1354
1355  // If the expression is obviously signed, use the sext cast value.
1356  if (isa<SCEVSMaxExpr>(Op))
1357    return SExt;
1358
1359  // Absent any other information, use the zext cast value.
1360  return ZExt;
1361}
1362
1363/// CollectAddOperandsWithScales - Process the given Ops list, which is
1364/// a list of operands to be added under the given scale, update the given
1365/// map. This is a helper function for getAddRecExpr. As an example of
1366/// what it does, given a sequence of operands that would form an add
1367/// expression like this:
1368///
1369///    m + n + 13 + (A * (o + p + (B * q + m + 29))) + r + (-1 * r)
1370///
1371/// where A and B are constants, update the map with these values:
1372///
1373///    (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
1374///
1375/// and add 13 + A*B*29 to AccumulatedConstant.
1376/// This will allow getAddRecExpr to produce this:
1377///
1378///    13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
1379///
1380/// This form often exposes folding opportunities that are hidden in
1381/// the original operand list.
1382///
1383/// Return true iff it appears that any interesting folding opportunities
1384/// may be exposed. This helps getAddRecExpr short-circuit extra work in
1385/// the common case where no interesting opportunities are present, and
1386/// is also used as a check to avoid infinite recursion.
1387///
1388static bool
1389CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
1390                             SmallVector<const SCEV *, 8> &NewOps,
1391                             APInt &AccumulatedConstant,
1392                             const SCEV *const *Ops, size_t NumOperands,
1393                             const APInt &Scale,
1394                             ScalarEvolution &SE) {
1395  bool Interesting = false;
1396
1397  // Iterate over the add operands. They are sorted, with constants first.
1398  unsigned i = 0;
1399  while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1400    ++i;
1401    // Pull a buried constant out to the outside.
1402    if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
1403      Interesting = true;
1404    AccumulatedConstant += Scale * C->getValue()->getValue();
1405  }
1406
1407  // Next comes everything else. We're especially interested in multiplies
1408  // here, but they're in the middle, so just visit the rest with one loop.
1409  for (; i != NumOperands; ++i) {
1410    const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
1411    if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
1412      APInt NewScale =
1413        Scale * cast<SCEVConstant>(Mul->getOperand(0))->getValue()->getValue();
1414      if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
1415        // A multiplication of a constant with another add; recurse.
1416        const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
1417        Interesting |=
1418          CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1419                                       Add->op_begin(), Add->getNumOperands(),
1420                                       NewScale, SE);
1421      } else {
1422        // A multiplication of a constant with some other value. Update
1423        // the map.
1424        SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end());
1425        const SCEV *Key = SE.getMulExpr(MulOps);
1426        std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
1427          M.insert(std::make_pair(Key, NewScale));
1428        if (Pair.second) {
1429          NewOps.push_back(Pair.first->first);
1430        } else {
1431          Pair.first->second += NewScale;
1432          // The map already had an entry for this value, which may indicate
1433          // a folding opportunity.
1434          Interesting = true;
1435        }
1436      }
1437    } else {
1438      // An ordinary operand. Update the map.
1439      std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
1440        M.insert(std::make_pair(Ops[i], Scale));
1441      if (Pair.second) {
1442        NewOps.push_back(Pair.first->first);
1443      } else {
1444        Pair.first->second += Scale;
1445        // The map already had an entry for this value, which may indicate
1446        // a folding opportunity.
1447        Interesting = true;
1448      }
1449    }
1450  }
1451
1452  return Interesting;
1453}
1454
1455namespace {
1456  struct APIntCompare {
1457    bool operator()(const APInt &LHS, const APInt &RHS) const {
1458      return LHS.ult(RHS);
1459    }
1460  };
1461}
1462
1463/// getAddExpr - Get a canonical add expression, or something simpler if
1464/// possible.
1465const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
1466                                        SCEV::NoWrapFlags Flags) {
1467  assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
1468         "only nuw or nsw allowed");
1469  assert(!Ops.empty() && "Cannot get empty add!");
1470  if (Ops.size() == 1) return Ops[0];
1471#ifndef NDEBUG
1472  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
1473  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1474    assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
1475           "SCEVAddExpr operand types don't match!");
1476#endif
1477
1478  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
1479  // And vice-versa.
1480  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
1481  SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
1482  if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
1483    bool All = true;
1484    for (SmallVectorImpl<const SCEV *>::const_iterator I = Ops.begin(),
1485         E = Ops.end(); I != E; ++I)
1486      if (!isKnownNonNegative(*I)) {
1487        All = false;
1488        break;
1489      }
1490    if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
1491  }
1492
1493  // Sort by complexity, this groups all similar expression types together.
1494  GroupByComplexity(Ops, LI);
1495
1496  // If there are any constants, fold them together.
1497  unsigned Idx = 0;
1498  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1499    ++Idx;
1500    assert(Idx < Ops.size());
1501    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1502      // We found two constants, fold them together!
1503      Ops[0] = getConstant(LHSC->getValue()->getValue() +
1504                           RHSC->getValue()->getValue());
1505      if (Ops.size() == 2) return Ops[0];
1506      Ops.erase(Ops.begin()+1);  // Erase the folded element
1507      LHSC = cast<SCEVConstant>(Ops[0]);
1508    }
1509
1510    // If we are left with a constant zero being added, strip it off.
1511    if (LHSC->getValue()->isZero()) {
1512      Ops.erase(Ops.begin());
1513      --Idx;
1514    }
1515
1516    if (Ops.size() == 1) return Ops[0];
1517  }
1518
1519  // Okay, check to see if the same value occurs in the operand list more than
1520  // once.  If so, merge them together into an multiply expression.  Since we
1521  // sorted the list, these values are required to be adjacent.
1522  Type *Ty = Ops[0]->getType();
1523  bool FoundMatch = false;
1524  for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
1525    if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2
1526      // Scan ahead to count how many equal operands there are.
1527      unsigned Count = 2;
1528      while (i+Count != e && Ops[i+Count] == Ops[i])
1529        ++Count;
1530      // Merge the values into a multiply.
1531      const SCEV *Scale = getConstant(Ty, Count);
1532      const SCEV *Mul = getMulExpr(Scale, Ops[i]);
1533      if (Ops.size() == Count)
1534        return Mul;
1535      Ops[i] = Mul;
1536      Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
1537      --i; e -= Count - 1;
1538      FoundMatch = true;
1539    }
1540  if (FoundMatch)
1541    return getAddExpr(Ops, Flags);
1542
1543  // Check for truncates. If all the operands are truncated from the same
1544  // type, see if factoring out the truncate would permit the result to be
1545  // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n)
1546  // if the contents of the resulting outer trunc fold to something simple.
1547  for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) {
1548    const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]);
1549    Type *DstType = Trunc->getType();
1550    Type *SrcType = Trunc->getOperand()->getType();
1551    SmallVector<const SCEV *, 8> LargeOps;
1552    bool Ok = true;
1553    // Check all the operands to see if they can be represented in the
1554    // source type of the truncate.
1555    for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
1556      if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
1557        if (T->getOperand()->getType() != SrcType) {
1558          Ok = false;
1559          break;
1560        }
1561        LargeOps.push_back(T->getOperand());
1562      } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1563        LargeOps.push_back(getAnyExtendExpr(C, SrcType));
1564      } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
1565        SmallVector<const SCEV *, 8> LargeMulOps;
1566        for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
1567          if (const SCEVTruncateExpr *T =
1568                dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
1569            if (T->getOperand()->getType() != SrcType) {
1570              Ok = false;
1571              break;
1572            }
1573            LargeMulOps.push_back(T->getOperand());
1574          } else if (const SCEVConstant *C =
1575                       dyn_cast<SCEVConstant>(M->getOperand(j))) {
1576            LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
1577          } else {
1578            Ok = false;
1579            break;
1580          }
1581        }
1582        if (Ok)
1583          LargeOps.push_back(getMulExpr(LargeMulOps));
1584      } else {
1585        Ok = false;
1586        break;
1587      }
1588    }
1589    if (Ok) {
1590      // Evaluate the expression in the larger type.
1591      const SCEV *Fold = getAddExpr(LargeOps, Flags);
1592      // If it folds to something simple, use it. Otherwise, don't.
1593      if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
1594        return getTruncateExpr(Fold, DstType);
1595    }
1596  }
1597
1598  // Skip past any other cast SCEVs.
1599  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
1600    ++Idx;
1601
1602  // If there are add operands they would be next.
1603  if (Idx < Ops.size()) {
1604    bool DeletedAdd = false;
1605    while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
1606      // If we have an add, expand the add operands onto the end of the operands
1607      // list.
1608      Ops.erase(Ops.begin()+Idx);
1609      Ops.append(Add->op_begin(), Add->op_end());
1610      DeletedAdd = true;
1611    }
1612
1613    // If we deleted at least one add, we added operands to the end of the list,
1614    // and they are not necessarily sorted.  Recurse to resort and resimplify
1615    // any operands we just acquired.
1616    if (DeletedAdd)
1617      return getAddExpr(Ops);
1618  }
1619
1620  // Skip over the add expression until we get to a multiply.
1621  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1622    ++Idx;
1623
1624  // Check to see if there are any folding opportunities present with
1625  // operands multiplied by constant values.
1626  if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
1627    uint64_t BitWidth = getTypeSizeInBits(Ty);
1628    DenseMap<const SCEV *, APInt> M;
1629    SmallVector<const SCEV *, 8> NewOps;
1630    APInt AccumulatedConstant(BitWidth, 0);
1631    if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1632                                     Ops.data(), Ops.size(),
1633                                     APInt(BitWidth, 1), *this)) {
1634      // Some interesting folding opportunity is present, so its worthwhile to
1635      // re-generate the operands list. Group the operands by constant scale,
1636      // to avoid multiplying by the same constant scale multiple times.
1637      std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
1638      for (SmallVector<const SCEV *, 8>::const_iterator I = NewOps.begin(),
1639           E = NewOps.end(); I != E; ++I)
1640        MulOpLists[M.find(*I)->second].push_back(*I);
1641      // Re-generate the operands list.
1642      Ops.clear();
1643      if (AccumulatedConstant != 0)
1644        Ops.push_back(getConstant(AccumulatedConstant));
1645      for (std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare>::iterator
1646           I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I)
1647        if (I->first != 0)
1648          Ops.push_back(getMulExpr(getConstant(I->first),
1649                                   getAddExpr(I->second)));
1650      if (Ops.empty())
1651        return getConstant(Ty, 0);
1652      if (Ops.size() == 1)
1653        return Ops[0];
1654      return getAddExpr(Ops);
1655    }
1656  }
1657
1658  // If we are adding something to a multiply expression, make sure the
1659  // something is not already an operand of the multiply.  If so, merge it into
1660  // the multiply.
1661  for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
1662    const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
1663    for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
1664      const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
1665      if (isa<SCEVConstant>(MulOpSCEV))
1666        continue;
1667      for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
1668        if (MulOpSCEV == Ops[AddOp]) {
1669          // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1))
1670          const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
1671          if (Mul->getNumOperands() != 2) {
1672            // If the multiply has more than two operands, we must get the
1673            // Y*Z term.
1674            SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
1675                                                Mul->op_begin()+MulOp);
1676            MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
1677            InnerMul = getMulExpr(MulOps);
1678          }
1679          const SCEV *One = getConstant(Ty, 1);
1680          const SCEV *AddOne = getAddExpr(One, InnerMul);
1681          const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV);
1682          if (Ops.size() == 2) return OuterMul;
1683          if (AddOp < Idx) {
1684            Ops.erase(Ops.begin()+AddOp);
1685            Ops.erase(Ops.begin()+Idx-1);
1686          } else {
1687            Ops.erase(Ops.begin()+Idx);
1688            Ops.erase(Ops.begin()+AddOp-1);
1689          }
1690          Ops.push_back(OuterMul);
1691          return getAddExpr(Ops);
1692        }
1693
1694      // Check this multiply against other multiplies being added together.
1695      for (unsigned OtherMulIdx = Idx+1;
1696           OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
1697           ++OtherMulIdx) {
1698        const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
1699        // If MulOp occurs in OtherMul, we can fold the two multiplies
1700        // together.
1701        for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
1702             OMulOp != e; ++OMulOp)
1703          if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
1704            // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
1705            const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
1706            if (Mul->getNumOperands() != 2) {
1707              SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
1708                                                  Mul->op_begin()+MulOp);
1709              MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
1710              InnerMul1 = getMulExpr(MulOps);
1711            }
1712            const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
1713            if (OtherMul->getNumOperands() != 2) {
1714              SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
1715                                                  OtherMul->op_begin()+OMulOp);
1716              MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
1717              InnerMul2 = getMulExpr(MulOps);
1718            }
1719            const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2);
1720            const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum);
1721            if (Ops.size() == 2) return OuterMul;
1722            Ops.erase(Ops.begin()+Idx);
1723            Ops.erase(Ops.begin()+OtherMulIdx-1);
1724            Ops.push_back(OuterMul);
1725            return getAddExpr(Ops);
1726          }
1727      }
1728    }
1729  }
1730
1731  // If there are any add recurrences in the operands list, see if any other
1732  // added values are loop invariant.  If so, we can fold them into the
1733  // recurrence.
1734  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1735    ++Idx;
1736
1737  // Scan over all recurrences, trying to fold loop invariants into them.
1738  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1739    // Scan all of the other operands to this add and add them to the vector if
1740    // they are loop invariant w.r.t. the recurrence.
1741    SmallVector<const SCEV *, 8> LIOps;
1742    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1743    const Loop *AddRecLoop = AddRec->getLoop();
1744    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1745      if (isLoopInvariant(Ops[i], AddRecLoop)) {
1746        LIOps.push_back(Ops[i]);
1747        Ops.erase(Ops.begin()+i);
1748        --i; --e;
1749      }
1750
1751    // If we found some loop invariants, fold them into the recurrence.
1752    if (!LIOps.empty()) {
1753      //  NLI + LI + {Start,+,Step}  -->  NLI + {LI+Start,+,Step}
1754      LIOps.push_back(AddRec->getStart());
1755
1756      SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
1757                                             AddRec->op_end());
1758      AddRecOps[0] = getAddExpr(LIOps);
1759
1760      // Build the new addrec. Propagate the NUW and NSW flags if both the
1761      // outer add and the inner addrec are guaranteed to have no overflow.
1762      // Always propagate NW.
1763      Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
1764      const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
1765
1766      // If all of the other operands were loop invariant, we are done.
1767      if (Ops.size() == 1) return NewRec;
1768
1769      // Otherwise, add the folded AddRec by the non-invariant parts.
1770      for (unsigned i = 0;; ++i)
1771        if (Ops[i] == AddRec) {
1772          Ops[i] = NewRec;
1773          break;
1774        }
1775      return getAddExpr(Ops);
1776    }
1777
1778    // Okay, if there weren't any loop invariants to be folded, check to see if
1779    // there are multiple AddRec's with the same loop induction variable being
1780    // added together.  If so, we can fold them.
1781    for (unsigned OtherIdx = Idx+1;
1782         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
1783         ++OtherIdx)
1784      if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
1785        // Other + {A,+,B}<L> + {C,+,D}<L>  -->  Other + {A+C,+,B+D}<L>
1786        SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
1787                                               AddRec->op_end());
1788        for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
1789             ++OtherIdx)
1790          if (const SCEVAddRecExpr *OtherAddRec =
1791                dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]))
1792            if (OtherAddRec->getLoop() == AddRecLoop) {
1793              for (unsigned i = 0, e = OtherAddRec->getNumOperands();
1794                   i != e; ++i) {
1795                if (i >= AddRecOps.size()) {
1796                  AddRecOps.append(OtherAddRec->op_begin()+i,
1797                                   OtherAddRec->op_end());
1798                  break;
1799                }
1800                AddRecOps[i] = getAddExpr(AddRecOps[i],
1801                                          OtherAddRec->getOperand(i));
1802              }
1803              Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
1804            }
1805        // Step size has changed, so we cannot guarantee no self-wraparound.
1806        Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
1807        return getAddExpr(Ops);
1808      }
1809
1810    // Otherwise couldn't fold anything into this recurrence.  Move onto the
1811    // next one.
1812  }
1813
1814  // Okay, it looks like we really DO need an add expr.  Check to see if we
1815  // already have one, otherwise create a new one.
1816  FoldingSetNodeID ID;
1817  ID.AddInteger(scAddExpr);
1818  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1819    ID.AddPointer(Ops[i]);
1820  void *IP = 0;
1821  SCEVAddExpr *S =
1822    static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1823  if (!S) {
1824    const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
1825    std::uninitialized_copy(Ops.begin(), Ops.end(), O);
1826    S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator),
1827                                        O, Ops.size());
1828    UniqueSCEVs.InsertNode(S, IP);
1829  }
1830  S->setNoWrapFlags(Flags);
1831  return S;
1832}
1833
1834static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
1835  uint64_t k = i*j;
1836  if (j > 1 && k / j != i) Overflow = true;
1837  return k;
1838}
1839
1840/// Compute the result of "n choose k", the binomial coefficient.  If an
1841/// intermediate computation overflows, Overflow will be set and the return will
1842/// be garbage. Overflow is not cleared on absense of overflow.
1843static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
1844  // We use the multiplicative formula:
1845  //     n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
1846  // At each iteration, we take the n-th term of the numeral and divide by the
1847  // (k-n)th term of the denominator.  This division will always produce an
1848  // integral result, and helps reduce the chance of overflow in the
1849  // intermediate computations. However, we can still overflow even when the
1850  // final result would fit.
1851
1852  if (n == 0 || n == k) return 1;
1853  if (k > n) return 0;
1854
1855  if (k > n/2)
1856    k = n-k;
1857
1858  uint64_t r = 1;
1859  for (uint64_t i = 1; i <= k; ++i) {
1860    r = umul_ov(r, n-(i-1), Overflow);
1861    r /= i;
1862  }
1863  return r;
1864}
1865
1866/// getMulExpr - Get a canonical multiply expression, or something simpler if
1867/// possible.
1868const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
1869                                        SCEV::NoWrapFlags Flags) {
1870  assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) &&
1871         "only nuw or nsw allowed");
1872  assert(!Ops.empty() && "Cannot get empty mul!");
1873  if (Ops.size() == 1) return Ops[0];
1874#ifndef NDEBUG
1875  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
1876  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1877    assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
1878           "SCEVMulExpr operand types don't match!");
1879#endif
1880
1881  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
1882  // And vice-versa.
1883  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
1884  SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
1885  if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
1886    bool All = true;
1887    for (SmallVectorImpl<const SCEV *>::const_iterator I = Ops.begin(),
1888         E = Ops.end(); I != E; ++I)
1889      if (!isKnownNonNegative(*I)) {
1890        All = false;
1891        break;
1892      }
1893    if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
1894  }
1895
1896  // Sort by complexity, this groups all similar expression types together.
1897  GroupByComplexity(Ops, LI);
1898
1899  // If there are any constants, fold them together.
1900  unsigned Idx = 0;
1901  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1902
1903    // C1*(C2+V) -> C1*C2 + C1*V
1904    if (Ops.size() == 2)
1905      if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
1906        if (Add->getNumOperands() == 2 &&
1907            isa<SCEVConstant>(Add->getOperand(0)))
1908          return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
1909                            getMulExpr(LHSC, Add->getOperand(1)));
1910
1911    ++Idx;
1912    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1913      // We found two constants, fold them together!
1914      ConstantInt *Fold = ConstantInt::get(getContext(),
1915                                           LHSC->getValue()->getValue() *
1916                                           RHSC->getValue()->getValue());
1917      Ops[0] = getConstant(Fold);
1918      Ops.erase(Ops.begin()+1);  // Erase the folded element
1919      if (Ops.size() == 1) return Ops[0];
1920      LHSC = cast<SCEVConstant>(Ops[0]);
1921    }
1922
1923    // If we are left with a constant one being multiplied, strip it off.
1924    if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) {
1925      Ops.erase(Ops.begin());
1926      --Idx;
1927    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
1928      // If we have a multiply of zero, it will always be zero.
1929      return Ops[0];
1930    } else if (Ops[0]->isAllOnesValue()) {
1931      // If we have a mul by -1 of an add, try distributing the -1 among the
1932      // add operands.
1933      if (Ops.size() == 2) {
1934        if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
1935          SmallVector<const SCEV *, 4> NewOps;
1936          bool AnyFolded = false;
1937          for (SCEVAddRecExpr::op_iterator I = Add->op_begin(),
1938                 E = Add->op_end(); I != E; ++I) {
1939            const SCEV *Mul = getMulExpr(Ops[0], *I);
1940            if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
1941            NewOps.push_back(Mul);
1942          }
1943          if (AnyFolded)
1944            return getAddExpr(NewOps);
1945        }
1946        else if (const SCEVAddRecExpr *
1947                 AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
1948          // Negation preserves a recurrence's no self-wrap property.
1949          SmallVector<const SCEV *, 4> Operands;
1950          for (SCEVAddRecExpr::op_iterator I = AddRec->op_begin(),
1951                 E = AddRec->op_end(); I != E; ++I) {
1952            Operands.push_back(getMulExpr(Ops[0], *I));
1953          }
1954          return getAddRecExpr(Operands, AddRec->getLoop(),
1955                               AddRec->getNoWrapFlags(SCEV::FlagNW));
1956        }
1957      }
1958    }
1959
1960    if (Ops.size() == 1)
1961      return Ops[0];
1962  }
1963
1964  // Skip over the add expression until we get to a multiply.
1965  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1966    ++Idx;
1967
1968  // If there are mul operands inline them all into this expression.
1969  if (Idx < Ops.size()) {
1970    bool DeletedMul = false;
1971    while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
1972      // If we have an mul, expand the mul operands onto the end of the operands
1973      // list.
1974      Ops.erase(Ops.begin()+Idx);
1975      Ops.append(Mul->op_begin(), Mul->op_end());
1976      DeletedMul = true;
1977    }
1978
1979    // If we deleted at least one mul, we added operands to the end of the list,
1980    // and they are not necessarily sorted.  Recurse to resort and resimplify
1981    // any operands we just acquired.
1982    if (DeletedMul)
1983      return getMulExpr(Ops);
1984  }
1985
1986  // If there are any add recurrences in the operands list, see if any other
1987  // added values are loop invariant.  If so, we can fold them into the
1988  // recurrence.
1989  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1990    ++Idx;
1991
1992  // Scan over all recurrences, trying to fold loop invariants into them.
1993  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1994    // Scan all of the other operands to this mul and add them to the vector if
1995    // they are loop invariant w.r.t. the recurrence.
1996    SmallVector<const SCEV *, 8> LIOps;
1997    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1998    const Loop *AddRecLoop = AddRec->getLoop();
1999    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2000      if (isLoopInvariant(Ops[i], AddRecLoop)) {
2001        LIOps.push_back(Ops[i]);
2002        Ops.erase(Ops.begin()+i);
2003        --i; --e;
2004      }
2005
2006    // If we found some loop invariants, fold them into the recurrence.
2007    if (!LIOps.empty()) {
2008      //  NLI * LI * {Start,+,Step}  -->  NLI * {LI*Start,+,LI*Step}
2009      SmallVector<const SCEV *, 4> NewOps;
2010      NewOps.reserve(AddRec->getNumOperands());
2011      const SCEV *Scale = getMulExpr(LIOps);
2012      for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
2013        NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i)));
2014
2015      // Build the new addrec. Propagate the NUW and NSW flags if both the
2016      // outer mul and the inner addrec are guaranteed to have no overflow.
2017      //
2018      // No self-wrap cannot be guaranteed after changing the step size, but
2019      // will be inferred if either NUW or NSW is true.
2020      Flags = AddRec->getNoWrapFlags(clearFlags(Flags, SCEV::FlagNW));
2021      const SCEV *NewRec = getAddRecExpr(NewOps, AddRecLoop, Flags);
2022
2023      // If all of the other operands were loop invariant, we are done.
2024      if (Ops.size() == 1) return NewRec;
2025
2026      // Otherwise, multiply the folded AddRec by the non-invariant parts.
2027      for (unsigned i = 0;; ++i)
2028        if (Ops[i] == AddRec) {
2029          Ops[i] = NewRec;
2030          break;
2031        }
2032      return getMulExpr(Ops);
2033    }
2034
2035    // Okay, if there weren't any loop invariants to be folded, check to see if
2036    // there are multiple AddRec's with the same loop induction variable being
2037    // multiplied together.  If so, we can fold them.
2038    for (unsigned OtherIdx = Idx+1;
2039         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2040         ++OtherIdx) {
2041      if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2042        // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
2043        // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
2044        //       choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
2045        //   ]]],+,...up to x=2n}.
2046        // Note that the arguments to choose() are always integers with values
2047        // known at compile time, never SCEV objects.
2048        //
2049        // The implementation avoids pointless extra computations when the two
2050        // addrec's are of different length (mathematically, it's equivalent to
2051        // an infinite stream of zeros on the right).
2052        bool OpsModified = false;
2053        for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2054             ++OtherIdx)
2055          if (const SCEVAddRecExpr *OtherAddRec =
2056                dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]))
2057            if (OtherAddRec->getLoop() == AddRecLoop) {
2058              bool Overflow = false;
2059              Type *Ty = AddRec->getType();
2060              bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
2061              SmallVector<const SCEV*, 7> AddRecOps;
2062              for (int x = 0, xe = AddRec->getNumOperands() +
2063                     OtherAddRec->getNumOperands() - 1;
2064                   x != xe && !Overflow; ++x) {
2065                const SCEV *Term = getConstant(Ty, 0);
2066                for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
2067                  uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
2068                  for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
2069                         ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
2070                       z < ze && !Overflow; ++z) {
2071                    uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
2072                    uint64_t Coeff;
2073                    if (LargerThan64Bits)
2074                      Coeff = umul_ov(Coeff1, Coeff2, Overflow);
2075                    else
2076                      Coeff = Coeff1*Coeff2;
2077                    const SCEV *CoeffTerm = getConstant(Ty, Coeff);
2078                    const SCEV *Term1 = AddRec->getOperand(y-z);
2079                    const SCEV *Term2 = OtherAddRec->getOperand(z);
2080                    Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2));
2081                  }
2082                }
2083                AddRecOps.push_back(Term);
2084              }
2085              if (!Overflow) {
2086                const SCEV *NewAddRec = getAddRecExpr(AddRecOps,
2087                                                      AddRec->getLoop(),
2088                                                      SCEV::FlagAnyWrap);
2089                if (Ops.size() == 2) return NewAddRec;
2090                Ops[Idx] = AddRec = cast<SCEVAddRecExpr>(NewAddRec);
2091                Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2092                OpsModified = true;
2093              }
2094            }
2095        if (OpsModified)
2096          return getMulExpr(Ops);
2097      }
2098    }
2099
2100    // Otherwise couldn't fold anything into this recurrence.  Move onto the
2101    // next one.
2102  }
2103
2104  // Okay, it looks like we really DO need an mul expr.  Check to see if we
2105  // already have one, otherwise create a new one.
2106  FoldingSetNodeID ID;
2107  ID.AddInteger(scMulExpr);
2108  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2109    ID.AddPointer(Ops[i]);
2110  void *IP = 0;
2111  SCEVMulExpr *S =
2112    static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2113  if (!S) {
2114    const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2115    std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2116    S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
2117                                        O, Ops.size());
2118    UniqueSCEVs.InsertNode(S, IP);
2119  }
2120  S->setNoWrapFlags(Flags);
2121  return S;
2122}
2123
2124/// getUDivExpr - Get a canonical unsigned division expression, or something
2125/// simpler if possible.
2126const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
2127                                         const SCEV *RHS) {
2128  assert(getEffectiveSCEVType(LHS->getType()) ==
2129         getEffectiveSCEVType(RHS->getType()) &&
2130         "SCEVUDivExpr operand types don't match!");
2131
2132  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
2133    if (RHSC->getValue()->equalsInt(1))
2134      return LHS;                               // X udiv 1 --> x
2135    // If the denominator is zero, the result of the udiv is undefined. Don't
2136    // try to analyze it, because the resolution chosen here may differ from
2137    // the resolution chosen in other parts of the compiler.
2138    if (!RHSC->getValue()->isZero()) {
2139      // Determine if the division can be folded into the operands of
2140      // its operands.
2141      // TODO: Generalize this to non-constants by using known-bits information.
2142      Type *Ty = LHS->getType();
2143      unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros();
2144      unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
2145      // For non-power-of-two values, effectively round the value up to the
2146      // nearest power of two.
2147      if (!RHSC->getValue()->getValue().isPowerOf2())
2148        ++MaxShiftAmt;
2149      IntegerType *ExtTy =
2150        IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
2151      if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
2152        if (const SCEVConstant *Step =
2153            dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
2154          // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
2155          const APInt &StepInt = Step->getValue()->getValue();
2156          const APInt &DivInt = RHSC->getValue()->getValue();
2157          if (!StepInt.urem(DivInt) &&
2158              getZeroExtendExpr(AR, ExtTy) ==
2159              getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
2160                            getZeroExtendExpr(Step, ExtTy),
2161                            AR->getLoop(), SCEV::FlagAnyWrap)) {
2162            SmallVector<const SCEV *, 4> Operands;
2163            for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i)
2164              Operands.push_back(getUDivExpr(AR->getOperand(i), RHS));
2165            return getAddRecExpr(Operands, AR->getLoop(),
2166                                 SCEV::FlagNW);
2167          }
2168          /// Get a canonical UDivExpr for a recurrence.
2169          /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
2170          // We can currently only fold X%N if X is constant.
2171          const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
2172          if (StartC && !DivInt.urem(StepInt) &&
2173              getZeroExtendExpr(AR, ExtTy) ==
2174              getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
2175                            getZeroExtendExpr(Step, ExtTy),
2176                            AR->getLoop(), SCEV::FlagAnyWrap)) {
2177            const APInt &StartInt = StartC->getValue()->getValue();
2178            const APInt &StartRem = StartInt.urem(StepInt);
2179            if (StartRem != 0)
2180              LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step,
2181                                  AR->getLoop(), SCEV::FlagNW);
2182          }
2183        }
2184      // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
2185      if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
2186        SmallVector<const SCEV *, 4> Operands;
2187        for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i)
2188          Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy));
2189        if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
2190          // Find an operand that's safely divisible.
2191          for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
2192            const SCEV *Op = M->getOperand(i);
2193            const SCEV *Div = getUDivExpr(Op, RHSC);
2194            if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
2195              Operands = SmallVector<const SCEV *, 4>(M->op_begin(),
2196                                                      M->op_end());
2197              Operands[i] = Div;
2198              return getMulExpr(Operands);
2199            }
2200          }
2201      }
2202      // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
2203      if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
2204        SmallVector<const SCEV *, 4> Operands;
2205        for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i)
2206          Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy));
2207        if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
2208          Operands.clear();
2209          for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
2210            const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
2211            if (isa<SCEVUDivExpr>(Op) ||
2212                getMulExpr(Op, RHS) != A->getOperand(i))
2213              break;
2214            Operands.push_back(Op);
2215          }
2216          if (Operands.size() == A->getNumOperands())
2217            return getAddExpr(Operands);
2218        }
2219      }
2220
2221      // Fold if both operands are constant.
2222      if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
2223        Constant *LHSCV = LHSC->getValue();
2224        Constant *RHSCV = RHSC->getValue();
2225        return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
2226                                                                   RHSCV)));
2227      }
2228    }
2229  }
2230
2231  FoldingSetNodeID ID;
2232  ID.AddInteger(scUDivExpr);
2233  ID.AddPointer(LHS);
2234  ID.AddPointer(RHS);
2235  void *IP = 0;
2236  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2237  SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
2238                                             LHS, RHS);
2239  UniqueSCEVs.InsertNode(S, IP);
2240  return S;
2241}
2242
2243
2244/// getAddRecExpr - Get an add recurrence expression for the specified loop.
2245/// Simplify the expression as much as possible.
2246const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
2247                                           const Loop *L,
2248                                           SCEV::NoWrapFlags Flags) {
2249  SmallVector<const SCEV *, 4> Operands;
2250  Operands.push_back(Start);
2251  if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
2252    if (StepChrec->getLoop() == L) {
2253      Operands.append(StepChrec->op_begin(), StepChrec->op_end());
2254      return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
2255    }
2256
2257  Operands.push_back(Step);
2258  return getAddRecExpr(Operands, L, Flags);
2259}
2260
2261/// getAddRecExpr - Get an add recurrence expression for the specified loop.
2262/// Simplify the expression as much as possible.
2263const SCEV *
2264ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
2265                               const Loop *L, SCEV::NoWrapFlags Flags) {
2266  if (Operands.size() == 1) return Operands[0];
2267#ifndef NDEBUG
2268  Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
2269  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
2270    assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy &&
2271           "SCEVAddRecExpr operand types don't match!");
2272  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2273    assert(isLoopInvariant(Operands[i], L) &&
2274           "SCEVAddRecExpr operand is not loop-invariant!");
2275#endif
2276
2277  if (Operands.back()->isZero()) {
2278    Operands.pop_back();
2279    return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0}  -->  X
2280  }
2281
2282  // It's tempting to want to call getMaxBackedgeTakenCount count here and
2283  // use that information to infer NUW and NSW flags. However, computing a
2284  // BE count requires calling getAddRecExpr, so we may not yet have a
2285  // meaningful BE count at this point (and if we don't, we'd be stuck
2286  // with a SCEVCouldNotCompute as the cached BE count).
2287
2288  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2289  // And vice-versa.
2290  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2291  SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
2292  if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
2293    bool All = true;
2294    for (SmallVectorImpl<const SCEV *>::const_iterator I = Operands.begin(),
2295         E = Operands.end(); I != E; ++I)
2296      if (!isKnownNonNegative(*I)) {
2297        All = false;
2298        break;
2299      }
2300    if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2301  }
2302
2303  // Canonicalize nested AddRecs in by nesting them in order of loop depth.
2304  if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
2305    const Loop *NestedLoop = NestedAR->getLoop();
2306    if (L->contains(NestedLoop) ?
2307        (L->getLoopDepth() < NestedLoop->getLoopDepth()) :
2308        (!NestedLoop->contains(L) &&
2309         DT->dominates(L->getHeader(), NestedLoop->getHeader()))) {
2310      SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(),
2311                                                  NestedAR->op_end());
2312      Operands[0] = NestedAR->getStart();
2313      // AddRecs require their operands be loop-invariant with respect to their
2314      // loops. Don't perform this transformation if it would break this
2315      // requirement.
2316      bool AllInvariant = true;
2317      for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2318        if (!isLoopInvariant(Operands[i], L)) {
2319          AllInvariant = false;
2320          break;
2321        }
2322      if (AllInvariant) {
2323        // Create a recurrence for the outer loop with the same step size.
2324        //
2325        // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
2326        // inner recurrence has the same property.
2327        SCEV::NoWrapFlags OuterFlags =
2328          maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
2329
2330        NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
2331        AllInvariant = true;
2332        for (unsigned i = 0, e = NestedOperands.size(); i != e; ++i)
2333          if (!isLoopInvariant(NestedOperands[i], NestedLoop)) {
2334            AllInvariant = false;
2335            break;
2336          }
2337        if (AllInvariant) {
2338          // Ok, both add recurrences are valid after the transformation.
2339          //
2340          // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
2341          // the outer recurrence has the same property.
2342          SCEV::NoWrapFlags InnerFlags =
2343            maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
2344          return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
2345        }
2346      }
2347      // Reset Operands to its original state.
2348      Operands[0] = NestedAR;
2349    }
2350  }
2351
2352  // Okay, it looks like we really DO need an addrec expr.  Check to see if we
2353  // already have one, otherwise create a new one.
2354  FoldingSetNodeID ID;
2355  ID.AddInteger(scAddRecExpr);
2356  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2357    ID.AddPointer(Operands[i]);
2358  ID.AddPointer(L);
2359  void *IP = 0;
2360  SCEVAddRecExpr *S =
2361    static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2362  if (!S) {
2363    const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Operands.size());
2364    std::uninitialized_copy(Operands.begin(), Operands.end(), O);
2365    S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator),
2366                                           O, Operands.size(), L);
2367    UniqueSCEVs.InsertNode(S, IP);
2368  }
2369  S->setNoWrapFlags(Flags);
2370  return S;
2371}
2372
2373const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS,
2374                                         const SCEV *RHS) {
2375  SmallVector<const SCEV *, 2> Ops;
2376  Ops.push_back(LHS);
2377  Ops.push_back(RHS);
2378  return getSMaxExpr(Ops);
2379}
2380
2381const SCEV *
2382ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
2383  assert(!Ops.empty() && "Cannot get empty smax!");
2384  if (Ops.size() == 1) return Ops[0];
2385#ifndef NDEBUG
2386  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2387  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2388    assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2389           "SCEVSMaxExpr operand types don't match!");
2390#endif
2391
2392  // Sort by complexity, this groups all similar expression types together.
2393  GroupByComplexity(Ops, LI);
2394
2395  // If there are any constants, fold them together.
2396  unsigned Idx = 0;
2397  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2398    ++Idx;
2399    assert(Idx < Ops.size());
2400    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2401      // We found two constants, fold them together!
2402      ConstantInt *Fold = ConstantInt::get(getContext(),
2403                              APIntOps::smax(LHSC->getValue()->getValue(),
2404                                             RHSC->getValue()->getValue()));
2405      Ops[0] = getConstant(Fold);
2406      Ops.erase(Ops.begin()+1);  // Erase the folded element
2407      if (Ops.size() == 1) return Ops[0];
2408      LHSC = cast<SCEVConstant>(Ops[0]);
2409    }
2410
2411    // If we are left with a constant minimum-int, strip it off.
2412    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
2413      Ops.erase(Ops.begin());
2414      --Idx;
2415    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(true)) {
2416      // If we have an smax with a constant maximum-int, it will always be
2417      // maximum-int.
2418      return Ops[0];
2419    }
2420
2421    if (Ops.size() == 1) return Ops[0];
2422  }
2423
2424  // Find the first SMax
2425  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
2426    ++Idx;
2427
2428  // Check to see if one of the operands is an SMax. If so, expand its operands
2429  // onto our operand list, and recurse to simplify.
2430  if (Idx < Ops.size()) {
2431    bool DeletedSMax = false;
2432    while (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
2433      Ops.erase(Ops.begin()+Idx);
2434      Ops.append(SMax->op_begin(), SMax->op_end());
2435      DeletedSMax = true;
2436    }
2437
2438    if (DeletedSMax)
2439      return getSMaxExpr(Ops);
2440  }
2441
2442  // Okay, check to see if the same value occurs in the operand list twice.  If
2443  // so, delete one.  Since we sorted the list, these values are required to
2444  // be adjacent.
2445  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
2446    //  X smax Y smax Y  -->  X smax Y
2447    //  X smax Y         -->  X, if X is always greater than Y
2448    if (Ops[i] == Ops[i+1] ||
2449        isKnownPredicate(ICmpInst::ICMP_SGE, Ops[i], Ops[i+1])) {
2450      Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2);
2451      --i; --e;
2452    } else if (isKnownPredicate(ICmpInst::ICMP_SLE, Ops[i], Ops[i+1])) {
2453      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
2454      --i; --e;
2455    }
2456
2457  if (Ops.size() == 1) return Ops[0];
2458
2459  assert(!Ops.empty() && "Reduced smax down to nothing!");
2460
2461  // Okay, it looks like we really DO need an smax expr.  Check to see if we
2462  // already have one, otherwise create a new one.
2463  FoldingSetNodeID ID;
2464  ID.AddInteger(scSMaxExpr);
2465  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2466    ID.AddPointer(Ops[i]);
2467  void *IP = 0;
2468  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2469  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2470  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2471  SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator),
2472                                             O, Ops.size());
2473  UniqueSCEVs.InsertNode(S, IP);
2474  return S;
2475}
2476
2477const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS,
2478                                         const SCEV *RHS) {
2479  SmallVector<const SCEV *, 2> Ops;
2480  Ops.push_back(LHS);
2481  Ops.push_back(RHS);
2482  return getUMaxExpr(Ops);
2483}
2484
2485const SCEV *
2486ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
2487  assert(!Ops.empty() && "Cannot get empty umax!");
2488  if (Ops.size() == 1) return Ops[0];
2489#ifndef NDEBUG
2490  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2491  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2492    assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2493           "SCEVUMaxExpr operand types don't match!");
2494#endif
2495
2496  // Sort by complexity, this groups all similar expression types together.
2497  GroupByComplexity(Ops, LI);
2498
2499  // If there are any constants, fold them together.
2500  unsigned Idx = 0;
2501  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2502    ++Idx;
2503    assert(Idx < Ops.size());
2504    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2505      // We found two constants, fold them together!
2506      ConstantInt *Fold = ConstantInt::get(getContext(),
2507                              APIntOps::umax(LHSC->getValue()->getValue(),
2508                                             RHSC->getValue()->getValue()));
2509      Ops[0] = getConstant(Fold);
2510      Ops.erase(Ops.begin()+1);  // Erase the folded element
2511      if (Ops.size() == 1) return Ops[0];
2512      LHSC = cast<SCEVConstant>(Ops[0]);
2513    }
2514
2515    // If we are left with a constant minimum-int, strip it off.
2516    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) {
2517      Ops.erase(Ops.begin());
2518      --Idx;
2519    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(false)) {
2520      // If we have an umax with a constant maximum-int, it will always be
2521      // maximum-int.
2522      return Ops[0];
2523    }
2524
2525    if (Ops.size() == 1) return Ops[0];
2526  }
2527
2528  // Find the first UMax
2529  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr)
2530    ++Idx;
2531
2532  // Check to see if one of the operands is a UMax. If so, expand its operands
2533  // onto our operand list, and recurse to simplify.
2534  if (Idx < Ops.size()) {
2535    bool DeletedUMax = false;
2536    while (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) {
2537      Ops.erase(Ops.begin()+Idx);
2538      Ops.append(UMax->op_begin(), UMax->op_end());
2539      DeletedUMax = true;
2540    }
2541
2542    if (DeletedUMax)
2543      return getUMaxExpr(Ops);
2544  }
2545
2546  // Okay, check to see if the same value occurs in the operand list twice.  If
2547  // so, delete one.  Since we sorted the list, these values are required to
2548  // be adjacent.
2549  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
2550    //  X umax Y umax Y  -->  X umax Y
2551    //  X umax Y         -->  X, if X is always greater than Y
2552    if (Ops[i] == Ops[i+1] ||
2553        isKnownPredicate(ICmpInst::ICMP_UGE, Ops[i], Ops[i+1])) {
2554      Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2);
2555      --i; --e;
2556    } else if (isKnownPredicate(ICmpInst::ICMP_ULE, Ops[i], Ops[i+1])) {
2557      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
2558      --i; --e;
2559    }
2560
2561  if (Ops.size() == 1) return Ops[0];
2562
2563  assert(!Ops.empty() && "Reduced umax down to nothing!");
2564
2565  // Okay, it looks like we really DO need a umax expr.  Check to see if we
2566  // already have one, otherwise create a new one.
2567  FoldingSetNodeID ID;
2568  ID.AddInteger(scUMaxExpr);
2569  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2570    ID.AddPointer(Ops[i]);
2571  void *IP = 0;
2572  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2573  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2574  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2575  SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator),
2576                                             O, Ops.size());
2577  UniqueSCEVs.InsertNode(S, IP);
2578  return S;
2579}
2580
2581const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
2582                                         const SCEV *RHS) {
2583  // ~smax(~x, ~y) == smin(x, y).
2584  return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
2585}
2586
2587const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS,
2588                                         const SCEV *RHS) {
2589  // ~umax(~x, ~y) == umin(x, y)
2590  return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
2591}
2592
2593const SCEV *ScalarEvolution::getSizeOfExpr(Type *AllocTy) {
2594  // If we have TargetData, we can bypass creating a target-independent
2595  // constant expression and then folding it back into a ConstantInt.
2596  // This is just a compile-time optimization.
2597  if (TD)
2598    return getConstant(TD->getIntPtrType(getContext()),
2599                       TD->getTypeAllocSize(AllocTy));
2600
2601  Constant *C = ConstantExpr::getSizeOf(AllocTy);
2602  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2603    if (Constant *Folded = ConstantFoldConstantExpression(CE, TD, TLI))
2604      C = Folded;
2605  Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy));
2606  return getTruncateOrZeroExtend(getSCEV(C), Ty);
2607}
2608
2609const SCEV *ScalarEvolution::getAlignOfExpr(Type *AllocTy) {
2610  Constant *C = ConstantExpr::getAlignOf(AllocTy);
2611  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2612    if (Constant *Folded = ConstantFoldConstantExpression(CE, TD, TLI))
2613      C = Folded;
2614  Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy));
2615  return getTruncateOrZeroExtend(getSCEV(C), Ty);
2616}
2617
2618const SCEV *ScalarEvolution::getOffsetOfExpr(StructType *STy,
2619                                             unsigned FieldNo) {
2620  // If we have TargetData, we can bypass creating a target-independent
2621  // constant expression and then folding it back into a ConstantInt.
2622  // This is just a compile-time optimization.
2623  if (TD)
2624    return getConstant(TD->getIntPtrType(getContext()),
2625                       TD->getStructLayout(STy)->getElementOffset(FieldNo));
2626
2627  Constant *C = ConstantExpr::getOffsetOf(STy, FieldNo);
2628  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2629    if (Constant *Folded = ConstantFoldConstantExpression(CE, TD, TLI))
2630      C = Folded;
2631  Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(STy));
2632  return getTruncateOrZeroExtend(getSCEV(C), Ty);
2633}
2634
2635const SCEV *ScalarEvolution::getOffsetOfExpr(Type *CTy,
2636                                             Constant *FieldNo) {
2637  Constant *C = ConstantExpr::getOffsetOf(CTy, FieldNo);
2638  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2639    if (Constant *Folded = ConstantFoldConstantExpression(CE, TD, TLI))
2640      C = Folded;
2641  Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(CTy));
2642  return getTruncateOrZeroExtend(getSCEV(C), Ty);
2643}
2644
2645const SCEV *ScalarEvolution::getUnknown(Value *V) {
2646  // Don't attempt to do anything other than create a SCEVUnknown object
2647  // here.  createSCEV only calls getUnknown after checking for all other
2648  // interesting possibilities, and any other code that calls getUnknown
2649  // is doing so in order to hide a value from SCEV canonicalization.
2650
2651  FoldingSetNodeID ID;
2652  ID.AddInteger(scUnknown);
2653  ID.AddPointer(V);
2654  void *IP = 0;
2655  if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
2656    assert(cast<SCEVUnknown>(S)->getValue() == V &&
2657           "Stale SCEVUnknown in uniquing map!");
2658    return S;
2659  }
2660  SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
2661                                            FirstUnknown);
2662  FirstUnknown = cast<SCEVUnknown>(S);
2663  UniqueSCEVs.InsertNode(S, IP);
2664  return S;
2665}
2666
2667//===----------------------------------------------------------------------===//
2668//            Basic SCEV Analysis and PHI Idiom Recognition Code
2669//
2670
2671/// isSCEVable - Test if values of the given type are analyzable within
2672/// the SCEV framework. This primarily includes integer types, and it
2673/// can optionally include pointer types if the ScalarEvolution class
2674/// has access to target-specific information.
2675bool ScalarEvolution::isSCEVable(Type *Ty) const {
2676  // Integers and pointers are always SCEVable.
2677  return Ty->isIntegerTy() || Ty->isPointerTy();
2678}
2679
2680/// getTypeSizeInBits - Return the size in bits of the specified type,
2681/// for which isSCEVable must return true.
2682uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
2683  assert(isSCEVable(Ty) && "Type is not SCEVable!");
2684
2685  // If we have a TargetData, use it!
2686  if (TD)
2687    return TD->getTypeSizeInBits(Ty);
2688
2689  // Integer types have fixed sizes.
2690  if (Ty->isIntegerTy())
2691    return Ty->getPrimitiveSizeInBits();
2692
2693  // The only other support type is pointer. Without TargetData, conservatively
2694  // assume pointers are 64-bit.
2695  assert(Ty->isPointerTy() && "isSCEVable permitted a non-SCEVable type!");
2696  return 64;
2697}
2698
2699/// getEffectiveSCEVType - Return a type with the same bitwidth as
2700/// the given type and which represents how SCEV will treat the given
2701/// type, for which isSCEVable must return true. For pointer types,
2702/// this is the pointer-sized integer type.
2703Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const {
2704  assert(isSCEVable(Ty) && "Type is not SCEVable!");
2705
2706  if (Ty->isIntegerTy())
2707    return Ty;
2708
2709  // The only other support type is pointer.
2710  assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
2711  if (TD) return TD->getIntPtrType(getContext());
2712
2713  // Without TargetData, conservatively assume pointers are 64-bit.
2714  return Type::getInt64Ty(getContext());
2715}
2716
2717const SCEV *ScalarEvolution::getCouldNotCompute() {
2718  return &CouldNotCompute;
2719}
2720
2721/// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
2722/// expression and create a new one.
2723const SCEV *ScalarEvolution::getSCEV(Value *V) {
2724  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
2725
2726  ValueExprMapType::const_iterator I = ValueExprMap.find(V);
2727  if (I != ValueExprMap.end()) return I->second;
2728  const SCEV *S = createSCEV(V);
2729
2730  // The process of creating a SCEV for V may have caused other SCEVs
2731  // to have been created, so it's necessary to insert the new entry
2732  // from scratch, rather than trying to remember the insert position
2733  // above.
2734  ValueExprMap.insert(std::make_pair(SCEVCallbackVH(V, this), S));
2735  return S;
2736}
2737
2738/// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
2739///
2740const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V) {
2741  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
2742    return getConstant(
2743               cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
2744
2745  Type *Ty = V->getType();
2746  Ty = getEffectiveSCEVType(Ty);
2747  return getMulExpr(V,
2748                  getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))));
2749}
2750
2751/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
2752const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
2753  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
2754    return getConstant(
2755                cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
2756
2757  Type *Ty = V->getType();
2758  Ty = getEffectiveSCEVType(Ty);
2759  const SCEV *AllOnes =
2760                   getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty)));
2761  return getMinusSCEV(AllOnes, V);
2762}
2763
2764/// getMinusSCEV - Return LHS-RHS.  Minus is represented in SCEV as A+B*-1.
2765const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
2766                                          SCEV::NoWrapFlags Flags) {
2767  assert(!maskFlags(Flags, SCEV::FlagNUW) && "subtraction does not have NUW");
2768
2769  // Fast path: X - X --> 0.
2770  if (LHS == RHS)
2771    return getConstant(LHS->getType(), 0);
2772
2773  // X - Y --> X + -Y
2774  return getAddExpr(LHS, getNegativeSCEV(RHS), Flags);
2775}
2776
2777/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
2778/// input value to the specified type.  If the type must be extended, it is zero
2779/// extended.
2780const SCEV *
2781ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) {
2782  Type *SrcTy = V->getType();
2783  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2784         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2785         "Cannot truncate or zero extend with non-integer arguments!");
2786  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2787    return V;  // No conversion
2788  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
2789    return getTruncateExpr(V, Ty);
2790  return getZeroExtendExpr(V, Ty);
2791}
2792
2793/// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the
2794/// input value to the specified type.  If the type must be extended, it is sign
2795/// extended.
2796const SCEV *
2797ScalarEvolution::getTruncateOrSignExtend(const SCEV *V,
2798                                         Type *Ty) {
2799  Type *SrcTy = V->getType();
2800  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2801         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2802         "Cannot truncate or zero extend with non-integer arguments!");
2803  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2804    return V;  // No conversion
2805  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
2806    return getTruncateExpr(V, Ty);
2807  return getSignExtendExpr(V, Ty);
2808}
2809
2810/// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the
2811/// input value to the specified type.  If the type must be extended, it is zero
2812/// extended.  The conversion must not be narrowing.
2813const SCEV *
2814ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
2815  Type *SrcTy = V->getType();
2816  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2817         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2818         "Cannot noop or zero extend with non-integer arguments!");
2819  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2820         "getNoopOrZeroExtend cannot truncate!");
2821  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2822    return V;  // No conversion
2823  return getZeroExtendExpr(V, Ty);
2824}
2825
2826/// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the
2827/// input value to the specified type.  If the type must be extended, it is sign
2828/// extended.  The conversion must not be narrowing.
2829const SCEV *
2830ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) {
2831  Type *SrcTy = V->getType();
2832  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2833         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2834         "Cannot noop or sign extend with non-integer arguments!");
2835  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2836         "getNoopOrSignExtend cannot truncate!");
2837  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2838    return V;  // No conversion
2839  return getSignExtendExpr(V, Ty);
2840}
2841
2842/// getNoopOrAnyExtend - Return a SCEV corresponding to a conversion of
2843/// the input value to the specified type. If the type must be extended,
2844/// it is extended with unspecified bits. The conversion must not be
2845/// narrowing.
2846const SCEV *
2847ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) {
2848  Type *SrcTy = V->getType();
2849  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2850         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2851         "Cannot noop or any extend with non-integer arguments!");
2852  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2853         "getNoopOrAnyExtend cannot truncate!");
2854  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2855    return V;  // No conversion
2856  return getAnyExtendExpr(V, Ty);
2857}
2858
2859/// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the
2860/// input value to the specified type.  The conversion must not be widening.
2861const SCEV *
2862ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) {
2863  Type *SrcTy = V->getType();
2864  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2865         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2866         "Cannot truncate or noop with non-integer arguments!");
2867  assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
2868         "getTruncateOrNoop cannot extend!");
2869  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2870    return V;  // No conversion
2871  return getTruncateExpr(V, Ty);
2872}
2873
2874/// getUMaxFromMismatchedTypes - Promote the operands to the wider of
2875/// the types using zero-extension, and then perform a umax operation
2876/// with them.
2877const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
2878                                                        const SCEV *RHS) {
2879  const SCEV *PromotedLHS = LHS;
2880  const SCEV *PromotedRHS = RHS;
2881
2882  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
2883    PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
2884  else
2885    PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
2886
2887  return getUMaxExpr(PromotedLHS, PromotedRHS);
2888}
2889
2890/// getUMinFromMismatchedTypes - Promote the operands to the wider of
2891/// the types using zero-extension, and then perform a umin operation
2892/// with them.
2893const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
2894                                                        const SCEV *RHS) {
2895  const SCEV *PromotedLHS = LHS;
2896  const SCEV *PromotedRHS = RHS;
2897
2898  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
2899    PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
2900  else
2901    PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
2902
2903  return getUMinExpr(PromotedLHS, PromotedRHS);
2904}
2905
2906/// getPointerBase - Transitively follow the chain of pointer-type operands
2907/// until reaching a SCEV that does not have a single pointer operand. This
2908/// returns a SCEVUnknown pointer for well-formed pointer-type expressions,
2909/// but corner cases do exist.
2910const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
2911  // A pointer operand may evaluate to a nonpointer expression, such as null.
2912  if (!V->getType()->isPointerTy())
2913    return V;
2914
2915  if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
2916    return getPointerBase(Cast->getOperand());
2917  }
2918  else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
2919    const SCEV *PtrOp = 0;
2920    for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
2921         I != E; ++I) {
2922      if ((*I)->getType()->isPointerTy()) {
2923        // Cannot find the base of an expression with multiple pointer operands.
2924        if (PtrOp)
2925          return V;
2926        PtrOp = *I;
2927      }
2928    }
2929    if (!PtrOp)
2930      return V;
2931    return getPointerBase(PtrOp);
2932  }
2933  return V;
2934}
2935
2936/// PushDefUseChildren - Push users of the given Instruction
2937/// onto the given Worklist.
2938static void
2939PushDefUseChildren(Instruction *I,
2940                   SmallVectorImpl<Instruction *> &Worklist) {
2941  // Push the def-use children onto the Worklist stack.
2942  for (Value::use_iterator UI = I->use_begin(), UE = I->use_end();
2943       UI != UE; ++UI)
2944    Worklist.push_back(cast<Instruction>(*UI));
2945}
2946
2947/// ForgetSymbolicValue - This looks up computed SCEV values for all
2948/// instructions that depend on the given instruction and removes them from
2949/// the ValueExprMapType map if they reference SymName. This is used during PHI
2950/// resolution.
2951void
2952ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) {
2953  SmallVector<Instruction *, 16> Worklist;
2954  PushDefUseChildren(PN, Worklist);
2955
2956  SmallPtrSet<Instruction *, 8> Visited;
2957  Visited.insert(PN);
2958  while (!Worklist.empty()) {
2959    Instruction *I = Worklist.pop_back_val();
2960    if (!Visited.insert(I)) continue;
2961
2962    ValueExprMapType::iterator It =
2963      ValueExprMap.find(static_cast<Value *>(I));
2964    if (It != ValueExprMap.end()) {
2965      const SCEV *Old = It->second;
2966
2967      // Short-circuit the def-use traversal if the symbolic name
2968      // ceases to appear in expressions.
2969      if (Old != SymName && !hasOperand(Old, SymName))
2970        continue;
2971
2972      // SCEVUnknown for a PHI either means that it has an unrecognized
2973      // structure, it's a PHI that's in the progress of being computed
2974      // by createNodeForPHI, or it's a single-value PHI. In the first case,
2975      // additional loop trip count information isn't going to change anything.
2976      // In the second case, createNodeForPHI will perform the necessary
2977      // updates on its own when it gets to that point. In the third, we do
2978      // want to forget the SCEVUnknown.
2979      if (!isa<PHINode>(I) ||
2980          !isa<SCEVUnknown>(Old) ||
2981          (I != PN && Old == SymName)) {
2982        forgetMemoizedResults(Old);
2983        ValueExprMap.erase(It);
2984      }
2985    }
2986
2987    PushDefUseChildren(I, Worklist);
2988  }
2989}
2990
2991/// createNodeForPHI - PHI nodes have two cases.  Either the PHI node exists in
2992/// a loop header, making it a potential recurrence, or it doesn't.
2993///
2994const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
2995  if (const Loop *L = LI->getLoopFor(PN->getParent()))
2996    if (L->getHeader() == PN->getParent()) {
2997      // The loop may have multiple entrances or multiple exits; we can analyze
2998      // this phi as an addrec if it has a unique entry value and a unique
2999      // backedge value.
3000      Value *BEValueV = 0, *StartValueV = 0;
3001      for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
3002        Value *V = PN->getIncomingValue(i);
3003        if (L->contains(PN->getIncomingBlock(i))) {
3004          if (!BEValueV) {
3005            BEValueV = V;
3006          } else if (BEValueV != V) {
3007            BEValueV = 0;
3008            break;
3009          }
3010        } else if (!StartValueV) {
3011          StartValueV = V;
3012        } else if (StartValueV != V) {
3013          StartValueV = 0;
3014          break;
3015        }
3016      }
3017      if (BEValueV && StartValueV) {
3018        // While we are analyzing this PHI node, handle its value symbolically.
3019        const SCEV *SymbolicName = getUnknown(PN);
3020        assert(ValueExprMap.find(PN) == ValueExprMap.end() &&
3021               "PHI node already processed?");
3022        ValueExprMap.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName));
3023
3024        // Using this symbolic name for the PHI, analyze the value coming around
3025        // the back-edge.
3026        const SCEV *BEValue = getSCEV(BEValueV);
3027
3028        // NOTE: If BEValue is loop invariant, we know that the PHI node just
3029        // has a special value for the first iteration of the loop.
3030
3031        // If the value coming around the backedge is an add with the symbolic
3032        // value we just inserted, then we found a simple induction variable!
3033        if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
3034          // If there is a single occurrence of the symbolic value, replace it
3035          // with a recurrence.
3036          unsigned FoundIndex = Add->getNumOperands();
3037          for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
3038            if (Add->getOperand(i) == SymbolicName)
3039              if (FoundIndex == e) {
3040                FoundIndex = i;
3041                break;
3042              }
3043
3044          if (FoundIndex != Add->getNumOperands()) {
3045            // Create an add with everything but the specified operand.
3046            SmallVector<const SCEV *, 8> Ops;
3047            for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
3048              if (i != FoundIndex)
3049                Ops.push_back(Add->getOperand(i));
3050            const SCEV *Accum = getAddExpr(Ops);
3051
3052            // This is not a valid addrec if the step amount is varying each
3053            // loop iteration, but is not itself an addrec in this loop.
3054            if (isLoopInvariant(Accum, L) ||
3055                (isa<SCEVAddRecExpr>(Accum) &&
3056                 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
3057              SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
3058
3059              // If the increment doesn't overflow, then neither the addrec nor
3060              // the post-increment will overflow.
3061              if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV)) {
3062                if (OBO->hasNoUnsignedWrap())
3063                  Flags = setFlags(Flags, SCEV::FlagNUW);
3064                if (OBO->hasNoSignedWrap())
3065                  Flags = setFlags(Flags, SCEV::FlagNSW);
3066              } else if (const GEPOperator *GEP =
3067                         dyn_cast<GEPOperator>(BEValueV)) {
3068                // If the increment is an inbounds GEP, then we know the address
3069                // space cannot be wrapped around. We cannot make any guarantee
3070                // about signed or unsigned overflow because pointers are
3071                // unsigned but we may have a negative index from the base
3072                // pointer.
3073                if (GEP->isInBounds())
3074                  Flags = setFlags(Flags, SCEV::FlagNW);
3075              }
3076
3077              const SCEV *StartVal = getSCEV(StartValueV);
3078              const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
3079
3080              // Since the no-wrap flags are on the increment, they apply to the
3081              // post-incremented value as well.
3082              if (isLoopInvariant(Accum, L))
3083                (void)getAddRecExpr(getAddExpr(StartVal, Accum),
3084                                    Accum, L, Flags);
3085
3086              // Okay, for the entire analysis of this edge we assumed the PHI
3087              // to be symbolic.  We now need to go back and purge all of the
3088              // entries for the scalars that use the symbolic expression.
3089              ForgetSymbolicName(PN, SymbolicName);
3090              ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
3091              return PHISCEV;
3092            }
3093          }
3094        } else if (const SCEVAddRecExpr *AddRec =
3095                     dyn_cast<SCEVAddRecExpr>(BEValue)) {
3096          // Otherwise, this could be a loop like this:
3097          //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
3098          // In this case, j = {1,+,1}  and BEValue is j.
3099          // Because the other in-value of i (0) fits the evolution of BEValue
3100          // i really is an addrec evolution.
3101          if (AddRec->getLoop() == L && AddRec->isAffine()) {
3102            const SCEV *StartVal = getSCEV(StartValueV);
3103
3104            // If StartVal = j.start - j.stride, we can use StartVal as the
3105            // initial step of the addrec evolution.
3106            if (StartVal == getMinusSCEV(AddRec->getOperand(0),
3107                                         AddRec->getOperand(1))) {
3108              // FIXME: For constant StartVal, we should be able to infer
3109              // no-wrap flags.
3110              const SCEV *PHISCEV =
3111                getAddRecExpr(StartVal, AddRec->getOperand(1), L,
3112                              SCEV::FlagAnyWrap);
3113
3114              // Okay, for the entire analysis of this edge we assumed the PHI
3115              // to be symbolic.  We now need to go back and purge all of the
3116              // entries for the scalars that use the symbolic expression.
3117              ForgetSymbolicName(PN, SymbolicName);
3118              ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
3119              return PHISCEV;
3120            }
3121          }
3122        }
3123      }
3124    }
3125
3126  // If the PHI has a single incoming value, follow that value, unless the
3127  // PHI's incoming blocks are in a different loop, in which case doing so
3128  // risks breaking LCSSA form. Instcombine would normally zap these, but
3129  // it doesn't have DominatorTree information, so it may miss cases.
3130  if (Value *V = SimplifyInstruction(PN, TD, TLI, DT))
3131    if (LI->replacementPreservesLCSSAForm(PN, V))
3132      return getSCEV(V);
3133
3134  // If it's not a loop phi, we can't handle it yet.
3135  return getUnknown(PN);
3136}
3137
3138/// createNodeForGEP - Expand GEP instructions into add and multiply
3139/// operations. This allows them to be analyzed by regular SCEV code.
3140///
3141const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
3142
3143  // Don't blindly transfer the inbounds flag from the GEP instruction to the
3144  // Add expression, because the Instruction may be guarded by control flow
3145  // and the no-overflow bits may not be valid for the expression in any
3146  // context.
3147  bool isInBounds = GEP->isInBounds();
3148
3149  Type *IntPtrTy = getEffectiveSCEVType(GEP->getType());
3150  Value *Base = GEP->getOperand(0);
3151  // Don't attempt to analyze GEPs over unsized objects.
3152  if (!cast<PointerType>(Base->getType())->getElementType()->isSized())
3153    return getUnknown(GEP);
3154  const SCEV *TotalOffset = getConstant(IntPtrTy, 0);
3155  gep_type_iterator GTI = gep_type_begin(GEP);
3156  for (GetElementPtrInst::op_iterator I = llvm::next(GEP->op_begin()),
3157                                      E = GEP->op_end();
3158       I != E; ++I) {
3159    Value *Index = *I;
3160    // Compute the (potentially symbolic) offset in bytes for this index.
3161    if (StructType *STy = dyn_cast<StructType>(*GTI++)) {
3162      // For a struct, add the member offset.
3163      unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue();
3164      const SCEV *FieldOffset = getOffsetOfExpr(STy, FieldNo);
3165
3166      // Add the field offset to the running total offset.
3167      TotalOffset = getAddExpr(TotalOffset, FieldOffset);
3168    } else {
3169      // For an array, add the element offset, explicitly scaled.
3170      const SCEV *ElementSize = getSizeOfExpr(*GTI);
3171      const SCEV *IndexS = getSCEV(Index);
3172      // Getelementptr indices are signed.
3173      IndexS = getTruncateOrSignExtend(IndexS, IntPtrTy);
3174
3175      // Multiply the index by the element size to compute the element offset.
3176      const SCEV *LocalOffset = getMulExpr(IndexS, ElementSize,
3177                                           isInBounds ? SCEV::FlagNSW :
3178                                           SCEV::FlagAnyWrap);
3179
3180      // Add the element offset to the running total offset.
3181      TotalOffset = getAddExpr(TotalOffset, LocalOffset);
3182    }
3183  }
3184
3185  // Get the SCEV for the GEP base.
3186  const SCEV *BaseS = getSCEV(Base);
3187
3188  // Add the total offset from all the GEP indices to the base.
3189  return getAddExpr(BaseS, TotalOffset,
3190                    isInBounds ? SCEV::FlagNSW : SCEV::FlagAnyWrap);
3191}
3192
3193/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
3194/// guaranteed to end in (at every loop iteration).  It is, at the same time,
3195/// the minimum number of times S is divisible by 2.  For example, given {4,+,8}
3196/// it returns 2.  If S is guaranteed to be 0, it returns the bitwidth of S.
3197uint32_t
3198ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
3199  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
3200    return C->getValue()->getValue().countTrailingZeros();
3201
3202  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
3203    return std::min(GetMinTrailingZeros(T->getOperand()),
3204                    (uint32_t)getTypeSizeInBits(T->getType()));
3205
3206  if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
3207    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
3208    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
3209             getTypeSizeInBits(E->getType()) : OpRes;
3210  }
3211
3212  if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
3213    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
3214    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
3215             getTypeSizeInBits(E->getType()) : OpRes;
3216  }
3217
3218  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
3219    // The result is the min of all operands results.
3220    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
3221    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
3222      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
3223    return MinOpRes;
3224  }
3225
3226  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
3227    // The result is the sum of all operands results.
3228    uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
3229    uint32_t BitWidth = getTypeSizeInBits(M->getType());
3230    for (unsigned i = 1, e = M->getNumOperands();
3231         SumOpRes != BitWidth && i != e; ++i)
3232      SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
3233                          BitWidth);
3234    return SumOpRes;
3235  }
3236
3237  if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
3238    // The result is the min of all operands results.
3239    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
3240    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
3241      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
3242    return MinOpRes;
3243  }
3244
3245  if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
3246    // The result is the min of all operands results.
3247    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
3248    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
3249      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
3250    return MinOpRes;
3251  }
3252
3253  if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
3254    // The result is the min of all operands results.
3255    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
3256    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
3257      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
3258    return MinOpRes;
3259  }
3260
3261  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
3262    // For a SCEVUnknown, ask ValueTracking.
3263    unsigned BitWidth = getTypeSizeInBits(U->getType());
3264    APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
3265    ComputeMaskedBits(U->getValue(), Zeros, Ones);
3266    return Zeros.countTrailingOnes();
3267  }
3268
3269  // SCEVUDivExpr
3270  return 0;
3271}
3272
3273/// getUnsignedRange - Determine the unsigned range for a particular SCEV.
3274///
3275ConstantRange
3276ScalarEvolution::getUnsignedRange(const SCEV *S) {
3277  // See if we've computed this range already.
3278  DenseMap<const SCEV *, ConstantRange>::iterator I = UnsignedRanges.find(S);
3279  if (I != UnsignedRanges.end())
3280    return I->second;
3281
3282  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
3283    return setUnsignedRange(C, ConstantRange(C->getValue()->getValue()));
3284
3285  unsigned BitWidth = getTypeSizeInBits(S->getType());
3286  ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
3287
3288  // If the value has known zeros, the maximum unsigned value will have those
3289  // known zeros as well.
3290  uint32_t TZ = GetMinTrailingZeros(S);
3291  if (TZ != 0)
3292    ConservativeResult =
3293      ConstantRange(APInt::getMinValue(BitWidth),
3294                    APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
3295
3296  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
3297    ConstantRange X = getUnsignedRange(Add->getOperand(0));
3298    for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
3299      X = X.add(getUnsignedRange(Add->getOperand(i)));
3300    return setUnsignedRange(Add, ConservativeResult.intersectWith(X));
3301  }
3302
3303  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
3304    ConstantRange X = getUnsignedRange(Mul->getOperand(0));
3305    for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
3306      X = X.multiply(getUnsignedRange(Mul->getOperand(i)));
3307    return setUnsignedRange(Mul, ConservativeResult.intersectWith(X));
3308  }
3309
3310  if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
3311    ConstantRange X = getUnsignedRange(SMax->getOperand(0));
3312    for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
3313      X = X.smax(getUnsignedRange(SMax->getOperand(i)));
3314    return setUnsignedRange(SMax, ConservativeResult.intersectWith(X));
3315  }
3316
3317  if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
3318    ConstantRange X = getUnsignedRange(UMax->getOperand(0));
3319    for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
3320      X = X.umax(getUnsignedRange(UMax->getOperand(i)));
3321    return setUnsignedRange(UMax, ConservativeResult.intersectWith(X));
3322  }
3323
3324  if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
3325    ConstantRange X = getUnsignedRange(UDiv->getLHS());
3326    ConstantRange Y = getUnsignedRange(UDiv->getRHS());
3327    return setUnsignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y)));
3328  }
3329
3330  if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
3331    ConstantRange X = getUnsignedRange(ZExt->getOperand());
3332    return setUnsignedRange(ZExt,
3333      ConservativeResult.intersectWith(X.zeroExtend(BitWidth)));
3334  }
3335
3336  if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
3337    ConstantRange X = getUnsignedRange(SExt->getOperand());
3338    return setUnsignedRange(SExt,
3339      ConservativeResult.intersectWith(X.signExtend(BitWidth)));
3340  }
3341
3342  if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
3343    ConstantRange X = getUnsignedRange(Trunc->getOperand());
3344    return setUnsignedRange(Trunc,
3345      ConservativeResult.intersectWith(X.truncate(BitWidth)));
3346  }
3347
3348  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
3349    // If there's no unsigned wrap, the value will never be less than its
3350    // initial value.
3351    if (AddRec->getNoWrapFlags(SCEV::FlagNUW))
3352      if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart()))
3353        if (!C->getValue()->isZero())
3354          ConservativeResult =
3355            ConservativeResult.intersectWith(
3356              ConstantRange(C->getValue()->getValue(), APInt(BitWidth, 0)));
3357
3358    // TODO: non-affine addrec
3359    if (AddRec->isAffine()) {
3360      Type *Ty = AddRec->getType();
3361      const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
3362      if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
3363          getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
3364        MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
3365
3366        const SCEV *Start = AddRec->getStart();
3367        const SCEV *Step = AddRec->getStepRecurrence(*this);
3368
3369        ConstantRange StartRange = getUnsignedRange(Start);
3370        ConstantRange StepRange = getSignedRange(Step);
3371        ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
3372        ConstantRange EndRange =
3373          StartRange.add(MaxBECountRange.multiply(StepRange));
3374
3375        // Check for overflow. This must be done with ConstantRange arithmetic
3376        // because we could be called from within the ScalarEvolution overflow
3377        // checking code.
3378        ConstantRange ExtStartRange = StartRange.zextOrTrunc(BitWidth*2+1);
3379        ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1);
3380        ConstantRange ExtMaxBECountRange =
3381          MaxBECountRange.zextOrTrunc(BitWidth*2+1);
3382        ConstantRange ExtEndRange = EndRange.zextOrTrunc(BitWidth*2+1);
3383        if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) !=
3384            ExtEndRange)
3385          return setUnsignedRange(AddRec, ConservativeResult);
3386
3387        APInt Min = APIntOps::umin(StartRange.getUnsignedMin(),
3388                                   EndRange.getUnsignedMin());
3389        APInt Max = APIntOps::umax(StartRange.getUnsignedMax(),
3390                                   EndRange.getUnsignedMax());
3391        if (Min.isMinValue() && Max.isMaxValue())
3392          return setUnsignedRange(AddRec, ConservativeResult);
3393        return setUnsignedRange(AddRec,
3394          ConservativeResult.intersectWith(ConstantRange(Min, Max+1)));
3395      }
3396    }
3397
3398    return setUnsignedRange(AddRec, ConservativeResult);
3399  }
3400
3401  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
3402    // For a SCEVUnknown, ask ValueTracking.
3403    APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
3404    ComputeMaskedBits(U->getValue(), Zeros, Ones, TD);
3405    if (Ones == ~Zeros + 1)
3406      return setUnsignedRange(U, ConservativeResult);
3407    return setUnsignedRange(U,
3408      ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1)));
3409  }
3410
3411  return setUnsignedRange(S, ConservativeResult);
3412}
3413
3414/// getSignedRange - Determine the signed range for a particular SCEV.
3415///
3416ConstantRange
3417ScalarEvolution::getSignedRange(const SCEV *S) {
3418  // See if we've computed this range already.
3419  DenseMap<const SCEV *, ConstantRange>::iterator I = SignedRanges.find(S);
3420  if (I != SignedRanges.end())
3421    return I->second;
3422
3423  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
3424    return setSignedRange(C, ConstantRange(C->getValue()->getValue()));
3425
3426  unsigned BitWidth = getTypeSizeInBits(S->getType());
3427  ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
3428
3429  // If the value has known zeros, the maximum signed value will have those
3430  // known zeros as well.
3431  uint32_t TZ = GetMinTrailingZeros(S);
3432  if (TZ != 0)
3433    ConservativeResult =
3434      ConstantRange(APInt::getSignedMinValue(BitWidth),
3435                    APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
3436
3437  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
3438    ConstantRange X = getSignedRange(Add->getOperand(0));
3439    for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
3440      X = X.add(getSignedRange(Add->getOperand(i)));
3441    return setSignedRange(Add, ConservativeResult.intersectWith(X));
3442  }
3443
3444  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
3445    ConstantRange X = getSignedRange(Mul->getOperand(0));
3446    for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
3447      X = X.multiply(getSignedRange(Mul->getOperand(i)));
3448    return setSignedRange(Mul, ConservativeResult.intersectWith(X));
3449  }
3450
3451  if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
3452    ConstantRange X = getSignedRange(SMax->getOperand(0));
3453    for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
3454      X = X.smax(getSignedRange(SMax->getOperand(i)));
3455    return setSignedRange(SMax, ConservativeResult.intersectWith(X));
3456  }
3457
3458  if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
3459    ConstantRange X = getSignedRange(UMax->getOperand(0));
3460    for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
3461      X = X.umax(getSignedRange(UMax->getOperand(i)));
3462    return setSignedRange(UMax, ConservativeResult.intersectWith(X));
3463  }
3464
3465  if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
3466    ConstantRange X = getSignedRange(UDiv->getLHS());
3467    ConstantRange Y = getSignedRange(UDiv->getRHS());
3468    return setSignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y)));
3469  }
3470
3471  if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
3472    ConstantRange X = getSignedRange(ZExt->getOperand());
3473    return setSignedRange(ZExt,
3474      ConservativeResult.intersectWith(X.zeroExtend(BitWidth)));
3475  }
3476
3477  if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
3478    ConstantRange X = getSignedRange(SExt->getOperand());
3479    return setSignedRange(SExt,
3480      ConservativeResult.intersectWith(X.signExtend(BitWidth)));
3481  }
3482
3483  if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
3484    ConstantRange X = getSignedRange(Trunc->getOperand());
3485    return setSignedRange(Trunc,
3486      ConservativeResult.intersectWith(X.truncate(BitWidth)));
3487  }
3488
3489  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
3490    // If there's no signed wrap, and all the operands have the same sign or
3491    // zero, the value won't ever change sign.
3492    if (AddRec->getNoWrapFlags(SCEV::FlagNSW)) {
3493      bool AllNonNeg = true;
3494      bool AllNonPos = true;
3495      for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3496        if (!isKnownNonNegative(AddRec->getOperand(i))) AllNonNeg = false;
3497        if (!isKnownNonPositive(AddRec->getOperand(i))) AllNonPos = false;
3498      }
3499      if (AllNonNeg)
3500        ConservativeResult = ConservativeResult.intersectWith(
3501          ConstantRange(APInt(BitWidth, 0),
3502                        APInt::getSignedMinValue(BitWidth)));
3503      else if (AllNonPos)
3504        ConservativeResult = ConservativeResult.intersectWith(
3505          ConstantRange(APInt::getSignedMinValue(BitWidth),
3506                        APInt(BitWidth, 1)));
3507    }
3508
3509    // TODO: non-affine addrec
3510    if (AddRec->isAffine()) {
3511      Type *Ty = AddRec->getType();
3512      const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
3513      if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
3514          getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
3515        MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
3516
3517        const SCEV *Start = AddRec->getStart();
3518        const SCEV *Step = AddRec->getStepRecurrence(*this);
3519
3520        ConstantRange StartRange = getSignedRange(Start);
3521        ConstantRange StepRange = getSignedRange(Step);
3522        ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
3523        ConstantRange EndRange =
3524          StartRange.add(MaxBECountRange.multiply(StepRange));
3525
3526        // Check for overflow. This must be done with ConstantRange arithmetic
3527        // because we could be called from within the ScalarEvolution overflow
3528        // checking code.
3529        ConstantRange ExtStartRange = StartRange.sextOrTrunc(BitWidth*2+1);
3530        ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1);
3531        ConstantRange ExtMaxBECountRange =
3532          MaxBECountRange.zextOrTrunc(BitWidth*2+1);
3533        ConstantRange ExtEndRange = EndRange.sextOrTrunc(BitWidth*2+1);
3534        if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) !=
3535            ExtEndRange)
3536          return setSignedRange(AddRec, ConservativeResult);
3537
3538        APInt Min = APIntOps::smin(StartRange.getSignedMin(),
3539                                   EndRange.getSignedMin());
3540        APInt Max = APIntOps::smax(StartRange.getSignedMax(),
3541                                   EndRange.getSignedMax());
3542        if (Min.isMinSignedValue() && Max.isMaxSignedValue())
3543          return setSignedRange(AddRec, ConservativeResult);
3544        return setSignedRange(AddRec,
3545          ConservativeResult.intersectWith(ConstantRange(Min, Max+1)));
3546      }
3547    }
3548
3549    return setSignedRange(AddRec, ConservativeResult);
3550  }
3551
3552  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
3553    // For a SCEVUnknown, ask ValueTracking.
3554    if (!U->getValue()->getType()->isIntegerTy() && !TD)
3555      return setSignedRange(U, ConservativeResult);
3556    unsigned NS = ComputeNumSignBits(U->getValue(), TD);
3557    if (NS == 1)
3558      return setSignedRange(U, ConservativeResult);
3559    return setSignedRange(U, ConservativeResult.intersectWith(
3560      ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
3561                    APInt::getSignedMaxValue(BitWidth).ashr(NS - 1)+1)));
3562  }
3563
3564  return setSignedRange(S, ConservativeResult);
3565}
3566
3567/// createSCEV - We know that there is no SCEV for the specified value.
3568/// Analyze the expression.
3569///
3570const SCEV *ScalarEvolution::createSCEV(Value *V) {
3571  if (!isSCEVable(V->getType()))
3572    return getUnknown(V);
3573
3574  unsigned Opcode = Instruction::UserOp1;
3575  if (Instruction *I = dyn_cast<Instruction>(V)) {
3576    Opcode = I->getOpcode();
3577
3578    // Don't attempt to analyze instructions in blocks that aren't
3579    // reachable. Such instructions don't matter, and they aren't required
3580    // to obey basic rules for definitions dominating uses which this
3581    // analysis depends on.
3582    if (!DT->isReachableFromEntry(I->getParent()))
3583      return getUnknown(V);
3584  } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
3585    Opcode = CE->getOpcode();
3586  else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
3587    return getConstant(CI);
3588  else if (isa<ConstantPointerNull>(V))
3589    return getConstant(V->getType(), 0);
3590  else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
3591    return GA->mayBeOverridden() ? getUnknown(V) : getSCEV(GA->getAliasee());
3592  else
3593    return getUnknown(V);
3594
3595  Operator *U = cast<Operator>(V);
3596  switch (Opcode) {
3597  case Instruction::Add: {
3598    // The simple thing to do would be to just call getSCEV on both operands
3599    // and call getAddExpr with the result. However if we're looking at a
3600    // bunch of things all added together, this can be quite inefficient,
3601    // because it leads to N-1 getAddExpr calls for N ultimate operands.
3602    // Instead, gather up all the operands and make a single getAddExpr call.
3603    // LLVM IR canonical form means we need only traverse the left operands.
3604    //
3605    // Don't apply this instruction's NSW or NUW flags to the new
3606    // expression. The instruction may be guarded by control flow that the
3607    // no-wrap behavior depends on. Non-control-equivalent instructions can be
3608    // mapped to the same SCEV expression, and it would be incorrect to transfer
3609    // NSW/NUW semantics to those operations.
3610    SmallVector<const SCEV *, 4> AddOps;
3611    AddOps.push_back(getSCEV(U->getOperand(1)));
3612    for (Value *Op = U->getOperand(0); ; Op = U->getOperand(0)) {
3613      unsigned Opcode = Op->getValueID() - Value::InstructionVal;
3614      if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
3615        break;
3616      U = cast<Operator>(Op);
3617      const SCEV *Op1 = getSCEV(U->getOperand(1));
3618      if (Opcode == Instruction::Sub)
3619        AddOps.push_back(getNegativeSCEV(Op1));
3620      else
3621        AddOps.push_back(Op1);
3622    }
3623    AddOps.push_back(getSCEV(U->getOperand(0)));
3624    return getAddExpr(AddOps);
3625  }
3626  case Instruction::Mul: {
3627    // Don't transfer NSW/NUW for the same reason as AddExpr.
3628    SmallVector<const SCEV *, 4> MulOps;
3629    MulOps.push_back(getSCEV(U->getOperand(1)));
3630    for (Value *Op = U->getOperand(0);
3631         Op->getValueID() == Instruction::Mul + Value::InstructionVal;
3632         Op = U->getOperand(0)) {
3633      U = cast<Operator>(Op);
3634      MulOps.push_back(getSCEV(U->getOperand(1)));
3635    }
3636    MulOps.push_back(getSCEV(U->getOperand(0)));
3637    return getMulExpr(MulOps);
3638  }
3639  case Instruction::UDiv:
3640    return getUDivExpr(getSCEV(U->getOperand(0)),
3641                       getSCEV(U->getOperand(1)));
3642  case Instruction::Sub:
3643    return getMinusSCEV(getSCEV(U->getOperand(0)),
3644                        getSCEV(U->getOperand(1)));
3645  case Instruction::And:
3646    // For an expression like x&255 that merely masks off the high bits,
3647    // use zext(trunc(x)) as the SCEV expression.
3648    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3649      if (CI->isNullValue())
3650        return getSCEV(U->getOperand(1));
3651      if (CI->isAllOnesValue())
3652        return getSCEV(U->getOperand(0));
3653      const APInt &A = CI->getValue();
3654
3655      // Instcombine's ShrinkDemandedConstant may strip bits out of
3656      // constants, obscuring what would otherwise be a low-bits mask.
3657      // Use ComputeMaskedBits to compute what ShrinkDemandedConstant
3658      // knew about to reconstruct a low-bits mask value.
3659      unsigned LZ = A.countLeadingZeros();
3660      unsigned BitWidth = A.getBitWidth();
3661      APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
3662      ComputeMaskedBits(U->getOperand(0), KnownZero, KnownOne, TD);
3663
3664      APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ);
3665
3666      if (LZ != 0 && !((~A & ~KnownZero) & EffectiveMask))
3667        return
3668          getZeroExtendExpr(getTruncateExpr(getSCEV(U->getOperand(0)),
3669                                IntegerType::get(getContext(), BitWidth - LZ)),
3670                            U->getType());
3671    }
3672    break;
3673
3674  case Instruction::Or:
3675    // If the RHS of the Or is a constant, we may have something like:
3676    // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
3677    // optimizations will transparently handle this case.
3678    //
3679    // In order for this transformation to be safe, the LHS must be of the
3680    // form X*(2^n) and the Or constant must be less than 2^n.
3681    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3682      const SCEV *LHS = getSCEV(U->getOperand(0));
3683      const APInt &CIVal = CI->getValue();
3684      if (GetMinTrailingZeros(LHS) >=
3685          (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
3686        // Build a plain add SCEV.
3687        const SCEV *S = getAddExpr(LHS, getSCEV(CI));
3688        // If the LHS of the add was an addrec and it has no-wrap flags,
3689        // transfer the no-wrap flags, since an or won't introduce a wrap.
3690        if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) {
3691          const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS);
3692          const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags(
3693            OldAR->getNoWrapFlags());
3694        }
3695        return S;
3696      }
3697    }
3698    break;
3699  case Instruction::Xor:
3700    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3701      // If the RHS of the xor is a signbit, then this is just an add.
3702      // Instcombine turns add of signbit into xor as a strength reduction step.
3703      if (CI->getValue().isSignBit())
3704        return getAddExpr(getSCEV(U->getOperand(0)),
3705                          getSCEV(U->getOperand(1)));
3706
3707      // If the RHS of xor is -1, then this is a not operation.
3708      if (CI->isAllOnesValue())
3709        return getNotSCEV(getSCEV(U->getOperand(0)));
3710
3711      // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
3712      // This is a variant of the check for xor with -1, and it handles
3713      // the case where instcombine has trimmed non-demanded bits out
3714      // of an xor with -1.
3715      if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0)))
3716        if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1)))
3717          if (BO->getOpcode() == Instruction::And &&
3718              LCI->getValue() == CI->getValue())
3719            if (const SCEVZeroExtendExpr *Z =
3720                  dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) {
3721              Type *UTy = U->getType();
3722              const SCEV *Z0 = Z->getOperand();
3723              Type *Z0Ty = Z0->getType();
3724              unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
3725
3726              // If C is a low-bits mask, the zero extend is serving to
3727              // mask off the high bits. Complement the operand and
3728              // re-apply the zext.
3729              if (APIntOps::isMask(Z0TySize, CI->getValue()))
3730                return getZeroExtendExpr(getNotSCEV(Z0), UTy);
3731
3732              // If C is a single bit, it may be in the sign-bit position
3733              // before the zero-extend. In this case, represent the xor
3734              // using an add, which is equivalent, and re-apply the zext.
3735              APInt Trunc = CI->getValue().trunc(Z0TySize);
3736              if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
3737                  Trunc.isSignBit())
3738                return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
3739                                         UTy);
3740            }
3741    }
3742    break;
3743
3744  case Instruction::Shl:
3745    // Turn shift left of a constant amount into a multiply.
3746    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
3747      uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
3748
3749      // If the shift count is not less than the bitwidth, the result of
3750      // the shift is undefined. Don't try to analyze it, because the
3751      // resolution chosen here may differ from the resolution chosen in
3752      // other parts of the compiler.
3753      if (SA->getValue().uge(BitWidth))
3754        break;
3755
3756      Constant *X = ConstantInt::get(getContext(),
3757        APInt(BitWidth, 1).shl(SA->getZExtValue()));
3758      return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
3759    }
3760    break;
3761
3762  case Instruction::LShr:
3763    // Turn logical shift right of a constant into a unsigned divide.
3764    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
3765      uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
3766
3767      // If the shift count is not less than the bitwidth, the result of
3768      // the shift is undefined. Don't try to analyze it, because the
3769      // resolution chosen here may differ from the resolution chosen in
3770      // other parts of the compiler.
3771      if (SA->getValue().uge(BitWidth))
3772        break;
3773
3774      Constant *X = ConstantInt::get(getContext(),
3775        APInt(BitWidth, 1).shl(SA->getZExtValue()));
3776      return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
3777    }
3778    break;
3779
3780  case Instruction::AShr:
3781    // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
3782    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1)))
3783      if (Operator *L = dyn_cast<Operator>(U->getOperand(0)))
3784        if (L->getOpcode() == Instruction::Shl &&
3785            L->getOperand(1) == U->getOperand(1)) {
3786          uint64_t BitWidth = getTypeSizeInBits(U->getType());
3787
3788          // If the shift count is not less than the bitwidth, the result of
3789          // the shift is undefined. Don't try to analyze it, because the
3790          // resolution chosen here may differ from the resolution chosen in
3791          // other parts of the compiler.
3792          if (CI->getValue().uge(BitWidth))
3793            break;
3794
3795          uint64_t Amt = BitWidth - CI->getZExtValue();
3796          if (Amt == BitWidth)
3797            return getSCEV(L->getOperand(0));       // shift by zero --> noop
3798          return
3799            getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)),
3800                                              IntegerType::get(getContext(),
3801                                                               Amt)),
3802                              U->getType());
3803        }
3804    break;
3805
3806  case Instruction::Trunc:
3807    return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
3808
3809  case Instruction::ZExt:
3810    return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
3811
3812  case Instruction::SExt:
3813    return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
3814
3815  case Instruction::BitCast:
3816    // BitCasts are no-op casts so we just eliminate the cast.
3817    if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
3818      return getSCEV(U->getOperand(0));
3819    break;
3820
3821  // It's tempting to handle inttoptr and ptrtoint as no-ops, however this can
3822  // lead to pointer expressions which cannot safely be expanded to GEPs,
3823  // because ScalarEvolution doesn't respect the GEP aliasing rules when
3824  // simplifying integer expressions.
3825
3826  case Instruction::GetElementPtr:
3827    return createNodeForGEP(cast<GEPOperator>(U));
3828
3829  case Instruction::PHI:
3830    return createNodeForPHI(cast<PHINode>(U));
3831
3832  case Instruction::Select:
3833    // This could be a smax or umax that was lowered earlier.
3834    // Try to recover it.
3835    if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
3836      Value *LHS = ICI->getOperand(0);
3837      Value *RHS = ICI->getOperand(1);
3838      switch (ICI->getPredicate()) {
3839      case ICmpInst::ICMP_SLT:
3840      case ICmpInst::ICMP_SLE:
3841        std::swap(LHS, RHS);
3842        // fall through
3843      case ICmpInst::ICMP_SGT:
3844      case ICmpInst::ICMP_SGE:
3845        // a >s b ? a+x : b+x  ->  smax(a, b)+x
3846        // a >s b ? b+x : a+x  ->  smin(a, b)+x
3847        if (LHS->getType() == U->getType()) {
3848          const SCEV *LS = getSCEV(LHS);
3849          const SCEV *RS = getSCEV(RHS);
3850          const SCEV *LA = getSCEV(U->getOperand(1));
3851          const SCEV *RA = getSCEV(U->getOperand(2));
3852          const SCEV *LDiff = getMinusSCEV(LA, LS);
3853          const SCEV *RDiff = getMinusSCEV(RA, RS);
3854          if (LDiff == RDiff)
3855            return getAddExpr(getSMaxExpr(LS, RS), LDiff);
3856          LDiff = getMinusSCEV(LA, RS);
3857          RDiff = getMinusSCEV(RA, LS);
3858          if (LDiff == RDiff)
3859            return getAddExpr(getSMinExpr(LS, RS), LDiff);
3860        }
3861        break;
3862      case ICmpInst::ICMP_ULT:
3863      case ICmpInst::ICMP_ULE:
3864        std::swap(LHS, RHS);
3865        // fall through
3866      case ICmpInst::ICMP_UGT:
3867      case ICmpInst::ICMP_UGE:
3868        // a >u b ? a+x : b+x  ->  umax(a, b)+x
3869        // a >u b ? b+x : a+x  ->  umin(a, b)+x
3870        if (LHS->getType() == U->getType()) {
3871          const SCEV *LS = getSCEV(LHS);
3872          const SCEV *RS = getSCEV(RHS);
3873          const SCEV *LA = getSCEV(U->getOperand(1));
3874          const SCEV *RA = getSCEV(U->getOperand(2));
3875          const SCEV *LDiff = getMinusSCEV(LA, LS);
3876          const SCEV *RDiff = getMinusSCEV(RA, RS);
3877          if (LDiff == RDiff)
3878            return getAddExpr(getUMaxExpr(LS, RS), LDiff);
3879          LDiff = getMinusSCEV(LA, RS);
3880          RDiff = getMinusSCEV(RA, LS);
3881          if (LDiff == RDiff)
3882            return getAddExpr(getUMinExpr(LS, RS), LDiff);
3883        }
3884        break;
3885      case ICmpInst::ICMP_NE:
3886        // n != 0 ? n+x : 1+x  ->  umax(n, 1)+x
3887        if (LHS->getType() == U->getType() &&
3888            isa<ConstantInt>(RHS) &&
3889            cast<ConstantInt>(RHS)->isZero()) {
3890          const SCEV *One = getConstant(LHS->getType(), 1);
3891          const SCEV *LS = getSCEV(LHS);
3892          const SCEV *LA = getSCEV(U->getOperand(1));
3893          const SCEV *RA = getSCEV(U->getOperand(2));
3894          const SCEV *LDiff = getMinusSCEV(LA, LS);
3895          const SCEV *RDiff = getMinusSCEV(RA, One);
3896          if (LDiff == RDiff)
3897            return getAddExpr(getUMaxExpr(One, LS), LDiff);
3898        }
3899        break;
3900      case ICmpInst::ICMP_EQ:
3901        // n == 0 ? 1+x : n+x  ->  umax(n, 1)+x
3902        if (LHS->getType() == U->getType() &&
3903            isa<ConstantInt>(RHS) &&
3904            cast<ConstantInt>(RHS)->isZero()) {
3905          const SCEV *One = getConstant(LHS->getType(), 1);
3906          const SCEV *LS = getSCEV(LHS);
3907          const SCEV *LA = getSCEV(U->getOperand(1));
3908          const SCEV *RA = getSCEV(U->getOperand(2));
3909          const SCEV *LDiff = getMinusSCEV(LA, One);
3910          const SCEV *RDiff = getMinusSCEV(RA, LS);
3911          if (LDiff == RDiff)
3912            return getAddExpr(getUMaxExpr(One, LS), LDiff);
3913        }
3914        break;
3915      default:
3916        break;
3917      }
3918    }
3919
3920  default: // We cannot analyze this expression.
3921    break;
3922  }
3923
3924  return getUnknown(V);
3925}
3926
3927
3928
3929//===----------------------------------------------------------------------===//
3930//                   Iteration Count Computation Code
3931//
3932
3933/// getSmallConstantTripCount - Returns the maximum trip count of this loop as a
3934/// normal unsigned value. Returns 0 if the trip count is unknown or not
3935/// constant. Will also return 0 if the maximum trip count is very large (>=
3936/// 2^32).
3937///
3938/// This "trip count" assumes that control exits via ExitingBlock. More
3939/// precisely, it is the number of times that control may reach ExitingBlock
3940/// before taking the branch. For loops with multiple exits, it may not be the
3941/// number times that the loop header executes because the loop may exit
3942/// prematurely via another branch.
3943unsigned ScalarEvolution::
3944getSmallConstantTripCount(Loop *L, BasicBlock *ExitingBlock) {
3945  const SCEVConstant *ExitCount =
3946    dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
3947  if (!ExitCount)
3948    return 0;
3949
3950  ConstantInt *ExitConst = ExitCount->getValue();
3951
3952  // Guard against huge trip counts.
3953  if (ExitConst->getValue().getActiveBits() > 32)
3954    return 0;
3955
3956  // In case of integer overflow, this returns 0, which is correct.
3957  return ((unsigned)ExitConst->getZExtValue()) + 1;
3958}
3959
3960/// getSmallConstantTripMultiple - Returns the largest constant divisor of the
3961/// trip count of this loop as a normal unsigned value, if possible. This
3962/// means that the actual trip count is always a multiple of the returned
3963/// value (don't forget the trip count could very well be zero as well!).
3964///
3965/// Returns 1 if the trip count is unknown or not guaranteed to be the
3966/// multiple of a constant (which is also the case if the trip count is simply
3967/// constant, use getSmallConstantTripCount for that case), Will also return 1
3968/// if the trip count is very large (>= 2^32).
3969///
3970/// As explained in the comments for getSmallConstantTripCount, this assumes
3971/// that control exits the loop via ExitingBlock.
3972unsigned ScalarEvolution::
3973getSmallConstantTripMultiple(Loop *L, BasicBlock *ExitingBlock) {
3974  const SCEV *ExitCount = getExitCount(L, ExitingBlock);
3975  if (ExitCount == getCouldNotCompute())
3976    return 1;
3977
3978  // Get the trip count from the BE count by adding 1.
3979  const SCEV *TCMul = getAddExpr(ExitCount,
3980                                 getConstant(ExitCount->getType(), 1));
3981  // FIXME: SCEV distributes multiplication as V1*C1 + V2*C1. We could attempt
3982  // to factor simple cases.
3983  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(TCMul))
3984    TCMul = Mul->getOperand(0);
3985
3986  const SCEVConstant *MulC = dyn_cast<SCEVConstant>(TCMul);
3987  if (!MulC)
3988    return 1;
3989
3990  ConstantInt *Result = MulC->getValue();
3991
3992  // Guard against huge trip counts.
3993  if (!Result || Result->getValue().getActiveBits() > 32)
3994    return 1;
3995
3996  return (unsigned)Result->getZExtValue();
3997}
3998
3999// getExitCount - Get the expression for the number of loop iterations for which
4000// this loop is guaranteed not to exit via ExitintBlock. Otherwise return
4001// SCEVCouldNotCompute.
4002const SCEV *ScalarEvolution::getExitCount(Loop *L, BasicBlock *ExitingBlock) {
4003  return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
4004}
4005
4006/// getBackedgeTakenCount - If the specified loop has a predictable
4007/// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute
4008/// object. The backedge-taken count is the number of times the loop header
4009/// will be branched to from within the loop. This is one less than the
4010/// trip count of the loop, since it doesn't count the first iteration,
4011/// when the header is branched to from outside the loop.
4012///
4013/// Note that it is not valid to call this method on a loop without a
4014/// loop-invariant backedge-taken count (see
4015/// hasLoopInvariantBackedgeTakenCount).
4016///
4017const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) {
4018  return getBackedgeTakenInfo(L).getExact(this);
4019}
4020
4021/// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except
4022/// return the least SCEV value that is known never to be less than the
4023/// actual backedge taken count.
4024const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {
4025  return getBackedgeTakenInfo(L).getMax(this);
4026}
4027
4028/// PushLoopPHIs - Push PHI nodes in the header of the given loop
4029/// onto the given Worklist.
4030static void
4031PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
4032  BasicBlock *Header = L->getHeader();
4033
4034  // Push all Loop-header PHIs onto the Worklist stack.
4035  for (BasicBlock::iterator I = Header->begin();
4036       PHINode *PN = dyn_cast<PHINode>(I); ++I)
4037    Worklist.push_back(PN);
4038}
4039
4040const ScalarEvolution::BackedgeTakenInfo &
4041ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
4042  // Initially insert an invalid entry for this loop. If the insertion
4043  // succeeds, proceed to actually compute a backedge-taken count and
4044  // update the value. The temporary CouldNotCompute value tells SCEV
4045  // code elsewhere that it shouldn't attempt to request a new
4046  // backedge-taken count, which could result in infinite recursion.
4047  std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
4048    BackedgeTakenCounts.insert(std::make_pair(L, BackedgeTakenInfo()));
4049  if (!Pair.second)
4050    return Pair.first->second;
4051
4052  // ComputeBackedgeTakenCount may allocate memory for its result. Inserting it
4053  // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
4054  // must be cleared in this scope.
4055  BackedgeTakenInfo Result = ComputeBackedgeTakenCount(L);
4056
4057  if (Result.getExact(this) != getCouldNotCompute()) {
4058    assert(isLoopInvariant(Result.getExact(this), L) &&
4059           isLoopInvariant(Result.getMax(this), L) &&
4060           "Computed backedge-taken count isn't loop invariant for loop!");
4061    ++NumTripCountsComputed;
4062  }
4063  else if (Result.getMax(this) == getCouldNotCompute() &&
4064           isa<PHINode>(L->getHeader()->begin())) {
4065    // Only count loops that have phi nodes as not being computable.
4066    ++NumTripCountsNotComputed;
4067  }
4068
4069  // Now that we know more about the trip count for this loop, forget any
4070  // existing SCEV values for PHI nodes in this loop since they are only
4071  // conservative estimates made without the benefit of trip count
4072  // information. This is similar to the code in forgetLoop, except that
4073  // it handles SCEVUnknown PHI nodes specially.
4074  if (Result.hasAnyInfo()) {
4075    SmallVector<Instruction *, 16> Worklist;
4076    PushLoopPHIs(L, Worklist);
4077
4078    SmallPtrSet<Instruction *, 8> Visited;
4079    while (!Worklist.empty()) {
4080      Instruction *I = Worklist.pop_back_val();
4081      if (!Visited.insert(I)) continue;
4082
4083      ValueExprMapType::iterator It =
4084        ValueExprMap.find(static_cast<Value *>(I));
4085      if (It != ValueExprMap.end()) {
4086        const SCEV *Old = It->second;
4087
4088        // SCEVUnknown for a PHI either means that it has an unrecognized
4089        // structure, or it's a PHI that's in the progress of being computed
4090        // by createNodeForPHI.  In the former case, additional loop trip
4091        // count information isn't going to change anything. In the later
4092        // case, createNodeForPHI will perform the necessary updates on its
4093        // own when it gets to that point.
4094        if (!isa<PHINode>(I) || !isa<SCEVUnknown>(Old)) {
4095          forgetMemoizedResults(Old);
4096          ValueExprMap.erase(It);
4097        }
4098        if (PHINode *PN = dyn_cast<PHINode>(I))
4099          ConstantEvolutionLoopExitValue.erase(PN);
4100      }
4101
4102      PushDefUseChildren(I, Worklist);
4103    }
4104  }
4105
4106  // Re-lookup the insert position, since the call to
4107  // ComputeBackedgeTakenCount above could result in a
4108  // recusive call to getBackedgeTakenInfo (on a different
4109  // loop), which would invalidate the iterator computed
4110  // earlier.
4111  return BackedgeTakenCounts.find(L)->second = Result;
4112}
4113
4114/// forgetLoop - This method should be called by the client when it has
4115/// changed a loop in a way that may effect ScalarEvolution's ability to
4116/// compute a trip count, or if the loop is deleted.
4117void ScalarEvolution::forgetLoop(const Loop *L) {
4118  // Drop any stored trip count value.
4119  DenseMap<const Loop*, BackedgeTakenInfo>::iterator BTCPos =
4120    BackedgeTakenCounts.find(L);
4121  if (BTCPos != BackedgeTakenCounts.end()) {
4122    BTCPos->second.clear();
4123    BackedgeTakenCounts.erase(BTCPos);
4124  }
4125
4126  // Drop information about expressions based on loop-header PHIs.
4127  SmallVector<Instruction *, 16> Worklist;
4128  PushLoopPHIs(L, Worklist);
4129
4130  SmallPtrSet<Instruction *, 8> Visited;
4131  while (!Worklist.empty()) {
4132    Instruction *I = Worklist.pop_back_val();
4133    if (!Visited.insert(I)) continue;
4134
4135    ValueExprMapType::iterator It = ValueExprMap.find(static_cast<Value *>(I));
4136    if (It != ValueExprMap.end()) {
4137      forgetMemoizedResults(It->second);
4138      ValueExprMap.erase(It);
4139      if (PHINode *PN = dyn_cast<PHINode>(I))
4140        ConstantEvolutionLoopExitValue.erase(PN);
4141    }
4142
4143    PushDefUseChildren(I, Worklist);
4144  }
4145
4146  // Forget all contained loops too, to avoid dangling entries in the
4147  // ValuesAtScopes map.
4148  for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
4149    forgetLoop(*I);
4150}
4151
4152/// forgetValue - This method should be called by the client when it has
4153/// changed a value in a way that may effect its value, or which may
4154/// disconnect it from a def-use chain linking it to a loop.
4155void ScalarEvolution::forgetValue(Value *V) {
4156  Instruction *I = dyn_cast<Instruction>(V);
4157  if (!I) return;
4158
4159  // Drop information about expressions based on loop-header PHIs.
4160  SmallVector<Instruction *, 16> Worklist;
4161  Worklist.push_back(I);
4162
4163  SmallPtrSet<Instruction *, 8> Visited;
4164  while (!Worklist.empty()) {
4165    I = Worklist.pop_back_val();
4166    if (!Visited.insert(I)) continue;
4167
4168    ValueExprMapType::iterator It = ValueExprMap.find(static_cast<Value *>(I));
4169    if (It != ValueExprMap.end()) {
4170      forgetMemoizedResults(It->second);
4171      ValueExprMap.erase(It);
4172      if (PHINode *PN = dyn_cast<PHINode>(I))
4173        ConstantEvolutionLoopExitValue.erase(PN);
4174    }
4175
4176    PushDefUseChildren(I, Worklist);
4177  }
4178}
4179
4180/// getExact - Get the exact loop backedge taken count considering all loop
4181/// exits. A computable result can only be return for loops with a single exit.
4182/// Returning the minimum taken count among all exits is incorrect because one
4183/// of the loop's exit limit's may have been skipped. HowFarToZero assumes that
4184/// the limit of each loop test is never skipped. This is a valid assumption as
4185/// long as the loop exits via that test. For precise results, it is the
4186/// caller's responsibility to specify the relevant loop exit using
4187/// getExact(ExitingBlock, SE).
4188const SCEV *
4189ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const {
4190  // If any exits were not computable, the loop is not computable.
4191  if (!ExitNotTaken.isCompleteList()) return SE->getCouldNotCompute();
4192
4193  // We need exactly one computable exit.
4194  if (!ExitNotTaken.ExitingBlock) return SE->getCouldNotCompute();
4195  assert(ExitNotTaken.ExactNotTaken && "uninitialized not-taken info");
4196
4197  const SCEV *BECount = 0;
4198  for (const ExitNotTakenInfo *ENT = &ExitNotTaken;
4199       ENT != 0; ENT = ENT->getNextExit()) {
4200
4201    assert(ENT->ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV");
4202
4203    if (!BECount)
4204      BECount = ENT->ExactNotTaken;
4205    else if (BECount != ENT->ExactNotTaken)
4206      return SE->getCouldNotCompute();
4207  }
4208  assert(BECount && "Invalid not taken count for loop exit");
4209  return BECount;
4210}
4211
4212/// getExact - Get the exact not taken count for this loop exit.
4213const SCEV *
4214ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock,
4215                                             ScalarEvolution *SE) const {
4216  for (const ExitNotTakenInfo *ENT = &ExitNotTaken;
4217       ENT != 0; ENT = ENT->getNextExit()) {
4218
4219    if (ENT->ExitingBlock == ExitingBlock)
4220      return ENT->ExactNotTaken;
4221  }
4222  return SE->getCouldNotCompute();
4223}
4224
4225/// getMax - Get the max backedge taken count for the loop.
4226const SCEV *
4227ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const {
4228  return Max ? Max : SE->getCouldNotCompute();
4229}
4230
4231/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
4232/// computable exit into a persistent ExitNotTakenInfo array.
4233ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
4234  SmallVectorImpl< std::pair<BasicBlock *, const SCEV *> > &ExitCounts,
4235  bool Complete, const SCEV *MaxCount) : Max(MaxCount) {
4236
4237  if (!Complete)
4238    ExitNotTaken.setIncomplete();
4239
4240  unsigned NumExits = ExitCounts.size();
4241  if (NumExits == 0) return;
4242
4243  ExitNotTaken.ExitingBlock = ExitCounts[0].first;
4244  ExitNotTaken.ExactNotTaken = ExitCounts[0].second;
4245  if (NumExits == 1) return;
4246
4247  // Handle the rare case of multiple computable exits.
4248  ExitNotTakenInfo *ENT = new ExitNotTakenInfo[NumExits-1];
4249
4250  ExitNotTakenInfo *PrevENT = &ExitNotTaken;
4251  for (unsigned i = 1; i < NumExits; ++i, PrevENT = ENT, ++ENT) {
4252    PrevENT->setNextExit(ENT);
4253    ENT->ExitingBlock = ExitCounts[i].first;
4254    ENT->ExactNotTaken = ExitCounts[i].second;
4255  }
4256}
4257
4258/// clear - Invalidate this result and free the ExitNotTakenInfo array.
4259void ScalarEvolution::BackedgeTakenInfo::clear() {
4260  ExitNotTaken.ExitingBlock = 0;
4261  ExitNotTaken.ExactNotTaken = 0;
4262  delete[] ExitNotTaken.getNextExit();
4263}
4264
4265/// ComputeBackedgeTakenCount - Compute the number of times the backedge
4266/// of the specified loop will execute.
4267ScalarEvolution::BackedgeTakenInfo
4268ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
4269  SmallVector<BasicBlock *, 8> ExitingBlocks;
4270  L->getExitingBlocks(ExitingBlocks);
4271
4272  // Examine all exits and pick the most conservative values.
4273  const SCEV *MaxBECount = getCouldNotCompute();
4274  bool CouldComputeBECount = true;
4275  SmallVector<std::pair<BasicBlock *, const SCEV *>, 4> ExitCounts;
4276  for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
4277    ExitLimit EL = ComputeExitLimit(L, ExitingBlocks[i]);
4278    if (EL.Exact == getCouldNotCompute())
4279      // We couldn't compute an exact value for this exit, so
4280      // we won't be able to compute an exact value for the loop.
4281      CouldComputeBECount = false;
4282    else
4283      ExitCounts.push_back(std::make_pair(ExitingBlocks[i], EL.Exact));
4284
4285    if (MaxBECount == getCouldNotCompute())
4286      MaxBECount = EL.Max;
4287    else if (EL.Max != getCouldNotCompute()) {
4288      // We cannot take the "min" MaxBECount, because non-unit stride loops may
4289      // skip some loop tests. Taking the max over the exits is sufficiently
4290      // conservative.  TODO: We could do better taking into consideration
4291      // that (1) the loop has unit stride (2) the last loop test is
4292      // less-than/greater-than (3) any loop test is less-than/greater-than AND
4293      // falls-through some constant times less then the other tests.
4294      MaxBECount = getUMaxFromMismatchedTypes(MaxBECount, EL.Max);
4295    }
4296  }
4297
4298  return BackedgeTakenInfo(ExitCounts, CouldComputeBECount, MaxBECount);
4299}
4300
4301/// ComputeExitLimit - Compute the number of times the backedge of the specified
4302/// loop will execute if it exits via the specified block.
4303ScalarEvolution::ExitLimit
4304ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) {
4305
4306  // Okay, we've chosen an exiting block.  See what condition causes us to
4307  // exit at this block.
4308  //
4309  // FIXME: we should be able to handle switch instructions (with a single exit)
4310  BranchInst *ExitBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
4311  if (ExitBr == 0) return getCouldNotCompute();
4312  assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!");
4313
4314  // At this point, we know we have a conditional branch that determines whether
4315  // the loop is exited.  However, we don't know if the branch is executed each
4316  // time through the loop.  If not, then the execution count of the branch will
4317  // not be equal to the trip count of the loop.
4318  //
4319  // Currently we check for this by checking to see if the Exit branch goes to
4320  // the loop header.  If so, we know it will always execute the same number of
4321  // times as the loop.  We also handle the case where the exit block *is* the
4322  // loop header.  This is common for un-rotated loops.
4323  //
4324  // If both of those tests fail, walk up the unique predecessor chain to the
4325  // header, stopping if there is an edge that doesn't exit the loop. If the
4326  // header is reached, the execution count of the branch will be equal to the
4327  // trip count of the loop.
4328  //
4329  //  More extensive analysis could be done to handle more cases here.
4330  //
4331  if (ExitBr->getSuccessor(0) != L->getHeader() &&
4332      ExitBr->getSuccessor(1) != L->getHeader() &&
4333      ExitBr->getParent() != L->getHeader()) {
4334    // The simple checks failed, try climbing the unique predecessor chain
4335    // up to the header.
4336    bool Ok = false;
4337    for (BasicBlock *BB = ExitBr->getParent(); BB; ) {
4338      BasicBlock *Pred = BB->getUniquePredecessor();
4339      if (!Pred)
4340        return getCouldNotCompute();
4341      TerminatorInst *PredTerm = Pred->getTerminator();
4342      for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) {
4343        BasicBlock *PredSucc = PredTerm->getSuccessor(i);
4344        if (PredSucc == BB)
4345          continue;
4346        // If the predecessor has a successor that isn't BB and isn't
4347        // outside the loop, assume the worst.
4348        if (L->contains(PredSucc))
4349          return getCouldNotCompute();
4350      }
4351      if (Pred == L->getHeader()) {
4352        Ok = true;
4353        break;
4354      }
4355      BB = Pred;
4356    }
4357    if (!Ok)
4358      return getCouldNotCompute();
4359  }
4360
4361  // Proceed to the next level to examine the exit condition expression.
4362  return ComputeExitLimitFromCond(L, ExitBr->getCondition(),
4363                                  ExitBr->getSuccessor(0),
4364                                  ExitBr->getSuccessor(1));
4365}
4366
4367/// ComputeExitLimitFromCond - Compute the number of times the
4368/// backedge of the specified loop will execute if its exit condition
4369/// were a conditional branch of ExitCond, TBB, and FBB.
4370ScalarEvolution::ExitLimit
4371ScalarEvolution::ComputeExitLimitFromCond(const Loop *L,
4372                                          Value *ExitCond,
4373                                          BasicBlock *TBB,
4374                                          BasicBlock *FBB) {
4375  // Check if the controlling expression for this loop is an And or Or.
4376  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
4377    if (BO->getOpcode() == Instruction::And) {
4378      // Recurse on the operands of the and.
4379      ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB);
4380      ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB);
4381      const SCEV *BECount = getCouldNotCompute();
4382      const SCEV *MaxBECount = getCouldNotCompute();
4383      if (L->contains(TBB)) {
4384        // Both conditions must be true for the loop to continue executing.
4385        // Choose the less conservative count.
4386        if (EL0.Exact == getCouldNotCompute() ||
4387            EL1.Exact == getCouldNotCompute())
4388          BECount = getCouldNotCompute();
4389        else
4390          BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact);
4391        if (EL0.Max == getCouldNotCompute())
4392          MaxBECount = EL1.Max;
4393        else if (EL1.Max == getCouldNotCompute())
4394          MaxBECount = EL0.Max;
4395        else
4396          MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max);
4397      } else {
4398        // Both conditions must be true at the same time for the loop to exit.
4399        // For now, be conservative.
4400        assert(L->contains(FBB) && "Loop block has no successor in loop!");
4401        if (EL0.Max == EL1.Max)
4402          MaxBECount = EL0.Max;
4403        if (EL0.Exact == EL1.Exact)
4404          BECount = EL0.Exact;
4405      }
4406
4407      return ExitLimit(BECount, MaxBECount);
4408    }
4409    if (BO->getOpcode() == Instruction::Or) {
4410      // Recurse on the operands of the or.
4411      ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB);
4412      ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB);
4413      const SCEV *BECount = getCouldNotCompute();
4414      const SCEV *MaxBECount = getCouldNotCompute();
4415      if (L->contains(FBB)) {
4416        // Both conditions must be false for the loop to continue executing.
4417        // Choose the less conservative count.
4418        if (EL0.Exact == getCouldNotCompute() ||
4419            EL1.Exact == getCouldNotCompute())
4420          BECount = getCouldNotCompute();
4421        else
4422          BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact);
4423        if (EL0.Max == getCouldNotCompute())
4424          MaxBECount = EL1.Max;
4425        else if (EL1.Max == getCouldNotCompute())
4426          MaxBECount = EL0.Max;
4427        else
4428          MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max);
4429      } else {
4430        // Both conditions must be false at the same time for the loop to exit.
4431        // For now, be conservative.
4432        assert(L->contains(TBB) && "Loop block has no successor in loop!");
4433        if (EL0.Max == EL1.Max)
4434          MaxBECount = EL0.Max;
4435        if (EL0.Exact == EL1.Exact)
4436          BECount = EL0.Exact;
4437      }
4438
4439      return ExitLimit(BECount, MaxBECount);
4440    }
4441  }
4442
4443  // With an icmp, it may be feasible to compute an exact backedge-taken count.
4444  // Proceed to the next level to examine the icmp.
4445  if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond))
4446    return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB);
4447
4448  // Check for a constant condition. These are normally stripped out by
4449  // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
4450  // preserve the CFG and is temporarily leaving constant conditions
4451  // in place.
4452  if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
4453    if (L->contains(FBB) == !CI->getZExtValue())
4454      // The backedge is always taken.
4455      return getCouldNotCompute();
4456    else
4457      // The backedge is never taken.
4458      return getConstant(CI->getType(), 0);
4459  }
4460
4461  // If it's not an integer or pointer comparison then compute it the hard way.
4462  return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
4463}
4464
4465/// ComputeExitLimitFromICmp - Compute the number of times the
4466/// backedge of the specified loop will execute if its exit condition
4467/// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB.
4468ScalarEvolution::ExitLimit
4469ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L,
4470                                          ICmpInst *ExitCond,
4471                                          BasicBlock *TBB,
4472                                          BasicBlock *FBB) {
4473
4474  // If the condition was exit on true, convert the condition to exit on false
4475  ICmpInst::Predicate Cond;
4476  if (!L->contains(FBB))
4477    Cond = ExitCond->getPredicate();
4478  else
4479    Cond = ExitCond->getInversePredicate();
4480
4481  // Handle common loops like: for (X = "string"; *X; ++X)
4482  if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
4483    if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
4484      ExitLimit ItCnt =
4485        ComputeLoadConstantCompareExitLimit(LI, RHS, L, Cond);
4486      if (ItCnt.hasAnyInfo())
4487        return ItCnt;
4488    }
4489
4490  const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
4491  const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
4492
4493  // Try to evaluate any dependencies out of the loop.
4494  LHS = getSCEVAtScope(LHS, L);
4495  RHS = getSCEVAtScope(RHS, L);
4496
4497  // At this point, we would like to compute how many iterations of the
4498  // loop the predicate will return true for these inputs.
4499  if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
4500    // If there is a loop-invariant, force it into the RHS.
4501    std::swap(LHS, RHS);
4502    Cond = ICmpInst::getSwappedPredicate(Cond);
4503  }
4504
4505  // Simplify the operands before analyzing them.
4506  (void)SimplifyICmpOperands(Cond, LHS, RHS);
4507
4508  // If we have a comparison of a chrec against a constant, try to use value
4509  // ranges to answer this query.
4510  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
4511    if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
4512      if (AddRec->getLoop() == L) {
4513        // Form the constant range.
4514        ConstantRange CompRange(
4515            ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue()));
4516
4517        const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
4518        if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
4519      }
4520
4521  switch (Cond) {
4522  case ICmpInst::ICMP_NE: {                     // while (X != Y)
4523    // Convert to: while (X-Y != 0)
4524    ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L);
4525    if (EL.hasAnyInfo()) return EL;
4526    break;
4527  }
4528  case ICmpInst::ICMP_EQ: {                     // while (X == Y)
4529    // Convert to: while (X-Y == 0)
4530    ExitLimit EL = HowFarToNonZero(getMinusSCEV(LHS, RHS), L);
4531    if (EL.hasAnyInfo()) return EL;
4532    break;
4533  }
4534  case ICmpInst::ICMP_SLT: {
4535    ExitLimit EL = HowManyLessThans(LHS, RHS, L, true);
4536    if (EL.hasAnyInfo()) return EL;
4537    break;
4538  }
4539  case ICmpInst::ICMP_SGT: {
4540    ExitLimit EL = HowManyLessThans(getNotSCEV(LHS),
4541                                             getNotSCEV(RHS), L, true);
4542    if (EL.hasAnyInfo()) return EL;
4543    break;
4544  }
4545  case ICmpInst::ICMP_ULT: {
4546    ExitLimit EL = HowManyLessThans(LHS, RHS, L, false);
4547    if (EL.hasAnyInfo()) return EL;
4548    break;
4549  }
4550  case ICmpInst::ICMP_UGT: {
4551    ExitLimit EL = HowManyLessThans(getNotSCEV(LHS),
4552                                             getNotSCEV(RHS), L, false);
4553    if (EL.hasAnyInfo()) return EL;
4554    break;
4555  }
4556  default:
4557#if 0
4558    dbgs() << "ComputeBackedgeTakenCount ";
4559    if (ExitCond->getOperand(0)->getType()->isUnsigned())
4560      dbgs() << "[unsigned] ";
4561    dbgs() << *LHS << "   "
4562         << Instruction::getOpcodeName(Instruction::ICmp)
4563         << "   " << *RHS << "\n";
4564#endif
4565    break;
4566  }
4567  return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
4568}
4569
4570static ConstantInt *
4571EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
4572                                ScalarEvolution &SE) {
4573  const SCEV *InVal = SE.getConstant(C);
4574  const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
4575  assert(isa<SCEVConstant>(Val) &&
4576         "Evaluation of SCEV at constant didn't fold correctly?");
4577  return cast<SCEVConstant>(Val)->getValue();
4578}
4579
4580/// ComputeLoadConstantCompareExitLimit - Given an exit condition of
4581/// 'icmp op load X, cst', try to see if we can compute the backedge
4582/// execution count.
4583ScalarEvolution::ExitLimit
4584ScalarEvolution::ComputeLoadConstantCompareExitLimit(
4585  LoadInst *LI,
4586  Constant *RHS,
4587  const Loop *L,
4588  ICmpInst::Predicate predicate) {
4589
4590  if (LI->isVolatile()) return getCouldNotCompute();
4591
4592  // Check to see if the loaded pointer is a getelementptr of a global.
4593  // TODO: Use SCEV instead of manually grubbing with GEPs.
4594  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0));
4595  if (!GEP) return getCouldNotCompute();
4596
4597  // Make sure that it is really a constant global we are gepping, with an
4598  // initializer, and make sure the first IDX is really 0.
4599  GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
4600  if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() ||
4601      GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) ||
4602      !cast<Constant>(GEP->getOperand(1))->isNullValue())
4603    return getCouldNotCompute();
4604
4605  // Okay, we allow one non-constant index into the GEP instruction.
4606  Value *VarIdx = 0;
4607  std::vector<Constant*> Indexes;
4608  unsigned VarIdxNum = 0;
4609  for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i)
4610    if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
4611      Indexes.push_back(CI);
4612    } else if (!isa<ConstantInt>(GEP->getOperand(i))) {
4613      if (VarIdx) return getCouldNotCompute();  // Multiple non-constant idx's.
4614      VarIdx = GEP->getOperand(i);
4615      VarIdxNum = i-2;
4616      Indexes.push_back(0);
4617    }
4618
4619  // Loop-invariant loads may be a byproduct of loop optimization. Skip them.
4620  if (!VarIdx)
4621    return getCouldNotCompute();
4622
4623  // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant.
4624  // Check to see if X is a loop variant variable value now.
4625  const SCEV *Idx = getSCEV(VarIdx);
4626  Idx = getSCEVAtScope(Idx, L);
4627
4628  // We can only recognize very limited forms of loop index expressions, in
4629  // particular, only affine AddRec's like {C1,+,C2}.
4630  const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx);
4631  if (!IdxExpr || !IdxExpr->isAffine() || isLoopInvariant(IdxExpr, L) ||
4632      !isa<SCEVConstant>(IdxExpr->getOperand(0)) ||
4633      !isa<SCEVConstant>(IdxExpr->getOperand(1)))
4634    return getCouldNotCompute();
4635
4636  unsigned MaxSteps = MaxBruteForceIterations;
4637  for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
4638    ConstantInt *ItCst = ConstantInt::get(
4639                           cast<IntegerType>(IdxExpr->getType()), IterationNum);
4640    ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this);
4641
4642    // Form the GEP offset.
4643    Indexes[VarIdxNum] = Val;
4644
4645    Constant *Result = ConstantFoldLoadThroughGEPIndices(GV->getInitializer(),
4646                                                         Indexes);
4647    if (Result == 0) break;  // Cannot compute!
4648
4649    // Evaluate the condition for this iteration.
4650    Result = ConstantExpr::getICmp(predicate, Result, RHS);
4651    if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
4652    if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
4653#if 0
4654      dbgs() << "\n***\n*** Computed loop count " << *ItCst
4655             << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
4656             << "***\n";
4657#endif
4658      ++NumArrayLenItCounts;
4659      return getConstant(ItCst);   // Found terminating iteration!
4660    }
4661  }
4662  return getCouldNotCompute();
4663}
4664
4665
4666/// CanConstantFold - Return true if we can constant fold an instruction of the
4667/// specified type, assuming that all operands were constants.
4668static bool CanConstantFold(const Instruction *I) {
4669  if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
4670      isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
4671      isa<LoadInst>(I))
4672    return true;
4673
4674  if (const CallInst *CI = dyn_cast<CallInst>(I))
4675    if (const Function *F = CI->getCalledFunction())
4676      return canConstantFoldCallTo(F);
4677  return false;
4678}
4679
4680/// Determine whether this instruction can constant evolve within this loop
4681/// assuming its operands can all constant evolve.
4682static bool canConstantEvolve(Instruction *I, const Loop *L) {
4683  // An instruction outside of the loop can't be derived from a loop PHI.
4684  if (!L->contains(I)) return false;
4685
4686  if (isa<PHINode>(I)) {
4687    if (L->getHeader() == I->getParent())
4688      return true;
4689    else
4690      // We don't currently keep track of the control flow needed to evaluate
4691      // PHIs, so we cannot handle PHIs inside of loops.
4692      return false;
4693  }
4694
4695  // If we won't be able to constant fold this expression even if the operands
4696  // are constants, bail early.
4697  return CanConstantFold(I);
4698}
4699
4700/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
4701/// recursing through each instruction operand until reaching a loop header phi.
4702static PHINode *
4703getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L,
4704                               DenseMap<Instruction *, PHINode *> &PHIMap) {
4705
4706  // Otherwise, we can evaluate this instruction if all of its operands are
4707  // constant or derived from a PHI node themselves.
4708  PHINode *PHI = 0;
4709  for (Instruction::op_iterator OpI = UseInst->op_begin(),
4710         OpE = UseInst->op_end(); OpI != OpE; ++OpI) {
4711
4712    if (isa<Constant>(*OpI)) continue;
4713
4714    Instruction *OpInst = dyn_cast<Instruction>(*OpI);
4715    if (!OpInst || !canConstantEvolve(OpInst, L)) return 0;
4716
4717    PHINode *P = dyn_cast<PHINode>(OpInst);
4718    if (!P)
4719      // If this operand is already visited, reuse the prior result.
4720      // We may have P != PHI if this is the deepest point at which the
4721      // inconsistent paths meet.
4722      P = PHIMap.lookup(OpInst);
4723    if (!P) {
4724      // Recurse and memoize the results, whether a phi is found or not.
4725      // This recursive call invalidates pointers into PHIMap.
4726      P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap);
4727      PHIMap[OpInst] = P;
4728    }
4729    if (P == 0) return 0;        // Not evolving from PHI
4730    if (PHI && PHI != P) return 0;  // Evolving from multiple different PHIs.
4731    PHI = P;
4732  }
4733  // This is a expression evolving from a constant PHI!
4734  return PHI;
4735}
4736
4737/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
4738/// in the loop that V is derived from.  We allow arbitrary operations along the
4739/// way, but the operands of an operation must either be constants or a value
4740/// derived from a constant PHI.  If this expression does not fit with these
4741/// constraints, return null.
4742static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
4743  Instruction *I = dyn_cast<Instruction>(V);
4744  if (I == 0 || !canConstantEvolve(I, L)) return 0;
4745
4746  if (PHINode *PN = dyn_cast<PHINode>(I)) {
4747    return PN;
4748  }
4749
4750  // Record non-constant instructions contained by the loop.
4751  DenseMap<Instruction *, PHINode *> PHIMap;
4752  return getConstantEvolvingPHIOperands(I, L, PHIMap);
4753}
4754
4755/// EvaluateExpression - Given an expression that passes the
4756/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
4757/// in the loop has the value PHIVal.  If we can't fold this expression for some
4758/// reason, return null.
4759static Constant *EvaluateExpression(Value *V, const Loop *L,
4760                                    DenseMap<Instruction *, Constant *> &Vals,
4761                                    const TargetData *TD,
4762                                    const TargetLibraryInfo *TLI) {
4763  // Convenient constant check, but redundant for recursive calls.
4764  if (Constant *C = dyn_cast<Constant>(V)) return C;
4765  Instruction *I = dyn_cast<Instruction>(V);
4766  if (!I) return 0;
4767
4768  if (Constant *C = Vals.lookup(I)) return C;
4769
4770  // An instruction inside the loop depends on a value outside the loop that we
4771  // weren't given a mapping for, or a value such as a call inside the loop.
4772  if (!canConstantEvolve(I, L)) return 0;
4773
4774  // An unmapped PHI can be due to a branch or another loop inside this loop,
4775  // or due to this not being the initial iteration through a loop where we
4776  // couldn't compute the evolution of this particular PHI last time.
4777  if (isa<PHINode>(I)) return 0;
4778
4779  std::vector<Constant*> Operands(I->getNumOperands());
4780
4781  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
4782    Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
4783    if (!Operand) {
4784      Operands[i] = dyn_cast<Constant>(I->getOperand(i));
4785      if (!Operands[i]) return 0;
4786      continue;
4787    }
4788    Constant *C = EvaluateExpression(Operand, L, Vals, TD, TLI);
4789    Vals[Operand] = C;
4790    if (!C) return 0;
4791    Operands[i] = C;
4792  }
4793
4794  if (CmpInst *CI = dyn_cast<CmpInst>(I))
4795    return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
4796                                           Operands[1], TD, TLI);
4797  if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
4798    if (!LI->isVolatile())
4799      return ConstantFoldLoadFromConstPtr(Operands[0], TD);
4800  }
4801  return ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands, TD,
4802                                  TLI);
4803}
4804
4805/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
4806/// in the header of its containing loop, we know the loop executes a
4807/// constant number of times, and the PHI node is just a recurrence
4808/// involving constants, fold it.
4809Constant *
4810ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
4811                                                   const APInt &BEs,
4812                                                   const Loop *L) {
4813  DenseMap<PHINode*, Constant*>::const_iterator I =
4814    ConstantEvolutionLoopExitValue.find(PN);
4815  if (I != ConstantEvolutionLoopExitValue.end())
4816    return I->second;
4817
4818  if (BEs.ugt(MaxBruteForceIterations))
4819    return ConstantEvolutionLoopExitValue[PN] = 0;  // Not going to evaluate it.
4820
4821  Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
4822
4823  DenseMap<Instruction *, Constant *> CurrentIterVals;
4824  BasicBlock *Header = L->getHeader();
4825  assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
4826
4827  // Since the loop is canonicalized, the PHI node must have two entries.  One
4828  // entry must be a constant (coming in from outside of the loop), and the
4829  // second must be derived from the same PHI.
4830  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
4831  PHINode *PHI = 0;
4832  for (BasicBlock::iterator I = Header->begin();
4833       (PHI = dyn_cast<PHINode>(I)); ++I) {
4834    Constant *StartCST =
4835      dyn_cast<Constant>(PHI->getIncomingValue(!SecondIsBackedge));
4836    if (StartCST == 0) continue;
4837    CurrentIterVals[PHI] = StartCST;
4838  }
4839  if (!CurrentIterVals.count(PN))
4840    return RetVal = 0;
4841
4842  Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
4843
4844  // Execute the loop symbolically to determine the exit value.
4845  if (BEs.getActiveBits() >= 32)
4846    return RetVal = 0; // More than 2^32-1 iterations?? Not doing it!
4847
4848  unsigned NumIterations = BEs.getZExtValue(); // must be in range
4849  unsigned IterationNum = 0;
4850  for (; ; ++IterationNum) {
4851    if (IterationNum == NumIterations)
4852      return RetVal = CurrentIterVals[PN];  // Got exit value!
4853
4854    // Compute the value of the PHIs for the next iteration.
4855    // EvaluateExpression adds non-phi values to the CurrentIterVals map.
4856    DenseMap<Instruction *, Constant *> NextIterVals;
4857    Constant *NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, TD,
4858                                           TLI);
4859    if (NextPHI == 0)
4860      return 0;        // Couldn't evaluate!
4861    NextIterVals[PN] = NextPHI;
4862
4863    bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
4864
4865    // Also evaluate the other PHI nodes.  However, we don't get to stop if we
4866    // cease to be able to evaluate one of them or if they stop evolving,
4867    // because that doesn't necessarily prevent us from computing PN.
4868    SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute;
4869    for (DenseMap<Instruction *, Constant *>::const_iterator
4870           I = CurrentIterVals.begin(), E = CurrentIterVals.end(); I != E; ++I){
4871      PHINode *PHI = dyn_cast<PHINode>(I->first);
4872      if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
4873      PHIsToCompute.push_back(std::make_pair(PHI, I->second));
4874    }
4875    // We use two distinct loops because EvaluateExpression may invalidate any
4876    // iterators into CurrentIterVals.
4877    for (SmallVectorImpl<std::pair<PHINode *, Constant*> >::const_iterator
4878             I = PHIsToCompute.begin(), E = PHIsToCompute.end(); I != E; ++I) {
4879      PHINode *PHI = I->first;
4880      Constant *&NextPHI = NextIterVals[PHI];
4881      if (!NextPHI) {   // Not already computed.
4882        Value *BEValue = PHI->getIncomingValue(SecondIsBackedge);
4883        NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, TD, TLI);
4884      }
4885      if (NextPHI != I->second)
4886        StoppedEvolving = false;
4887    }
4888
4889    // If all entries in CurrentIterVals == NextIterVals then we can stop
4890    // iterating, the loop can't continue to change.
4891    if (StoppedEvolving)
4892      return RetVal = CurrentIterVals[PN];
4893
4894    CurrentIterVals.swap(NextIterVals);
4895  }
4896}
4897
4898/// ComputeExitCountExhaustively - If the loop is known to execute a
4899/// constant number of times (the condition evolves only from constants),
4900/// try to evaluate a few iterations of the loop until we get the exit
4901/// condition gets a value of ExitWhen (true or false).  If we cannot
4902/// evaluate the trip count of the loop, return getCouldNotCompute().
4903const SCEV *ScalarEvolution::ComputeExitCountExhaustively(const Loop *L,
4904                                                          Value *Cond,
4905                                                          bool ExitWhen) {
4906  PHINode *PN = getConstantEvolvingPHI(Cond, L);
4907  if (PN == 0) return getCouldNotCompute();
4908
4909  // If the loop is canonicalized, the PHI will have exactly two entries.
4910  // That's the only form we support here.
4911  if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
4912
4913  DenseMap<Instruction *, Constant *> CurrentIterVals;
4914  BasicBlock *Header = L->getHeader();
4915  assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
4916
4917  // One entry must be a constant (coming in from outside of the loop), and the
4918  // second must be derived from the same PHI.
4919  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
4920  PHINode *PHI = 0;
4921  for (BasicBlock::iterator I = Header->begin();
4922       (PHI = dyn_cast<PHINode>(I)); ++I) {
4923    Constant *StartCST =
4924      dyn_cast<Constant>(PHI->getIncomingValue(!SecondIsBackedge));
4925    if (StartCST == 0) continue;
4926    CurrentIterVals[PHI] = StartCST;
4927  }
4928  if (!CurrentIterVals.count(PN))
4929    return getCouldNotCompute();
4930
4931  // Okay, we find a PHI node that defines the trip count of this loop.  Execute
4932  // the loop symbolically to determine when the condition gets a value of
4933  // "ExitWhen".
4934
4935  unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis.
4936  for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
4937    ConstantInt *CondVal =
4938      dyn_cast_or_null<ConstantInt>(EvaluateExpression(Cond, L, CurrentIterVals,
4939                                                       TD, TLI));
4940
4941    // Couldn't symbolically evaluate.
4942    if (!CondVal) return getCouldNotCompute();
4943
4944    if (CondVal->getValue() == uint64_t(ExitWhen)) {
4945      ++NumBruteForceTripCountsComputed;
4946      return getConstant(Type::getInt32Ty(getContext()), IterationNum);
4947    }
4948
4949    // Update all the PHI nodes for the next iteration.
4950    DenseMap<Instruction *, Constant *> NextIterVals;
4951
4952    // Create a list of which PHIs we need to compute. We want to do this before
4953    // calling EvaluateExpression on them because that may invalidate iterators
4954    // into CurrentIterVals.
4955    SmallVector<PHINode *, 8> PHIsToCompute;
4956    for (DenseMap<Instruction *, Constant *>::const_iterator
4957           I = CurrentIterVals.begin(), E = CurrentIterVals.end(); I != E; ++I){
4958      PHINode *PHI = dyn_cast<PHINode>(I->first);
4959      if (!PHI || PHI->getParent() != Header) continue;
4960      PHIsToCompute.push_back(PHI);
4961    }
4962    for (SmallVectorImpl<PHINode *>::const_iterator I = PHIsToCompute.begin(),
4963             E = PHIsToCompute.end(); I != E; ++I) {
4964      PHINode *PHI = *I;
4965      Constant *&NextPHI = NextIterVals[PHI];
4966      if (NextPHI) continue;    // Already computed!
4967
4968      Value *BEValue = PHI->getIncomingValue(SecondIsBackedge);
4969      NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, TD, TLI);
4970    }
4971    CurrentIterVals.swap(NextIterVals);
4972  }
4973
4974  // Too many iterations were needed to evaluate.
4975  return getCouldNotCompute();
4976}
4977
4978/// getSCEVAtScope - Return a SCEV expression for the specified value
4979/// at the specified scope in the program.  The L value specifies a loop
4980/// nest to evaluate the expression at, where null is the top-level or a
4981/// specified loop is immediately inside of the loop.
4982///
4983/// This method can be used to compute the exit value for a variable defined
4984/// in a loop by querying what the value will hold in the parent loop.
4985///
4986/// In the case that a relevant loop exit value cannot be computed, the
4987/// original value V is returned.
4988const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
4989  // Check to see if we've folded this expression at this loop before.
4990  std::map<const Loop *, const SCEV *> &Values = ValuesAtScopes[V];
4991  std::pair<std::map<const Loop *, const SCEV *>::iterator, bool> Pair =
4992    Values.insert(std::make_pair(L, static_cast<const SCEV *>(0)));
4993  if (!Pair.second)
4994    return Pair.first->second ? Pair.first->second : V;
4995
4996  // Otherwise compute it.
4997  const SCEV *C = computeSCEVAtScope(V, L);
4998  ValuesAtScopes[V][L] = C;
4999  return C;
5000}
5001
5002/// This builds up a Constant using the ConstantExpr interface.  That way, we
5003/// will return Constants for objects which aren't represented by a
5004/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
5005/// Returns NULL if the SCEV isn't representable as a Constant.
5006static Constant *BuildConstantFromSCEV(const SCEV *V) {
5007  switch (V->getSCEVType()) {
5008    default:  // TODO: smax, umax.
5009    case scCouldNotCompute:
5010    case scAddRecExpr:
5011      break;
5012    case scConstant:
5013      return cast<SCEVConstant>(V)->getValue();
5014    case scUnknown:
5015      return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
5016    case scSignExtend: {
5017      const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
5018      if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand()))
5019        return ConstantExpr::getSExt(CastOp, SS->getType());
5020      break;
5021    }
5022    case scZeroExtend: {
5023      const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
5024      if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand()))
5025        return ConstantExpr::getZExt(CastOp, SZ->getType());
5026      break;
5027    }
5028    case scTruncate: {
5029      const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
5030      if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
5031        return ConstantExpr::getTrunc(CastOp, ST->getType());
5032      break;
5033    }
5034    case scAddExpr: {
5035      const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
5036      if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) {
5037        if (C->getType()->isPointerTy())
5038          C = ConstantExpr::getBitCast(C, Type::getInt8PtrTy(C->getContext()));
5039        for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) {
5040          Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i));
5041          if (!C2) return 0;
5042
5043          // First pointer!
5044          if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) {
5045            std::swap(C, C2);
5046            // The offsets have been converted to bytes.  We can add bytes to an
5047            // i8* by GEP with the byte count in the first index.
5048            C = ConstantExpr::getBitCast(C,Type::getInt8PtrTy(C->getContext()));
5049          }
5050
5051          // Don't bother trying to sum two pointers. We probably can't
5052          // statically compute a load that results from it anyway.
5053          if (C2->getType()->isPointerTy())
5054            return 0;
5055
5056          if (C->getType()->isPointerTy()) {
5057            if (cast<PointerType>(C->getType())->getElementType()->isStructTy())
5058              C2 = ConstantExpr::getIntegerCast(
5059                  C2, Type::getInt32Ty(C->getContext()), true);
5060            C = ConstantExpr::getGetElementPtr(C, C2);
5061          } else
5062            C = ConstantExpr::getAdd(C, C2);
5063        }
5064        return C;
5065      }
5066      break;
5067    }
5068    case scMulExpr: {
5069      const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
5070      if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) {
5071        // Don't bother with pointers at all.
5072        if (C->getType()->isPointerTy()) return 0;
5073        for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) {
5074          Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i));
5075          if (!C2 || C2->getType()->isPointerTy()) return 0;
5076          C = ConstantExpr::getMul(C, C2);
5077        }
5078        return C;
5079      }
5080      break;
5081    }
5082    case scUDivExpr: {
5083      const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V);
5084      if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS()))
5085        if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS()))
5086          if (LHS->getType() == RHS->getType())
5087            return ConstantExpr::getUDiv(LHS, RHS);
5088      break;
5089    }
5090  }
5091  return 0;
5092}
5093
5094const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
5095  if (isa<SCEVConstant>(V)) return V;
5096
5097  // If this instruction is evolved from a constant-evolving PHI, compute the
5098  // exit value from the loop without using SCEVs.
5099  if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
5100    if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
5101      const Loop *LI = (*this->LI)[I->getParent()];
5102      if (LI && LI->getParentLoop() == L)  // Looking for loop exit value.
5103        if (PHINode *PN = dyn_cast<PHINode>(I))
5104          if (PN->getParent() == LI->getHeader()) {
5105            // Okay, there is no closed form solution for the PHI node.  Check
5106            // to see if the loop that contains it has a known backedge-taken
5107            // count.  If so, we may be able to force computation of the exit
5108            // value.
5109            const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI);
5110            if (const SCEVConstant *BTCC =
5111                  dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
5112              // Okay, we know how many times the containing loop executes.  If
5113              // this is a constant evolving PHI node, get the final value at
5114              // the specified iteration number.
5115              Constant *RV = getConstantEvolutionLoopExitValue(PN,
5116                                                   BTCC->getValue()->getValue(),
5117                                                               LI);
5118              if (RV) return getSCEV(RV);
5119            }
5120          }
5121
5122      // Okay, this is an expression that we cannot symbolically evaluate
5123      // into a SCEV.  Check to see if it's possible to symbolically evaluate
5124      // the arguments into constants, and if so, try to constant propagate the
5125      // result.  This is particularly useful for computing loop exit values.
5126      if (CanConstantFold(I)) {
5127        SmallVector<Constant *, 4> Operands;
5128        bool MadeImprovement = false;
5129        for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
5130          Value *Op = I->getOperand(i);
5131          if (Constant *C = dyn_cast<Constant>(Op)) {
5132            Operands.push_back(C);
5133            continue;
5134          }
5135
5136          // If any of the operands is non-constant and if they are
5137          // non-integer and non-pointer, don't even try to analyze them
5138          // with scev techniques.
5139          if (!isSCEVable(Op->getType()))
5140            return V;
5141
5142          const SCEV *OrigV = getSCEV(Op);
5143          const SCEV *OpV = getSCEVAtScope(OrigV, L);
5144          MadeImprovement |= OrigV != OpV;
5145
5146          Constant *C = BuildConstantFromSCEV(OpV);
5147          if (!C) return V;
5148          if (C->getType() != Op->getType())
5149            C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
5150                                                              Op->getType(),
5151                                                              false),
5152                                      C, Op->getType());
5153          Operands.push_back(C);
5154        }
5155
5156        // Check to see if getSCEVAtScope actually made an improvement.
5157        if (MadeImprovement) {
5158          Constant *C = 0;
5159          if (const CmpInst *CI = dyn_cast<CmpInst>(I))
5160            C = ConstantFoldCompareInstOperands(CI->getPredicate(),
5161                                                Operands[0], Operands[1], TD,
5162                                                TLI);
5163          else if (const LoadInst *LI = dyn_cast<LoadInst>(I)) {
5164            if (!LI->isVolatile())
5165              C = ConstantFoldLoadFromConstPtr(Operands[0], TD);
5166          } else
5167            C = ConstantFoldInstOperands(I->getOpcode(), I->getType(),
5168                                         Operands, TD, TLI);
5169          if (!C) return V;
5170          return getSCEV(C);
5171        }
5172      }
5173    }
5174
5175    // This is some other type of SCEVUnknown, just return it.
5176    return V;
5177  }
5178
5179  if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) {
5180    // Avoid performing the look-up in the common case where the specified
5181    // expression has no loop-variant portions.
5182    for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
5183      const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
5184      if (OpAtScope != Comm->getOperand(i)) {
5185        // Okay, at least one of these operands is loop variant but might be
5186        // foldable.  Build a new instance of the folded commutative expression.
5187        SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(),
5188                                            Comm->op_begin()+i);
5189        NewOps.push_back(OpAtScope);
5190
5191        for (++i; i != e; ++i) {
5192          OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
5193          NewOps.push_back(OpAtScope);
5194        }
5195        if (isa<SCEVAddExpr>(Comm))
5196          return getAddExpr(NewOps);
5197        if (isa<SCEVMulExpr>(Comm))
5198          return getMulExpr(NewOps);
5199        if (isa<SCEVSMaxExpr>(Comm))
5200          return getSMaxExpr(NewOps);
5201        if (isa<SCEVUMaxExpr>(Comm))
5202          return getUMaxExpr(NewOps);
5203        llvm_unreachable("Unknown commutative SCEV type!");
5204      }
5205    }
5206    // If we got here, all operands are loop invariant.
5207    return Comm;
5208  }
5209
5210  if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
5211    const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
5212    const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
5213    if (LHS == Div->getLHS() && RHS == Div->getRHS())
5214      return Div;   // must be loop invariant
5215    return getUDivExpr(LHS, RHS);
5216  }
5217
5218  // If this is a loop recurrence for a loop that does not contain L, then we
5219  // are dealing with the final value computed by the loop.
5220  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
5221    // First, attempt to evaluate each operand.
5222    // Avoid performing the look-up in the common case where the specified
5223    // expression has no loop-variant portions.
5224    for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
5225      const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
5226      if (OpAtScope == AddRec->getOperand(i))
5227        continue;
5228
5229      // Okay, at least one of these operands is loop variant but might be
5230      // foldable.  Build a new instance of the folded commutative expression.
5231      SmallVector<const SCEV *, 8> NewOps(AddRec->op_begin(),
5232                                          AddRec->op_begin()+i);
5233      NewOps.push_back(OpAtScope);
5234      for (++i; i != e; ++i)
5235        NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
5236
5237      const SCEV *FoldedRec =
5238        getAddRecExpr(NewOps, AddRec->getLoop(),
5239                      AddRec->getNoWrapFlags(SCEV::FlagNW));
5240      AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
5241      // The addrec may be folded to a nonrecurrence, for example, if the
5242      // induction variable is multiplied by zero after constant folding. Go
5243      // ahead and return the folded value.
5244      if (!AddRec)
5245        return FoldedRec;
5246      break;
5247    }
5248
5249    // If the scope is outside the addrec's loop, evaluate it by using the
5250    // loop exit value of the addrec.
5251    if (!AddRec->getLoop()->contains(L)) {
5252      // To evaluate this recurrence, we need to know how many times the AddRec
5253      // loop iterates.  Compute this now.
5254      const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
5255      if (BackedgeTakenCount == getCouldNotCompute()) return AddRec;
5256
5257      // Then, evaluate the AddRec.
5258      return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
5259    }
5260
5261    return AddRec;
5262  }
5263
5264  if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) {
5265    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
5266    if (Op == Cast->getOperand())
5267      return Cast;  // must be loop invariant
5268    return getZeroExtendExpr(Op, Cast->getType());
5269  }
5270
5271  if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) {
5272    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
5273    if (Op == Cast->getOperand())
5274      return Cast;  // must be loop invariant
5275    return getSignExtendExpr(Op, Cast->getType());
5276  }
5277
5278  if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) {
5279    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
5280    if (Op == Cast->getOperand())
5281      return Cast;  // must be loop invariant
5282    return getTruncateExpr(Op, Cast->getType());
5283  }
5284
5285  llvm_unreachable("Unknown SCEV type!");
5286}
5287
5288/// getSCEVAtScope - This is a convenience function which does
5289/// getSCEVAtScope(getSCEV(V), L).
5290const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
5291  return getSCEVAtScope(getSCEV(V), L);
5292}
5293
5294/// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the
5295/// following equation:
5296///
5297///     A * X = B (mod N)
5298///
5299/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
5300/// A and B isn't important.
5301///
5302/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
5303static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
5304                                               ScalarEvolution &SE) {
5305  uint32_t BW = A.getBitWidth();
5306  assert(BW == B.getBitWidth() && "Bit widths must be the same.");
5307  assert(A != 0 && "A must be non-zero.");
5308
5309  // 1. D = gcd(A, N)
5310  //
5311  // The gcd of A and N may have only one prime factor: 2. The number of
5312  // trailing zeros in A is its multiplicity
5313  uint32_t Mult2 = A.countTrailingZeros();
5314  // D = 2^Mult2
5315
5316  // 2. Check if B is divisible by D.
5317  //
5318  // B is divisible by D if and only if the multiplicity of prime factor 2 for B
5319  // is not less than multiplicity of this prime factor for D.
5320  if (B.countTrailingZeros() < Mult2)
5321    return SE.getCouldNotCompute();
5322
5323  // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
5324  // modulo (N / D).
5325  //
5326  // (N / D) may need BW+1 bits in its representation.  Hence, we'll use this
5327  // bit width during computations.
5328  APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
5329  APInt Mod(BW + 1, 0);
5330  Mod.setBit(BW - Mult2);  // Mod = N / D
5331  APInt I = AD.multiplicativeInverse(Mod);
5332
5333  // 4. Compute the minimum unsigned root of the equation:
5334  // I * (B / D) mod (N / D)
5335  APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
5336
5337  // The result is guaranteed to be less than 2^BW so we may truncate it to BW
5338  // bits.
5339  return SE.getConstant(Result.trunc(BW));
5340}
5341
5342/// SolveQuadraticEquation - Find the roots of the quadratic equation for the
5343/// given quadratic chrec {L,+,M,+,N}.  This returns either the two roots (which
5344/// might be the same) or two SCEVCouldNotCompute objects.
5345///
5346static std::pair<const SCEV *,const SCEV *>
5347SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
5348  assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
5349  const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
5350  const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
5351  const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
5352
5353  // We currently can only solve this if the coefficients are constants.
5354  if (!LC || !MC || !NC) {
5355    const SCEV *CNC = SE.getCouldNotCompute();
5356    return std::make_pair(CNC, CNC);
5357  }
5358
5359  uint32_t BitWidth = LC->getValue()->getValue().getBitWidth();
5360  const APInt &L = LC->getValue()->getValue();
5361  const APInt &M = MC->getValue()->getValue();
5362  const APInt &N = NC->getValue()->getValue();
5363  APInt Two(BitWidth, 2);
5364  APInt Four(BitWidth, 4);
5365
5366  {
5367    using namespace APIntOps;
5368    const APInt& C = L;
5369    // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
5370    // The B coefficient is M-N/2
5371    APInt B(M);
5372    B -= sdiv(N,Two);
5373
5374    // The A coefficient is N/2
5375    APInt A(N.sdiv(Two));
5376
5377    // Compute the B^2-4ac term.
5378    APInt SqrtTerm(B);
5379    SqrtTerm *= B;
5380    SqrtTerm -= Four * (A * C);
5381
5382    // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
5383    // integer value or else APInt::sqrt() will assert.
5384    APInt SqrtVal(SqrtTerm.sqrt());
5385
5386    // Compute the two solutions for the quadratic formula.
5387    // The divisions must be performed as signed divisions.
5388    APInt NegB(-B);
5389    APInt TwoA(A << 1);
5390    if (TwoA.isMinValue()) {
5391      const SCEV *CNC = SE.getCouldNotCompute();
5392      return std::make_pair(CNC, CNC);
5393    }
5394
5395    LLVMContext &Context = SE.getContext();
5396
5397    ConstantInt *Solution1 =
5398      ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA));
5399    ConstantInt *Solution2 =
5400      ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA));
5401
5402    return std::make_pair(SE.getConstant(Solution1),
5403                          SE.getConstant(Solution2));
5404  } // end APIntOps namespace
5405}
5406
5407/// HowFarToZero - Return the number of times a backedge comparing the specified
5408/// value to zero will execute.  If not computable, return CouldNotCompute.
5409///
5410/// This is only used for loops with a "x != y" exit test. The exit condition is
5411/// now expressed as a single expression, V = x-y. So the exit test is
5412/// effectively V != 0.  We know and take advantage of the fact that this
5413/// expression only being used in a comparison by zero context.
5414ScalarEvolution::ExitLimit
5415ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) {
5416  // If the value is a constant
5417  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
5418    // If the value is already zero, the branch will execute zero times.
5419    if (C->getValue()->isZero()) return C;
5420    return getCouldNotCompute();  // Otherwise it will loop infinitely.
5421  }
5422
5423  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V);
5424  if (!AddRec || AddRec->getLoop() != L)
5425    return getCouldNotCompute();
5426
5427  // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
5428  // the quadratic equation to solve it.
5429  if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
5430    std::pair<const SCEV *,const SCEV *> Roots =
5431      SolveQuadraticEquation(AddRec, *this);
5432    const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
5433    const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
5434    if (R1 && R2) {
5435#if 0
5436      dbgs() << "HFTZ: " << *V << " - sol#1: " << *R1
5437             << "  sol#2: " << *R2 << "\n";
5438#endif
5439      // Pick the smallest positive root value.
5440      if (ConstantInt *CB =
5441          dyn_cast<ConstantInt>(ConstantExpr::getICmp(CmpInst::ICMP_ULT,
5442                                                      R1->getValue(),
5443                                                      R2->getValue()))) {
5444        if (CB->getZExtValue() == false)
5445          std::swap(R1, R2);   // R1 is the minimum root now.
5446
5447        // We can only use this value if the chrec ends up with an exact zero
5448        // value at this index.  When solving for "X*X != 5", for example, we
5449        // should not accept a root of 2.
5450        const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
5451        if (Val->isZero())
5452          return R1;  // We found a quadratic root!
5453      }
5454    }
5455    return getCouldNotCompute();
5456  }
5457
5458  // Otherwise we can only handle this if it is affine.
5459  if (!AddRec->isAffine())
5460    return getCouldNotCompute();
5461
5462  // If this is an affine expression, the execution count of this branch is
5463  // the minimum unsigned root of the following equation:
5464  //
5465  //     Start + Step*N = 0 (mod 2^BW)
5466  //
5467  // equivalent to:
5468  //
5469  //             Step*N = -Start (mod 2^BW)
5470  //
5471  // where BW is the common bit width of Start and Step.
5472
5473  // Get the initial value for the loop.
5474  const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
5475  const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
5476
5477  // For now we handle only constant steps.
5478  //
5479  // TODO: Handle a nonconstant Step given AddRec<NUW>. If the
5480  // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap
5481  // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step.
5482  // We have not yet seen any such cases.
5483  const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
5484  if (StepC == 0)
5485    return getCouldNotCompute();
5486
5487  // For positive steps (counting up until unsigned overflow):
5488  //   N = -Start/Step (as unsigned)
5489  // For negative steps (counting down to zero):
5490  //   N = Start/-Step
5491  // First compute the unsigned distance from zero in the direction of Step.
5492  bool CountDown = StepC->getValue()->getValue().isNegative();
5493  const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
5494
5495  // Handle unitary steps, which cannot wraparound.
5496  // 1*N = -Start; -1*N = Start (mod 2^BW), so:
5497  //   N = Distance (as unsigned)
5498  if (StepC->getValue()->equalsInt(1) || StepC->getValue()->isAllOnesValue()) {
5499    ConstantRange CR = getUnsignedRange(Start);
5500    const SCEV *MaxBECount;
5501    if (!CountDown && CR.getUnsignedMin().isMinValue())
5502      // When counting up, the worst starting value is 1, not 0.
5503      MaxBECount = CR.getUnsignedMax().isMinValue()
5504        ? getConstant(APInt::getMinValue(CR.getBitWidth()))
5505        : getConstant(APInt::getMaxValue(CR.getBitWidth()));
5506    else
5507      MaxBECount = getConstant(CountDown ? CR.getUnsignedMax()
5508                                         : -CR.getUnsignedMin());
5509    return ExitLimit(Distance, MaxBECount);
5510  }
5511
5512  // If the recurrence is known not to wraparound, unsigned divide computes the
5513  // back edge count. We know that the value will either become zero (and thus
5514  // the loop terminates), that the loop will terminate through some other exit
5515  // condition first, or that the loop has undefined behavior.  This means
5516  // we can't "miss" the exit value, even with nonunit stride.
5517  //
5518  // FIXME: Prove that loops always exhibits *acceptable* undefined
5519  // behavior. Loops must exhibit defined behavior until a wrapped value is
5520  // actually used. So the trip count computed by udiv could be smaller than the
5521  // number of well-defined iterations.
5522  if (AddRec->getNoWrapFlags(SCEV::FlagNW)) {
5523    // FIXME: We really want an "isexact" bit for udiv.
5524    return getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
5525  }
5526  // Then, try to solve the above equation provided that Start is constant.
5527  if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
5528    return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
5529                                        -StartC->getValue()->getValue(),
5530                                        *this);
5531  return getCouldNotCompute();
5532}
5533
5534/// HowFarToNonZero - Return the number of times a backedge checking the
5535/// specified value for nonzero will execute.  If not computable, return
5536/// CouldNotCompute
5537ScalarEvolution::ExitLimit
5538ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) {
5539  // Loops that look like: while (X == 0) are very strange indeed.  We don't
5540  // handle them yet except for the trivial case.  This could be expanded in the
5541  // future as needed.
5542
5543  // If the value is a constant, check to see if it is known to be non-zero
5544  // already.  If so, the backedge will execute zero times.
5545  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
5546    if (!C->getValue()->isNullValue())
5547      return getConstant(C->getType(), 0);
5548    return getCouldNotCompute();  // Otherwise it will loop infinitely.
5549  }
5550
5551  // We could implement others, but I really doubt anyone writes loops like
5552  // this, and if they did, they would already be constant folded.
5553  return getCouldNotCompute();
5554}
5555
5556/// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
5557/// (which may not be an immediate predecessor) which has exactly one
5558/// successor from which BB is reachable, or null if no such block is
5559/// found.
5560///
5561std::pair<BasicBlock *, BasicBlock *>
5562ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
5563  // If the block has a unique predecessor, then there is no path from the
5564  // predecessor to the block that does not go through the direct edge
5565  // from the predecessor to the block.
5566  if (BasicBlock *Pred = BB->getSinglePredecessor())
5567    return std::make_pair(Pred, BB);
5568
5569  // A loop's header is defined to be a block that dominates the loop.
5570  // If the header has a unique predecessor outside the loop, it must be
5571  // a block that has exactly one successor that can reach the loop.
5572  if (Loop *L = LI->getLoopFor(BB))
5573    return std::make_pair(L->getLoopPredecessor(), L->getHeader());
5574
5575  return std::pair<BasicBlock *, BasicBlock *>();
5576}
5577
5578/// HasSameValue - SCEV structural equivalence is usually sufficient for
5579/// testing whether two expressions are equal, however for the purposes of
5580/// looking for a condition guarding a loop, it can be useful to be a little
5581/// more general, since a front-end may have replicated the controlling
5582/// expression.
5583///
5584static bool HasSameValue(const SCEV *A, const SCEV *B) {
5585  // Quick check to see if they are the same SCEV.
5586  if (A == B) return true;
5587
5588  // Otherwise, if they're both SCEVUnknown, it's possible that they hold
5589  // two different instructions with the same value. Check for this case.
5590  if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
5591    if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
5592      if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
5593        if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
5594          if (AI->isIdenticalTo(BI) && !AI->mayReadFromMemory())
5595            return true;
5596
5597  // Otherwise assume they may have a different value.
5598  return false;
5599}
5600
5601/// SimplifyICmpOperands - Simplify LHS and RHS in a comparison with
5602/// predicate Pred. Return true iff any changes were made.
5603///
5604bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
5605                                           const SCEV *&LHS, const SCEV *&RHS) {
5606  bool Changed = false;
5607
5608  // Canonicalize a constant to the right side.
5609  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
5610    // Check for both operands constant.
5611    if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
5612      if (ConstantExpr::getICmp(Pred,
5613                                LHSC->getValue(),
5614                                RHSC->getValue())->isNullValue())
5615        goto trivially_false;
5616      else
5617        goto trivially_true;
5618    }
5619    // Otherwise swap the operands to put the constant on the right.
5620    std::swap(LHS, RHS);
5621    Pred = ICmpInst::getSwappedPredicate(Pred);
5622    Changed = true;
5623  }
5624
5625  // If we're comparing an addrec with a value which is loop-invariant in the
5626  // addrec's loop, put the addrec on the left. Also make a dominance check,
5627  // as both operands could be addrecs loop-invariant in each other's loop.
5628  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
5629    const Loop *L = AR->getLoop();
5630    if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
5631      std::swap(LHS, RHS);
5632      Pred = ICmpInst::getSwappedPredicate(Pred);
5633      Changed = true;
5634    }
5635  }
5636
5637  // If there's a constant operand, canonicalize comparisons with boundary
5638  // cases, and canonicalize *-or-equal comparisons to regular comparisons.
5639  if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
5640    const APInt &RA = RC->getValue()->getValue();
5641    switch (Pred) {
5642    default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
5643    case ICmpInst::ICMP_EQ:
5644    case ICmpInst::ICMP_NE:
5645      break;
5646    case ICmpInst::ICMP_UGE:
5647      if ((RA - 1).isMinValue()) {
5648        Pred = ICmpInst::ICMP_NE;
5649        RHS = getConstant(RA - 1);
5650        Changed = true;
5651        break;
5652      }
5653      if (RA.isMaxValue()) {
5654        Pred = ICmpInst::ICMP_EQ;
5655        Changed = true;
5656        break;
5657      }
5658      if (RA.isMinValue()) goto trivially_true;
5659
5660      Pred = ICmpInst::ICMP_UGT;
5661      RHS = getConstant(RA - 1);
5662      Changed = true;
5663      break;
5664    case ICmpInst::ICMP_ULE:
5665      if ((RA + 1).isMaxValue()) {
5666        Pred = ICmpInst::ICMP_NE;
5667        RHS = getConstant(RA + 1);
5668        Changed = true;
5669        break;
5670      }
5671      if (RA.isMinValue()) {
5672        Pred = ICmpInst::ICMP_EQ;
5673        Changed = true;
5674        break;
5675      }
5676      if (RA.isMaxValue()) goto trivially_true;
5677
5678      Pred = ICmpInst::ICMP_ULT;
5679      RHS = getConstant(RA + 1);
5680      Changed = true;
5681      break;
5682    case ICmpInst::ICMP_SGE:
5683      if ((RA - 1).isMinSignedValue()) {
5684        Pred = ICmpInst::ICMP_NE;
5685        RHS = getConstant(RA - 1);
5686        Changed = true;
5687        break;
5688      }
5689      if (RA.isMaxSignedValue()) {
5690        Pred = ICmpInst::ICMP_EQ;
5691        Changed = true;
5692        break;
5693      }
5694      if (RA.isMinSignedValue()) goto trivially_true;
5695
5696      Pred = ICmpInst::ICMP_SGT;
5697      RHS = getConstant(RA - 1);
5698      Changed = true;
5699      break;
5700    case ICmpInst::ICMP_SLE:
5701      if ((RA + 1).isMaxSignedValue()) {
5702        Pred = ICmpInst::ICMP_NE;
5703        RHS = getConstant(RA + 1);
5704        Changed = true;
5705        break;
5706      }
5707      if (RA.isMinSignedValue()) {
5708        Pred = ICmpInst::ICMP_EQ;
5709        Changed = true;
5710        break;
5711      }
5712      if (RA.isMaxSignedValue()) goto trivially_true;
5713
5714      Pred = ICmpInst::ICMP_SLT;
5715      RHS = getConstant(RA + 1);
5716      Changed = true;
5717      break;
5718    case ICmpInst::ICMP_UGT:
5719      if (RA.isMinValue()) {
5720        Pred = ICmpInst::ICMP_NE;
5721        Changed = true;
5722        break;
5723      }
5724      if ((RA + 1).isMaxValue()) {
5725        Pred = ICmpInst::ICMP_EQ;
5726        RHS = getConstant(RA + 1);
5727        Changed = true;
5728        break;
5729      }
5730      if (RA.isMaxValue()) goto trivially_false;
5731      break;
5732    case ICmpInst::ICMP_ULT:
5733      if (RA.isMaxValue()) {
5734        Pred = ICmpInst::ICMP_NE;
5735        Changed = true;
5736        break;
5737      }
5738      if ((RA - 1).isMinValue()) {
5739        Pred = ICmpInst::ICMP_EQ;
5740        RHS = getConstant(RA - 1);
5741        Changed = true;
5742        break;
5743      }
5744      if (RA.isMinValue()) goto trivially_false;
5745      break;
5746    case ICmpInst::ICMP_SGT:
5747      if (RA.isMinSignedValue()) {
5748        Pred = ICmpInst::ICMP_NE;
5749        Changed = true;
5750        break;
5751      }
5752      if ((RA + 1).isMaxSignedValue()) {
5753        Pred = ICmpInst::ICMP_EQ;
5754        RHS = getConstant(RA + 1);
5755        Changed = true;
5756        break;
5757      }
5758      if (RA.isMaxSignedValue()) goto trivially_false;
5759      break;
5760    case ICmpInst::ICMP_SLT:
5761      if (RA.isMaxSignedValue()) {
5762        Pred = ICmpInst::ICMP_NE;
5763        Changed = true;
5764        break;
5765      }
5766      if ((RA - 1).isMinSignedValue()) {
5767       Pred = ICmpInst::ICMP_EQ;
5768       RHS = getConstant(RA - 1);
5769        Changed = true;
5770       break;
5771      }
5772      if (RA.isMinSignedValue()) goto trivially_false;
5773      break;
5774    }
5775  }
5776
5777  // Check for obvious equality.
5778  if (HasSameValue(LHS, RHS)) {
5779    if (ICmpInst::isTrueWhenEqual(Pred))
5780      goto trivially_true;
5781    if (ICmpInst::isFalseWhenEqual(Pred))
5782      goto trivially_false;
5783  }
5784
5785  // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
5786  // adding or subtracting 1 from one of the operands.
5787  switch (Pred) {
5788  case ICmpInst::ICMP_SLE:
5789    if (!getSignedRange(RHS).getSignedMax().isMaxSignedValue()) {
5790      RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
5791                       SCEV::FlagNSW);
5792      Pred = ICmpInst::ICMP_SLT;
5793      Changed = true;
5794    } else if (!getSignedRange(LHS).getSignedMin().isMinSignedValue()) {
5795      LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
5796                       SCEV::FlagNSW);
5797      Pred = ICmpInst::ICMP_SLT;
5798      Changed = true;
5799    }
5800    break;
5801  case ICmpInst::ICMP_SGE:
5802    if (!getSignedRange(RHS).getSignedMin().isMinSignedValue()) {
5803      RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
5804                       SCEV::FlagNSW);
5805      Pred = ICmpInst::ICMP_SGT;
5806      Changed = true;
5807    } else if (!getSignedRange(LHS).getSignedMax().isMaxSignedValue()) {
5808      LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
5809                       SCEV::FlagNSW);
5810      Pred = ICmpInst::ICMP_SGT;
5811      Changed = true;
5812    }
5813    break;
5814  case ICmpInst::ICMP_ULE:
5815    if (!getUnsignedRange(RHS).getUnsignedMax().isMaxValue()) {
5816      RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
5817                       SCEV::FlagNUW);
5818      Pred = ICmpInst::ICMP_ULT;
5819      Changed = true;
5820    } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) {
5821      LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
5822                       SCEV::FlagNUW);
5823      Pred = ICmpInst::ICMP_ULT;
5824      Changed = true;
5825    }
5826    break;
5827  case ICmpInst::ICMP_UGE:
5828    if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) {
5829      RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
5830                       SCEV::FlagNUW);
5831      Pred = ICmpInst::ICMP_UGT;
5832      Changed = true;
5833    } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) {
5834      LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
5835                       SCEV::FlagNUW);
5836      Pred = ICmpInst::ICMP_UGT;
5837      Changed = true;
5838    }
5839    break;
5840  default:
5841    break;
5842  }
5843
5844  // TODO: More simplifications are possible here.
5845
5846  return Changed;
5847
5848trivially_true:
5849  // Return 0 == 0.
5850  LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
5851  Pred = ICmpInst::ICMP_EQ;
5852  return true;
5853
5854trivially_false:
5855  // Return 0 != 0.
5856  LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
5857  Pred = ICmpInst::ICMP_NE;
5858  return true;
5859}
5860
5861bool ScalarEvolution::isKnownNegative(const SCEV *S) {
5862  return getSignedRange(S).getSignedMax().isNegative();
5863}
5864
5865bool ScalarEvolution::isKnownPositive(const SCEV *S) {
5866  return getSignedRange(S).getSignedMin().isStrictlyPositive();
5867}
5868
5869bool ScalarEvolution::isKnownNonNegative(const SCEV *S) {
5870  return !getSignedRange(S).getSignedMin().isNegative();
5871}
5872
5873bool ScalarEvolution::isKnownNonPositive(const SCEV *S) {
5874  return !getSignedRange(S).getSignedMax().isStrictlyPositive();
5875}
5876
5877bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
5878  return isKnownNegative(S) || isKnownPositive(S);
5879}
5880
5881bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
5882                                       const SCEV *LHS, const SCEV *RHS) {
5883  // Canonicalize the inputs first.
5884  (void)SimplifyICmpOperands(Pred, LHS, RHS);
5885
5886  // If LHS or RHS is an addrec, check to see if the condition is true in
5887  // every iteration of the loop.
5888  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
5889    if (isLoopEntryGuardedByCond(
5890          AR->getLoop(), Pred, AR->getStart(), RHS) &&
5891        isLoopBackedgeGuardedByCond(
5892          AR->getLoop(), Pred, AR->getPostIncExpr(*this), RHS))
5893      return true;
5894  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS))
5895    if (isLoopEntryGuardedByCond(
5896          AR->getLoop(), Pred, LHS, AR->getStart()) &&
5897        isLoopBackedgeGuardedByCond(
5898          AR->getLoop(), Pred, LHS, AR->getPostIncExpr(*this)))
5899      return true;
5900
5901  // Otherwise see what can be done with known constant ranges.
5902  return isKnownPredicateWithRanges(Pred, LHS, RHS);
5903}
5904
5905bool
5906ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred,
5907                                            const SCEV *LHS, const SCEV *RHS) {
5908  if (HasSameValue(LHS, RHS))
5909    return ICmpInst::isTrueWhenEqual(Pred);
5910
5911  // This code is split out from isKnownPredicate because it is called from
5912  // within isLoopEntryGuardedByCond.
5913  switch (Pred) {
5914  default:
5915    llvm_unreachable("Unexpected ICmpInst::Predicate value!");
5916  case ICmpInst::ICMP_SGT:
5917    Pred = ICmpInst::ICMP_SLT;
5918    std::swap(LHS, RHS);
5919  case ICmpInst::ICMP_SLT: {
5920    ConstantRange LHSRange = getSignedRange(LHS);
5921    ConstantRange RHSRange = getSignedRange(RHS);
5922    if (LHSRange.getSignedMax().slt(RHSRange.getSignedMin()))
5923      return true;
5924    if (LHSRange.getSignedMin().sge(RHSRange.getSignedMax()))
5925      return false;
5926    break;
5927  }
5928  case ICmpInst::ICMP_SGE:
5929    Pred = ICmpInst::ICMP_SLE;
5930    std::swap(LHS, RHS);
5931  case ICmpInst::ICMP_SLE: {
5932    ConstantRange LHSRange = getSignedRange(LHS);
5933    ConstantRange RHSRange = getSignedRange(RHS);
5934    if (LHSRange.getSignedMax().sle(RHSRange.getSignedMin()))
5935      return true;
5936    if (LHSRange.getSignedMin().sgt(RHSRange.getSignedMax()))
5937      return false;
5938    break;
5939  }
5940  case ICmpInst::ICMP_UGT:
5941    Pred = ICmpInst::ICMP_ULT;
5942    std::swap(LHS, RHS);
5943  case ICmpInst::ICMP_ULT: {
5944    ConstantRange LHSRange = getUnsignedRange(LHS);
5945    ConstantRange RHSRange = getUnsignedRange(RHS);
5946    if (LHSRange.getUnsignedMax().ult(RHSRange.getUnsignedMin()))
5947      return true;
5948    if (LHSRange.getUnsignedMin().uge(RHSRange.getUnsignedMax()))
5949      return false;
5950    break;
5951  }
5952  case ICmpInst::ICMP_UGE:
5953    Pred = ICmpInst::ICMP_ULE;
5954    std::swap(LHS, RHS);
5955  case ICmpInst::ICMP_ULE: {
5956    ConstantRange LHSRange = getUnsignedRange(LHS);
5957    ConstantRange RHSRange = getUnsignedRange(RHS);
5958    if (LHSRange.getUnsignedMax().ule(RHSRange.getUnsignedMin()))
5959      return true;
5960    if (LHSRange.getUnsignedMin().ugt(RHSRange.getUnsignedMax()))
5961      return false;
5962    break;
5963  }
5964  case ICmpInst::ICMP_NE: {
5965    if (getUnsignedRange(LHS).intersectWith(getUnsignedRange(RHS)).isEmptySet())
5966      return true;
5967    if (getSignedRange(LHS).intersectWith(getSignedRange(RHS)).isEmptySet())
5968      return true;
5969
5970    const SCEV *Diff = getMinusSCEV(LHS, RHS);
5971    if (isKnownNonZero(Diff))
5972      return true;
5973    break;
5974  }
5975  case ICmpInst::ICMP_EQ:
5976    // The check at the top of the function catches the case where
5977    // the values are known to be equal.
5978    break;
5979  }
5980  return false;
5981}
5982
5983/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
5984/// protected by a conditional between LHS and RHS.  This is used to
5985/// to eliminate casts.
5986bool
5987ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
5988                                             ICmpInst::Predicate Pred,
5989                                             const SCEV *LHS, const SCEV *RHS) {
5990  // Interpret a null as meaning no loop, where there is obviously no guard
5991  // (interprocedural conditions notwithstanding).
5992  if (!L) return true;
5993
5994  BasicBlock *Latch = L->getLoopLatch();
5995  if (!Latch)
5996    return false;
5997
5998  BranchInst *LoopContinuePredicate =
5999    dyn_cast<BranchInst>(Latch->getTerminator());
6000  if (!LoopContinuePredicate ||
6001      LoopContinuePredicate->isUnconditional())
6002    return false;
6003
6004  return isImpliedCond(Pred, LHS, RHS,
6005                       LoopContinuePredicate->getCondition(),
6006                       LoopContinuePredicate->getSuccessor(0) != L->getHeader());
6007}
6008
6009/// isLoopEntryGuardedByCond - Test whether entry to the loop is protected
6010/// by a conditional between LHS and RHS.  This is used to help avoid max
6011/// expressions in loop trip counts, and to eliminate casts.
6012bool
6013ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
6014                                          ICmpInst::Predicate Pred,
6015                                          const SCEV *LHS, const SCEV *RHS) {
6016  // Interpret a null as meaning no loop, where there is obviously no guard
6017  // (interprocedural conditions notwithstanding).
6018  if (!L) return false;
6019
6020  // Starting at the loop predecessor, climb up the predecessor chain, as long
6021  // as there are predecessors that can be found that have unique successors
6022  // leading to the original header.
6023  for (std::pair<BasicBlock *, BasicBlock *>
6024         Pair(L->getLoopPredecessor(), L->getHeader());
6025       Pair.first;
6026       Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
6027
6028    BranchInst *LoopEntryPredicate =
6029      dyn_cast<BranchInst>(Pair.first->getTerminator());
6030    if (!LoopEntryPredicate ||
6031        LoopEntryPredicate->isUnconditional())
6032      continue;
6033
6034    if (isImpliedCond(Pred, LHS, RHS,
6035                      LoopEntryPredicate->getCondition(),
6036                      LoopEntryPredicate->getSuccessor(0) != Pair.second))
6037      return true;
6038  }
6039
6040  return false;
6041}
6042
6043/// isImpliedCond - Test whether the condition described by Pred, LHS,
6044/// and RHS is true whenever the given Cond value evaluates to true.
6045bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
6046                                    const SCEV *LHS, const SCEV *RHS,
6047                                    Value *FoundCondValue,
6048                                    bool Inverse) {
6049  // Recursively handle And and Or conditions.
6050  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) {
6051    if (BO->getOpcode() == Instruction::And) {
6052      if (!Inverse)
6053        return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
6054               isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
6055    } else if (BO->getOpcode() == Instruction::Or) {
6056      if (Inverse)
6057        return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
6058               isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
6059    }
6060  }
6061
6062  ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
6063  if (!ICI) return false;
6064
6065  // Bail if the ICmp's operands' types are wider than the needed type
6066  // before attempting to call getSCEV on them. This avoids infinite
6067  // recursion, since the analysis of widening casts can require loop
6068  // exit condition information for overflow checking, which would
6069  // lead back here.
6070  if (getTypeSizeInBits(LHS->getType()) <
6071      getTypeSizeInBits(ICI->getOperand(0)->getType()))
6072    return false;
6073
6074  // Now that we found a conditional branch that dominates the loop, check to
6075  // see if it is the comparison we are looking for.
6076  ICmpInst::Predicate FoundPred;
6077  if (Inverse)
6078    FoundPred = ICI->getInversePredicate();
6079  else
6080    FoundPred = ICI->getPredicate();
6081
6082  const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
6083  const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
6084
6085  // Balance the types. The case where FoundLHS' type is wider than
6086  // LHS' type is checked for above.
6087  if (getTypeSizeInBits(LHS->getType()) >
6088      getTypeSizeInBits(FoundLHS->getType())) {
6089    if (CmpInst::isSigned(Pred)) {
6090      FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
6091      FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
6092    } else {
6093      FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
6094      FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
6095    }
6096  }
6097
6098  // Canonicalize the query to match the way instcombine will have
6099  // canonicalized the comparison.
6100  if (SimplifyICmpOperands(Pred, LHS, RHS))
6101    if (LHS == RHS)
6102      return CmpInst::isTrueWhenEqual(Pred);
6103  if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
6104    if (FoundLHS == FoundRHS)
6105      return CmpInst::isFalseWhenEqual(Pred);
6106
6107  // Check to see if we can make the LHS or RHS match.
6108  if (LHS == FoundRHS || RHS == FoundLHS) {
6109    if (isa<SCEVConstant>(RHS)) {
6110      std::swap(FoundLHS, FoundRHS);
6111      FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
6112    } else {
6113      std::swap(LHS, RHS);
6114      Pred = ICmpInst::getSwappedPredicate(Pred);
6115    }
6116  }
6117
6118  // Check whether the found predicate is the same as the desired predicate.
6119  if (FoundPred == Pred)
6120    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
6121
6122  // Check whether swapping the found predicate makes it the same as the
6123  // desired predicate.
6124  if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
6125    if (isa<SCEVConstant>(RHS))
6126      return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS);
6127    else
6128      return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred),
6129                                   RHS, LHS, FoundLHS, FoundRHS);
6130  }
6131
6132  // Check whether the actual condition is beyond sufficient.
6133  if (FoundPred == ICmpInst::ICMP_EQ)
6134    if (ICmpInst::isTrueWhenEqual(Pred))
6135      if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS))
6136        return true;
6137  if (Pred == ICmpInst::ICMP_NE)
6138    if (!ICmpInst::isTrueWhenEqual(FoundPred))
6139      if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS))
6140        return true;
6141
6142  // Otherwise assume the worst.
6143  return false;
6144}
6145
6146/// isImpliedCondOperands - Test whether the condition described by Pred,
6147/// LHS, and RHS is true whenever the condition described by Pred, FoundLHS,
6148/// and FoundRHS is true.
6149bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
6150                                            const SCEV *LHS, const SCEV *RHS,
6151                                            const SCEV *FoundLHS,
6152                                            const SCEV *FoundRHS) {
6153  return isImpliedCondOperandsHelper(Pred, LHS, RHS,
6154                                     FoundLHS, FoundRHS) ||
6155         // ~x < ~y --> x > y
6156         isImpliedCondOperandsHelper(Pred, LHS, RHS,
6157                                     getNotSCEV(FoundRHS),
6158                                     getNotSCEV(FoundLHS));
6159}
6160
6161/// isImpliedCondOperandsHelper - Test whether the condition described by
6162/// Pred, LHS, and RHS is true whenever the condition described by Pred,
6163/// FoundLHS, and FoundRHS is true.
6164bool
6165ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
6166                                             const SCEV *LHS, const SCEV *RHS,
6167                                             const SCEV *FoundLHS,
6168                                             const SCEV *FoundRHS) {
6169  switch (Pred) {
6170  default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
6171  case ICmpInst::ICMP_EQ:
6172  case ICmpInst::ICMP_NE:
6173    if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
6174      return true;
6175    break;
6176  case ICmpInst::ICMP_SLT:
6177  case ICmpInst::ICMP_SLE:
6178    if (isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
6179        isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, RHS, FoundRHS))
6180      return true;
6181    break;
6182  case ICmpInst::ICMP_SGT:
6183  case ICmpInst::ICMP_SGE:
6184    if (isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
6185        isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, RHS, FoundRHS))
6186      return true;
6187    break;
6188  case ICmpInst::ICMP_ULT:
6189  case ICmpInst::ICMP_ULE:
6190    if (isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
6191        isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, RHS, FoundRHS))
6192      return true;
6193    break;
6194  case ICmpInst::ICMP_UGT:
6195  case ICmpInst::ICMP_UGE:
6196    if (isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
6197        isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, RHS, FoundRHS))
6198      return true;
6199    break;
6200  }
6201
6202  return false;
6203}
6204
6205/// getBECount - Subtract the end and start values and divide by the step,
6206/// rounding up, to get the number of times the backedge is executed. Return
6207/// CouldNotCompute if an intermediate computation overflows.
6208const SCEV *ScalarEvolution::getBECount(const SCEV *Start,
6209                                        const SCEV *End,
6210                                        const SCEV *Step,
6211                                        bool NoWrap) {
6212  assert(!isKnownNegative(Step) &&
6213         "This code doesn't handle negative strides yet!");
6214
6215  Type *Ty = Start->getType();
6216
6217  // When Start == End, we have an exact BECount == 0. Short-circuit this case
6218  // here because SCEV may not be able to determine that the unsigned division
6219  // after rounding is zero.
6220  if (Start == End)
6221    return getConstant(Ty, 0);
6222
6223  const SCEV *NegOne = getConstant(Ty, (uint64_t)-1);
6224  const SCEV *Diff = getMinusSCEV(End, Start);
6225  const SCEV *RoundUp = getAddExpr(Step, NegOne);
6226
6227  // Add an adjustment to the difference between End and Start so that
6228  // the division will effectively round up.
6229  const SCEV *Add = getAddExpr(Diff, RoundUp);
6230
6231  if (!NoWrap) {
6232    // Check Add for unsigned overflow.
6233    // TODO: More sophisticated things could be done here.
6234    Type *WideTy = IntegerType::get(getContext(),
6235                                          getTypeSizeInBits(Ty) + 1);
6236    const SCEV *EDiff = getZeroExtendExpr(Diff, WideTy);
6237    const SCEV *ERoundUp = getZeroExtendExpr(RoundUp, WideTy);
6238    const SCEV *OperandExtendedAdd = getAddExpr(EDiff, ERoundUp);
6239    if (getZeroExtendExpr(Add, WideTy) != OperandExtendedAdd)
6240      return getCouldNotCompute();
6241  }
6242
6243  return getUDivExpr(Add, Step);
6244}
6245
6246/// HowManyLessThans - Return the number of times a backedge containing the
6247/// specified less-than comparison will execute.  If not computable, return
6248/// CouldNotCompute.
6249ScalarEvolution::ExitLimit
6250ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
6251                                  const Loop *L, bool isSigned) {
6252  // Only handle:  "ADDREC < LoopInvariant".
6253  if (!isLoopInvariant(RHS, L)) return getCouldNotCompute();
6254
6255  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS);
6256  if (!AddRec || AddRec->getLoop() != L)
6257    return getCouldNotCompute();
6258
6259  // Check to see if we have a flag which makes analysis easy.
6260  bool NoWrap = isSigned ?
6261    AddRec->getNoWrapFlags((SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNW)) :
6262    AddRec->getNoWrapFlags((SCEV::NoWrapFlags)(SCEV::FlagNUW | SCEV::FlagNW));
6263
6264  if (AddRec->isAffine()) {
6265    unsigned BitWidth = getTypeSizeInBits(AddRec->getType());
6266    const SCEV *Step = AddRec->getStepRecurrence(*this);
6267
6268    if (Step->isZero())
6269      return getCouldNotCompute();
6270    if (Step->isOne()) {
6271      // With unit stride, the iteration never steps past the limit value.
6272    } else if (isKnownPositive(Step)) {
6273      // Test whether a positive iteration can step past the limit
6274      // value and past the maximum value for its type in a single step.
6275      // Note that it's not sufficient to check NoWrap here, because even
6276      // though the value after a wrap is undefined, it's not undefined
6277      // behavior, so if wrap does occur, the loop could either terminate or
6278      // loop infinitely, but in either case, the loop is guaranteed to
6279      // iterate at least until the iteration where the wrapping occurs.
6280      const SCEV *One = getConstant(Step->getType(), 1);
6281      if (isSigned) {
6282        APInt Max = APInt::getSignedMaxValue(BitWidth);
6283        if ((Max - getSignedRange(getMinusSCEV(Step, One)).getSignedMax())
6284              .slt(getSignedRange(RHS).getSignedMax()))
6285          return getCouldNotCompute();
6286      } else {
6287        APInt Max = APInt::getMaxValue(BitWidth);
6288        if ((Max - getUnsignedRange(getMinusSCEV(Step, One)).getUnsignedMax())
6289              .ult(getUnsignedRange(RHS).getUnsignedMax()))
6290          return getCouldNotCompute();
6291      }
6292    } else
6293      // TODO: Handle negative strides here and below.
6294      return getCouldNotCompute();
6295
6296    // We know the LHS is of the form {n,+,s} and the RHS is some loop-invariant
6297    // m.  So, we count the number of iterations in which {n,+,s} < m is true.
6298    // Note that we cannot simply return max(m-n,0)/s because it's not safe to
6299    // treat m-n as signed nor unsigned due to overflow possibility.
6300
6301    // First, we get the value of the LHS in the first iteration: n
6302    const SCEV *Start = AddRec->getOperand(0);
6303
6304    // Determine the minimum constant start value.
6305    const SCEV *MinStart = getConstant(isSigned ?
6306      getSignedRange(Start).getSignedMin() :
6307      getUnsignedRange(Start).getUnsignedMin());
6308
6309    // If we know that the condition is true in order to enter the loop,
6310    // then we know that it will run exactly (m-n)/s times. Otherwise, we
6311    // only know that it will execute (max(m,n)-n)/s times. In both cases,
6312    // the division must round up.
6313    const SCEV *End = RHS;
6314    if (!isLoopEntryGuardedByCond(L,
6315                                  isSigned ? ICmpInst::ICMP_SLT :
6316                                             ICmpInst::ICMP_ULT,
6317                                  getMinusSCEV(Start, Step), RHS))
6318      End = isSigned ? getSMaxExpr(RHS, Start)
6319                     : getUMaxExpr(RHS, Start);
6320
6321    // Determine the maximum constant end value.
6322    const SCEV *MaxEnd = getConstant(isSigned ?
6323      getSignedRange(End).getSignedMax() :
6324      getUnsignedRange(End).getUnsignedMax());
6325
6326    // If MaxEnd is within a step of the maximum integer value in its type,
6327    // adjust it down to the minimum value which would produce the same effect.
6328    // This allows the subsequent ceiling division of (N+(step-1))/step to
6329    // compute the correct value.
6330    const SCEV *StepMinusOne = getMinusSCEV(Step,
6331                                            getConstant(Step->getType(), 1));
6332    MaxEnd = isSigned ?
6333      getSMinExpr(MaxEnd,
6334                  getMinusSCEV(getConstant(APInt::getSignedMaxValue(BitWidth)),
6335                               StepMinusOne)) :
6336      getUMinExpr(MaxEnd,
6337                  getMinusSCEV(getConstant(APInt::getMaxValue(BitWidth)),
6338                               StepMinusOne));
6339
6340    // Finally, we subtract these two values and divide, rounding up, to get
6341    // the number of times the backedge is executed.
6342    const SCEV *BECount = getBECount(Start, End, Step, NoWrap);
6343
6344    // The maximum backedge count is similar, except using the minimum start
6345    // value and the maximum end value.
6346    // If we already have an exact constant BECount, use it instead.
6347    const SCEV *MaxBECount = isa<SCEVConstant>(BECount) ? BECount
6348      : getBECount(MinStart, MaxEnd, Step, NoWrap);
6349
6350    // If the stride is nonconstant, and NoWrap == true, then
6351    // getBECount(MinStart, MaxEnd) may not compute. This would result in an
6352    // exact BECount and invalid MaxBECount, which should be avoided to catch
6353    // more optimization opportunities.
6354    if (isa<SCEVCouldNotCompute>(MaxBECount))
6355      MaxBECount = BECount;
6356
6357    return ExitLimit(BECount, MaxBECount);
6358  }
6359
6360  return getCouldNotCompute();
6361}
6362
6363/// getNumIterationsInRange - Return the number of iterations of this loop that
6364/// produce values in the specified constant range.  Another way of looking at
6365/// this is that it returns the first iteration number where the value is not in
6366/// the condition, thus computing the exit count. If the iteration count can't
6367/// be computed, an instance of SCEVCouldNotCompute is returned.
6368const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
6369                                                    ScalarEvolution &SE) const {
6370  if (Range.isFullSet())  // Infinite loop.
6371    return SE.getCouldNotCompute();
6372
6373  // If the start is a non-zero constant, shift the range to simplify things.
6374  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
6375    if (!SC->getValue()->isZero()) {
6376      SmallVector<const SCEV *, 4> Operands(op_begin(), op_end());
6377      Operands[0] = SE.getConstant(SC->getType(), 0);
6378      const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
6379                                             getNoWrapFlags(FlagNW));
6380      if (const SCEVAddRecExpr *ShiftedAddRec =
6381            dyn_cast<SCEVAddRecExpr>(Shifted))
6382        return ShiftedAddRec->getNumIterationsInRange(
6383                           Range.subtract(SC->getValue()->getValue()), SE);
6384      // This is strange and shouldn't happen.
6385      return SE.getCouldNotCompute();
6386    }
6387
6388  // The only time we can solve this is when we have all constant indices.
6389  // Otherwise, we cannot determine the overflow conditions.
6390  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
6391    if (!isa<SCEVConstant>(getOperand(i)))
6392      return SE.getCouldNotCompute();
6393
6394
6395  // Okay at this point we know that all elements of the chrec are constants and
6396  // that the start element is zero.
6397
6398  // First check to see if the range contains zero.  If not, the first
6399  // iteration exits.
6400  unsigned BitWidth = SE.getTypeSizeInBits(getType());
6401  if (!Range.contains(APInt(BitWidth, 0)))
6402    return SE.getConstant(getType(), 0);
6403
6404  if (isAffine()) {
6405    // If this is an affine expression then we have this situation:
6406    //   Solve {0,+,A} in Range  ===  Ax in Range
6407
6408    // We know that zero is in the range.  If A is positive then we know that
6409    // the upper value of the range must be the first possible exit value.
6410    // If A is negative then the lower of the range is the last possible loop
6411    // value.  Also note that we already checked for a full range.
6412    APInt One(BitWidth,1);
6413    APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
6414    APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
6415
6416    // The exit value should be (End+A)/A.
6417    APInt ExitVal = (End + A).udiv(A);
6418    ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
6419
6420    // Evaluate at the exit value.  If we really did fall out of the valid
6421    // range, then we computed our trip count, otherwise wrap around or other
6422    // things must have happened.
6423    ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
6424    if (Range.contains(Val->getValue()))
6425      return SE.getCouldNotCompute();  // Something strange happened
6426
6427    // Ensure that the previous value is in the range.  This is a sanity check.
6428    assert(Range.contains(
6429           EvaluateConstantChrecAtConstant(this,
6430           ConstantInt::get(SE.getContext(), ExitVal - One), SE)->getValue()) &&
6431           "Linear scev computation is off in a bad way!");
6432    return SE.getConstant(ExitValue);
6433  } else if (isQuadratic()) {
6434    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
6435    // quadratic equation to solve it.  To do this, we must frame our problem in
6436    // terms of figuring out when zero is crossed, instead of when
6437    // Range.getUpper() is crossed.
6438    SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end());
6439    NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper()));
6440    const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(),
6441                                             // getNoWrapFlags(FlagNW)
6442                                             FlagAnyWrap);
6443
6444    // Next, solve the constructed addrec
6445    std::pair<const SCEV *,const SCEV *> Roots =
6446      SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
6447    const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
6448    const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
6449    if (R1) {
6450      // Pick the smallest positive root value.
6451      if (ConstantInt *CB =
6452          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
6453                         R1->getValue(), R2->getValue()))) {
6454        if (CB->getZExtValue() == false)
6455          std::swap(R1, R2);   // R1 is the minimum root now.
6456
6457        // Make sure the root is not off by one.  The returned iteration should
6458        // not be in the range, but the previous one should be.  When solving
6459        // for "X*X < 5", for example, we should not return a root of 2.
6460        ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this,
6461                                                             R1->getValue(),
6462                                                             SE);
6463        if (Range.contains(R1Val->getValue())) {
6464          // The next iteration must be out of the range...
6465          ConstantInt *NextVal =
6466                ConstantInt::get(SE.getContext(), R1->getValue()->getValue()+1);
6467
6468          R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
6469          if (!Range.contains(R1Val->getValue()))
6470            return SE.getConstant(NextVal);
6471          return SE.getCouldNotCompute();  // Something strange happened
6472        }
6473
6474        // If R1 was not in the range, then it is a good return value.  Make
6475        // sure that R1-1 WAS in the range though, just in case.
6476        ConstantInt *NextVal =
6477               ConstantInt::get(SE.getContext(), R1->getValue()->getValue()-1);
6478        R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
6479        if (Range.contains(R1Val->getValue()))
6480          return R1;
6481        return SE.getCouldNotCompute();  // Something strange happened
6482      }
6483    }
6484  }
6485
6486  return SE.getCouldNotCompute();
6487}
6488
6489
6490
6491//===----------------------------------------------------------------------===//
6492//                   SCEVCallbackVH Class Implementation
6493//===----------------------------------------------------------------------===//
6494
6495void ScalarEvolution::SCEVCallbackVH::deleted() {
6496  assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
6497  if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
6498    SE->ConstantEvolutionLoopExitValue.erase(PN);
6499  SE->ValueExprMap.erase(getValPtr());
6500  // this now dangles!
6501}
6502
6503void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
6504  assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
6505
6506  // Forget all the expressions associated with users of the old value,
6507  // so that future queries will recompute the expressions using the new
6508  // value.
6509  Value *Old = getValPtr();
6510  SmallVector<User *, 16> Worklist;
6511  SmallPtrSet<User *, 8> Visited;
6512  for (Value::use_iterator UI = Old->use_begin(), UE = Old->use_end();
6513       UI != UE; ++UI)
6514    Worklist.push_back(*UI);
6515  while (!Worklist.empty()) {
6516    User *U = Worklist.pop_back_val();
6517    // Deleting the Old value will cause this to dangle. Postpone
6518    // that until everything else is done.
6519    if (U == Old)
6520      continue;
6521    if (!Visited.insert(U))
6522      continue;
6523    if (PHINode *PN = dyn_cast<PHINode>(U))
6524      SE->ConstantEvolutionLoopExitValue.erase(PN);
6525    SE->ValueExprMap.erase(U);
6526    for (Value::use_iterator UI = U->use_begin(), UE = U->use_end();
6527         UI != UE; ++UI)
6528      Worklist.push_back(*UI);
6529  }
6530  // Delete the Old value.
6531  if (PHINode *PN = dyn_cast<PHINode>(Old))
6532    SE->ConstantEvolutionLoopExitValue.erase(PN);
6533  SE->ValueExprMap.erase(Old);
6534  // this now dangles!
6535}
6536
6537ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
6538  : CallbackVH(V), SE(se) {}
6539
6540//===----------------------------------------------------------------------===//
6541//                   ScalarEvolution Class Implementation
6542//===----------------------------------------------------------------------===//
6543
6544ScalarEvolution::ScalarEvolution()
6545  : FunctionPass(ID), FirstUnknown(0) {
6546  initializeScalarEvolutionPass(*PassRegistry::getPassRegistry());
6547}
6548
6549bool ScalarEvolution::runOnFunction(Function &F) {
6550  this->F = &F;
6551  LI = &getAnalysis<LoopInfo>();
6552  TD = getAnalysisIfAvailable<TargetData>();
6553  TLI = &getAnalysis<TargetLibraryInfo>();
6554  DT = &getAnalysis<DominatorTree>();
6555  return false;
6556}
6557
6558void ScalarEvolution::releaseMemory() {
6559  // Iterate through all the SCEVUnknown instances and call their
6560  // destructors, so that they release their references to their values.
6561  for (SCEVUnknown *U = FirstUnknown; U; U = U->Next)
6562    U->~SCEVUnknown();
6563  FirstUnknown = 0;
6564
6565  ValueExprMap.clear();
6566
6567  // Free any extra memory created for ExitNotTakenInfo in the unlikely event
6568  // that a loop had multiple computable exits.
6569  for (DenseMap<const Loop*, BackedgeTakenInfo>::iterator I =
6570         BackedgeTakenCounts.begin(), E = BackedgeTakenCounts.end();
6571       I != E; ++I) {
6572    I->second.clear();
6573  }
6574
6575  BackedgeTakenCounts.clear();
6576  ConstantEvolutionLoopExitValue.clear();
6577  ValuesAtScopes.clear();
6578  LoopDispositions.clear();
6579  BlockDispositions.clear();
6580  UnsignedRanges.clear();
6581  SignedRanges.clear();
6582  UniqueSCEVs.clear();
6583  SCEVAllocator.Reset();
6584}
6585
6586void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
6587  AU.setPreservesAll();
6588  AU.addRequiredTransitive<LoopInfo>();
6589  AU.addRequiredTransitive<DominatorTree>();
6590  AU.addRequired<TargetLibraryInfo>();
6591}
6592
6593bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
6594  return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
6595}
6596
6597static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
6598                          const Loop *L) {
6599  // Print all inner loops first
6600  for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
6601    PrintLoopInfo(OS, SE, *I);
6602
6603  OS << "Loop ";
6604  WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false);
6605  OS << ": ";
6606
6607  SmallVector<BasicBlock *, 8> ExitBlocks;
6608  L->getExitBlocks(ExitBlocks);
6609  if (ExitBlocks.size() != 1)
6610    OS << "<multiple exits> ";
6611
6612  if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
6613    OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L);
6614  } else {
6615    OS << "Unpredictable backedge-taken count. ";
6616  }
6617
6618  OS << "\n"
6619        "Loop ";
6620  WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false);
6621  OS << ": ";
6622
6623  if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) {
6624    OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L);
6625  } else {
6626    OS << "Unpredictable max backedge-taken count. ";
6627  }
6628
6629  OS << "\n";
6630}
6631
6632void ScalarEvolution::print(raw_ostream &OS, const Module *) const {
6633  // ScalarEvolution's implementation of the print method is to print
6634  // out SCEV values of all instructions that are interesting. Doing
6635  // this potentially causes it to create new SCEV objects though,
6636  // which technically conflicts with the const qualifier. This isn't
6637  // observable from outside the class though, so casting away the
6638  // const isn't dangerous.
6639  ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
6640
6641  OS << "Classifying expressions for: ";
6642  WriteAsOperand(OS, F, /*PrintType=*/false);
6643  OS << "\n";
6644  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
6645    if (isSCEVable(I->getType()) && !isa<CmpInst>(*I)) {
6646      OS << *I << '\n';
6647      OS << "  -->  ";
6648      const SCEV *SV = SE.getSCEV(&*I);
6649      SV->print(OS);
6650
6651      const Loop *L = LI->getLoopFor((*I).getParent());
6652
6653      const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
6654      if (AtUse != SV) {
6655        OS << "  -->  ";
6656        AtUse->print(OS);
6657      }
6658
6659      if (L) {
6660        OS << "\t\t" "Exits: ";
6661        const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
6662        if (!SE.isLoopInvariant(ExitValue, L)) {
6663          OS << "<<Unknown>>";
6664        } else {
6665          OS << *ExitValue;
6666        }
6667      }
6668
6669      OS << "\n";
6670    }
6671
6672  OS << "Determining loop execution counts for: ";
6673  WriteAsOperand(OS, F, /*PrintType=*/false);
6674  OS << "\n";
6675  for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I)
6676    PrintLoopInfo(OS, &SE, *I);
6677}
6678
6679ScalarEvolution::LoopDisposition
6680ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) {
6681  std::map<const Loop *, LoopDisposition> &Values = LoopDispositions[S];
6682  std::pair<std::map<const Loop *, LoopDisposition>::iterator, bool> Pair =
6683    Values.insert(std::make_pair(L, LoopVariant));
6684  if (!Pair.second)
6685    return Pair.first->second;
6686
6687  LoopDisposition D = computeLoopDisposition(S, L);
6688  return LoopDispositions[S][L] = D;
6689}
6690
6691ScalarEvolution::LoopDisposition
6692ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
6693  switch (S->getSCEVType()) {
6694  case scConstant:
6695    return LoopInvariant;
6696  case scTruncate:
6697  case scZeroExtend:
6698  case scSignExtend:
6699    return getLoopDisposition(cast<SCEVCastExpr>(S)->getOperand(), L);
6700  case scAddRecExpr: {
6701    const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
6702
6703    // If L is the addrec's loop, it's computable.
6704    if (AR->getLoop() == L)
6705      return LoopComputable;
6706
6707    // Add recurrences are never invariant in the function-body (null loop).
6708    if (!L)
6709      return LoopVariant;
6710
6711    // This recurrence is variant w.r.t. L if L contains AR's loop.
6712    if (L->contains(AR->getLoop()))
6713      return LoopVariant;
6714
6715    // This recurrence is invariant w.r.t. L if AR's loop contains L.
6716    if (AR->getLoop()->contains(L))
6717      return LoopInvariant;
6718
6719    // This recurrence is variant w.r.t. L if any of its operands
6720    // are variant.
6721    for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
6722         I != E; ++I)
6723      if (!isLoopInvariant(*I, L))
6724        return LoopVariant;
6725
6726    // Otherwise it's loop-invariant.
6727    return LoopInvariant;
6728  }
6729  case scAddExpr:
6730  case scMulExpr:
6731  case scUMaxExpr:
6732  case scSMaxExpr: {
6733    const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
6734    bool HasVarying = false;
6735    for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
6736         I != E; ++I) {
6737      LoopDisposition D = getLoopDisposition(*I, L);
6738      if (D == LoopVariant)
6739        return LoopVariant;
6740      if (D == LoopComputable)
6741        HasVarying = true;
6742    }
6743    return HasVarying ? LoopComputable : LoopInvariant;
6744  }
6745  case scUDivExpr: {
6746    const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6747    LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L);
6748    if (LD == LoopVariant)
6749      return LoopVariant;
6750    LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L);
6751    if (RD == LoopVariant)
6752      return LoopVariant;
6753    return (LD == LoopInvariant && RD == LoopInvariant) ?
6754           LoopInvariant : LoopComputable;
6755  }
6756  case scUnknown:
6757    // All non-instruction values are loop invariant.  All instructions are loop
6758    // invariant if they are not contained in the specified loop.
6759    // Instructions are never considered invariant in the function body
6760    // (null loop) because they are defined within the "loop".
6761    if (Instruction *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
6762      return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
6763    return LoopInvariant;
6764  case scCouldNotCompute:
6765    llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6766  default: llvm_unreachable("Unknown SCEV kind!");
6767  }
6768}
6769
6770bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) {
6771  return getLoopDisposition(S, L) == LoopInvariant;
6772}
6773
6774bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) {
6775  return getLoopDisposition(S, L) == LoopComputable;
6776}
6777
6778ScalarEvolution::BlockDisposition
6779ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) {
6780  std::map<const BasicBlock *, BlockDisposition> &Values = BlockDispositions[S];
6781  std::pair<std::map<const BasicBlock *, BlockDisposition>::iterator, bool>
6782    Pair = Values.insert(std::make_pair(BB, DoesNotDominateBlock));
6783  if (!Pair.second)
6784    return Pair.first->second;
6785
6786  BlockDisposition D = computeBlockDisposition(S, BB);
6787  return BlockDispositions[S][BB] = D;
6788}
6789
6790ScalarEvolution::BlockDisposition
6791ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
6792  switch (S->getSCEVType()) {
6793  case scConstant:
6794    return ProperlyDominatesBlock;
6795  case scTruncate:
6796  case scZeroExtend:
6797  case scSignExtend:
6798    return getBlockDisposition(cast<SCEVCastExpr>(S)->getOperand(), BB);
6799  case scAddRecExpr: {
6800    // This uses a "dominates" query instead of "properly dominates" query
6801    // to test for proper dominance too, because the instruction which
6802    // produces the addrec's value is a PHI, and a PHI effectively properly
6803    // dominates its entire containing block.
6804    const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
6805    if (!DT->dominates(AR->getLoop()->getHeader(), BB))
6806      return DoesNotDominateBlock;
6807  }
6808  // FALL THROUGH into SCEVNAryExpr handling.
6809  case scAddExpr:
6810  case scMulExpr:
6811  case scUMaxExpr:
6812  case scSMaxExpr: {
6813    const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
6814    bool Proper = true;
6815    for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
6816         I != E; ++I) {
6817      BlockDisposition D = getBlockDisposition(*I, BB);
6818      if (D == DoesNotDominateBlock)
6819        return DoesNotDominateBlock;
6820      if (D == DominatesBlock)
6821        Proper = false;
6822    }
6823    return Proper ? ProperlyDominatesBlock : DominatesBlock;
6824  }
6825  case scUDivExpr: {
6826    const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6827    const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS();
6828    BlockDisposition LD = getBlockDisposition(LHS, BB);
6829    if (LD == DoesNotDominateBlock)
6830      return DoesNotDominateBlock;
6831    BlockDisposition RD = getBlockDisposition(RHS, BB);
6832    if (RD == DoesNotDominateBlock)
6833      return DoesNotDominateBlock;
6834    return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ?
6835      ProperlyDominatesBlock : DominatesBlock;
6836  }
6837  case scUnknown:
6838    if (Instruction *I =
6839          dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
6840      if (I->getParent() == BB)
6841        return DominatesBlock;
6842      if (DT->properlyDominates(I->getParent(), BB))
6843        return ProperlyDominatesBlock;
6844      return DoesNotDominateBlock;
6845    }
6846    return ProperlyDominatesBlock;
6847  case scCouldNotCompute:
6848    llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6849  default:
6850    llvm_unreachable("Unknown SCEV kind!");
6851  }
6852}
6853
6854bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
6855  return getBlockDisposition(S, BB) >= DominatesBlock;
6856}
6857
6858bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) {
6859  return getBlockDisposition(S, BB) == ProperlyDominatesBlock;
6860}
6861
6862bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
6863  switch (S->getSCEVType()) {
6864  case scConstant:
6865    return false;
6866  case scTruncate:
6867  case scZeroExtend:
6868  case scSignExtend: {
6869    const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
6870    const SCEV *CastOp = Cast->getOperand();
6871    return Op == CastOp || hasOperand(CastOp, Op);
6872  }
6873  case scAddRecExpr:
6874  case scAddExpr:
6875  case scMulExpr:
6876  case scUMaxExpr:
6877  case scSMaxExpr: {
6878    const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
6879    for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
6880         I != E; ++I) {
6881      const SCEV *NAryOp = *I;
6882      if (NAryOp == Op || hasOperand(NAryOp, Op))
6883        return true;
6884    }
6885    return false;
6886  }
6887  case scUDivExpr: {
6888    const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6889    const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS();
6890    return LHS == Op || hasOperand(LHS, Op) ||
6891           RHS == Op || hasOperand(RHS, Op);
6892  }
6893  case scUnknown:
6894    return false;
6895  case scCouldNotCompute:
6896    llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6897  default:
6898    llvm_unreachable("Unknown SCEV kind!");
6899  }
6900}
6901
6902void ScalarEvolution::forgetMemoizedResults(const SCEV *S) {
6903  ValuesAtScopes.erase(S);
6904  LoopDispositions.erase(S);
6905  BlockDispositions.erase(S);
6906  UnsignedRanges.erase(S);
6907  SignedRanges.erase(S);
6908}
6909