ScalarEvolution.cpp revision 5be18e84766fb495b0bde3c8244c1df459a18683
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file contains the implementation of the scalar evolution analysis
11// engine, which is used primarily to analyze expressions involving induction
12// variables in loops.
13//
14// There are several aspects to this library.  First is the representation of
15// scalar expressions, which are represented as subclasses of the SCEV class.
16// These classes are used to represent certain types of subexpressions that we
17// can handle.  These classes are reference counted, managed by the SCEVHandle
18// class.  We only create one SCEV of a particular shape, so pointer-comparisons
19// for equality are legal.
20//
21// One important aspect of the SCEV objects is that they are never cyclic, even
22// if there is a cycle in the dataflow for an expression (ie, a PHI node).  If
23// the PHI node is one of the idioms that we can represent (e.g., a polynomial
24// recurrence) then we represent it directly as a recurrence node, otherwise we
25// represent it as a SCEVUnknown node.
26//
27// In addition to being able to represent expressions of various types, we also
28// have folders that are used to build the *canonical* representation for a
29// particular expression.  These folders are capable of using a variety of
30// rewrite rules to simplify the expressions.
31//
32// Once the folders are defined, we can implement the more interesting
33// higher-level code, such as the code that recognizes PHI nodes of various
34// types, computes the execution count of a loop, etc.
35//
36// TODO: We should use these routines and value representations to implement
37// dependence analysis!
38//
39//===----------------------------------------------------------------------===//
40//
41// There are several good references for the techniques used in this analysis.
42//
43//  Chains of recurrences -- a method to expedite the evaluation
44//  of closed-form functions
45//  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
46//
47//  On computational properties of chains of recurrences
48//  Eugene V. Zima
49//
50//  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
51//  Robert A. van Engelen
52//
53//  Efficient Symbolic Analysis for Optimizing Compilers
54//  Robert A. van Engelen
55//
56//  Using the chains of recurrences algebra for data dependence testing and
57//  induction variable substitution
58//  MS Thesis, Johnie Birch
59//
60//===----------------------------------------------------------------------===//
61
62#define DEBUG_TYPE "scalar-evolution"
63#include "llvm/Analysis/ScalarEvolutionExpressions.h"
64#include "llvm/Constants.h"
65#include "llvm/DerivedTypes.h"
66#include "llvm/GlobalVariable.h"
67#include "llvm/Instructions.h"
68#include "llvm/Analysis/ConstantFolding.h"
69#include "llvm/Analysis/Dominators.h"
70#include "llvm/Analysis/LoopInfo.h"
71#include "llvm/Assembly/Writer.h"
72#include "llvm/Target/TargetData.h"
73#include "llvm/Support/CommandLine.h"
74#include "llvm/Support/Compiler.h"
75#include "llvm/Support/ConstantRange.h"
76#include "llvm/Support/GetElementPtrTypeIterator.h"
77#include "llvm/Support/InstIterator.h"
78#include "llvm/Support/ManagedStatic.h"
79#include "llvm/Support/MathExtras.h"
80#include "llvm/Support/raw_ostream.h"
81#include "llvm/ADT/Statistic.h"
82#include "llvm/ADT/STLExtras.h"
83#include <ostream>
84#include <algorithm>
85using namespace llvm;
86
87STATISTIC(NumArrayLenItCounts,
88          "Number of trip counts computed with array length");
89STATISTIC(NumTripCountsComputed,
90          "Number of loops with predictable loop counts");
91STATISTIC(NumTripCountsNotComputed,
92          "Number of loops without predictable loop counts");
93STATISTIC(NumBruteForceTripCountsComputed,
94          "Number of loops with trip counts computed by force");
95
96static cl::opt<unsigned>
97MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
98                        cl::desc("Maximum number of iterations SCEV will "
99                                 "symbolically execute a constant derived loop"),
100                        cl::init(100));
101
102static RegisterPass<ScalarEvolution>
103R("scalar-evolution", "Scalar Evolution Analysis", false, true);
104char ScalarEvolution::ID = 0;
105
106//===----------------------------------------------------------------------===//
107//                           SCEV class definitions
108//===----------------------------------------------------------------------===//
109
110//===----------------------------------------------------------------------===//
111// Implementation of the SCEV class.
112//
113SCEV::~SCEV() {}
114void SCEV::dump() const {
115  print(errs());
116  errs() << '\n';
117}
118
119void SCEV::print(std::ostream &o) const {
120  raw_os_ostream OS(o);
121  print(OS);
122}
123
124bool SCEV::isZero() const {
125  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
126    return SC->getValue()->isZero();
127  return false;
128}
129
130bool SCEV::isOne() const {
131  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
132    return SC->getValue()->isOne();
133  return false;
134}
135
136SCEVCouldNotCompute::SCEVCouldNotCompute() : SCEV(scCouldNotCompute) {}
137SCEVCouldNotCompute::~SCEVCouldNotCompute() {}
138
139bool SCEVCouldNotCompute::isLoopInvariant(const Loop *L) const {
140  assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
141  return false;
142}
143
144const Type *SCEVCouldNotCompute::getType() const {
145  assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
146  return 0;
147}
148
149bool SCEVCouldNotCompute::hasComputableLoopEvolution(const Loop *L) const {
150  assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
151  return false;
152}
153
154SCEVHandle SCEVCouldNotCompute::
155replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
156                                  const SCEVHandle &Conc,
157                                  ScalarEvolution &SE) const {
158  return this;
159}
160
161void SCEVCouldNotCompute::print(raw_ostream &OS) const {
162  OS << "***COULDNOTCOMPUTE***";
163}
164
165bool SCEVCouldNotCompute::classof(const SCEV *S) {
166  return S->getSCEVType() == scCouldNotCompute;
167}
168
169
170// SCEVConstants - Only allow the creation of one SCEVConstant for any
171// particular value.  Don't use a SCEVHandle here, or else the object will
172// never be deleted!
173static ManagedStatic<std::map<ConstantInt*, SCEVConstant*> > SCEVConstants;
174
175
176SCEVConstant::~SCEVConstant() {
177  SCEVConstants->erase(V);
178}
179
180SCEVHandle ScalarEvolution::getConstant(ConstantInt *V) {
181  SCEVConstant *&R = (*SCEVConstants)[V];
182  if (R == 0) R = new SCEVConstant(V);
183  return R;
184}
185
186SCEVHandle ScalarEvolution::getConstant(const APInt& Val) {
187  return getConstant(ConstantInt::get(Val));
188}
189
190const Type *SCEVConstant::getType() const { return V->getType(); }
191
192void SCEVConstant::print(raw_ostream &OS) const {
193  WriteAsOperand(OS, V, false);
194}
195
196SCEVCastExpr::SCEVCastExpr(unsigned SCEVTy,
197                           const SCEVHandle &op, const Type *ty)
198  : SCEV(SCEVTy), Op(op), Ty(ty) {}
199
200SCEVCastExpr::~SCEVCastExpr() {}
201
202bool SCEVCastExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
203  return Op->dominates(BB, DT);
204}
205
206// SCEVTruncates - Only allow the creation of one SCEVTruncateExpr for any
207// particular input.  Don't use a SCEVHandle here, or else the object will
208// never be deleted!
209static ManagedStatic<std::map<std::pair<const SCEV*, const Type*>,
210                     SCEVTruncateExpr*> > SCEVTruncates;
211
212SCEVTruncateExpr::SCEVTruncateExpr(const SCEVHandle &op, const Type *ty)
213  : SCEVCastExpr(scTruncate, op, ty) {
214  assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
215         (Ty->isInteger() || isa<PointerType>(Ty)) &&
216         "Cannot truncate non-integer value!");
217}
218
219SCEVTruncateExpr::~SCEVTruncateExpr() {
220  SCEVTruncates->erase(std::make_pair(Op, Ty));
221}
222
223void SCEVTruncateExpr::print(raw_ostream &OS) const {
224  OS << "(trunc " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
225}
226
227// SCEVZeroExtends - Only allow the creation of one SCEVZeroExtendExpr for any
228// particular input.  Don't use a SCEVHandle here, or else the object will never
229// be deleted!
230static ManagedStatic<std::map<std::pair<const SCEV*, const Type*>,
231                     SCEVZeroExtendExpr*> > SCEVZeroExtends;
232
233SCEVZeroExtendExpr::SCEVZeroExtendExpr(const SCEVHandle &op, const Type *ty)
234  : SCEVCastExpr(scZeroExtend, op, ty) {
235  assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
236         (Ty->isInteger() || isa<PointerType>(Ty)) &&
237         "Cannot zero extend non-integer value!");
238}
239
240SCEVZeroExtendExpr::~SCEVZeroExtendExpr() {
241  SCEVZeroExtends->erase(std::make_pair(Op, Ty));
242}
243
244void SCEVZeroExtendExpr::print(raw_ostream &OS) const {
245  OS << "(zext " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
246}
247
248// SCEVSignExtends - Only allow the creation of one SCEVSignExtendExpr for any
249// particular input.  Don't use a SCEVHandle here, or else the object will never
250// be deleted!
251static ManagedStatic<std::map<std::pair<const SCEV*, const Type*>,
252                     SCEVSignExtendExpr*> > SCEVSignExtends;
253
254SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty)
255  : SCEVCastExpr(scSignExtend, op, ty) {
256  assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
257         (Ty->isInteger() || isa<PointerType>(Ty)) &&
258         "Cannot sign extend non-integer value!");
259}
260
261SCEVSignExtendExpr::~SCEVSignExtendExpr() {
262  SCEVSignExtends->erase(std::make_pair(Op, Ty));
263}
264
265void SCEVSignExtendExpr::print(raw_ostream &OS) const {
266  OS << "(sext " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
267}
268
269// SCEVCommExprs - Only allow the creation of one SCEVCommutativeExpr for any
270// particular input.  Don't use a SCEVHandle here, or else the object will never
271// be deleted!
272static ManagedStatic<std::map<std::pair<unsigned, std::vector<const SCEV*> >,
273                     SCEVCommutativeExpr*> > SCEVCommExprs;
274
275SCEVCommutativeExpr::~SCEVCommutativeExpr() {
276  std::vector<const SCEV*> SCEVOps(Operands.begin(), Operands.end());
277  SCEVCommExprs->erase(std::make_pair(getSCEVType(), SCEVOps));
278}
279
280void SCEVCommutativeExpr::print(raw_ostream &OS) const {
281  assert(Operands.size() > 1 && "This plus expr shouldn't exist!");
282  const char *OpStr = getOperationStr();
283  OS << "(" << *Operands[0];
284  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
285    OS << OpStr << *Operands[i];
286  OS << ")";
287}
288
289SCEVHandle SCEVCommutativeExpr::
290replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
291                                  const SCEVHandle &Conc,
292                                  ScalarEvolution &SE) const {
293  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
294    SCEVHandle H =
295      getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE);
296    if (H != getOperand(i)) {
297      std::vector<SCEVHandle> NewOps;
298      NewOps.reserve(getNumOperands());
299      for (unsigned j = 0; j != i; ++j)
300        NewOps.push_back(getOperand(j));
301      NewOps.push_back(H);
302      for (++i; i != e; ++i)
303        NewOps.push_back(getOperand(i)->
304                         replaceSymbolicValuesWithConcrete(Sym, Conc, SE));
305
306      if (isa<SCEVAddExpr>(this))
307        return SE.getAddExpr(NewOps);
308      else if (isa<SCEVMulExpr>(this))
309        return SE.getMulExpr(NewOps);
310      else if (isa<SCEVSMaxExpr>(this))
311        return SE.getSMaxExpr(NewOps);
312      else if (isa<SCEVUMaxExpr>(this))
313        return SE.getUMaxExpr(NewOps);
314      else
315        assert(0 && "Unknown commutative expr!");
316    }
317  }
318  return this;
319}
320
321bool SCEVNAryExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
322  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
323    if (!getOperand(i)->dominates(BB, DT))
324      return false;
325  }
326  return true;
327}
328
329
330// SCEVUDivs - Only allow the creation of one SCEVUDivExpr for any particular
331// input.  Don't use a SCEVHandle here, or else the object will never be
332// deleted!
333static ManagedStatic<std::map<std::pair<const SCEV*, const SCEV*>,
334                     SCEVUDivExpr*> > SCEVUDivs;
335
336SCEVUDivExpr::~SCEVUDivExpr() {
337  SCEVUDivs->erase(std::make_pair(LHS, RHS));
338}
339
340bool SCEVUDivExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
341  return LHS->dominates(BB, DT) && RHS->dominates(BB, DT);
342}
343
344void SCEVUDivExpr::print(raw_ostream &OS) const {
345  OS << "(" << *LHS << " /u " << *RHS << ")";
346}
347
348const Type *SCEVUDivExpr::getType() const {
349  return LHS->getType();
350}
351
352// SCEVAddRecExprs - Only allow the creation of one SCEVAddRecExpr for any
353// particular input.  Don't use a SCEVHandle here, or else the object will never
354// be deleted!
355static ManagedStatic<std::map<std::pair<const Loop *,
356                                        std::vector<const SCEV*> >,
357                     SCEVAddRecExpr*> > SCEVAddRecExprs;
358
359SCEVAddRecExpr::~SCEVAddRecExpr() {
360  std::vector<const SCEV*> SCEVOps(Operands.begin(), Operands.end());
361  SCEVAddRecExprs->erase(std::make_pair(L, SCEVOps));
362}
363
364SCEVHandle SCEVAddRecExpr::
365replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
366                                  const SCEVHandle &Conc,
367                                  ScalarEvolution &SE) const {
368  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
369    SCEVHandle H =
370      getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE);
371    if (H != getOperand(i)) {
372      std::vector<SCEVHandle> NewOps;
373      NewOps.reserve(getNumOperands());
374      for (unsigned j = 0; j != i; ++j)
375        NewOps.push_back(getOperand(j));
376      NewOps.push_back(H);
377      for (++i; i != e; ++i)
378        NewOps.push_back(getOperand(i)->
379                         replaceSymbolicValuesWithConcrete(Sym, Conc, SE));
380
381      return SE.getAddRecExpr(NewOps, L);
382    }
383  }
384  return this;
385}
386
387
388bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const {
389  // This recurrence is invariant w.r.t to QueryLoop iff QueryLoop doesn't
390  // contain L and if the start is invariant.
391  return !QueryLoop->contains(L->getHeader()) &&
392         getOperand(0)->isLoopInvariant(QueryLoop);
393}
394
395
396void SCEVAddRecExpr::print(raw_ostream &OS) const {
397  OS << "{" << *Operands[0];
398  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
399    OS << ",+," << *Operands[i];
400  OS << "}<" << L->getHeader()->getName() + ">";
401}
402
403// SCEVUnknowns - Only allow the creation of one SCEVUnknown for any particular
404// value.  Don't use a SCEVHandle here, or else the object will never be
405// deleted!
406static ManagedStatic<std::map<Value*, SCEVUnknown*> > SCEVUnknowns;
407
408SCEVUnknown::~SCEVUnknown() { SCEVUnknowns->erase(V); }
409
410bool SCEVUnknown::isLoopInvariant(const Loop *L) const {
411  // All non-instruction values are loop invariant.  All instructions are loop
412  // invariant if they are not contained in the specified loop.
413  if (Instruction *I = dyn_cast<Instruction>(V))
414    return !L->contains(I->getParent());
415  return true;
416}
417
418bool SCEVUnknown::dominates(BasicBlock *BB, DominatorTree *DT) const {
419  if (Instruction *I = dyn_cast<Instruction>(getValue()))
420    return DT->dominates(I->getParent(), BB);
421  return true;
422}
423
424const Type *SCEVUnknown::getType() const {
425  return V->getType();
426}
427
428void SCEVUnknown::print(raw_ostream &OS) const {
429  WriteAsOperand(OS, V, false);
430}
431
432//===----------------------------------------------------------------------===//
433//                               SCEV Utilities
434//===----------------------------------------------------------------------===//
435
436namespace {
437  /// SCEVComplexityCompare - Return true if the complexity of the LHS is less
438  /// than the complexity of the RHS.  This comparator is used to canonicalize
439  /// expressions.
440  class VISIBILITY_HIDDEN SCEVComplexityCompare {
441    LoopInfo *LI;
442  public:
443    explicit SCEVComplexityCompare(LoopInfo *li) : LI(li) {}
444
445    bool operator()(const SCEV *LHS, const SCEV *RHS) const {
446      // Primarily, sort the SCEVs by their getSCEVType().
447      if (LHS->getSCEVType() != RHS->getSCEVType())
448        return LHS->getSCEVType() < RHS->getSCEVType();
449
450      // Aside from the getSCEVType() ordering, the particular ordering
451      // isn't very important except that it's beneficial to be consistent,
452      // so that (a + b) and (b + a) don't end up as different expressions.
453
454      // Sort SCEVUnknown values with some loose heuristics. TODO: This is
455      // not as complete as it could be.
456      if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS)) {
457        const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
458
459        // Order pointer values after integer values. This helps SCEVExpander
460        // form GEPs.
461        if (isa<PointerType>(LU->getType()) && !isa<PointerType>(RU->getType()))
462          return false;
463        if (isa<PointerType>(RU->getType()) && !isa<PointerType>(LU->getType()))
464          return true;
465
466        // Compare getValueID values.
467        if (LU->getValue()->getValueID() != RU->getValue()->getValueID())
468          return LU->getValue()->getValueID() < RU->getValue()->getValueID();
469
470        // Sort arguments by their position.
471        if (const Argument *LA = dyn_cast<Argument>(LU->getValue())) {
472          const Argument *RA = cast<Argument>(RU->getValue());
473          return LA->getArgNo() < RA->getArgNo();
474        }
475
476        // For instructions, compare their loop depth, and their opcode.
477        // This is pretty loose.
478        if (Instruction *LV = dyn_cast<Instruction>(LU->getValue())) {
479          Instruction *RV = cast<Instruction>(RU->getValue());
480
481          // Compare loop depths.
482          if (LI->getLoopDepth(LV->getParent()) !=
483              LI->getLoopDepth(RV->getParent()))
484            return LI->getLoopDepth(LV->getParent()) <
485                   LI->getLoopDepth(RV->getParent());
486
487          // Compare opcodes.
488          if (LV->getOpcode() != RV->getOpcode())
489            return LV->getOpcode() < RV->getOpcode();
490
491          // Compare the number of operands.
492          if (LV->getNumOperands() != RV->getNumOperands())
493            return LV->getNumOperands() < RV->getNumOperands();
494        }
495
496        return false;
497      }
498
499      // Constant sorting doesn't matter since they'll be folded.
500      if (isa<SCEVConstant>(LHS))
501        return false;
502
503      // Lexicographically compare n-ary expressions.
504      if (const SCEVNAryExpr *LC = dyn_cast<SCEVNAryExpr>(LHS)) {
505        const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
506        for (unsigned i = 0, e = LC->getNumOperands(); i != e; ++i) {
507          if (i >= RC->getNumOperands())
508            return false;
509          if (operator()(LC->getOperand(i), RC->getOperand(i)))
510            return true;
511          if (operator()(RC->getOperand(i), LC->getOperand(i)))
512            return false;
513        }
514        return LC->getNumOperands() < RC->getNumOperands();
515      }
516
517      // Lexicographically compare udiv expressions.
518      if (const SCEVUDivExpr *LC = dyn_cast<SCEVUDivExpr>(LHS)) {
519        const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
520        if (operator()(LC->getLHS(), RC->getLHS()))
521          return true;
522        if (operator()(RC->getLHS(), LC->getLHS()))
523          return false;
524        if (operator()(LC->getRHS(), RC->getRHS()))
525          return true;
526        if (operator()(RC->getRHS(), LC->getRHS()))
527          return false;
528        return false;
529      }
530
531      // Compare cast expressions by operand.
532      if (const SCEVCastExpr *LC = dyn_cast<SCEVCastExpr>(LHS)) {
533        const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
534        return operator()(LC->getOperand(), RC->getOperand());
535      }
536
537      assert(0 && "Unknown SCEV kind!");
538      return false;
539    }
540  };
541}
542
543/// GroupByComplexity - Given a list of SCEV objects, order them by their
544/// complexity, and group objects of the same complexity together by value.
545/// When this routine is finished, we know that any duplicates in the vector are
546/// consecutive and that complexity is monotonically increasing.
547///
548/// Note that we go take special precautions to ensure that we get determinstic
549/// results from this routine.  In other words, we don't want the results of
550/// this to depend on where the addresses of various SCEV objects happened to
551/// land in memory.
552///
553static void GroupByComplexity(std::vector<SCEVHandle> &Ops,
554                              LoopInfo *LI) {
555  if (Ops.size() < 2) return;  // Noop
556  if (Ops.size() == 2) {
557    // This is the common case, which also happens to be trivially simple.
558    // Special case it.
559    if (SCEVComplexityCompare(LI)(Ops[1], Ops[0]))
560      std::swap(Ops[0], Ops[1]);
561    return;
562  }
563
564  // Do the rough sort by complexity.
565  std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI));
566
567  // Now that we are sorted by complexity, group elements of the same
568  // complexity.  Note that this is, at worst, N^2, but the vector is likely to
569  // be extremely short in practice.  Note that we take this approach because we
570  // do not want to depend on the addresses of the objects we are grouping.
571  for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
572    const SCEV *S = Ops[i];
573    unsigned Complexity = S->getSCEVType();
574
575    // If there are any objects of the same complexity and same value as this
576    // one, group them.
577    for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
578      if (Ops[j] == S) { // Found a duplicate.
579        // Move it to immediately after i'th element.
580        std::swap(Ops[i+1], Ops[j]);
581        ++i;   // no need to rescan it.
582        if (i == e-2) return;  // Done!
583      }
584    }
585  }
586}
587
588
589
590//===----------------------------------------------------------------------===//
591//                      Simple SCEV method implementations
592//===----------------------------------------------------------------------===//
593
594/// BinomialCoefficient - Compute BC(It, K).  The result has width W.
595// Assume, K > 0.
596static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K,
597                                      ScalarEvolution &SE,
598                                      const Type* ResultTy) {
599  // Handle the simplest case efficiently.
600  if (K == 1)
601    return SE.getTruncateOrZeroExtend(It, ResultTy);
602
603  // We are using the following formula for BC(It, K):
604  //
605  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
606  //
607  // Suppose, W is the bitwidth of the return value.  We must be prepared for
608  // overflow.  Hence, we must assure that the result of our computation is
609  // equal to the accurate one modulo 2^W.  Unfortunately, division isn't
610  // safe in modular arithmetic.
611  //
612  // However, this code doesn't use exactly that formula; the formula it uses
613  // is something like the following, where T is the number of factors of 2 in
614  // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
615  // exponentiation:
616  //
617  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
618  //
619  // This formula is trivially equivalent to the previous formula.  However,
620  // this formula can be implemented much more efficiently.  The trick is that
621  // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
622  // arithmetic.  To do exact division in modular arithmetic, all we have
623  // to do is multiply by the inverse.  Therefore, this step can be done at
624  // width W.
625  //
626  // The next issue is how to safely do the division by 2^T.  The way this
627  // is done is by doing the multiplication step at a width of at least W + T
628  // bits.  This way, the bottom W+T bits of the product are accurate. Then,
629  // when we perform the division by 2^T (which is equivalent to a right shift
630  // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get
631  // truncated out after the division by 2^T.
632  //
633  // In comparison to just directly using the first formula, this technique
634  // is much more efficient; using the first formula requires W * K bits,
635  // but this formula less than W + K bits. Also, the first formula requires
636  // a division step, whereas this formula only requires multiplies and shifts.
637  //
638  // It doesn't matter whether the subtraction step is done in the calculation
639  // width or the input iteration count's width; if the subtraction overflows,
640  // the result must be zero anyway.  We prefer here to do it in the width of
641  // the induction variable because it helps a lot for certain cases; CodeGen
642  // isn't smart enough to ignore the overflow, which leads to much less
643  // efficient code if the width of the subtraction is wider than the native
644  // register width.
645  //
646  // (It's possible to not widen at all by pulling out factors of 2 before
647  // the multiplication; for example, K=2 can be calculated as
648  // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
649  // extra arithmetic, so it's not an obvious win, and it gets
650  // much more complicated for K > 3.)
651
652  // Protection from insane SCEVs; this bound is conservative,
653  // but it probably doesn't matter.
654  if (K > 1000)
655    return SE.getCouldNotCompute();
656
657  unsigned W = SE.getTypeSizeInBits(ResultTy);
658
659  // Calculate K! / 2^T and T; we divide out the factors of two before
660  // multiplying for calculating K! / 2^T to avoid overflow.
661  // Other overflow doesn't matter because we only care about the bottom
662  // W bits of the result.
663  APInt OddFactorial(W, 1);
664  unsigned T = 1;
665  for (unsigned i = 3; i <= K; ++i) {
666    APInt Mult(W, i);
667    unsigned TwoFactors = Mult.countTrailingZeros();
668    T += TwoFactors;
669    Mult = Mult.lshr(TwoFactors);
670    OddFactorial *= Mult;
671  }
672
673  // We need at least W + T bits for the multiplication step
674  unsigned CalculationBits = W + T;
675
676  // Calcuate 2^T, at width T+W.
677  APInt DivFactor = APInt(CalculationBits, 1).shl(T);
678
679  // Calculate the multiplicative inverse of K! / 2^T;
680  // this multiplication factor will perform the exact division by
681  // K! / 2^T.
682  APInt Mod = APInt::getSignedMinValue(W+1);
683  APInt MultiplyFactor = OddFactorial.zext(W+1);
684  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
685  MultiplyFactor = MultiplyFactor.trunc(W);
686
687  // Calculate the product, at width T+W
688  const IntegerType *CalculationTy = IntegerType::get(CalculationBits);
689  SCEVHandle Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
690  for (unsigned i = 1; i != K; ++i) {
691    SCEVHandle S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType()));
692    Dividend = SE.getMulExpr(Dividend,
693                             SE.getTruncateOrZeroExtend(S, CalculationTy));
694  }
695
696  // Divide by 2^T
697  SCEVHandle DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
698
699  // Truncate the result, and divide by K! / 2^T.
700
701  return SE.getMulExpr(SE.getConstant(MultiplyFactor),
702                       SE.getTruncateOrZeroExtend(DivResult, ResultTy));
703}
704
705/// evaluateAtIteration - Return the value of this chain of recurrences at
706/// the specified iteration number.  We can evaluate this recurrence by
707/// multiplying each element in the chain by the binomial coefficient
708/// corresponding to it.  In other words, we can evaluate {A,+,B,+,C,+,D} as:
709///
710///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
711///
712/// where BC(It, k) stands for binomial coefficient.
713///
714SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It,
715                                               ScalarEvolution &SE) const {
716  SCEVHandle Result = getStart();
717  for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
718    // The computation is correct in the face of overflow provided that the
719    // multiplication is performed _after_ the evaluation of the binomial
720    // coefficient.
721    SCEVHandle Coeff = BinomialCoefficient(It, i, SE, getType());
722    if (isa<SCEVCouldNotCompute>(Coeff))
723      return Coeff;
724
725    Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
726  }
727  return Result;
728}
729
730//===----------------------------------------------------------------------===//
731//                    SCEV Expression folder implementations
732//===----------------------------------------------------------------------===//
733
734SCEVHandle ScalarEvolution::getTruncateExpr(const SCEVHandle &Op,
735                                            const Type *Ty) {
736  assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
737         "This is not a truncating conversion!");
738  assert(isSCEVable(Ty) &&
739         "This is not a conversion to a SCEVable type!");
740  Ty = getEffectiveSCEVType(Ty);
741
742  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
743    return getUnknown(
744        ConstantExpr::getTrunc(SC->getValue(), Ty));
745
746  // trunc(trunc(x)) --> trunc(x)
747  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
748    return getTruncateExpr(ST->getOperand(), Ty);
749
750  // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
751  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
752    return getTruncateOrSignExtend(SS->getOperand(), Ty);
753
754  // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
755  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
756    return getTruncateOrZeroExtend(SZ->getOperand(), Ty);
757
758  // If the input value is a chrec scev made out of constants, truncate
759  // all of the constants.
760  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
761    std::vector<SCEVHandle> Operands;
762    for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
763      Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty));
764    return getAddRecExpr(Operands, AddRec->getLoop());
765  }
766
767  SCEVTruncateExpr *&Result = (*SCEVTruncates)[std::make_pair(Op, Ty)];
768  if (Result == 0) Result = new SCEVTruncateExpr(Op, Ty);
769  return Result;
770}
771
772SCEVHandle ScalarEvolution::getZeroExtendExpr(const SCEVHandle &Op,
773                                              const Type *Ty) {
774  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
775         "This is not an extending conversion!");
776  assert(isSCEVable(Ty) &&
777         "This is not a conversion to a SCEVable type!");
778  Ty = getEffectiveSCEVType(Ty);
779
780  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) {
781    const Type *IntTy = getEffectiveSCEVType(Ty);
782    Constant *C = ConstantExpr::getZExt(SC->getValue(), IntTy);
783    if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty);
784    return getUnknown(C);
785  }
786
787  // zext(zext(x)) --> zext(x)
788  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
789    return getZeroExtendExpr(SZ->getOperand(), Ty);
790
791  // If the input value is a chrec scev, and we can prove that the value
792  // did not overflow the old, smaller, value, we can zero extend all of the
793  // operands (often constants).  This allows analysis of something like
794  // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
795  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
796    if (AR->isAffine()) {
797      // Check whether the backedge-taken count is SCEVCouldNotCompute.
798      // Note that this serves two purposes: It filters out loops that are
799      // simply not analyzable, and it covers the case where this code is
800      // being called from within backedge-taken count analysis, such that
801      // attempting to ask for the backedge-taken count would likely result
802      // in infinite recursion. In the later case, the analysis code will
803      // cope with a conservative value, and it will take care to purge
804      // that value once it has finished.
805      SCEVHandle MaxBECount = getMaxBackedgeTakenCount(AR->getLoop());
806      if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
807        // Manually compute the final value for AR, checking for
808        // overflow.
809        SCEVHandle Start = AR->getStart();
810        SCEVHandle Step = AR->getStepRecurrence(*this);
811
812        // Check whether the backedge-taken count can be losslessly casted to
813        // the addrec's type. The count is always unsigned.
814        SCEVHandle CastedMaxBECount =
815          getTruncateOrZeroExtend(MaxBECount, Start->getType());
816        SCEVHandle RecastedMaxBECount =
817          getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
818        if (MaxBECount == RecastedMaxBECount) {
819          const Type *WideTy =
820            IntegerType::get(getTypeSizeInBits(Start->getType()) * 2);
821          // Check whether Start+Step*MaxBECount has no unsigned overflow.
822          SCEVHandle ZMul =
823            getMulExpr(CastedMaxBECount,
824                       getTruncateOrZeroExtend(Step, Start->getType()));
825          SCEVHandle Add = getAddExpr(Start, ZMul);
826          SCEVHandle OperandExtendedAdd =
827            getAddExpr(getZeroExtendExpr(Start, WideTy),
828                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
829                                  getZeroExtendExpr(Step, WideTy)));
830          if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd)
831            // Return the expression with the addrec on the outside.
832            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
833                                 getZeroExtendExpr(Step, Ty),
834                                 AR->getLoop());
835
836          // Similar to above, only this time treat the step value as signed.
837          // This covers loops that count down.
838          SCEVHandle SMul =
839            getMulExpr(CastedMaxBECount,
840                       getTruncateOrSignExtend(Step, Start->getType()));
841          Add = getAddExpr(Start, SMul);
842          OperandExtendedAdd =
843            getAddExpr(getZeroExtendExpr(Start, WideTy),
844                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
845                                  getSignExtendExpr(Step, WideTy)));
846          if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd)
847            // Return the expression with the addrec on the outside.
848            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
849                                 getSignExtendExpr(Step, Ty),
850                                 AR->getLoop());
851        }
852      }
853    }
854
855  SCEVZeroExtendExpr *&Result = (*SCEVZeroExtends)[std::make_pair(Op, Ty)];
856  if (Result == 0) Result = new SCEVZeroExtendExpr(Op, Ty);
857  return Result;
858}
859
860SCEVHandle ScalarEvolution::getSignExtendExpr(const SCEVHandle &Op,
861                                              const Type *Ty) {
862  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
863         "This is not an extending conversion!");
864  assert(isSCEVable(Ty) &&
865         "This is not a conversion to a SCEVable type!");
866  Ty = getEffectiveSCEVType(Ty);
867
868  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) {
869    const Type *IntTy = getEffectiveSCEVType(Ty);
870    Constant *C = ConstantExpr::getSExt(SC->getValue(), IntTy);
871    if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty);
872    return getUnknown(C);
873  }
874
875  // sext(sext(x)) --> sext(x)
876  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
877    return getSignExtendExpr(SS->getOperand(), Ty);
878
879  // If the input value is a chrec scev, and we can prove that the value
880  // did not overflow the old, smaller, value, we can sign extend all of the
881  // operands (often constants).  This allows analysis of something like
882  // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
883  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
884    if (AR->isAffine()) {
885      // Check whether the backedge-taken count is SCEVCouldNotCompute.
886      // Note that this serves two purposes: It filters out loops that are
887      // simply not analyzable, and it covers the case where this code is
888      // being called from within backedge-taken count analysis, such that
889      // attempting to ask for the backedge-taken count would likely result
890      // in infinite recursion. In the later case, the analysis code will
891      // cope with a conservative value, and it will take care to purge
892      // that value once it has finished.
893      SCEVHandle MaxBECount = getMaxBackedgeTakenCount(AR->getLoop());
894      if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
895        // Manually compute the final value for AR, checking for
896        // overflow.
897        SCEVHandle Start = AR->getStart();
898        SCEVHandle Step = AR->getStepRecurrence(*this);
899
900        // Check whether the backedge-taken count can be losslessly casted to
901        // the addrec's type. The count is always unsigned.
902        SCEVHandle CastedMaxBECount =
903          getTruncateOrZeroExtend(MaxBECount, Start->getType());
904        SCEVHandle RecastedMaxBECount =
905          getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
906        if (MaxBECount == RecastedMaxBECount) {
907          const Type *WideTy =
908            IntegerType::get(getTypeSizeInBits(Start->getType()) * 2);
909          // Check whether Start+Step*MaxBECount has no signed overflow.
910          SCEVHandle SMul =
911            getMulExpr(CastedMaxBECount,
912                       getTruncateOrSignExtend(Step, Start->getType()));
913          SCEVHandle Add = getAddExpr(Start, SMul);
914          SCEVHandle OperandExtendedAdd =
915            getAddExpr(getSignExtendExpr(Start, WideTy),
916                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
917                                  getSignExtendExpr(Step, WideTy)));
918          if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd)
919            // Return the expression with the addrec on the outside.
920            return getAddRecExpr(getSignExtendExpr(Start, Ty),
921                                 getSignExtendExpr(Step, Ty),
922                                 AR->getLoop());
923        }
924      }
925    }
926
927  SCEVSignExtendExpr *&Result = (*SCEVSignExtends)[std::make_pair(Op, Ty)];
928  if (Result == 0) Result = new SCEVSignExtendExpr(Op, Ty);
929  return Result;
930}
931
932// get - Get a canonical add expression, or something simpler if possible.
933SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) {
934  assert(!Ops.empty() && "Cannot get empty add!");
935  if (Ops.size() == 1) return Ops[0];
936#ifndef NDEBUG
937  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
938    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
939           getEffectiveSCEVType(Ops[0]->getType()) &&
940           "SCEVAddExpr operand types don't match!");
941#endif
942
943  // Sort by complexity, this groups all similar expression types together.
944  GroupByComplexity(Ops, LI);
945
946  // If there are any constants, fold them together.
947  unsigned Idx = 0;
948  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
949    ++Idx;
950    assert(Idx < Ops.size());
951    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
952      // We found two constants, fold them together!
953      ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() +
954                                           RHSC->getValue()->getValue());
955      Ops[0] = getConstant(Fold);
956      Ops.erase(Ops.begin()+1);  // Erase the folded element
957      if (Ops.size() == 1) return Ops[0];
958      LHSC = cast<SCEVConstant>(Ops[0]);
959    }
960
961    // If we are left with a constant zero being added, strip it off.
962    if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
963      Ops.erase(Ops.begin());
964      --Idx;
965    }
966  }
967
968  if (Ops.size() == 1) return Ops[0];
969
970  // Okay, check to see if the same value occurs in the operand list twice.  If
971  // so, merge them together into an multiply expression.  Since we sorted the
972  // list, these values are required to be adjacent.
973  const Type *Ty = Ops[0]->getType();
974  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
975    if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2
976      // Found a match, merge the two values into a multiply, and add any
977      // remaining values to the result.
978      SCEVHandle Two = getIntegerSCEV(2, Ty);
979      SCEVHandle Mul = getMulExpr(Ops[i], Two);
980      if (Ops.size() == 2)
981        return Mul;
982      Ops.erase(Ops.begin()+i, Ops.begin()+i+2);
983      Ops.push_back(Mul);
984      return getAddExpr(Ops);
985    }
986
987  // Check for truncates. If all the operands are truncated from the same
988  // type, see if factoring out the truncate would permit the result to be
989  // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n)
990  // if the contents of the resulting outer trunc fold to something simple.
991  for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) {
992    const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]);
993    const Type *DstType = Trunc->getType();
994    const Type *SrcType = Trunc->getOperand()->getType();
995    std::vector<SCEVHandle> LargeOps;
996    bool Ok = true;
997    // Check all the operands to see if they can be represented in the
998    // source type of the truncate.
999    for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
1000      if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
1001        if (T->getOperand()->getType() != SrcType) {
1002          Ok = false;
1003          break;
1004        }
1005        LargeOps.push_back(T->getOperand());
1006      } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1007        // This could be either sign or zero extension, but sign extension
1008        // is much more likely to be foldable here.
1009        LargeOps.push_back(getSignExtendExpr(C, SrcType));
1010      } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
1011        std::vector<SCEVHandle> LargeMulOps;
1012        for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
1013          if (const SCEVTruncateExpr *T =
1014                dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
1015            if (T->getOperand()->getType() != SrcType) {
1016              Ok = false;
1017              break;
1018            }
1019            LargeMulOps.push_back(T->getOperand());
1020          } else if (const SCEVConstant *C =
1021                       dyn_cast<SCEVConstant>(M->getOperand(j))) {
1022            // This could be either sign or zero extension, but sign extension
1023            // is much more likely to be foldable here.
1024            LargeMulOps.push_back(getSignExtendExpr(C, SrcType));
1025          } else {
1026            Ok = false;
1027            break;
1028          }
1029        }
1030        if (Ok)
1031          LargeOps.push_back(getMulExpr(LargeMulOps));
1032      } else {
1033        Ok = false;
1034        break;
1035      }
1036    }
1037    if (Ok) {
1038      // Evaluate the expression in the larger type.
1039      SCEVHandle Fold = getAddExpr(LargeOps);
1040      // If it folds to something simple, use it. Otherwise, don't.
1041      if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
1042        return getTruncateExpr(Fold, DstType);
1043    }
1044  }
1045
1046  // Skip past any other cast SCEVs.
1047  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
1048    ++Idx;
1049
1050  // If there are add operands they would be next.
1051  if (Idx < Ops.size()) {
1052    bool DeletedAdd = false;
1053    while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
1054      // If we have an add, expand the add operands onto the end of the operands
1055      // list.
1056      Ops.insert(Ops.end(), Add->op_begin(), Add->op_end());
1057      Ops.erase(Ops.begin()+Idx);
1058      DeletedAdd = true;
1059    }
1060
1061    // If we deleted at least one add, we added operands to the end of the list,
1062    // and they are not necessarily sorted.  Recurse to resort and resimplify
1063    // any operands we just aquired.
1064    if (DeletedAdd)
1065      return getAddExpr(Ops);
1066  }
1067
1068  // Skip over the add expression until we get to a multiply.
1069  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1070    ++Idx;
1071
1072  // If we are adding something to a multiply expression, make sure the
1073  // something is not already an operand of the multiply.  If so, merge it into
1074  // the multiply.
1075  for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
1076    const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
1077    for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
1078      const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
1079      for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
1080        if (MulOpSCEV == Ops[AddOp] && !isa<SCEVConstant>(MulOpSCEV)) {
1081          // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1))
1082          SCEVHandle InnerMul = Mul->getOperand(MulOp == 0);
1083          if (Mul->getNumOperands() != 2) {
1084            // If the multiply has more than two operands, we must get the
1085            // Y*Z term.
1086            std::vector<SCEVHandle> MulOps(Mul->op_begin(), Mul->op_end());
1087            MulOps.erase(MulOps.begin()+MulOp);
1088            InnerMul = getMulExpr(MulOps);
1089          }
1090          SCEVHandle One = getIntegerSCEV(1, Ty);
1091          SCEVHandle AddOne = getAddExpr(InnerMul, One);
1092          SCEVHandle OuterMul = getMulExpr(AddOne, Ops[AddOp]);
1093          if (Ops.size() == 2) return OuterMul;
1094          if (AddOp < Idx) {
1095            Ops.erase(Ops.begin()+AddOp);
1096            Ops.erase(Ops.begin()+Idx-1);
1097          } else {
1098            Ops.erase(Ops.begin()+Idx);
1099            Ops.erase(Ops.begin()+AddOp-1);
1100          }
1101          Ops.push_back(OuterMul);
1102          return getAddExpr(Ops);
1103        }
1104
1105      // Check this multiply against other multiplies being added together.
1106      for (unsigned OtherMulIdx = Idx+1;
1107           OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
1108           ++OtherMulIdx) {
1109        const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
1110        // If MulOp occurs in OtherMul, we can fold the two multiplies
1111        // together.
1112        for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
1113             OMulOp != e; ++OMulOp)
1114          if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
1115            // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
1116            SCEVHandle InnerMul1 = Mul->getOperand(MulOp == 0);
1117            if (Mul->getNumOperands() != 2) {
1118              std::vector<SCEVHandle> MulOps(Mul->op_begin(), Mul->op_end());
1119              MulOps.erase(MulOps.begin()+MulOp);
1120              InnerMul1 = getMulExpr(MulOps);
1121            }
1122            SCEVHandle InnerMul2 = OtherMul->getOperand(OMulOp == 0);
1123            if (OtherMul->getNumOperands() != 2) {
1124              std::vector<SCEVHandle> MulOps(OtherMul->op_begin(),
1125                                             OtherMul->op_end());
1126              MulOps.erase(MulOps.begin()+OMulOp);
1127              InnerMul2 = getMulExpr(MulOps);
1128            }
1129            SCEVHandle InnerMulSum = getAddExpr(InnerMul1,InnerMul2);
1130            SCEVHandle OuterMul = getMulExpr(MulOpSCEV, InnerMulSum);
1131            if (Ops.size() == 2) return OuterMul;
1132            Ops.erase(Ops.begin()+Idx);
1133            Ops.erase(Ops.begin()+OtherMulIdx-1);
1134            Ops.push_back(OuterMul);
1135            return getAddExpr(Ops);
1136          }
1137      }
1138    }
1139  }
1140
1141  // If there are any add recurrences in the operands list, see if any other
1142  // added values are loop invariant.  If so, we can fold them into the
1143  // recurrence.
1144  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1145    ++Idx;
1146
1147  // Scan over all recurrences, trying to fold loop invariants into them.
1148  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1149    // Scan all of the other operands to this add and add them to the vector if
1150    // they are loop invariant w.r.t. the recurrence.
1151    std::vector<SCEVHandle> LIOps;
1152    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1153    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1154      if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
1155        LIOps.push_back(Ops[i]);
1156        Ops.erase(Ops.begin()+i);
1157        --i; --e;
1158      }
1159
1160    // If we found some loop invariants, fold them into the recurrence.
1161    if (!LIOps.empty()) {
1162      //  NLI + LI + {Start,+,Step}  -->  NLI + {LI+Start,+,Step}
1163      LIOps.push_back(AddRec->getStart());
1164
1165      std::vector<SCEVHandle> AddRecOps(AddRec->op_begin(), AddRec->op_end());
1166      AddRecOps[0] = getAddExpr(LIOps);
1167
1168      SCEVHandle NewRec = getAddRecExpr(AddRecOps, AddRec->getLoop());
1169      // If all of the other operands were loop invariant, we are done.
1170      if (Ops.size() == 1) return NewRec;
1171
1172      // Otherwise, add the folded AddRec by the non-liv parts.
1173      for (unsigned i = 0;; ++i)
1174        if (Ops[i] == AddRec) {
1175          Ops[i] = NewRec;
1176          break;
1177        }
1178      return getAddExpr(Ops);
1179    }
1180
1181    // Okay, if there weren't any loop invariants to be folded, check to see if
1182    // there are multiple AddRec's with the same loop induction variable being
1183    // added together.  If so, we can fold them.
1184    for (unsigned OtherIdx = Idx+1;
1185         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
1186      if (OtherIdx != Idx) {
1187        const SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
1188        if (AddRec->getLoop() == OtherAddRec->getLoop()) {
1189          // Other + {A,+,B} + {C,+,D}  -->  Other + {A+C,+,B+D}
1190          std::vector<SCEVHandle> NewOps(AddRec->op_begin(), AddRec->op_end());
1191          for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) {
1192            if (i >= NewOps.size()) {
1193              NewOps.insert(NewOps.end(), OtherAddRec->op_begin()+i,
1194                            OtherAddRec->op_end());
1195              break;
1196            }
1197            NewOps[i] = getAddExpr(NewOps[i], OtherAddRec->getOperand(i));
1198          }
1199          SCEVHandle NewAddRec = getAddRecExpr(NewOps, AddRec->getLoop());
1200
1201          if (Ops.size() == 2) return NewAddRec;
1202
1203          Ops.erase(Ops.begin()+Idx);
1204          Ops.erase(Ops.begin()+OtherIdx-1);
1205          Ops.push_back(NewAddRec);
1206          return getAddExpr(Ops);
1207        }
1208      }
1209
1210    // Otherwise couldn't fold anything into this recurrence.  Move onto the
1211    // next one.
1212  }
1213
1214  // Okay, it looks like we really DO need an add expr.  Check to see if we
1215  // already have one, otherwise create a new one.
1216  std::vector<const SCEV*> SCEVOps(Ops.begin(), Ops.end());
1217  SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scAddExpr,
1218                                                                 SCEVOps)];
1219  if (Result == 0) Result = new SCEVAddExpr(Ops);
1220  return Result;
1221}
1222
1223
1224SCEVHandle ScalarEvolution::getMulExpr(std::vector<SCEVHandle> &Ops) {
1225  assert(!Ops.empty() && "Cannot get empty mul!");
1226#ifndef NDEBUG
1227  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1228    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
1229           getEffectiveSCEVType(Ops[0]->getType()) &&
1230           "SCEVMulExpr operand types don't match!");
1231#endif
1232
1233  // Sort by complexity, this groups all similar expression types together.
1234  GroupByComplexity(Ops, LI);
1235
1236  // If there are any constants, fold them together.
1237  unsigned Idx = 0;
1238  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1239
1240    // C1*(C2+V) -> C1*C2 + C1*V
1241    if (Ops.size() == 2)
1242      if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
1243        if (Add->getNumOperands() == 2 &&
1244            isa<SCEVConstant>(Add->getOperand(0)))
1245          return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
1246                            getMulExpr(LHSC, Add->getOperand(1)));
1247
1248
1249    ++Idx;
1250    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1251      // We found two constants, fold them together!
1252      ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() *
1253                                           RHSC->getValue()->getValue());
1254      Ops[0] = getConstant(Fold);
1255      Ops.erase(Ops.begin()+1);  // Erase the folded element
1256      if (Ops.size() == 1) return Ops[0];
1257      LHSC = cast<SCEVConstant>(Ops[0]);
1258    }
1259
1260    // If we are left with a constant one being multiplied, strip it off.
1261    if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) {
1262      Ops.erase(Ops.begin());
1263      --Idx;
1264    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
1265      // If we have a multiply of zero, it will always be zero.
1266      return Ops[0];
1267    }
1268  }
1269
1270  // Skip over the add expression until we get to a multiply.
1271  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1272    ++Idx;
1273
1274  if (Ops.size() == 1)
1275    return Ops[0];
1276
1277  // If there are mul operands inline them all into this expression.
1278  if (Idx < Ops.size()) {
1279    bool DeletedMul = false;
1280    while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
1281      // If we have an mul, expand the mul operands onto the end of the operands
1282      // list.
1283      Ops.insert(Ops.end(), Mul->op_begin(), Mul->op_end());
1284      Ops.erase(Ops.begin()+Idx);
1285      DeletedMul = true;
1286    }
1287
1288    // If we deleted at least one mul, we added operands to the end of the list,
1289    // and they are not necessarily sorted.  Recurse to resort and resimplify
1290    // any operands we just aquired.
1291    if (DeletedMul)
1292      return getMulExpr(Ops);
1293  }
1294
1295  // If there are any add recurrences in the operands list, see if any other
1296  // added values are loop invariant.  If so, we can fold them into the
1297  // recurrence.
1298  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1299    ++Idx;
1300
1301  // Scan over all recurrences, trying to fold loop invariants into them.
1302  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1303    // Scan all of the other operands to this mul and add them to the vector if
1304    // they are loop invariant w.r.t. the recurrence.
1305    std::vector<SCEVHandle> LIOps;
1306    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1307    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1308      if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
1309        LIOps.push_back(Ops[i]);
1310        Ops.erase(Ops.begin()+i);
1311        --i; --e;
1312      }
1313
1314    // If we found some loop invariants, fold them into the recurrence.
1315    if (!LIOps.empty()) {
1316      //  NLI * LI * {Start,+,Step}  -->  NLI * {LI*Start,+,LI*Step}
1317      std::vector<SCEVHandle> NewOps;
1318      NewOps.reserve(AddRec->getNumOperands());
1319      if (LIOps.size() == 1) {
1320        const SCEV *Scale = LIOps[0];
1321        for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
1322          NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i)));
1323      } else {
1324        for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
1325          std::vector<SCEVHandle> MulOps(LIOps);
1326          MulOps.push_back(AddRec->getOperand(i));
1327          NewOps.push_back(getMulExpr(MulOps));
1328        }
1329      }
1330
1331      SCEVHandle NewRec = getAddRecExpr(NewOps, AddRec->getLoop());
1332
1333      // If all of the other operands were loop invariant, we are done.
1334      if (Ops.size() == 1) return NewRec;
1335
1336      // Otherwise, multiply the folded AddRec by the non-liv parts.
1337      for (unsigned i = 0;; ++i)
1338        if (Ops[i] == AddRec) {
1339          Ops[i] = NewRec;
1340          break;
1341        }
1342      return getMulExpr(Ops);
1343    }
1344
1345    // Okay, if there weren't any loop invariants to be folded, check to see if
1346    // there are multiple AddRec's with the same loop induction variable being
1347    // multiplied together.  If so, we can fold them.
1348    for (unsigned OtherIdx = Idx+1;
1349         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
1350      if (OtherIdx != Idx) {
1351        const SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
1352        if (AddRec->getLoop() == OtherAddRec->getLoop()) {
1353          // F * G  -->  {A,+,B} * {C,+,D}  -->  {A*C,+,F*D + G*B + B*D}
1354          const SCEVAddRecExpr *F = AddRec, *G = OtherAddRec;
1355          SCEVHandle NewStart = getMulExpr(F->getStart(),
1356                                                 G->getStart());
1357          SCEVHandle B = F->getStepRecurrence(*this);
1358          SCEVHandle D = G->getStepRecurrence(*this);
1359          SCEVHandle NewStep = getAddExpr(getMulExpr(F, D),
1360                                          getMulExpr(G, B),
1361                                          getMulExpr(B, D));
1362          SCEVHandle NewAddRec = getAddRecExpr(NewStart, NewStep,
1363                                               F->getLoop());
1364          if (Ops.size() == 2) return NewAddRec;
1365
1366          Ops.erase(Ops.begin()+Idx);
1367          Ops.erase(Ops.begin()+OtherIdx-1);
1368          Ops.push_back(NewAddRec);
1369          return getMulExpr(Ops);
1370        }
1371      }
1372
1373    // Otherwise couldn't fold anything into this recurrence.  Move onto the
1374    // next one.
1375  }
1376
1377  // Okay, it looks like we really DO need an mul expr.  Check to see if we
1378  // already have one, otherwise create a new one.
1379  std::vector<const SCEV*> SCEVOps(Ops.begin(), Ops.end());
1380  SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scMulExpr,
1381                                                                 SCEVOps)];
1382  if (Result == 0)
1383    Result = new SCEVMulExpr(Ops);
1384  return Result;
1385}
1386
1387SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS,
1388                                        const SCEVHandle &RHS) {
1389  assert(getEffectiveSCEVType(LHS->getType()) ==
1390         getEffectiveSCEVType(RHS->getType()) &&
1391         "SCEVUDivExpr operand types don't match!");
1392
1393  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
1394    if (RHSC->getValue()->equalsInt(1))
1395      return LHS;                            // X udiv 1 --> x
1396    if (RHSC->isZero())
1397      return getIntegerSCEV(0, LHS->getType()); // value is undefined
1398
1399    // Determine if the division can be folded into the operands of
1400    // its operands.
1401    // TODO: Generalize this to non-constants by using known-bits information.
1402    const Type *Ty = LHS->getType();
1403    unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros();
1404    unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ;
1405    // For non-power-of-two values, effectively round the value up to the
1406    // nearest power of two.
1407    if (!RHSC->getValue()->getValue().isPowerOf2())
1408      ++MaxShiftAmt;
1409    const IntegerType *ExtTy =
1410      IntegerType::get(getTypeSizeInBits(Ty) + MaxShiftAmt);
1411    // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
1412    if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
1413      if (const SCEVConstant *Step =
1414            dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this)))
1415        if (!Step->getValue()->getValue()
1416              .urem(RHSC->getValue()->getValue()) &&
1417            getZeroExtendExpr(AR, ExtTy) ==
1418            getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
1419                          getZeroExtendExpr(Step, ExtTy),
1420                          AR->getLoop())) {
1421          std::vector<SCEVHandle> Operands;
1422          for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i)
1423            Operands.push_back(getUDivExpr(AR->getOperand(i), RHS));
1424          return getAddRecExpr(Operands, AR->getLoop());
1425        }
1426    // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
1427    if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
1428      std::vector<SCEVHandle> Operands;
1429      for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i)
1430        Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy));
1431      if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
1432        // Find an operand that's safely divisible.
1433        for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
1434          SCEVHandle Op = M->getOperand(i);
1435          SCEVHandle Div = getUDivExpr(Op, RHSC);
1436          if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
1437            Operands = M->getOperands();
1438            Operands[i] = Div;
1439            return getMulExpr(Operands);
1440          }
1441        }
1442    }
1443    // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
1444    if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(LHS)) {
1445      std::vector<SCEVHandle> Operands;
1446      for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i)
1447        Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy));
1448      if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
1449        Operands.clear();
1450        for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
1451          SCEVHandle Op = getUDivExpr(A->getOperand(i), RHS);
1452          if (isa<SCEVUDivExpr>(Op) || getMulExpr(Op, RHS) != A->getOperand(i))
1453            break;
1454          Operands.push_back(Op);
1455        }
1456        if (Operands.size() == A->getNumOperands())
1457          return getAddExpr(Operands);
1458      }
1459    }
1460
1461    // Fold if both operands are constant.
1462    if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
1463      Constant *LHSCV = LHSC->getValue();
1464      Constant *RHSCV = RHSC->getValue();
1465      return getUnknown(ConstantExpr::getUDiv(LHSCV, RHSCV));
1466    }
1467  }
1468
1469  SCEVUDivExpr *&Result = (*SCEVUDivs)[std::make_pair(LHS, RHS)];
1470  if (Result == 0) Result = new SCEVUDivExpr(LHS, RHS);
1471  return Result;
1472}
1473
1474
1475/// SCEVAddRecExpr::get - Get a add recurrence expression for the
1476/// specified loop.  Simplify the expression as much as possible.
1477SCEVHandle ScalarEvolution::getAddRecExpr(const SCEVHandle &Start,
1478                               const SCEVHandle &Step, const Loop *L) {
1479  std::vector<SCEVHandle> Operands;
1480  Operands.push_back(Start);
1481  if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
1482    if (StepChrec->getLoop() == L) {
1483      Operands.insert(Operands.end(), StepChrec->op_begin(),
1484                      StepChrec->op_end());
1485      return getAddRecExpr(Operands, L);
1486    }
1487
1488  Operands.push_back(Step);
1489  return getAddRecExpr(Operands, L);
1490}
1491
1492/// SCEVAddRecExpr::get - Get a add recurrence expression for the
1493/// specified loop.  Simplify the expression as much as possible.
1494SCEVHandle ScalarEvolution::getAddRecExpr(std::vector<SCEVHandle> &Operands,
1495                                          const Loop *L) {
1496  if (Operands.size() == 1) return Operands[0];
1497#ifndef NDEBUG
1498  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
1499    assert(getEffectiveSCEVType(Operands[i]->getType()) ==
1500           getEffectiveSCEVType(Operands[0]->getType()) &&
1501           "SCEVAddRecExpr operand types don't match!");
1502#endif
1503
1504  if (Operands.back()->isZero()) {
1505    Operands.pop_back();
1506    return getAddRecExpr(Operands, L);             // {X,+,0}  -->  X
1507  }
1508
1509  // Canonicalize nested AddRecs in by nesting them in order of loop depth.
1510  if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
1511    const Loop* NestedLoop = NestedAR->getLoop();
1512    if (L->getLoopDepth() < NestedLoop->getLoopDepth()) {
1513      std::vector<SCEVHandle> NestedOperands(NestedAR->op_begin(),
1514                                             NestedAR->op_end());
1515      SCEVHandle NestedARHandle(NestedAR);
1516      Operands[0] = NestedAR->getStart();
1517      NestedOperands[0] = getAddRecExpr(Operands, L);
1518      return getAddRecExpr(NestedOperands, NestedLoop);
1519    }
1520  }
1521
1522  std::vector<const SCEV*> SCEVOps(Operands.begin(), Operands.end());
1523  SCEVAddRecExpr *&Result = (*SCEVAddRecExprs)[std::make_pair(L, SCEVOps)];
1524  if (Result == 0) Result = new SCEVAddRecExpr(Operands, L);
1525  return Result;
1526}
1527
1528SCEVHandle ScalarEvolution::getSMaxExpr(const SCEVHandle &LHS,
1529                                        const SCEVHandle &RHS) {
1530  std::vector<SCEVHandle> Ops;
1531  Ops.push_back(LHS);
1532  Ops.push_back(RHS);
1533  return getSMaxExpr(Ops);
1534}
1535
1536SCEVHandle ScalarEvolution::getSMaxExpr(std::vector<SCEVHandle> Ops) {
1537  assert(!Ops.empty() && "Cannot get empty smax!");
1538  if (Ops.size() == 1) return Ops[0];
1539#ifndef NDEBUG
1540  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1541    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
1542           getEffectiveSCEVType(Ops[0]->getType()) &&
1543           "SCEVSMaxExpr operand types don't match!");
1544#endif
1545
1546  // Sort by complexity, this groups all similar expression types together.
1547  GroupByComplexity(Ops, LI);
1548
1549  // If there are any constants, fold them together.
1550  unsigned Idx = 0;
1551  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1552    ++Idx;
1553    assert(Idx < Ops.size());
1554    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1555      // We found two constants, fold them together!
1556      ConstantInt *Fold = ConstantInt::get(
1557                              APIntOps::smax(LHSC->getValue()->getValue(),
1558                                             RHSC->getValue()->getValue()));
1559      Ops[0] = getConstant(Fold);
1560      Ops.erase(Ops.begin()+1);  // Erase the folded element
1561      if (Ops.size() == 1) return Ops[0];
1562      LHSC = cast<SCEVConstant>(Ops[0]);
1563    }
1564
1565    // If we are left with a constant -inf, strip it off.
1566    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
1567      Ops.erase(Ops.begin());
1568      --Idx;
1569    }
1570  }
1571
1572  if (Ops.size() == 1) return Ops[0];
1573
1574  // Find the first SMax
1575  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
1576    ++Idx;
1577
1578  // Check to see if one of the operands is an SMax. If so, expand its operands
1579  // onto our operand list, and recurse to simplify.
1580  if (Idx < Ops.size()) {
1581    bool DeletedSMax = false;
1582    while (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
1583      Ops.insert(Ops.end(), SMax->op_begin(), SMax->op_end());
1584      Ops.erase(Ops.begin()+Idx);
1585      DeletedSMax = true;
1586    }
1587
1588    if (DeletedSMax)
1589      return getSMaxExpr(Ops);
1590  }
1591
1592  // Okay, check to see if the same value occurs in the operand list twice.  If
1593  // so, delete one.  Since we sorted the list, these values are required to
1594  // be adjacent.
1595  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1596    if (Ops[i] == Ops[i+1]) {      //  X smax Y smax Y  -->  X smax Y
1597      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
1598      --i; --e;
1599    }
1600
1601  if (Ops.size() == 1) return Ops[0];
1602
1603  assert(!Ops.empty() && "Reduced smax down to nothing!");
1604
1605  // Okay, it looks like we really DO need an smax expr.  Check to see if we
1606  // already have one, otherwise create a new one.
1607  std::vector<const SCEV*> SCEVOps(Ops.begin(), Ops.end());
1608  SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scSMaxExpr,
1609                                                                 SCEVOps)];
1610  if (Result == 0) Result = new SCEVSMaxExpr(Ops);
1611  return Result;
1612}
1613
1614SCEVHandle ScalarEvolution::getUMaxExpr(const SCEVHandle &LHS,
1615                                        const SCEVHandle &RHS) {
1616  std::vector<SCEVHandle> Ops;
1617  Ops.push_back(LHS);
1618  Ops.push_back(RHS);
1619  return getUMaxExpr(Ops);
1620}
1621
1622SCEVHandle ScalarEvolution::getUMaxExpr(std::vector<SCEVHandle> Ops) {
1623  assert(!Ops.empty() && "Cannot get empty umax!");
1624  if (Ops.size() == 1) return Ops[0];
1625#ifndef NDEBUG
1626  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1627    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
1628           getEffectiveSCEVType(Ops[0]->getType()) &&
1629           "SCEVUMaxExpr operand types don't match!");
1630#endif
1631
1632  // Sort by complexity, this groups all similar expression types together.
1633  GroupByComplexity(Ops, LI);
1634
1635  // If there are any constants, fold them together.
1636  unsigned Idx = 0;
1637  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1638    ++Idx;
1639    assert(Idx < Ops.size());
1640    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1641      // We found two constants, fold them together!
1642      ConstantInt *Fold = ConstantInt::get(
1643                              APIntOps::umax(LHSC->getValue()->getValue(),
1644                                             RHSC->getValue()->getValue()));
1645      Ops[0] = getConstant(Fold);
1646      Ops.erase(Ops.begin()+1);  // Erase the folded element
1647      if (Ops.size() == 1) return Ops[0];
1648      LHSC = cast<SCEVConstant>(Ops[0]);
1649    }
1650
1651    // If we are left with a constant zero, strip it off.
1652    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) {
1653      Ops.erase(Ops.begin());
1654      --Idx;
1655    }
1656  }
1657
1658  if (Ops.size() == 1) return Ops[0];
1659
1660  // Find the first UMax
1661  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr)
1662    ++Idx;
1663
1664  // Check to see if one of the operands is a UMax. If so, expand its operands
1665  // onto our operand list, and recurse to simplify.
1666  if (Idx < Ops.size()) {
1667    bool DeletedUMax = false;
1668    while (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) {
1669      Ops.insert(Ops.end(), UMax->op_begin(), UMax->op_end());
1670      Ops.erase(Ops.begin()+Idx);
1671      DeletedUMax = true;
1672    }
1673
1674    if (DeletedUMax)
1675      return getUMaxExpr(Ops);
1676  }
1677
1678  // Okay, check to see if the same value occurs in the operand list twice.  If
1679  // so, delete one.  Since we sorted the list, these values are required to
1680  // be adjacent.
1681  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1682    if (Ops[i] == Ops[i+1]) {      //  X umax Y umax Y  -->  X umax Y
1683      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
1684      --i; --e;
1685    }
1686
1687  if (Ops.size() == 1) return Ops[0];
1688
1689  assert(!Ops.empty() && "Reduced umax down to nothing!");
1690
1691  // Okay, it looks like we really DO need a umax expr.  Check to see if we
1692  // already have one, otherwise create a new one.
1693  std::vector<const SCEV*> SCEVOps(Ops.begin(), Ops.end());
1694  SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scUMaxExpr,
1695                                                                 SCEVOps)];
1696  if (Result == 0) Result = new SCEVUMaxExpr(Ops);
1697  return Result;
1698}
1699
1700SCEVHandle ScalarEvolution::getUnknown(Value *V) {
1701  if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
1702    return getConstant(CI);
1703  if (isa<ConstantPointerNull>(V))
1704    return getIntegerSCEV(0, V->getType());
1705  SCEVUnknown *&Result = (*SCEVUnknowns)[V];
1706  if (Result == 0) Result = new SCEVUnknown(V);
1707  return Result;
1708}
1709
1710//===----------------------------------------------------------------------===//
1711//            Basic SCEV Analysis and PHI Idiom Recognition Code
1712//
1713
1714/// isSCEVable - Test if values of the given type are analyzable within
1715/// the SCEV framework. This primarily includes integer types, and it
1716/// can optionally include pointer types if the ScalarEvolution class
1717/// has access to target-specific information.
1718bool ScalarEvolution::isSCEVable(const Type *Ty) const {
1719  // Integers are always SCEVable.
1720  if (Ty->isInteger())
1721    return true;
1722
1723  // Pointers are SCEVable if TargetData information is available
1724  // to provide pointer size information.
1725  if (isa<PointerType>(Ty))
1726    return TD != NULL;
1727
1728  // Otherwise it's not SCEVable.
1729  return false;
1730}
1731
1732/// getTypeSizeInBits - Return the size in bits of the specified type,
1733/// for which isSCEVable must return true.
1734uint64_t ScalarEvolution::getTypeSizeInBits(const Type *Ty) const {
1735  assert(isSCEVable(Ty) && "Type is not SCEVable!");
1736
1737  // If we have a TargetData, use it!
1738  if (TD)
1739    return TD->getTypeSizeInBits(Ty);
1740
1741  // Otherwise, we support only integer types.
1742  assert(Ty->isInteger() && "isSCEVable permitted a non-SCEVable type!");
1743  return Ty->getPrimitiveSizeInBits();
1744}
1745
1746/// getEffectiveSCEVType - Return a type with the same bitwidth as
1747/// the given type and which represents how SCEV will treat the given
1748/// type, for which isSCEVable must return true. For pointer types,
1749/// this is the pointer-sized integer type.
1750const Type *ScalarEvolution::getEffectiveSCEVType(const Type *Ty) const {
1751  assert(isSCEVable(Ty) && "Type is not SCEVable!");
1752
1753  if (Ty->isInteger())
1754    return Ty;
1755
1756  assert(isa<PointerType>(Ty) && "Unexpected non-pointer non-integer type!");
1757  return TD->getIntPtrType();
1758}
1759
1760SCEVHandle ScalarEvolution::getCouldNotCompute() {
1761  return UnknownValue;
1762}
1763
1764/// hasSCEV - Return true if the SCEV for this value has already been
1765/// computed.
1766bool ScalarEvolution::hasSCEV(Value *V) const {
1767  return Scalars.count(V);
1768}
1769
1770/// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
1771/// expression and create a new one.
1772SCEVHandle ScalarEvolution::getSCEV(Value *V) {
1773  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
1774
1775  std::map<SCEVCallbackVH, SCEVHandle>::iterator I = Scalars.find(V);
1776  if (I != Scalars.end()) return I->second;
1777  SCEVHandle S = createSCEV(V);
1778  Scalars.insert(std::make_pair(SCEVCallbackVH(V, this), S));
1779  return S;
1780}
1781
1782/// getIntegerSCEV - Given an integer or FP type, create a constant for the
1783/// specified signed integer value and return a SCEV for the constant.
1784SCEVHandle ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) {
1785  Ty = getEffectiveSCEVType(Ty);
1786  Constant *C;
1787  if (Val == 0)
1788    C = Constant::getNullValue(Ty);
1789  else if (Ty->isFloatingPoint())
1790    C = ConstantFP::get(APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle :
1791                                APFloat::IEEEdouble, Val));
1792  else
1793    C = ConstantInt::get(Ty, Val);
1794  return getUnknown(C);
1795}
1796
1797/// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
1798///
1799SCEVHandle ScalarEvolution::getNegativeSCEV(const SCEVHandle &V) {
1800  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
1801    return getUnknown(ConstantExpr::getNeg(VC->getValue()));
1802
1803  const Type *Ty = V->getType();
1804  Ty = getEffectiveSCEVType(Ty);
1805  return getMulExpr(V, getConstant(ConstantInt::getAllOnesValue(Ty)));
1806}
1807
1808/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
1809SCEVHandle ScalarEvolution::getNotSCEV(const SCEVHandle &V) {
1810  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
1811    return getUnknown(ConstantExpr::getNot(VC->getValue()));
1812
1813  const Type *Ty = V->getType();
1814  Ty = getEffectiveSCEVType(Ty);
1815  SCEVHandle AllOnes = getConstant(ConstantInt::getAllOnesValue(Ty));
1816  return getMinusSCEV(AllOnes, V);
1817}
1818
1819/// getMinusSCEV - Return a SCEV corresponding to LHS - RHS.
1820///
1821SCEVHandle ScalarEvolution::getMinusSCEV(const SCEVHandle &LHS,
1822                                         const SCEVHandle &RHS) {
1823  // X - Y --> X + -Y
1824  return getAddExpr(LHS, getNegativeSCEV(RHS));
1825}
1826
1827/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
1828/// input value to the specified type.  If the type must be extended, it is zero
1829/// extended.
1830SCEVHandle
1831ScalarEvolution::getTruncateOrZeroExtend(const SCEVHandle &V,
1832                                         const Type *Ty) {
1833  const Type *SrcTy = V->getType();
1834  assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
1835         (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
1836         "Cannot truncate or zero extend with non-integer arguments!");
1837  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
1838    return V;  // No conversion
1839  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
1840    return getTruncateExpr(V, Ty);
1841  return getZeroExtendExpr(V, Ty);
1842}
1843
1844/// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the
1845/// input value to the specified type.  If the type must be extended, it is sign
1846/// extended.
1847SCEVHandle
1848ScalarEvolution::getTruncateOrSignExtend(const SCEVHandle &V,
1849                                         const Type *Ty) {
1850  const Type *SrcTy = V->getType();
1851  assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
1852         (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
1853         "Cannot truncate or zero extend with non-integer arguments!");
1854  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
1855    return V;  // No conversion
1856  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
1857    return getTruncateExpr(V, Ty);
1858  return getSignExtendExpr(V, Ty);
1859}
1860
1861/// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the
1862/// input value to the specified type.  If the type must be extended, it is zero
1863/// extended.  The conversion must not be narrowing.
1864SCEVHandle
1865ScalarEvolution::getNoopOrZeroExtend(const SCEVHandle &V, const Type *Ty) {
1866  const Type *SrcTy = V->getType();
1867  assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
1868         (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
1869         "Cannot noop or zero extend with non-integer arguments!");
1870  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
1871         "getNoopOrZeroExtend cannot truncate!");
1872  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
1873    return V;  // No conversion
1874  return getZeroExtendExpr(V, Ty);
1875}
1876
1877/// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the
1878/// input value to the specified type.  If the type must be extended, it is sign
1879/// extended.  The conversion must not be narrowing.
1880SCEVHandle
1881ScalarEvolution::getNoopOrSignExtend(const SCEVHandle &V, const Type *Ty) {
1882  const Type *SrcTy = V->getType();
1883  assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
1884         (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
1885         "Cannot noop or sign extend with non-integer arguments!");
1886  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
1887         "getNoopOrSignExtend cannot truncate!");
1888  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
1889    return V;  // No conversion
1890  return getSignExtendExpr(V, Ty);
1891}
1892
1893/// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the
1894/// input value to the specified type.  The conversion must not be widening.
1895SCEVHandle
1896ScalarEvolution::getTruncateOrNoop(const SCEVHandle &V, const Type *Ty) {
1897  const Type *SrcTy = V->getType();
1898  assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
1899         (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
1900         "Cannot truncate or noop with non-integer arguments!");
1901  assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
1902         "getTruncateOrNoop cannot extend!");
1903  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
1904    return V;  // No conversion
1905  return getTruncateExpr(V, Ty);
1906}
1907
1908/// ReplaceSymbolicValueWithConcrete - This looks up the computed SCEV value for
1909/// the specified instruction and replaces any references to the symbolic value
1910/// SymName with the specified value.  This is used during PHI resolution.
1911void ScalarEvolution::
1912ReplaceSymbolicValueWithConcrete(Instruction *I, const SCEVHandle &SymName,
1913                                 const SCEVHandle &NewVal) {
1914  std::map<SCEVCallbackVH, SCEVHandle>::iterator SI =
1915    Scalars.find(SCEVCallbackVH(I, this));
1916  if (SI == Scalars.end()) return;
1917
1918  SCEVHandle NV =
1919    SI->second->replaceSymbolicValuesWithConcrete(SymName, NewVal, *this);
1920  if (NV == SI->second) return;  // No change.
1921
1922  SI->second = NV;       // Update the scalars map!
1923
1924  // Any instruction values that use this instruction might also need to be
1925  // updated!
1926  for (Value::use_iterator UI = I->use_begin(), E = I->use_end();
1927       UI != E; ++UI)
1928    ReplaceSymbolicValueWithConcrete(cast<Instruction>(*UI), SymName, NewVal);
1929}
1930
1931/// createNodeForPHI - PHI nodes have two cases.  Either the PHI node exists in
1932/// a loop header, making it a potential recurrence, or it doesn't.
1933///
1934SCEVHandle ScalarEvolution::createNodeForPHI(PHINode *PN) {
1935  if (PN->getNumIncomingValues() == 2)  // The loops have been canonicalized.
1936    if (const Loop *L = LI->getLoopFor(PN->getParent()))
1937      if (L->getHeader() == PN->getParent()) {
1938        // If it lives in the loop header, it has two incoming values, one
1939        // from outside the loop, and one from inside.
1940        unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0));
1941        unsigned BackEdge     = IncomingEdge^1;
1942
1943        // While we are analyzing this PHI node, handle its value symbolically.
1944        SCEVHandle SymbolicName = getUnknown(PN);
1945        assert(Scalars.find(PN) == Scalars.end() &&
1946               "PHI node already processed?");
1947        Scalars.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName));
1948
1949        // Using this symbolic name for the PHI, analyze the value coming around
1950        // the back-edge.
1951        SCEVHandle BEValue = getSCEV(PN->getIncomingValue(BackEdge));
1952
1953        // NOTE: If BEValue is loop invariant, we know that the PHI node just
1954        // has a special value for the first iteration of the loop.
1955
1956        // If the value coming around the backedge is an add with the symbolic
1957        // value we just inserted, then we found a simple induction variable!
1958        if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
1959          // If there is a single occurrence of the symbolic value, replace it
1960          // with a recurrence.
1961          unsigned FoundIndex = Add->getNumOperands();
1962          for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
1963            if (Add->getOperand(i) == SymbolicName)
1964              if (FoundIndex == e) {
1965                FoundIndex = i;
1966                break;
1967              }
1968
1969          if (FoundIndex != Add->getNumOperands()) {
1970            // Create an add with everything but the specified operand.
1971            std::vector<SCEVHandle> Ops;
1972            for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
1973              if (i != FoundIndex)
1974                Ops.push_back(Add->getOperand(i));
1975            SCEVHandle Accum = getAddExpr(Ops);
1976
1977            // This is not a valid addrec if the step amount is varying each
1978            // loop iteration, but is not itself an addrec in this loop.
1979            if (Accum->isLoopInvariant(L) ||
1980                (isa<SCEVAddRecExpr>(Accum) &&
1981                 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
1982              SCEVHandle StartVal = getSCEV(PN->getIncomingValue(IncomingEdge));
1983              SCEVHandle PHISCEV  = getAddRecExpr(StartVal, Accum, L);
1984
1985              // Okay, for the entire analysis of this edge we assumed the PHI
1986              // to be symbolic.  We now need to go back and update all of the
1987              // entries for the scalars that use the PHI (except for the PHI
1988              // itself) to use the new analyzed value instead of the "symbolic"
1989              // value.
1990              ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV);
1991              return PHISCEV;
1992            }
1993          }
1994        } else if (const SCEVAddRecExpr *AddRec =
1995                     dyn_cast<SCEVAddRecExpr>(BEValue)) {
1996          // Otherwise, this could be a loop like this:
1997          //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
1998          // In this case, j = {1,+,1}  and BEValue is j.
1999          // Because the other in-value of i (0) fits the evolution of BEValue
2000          // i really is an addrec evolution.
2001          if (AddRec->getLoop() == L && AddRec->isAffine()) {
2002            SCEVHandle StartVal = getSCEV(PN->getIncomingValue(IncomingEdge));
2003
2004            // If StartVal = j.start - j.stride, we can use StartVal as the
2005            // initial step of the addrec evolution.
2006            if (StartVal == getMinusSCEV(AddRec->getOperand(0),
2007                                            AddRec->getOperand(1))) {
2008              SCEVHandle PHISCEV =
2009                 getAddRecExpr(StartVal, AddRec->getOperand(1), L);
2010
2011              // Okay, for the entire analysis of this edge we assumed the PHI
2012              // to be symbolic.  We now need to go back and update all of the
2013              // entries for the scalars that use the PHI (except for the PHI
2014              // itself) to use the new analyzed value instead of the "symbolic"
2015              // value.
2016              ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV);
2017              return PHISCEV;
2018            }
2019          }
2020        }
2021
2022        return SymbolicName;
2023      }
2024
2025  // If it's not a loop phi, we can't handle it yet.
2026  return getUnknown(PN);
2027}
2028
2029/// createNodeForGEP - Expand GEP instructions into add and multiply
2030/// operations. This allows them to be analyzed by regular SCEV code.
2031///
2032SCEVHandle ScalarEvolution::createNodeForGEP(User *GEP) {
2033
2034  const Type *IntPtrTy = TD->getIntPtrType();
2035  Value *Base = GEP->getOperand(0);
2036  // Don't attempt to analyze GEPs over unsized objects.
2037  if (!cast<PointerType>(Base->getType())->getElementType()->isSized())
2038    return getUnknown(GEP);
2039  SCEVHandle TotalOffset = getIntegerSCEV(0, IntPtrTy);
2040  gep_type_iterator GTI = gep_type_begin(GEP);
2041  for (GetElementPtrInst::op_iterator I = next(GEP->op_begin()),
2042                                      E = GEP->op_end();
2043       I != E; ++I) {
2044    Value *Index = *I;
2045    // Compute the (potentially symbolic) offset in bytes for this index.
2046    if (const StructType *STy = dyn_cast<StructType>(*GTI++)) {
2047      // For a struct, add the member offset.
2048      const StructLayout &SL = *TD->getStructLayout(STy);
2049      unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue();
2050      uint64_t Offset = SL.getElementOffset(FieldNo);
2051      TotalOffset = getAddExpr(TotalOffset,
2052                                  getIntegerSCEV(Offset, IntPtrTy));
2053    } else {
2054      // For an array, add the element offset, explicitly scaled.
2055      SCEVHandle LocalOffset = getSCEV(Index);
2056      if (!isa<PointerType>(LocalOffset->getType()))
2057        // Getelementptr indicies are signed.
2058        LocalOffset = getTruncateOrSignExtend(LocalOffset,
2059                                              IntPtrTy);
2060      LocalOffset =
2061        getMulExpr(LocalOffset,
2062                   getIntegerSCEV(TD->getTypeAllocSize(*GTI),
2063                                  IntPtrTy));
2064      TotalOffset = getAddExpr(TotalOffset, LocalOffset);
2065    }
2066  }
2067  return getAddExpr(getSCEV(Base), TotalOffset);
2068}
2069
2070/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
2071/// guaranteed to end in (at every loop iteration).  It is, at the same time,
2072/// the minimum number of times S is divisible by 2.  For example, given {4,+,8}
2073/// it returns 2.  If S is guaranteed to be 0, it returns the bitwidth of S.
2074static uint32_t GetMinTrailingZeros(SCEVHandle S, const ScalarEvolution &SE) {
2075  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
2076    return C->getValue()->getValue().countTrailingZeros();
2077
2078  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
2079    return std::min(GetMinTrailingZeros(T->getOperand(), SE),
2080                    (uint32_t)SE.getTypeSizeInBits(T->getType()));
2081
2082  if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
2083    uint32_t OpRes = GetMinTrailingZeros(E->getOperand(), SE);
2084    return OpRes == SE.getTypeSizeInBits(E->getOperand()->getType()) ?
2085             SE.getTypeSizeInBits(E->getType()) : OpRes;
2086  }
2087
2088  if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
2089    uint32_t OpRes = GetMinTrailingZeros(E->getOperand(), SE);
2090    return OpRes == SE.getTypeSizeInBits(E->getOperand()->getType()) ?
2091             SE.getTypeSizeInBits(E->getType()) : OpRes;
2092  }
2093
2094  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
2095    // The result is the min of all operands results.
2096    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0), SE);
2097    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
2098      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i), SE));
2099    return MinOpRes;
2100  }
2101
2102  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
2103    // The result is the sum of all operands results.
2104    uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0), SE);
2105    uint32_t BitWidth = SE.getTypeSizeInBits(M->getType());
2106    for (unsigned i = 1, e = M->getNumOperands();
2107         SumOpRes != BitWidth && i != e; ++i)
2108      SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i), SE),
2109                          BitWidth);
2110    return SumOpRes;
2111  }
2112
2113  if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
2114    // The result is the min of all operands results.
2115    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0), SE);
2116    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
2117      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i), SE));
2118    return MinOpRes;
2119  }
2120
2121  if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
2122    // The result is the min of all operands results.
2123    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0), SE);
2124    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
2125      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i), SE));
2126    return MinOpRes;
2127  }
2128
2129  if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
2130    // The result is the min of all operands results.
2131    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0), SE);
2132    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
2133      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i), SE));
2134    return MinOpRes;
2135  }
2136
2137  // SCEVUDivExpr, SCEVUnknown
2138  return 0;
2139}
2140
2141/// createSCEV - We know that there is no SCEV for the specified value.
2142/// Analyze the expression.
2143///
2144SCEVHandle ScalarEvolution::createSCEV(Value *V) {
2145  if (!isSCEVable(V->getType()))
2146    return getUnknown(V);
2147
2148  unsigned Opcode = Instruction::UserOp1;
2149  if (Instruction *I = dyn_cast<Instruction>(V))
2150    Opcode = I->getOpcode();
2151  else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
2152    Opcode = CE->getOpcode();
2153  else
2154    return getUnknown(V);
2155
2156  User *U = cast<User>(V);
2157  switch (Opcode) {
2158  case Instruction::Add:
2159    return getAddExpr(getSCEV(U->getOperand(0)),
2160                      getSCEV(U->getOperand(1)));
2161  case Instruction::Mul:
2162    return getMulExpr(getSCEV(U->getOperand(0)),
2163                      getSCEV(U->getOperand(1)));
2164  case Instruction::UDiv:
2165    return getUDivExpr(getSCEV(U->getOperand(0)),
2166                       getSCEV(U->getOperand(1)));
2167  case Instruction::Sub:
2168    return getMinusSCEV(getSCEV(U->getOperand(0)),
2169                        getSCEV(U->getOperand(1)));
2170  case Instruction::And:
2171    // For an expression like x&255 that merely masks off the high bits,
2172    // use zext(trunc(x)) as the SCEV expression.
2173    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
2174      if (CI->isNullValue())
2175        return getSCEV(U->getOperand(1));
2176      if (CI->isAllOnesValue())
2177        return getSCEV(U->getOperand(0));
2178      const APInt &A = CI->getValue();
2179      unsigned Ones = A.countTrailingOnes();
2180      if (APIntOps::isMask(Ones, A))
2181        return
2182          getZeroExtendExpr(getTruncateExpr(getSCEV(U->getOperand(0)),
2183                                            IntegerType::get(Ones)),
2184                            U->getType());
2185    }
2186    break;
2187  case Instruction::Or:
2188    // If the RHS of the Or is a constant, we may have something like:
2189    // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
2190    // optimizations will transparently handle this case.
2191    //
2192    // In order for this transformation to be safe, the LHS must be of the
2193    // form X*(2^n) and the Or constant must be less than 2^n.
2194    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
2195      SCEVHandle LHS = getSCEV(U->getOperand(0));
2196      const APInt &CIVal = CI->getValue();
2197      if (GetMinTrailingZeros(LHS, *this) >=
2198          (CIVal.getBitWidth() - CIVal.countLeadingZeros()))
2199        return getAddExpr(LHS, getSCEV(U->getOperand(1)));
2200    }
2201    break;
2202  case Instruction::Xor:
2203    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
2204      // If the RHS of the xor is a signbit, then this is just an add.
2205      // Instcombine turns add of signbit into xor as a strength reduction step.
2206      if (CI->getValue().isSignBit())
2207        return getAddExpr(getSCEV(U->getOperand(0)),
2208                          getSCEV(U->getOperand(1)));
2209
2210      // If the RHS of xor is -1, then this is a not operation.
2211      if (CI->isAllOnesValue())
2212        return getNotSCEV(getSCEV(U->getOperand(0)));
2213
2214      // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
2215      // This is a variant of the check for xor with -1, and it handles
2216      // the case where instcombine has trimmed non-demanded bits out
2217      // of an xor with -1.
2218      if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0)))
2219        if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1)))
2220          if (BO->getOpcode() == Instruction::And &&
2221              LCI->getValue() == CI->getValue())
2222            if (const SCEVZeroExtendExpr *Z =
2223                  dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0))))
2224              return getZeroExtendExpr(getNotSCEV(Z->getOperand()),
2225                                       U->getType());
2226    }
2227    break;
2228
2229  case Instruction::Shl:
2230    // Turn shift left of a constant amount into a multiply.
2231    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
2232      uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
2233      Constant *X = ConstantInt::get(
2234        APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
2235      return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
2236    }
2237    break;
2238
2239  case Instruction::LShr:
2240    // Turn logical shift right of a constant into a unsigned divide.
2241    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
2242      uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
2243      Constant *X = ConstantInt::get(
2244        APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
2245      return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
2246    }
2247    break;
2248
2249  case Instruction::AShr:
2250    // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
2251    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1)))
2252      if (Instruction *L = dyn_cast<Instruction>(U->getOperand(0)))
2253        if (L->getOpcode() == Instruction::Shl &&
2254            L->getOperand(1) == U->getOperand(1)) {
2255          unsigned BitWidth = getTypeSizeInBits(U->getType());
2256          uint64_t Amt = BitWidth - CI->getZExtValue();
2257          if (Amt == BitWidth)
2258            return getSCEV(L->getOperand(0));       // shift by zero --> noop
2259          if (Amt > BitWidth)
2260            return getIntegerSCEV(0, U->getType()); // value is undefined
2261          return
2262            getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)),
2263                                                      IntegerType::get(Amt)),
2264                                 U->getType());
2265        }
2266    break;
2267
2268  case Instruction::Trunc:
2269    return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
2270
2271  case Instruction::ZExt:
2272    return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
2273
2274  case Instruction::SExt:
2275    return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
2276
2277  case Instruction::BitCast:
2278    // BitCasts are no-op casts so we just eliminate the cast.
2279    if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
2280      return getSCEV(U->getOperand(0));
2281    break;
2282
2283  case Instruction::IntToPtr:
2284    if (!TD) break; // Without TD we can't analyze pointers.
2285    return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)),
2286                                   TD->getIntPtrType());
2287
2288  case Instruction::PtrToInt:
2289    if (!TD) break; // Without TD we can't analyze pointers.
2290    return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)),
2291                                   U->getType());
2292
2293  case Instruction::GetElementPtr:
2294    if (!TD) break; // Without TD we can't analyze pointers.
2295    return createNodeForGEP(U);
2296
2297  case Instruction::PHI:
2298    return createNodeForPHI(cast<PHINode>(U));
2299
2300  case Instruction::Select:
2301    // This could be a smax or umax that was lowered earlier.
2302    // Try to recover it.
2303    if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
2304      Value *LHS = ICI->getOperand(0);
2305      Value *RHS = ICI->getOperand(1);
2306      switch (ICI->getPredicate()) {
2307      case ICmpInst::ICMP_SLT:
2308      case ICmpInst::ICMP_SLE:
2309        std::swap(LHS, RHS);
2310        // fall through
2311      case ICmpInst::ICMP_SGT:
2312      case ICmpInst::ICMP_SGE:
2313        if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
2314          return getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
2315        else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
2316          // ~smax(~x, ~y) == smin(x, y).
2317          return getNotSCEV(getSMaxExpr(
2318                                   getNotSCEV(getSCEV(LHS)),
2319                                   getNotSCEV(getSCEV(RHS))));
2320        break;
2321      case ICmpInst::ICMP_ULT:
2322      case ICmpInst::ICMP_ULE:
2323        std::swap(LHS, RHS);
2324        // fall through
2325      case ICmpInst::ICMP_UGT:
2326      case ICmpInst::ICMP_UGE:
2327        if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
2328          return getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
2329        else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
2330          // ~umax(~x, ~y) == umin(x, y)
2331          return getNotSCEV(getUMaxExpr(getNotSCEV(getSCEV(LHS)),
2332                                        getNotSCEV(getSCEV(RHS))));
2333        break;
2334      default:
2335        break;
2336      }
2337    }
2338
2339  default: // We cannot analyze this expression.
2340    break;
2341  }
2342
2343  return getUnknown(V);
2344}
2345
2346
2347
2348//===----------------------------------------------------------------------===//
2349//                   Iteration Count Computation Code
2350//
2351
2352/// getBackedgeTakenCount - If the specified loop has a predictable
2353/// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute
2354/// object. The backedge-taken count is the number of times the loop header
2355/// will be branched to from within the loop. This is one less than the
2356/// trip count of the loop, since it doesn't count the first iteration,
2357/// when the header is branched to from outside the loop.
2358///
2359/// Note that it is not valid to call this method on a loop without a
2360/// loop-invariant backedge-taken count (see
2361/// hasLoopInvariantBackedgeTakenCount).
2362///
2363SCEVHandle ScalarEvolution::getBackedgeTakenCount(const Loop *L) {
2364  return getBackedgeTakenInfo(L).Exact;
2365}
2366
2367/// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except
2368/// return the least SCEV value that is known never to be less than the
2369/// actual backedge taken count.
2370SCEVHandle ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {
2371  return getBackedgeTakenInfo(L).Max;
2372}
2373
2374const ScalarEvolution::BackedgeTakenInfo &
2375ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
2376  // Initially insert a CouldNotCompute for this loop. If the insertion
2377  // succeeds, procede to actually compute a backedge-taken count and
2378  // update the value. The temporary CouldNotCompute value tells SCEV
2379  // code elsewhere that it shouldn't attempt to request a new
2380  // backedge-taken count, which could result in infinite recursion.
2381  std::pair<std::map<const Loop*, BackedgeTakenInfo>::iterator, bool> Pair =
2382    BackedgeTakenCounts.insert(std::make_pair(L, getCouldNotCompute()));
2383  if (Pair.second) {
2384    BackedgeTakenInfo ItCount = ComputeBackedgeTakenCount(L);
2385    if (ItCount.Exact != UnknownValue) {
2386      assert(ItCount.Exact->isLoopInvariant(L) &&
2387             ItCount.Max->isLoopInvariant(L) &&
2388             "Computed trip count isn't loop invariant for loop!");
2389      ++NumTripCountsComputed;
2390
2391      // Update the value in the map.
2392      Pair.first->second = ItCount;
2393    } else if (isa<PHINode>(L->getHeader()->begin())) {
2394      // Only count loops that have phi nodes as not being computable.
2395      ++NumTripCountsNotComputed;
2396    }
2397
2398    // Now that we know more about the trip count for this loop, forget any
2399    // existing SCEV values for PHI nodes in this loop since they are only
2400    // conservative estimates made without the benefit
2401    // of trip count information.
2402    if (ItCount.hasAnyInfo())
2403      forgetLoopPHIs(L);
2404  }
2405  return Pair.first->second;
2406}
2407
2408/// forgetLoopBackedgeTakenCount - This method should be called by the
2409/// client when it has changed a loop in a way that may effect
2410/// ScalarEvolution's ability to compute a trip count, or if the loop
2411/// is deleted.
2412void ScalarEvolution::forgetLoopBackedgeTakenCount(const Loop *L) {
2413  BackedgeTakenCounts.erase(L);
2414  forgetLoopPHIs(L);
2415}
2416
2417/// forgetLoopPHIs - Delete the memoized SCEVs associated with the
2418/// PHI nodes in the given loop. This is used when the trip count of
2419/// the loop may have changed.
2420void ScalarEvolution::forgetLoopPHIs(const Loop *L) {
2421  BasicBlock *Header = L->getHeader();
2422
2423  // Push all Loop-header PHIs onto the Worklist stack, except those
2424  // that are presently represented via a SCEVUnknown. SCEVUnknown for
2425  // a PHI either means that it has an unrecognized structure, or it's
2426  // a PHI that's in the progress of being computed by createNodeForPHI.
2427  // In the former case, additional loop trip count information isn't
2428  // going to change anything. In the later case, createNodeForPHI will
2429  // perform the necessary updates on its own when it gets to that point.
2430  SmallVector<Instruction *, 16> Worklist;
2431  for (BasicBlock::iterator I = Header->begin();
2432       PHINode *PN = dyn_cast<PHINode>(I); ++I) {
2433    std::map<SCEVCallbackVH, SCEVHandle>::iterator It = Scalars.find((Value*)I);
2434    if (It != Scalars.end() && !isa<SCEVUnknown>(It->second))
2435      Worklist.push_back(PN);
2436  }
2437
2438  while (!Worklist.empty()) {
2439    Instruction *I = Worklist.pop_back_val();
2440    if (Scalars.erase(I))
2441      for (Value::use_iterator UI = I->use_begin(), UE = I->use_end();
2442           UI != UE; ++UI)
2443        Worklist.push_back(cast<Instruction>(UI));
2444  }
2445}
2446
2447/// ComputeBackedgeTakenCount - Compute the number of times the backedge
2448/// of the specified loop will execute.
2449ScalarEvolution::BackedgeTakenInfo
2450ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
2451  // If the loop has a non-one exit block count, we can't analyze it.
2452  SmallVector<BasicBlock*, 8> ExitBlocks;
2453  L->getExitBlocks(ExitBlocks);
2454  if (ExitBlocks.size() != 1) return UnknownValue;
2455
2456  // Okay, there is one exit block.  Try to find the condition that causes the
2457  // loop to be exited.
2458  BasicBlock *ExitBlock = ExitBlocks[0];
2459
2460  BasicBlock *ExitingBlock = 0;
2461  for (pred_iterator PI = pred_begin(ExitBlock), E = pred_end(ExitBlock);
2462       PI != E; ++PI)
2463    if (L->contains(*PI)) {
2464      if (ExitingBlock == 0)
2465        ExitingBlock = *PI;
2466      else
2467        return UnknownValue;   // More than one block exiting!
2468    }
2469  assert(ExitingBlock && "No exits from loop, something is broken!");
2470
2471  // Okay, we've computed the exiting block.  See what condition causes us to
2472  // exit.
2473  //
2474  // FIXME: we should be able to handle switch instructions (with a single exit)
2475  BranchInst *ExitBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
2476  if (ExitBr == 0) return UnknownValue;
2477  assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!");
2478
2479  // At this point, we know we have a conditional branch that determines whether
2480  // the loop is exited.  However, we don't know if the branch is executed each
2481  // time through the loop.  If not, then the execution count of the branch will
2482  // not be equal to the trip count of the loop.
2483  //
2484  // Currently we check for this by checking to see if the Exit branch goes to
2485  // the loop header.  If so, we know it will always execute the same number of
2486  // times as the loop.  We also handle the case where the exit block *is* the
2487  // loop header.  This is common for un-rotated loops.  More extensive analysis
2488  // could be done to handle more cases here.
2489  if (ExitBr->getSuccessor(0) != L->getHeader() &&
2490      ExitBr->getSuccessor(1) != L->getHeader() &&
2491      ExitBr->getParent() != L->getHeader())
2492    return UnknownValue;
2493
2494  ICmpInst *ExitCond = dyn_cast<ICmpInst>(ExitBr->getCondition());
2495
2496  // If it's not an integer or pointer comparison then compute it the hard way.
2497  if (ExitCond == 0)
2498    return ComputeBackedgeTakenCountExhaustively(L, ExitBr->getCondition(),
2499                                          ExitBr->getSuccessor(0) == ExitBlock);
2500
2501  // If the condition was exit on true, convert the condition to exit on false
2502  ICmpInst::Predicate Cond;
2503  if (ExitBr->getSuccessor(1) == ExitBlock)
2504    Cond = ExitCond->getPredicate();
2505  else
2506    Cond = ExitCond->getInversePredicate();
2507
2508  // Handle common loops like: for (X = "string"; *X; ++X)
2509  if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
2510    if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
2511      SCEVHandle ItCnt =
2512        ComputeLoadConstantCompareBackedgeTakenCount(LI, RHS, L, Cond);
2513      if (!isa<SCEVCouldNotCompute>(ItCnt)) return ItCnt;
2514    }
2515
2516  SCEVHandle LHS = getSCEV(ExitCond->getOperand(0));
2517  SCEVHandle RHS = getSCEV(ExitCond->getOperand(1));
2518
2519  // Try to evaluate any dependencies out of the loop.
2520  SCEVHandle Tmp = getSCEVAtScope(LHS, L);
2521  if (!isa<SCEVCouldNotCompute>(Tmp)) LHS = Tmp;
2522  Tmp = getSCEVAtScope(RHS, L);
2523  if (!isa<SCEVCouldNotCompute>(Tmp)) RHS = Tmp;
2524
2525  // At this point, we would like to compute how many iterations of the
2526  // loop the predicate will return true for these inputs.
2527  if (LHS->isLoopInvariant(L) && !RHS->isLoopInvariant(L)) {
2528    // If there is a loop-invariant, force it into the RHS.
2529    std::swap(LHS, RHS);
2530    Cond = ICmpInst::getSwappedPredicate(Cond);
2531  }
2532
2533  // If we have a comparison of a chrec against a constant, try to use value
2534  // ranges to answer this query.
2535  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
2536    if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
2537      if (AddRec->getLoop() == L) {
2538        // Form the constant range.
2539        ConstantRange CompRange(
2540            ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue()));
2541
2542        SCEVHandle Ret = AddRec->getNumIterationsInRange(CompRange, *this);
2543        if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
2544      }
2545
2546  switch (Cond) {
2547  case ICmpInst::ICMP_NE: {                     // while (X != Y)
2548    // Convert to: while (X-Y != 0)
2549    SCEVHandle TC = HowFarToZero(getMinusSCEV(LHS, RHS), L);
2550    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2551    break;
2552  }
2553  case ICmpInst::ICMP_EQ: {
2554    // Convert to: while (X-Y == 0)           // while (X == Y)
2555    SCEVHandle TC = HowFarToNonZero(getMinusSCEV(LHS, RHS), L);
2556    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2557    break;
2558  }
2559  case ICmpInst::ICMP_SLT: {
2560    BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, true);
2561    if (BTI.hasAnyInfo()) return BTI;
2562    break;
2563  }
2564  case ICmpInst::ICMP_SGT: {
2565    BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS),
2566                                             getNotSCEV(RHS), L, true);
2567    if (BTI.hasAnyInfo()) return BTI;
2568    break;
2569  }
2570  case ICmpInst::ICMP_ULT: {
2571    BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, false);
2572    if (BTI.hasAnyInfo()) return BTI;
2573    break;
2574  }
2575  case ICmpInst::ICMP_UGT: {
2576    BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS),
2577                                             getNotSCEV(RHS), L, false);
2578    if (BTI.hasAnyInfo()) return BTI;
2579    break;
2580  }
2581  default:
2582#if 0
2583    errs() << "ComputeBackedgeTakenCount ";
2584    if (ExitCond->getOperand(0)->getType()->isUnsigned())
2585      errs() << "[unsigned] ";
2586    errs() << *LHS << "   "
2587         << Instruction::getOpcodeName(Instruction::ICmp)
2588         << "   " << *RHS << "\n";
2589#endif
2590    break;
2591  }
2592  return
2593    ComputeBackedgeTakenCountExhaustively(L, ExitCond,
2594                                          ExitBr->getSuccessor(0) == ExitBlock);
2595}
2596
2597static ConstantInt *
2598EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
2599                                ScalarEvolution &SE) {
2600  SCEVHandle InVal = SE.getConstant(C);
2601  SCEVHandle Val = AddRec->evaluateAtIteration(InVal, SE);
2602  assert(isa<SCEVConstant>(Val) &&
2603         "Evaluation of SCEV at constant didn't fold correctly?");
2604  return cast<SCEVConstant>(Val)->getValue();
2605}
2606
2607/// GetAddressedElementFromGlobal - Given a global variable with an initializer
2608/// and a GEP expression (missing the pointer index) indexing into it, return
2609/// the addressed element of the initializer or null if the index expression is
2610/// invalid.
2611static Constant *
2612GetAddressedElementFromGlobal(GlobalVariable *GV,
2613                              const std::vector<ConstantInt*> &Indices) {
2614  Constant *Init = GV->getInitializer();
2615  for (unsigned i = 0, e = Indices.size(); i != e; ++i) {
2616    uint64_t Idx = Indices[i]->getZExtValue();
2617    if (ConstantStruct *CS = dyn_cast<ConstantStruct>(Init)) {
2618      assert(Idx < CS->getNumOperands() && "Bad struct index!");
2619      Init = cast<Constant>(CS->getOperand(Idx));
2620    } else if (ConstantArray *CA = dyn_cast<ConstantArray>(Init)) {
2621      if (Idx >= CA->getNumOperands()) return 0;  // Bogus program
2622      Init = cast<Constant>(CA->getOperand(Idx));
2623    } else if (isa<ConstantAggregateZero>(Init)) {
2624      if (const StructType *STy = dyn_cast<StructType>(Init->getType())) {
2625        assert(Idx < STy->getNumElements() && "Bad struct index!");
2626        Init = Constant::getNullValue(STy->getElementType(Idx));
2627      } else if (const ArrayType *ATy = dyn_cast<ArrayType>(Init->getType())) {
2628        if (Idx >= ATy->getNumElements()) return 0;  // Bogus program
2629        Init = Constant::getNullValue(ATy->getElementType());
2630      } else {
2631        assert(0 && "Unknown constant aggregate type!");
2632      }
2633      return 0;
2634    } else {
2635      return 0; // Unknown initializer type
2636    }
2637  }
2638  return Init;
2639}
2640
2641/// ComputeLoadConstantCompareBackedgeTakenCount - Given an exit condition of
2642/// 'icmp op load X, cst', try to see if we can compute the backedge
2643/// execution count.
2644SCEVHandle ScalarEvolution::
2645ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, Constant *RHS,
2646                                             const Loop *L,
2647                                             ICmpInst::Predicate predicate) {
2648  if (LI->isVolatile()) return UnknownValue;
2649
2650  // Check to see if the loaded pointer is a getelementptr of a global.
2651  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0));
2652  if (!GEP) return UnknownValue;
2653
2654  // Make sure that it is really a constant global we are gepping, with an
2655  // initializer, and make sure the first IDX is really 0.
2656  GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
2657  if (!GV || !GV->isConstant() || !GV->hasInitializer() ||
2658      GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) ||
2659      !cast<Constant>(GEP->getOperand(1))->isNullValue())
2660    return UnknownValue;
2661
2662  // Okay, we allow one non-constant index into the GEP instruction.
2663  Value *VarIdx = 0;
2664  std::vector<ConstantInt*> Indexes;
2665  unsigned VarIdxNum = 0;
2666  for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i)
2667    if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
2668      Indexes.push_back(CI);
2669    } else if (!isa<ConstantInt>(GEP->getOperand(i))) {
2670      if (VarIdx) return UnknownValue;  // Multiple non-constant idx's.
2671      VarIdx = GEP->getOperand(i);
2672      VarIdxNum = i-2;
2673      Indexes.push_back(0);
2674    }
2675
2676  // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant.
2677  // Check to see if X is a loop variant variable value now.
2678  SCEVHandle Idx = getSCEV(VarIdx);
2679  SCEVHandle Tmp = getSCEVAtScope(Idx, L);
2680  if (!isa<SCEVCouldNotCompute>(Tmp)) Idx = Tmp;
2681
2682  // We can only recognize very limited forms of loop index expressions, in
2683  // particular, only affine AddRec's like {C1,+,C2}.
2684  const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx);
2685  if (!IdxExpr || !IdxExpr->isAffine() || IdxExpr->isLoopInvariant(L) ||
2686      !isa<SCEVConstant>(IdxExpr->getOperand(0)) ||
2687      !isa<SCEVConstant>(IdxExpr->getOperand(1)))
2688    return UnknownValue;
2689
2690  unsigned MaxSteps = MaxBruteForceIterations;
2691  for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
2692    ConstantInt *ItCst =
2693      ConstantInt::get(IdxExpr->getType(), IterationNum);
2694    ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this);
2695
2696    // Form the GEP offset.
2697    Indexes[VarIdxNum] = Val;
2698
2699    Constant *Result = GetAddressedElementFromGlobal(GV, Indexes);
2700    if (Result == 0) break;  // Cannot compute!
2701
2702    // Evaluate the condition for this iteration.
2703    Result = ConstantExpr::getICmp(predicate, Result, RHS);
2704    if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
2705    if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
2706#if 0
2707      errs() << "\n***\n*** Computed loop count " << *ItCst
2708             << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
2709             << "***\n";
2710#endif
2711      ++NumArrayLenItCounts;
2712      return getConstant(ItCst);   // Found terminating iteration!
2713    }
2714  }
2715  return UnknownValue;
2716}
2717
2718
2719/// CanConstantFold - Return true if we can constant fold an instruction of the
2720/// specified type, assuming that all operands were constants.
2721static bool CanConstantFold(const Instruction *I) {
2722  if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
2723      isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I))
2724    return true;
2725
2726  if (const CallInst *CI = dyn_cast<CallInst>(I))
2727    if (const Function *F = CI->getCalledFunction())
2728      return canConstantFoldCallTo(F);
2729  return false;
2730}
2731
2732/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
2733/// in the loop that V is derived from.  We allow arbitrary operations along the
2734/// way, but the operands of an operation must either be constants or a value
2735/// derived from a constant PHI.  If this expression does not fit with these
2736/// constraints, return null.
2737static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
2738  // If this is not an instruction, or if this is an instruction outside of the
2739  // loop, it can't be derived from a loop PHI.
2740  Instruction *I = dyn_cast<Instruction>(V);
2741  if (I == 0 || !L->contains(I->getParent())) return 0;
2742
2743  if (PHINode *PN = dyn_cast<PHINode>(I)) {
2744    if (L->getHeader() == I->getParent())
2745      return PN;
2746    else
2747      // We don't currently keep track of the control flow needed to evaluate
2748      // PHIs, so we cannot handle PHIs inside of loops.
2749      return 0;
2750  }
2751
2752  // If we won't be able to constant fold this expression even if the operands
2753  // are constants, return early.
2754  if (!CanConstantFold(I)) return 0;
2755
2756  // Otherwise, we can evaluate this instruction if all of its operands are
2757  // constant or derived from a PHI node themselves.
2758  PHINode *PHI = 0;
2759  for (unsigned Op = 0, e = I->getNumOperands(); Op != e; ++Op)
2760    if (!(isa<Constant>(I->getOperand(Op)) ||
2761          isa<GlobalValue>(I->getOperand(Op)))) {
2762      PHINode *P = getConstantEvolvingPHI(I->getOperand(Op), L);
2763      if (P == 0) return 0;  // Not evolving from PHI
2764      if (PHI == 0)
2765        PHI = P;
2766      else if (PHI != P)
2767        return 0;  // Evolving from multiple different PHIs.
2768    }
2769
2770  // This is a expression evolving from a constant PHI!
2771  return PHI;
2772}
2773
2774/// EvaluateExpression - Given an expression that passes the
2775/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
2776/// in the loop has the value PHIVal.  If we can't fold this expression for some
2777/// reason, return null.
2778static Constant *EvaluateExpression(Value *V, Constant *PHIVal) {
2779  if (isa<PHINode>(V)) return PHIVal;
2780  if (Constant *C = dyn_cast<Constant>(V)) return C;
2781  if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) return GV;
2782  Instruction *I = cast<Instruction>(V);
2783
2784  std::vector<Constant*> Operands;
2785  Operands.resize(I->getNumOperands());
2786
2787  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
2788    Operands[i] = EvaluateExpression(I->getOperand(i), PHIVal);
2789    if (Operands[i] == 0) return 0;
2790  }
2791
2792  if (const CmpInst *CI = dyn_cast<CmpInst>(I))
2793    return ConstantFoldCompareInstOperands(CI->getPredicate(),
2794                                           &Operands[0], Operands.size());
2795  else
2796    return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
2797                                    &Operands[0], Operands.size());
2798}
2799
2800/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
2801/// in the header of its containing loop, we know the loop executes a
2802/// constant number of times, and the PHI node is just a recurrence
2803/// involving constants, fold it.
2804Constant *ScalarEvolution::
2805getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& BEs, const Loop *L){
2806  std::map<PHINode*, Constant*>::iterator I =
2807    ConstantEvolutionLoopExitValue.find(PN);
2808  if (I != ConstantEvolutionLoopExitValue.end())
2809    return I->second;
2810
2811  if (BEs.ugt(APInt(BEs.getBitWidth(),MaxBruteForceIterations)))
2812    return ConstantEvolutionLoopExitValue[PN] = 0;  // Not going to evaluate it.
2813
2814  Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
2815
2816  // Since the loop is canonicalized, the PHI node must have two entries.  One
2817  // entry must be a constant (coming in from outside of the loop), and the
2818  // second must be derived from the same PHI.
2819  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
2820  Constant *StartCST =
2821    dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
2822  if (StartCST == 0)
2823    return RetVal = 0;  // Must be a constant.
2824
2825  Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
2826  PHINode *PN2 = getConstantEvolvingPHI(BEValue, L);
2827  if (PN2 != PN)
2828    return RetVal = 0;  // Not derived from same PHI.
2829
2830  // Execute the loop symbolically to determine the exit value.
2831  if (BEs.getActiveBits() >= 32)
2832    return RetVal = 0; // More than 2^32-1 iterations?? Not doing it!
2833
2834  unsigned NumIterations = BEs.getZExtValue(); // must be in range
2835  unsigned IterationNum = 0;
2836  for (Constant *PHIVal = StartCST; ; ++IterationNum) {
2837    if (IterationNum == NumIterations)
2838      return RetVal = PHIVal;  // Got exit value!
2839
2840    // Compute the value of the PHI node for the next iteration.
2841    Constant *NextPHI = EvaluateExpression(BEValue, PHIVal);
2842    if (NextPHI == PHIVal)
2843      return RetVal = NextPHI;  // Stopped evolving!
2844    if (NextPHI == 0)
2845      return 0;        // Couldn't evaluate!
2846    PHIVal = NextPHI;
2847  }
2848}
2849
2850/// ComputeBackedgeTakenCountExhaustively - If the trip is known to execute a
2851/// constant number of times (the condition evolves only from constants),
2852/// try to evaluate a few iterations of the loop until we get the exit
2853/// condition gets a value of ExitWhen (true or false).  If we cannot
2854/// evaluate the trip count of the loop, return UnknownValue.
2855SCEVHandle ScalarEvolution::
2856ComputeBackedgeTakenCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) {
2857  PHINode *PN = getConstantEvolvingPHI(Cond, L);
2858  if (PN == 0) return UnknownValue;
2859
2860  // Since the loop is canonicalized, the PHI node must have two entries.  One
2861  // entry must be a constant (coming in from outside of the loop), and the
2862  // second must be derived from the same PHI.
2863  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
2864  Constant *StartCST =
2865    dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
2866  if (StartCST == 0) return UnknownValue;  // Must be a constant.
2867
2868  Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
2869  PHINode *PN2 = getConstantEvolvingPHI(BEValue, L);
2870  if (PN2 != PN) return UnknownValue;  // Not derived from same PHI.
2871
2872  // Okay, we find a PHI node that defines the trip count of this loop.  Execute
2873  // the loop symbolically to determine when the condition gets a value of
2874  // "ExitWhen".
2875  unsigned IterationNum = 0;
2876  unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis.
2877  for (Constant *PHIVal = StartCST;
2878       IterationNum != MaxIterations; ++IterationNum) {
2879    ConstantInt *CondVal =
2880      dyn_cast_or_null<ConstantInt>(EvaluateExpression(Cond, PHIVal));
2881
2882    // Couldn't symbolically evaluate.
2883    if (!CondVal) return UnknownValue;
2884
2885    if (CondVal->getValue() == uint64_t(ExitWhen)) {
2886      ConstantEvolutionLoopExitValue[PN] = PHIVal;
2887      ++NumBruteForceTripCountsComputed;
2888      return getConstant(ConstantInt::get(Type::Int32Ty, IterationNum));
2889    }
2890
2891    // Compute the value of the PHI node for the next iteration.
2892    Constant *NextPHI = EvaluateExpression(BEValue, PHIVal);
2893    if (NextPHI == 0 || NextPHI == PHIVal)
2894      return UnknownValue;  // Couldn't evaluate or not making progress...
2895    PHIVal = NextPHI;
2896  }
2897
2898  // Too many iterations were needed to evaluate.
2899  return UnknownValue;
2900}
2901
2902/// getSCEVAtScope - Return a SCEV expression handle for the specified value
2903/// at the specified scope in the program.  The L value specifies a loop
2904/// nest to evaluate the expression at, where null is the top-level or a
2905/// specified loop is immediately inside of the loop.
2906///
2907/// This method can be used to compute the exit value for a variable defined
2908/// in a loop by querying what the value will hold in the parent loop.
2909///
2910/// If this value is not computable at this scope, a SCEVCouldNotCompute
2911/// object is returned.
2912SCEVHandle ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
2913  // FIXME: this should be turned into a virtual method on SCEV!
2914
2915  if (isa<SCEVConstant>(V)) return V;
2916
2917  // If this instruction is evolved from a constant-evolving PHI, compute the
2918  // exit value from the loop without using SCEVs.
2919  if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
2920    if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
2921      const Loop *LI = (*this->LI)[I->getParent()];
2922      if (LI && LI->getParentLoop() == L)  // Looking for loop exit value.
2923        if (PHINode *PN = dyn_cast<PHINode>(I))
2924          if (PN->getParent() == LI->getHeader()) {
2925            // Okay, there is no closed form solution for the PHI node.  Check
2926            // to see if the loop that contains it has a known backedge-taken
2927            // count.  If so, we may be able to force computation of the exit
2928            // value.
2929            SCEVHandle BackedgeTakenCount = getBackedgeTakenCount(LI);
2930            if (const SCEVConstant *BTCC =
2931                  dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
2932              // Okay, we know how many times the containing loop executes.  If
2933              // this is a constant evolving PHI node, get the final value at
2934              // the specified iteration number.
2935              Constant *RV = getConstantEvolutionLoopExitValue(PN,
2936                                                   BTCC->getValue()->getValue(),
2937                                                               LI);
2938              if (RV) return getUnknown(RV);
2939            }
2940          }
2941
2942      // Okay, this is an expression that we cannot symbolically evaluate
2943      // into a SCEV.  Check to see if it's possible to symbolically evaluate
2944      // the arguments into constants, and if so, try to constant propagate the
2945      // result.  This is particularly useful for computing loop exit values.
2946      if (CanConstantFold(I)) {
2947        // Check to see if we've folded this instruction at this loop before.
2948        std::map<const Loop *, Constant *> &Values = ValuesAtScopes[I];
2949        std::pair<std::map<const Loop *, Constant *>::iterator, bool> Pair =
2950          Values.insert(std::make_pair(L, static_cast<Constant *>(0)));
2951        if (!Pair.second)
2952          return Pair.first->second ? &*getUnknown(Pair.first->second) : V;
2953
2954        std::vector<Constant*> Operands;
2955        Operands.reserve(I->getNumOperands());
2956        for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
2957          Value *Op = I->getOperand(i);
2958          if (Constant *C = dyn_cast<Constant>(Op)) {
2959            Operands.push_back(C);
2960          } else {
2961            // If any of the operands is non-constant and if they are
2962            // non-integer and non-pointer, don't even try to analyze them
2963            // with scev techniques.
2964            if (!isSCEVable(Op->getType()))
2965              return V;
2966
2967            SCEVHandle OpV = getSCEVAtScope(getSCEV(Op), L);
2968            if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(OpV)) {
2969              Constant *C = SC->getValue();
2970              if (C->getType() != Op->getType())
2971                C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
2972                                                                  Op->getType(),
2973                                                                  false),
2974                                          C, Op->getType());
2975              Operands.push_back(C);
2976            } else if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(OpV)) {
2977              if (Constant *C = dyn_cast<Constant>(SU->getValue())) {
2978                if (C->getType() != Op->getType())
2979                  C =
2980                    ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
2981                                                                  Op->getType(),
2982                                                                  false),
2983                                          C, Op->getType());
2984                Operands.push_back(C);
2985              } else
2986                return V;
2987            } else {
2988              return V;
2989            }
2990          }
2991        }
2992
2993        Constant *C;
2994        if (const CmpInst *CI = dyn_cast<CmpInst>(I))
2995          C = ConstantFoldCompareInstOperands(CI->getPredicate(),
2996                                              &Operands[0], Operands.size());
2997        else
2998          C = ConstantFoldInstOperands(I->getOpcode(), I->getType(),
2999                                       &Operands[0], Operands.size());
3000        Pair.first->second = C;
3001        return getUnknown(C);
3002      }
3003    }
3004
3005    // This is some other type of SCEVUnknown, just return it.
3006    return V;
3007  }
3008
3009  if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) {
3010    // Avoid performing the look-up in the common case where the specified
3011    // expression has no loop-variant portions.
3012    for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
3013      SCEVHandle OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
3014      if (OpAtScope != Comm->getOperand(i)) {
3015        if (OpAtScope == UnknownValue) return UnknownValue;
3016        // Okay, at least one of these operands is loop variant but might be
3017        // foldable.  Build a new instance of the folded commutative expression.
3018        std::vector<SCEVHandle> NewOps(Comm->op_begin(), Comm->op_begin()+i);
3019        NewOps.push_back(OpAtScope);
3020
3021        for (++i; i != e; ++i) {
3022          OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
3023          if (OpAtScope == UnknownValue) return UnknownValue;
3024          NewOps.push_back(OpAtScope);
3025        }
3026        if (isa<SCEVAddExpr>(Comm))
3027          return getAddExpr(NewOps);
3028        if (isa<SCEVMulExpr>(Comm))
3029          return getMulExpr(NewOps);
3030        if (isa<SCEVSMaxExpr>(Comm))
3031          return getSMaxExpr(NewOps);
3032        if (isa<SCEVUMaxExpr>(Comm))
3033          return getUMaxExpr(NewOps);
3034        assert(0 && "Unknown commutative SCEV type!");
3035      }
3036    }
3037    // If we got here, all operands are loop invariant.
3038    return Comm;
3039  }
3040
3041  if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
3042    SCEVHandle LHS = getSCEVAtScope(Div->getLHS(), L);
3043    if (LHS == UnknownValue) return LHS;
3044    SCEVHandle RHS = getSCEVAtScope(Div->getRHS(), L);
3045    if (RHS == UnknownValue) return RHS;
3046    if (LHS == Div->getLHS() && RHS == Div->getRHS())
3047      return Div;   // must be loop invariant
3048    return getUDivExpr(LHS, RHS);
3049  }
3050
3051  // If this is a loop recurrence for a loop that does not contain L, then we
3052  // are dealing with the final value computed by the loop.
3053  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
3054    if (!L || !AddRec->getLoop()->contains(L->getHeader())) {
3055      // To evaluate this recurrence, we need to know how many times the AddRec
3056      // loop iterates.  Compute this now.
3057      SCEVHandle BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
3058      if (BackedgeTakenCount == UnknownValue) return UnknownValue;
3059
3060      // Then, evaluate the AddRec.
3061      return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
3062    }
3063    return UnknownValue;
3064  }
3065
3066  if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) {
3067    SCEVHandle Op = getSCEVAtScope(Cast->getOperand(), L);
3068    if (Op == UnknownValue) return Op;
3069    if (Op == Cast->getOperand())
3070      return Cast;  // must be loop invariant
3071    return getZeroExtendExpr(Op, Cast->getType());
3072  }
3073
3074  if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) {
3075    SCEVHandle Op = getSCEVAtScope(Cast->getOperand(), L);
3076    if (Op == UnknownValue) return Op;
3077    if (Op == Cast->getOperand())
3078      return Cast;  // must be loop invariant
3079    return getSignExtendExpr(Op, Cast->getType());
3080  }
3081
3082  if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) {
3083    SCEVHandle Op = getSCEVAtScope(Cast->getOperand(), L);
3084    if (Op == UnknownValue) return Op;
3085    if (Op == Cast->getOperand())
3086      return Cast;  // must be loop invariant
3087    return getTruncateExpr(Op, Cast->getType());
3088  }
3089
3090  assert(0 && "Unknown SCEV type!");
3091  return 0;
3092}
3093
3094/// getSCEVAtScope - This is a convenience function which does
3095/// getSCEVAtScope(getSCEV(V), L).
3096SCEVHandle ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
3097  return getSCEVAtScope(getSCEV(V), L);
3098}
3099
3100/// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the
3101/// following equation:
3102///
3103///     A * X = B (mod N)
3104///
3105/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
3106/// A and B isn't important.
3107///
3108/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
3109static SCEVHandle SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
3110                                               ScalarEvolution &SE) {
3111  uint32_t BW = A.getBitWidth();
3112  assert(BW == B.getBitWidth() && "Bit widths must be the same.");
3113  assert(A != 0 && "A must be non-zero.");
3114
3115  // 1. D = gcd(A, N)
3116  //
3117  // The gcd of A and N may have only one prime factor: 2. The number of
3118  // trailing zeros in A is its multiplicity
3119  uint32_t Mult2 = A.countTrailingZeros();
3120  // D = 2^Mult2
3121
3122  // 2. Check if B is divisible by D.
3123  //
3124  // B is divisible by D if and only if the multiplicity of prime factor 2 for B
3125  // is not less than multiplicity of this prime factor for D.
3126  if (B.countTrailingZeros() < Mult2)
3127    return SE.getCouldNotCompute();
3128
3129  // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
3130  // modulo (N / D).
3131  //
3132  // (N / D) may need BW+1 bits in its representation.  Hence, we'll use this
3133  // bit width during computations.
3134  APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
3135  APInt Mod(BW + 1, 0);
3136  Mod.set(BW - Mult2);  // Mod = N / D
3137  APInt I = AD.multiplicativeInverse(Mod);
3138
3139  // 4. Compute the minimum unsigned root of the equation:
3140  // I * (B / D) mod (N / D)
3141  APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
3142
3143  // The result is guaranteed to be less than 2^BW so we may truncate it to BW
3144  // bits.
3145  return SE.getConstant(Result.trunc(BW));
3146}
3147
3148/// SolveQuadraticEquation - Find the roots of the quadratic equation for the
3149/// given quadratic chrec {L,+,M,+,N}.  This returns either the two roots (which
3150/// might be the same) or two SCEVCouldNotCompute objects.
3151///
3152static std::pair<SCEVHandle,SCEVHandle>
3153SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
3154  assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
3155  const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
3156  const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
3157  const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
3158
3159  // We currently can only solve this if the coefficients are constants.
3160  if (!LC || !MC || !NC) {
3161    const SCEV *CNC = SE.getCouldNotCompute();
3162    return std::make_pair(CNC, CNC);
3163  }
3164
3165  uint32_t BitWidth = LC->getValue()->getValue().getBitWidth();
3166  const APInt &L = LC->getValue()->getValue();
3167  const APInt &M = MC->getValue()->getValue();
3168  const APInt &N = NC->getValue()->getValue();
3169  APInt Two(BitWidth, 2);
3170  APInt Four(BitWidth, 4);
3171
3172  {
3173    using namespace APIntOps;
3174    const APInt& C = L;
3175    // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
3176    // The B coefficient is M-N/2
3177    APInt B(M);
3178    B -= sdiv(N,Two);
3179
3180    // The A coefficient is N/2
3181    APInt A(N.sdiv(Two));
3182
3183    // Compute the B^2-4ac term.
3184    APInt SqrtTerm(B);
3185    SqrtTerm *= B;
3186    SqrtTerm -= Four * (A * C);
3187
3188    // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
3189    // integer value or else APInt::sqrt() will assert.
3190    APInt SqrtVal(SqrtTerm.sqrt());
3191
3192    // Compute the two solutions for the quadratic formula.
3193    // The divisions must be performed as signed divisions.
3194    APInt NegB(-B);
3195    APInt TwoA( A << 1 );
3196    if (TwoA.isMinValue()) {
3197      const SCEV *CNC = SE.getCouldNotCompute();
3198      return std::make_pair(CNC, CNC);
3199    }
3200
3201    ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA));
3202    ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA));
3203
3204    return std::make_pair(SE.getConstant(Solution1),
3205                          SE.getConstant(Solution2));
3206    } // end APIntOps namespace
3207}
3208
3209/// HowFarToZero - Return the number of times a backedge comparing the specified
3210/// value to zero will execute.  If not computable, return UnknownValue
3211SCEVHandle ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) {
3212  // If the value is a constant
3213  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
3214    // If the value is already zero, the branch will execute zero times.
3215    if (C->getValue()->isZero()) return C;
3216    return UnknownValue;  // Otherwise it will loop infinitely.
3217  }
3218
3219  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V);
3220  if (!AddRec || AddRec->getLoop() != L)
3221    return UnknownValue;
3222
3223  if (AddRec->isAffine()) {
3224    // If this is an affine expression, the execution count of this branch is
3225    // the minimum unsigned root of the following equation:
3226    //
3227    //     Start + Step*N = 0 (mod 2^BW)
3228    //
3229    // equivalent to:
3230    //
3231    //             Step*N = -Start (mod 2^BW)
3232    //
3233    // where BW is the common bit width of Start and Step.
3234
3235    // Get the initial value for the loop.
3236    SCEVHandle Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
3237    if (isa<SCEVCouldNotCompute>(Start)) return UnknownValue;
3238
3239    SCEVHandle Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
3240
3241    if (const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step)) {
3242      // For now we handle only constant steps.
3243
3244      // First, handle unitary steps.
3245      if (StepC->getValue()->equalsInt(1))      // 1*N = -Start (mod 2^BW), so:
3246        return getNegativeSCEV(Start);       //   N = -Start (as unsigned)
3247      if (StepC->getValue()->isAllOnesValue())  // -1*N = -Start (mod 2^BW), so:
3248        return Start;                           //    N = Start (as unsigned)
3249
3250      // Then, try to solve the above equation provided that Start is constant.
3251      if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
3252        return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
3253                                            -StartC->getValue()->getValue(),
3254                                            *this);
3255    }
3256  } else if (AddRec->isQuadratic() && AddRec->getType()->isInteger()) {
3257    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
3258    // the quadratic equation to solve it.
3259    std::pair<SCEVHandle,SCEVHandle> Roots = SolveQuadraticEquation(AddRec,
3260                                                                    *this);
3261    const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
3262    const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
3263    if (R1) {
3264#if 0
3265      errs() << "HFTZ: " << *V << " - sol#1: " << *R1
3266             << "  sol#2: " << *R2 << "\n";
3267#endif
3268      // Pick the smallest positive root value.
3269      if (ConstantInt *CB =
3270          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
3271                                   R1->getValue(), R2->getValue()))) {
3272        if (CB->getZExtValue() == false)
3273          std::swap(R1, R2);   // R1 is the minimum root now.
3274
3275        // We can only use this value if the chrec ends up with an exact zero
3276        // value at this index.  When solving for "X*X != 5", for example, we
3277        // should not accept a root of 2.
3278        SCEVHandle Val = AddRec->evaluateAtIteration(R1, *this);
3279        if (Val->isZero())
3280          return R1;  // We found a quadratic root!
3281      }
3282    }
3283  }
3284
3285  return UnknownValue;
3286}
3287
3288/// HowFarToNonZero - Return the number of times a backedge checking the
3289/// specified value for nonzero will execute.  If not computable, return
3290/// UnknownValue
3291SCEVHandle ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) {
3292  // Loops that look like: while (X == 0) are very strange indeed.  We don't
3293  // handle them yet except for the trivial case.  This could be expanded in the
3294  // future as needed.
3295
3296  // If the value is a constant, check to see if it is known to be non-zero
3297  // already.  If so, the backedge will execute zero times.
3298  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
3299    if (!C->getValue()->isNullValue())
3300      return getIntegerSCEV(0, C->getType());
3301    return UnknownValue;  // Otherwise it will loop infinitely.
3302  }
3303
3304  // We could implement others, but I really doubt anyone writes loops like
3305  // this, and if they did, they would already be constant folded.
3306  return UnknownValue;
3307}
3308
3309/// getLoopPredecessor - If the given loop's header has exactly one unique
3310/// predecessor outside the loop, return it. Otherwise return null.
3311///
3312BasicBlock *ScalarEvolution::getLoopPredecessor(const Loop *L) {
3313  BasicBlock *Header = L->getHeader();
3314  BasicBlock *Pred = 0;
3315  for (pred_iterator PI = pred_begin(Header), E = pred_end(Header);
3316       PI != E; ++PI)
3317    if (!L->contains(*PI)) {
3318      if (Pred && Pred != *PI) return 0; // Multiple predecessors.
3319      Pred = *PI;
3320    }
3321  return Pred;
3322}
3323
3324/// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
3325/// (which may not be an immediate predecessor) which has exactly one
3326/// successor from which BB is reachable, or null if no such block is
3327/// found.
3328///
3329BasicBlock *
3330ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
3331  // If the block has a unique predecessor, then there is no path from the
3332  // predecessor to the block that does not go through the direct edge
3333  // from the predecessor to the block.
3334  if (BasicBlock *Pred = BB->getSinglePredecessor())
3335    return Pred;
3336
3337  // A loop's header is defined to be a block that dominates the loop.
3338  // If the header has a unique predecessor outside the loop, it must be
3339  // a block that has exactly one successor that can reach the loop.
3340  if (Loop *L = LI->getLoopFor(BB))
3341    return getLoopPredecessor(L);
3342
3343  return 0;
3344}
3345
3346/// isLoopGuardedByCond - Test whether entry to the loop is protected by
3347/// a conditional between LHS and RHS.  This is used to help avoid max
3348/// expressions in loop trip counts.
3349bool ScalarEvolution::isLoopGuardedByCond(const Loop *L,
3350                                          ICmpInst::Predicate Pred,
3351                                          const SCEV *LHS, const SCEV *RHS) {
3352  // Interpret a null as meaning no loop, where there is obviously no guard
3353  // (interprocedural conditions notwithstanding).
3354  if (!L) return false;
3355
3356  BasicBlock *Predecessor = getLoopPredecessor(L);
3357  BasicBlock *PredecessorDest = L->getHeader();
3358
3359  // Starting at the loop predecessor, climb up the predecessor chain, as long
3360  // as there are predecessors that can be found that have unique successors
3361  // leading to the original header.
3362  for (; Predecessor;
3363       PredecessorDest = Predecessor,
3364       Predecessor = getPredecessorWithUniqueSuccessorForBB(Predecessor)) {
3365
3366    BranchInst *LoopEntryPredicate =
3367      dyn_cast<BranchInst>(Predecessor->getTerminator());
3368    if (!LoopEntryPredicate ||
3369        LoopEntryPredicate->isUnconditional())
3370      continue;
3371
3372    ICmpInst *ICI = dyn_cast<ICmpInst>(LoopEntryPredicate->getCondition());
3373    if (!ICI) continue;
3374
3375    // Now that we found a conditional branch that dominates the loop, check to
3376    // see if it is the comparison we are looking for.
3377    Value *PreCondLHS = ICI->getOperand(0);
3378    Value *PreCondRHS = ICI->getOperand(1);
3379    ICmpInst::Predicate Cond;
3380    if (LoopEntryPredicate->getSuccessor(0) == PredecessorDest)
3381      Cond = ICI->getPredicate();
3382    else
3383      Cond = ICI->getInversePredicate();
3384
3385    if (Cond == Pred)
3386      ; // An exact match.
3387    else if (!ICmpInst::isTrueWhenEqual(Cond) && Pred == ICmpInst::ICMP_NE)
3388      ; // The actual condition is beyond sufficient.
3389    else
3390      // Check a few special cases.
3391      switch (Cond) {
3392      case ICmpInst::ICMP_UGT:
3393        if (Pred == ICmpInst::ICMP_ULT) {
3394          std::swap(PreCondLHS, PreCondRHS);
3395          Cond = ICmpInst::ICMP_ULT;
3396          break;
3397        }
3398        continue;
3399      case ICmpInst::ICMP_SGT:
3400        if (Pred == ICmpInst::ICMP_SLT) {
3401          std::swap(PreCondLHS, PreCondRHS);
3402          Cond = ICmpInst::ICMP_SLT;
3403          break;
3404        }
3405        continue;
3406      case ICmpInst::ICMP_NE:
3407        // Expressions like (x >u 0) are often canonicalized to (x != 0),
3408        // so check for this case by checking if the NE is comparing against
3409        // a minimum or maximum constant.
3410        if (!ICmpInst::isTrueWhenEqual(Pred))
3411          if (ConstantInt *CI = dyn_cast<ConstantInt>(PreCondRHS)) {
3412            const APInt &A = CI->getValue();
3413            switch (Pred) {
3414            case ICmpInst::ICMP_SLT:
3415              if (A.isMaxSignedValue()) break;
3416              continue;
3417            case ICmpInst::ICMP_SGT:
3418              if (A.isMinSignedValue()) break;
3419              continue;
3420            case ICmpInst::ICMP_ULT:
3421              if (A.isMaxValue()) break;
3422              continue;
3423            case ICmpInst::ICMP_UGT:
3424              if (A.isMinValue()) break;
3425              continue;
3426            default:
3427              continue;
3428            }
3429            Cond = ICmpInst::ICMP_NE;
3430            // NE is symmetric but the original comparison may not be. Swap
3431            // the operands if necessary so that they match below.
3432            if (isa<SCEVConstant>(LHS))
3433              std::swap(PreCondLHS, PreCondRHS);
3434            break;
3435          }
3436        continue;
3437      default:
3438        // We weren't able to reconcile the condition.
3439        continue;
3440      }
3441
3442    if (!PreCondLHS->getType()->isInteger()) continue;
3443
3444    SCEVHandle PreCondLHSSCEV = getSCEV(PreCondLHS);
3445    SCEVHandle PreCondRHSSCEV = getSCEV(PreCondRHS);
3446    if ((LHS == PreCondLHSSCEV && RHS == PreCondRHSSCEV) ||
3447        (LHS == getNotSCEV(PreCondRHSSCEV) &&
3448         RHS == getNotSCEV(PreCondLHSSCEV)))
3449      return true;
3450  }
3451
3452  return false;
3453}
3454
3455/// HowManyLessThans - Return the number of times a backedge containing the
3456/// specified less-than comparison will execute.  If not computable, return
3457/// UnknownValue.
3458ScalarEvolution::BackedgeTakenInfo ScalarEvolution::
3459HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
3460                 const Loop *L, bool isSigned) {
3461  // Only handle:  "ADDREC < LoopInvariant".
3462  if (!RHS->isLoopInvariant(L)) return UnknownValue;
3463
3464  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS);
3465  if (!AddRec || AddRec->getLoop() != L)
3466    return UnknownValue;
3467
3468  if (AddRec->isAffine()) {
3469    // FORNOW: We only support unit strides.
3470    unsigned BitWidth = getTypeSizeInBits(AddRec->getType());
3471    SCEVHandle Step = AddRec->getStepRecurrence(*this);
3472    SCEVHandle NegOne = getIntegerSCEV(-1, AddRec->getType());
3473
3474    // TODO: handle non-constant strides.
3475    const SCEVConstant *CStep = dyn_cast<SCEVConstant>(Step);
3476    if (!CStep || CStep->isZero())
3477      return UnknownValue;
3478    if (CStep->isOne()) {
3479      // With unit stride, the iteration never steps past the limit value.
3480    } else if (CStep->getValue()->getValue().isStrictlyPositive()) {
3481      if (const SCEVConstant *CLimit = dyn_cast<SCEVConstant>(RHS)) {
3482        // Test whether a positive iteration iteration can step past the limit
3483        // value and past the maximum value for its type in a single step.
3484        if (isSigned) {
3485          APInt Max = APInt::getSignedMaxValue(BitWidth);
3486          if ((Max - CStep->getValue()->getValue())
3487                .slt(CLimit->getValue()->getValue()))
3488            return UnknownValue;
3489        } else {
3490          APInt Max = APInt::getMaxValue(BitWidth);
3491          if ((Max - CStep->getValue()->getValue())
3492                .ult(CLimit->getValue()->getValue()))
3493            return UnknownValue;
3494        }
3495      } else
3496        // TODO: handle non-constant limit values below.
3497        return UnknownValue;
3498    } else
3499      // TODO: handle negative strides below.
3500      return UnknownValue;
3501
3502    // We know the LHS is of the form {n,+,s} and the RHS is some loop-invariant
3503    // m.  So, we count the number of iterations in which {n,+,s} < m is true.
3504    // Note that we cannot simply return max(m-n,0)/s because it's not safe to
3505    // treat m-n as signed nor unsigned due to overflow possibility.
3506
3507    // First, we get the value of the LHS in the first iteration: n
3508    SCEVHandle Start = AddRec->getOperand(0);
3509
3510    // Determine the minimum constant start value.
3511    SCEVHandle MinStart = isa<SCEVConstant>(Start) ? Start :
3512      getConstant(isSigned ? APInt::getSignedMinValue(BitWidth) :
3513                             APInt::getMinValue(BitWidth));
3514
3515    // If we know that the condition is true in order to enter the loop,
3516    // then we know that it will run exactly (m-n)/s times. Otherwise, we
3517    // only know if will execute (max(m,n)-n)/s times. In both cases, the
3518    // division must round up.
3519    SCEVHandle End = RHS;
3520    if (!isLoopGuardedByCond(L,
3521                             isSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT,
3522                             getMinusSCEV(Start, Step), RHS))
3523      End = isSigned ? getSMaxExpr(RHS, Start)
3524                     : getUMaxExpr(RHS, Start);
3525
3526    // Determine the maximum constant end value.
3527    SCEVHandle MaxEnd = isa<SCEVConstant>(End) ? End :
3528      getConstant(isSigned ? APInt::getSignedMaxValue(BitWidth) :
3529                             APInt::getMaxValue(BitWidth));
3530
3531    // Finally, we subtract these two values and divide, rounding up, to get
3532    // the number of times the backedge is executed.
3533    SCEVHandle BECount = getUDivExpr(getAddExpr(getMinusSCEV(End, Start),
3534                                                getAddExpr(Step, NegOne)),
3535                                     Step);
3536
3537    // The maximum backedge count is similar, except using the minimum start
3538    // value and the maximum end value.
3539    SCEVHandle MaxBECount = getUDivExpr(getAddExpr(getMinusSCEV(MaxEnd,
3540                                                                MinStart),
3541                                                   getAddExpr(Step, NegOne)),
3542                                        Step);
3543
3544    return BackedgeTakenInfo(BECount, MaxBECount);
3545  }
3546
3547  return UnknownValue;
3548}
3549
3550/// getNumIterationsInRange - Return the number of iterations of this loop that
3551/// produce values in the specified constant range.  Another way of looking at
3552/// this is that it returns the first iteration number where the value is not in
3553/// the condition, thus computing the exit count. If the iteration count can't
3554/// be computed, an instance of SCEVCouldNotCompute is returned.
3555SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
3556                                                   ScalarEvolution &SE) const {
3557  if (Range.isFullSet())  // Infinite loop.
3558    return SE.getCouldNotCompute();
3559
3560  // If the start is a non-zero constant, shift the range to simplify things.
3561  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
3562    if (!SC->getValue()->isZero()) {
3563      std::vector<SCEVHandle> Operands(op_begin(), op_end());
3564      Operands[0] = SE.getIntegerSCEV(0, SC->getType());
3565      SCEVHandle Shifted = SE.getAddRecExpr(Operands, getLoop());
3566      if (const SCEVAddRecExpr *ShiftedAddRec =
3567            dyn_cast<SCEVAddRecExpr>(Shifted))
3568        return ShiftedAddRec->getNumIterationsInRange(
3569                           Range.subtract(SC->getValue()->getValue()), SE);
3570      // This is strange and shouldn't happen.
3571      return SE.getCouldNotCompute();
3572    }
3573
3574  // The only time we can solve this is when we have all constant indices.
3575  // Otherwise, we cannot determine the overflow conditions.
3576  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
3577    if (!isa<SCEVConstant>(getOperand(i)))
3578      return SE.getCouldNotCompute();
3579
3580
3581  // Okay at this point we know that all elements of the chrec are constants and
3582  // that the start element is zero.
3583
3584  // First check to see if the range contains zero.  If not, the first
3585  // iteration exits.
3586  unsigned BitWidth = SE.getTypeSizeInBits(getType());
3587  if (!Range.contains(APInt(BitWidth, 0)))
3588    return SE.getConstant(ConstantInt::get(getType(),0));
3589
3590  if (isAffine()) {
3591    // If this is an affine expression then we have this situation:
3592    //   Solve {0,+,A} in Range  ===  Ax in Range
3593
3594    // We know that zero is in the range.  If A is positive then we know that
3595    // the upper value of the range must be the first possible exit value.
3596    // If A is negative then the lower of the range is the last possible loop
3597    // value.  Also note that we already checked for a full range.
3598    APInt One(BitWidth,1);
3599    APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
3600    APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
3601
3602    // The exit value should be (End+A)/A.
3603    APInt ExitVal = (End + A).udiv(A);
3604    ConstantInt *ExitValue = ConstantInt::get(ExitVal);
3605
3606    // Evaluate at the exit value.  If we really did fall out of the valid
3607    // range, then we computed our trip count, otherwise wrap around or other
3608    // things must have happened.
3609    ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
3610    if (Range.contains(Val->getValue()))
3611      return SE.getCouldNotCompute();  // Something strange happened
3612
3613    // Ensure that the previous value is in the range.  This is a sanity check.
3614    assert(Range.contains(
3615           EvaluateConstantChrecAtConstant(this,
3616           ConstantInt::get(ExitVal - One), SE)->getValue()) &&
3617           "Linear scev computation is off in a bad way!");
3618    return SE.getConstant(ExitValue);
3619  } else if (isQuadratic()) {
3620    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
3621    // quadratic equation to solve it.  To do this, we must frame our problem in
3622    // terms of figuring out when zero is crossed, instead of when
3623    // Range.getUpper() is crossed.
3624    std::vector<SCEVHandle> NewOps(op_begin(), op_end());
3625    NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper()));
3626    SCEVHandle NewAddRec = SE.getAddRecExpr(NewOps, getLoop());
3627
3628    // Next, solve the constructed addrec
3629    std::pair<SCEVHandle,SCEVHandle> Roots =
3630      SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
3631    const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
3632    const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
3633    if (R1) {
3634      // Pick the smallest positive root value.
3635      if (ConstantInt *CB =
3636          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
3637                                   R1->getValue(), R2->getValue()))) {
3638        if (CB->getZExtValue() == false)
3639          std::swap(R1, R2);   // R1 is the minimum root now.
3640
3641        // Make sure the root is not off by one.  The returned iteration should
3642        // not be in the range, but the previous one should be.  When solving
3643        // for "X*X < 5", for example, we should not return a root of 2.
3644        ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this,
3645                                                             R1->getValue(),
3646                                                             SE);
3647        if (Range.contains(R1Val->getValue())) {
3648          // The next iteration must be out of the range...
3649          ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()+1);
3650
3651          R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
3652          if (!Range.contains(R1Val->getValue()))
3653            return SE.getConstant(NextVal);
3654          return SE.getCouldNotCompute();  // Something strange happened
3655        }
3656
3657        // If R1 was not in the range, then it is a good return value.  Make
3658        // sure that R1-1 WAS in the range though, just in case.
3659        ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()-1);
3660        R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
3661        if (Range.contains(R1Val->getValue()))
3662          return R1;
3663        return SE.getCouldNotCompute();  // Something strange happened
3664      }
3665    }
3666  }
3667
3668  return SE.getCouldNotCompute();
3669}
3670
3671
3672
3673//===----------------------------------------------------------------------===//
3674//                   SCEVCallbackVH Class Implementation
3675//===----------------------------------------------------------------------===//
3676
3677void SCEVCallbackVH::deleted() {
3678  assert(SE && "SCEVCallbackVH called with a non-null ScalarEvolution!");
3679  if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
3680    SE->ConstantEvolutionLoopExitValue.erase(PN);
3681  if (Instruction *I = dyn_cast<Instruction>(getValPtr()))
3682    SE->ValuesAtScopes.erase(I);
3683  SE->Scalars.erase(getValPtr());
3684  // this now dangles!
3685}
3686
3687void SCEVCallbackVH::allUsesReplacedWith(Value *) {
3688  assert(SE && "SCEVCallbackVH called with a non-null ScalarEvolution!");
3689
3690  // Forget all the expressions associated with users of the old value,
3691  // so that future queries will recompute the expressions using the new
3692  // value.
3693  SmallVector<User *, 16> Worklist;
3694  Value *Old = getValPtr();
3695  bool DeleteOld = false;
3696  for (Value::use_iterator UI = Old->use_begin(), UE = Old->use_end();
3697       UI != UE; ++UI)
3698    Worklist.push_back(*UI);
3699  while (!Worklist.empty()) {
3700    User *U = Worklist.pop_back_val();
3701    // Deleting the Old value will cause this to dangle. Postpone
3702    // that until everything else is done.
3703    if (U == Old) {
3704      DeleteOld = true;
3705      continue;
3706    }
3707    if (PHINode *PN = dyn_cast<PHINode>(U))
3708      SE->ConstantEvolutionLoopExitValue.erase(PN);
3709    if (Instruction *I = dyn_cast<Instruction>(U))
3710      SE->ValuesAtScopes.erase(I);
3711    if (SE->Scalars.erase(U))
3712      for (Value::use_iterator UI = U->use_begin(), UE = U->use_end();
3713           UI != UE; ++UI)
3714        Worklist.push_back(*UI);
3715  }
3716  if (DeleteOld) {
3717    if (PHINode *PN = dyn_cast<PHINode>(Old))
3718      SE->ConstantEvolutionLoopExitValue.erase(PN);
3719    if (Instruction *I = dyn_cast<Instruction>(Old))
3720      SE->ValuesAtScopes.erase(I);
3721    SE->Scalars.erase(Old);
3722    // this now dangles!
3723  }
3724  // this may dangle!
3725}
3726
3727SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
3728  : CallbackVH(V), SE(se) {}
3729
3730//===----------------------------------------------------------------------===//
3731//                   ScalarEvolution Class Implementation
3732//===----------------------------------------------------------------------===//
3733
3734ScalarEvolution::ScalarEvolution()
3735  : FunctionPass(&ID), UnknownValue(new SCEVCouldNotCompute()) {
3736}
3737
3738bool ScalarEvolution::runOnFunction(Function &F) {
3739  this->F = &F;
3740  LI = &getAnalysis<LoopInfo>();
3741  TD = getAnalysisIfAvailable<TargetData>();
3742  return false;
3743}
3744
3745void ScalarEvolution::releaseMemory() {
3746  Scalars.clear();
3747  BackedgeTakenCounts.clear();
3748  ConstantEvolutionLoopExitValue.clear();
3749  ValuesAtScopes.clear();
3750}
3751
3752void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
3753  AU.setPreservesAll();
3754  AU.addRequiredTransitive<LoopInfo>();
3755}
3756
3757bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
3758  return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
3759}
3760
3761static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
3762                          const Loop *L) {
3763  // Print all inner loops first
3764  for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
3765    PrintLoopInfo(OS, SE, *I);
3766
3767  OS << "Loop " << L->getHeader()->getName() << ": ";
3768
3769  SmallVector<BasicBlock*, 8> ExitBlocks;
3770  L->getExitBlocks(ExitBlocks);
3771  if (ExitBlocks.size() != 1)
3772    OS << "<multiple exits> ";
3773
3774  if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
3775    OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L);
3776  } else {
3777    OS << "Unpredictable backedge-taken count. ";
3778  }
3779
3780  OS << "\n";
3781}
3782
3783void ScalarEvolution::print(raw_ostream &OS, const Module* ) const {
3784  // ScalarEvolution's implementaiton of the print method is to print
3785  // out SCEV values of all instructions that are interesting. Doing
3786  // this potentially causes it to create new SCEV objects though,
3787  // which technically conflicts with the const qualifier. This isn't
3788  // observable from outside the class though (the hasSCEV function
3789  // notwithstanding), so casting away the const isn't dangerous.
3790  ScalarEvolution &SE = *const_cast<ScalarEvolution*>(this);
3791
3792  OS << "Classifying expressions for: " << F->getName() << "\n";
3793  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
3794    if (isSCEVable(I->getType())) {
3795      OS << *I;
3796      OS << "  -->  ";
3797      SCEVHandle SV = SE.getSCEV(&*I);
3798      SV->print(OS);
3799      OS << "\t\t";
3800
3801      if (const Loop *L = LI->getLoopFor((*I).getParent())) {
3802        OS << "Exits: ";
3803        SCEVHandle ExitValue = SE.getSCEVAtScope(&*I, L->getParentLoop());
3804        if (isa<SCEVCouldNotCompute>(ExitValue)) {
3805          OS << "<<Unknown>>";
3806        } else {
3807          OS << *ExitValue;
3808        }
3809      }
3810
3811
3812      OS << "\n";
3813    }
3814
3815  OS << "Determining loop execution counts for: " << F->getName() << "\n";
3816  for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I)
3817    PrintLoopInfo(OS, &SE, *I);
3818}
3819
3820void ScalarEvolution::print(std::ostream &o, const Module *M) const {
3821  raw_os_ostream OS(o);
3822  print(OS, M);
3823}
3824