ScalarEvolution.cpp revision b7ef72963b2215ca23c27fa8ea777bada06994d0
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file contains the implementation of the scalar evolution analysis
11// engine, which is used primarily to analyze expressions involving induction
12// variables in loops.
13//
14// There are several aspects to this library.  First is the representation of
15// scalar expressions, which are represented as subclasses of the SCEV class.
16// These classes are used to represent certain types of subexpressions that we
17// can handle.  These classes are reference counted, managed by the SCEVHandle
18// class.  We only create one SCEV of a particular shape, so pointer-comparisons
19// for equality are legal.
20//
21// One important aspect of the SCEV objects is that they are never cyclic, even
22// if there is a cycle in the dataflow for an expression (ie, a PHI node).  If
23// the PHI node is one of the idioms that we can represent (e.g., a polynomial
24// recurrence) then we represent it directly as a recurrence node, otherwise we
25// represent it as a SCEVUnknown node.
26//
27// In addition to being able to represent expressions of various types, we also
28// have folders that are used to build the *canonical* representation for a
29// particular expression.  These folders are capable of using a variety of
30// rewrite rules to simplify the expressions.
31//
32// Once the folders are defined, we can implement the more interesting
33// higher-level code, such as the code that recognizes PHI nodes of various
34// types, computes the execution count of a loop, etc.
35//
36// TODO: We should use these routines and value representations to implement
37// dependence analysis!
38//
39//===----------------------------------------------------------------------===//
40//
41// There are several good references for the techniques used in this analysis.
42//
43//  Chains of recurrences -- a method to expedite the evaluation
44//  of closed-form functions
45//  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
46//
47//  On computational properties of chains of recurrences
48//  Eugene V. Zima
49//
50//  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
51//  Robert A. van Engelen
52//
53//  Efficient Symbolic Analysis for Optimizing Compilers
54//  Robert A. van Engelen
55//
56//  Using the chains of recurrences algebra for data dependence testing and
57//  induction variable substitution
58//  MS Thesis, Johnie Birch
59//
60//===----------------------------------------------------------------------===//
61
62#define DEBUG_TYPE "scalar-evolution"
63#include "llvm/Analysis/ScalarEvolutionExpressions.h"
64#include "llvm/Constants.h"
65#include "llvm/DerivedTypes.h"
66#include "llvm/GlobalVariable.h"
67#include "llvm/Instructions.h"
68#include "llvm/Analysis/ConstantFolding.h"
69#include "llvm/Analysis/Dominators.h"
70#include "llvm/Analysis/LoopInfo.h"
71#include "llvm/Assembly/Writer.h"
72#include "llvm/Target/TargetData.h"
73#include "llvm/Transforms/Scalar.h"
74#include "llvm/Support/CFG.h"
75#include "llvm/Support/CommandLine.h"
76#include "llvm/Support/Compiler.h"
77#include "llvm/Support/ConstantRange.h"
78#include "llvm/Support/GetElementPtrTypeIterator.h"
79#include "llvm/Support/InstIterator.h"
80#include "llvm/Support/ManagedStatic.h"
81#include "llvm/Support/MathExtras.h"
82#include "llvm/Support/raw_ostream.h"
83#include "llvm/ADT/Statistic.h"
84#include "llvm/ADT/STLExtras.h"
85#include <ostream>
86#include <algorithm>
87#include <cmath>
88using namespace llvm;
89
90STATISTIC(NumArrayLenItCounts,
91          "Number of trip counts computed with array length");
92STATISTIC(NumTripCountsComputed,
93          "Number of loops with predictable loop counts");
94STATISTIC(NumTripCountsNotComputed,
95          "Number of loops without predictable loop counts");
96STATISTIC(NumBruteForceTripCountsComputed,
97          "Number of loops with trip counts computed by force");
98
99static cl::opt<unsigned>
100MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
101                        cl::desc("Maximum number of iterations SCEV will "
102                                 "symbolically execute a constant derived loop"),
103                        cl::init(100));
104
105static RegisterPass<ScalarEvolution>
106R("scalar-evolution", "Scalar Evolution Analysis", false, true);
107char ScalarEvolution::ID = 0;
108
109//===----------------------------------------------------------------------===//
110//                           SCEV class definitions
111//===----------------------------------------------------------------------===//
112
113//===----------------------------------------------------------------------===//
114// Implementation of the SCEV class.
115//
116SCEV::~SCEV() {}
117void SCEV::dump() const {
118  print(errs());
119  errs() << '\n';
120}
121
122void SCEV::print(std::ostream &o) const {
123  raw_os_ostream OS(o);
124  print(OS);
125}
126
127bool SCEV::isZero() const {
128  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
129    return SC->getValue()->isZero();
130  return false;
131}
132
133
134SCEVCouldNotCompute::SCEVCouldNotCompute() : SCEV(scCouldNotCompute) {}
135
136bool SCEVCouldNotCompute::isLoopInvariant(const Loop *L) const {
137  assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
138  return false;
139}
140
141const Type *SCEVCouldNotCompute::getType() const {
142  assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
143  return 0;
144}
145
146bool SCEVCouldNotCompute::hasComputableLoopEvolution(const Loop *L) const {
147  assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
148  return false;
149}
150
151SCEVHandle SCEVCouldNotCompute::
152replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
153                                  const SCEVHandle &Conc,
154                                  ScalarEvolution &SE) const {
155  return this;
156}
157
158void SCEVCouldNotCompute::print(raw_ostream &OS) const {
159  OS << "***COULDNOTCOMPUTE***";
160}
161
162bool SCEVCouldNotCompute::classof(const SCEV *S) {
163  return S->getSCEVType() == scCouldNotCompute;
164}
165
166
167// SCEVConstants - Only allow the creation of one SCEVConstant for any
168// particular value.  Don't use a SCEVHandle here, or else the object will
169// never be deleted!
170static ManagedStatic<std::map<ConstantInt*, SCEVConstant*> > SCEVConstants;
171
172
173SCEVConstant::~SCEVConstant() {
174  SCEVConstants->erase(V);
175}
176
177SCEVHandle ScalarEvolution::getConstant(ConstantInt *V) {
178  SCEVConstant *&R = (*SCEVConstants)[V];
179  if (R == 0) R = new SCEVConstant(V);
180  return R;
181}
182
183SCEVHandle ScalarEvolution::getConstant(const APInt& Val) {
184  return getConstant(ConstantInt::get(Val));
185}
186
187const Type *SCEVConstant::getType() const { return V->getType(); }
188
189void SCEVConstant::print(raw_ostream &OS) const {
190  WriteAsOperand(OS, V, false);
191}
192
193// SCEVTruncates - Only allow the creation of one SCEVTruncateExpr for any
194// particular input.  Don't use a SCEVHandle here, or else the object will
195// never be deleted!
196static ManagedStatic<std::map<std::pair<SCEV*, const Type*>,
197                     SCEVTruncateExpr*> > SCEVTruncates;
198
199SCEVTruncateExpr::SCEVTruncateExpr(const SCEVHandle &op, const Type *ty)
200  : SCEV(scTruncate), Op(op), Ty(ty) {
201  assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
202         (Ty->isInteger() || isa<PointerType>(Ty)) &&
203         "Cannot truncate non-integer value!");
204  assert((!Op->getType()->isInteger() || !Ty->isInteger() ||
205          Op->getType()->getPrimitiveSizeInBits() >
206            Ty->getPrimitiveSizeInBits()) &&
207         "This is not a truncating conversion!");
208}
209
210SCEVTruncateExpr::~SCEVTruncateExpr() {
211  SCEVTruncates->erase(std::make_pair(Op, Ty));
212}
213
214bool SCEVTruncateExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
215  return Op->dominates(BB, DT);
216}
217
218void SCEVTruncateExpr::print(raw_ostream &OS) const {
219  OS << "(truncate " << *Op << " to " << *Ty << ")";
220}
221
222// SCEVZeroExtends - Only allow the creation of one SCEVZeroExtendExpr for any
223// particular input.  Don't use a SCEVHandle here, or else the object will never
224// be deleted!
225static ManagedStatic<std::map<std::pair<SCEV*, const Type*>,
226                     SCEVZeroExtendExpr*> > SCEVZeroExtends;
227
228SCEVZeroExtendExpr::SCEVZeroExtendExpr(const SCEVHandle &op, const Type *ty)
229  : SCEV(scZeroExtend), Op(op), Ty(ty) {
230  assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
231         (Ty->isInteger() || isa<PointerType>(Ty)) &&
232         "Cannot zero extend non-integer value!");
233}
234
235SCEVZeroExtendExpr::~SCEVZeroExtendExpr() {
236  SCEVZeroExtends->erase(std::make_pair(Op, Ty));
237}
238
239bool SCEVZeroExtendExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
240  return Op->dominates(BB, DT);
241}
242
243void SCEVZeroExtendExpr::print(raw_ostream &OS) const {
244  OS << "(zeroextend " << *Op << " to " << *Ty << ")";
245}
246
247// SCEVSignExtends - Only allow the creation of one SCEVSignExtendExpr for any
248// particular input.  Don't use a SCEVHandle here, or else the object will never
249// be deleted!
250static ManagedStatic<std::map<std::pair<SCEV*, const Type*>,
251                     SCEVSignExtendExpr*> > SCEVSignExtends;
252
253SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty)
254  : SCEV(scSignExtend), Op(op), Ty(ty) {
255  assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
256         (Ty->isInteger() || isa<PointerType>(Ty)) &&
257         "Cannot sign extend non-integer value!");
258  assert(Op->getType()->getPrimitiveSizeInBits() < Ty->getPrimitiveSizeInBits()
259         && "This is not an extending conversion!");
260}
261
262SCEVSignExtendExpr::~SCEVSignExtendExpr() {
263  SCEVSignExtends->erase(std::make_pair(Op, Ty));
264}
265
266bool SCEVSignExtendExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
267  return Op->dominates(BB, DT);
268}
269
270void SCEVSignExtendExpr::print(raw_ostream &OS) const {
271  OS << "(signextend " << *Op << " to " << *Ty << ")";
272}
273
274// SCEVCommExprs - Only allow the creation of one SCEVCommutativeExpr for any
275// particular input.  Don't use a SCEVHandle here, or else the object will never
276// be deleted!
277static ManagedStatic<std::map<std::pair<unsigned, std::vector<SCEV*> >,
278                     SCEVCommutativeExpr*> > SCEVCommExprs;
279
280SCEVCommutativeExpr::~SCEVCommutativeExpr() {
281  SCEVCommExprs->erase(std::make_pair(getSCEVType(),
282                                      std::vector<SCEV*>(Operands.begin(),
283                                                         Operands.end())));
284}
285
286void SCEVCommutativeExpr::print(raw_ostream &OS) const {
287  assert(Operands.size() > 1 && "This plus expr shouldn't exist!");
288  const char *OpStr = getOperationStr();
289  OS << "(" << *Operands[0];
290  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
291    OS << OpStr << *Operands[i];
292  OS << ")";
293}
294
295SCEVHandle SCEVCommutativeExpr::
296replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
297                                  const SCEVHandle &Conc,
298                                  ScalarEvolution &SE) const {
299  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
300    SCEVHandle H =
301      getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE);
302    if (H != getOperand(i)) {
303      std::vector<SCEVHandle> NewOps;
304      NewOps.reserve(getNumOperands());
305      for (unsigned j = 0; j != i; ++j)
306        NewOps.push_back(getOperand(j));
307      NewOps.push_back(H);
308      for (++i; i != e; ++i)
309        NewOps.push_back(getOperand(i)->
310                         replaceSymbolicValuesWithConcrete(Sym, Conc, SE));
311
312      if (isa<SCEVAddExpr>(this))
313        return SE.getAddExpr(NewOps);
314      else if (isa<SCEVMulExpr>(this))
315        return SE.getMulExpr(NewOps);
316      else if (isa<SCEVSMaxExpr>(this))
317        return SE.getSMaxExpr(NewOps);
318      else if (isa<SCEVUMaxExpr>(this))
319        return SE.getUMaxExpr(NewOps);
320      else
321        assert(0 && "Unknown commutative expr!");
322    }
323  }
324  return this;
325}
326
327bool SCEVCommutativeExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
328  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
329    if (!getOperand(i)->dominates(BB, DT))
330      return false;
331  }
332  return true;
333}
334
335
336// SCEVUDivs - Only allow the creation of one SCEVUDivExpr for any particular
337// input.  Don't use a SCEVHandle here, or else the object will never be
338// deleted!
339static ManagedStatic<std::map<std::pair<SCEV*, SCEV*>,
340                     SCEVUDivExpr*> > SCEVUDivs;
341
342SCEVUDivExpr::~SCEVUDivExpr() {
343  SCEVUDivs->erase(std::make_pair(LHS, RHS));
344}
345
346bool SCEVUDivExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
347  return LHS->dominates(BB, DT) && RHS->dominates(BB, DT);
348}
349
350void SCEVUDivExpr::print(raw_ostream &OS) const {
351  OS << "(" << *LHS << " /u " << *RHS << ")";
352}
353
354const Type *SCEVUDivExpr::getType() const {
355  return LHS->getType();
356}
357
358// SCEVAddRecExprs - Only allow the creation of one SCEVAddRecExpr for any
359// particular input.  Don't use a SCEVHandle here, or else the object will never
360// be deleted!
361static ManagedStatic<std::map<std::pair<const Loop *, std::vector<SCEV*> >,
362                     SCEVAddRecExpr*> > SCEVAddRecExprs;
363
364SCEVAddRecExpr::~SCEVAddRecExpr() {
365  SCEVAddRecExprs->erase(std::make_pair(L,
366                                        std::vector<SCEV*>(Operands.begin(),
367                                                           Operands.end())));
368}
369
370bool SCEVAddRecExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
371  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
372    if (!getOperand(i)->dominates(BB, DT))
373      return false;
374  }
375  return true;
376}
377
378
379SCEVHandle SCEVAddRecExpr::
380replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
381                                  const SCEVHandle &Conc,
382                                  ScalarEvolution &SE) const {
383  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
384    SCEVHandle H =
385      getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE);
386    if (H != getOperand(i)) {
387      std::vector<SCEVHandle> NewOps;
388      NewOps.reserve(getNumOperands());
389      for (unsigned j = 0; j != i; ++j)
390        NewOps.push_back(getOperand(j));
391      NewOps.push_back(H);
392      for (++i; i != e; ++i)
393        NewOps.push_back(getOperand(i)->
394                         replaceSymbolicValuesWithConcrete(Sym, Conc, SE));
395
396      return SE.getAddRecExpr(NewOps, L);
397    }
398  }
399  return this;
400}
401
402
403bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const {
404  // This recurrence is invariant w.r.t to QueryLoop iff QueryLoop doesn't
405  // contain L and if the start is invariant.
406  return !QueryLoop->contains(L->getHeader()) &&
407         getOperand(0)->isLoopInvariant(QueryLoop);
408}
409
410
411void SCEVAddRecExpr::print(raw_ostream &OS) const {
412  OS << "{" << *Operands[0];
413  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
414    OS << ",+," << *Operands[i];
415  OS << "}<" << L->getHeader()->getName() + ">";
416}
417
418// SCEVUnknowns - Only allow the creation of one SCEVUnknown for any particular
419// value.  Don't use a SCEVHandle here, or else the object will never be
420// deleted!
421static ManagedStatic<std::map<Value*, SCEVUnknown*> > SCEVUnknowns;
422
423SCEVUnknown::~SCEVUnknown() { SCEVUnknowns->erase(V); }
424
425bool SCEVUnknown::isLoopInvariant(const Loop *L) const {
426  // All non-instruction values are loop invariant.  All instructions are loop
427  // invariant if they are not contained in the specified loop.
428  if (Instruction *I = dyn_cast<Instruction>(V))
429    return !L->contains(I->getParent());
430  return true;
431}
432
433bool SCEVUnknown::dominates(BasicBlock *BB, DominatorTree *DT) const {
434  if (Instruction *I = dyn_cast<Instruction>(getValue()))
435    return DT->dominates(I->getParent(), BB);
436  return true;
437}
438
439const Type *SCEVUnknown::getType() const {
440  return V->getType();
441}
442
443void SCEVUnknown::print(raw_ostream &OS) const {
444  if (isa<PointerType>(V->getType()))
445    OS << "(ptrtoint " << *V->getType() << " ";
446  WriteAsOperand(OS, V, false);
447  if (isa<PointerType>(V->getType()))
448    OS << " to iPTR)";
449}
450
451//===----------------------------------------------------------------------===//
452//                               SCEV Utilities
453//===----------------------------------------------------------------------===//
454
455namespace {
456  /// SCEVComplexityCompare - Return true if the complexity of the LHS is less
457  /// than the complexity of the RHS.  This comparator is used to canonicalize
458  /// expressions.
459  struct VISIBILITY_HIDDEN SCEVComplexityCompare {
460    bool operator()(const SCEV *LHS, const SCEV *RHS) const {
461      return LHS->getSCEVType() < RHS->getSCEVType();
462    }
463  };
464}
465
466/// GroupByComplexity - Given a list of SCEV objects, order them by their
467/// complexity, and group objects of the same complexity together by value.
468/// When this routine is finished, we know that any duplicates in the vector are
469/// consecutive and that complexity is monotonically increasing.
470///
471/// Note that we go take special precautions to ensure that we get determinstic
472/// results from this routine.  In other words, we don't want the results of
473/// this to depend on where the addresses of various SCEV objects happened to
474/// land in memory.
475///
476static void GroupByComplexity(std::vector<SCEVHandle> &Ops) {
477  if (Ops.size() < 2) return;  // Noop
478  if (Ops.size() == 2) {
479    // This is the common case, which also happens to be trivially simple.
480    // Special case it.
481    if (SCEVComplexityCompare()(Ops[1], Ops[0]))
482      std::swap(Ops[0], Ops[1]);
483    return;
484  }
485
486  // Do the rough sort by complexity.
487  std::sort(Ops.begin(), Ops.end(), SCEVComplexityCompare());
488
489  // Now that we are sorted by complexity, group elements of the same
490  // complexity.  Note that this is, at worst, N^2, but the vector is likely to
491  // be extremely short in practice.  Note that we take this approach because we
492  // do not want to depend on the addresses of the objects we are grouping.
493  for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
494    SCEV *S = Ops[i];
495    unsigned Complexity = S->getSCEVType();
496
497    // If there are any objects of the same complexity and same value as this
498    // one, group them.
499    for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
500      if (Ops[j] == S) { // Found a duplicate.
501        // Move it to immediately after i'th element.
502        std::swap(Ops[i+1], Ops[j]);
503        ++i;   // no need to rescan it.
504        if (i == e-2) return;  // Done!
505      }
506    }
507  }
508}
509
510
511
512//===----------------------------------------------------------------------===//
513//                      Simple SCEV method implementations
514//===----------------------------------------------------------------------===//
515
516/// BinomialCoefficient - Compute BC(It, K).  The result has width W.
517// Assume, K > 0.
518static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K,
519                                      ScalarEvolution &SE,
520                                      const Type* ResultTy) {
521  // Handle the simplest case efficiently.
522  if (K == 1)
523    return SE.getTruncateOrZeroExtend(It, ResultTy);
524
525  // We are using the following formula for BC(It, K):
526  //
527  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
528  //
529  // Suppose, W is the bitwidth of the return value.  We must be prepared for
530  // overflow.  Hence, we must assure that the result of our computation is
531  // equal to the accurate one modulo 2^W.  Unfortunately, division isn't
532  // safe in modular arithmetic.
533  //
534  // However, this code doesn't use exactly that formula; the formula it uses
535  // is something like the following, where T is the number of factors of 2 in
536  // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
537  // exponentiation:
538  //
539  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
540  //
541  // This formula is trivially equivalent to the previous formula.  However,
542  // this formula can be implemented much more efficiently.  The trick is that
543  // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
544  // arithmetic.  To do exact division in modular arithmetic, all we have
545  // to do is multiply by the inverse.  Therefore, this step can be done at
546  // width W.
547  //
548  // The next issue is how to safely do the division by 2^T.  The way this
549  // is done is by doing the multiplication step at a width of at least W + T
550  // bits.  This way, the bottom W+T bits of the product are accurate. Then,
551  // when we perform the division by 2^T (which is equivalent to a right shift
552  // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get
553  // truncated out after the division by 2^T.
554  //
555  // In comparison to just directly using the first formula, this technique
556  // is much more efficient; using the first formula requires W * K bits,
557  // but this formula less than W + K bits. Also, the first formula requires
558  // a division step, whereas this formula only requires multiplies and shifts.
559  //
560  // It doesn't matter whether the subtraction step is done in the calculation
561  // width or the input iteration count's width; if the subtraction overflows,
562  // the result must be zero anyway.  We prefer here to do it in the width of
563  // the induction variable because it helps a lot for certain cases; CodeGen
564  // isn't smart enough to ignore the overflow, which leads to much less
565  // efficient code if the width of the subtraction is wider than the native
566  // register width.
567  //
568  // (It's possible to not widen at all by pulling out factors of 2 before
569  // the multiplication; for example, K=2 can be calculated as
570  // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
571  // extra arithmetic, so it's not an obvious win, and it gets
572  // much more complicated for K > 3.)
573
574  // Protection from insane SCEVs; this bound is conservative,
575  // but it probably doesn't matter.
576  if (K > 1000)
577    return SE.getCouldNotCompute();
578
579  unsigned W = SE.getTargetData().getTypeSizeInBits(ResultTy);
580
581  // Calculate K! / 2^T and T; we divide out the factors of two before
582  // multiplying for calculating K! / 2^T to avoid overflow.
583  // Other overflow doesn't matter because we only care about the bottom
584  // W bits of the result.
585  APInt OddFactorial(W, 1);
586  unsigned T = 1;
587  for (unsigned i = 3; i <= K; ++i) {
588    APInt Mult(W, i);
589    unsigned TwoFactors = Mult.countTrailingZeros();
590    T += TwoFactors;
591    Mult = Mult.lshr(TwoFactors);
592    OddFactorial *= Mult;
593  }
594
595  // We need at least W + T bits for the multiplication step
596  unsigned CalculationBits = W + T;
597
598  // Calcuate 2^T, at width T+W.
599  APInt DivFactor = APInt(CalculationBits, 1).shl(T);
600
601  // Calculate the multiplicative inverse of K! / 2^T;
602  // this multiplication factor will perform the exact division by
603  // K! / 2^T.
604  APInt Mod = APInt::getSignedMinValue(W+1);
605  APInt MultiplyFactor = OddFactorial.zext(W+1);
606  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
607  MultiplyFactor = MultiplyFactor.trunc(W);
608
609  // Calculate the product, at width T+W
610  const IntegerType *CalculationTy = IntegerType::get(CalculationBits);
611  SCEVHandle Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
612  for (unsigned i = 1; i != K; ++i) {
613    SCEVHandle S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType()));
614    Dividend = SE.getMulExpr(Dividend,
615                             SE.getTruncateOrZeroExtend(S, CalculationTy));
616  }
617
618  // Divide by 2^T
619  SCEVHandle DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
620
621  // Truncate the result, and divide by K! / 2^T.
622
623  return SE.getMulExpr(SE.getConstant(MultiplyFactor),
624                       SE.getTruncateOrZeroExtend(DivResult, ResultTy));
625}
626
627/// evaluateAtIteration - Return the value of this chain of recurrences at
628/// the specified iteration number.  We can evaluate this recurrence by
629/// multiplying each element in the chain by the binomial coefficient
630/// corresponding to it.  In other words, we can evaluate {A,+,B,+,C,+,D} as:
631///
632///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
633///
634/// where BC(It, k) stands for binomial coefficient.
635///
636SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It,
637                                               ScalarEvolution &SE) const {
638  SCEVHandle Result = getStart();
639  for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
640    // The computation is correct in the face of overflow provided that the
641    // multiplication is performed _after_ the evaluation of the binomial
642    // coefficient.
643    SCEVHandle Coeff = BinomialCoefficient(It, i, SE, getType());
644    if (isa<SCEVCouldNotCompute>(Coeff))
645      return Coeff;
646
647    Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
648  }
649  return Result;
650}
651
652//===----------------------------------------------------------------------===//
653//                    SCEV Expression folder implementations
654//===----------------------------------------------------------------------===//
655
656SCEVHandle ScalarEvolution::getTruncateExpr(const SCEVHandle &Op, const Type *Ty) {
657  if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
658    return getUnknown(
659        ConstantExpr::getTrunc(SC->getValue(), Ty));
660
661  // If the input value is a chrec scev made out of constants, truncate
662  // all of the constants.
663  if (SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
664    std::vector<SCEVHandle> Operands;
665    for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
666      // FIXME: This should allow truncation of other expression types!
667      if (isa<SCEVConstant>(AddRec->getOperand(i)))
668        Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty));
669      else
670        break;
671    if (Operands.size() == AddRec->getNumOperands())
672      return getAddRecExpr(Operands, AddRec->getLoop());
673  }
674
675  SCEVTruncateExpr *&Result = (*SCEVTruncates)[std::make_pair(Op, Ty)];
676  if (Result == 0) Result = new SCEVTruncateExpr(Op, Ty);
677  return Result;
678}
679
680SCEVHandle ScalarEvolution::getZeroExtendExpr(const SCEVHandle &Op,
681                                              const Type *Ty) {
682  assert(getTargetData().getTypeSizeInBits(Op->getType()) <
683         getTargetData().getTypeSizeInBits(Ty) &&
684         "This is not an extending conversion!");
685
686  if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) {
687    const Type *IntTy = Ty;
688    if (isa<PointerType>(IntTy)) IntTy = getTargetData().getIntPtrType();
689    Constant *C = ConstantExpr::getZExt(SC->getValue(), IntTy);
690    if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty);
691    return getUnknown(C);
692  }
693
694  // FIXME: If the input value is a chrec scev, and we can prove that the value
695  // did not overflow the old, smaller, value, we can zero extend all of the
696  // operands (often constants).  This would allow analysis of something like
697  // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
698
699  SCEVZeroExtendExpr *&Result = (*SCEVZeroExtends)[std::make_pair(Op, Ty)];
700  if (Result == 0) Result = new SCEVZeroExtendExpr(Op, Ty);
701  return Result;
702}
703
704SCEVHandle ScalarEvolution::getSignExtendExpr(const SCEVHandle &Op, const Type *Ty) {
705  if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) {
706    const Type *IntTy = Ty;
707    if (isa<PointerType>(IntTy)) IntTy = getTargetData().getIntPtrType();
708    Constant *C = ConstantExpr::getSExt(SC->getValue(), IntTy);
709    if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty);
710    return getUnknown(C);
711  }
712
713  // FIXME: If the input value is a chrec scev, and we can prove that the value
714  // did not overflow the old, smaller, value, we can sign extend all of the
715  // operands (often constants).  This would allow analysis of something like
716  // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
717
718  SCEVSignExtendExpr *&Result = (*SCEVSignExtends)[std::make_pair(Op, Ty)];
719  if (Result == 0) Result = new SCEVSignExtendExpr(Op, Ty);
720  return Result;
721}
722
723// get - Get a canonical add expression, or something simpler if possible.
724SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) {
725  assert(!Ops.empty() && "Cannot get empty add!");
726  if (Ops.size() == 1) return Ops[0];
727
728  // Sort by complexity, this groups all similar expression types together.
729  GroupByComplexity(Ops);
730
731  // If there are any constants, fold them together.
732  unsigned Idx = 0;
733  if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
734    ++Idx;
735    assert(Idx < Ops.size());
736    while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
737      // We found two constants, fold them together!
738      ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() +
739                                           RHSC->getValue()->getValue());
740      Ops[0] = getConstant(Fold);
741      Ops.erase(Ops.begin()+1);  // Erase the folded element
742      if (Ops.size() == 1) return Ops[0];
743      LHSC = cast<SCEVConstant>(Ops[0]);
744    }
745
746    // If we are left with a constant zero being added, strip it off.
747    if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
748      Ops.erase(Ops.begin());
749      --Idx;
750    }
751  }
752
753  if (Ops.size() == 1) return Ops[0];
754
755  // Okay, check to see if the same value occurs in the operand list twice.  If
756  // so, merge them together into an multiply expression.  Since we sorted the
757  // list, these values are required to be adjacent.
758  const Type *Ty = Ops[0]->getType();
759  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
760    if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2
761      // Found a match, merge the two values into a multiply, and add any
762      // remaining values to the result.
763      SCEVHandle Two = getIntegerSCEV(2, Ty);
764      SCEVHandle Mul = getMulExpr(Ops[i], Two);
765      if (Ops.size() == 2)
766        return Mul;
767      Ops.erase(Ops.begin()+i, Ops.begin()+i+2);
768      Ops.push_back(Mul);
769      return getAddExpr(Ops);
770    }
771
772  // Now we know the first non-constant operand.  Skip past any cast SCEVs.
773  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
774    ++Idx;
775
776  // If there are add operands they would be next.
777  if (Idx < Ops.size()) {
778    bool DeletedAdd = false;
779    while (SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
780      // If we have an add, expand the add operands onto the end of the operands
781      // list.
782      Ops.insert(Ops.end(), Add->op_begin(), Add->op_end());
783      Ops.erase(Ops.begin()+Idx);
784      DeletedAdd = true;
785    }
786
787    // If we deleted at least one add, we added operands to the end of the list,
788    // and they are not necessarily sorted.  Recurse to resort and resimplify
789    // any operands we just aquired.
790    if (DeletedAdd)
791      return getAddExpr(Ops);
792  }
793
794  // Skip over the add expression until we get to a multiply.
795  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
796    ++Idx;
797
798  // If we are adding something to a multiply expression, make sure the
799  // something is not already an operand of the multiply.  If so, merge it into
800  // the multiply.
801  for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
802    SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
803    for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
804      SCEV *MulOpSCEV = Mul->getOperand(MulOp);
805      for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
806        if (MulOpSCEV == Ops[AddOp] && !isa<SCEVConstant>(MulOpSCEV)) {
807          // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1))
808          SCEVHandle InnerMul = Mul->getOperand(MulOp == 0);
809          if (Mul->getNumOperands() != 2) {
810            // If the multiply has more than two operands, we must get the
811            // Y*Z term.
812            std::vector<SCEVHandle> MulOps(Mul->op_begin(), Mul->op_end());
813            MulOps.erase(MulOps.begin()+MulOp);
814            InnerMul = getMulExpr(MulOps);
815          }
816          SCEVHandle One = getIntegerSCEV(1, Ty);
817          SCEVHandle AddOne = getAddExpr(InnerMul, One);
818          SCEVHandle OuterMul = getMulExpr(AddOne, Ops[AddOp]);
819          if (Ops.size() == 2) return OuterMul;
820          if (AddOp < Idx) {
821            Ops.erase(Ops.begin()+AddOp);
822            Ops.erase(Ops.begin()+Idx-1);
823          } else {
824            Ops.erase(Ops.begin()+Idx);
825            Ops.erase(Ops.begin()+AddOp-1);
826          }
827          Ops.push_back(OuterMul);
828          return getAddExpr(Ops);
829        }
830
831      // Check this multiply against other multiplies being added together.
832      for (unsigned OtherMulIdx = Idx+1;
833           OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
834           ++OtherMulIdx) {
835        SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
836        // If MulOp occurs in OtherMul, we can fold the two multiplies
837        // together.
838        for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
839             OMulOp != e; ++OMulOp)
840          if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
841            // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
842            SCEVHandle InnerMul1 = Mul->getOperand(MulOp == 0);
843            if (Mul->getNumOperands() != 2) {
844              std::vector<SCEVHandle> MulOps(Mul->op_begin(), Mul->op_end());
845              MulOps.erase(MulOps.begin()+MulOp);
846              InnerMul1 = getMulExpr(MulOps);
847            }
848            SCEVHandle InnerMul2 = OtherMul->getOperand(OMulOp == 0);
849            if (OtherMul->getNumOperands() != 2) {
850              std::vector<SCEVHandle> MulOps(OtherMul->op_begin(),
851                                             OtherMul->op_end());
852              MulOps.erase(MulOps.begin()+OMulOp);
853              InnerMul2 = getMulExpr(MulOps);
854            }
855            SCEVHandle InnerMulSum = getAddExpr(InnerMul1,InnerMul2);
856            SCEVHandle OuterMul = getMulExpr(MulOpSCEV, InnerMulSum);
857            if (Ops.size() == 2) return OuterMul;
858            Ops.erase(Ops.begin()+Idx);
859            Ops.erase(Ops.begin()+OtherMulIdx-1);
860            Ops.push_back(OuterMul);
861            return getAddExpr(Ops);
862          }
863      }
864    }
865  }
866
867  // If there are any add recurrences in the operands list, see if any other
868  // added values are loop invariant.  If so, we can fold them into the
869  // recurrence.
870  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
871    ++Idx;
872
873  // Scan over all recurrences, trying to fold loop invariants into them.
874  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
875    // Scan all of the other operands to this add and add them to the vector if
876    // they are loop invariant w.r.t. the recurrence.
877    std::vector<SCEVHandle> LIOps;
878    SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
879    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
880      if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
881        LIOps.push_back(Ops[i]);
882        Ops.erase(Ops.begin()+i);
883        --i; --e;
884      }
885
886    // If we found some loop invariants, fold them into the recurrence.
887    if (!LIOps.empty()) {
888      //  NLI + LI + {Start,+,Step}  -->  NLI + {LI+Start,+,Step}
889      LIOps.push_back(AddRec->getStart());
890
891      std::vector<SCEVHandle> AddRecOps(AddRec->op_begin(), AddRec->op_end());
892      AddRecOps[0] = getAddExpr(LIOps);
893
894      SCEVHandle NewRec = getAddRecExpr(AddRecOps, AddRec->getLoop());
895      // If all of the other operands were loop invariant, we are done.
896      if (Ops.size() == 1) return NewRec;
897
898      // Otherwise, add the folded AddRec by the non-liv parts.
899      for (unsigned i = 0;; ++i)
900        if (Ops[i] == AddRec) {
901          Ops[i] = NewRec;
902          break;
903        }
904      return getAddExpr(Ops);
905    }
906
907    // Okay, if there weren't any loop invariants to be folded, check to see if
908    // there are multiple AddRec's with the same loop induction variable being
909    // added together.  If so, we can fold them.
910    for (unsigned OtherIdx = Idx+1;
911         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
912      if (OtherIdx != Idx) {
913        SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
914        if (AddRec->getLoop() == OtherAddRec->getLoop()) {
915          // Other + {A,+,B} + {C,+,D}  -->  Other + {A+C,+,B+D}
916          std::vector<SCEVHandle> NewOps(AddRec->op_begin(), AddRec->op_end());
917          for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) {
918            if (i >= NewOps.size()) {
919              NewOps.insert(NewOps.end(), OtherAddRec->op_begin()+i,
920                            OtherAddRec->op_end());
921              break;
922            }
923            NewOps[i] = getAddExpr(NewOps[i], OtherAddRec->getOperand(i));
924          }
925          SCEVHandle NewAddRec = getAddRecExpr(NewOps, AddRec->getLoop());
926
927          if (Ops.size() == 2) return NewAddRec;
928
929          Ops.erase(Ops.begin()+Idx);
930          Ops.erase(Ops.begin()+OtherIdx-1);
931          Ops.push_back(NewAddRec);
932          return getAddExpr(Ops);
933        }
934      }
935
936    // Otherwise couldn't fold anything into this recurrence.  Move onto the
937    // next one.
938  }
939
940  // Okay, it looks like we really DO need an add expr.  Check to see if we
941  // already have one, otherwise create a new one.
942  std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end());
943  SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scAddExpr,
944                                                                 SCEVOps)];
945  if (Result == 0) Result = new SCEVAddExpr(Ops);
946  return Result;
947}
948
949
950SCEVHandle ScalarEvolution::getMulExpr(std::vector<SCEVHandle> &Ops) {
951  assert(!Ops.empty() && "Cannot get empty mul!");
952
953  // Sort by complexity, this groups all similar expression types together.
954  GroupByComplexity(Ops);
955
956  // If there are any constants, fold them together.
957  unsigned Idx = 0;
958  if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
959
960    // C1*(C2+V) -> C1*C2 + C1*V
961    if (Ops.size() == 2)
962      if (SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
963        if (Add->getNumOperands() == 2 &&
964            isa<SCEVConstant>(Add->getOperand(0)))
965          return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
966                            getMulExpr(LHSC, Add->getOperand(1)));
967
968
969    ++Idx;
970    while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
971      // We found two constants, fold them together!
972      ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() *
973                                           RHSC->getValue()->getValue());
974      Ops[0] = getConstant(Fold);
975      Ops.erase(Ops.begin()+1);  // Erase the folded element
976      if (Ops.size() == 1) return Ops[0];
977      LHSC = cast<SCEVConstant>(Ops[0]);
978    }
979
980    // If we are left with a constant one being multiplied, strip it off.
981    if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) {
982      Ops.erase(Ops.begin());
983      --Idx;
984    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
985      // If we have a multiply of zero, it will always be zero.
986      return Ops[0];
987    }
988  }
989
990  // Skip over the add expression until we get to a multiply.
991  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
992    ++Idx;
993
994  if (Ops.size() == 1)
995    return Ops[0];
996
997  // If there are mul operands inline them all into this expression.
998  if (Idx < Ops.size()) {
999    bool DeletedMul = false;
1000    while (SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
1001      // If we have an mul, expand the mul operands onto the end of the operands
1002      // list.
1003      Ops.insert(Ops.end(), Mul->op_begin(), Mul->op_end());
1004      Ops.erase(Ops.begin()+Idx);
1005      DeletedMul = true;
1006    }
1007
1008    // If we deleted at least one mul, we added operands to the end of the list,
1009    // and they are not necessarily sorted.  Recurse to resort and resimplify
1010    // any operands we just aquired.
1011    if (DeletedMul)
1012      return getMulExpr(Ops);
1013  }
1014
1015  // If there are any add recurrences in the operands list, see if any other
1016  // added values are loop invariant.  If so, we can fold them into the
1017  // recurrence.
1018  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1019    ++Idx;
1020
1021  // Scan over all recurrences, trying to fold loop invariants into them.
1022  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1023    // Scan all of the other operands to this mul and add them to the vector if
1024    // they are loop invariant w.r.t. the recurrence.
1025    std::vector<SCEVHandle> LIOps;
1026    SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1027    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1028      if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
1029        LIOps.push_back(Ops[i]);
1030        Ops.erase(Ops.begin()+i);
1031        --i; --e;
1032      }
1033
1034    // If we found some loop invariants, fold them into the recurrence.
1035    if (!LIOps.empty()) {
1036      //  NLI * LI * {Start,+,Step}  -->  NLI * {LI*Start,+,LI*Step}
1037      std::vector<SCEVHandle> NewOps;
1038      NewOps.reserve(AddRec->getNumOperands());
1039      if (LIOps.size() == 1) {
1040        SCEV *Scale = LIOps[0];
1041        for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
1042          NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i)));
1043      } else {
1044        for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
1045          std::vector<SCEVHandle> MulOps(LIOps);
1046          MulOps.push_back(AddRec->getOperand(i));
1047          NewOps.push_back(getMulExpr(MulOps));
1048        }
1049      }
1050
1051      SCEVHandle NewRec = getAddRecExpr(NewOps, AddRec->getLoop());
1052
1053      // If all of the other operands were loop invariant, we are done.
1054      if (Ops.size() == 1) return NewRec;
1055
1056      // Otherwise, multiply the folded AddRec by the non-liv parts.
1057      for (unsigned i = 0;; ++i)
1058        if (Ops[i] == AddRec) {
1059          Ops[i] = NewRec;
1060          break;
1061        }
1062      return getMulExpr(Ops);
1063    }
1064
1065    // Okay, if there weren't any loop invariants to be folded, check to see if
1066    // there are multiple AddRec's with the same loop induction variable being
1067    // multiplied together.  If so, we can fold them.
1068    for (unsigned OtherIdx = Idx+1;
1069         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
1070      if (OtherIdx != Idx) {
1071        SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
1072        if (AddRec->getLoop() == OtherAddRec->getLoop()) {
1073          // F * G  -->  {A,+,B} * {C,+,D}  -->  {A*C,+,F*D + G*B + B*D}
1074          SCEVAddRecExpr *F = AddRec, *G = OtherAddRec;
1075          SCEVHandle NewStart = getMulExpr(F->getStart(),
1076                                                 G->getStart());
1077          SCEVHandle B = F->getStepRecurrence(*this);
1078          SCEVHandle D = G->getStepRecurrence(*this);
1079          SCEVHandle NewStep = getAddExpr(getMulExpr(F, D),
1080                                          getMulExpr(G, B),
1081                                          getMulExpr(B, D));
1082          SCEVHandle NewAddRec = getAddRecExpr(NewStart, NewStep,
1083                                               F->getLoop());
1084          if (Ops.size() == 2) return NewAddRec;
1085
1086          Ops.erase(Ops.begin()+Idx);
1087          Ops.erase(Ops.begin()+OtherIdx-1);
1088          Ops.push_back(NewAddRec);
1089          return getMulExpr(Ops);
1090        }
1091      }
1092
1093    // Otherwise couldn't fold anything into this recurrence.  Move onto the
1094    // next one.
1095  }
1096
1097  // Okay, it looks like we really DO need an mul expr.  Check to see if we
1098  // already have one, otherwise create a new one.
1099  std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end());
1100  SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scMulExpr,
1101                                                                 SCEVOps)];
1102  if (Result == 0)
1103    Result = new SCEVMulExpr(Ops);
1104  return Result;
1105}
1106
1107SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, const SCEVHandle &RHS) {
1108  if (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
1109    if (RHSC->getValue()->equalsInt(1))
1110      return LHS;                            // X udiv 1 --> x
1111
1112    if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
1113      Constant *LHSCV = LHSC->getValue();
1114      Constant *RHSCV = RHSC->getValue();
1115      return getUnknown(ConstantExpr::getUDiv(LHSCV, RHSCV));
1116    }
1117  }
1118
1119  // FIXME: implement folding of (X*4)/4 when we know X*4 doesn't overflow.
1120
1121  SCEVUDivExpr *&Result = (*SCEVUDivs)[std::make_pair(LHS, RHS)];
1122  if (Result == 0) Result = new SCEVUDivExpr(LHS, RHS);
1123  return Result;
1124}
1125
1126
1127/// SCEVAddRecExpr::get - Get a add recurrence expression for the
1128/// specified loop.  Simplify the expression as much as possible.
1129SCEVHandle ScalarEvolution::getAddRecExpr(const SCEVHandle &Start,
1130                               const SCEVHandle &Step, const Loop *L) {
1131  std::vector<SCEVHandle> Operands;
1132  Operands.push_back(Start);
1133  if (SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
1134    if (StepChrec->getLoop() == L) {
1135      Operands.insert(Operands.end(), StepChrec->op_begin(),
1136                      StepChrec->op_end());
1137      return getAddRecExpr(Operands, L);
1138    }
1139
1140  Operands.push_back(Step);
1141  return getAddRecExpr(Operands, L);
1142}
1143
1144/// SCEVAddRecExpr::get - Get a add recurrence expression for the
1145/// specified loop.  Simplify the expression as much as possible.
1146SCEVHandle ScalarEvolution::getAddRecExpr(std::vector<SCEVHandle> &Operands,
1147                               const Loop *L) {
1148  if (Operands.size() == 1) return Operands[0];
1149
1150  if (Operands.back()->isZero()) {
1151    Operands.pop_back();
1152    return getAddRecExpr(Operands, L);             // {X,+,0}  -->  X
1153  }
1154
1155  // Canonicalize nested AddRecs in by nesting them in order of loop depth.
1156  if (SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
1157    const Loop* NestedLoop = NestedAR->getLoop();
1158    if (L->getLoopDepth() < NestedLoop->getLoopDepth()) {
1159      std::vector<SCEVHandle> NestedOperands(NestedAR->op_begin(),
1160                                             NestedAR->op_end());
1161      SCEVHandle NestedARHandle(NestedAR);
1162      Operands[0] = NestedAR->getStart();
1163      NestedOperands[0] = getAddRecExpr(Operands, L);
1164      return getAddRecExpr(NestedOperands, NestedLoop);
1165    }
1166  }
1167
1168  SCEVAddRecExpr *&Result =
1169    (*SCEVAddRecExprs)[std::make_pair(L, std::vector<SCEV*>(Operands.begin(),
1170                                                            Operands.end()))];
1171  if (Result == 0) Result = new SCEVAddRecExpr(Operands, L);
1172  return Result;
1173}
1174
1175SCEVHandle ScalarEvolution::getSMaxExpr(const SCEVHandle &LHS,
1176                                        const SCEVHandle &RHS) {
1177  std::vector<SCEVHandle> Ops;
1178  Ops.push_back(LHS);
1179  Ops.push_back(RHS);
1180  return getSMaxExpr(Ops);
1181}
1182
1183SCEVHandle ScalarEvolution::getSMaxExpr(std::vector<SCEVHandle> Ops) {
1184  assert(!Ops.empty() && "Cannot get empty smax!");
1185  if (Ops.size() == 1) return Ops[0];
1186
1187  // Sort by complexity, this groups all similar expression types together.
1188  GroupByComplexity(Ops);
1189
1190  // If there are any constants, fold them together.
1191  unsigned Idx = 0;
1192  if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1193    ++Idx;
1194    assert(Idx < Ops.size());
1195    while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1196      // We found two constants, fold them together!
1197      ConstantInt *Fold = ConstantInt::get(
1198                              APIntOps::smax(LHSC->getValue()->getValue(),
1199                                             RHSC->getValue()->getValue()));
1200      Ops[0] = getConstant(Fold);
1201      Ops.erase(Ops.begin()+1);  // Erase the folded element
1202      if (Ops.size() == 1) return Ops[0];
1203      LHSC = cast<SCEVConstant>(Ops[0]);
1204    }
1205
1206    // If we are left with a constant -inf, strip it off.
1207    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
1208      Ops.erase(Ops.begin());
1209      --Idx;
1210    }
1211  }
1212
1213  if (Ops.size() == 1) return Ops[0];
1214
1215  // Find the first SMax
1216  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
1217    ++Idx;
1218
1219  // Check to see if one of the operands is an SMax. If so, expand its operands
1220  // onto our operand list, and recurse to simplify.
1221  if (Idx < Ops.size()) {
1222    bool DeletedSMax = false;
1223    while (SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
1224      Ops.insert(Ops.end(), SMax->op_begin(), SMax->op_end());
1225      Ops.erase(Ops.begin()+Idx);
1226      DeletedSMax = true;
1227    }
1228
1229    if (DeletedSMax)
1230      return getSMaxExpr(Ops);
1231  }
1232
1233  // Okay, check to see if the same value occurs in the operand list twice.  If
1234  // so, delete one.  Since we sorted the list, these values are required to
1235  // be adjacent.
1236  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1237    if (Ops[i] == Ops[i+1]) {      //  X smax Y smax Y  -->  X smax Y
1238      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
1239      --i; --e;
1240    }
1241
1242  if (Ops.size() == 1) return Ops[0];
1243
1244  assert(!Ops.empty() && "Reduced smax down to nothing!");
1245
1246  // Okay, it looks like we really DO need an smax expr.  Check to see if we
1247  // already have one, otherwise create a new one.
1248  std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end());
1249  SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scSMaxExpr,
1250                                                                 SCEVOps)];
1251  if (Result == 0) Result = new SCEVSMaxExpr(Ops);
1252  return Result;
1253}
1254
1255SCEVHandle ScalarEvolution::getUMaxExpr(const SCEVHandle &LHS,
1256                                        const SCEVHandle &RHS) {
1257  std::vector<SCEVHandle> Ops;
1258  Ops.push_back(LHS);
1259  Ops.push_back(RHS);
1260  return getUMaxExpr(Ops);
1261}
1262
1263SCEVHandle ScalarEvolution::getUMaxExpr(std::vector<SCEVHandle> Ops) {
1264  assert(!Ops.empty() && "Cannot get empty umax!");
1265  if (Ops.size() == 1) return Ops[0];
1266
1267  // Sort by complexity, this groups all similar expression types together.
1268  GroupByComplexity(Ops);
1269
1270  // If there are any constants, fold them together.
1271  unsigned Idx = 0;
1272  if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1273    ++Idx;
1274    assert(Idx < Ops.size());
1275    while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1276      // We found two constants, fold them together!
1277      ConstantInt *Fold = ConstantInt::get(
1278                              APIntOps::umax(LHSC->getValue()->getValue(),
1279                                             RHSC->getValue()->getValue()));
1280      Ops[0] = getConstant(Fold);
1281      Ops.erase(Ops.begin()+1);  // Erase the folded element
1282      if (Ops.size() == 1) return Ops[0];
1283      LHSC = cast<SCEVConstant>(Ops[0]);
1284    }
1285
1286    // If we are left with a constant zero, strip it off.
1287    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) {
1288      Ops.erase(Ops.begin());
1289      --Idx;
1290    }
1291  }
1292
1293  if (Ops.size() == 1) return Ops[0];
1294
1295  // Find the first UMax
1296  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr)
1297    ++Idx;
1298
1299  // Check to see if one of the operands is a UMax. If so, expand its operands
1300  // onto our operand list, and recurse to simplify.
1301  if (Idx < Ops.size()) {
1302    bool DeletedUMax = false;
1303    while (SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) {
1304      Ops.insert(Ops.end(), UMax->op_begin(), UMax->op_end());
1305      Ops.erase(Ops.begin()+Idx);
1306      DeletedUMax = true;
1307    }
1308
1309    if (DeletedUMax)
1310      return getUMaxExpr(Ops);
1311  }
1312
1313  // Okay, check to see if the same value occurs in the operand list twice.  If
1314  // so, delete one.  Since we sorted the list, these values are required to
1315  // be adjacent.
1316  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1317    if (Ops[i] == Ops[i+1]) {      //  X umax Y umax Y  -->  X umax Y
1318      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
1319      --i; --e;
1320    }
1321
1322  if (Ops.size() == 1) return Ops[0];
1323
1324  assert(!Ops.empty() && "Reduced umax down to nothing!");
1325
1326  // Okay, it looks like we really DO need a umax expr.  Check to see if we
1327  // already have one, otherwise create a new one.
1328  std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end());
1329  SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scUMaxExpr,
1330                                                                 SCEVOps)];
1331  if (Result == 0) Result = new SCEVUMaxExpr(Ops);
1332  return Result;
1333}
1334
1335SCEVHandle ScalarEvolution::getUnknown(Value *V) {
1336  if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
1337    return getConstant(CI);
1338  if (isa<ConstantPointerNull>(V))
1339    return getIntegerSCEV(0, V->getType());
1340  SCEVUnknown *&Result = (*SCEVUnknowns)[V];
1341  if (Result == 0) Result = new SCEVUnknown(V);
1342  return Result;
1343}
1344
1345//===----------------------------------------------------------------------===//
1346//             ScalarEvolutionsImpl Definition and Implementation
1347//===----------------------------------------------------------------------===//
1348//
1349/// ScalarEvolutionsImpl - This class implements the main driver for the scalar
1350/// evolution code.
1351///
1352namespace {
1353  struct VISIBILITY_HIDDEN ScalarEvolutionsImpl {
1354    /// SE - A reference to the public ScalarEvolution object.
1355    ScalarEvolution &SE;
1356
1357    /// F - The function we are analyzing.
1358    ///
1359    Function &F;
1360
1361    /// LI - The loop information for the function we are currently analyzing.
1362    ///
1363    LoopInfo &LI;
1364
1365    /// TD - The target data information for the target we are targetting.
1366    ///
1367    TargetData &TD;
1368
1369    /// UnknownValue - This SCEV is used to represent unknown trip counts and
1370    /// things.
1371    SCEVHandle UnknownValue;
1372
1373    /// Scalars - This is a cache of the scalars we have analyzed so far.
1374    ///
1375    std::map<Value*, SCEVHandle> Scalars;
1376
1377    /// BackedgeTakenCounts - Cache the backedge-taken count of the loops for
1378    /// this function as they are computed.
1379    std::map<const Loop*, SCEVHandle> BackedgeTakenCounts;
1380
1381    /// ConstantEvolutionLoopExitValue - This map contains entries for all of
1382    /// the PHI instructions that we attempt to compute constant evolutions for.
1383    /// This allows us to avoid potentially expensive recomputation of these
1384    /// properties.  An instruction maps to null if we are unable to compute its
1385    /// exit value.
1386    std::map<PHINode*, Constant*> ConstantEvolutionLoopExitValue;
1387
1388  public:
1389    ScalarEvolutionsImpl(ScalarEvolution &se, Function &f, LoopInfo &li,
1390                         TargetData &td)
1391      : SE(se), F(f), LI(li), TD(td), UnknownValue(new SCEVCouldNotCompute()) {}
1392
1393    SCEVHandle getCouldNotCompute();
1394
1395    /// getIntegerSCEV - Given an integer or FP type, create a constant for the
1396    /// specified signed integer value and return a SCEV for the constant.
1397    SCEVHandle getIntegerSCEV(int Val, const Type *Ty);
1398
1399    /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
1400    ///
1401    SCEVHandle getNegativeSCEV(const SCEVHandle &V);
1402
1403    /// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
1404    ///
1405    SCEVHandle getNotSCEV(const SCEVHandle &V);
1406
1407    /// getMinusSCEV - Return a SCEV corresponding to LHS - RHS.
1408    ///
1409    SCEVHandle getMinusSCEV(const SCEVHandle &LHS, const SCEVHandle &RHS);
1410
1411    /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion
1412    /// of the input value to the specified type.  If the type must be extended,
1413    /// it is zero extended.
1414    SCEVHandle getTruncateOrZeroExtend(const SCEVHandle &V, const Type *Ty);
1415
1416    /// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion
1417    /// of the input value to the specified type.  If the type must be extended,
1418    /// it is sign extended.
1419    SCEVHandle getTruncateOrSignExtend(const SCEVHandle &V, const Type *Ty);
1420
1421    /// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
1422    /// expression and create a new one.
1423    SCEVHandle getSCEV(Value *V);
1424
1425    /// hasSCEV - Return true if the SCEV for this value has already been
1426    /// computed.
1427    bool hasSCEV(Value *V) const {
1428      return Scalars.count(V);
1429    }
1430
1431    /// setSCEV - Insert the specified SCEV into the map of current SCEVs for
1432    /// the specified value.
1433    void setSCEV(Value *V, const SCEVHandle &H) {
1434      bool isNew = Scalars.insert(std::make_pair(V, H)).second;
1435      assert(isNew && "This entry already existed!");
1436      isNew = false;
1437    }
1438
1439
1440    /// getSCEVAtScope - Compute the value of the specified expression within
1441    /// the indicated loop (which may be null to indicate in no loop).  If the
1442    /// expression cannot be evaluated, return UnknownValue itself.
1443    SCEVHandle getSCEVAtScope(SCEV *V, const Loop *L);
1444
1445
1446    /// isLoopGuardedByCond - Test whether entry to the loop is protected by
1447    /// a conditional between LHS and RHS.
1448    bool isLoopGuardedByCond(const Loop *L, ICmpInst::Predicate Pred,
1449                             SCEV *LHS, SCEV *RHS);
1450
1451    /// hasLoopInvariantBackedgeTakenCount - Return true if the specified loop
1452    /// has an analyzable loop-invariant backedge-taken count.
1453    bool hasLoopInvariantBackedgeTakenCount(const Loop *L);
1454
1455    /// forgetLoopBackedgeTakenCount - This method should be called by the
1456    /// client when it has changed a loop in a way that may effect
1457    /// ScalarEvolution's ability to compute a trip count, or if the loop
1458    /// is deleted.
1459    void forgetLoopBackedgeTakenCount(const Loop *L);
1460
1461    /// getBackedgeTakenCount - If the specified loop has a predictable
1462    /// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute
1463    /// object. The backedge-taken count is the number of times the loop header
1464    /// will be branched to from within the loop. This is one less than the
1465    /// trip count of the loop, since it doesn't count the first iteration,
1466    /// when the header is branched to from outside the loop.
1467    ///
1468    /// Note that it is not valid to call this method on a loop without a
1469    /// loop-invariant backedge-taken count (see
1470    /// hasLoopInvariantBackedgeTakenCount).
1471    ///
1472    SCEVHandle getBackedgeTakenCount(const Loop *L);
1473
1474    /// deleteValueFromRecords - This method should be called by the
1475    /// client before it removes a value from the program, to make sure
1476    /// that no dangling references are left around.
1477    void deleteValueFromRecords(Value *V);
1478
1479    /// getTargetData - Return the TargetData.
1480    const TargetData &getTargetData() const;
1481
1482  private:
1483    /// createSCEV - We know that there is no SCEV for the specified value.
1484    /// Analyze the expression.
1485    SCEVHandle createSCEV(Value *V);
1486
1487    /// createNodeForPHI - Provide the special handling we need to analyze PHI
1488    /// SCEVs.
1489    SCEVHandle createNodeForPHI(PHINode *PN);
1490
1491    /// ReplaceSymbolicValueWithConcrete - This looks up the computed SCEV value
1492    /// for the specified instruction and replaces any references to the
1493    /// symbolic value SymName with the specified value.  This is used during
1494    /// PHI resolution.
1495    void ReplaceSymbolicValueWithConcrete(Instruction *I,
1496                                          const SCEVHandle &SymName,
1497                                          const SCEVHandle &NewVal);
1498
1499    /// ComputeBackedgeTakenCount - Compute the number of times the specified
1500    /// loop will iterate.
1501    SCEVHandle ComputeBackedgeTakenCount(const Loop *L);
1502
1503    /// ComputeLoadConstantCompareBackedgeTakenCount - Given an exit condition
1504    /// of 'icmp op load X, cst', try to see if we can compute the trip count.
1505    SCEVHandle
1506      ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI,
1507                                                   Constant *RHS,
1508                                                   const Loop *L,
1509                                                   ICmpInst::Predicate p);
1510
1511    /// ComputeBackedgeTakenCountExhaustively - If the trip is known to execute
1512    /// a constant number of times (the condition evolves only from constants),
1513    /// try to evaluate a few iterations of the loop until we get the exit
1514    /// condition gets a value of ExitWhen (true or false).  If we cannot
1515    /// evaluate the trip count of the loop, return UnknownValue.
1516    SCEVHandle ComputeBackedgeTakenCountExhaustively(const Loop *L, Value *Cond,
1517                                                     bool ExitWhen);
1518
1519    /// HowFarToZero - Return the number of times a backedge comparing the
1520    /// specified value to zero will execute.  If not computable, return
1521    /// UnknownValue.
1522    SCEVHandle HowFarToZero(SCEV *V, const Loop *L);
1523
1524    /// HowFarToNonZero - Return the number of times a backedge checking the
1525    /// specified value for nonzero will execute.  If not computable, return
1526    /// UnknownValue.
1527    SCEVHandle HowFarToNonZero(SCEV *V, const Loop *L);
1528
1529    /// HowManyLessThans - Return the number of times a backedge containing the
1530    /// specified less-than comparison will execute.  If not computable, return
1531    /// UnknownValue. isSigned specifies whether the less-than is signed.
1532    SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L,
1533                                bool isSigned);
1534
1535    /// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
1536    /// (which may not be an immediate predecessor) which has exactly one
1537    /// successor from which BB is reachable, or null if no such block is
1538    /// found.
1539    BasicBlock* getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB);
1540
1541    /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
1542    /// in the header of its containing loop, we know the loop executes a
1543    /// constant number of times, and the PHI node is just a recurrence
1544    /// involving constants, fold it.
1545    Constant *getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& BEs,
1546                                                const Loop *L);
1547  };
1548}
1549
1550//===----------------------------------------------------------------------===//
1551//            Basic SCEV Analysis and PHI Idiom Recognition Code
1552//
1553
1554/// deleteValueFromRecords - This method should be called by the
1555/// client before it removes an instruction from the program, to make sure
1556/// that no dangling references are left around.
1557void ScalarEvolutionsImpl::deleteValueFromRecords(Value *V) {
1558  SmallVector<Value *, 16> Worklist;
1559
1560  if (Scalars.erase(V)) {
1561    if (PHINode *PN = dyn_cast<PHINode>(V))
1562      ConstantEvolutionLoopExitValue.erase(PN);
1563    Worklist.push_back(V);
1564  }
1565
1566  while (!Worklist.empty()) {
1567    Value *VV = Worklist.back();
1568    Worklist.pop_back();
1569
1570    for (Instruction::use_iterator UI = VV->use_begin(), UE = VV->use_end();
1571         UI != UE; ++UI) {
1572      Instruction *Inst = cast<Instruction>(*UI);
1573      if (Scalars.erase(Inst)) {
1574        if (PHINode *PN = dyn_cast<PHINode>(VV))
1575          ConstantEvolutionLoopExitValue.erase(PN);
1576        Worklist.push_back(Inst);
1577      }
1578    }
1579  }
1580}
1581
1582const TargetData &ScalarEvolutionsImpl::getTargetData() const {
1583  return TD;
1584}
1585
1586SCEVHandle ScalarEvolutionsImpl::getCouldNotCompute() {
1587  return UnknownValue;
1588}
1589
1590/// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
1591/// expression and create a new one.
1592SCEVHandle ScalarEvolutionsImpl::getSCEV(Value *V) {
1593  assert(V->getType() != Type::VoidTy && "Can't analyze void expressions!");
1594
1595  std::map<Value*, SCEVHandle>::iterator I = Scalars.find(V);
1596  if (I != Scalars.end()) return I->second;
1597  SCEVHandle S = createSCEV(V);
1598  Scalars.insert(std::make_pair(V, S));
1599  return S;
1600}
1601
1602/// getIntegerSCEV - Given an integer or FP type, create a constant for the
1603/// specified signed integer value and return a SCEV for the constant.
1604SCEVHandle ScalarEvolutionsImpl::getIntegerSCEV(int Val, const Type *Ty) {
1605  if (isa<PointerType>(Ty))
1606    Ty = TD.getIntPtrType();
1607  Constant *C;
1608  if (Val == 0)
1609    C = Constant::getNullValue(Ty);
1610  else if (Ty->isFloatingPoint())
1611    C = ConstantFP::get(APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle :
1612                                APFloat::IEEEdouble, Val));
1613  else
1614    C = ConstantInt::get(Ty, Val);
1615  return SE.getUnknown(C);
1616}
1617
1618/// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
1619///
1620SCEVHandle ScalarEvolutionsImpl::getNegativeSCEV(const SCEVHandle &V) {
1621  if (SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
1622    return SE.getUnknown(ConstantExpr::getNeg(VC->getValue()));
1623
1624  const Type *Ty = V->getType();
1625  if (isa<PointerType>(Ty))
1626    Ty = TD.getIntPtrType();
1627  return SE.getMulExpr(V, SE.getConstant(ConstantInt::getAllOnesValue(Ty)));
1628}
1629
1630/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
1631SCEVHandle ScalarEvolutionsImpl::getNotSCEV(const SCEVHandle &V) {
1632  if (SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
1633    return SE.getUnknown(ConstantExpr::getNot(VC->getValue()));
1634
1635  const Type *Ty = V->getType();
1636  if (isa<PointerType>(Ty))
1637    Ty = TD.getIntPtrType();
1638  SCEVHandle AllOnes = SE.getConstant(ConstantInt::getAllOnesValue(Ty));
1639  return getMinusSCEV(AllOnes, V);
1640}
1641
1642/// getMinusSCEV - Return a SCEV corresponding to LHS - RHS.
1643///
1644SCEVHandle ScalarEvolutionsImpl::getMinusSCEV(const SCEVHandle &LHS,
1645                                              const SCEVHandle &RHS) {
1646  // X - Y --> X + -Y
1647  return SE.getAddExpr(LHS, SE.getNegativeSCEV(RHS));
1648}
1649
1650/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
1651/// input value to the specified type.  If the type must be extended, it is zero
1652/// extended.
1653SCEVHandle
1654ScalarEvolutionsImpl::getTruncateOrZeroExtend(const SCEVHandle &V,
1655                                              const Type *Ty) {
1656  const Type *SrcTy = V->getType();
1657  assert((SrcTy->isInteger() || isa<PointerType>(SrcTy)) &&
1658         (Ty->isInteger() || isa<PointerType>(Ty)) &&
1659         "Cannot truncate or zero extend with non-integer arguments!");
1660  if (TD.getTypeSizeInBits(SrcTy) == TD.getTypeSizeInBits(Ty))
1661    return V;  // No conversion
1662  if (TD.getTypeSizeInBits(SrcTy) > TD.getTypeSizeInBits(Ty))
1663    return SE.getTruncateExpr(V, Ty);
1664  return SE.getZeroExtendExpr(V, Ty);
1665}
1666
1667/// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the
1668/// input value to the specified type.  If the type must be extended, it is sign
1669/// extended.
1670SCEVHandle
1671ScalarEvolutionsImpl::getTruncateOrSignExtend(const SCEVHandle &V,
1672                                              const Type *Ty) {
1673  const Type *SrcTy = V->getType();
1674  assert((SrcTy->isInteger() || isa<PointerType>(SrcTy)) &&
1675         (Ty->isInteger() || isa<PointerType>(Ty)) &&
1676         "Cannot truncate or zero extend with non-integer arguments!");
1677  if (TD.getTypeSizeInBits(SrcTy) == TD.getTypeSizeInBits(Ty))
1678    return V;  // No conversion
1679  if (TD.getTypeSizeInBits(SrcTy) > TD.getTypeSizeInBits(Ty))
1680    return SE.getTruncateExpr(V, Ty);
1681  return SE.getSignExtendExpr(V, Ty);
1682}
1683
1684/// ReplaceSymbolicValueWithConcrete - This looks up the computed SCEV value for
1685/// the specified instruction and replaces any references to the symbolic value
1686/// SymName with the specified value.  This is used during PHI resolution.
1687void ScalarEvolutionsImpl::
1688ReplaceSymbolicValueWithConcrete(Instruction *I, const SCEVHandle &SymName,
1689                                 const SCEVHandle &NewVal) {
1690  std::map<Value*, SCEVHandle>::iterator SI = Scalars.find(I);
1691  if (SI == Scalars.end()) return;
1692
1693  SCEVHandle NV =
1694    SI->second->replaceSymbolicValuesWithConcrete(SymName, NewVal, SE);
1695  if (NV == SI->second) return;  // No change.
1696
1697  SI->second = NV;       // Update the scalars map!
1698
1699  // Any instruction values that use this instruction might also need to be
1700  // updated!
1701  for (Value::use_iterator UI = I->use_begin(), E = I->use_end();
1702       UI != E; ++UI)
1703    ReplaceSymbolicValueWithConcrete(cast<Instruction>(*UI), SymName, NewVal);
1704}
1705
1706/// createNodeForPHI - PHI nodes have two cases.  Either the PHI node exists in
1707/// a loop header, making it a potential recurrence, or it doesn't.
1708///
1709SCEVHandle ScalarEvolutionsImpl::createNodeForPHI(PHINode *PN) {
1710  if (PN->getNumIncomingValues() == 2)  // The loops have been canonicalized.
1711    if (const Loop *L = LI.getLoopFor(PN->getParent()))
1712      if (L->getHeader() == PN->getParent()) {
1713        // If it lives in the loop header, it has two incoming values, one
1714        // from outside the loop, and one from inside.
1715        unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0));
1716        unsigned BackEdge     = IncomingEdge^1;
1717
1718        // While we are analyzing this PHI node, handle its value symbolically.
1719        SCEVHandle SymbolicName = SE.getUnknown(PN);
1720        assert(Scalars.find(PN) == Scalars.end() &&
1721               "PHI node already processed?");
1722        Scalars.insert(std::make_pair(PN, SymbolicName));
1723
1724        // Using this symbolic name for the PHI, analyze the value coming around
1725        // the back-edge.
1726        SCEVHandle BEValue = getSCEV(PN->getIncomingValue(BackEdge));
1727
1728        // NOTE: If BEValue is loop invariant, we know that the PHI node just
1729        // has a special value for the first iteration of the loop.
1730
1731        // If the value coming around the backedge is an add with the symbolic
1732        // value we just inserted, then we found a simple induction variable!
1733        if (SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
1734          // If there is a single occurrence of the symbolic value, replace it
1735          // with a recurrence.
1736          unsigned FoundIndex = Add->getNumOperands();
1737          for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
1738            if (Add->getOperand(i) == SymbolicName)
1739              if (FoundIndex == e) {
1740                FoundIndex = i;
1741                break;
1742              }
1743
1744          if (FoundIndex != Add->getNumOperands()) {
1745            // Create an add with everything but the specified operand.
1746            std::vector<SCEVHandle> Ops;
1747            for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
1748              if (i != FoundIndex)
1749                Ops.push_back(Add->getOperand(i));
1750            SCEVHandle Accum = SE.getAddExpr(Ops);
1751
1752            // This is not a valid addrec if the step amount is varying each
1753            // loop iteration, but is not itself an addrec in this loop.
1754            if (Accum->isLoopInvariant(L) ||
1755                (isa<SCEVAddRecExpr>(Accum) &&
1756                 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
1757              SCEVHandle StartVal = getSCEV(PN->getIncomingValue(IncomingEdge));
1758              SCEVHandle PHISCEV  = SE.getAddRecExpr(StartVal, Accum, L);
1759
1760              // Okay, for the entire analysis of this edge we assumed the PHI
1761              // to be symbolic.  We now need to go back and update all of the
1762              // entries for the scalars that use the PHI (except for the PHI
1763              // itself) to use the new analyzed value instead of the "symbolic"
1764              // value.
1765              ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV);
1766              return PHISCEV;
1767            }
1768          }
1769        } else if (SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(BEValue)) {
1770          // Otherwise, this could be a loop like this:
1771          //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
1772          // In this case, j = {1,+,1}  and BEValue is j.
1773          // Because the other in-value of i (0) fits the evolution of BEValue
1774          // i really is an addrec evolution.
1775          if (AddRec->getLoop() == L && AddRec->isAffine()) {
1776            SCEVHandle StartVal = getSCEV(PN->getIncomingValue(IncomingEdge));
1777
1778            // If StartVal = j.start - j.stride, we can use StartVal as the
1779            // initial step of the addrec evolution.
1780            if (StartVal == SE.getMinusSCEV(AddRec->getOperand(0),
1781                                            AddRec->getOperand(1))) {
1782              SCEVHandle PHISCEV =
1783                 SE.getAddRecExpr(StartVal, AddRec->getOperand(1), L);
1784
1785              // Okay, for the entire analysis of this edge we assumed the PHI
1786              // to be symbolic.  We now need to go back and update all of the
1787              // entries for the scalars that use the PHI (except for the PHI
1788              // itself) to use the new analyzed value instead of the "symbolic"
1789              // value.
1790              ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV);
1791              return PHISCEV;
1792            }
1793          }
1794        }
1795
1796        return SymbolicName;
1797      }
1798
1799  // If it's not a loop phi, we can't handle it yet.
1800  return SE.getUnknown(PN);
1801}
1802
1803/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
1804/// guaranteed to end in (at every loop iteration).  It is, at the same time,
1805/// the minimum number of times S is divisible by 2.  For example, given {4,+,8}
1806/// it returns 2.  If S is guaranteed to be 0, it returns the bitwidth of S.
1807static uint32_t GetMinTrailingZeros(SCEVHandle S, const TargetData &TD) {
1808  if (SCEVConstant *C = dyn_cast<SCEVConstant>(S))
1809    return C->getValue()->getValue().countTrailingZeros();
1810
1811  if (SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
1812    return std::min(GetMinTrailingZeros(T->getOperand(), TD),
1813                    (uint32_t)TD.getTypeSizeInBits(T->getType()));
1814
1815  if (SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
1816    uint32_t OpRes = GetMinTrailingZeros(E->getOperand(), TD);
1817    return OpRes == TD.getTypeSizeInBits(E->getOperand()->getType()) ?
1818             TD.getTypeSizeInBits(E->getOperand()->getType()) : OpRes;
1819  }
1820
1821  if (SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
1822    uint32_t OpRes = GetMinTrailingZeros(E->getOperand(), TD);
1823    return OpRes == TD.getTypeSizeInBits(E->getOperand()->getType()) ?
1824             TD.getTypeSizeInBits(E->getOperand()->getType()) : OpRes;
1825  }
1826
1827  if (SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
1828    // The result is the min of all operands results.
1829    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0), TD);
1830    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
1831      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i), TD));
1832    return MinOpRes;
1833  }
1834
1835  if (SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
1836    // The result is the sum of all operands results.
1837    uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0), TD);
1838    uint32_t BitWidth = TD.getTypeSizeInBits(M->getType());
1839    for (unsigned i = 1, e = M->getNumOperands();
1840         SumOpRes != BitWidth && i != e; ++i)
1841      SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i), TD),
1842                          BitWidth);
1843    return SumOpRes;
1844  }
1845
1846  if (SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
1847    // The result is the min of all operands results.
1848    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0), TD);
1849    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
1850      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i), TD));
1851    return MinOpRes;
1852  }
1853
1854  if (SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
1855    // The result is the min of all operands results.
1856    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0), TD);
1857    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
1858      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i), TD));
1859    return MinOpRes;
1860  }
1861
1862  if (SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
1863    // The result is the min of all operands results.
1864    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0), TD);
1865    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
1866      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i), TD));
1867    return MinOpRes;
1868  }
1869
1870  // SCEVUDivExpr, SCEVUnknown
1871  return 0;
1872}
1873
1874/// createSCEV - We know that there is no SCEV for the specified value.
1875/// Analyze the expression.
1876///
1877SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
1878  if (!isa<IntegerType>(V->getType()) &&
1879      !isa<PointerType>(V->getType()))
1880    return SE.getUnknown(V);
1881
1882  unsigned Opcode = Instruction::UserOp1;
1883  if (Instruction *I = dyn_cast<Instruction>(V))
1884    Opcode = I->getOpcode();
1885  else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
1886    Opcode = CE->getOpcode();
1887  else
1888    return SE.getUnknown(V);
1889
1890  User *U = cast<User>(V);
1891  switch (Opcode) {
1892  case Instruction::Add:
1893    return SE.getAddExpr(getSCEV(U->getOperand(0)),
1894                         getSCEV(U->getOperand(1)));
1895  case Instruction::Mul:
1896    return SE.getMulExpr(getSCEV(U->getOperand(0)),
1897                         getSCEV(U->getOperand(1)));
1898  case Instruction::UDiv:
1899    return SE.getUDivExpr(getSCEV(U->getOperand(0)),
1900                          getSCEV(U->getOperand(1)));
1901  case Instruction::Sub:
1902    return SE.getMinusSCEV(getSCEV(U->getOperand(0)),
1903                           getSCEV(U->getOperand(1)));
1904  case Instruction::Or:
1905    // If the RHS of the Or is a constant, we may have something like:
1906    // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
1907    // optimizations will transparently handle this case.
1908    //
1909    // In order for this transformation to be safe, the LHS must be of the
1910    // form X*(2^n) and the Or constant must be less than 2^n.
1911    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
1912      SCEVHandle LHS = getSCEV(U->getOperand(0));
1913      const APInt &CIVal = CI->getValue();
1914      if (GetMinTrailingZeros(LHS, TD) >=
1915          (CIVal.getBitWidth() - CIVal.countLeadingZeros()))
1916        return SE.getAddExpr(LHS, getSCEV(U->getOperand(1)));
1917    }
1918    break;
1919  case Instruction::Xor:
1920    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
1921      // If the RHS of the xor is a signbit, then this is just an add.
1922      // Instcombine turns add of signbit into xor as a strength reduction step.
1923      if (CI->getValue().isSignBit())
1924        return SE.getAddExpr(getSCEV(U->getOperand(0)),
1925                             getSCEV(U->getOperand(1)));
1926
1927      // If the RHS of xor is -1, then this is a not operation.
1928      else if (CI->isAllOnesValue())
1929        return SE.getNotSCEV(getSCEV(U->getOperand(0)));
1930    }
1931    break;
1932
1933  case Instruction::Shl:
1934    // Turn shift left of a constant amount into a multiply.
1935    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
1936      uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
1937      Constant *X = ConstantInt::get(
1938        APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
1939      return SE.getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
1940    }
1941    break;
1942
1943  case Instruction::LShr:
1944    // Turn logical shift right of a constant into a unsigned divide.
1945    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
1946      uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
1947      Constant *X = ConstantInt::get(
1948        APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
1949      return SE.getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
1950    }
1951    break;
1952
1953  case Instruction::Trunc:
1954    return SE.getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
1955
1956  case Instruction::ZExt:
1957    return SE.getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
1958
1959  case Instruction::SExt:
1960    return SE.getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
1961
1962  case Instruction::BitCast:
1963    // BitCasts are no-op casts so we just eliminate the cast.
1964    if ((U->getType()->isInteger() ||
1965         isa<PointerType>(U->getType())) &&
1966        (U->getOperand(0)->getType()->isInteger() ||
1967         isa<PointerType>(U->getOperand(0)->getType())))
1968      return getSCEV(U->getOperand(0));
1969    break;
1970
1971  case Instruction::IntToPtr:
1972    return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)),
1973                                   TD.getIntPtrType());
1974
1975  case Instruction::PtrToInt:
1976    return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)),
1977                                   U->getType());
1978
1979  case Instruction::GetElementPtr: {
1980    const Type *IntPtrTy = TD.getIntPtrType();
1981    Value *Base = U->getOperand(0);
1982    SCEVHandle TotalOffset = SE.getIntegerSCEV(0, IntPtrTy);
1983    gep_type_iterator GTI = gep_type_begin(U);
1984    for (GetElementPtrInst::op_iterator I = next(U->op_begin()),
1985                                        E = U->op_end();
1986         I != E; ++I) {
1987      Value *Index = *I;
1988      // Compute the (potentially symbolic) offset in bytes for this index.
1989      if (const StructType *STy = dyn_cast<StructType>(*GTI++)) {
1990        // For a struct, add the member offset.
1991        const StructLayout &SL = *TD.getStructLayout(STy);
1992        unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue();
1993        uint64_t Offset = SL.getElementOffset(FieldNo);
1994        TotalOffset = SE.getAddExpr(TotalOffset,
1995                                    SE.getIntegerSCEV(Offset, IntPtrTy));
1996      } else {
1997        // For an array, add the element offset, explicitly scaled.
1998        SCEVHandle LocalOffset = getSCEV(Index);
1999        if (!isa<PointerType>(LocalOffset->getType()))
2000          // Getelementptr indicies are signed.
2001          LocalOffset = getTruncateOrSignExtend(LocalOffset,
2002                                                IntPtrTy);
2003        LocalOffset =
2004          SE.getMulExpr(LocalOffset,
2005                        SE.getIntegerSCEV(TD.getTypePaddedSize(*GTI),
2006                                          IntPtrTy));
2007        TotalOffset = SE.getAddExpr(TotalOffset, LocalOffset);
2008      }
2009    }
2010    return SE.getAddExpr(getSCEV(Base), TotalOffset);
2011  }
2012
2013  case Instruction::PHI:
2014    return createNodeForPHI(cast<PHINode>(U));
2015
2016  case Instruction::Select:
2017    // This could be a smax or umax that was lowered earlier.
2018    // Try to recover it.
2019    if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
2020      Value *LHS = ICI->getOperand(0);
2021      Value *RHS = ICI->getOperand(1);
2022      switch (ICI->getPredicate()) {
2023      case ICmpInst::ICMP_SLT:
2024      case ICmpInst::ICMP_SLE:
2025        std::swap(LHS, RHS);
2026        // fall through
2027      case ICmpInst::ICMP_SGT:
2028      case ICmpInst::ICMP_SGE:
2029        if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
2030          return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
2031        else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
2032          // ~smax(~x, ~y) == smin(x, y).
2033          return SE.getNotSCEV(SE.getSMaxExpr(
2034                                   SE.getNotSCEV(getSCEV(LHS)),
2035                                   SE.getNotSCEV(getSCEV(RHS))));
2036        break;
2037      case ICmpInst::ICMP_ULT:
2038      case ICmpInst::ICMP_ULE:
2039        std::swap(LHS, RHS);
2040        // fall through
2041      case ICmpInst::ICMP_UGT:
2042      case ICmpInst::ICMP_UGE:
2043        if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
2044          return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
2045        else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
2046          // ~umax(~x, ~y) == umin(x, y)
2047          return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)),
2048                                              SE.getNotSCEV(getSCEV(RHS))));
2049        break;
2050      default:
2051        break;
2052      }
2053    }
2054
2055  default: // We cannot analyze this expression.
2056    break;
2057  }
2058
2059  return SE.getUnknown(V);
2060}
2061
2062
2063
2064//===----------------------------------------------------------------------===//
2065//                   Iteration Count Computation Code
2066//
2067
2068/// getBackedgeTakenCount - If the specified loop has a predictable
2069/// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute
2070/// object. The backedge-taken count is the number of times the loop header
2071/// will be branched to from within the loop. This is one less than the
2072/// trip count of the loop, since it doesn't count the first iteration,
2073/// when the header is branched to from outside the loop.
2074///
2075/// Note that it is not valid to call this method on a loop without a
2076/// loop-invariant backedge-taken count (see
2077/// hasLoopInvariantBackedgeTakenCount).
2078///
2079SCEVHandle ScalarEvolutionsImpl::getBackedgeTakenCount(const Loop *L) {
2080  std::map<const Loop*, SCEVHandle>::iterator I = BackedgeTakenCounts.find(L);
2081  if (I == BackedgeTakenCounts.end()) {
2082    SCEVHandle ItCount = ComputeBackedgeTakenCount(L);
2083    I = BackedgeTakenCounts.insert(std::make_pair(L, ItCount)).first;
2084    if (ItCount != UnknownValue) {
2085      assert(ItCount->isLoopInvariant(L) &&
2086             "Computed trip count isn't loop invariant for loop!");
2087      ++NumTripCountsComputed;
2088    } else if (isa<PHINode>(L->getHeader()->begin())) {
2089      // Only count loops that have phi nodes as not being computable.
2090      ++NumTripCountsNotComputed;
2091    }
2092  }
2093  return I->second;
2094}
2095
2096/// forgetLoopBackedgeTakenCount - This method should be called by the
2097/// client when it has changed a loop in a way that may effect
2098/// ScalarEvolution's ability to compute a trip count, or if the loop
2099/// is deleted.
2100void ScalarEvolutionsImpl::forgetLoopBackedgeTakenCount(const Loop *L) {
2101  BackedgeTakenCounts.erase(L);
2102}
2103
2104/// ComputeBackedgeTakenCount - Compute the number of times the backedge
2105/// of the specified loop will execute.
2106SCEVHandle ScalarEvolutionsImpl::ComputeBackedgeTakenCount(const Loop *L) {
2107  // If the loop has a non-one exit block count, we can't analyze it.
2108  SmallVector<BasicBlock*, 8> ExitBlocks;
2109  L->getExitBlocks(ExitBlocks);
2110  if (ExitBlocks.size() != 1) return UnknownValue;
2111
2112  // Okay, there is one exit block.  Try to find the condition that causes the
2113  // loop to be exited.
2114  BasicBlock *ExitBlock = ExitBlocks[0];
2115
2116  BasicBlock *ExitingBlock = 0;
2117  for (pred_iterator PI = pred_begin(ExitBlock), E = pred_end(ExitBlock);
2118       PI != E; ++PI)
2119    if (L->contains(*PI)) {
2120      if (ExitingBlock == 0)
2121        ExitingBlock = *PI;
2122      else
2123        return UnknownValue;   // More than one block exiting!
2124    }
2125  assert(ExitingBlock && "No exits from loop, something is broken!");
2126
2127  // Okay, we've computed the exiting block.  See what condition causes us to
2128  // exit.
2129  //
2130  // FIXME: we should be able to handle switch instructions (with a single exit)
2131  BranchInst *ExitBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
2132  if (ExitBr == 0) return UnknownValue;
2133  assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!");
2134
2135  // At this point, we know we have a conditional branch that determines whether
2136  // the loop is exited.  However, we don't know if the branch is executed each
2137  // time through the loop.  If not, then the execution count of the branch will
2138  // not be equal to the trip count of the loop.
2139  //
2140  // Currently we check for this by checking to see if the Exit branch goes to
2141  // the loop header.  If so, we know it will always execute the same number of
2142  // times as the loop.  We also handle the case where the exit block *is* the
2143  // loop header.  This is common for un-rotated loops.  More extensive analysis
2144  // could be done to handle more cases here.
2145  if (ExitBr->getSuccessor(0) != L->getHeader() &&
2146      ExitBr->getSuccessor(1) != L->getHeader() &&
2147      ExitBr->getParent() != L->getHeader())
2148    return UnknownValue;
2149
2150  ICmpInst *ExitCond = dyn_cast<ICmpInst>(ExitBr->getCondition());
2151
2152  // If it's not an integer comparison then compute it the hard way.
2153  // Note that ICmpInst deals with pointer comparisons too so we must check
2154  // the type of the operand.
2155  if (ExitCond == 0 || isa<PointerType>(ExitCond->getOperand(0)->getType()))
2156    return ComputeBackedgeTakenCountExhaustively(L, ExitBr->getCondition(),
2157                                          ExitBr->getSuccessor(0) == ExitBlock);
2158
2159  // If the condition was exit on true, convert the condition to exit on false
2160  ICmpInst::Predicate Cond;
2161  if (ExitBr->getSuccessor(1) == ExitBlock)
2162    Cond = ExitCond->getPredicate();
2163  else
2164    Cond = ExitCond->getInversePredicate();
2165
2166  // Handle common loops like: for (X = "string"; *X; ++X)
2167  if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
2168    if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
2169      SCEVHandle ItCnt =
2170        ComputeLoadConstantCompareBackedgeTakenCount(LI, RHS, L, Cond);
2171      if (!isa<SCEVCouldNotCompute>(ItCnt)) return ItCnt;
2172    }
2173
2174  SCEVHandle LHS = getSCEV(ExitCond->getOperand(0));
2175  SCEVHandle RHS = getSCEV(ExitCond->getOperand(1));
2176
2177  // Try to evaluate any dependencies out of the loop.
2178  SCEVHandle Tmp = getSCEVAtScope(LHS, L);
2179  if (!isa<SCEVCouldNotCompute>(Tmp)) LHS = Tmp;
2180  Tmp = getSCEVAtScope(RHS, L);
2181  if (!isa<SCEVCouldNotCompute>(Tmp)) RHS = Tmp;
2182
2183  // At this point, we would like to compute how many iterations of the
2184  // loop the predicate will return true for these inputs.
2185  if (LHS->isLoopInvariant(L) && !RHS->isLoopInvariant(L)) {
2186    // If there is a loop-invariant, force it into the RHS.
2187    std::swap(LHS, RHS);
2188    Cond = ICmpInst::getSwappedPredicate(Cond);
2189  }
2190
2191  // FIXME: think about handling pointer comparisons!  i.e.:
2192  // while (P != P+100) ++P;
2193
2194  // If we have a comparison of a chrec against a constant, try to use value
2195  // ranges to answer this query.
2196  if (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
2197    if (SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
2198      if (AddRec->getLoop() == L) {
2199        // Form the comparison range using the constant of the correct type so
2200        // that the ConstantRange class knows to do a signed or unsigned
2201        // comparison.
2202        ConstantInt *CompVal = RHSC->getValue();
2203        const Type *RealTy = ExitCond->getOperand(0)->getType();
2204        CompVal = dyn_cast<ConstantInt>(
2205          ConstantExpr::getBitCast(CompVal, RealTy));
2206        if (CompVal) {
2207          // Form the constant range.
2208          ConstantRange CompRange(
2209              ICmpInst::makeConstantRange(Cond, CompVal->getValue()));
2210
2211          SCEVHandle Ret = AddRec->getNumIterationsInRange(CompRange, SE);
2212          if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
2213        }
2214      }
2215
2216  switch (Cond) {
2217  case ICmpInst::ICMP_NE: {                     // while (X != Y)
2218    // Convert to: while (X-Y != 0)
2219    SCEVHandle TC = HowFarToZero(SE.getMinusSCEV(LHS, RHS), L);
2220    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2221    break;
2222  }
2223  case ICmpInst::ICMP_EQ: {
2224    // Convert to: while (X-Y == 0)           // while (X == Y)
2225    SCEVHandle TC = HowFarToNonZero(SE.getMinusSCEV(LHS, RHS), L);
2226    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2227    break;
2228  }
2229  case ICmpInst::ICMP_SLT: {
2230    SCEVHandle TC = HowManyLessThans(LHS, RHS, L, true);
2231    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2232    break;
2233  }
2234  case ICmpInst::ICMP_SGT: {
2235    SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS),
2236                                     SE.getNotSCEV(RHS), L, true);
2237    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2238    break;
2239  }
2240  case ICmpInst::ICMP_ULT: {
2241    SCEVHandle TC = HowManyLessThans(LHS, RHS, L, false);
2242    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2243    break;
2244  }
2245  case ICmpInst::ICMP_UGT: {
2246    SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS),
2247                                     SE.getNotSCEV(RHS), L, false);
2248    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2249    break;
2250  }
2251  default:
2252#if 0
2253    errs() << "ComputeBackedgeTakenCount ";
2254    if (ExitCond->getOperand(0)->getType()->isUnsigned())
2255      errs() << "[unsigned] ";
2256    errs() << *LHS << "   "
2257         << Instruction::getOpcodeName(Instruction::ICmp)
2258         << "   " << *RHS << "\n";
2259#endif
2260    break;
2261  }
2262  return
2263    ComputeBackedgeTakenCountExhaustively(L, ExitCond,
2264                                          ExitBr->getSuccessor(0) == ExitBlock);
2265}
2266
2267static ConstantInt *
2268EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
2269                                ScalarEvolution &SE) {
2270  SCEVHandle InVal = SE.getConstant(C);
2271  SCEVHandle Val = AddRec->evaluateAtIteration(InVal, SE);
2272  assert(isa<SCEVConstant>(Val) &&
2273         "Evaluation of SCEV at constant didn't fold correctly?");
2274  return cast<SCEVConstant>(Val)->getValue();
2275}
2276
2277/// GetAddressedElementFromGlobal - Given a global variable with an initializer
2278/// and a GEP expression (missing the pointer index) indexing into it, return
2279/// the addressed element of the initializer or null if the index expression is
2280/// invalid.
2281static Constant *
2282GetAddressedElementFromGlobal(GlobalVariable *GV,
2283                              const std::vector<ConstantInt*> &Indices) {
2284  Constant *Init = GV->getInitializer();
2285  for (unsigned i = 0, e = Indices.size(); i != e; ++i) {
2286    uint64_t Idx = Indices[i]->getZExtValue();
2287    if (ConstantStruct *CS = dyn_cast<ConstantStruct>(Init)) {
2288      assert(Idx < CS->getNumOperands() && "Bad struct index!");
2289      Init = cast<Constant>(CS->getOperand(Idx));
2290    } else if (ConstantArray *CA = dyn_cast<ConstantArray>(Init)) {
2291      if (Idx >= CA->getNumOperands()) return 0;  // Bogus program
2292      Init = cast<Constant>(CA->getOperand(Idx));
2293    } else if (isa<ConstantAggregateZero>(Init)) {
2294      if (const StructType *STy = dyn_cast<StructType>(Init->getType())) {
2295        assert(Idx < STy->getNumElements() && "Bad struct index!");
2296        Init = Constant::getNullValue(STy->getElementType(Idx));
2297      } else if (const ArrayType *ATy = dyn_cast<ArrayType>(Init->getType())) {
2298        if (Idx >= ATy->getNumElements()) return 0;  // Bogus program
2299        Init = Constant::getNullValue(ATy->getElementType());
2300      } else {
2301        assert(0 && "Unknown constant aggregate type!");
2302      }
2303      return 0;
2304    } else {
2305      return 0; // Unknown initializer type
2306    }
2307  }
2308  return Init;
2309}
2310
2311/// ComputeLoadConstantCompareBackedgeTakenCount - Given an exit condition of
2312/// 'icmp op load X, cst', try to see if we can compute the backedge
2313/// execution count.
2314SCEVHandle ScalarEvolutionsImpl::
2315ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, Constant *RHS,
2316                                             const Loop *L,
2317                                             ICmpInst::Predicate predicate) {
2318  if (LI->isVolatile()) return UnknownValue;
2319
2320  // Check to see if the loaded pointer is a getelementptr of a global.
2321  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0));
2322  if (!GEP) return UnknownValue;
2323
2324  // Make sure that it is really a constant global we are gepping, with an
2325  // initializer, and make sure the first IDX is really 0.
2326  GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
2327  if (!GV || !GV->isConstant() || !GV->hasInitializer() ||
2328      GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) ||
2329      !cast<Constant>(GEP->getOperand(1))->isNullValue())
2330    return UnknownValue;
2331
2332  // Okay, we allow one non-constant index into the GEP instruction.
2333  Value *VarIdx = 0;
2334  std::vector<ConstantInt*> Indexes;
2335  unsigned VarIdxNum = 0;
2336  for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i)
2337    if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
2338      Indexes.push_back(CI);
2339    } else if (!isa<ConstantInt>(GEP->getOperand(i))) {
2340      if (VarIdx) return UnknownValue;  // Multiple non-constant idx's.
2341      VarIdx = GEP->getOperand(i);
2342      VarIdxNum = i-2;
2343      Indexes.push_back(0);
2344    }
2345
2346  // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant.
2347  // Check to see if X is a loop variant variable value now.
2348  SCEVHandle Idx = getSCEV(VarIdx);
2349  SCEVHandle Tmp = getSCEVAtScope(Idx, L);
2350  if (!isa<SCEVCouldNotCompute>(Tmp)) Idx = Tmp;
2351
2352  // We can only recognize very limited forms of loop index expressions, in
2353  // particular, only affine AddRec's like {C1,+,C2}.
2354  SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx);
2355  if (!IdxExpr || !IdxExpr->isAffine() || IdxExpr->isLoopInvariant(L) ||
2356      !isa<SCEVConstant>(IdxExpr->getOperand(0)) ||
2357      !isa<SCEVConstant>(IdxExpr->getOperand(1)))
2358    return UnknownValue;
2359
2360  unsigned MaxSteps = MaxBruteForceIterations;
2361  for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
2362    ConstantInt *ItCst =
2363      ConstantInt::get(IdxExpr->getType(), IterationNum);
2364    ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, SE);
2365
2366    // Form the GEP offset.
2367    Indexes[VarIdxNum] = Val;
2368
2369    Constant *Result = GetAddressedElementFromGlobal(GV, Indexes);
2370    if (Result == 0) break;  // Cannot compute!
2371
2372    // Evaluate the condition for this iteration.
2373    Result = ConstantExpr::getICmp(predicate, Result, RHS);
2374    if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
2375    if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
2376#if 0
2377      errs() << "\n***\n*** Computed loop count " << *ItCst
2378             << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
2379             << "***\n";
2380#endif
2381      ++NumArrayLenItCounts;
2382      return SE.getConstant(ItCst);   // Found terminating iteration!
2383    }
2384  }
2385  return UnknownValue;
2386}
2387
2388
2389/// CanConstantFold - Return true if we can constant fold an instruction of the
2390/// specified type, assuming that all operands were constants.
2391static bool CanConstantFold(const Instruction *I) {
2392  if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
2393      isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I))
2394    return true;
2395
2396  if (const CallInst *CI = dyn_cast<CallInst>(I))
2397    if (const Function *F = CI->getCalledFunction())
2398      return canConstantFoldCallTo(F);
2399  return false;
2400}
2401
2402/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
2403/// in the loop that V is derived from.  We allow arbitrary operations along the
2404/// way, but the operands of an operation must either be constants or a value
2405/// derived from a constant PHI.  If this expression does not fit with these
2406/// constraints, return null.
2407static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
2408  // If this is not an instruction, or if this is an instruction outside of the
2409  // loop, it can't be derived from a loop PHI.
2410  Instruction *I = dyn_cast<Instruction>(V);
2411  if (I == 0 || !L->contains(I->getParent())) return 0;
2412
2413  if (PHINode *PN = dyn_cast<PHINode>(I)) {
2414    if (L->getHeader() == I->getParent())
2415      return PN;
2416    else
2417      // We don't currently keep track of the control flow needed to evaluate
2418      // PHIs, so we cannot handle PHIs inside of loops.
2419      return 0;
2420  }
2421
2422  // If we won't be able to constant fold this expression even if the operands
2423  // are constants, return early.
2424  if (!CanConstantFold(I)) return 0;
2425
2426  // Otherwise, we can evaluate this instruction if all of its operands are
2427  // constant or derived from a PHI node themselves.
2428  PHINode *PHI = 0;
2429  for (unsigned Op = 0, e = I->getNumOperands(); Op != e; ++Op)
2430    if (!(isa<Constant>(I->getOperand(Op)) ||
2431          isa<GlobalValue>(I->getOperand(Op)))) {
2432      PHINode *P = getConstantEvolvingPHI(I->getOperand(Op), L);
2433      if (P == 0) return 0;  // Not evolving from PHI
2434      if (PHI == 0)
2435        PHI = P;
2436      else if (PHI != P)
2437        return 0;  // Evolving from multiple different PHIs.
2438    }
2439
2440  // This is a expression evolving from a constant PHI!
2441  return PHI;
2442}
2443
2444/// EvaluateExpression - Given an expression that passes the
2445/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
2446/// in the loop has the value PHIVal.  If we can't fold this expression for some
2447/// reason, return null.
2448static Constant *EvaluateExpression(Value *V, Constant *PHIVal) {
2449  if (isa<PHINode>(V)) return PHIVal;
2450  if (Constant *C = dyn_cast<Constant>(V)) return C;
2451  if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) return GV;
2452  Instruction *I = cast<Instruction>(V);
2453
2454  std::vector<Constant*> Operands;
2455  Operands.resize(I->getNumOperands());
2456
2457  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
2458    Operands[i] = EvaluateExpression(I->getOperand(i), PHIVal);
2459    if (Operands[i] == 0) return 0;
2460  }
2461
2462  if (const CmpInst *CI = dyn_cast<CmpInst>(I))
2463    return ConstantFoldCompareInstOperands(CI->getPredicate(),
2464                                           &Operands[0], Operands.size());
2465  else
2466    return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
2467                                    &Operands[0], Operands.size());
2468}
2469
2470/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
2471/// in the header of its containing loop, we know the loop executes a
2472/// constant number of times, and the PHI node is just a recurrence
2473/// involving constants, fold it.
2474Constant *ScalarEvolutionsImpl::
2475getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& BEs, const Loop *L){
2476  std::map<PHINode*, Constant*>::iterator I =
2477    ConstantEvolutionLoopExitValue.find(PN);
2478  if (I != ConstantEvolutionLoopExitValue.end())
2479    return I->second;
2480
2481  if (BEs.ugt(APInt(BEs.getBitWidth(),MaxBruteForceIterations)))
2482    return ConstantEvolutionLoopExitValue[PN] = 0;  // Not going to evaluate it.
2483
2484  Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
2485
2486  // Since the loop is canonicalized, the PHI node must have two entries.  One
2487  // entry must be a constant (coming in from outside of the loop), and the
2488  // second must be derived from the same PHI.
2489  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
2490  Constant *StartCST =
2491    dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
2492  if (StartCST == 0)
2493    return RetVal = 0;  // Must be a constant.
2494
2495  Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
2496  PHINode *PN2 = getConstantEvolvingPHI(BEValue, L);
2497  if (PN2 != PN)
2498    return RetVal = 0;  // Not derived from same PHI.
2499
2500  // Execute the loop symbolically to determine the exit value.
2501  if (BEs.getActiveBits() >= 32)
2502    return RetVal = 0; // More than 2^32-1 iterations?? Not doing it!
2503
2504  unsigned NumIterations = BEs.getZExtValue(); // must be in range
2505  unsigned IterationNum = 0;
2506  for (Constant *PHIVal = StartCST; ; ++IterationNum) {
2507    if (IterationNum == NumIterations)
2508      return RetVal = PHIVal;  // Got exit value!
2509
2510    // Compute the value of the PHI node for the next iteration.
2511    Constant *NextPHI = EvaluateExpression(BEValue, PHIVal);
2512    if (NextPHI == PHIVal)
2513      return RetVal = NextPHI;  // Stopped evolving!
2514    if (NextPHI == 0)
2515      return 0;        // Couldn't evaluate!
2516    PHIVal = NextPHI;
2517  }
2518}
2519
2520/// ComputeBackedgeTakenCountExhaustively - If the trip is known to execute a
2521/// constant number of times (the condition evolves only from constants),
2522/// try to evaluate a few iterations of the loop until we get the exit
2523/// condition gets a value of ExitWhen (true or false).  If we cannot
2524/// evaluate the trip count of the loop, return UnknownValue.
2525SCEVHandle ScalarEvolutionsImpl::
2526ComputeBackedgeTakenCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) {
2527  PHINode *PN = getConstantEvolvingPHI(Cond, L);
2528  if (PN == 0) return UnknownValue;
2529
2530  // Since the loop is canonicalized, the PHI node must have two entries.  One
2531  // entry must be a constant (coming in from outside of the loop), and the
2532  // second must be derived from the same PHI.
2533  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
2534  Constant *StartCST =
2535    dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
2536  if (StartCST == 0) return UnknownValue;  // Must be a constant.
2537
2538  Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
2539  PHINode *PN2 = getConstantEvolvingPHI(BEValue, L);
2540  if (PN2 != PN) return UnknownValue;  // Not derived from same PHI.
2541
2542  // Okay, we find a PHI node that defines the trip count of this loop.  Execute
2543  // the loop symbolically to determine when the condition gets a value of
2544  // "ExitWhen".
2545  unsigned IterationNum = 0;
2546  unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis.
2547  for (Constant *PHIVal = StartCST;
2548       IterationNum != MaxIterations; ++IterationNum) {
2549    ConstantInt *CondVal =
2550      dyn_cast_or_null<ConstantInt>(EvaluateExpression(Cond, PHIVal));
2551
2552    // Couldn't symbolically evaluate.
2553    if (!CondVal) return UnknownValue;
2554
2555    if (CondVal->getValue() == uint64_t(ExitWhen)) {
2556      ConstantEvolutionLoopExitValue[PN] = PHIVal;
2557      ++NumBruteForceTripCountsComputed;
2558      return SE.getConstant(ConstantInt::get(Type::Int32Ty, IterationNum));
2559    }
2560
2561    // Compute the value of the PHI node for the next iteration.
2562    Constant *NextPHI = EvaluateExpression(BEValue, PHIVal);
2563    if (NextPHI == 0 || NextPHI == PHIVal)
2564      return UnknownValue;  // Couldn't evaluate or not making progress...
2565    PHIVal = NextPHI;
2566  }
2567
2568  // Too many iterations were needed to evaluate.
2569  return UnknownValue;
2570}
2571
2572/// getSCEVAtScope - Compute the value of the specified expression within the
2573/// indicated loop (which may be null to indicate in no loop).  If the
2574/// expression cannot be evaluated, return UnknownValue.
2575SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) {
2576  // FIXME: this should be turned into a virtual method on SCEV!
2577
2578  if (isa<SCEVConstant>(V)) return V;
2579
2580  // If this instruction is evolved from a constant-evolving PHI, compute the
2581  // exit value from the loop without using SCEVs.
2582  if (SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
2583    if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
2584      const Loop *LI = this->LI[I->getParent()];
2585      if (LI && LI->getParentLoop() == L)  // Looking for loop exit value.
2586        if (PHINode *PN = dyn_cast<PHINode>(I))
2587          if (PN->getParent() == LI->getHeader()) {
2588            // Okay, there is no closed form solution for the PHI node.  Check
2589            // to see if the loop that contains it has a known backedge-taken
2590            // count.  If so, we may be able to force computation of the exit
2591            // value.
2592            SCEVHandle BackedgeTakenCount = getBackedgeTakenCount(LI);
2593            if (SCEVConstant *BTCC =
2594                  dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
2595              // Okay, we know how many times the containing loop executes.  If
2596              // this is a constant evolving PHI node, get the final value at
2597              // the specified iteration number.
2598              Constant *RV = getConstantEvolutionLoopExitValue(PN,
2599                                                   BTCC->getValue()->getValue(),
2600                                                               LI);
2601              if (RV) return SE.getUnknown(RV);
2602            }
2603          }
2604
2605      // Okay, this is an expression that we cannot symbolically evaluate
2606      // into a SCEV.  Check to see if it's possible to symbolically evaluate
2607      // the arguments into constants, and if so, try to constant propagate the
2608      // result.  This is particularly useful for computing loop exit values.
2609      if (CanConstantFold(I)) {
2610        std::vector<Constant*> Operands;
2611        Operands.reserve(I->getNumOperands());
2612        for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
2613          Value *Op = I->getOperand(i);
2614          if (Constant *C = dyn_cast<Constant>(Op)) {
2615            Operands.push_back(C);
2616          } else {
2617            // If any of the operands is non-constant and if they are
2618            // non-integer and non-pointer, don't even try to analyze them
2619            // with scev techniques.
2620            if (!isa<IntegerType>(Op->getType()) &&
2621                !isa<PointerType>(Op->getType()))
2622              return V;
2623
2624            SCEVHandle OpV = getSCEVAtScope(getSCEV(Op), L);
2625            if (SCEVConstant *SC = dyn_cast<SCEVConstant>(OpV))
2626              Operands.push_back(ConstantExpr::getIntegerCast(SC->getValue(),
2627                                                              Op->getType(),
2628                                                              false));
2629            else if (SCEVUnknown *SU = dyn_cast<SCEVUnknown>(OpV)) {
2630              if (Constant *C = dyn_cast<Constant>(SU->getValue()))
2631                Operands.push_back(ConstantExpr::getIntegerCast(C,
2632                                                                Op->getType(),
2633                                                                false));
2634              else
2635                return V;
2636            } else {
2637              return V;
2638            }
2639          }
2640        }
2641
2642        Constant *C;
2643        if (const CmpInst *CI = dyn_cast<CmpInst>(I))
2644          C = ConstantFoldCompareInstOperands(CI->getPredicate(),
2645                                              &Operands[0], Operands.size());
2646        else
2647          C = ConstantFoldInstOperands(I->getOpcode(), I->getType(),
2648                                       &Operands[0], Operands.size());
2649        return SE.getUnknown(C);
2650      }
2651    }
2652
2653    // This is some other type of SCEVUnknown, just return it.
2654    return V;
2655  }
2656
2657  if (SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) {
2658    // Avoid performing the look-up in the common case where the specified
2659    // expression has no loop-variant portions.
2660    for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
2661      SCEVHandle OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
2662      if (OpAtScope != Comm->getOperand(i)) {
2663        if (OpAtScope == UnknownValue) return UnknownValue;
2664        // Okay, at least one of these operands is loop variant but might be
2665        // foldable.  Build a new instance of the folded commutative expression.
2666        std::vector<SCEVHandle> NewOps(Comm->op_begin(), Comm->op_begin()+i);
2667        NewOps.push_back(OpAtScope);
2668
2669        for (++i; i != e; ++i) {
2670          OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
2671          if (OpAtScope == UnknownValue) return UnknownValue;
2672          NewOps.push_back(OpAtScope);
2673        }
2674        if (isa<SCEVAddExpr>(Comm))
2675          return SE.getAddExpr(NewOps);
2676        if (isa<SCEVMulExpr>(Comm))
2677          return SE.getMulExpr(NewOps);
2678        if (isa<SCEVSMaxExpr>(Comm))
2679          return SE.getSMaxExpr(NewOps);
2680        if (isa<SCEVUMaxExpr>(Comm))
2681          return SE.getUMaxExpr(NewOps);
2682        assert(0 && "Unknown commutative SCEV type!");
2683      }
2684    }
2685    // If we got here, all operands are loop invariant.
2686    return Comm;
2687  }
2688
2689  if (SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
2690    SCEVHandle LHS = getSCEVAtScope(Div->getLHS(), L);
2691    if (LHS == UnknownValue) return LHS;
2692    SCEVHandle RHS = getSCEVAtScope(Div->getRHS(), L);
2693    if (RHS == UnknownValue) return RHS;
2694    if (LHS == Div->getLHS() && RHS == Div->getRHS())
2695      return Div;   // must be loop invariant
2696    return SE.getUDivExpr(LHS, RHS);
2697  }
2698
2699  // If this is a loop recurrence for a loop that does not contain L, then we
2700  // are dealing with the final value computed by the loop.
2701  if (SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
2702    if (!L || !AddRec->getLoop()->contains(L->getHeader())) {
2703      // To evaluate this recurrence, we need to know how many times the AddRec
2704      // loop iterates.  Compute this now.
2705      SCEVHandle BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
2706      if (BackedgeTakenCount == UnknownValue) return UnknownValue;
2707
2708      // Then, evaluate the AddRec.
2709      return AddRec->evaluateAtIteration(BackedgeTakenCount, SE);
2710    }
2711    return UnknownValue;
2712  }
2713
2714  //assert(0 && "Unknown SCEV type!");
2715  return UnknownValue;
2716}
2717
2718/// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the
2719/// following equation:
2720///
2721///     A * X = B (mod N)
2722///
2723/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
2724/// A and B isn't important.
2725///
2726/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
2727static SCEVHandle SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
2728                                               ScalarEvolution &SE) {
2729  uint32_t BW = A.getBitWidth();
2730  assert(BW == B.getBitWidth() && "Bit widths must be the same.");
2731  assert(A != 0 && "A must be non-zero.");
2732
2733  // 1. D = gcd(A, N)
2734  //
2735  // The gcd of A and N may have only one prime factor: 2. The number of
2736  // trailing zeros in A is its multiplicity
2737  uint32_t Mult2 = A.countTrailingZeros();
2738  // D = 2^Mult2
2739
2740  // 2. Check if B is divisible by D.
2741  //
2742  // B is divisible by D if and only if the multiplicity of prime factor 2 for B
2743  // is not less than multiplicity of this prime factor for D.
2744  if (B.countTrailingZeros() < Mult2)
2745    return SE.getCouldNotCompute();
2746
2747  // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
2748  // modulo (N / D).
2749  //
2750  // (N / D) may need BW+1 bits in its representation.  Hence, we'll use this
2751  // bit width during computations.
2752  APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
2753  APInt Mod(BW + 1, 0);
2754  Mod.set(BW - Mult2);  // Mod = N / D
2755  APInt I = AD.multiplicativeInverse(Mod);
2756
2757  // 4. Compute the minimum unsigned root of the equation:
2758  // I * (B / D) mod (N / D)
2759  APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
2760
2761  // The result is guaranteed to be less than 2^BW so we may truncate it to BW
2762  // bits.
2763  return SE.getConstant(Result.trunc(BW));
2764}
2765
2766/// SolveQuadraticEquation - Find the roots of the quadratic equation for the
2767/// given quadratic chrec {L,+,M,+,N}.  This returns either the two roots (which
2768/// might be the same) or two SCEVCouldNotCompute objects.
2769///
2770static std::pair<SCEVHandle,SCEVHandle>
2771SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
2772  assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
2773  SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
2774  SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
2775  SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
2776
2777  // We currently can only solve this if the coefficients are constants.
2778  if (!LC || !MC || !NC) {
2779    SCEV *CNC = SE.getCouldNotCompute();
2780    return std::make_pair(CNC, CNC);
2781  }
2782
2783  uint32_t BitWidth = LC->getValue()->getValue().getBitWidth();
2784  const APInt &L = LC->getValue()->getValue();
2785  const APInt &M = MC->getValue()->getValue();
2786  const APInt &N = NC->getValue()->getValue();
2787  APInt Two(BitWidth, 2);
2788  APInt Four(BitWidth, 4);
2789
2790  {
2791    using namespace APIntOps;
2792    const APInt& C = L;
2793    // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
2794    // The B coefficient is M-N/2
2795    APInt B(M);
2796    B -= sdiv(N,Two);
2797
2798    // The A coefficient is N/2
2799    APInt A(N.sdiv(Two));
2800
2801    // Compute the B^2-4ac term.
2802    APInt SqrtTerm(B);
2803    SqrtTerm *= B;
2804    SqrtTerm -= Four * (A * C);
2805
2806    // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
2807    // integer value or else APInt::sqrt() will assert.
2808    APInt SqrtVal(SqrtTerm.sqrt());
2809
2810    // Compute the two solutions for the quadratic formula.
2811    // The divisions must be performed as signed divisions.
2812    APInt NegB(-B);
2813    APInt TwoA( A << 1 );
2814    if (TwoA.isMinValue()) {
2815      SCEV *CNC = SE.getCouldNotCompute();
2816      return std::make_pair(CNC, CNC);
2817    }
2818
2819    ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA));
2820    ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA));
2821
2822    return std::make_pair(SE.getConstant(Solution1),
2823                          SE.getConstant(Solution2));
2824    } // end APIntOps namespace
2825}
2826
2827/// HowFarToZero - Return the number of times a backedge comparing the specified
2828/// value to zero will execute.  If not computable, return UnknownValue
2829SCEVHandle ScalarEvolutionsImpl::HowFarToZero(SCEV *V, const Loop *L) {
2830  // If the value is a constant
2831  if (SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
2832    // If the value is already zero, the branch will execute zero times.
2833    if (C->getValue()->isZero()) return C;
2834    return UnknownValue;  // Otherwise it will loop infinitely.
2835  }
2836
2837  SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V);
2838  if (!AddRec || AddRec->getLoop() != L)
2839    return UnknownValue;
2840
2841  if (AddRec->isAffine()) {
2842    // If this is an affine expression, the execution count of this branch is
2843    // the minimum unsigned root of the following equation:
2844    //
2845    //     Start + Step*N = 0 (mod 2^BW)
2846    //
2847    // equivalent to:
2848    //
2849    //             Step*N = -Start (mod 2^BW)
2850    //
2851    // where BW is the common bit width of Start and Step.
2852
2853    // Get the initial value for the loop.
2854    SCEVHandle Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
2855    if (isa<SCEVCouldNotCompute>(Start)) return UnknownValue;
2856
2857    SCEVHandle Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
2858
2859    if (SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step)) {
2860      // For now we handle only constant steps.
2861
2862      // First, handle unitary steps.
2863      if (StepC->getValue()->equalsInt(1))      // 1*N = -Start (mod 2^BW), so:
2864        return SE.getNegativeSCEV(Start);       //   N = -Start (as unsigned)
2865      if (StepC->getValue()->isAllOnesValue())  // -1*N = -Start (mod 2^BW), so:
2866        return Start;                           //    N = Start (as unsigned)
2867
2868      // Then, try to solve the above equation provided that Start is constant.
2869      if (SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
2870        return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
2871                                            -StartC->getValue()->getValue(),SE);
2872    }
2873  } else if (AddRec->isQuadratic() && AddRec->getType()->isInteger()) {
2874    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
2875    // the quadratic equation to solve it.
2876    std::pair<SCEVHandle,SCEVHandle> Roots = SolveQuadraticEquation(AddRec, SE);
2877    SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
2878    SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
2879    if (R1) {
2880#if 0
2881      errs() << "HFTZ: " << *V << " - sol#1: " << *R1
2882             << "  sol#2: " << *R2 << "\n";
2883#endif
2884      // Pick the smallest positive root value.
2885      if (ConstantInt *CB =
2886          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
2887                                   R1->getValue(), R2->getValue()))) {
2888        if (CB->getZExtValue() == false)
2889          std::swap(R1, R2);   // R1 is the minimum root now.
2890
2891        // We can only use this value if the chrec ends up with an exact zero
2892        // value at this index.  When solving for "X*X != 5", for example, we
2893        // should not accept a root of 2.
2894        SCEVHandle Val = AddRec->evaluateAtIteration(R1, SE);
2895        if (Val->isZero())
2896          return R1;  // We found a quadratic root!
2897      }
2898    }
2899  }
2900
2901  return UnknownValue;
2902}
2903
2904/// HowFarToNonZero - Return the number of times a backedge checking the
2905/// specified value for nonzero will execute.  If not computable, return
2906/// UnknownValue
2907SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) {
2908  // Loops that look like: while (X == 0) are very strange indeed.  We don't
2909  // handle them yet except for the trivial case.  This could be expanded in the
2910  // future as needed.
2911
2912  // If the value is a constant, check to see if it is known to be non-zero
2913  // already.  If so, the backedge will execute zero times.
2914  if (SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
2915    if (!C->getValue()->isNullValue())
2916      return SE.getIntegerSCEV(0, C->getType());
2917    return UnknownValue;  // Otherwise it will loop infinitely.
2918  }
2919
2920  // We could implement others, but I really doubt anyone writes loops like
2921  // this, and if they did, they would already be constant folded.
2922  return UnknownValue;
2923}
2924
2925/// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
2926/// (which may not be an immediate predecessor) which has exactly one
2927/// successor from which BB is reachable, or null if no such block is
2928/// found.
2929///
2930BasicBlock *
2931ScalarEvolutionsImpl::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
2932  // If the block has a unique predecessor, the predecessor must have
2933  // no other successors from which BB is reachable.
2934  if (BasicBlock *Pred = BB->getSinglePredecessor())
2935    return Pred;
2936
2937  // A loop's header is defined to be a block that dominates the loop.
2938  // If the loop has a preheader, it must be a block that has exactly
2939  // one successor that can reach BB. This is slightly more strict
2940  // than necessary, but works if critical edges are split.
2941  if (Loop *L = LI.getLoopFor(BB))
2942    return L->getLoopPreheader();
2943
2944  return 0;
2945}
2946
2947/// isLoopGuardedByCond - Test whether entry to the loop is protected by
2948/// a conditional between LHS and RHS.
2949bool ScalarEvolutionsImpl::isLoopGuardedByCond(const Loop *L,
2950                                               ICmpInst::Predicate Pred,
2951                                               SCEV *LHS, SCEV *RHS) {
2952  BasicBlock *Preheader = L->getLoopPreheader();
2953  BasicBlock *PreheaderDest = L->getHeader();
2954
2955  // Starting at the preheader, climb up the predecessor chain, as long as
2956  // there are predecessors that can be found that have unique successors
2957  // leading to the original header.
2958  for (; Preheader;
2959       PreheaderDest = Preheader,
2960       Preheader = getPredecessorWithUniqueSuccessorForBB(Preheader)) {
2961
2962    BranchInst *LoopEntryPredicate =
2963      dyn_cast<BranchInst>(Preheader->getTerminator());
2964    if (!LoopEntryPredicate ||
2965        LoopEntryPredicate->isUnconditional())
2966      continue;
2967
2968    ICmpInst *ICI = dyn_cast<ICmpInst>(LoopEntryPredicate->getCondition());
2969    if (!ICI) continue;
2970
2971    // Now that we found a conditional branch that dominates the loop, check to
2972    // see if it is the comparison we are looking for.
2973    Value *PreCondLHS = ICI->getOperand(0);
2974    Value *PreCondRHS = ICI->getOperand(1);
2975    ICmpInst::Predicate Cond;
2976    if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest)
2977      Cond = ICI->getPredicate();
2978    else
2979      Cond = ICI->getInversePredicate();
2980
2981    if (Cond == Pred)
2982      ; // An exact match.
2983    else if (!ICmpInst::isTrueWhenEqual(Cond) && Pred == ICmpInst::ICMP_NE)
2984      ; // The actual condition is beyond sufficient.
2985    else
2986      // Check a few special cases.
2987      switch (Cond) {
2988      case ICmpInst::ICMP_UGT:
2989        if (Pred == ICmpInst::ICMP_ULT) {
2990          std::swap(PreCondLHS, PreCondRHS);
2991          Cond = ICmpInst::ICMP_ULT;
2992          break;
2993        }
2994        continue;
2995      case ICmpInst::ICMP_SGT:
2996        if (Pred == ICmpInst::ICMP_SLT) {
2997          std::swap(PreCondLHS, PreCondRHS);
2998          Cond = ICmpInst::ICMP_SLT;
2999          break;
3000        }
3001        continue;
3002      case ICmpInst::ICMP_NE:
3003        // Expressions like (x >u 0) are often canonicalized to (x != 0),
3004        // so check for this case by checking if the NE is comparing against
3005        // a minimum or maximum constant.
3006        if (!ICmpInst::isTrueWhenEqual(Pred))
3007          if (ConstantInt *CI = dyn_cast<ConstantInt>(PreCondRHS)) {
3008            const APInt &A = CI->getValue();
3009            switch (Pred) {
3010            case ICmpInst::ICMP_SLT:
3011              if (A.isMaxSignedValue()) break;
3012              continue;
3013            case ICmpInst::ICMP_SGT:
3014              if (A.isMinSignedValue()) break;
3015              continue;
3016            case ICmpInst::ICMP_ULT:
3017              if (A.isMaxValue()) break;
3018              continue;
3019            case ICmpInst::ICMP_UGT:
3020              if (A.isMinValue()) break;
3021              continue;
3022            default:
3023              continue;
3024            }
3025            Cond = ICmpInst::ICMP_NE;
3026            // NE is symmetric but the original comparison may not be. Swap
3027            // the operands if necessary so that they match below.
3028            if (isa<SCEVConstant>(LHS))
3029              std::swap(PreCondLHS, PreCondRHS);
3030            break;
3031          }
3032        continue;
3033      default:
3034        // We weren't able to reconcile the condition.
3035        continue;
3036      }
3037
3038    if (!PreCondLHS->getType()->isInteger()) continue;
3039
3040    SCEVHandle PreCondLHSSCEV = getSCEV(PreCondLHS);
3041    SCEVHandle PreCondRHSSCEV = getSCEV(PreCondRHS);
3042    if ((LHS == PreCondLHSSCEV && RHS == PreCondRHSSCEV) ||
3043        (LHS == SE.getNotSCEV(PreCondRHSSCEV) &&
3044         RHS == SE.getNotSCEV(PreCondLHSSCEV)))
3045      return true;
3046  }
3047
3048  return false;
3049}
3050
3051/// HowManyLessThans - Return the number of times a backedge containing the
3052/// specified less-than comparison will execute.  If not computable, return
3053/// UnknownValue.
3054SCEVHandle ScalarEvolutionsImpl::
3055HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned) {
3056  // Only handle:  "ADDREC < LoopInvariant".
3057  if (!RHS->isLoopInvariant(L)) return UnknownValue;
3058
3059  SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS);
3060  if (!AddRec || AddRec->getLoop() != L)
3061    return UnknownValue;
3062
3063  if (AddRec->isAffine()) {
3064    // FORNOW: We only support unit strides.
3065    SCEVHandle One = SE.getIntegerSCEV(1, RHS->getType());
3066    if (AddRec->getOperand(1) != One)
3067      return UnknownValue;
3068
3069    // We know the LHS is of the form {n,+,1} and the RHS is some loop-invariant
3070    // m.  So, we count the number of iterations in which {n,+,1} < m is true.
3071    // Note that we cannot simply return max(m-n,0) because it's not safe to
3072    // treat m-n as signed nor unsigned due to overflow possibility.
3073
3074    // First, we get the value of the LHS in the first iteration: n
3075    SCEVHandle Start = AddRec->getOperand(0);
3076
3077    if (isLoopGuardedByCond(L,
3078                            isSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT,
3079                            SE.getMinusSCEV(AddRec->getOperand(0), One), RHS)) {
3080      // Since we know that the condition is true in order to enter the loop,
3081      // we know that it will run exactly m-n times.
3082      return SE.getMinusSCEV(RHS, Start);
3083    } else {
3084      // Then, we get the value of the LHS in the first iteration in which the
3085      // above condition doesn't hold.  This equals to max(m,n).
3086      SCEVHandle End = isSigned ? SE.getSMaxExpr(RHS, Start)
3087                                : SE.getUMaxExpr(RHS, Start);
3088
3089      // Finally, we subtract these two values to get the number of times the
3090      // backedge is executed: max(m,n)-n.
3091      return SE.getMinusSCEV(End, Start);
3092    }
3093  }
3094
3095  return UnknownValue;
3096}
3097
3098/// getNumIterationsInRange - Return the number of iterations of this loop that
3099/// produce values in the specified constant range.  Another way of looking at
3100/// this is that it returns the first iteration number where the value is not in
3101/// the condition, thus computing the exit count. If the iteration count can't
3102/// be computed, an instance of SCEVCouldNotCompute is returned.
3103SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
3104                                                   ScalarEvolution &SE) const {
3105  if (Range.isFullSet())  // Infinite loop.
3106    return SE.getCouldNotCompute();
3107
3108  // If the start is a non-zero constant, shift the range to simplify things.
3109  if (SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
3110    if (!SC->getValue()->isZero()) {
3111      std::vector<SCEVHandle> Operands(op_begin(), op_end());
3112      Operands[0] = SE.getIntegerSCEV(0, SC->getType());
3113      SCEVHandle Shifted = SE.getAddRecExpr(Operands, getLoop());
3114      if (SCEVAddRecExpr *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
3115        return ShiftedAddRec->getNumIterationsInRange(
3116                           Range.subtract(SC->getValue()->getValue()), SE);
3117      // This is strange and shouldn't happen.
3118      return SE.getCouldNotCompute();
3119    }
3120
3121  // The only time we can solve this is when we have all constant indices.
3122  // Otherwise, we cannot determine the overflow conditions.
3123  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
3124    if (!isa<SCEVConstant>(getOperand(i)))
3125      return SE.getCouldNotCompute();
3126
3127
3128  // Okay at this point we know that all elements of the chrec are constants and
3129  // that the start element is zero.
3130
3131  // First check to see if the range contains zero.  If not, the first
3132  // iteration exits.
3133  unsigned BitWidth = SE.getTargetData().getTypeSizeInBits(getType());
3134  if (!Range.contains(APInt(BitWidth, 0)))
3135    return SE.getConstant(ConstantInt::get(getType(),0));
3136
3137  if (isAffine()) {
3138    // If this is an affine expression then we have this situation:
3139    //   Solve {0,+,A} in Range  ===  Ax in Range
3140
3141    // We know that zero is in the range.  If A is positive then we know that
3142    // the upper value of the range must be the first possible exit value.
3143    // If A is negative then the lower of the range is the last possible loop
3144    // value.  Also note that we already checked for a full range.
3145    APInt One(BitWidth,1);
3146    APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
3147    APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
3148
3149    // The exit value should be (End+A)/A.
3150    APInt ExitVal = (End + A).udiv(A);
3151    ConstantInt *ExitValue = ConstantInt::get(ExitVal);
3152
3153    // Evaluate at the exit value.  If we really did fall out of the valid
3154    // range, then we computed our trip count, otherwise wrap around or other
3155    // things must have happened.
3156    ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
3157    if (Range.contains(Val->getValue()))
3158      return SE.getCouldNotCompute();  // Something strange happened
3159
3160    // Ensure that the previous value is in the range.  This is a sanity check.
3161    assert(Range.contains(
3162           EvaluateConstantChrecAtConstant(this,
3163           ConstantInt::get(ExitVal - One), SE)->getValue()) &&
3164           "Linear scev computation is off in a bad way!");
3165    return SE.getConstant(ExitValue);
3166  } else if (isQuadratic()) {
3167    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
3168    // quadratic equation to solve it.  To do this, we must frame our problem in
3169    // terms of figuring out when zero is crossed, instead of when
3170    // Range.getUpper() is crossed.
3171    std::vector<SCEVHandle> NewOps(op_begin(), op_end());
3172    NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper()));
3173    SCEVHandle NewAddRec = SE.getAddRecExpr(NewOps, getLoop());
3174
3175    // Next, solve the constructed addrec
3176    std::pair<SCEVHandle,SCEVHandle> Roots =
3177      SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
3178    SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
3179    SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
3180    if (R1) {
3181      // Pick the smallest positive root value.
3182      if (ConstantInt *CB =
3183          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
3184                                   R1->getValue(), R2->getValue()))) {
3185        if (CB->getZExtValue() == false)
3186          std::swap(R1, R2);   // R1 is the minimum root now.
3187
3188        // Make sure the root is not off by one.  The returned iteration should
3189        // not be in the range, but the previous one should be.  When solving
3190        // for "X*X < 5", for example, we should not return a root of 2.
3191        ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this,
3192                                                             R1->getValue(),
3193                                                             SE);
3194        if (Range.contains(R1Val->getValue())) {
3195          // The next iteration must be out of the range...
3196          ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()+1);
3197
3198          R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
3199          if (!Range.contains(R1Val->getValue()))
3200            return SE.getConstant(NextVal);
3201          return SE.getCouldNotCompute();  // Something strange happened
3202        }
3203
3204        // If R1 was not in the range, then it is a good return value.  Make
3205        // sure that R1-1 WAS in the range though, just in case.
3206        ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()-1);
3207        R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
3208        if (Range.contains(R1Val->getValue()))
3209          return R1;
3210        return SE.getCouldNotCompute();  // Something strange happened
3211      }
3212    }
3213  }
3214
3215  return SE.getCouldNotCompute();
3216}
3217
3218
3219
3220//===----------------------------------------------------------------------===//
3221//                   ScalarEvolution Class Implementation
3222//===----------------------------------------------------------------------===//
3223
3224bool ScalarEvolution::runOnFunction(Function &F) {
3225  Impl = new ScalarEvolutionsImpl(*this, F,
3226                                  getAnalysis<LoopInfo>(),
3227                                  getAnalysis<TargetData>());
3228  return false;
3229}
3230
3231void ScalarEvolution::releaseMemory() {
3232  delete (ScalarEvolutionsImpl*)Impl;
3233  Impl = 0;
3234}
3235
3236void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
3237  AU.setPreservesAll();
3238  AU.addRequiredTransitive<LoopInfo>();
3239  AU.addRequiredTransitive<TargetData>();
3240}
3241
3242const TargetData &ScalarEvolution::getTargetData() const {
3243  return ((ScalarEvolutionsImpl*)Impl)->getTargetData();
3244}
3245
3246SCEVHandle ScalarEvolution::getCouldNotCompute() {
3247  return ((ScalarEvolutionsImpl*)Impl)->getCouldNotCompute();
3248}
3249
3250SCEVHandle ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) {
3251  return ((ScalarEvolutionsImpl*)Impl)->getIntegerSCEV(Val, Ty);
3252}
3253
3254SCEVHandle ScalarEvolution::getSCEV(Value *V) const {
3255  return ((ScalarEvolutionsImpl*)Impl)->getSCEV(V);
3256}
3257
3258/// hasSCEV - Return true if the SCEV for this value has already been
3259/// computed.
3260bool ScalarEvolution::hasSCEV(Value *V) const {
3261  return ((ScalarEvolutionsImpl*)Impl)->hasSCEV(V);
3262}
3263
3264
3265/// setSCEV - Insert the specified SCEV into the map of current SCEVs for
3266/// the specified value.
3267void ScalarEvolution::setSCEV(Value *V, const SCEVHandle &H) {
3268  ((ScalarEvolutionsImpl*)Impl)->setSCEV(V, H);
3269}
3270
3271/// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
3272///
3273SCEVHandle ScalarEvolution::getNegativeSCEV(const SCEVHandle &V) {
3274  return ((ScalarEvolutionsImpl*)Impl)->getNegativeSCEV(V);
3275}
3276
3277/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
3278///
3279SCEVHandle ScalarEvolution::getNotSCEV(const SCEVHandle &V) {
3280  return ((ScalarEvolutionsImpl*)Impl)->getNotSCEV(V);
3281}
3282
3283/// getMinusSCEV - Return a SCEV corresponding to LHS - RHS.
3284///
3285SCEVHandle ScalarEvolution::getMinusSCEV(const SCEVHandle &LHS,
3286                                         const SCEVHandle &RHS) {
3287  return ((ScalarEvolutionsImpl*)Impl)->getMinusSCEV(LHS, RHS);
3288}
3289
3290/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion
3291/// of the input value to the specified type.  If the type must be
3292/// extended, it is zero extended.
3293SCEVHandle ScalarEvolution::getTruncateOrZeroExtend(const SCEVHandle &V,
3294                                                    const Type *Ty) {
3295  return ((ScalarEvolutionsImpl*)Impl)->getTruncateOrZeroExtend(V, Ty);
3296}
3297
3298/// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion
3299/// of the input value to the specified type.  If the type must be
3300/// extended, it is sign extended.
3301SCEVHandle ScalarEvolution::getTruncateOrSignExtend(const SCEVHandle &V,
3302                                                    const Type *Ty) {
3303  return ((ScalarEvolutionsImpl*)Impl)->getTruncateOrSignExtend(V, Ty);
3304}
3305
3306
3307bool ScalarEvolution::isLoopGuardedByCond(const Loop *L,
3308                                          ICmpInst::Predicate Pred,
3309                                          SCEV *LHS, SCEV *RHS) {
3310  return ((ScalarEvolutionsImpl*)Impl)->isLoopGuardedByCond(L, Pred,
3311                                                            LHS, RHS);
3312}
3313
3314SCEVHandle ScalarEvolution::getBackedgeTakenCount(const Loop *L) const {
3315  return ((ScalarEvolutionsImpl*)Impl)->getBackedgeTakenCount(L);
3316}
3317
3318bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) const {
3319  return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
3320}
3321
3322void ScalarEvolution::forgetLoopBackedgeTakenCount(const Loop *L) {
3323  return ((ScalarEvolutionsImpl*)Impl)->forgetLoopBackedgeTakenCount(L);
3324}
3325
3326SCEVHandle ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) const {
3327  return ((ScalarEvolutionsImpl*)Impl)->getSCEVAtScope(getSCEV(V), L);
3328}
3329
3330void ScalarEvolution::deleteValueFromRecords(Value *V) const {
3331  return ((ScalarEvolutionsImpl*)Impl)->deleteValueFromRecords(V);
3332}
3333
3334static void PrintLoopInfo(raw_ostream &OS, const ScalarEvolution *SE,
3335                          const Loop *L) {
3336  // Print all inner loops first
3337  for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
3338    PrintLoopInfo(OS, SE, *I);
3339
3340  OS << "Loop " << L->getHeader()->getName() << ": ";
3341
3342  SmallVector<BasicBlock*, 8> ExitBlocks;
3343  L->getExitBlocks(ExitBlocks);
3344  if (ExitBlocks.size() != 1)
3345    OS << "<multiple exits> ";
3346
3347  if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
3348    OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L);
3349  } else {
3350    OS << "Unpredictable backedge-taken count. ";
3351  }
3352
3353  OS << "\n";
3354}
3355
3356void ScalarEvolution::print(raw_ostream &OS, const Module* ) const {
3357  Function &F = ((ScalarEvolutionsImpl*)Impl)->F;
3358  LoopInfo &LI = ((ScalarEvolutionsImpl*)Impl)->LI;
3359
3360  OS << "Classifying expressions for: " << F.getName() << "\n";
3361  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
3362    if (I->getType()->isInteger()) {
3363      OS << *I;
3364      OS << "  -->  ";
3365      SCEVHandle SV = getSCEV(&*I);
3366      SV->print(OS);
3367      OS << "\t\t";
3368
3369      if (const Loop *L = LI.getLoopFor((*I).getParent())) {
3370        OS << "Exits: ";
3371        SCEVHandle ExitValue = getSCEVAtScope(&*I, L->getParentLoop());
3372        if (isa<SCEVCouldNotCompute>(ExitValue)) {
3373          OS << "<<Unknown>>";
3374        } else {
3375          OS << *ExitValue;
3376        }
3377      }
3378
3379
3380      OS << "\n";
3381    }
3382
3383  OS << "Determining loop execution counts for: " << F.getName() << "\n";
3384  for (LoopInfo::iterator I = LI.begin(), E = LI.end(); I != E; ++I)
3385    PrintLoopInfo(OS, this, *I);
3386}
3387
3388void ScalarEvolution::print(std::ostream &o, const Module *M) const {
3389  raw_os_ostream OS(o);
3390  print(OS, M);
3391}
3392