ScalarEvolution.cpp revision 2ceb40f3da2290d37e9a4faa35bd5199e5dc90d5
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file contains the implementation of the scalar evolution analysis
11// engine, which is used primarily to analyze expressions involving induction
12// variables in loops.
13//
14// There are several aspects to this library.  First is the representation of
15// scalar expressions, which are represented as subclasses of the SCEV class.
16// These classes are used to represent certain types of subexpressions that we
17// can handle.  These classes are reference counted, managed by the SCEVHandle
18// class.  We only create one SCEV of a particular shape, so pointer-comparisons
19// for equality are legal.
20//
21// One important aspect of the SCEV objects is that they are never cyclic, even
22// if there is a cycle in the dataflow for an expression (ie, a PHI node).  If
23// the PHI node is one of the idioms that we can represent (e.g., a polynomial
24// recurrence) then we represent it directly as a recurrence node, otherwise we
25// represent it as a SCEVUnknown node.
26//
27// In addition to being able to represent expressions of various types, we also
28// have folders that are used to build the *canonical* representation for a
29// particular expression.  These folders are capable of using a variety of
30// rewrite rules to simplify the expressions.
31//
32// Once the folders are defined, we can implement the more interesting
33// higher-level code, such as the code that recognizes PHI nodes of various
34// types, computes the execution count of a loop, etc.
35//
36// TODO: We should use these routines and value representations to implement
37// dependence analysis!
38//
39//===----------------------------------------------------------------------===//
40//
41// There are several good references for the techniques used in this analysis.
42//
43//  Chains of recurrences -- a method to expedite the evaluation
44//  of closed-form functions
45//  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
46//
47//  On computational properties of chains of recurrences
48//  Eugene V. Zima
49//
50//  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
51//  Robert A. van Engelen
52//
53//  Efficient Symbolic Analysis for Optimizing Compilers
54//  Robert A. van Engelen
55//
56//  Using the chains of recurrences algebra for data dependence testing and
57//  induction variable substitution
58//  MS Thesis, Johnie Birch
59//
60//===----------------------------------------------------------------------===//
61
62#define DEBUG_TYPE "scalar-evolution"
63#include "llvm/Analysis/ScalarEvolutionExpressions.h"
64#include "llvm/Constants.h"
65#include "llvm/DerivedTypes.h"
66#include "llvm/GlobalVariable.h"
67#include "llvm/Instructions.h"
68#include "llvm/Analysis/ConstantFolding.h"
69#include "llvm/Analysis/LoopInfo.h"
70#include "llvm/Assembly/Writer.h"
71#include "llvm/Transforms/Scalar.h"
72#include "llvm/Support/CFG.h"
73#include "llvm/Support/CommandLine.h"
74#include "llvm/Support/Compiler.h"
75#include "llvm/Support/ConstantRange.h"
76#include "llvm/Support/InstIterator.h"
77#include "llvm/Support/ManagedStatic.h"
78#include "llvm/Support/MathExtras.h"
79#include "llvm/Support/Streams.h"
80#include "llvm/ADT/Statistic.h"
81#include <ostream>
82#include <algorithm>
83#include <cmath>
84using namespace llvm;
85
86STATISTIC(NumBruteForceEvaluations,
87          "Number of brute force evaluations needed to "
88          "calculate high-order polynomial exit values");
89STATISTIC(NumArrayLenItCounts,
90          "Number of trip counts computed with array length");
91STATISTIC(NumTripCountsComputed,
92          "Number of loops with predictable loop counts");
93STATISTIC(NumTripCountsNotComputed,
94          "Number of loops without predictable loop counts");
95STATISTIC(NumBruteForceTripCountsComputed,
96          "Number of loops with trip counts computed by force");
97
98static cl::opt<unsigned>
99MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
100                        cl::desc("Maximum number of iterations SCEV will "
101                                 "symbolically execute a constant derived loop"),
102                        cl::init(100));
103
104static RegisterPass<ScalarEvolution>
105R("scalar-evolution", "Scalar Evolution Analysis", false, true);
106char ScalarEvolution::ID = 0;
107
108//===----------------------------------------------------------------------===//
109//                           SCEV class definitions
110//===----------------------------------------------------------------------===//
111
112//===----------------------------------------------------------------------===//
113// Implementation of the SCEV class.
114//
115SCEV::~SCEV() {}
116void SCEV::dump() const {
117  print(cerr);
118}
119
120uint32_t SCEV::getBitWidth() const {
121  if (const IntegerType* ITy = dyn_cast<IntegerType>(getType()))
122    return ITy->getBitWidth();
123  return 0;
124}
125
126bool SCEV::isZero() const {
127  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
128    return SC->getValue()->isZero();
129  return false;
130}
131
132
133SCEVCouldNotCompute::SCEVCouldNotCompute() : SCEV(scCouldNotCompute) {}
134
135bool SCEVCouldNotCompute::isLoopInvariant(const Loop *L) const {
136  assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
137  return false;
138}
139
140const Type *SCEVCouldNotCompute::getType() const {
141  assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
142  return 0;
143}
144
145bool SCEVCouldNotCompute::hasComputableLoopEvolution(const Loop *L) const {
146  assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
147  return false;
148}
149
150SCEVHandle SCEVCouldNotCompute::
151replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
152                                  const SCEVHandle &Conc,
153                                  ScalarEvolution &SE) const {
154  return this;
155}
156
157void SCEVCouldNotCompute::print(std::ostream &OS) const {
158  OS << "***COULDNOTCOMPUTE***";
159}
160
161bool SCEVCouldNotCompute::classof(const SCEV *S) {
162  return S->getSCEVType() == scCouldNotCompute;
163}
164
165
166// SCEVConstants - Only allow the creation of one SCEVConstant for any
167// particular value.  Don't use a SCEVHandle here, or else the object will
168// never be deleted!
169static ManagedStatic<std::map<ConstantInt*, SCEVConstant*> > SCEVConstants;
170
171
172SCEVConstant::~SCEVConstant() {
173  SCEVConstants->erase(V);
174}
175
176SCEVHandle ScalarEvolution::getConstant(ConstantInt *V) {
177  SCEVConstant *&R = (*SCEVConstants)[V];
178  if (R == 0) R = new SCEVConstant(V);
179  return R;
180}
181
182SCEVHandle ScalarEvolution::getConstant(const APInt& Val) {
183  return getConstant(ConstantInt::get(Val));
184}
185
186const Type *SCEVConstant::getType() const { return V->getType(); }
187
188void SCEVConstant::print(std::ostream &OS) const {
189  WriteAsOperand(OS, V, false);
190}
191
192// SCEVTruncates - Only allow the creation of one SCEVTruncateExpr for any
193// particular input.  Don't use a SCEVHandle here, or else the object will
194// never be deleted!
195static ManagedStatic<std::map<std::pair<SCEV*, const Type*>,
196                     SCEVTruncateExpr*> > SCEVTruncates;
197
198SCEVTruncateExpr::SCEVTruncateExpr(const SCEVHandle &op, const Type *ty)
199  : SCEV(scTruncate), Op(op), Ty(ty) {
200  assert(Op->getType()->isInteger() && Ty->isInteger() &&
201         "Cannot truncate non-integer value!");
202  assert(Op->getType()->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits()
203         && "This is not a truncating conversion!");
204}
205
206SCEVTruncateExpr::~SCEVTruncateExpr() {
207  SCEVTruncates->erase(std::make_pair(Op, Ty));
208}
209
210void SCEVTruncateExpr::print(std::ostream &OS) const {
211  OS << "(truncate " << *Op << " to " << *Ty << ")";
212}
213
214// SCEVZeroExtends - Only allow the creation of one SCEVZeroExtendExpr for any
215// particular input.  Don't use a SCEVHandle here, or else the object will never
216// be deleted!
217static ManagedStatic<std::map<std::pair<SCEV*, const Type*>,
218                     SCEVZeroExtendExpr*> > SCEVZeroExtends;
219
220SCEVZeroExtendExpr::SCEVZeroExtendExpr(const SCEVHandle &op, const Type *ty)
221  : SCEV(scZeroExtend), Op(op), Ty(ty) {
222  assert(Op->getType()->isInteger() && Ty->isInteger() &&
223         "Cannot zero extend non-integer value!");
224  assert(Op->getType()->getPrimitiveSizeInBits() < Ty->getPrimitiveSizeInBits()
225         && "This is not an extending conversion!");
226}
227
228SCEVZeroExtendExpr::~SCEVZeroExtendExpr() {
229  SCEVZeroExtends->erase(std::make_pair(Op, Ty));
230}
231
232void SCEVZeroExtendExpr::print(std::ostream &OS) const {
233  OS << "(zeroextend " << *Op << " to " << *Ty << ")";
234}
235
236// SCEVSignExtends - Only allow the creation of one SCEVSignExtendExpr for any
237// particular input.  Don't use a SCEVHandle here, or else the object will never
238// be deleted!
239static ManagedStatic<std::map<std::pair<SCEV*, const Type*>,
240                     SCEVSignExtendExpr*> > SCEVSignExtends;
241
242SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty)
243  : SCEV(scSignExtend), Op(op), Ty(ty) {
244  assert(Op->getType()->isInteger() && Ty->isInteger() &&
245         "Cannot sign extend non-integer value!");
246  assert(Op->getType()->getPrimitiveSizeInBits() < Ty->getPrimitiveSizeInBits()
247         && "This is not an extending conversion!");
248}
249
250SCEVSignExtendExpr::~SCEVSignExtendExpr() {
251  SCEVSignExtends->erase(std::make_pair(Op, Ty));
252}
253
254void SCEVSignExtendExpr::print(std::ostream &OS) const {
255  OS << "(signextend " << *Op << " to " << *Ty << ")";
256}
257
258// SCEVCommExprs - Only allow the creation of one SCEVCommutativeExpr for any
259// particular input.  Don't use a SCEVHandle here, or else the object will never
260// be deleted!
261static ManagedStatic<std::map<std::pair<unsigned, std::vector<SCEV*> >,
262                     SCEVCommutativeExpr*> > SCEVCommExprs;
263
264SCEVCommutativeExpr::~SCEVCommutativeExpr() {
265  SCEVCommExprs->erase(std::make_pair(getSCEVType(),
266                                      std::vector<SCEV*>(Operands.begin(),
267                                                         Operands.end())));
268}
269
270void SCEVCommutativeExpr::print(std::ostream &OS) const {
271  assert(Operands.size() > 1 && "This plus expr shouldn't exist!");
272  const char *OpStr = getOperationStr();
273  OS << "(" << *Operands[0];
274  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
275    OS << OpStr << *Operands[i];
276  OS << ")";
277}
278
279SCEVHandle SCEVCommutativeExpr::
280replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
281                                  const SCEVHandle &Conc,
282                                  ScalarEvolution &SE) const {
283  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
284    SCEVHandle H =
285      getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE);
286    if (H != getOperand(i)) {
287      std::vector<SCEVHandle> NewOps;
288      NewOps.reserve(getNumOperands());
289      for (unsigned j = 0; j != i; ++j)
290        NewOps.push_back(getOperand(j));
291      NewOps.push_back(H);
292      for (++i; i != e; ++i)
293        NewOps.push_back(getOperand(i)->
294                         replaceSymbolicValuesWithConcrete(Sym, Conc, SE));
295
296      if (isa<SCEVAddExpr>(this))
297        return SE.getAddExpr(NewOps);
298      else if (isa<SCEVMulExpr>(this))
299        return SE.getMulExpr(NewOps);
300      else if (isa<SCEVSMaxExpr>(this))
301        return SE.getSMaxExpr(NewOps);
302      else if (isa<SCEVUMaxExpr>(this))
303        return SE.getUMaxExpr(NewOps);
304      else
305        assert(0 && "Unknown commutative expr!");
306    }
307  }
308  return this;
309}
310
311
312// SCEVUDivs - Only allow the creation of one SCEVUDivExpr for any particular
313// input.  Don't use a SCEVHandle here, or else the object will never be
314// deleted!
315static ManagedStatic<std::map<std::pair<SCEV*, SCEV*>,
316                     SCEVUDivExpr*> > SCEVUDivs;
317
318SCEVUDivExpr::~SCEVUDivExpr() {
319  SCEVUDivs->erase(std::make_pair(LHS, RHS));
320}
321
322void SCEVUDivExpr::print(std::ostream &OS) const {
323  OS << "(" << *LHS << " /u " << *RHS << ")";
324}
325
326const Type *SCEVUDivExpr::getType() const {
327  return LHS->getType();
328}
329
330// SCEVAddRecExprs - Only allow the creation of one SCEVAddRecExpr for any
331// particular input.  Don't use a SCEVHandle here, or else the object will never
332// be deleted!
333static ManagedStatic<std::map<std::pair<const Loop *, std::vector<SCEV*> >,
334                     SCEVAddRecExpr*> > SCEVAddRecExprs;
335
336SCEVAddRecExpr::~SCEVAddRecExpr() {
337  SCEVAddRecExprs->erase(std::make_pair(L,
338                                        std::vector<SCEV*>(Operands.begin(),
339                                                           Operands.end())));
340}
341
342SCEVHandle SCEVAddRecExpr::
343replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
344                                  const SCEVHandle &Conc,
345                                  ScalarEvolution &SE) const {
346  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
347    SCEVHandle H =
348      getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE);
349    if (H != getOperand(i)) {
350      std::vector<SCEVHandle> NewOps;
351      NewOps.reserve(getNumOperands());
352      for (unsigned j = 0; j != i; ++j)
353        NewOps.push_back(getOperand(j));
354      NewOps.push_back(H);
355      for (++i; i != e; ++i)
356        NewOps.push_back(getOperand(i)->
357                         replaceSymbolicValuesWithConcrete(Sym, Conc, SE));
358
359      return SE.getAddRecExpr(NewOps, L);
360    }
361  }
362  return this;
363}
364
365
366bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const {
367  // This recurrence is invariant w.r.t to QueryLoop iff QueryLoop doesn't
368  // contain L and if the start is invariant.
369  return !QueryLoop->contains(L->getHeader()) &&
370         getOperand(0)->isLoopInvariant(QueryLoop);
371}
372
373
374void SCEVAddRecExpr::print(std::ostream &OS) const {
375  OS << "{" << *Operands[0];
376  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
377    OS << ",+," << *Operands[i];
378  OS << "}<" << L->getHeader()->getName() + ">";
379}
380
381// SCEVUnknowns - Only allow the creation of one SCEVUnknown for any particular
382// value.  Don't use a SCEVHandle here, or else the object will never be
383// deleted!
384static ManagedStatic<std::map<Value*, SCEVUnknown*> > SCEVUnknowns;
385
386SCEVUnknown::~SCEVUnknown() { SCEVUnknowns->erase(V); }
387
388bool SCEVUnknown::isLoopInvariant(const Loop *L) const {
389  // All non-instruction values are loop invariant.  All instructions are loop
390  // invariant if they are not contained in the specified loop.
391  if (Instruction *I = dyn_cast<Instruction>(V))
392    return !L->contains(I->getParent());
393  return true;
394}
395
396const Type *SCEVUnknown::getType() const {
397  return V->getType();
398}
399
400void SCEVUnknown::print(std::ostream &OS) const {
401  WriteAsOperand(OS, V, false);
402}
403
404//===----------------------------------------------------------------------===//
405//                               SCEV Utilities
406//===----------------------------------------------------------------------===//
407
408namespace {
409  /// SCEVComplexityCompare - Return true if the complexity of the LHS is less
410  /// than the complexity of the RHS.  This comparator is used to canonicalize
411  /// expressions.
412  struct VISIBILITY_HIDDEN SCEVComplexityCompare {
413    bool operator()(const SCEV *LHS, const SCEV *RHS) const {
414      return LHS->getSCEVType() < RHS->getSCEVType();
415    }
416  };
417}
418
419/// GroupByComplexity - Given a list of SCEV objects, order them by their
420/// complexity, and group objects of the same complexity together by value.
421/// When this routine is finished, we know that any duplicates in the vector are
422/// consecutive and that complexity is monotonically increasing.
423///
424/// Note that we go take special precautions to ensure that we get determinstic
425/// results from this routine.  In other words, we don't want the results of
426/// this to depend on where the addresses of various SCEV objects happened to
427/// land in memory.
428///
429static void GroupByComplexity(std::vector<SCEVHandle> &Ops) {
430  if (Ops.size() < 2) return;  // Noop
431  if (Ops.size() == 2) {
432    // This is the common case, which also happens to be trivially simple.
433    // Special case it.
434    if (SCEVComplexityCompare()(Ops[1], Ops[0]))
435      std::swap(Ops[0], Ops[1]);
436    return;
437  }
438
439  // Do the rough sort by complexity.
440  std::sort(Ops.begin(), Ops.end(), SCEVComplexityCompare());
441
442  // Now that we are sorted by complexity, group elements of the same
443  // complexity.  Note that this is, at worst, N^2, but the vector is likely to
444  // be extremely short in practice.  Note that we take this approach because we
445  // do not want to depend on the addresses of the objects we are grouping.
446  for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
447    SCEV *S = Ops[i];
448    unsigned Complexity = S->getSCEVType();
449
450    // If there are any objects of the same complexity and same value as this
451    // one, group them.
452    for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
453      if (Ops[j] == S) { // Found a duplicate.
454        // Move it to immediately after i'th element.
455        std::swap(Ops[i+1], Ops[j]);
456        ++i;   // no need to rescan it.
457        if (i == e-2) return;  // Done!
458      }
459    }
460  }
461}
462
463
464
465//===----------------------------------------------------------------------===//
466//                      Simple SCEV method implementations
467//===----------------------------------------------------------------------===//
468
469/// getIntegerSCEV - Given an integer or FP type, create a constant for the
470/// specified signed integer value and return a SCEV for the constant.
471SCEVHandle ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) {
472  Constant *C;
473  if (Val == 0)
474    C = Constant::getNullValue(Ty);
475  else if (Ty->isFloatingPoint())
476    C = ConstantFP::get(APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle :
477                                APFloat::IEEEdouble, Val));
478  else
479    C = ConstantInt::get(Ty, Val);
480  return getUnknown(C);
481}
482
483/// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
484///
485SCEVHandle ScalarEvolution::getNegativeSCEV(const SCEVHandle &V) {
486  if (SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
487    return getUnknown(ConstantExpr::getNeg(VC->getValue()));
488
489  return getMulExpr(V, getConstant(ConstantInt::getAllOnesValue(V->getType())));
490}
491
492/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
493SCEVHandle ScalarEvolution::getNotSCEV(const SCEVHandle &V) {
494  if (SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
495    return getUnknown(ConstantExpr::getNot(VC->getValue()));
496
497  SCEVHandle AllOnes = getConstant(ConstantInt::getAllOnesValue(V->getType()));
498  return getMinusSCEV(AllOnes, V);
499}
500
501/// getMinusSCEV - Return a SCEV corresponding to LHS - RHS.
502///
503SCEVHandle ScalarEvolution::getMinusSCEV(const SCEVHandle &LHS,
504                                         const SCEVHandle &RHS) {
505  // X - Y --> X + -Y
506  return getAddExpr(LHS, getNegativeSCEV(RHS));
507}
508
509
510/// BinomialCoefficient - Compute BC(It, K).  The result is of the same type as
511/// It.  Assume, K > 0.
512static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K,
513                                      ScalarEvolution &SE) {
514  // We are using the following formula for BC(It, K):
515  //
516  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
517  //
518  // Suppose, W is the bitwidth of It (and of the return value as well).  We
519  // must be prepared for overflow.  Hence, we must assure that the result of
520  // our computation is equal to the accurate one modulo 2^W.  Unfortunately,
521  // division isn't safe in modular arithmetic.  This means we must perform the
522  // whole computation accurately and then truncate the result to W bits.
523  //
524  // The dividend of the formula is a multiplication of K integers of bitwidth
525  // W.  K*W bits suffice to compute it accurately.
526  //
527  // FIXME: We assume the divisor can be accurately computed using 16-bit
528  // unsigned integer type. It is true up to K = 8 (AddRecs of length 9). In
529  // future we may use APInt to use the minimum number of bits necessary to
530  // compute it accurately.
531  //
532  // It is safe to use unsigned division here: the dividend is nonnegative and
533  // the divisor is positive.
534
535  // Handle the simplest case efficiently.
536  if (K == 1)
537    return It;
538
539  assert(K < 9 && "We cannot handle such long AddRecs yet.");
540
541  unsigned DividendBits = K * It->getBitWidth();
542  if (DividendBits > 256)
543    return new SCEVCouldNotCompute();
544
545  const IntegerType *DividendTy = IntegerType::get(DividendBits);
546  const SCEVHandle ExIt = SE.getZeroExtendExpr(It, DividendTy);
547
548  // The final number of bits we need to perform the division is the maximum of
549  // dividend and divisor bitwidths.
550  const IntegerType *DivisionTy =
551    IntegerType::get(std::max(DividendBits, 16U));
552
553  // Compute K!  We know K >= 2 here.
554  unsigned F = 2;
555  for (unsigned i = 3; i <= K; ++i)
556    F *= i;
557  APInt Divisor(DivisionTy->getBitWidth(), F);
558
559  // Handle this case efficiently, it is common to have constant iteration
560  // counts while computing loop exit values.
561  if (SCEVConstant *SC = dyn_cast<SCEVConstant>(ExIt)) {
562    const APInt& N = SC->getValue()->getValue();
563    APInt Dividend(N.getBitWidth(), 1);
564    for (; K; --K)
565      Dividend *= N-(K-1);
566    if (DividendTy != DivisionTy)
567      Dividend = Dividend.zext(DivisionTy->getBitWidth());
568    return SE.getConstant(Dividend.udiv(Divisor).trunc(It->getBitWidth()));
569  }
570
571  SCEVHandle Dividend = ExIt;
572  for (unsigned i = 1; i != K; ++i)
573    Dividend =
574      SE.getMulExpr(Dividend,
575                    SE.getMinusSCEV(ExIt, SE.getIntegerSCEV(i, DividendTy)));
576
577  if (DividendTy != DivisionTy)
578    Dividend = SE.getZeroExtendExpr(Dividend, DivisionTy);
579  return SE.getTruncateExpr(SE.getUDivExpr(Dividend, SE.getConstant(Divisor)),
580                            It->getType());
581}
582
583/// evaluateAtIteration - Return the value of this chain of recurrences at
584/// the specified iteration number.  We can evaluate this recurrence by
585/// multiplying each element in the chain by the binomial coefficient
586/// corresponding to it.  In other words, we can evaluate {A,+,B,+,C,+,D} as:
587///
588///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
589///
590/// where BC(It, k) stands for binomial coefficient.
591///
592SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It,
593                                               ScalarEvolution &SE) const {
594  SCEVHandle Result = getStart();
595  for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
596    // The computation is correct in the face of overflow provided that the
597    // multiplication is performed _after_ the evaluation of the binomial
598    // coefficient.
599    SCEVHandle Val = SE.getMulExpr(getOperand(i),
600                                   BinomialCoefficient(It, i, SE));
601    Result = SE.getAddExpr(Result, Val);
602  }
603  return Result;
604}
605
606//===----------------------------------------------------------------------===//
607//                    SCEV Expression folder implementations
608//===----------------------------------------------------------------------===//
609
610SCEVHandle ScalarEvolution::getTruncateExpr(const SCEVHandle &Op, const Type *Ty) {
611  if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
612    return getUnknown(
613        ConstantExpr::getTrunc(SC->getValue(), Ty));
614
615  // If the input value is a chrec scev made out of constants, truncate
616  // all of the constants.
617  if (SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
618    std::vector<SCEVHandle> Operands;
619    for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
620      // FIXME: This should allow truncation of other expression types!
621      if (isa<SCEVConstant>(AddRec->getOperand(i)))
622        Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty));
623      else
624        break;
625    if (Operands.size() == AddRec->getNumOperands())
626      return getAddRecExpr(Operands, AddRec->getLoop());
627  }
628
629  SCEVTruncateExpr *&Result = (*SCEVTruncates)[std::make_pair(Op, Ty)];
630  if (Result == 0) Result = new SCEVTruncateExpr(Op, Ty);
631  return Result;
632}
633
634SCEVHandle ScalarEvolution::getZeroExtendExpr(const SCEVHandle &Op, const Type *Ty) {
635  if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
636    return getUnknown(
637        ConstantExpr::getZExt(SC->getValue(), Ty));
638
639  // FIXME: If the input value is a chrec scev, and we can prove that the value
640  // did not overflow the old, smaller, value, we can zero extend all of the
641  // operands (often constants).  This would allow analysis of something like
642  // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
643
644  SCEVZeroExtendExpr *&Result = (*SCEVZeroExtends)[std::make_pair(Op, Ty)];
645  if (Result == 0) Result = new SCEVZeroExtendExpr(Op, Ty);
646  return Result;
647}
648
649SCEVHandle ScalarEvolution::getSignExtendExpr(const SCEVHandle &Op, const Type *Ty) {
650  if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
651    return getUnknown(
652        ConstantExpr::getSExt(SC->getValue(), Ty));
653
654  // FIXME: If the input value is a chrec scev, and we can prove that the value
655  // did not overflow the old, smaller, value, we can sign extend all of the
656  // operands (often constants).  This would allow analysis of something like
657  // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
658
659  SCEVSignExtendExpr *&Result = (*SCEVSignExtends)[std::make_pair(Op, Ty)];
660  if (Result == 0) Result = new SCEVSignExtendExpr(Op, Ty);
661  return Result;
662}
663
664/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion
665/// of the input value to the specified type.  If the type must be
666/// extended, it is zero extended.
667SCEVHandle ScalarEvolution::getTruncateOrZeroExtend(const SCEVHandle &V,
668                                                    const Type *Ty) {
669  const Type *SrcTy = V->getType();
670  assert(SrcTy->isInteger() && Ty->isInteger() &&
671         "Cannot truncate or zero extend with non-integer arguments!");
672  if (SrcTy->getPrimitiveSizeInBits() == Ty->getPrimitiveSizeInBits())
673    return V;  // No conversion
674  if (SrcTy->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits())
675    return getTruncateExpr(V, Ty);
676  return getZeroExtendExpr(V, Ty);
677}
678
679// get - Get a canonical add expression, or something simpler if possible.
680SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) {
681  assert(!Ops.empty() && "Cannot get empty add!");
682  if (Ops.size() == 1) return Ops[0];
683
684  // Sort by complexity, this groups all similar expression types together.
685  GroupByComplexity(Ops);
686
687  // If there are any constants, fold them together.
688  unsigned Idx = 0;
689  if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
690    ++Idx;
691    assert(Idx < Ops.size());
692    while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
693      // We found two constants, fold them together!
694      ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() +
695                                           RHSC->getValue()->getValue());
696      Ops[0] = getConstant(Fold);
697      Ops.erase(Ops.begin()+1);  // Erase the folded element
698      if (Ops.size() == 1) return Ops[0];
699      LHSC = cast<SCEVConstant>(Ops[0]);
700    }
701
702    // If we are left with a constant zero being added, strip it off.
703    if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
704      Ops.erase(Ops.begin());
705      --Idx;
706    }
707  }
708
709  if (Ops.size() == 1) return Ops[0];
710
711  // Okay, check to see if the same value occurs in the operand list twice.  If
712  // so, merge them together into an multiply expression.  Since we sorted the
713  // list, these values are required to be adjacent.
714  const Type *Ty = Ops[0]->getType();
715  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
716    if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2
717      // Found a match, merge the two values into a multiply, and add any
718      // remaining values to the result.
719      SCEVHandle Two = getIntegerSCEV(2, Ty);
720      SCEVHandle Mul = getMulExpr(Ops[i], Two);
721      if (Ops.size() == 2)
722        return Mul;
723      Ops.erase(Ops.begin()+i, Ops.begin()+i+2);
724      Ops.push_back(Mul);
725      return getAddExpr(Ops);
726    }
727
728  // Now we know the first non-constant operand.  Skip past any cast SCEVs.
729  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
730    ++Idx;
731
732  // If there are add operands they would be next.
733  if (Idx < Ops.size()) {
734    bool DeletedAdd = false;
735    while (SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
736      // If we have an add, expand the add operands onto the end of the operands
737      // list.
738      Ops.insert(Ops.end(), Add->op_begin(), Add->op_end());
739      Ops.erase(Ops.begin()+Idx);
740      DeletedAdd = true;
741    }
742
743    // If we deleted at least one add, we added operands to the end of the list,
744    // and they are not necessarily sorted.  Recurse to resort and resimplify
745    // any operands we just aquired.
746    if (DeletedAdd)
747      return getAddExpr(Ops);
748  }
749
750  // Skip over the add expression until we get to a multiply.
751  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
752    ++Idx;
753
754  // If we are adding something to a multiply expression, make sure the
755  // something is not already an operand of the multiply.  If so, merge it into
756  // the multiply.
757  for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
758    SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
759    for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
760      SCEV *MulOpSCEV = Mul->getOperand(MulOp);
761      for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
762        if (MulOpSCEV == Ops[AddOp] && !isa<SCEVConstant>(MulOpSCEV)) {
763          // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1))
764          SCEVHandle InnerMul = Mul->getOperand(MulOp == 0);
765          if (Mul->getNumOperands() != 2) {
766            // If the multiply has more than two operands, we must get the
767            // Y*Z term.
768            std::vector<SCEVHandle> MulOps(Mul->op_begin(), Mul->op_end());
769            MulOps.erase(MulOps.begin()+MulOp);
770            InnerMul = getMulExpr(MulOps);
771          }
772          SCEVHandle One = getIntegerSCEV(1, Ty);
773          SCEVHandle AddOne = getAddExpr(InnerMul, One);
774          SCEVHandle OuterMul = getMulExpr(AddOne, Ops[AddOp]);
775          if (Ops.size() == 2) return OuterMul;
776          if (AddOp < Idx) {
777            Ops.erase(Ops.begin()+AddOp);
778            Ops.erase(Ops.begin()+Idx-1);
779          } else {
780            Ops.erase(Ops.begin()+Idx);
781            Ops.erase(Ops.begin()+AddOp-1);
782          }
783          Ops.push_back(OuterMul);
784          return getAddExpr(Ops);
785        }
786
787      // Check this multiply against other multiplies being added together.
788      for (unsigned OtherMulIdx = Idx+1;
789           OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
790           ++OtherMulIdx) {
791        SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
792        // If MulOp occurs in OtherMul, we can fold the two multiplies
793        // together.
794        for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
795             OMulOp != e; ++OMulOp)
796          if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
797            // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
798            SCEVHandle InnerMul1 = Mul->getOperand(MulOp == 0);
799            if (Mul->getNumOperands() != 2) {
800              std::vector<SCEVHandle> MulOps(Mul->op_begin(), Mul->op_end());
801              MulOps.erase(MulOps.begin()+MulOp);
802              InnerMul1 = getMulExpr(MulOps);
803            }
804            SCEVHandle InnerMul2 = OtherMul->getOperand(OMulOp == 0);
805            if (OtherMul->getNumOperands() != 2) {
806              std::vector<SCEVHandle> MulOps(OtherMul->op_begin(),
807                                             OtherMul->op_end());
808              MulOps.erase(MulOps.begin()+OMulOp);
809              InnerMul2 = getMulExpr(MulOps);
810            }
811            SCEVHandle InnerMulSum = getAddExpr(InnerMul1,InnerMul2);
812            SCEVHandle OuterMul = getMulExpr(MulOpSCEV, InnerMulSum);
813            if (Ops.size() == 2) return OuterMul;
814            Ops.erase(Ops.begin()+Idx);
815            Ops.erase(Ops.begin()+OtherMulIdx-1);
816            Ops.push_back(OuterMul);
817            return getAddExpr(Ops);
818          }
819      }
820    }
821  }
822
823  // If there are any add recurrences in the operands list, see if any other
824  // added values are loop invariant.  If so, we can fold them into the
825  // recurrence.
826  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
827    ++Idx;
828
829  // Scan over all recurrences, trying to fold loop invariants into them.
830  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
831    // Scan all of the other operands to this add and add them to the vector if
832    // they are loop invariant w.r.t. the recurrence.
833    std::vector<SCEVHandle> LIOps;
834    SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
835    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
836      if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
837        LIOps.push_back(Ops[i]);
838        Ops.erase(Ops.begin()+i);
839        --i; --e;
840      }
841
842    // If we found some loop invariants, fold them into the recurrence.
843    if (!LIOps.empty()) {
844      //  NLI + LI + { Start,+,Step}  -->  NLI + { LI+Start,+,Step }
845      LIOps.push_back(AddRec->getStart());
846
847      std::vector<SCEVHandle> AddRecOps(AddRec->op_begin(), AddRec->op_end());
848      AddRecOps[0] = getAddExpr(LIOps);
849
850      SCEVHandle NewRec = getAddRecExpr(AddRecOps, AddRec->getLoop());
851      // If all of the other operands were loop invariant, we are done.
852      if (Ops.size() == 1) return NewRec;
853
854      // Otherwise, add the folded AddRec by the non-liv parts.
855      for (unsigned i = 0;; ++i)
856        if (Ops[i] == AddRec) {
857          Ops[i] = NewRec;
858          break;
859        }
860      return getAddExpr(Ops);
861    }
862
863    // Okay, if there weren't any loop invariants to be folded, check to see if
864    // there are multiple AddRec's with the same loop induction variable being
865    // added together.  If so, we can fold them.
866    for (unsigned OtherIdx = Idx+1;
867         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
868      if (OtherIdx != Idx) {
869        SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
870        if (AddRec->getLoop() == OtherAddRec->getLoop()) {
871          // Other + {A,+,B} + {C,+,D}  -->  Other + {A+C,+,B+D}
872          std::vector<SCEVHandle> NewOps(AddRec->op_begin(), AddRec->op_end());
873          for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) {
874            if (i >= NewOps.size()) {
875              NewOps.insert(NewOps.end(), OtherAddRec->op_begin()+i,
876                            OtherAddRec->op_end());
877              break;
878            }
879            NewOps[i] = getAddExpr(NewOps[i], OtherAddRec->getOperand(i));
880          }
881          SCEVHandle NewAddRec = getAddRecExpr(NewOps, AddRec->getLoop());
882
883          if (Ops.size() == 2) return NewAddRec;
884
885          Ops.erase(Ops.begin()+Idx);
886          Ops.erase(Ops.begin()+OtherIdx-1);
887          Ops.push_back(NewAddRec);
888          return getAddExpr(Ops);
889        }
890      }
891
892    // Otherwise couldn't fold anything into this recurrence.  Move onto the
893    // next one.
894  }
895
896  // Okay, it looks like we really DO need an add expr.  Check to see if we
897  // already have one, otherwise create a new one.
898  std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end());
899  SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scAddExpr,
900                                                                 SCEVOps)];
901  if (Result == 0) Result = new SCEVAddExpr(Ops);
902  return Result;
903}
904
905
906SCEVHandle ScalarEvolution::getMulExpr(std::vector<SCEVHandle> &Ops) {
907  assert(!Ops.empty() && "Cannot get empty mul!");
908
909  // Sort by complexity, this groups all similar expression types together.
910  GroupByComplexity(Ops);
911
912  // If there are any constants, fold them together.
913  unsigned Idx = 0;
914  if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
915
916    // C1*(C2+V) -> C1*C2 + C1*V
917    if (Ops.size() == 2)
918      if (SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
919        if (Add->getNumOperands() == 2 &&
920            isa<SCEVConstant>(Add->getOperand(0)))
921          return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
922                            getMulExpr(LHSC, Add->getOperand(1)));
923
924
925    ++Idx;
926    while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
927      // We found two constants, fold them together!
928      ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() *
929                                           RHSC->getValue()->getValue());
930      Ops[0] = getConstant(Fold);
931      Ops.erase(Ops.begin()+1);  // Erase the folded element
932      if (Ops.size() == 1) return Ops[0];
933      LHSC = cast<SCEVConstant>(Ops[0]);
934    }
935
936    // If we are left with a constant one being multiplied, strip it off.
937    if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) {
938      Ops.erase(Ops.begin());
939      --Idx;
940    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
941      // If we have a multiply of zero, it will always be zero.
942      return Ops[0];
943    }
944  }
945
946  // Skip over the add expression until we get to a multiply.
947  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
948    ++Idx;
949
950  if (Ops.size() == 1)
951    return Ops[0];
952
953  // If there are mul operands inline them all into this expression.
954  if (Idx < Ops.size()) {
955    bool DeletedMul = false;
956    while (SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
957      // If we have an mul, expand the mul operands onto the end of the operands
958      // list.
959      Ops.insert(Ops.end(), Mul->op_begin(), Mul->op_end());
960      Ops.erase(Ops.begin()+Idx);
961      DeletedMul = true;
962    }
963
964    // If we deleted at least one mul, we added operands to the end of the list,
965    // and they are not necessarily sorted.  Recurse to resort and resimplify
966    // any operands we just aquired.
967    if (DeletedMul)
968      return getMulExpr(Ops);
969  }
970
971  // If there are any add recurrences in the operands list, see if any other
972  // added values are loop invariant.  If so, we can fold them into the
973  // recurrence.
974  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
975    ++Idx;
976
977  // Scan over all recurrences, trying to fold loop invariants into them.
978  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
979    // Scan all of the other operands to this mul and add them to the vector if
980    // they are loop invariant w.r.t. the recurrence.
981    std::vector<SCEVHandle> LIOps;
982    SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
983    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
984      if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
985        LIOps.push_back(Ops[i]);
986        Ops.erase(Ops.begin()+i);
987        --i; --e;
988      }
989
990    // If we found some loop invariants, fold them into the recurrence.
991    if (!LIOps.empty()) {
992      //  NLI * LI * { Start,+,Step}  -->  NLI * { LI*Start,+,LI*Step }
993      std::vector<SCEVHandle> NewOps;
994      NewOps.reserve(AddRec->getNumOperands());
995      if (LIOps.size() == 1) {
996        SCEV *Scale = LIOps[0];
997        for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
998          NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i)));
999      } else {
1000        for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
1001          std::vector<SCEVHandle> MulOps(LIOps);
1002          MulOps.push_back(AddRec->getOperand(i));
1003          NewOps.push_back(getMulExpr(MulOps));
1004        }
1005      }
1006
1007      SCEVHandle NewRec = getAddRecExpr(NewOps, AddRec->getLoop());
1008
1009      // If all of the other operands were loop invariant, we are done.
1010      if (Ops.size() == 1) return NewRec;
1011
1012      // Otherwise, multiply the folded AddRec by the non-liv parts.
1013      for (unsigned i = 0;; ++i)
1014        if (Ops[i] == AddRec) {
1015          Ops[i] = NewRec;
1016          break;
1017        }
1018      return getMulExpr(Ops);
1019    }
1020
1021    // Okay, if there weren't any loop invariants to be folded, check to see if
1022    // there are multiple AddRec's with the same loop induction variable being
1023    // multiplied together.  If so, we can fold them.
1024    for (unsigned OtherIdx = Idx+1;
1025         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
1026      if (OtherIdx != Idx) {
1027        SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
1028        if (AddRec->getLoop() == OtherAddRec->getLoop()) {
1029          // F * G  -->  {A,+,B} * {C,+,D}  -->  {A*C,+,F*D + G*B + B*D}
1030          SCEVAddRecExpr *F = AddRec, *G = OtherAddRec;
1031          SCEVHandle NewStart = getMulExpr(F->getStart(),
1032                                                 G->getStart());
1033          SCEVHandle B = F->getStepRecurrence(*this);
1034          SCEVHandle D = G->getStepRecurrence(*this);
1035          SCEVHandle NewStep = getAddExpr(getMulExpr(F, D),
1036                                          getMulExpr(G, B),
1037                                          getMulExpr(B, D));
1038          SCEVHandle NewAddRec = getAddRecExpr(NewStart, NewStep,
1039                                               F->getLoop());
1040          if (Ops.size() == 2) return NewAddRec;
1041
1042          Ops.erase(Ops.begin()+Idx);
1043          Ops.erase(Ops.begin()+OtherIdx-1);
1044          Ops.push_back(NewAddRec);
1045          return getMulExpr(Ops);
1046        }
1047      }
1048
1049    // Otherwise couldn't fold anything into this recurrence.  Move onto the
1050    // next one.
1051  }
1052
1053  // Okay, it looks like we really DO need an mul expr.  Check to see if we
1054  // already have one, otherwise create a new one.
1055  std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end());
1056  SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scMulExpr,
1057                                                                 SCEVOps)];
1058  if (Result == 0)
1059    Result = new SCEVMulExpr(Ops);
1060  return Result;
1061}
1062
1063SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, const SCEVHandle &RHS) {
1064  if (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
1065    if (RHSC->getValue()->equalsInt(1))
1066      return LHS;                            // X udiv 1 --> x
1067
1068    if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
1069      Constant *LHSCV = LHSC->getValue();
1070      Constant *RHSCV = RHSC->getValue();
1071      return getUnknown(ConstantExpr::getUDiv(LHSCV, RHSCV));
1072    }
1073  }
1074
1075  // FIXME: implement folding of (X*4)/4 when we know X*4 doesn't overflow.
1076
1077  SCEVUDivExpr *&Result = (*SCEVUDivs)[std::make_pair(LHS, RHS)];
1078  if (Result == 0) Result = new SCEVUDivExpr(LHS, RHS);
1079  return Result;
1080}
1081
1082
1083/// SCEVAddRecExpr::get - Get a add recurrence expression for the
1084/// specified loop.  Simplify the expression as much as possible.
1085SCEVHandle ScalarEvolution::getAddRecExpr(const SCEVHandle &Start,
1086                               const SCEVHandle &Step, const Loop *L) {
1087  std::vector<SCEVHandle> Operands;
1088  Operands.push_back(Start);
1089  if (SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
1090    if (StepChrec->getLoop() == L) {
1091      Operands.insert(Operands.end(), StepChrec->op_begin(),
1092                      StepChrec->op_end());
1093      return getAddRecExpr(Operands, L);
1094    }
1095
1096  Operands.push_back(Step);
1097  return getAddRecExpr(Operands, L);
1098}
1099
1100/// SCEVAddRecExpr::get - Get a add recurrence expression for the
1101/// specified loop.  Simplify the expression as much as possible.
1102SCEVHandle ScalarEvolution::getAddRecExpr(std::vector<SCEVHandle> &Operands,
1103                               const Loop *L) {
1104  if (Operands.size() == 1) return Operands[0];
1105
1106  if (Operands.back()->isZero()) {
1107    Operands.pop_back();
1108    return getAddRecExpr(Operands, L);             // { X,+,0 }  -->  X
1109  }
1110
1111  SCEVAddRecExpr *&Result =
1112    (*SCEVAddRecExprs)[std::make_pair(L, std::vector<SCEV*>(Operands.begin(),
1113                                                            Operands.end()))];
1114  if (Result == 0) Result = new SCEVAddRecExpr(Operands, L);
1115  return Result;
1116}
1117
1118SCEVHandle ScalarEvolution::getSMaxExpr(const SCEVHandle &LHS,
1119                                        const SCEVHandle &RHS) {
1120  std::vector<SCEVHandle> Ops;
1121  Ops.push_back(LHS);
1122  Ops.push_back(RHS);
1123  return getSMaxExpr(Ops);
1124}
1125
1126SCEVHandle ScalarEvolution::getSMaxExpr(std::vector<SCEVHandle> Ops) {
1127  assert(!Ops.empty() && "Cannot get empty smax!");
1128  if (Ops.size() == 1) return Ops[0];
1129
1130  // Sort by complexity, this groups all similar expression types together.
1131  GroupByComplexity(Ops);
1132
1133  // If there are any constants, fold them together.
1134  unsigned Idx = 0;
1135  if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1136    ++Idx;
1137    assert(Idx < Ops.size());
1138    while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1139      // We found two constants, fold them together!
1140      ConstantInt *Fold = ConstantInt::get(
1141                              APIntOps::smax(LHSC->getValue()->getValue(),
1142                                             RHSC->getValue()->getValue()));
1143      Ops[0] = getConstant(Fold);
1144      Ops.erase(Ops.begin()+1);  // Erase the folded element
1145      if (Ops.size() == 1) return Ops[0];
1146      LHSC = cast<SCEVConstant>(Ops[0]);
1147    }
1148
1149    // If we are left with a constant -inf, strip it off.
1150    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
1151      Ops.erase(Ops.begin());
1152      --Idx;
1153    }
1154  }
1155
1156  if (Ops.size() == 1) return Ops[0];
1157
1158  // Find the first SMax
1159  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
1160    ++Idx;
1161
1162  // Check to see if one of the operands is an SMax. If so, expand its operands
1163  // onto our operand list, and recurse to simplify.
1164  if (Idx < Ops.size()) {
1165    bool DeletedSMax = false;
1166    while (SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
1167      Ops.insert(Ops.end(), SMax->op_begin(), SMax->op_end());
1168      Ops.erase(Ops.begin()+Idx);
1169      DeletedSMax = true;
1170    }
1171
1172    if (DeletedSMax)
1173      return getSMaxExpr(Ops);
1174  }
1175
1176  // Okay, check to see if the same value occurs in the operand list twice.  If
1177  // so, delete one.  Since we sorted the list, these values are required to
1178  // be adjacent.
1179  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1180    if (Ops[i] == Ops[i+1]) {      //  X smax Y smax Y  -->  X smax Y
1181      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
1182      --i; --e;
1183    }
1184
1185  if (Ops.size() == 1) return Ops[0];
1186
1187  assert(!Ops.empty() && "Reduced smax down to nothing!");
1188
1189  // Okay, it looks like we really DO need an smax expr.  Check to see if we
1190  // already have one, otherwise create a new one.
1191  std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end());
1192  SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scSMaxExpr,
1193                                                                 SCEVOps)];
1194  if (Result == 0) Result = new SCEVSMaxExpr(Ops);
1195  return Result;
1196}
1197
1198SCEVHandle ScalarEvolution::getUMaxExpr(const SCEVHandle &LHS,
1199                                        const SCEVHandle &RHS) {
1200  std::vector<SCEVHandle> Ops;
1201  Ops.push_back(LHS);
1202  Ops.push_back(RHS);
1203  return getUMaxExpr(Ops);
1204}
1205
1206SCEVHandle ScalarEvolution::getUMaxExpr(std::vector<SCEVHandle> Ops) {
1207  assert(!Ops.empty() && "Cannot get empty umax!");
1208  if (Ops.size() == 1) return Ops[0];
1209
1210  // Sort by complexity, this groups all similar expression types together.
1211  GroupByComplexity(Ops);
1212
1213  // If there are any constants, fold them together.
1214  unsigned Idx = 0;
1215  if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1216    ++Idx;
1217    assert(Idx < Ops.size());
1218    while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1219      // We found two constants, fold them together!
1220      ConstantInt *Fold = ConstantInt::get(
1221                              APIntOps::umax(LHSC->getValue()->getValue(),
1222                                             RHSC->getValue()->getValue()));
1223      Ops[0] = getConstant(Fold);
1224      Ops.erase(Ops.begin()+1);  // Erase the folded element
1225      if (Ops.size() == 1) return Ops[0];
1226      LHSC = cast<SCEVConstant>(Ops[0]);
1227    }
1228
1229    // If we are left with a constant zero, strip it off.
1230    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) {
1231      Ops.erase(Ops.begin());
1232      --Idx;
1233    }
1234  }
1235
1236  if (Ops.size() == 1) return Ops[0];
1237
1238  // Find the first UMax
1239  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr)
1240    ++Idx;
1241
1242  // Check to see if one of the operands is a UMax. If so, expand its operands
1243  // onto our operand list, and recurse to simplify.
1244  if (Idx < Ops.size()) {
1245    bool DeletedUMax = false;
1246    while (SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) {
1247      Ops.insert(Ops.end(), UMax->op_begin(), UMax->op_end());
1248      Ops.erase(Ops.begin()+Idx);
1249      DeletedUMax = true;
1250    }
1251
1252    if (DeletedUMax)
1253      return getUMaxExpr(Ops);
1254  }
1255
1256  // Okay, check to see if the same value occurs in the operand list twice.  If
1257  // so, delete one.  Since we sorted the list, these values are required to
1258  // be adjacent.
1259  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1260    if (Ops[i] == Ops[i+1]) {      //  X umax Y umax Y  -->  X umax Y
1261      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
1262      --i; --e;
1263    }
1264
1265  if (Ops.size() == 1) return Ops[0];
1266
1267  assert(!Ops.empty() && "Reduced umax down to nothing!");
1268
1269  // Okay, it looks like we really DO need a umax expr.  Check to see if we
1270  // already have one, otherwise create a new one.
1271  std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end());
1272  SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scUMaxExpr,
1273                                                                 SCEVOps)];
1274  if (Result == 0) Result = new SCEVUMaxExpr(Ops);
1275  return Result;
1276}
1277
1278SCEVHandle ScalarEvolution::getUnknown(Value *V) {
1279  if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
1280    return getConstant(CI);
1281  SCEVUnknown *&Result = (*SCEVUnknowns)[V];
1282  if (Result == 0) Result = new SCEVUnknown(V);
1283  return Result;
1284}
1285
1286
1287//===----------------------------------------------------------------------===//
1288//             ScalarEvolutionsImpl Definition and Implementation
1289//===----------------------------------------------------------------------===//
1290//
1291/// ScalarEvolutionsImpl - This class implements the main driver for the scalar
1292/// evolution code.
1293///
1294namespace {
1295  struct VISIBILITY_HIDDEN ScalarEvolutionsImpl {
1296    /// SE - A reference to the public ScalarEvolution object.
1297    ScalarEvolution &SE;
1298
1299    /// F - The function we are analyzing.
1300    ///
1301    Function &F;
1302
1303    /// LI - The loop information for the function we are currently analyzing.
1304    ///
1305    LoopInfo &LI;
1306
1307    /// UnknownValue - This SCEV is used to represent unknown trip counts and
1308    /// things.
1309    SCEVHandle UnknownValue;
1310
1311    /// Scalars - This is a cache of the scalars we have analyzed so far.
1312    ///
1313    std::map<Value*, SCEVHandle> Scalars;
1314
1315    /// IterationCounts - Cache the iteration count of the loops for this
1316    /// function as they are computed.
1317    std::map<const Loop*, SCEVHandle> IterationCounts;
1318
1319    /// ConstantEvolutionLoopExitValue - This map contains entries for all of
1320    /// the PHI instructions that we attempt to compute constant evolutions for.
1321    /// This allows us to avoid potentially expensive recomputation of these
1322    /// properties.  An instruction maps to null if we are unable to compute its
1323    /// exit value.
1324    std::map<PHINode*, Constant*> ConstantEvolutionLoopExitValue;
1325
1326  public:
1327    ScalarEvolutionsImpl(ScalarEvolution &se, Function &f, LoopInfo &li)
1328      : SE(se), F(f), LI(li), UnknownValue(new SCEVCouldNotCompute()) {}
1329
1330    /// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
1331    /// expression and create a new one.
1332    SCEVHandle getSCEV(Value *V);
1333
1334    /// hasSCEV - Return true if the SCEV for this value has already been
1335    /// computed.
1336    bool hasSCEV(Value *V) const {
1337      return Scalars.count(V);
1338    }
1339
1340    /// setSCEV - Insert the specified SCEV into the map of current SCEVs for
1341    /// the specified value.
1342    void setSCEV(Value *V, const SCEVHandle &H) {
1343      bool isNew = Scalars.insert(std::make_pair(V, H)).second;
1344      assert(isNew && "This entry already existed!");
1345    }
1346
1347
1348    /// getSCEVAtScope - Compute the value of the specified expression within
1349    /// the indicated loop (which may be null to indicate in no loop).  If the
1350    /// expression cannot be evaluated, return UnknownValue itself.
1351    SCEVHandle getSCEVAtScope(SCEV *V, const Loop *L);
1352
1353
1354    /// hasLoopInvariantIterationCount - Return true if the specified loop has
1355    /// an analyzable loop-invariant iteration count.
1356    bool hasLoopInvariantIterationCount(const Loop *L);
1357
1358    /// getIterationCount - If the specified loop has a predictable iteration
1359    /// count, return it.  Note that it is not valid to call this method on a
1360    /// loop without a loop-invariant iteration count.
1361    SCEVHandle getIterationCount(const Loop *L);
1362
1363    /// deleteValueFromRecords - This method should be called by the
1364    /// client before it removes a value from the program, to make sure
1365    /// that no dangling references are left around.
1366    void deleteValueFromRecords(Value *V);
1367
1368  private:
1369    /// createSCEV - We know that there is no SCEV for the specified value.
1370    /// Analyze the expression.
1371    SCEVHandle createSCEV(Value *V);
1372
1373    /// createNodeForPHI - Provide the special handling we need to analyze PHI
1374    /// SCEVs.
1375    SCEVHandle createNodeForPHI(PHINode *PN);
1376
1377    /// ReplaceSymbolicValueWithConcrete - This looks up the computed SCEV value
1378    /// for the specified instruction and replaces any references to the
1379    /// symbolic value SymName with the specified value.  This is used during
1380    /// PHI resolution.
1381    void ReplaceSymbolicValueWithConcrete(Instruction *I,
1382                                          const SCEVHandle &SymName,
1383                                          const SCEVHandle &NewVal);
1384
1385    /// ComputeIterationCount - Compute the number of times the specified loop
1386    /// will iterate.
1387    SCEVHandle ComputeIterationCount(const Loop *L);
1388
1389    /// ComputeLoadConstantCompareIterationCount - Given an exit condition of
1390    /// 'icmp op load X, cst', try to see if we can compute the trip count.
1391    SCEVHandle ComputeLoadConstantCompareIterationCount(LoadInst *LI,
1392                                                        Constant *RHS,
1393                                                        const Loop *L,
1394                                                        ICmpInst::Predicate p);
1395
1396    /// ComputeIterationCountExhaustively - If the trip is known to execute a
1397    /// constant number of times (the condition evolves only from constants),
1398    /// try to evaluate a few iterations of the loop until we get the exit
1399    /// condition gets a value of ExitWhen (true or false).  If we cannot
1400    /// evaluate the trip count of the loop, return UnknownValue.
1401    SCEVHandle ComputeIterationCountExhaustively(const Loop *L, Value *Cond,
1402                                                 bool ExitWhen);
1403
1404    /// HowFarToZero - Return the number of times a backedge comparing the
1405    /// specified value to zero will execute.  If not computable, return
1406    /// UnknownValue.
1407    SCEVHandle HowFarToZero(SCEV *V, const Loop *L);
1408
1409    /// HowFarToNonZero - Return the number of times a backedge checking the
1410    /// specified value for nonzero will execute.  If not computable, return
1411    /// UnknownValue.
1412    SCEVHandle HowFarToNonZero(SCEV *V, const Loop *L);
1413
1414    /// HowManyLessThans - Return the number of times a backedge containing the
1415    /// specified less-than comparison will execute.  If not computable, return
1416    /// UnknownValue. isSigned specifies whether the less-than is signed.
1417    SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L,
1418                                bool isSigned);
1419
1420    /// executesAtLeastOnce - Test whether entry to the loop is protected by
1421    /// a conditional between LHS and RHS.
1422    bool executesAtLeastOnce(const Loop *L, bool isSigned, SCEV *LHS, SCEV *RHS);
1423
1424    /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
1425    /// in the header of its containing loop, we know the loop executes a
1426    /// constant number of times, and the PHI node is just a recurrence
1427    /// involving constants, fold it.
1428    Constant *getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& Its,
1429                                                const Loop *L);
1430  };
1431}
1432
1433//===----------------------------------------------------------------------===//
1434//            Basic SCEV Analysis and PHI Idiom Recognition Code
1435//
1436
1437/// deleteValueFromRecords - This method should be called by the
1438/// client before it removes an instruction from the program, to make sure
1439/// that no dangling references are left around.
1440void ScalarEvolutionsImpl::deleteValueFromRecords(Value *V) {
1441  SmallVector<Value *, 16> Worklist;
1442
1443  if (Scalars.erase(V)) {
1444    if (PHINode *PN = dyn_cast<PHINode>(V))
1445      ConstantEvolutionLoopExitValue.erase(PN);
1446    Worklist.push_back(V);
1447  }
1448
1449  while (!Worklist.empty()) {
1450    Value *VV = Worklist.back();
1451    Worklist.pop_back();
1452
1453    for (Instruction::use_iterator UI = VV->use_begin(), UE = VV->use_end();
1454         UI != UE; ++UI) {
1455      Instruction *Inst = cast<Instruction>(*UI);
1456      if (Scalars.erase(Inst)) {
1457        if (PHINode *PN = dyn_cast<PHINode>(VV))
1458          ConstantEvolutionLoopExitValue.erase(PN);
1459        Worklist.push_back(Inst);
1460      }
1461    }
1462  }
1463}
1464
1465
1466/// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
1467/// expression and create a new one.
1468SCEVHandle ScalarEvolutionsImpl::getSCEV(Value *V) {
1469  assert(V->getType() != Type::VoidTy && "Can't analyze void expressions!");
1470
1471  std::map<Value*, SCEVHandle>::iterator I = Scalars.find(V);
1472  if (I != Scalars.end()) return I->second;
1473  SCEVHandle S = createSCEV(V);
1474  Scalars.insert(std::make_pair(V, S));
1475  return S;
1476}
1477
1478/// ReplaceSymbolicValueWithConcrete - This looks up the computed SCEV value for
1479/// the specified instruction and replaces any references to the symbolic value
1480/// SymName with the specified value.  This is used during PHI resolution.
1481void ScalarEvolutionsImpl::
1482ReplaceSymbolicValueWithConcrete(Instruction *I, const SCEVHandle &SymName,
1483                                 const SCEVHandle &NewVal) {
1484  std::map<Value*, SCEVHandle>::iterator SI = Scalars.find(I);
1485  if (SI == Scalars.end()) return;
1486
1487  SCEVHandle NV =
1488    SI->second->replaceSymbolicValuesWithConcrete(SymName, NewVal, SE);
1489  if (NV == SI->second) return;  // No change.
1490
1491  SI->second = NV;       // Update the scalars map!
1492
1493  // Any instruction values that use this instruction might also need to be
1494  // updated!
1495  for (Value::use_iterator UI = I->use_begin(), E = I->use_end();
1496       UI != E; ++UI)
1497    ReplaceSymbolicValueWithConcrete(cast<Instruction>(*UI), SymName, NewVal);
1498}
1499
1500/// createNodeForPHI - PHI nodes have two cases.  Either the PHI node exists in
1501/// a loop header, making it a potential recurrence, or it doesn't.
1502///
1503SCEVHandle ScalarEvolutionsImpl::createNodeForPHI(PHINode *PN) {
1504  if (PN->getNumIncomingValues() == 2)  // The loops have been canonicalized.
1505    if (const Loop *L = LI.getLoopFor(PN->getParent()))
1506      if (L->getHeader() == PN->getParent()) {
1507        // If it lives in the loop header, it has two incoming values, one
1508        // from outside the loop, and one from inside.
1509        unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0));
1510        unsigned BackEdge     = IncomingEdge^1;
1511
1512        // While we are analyzing this PHI node, handle its value symbolically.
1513        SCEVHandle SymbolicName = SE.getUnknown(PN);
1514        assert(Scalars.find(PN) == Scalars.end() &&
1515               "PHI node already processed?");
1516        Scalars.insert(std::make_pair(PN, SymbolicName));
1517
1518        // Using this symbolic name for the PHI, analyze the value coming around
1519        // the back-edge.
1520        SCEVHandle BEValue = getSCEV(PN->getIncomingValue(BackEdge));
1521
1522        // NOTE: If BEValue is loop invariant, we know that the PHI node just
1523        // has a special value for the first iteration of the loop.
1524
1525        // If the value coming around the backedge is an add with the symbolic
1526        // value we just inserted, then we found a simple induction variable!
1527        if (SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
1528          // If there is a single occurrence of the symbolic value, replace it
1529          // with a recurrence.
1530          unsigned FoundIndex = Add->getNumOperands();
1531          for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
1532            if (Add->getOperand(i) == SymbolicName)
1533              if (FoundIndex == e) {
1534                FoundIndex = i;
1535                break;
1536              }
1537
1538          if (FoundIndex != Add->getNumOperands()) {
1539            // Create an add with everything but the specified operand.
1540            std::vector<SCEVHandle> Ops;
1541            for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
1542              if (i != FoundIndex)
1543                Ops.push_back(Add->getOperand(i));
1544            SCEVHandle Accum = SE.getAddExpr(Ops);
1545
1546            // This is not a valid addrec if the step amount is varying each
1547            // loop iteration, but is not itself an addrec in this loop.
1548            if (Accum->isLoopInvariant(L) ||
1549                (isa<SCEVAddRecExpr>(Accum) &&
1550                 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
1551              SCEVHandle StartVal = getSCEV(PN->getIncomingValue(IncomingEdge));
1552              SCEVHandle PHISCEV  = SE.getAddRecExpr(StartVal, Accum, L);
1553
1554              // Okay, for the entire analysis of this edge we assumed the PHI
1555              // to be symbolic.  We now need to go back and update all of the
1556              // entries for the scalars that use the PHI (except for the PHI
1557              // itself) to use the new analyzed value instead of the "symbolic"
1558              // value.
1559              ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV);
1560              return PHISCEV;
1561            }
1562          }
1563        } else if (SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(BEValue)) {
1564          // Otherwise, this could be a loop like this:
1565          //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
1566          // In this case, j = {1,+,1}  and BEValue is j.
1567          // Because the other in-value of i (0) fits the evolution of BEValue
1568          // i really is an addrec evolution.
1569          if (AddRec->getLoop() == L && AddRec->isAffine()) {
1570            SCEVHandle StartVal = getSCEV(PN->getIncomingValue(IncomingEdge));
1571
1572            // If StartVal = j.start - j.stride, we can use StartVal as the
1573            // initial step of the addrec evolution.
1574            if (StartVal == SE.getMinusSCEV(AddRec->getOperand(0),
1575                                            AddRec->getOperand(1))) {
1576              SCEVHandle PHISCEV =
1577                 SE.getAddRecExpr(StartVal, AddRec->getOperand(1), L);
1578
1579              // Okay, for the entire analysis of this edge we assumed the PHI
1580              // to be symbolic.  We now need to go back and update all of the
1581              // entries for the scalars that use the PHI (except for the PHI
1582              // itself) to use the new analyzed value instead of the "symbolic"
1583              // value.
1584              ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV);
1585              return PHISCEV;
1586            }
1587          }
1588        }
1589
1590        return SymbolicName;
1591      }
1592
1593  // If it's not a loop phi, we can't handle it yet.
1594  return SE.getUnknown(PN);
1595}
1596
1597/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
1598/// guaranteed to end in (at every loop iteration).  It is, at the same time,
1599/// the minimum number of times S is divisible by 2.  For example, given {4,+,8}
1600/// it returns 2.  If S is guaranteed to be 0, it returns the bitwidth of S.
1601static uint32_t GetMinTrailingZeros(SCEVHandle S) {
1602  if (SCEVConstant *C = dyn_cast<SCEVConstant>(S))
1603    return C->getValue()->getValue().countTrailingZeros();
1604
1605  if (SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
1606    return std::min(GetMinTrailingZeros(T->getOperand()), T->getBitWidth());
1607
1608  if (SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
1609    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
1610    return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes;
1611  }
1612
1613  if (SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
1614    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
1615    return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes;
1616  }
1617
1618  if (SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
1619    // The result is the min of all operands results.
1620    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
1621    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
1622      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
1623    return MinOpRes;
1624  }
1625
1626  if (SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
1627    // The result is the sum of all operands results.
1628    uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
1629    uint32_t BitWidth = M->getBitWidth();
1630    for (unsigned i = 1, e = M->getNumOperands();
1631         SumOpRes != BitWidth && i != e; ++i)
1632      SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
1633                          BitWidth);
1634    return SumOpRes;
1635  }
1636
1637  if (SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
1638    // The result is the min of all operands results.
1639    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
1640    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
1641      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
1642    return MinOpRes;
1643  }
1644
1645  if (SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
1646    // The result is the min of all operands results.
1647    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
1648    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
1649      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
1650    return MinOpRes;
1651  }
1652
1653  if (SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
1654    // The result is the min of all operands results.
1655    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
1656    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
1657      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
1658    return MinOpRes;
1659  }
1660
1661  // SCEVUDivExpr, SCEVUnknown
1662  return 0;
1663}
1664
1665/// createSCEV - We know that there is no SCEV for the specified value.
1666/// Analyze the expression.
1667///
1668SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
1669  if (!isa<IntegerType>(V->getType()))
1670    return SE.getUnknown(V);
1671
1672  unsigned Opcode = Instruction::UserOp1;
1673  if (Instruction *I = dyn_cast<Instruction>(V))
1674    Opcode = I->getOpcode();
1675  else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
1676    Opcode = CE->getOpcode();
1677  else
1678    return SE.getUnknown(V);
1679
1680  User *U = cast<User>(V);
1681  switch (Opcode) {
1682  case Instruction::Add:
1683    return SE.getAddExpr(getSCEV(U->getOperand(0)),
1684                         getSCEV(U->getOperand(1)));
1685  case Instruction::Mul:
1686    return SE.getMulExpr(getSCEV(U->getOperand(0)),
1687                         getSCEV(U->getOperand(1)));
1688  case Instruction::UDiv:
1689    return SE.getUDivExpr(getSCEV(U->getOperand(0)),
1690                          getSCEV(U->getOperand(1)));
1691  case Instruction::Sub:
1692    return SE.getMinusSCEV(getSCEV(U->getOperand(0)),
1693                           getSCEV(U->getOperand(1)));
1694  case Instruction::Or:
1695    // If the RHS of the Or is a constant, we may have something like:
1696    // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
1697    // optimizations will transparently handle this case.
1698    //
1699    // In order for this transformation to be safe, the LHS must be of the
1700    // form X*(2^n) and the Or constant must be less than 2^n.
1701    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
1702      SCEVHandle LHS = getSCEV(U->getOperand(0));
1703      const APInt &CIVal = CI->getValue();
1704      if (GetMinTrailingZeros(LHS) >=
1705          (CIVal.getBitWidth() - CIVal.countLeadingZeros()))
1706        return SE.getAddExpr(LHS, getSCEV(U->getOperand(1)));
1707    }
1708    break;
1709  case Instruction::Xor:
1710    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
1711      // If the RHS of the xor is a signbit, then this is just an add.
1712      // Instcombine turns add of signbit into xor as a strength reduction step.
1713      if (CI->getValue().isSignBit())
1714        return SE.getAddExpr(getSCEV(U->getOperand(0)),
1715                             getSCEV(U->getOperand(1)));
1716
1717      // If the RHS of xor is -1, then this is a not operation.
1718      else if (CI->isAllOnesValue())
1719        return SE.getNotSCEV(getSCEV(U->getOperand(0)));
1720    }
1721    break;
1722
1723  case Instruction::Shl:
1724    // Turn shift left of a constant amount into a multiply.
1725    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
1726      uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
1727      Constant *X = ConstantInt::get(
1728        APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
1729      return SE.getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
1730    }
1731    break;
1732
1733  case Instruction::LShr:
1734    // Turn logical shift right of a constant into a unsigned divide.
1735    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
1736      uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
1737      Constant *X = ConstantInt::get(
1738        APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
1739      return SE.getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
1740    }
1741    break;
1742
1743  case Instruction::Trunc:
1744    return SE.getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
1745
1746  case Instruction::ZExt:
1747    return SE.getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
1748
1749  case Instruction::SExt:
1750    return SE.getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
1751
1752  case Instruction::BitCast:
1753    // BitCasts are no-op casts so we just eliminate the cast.
1754    if (U->getType()->isInteger() &&
1755        U->getOperand(0)->getType()->isInteger())
1756      return getSCEV(U->getOperand(0));
1757    break;
1758
1759  case Instruction::PHI:
1760    return createNodeForPHI(cast<PHINode>(U));
1761
1762  case Instruction::Select:
1763    // This could be a smax or umax that was lowered earlier.
1764    // Try to recover it.
1765    if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
1766      Value *LHS = ICI->getOperand(0);
1767      Value *RHS = ICI->getOperand(1);
1768      switch (ICI->getPredicate()) {
1769      case ICmpInst::ICMP_SLT:
1770      case ICmpInst::ICMP_SLE:
1771        std::swap(LHS, RHS);
1772        // fall through
1773      case ICmpInst::ICMP_SGT:
1774      case ICmpInst::ICMP_SGE:
1775        if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
1776          return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
1777        else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
1778          // -smax(-x, -y) == smin(x, y).
1779          return SE.getNegativeSCEV(SE.getSMaxExpr(
1780                                        SE.getNegativeSCEV(getSCEV(LHS)),
1781                                        SE.getNegativeSCEV(getSCEV(RHS))));
1782        break;
1783      case ICmpInst::ICMP_ULT:
1784      case ICmpInst::ICMP_ULE:
1785        std::swap(LHS, RHS);
1786        // fall through
1787      case ICmpInst::ICMP_UGT:
1788      case ICmpInst::ICMP_UGE:
1789        if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
1790          return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
1791        else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
1792          // ~umax(~x, ~y) == umin(x, y)
1793          return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)),
1794                                              SE.getNotSCEV(getSCEV(RHS))));
1795        break;
1796      default:
1797        break;
1798      }
1799    }
1800
1801  default: // We cannot analyze this expression.
1802    break;
1803  }
1804
1805  return SE.getUnknown(V);
1806}
1807
1808
1809
1810//===----------------------------------------------------------------------===//
1811//                   Iteration Count Computation Code
1812//
1813
1814/// getIterationCount - If the specified loop has a predictable iteration
1815/// count, return it.  Note that it is not valid to call this method on a
1816/// loop without a loop-invariant iteration count.
1817SCEVHandle ScalarEvolutionsImpl::getIterationCount(const Loop *L) {
1818  std::map<const Loop*, SCEVHandle>::iterator I = IterationCounts.find(L);
1819  if (I == IterationCounts.end()) {
1820    SCEVHandle ItCount = ComputeIterationCount(L);
1821    I = IterationCounts.insert(std::make_pair(L, ItCount)).first;
1822    if (ItCount != UnknownValue) {
1823      assert(ItCount->isLoopInvariant(L) &&
1824             "Computed trip count isn't loop invariant for loop!");
1825      ++NumTripCountsComputed;
1826    } else if (isa<PHINode>(L->getHeader()->begin())) {
1827      // Only count loops that have phi nodes as not being computable.
1828      ++NumTripCountsNotComputed;
1829    }
1830  }
1831  return I->second;
1832}
1833
1834/// ComputeIterationCount - Compute the number of times the specified loop
1835/// will iterate.
1836SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) {
1837  // If the loop has a non-one exit block count, we can't analyze it.
1838  SmallVector<BasicBlock*, 8> ExitBlocks;
1839  L->getExitBlocks(ExitBlocks);
1840  if (ExitBlocks.size() != 1) return UnknownValue;
1841
1842  // Okay, there is one exit block.  Try to find the condition that causes the
1843  // loop to be exited.
1844  BasicBlock *ExitBlock = ExitBlocks[0];
1845
1846  BasicBlock *ExitingBlock = 0;
1847  for (pred_iterator PI = pred_begin(ExitBlock), E = pred_end(ExitBlock);
1848       PI != E; ++PI)
1849    if (L->contains(*PI)) {
1850      if (ExitingBlock == 0)
1851        ExitingBlock = *PI;
1852      else
1853        return UnknownValue;   // More than one block exiting!
1854    }
1855  assert(ExitingBlock && "No exits from loop, something is broken!");
1856
1857  // Okay, we've computed the exiting block.  See what condition causes us to
1858  // exit.
1859  //
1860  // FIXME: we should be able to handle switch instructions (with a single exit)
1861  BranchInst *ExitBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
1862  if (ExitBr == 0) return UnknownValue;
1863  assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!");
1864
1865  // At this point, we know we have a conditional branch that determines whether
1866  // the loop is exited.  However, we don't know if the branch is executed each
1867  // time through the loop.  If not, then the execution count of the branch will
1868  // not be equal to the trip count of the loop.
1869  //
1870  // Currently we check for this by checking to see if the Exit branch goes to
1871  // the loop header.  If so, we know it will always execute the same number of
1872  // times as the loop.  We also handle the case where the exit block *is* the
1873  // loop header.  This is common for un-rotated loops.  More extensive analysis
1874  // could be done to handle more cases here.
1875  if (ExitBr->getSuccessor(0) != L->getHeader() &&
1876      ExitBr->getSuccessor(1) != L->getHeader() &&
1877      ExitBr->getParent() != L->getHeader())
1878    return UnknownValue;
1879
1880  ICmpInst *ExitCond = dyn_cast<ICmpInst>(ExitBr->getCondition());
1881
1882  // If it's not an integer comparison then compute it the hard way.
1883  // Note that ICmpInst deals with pointer comparisons too so we must check
1884  // the type of the operand.
1885  if (ExitCond == 0 || isa<PointerType>(ExitCond->getOperand(0)->getType()))
1886    return ComputeIterationCountExhaustively(L, ExitBr->getCondition(),
1887                                          ExitBr->getSuccessor(0) == ExitBlock);
1888
1889  // If the condition was exit on true, convert the condition to exit on false
1890  ICmpInst::Predicate Cond;
1891  if (ExitBr->getSuccessor(1) == ExitBlock)
1892    Cond = ExitCond->getPredicate();
1893  else
1894    Cond = ExitCond->getInversePredicate();
1895
1896  // Handle common loops like: for (X = "string"; *X; ++X)
1897  if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
1898    if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
1899      SCEVHandle ItCnt =
1900        ComputeLoadConstantCompareIterationCount(LI, RHS, L, Cond);
1901      if (!isa<SCEVCouldNotCompute>(ItCnt)) return ItCnt;
1902    }
1903
1904  SCEVHandle LHS = getSCEV(ExitCond->getOperand(0));
1905  SCEVHandle RHS = getSCEV(ExitCond->getOperand(1));
1906
1907  // Try to evaluate any dependencies out of the loop.
1908  SCEVHandle Tmp = getSCEVAtScope(LHS, L);
1909  if (!isa<SCEVCouldNotCompute>(Tmp)) LHS = Tmp;
1910  Tmp = getSCEVAtScope(RHS, L);
1911  if (!isa<SCEVCouldNotCompute>(Tmp)) RHS = Tmp;
1912
1913  // At this point, we would like to compute how many iterations of the
1914  // loop the predicate will return true for these inputs.
1915  if (isa<SCEVConstant>(LHS) && !isa<SCEVConstant>(RHS)) {
1916    // If there is a constant, force it into the RHS.
1917    std::swap(LHS, RHS);
1918    Cond = ICmpInst::getSwappedPredicate(Cond);
1919  }
1920
1921  // FIXME: think about handling pointer comparisons!  i.e.:
1922  // while (P != P+100) ++P;
1923
1924  // If we have a comparison of a chrec against a constant, try to use value
1925  // ranges to answer this query.
1926  if (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
1927    if (SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
1928      if (AddRec->getLoop() == L) {
1929        // Form the comparison range using the constant of the correct type so
1930        // that the ConstantRange class knows to do a signed or unsigned
1931        // comparison.
1932        ConstantInt *CompVal = RHSC->getValue();
1933        const Type *RealTy = ExitCond->getOperand(0)->getType();
1934        CompVal = dyn_cast<ConstantInt>(
1935          ConstantExpr::getBitCast(CompVal, RealTy));
1936        if (CompVal) {
1937          // Form the constant range.
1938          ConstantRange CompRange(
1939              ICmpInst::makeConstantRange(Cond, CompVal->getValue()));
1940
1941          SCEVHandle Ret = AddRec->getNumIterationsInRange(CompRange, SE);
1942          if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
1943        }
1944      }
1945
1946  switch (Cond) {
1947  case ICmpInst::ICMP_NE: {                     // while (X != Y)
1948    // Convert to: while (X-Y != 0)
1949    SCEVHandle TC = HowFarToZero(SE.getMinusSCEV(LHS, RHS), L);
1950    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
1951    break;
1952  }
1953  case ICmpInst::ICMP_EQ: {
1954    // Convert to: while (X-Y == 0)           // while (X == Y)
1955    SCEVHandle TC = HowFarToNonZero(SE.getMinusSCEV(LHS, RHS), L);
1956    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
1957    break;
1958  }
1959  case ICmpInst::ICMP_SLT: {
1960    SCEVHandle TC = HowManyLessThans(LHS, RHS, L, true);
1961    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
1962    break;
1963  }
1964  case ICmpInst::ICMP_SGT: {
1965    SCEVHandle TC = HowManyLessThans(SE.getNegativeSCEV(LHS),
1966                                     SE.getNegativeSCEV(RHS), L, true);
1967    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
1968    break;
1969  }
1970  case ICmpInst::ICMP_ULT: {
1971    SCEVHandle TC = HowManyLessThans(LHS, RHS, L, false);
1972    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
1973    break;
1974  }
1975  case ICmpInst::ICMP_UGT: {
1976    SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS),
1977                                     SE.getNotSCEV(RHS), L, false);
1978    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
1979    break;
1980  }
1981  default:
1982#if 0
1983    cerr << "ComputeIterationCount ";
1984    if (ExitCond->getOperand(0)->getType()->isUnsigned())
1985      cerr << "[unsigned] ";
1986    cerr << *LHS << "   "
1987         << Instruction::getOpcodeName(Instruction::ICmp)
1988         << "   " << *RHS << "\n";
1989#endif
1990    break;
1991  }
1992  return ComputeIterationCountExhaustively(L, ExitCond,
1993                                       ExitBr->getSuccessor(0) == ExitBlock);
1994}
1995
1996static ConstantInt *
1997EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
1998                                ScalarEvolution &SE) {
1999  SCEVHandle InVal = SE.getConstant(C);
2000  SCEVHandle Val = AddRec->evaluateAtIteration(InVal, SE);
2001  assert(isa<SCEVConstant>(Val) &&
2002         "Evaluation of SCEV at constant didn't fold correctly?");
2003  return cast<SCEVConstant>(Val)->getValue();
2004}
2005
2006/// GetAddressedElementFromGlobal - Given a global variable with an initializer
2007/// and a GEP expression (missing the pointer index) indexing into it, return
2008/// the addressed element of the initializer or null if the index expression is
2009/// invalid.
2010static Constant *
2011GetAddressedElementFromGlobal(GlobalVariable *GV,
2012                              const std::vector<ConstantInt*> &Indices) {
2013  Constant *Init = GV->getInitializer();
2014  for (unsigned i = 0, e = Indices.size(); i != e; ++i) {
2015    uint64_t Idx = Indices[i]->getZExtValue();
2016    if (ConstantStruct *CS = dyn_cast<ConstantStruct>(Init)) {
2017      assert(Idx < CS->getNumOperands() && "Bad struct index!");
2018      Init = cast<Constant>(CS->getOperand(Idx));
2019    } else if (ConstantArray *CA = dyn_cast<ConstantArray>(Init)) {
2020      if (Idx >= CA->getNumOperands()) return 0;  // Bogus program
2021      Init = cast<Constant>(CA->getOperand(Idx));
2022    } else if (isa<ConstantAggregateZero>(Init)) {
2023      if (const StructType *STy = dyn_cast<StructType>(Init->getType())) {
2024        assert(Idx < STy->getNumElements() && "Bad struct index!");
2025        Init = Constant::getNullValue(STy->getElementType(Idx));
2026      } else if (const ArrayType *ATy = dyn_cast<ArrayType>(Init->getType())) {
2027        if (Idx >= ATy->getNumElements()) return 0;  // Bogus program
2028        Init = Constant::getNullValue(ATy->getElementType());
2029      } else {
2030        assert(0 && "Unknown constant aggregate type!");
2031      }
2032      return 0;
2033    } else {
2034      return 0; // Unknown initializer type
2035    }
2036  }
2037  return Init;
2038}
2039
2040/// ComputeLoadConstantCompareIterationCount - Given an exit condition of
2041/// 'icmp op load X, cst', try to see if we can compute the trip count.
2042SCEVHandle ScalarEvolutionsImpl::
2043ComputeLoadConstantCompareIterationCount(LoadInst *LI, Constant *RHS,
2044                                         const Loop *L,
2045                                         ICmpInst::Predicate predicate) {
2046  if (LI->isVolatile()) return UnknownValue;
2047
2048  // Check to see if the loaded pointer is a getelementptr of a global.
2049  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0));
2050  if (!GEP) return UnknownValue;
2051
2052  // Make sure that it is really a constant global we are gepping, with an
2053  // initializer, and make sure the first IDX is really 0.
2054  GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
2055  if (!GV || !GV->isConstant() || !GV->hasInitializer() ||
2056      GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) ||
2057      !cast<Constant>(GEP->getOperand(1))->isNullValue())
2058    return UnknownValue;
2059
2060  // Okay, we allow one non-constant index into the GEP instruction.
2061  Value *VarIdx = 0;
2062  std::vector<ConstantInt*> Indexes;
2063  unsigned VarIdxNum = 0;
2064  for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i)
2065    if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
2066      Indexes.push_back(CI);
2067    } else if (!isa<ConstantInt>(GEP->getOperand(i))) {
2068      if (VarIdx) return UnknownValue;  // Multiple non-constant idx's.
2069      VarIdx = GEP->getOperand(i);
2070      VarIdxNum = i-2;
2071      Indexes.push_back(0);
2072    }
2073
2074  // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant.
2075  // Check to see if X is a loop variant variable value now.
2076  SCEVHandle Idx = getSCEV(VarIdx);
2077  SCEVHandle Tmp = getSCEVAtScope(Idx, L);
2078  if (!isa<SCEVCouldNotCompute>(Tmp)) Idx = Tmp;
2079
2080  // We can only recognize very limited forms of loop index expressions, in
2081  // particular, only affine AddRec's like {C1,+,C2}.
2082  SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx);
2083  if (!IdxExpr || !IdxExpr->isAffine() || IdxExpr->isLoopInvariant(L) ||
2084      !isa<SCEVConstant>(IdxExpr->getOperand(0)) ||
2085      !isa<SCEVConstant>(IdxExpr->getOperand(1)))
2086    return UnknownValue;
2087
2088  unsigned MaxSteps = MaxBruteForceIterations;
2089  for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
2090    ConstantInt *ItCst =
2091      ConstantInt::get(IdxExpr->getType(), IterationNum);
2092    ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, SE);
2093
2094    // Form the GEP offset.
2095    Indexes[VarIdxNum] = Val;
2096
2097    Constant *Result = GetAddressedElementFromGlobal(GV, Indexes);
2098    if (Result == 0) break;  // Cannot compute!
2099
2100    // Evaluate the condition for this iteration.
2101    Result = ConstantExpr::getICmp(predicate, Result, RHS);
2102    if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
2103    if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
2104#if 0
2105      cerr << "\n***\n*** Computed loop count " << *ItCst
2106           << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
2107           << "***\n";
2108#endif
2109      ++NumArrayLenItCounts;
2110      return SE.getConstant(ItCst);   // Found terminating iteration!
2111    }
2112  }
2113  return UnknownValue;
2114}
2115
2116
2117/// CanConstantFold - Return true if we can constant fold an instruction of the
2118/// specified type, assuming that all operands were constants.
2119static bool CanConstantFold(const Instruction *I) {
2120  if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
2121      isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I))
2122    return true;
2123
2124  if (const CallInst *CI = dyn_cast<CallInst>(I))
2125    if (const Function *F = CI->getCalledFunction())
2126      return canConstantFoldCallTo(F);
2127  return false;
2128}
2129
2130/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
2131/// in the loop that V is derived from.  We allow arbitrary operations along the
2132/// way, but the operands of an operation must either be constants or a value
2133/// derived from a constant PHI.  If this expression does not fit with these
2134/// constraints, return null.
2135static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
2136  // If this is not an instruction, or if this is an instruction outside of the
2137  // loop, it can't be derived from a loop PHI.
2138  Instruction *I = dyn_cast<Instruction>(V);
2139  if (I == 0 || !L->contains(I->getParent())) return 0;
2140
2141  if (PHINode *PN = dyn_cast<PHINode>(I)) {
2142    if (L->getHeader() == I->getParent())
2143      return PN;
2144    else
2145      // We don't currently keep track of the control flow needed to evaluate
2146      // PHIs, so we cannot handle PHIs inside of loops.
2147      return 0;
2148  }
2149
2150  // If we won't be able to constant fold this expression even if the operands
2151  // are constants, return early.
2152  if (!CanConstantFold(I)) return 0;
2153
2154  // Otherwise, we can evaluate this instruction if all of its operands are
2155  // constant or derived from a PHI node themselves.
2156  PHINode *PHI = 0;
2157  for (unsigned Op = 0, e = I->getNumOperands(); Op != e; ++Op)
2158    if (!(isa<Constant>(I->getOperand(Op)) ||
2159          isa<GlobalValue>(I->getOperand(Op)))) {
2160      PHINode *P = getConstantEvolvingPHI(I->getOperand(Op), L);
2161      if (P == 0) return 0;  // Not evolving from PHI
2162      if (PHI == 0)
2163        PHI = P;
2164      else if (PHI != P)
2165        return 0;  // Evolving from multiple different PHIs.
2166    }
2167
2168  // This is a expression evolving from a constant PHI!
2169  return PHI;
2170}
2171
2172/// EvaluateExpression - Given an expression that passes the
2173/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
2174/// in the loop has the value PHIVal.  If we can't fold this expression for some
2175/// reason, return null.
2176static Constant *EvaluateExpression(Value *V, Constant *PHIVal) {
2177  if (isa<PHINode>(V)) return PHIVal;
2178  if (Constant *C = dyn_cast<Constant>(V)) return C;
2179  Instruction *I = cast<Instruction>(V);
2180
2181  std::vector<Constant*> Operands;
2182  Operands.resize(I->getNumOperands());
2183
2184  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
2185    Operands[i] = EvaluateExpression(I->getOperand(i), PHIVal);
2186    if (Operands[i] == 0) return 0;
2187  }
2188
2189  if (const CmpInst *CI = dyn_cast<CmpInst>(I))
2190    return ConstantFoldCompareInstOperands(CI->getPredicate(),
2191                                           &Operands[0], Operands.size());
2192  else
2193    return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
2194                                    &Operands[0], Operands.size());
2195}
2196
2197/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
2198/// in the header of its containing loop, we know the loop executes a
2199/// constant number of times, and the PHI node is just a recurrence
2200/// involving constants, fold it.
2201Constant *ScalarEvolutionsImpl::
2202getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& Its, const Loop *L){
2203  std::map<PHINode*, Constant*>::iterator I =
2204    ConstantEvolutionLoopExitValue.find(PN);
2205  if (I != ConstantEvolutionLoopExitValue.end())
2206    return I->second;
2207
2208  if (Its.ugt(APInt(Its.getBitWidth(),MaxBruteForceIterations)))
2209    return ConstantEvolutionLoopExitValue[PN] = 0;  // Not going to evaluate it.
2210
2211  Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
2212
2213  // Since the loop is canonicalized, the PHI node must have two entries.  One
2214  // entry must be a constant (coming in from outside of the loop), and the
2215  // second must be derived from the same PHI.
2216  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
2217  Constant *StartCST =
2218    dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
2219  if (StartCST == 0)
2220    return RetVal = 0;  // Must be a constant.
2221
2222  Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
2223  PHINode *PN2 = getConstantEvolvingPHI(BEValue, L);
2224  if (PN2 != PN)
2225    return RetVal = 0;  // Not derived from same PHI.
2226
2227  // Execute the loop symbolically to determine the exit value.
2228  if (Its.getActiveBits() >= 32)
2229    return RetVal = 0; // More than 2^32-1 iterations?? Not doing it!
2230
2231  unsigned NumIterations = Its.getZExtValue(); // must be in range
2232  unsigned IterationNum = 0;
2233  for (Constant *PHIVal = StartCST; ; ++IterationNum) {
2234    if (IterationNum == NumIterations)
2235      return RetVal = PHIVal;  // Got exit value!
2236
2237    // Compute the value of the PHI node for the next iteration.
2238    Constant *NextPHI = EvaluateExpression(BEValue, PHIVal);
2239    if (NextPHI == PHIVal)
2240      return RetVal = NextPHI;  // Stopped evolving!
2241    if (NextPHI == 0)
2242      return 0;        // Couldn't evaluate!
2243    PHIVal = NextPHI;
2244  }
2245}
2246
2247/// ComputeIterationCountExhaustively - If the trip is known to execute a
2248/// constant number of times (the condition evolves only from constants),
2249/// try to evaluate a few iterations of the loop until we get the exit
2250/// condition gets a value of ExitWhen (true or false).  If we cannot
2251/// evaluate the trip count of the loop, return UnknownValue.
2252SCEVHandle ScalarEvolutionsImpl::
2253ComputeIterationCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) {
2254  PHINode *PN = getConstantEvolvingPHI(Cond, L);
2255  if (PN == 0) return UnknownValue;
2256
2257  // Since the loop is canonicalized, the PHI node must have two entries.  One
2258  // entry must be a constant (coming in from outside of the loop), and the
2259  // second must be derived from the same PHI.
2260  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
2261  Constant *StartCST =
2262    dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
2263  if (StartCST == 0) return UnknownValue;  // Must be a constant.
2264
2265  Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
2266  PHINode *PN2 = getConstantEvolvingPHI(BEValue, L);
2267  if (PN2 != PN) return UnknownValue;  // Not derived from same PHI.
2268
2269  // Okay, we find a PHI node that defines the trip count of this loop.  Execute
2270  // the loop symbolically to determine when the condition gets a value of
2271  // "ExitWhen".
2272  unsigned IterationNum = 0;
2273  unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis.
2274  for (Constant *PHIVal = StartCST;
2275       IterationNum != MaxIterations; ++IterationNum) {
2276    ConstantInt *CondVal =
2277      dyn_cast_or_null<ConstantInt>(EvaluateExpression(Cond, PHIVal));
2278
2279    // Couldn't symbolically evaluate.
2280    if (!CondVal) return UnknownValue;
2281
2282    if (CondVal->getValue() == uint64_t(ExitWhen)) {
2283      ConstantEvolutionLoopExitValue[PN] = PHIVal;
2284      ++NumBruteForceTripCountsComputed;
2285      return SE.getConstant(ConstantInt::get(Type::Int32Ty, IterationNum));
2286    }
2287
2288    // Compute the value of the PHI node for the next iteration.
2289    Constant *NextPHI = EvaluateExpression(BEValue, PHIVal);
2290    if (NextPHI == 0 || NextPHI == PHIVal)
2291      return UnknownValue;  // Couldn't evaluate or not making progress...
2292    PHIVal = NextPHI;
2293  }
2294
2295  // Too many iterations were needed to evaluate.
2296  return UnknownValue;
2297}
2298
2299/// getSCEVAtScope - Compute the value of the specified expression within the
2300/// indicated loop (which may be null to indicate in no loop).  If the
2301/// expression cannot be evaluated, return UnknownValue.
2302SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) {
2303  // FIXME: this should be turned into a virtual method on SCEV!
2304
2305  if (isa<SCEVConstant>(V)) return V;
2306
2307  // If this instruction is evolved from a constant-evolving PHI, compute the
2308  // exit value from the loop without using SCEVs.
2309  if (SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
2310    if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
2311      const Loop *LI = this->LI[I->getParent()];
2312      if (LI && LI->getParentLoop() == L)  // Looking for loop exit value.
2313        if (PHINode *PN = dyn_cast<PHINode>(I))
2314          if (PN->getParent() == LI->getHeader()) {
2315            // Okay, there is no closed form solution for the PHI node.  Check
2316            // to see if the loop that contains it has a known iteration count.
2317            // If so, we may be able to force computation of the exit value.
2318            SCEVHandle IterationCount = getIterationCount(LI);
2319            if (SCEVConstant *ICC = dyn_cast<SCEVConstant>(IterationCount)) {
2320              // Okay, we know how many times the containing loop executes.  If
2321              // this is a constant evolving PHI node, get the final value at
2322              // the specified iteration number.
2323              Constant *RV = getConstantEvolutionLoopExitValue(PN,
2324                                                    ICC->getValue()->getValue(),
2325                                                               LI);
2326              if (RV) return SE.getUnknown(RV);
2327            }
2328          }
2329
2330      // Okay, this is an expression that we cannot symbolically evaluate
2331      // into a SCEV.  Check to see if it's possible to symbolically evaluate
2332      // the arguments into constants, and if so, try to constant propagate the
2333      // result.  This is particularly useful for computing loop exit values.
2334      if (CanConstantFold(I)) {
2335        std::vector<Constant*> Operands;
2336        Operands.reserve(I->getNumOperands());
2337        for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
2338          Value *Op = I->getOperand(i);
2339          if (Constant *C = dyn_cast<Constant>(Op)) {
2340            Operands.push_back(C);
2341          } else {
2342            // If any of the operands is non-constant and if they are
2343            // non-integer, don't even try to analyze them with scev techniques.
2344            if (!isa<IntegerType>(Op->getType()))
2345              return V;
2346
2347            SCEVHandle OpV = getSCEVAtScope(getSCEV(Op), L);
2348            if (SCEVConstant *SC = dyn_cast<SCEVConstant>(OpV))
2349              Operands.push_back(ConstantExpr::getIntegerCast(SC->getValue(),
2350                                                              Op->getType(),
2351                                                              false));
2352            else if (SCEVUnknown *SU = dyn_cast<SCEVUnknown>(OpV)) {
2353              if (Constant *C = dyn_cast<Constant>(SU->getValue()))
2354                Operands.push_back(ConstantExpr::getIntegerCast(C,
2355                                                                Op->getType(),
2356                                                                false));
2357              else
2358                return V;
2359            } else {
2360              return V;
2361            }
2362          }
2363        }
2364
2365        Constant *C;
2366        if (const CmpInst *CI = dyn_cast<CmpInst>(I))
2367          C = ConstantFoldCompareInstOperands(CI->getPredicate(),
2368                                              &Operands[0], Operands.size());
2369        else
2370          C = ConstantFoldInstOperands(I->getOpcode(), I->getType(),
2371                                       &Operands[0], Operands.size());
2372        return SE.getUnknown(C);
2373      }
2374    }
2375
2376    // This is some other type of SCEVUnknown, just return it.
2377    return V;
2378  }
2379
2380  if (SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) {
2381    // Avoid performing the look-up in the common case where the specified
2382    // expression has no loop-variant portions.
2383    for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
2384      SCEVHandle OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
2385      if (OpAtScope != Comm->getOperand(i)) {
2386        if (OpAtScope == UnknownValue) return UnknownValue;
2387        // Okay, at least one of these operands is loop variant but might be
2388        // foldable.  Build a new instance of the folded commutative expression.
2389        std::vector<SCEVHandle> NewOps(Comm->op_begin(), Comm->op_begin()+i);
2390        NewOps.push_back(OpAtScope);
2391
2392        for (++i; i != e; ++i) {
2393          OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
2394          if (OpAtScope == UnknownValue) return UnknownValue;
2395          NewOps.push_back(OpAtScope);
2396        }
2397        if (isa<SCEVAddExpr>(Comm))
2398          return SE.getAddExpr(NewOps);
2399        if (isa<SCEVMulExpr>(Comm))
2400          return SE.getMulExpr(NewOps);
2401        if (isa<SCEVSMaxExpr>(Comm))
2402          return SE.getSMaxExpr(NewOps);
2403        if (isa<SCEVUMaxExpr>(Comm))
2404          return SE.getUMaxExpr(NewOps);
2405        assert(0 && "Unknown commutative SCEV type!");
2406      }
2407    }
2408    // If we got here, all operands are loop invariant.
2409    return Comm;
2410  }
2411
2412  if (SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
2413    SCEVHandle LHS = getSCEVAtScope(Div->getLHS(), L);
2414    if (LHS == UnknownValue) return LHS;
2415    SCEVHandle RHS = getSCEVAtScope(Div->getRHS(), L);
2416    if (RHS == UnknownValue) return RHS;
2417    if (LHS == Div->getLHS() && RHS == Div->getRHS())
2418      return Div;   // must be loop invariant
2419    return SE.getUDivExpr(LHS, RHS);
2420  }
2421
2422  // If this is a loop recurrence for a loop that does not contain L, then we
2423  // are dealing with the final value computed by the loop.
2424  if (SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
2425    if (!L || !AddRec->getLoop()->contains(L->getHeader())) {
2426      // To evaluate this recurrence, we need to know how many times the AddRec
2427      // loop iterates.  Compute this now.
2428      SCEVHandle IterationCount = getIterationCount(AddRec->getLoop());
2429      if (IterationCount == UnknownValue) return UnknownValue;
2430      IterationCount = SE.getTruncateOrZeroExtend(IterationCount,
2431                                                  AddRec->getType());
2432
2433      // If the value is affine, simplify the expression evaluation to just
2434      // Start + Step*IterationCount.
2435      if (AddRec->isAffine())
2436        return SE.getAddExpr(AddRec->getStart(),
2437                             SE.getMulExpr(IterationCount,
2438                                           AddRec->getOperand(1)));
2439
2440      // Otherwise, evaluate it the hard way.
2441      return AddRec->evaluateAtIteration(IterationCount, SE);
2442    }
2443    return UnknownValue;
2444  }
2445
2446  //assert(0 && "Unknown SCEV type!");
2447  return UnknownValue;
2448}
2449
2450/// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the
2451/// following equation:
2452///
2453///     A * X = B (mod N)
2454///
2455/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
2456/// A and B isn't important.
2457///
2458/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
2459static SCEVHandle SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
2460                                               ScalarEvolution &SE) {
2461  uint32_t BW = A.getBitWidth();
2462  assert(BW == B.getBitWidth() && "Bit widths must be the same.");
2463  assert(A != 0 && "A must be non-zero.");
2464
2465  // 1. D = gcd(A, N)
2466  //
2467  // The gcd of A and N may have only one prime factor: 2. The number of
2468  // trailing zeros in A is its multiplicity
2469  uint32_t Mult2 = A.countTrailingZeros();
2470  // D = 2^Mult2
2471
2472  // 2. Check if B is divisible by D.
2473  //
2474  // B is divisible by D if and only if the multiplicity of prime factor 2 for B
2475  // is not less than multiplicity of this prime factor for D.
2476  if (B.countTrailingZeros() < Mult2)
2477    return new SCEVCouldNotCompute();
2478
2479  // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
2480  // modulo (N / D).
2481  //
2482  // (N / D) may need BW+1 bits in its representation.  Hence, we'll use this
2483  // bit width during computations.
2484  APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
2485  APInt Mod(BW + 1, 0);
2486  Mod.set(BW - Mult2);  // Mod = N / D
2487  APInt I = AD.multiplicativeInverse(Mod);
2488
2489  // 4. Compute the minimum unsigned root of the equation:
2490  // I * (B / D) mod (N / D)
2491  APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
2492
2493  // The result is guaranteed to be less than 2^BW so we may truncate it to BW
2494  // bits.
2495  return SE.getConstant(Result.trunc(BW));
2496}
2497
2498/// SolveQuadraticEquation - Find the roots of the quadratic equation for the
2499/// given quadratic chrec {L,+,M,+,N}.  This returns either the two roots (which
2500/// might be the same) or two SCEVCouldNotCompute objects.
2501///
2502static std::pair<SCEVHandle,SCEVHandle>
2503SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
2504  assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
2505  SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
2506  SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
2507  SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
2508
2509  // We currently can only solve this if the coefficients are constants.
2510  if (!LC || !MC || !NC) {
2511    SCEV *CNC = new SCEVCouldNotCompute();
2512    return std::make_pair(CNC, CNC);
2513  }
2514
2515  uint32_t BitWidth = LC->getValue()->getValue().getBitWidth();
2516  const APInt &L = LC->getValue()->getValue();
2517  const APInt &M = MC->getValue()->getValue();
2518  const APInt &N = NC->getValue()->getValue();
2519  APInt Two(BitWidth, 2);
2520  APInt Four(BitWidth, 4);
2521
2522  {
2523    using namespace APIntOps;
2524    const APInt& C = L;
2525    // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
2526    // The B coefficient is M-N/2
2527    APInt B(M);
2528    B -= sdiv(N,Two);
2529
2530    // The A coefficient is N/2
2531    APInt A(N.sdiv(Two));
2532
2533    // Compute the B^2-4ac term.
2534    APInt SqrtTerm(B);
2535    SqrtTerm *= B;
2536    SqrtTerm -= Four * (A * C);
2537
2538    // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
2539    // integer value or else APInt::sqrt() will assert.
2540    APInt SqrtVal(SqrtTerm.sqrt());
2541
2542    // Compute the two solutions for the quadratic formula.
2543    // The divisions must be performed as signed divisions.
2544    APInt NegB(-B);
2545    APInt TwoA( A << 1 );
2546    ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA));
2547    ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA));
2548
2549    return std::make_pair(SE.getConstant(Solution1),
2550                          SE.getConstant(Solution2));
2551    } // end APIntOps namespace
2552}
2553
2554/// HowFarToZero - Return the number of times a backedge comparing the specified
2555/// value to zero will execute.  If not computable, return UnknownValue
2556SCEVHandle ScalarEvolutionsImpl::HowFarToZero(SCEV *V, const Loop *L) {
2557  // If the value is a constant
2558  if (SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
2559    // If the value is already zero, the branch will execute zero times.
2560    if (C->getValue()->isZero()) return C;
2561    return UnknownValue;  // Otherwise it will loop infinitely.
2562  }
2563
2564  SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V);
2565  if (!AddRec || AddRec->getLoop() != L)
2566    return UnknownValue;
2567
2568  if (AddRec->isAffine()) {
2569    // If this is an affine expression, the execution count of this branch is
2570    // the minimum unsigned root of the following equation:
2571    //
2572    //     Start + Step*N = 0 (mod 2^BW)
2573    //
2574    // equivalent to:
2575    //
2576    //             Step*N = -Start (mod 2^BW)
2577    //
2578    // where BW is the common bit width of Start and Step.
2579
2580    // Get the initial value for the loop.
2581    SCEVHandle Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
2582    if (isa<SCEVCouldNotCompute>(Start)) return UnknownValue;
2583
2584    SCEVHandle Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
2585
2586    if (SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step)) {
2587      // For now we handle only constant steps.
2588
2589      // First, handle unitary steps.
2590      if (StepC->getValue()->equalsInt(1))      // 1*N = -Start (mod 2^BW), so:
2591        return SE.getNegativeSCEV(Start);       //   N = -Start (as unsigned)
2592      if (StepC->getValue()->isAllOnesValue())  // -1*N = -Start (mod 2^BW), so:
2593        return Start;                           //    N = Start (as unsigned)
2594
2595      // Then, try to solve the above equation provided that Start is constant.
2596      if (SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
2597        return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
2598                                            -StartC->getValue()->getValue(),SE);
2599    }
2600  } else if (AddRec->isQuadratic() && AddRec->getType()->isInteger()) {
2601    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
2602    // the quadratic equation to solve it.
2603    std::pair<SCEVHandle,SCEVHandle> Roots = SolveQuadraticEquation(AddRec, SE);
2604    SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
2605    SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
2606    if (R1) {
2607#if 0
2608      cerr << "HFTZ: " << *V << " - sol#1: " << *R1
2609           << "  sol#2: " << *R2 << "\n";
2610#endif
2611      // Pick the smallest positive root value.
2612      if (ConstantInt *CB =
2613          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
2614                                   R1->getValue(), R2->getValue()))) {
2615        if (CB->getZExtValue() == false)
2616          std::swap(R1, R2);   // R1 is the minimum root now.
2617
2618        // We can only use this value if the chrec ends up with an exact zero
2619        // value at this index.  When solving for "X*X != 5", for example, we
2620        // should not accept a root of 2.
2621        SCEVHandle Val = AddRec->evaluateAtIteration(R1, SE);
2622        if (Val->isZero())
2623          return R1;  // We found a quadratic root!
2624      }
2625    }
2626  }
2627
2628  return UnknownValue;
2629}
2630
2631/// HowFarToNonZero - Return the number of times a backedge checking the
2632/// specified value for nonzero will execute.  If not computable, return
2633/// UnknownValue
2634SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) {
2635  // Loops that look like: while (X == 0) are very strange indeed.  We don't
2636  // handle them yet except for the trivial case.  This could be expanded in the
2637  // future as needed.
2638
2639  // If the value is a constant, check to see if it is known to be non-zero
2640  // already.  If so, the backedge will execute zero times.
2641  if (SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
2642    if (!C->getValue()->isNullValue())
2643      return SE.getIntegerSCEV(0, C->getType());
2644    return UnknownValue;  // Otherwise it will loop infinitely.
2645  }
2646
2647  // We could implement others, but I really doubt anyone writes loops like
2648  // this, and if they did, they would already be constant folded.
2649  return UnknownValue;
2650}
2651
2652/// executesAtLeastOnce - Test whether entry to the loop is protected by
2653/// a conditional between LHS and RHS.
2654bool ScalarEvolutionsImpl::executesAtLeastOnce(const Loop *L, bool isSigned,
2655                                               SCEV *LHS, SCEV *RHS) {
2656  BasicBlock *Preheader = L->getLoopPreheader();
2657  BasicBlock *PreheaderDest = L->getHeader();
2658  if (Preheader == 0) return false;
2659
2660  BranchInst *LoopEntryPredicate =
2661    dyn_cast<BranchInst>(Preheader->getTerminator());
2662  if (!LoopEntryPredicate) return false;
2663
2664  // This might be a critical edge broken out.  If the loop preheader ends in
2665  // an unconditional branch to the loop, check to see if the preheader has a
2666  // single predecessor, and if so, look for its terminator.
2667  while (LoopEntryPredicate->isUnconditional()) {
2668    PreheaderDest = Preheader;
2669    Preheader = Preheader->getSinglePredecessor();
2670    if (!Preheader) return false;  // Multiple preds.
2671
2672    LoopEntryPredicate =
2673      dyn_cast<BranchInst>(Preheader->getTerminator());
2674    if (!LoopEntryPredicate) return false;
2675  }
2676
2677  ICmpInst *ICI = dyn_cast<ICmpInst>(LoopEntryPredicate->getCondition());
2678  if (!ICI) return false;
2679
2680  // Now that we found a conditional branch that dominates the loop, check to
2681  // see if it is the comparison we are looking for.
2682  Value *PreCondLHS = ICI->getOperand(0);
2683  Value *PreCondRHS = ICI->getOperand(1);
2684  ICmpInst::Predicate Cond;
2685  if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest)
2686    Cond = ICI->getPredicate();
2687  else
2688    Cond = ICI->getInversePredicate();
2689
2690  switch (Cond) {
2691  case ICmpInst::ICMP_UGT:
2692    if (isSigned) return false;
2693    std::swap(PreCondLHS, PreCondRHS);
2694    Cond = ICmpInst::ICMP_ULT;
2695    break;
2696  case ICmpInst::ICMP_SGT:
2697    if (!isSigned) return false;
2698    std::swap(PreCondLHS, PreCondRHS);
2699    Cond = ICmpInst::ICMP_SLT;
2700    break;
2701  case ICmpInst::ICMP_ULT:
2702    if (isSigned) return false;
2703    break;
2704  case ICmpInst::ICMP_SLT:
2705    if (!isSigned) return false;
2706    break;
2707  default:
2708    return false;
2709  }
2710
2711  if (!PreCondLHS->getType()->isInteger()) return false;
2712
2713  return LHS == getSCEV(PreCondLHS) && RHS == getSCEV(PreCondRHS);
2714}
2715
2716/// HowManyLessThans - Return the number of times a backedge containing the
2717/// specified less-than comparison will execute.  If not computable, return
2718/// UnknownValue.
2719SCEVHandle ScalarEvolutionsImpl::
2720HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned) {
2721  // Only handle:  "ADDREC < LoopInvariant".
2722  if (!RHS->isLoopInvariant(L)) return UnknownValue;
2723
2724  SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS);
2725  if (!AddRec || AddRec->getLoop() != L)
2726    return UnknownValue;
2727
2728  if (AddRec->isAffine()) {
2729    // FORNOW: We only support unit strides.
2730    SCEVHandle One = SE.getIntegerSCEV(1, RHS->getType());
2731    if (AddRec->getOperand(1) != One)
2732      return UnknownValue;
2733
2734    // We know the LHS is of the form {n,+,1} and the RHS is some loop-invariant
2735    // m.  So, we count the number of iterations in which {n,+,1} < m is true.
2736    // Note that we cannot simply return max(m-n,0) because it's not safe to
2737    // treat m-n as signed nor unsigned due to overflow possibility.
2738
2739    // First, we get the value of the LHS in the first iteration: n
2740    SCEVHandle Start = AddRec->getOperand(0);
2741
2742    if (executesAtLeastOnce(L, isSigned,
2743                            SE.getMinusSCEV(AddRec->getOperand(0), One), RHS)) {
2744      // Since we know that the condition is true in order to enter the loop,
2745      // we know that it will run exactly m-n times.
2746      return SE.getMinusSCEV(RHS, Start);
2747    } else {
2748      // Then, we get the value of the LHS in the first iteration in which the
2749      // above condition doesn't hold.  This equals to max(m,n).
2750      SCEVHandle End = isSigned ? SE.getSMaxExpr(RHS, Start)
2751                                : SE.getUMaxExpr(RHS, Start);
2752
2753      // Finally, we subtract these two values to get the number of times the
2754      // backedge is executed: max(m,n)-n.
2755      return SE.getMinusSCEV(End, Start);
2756    }
2757  }
2758
2759  return UnknownValue;
2760}
2761
2762/// getNumIterationsInRange - Return the number of iterations of this loop that
2763/// produce values in the specified constant range.  Another way of looking at
2764/// this is that it returns the first iteration number where the value is not in
2765/// the condition, thus computing the exit count. If the iteration count can't
2766/// be computed, an instance of SCEVCouldNotCompute is returned.
2767SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
2768                                                   ScalarEvolution &SE) const {
2769  if (Range.isFullSet())  // Infinite loop.
2770    return new SCEVCouldNotCompute();
2771
2772  // If the start is a non-zero constant, shift the range to simplify things.
2773  if (SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
2774    if (!SC->getValue()->isZero()) {
2775      std::vector<SCEVHandle> Operands(op_begin(), op_end());
2776      Operands[0] = SE.getIntegerSCEV(0, SC->getType());
2777      SCEVHandle Shifted = SE.getAddRecExpr(Operands, getLoop());
2778      if (SCEVAddRecExpr *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
2779        return ShiftedAddRec->getNumIterationsInRange(
2780                           Range.subtract(SC->getValue()->getValue()), SE);
2781      // This is strange and shouldn't happen.
2782      return new SCEVCouldNotCompute();
2783    }
2784
2785  // The only time we can solve this is when we have all constant indices.
2786  // Otherwise, we cannot determine the overflow conditions.
2787  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
2788    if (!isa<SCEVConstant>(getOperand(i)))
2789      return new SCEVCouldNotCompute();
2790
2791
2792  // Okay at this point we know that all elements of the chrec are constants and
2793  // that the start element is zero.
2794
2795  // First check to see if the range contains zero.  If not, the first
2796  // iteration exits.
2797  if (!Range.contains(APInt(getBitWidth(),0)))
2798    return SE.getConstant(ConstantInt::get(getType(),0));
2799
2800  if (isAffine()) {
2801    // If this is an affine expression then we have this situation:
2802    //   Solve {0,+,A} in Range  ===  Ax in Range
2803
2804    // We know that zero is in the range.  If A is positive then we know that
2805    // the upper value of the range must be the first possible exit value.
2806    // If A is negative then the lower of the range is the last possible loop
2807    // value.  Also note that we already checked for a full range.
2808    APInt One(getBitWidth(),1);
2809    APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
2810    APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
2811
2812    // The exit value should be (End+A)/A.
2813    APInt ExitVal = (End + A).udiv(A);
2814    ConstantInt *ExitValue = ConstantInt::get(ExitVal);
2815
2816    // Evaluate at the exit value.  If we really did fall out of the valid
2817    // range, then we computed our trip count, otherwise wrap around or other
2818    // things must have happened.
2819    ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
2820    if (Range.contains(Val->getValue()))
2821      return new SCEVCouldNotCompute();  // Something strange happened
2822
2823    // Ensure that the previous value is in the range.  This is a sanity check.
2824    assert(Range.contains(
2825           EvaluateConstantChrecAtConstant(this,
2826           ConstantInt::get(ExitVal - One), SE)->getValue()) &&
2827           "Linear scev computation is off in a bad way!");
2828    return SE.getConstant(ExitValue);
2829  } else if (isQuadratic()) {
2830    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
2831    // quadratic equation to solve it.  To do this, we must frame our problem in
2832    // terms of figuring out when zero is crossed, instead of when
2833    // Range.getUpper() is crossed.
2834    std::vector<SCEVHandle> NewOps(op_begin(), op_end());
2835    NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper()));
2836    SCEVHandle NewAddRec = SE.getAddRecExpr(NewOps, getLoop());
2837
2838    // Next, solve the constructed addrec
2839    std::pair<SCEVHandle,SCEVHandle> Roots =
2840      SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
2841    SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
2842    SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
2843    if (R1) {
2844      // Pick the smallest positive root value.
2845      if (ConstantInt *CB =
2846          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
2847                                   R1->getValue(), R2->getValue()))) {
2848        if (CB->getZExtValue() == false)
2849          std::swap(R1, R2);   // R1 is the minimum root now.
2850
2851        // Make sure the root is not off by one.  The returned iteration should
2852        // not be in the range, but the previous one should be.  When solving
2853        // for "X*X < 5", for example, we should not return a root of 2.
2854        ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this,
2855                                                             R1->getValue(),
2856                                                             SE);
2857        if (Range.contains(R1Val->getValue())) {
2858          // The next iteration must be out of the range...
2859          ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()+1);
2860
2861          R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
2862          if (!Range.contains(R1Val->getValue()))
2863            return SE.getConstant(NextVal);
2864          return new SCEVCouldNotCompute();  // Something strange happened
2865        }
2866
2867        // If R1 was not in the range, then it is a good return value.  Make
2868        // sure that R1-1 WAS in the range though, just in case.
2869        ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()-1);
2870        R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
2871        if (Range.contains(R1Val->getValue()))
2872          return R1;
2873        return new SCEVCouldNotCompute();  // Something strange happened
2874      }
2875    }
2876  }
2877
2878  // Fallback, if this is a general polynomial, figure out the progression
2879  // through brute force: evaluate until we find an iteration that fails the
2880  // test.  This is likely to be slow, but getting an accurate trip count is
2881  // incredibly important, we will be able to simplify the exit test a lot, and
2882  // we are almost guaranteed to get a trip count in this case.
2883  ConstantInt *TestVal = ConstantInt::get(getType(), 0);
2884  ConstantInt *EndVal  = TestVal;  // Stop when we wrap around.
2885  do {
2886    ++NumBruteForceEvaluations;
2887    SCEVHandle Val = evaluateAtIteration(SE.getConstant(TestVal), SE);
2888    if (!isa<SCEVConstant>(Val))  // This shouldn't happen.
2889      return new SCEVCouldNotCompute();
2890
2891    // Check to see if we found the value!
2892    if (!Range.contains(cast<SCEVConstant>(Val)->getValue()->getValue()))
2893      return SE.getConstant(TestVal);
2894
2895    // Increment to test the next index.
2896    TestVal = ConstantInt::get(TestVal->getValue()+1);
2897  } while (TestVal != EndVal);
2898
2899  return new SCEVCouldNotCompute();
2900}
2901
2902
2903
2904//===----------------------------------------------------------------------===//
2905//                   ScalarEvolution Class Implementation
2906//===----------------------------------------------------------------------===//
2907
2908bool ScalarEvolution::runOnFunction(Function &F) {
2909  Impl = new ScalarEvolutionsImpl(*this, F, getAnalysis<LoopInfo>());
2910  return false;
2911}
2912
2913void ScalarEvolution::releaseMemory() {
2914  delete (ScalarEvolutionsImpl*)Impl;
2915  Impl = 0;
2916}
2917
2918void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
2919  AU.setPreservesAll();
2920  AU.addRequiredTransitive<LoopInfo>();
2921}
2922
2923SCEVHandle ScalarEvolution::getSCEV(Value *V) const {
2924  return ((ScalarEvolutionsImpl*)Impl)->getSCEV(V);
2925}
2926
2927/// hasSCEV - Return true if the SCEV for this value has already been
2928/// computed.
2929bool ScalarEvolution::hasSCEV(Value *V) const {
2930  return ((ScalarEvolutionsImpl*)Impl)->hasSCEV(V);
2931}
2932
2933
2934/// setSCEV - Insert the specified SCEV into the map of current SCEVs for
2935/// the specified value.
2936void ScalarEvolution::setSCEV(Value *V, const SCEVHandle &H) {
2937  ((ScalarEvolutionsImpl*)Impl)->setSCEV(V, H);
2938}
2939
2940
2941SCEVHandle ScalarEvolution::getIterationCount(const Loop *L) const {
2942  return ((ScalarEvolutionsImpl*)Impl)->getIterationCount(L);
2943}
2944
2945bool ScalarEvolution::hasLoopInvariantIterationCount(const Loop *L) const {
2946  return !isa<SCEVCouldNotCompute>(getIterationCount(L));
2947}
2948
2949SCEVHandle ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) const {
2950  return ((ScalarEvolutionsImpl*)Impl)->getSCEVAtScope(getSCEV(V), L);
2951}
2952
2953void ScalarEvolution::deleteValueFromRecords(Value *V) const {
2954  return ((ScalarEvolutionsImpl*)Impl)->deleteValueFromRecords(V);
2955}
2956
2957static void PrintLoopInfo(std::ostream &OS, const ScalarEvolution *SE,
2958                          const Loop *L) {
2959  // Print all inner loops first
2960  for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
2961    PrintLoopInfo(OS, SE, *I);
2962
2963  OS << "Loop " << L->getHeader()->getName() << ": ";
2964
2965  SmallVector<BasicBlock*, 8> ExitBlocks;
2966  L->getExitBlocks(ExitBlocks);
2967  if (ExitBlocks.size() != 1)
2968    OS << "<multiple exits> ";
2969
2970  if (SE->hasLoopInvariantIterationCount(L)) {
2971    OS << *SE->getIterationCount(L) << " iterations! ";
2972  } else {
2973    OS << "Unpredictable iteration count. ";
2974  }
2975
2976  OS << "\n";
2977}
2978
2979void ScalarEvolution::print(std::ostream &OS, const Module* ) const {
2980  Function &F = ((ScalarEvolutionsImpl*)Impl)->F;
2981  LoopInfo &LI = ((ScalarEvolutionsImpl*)Impl)->LI;
2982
2983  OS << "Classifying expressions for: " << F.getName() << "\n";
2984  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
2985    if (I->getType()->isInteger()) {
2986      OS << *I;
2987      OS << "  --> ";
2988      SCEVHandle SV = getSCEV(&*I);
2989      SV->print(OS);
2990      OS << "\t\t";
2991
2992      if (const Loop *L = LI.getLoopFor((*I).getParent())) {
2993        OS << "Exits: ";
2994        SCEVHandle ExitValue = getSCEVAtScope(&*I, L->getParentLoop());
2995        if (isa<SCEVCouldNotCompute>(ExitValue)) {
2996          OS << "<<Unknown>>";
2997        } else {
2998          OS << *ExitValue;
2999        }
3000      }
3001
3002
3003      OS << "\n";
3004    }
3005
3006  OS << "Determining loop execution counts for: " << F.getName() << "\n";
3007  for (LoopInfo::iterator I = LI.begin(), E = LI.end(); I != E; ++I)
3008    PrintLoopInfo(OS, this, *I);
3009}
3010