ScalarEvolution.cpp revision 73c6b7127aff4499e4d6a2edb219685aee178ee1
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file contains the implementation of the scalar evolution analysis
11// engine, which is used primarily to analyze expressions involving induction
12// variables in loops.
13//
14// There are several aspects to this library.  First is the representation of
15// scalar expressions, which are represented as subclasses of the SCEV class.
16// These classes are used to represent certain types of subexpressions that we
17// can handle.  These classes are reference counted, managed by the const SCEV *
18// class.  We only create one SCEV of a particular shape, so pointer-comparisons
19// for equality are legal.
20//
21// One important aspect of the SCEV objects is that they are never cyclic, even
22// if there is a cycle in the dataflow for an expression (ie, a PHI node).  If
23// the PHI node is one of the idioms that we can represent (e.g., a polynomial
24// recurrence) then we represent it directly as a recurrence node, otherwise we
25// represent it as a SCEVUnknown node.
26//
27// In addition to being able to represent expressions of various types, we also
28// have folders that are used to build the *canonical* representation for a
29// particular expression.  These folders are capable of using a variety of
30// rewrite rules to simplify the expressions.
31//
32// Once the folders are defined, we can implement the more interesting
33// higher-level code, such as the code that recognizes PHI nodes of various
34// types, computes the execution count of a loop, etc.
35//
36// TODO: We should use these routines and value representations to implement
37// dependence analysis!
38//
39//===----------------------------------------------------------------------===//
40//
41// There are several good references for the techniques used in this analysis.
42//
43//  Chains of recurrences -- a method to expedite the evaluation
44//  of closed-form functions
45//  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
46//
47//  On computational properties of chains of recurrences
48//  Eugene V. Zima
49//
50//  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
51//  Robert A. van Engelen
52//
53//  Efficient Symbolic Analysis for Optimizing Compilers
54//  Robert A. van Engelen
55//
56//  Using the chains of recurrences algebra for data dependence testing and
57//  induction variable substitution
58//  MS Thesis, Johnie Birch
59//
60//===----------------------------------------------------------------------===//
61
62#define DEBUG_TYPE "scalar-evolution"
63#include "llvm/Analysis/ScalarEvolutionExpressions.h"
64#include "llvm/Constants.h"
65#include "llvm/DerivedTypes.h"
66#include "llvm/GlobalVariable.h"
67#include "llvm/Instructions.h"
68#include "llvm/LLVMContext.h"
69#include "llvm/Analysis/ConstantFolding.h"
70#include "llvm/Analysis/Dominators.h"
71#include "llvm/Analysis/LoopInfo.h"
72#include "llvm/Analysis/ValueTracking.h"
73#include "llvm/Assembly/Writer.h"
74#include "llvm/Target/TargetData.h"
75#include "llvm/Support/CommandLine.h"
76#include "llvm/Support/Compiler.h"
77#include "llvm/Support/ConstantRange.h"
78#include "llvm/Support/ErrorHandling.h"
79#include "llvm/Support/GetElementPtrTypeIterator.h"
80#include "llvm/Support/InstIterator.h"
81#include "llvm/Support/MathExtras.h"
82#include "llvm/Support/raw_ostream.h"
83#include "llvm/ADT/Statistic.h"
84#include "llvm/ADT/STLExtras.h"
85#include "llvm/ADT/SmallPtrSet.h"
86#include <algorithm>
87using namespace llvm;
88
89STATISTIC(NumArrayLenItCounts,
90          "Number of trip counts computed with array length");
91STATISTIC(NumTripCountsComputed,
92          "Number of loops with predictable loop counts");
93STATISTIC(NumTripCountsNotComputed,
94          "Number of loops without predictable loop counts");
95STATISTIC(NumBruteForceTripCountsComputed,
96          "Number of loops with trip counts computed by force");
97
98static cl::opt<unsigned>
99MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
100                        cl::desc("Maximum number of iterations SCEV will "
101                                 "symbolically execute a constant "
102                                 "derived loop"),
103                        cl::init(100));
104
105static RegisterPass<ScalarEvolution>
106R("scalar-evolution", "Scalar Evolution Analysis", false, true);
107char ScalarEvolution::ID = 0;
108
109//===----------------------------------------------------------------------===//
110//                           SCEV class definitions
111//===----------------------------------------------------------------------===//
112
113//===----------------------------------------------------------------------===//
114// Implementation of the SCEV class.
115//
116
117SCEV::~SCEV() {}
118
119void SCEV::dump() const {
120  print(errs());
121  errs() << '\n';
122}
123
124void SCEV::print(std::ostream &o) const {
125  raw_os_ostream OS(o);
126  print(OS);
127}
128
129bool SCEV::isZero() const {
130  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
131    return SC->getValue()->isZero();
132  return false;
133}
134
135bool SCEV::isOne() const {
136  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
137    return SC->getValue()->isOne();
138  return false;
139}
140
141bool SCEV::isAllOnesValue() const {
142  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
143    return SC->getValue()->isAllOnesValue();
144  return false;
145}
146
147SCEVCouldNotCompute::SCEVCouldNotCompute() :
148  SCEV(FoldingSetNodeID(), scCouldNotCompute) {}
149
150bool SCEVCouldNotCompute::isLoopInvariant(const Loop *L) const {
151  LLVM_UNREACHABLE("Attempt to use a SCEVCouldNotCompute object!");
152  return false;
153}
154
155const Type *SCEVCouldNotCompute::getType() const {
156  LLVM_UNREACHABLE("Attempt to use a SCEVCouldNotCompute object!");
157  return 0;
158}
159
160bool SCEVCouldNotCompute::hasComputableLoopEvolution(const Loop *L) const {
161  LLVM_UNREACHABLE("Attempt to use a SCEVCouldNotCompute object!");
162  return false;
163}
164
165const SCEV *
166SCEVCouldNotCompute::replaceSymbolicValuesWithConcrete(
167                                                    const SCEV *Sym,
168                                                    const SCEV *Conc,
169                                                    ScalarEvolution &SE) const {
170  return this;
171}
172
173void SCEVCouldNotCompute::print(raw_ostream &OS) const {
174  OS << "***COULDNOTCOMPUTE***";
175}
176
177bool SCEVCouldNotCompute::classof(const SCEV *S) {
178  return S->getSCEVType() == scCouldNotCompute;
179}
180
181const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
182  FoldingSetNodeID ID;
183  ID.AddInteger(scConstant);
184  ID.AddPointer(V);
185  void *IP = 0;
186  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
187  SCEV *S = SCEVAllocator.Allocate<SCEVConstant>();
188  new (S) SCEVConstant(ID, V);
189  UniqueSCEVs.InsertNode(S, IP);
190  return S;
191}
192
193const SCEV *ScalarEvolution::getConstant(const APInt& Val) {
194  return getConstant(ConstantInt::get(Val));
195}
196
197const SCEV *
198ScalarEvolution::getConstant(const Type *Ty, uint64_t V, bool isSigned) {
199  return getConstant(ConstantInt::get(cast<IntegerType>(Ty), V, isSigned));
200}
201
202const Type *SCEVConstant::getType() const { return V->getType(); }
203
204void SCEVConstant::print(raw_ostream &OS) const {
205  WriteAsOperand(OS, V, false);
206}
207
208SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeID &ID,
209                           unsigned SCEVTy, const SCEV *op, const Type *ty)
210  : SCEV(ID, SCEVTy), Op(op), Ty(ty) {}
211
212bool SCEVCastExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
213  return Op->dominates(BB, DT);
214}
215
216SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeID &ID,
217                                   const SCEV *op, const Type *ty)
218  : SCEVCastExpr(ID, scTruncate, op, ty) {
219  assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
220         (Ty->isInteger() || isa<PointerType>(Ty)) &&
221         "Cannot truncate non-integer value!");
222}
223
224void SCEVTruncateExpr::print(raw_ostream &OS) const {
225  OS << "(trunc " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
226}
227
228SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeID &ID,
229                                       const SCEV *op, const Type *ty)
230  : SCEVCastExpr(ID, scZeroExtend, op, ty) {
231  assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
232         (Ty->isInteger() || isa<PointerType>(Ty)) &&
233         "Cannot zero extend non-integer value!");
234}
235
236void SCEVZeroExtendExpr::print(raw_ostream &OS) const {
237  OS << "(zext " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
238}
239
240SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeID &ID,
241                                       const SCEV *op, const Type *ty)
242  : SCEVCastExpr(ID, scSignExtend, op, ty) {
243  assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
244         (Ty->isInteger() || isa<PointerType>(Ty)) &&
245         "Cannot sign extend non-integer value!");
246}
247
248void SCEVSignExtendExpr::print(raw_ostream &OS) const {
249  OS << "(sext " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
250}
251
252void SCEVCommutativeExpr::print(raw_ostream &OS) const {
253  assert(Operands.size() > 1 && "This plus expr shouldn't exist!");
254  const char *OpStr = getOperationStr();
255  OS << "(" << *Operands[0];
256  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
257    OS << OpStr << *Operands[i];
258  OS << ")";
259}
260
261const SCEV *
262SCEVCommutativeExpr::replaceSymbolicValuesWithConcrete(
263                                                    const SCEV *Sym,
264                                                    const SCEV *Conc,
265                                                    ScalarEvolution &SE) const {
266  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
267    const SCEV *H =
268      getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE);
269    if (H != getOperand(i)) {
270      SmallVector<const SCEV *, 8> NewOps;
271      NewOps.reserve(getNumOperands());
272      for (unsigned j = 0; j != i; ++j)
273        NewOps.push_back(getOperand(j));
274      NewOps.push_back(H);
275      for (++i; i != e; ++i)
276        NewOps.push_back(getOperand(i)->
277                         replaceSymbolicValuesWithConcrete(Sym, Conc, SE));
278
279      if (isa<SCEVAddExpr>(this))
280        return SE.getAddExpr(NewOps);
281      else if (isa<SCEVMulExpr>(this))
282        return SE.getMulExpr(NewOps);
283      else if (isa<SCEVSMaxExpr>(this))
284        return SE.getSMaxExpr(NewOps);
285      else if (isa<SCEVUMaxExpr>(this))
286        return SE.getUMaxExpr(NewOps);
287      else
288        LLVM_UNREACHABLE("Unknown commutative expr!");
289    }
290  }
291  return this;
292}
293
294bool SCEVNAryExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
295  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
296    if (!getOperand(i)->dominates(BB, DT))
297      return false;
298  }
299  return true;
300}
301
302bool SCEVUDivExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
303  return LHS->dominates(BB, DT) && RHS->dominates(BB, DT);
304}
305
306void SCEVUDivExpr::print(raw_ostream &OS) const {
307  OS << "(" << *LHS << " /u " << *RHS << ")";
308}
309
310const Type *SCEVUDivExpr::getType() const {
311  // In most cases the types of LHS and RHS will be the same, but in some
312  // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
313  // depend on the type for correctness, but handling types carefully can
314  // avoid extra casts in the SCEVExpander. The LHS is more likely to be
315  // a pointer type than the RHS, so use the RHS' type here.
316  return RHS->getType();
317}
318
319const SCEV *
320SCEVAddRecExpr::replaceSymbolicValuesWithConcrete(const SCEV *Sym,
321                                                  const SCEV *Conc,
322                                                  ScalarEvolution &SE) const {
323  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
324    const SCEV *H =
325      getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE);
326    if (H != getOperand(i)) {
327      SmallVector<const SCEV *, 8> NewOps;
328      NewOps.reserve(getNumOperands());
329      for (unsigned j = 0; j != i; ++j)
330        NewOps.push_back(getOperand(j));
331      NewOps.push_back(H);
332      for (++i; i != e; ++i)
333        NewOps.push_back(getOperand(i)->
334                         replaceSymbolicValuesWithConcrete(Sym, Conc, SE));
335
336      return SE.getAddRecExpr(NewOps, L);
337    }
338  }
339  return this;
340}
341
342
343bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const {
344  // Add recurrences are never invariant in the function-body (null loop).
345  if (!QueryLoop)
346    return false;
347
348  // This recurrence is variant w.r.t. QueryLoop if QueryLoop contains L.
349  if (QueryLoop->contains(L->getHeader()))
350    return false;
351
352  // This recurrence is variant w.r.t. QueryLoop if any of its operands
353  // are variant.
354  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
355    if (!getOperand(i)->isLoopInvariant(QueryLoop))
356      return false;
357
358  // Otherwise it's loop-invariant.
359  return true;
360}
361
362void SCEVAddRecExpr::print(raw_ostream &OS) const {
363  OS << "{" << *Operands[0];
364  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
365    OS << ",+," << *Operands[i];
366  OS << "}<" << L->getHeader()->getName() + ">";
367}
368
369bool SCEVUnknown::isLoopInvariant(const Loop *L) const {
370  // All non-instruction values are loop invariant.  All instructions are loop
371  // invariant if they are not contained in the specified loop.
372  // Instructions are never considered invariant in the function body
373  // (null loop) because they are defined within the "loop".
374  if (Instruction *I = dyn_cast<Instruction>(V))
375    return L && !L->contains(I->getParent());
376  return true;
377}
378
379bool SCEVUnknown::dominates(BasicBlock *BB, DominatorTree *DT) const {
380  if (Instruction *I = dyn_cast<Instruction>(getValue()))
381    return DT->dominates(I->getParent(), BB);
382  return true;
383}
384
385const Type *SCEVUnknown::getType() const {
386  return V->getType();
387}
388
389void SCEVUnknown::print(raw_ostream &OS) const {
390  WriteAsOperand(OS, V, false);
391}
392
393//===----------------------------------------------------------------------===//
394//                               SCEV Utilities
395//===----------------------------------------------------------------------===//
396
397namespace {
398  /// SCEVComplexityCompare - Return true if the complexity of the LHS is less
399  /// than the complexity of the RHS.  This comparator is used to canonicalize
400  /// expressions.
401  class VISIBILITY_HIDDEN SCEVComplexityCompare {
402    LoopInfo *LI;
403  public:
404    explicit SCEVComplexityCompare(LoopInfo *li) : LI(li) {}
405
406    bool operator()(const SCEV *LHS, const SCEV *RHS) const {
407      // Primarily, sort the SCEVs by their getSCEVType().
408      if (LHS->getSCEVType() != RHS->getSCEVType())
409        return LHS->getSCEVType() < RHS->getSCEVType();
410
411      // Aside from the getSCEVType() ordering, the particular ordering
412      // isn't very important except that it's beneficial to be consistent,
413      // so that (a + b) and (b + a) don't end up as different expressions.
414
415      // Sort SCEVUnknown values with some loose heuristics. TODO: This is
416      // not as complete as it could be.
417      if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS)) {
418        const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
419
420        // Order pointer values after integer values. This helps SCEVExpander
421        // form GEPs.
422        if (isa<PointerType>(LU->getType()) && !isa<PointerType>(RU->getType()))
423          return false;
424        if (isa<PointerType>(RU->getType()) && !isa<PointerType>(LU->getType()))
425          return true;
426
427        // Compare getValueID values.
428        if (LU->getValue()->getValueID() != RU->getValue()->getValueID())
429          return LU->getValue()->getValueID() < RU->getValue()->getValueID();
430
431        // Sort arguments by their position.
432        if (const Argument *LA = dyn_cast<Argument>(LU->getValue())) {
433          const Argument *RA = cast<Argument>(RU->getValue());
434          return LA->getArgNo() < RA->getArgNo();
435        }
436
437        // For instructions, compare their loop depth, and their opcode.
438        // This is pretty loose.
439        if (Instruction *LV = dyn_cast<Instruction>(LU->getValue())) {
440          Instruction *RV = cast<Instruction>(RU->getValue());
441
442          // Compare loop depths.
443          if (LI->getLoopDepth(LV->getParent()) !=
444              LI->getLoopDepth(RV->getParent()))
445            return LI->getLoopDepth(LV->getParent()) <
446                   LI->getLoopDepth(RV->getParent());
447
448          // Compare opcodes.
449          if (LV->getOpcode() != RV->getOpcode())
450            return LV->getOpcode() < RV->getOpcode();
451
452          // Compare the number of operands.
453          if (LV->getNumOperands() != RV->getNumOperands())
454            return LV->getNumOperands() < RV->getNumOperands();
455        }
456
457        return false;
458      }
459
460      // Compare constant values.
461      if (const SCEVConstant *LC = dyn_cast<SCEVConstant>(LHS)) {
462        const SCEVConstant *RC = cast<SCEVConstant>(RHS);
463        if (LC->getValue()->getBitWidth() != RC->getValue()->getBitWidth())
464          return LC->getValue()->getBitWidth() < RC->getValue()->getBitWidth();
465        return LC->getValue()->getValue().ult(RC->getValue()->getValue());
466      }
467
468      // Compare addrec loop depths.
469      if (const SCEVAddRecExpr *LA = dyn_cast<SCEVAddRecExpr>(LHS)) {
470        const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
471        if (LA->getLoop()->getLoopDepth() != RA->getLoop()->getLoopDepth())
472          return LA->getLoop()->getLoopDepth() < RA->getLoop()->getLoopDepth();
473      }
474
475      // Lexicographically compare n-ary expressions.
476      if (const SCEVNAryExpr *LC = dyn_cast<SCEVNAryExpr>(LHS)) {
477        const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
478        for (unsigned i = 0, e = LC->getNumOperands(); i != e; ++i) {
479          if (i >= RC->getNumOperands())
480            return false;
481          if (operator()(LC->getOperand(i), RC->getOperand(i)))
482            return true;
483          if (operator()(RC->getOperand(i), LC->getOperand(i)))
484            return false;
485        }
486        return LC->getNumOperands() < RC->getNumOperands();
487      }
488
489      // Lexicographically compare udiv expressions.
490      if (const SCEVUDivExpr *LC = dyn_cast<SCEVUDivExpr>(LHS)) {
491        const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
492        if (operator()(LC->getLHS(), RC->getLHS()))
493          return true;
494        if (operator()(RC->getLHS(), LC->getLHS()))
495          return false;
496        if (operator()(LC->getRHS(), RC->getRHS()))
497          return true;
498        if (operator()(RC->getRHS(), LC->getRHS()))
499          return false;
500        return false;
501      }
502
503      // Compare cast expressions by operand.
504      if (const SCEVCastExpr *LC = dyn_cast<SCEVCastExpr>(LHS)) {
505        const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
506        return operator()(LC->getOperand(), RC->getOperand());
507      }
508
509      LLVM_UNREACHABLE("Unknown SCEV kind!");
510      return false;
511    }
512  };
513}
514
515/// GroupByComplexity - Given a list of SCEV objects, order them by their
516/// complexity, and group objects of the same complexity together by value.
517/// When this routine is finished, we know that any duplicates in the vector are
518/// consecutive and that complexity is monotonically increasing.
519///
520/// Note that we go take special precautions to ensure that we get determinstic
521/// results from this routine.  In other words, we don't want the results of
522/// this to depend on where the addresses of various SCEV objects happened to
523/// land in memory.
524///
525static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
526                              LoopInfo *LI) {
527  if (Ops.size() < 2) return;  // Noop
528  if (Ops.size() == 2) {
529    // This is the common case, which also happens to be trivially simple.
530    // Special case it.
531    if (SCEVComplexityCompare(LI)(Ops[1], Ops[0]))
532      std::swap(Ops[0], Ops[1]);
533    return;
534  }
535
536  // Do the rough sort by complexity.
537  std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI));
538
539  // Now that we are sorted by complexity, group elements of the same
540  // complexity.  Note that this is, at worst, N^2, but the vector is likely to
541  // be extremely short in practice.  Note that we take this approach because we
542  // do not want to depend on the addresses of the objects we are grouping.
543  for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
544    const SCEV *S = Ops[i];
545    unsigned Complexity = S->getSCEVType();
546
547    // If there are any objects of the same complexity and same value as this
548    // one, group them.
549    for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
550      if (Ops[j] == S) { // Found a duplicate.
551        // Move it to immediately after i'th element.
552        std::swap(Ops[i+1], Ops[j]);
553        ++i;   // no need to rescan it.
554        if (i == e-2) return;  // Done!
555      }
556    }
557  }
558}
559
560
561
562//===----------------------------------------------------------------------===//
563//                      Simple SCEV method implementations
564//===----------------------------------------------------------------------===//
565
566/// BinomialCoefficient - Compute BC(It, K).  The result has width W.
567/// Assume, K > 0.
568static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
569                                      ScalarEvolution &SE,
570                                      const Type* ResultTy) {
571  // Handle the simplest case efficiently.
572  if (K == 1)
573    return SE.getTruncateOrZeroExtend(It, ResultTy);
574
575  // We are using the following formula for BC(It, K):
576  //
577  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
578  //
579  // Suppose, W is the bitwidth of the return value.  We must be prepared for
580  // overflow.  Hence, we must assure that the result of our computation is
581  // equal to the accurate one modulo 2^W.  Unfortunately, division isn't
582  // safe in modular arithmetic.
583  //
584  // However, this code doesn't use exactly that formula; the formula it uses
585  // is something like the following, where T is the number of factors of 2 in
586  // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
587  // exponentiation:
588  //
589  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
590  //
591  // This formula is trivially equivalent to the previous formula.  However,
592  // this formula can be implemented much more efficiently.  The trick is that
593  // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
594  // arithmetic.  To do exact division in modular arithmetic, all we have
595  // to do is multiply by the inverse.  Therefore, this step can be done at
596  // width W.
597  //
598  // The next issue is how to safely do the division by 2^T.  The way this
599  // is done is by doing the multiplication step at a width of at least W + T
600  // bits.  This way, the bottom W+T bits of the product are accurate. Then,
601  // when we perform the division by 2^T (which is equivalent to a right shift
602  // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get
603  // truncated out after the division by 2^T.
604  //
605  // In comparison to just directly using the first formula, this technique
606  // is much more efficient; using the first formula requires W * K bits,
607  // but this formula less than W + K bits. Also, the first formula requires
608  // a division step, whereas this formula only requires multiplies and shifts.
609  //
610  // It doesn't matter whether the subtraction step is done in the calculation
611  // width or the input iteration count's width; if the subtraction overflows,
612  // the result must be zero anyway.  We prefer here to do it in the width of
613  // the induction variable because it helps a lot for certain cases; CodeGen
614  // isn't smart enough to ignore the overflow, which leads to much less
615  // efficient code if the width of the subtraction is wider than the native
616  // register width.
617  //
618  // (It's possible to not widen at all by pulling out factors of 2 before
619  // the multiplication; for example, K=2 can be calculated as
620  // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
621  // extra arithmetic, so it's not an obvious win, and it gets
622  // much more complicated for K > 3.)
623
624  // Protection from insane SCEVs; this bound is conservative,
625  // but it probably doesn't matter.
626  if (K > 1000)
627    return SE.getCouldNotCompute();
628
629  unsigned W = SE.getTypeSizeInBits(ResultTy);
630
631  // Calculate K! / 2^T and T; we divide out the factors of two before
632  // multiplying for calculating K! / 2^T to avoid overflow.
633  // Other overflow doesn't matter because we only care about the bottom
634  // W bits of the result.
635  APInt OddFactorial(W, 1);
636  unsigned T = 1;
637  for (unsigned i = 3; i <= K; ++i) {
638    APInt Mult(W, i);
639    unsigned TwoFactors = Mult.countTrailingZeros();
640    T += TwoFactors;
641    Mult = Mult.lshr(TwoFactors);
642    OddFactorial *= Mult;
643  }
644
645  // We need at least W + T bits for the multiplication step
646  unsigned CalculationBits = W + T;
647
648  // Calcuate 2^T, at width T+W.
649  APInt DivFactor = APInt(CalculationBits, 1).shl(T);
650
651  // Calculate the multiplicative inverse of K! / 2^T;
652  // this multiplication factor will perform the exact division by
653  // K! / 2^T.
654  APInt Mod = APInt::getSignedMinValue(W+1);
655  APInt MultiplyFactor = OddFactorial.zext(W+1);
656  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
657  MultiplyFactor = MultiplyFactor.trunc(W);
658
659  // Calculate the product, at width T+W
660  const IntegerType *CalculationTy = IntegerType::get(CalculationBits);
661  const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
662  for (unsigned i = 1; i != K; ++i) {
663    const SCEV *S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType()));
664    Dividend = SE.getMulExpr(Dividend,
665                             SE.getTruncateOrZeroExtend(S, CalculationTy));
666  }
667
668  // Divide by 2^T
669  const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
670
671  // Truncate the result, and divide by K! / 2^T.
672
673  return SE.getMulExpr(SE.getConstant(MultiplyFactor),
674                       SE.getTruncateOrZeroExtend(DivResult, ResultTy));
675}
676
677/// evaluateAtIteration - Return the value of this chain of recurrences at
678/// the specified iteration number.  We can evaluate this recurrence by
679/// multiplying each element in the chain by the binomial coefficient
680/// corresponding to it.  In other words, we can evaluate {A,+,B,+,C,+,D} as:
681///
682///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
683///
684/// where BC(It, k) stands for binomial coefficient.
685///
686const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
687                                               ScalarEvolution &SE) const {
688  const SCEV *Result = getStart();
689  for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
690    // The computation is correct in the face of overflow provided that the
691    // multiplication is performed _after_ the evaluation of the binomial
692    // coefficient.
693    const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType());
694    if (isa<SCEVCouldNotCompute>(Coeff))
695      return Coeff;
696
697    Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
698  }
699  return Result;
700}
701
702//===----------------------------------------------------------------------===//
703//                    SCEV Expression folder implementations
704//===----------------------------------------------------------------------===//
705
706const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,
707                                            const Type *Ty) {
708  assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
709         "This is not a truncating conversion!");
710  assert(isSCEVable(Ty) &&
711         "This is not a conversion to a SCEVable type!");
712  Ty = getEffectiveSCEVType(Ty);
713
714  FoldingSetNodeID ID;
715  ID.AddInteger(scTruncate);
716  ID.AddPointer(Op);
717  ID.AddPointer(Ty);
718  void *IP = 0;
719  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
720
721  // Fold if the operand is constant.
722  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
723    return getConstant(
724      cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
725
726  // trunc(trunc(x)) --> trunc(x)
727  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
728    return getTruncateExpr(ST->getOperand(), Ty);
729
730  // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
731  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
732    return getTruncateOrSignExtend(SS->getOperand(), Ty);
733
734  // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
735  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
736    return getTruncateOrZeroExtend(SZ->getOperand(), Ty);
737
738  // If the input value is a chrec scev, truncate the chrec's operands.
739  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
740    SmallVector<const SCEV *, 4> Operands;
741    for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
742      Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty));
743    return getAddRecExpr(Operands, AddRec->getLoop());
744  }
745
746  // The cast wasn't folded; create an explicit cast node.
747  // Recompute the insert position, as it may have been invalidated.
748  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
749  SCEV *S = SCEVAllocator.Allocate<SCEVTruncateExpr>();
750  new (S) SCEVTruncateExpr(ID, Op, Ty);
751  UniqueSCEVs.InsertNode(S, IP);
752  return S;
753}
754
755const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
756                                              const Type *Ty) {
757  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
758         "This is not an extending conversion!");
759  assert(isSCEVable(Ty) &&
760         "This is not a conversion to a SCEVable type!");
761  Ty = getEffectiveSCEVType(Ty);
762
763  // Fold if the operand is constant.
764  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) {
765    const Type *IntTy = getEffectiveSCEVType(Ty);
766    Constant *C = ConstantExpr::getZExt(SC->getValue(), IntTy);
767    if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty);
768    return getConstant(cast<ConstantInt>(C));
769  }
770
771  // zext(zext(x)) --> zext(x)
772  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
773    return getZeroExtendExpr(SZ->getOperand(), Ty);
774
775  // Before doing any expensive analysis, check to see if we've already
776  // computed a SCEV for this Op and Ty.
777  FoldingSetNodeID ID;
778  ID.AddInteger(scZeroExtend);
779  ID.AddPointer(Op);
780  ID.AddPointer(Ty);
781  void *IP = 0;
782  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
783
784  // If the input value is a chrec scev, and we can prove that the value
785  // did not overflow the old, smaller, value, we can zero extend all of the
786  // operands (often constants).  This allows analysis of something like
787  // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
788  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
789    if (AR->isAffine()) {
790      // Check whether the backedge-taken count is SCEVCouldNotCompute.
791      // Note that this serves two purposes: It filters out loops that are
792      // simply not analyzable, and it covers the case where this code is
793      // being called from within backedge-taken count analysis, such that
794      // attempting to ask for the backedge-taken count would likely result
795      // in infinite recursion. In the later case, the analysis code will
796      // cope with a conservative value, and it will take care to purge
797      // that value once it has finished.
798      const SCEV *MaxBECount = getMaxBackedgeTakenCount(AR->getLoop());
799      if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
800        // Manually compute the final value for AR, checking for
801        // overflow.
802        const SCEV *Start = AR->getStart();
803        const SCEV *Step = AR->getStepRecurrence(*this);
804
805        // Check whether the backedge-taken count can be losslessly casted to
806        // the addrec's type. The count is always unsigned.
807        const SCEV *CastedMaxBECount =
808          getTruncateOrZeroExtend(MaxBECount, Start->getType());
809        const SCEV *RecastedMaxBECount =
810          getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
811        if (MaxBECount == RecastedMaxBECount) {
812          const Type *WideTy =
813            IntegerType::get(getTypeSizeInBits(Start->getType()) * 2);
814          // Check whether Start+Step*MaxBECount has no unsigned overflow.
815          const SCEV *ZMul =
816            getMulExpr(CastedMaxBECount,
817                       getTruncateOrZeroExtend(Step, Start->getType()));
818          const SCEV *Add = getAddExpr(Start, ZMul);
819          const SCEV *OperandExtendedAdd =
820            getAddExpr(getZeroExtendExpr(Start, WideTy),
821                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
822                                  getZeroExtendExpr(Step, WideTy)));
823          if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd)
824            // Return the expression with the addrec on the outside.
825            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
826                                 getZeroExtendExpr(Step, Ty),
827                                 AR->getLoop());
828
829          // Similar to above, only this time treat the step value as signed.
830          // This covers loops that count down.
831          const SCEV *SMul =
832            getMulExpr(CastedMaxBECount,
833                       getTruncateOrSignExtend(Step, Start->getType()));
834          Add = getAddExpr(Start, SMul);
835          OperandExtendedAdd =
836            getAddExpr(getZeroExtendExpr(Start, WideTy),
837                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
838                                  getSignExtendExpr(Step, WideTy)));
839          if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd)
840            // Return the expression with the addrec on the outside.
841            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
842                                 getSignExtendExpr(Step, Ty),
843                                 AR->getLoop());
844        }
845      }
846    }
847
848  // The cast wasn't folded; create an explicit cast node.
849  // Recompute the insert position, as it may have been invalidated.
850  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
851  SCEV *S = SCEVAllocator.Allocate<SCEVZeroExtendExpr>();
852  new (S) SCEVZeroExtendExpr(ID, Op, Ty);
853  UniqueSCEVs.InsertNode(S, IP);
854  return S;
855}
856
857const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
858                                              const Type *Ty) {
859  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
860         "This is not an extending conversion!");
861  assert(isSCEVable(Ty) &&
862         "This is not a conversion to a SCEVable type!");
863  Ty = getEffectiveSCEVType(Ty);
864
865  // Fold if the operand is constant.
866  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) {
867    const Type *IntTy = getEffectiveSCEVType(Ty);
868    Constant *C = ConstantExpr::getSExt(SC->getValue(), IntTy);
869    if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty);
870    return getConstant(cast<ConstantInt>(C));
871  }
872
873  // sext(sext(x)) --> sext(x)
874  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
875    return getSignExtendExpr(SS->getOperand(), Ty);
876
877  // Before doing any expensive analysis, check to see if we've already
878  // computed a SCEV for this Op and Ty.
879  FoldingSetNodeID ID;
880  ID.AddInteger(scSignExtend);
881  ID.AddPointer(Op);
882  ID.AddPointer(Ty);
883  void *IP = 0;
884  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
885
886  // If the input value is a chrec scev, and we can prove that the value
887  // did not overflow the old, smaller, value, we can sign extend all of the
888  // operands (often constants).  This allows analysis of something like
889  // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
890  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
891    if (AR->isAffine()) {
892      // Check whether the backedge-taken count is SCEVCouldNotCompute.
893      // Note that this serves two purposes: It filters out loops that are
894      // simply not analyzable, and it covers the case where this code is
895      // being called from within backedge-taken count analysis, such that
896      // attempting to ask for the backedge-taken count would likely result
897      // in infinite recursion. In the later case, the analysis code will
898      // cope with a conservative value, and it will take care to purge
899      // that value once it has finished.
900      const SCEV *MaxBECount = getMaxBackedgeTakenCount(AR->getLoop());
901      if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
902        // Manually compute the final value for AR, checking for
903        // overflow.
904        const SCEV *Start = AR->getStart();
905        const SCEV *Step = AR->getStepRecurrence(*this);
906
907        // Check whether the backedge-taken count can be losslessly casted to
908        // the addrec's type. The count is always unsigned.
909        const SCEV *CastedMaxBECount =
910          getTruncateOrZeroExtend(MaxBECount, Start->getType());
911        const SCEV *RecastedMaxBECount =
912          getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
913        if (MaxBECount == RecastedMaxBECount) {
914          const Type *WideTy =
915            IntegerType::get(getTypeSizeInBits(Start->getType()) * 2);
916          // Check whether Start+Step*MaxBECount has no signed overflow.
917          const SCEV *SMul =
918            getMulExpr(CastedMaxBECount,
919                       getTruncateOrSignExtend(Step, Start->getType()));
920          const SCEV *Add = getAddExpr(Start, SMul);
921          const SCEV *OperandExtendedAdd =
922            getAddExpr(getSignExtendExpr(Start, WideTy),
923                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
924                                  getSignExtendExpr(Step, WideTy)));
925          if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd)
926            // Return the expression with the addrec on the outside.
927            return getAddRecExpr(getSignExtendExpr(Start, Ty),
928                                 getSignExtendExpr(Step, Ty),
929                                 AR->getLoop());
930        }
931      }
932    }
933
934  // The cast wasn't folded; create an explicit cast node.
935  // Recompute the insert position, as it may have been invalidated.
936  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
937  SCEV *S = SCEVAllocator.Allocate<SCEVSignExtendExpr>();
938  new (S) SCEVSignExtendExpr(ID, Op, Ty);
939  UniqueSCEVs.InsertNode(S, IP);
940  return S;
941}
942
943/// getAnyExtendExpr - Return a SCEV for the given operand extended with
944/// unspecified bits out to the given type.
945///
946const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
947                                             const Type *Ty) {
948  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
949         "This is not an extending conversion!");
950  assert(isSCEVable(Ty) &&
951         "This is not a conversion to a SCEVable type!");
952  Ty = getEffectiveSCEVType(Ty);
953
954  // Sign-extend negative constants.
955  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
956    if (SC->getValue()->getValue().isNegative())
957      return getSignExtendExpr(Op, Ty);
958
959  // Peel off a truncate cast.
960  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
961    const SCEV *NewOp = T->getOperand();
962    if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
963      return getAnyExtendExpr(NewOp, Ty);
964    return getTruncateOrNoop(NewOp, Ty);
965  }
966
967  // Next try a zext cast. If the cast is folded, use it.
968  const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
969  if (!isa<SCEVZeroExtendExpr>(ZExt))
970    return ZExt;
971
972  // Next try a sext cast. If the cast is folded, use it.
973  const SCEV *SExt = getSignExtendExpr(Op, Ty);
974  if (!isa<SCEVSignExtendExpr>(SExt))
975    return SExt;
976
977  // If the expression is obviously signed, use the sext cast value.
978  if (isa<SCEVSMaxExpr>(Op))
979    return SExt;
980
981  // Absent any other information, use the zext cast value.
982  return ZExt;
983}
984
985/// CollectAddOperandsWithScales - Process the given Ops list, which is
986/// a list of operands to be added under the given scale, update the given
987/// map. This is a helper function for getAddRecExpr. As an example of
988/// what it does, given a sequence of operands that would form an add
989/// expression like this:
990///
991///    m + n + 13 + (A * (o + p + (B * q + m + 29))) + r + (-1 * r)
992///
993/// where A and B are constants, update the map with these values:
994///
995///    (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
996///
997/// and add 13 + A*B*29 to AccumulatedConstant.
998/// This will allow getAddRecExpr to produce this:
999///
1000///    13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
1001///
1002/// This form often exposes folding opportunities that are hidden in
1003/// the original operand list.
1004///
1005/// Return true iff it appears that any interesting folding opportunities
1006/// may be exposed. This helps getAddRecExpr short-circuit extra work in
1007/// the common case where no interesting opportunities are present, and
1008/// is also used as a check to avoid infinite recursion.
1009///
1010static bool
1011CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
1012                             SmallVector<const SCEV *, 8> &NewOps,
1013                             APInt &AccumulatedConstant,
1014                             const SmallVectorImpl<const SCEV *> &Ops,
1015                             const APInt &Scale,
1016                             ScalarEvolution &SE) {
1017  bool Interesting = false;
1018
1019  // Iterate over the add operands.
1020  for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
1021    const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
1022    if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
1023      APInt NewScale =
1024        Scale * cast<SCEVConstant>(Mul->getOperand(0))->getValue()->getValue();
1025      if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
1026        // A multiplication of a constant with another add; recurse.
1027        Interesting |=
1028          CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1029                                       cast<SCEVAddExpr>(Mul->getOperand(1))
1030                                         ->getOperands(),
1031                                       NewScale, SE);
1032      } else {
1033        // A multiplication of a constant with some other value. Update
1034        // the map.
1035        SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end());
1036        const SCEV *Key = SE.getMulExpr(MulOps);
1037        std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
1038          M.insert(std::make_pair(Key, NewScale));
1039        if (Pair.second) {
1040          NewOps.push_back(Pair.first->first);
1041        } else {
1042          Pair.first->second += NewScale;
1043          // The map already had an entry for this value, which may indicate
1044          // a folding opportunity.
1045          Interesting = true;
1046        }
1047      }
1048    } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1049      // Pull a buried constant out to the outside.
1050      if (Scale != 1 || AccumulatedConstant != 0 || C->isZero())
1051        Interesting = true;
1052      AccumulatedConstant += Scale * C->getValue()->getValue();
1053    } else {
1054      // An ordinary operand. Update the map.
1055      std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
1056        M.insert(std::make_pair(Ops[i], Scale));
1057      if (Pair.second) {
1058        NewOps.push_back(Pair.first->first);
1059      } else {
1060        Pair.first->second += Scale;
1061        // The map already had an entry for this value, which may indicate
1062        // a folding opportunity.
1063        Interesting = true;
1064      }
1065    }
1066  }
1067
1068  return Interesting;
1069}
1070
1071namespace {
1072  struct APIntCompare {
1073    bool operator()(const APInt &LHS, const APInt &RHS) const {
1074      return LHS.ult(RHS);
1075    }
1076  };
1077}
1078
1079/// getAddExpr - Get a canonical add expression, or something simpler if
1080/// possible.
1081const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops) {
1082  assert(!Ops.empty() && "Cannot get empty add!");
1083  if (Ops.size() == 1) return Ops[0];
1084#ifndef NDEBUG
1085  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1086    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
1087           getEffectiveSCEVType(Ops[0]->getType()) &&
1088           "SCEVAddExpr operand types don't match!");
1089#endif
1090
1091  // Sort by complexity, this groups all similar expression types together.
1092  GroupByComplexity(Ops, LI);
1093
1094  // If there are any constants, fold them together.
1095  unsigned Idx = 0;
1096  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1097    ++Idx;
1098    assert(Idx < Ops.size());
1099    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1100      // We found two constants, fold them together!
1101      Ops[0] = getConstant(LHSC->getValue()->getValue() +
1102                           RHSC->getValue()->getValue());
1103      if (Ops.size() == 2) return Ops[0];
1104      Ops.erase(Ops.begin()+1);  // Erase the folded element
1105      LHSC = cast<SCEVConstant>(Ops[0]);
1106    }
1107
1108    // If we are left with a constant zero being added, strip it off.
1109    if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
1110      Ops.erase(Ops.begin());
1111      --Idx;
1112    }
1113  }
1114
1115  if (Ops.size() == 1) return Ops[0];
1116
1117  // Okay, check to see if the same value occurs in the operand list twice.  If
1118  // so, merge them together into an multiply expression.  Since we sorted the
1119  // list, these values are required to be adjacent.
1120  const Type *Ty = Ops[0]->getType();
1121  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1122    if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2
1123      // Found a match, merge the two values into a multiply, and add any
1124      // remaining values to the result.
1125      const SCEV *Two = getIntegerSCEV(2, Ty);
1126      const SCEV *Mul = getMulExpr(Ops[i], Two);
1127      if (Ops.size() == 2)
1128        return Mul;
1129      Ops.erase(Ops.begin()+i, Ops.begin()+i+2);
1130      Ops.push_back(Mul);
1131      return getAddExpr(Ops);
1132    }
1133
1134  // Check for truncates. If all the operands are truncated from the same
1135  // type, see if factoring out the truncate would permit the result to be
1136  // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n)
1137  // if the contents of the resulting outer trunc fold to something simple.
1138  for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) {
1139    const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]);
1140    const Type *DstType = Trunc->getType();
1141    const Type *SrcType = Trunc->getOperand()->getType();
1142    SmallVector<const SCEV *, 8> LargeOps;
1143    bool Ok = true;
1144    // Check all the operands to see if they can be represented in the
1145    // source type of the truncate.
1146    for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
1147      if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
1148        if (T->getOperand()->getType() != SrcType) {
1149          Ok = false;
1150          break;
1151        }
1152        LargeOps.push_back(T->getOperand());
1153      } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1154        // This could be either sign or zero extension, but sign extension
1155        // is much more likely to be foldable here.
1156        LargeOps.push_back(getSignExtendExpr(C, SrcType));
1157      } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
1158        SmallVector<const SCEV *, 8> LargeMulOps;
1159        for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
1160          if (const SCEVTruncateExpr *T =
1161                dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
1162            if (T->getOperand()->getType() != SrcType) {
1163              Ok = false;
1164              break;
1165            }
1166            LargeMulOps.push_back(T->getOperand());
1167          } else if (const SCEVConstant *C =
1168                       dyn_cast<SCEVConstant>(M->getOperand(j))) {
1169            // This could be either sign or zero extension, but sign extension
1170            // is much more likely to be foldable here.
1171            LargeMulOps.push_back(getSignExtendExpr(C, SrcType));
1172          } else {
1173            Ok = false;
1174            break;
1175          }
1176        }
1177        if (Ok)
1178          LargeOps.push_back(getMulExpr(LargeMulOps));
1179      } else {
1180        Ok = false;
1181        break;
1182      }
1183    }
1184    if (Ok) {
1185      // Evaluate the expression in the larger type.
1186      const SCEV *Fold = getAddExpr(LargeOps);
1187      // If it folds to something simple, use it. Otherwise, don't.
1188      if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
1189        return getTruncateExpr(Fold, DstType);
1190    }
1191  }
1192
1193  // Skip past any other cast SCEVs.
1194  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
1195    ++Idx;
1196
1197  // If there are add operands they would be next.
1198  if (Idx < Ops.size()) {
1199    bool DeletedAdd = false;
1200    while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
1201      // If we have an add, expand the add operands onto the end of the operands
1202      // list.
1203      Ops.insert(Ops.end(), Add->op_begin(), Add->op_end());
1204      Ops.erase(Ops.begin()+Idx);
1205      DeletedAdd = true;
1206    }
1207
1208    // If we deleted at least one add, we added operands to the end of the list,
1209    // and they are not necessarily sorted.  Recurse to resort and resimplify
1210    // any operands we just aquired.
1211    if (DeletedAdd)
1212      return getAddExpr(Ops);
1213  }
1214
1215  // Skip over the add expression until we get to a multiply.
1216  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1217    ++Idx;
1218
1219  // Check to see if there are any folding opportunities present with
1220  // operands multiplied by constant values.
1221  if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
1222    uint64_t BitWidth = getTypeSizeInBits(Ty);
1223    DenseMap<const SCEV *, APInt> M;
1224    SmallVector<const SCEV *, 8> NewOps;
1225    APInt AccumulatedConstant(BitWidth, 0);
1226    if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1227                                     Ops, APInt(BitWidth, 1), *this)) {
1228      // Some interesting folding opportunity is present, so its worthwhile to
1229      // re-generate the operands list. Group the operands by constant scale,
1230      // to avoid multiplying by the same constant scale multiple times.
1231      std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
1232      for (SmallVector<const SCEV *, 8>::iterator I = NewOps.begin(),
1233           E = NewOps.end(); I != E; ++I)
1234        MulOpLists[M.find(*I)->second].push_back(*I);
1235      // Re-generate the operands list.
1236      Ops.clear();
1237      if (AccumulatedConstant != 0)
1238        Ops.push_back(getConstant(AccumulatedConstant));
1239      for (std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare>::iterator
1240           I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I)
1241        if (I->first != 0)
1242          Ops.push_back(getMulExpr(getConstant(I->first),
1243                                   getAddExpr(I->second)));
1244      if (Ops.empty())
1245        return getIntegerSCEV(0, Ty);
1246      if (Ops.size() == 1)
1247        return Ops[0];
1248      return getAddExpr(Ops);
1249    }
1250  }
1251
1252  // If we are adding something to a multiply expression, make sure the
1253  // something is not already an operand of the multiply.  If so, merge it into
1254  // the multiply.
1255  for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
1256    const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
1257    for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
1258      const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
1259      for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
1260        if (MulOpSCEV == Ops[AddOp] && !isa<SCEVConstant>(Ops[AddOp])) {
1261          // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1))
1262          const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
1263          if (Mul->getNumOperands() != 2) {
1264            // If the multiply has more than two operands, we must get the
1265            // Y*Z term.
1266            SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(), Mul->op_end());
1267            MulOps.erase(MulOps.begin()+MulOp);
1268            InnerMul = getMulExpr(MulOps);
1269          }
1270          const SCEV *One = getIntegerSCEV(1, Ty);
1271          const SCEV *AddOne = getAddExpr(InnerMul, One);
1272          const SCEV *OuterMul = getMulExpr(AddOne, Ops[AddOp]);
1273          if (Ops.size() == 2) return OuterMul;
1274          if (AddOp < Idx) {
1275            Ops.erase(Ops.begin()+AddOp);
1276            Ops.erase(Ops.begin()+Idx-1);
1277          } else {
1278            Ops.erase(Ops.begin()+Idx);
1279            Ops.erase(Ops.begin()+AddOp-1);
1280          }
1281          Ops.push_back(OuterMul);
1282          return getAddExpr(Ops);
1283        }
1284
1285      // Check this multiply against other multiplies being added together.
1286      for (unsigned OtherMulIdx = Idx+1;
1287           OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
1288           ++OtherMulIdx) {
1289        const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
1290        // If MulOp occurs in OtherMul, we can fold the two multiplies
1291        // together.
1292        for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
1293             OMulOp != e; ++OMulOp)
1294          if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
1295            // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
1296            const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
1297            if (Mul->getNumOperands() != 2) {
1298              SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
1299                                                  Mul->op_end());
1300              MulOps.erase(MulOps.begin()+MulOp);
1301              InnerMul1 = getMulExpr(MulOps);
1302            }
1303            const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
1304            if (OtherMul->getNumOperands() != 2) {
1305              SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
1306                                                  OtherMul->op_end());
1307              MulOps.erase(MulOps.begin()+OMulOp);
1308              InnerMul2 = getMulExpr(MulOps);
1309            }
1310            const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2);
1311            const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum);
1312            if (Ops.size() == 2) return OuterMul;
1313            Ops.erase(Ops.begin()+Idx);
1314            Ops.erase(Ops.begin()+OtherMulIdx-1);
1315            Ops.push_back(OuterMul);
1316            return getAddExpr(Ops);
1317          }
1318      }
1319    }
1320  }
1321
1322  // If there are any add recurrences in the operands list, see if any other
1323  // added values are loop invariant.  If so, we can fold them into the
1324  // recurrence.
1325  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1326    ++Idx;
1327
1328  // Scan over all recurrences, trying to fold loop invariants into them.
1329  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1330    // Scan all of the other operands to this add and add them to the vector if
1331    // they are loop invariant w.r.t. the recurrence.
1332    SmallVector<const SCEV *, 8> LIOps;
1333    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1334    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1335      if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
1336        LIOps.push_back(Ops[i]);
1337        Ops.erase(Ops.begin()+i);
1338        --i; --e;
1339      }
1340
1341    // If we found some loop invariants, fold them into the recurrence.
1342    if (!LIOps.empty()) {
1343      //  NLI + LI + {Start,+,Step}  -->  NLI + {LI+Start,+,Step}
1344      LIOps.push_back(AddRec->getStart());
1345
1346      SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
1347                                           AddRec->op_end());
1348      AddRecOps[0] = getAddExpr(LIOps);
1349
1350      const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRec->getLoop());
1351      // If all of the other operands were loop invariant, we are done.
1352      if (Ops.size() == 1) return NewRec;
1353
1354      // Otherwise, add the folded AddRec by the non-liv parts.
1355      for (unsigned i = 0;; ++i)
1356        if (Ops[i] == AddRec) {
1357          Ops[i] = NewRec;
1358          break;
1359        }
1360      return getAddExpr(Ops);
1361    }
1362
1363    // Okay, if there weren't any loop invariants to be folded, check to see if
1364    // there are multiple AddRec's with the same loop induction variable being
1365    // added together.  If so, we can fold them.
1366    for (unsigned OtherIdx = Idx+1;
1367         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
1368      if (OtherIdx != Idx) {
1369        const SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
1370        if (AddRec->getLoop() == OtherAddRec->getLoop()) {
1371          // Other + {A,+,B} + {C,+,D}  -->  Other + {A+C,+,B+D}
1372          SmallVector<const SCEV *, 4> NewOps(AddRec->op_begin(),
1373                                              AddRec->op_end());
1374          for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) {
1375            if (i >= NewOps.size()) {
1376              NewOps.insert(NewOps.end(), OtherAddRec->op_begin()+i,
1377                            OtherAddRec->op_end());
1378              break;
1379            }
1380            NewOps[i] = getAddExpr(NewOps[i], OtherAddRec->getOperand(i));
1381          }
1382          const SCEV *NewAddRec = getAddRecExpr(NewOps, AddRec->getLoop());
1383
1384          if (Ops.size() == 2) return NewAddRec;
1385
1386          Ops.erase(Ops.begin()+Idx);
1387          Ops.erase(Ops.begin()+OtherIdx-1);
1388          Ops.push_back(NewAddRec);
1389          return getAddExpr(Ops);
1390        }
1391      }
1392
1393    // Otherwise couldn't fold anything into this recurrence.  Move onto the
1394    // next one.
1395  }
1396
1397  // Okay, it looks like we really DO need an add expr.  Check to see if we
1398  // already have one, otherwise create a new one.
1399  FoldingSetNodeID ID;
1400  ID.AddInteger(scAddExpr);
1401  ID.AddInteger(Ops.size());
1402  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1403    ID.AddPointer(Ops[i]);
1404  void *IP = 0;
1405  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1406  SCEV *S = SCEVAllocator.Allocate<SCEVAddExpr>();
1407  new (S) SCEVAddExpr(ID, Ops);
1408  UniqueSCEVs.InsertNode(S, IP);
1409  return S;
1410}
1411
1412
1413/// getMulExpr - Get a canonical multiply expression, or something simpler if
1414/// possible.
1415const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops) {
1416  assert(!Ops.empty() && "Cannot get empty mul!");
1417#ifndef NDEBUG
1418  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1419    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
1420           getEffectiveSCEVType(Ops[0]->getType()) &&
1421           "SCEVMulExpr operand types don't match!");
1422#endif
1423
1424  // Sort by complexity, this groups all similar expression types together.
1425  GroupByComplexity(Ops, LI);
1426
1427  // If there are any constants, fold them together.
1428  unsigned Idx = 0;
1429  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1430
1431    // C1*(C2+V) -> C1*C2 + C1*V
1432    if (Ops.size() == 2)
1433      if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
1434        if (Add->getNumOperands() == 2 &&
1435            isa<SCEVConstant>(Add->getOperand(0)))
1436          return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
1437                            getMulExpr(LHSC, Add->getOperand(1)));
1438
1439
1440    ++Idx;
1441    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1442      // We found two constants, fold them together!
1443      ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() *
1444                                           RHSC->getValue()->getValue());
1445      Ops[0] = getConstant(Fold);
1446      Ops.erase(Ops.begin()+1);  // Erase the folded element
1447      if (Ops.size() == 1) return Ops[0];
1448      LHSC = cast<SCEVConstant>(Ops[0]);
1449    }
1450
1451    // If we are left with a constant one being multiplied, strip it off.
1452    if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) {
1453      Ops.erase(Ops.begin());
1454      --Idx;
1455    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
1456      // If we have a multiply of zero, it will always be zero.
1457      return Ops[0];
1458    }
1459  }
1460
1461  // Skip over the add expression until we get to a multiply.
1462  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1463    ++Idx;
1464
1465  if (Ops.size() == 1)
1466    return Ops[0];
1467
1468  // If there are mul operands inline them all into this expression.
1469  if (Idx < Ops.size()) {
1470    bool DeletedMul = false;
1471    while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
1472      // If we have an mul, expand the mul operands onto the end of the operands
1473      // list.
1474      Ops.insert(Ops.end(), Mul->op_begin(), Mul->op_end());
1475      Ops.erase(Ops.begin()+Idx);
1476      DeletedMul = true;
1477    }
1478
1479    // If we deleted at least one mul, we added operands to the end of the list,
1480    // and they are not necessarily sorted.  Recurse to resort and resimplify
1481    // any operands we just aquired.
1482    if (DeletedMul)
1483      return getMulExpr(Ops);
1484  }
1485
1486  // If there are any add recurrences in the operands list, see if any other
1487  // added values are loop invariant.  If so, we can fold them into the
1488  // recurrence.
1489  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1490    ++Idx;
1491
1492  // Scan over all recurrences, trying to fold loop invariants into them.
1493  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1494    // Scan all of the other operands to this mul and add them to the vector if
1495    // they are loop invariant w.r.t. the recurrence.
1496    SmallVector<const SCEV *, 8> LIOps;
1497    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1498    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1499      if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
1500        LIOps.push_back(Ops[i]);
1501        Ops.erase(Ops.begin()+i);
1502        --i; --e;
1503      }
1504
1505    // If we found some loop invariants, fold them into the recurrence.
1506    if (!LIOps.empty()) {
1507      //  NLI * LI * {Start,+,Step}  -->  NLI * {LI*Start,+,LI*Step}
1508      SmallVector<const SCEV *, 4> NewOps;
1509      NewOps.reserve(AddRec->getNumOperands());
1510      if (LIOps.size() == 1) {
1511        const SCEV *Scale = LIOps[0];
1512        for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
1513          NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i)));
1514      } else {
1515        for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
1516          SmallVector<const SCEV *, 4> MulOps(LIOps.begin(), LIOps.end());
1517          MulOps.push_back(AddRec->getOperand(i));
1518          NewOps.push_back(getMulExpr(MulOps));
1519        }
1520      }
1521
1522      const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop());
1523
1524      // If all of the other operands were loop invariant, we are done.
1525      if (Ops.size() == 1) return NewRec;
1526
1527      // Otherwise, multiply the folded AddRec by the non-liv parts.
1528      for (unsigned i = 0;; ++i)
1529        if (Ops[i] == AddRec) {
1530          Ops[i] = NewRec;
1531          break;
1532        }
1533      return getMulExpr(Ops);
1534    }
1535
1536    // Okay, if there weren't any loop invariants to be folded, check to see if
1537    // there are multiple AddRec's with the same loop induction variable being
1538    // multiplied together.  If so, we can fold them.
1539    for (unsigned OtherIdx = Idx+1;
1540         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
1541      if (OtherIdx != Idx) {
1542        const SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
1543        if (AddRec->getLoop() == OtherAddRec->getLoop()) {
1544          // F * G  -->  {A,+,B} * {C,+,D}  -->  {A*C,+,F*D + G*B + B*D}
1545          const SCEVAddRecExpr *F = AddRec, *G = OtherAddRec;
1546          const SCEV *NewStart = getMulExpr(F->getStart(),
1547                                                 G->getStart());
1548          const SCEV *B = F->getStepRecurrence(*this);
1549          const SCEV *D = G->getStepRecurrence(*this);
1550          const SCEV *NewStep = getAddExpr(getMulExpr(F, D),
1551                                          getMulExpr(G, B),
1552                                          getMulExpr(B, D));
1553          const SCEV *NewAddRec = getAddRecExpr(NewStart, NewStep,
1554                                               F->getLoop());
1555          if (Ops.size() == 2) return NewAddRec;
1556
1557          Ops.erase(Ops.begin()+Idx);
1558          Ops.erase(Ops.begin()+OtherIdx-1);
1559          Ops.push_back(NewAddRec);
1560          return getMulExpr(Ops);
1561        }
1562      }
1563
1564    // Otherwise couldn't fold anything into this recurrence.  Move onto the
1565    // next one.
1566  }
1567
1568  // Okay, it looks like we really DO need an mul expr.  Check to see if we
1569  // already have one, otherwise create a new one.
1570  FoldingSetNodeID ID;
1571  ID.AddInteger(scMulExpr);
1572  ID.AddInteger(Ops.size());
1573  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1574    ID.AddPointer(Ops[i]);
1575  void *IP = 0;
1576  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1577  SCEV *S = SCEVAllocator.Allocate<SCEVMulExpr>();
1578  new (S) SCEVMulExpr(ID, Ops);
1579  UniqueSCEVs.InsertNode(S, IP);
1580  return S;
1581}
1582
1583/// getUDivExpr - Get a canonical multiply expression, or something simpler if
1584/// possible.
1585const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
1586                                         const SCEV *RHS) {
1587  assert(getEffectiveSCEVType(LHS->getType()) ==
1588         getEffectiveSCEVType(RHS->getType()) &&
1589         "SCEVUDivExpr operand types don't match!");
1590
1591  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
1592    if (RHSC->getValue()->equalsInt(1))
1593      return LHS;                            // X udiv 1 --> x
1594    if (RHSC->isZero())
1595      return getIntegerSCEV(0, LHS->getType()); // value is undefined
1596
1597    // Determine if the division can be folded into the operands of
1598    // its operands.
1599    // TODO: Generalize this to non-constants by using known-bits information.
1600    const Type *Ty = LHS->getType();
1601    unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros();
1602    unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ;
1603    // For non-power-of-two values, effectively round the value up to the
1604    // nearest power of two.
1605    if (!RHSC->getValue()->getValue().isPowerOf2())
1606      ++MaxShiftAmt;
1607    const IntegerType *ExtTy =
1608      IntegerType::get(getTypeSizeInBits(Ty) + MaxShiftAmt);
1609    // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
1610    if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
1611      if (const SCEVConstant *Step =
1612            dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this)))
1613        if (!Step->getValue()->getValue()
1614              .urem(RHSC->getValue()->getValue()) &&
1615            getZeroExtendExpr(AR, ExtTy) ==
1616            getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
1617                          getZeroExtendExpr(Step, ExtTy),
1618                          AR->getLoop())) {
1619          SmallVector<const SCEV *, 4> Operands;
1620          for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i)
1621            Operands.push_back(getUDivExpr(AR->getOperand(i), RHS));
1622          return getAddRecExpr(Operands, AR->getLoop());
1623        }
1624    // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
1625    if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
1626      SmallVector<const SCEV *, 4> Operands;
1627      for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i)
1628        Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy));
1629      if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
1630        // Find an operand that's safely divisible.
1631        for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
1632          const SCEV *Op = M->getOperand(i);
1633          const SCEV *Div = getUDivExpr(Op, RHSC);
1634          if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
1635            const SmallVectorImpl<const SCEV *> &MOperands = M->getOperands();
1636            Operands = SmallVector<const SCEV *, 4>(MOperands.begin(),
1637                                                  MOperands.end());
1638            Operands[i] = Div;
1639            return getMulExpr(Operands);
1640          }
1641        }
1642    }
1643    // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
1644    if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(LHS)) {
1645      SmallVector<const SCEV *, 4> Operands;
1646      for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i)
1647        Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy));
1648      if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
1649        Operands.clear();
1650        for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
1651          const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
1652          if (isa<SCEVUDivExpr>(Op) || getMulExpr(Op, RHS) != A->getOperand(i))
1653            break;
1654          Operands.push_back(Op);
1655        }
1656        if (Operands.size() == A->getNumOperands())
1657          return getAddExpr(Operands);
1658      }
1659    }
1660
1661    // Fold if both operands are constant.
1662    if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
1663      Constant *LHSCV = LHSC->getValue();
1664      Constant *RHSCV = RHSC->getValue();
1665      return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
1666                                                                 RHSCV)));
1667    }
1668  }
1669
1670  FoldingSetNodeID ID;
1671  ID.AddInteger(scUDivExpr);
1672  ID.AddPointer(LHS);
1673  ID.AddPointer(RHS);
1674  void *IP = 0;
1675  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1676  SCEV *S = SCEVAllocator.Allocate<SCEVUDivExpr>();
1677  new (S) SCEVUDivExpr(ID, LHS, RHS);
1678  UniqueSCEVs.InsertNode(S, IP);
1679  return S;
1680}
1681
1682
1683/// getAddRecExpr - Get an add recurrence expression for the specified loop.
1684/// Simplify the expression as much as possible.
1685const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start,
1686                               const SCEV *Step, const Loop *L) {
1687  SmallVector<const SCEV *, 4> Operands;
1688  Operands.push_back(Start);
1689  if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
1690    if (StepChrec->getLoop() == L) {
1691      Operands.insert(Operands.end(), StepChrec->op_begin(),
1692                      StepChrec->op_end());
1693      return getAddRecExpr(Operands, L);
1694    }
1695
1696  Operands.push_back(Step);
1697  return getAddRecExpr(Operands, L);
1698}
1699
1700/// getAddRecExpr - Get an add recurrence expression for the specified loop.
1701/// Simplify the expression as much as possible.
1702const SCEV *
1703ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
1704                               const Loop *L) {
1705  if (Operands.size() == 1) return Operands[0];
1706#ifndef NDEBUG
1707  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
1708    assert(getEffectiveSCEVType(Operands[i]->getType()) ==
1709           getEffectiveSCEVType(Operands[0]->getType()) &&
1710           "SCEVAddRecExpr operand types don't match!");
1711#endif
1712
1713  if (Operands.back()->isZero()) {
1714    Operands.pop_back();
1715    return getAddRecExpr(Operands, L);             // {X,+,0}  -->  X
1716  }
1717
1718  // Canonicalize nested AddRecs in by nesting them in order of loop depth.
1719  if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
1720    const Loop* NestedLoop = NestedAR->getLoop();
1721    if (L->getLoopDepth() < NestedLoop->getLoopDepth()) {
1722      SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(),
1723                                                NestedAR->op_end());
1724      Operands[0] = NestedAR->getStart();
1725      // AddRecs require their operands be loop-invariant with respect to their
1726      // loops. Don't perform this transformation if it would break this
1727      // requirement.
1728      bool AllInvariant = true;
1729      for (unsigned i = 0, e = Operands.size(); i != e; ++i)
1730        if (!Operands[i]->isLoopInvariant(L)) {
1731          AllInvariant = false;
1732          break;
1733        }
1734      if (AllInvariant) {
1735        NestedOperands[0] = getAddRecExpr(Operands, L);
1736        AllInvariant = true;
1737        for (unsigned i = 0, e = NestedOperands.size(); i != e; ++i)
1738          if (!NestedOperands[i]->isLoopInvariant(NestedLoop)) {
1739            AllInvariant = false;
1740            break;
1741          }
1742        if (AllInvariant)
1743          // Ok, both add recurrences are valid after the transformation.
1744          return getAddRecExpr(NestedOperands, NestedLoop);
1745      }
1746      // Reset Operands to its original state.
1747      Operands[0] = NestedAR;
1748    }
1749  }
1750
1751  FoldingSetNodeID ID;
1752  ID.AddInteger(scAddRecExpr);
1753  ID.AddInteger(Operands.size());
1754  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
1755    ID.AddPointer(Operands[i]);
1756  ID.AddPointer(L);
1757  void *IP = 0;
1758  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1759  SCEV *S = SCEVAllocator.Allocate<SCEVAddRecExpr>();
1760  new (S) SCEVAddRecExpr(ID, Operands, L);
1761  UniqueSCEVs.InsertNode(S, IP);
1762  return S;
1763}
1764
1765const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS,
1766                                         const SCEV *RHS) {
1767  SmallVector<const SCEV *, 2> Ops;
1768  Ops.push_back(LHS);
1769  Ops.push_back(RHS);
1770  return getSMaxExpr(Ops);
1771}
1772
1773const SCEV *
1774ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
1775  assert(!Ops.empty() && "Cannot get empty smax!");
1776  if (Ops.size() == 1) return Ops[0];
1777#ifndef NDEBUG
1778  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1779    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
1780           getEffectiveSCEVType(Ops[0]->getType()) &&
1781           "SCEVSMaxExpr operand types don't match!");
1782#endif
1783
1784  // Sort by complexity, this groups all similar expression types together.
1785  GroupByComplexity(Ops, LI);
1786
1787  // If there are any constants, fold them together.
1788  unsigned Idx = 0;
1789  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1790    ++Idx;
1791    assert(Idx < Ops.size());
1792    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1793      // We found two constants, fold them together!
1794      ConstantInt *Fold = ConstantInt::get(
1795                              APIntOps::smax(LHSC->getValue()->getValue(),
1796                                             RHSC->getValue()->getValue()));
1797      Ops[0] = getConstant(Fold);
1798      Ops.erase(Ops.begin()+1);  // Erase the folded element
1799      if (Ops.size() == 1) return Ops[0];
1800      LHSC = cast<SCEVConstant>(Ops[0]);
1801    }
1802
1803    // If we are left with a constant minimum-int, strip it off.
1804    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
1805      Ops.erase(Ops.begin());
1806      --Idx;
1807    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(true)) {
1808      // If we have an smax with a constant maximum-int, it will always be
1809      // maximum-int.
1810      return Ops[0];
1811    }
1812  }
1813
1814  if (Ops.size() == 1) return Ops[0];
1815
1816  // Find the first SMax
1817  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
1818    ++Idx;
1819
1820  // Check to see if one of the operands is an SMax. If so, expand its operands
1821  // onto our operand list, and recurse to simplify.
1822  if (Idx < Ops.size()) {
1823    bool DeletedSMax = false;
1824    while (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
1825      Ops.insert(Ops.end(), SMax->op_begin(), SMax->op_end());
1826      Ops.erase(Ops.begin()+Idx);
1827      DeletedSMax = true;
1828    }
1829
1830    if (DeletedSMax)
1831      return getSMaxExpr(Ops);
1832  }
1833
1834  // Okay, check to see if the same value occurs in the operand list twice.  If
1835  // so, delete one.  Since we sorted the list, these values are required to
1836  // be adjacent.
1837  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1838    if (Ops[i] == Ops[i+1]) {      //  X smax Y smax Y  -->  X smax Y
1839      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
1840      --i; --e;
1841    }
1842
1843  if (Ops.size() == 1) return Ops[0];
1844
1845  assert(!Ops.empty() && "Reduced smax down to nothing!");
1846
1847  // Okay, it looks like we really DO need an smax expr.  Check to see if we
1848  // already have one, otherwise create a new one.
1849  FoldingSetNodeID ID;
1850  ID.AddInteger(scSMaxExpr);
1851  ID.AddInteger(Ops.size());
1852  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1853    ID.AddPointer(Ops[i]);
1854  void *IP = 0;
1855  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1856  SCEV *S = SCEVAllocator.Allocate<SCEVSMaxExpr>();
1857  new (S) SCEVSMaxExpr(ID, Ops);
1858  UniqueSCEVs.InsertNode(S, IP);
1859  return S;
1860}
1861
1862const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS,
1863                                         const SCEV *RHS) {
1864  SmallVector<const SCEV *, 2> Ops;
1865  Ops.push_back(LHS);
1866  Ops.push_back(RHS);
1867  return getUMaxExpr(Ops);
1868}
1869
1870const SCEV *
1871ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
1872  assert(!Ops.empty() && "Cannot get empty umax!");
1873  if (Ops.size() == 1) return Ops[0];
1874#ifndef NDEBUG
1875  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1876    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
1877           getEffectiveSCEVType(Ops[0]->getType()) &&
1878           "SCEVUMaxExpr operand types don't match!");
1879#endif
1880
1881  // Sort by complexity, this groups all similar expression types together.
1882  GroupByComplexity(Ops, LI);
1883
1884  // If there are any constants, fold them together.
1885  unsigned Idx = 0;
1886  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1887    ++Idx;
1888    assert(Idx < Ops.size());
1889    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1890      // We found two constants, fold them together!
1891      ConstantInt *Fold = ConstantInt::get(
1892                              APIntOps::umax(LHSC->getValue()->getValue(),
1893                                             RHSC->getValue()->getValue()));
1894      Ops[0] = getConstant(Fold);
1895      Ops.erase(Ops.begin()+1);  // Erase the folded element
1896      if (Ops.size() == 1) return Ops[0];
1897      LHSC = cast<SCEVConstant>(Ops[0]);
1898    }
1899
1900    // If we are left with a constant minimum-int, strip it off.
1901    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) {
1902      Ops.erase(Ops.begin());
1903      --Idx;
1904    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(false)) {
1905      // If we have an umax with a constant maximum-int, it will always be
1906      // maximum-int.
1907      return Ops[0];
1908    }
1909  }
1910
1911  if (Ops.size() == 1) return Ops[0];
1912
1913  // Find the first UMax
1914  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr)
1915    ++Idx;
1916
1917  // Check to see if one of the operands is a UMax. If so, expand its operands
1918  // onto our operand list, and recurse to simplify.
1919  if (Idx < Ops.size()) {
1920    bool DeletedUMax = false;
1921    while (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) {
1922      Ops.insert(Ops.end(), UMax->op_begin(), UMax->op_end());
1923      Ops.erase(Ops.begin()+Idx);
1924      DeletedUMax = true;
1925    }
1926
1927    if (DeletedUMax)
1928      return getUMaxExpr(Ops);
1929  }
1930
1931  // Okay, check to see if the same value occurs in the operand list twice.  If
1932  // so, delete one.  Since we sorted the list, these values are required to
1933  // be adjacent.
1934  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1935    if (Ops[i] == Ops[i+1]) {      //  X umax Y umax Y  -->  X umax Y
1936      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
1937      --i; --e;
1938    }
1939
1940  if (Ops.size() == 1) return Ops[0];
1941
1942  assert(!Ops.empty() && "Reduced umax down to nothing!");
1943
1944  // Okay, it looks like we really DO need a umax expr.  Check to see if we
1945  // already have one, otherwise create a new one.
1946  FoldingSetNodeID ID;
1947  ID.AddInteger(scUMaxExpr);
1948  ID.AddInteger(Ops.size());
1949  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1950    ID.AddPointer(Ops[i]);
1951  void *IP = 0;
1952  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1953  SCEV *S = SCEVAllocator.Allocate<SCEVUMaxExpr>();
1954  new (S) SCEVUMaxExpr(ID, Ops);
1955  UniqueSCEVs.InsertNode(S, IP);
1956  return S;
1957}
1958
1959const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
1960                                         const SCEV *RHS) {
1961  // ~smax(~x, ~y) == smin(x, y).
1962  return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
1963}
1964
1965const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS,
1966                                         const SCEV *RHS) {
1967  // ~umax(~x, ~y) == umin(x, y)
1968  return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
1969}
1970
1971const SCEV *ScalarEvolution::getUnknown(Value *V) {
1972  // Don't attempt to do anything other than create a SCEVUnknown object
1973  // here.  createSCEV only calls getUnknown after checking for all other
1974  // interesting possibilities, and any other code that calls getUnknown
1975  // is doing so in order to hide a value from SCEV canonicalization.
1976
1977  FoldingSetNodeID ID;
1978  ID.AddInteger(scUnknown);
1979  ID.AddPointer(V);
1980  void *IP = 0;
1981  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1982  SCEV *S = SCEVAllocator.Allocate<SCEVUnknown>();
1983  new (S) SCEVUnknown(ID, V);
1984  UniqueSCEVs.InsertNode(S, IP);
1985  return S;
1986}
1987
1988//===----------------------------------------------------------------------===//
1989//            Basic SCEV Analysis and PHI Idiom Recognition Code
1990//
1991
1992/// isSCEVable - Test if values of the given type are analyzable within
1993/// the SCEV framework. This primarily includes integer types, and it
1994/// can optionally include pointer types if the ScalarEvolution class
1995/// has access to target-specific information.
1996bool ScalarEvolution::isSCEVable(const Type *Ty) const {
1997  // Integers are always SCEVable.
1998  if (Ty->isInteger())
1999    return true;
2000
2001  // Pointers are SCEVable if TargetData information is available
2002  // to provide pointer size information.
2003  if (isa<PointerType>(Ty))
2004    return TD != NULL;
2005
2006  // Otherwise it's not SCEVable.
2007  return false;
2008}
2009
2010/// getTypeSizeInBits - Return the size in bits of the specified type,
2011/// for which isSCEVable must return true.
2012uint64_t ScalarEvolution::getTypeSizeInBits(const Type *Ty) const {
2013  assert(isSCEVable(Ty) && "Type is not SCEVable!");
2014
2015  // If we have a TargetData, use it!
2016  if (TD)
2017    return TD->getTypeSizeInBits(Ty);
2018
2019  // Otherwise, we support only integer types.
2020  assert(Ty->isInteger() && "isSCEVable permitted a non-SCEVable type!");
2021  return Ty->getPrimitiveSizeInBits();
2022}
2023
2024/// getEffectiveSCEVType - Return a type with the same bitwidth as
2025/// the given type and which represents how SCEV will treat the given
2026/// type, for which isSCEVable must return true. For pointer types,
2027/// this is the pointer-sized integer type.
2028const Type *ScalarEvolution::getEffectiveSCEVType(const Type *Ty) const {
2029  assert(isSCEVable(Ty) && "Type is not SCEVable!");
2030
2031  if (Ty->isInteger())
2032    return Ty;
2033
2034  assert(isa<PointerType>(Ty) && "Unexpected non-pointer non-integer type!");
2035  return TD->getIntPtrType();
2036}
2037
2038const SCEV *ScalarEvolution::getCouldNotCompute() {
2039  return &CouldNotCompute;
2040}
2041
2042/// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
2043/// expression and create a new one.
2044const SCEV *ScalarEvolution::getSCEV(Value *V) {
2045  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
2046
2047  std::map<SCEVCallbackVH, const SCEV *>::iterator I = Scalars.find(V);
2048  if (I != Scalars.end()) return I->second;
2049  const SCEV *S = createSCEV(V);
2050  Scalars.insert(std::make_pair(SCEVCallbackVH(V, this), S));
2051  return S;
2052}
2053
2054/// getIntegerSCEV - Given a SCEVable type, create a constant for the
2055/// specified signed integer value and return a SCEV for the constant.
2056const SCEV *ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) {
2057  const IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
2058  return getConstant(ConstantInt::get(ITy, Val));
2059}
2060
2061/// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
2062///
2063const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V) {
2064  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
2065    return getConstant(
2066               cast<ConstantInt>(Context->getConstantExprNeg(VC->getValue())));
2067
2068  const Type *Ty = V->getType();
2069  Ty = getEffectiveSCEVType(Ty);
2070  return getMulExpr(V,
2071                  getConstant(cast<ConstantInt>(Context->getAllOnesValue(Ty))));
2072}
2073
2074/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
2075const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
2076  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
2077    return getConstant(
2078                cast<ConstantInt>(Context->getConstantExprNot(VC->getValue())));
2079
2080  const Type *Ty = V->getType();
2081  Ty = getEffectiveSCEVType(Ty);
2082  const SCEV *AllOnes =
2083                   getConstant(cast<ConstantInt>(Context->getAllOnesValue(Ty)));
2084  return getMinusSCEV(AllOnes, V);
2085}
2086
2087/// getMinusSCEV - Return a SCEV corresponding to LHS - RHS.
2088///
2089const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS,
2090                                          const SCEV *RHS) {
2091  // X - Y --> X + -Y
2092  return getAddExpr(LHS, getNegativeSCEV(RHS));
2093}
2094
2095/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
2096/// input value to the specified type.  If the type must be extended, it is zero
2097/// extended.
2098const SCEV *
2099ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V,
2100                                         const Type *Ty) {
2101  const Type *SrcTy = V->getType();
2102  assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
2103         (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
2104         "Cannot truncate or zero extend with non-integer arguments!");
2105  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2106    return V;  // No conversion
2107  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
2108    return getTruncateExpr(V, Ty);
2109  return getZeroExtendExpr(V, Ty);
2110}
2111
2112/// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the
2113/// input value to the specified type.  If the type must be extended, it is sign
2114/// extended.
2115const SCEV *
2116ScalarEvolution::getTruncateOrSignExtend(const SCEV *V,
2117                                         const Type *Ty) {
2118  const Type *SrcTy = V->getType();
2119  assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
2120         (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
2121         "Cannot truncate or zero extend with non-integer arguments!");
2122  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2123    return V;  // No conversion
2124  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
2125    return getTruncateExpr(V, Ty);
2126  return getSignExtendExpr(V, Ty);
2127}
2128
2129/// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the
2130/// input value to the specified type.  If the type must be extended, it is zero
2131/// extended.  The conversion must not be narrowing.
2132const SCEV *
2133ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, const Type *Ty) {
2134  const Type *SrcTy = V->getType();
2135  assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
2136         (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
2137         "Cannot noop or zero extend with non-integer arguments!");
2138  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2139         "getNoopOrZeroExtend cannot truncate!");
2140  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2141    return V;  // No conversion
2142  return getZeroExtendExpr(V, Ty);
2143}
2144
2145/// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the
2146/// input value to the specified type.  If the type must be extended, it is sign
2147/// extended.  The conversion must not be narrowing.
2148const SCEV *
2149ScalarEvolution::getNoopOrSignExtend(const SCEV *V, const Type *Ty) {
2150  const Type *SrcTy = V->getType();
2151  assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
2152         (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
2153         "Cannot noop or sign extend with non-integer arguments!");
2154  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2155         "getNoopOrSignExtend cannot truncate!");
2156  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2157    return V;  // No conversion
2158  return getSignExtendExpr(V, Ty);
2159}
2160
2161/// getNoopOrAnyExtend - Return a SCEV corresponding to a conversion of
2162/// the input value to the specified type. If the type must be extended,
2163/// it is extended with unspecified bits. The conversion must not be
2164/// narrowing.
2165const SCEV *
2166ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, const Type *Ty) {
2167  const Type *SrcTy = V->getType();
2168  assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
2169         (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
2170         "Cannot noop or any extend with non-integer arguments!");
2171  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2172         "getNoopOrAnyExtend cannot truncate!");
2173  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2174    return V;  // No conversion
2175  return getAnyExtendExpr(V, Ty);
2176}
2177
2178/// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the
2179/// input value to the specified type.  The conversion must not be widening.
2180const SCEV *
2181ScalarEvolution::getTruncateOrNoop(const SCEV *V, const Type *Ty) {
2182  const Type *SrcTy = V->getType();
2183  assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
2184         (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
2185         "Cannot truncate or noop with non-integer arguments!");
2186  assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
2187         "getTruncateOrNoop cannot extend!");
2188  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2189    return V;  // No conversion
2190  return getTruncateExpr(V, Ty);
2191}
2192
2193/// getUMaxFromMismatchedTypes - Promote the operands to the wider of
2194/// the types using zero-extension, and then perform a umax operation
2195/// with them.
2196const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
2197                                                        const SCEV *RHS) {
2198  const SCEV *PromotedLHS = LHS;
2199  const SCEV *PromotedRHS = RHS;
2200
2201  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
2202    PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
2203  else
2204    PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
2205
2206  return getUMaxExpr(PromotedLHS, PromotedRHS);
2207}
2208
2209/// getUMinFromMismatchedTypes - Promote the operands to the wider of
2210/// the types using zero-extension, and then perform a umin operation
2211/// with them.
2212const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
2213                                                        const SCEV *RHS) {
2214  const SCEV *PromotedLHS = LHS;
2215  const SCEV *PromotedRHS = RHS;
2216
2217  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
2218    PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
2219  else
2220    PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
2221
2222  return getUMinExpr(PromotedLHS, PromotedRHS);
2223}
2224
2225/// ReplaceSymbolicValueWithConcrete - This looks up the computed SCEV value for
2226/// the specified instruction and replaces any references to the symbolic value
2227/// SymName with the specified value.  This is used during PHI resolution.
2228void
2229ScalarEvolution::ReplaceSymbolicValueWithConcrete(Instruction *I,
2230                                                  const SCEV *SymName,
2231                                                  const SCEV *NewVal) {
2232  std::map<SCEVCallbackVH, const SCEV *>::iterator SI =
2233    Scalars.find(SCEVCallbackVH(I, this));
2234  if (SI == Scalars.end()) return;
2235
2236  const SCEV *NV =
2237    SI->second->replaceSymbolicValuesWithConcrete(SymName, NewVal, *this);
2238  if (NV == SI->second) return;  // No change.
2239
2240  SI->second = NV;       // Update the scalars map!
2241
2242  // Any instruction values that use this instruction might also need to be
2243  // updated!
2244  for (Value::use_iterator UI = I->use_begin(), E = I->use_end();
2245       UI != E; ++UI)
2246    ReplaceSymbolicValueWithConcrete(cast<Instruction>(*UI), SymName, NewVal);
2247}
2248
2249/// createNodeForPHI - PHI nodes have two cases.  Either the PHI node exists in
2250/// a loop header, making it a potential recurrence, or it doesn't.
2251///
2252const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
2253  if (PN->getNumIncomingValues() == 2)  // The loops have been canonicalized.
2254    if (const Loop *L = LI->getLoopFor(PN->getParent()))
2255      if (L->getHeader() == PN->getParent()) {
2256        // If it lives in the loop header, it has two incoming values, one
2257        // from outside the loop, and one from inside.
2258        unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0));
2259        unsigned BackEdge     = IncomingEdge^1;
2260
2261        // While we are analyzing this PHI node, handle its value symbolically.
2262        const SCEV *SymbolicName = getUnknown(PN);
2263        assert(Scalars.find(PN) == Scalars.end() &&
2264               "PHI node already processed?");
2265        Scalars.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName));
2266
2267        // Using this symbolic name for the PHI, analyze the value coming around
2268        // the back-edge.
2269        const SCEV *BEValue = getSCEV(PN->getIncomingValue(BackEdge));
2270
2271        // NOTE: If BEValue is loop invariant, we know that the PHI node just
2272        // has a special value for the first iteration of the loop.
2273
2274        // If the value coming around the backedge is an add with the symbolic
2275        // value we just inserted, then we found a simple induction variable!
2276        if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
2277          // If there is a single occurrence of the symbolic value, replace it
2278          // with a recurrence.
2279          unsigned FoundIndex = Add->getNumOperands();
2280          for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
2281            if (Add->getOperand(i) == SymbolicName)
2282              if (FoundIndex == e) {
2283                FoundIndex = i;
2284                break;
2285              }
2286
2287          if (FoundIndex != Add->getNumOperands()) {
2288            // Create an add with everything but the specified operand.
2289            SmallVector<const SCEV *, 8> Ops;
2290            for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
2291              if (i != FoundIndex)
2292                Ops.push_back(Add->getOperand(i));
2293            const SCEV *Accum = getAddExpr(Ops);
2294
2295            // This is not a valid addrec if the step amount is varying each
2296            // loop iteration, but is not itself an addrec in this loop.
2297            if (Accum->isLoopInvariant(L) ||
2298                (isa<SCEVAddRecExpr>(Accum) &&
2299                 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
2300              const SCEV *StartVal =
2301                getSCEV(PN->getIncomingValue(IncomingEdge));
2302              const SCEV *PHISCEV =
2303                getAddRecExpr(StartVal, Accum, L);
2304
2305              // Okay, for the entire analysis of this edge we assumed the PHI
2306              // to be symbolic.  We now need to go back and update all of the
2307              // entries for the scalars that use the PHI (except for the PHI
2308              // itself) to use the new analyzed value instead of the "symbolic"
2309              // value.
2310              ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV);
2311              return PHISCEV;
2312            }
2313          }
2314        } else if (const SCEVAddRecExpr *AddRec =
2315                     dyn_cast<SCEVAddRecExpr>(BEValue)) {
2316          // Otherwise, this could be a loop like this:
2317          //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
2318          // In this case, j = {1,+,1}  and BEValue is j.
2319          // Because the other in-value of i (0) fits the evolution of BEValue
2320          // i really is an addrec evolution.
2321          if (AddRec->getLoop() == L && AddRec->isAffine()) {
2322            const SCEV *StartVal = getSCEV(PN->getIncomingValue(IncomingEdge));
2323
2324            // If StartVal = j.start - j.stride, we can use StartVal as the
2325            // initial step of the addrec evolution.
2326            if (StartVal == getMinusSCEV(AddRec->getOperand(0),
2327                                            AddRec->getOperand(1))) {
2328              const SCEV *PHISCEV =
2329                 getAddRecExpr(StartVal, AddRec->getOperand(1), L);
2330
2331              // Okay, for the entire analysis of this edge we assumed the PHI
2332              // to be symbolic.  We now need to go back and update all of the
2333              // entries for the scalars that use the PHI (except for the PHI
2334              // itself) to use the new analyzed value instead of the "symbolic"
2335              // value.
2336              ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV);
2337              return PHISCEV;
2338            }
2339          }
2340        }
2341
2342        return SymbolicName;
2343      }
2344
2345  // If it's not a loop phi, we can't handle it yet.
2346  return getUnknown(PN);
2347}
2348
2349/// createNodeForGEP - Expand GEP instructions into add and multiply
2350/// operations. This allows them to be analyzed by regular SCEV code.
2351///
2352const SCEV *ScalarEvolution::createNodeForGEP(User *GEP) {
2353
2354  const Type *IntPtrTy = TD->getIntPtrType();
2355  Value *Base = GEP->getOperand(0);
2356  // Don't attempt to analyze GEPs over unsized objects.
2357  if (!cast<PointerType>(Base->getType())->getElementType()->isSized())
2358    return getUnknown(GEP);
2359  const SCEV *TotalOffset = getIntegerSCEV(0, IntPtrTy);
2360  gep_type_iterator GTI = gep_type_begin(GEP);
2361  for (GetElementPtrInst::op_iterator I = next(GEP->op_begin()),
2362                                      E = GEP->op_end();
2363       I != E; ++I) {
2364    Value *Index = *I;
2365    // Compute the (potentially symbolic) offset in bytes for this index.
2366    if (const StructType *STy = dyn_cast<StructType>(*GTI++)) {
2367      // For a struct, add the member offset.
2368      const StructLayout &SL = *TD->getStructLayout(STy);
2369      unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue();
2370      uint64_t Offset = SL.getElementOffset(FieldNo);
2371      TotalOffset = getAddExpr(TotalOffset,
2372                                  getIntegerSCEV(Offset, IntPtrTy));
2373    } else {
2374      // For an array, add the element offset, explicitly scaled.
2375      const SCEV *LocalOffset = getSCEV(Index);
2376      if (!isa<PointerType>(LocalOffset->getType()))
2377        // Getelementptr indicies are signed.
2378        LocalOffset = getTruncateOrSignExtend(LocalOffset,
2379                                              IntPtrTy);
2380      LocalOffset =
2381        getMulExpr(LocalOffset,
2382                   getIntegerSCEV(TD->getTypeAllocSize(*GTI),
2383                                  IntPtrTy));
2384      TotalOffset = getAddExpr(TotalOffset, LocalOffset);
2385    }
2386  }
2387  return getAddExpr(getSCEV(Base), TotalOffset);
2388}
2389
2390/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
2391/// guaranteed to end in (at every loop iteration).  It is, at the same time,
2392/// the minimum number of times S is divisible by 2.  For example, given {4,+,8}
2393/// it returns 2.  If S is guaranteed to be 0, it returns the bitwidth of S.
2394uint32_t
2395ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
2396  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
2397    return C->getValue()->getValue().countTrailingZeros();
2398
2399  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
2400    return std::min(GetMinTrailingZeros(T->getOperand()),
2401                    (uint32_t)getTypeSizeInBits(T->getType()));
2402
2403  if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
2404    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
2405    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
2406             getTypeSizeInBits(E->getType()) : OpRes;
2407  }
2408
2409  if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
2410    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
2411    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
2412             getTypeSizeInBits(E->getType()) : OpRes;
2413  }
2414
2415  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
2416    // The result is the min of all operands results.
2417    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
2418    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
2419      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
2420    return MinOpRes;
2421  }
2422
2423  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
2424    // The result is the sum of all operands results.
2425    uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
2426    uint32_t BitWidth = getTypeSizeInBits(M->getType());
2427    for (unsigned i = 1, e = M->getNumOperands();
2428         SumOpRes != BitWidth && i != e; ++i)
2429      SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
2430                          BitWidth);
2431    return SumOpRes;
2432  }
2433
2434  if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
2435    // The result is the min of all operands results.
2436    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
2437    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
2438      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
2439    return MinOpRes;
2440  }
2441
2442  if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
2443    // The result is the min of all operands results.
2444    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
2445    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
2446      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
2447    return MinOpRes;
2448  }
2449
2450  if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
2451    // The result is the min of all operands results.
2452    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
2453    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
2454      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
2455    return MinOpRes;
2456  }
2457
2458  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
2459    // For a SCEVUnknown, ask ValueTracking.
2460    unsigned BitWidth = getTypeSizeInBits(U->getType());
2461    APInt Mask = APInt::getAllOnesValue(BitWidth);
2462    APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
2463    ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones);
2464    return Zeros.countTrailingOnes();
2465  }
2466
2467  // SCEVUDivExpr
2468  return 0;
2469}
2470
2471uint32_t
2472ScalarEvolution::GetMinLeadingZeros(const SCEV *S) {
2473  // TODO: Handle other SCEV expression types here.
2474
2475  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
2476    return C->getValue()->getValue().countLeadingZeros();
2477
2478  if (const SCEVZeroExtendExpr *C = dyn_cast<SCEVZeroExtendExpr>(S)) {
2479    // A zero-extension cast adds zero bits.
2480    return GetMinLeadingZeros(C->getOperand()) +
2481           (getTypeSizeInBits(C->getType()) -
2482            getTypeSizeInBits(C->getOperand()->getType()));
2483  }
2484
2485  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
2486    // For a SCEVUnknown, ask ValueTracking.
2487    unsigned BitWidth = getTypeSizeInBits(U->getType());
2488    APInt Mask = APInt::getAllOnesValue(BitWidth);
2489    APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
2490    ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones, TD);
2491    return Zeros.countLeadingOnes();
2492  }
2493
2494  return 1;
2495}
2496
2497uint32_t
2498ScalarEvolution::GetMinSignBits(const SCEV *S) {
2499  // TODO: Handle other SCEV expression types here.
2500
2501  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
2502    const APInt &A = C->getValue()->getValue();
2503    return A.isNegative() ? A.countLeadingOnes() :
2504                            A.countLeadingZeros();
2505  }
2506
2507  if (const SCEVSignExtendExpr *C = dyn_cast<SCEVSignExtendExpr>(S)) {
2508    // A sign-extension cast adds sign bits.
2509    return GetMinSignBits(C->getOperand()) +
2510           (getTypeSizeInBits(C->getType()) -
2511            getTypeSizeInBits(C->getOperand()->getType()));
2512  }
2513
2514  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
2515    unsigned BitWidth = getTypeSizeInBits(A->getType());
2516
2517    // Special case decrementing a value (ADD X, -1):
2518    if (const SCEVConstant *CRHS = dyn_cast<SCEVConstant>(A->getOperand(0)))
2519      if (CRHS->isAllOnesValue()) {
2520        SmallVector<const SCEV *, 4> OtherOps(A->op_begin() + 1, A->op_end());
2521        const SCEV *OtherOpsAdd = getAddExpr(OtherOps);
2522        unsigned LZ = GetMinLeadingZeros(OtherOpsAdd);
2523
2524        // If the input is known to be 0 or 1, the output is 0/-1, which is all
2525        // sign bits set.
2526        if (LZ == BitWidth - 1)
2527          return BitWidth;
2528
2529        // If we are subtracting one from a positive number, there is no carry
2530        // out of the result.
2531        if (LZ > 0)
2532          return GetMinSignBits(OtherOpsAdd);
2533      }
2534
2535    // Add can have at most one carry bit.  Thus we know that the output
2536    // is, at worst, one more bit than the inputs.
2537    unsigned Min = BitWidth;
2538    for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
2539      unsigned N = GetMinSignBits(A->getOperand(i));
2540      Min = std::min(Min, N) - 1;
2541      if (Min == 0) return 1;
2542    }
2543    return 1;
2544  }
2545
2546  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
2547    // For a SCEVUnknown, ask ValueTracking.
2548    return ComputeNumSignBits(U->getValue(), TD);
2549  }
2550
2551  return 1;
2552}
2553
2554/// createSCEV - We know that there is no SCEV for the specified value.
2555/// Analyze the expression.
2556///
2557const SCEV *ScalarEvolution::createSCEV(Value *V) {
2558  if (!isSCEVable(V->getType()))
2559    return getUnknown(V);
2560
2561  unsigned Opcode = Instruction::UserOp1;
2562  if (Instruction *I = dyn_cast<Instruction>(V))
2563    Opcode = I->getOpcode();
2564  else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
2565    Opcode = CE->getOpcode();
2566  else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
2567    return getConstant(CI);
2568  else if (isa<ConstantPointerNull>(V))
2569    return getIntegerSCEV(0, V->getType());
2570  else if (isa<UndefValue>(V))
2571    return getIntegerSCEV(0, V->getType());
2572  else
2573    return getUnknown(V);
2574
2575  User *U = cast<User>(V);
2576  switch (Opcode) {
2577  case Instruction::Add:
2578    return getAddExpr(getSCEV(U->getOperand(0)),
2579                      getSCEV(U->getOperand(1)));
2580  case Instruction::Mul:
2581    return getMulExpr(getSCEV(U->getOperand(0)),
2582                      getSCEV(U->getOperand(1)));
2583  case Instruction::UDiv:
2584    return getUDivExpr(getSCEV(U->getOperand(0)),
2585                       getSCEV(U->getOperand(1)));
2586  case Instruction::Sub:
2587    return getMinusSCEV(getSCEV(U->getOperand(0)),
2588                        getSCEV(U->getOperand(1)));
2589  case Instruction::And:
2590    // For an expression like x&255 that merely masks off the high bits,
2591    // use zext(trunc(x)) as the SCEV expression.
2592    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
2593      if (CI->isNullValue())
2594        return getSCEV(U->getOperand(1));
2595      if (CI->isAllOnesValue())
2596        return getSCEV(U->getOperand(0));
2597      const APInt &A = CI->getValue();
2598
2599      // Instcombine's ShrinkDemandedConstant may strip bits out of
2600      // constants, obscuring what would otherwise be a low-bits mask.
2601      // Use ComputeMaskedBits to compute what ShrinkDemandedConstant
2602      // knew about to reconstruct a low-bits mask value.
2603      unsigned LZ = A.countLeadingZeros();
2604      unsigned BitWidth = A.getBitWidth();
2605      APInt AllOnes = APInt::getAllOnesValue(BitWidth);
2606      APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
2607      ComputeMaskedBits(U->getOperand(0), AllOnes, KnownZero, KnownOne, TD);
2608
2609      APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ);
2610
2611      if (LZ != 0 && !((~A & ~KnownZero) & EffectiveMask))
2612        return
2613          getZeroExtendExpr(getTruncateExpr(getSCEV(U->getOperand(0)),
2614                                            IntegerType::get(BitWidth - LZ)),
2615                            U->getType());
2616    }
2617    break;
2618
2619  case Instruction::Or:
2620    // If the RHS of the Or is a constant, we may have something like:
2621    // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
2622    // optimizations will transparently handle this case.
2623    //
2624    // In order for this transformation to be safe, the LHS must be of the
2625    // form X*(2^n) and the Or constant must be less than 2^n.
2626    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
2627      const SCEV *LHS = getSCEV(U->getOperand(0));
2628      const APInt &CIVal = CI->getValue();
2629      if (GetMinTrailingZeros(LHS) >=
2630          (CIVal.getBitWidth() - CIVal.countLeadingZeros()))
2631        return getAddExpr(LHS, getSCEV(U->getOperand(1)));
2632    }
2633    break;
2634  case Instruction::Xor:
2635    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
2636      // If the RHS of the xor is a signbit, then this is just an add.
2637      // Instcombine turns add of signbit into xor as a strength reduction step.
2638      if (CI->getValue().isSignBit())
2639        return getAddExpr(getSCEV(U->getOperand(0)),
2640                          getSCEV(U->getOperand(1)));
2641
2642      // If the RHS of xor is -1, then this is a not operation.
2643      if (CI->isAllOnesValue())
2644        return getNotSCEV(getSCEV(U->getOperand(0)));
2645
2646      // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
2647      // This is a variant of the check for xor with -1, and it handles
2648      // the case where instcombine has trimmed non-demanded bits out
2649      // of an xor with -1.
2650      if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0)))
2651        if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1)))
2652          if (BO->getOpcode() == Instruction::And &&
2653              LCI->getValue() == CI->getValue())
2654            if (const SCEVZeroExtendExpr *Z =
2655                  dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) {
2656              const Type *UTy = U->getType();
2657              const SCEV *Z0 = Z->getOperand();
2658              const Type *Z0Ty = Z0->getType();
2659              unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
2660
2661              // If C is a low-bits mask, the zero extend is zerving to
2662              // mask off the high bits. Complement the operand and
2663              // re-apply the zext.
2664              if (APIntOps::isMask(Z0TySize, CI->getValue()))
2665                return getZeroExtendExpr(getNotSCEV(Z0), UTy);
2666
2667              // If C is a single bit, it may be in the sign-bit position
2668              // before the zero-extend. In this case, represent the xor
2669              // using an add, which is equivalent, and re-apply the zext.
2670              APInt Trunc = APInt(CI->getValue()).trunc(Z0TySize);
2671              if (APInt(Trunc).zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
2672                  Trunc.isSignBit())
2673                return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
2674                                         UTy);
2675            }
2676    }
2677    break;
2678
2679  case Instruction::Shl:
2680    // Turn shift left of a constant amount into a multiply.
2681    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
2682      uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
2683      Constant *X = ConstantInt::get(
2684        APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
2685      return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
2686    }
2687    break;
2688
2689  case Instruction::LShr:
2690    // Turn logical shift right of a constant into a unsigned divide.
2691    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
2692      uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
2693      Constant *X = ConstantInt::get(
2694        APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
2695      return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
2696    }
2697    break;
2698
2699  case Instruction::AShr:
2700    // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
2701    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1)))
2702      if (Instruction *L = dyn_cast<Instruction>(U->getOperand(0)))
2703        if (L->getOpcode() == Instruction::Shl &&
2704            L->getOperand(1) == U->getOperand(1)) {
2705          unsigned BitWidth = getTypeSizeInBits(U->getType());
2706          uint64_t Amt = BitWidth - CI->getZExtValue();
2707          if (Amt == BitWidth)
2708            return getSCEV(L->getOperand(0));       // shift by zero --> noop
2709          if (Amt > BitWidth)
2710            return getIntegerSCEV(0, U->getType()); // value is undefined
2711          return
2712            getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)),
2713                                                      IntegerType::get(Amt)),
2714                                 U->getType());
2715        }
2716    break;
2717
2718  case Instruction::Trunc:
2719    return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
2720
2721  case Instruction::ZExt:
2722    return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
2723
2724  case Instruction::SExt:
2725    return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
2726
2727  case Instruction::BitCast:
2728    // BitCasts are no-op casts so we just eliminate the cast.
2729    if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
2730      return getSCEV(U->getOperand(0));
2731    break;
2732
2733  case Instruction::IntToPtr:
2734    if (!TD) break; // Without TD we can't analyze pointers.
2735    return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)),
2736                                   TD->getIntPtrType());
2737
2738  case Instruction::PtrToInt:
2739    if (!TD) break; // Without TD we can't analyze pointers.
2740    return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)),
2741                                   U->getType());
2742
2743  case Instruction::GetElementPtr:
2744    if (!TD) break; // Without TD we can't analyze pointers.
2745    return createNodeForGEP(U);
2746
2747  case Instruction::PHI:
2748    return createNodeForPHI(cast<PHINode>(U));
2749
2750  case Instruction::Select:
2751    // This could be a smax or umax that was lowered earlier.
2752    // Try to recover it.
2753    if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
2754      Value *LHS = ICI->getOperand(0);
2755      Value *RHS = ICI->getOperand(1);
2756      switch (ICI->getPredicate()) {
2757      case ICmpInst::ICMP_SLT:
2758      case ICmpInst::ICMP_SLE:
2759        std::swap(LHS, RHS);
2760        // fall through
2761      case ICmpInst::ICMP_SGT:
2762      case ICmpInst::ICMP_SGE:
2763        if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
2764          return getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
2765        else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
2766          return getSMinExpr(getSCEV(LHS), getSCEV(RHS));
2767        break;
2768      case ICmpInst::ICMP_ULT:
2769      case ICmpInst::ICMP_ULE:
2770        std::swap(LHS, RHS);
2771        // fall through
2772      case ICmpInst::ICMP_UGT:
2773      case ICmpInst::ICMP_UGE:
2774        if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
2775          return getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
2776        else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
2777          return getUMinExpr(getSCEV(LHS), getSCEV(RHS));
2778        break;
2779      case ICmpInst::ICMP_NE:
2780        // n != 0 ? n : 1  ->  umax(n, 1)
2781        if (LHS == U->getOperand(1) &&
2782            isa<ConstantInt>(U->getOperand(2)) &&
2783            cast<ConstantInt>(U->getOperand(2))->isOne() &&
2784            isa<ConstantInt>(RHS) &&
2785            cast<ConstantInt>(RHS)->isZero())
2786          return getUMaxExpr(getSCEV(LHS), getSCEV(U->getOperand(2)));
2787        break;
2788      case ICmpInst::ICMP_EQ:
2789        // n == 0 ? 1 : n  ->  umax(n, 1)
2790        if (LHS == U->getOperand(2) &&
2791            isa<ConstantInt>(U->getOperand(1)) &&
2792            cast<ConstantInt>(U->getOperand(1))->isOne() &&
2793            isa<ConstantInt>(RHS) &&
2794            cast<ConstantInt>(RHS)->isZero())
2795          return getUMaxExpr(getSCEV(LHS), getSCEV(U->getOperand(1)));
2796        break;
2797      default:
2798        break;
2799      }
2800    }
2801
2802  default: // We cannot analyze this expression.
2803    break;
2804  }
2805
2806  return getUnknown(V);
2807}
2808
2809
2810
2811//===----------------------------------------------------------------------===//
2812//                   Iteration Count Computation Code
2813//
2814
2815/// getBackedgeTakenCount - If the specified loop has a predictable
2816/// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute
2817/// object. The backedge-taken count is the number of times the loop header
2818/// will be branched to from within the loop. This is one less than the
2819/// trip count of the loop, since it doesn't count the first iteration,
2820/// when the header is branched to from outside the loop.
2821///
2822/// Note that it is not valid to call this method on a loop without a
2823/// loop-invariant backedge-taken count (see
2824/// hasLoopInvariantBackedgeTakenCount).
2825///
2826const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) {
2827  return getBackedgeTakenInfo(L).Exact;
2828}
2829
2830/// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except
2831/// return the least SCEV value that is known never to be less than the
2832/// actual backedge taken count.
2833const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {
2834  return getBackedgeTakenInfo(L).Max;
2835}
2836
2837/// PushLoopPHIs - Push PHI nodes in the header of the given loop
2838/// onto the given Worklist.
2839static void
2840PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
2841  BasicBlock *Header = L->getHeader();
2842
2843  // Push all Loop-header PHIs onto the Worklist stack.
2844  for (BasicBlock::iterator I = Header->begin();
2845       PHINode *PN = dyn_cast<PHINode>(I); ++I)
2846    Worklist.push_back(PN);
2847}
2848
2849/// PushDefUseChildren - Push users of the given Instruction
2850/// onto the given Worklist.
2851static void
2852PushDefUseChildren(Instruction *I,
2853                   SmallVectorImpl<Instruction *> &Worklist) {
2854  // Push the def-use children onto the Worklist stack.
2855  for (Value::use_iterator UI = I->use_begin(), UE = I->use_end();
2856       UI != UE; ++UI)
2857    Worklist.push_back(cast<Instruction>(UI));
2858}
2859
2860const ScalarEvolution::BackedgeTakenInfo &
2861ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
2862  // Initially insert a CouldNotCompute for this loop. If the insertion
2863  // succeeds, procede to actually compute a backedge-taken count and
2864  // update the value. The temporary CouldNotCompute value tells SCEV
2865  // code elsewhere that it shouldn't attempt to request a new
2866  // backedge-taken count, which could result in infinite recursion.
2867  std::pair<std::map<const Loop*, BackedgeTakenInfo>::iterator, bool> Pair =
2868    BackedgeTakenCounts.insert(std::make_pair(L, getCouldNotCompute()));
2869  if (Pair.second) {
2870    BackedgeTakenInfo ItCount = ComputeBackedgeTakenCount(L);
2871    if (ItCount.Exact != getCouldNotCompute()) {
2872      assert(ItCount.Exact->isLoopInvariant(L) &&
2873             ItCount.Max->isLoopInvariant(L) &&
2874             "Computed trip count isn't loop invariant for loop!");
2875      ++NumTripCountsComputed;
2876
2877      // Update the value in the map.
2878      Pair.first->second = ItCount;
2879    } else {
2880      if (ItCount.Max != getCouldNotCompute())
2881        // Update the value in the map.
2882        Pair.first->second = ItCount;
2883      if (isa<PHINode>(L->getHeader()->begin()))
2884        // Only count loops that have phi nodes as not being computable.
2885        ++NumTripCountsNotComputed;
2886    }
2887
2888    // Now that we know more about the trip count for this loop, forget any
2889    // existing SCEV values for PHI nodes in this loop since they are only
2890    // conservative estimates made without the benefit of trip count
2891    // information. This is similar to the code in
2892    // forgetLoopBackedgeTakenCount, except that it handles SCEVUnknown PHI
2893    // nodes specially.
2894    if (ItCount.hasAnyInfo()) {
2895      SmallVector<Instruction *, 16> Worklist;
2896      PushLoopPHIs(L, Worklist);
2897
2898      SmallPtrSet<Instruction *, 8> Visited;
2899      while (!Worklist.empty()) {
2900        Instruction *I = Worklist.pop_back_val();
2901        if (!Visited.insert(I)) continue;
2902
2903        std::map<SCEVCallbackVH, const SCEV*>::iterator It =
2904          Scalars.find(static_cast<Value *>(I));
2905        if (It != Scalars.end()) {
2906          // SCEVUnknown for a PHI either means that it has an unrecognized
2907          // structure, or it's a PHI that's in the progress of being computed
2908          // by createNodeForPHI.  In the former case, additional loop trip count
2909          // information isn't going to change anything. In the later case,
2910          // createNodeForPHI will perform the necessary updates on its own when
2911          // it gets to that point.
2912          if (!isa<PHINode>(I) || !isa<SCEVUnknown>(It->second))
2913            Scalars.erase(It);
2914          ValuesAtScopes.erase(I);
2915          if (PHINode *PN = dyn_cast<PHINode>(I))
2916            ConstantEvolutionLoopExitValue.erase(PN);
2917        }
2918
2919        PushDefUseChildren(I, Worklist);
2920      }
2921    }
2922  }
2923  return Pair.first->second;
2924}
2925
2926/// forgetLoopBackedgeTakenCount - This method should be called by the
2927/// client when it has changed a loop in a way that may effect
2928/// ScalarEvolution's ability to compute a trip count, or if the loop
2929/// is deleted.
2930void ScalarEvolution::forgetLoopBackedgeTakenCount(const Loop *L) {
2931  BackedgeTakenCounts.erase(L);
2932
2933  SmallVector<Instruction *, 16> Worklist;
2934  PushLoopPHIs(L, Worklist);
2935
2936  SmallPtrSet<Instruction *, 8> Visited;
2937  while (!Worklist.empty()) {
2938    Instruction *I = Worklist.pop_back_val();
2939    if (!Visited.insert(I)) continue;
2940
2941    std::map<SCEVCallbackVH, const SCEV*>::iterator It =
2942      Scalars.find(static_cast<Value *>(I));
2943    if (It != Scalars.end()) {
2944      Scalars.erase(It);
2945      ValuesAtScopes.erase(I);
2946      if (PHINode *PN = dyn_cast<PHINode>(I))
2947        ConstantEvolutionLoopExitValue.erase(PN);
2948    }
2949
2950    PushDefUseChildren(I, Worklist);
2951  }
2952}
2953
2954/// ComputeBackedgeTakenCount - Compute the number of times the backedge
2955/// of the specified loop will execute.
2956ScalarEvolution::BackedgeTakenInfo
2957ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
2958  SmallVector<BasicBlock*, 8> ExitingBlocks;
2959  L->getExitingBlocks(ExitingBlocks);
2960
2961  // Examine all exits and pick the most conservative values.
2962  const SCEV *BECount = getCouldNotCompute();
2963  const SCEV *MaxBECount = getCouldNotCompute();
2964  bool CouldNotComputeBECount = false;
2965  for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
2966    BackedgeTakenInfo NewBTI =
2967      ComputeBackedgeTakenCountFromExit(L, ExitingBlocks[i]);
2968
2969    if (NewBTI.Exact == getCouldNotCompute()) {
2970      // We couldn't compute an exact value for this exit, so
2971      // we won't be able to compute an exact value for the loop.
2972      CouldNotComputeBECount = true;
2973      BECount = getCouldNotCompute();
2974    } else if (!CouldNotComputeBECount) {
2975      if (BECount == getCouldNotCompute())
2976        BECount = NewBTI.Exact;
2977      else
2978        BECount = getUMinFromMismatchedTypes(BECount, NewBTI.Exact);
2979    }
2980    if (MaxBECount == getCouldNotCompute())
2981      MaxBECount = NewBTI.Max;
2982    else if (NewBTI.Max != getCouldNotCompute())
2983      MaxBECount = getUMinFromMismatchedTypes(MaxBECount, NewBTI.Max);
2984  }
2985
2986  return BackedgeTakenInfo(BECount, MaxBECount);
2987}
2988
2989/// ComputeBackedgeTakenCountFromExit - Compute the number of times the backedge
2990/// of the specified loop will execute if it exits via the specified block.
2991ScalarEvolution::BackedgeTakenInfo
2992ScalarEvolution::ComputeBackedgeTakenCountFromExit(const Loop *L,
2993                                                   BasicBlock *ExitingBlock) {
2994
2995  // Okay, we've chosen an exiting block.  See what condition causes us to
2996  // exit at this block.
2997  //
2998  // FIXME: we should be able to handle switch instructions (with a single exit)
2999  BranchInst *ExitBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
3000  if (ExitBr == 0) return getCouldNotCompute();
3001  assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!");
3002
3003  // At this point, we know we have a conditional branch that determines whether
3004  // the loop is exited.  However, we don't know if the branch is executed each
3005  // time through the loop.  If not, then the execution count of the branch will
3006  // not be equal to the trip count of the loop.
3007  //
3008  // Currently we check for this by checking to see if the Exit branch goes to
3009  // the loop header.  If so, we know it will always execute the same number of
3010  // times as the loop.  We also handle the case where the exit block *is* the
3011  // loop header.  This is common for un-rotated loops.
3012  //
3013  // If both of those tests fail, walk up the unique predecessor chain to the
3014  // header, stopping if there is an edge that doesn't exit the loop. If the
3015  // header is reached, the execution count of the branch will be equal to the
3016  // trip count of the loop.
3017  //
3018  //  More extensive analysis could be done to handle more cases here.
3019  //
3020  if (ExitBr->getSuccessor(0) != L->getHeader() &&
3021      ExitBr->getSuccessor(1) != L->getHeader() &&
3022      ExitBr->getParent() != L->getHeader()) {
3023    // The simple checks failed, try climbing the unique predecessor chain
3024    // up to the header.
3025    bool Ok = false;
3026    for (BasicBlock *BB = ExitBr->getParent(); BB; ) {
3027      BasicBlock *Pred = BB->getUniquePredecessor();
3028      if (!Pred)
3029        return getCouldNotCompute();
3030      TerminatorInst *PredTerm = Pred->getTerminator();
3031      for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) {
3032        BasicBlock *PredSucc = PredTerm->getSuccessor(i);
3033        if (PredSucc == BB)
3034          continue;
3035        // If the predecessor has a successor that isn't BB and isn't
3036        // outside the loop, assume the worst.
3037        if (L->contains(PredSucc))
3038          return getCouldNotCompute();
3039      }
3040      if (Pred == L->getHeader()) {
3041        Ok = true;
3042        break;
3043      }
3044      BB = Pred;
3045    }
3046    if (!Ok)
3047      return getCouldNotCompute();
3048  }
3049
3050  // Procede to the next level to examine the exit condition expression.
3051  return ComputeBackedgeTakenCountFromExitCond(L, ExitBr->getCondition(),
3052                                               ExitBr->getSuccessor(0),
3053                                               ExitBr->getSuccessor(1));
3054}
3055
3056/// ComputeBackedgeTakenCountFromExitCond - Compute the number of times the
3057/// backedge of the specified loop will execute if its exit condition
3058/// were a conditional branch of ExitCond, TBB, and FBB.
3059ScalarEvolution::BackedgeTakenInfo
3060ScalarEvolution::ComputeBackedgeTakenCountFromExitCond(const Loop *L,
3061                                                       Value *ExitCond,
3062                                                       BasicBlock *TBB,
3063                                                       BasicBlock *FBB) {
3064  // Check if the controlling expression for this loop is an And or Or.
3065  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
3066    if (BO->getOpcode() == Instruction::And) {
3067      // Recurse on the operands of the and.
3068      BackedgeTakenInfo BTI0 =
3069        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB);
3070      BackedgeTakenInfo BTI1 =
3071        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB);
3072      const SCEV *BECount = getCouldNotCompute();
3073      const SCEV *MaxBECount = getCouldNotCompute();
3074      if (L->contains(TBB)) {
3075        // Both conditions must be true for the loop to continue executing.
3076        // Choose the less conservative count.
3077        if (BTI0.Exact == getCouldNotCompute() ||
3078            BTI1.Exact == getCouldNotCompute())
3079          BECount = getCouldNotCompute();
3080        else
3081          BECount = getUMinFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
3082        if (BTI0.Max == getCouldNotCompute())
3083          MaxBECount = BTI1.Max;
3084        else if (BTI1.Max == getCouldNotCompute())
3085          MaxBECount = BTI0.Max;
3086        else
3087          MaxBECount = getUMinFromMismatchedTypes(BTI0.Max, BTI1.Max);
3088      } else {
3089        // Both conditions must be true for the loop to exit.
3090        assert(L->contains(FBB) && "Loop block has no successor in loop!");
3091        if (BTI0.Exact != getCouldNotCompute() &&
3092            BTI1.Exact != getCouldNotCompute())
3093          BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
3094        if (BTI0.Max != getCouldNotCompute() &&
3095            BTI1.Max != getCouldNotCompute())
3096          MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max);
3097      }
3098
3099      return BackedgeTakenInfo(BECount, MaxBECount);
3100    }
3101    if (BO->getOpcode() == Instruction::Or) {
3102      // Recurse on the operands of the or.
3103      BackedgeTakenInfo BTI0 =
3104        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB);
3105      BackedgeTakenInfo BTI1 =
3106        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB);
3107      const SCEV *BECount = getCouldNotCompute();
3108      const SCEV *MaxBECount = getCouldNotCompute();
3109      if (L->contains(FBB)) {
3110        // Both conditions must be false for the loop to continue executing.
3111        // Choose the less conservative count.
3112        if (BTI0.Exact == getCouldNotCompute() ||
3113            BTI1.Exact == getCouldNotCompute())
3114          BECount = getCouldNotCompute();
3115        else
3116          BECount = getUMinFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
3117        if (BTI0.Max == getCouldNotCompute())
3118          MaxBECount = BTI1.Max;
3119        else if (BTI1.Max == getCouldNotCompute())
3120          MaxBECount = BTI0.Max;
3121        else
3122          MaxBECount = getUMinFromMismatchedTypes(BTI0.Max, BTI1.Max);
3123      } else {
3124        // Both conditions must be false for the loop to exit.
3125        assert(L->contains(TBB) && "Loop block has no successor in loop!");
3126        if (BTI0.Exact != getCouldNotCompute() &&
3127            BTI1.Exact != getCouldNotCompute())
3128          BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
3129        if (BTI0.Max != getCouldNotCompute() &&
3130            BTI1.Max != getCouldNotCompute())
3131          MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max);
3132      }
3133
3134      return BackedgeTakenInfo(BECount, MaxBECount);
3135    }
3136  }
3137
3138  // With an icmp, it may be feasible to compute an exact backedge-taken count.
3139  // Procede to the next level to examine the icmp.
3140  if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond))
3141    return ComputeBackedgeTakenCountFromExitCondICmp(L, ExitCondICmp, TBB, FBB);
3142
3143  // If it's not an integer or pointer comparison then compute it the hard way.
3144  return ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB));
3145}
3146
3147/// ComputeBackedgeTakenCountFromExitCondICmp - Compute the number of times the
3148/// backedge of the specified loop will execute if its exit condition
3149/// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB.
3150ScalarEvolution::BackedgeTakenInfo
3151ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L,
3152                                                           ICmpInst *ExitCond,
3153                                                           BasicBlock *TBB,
3154                                                           BasicBlock *FBB) {
3155
3156  // If the condition was exit on true, convert the condition to exit on false
3157  ICmpInst::Predicate Cond;
3158  if (!L->contains(FBB))
3159    Cond = ExitCond->getPredicate();
3160  else
3161    Cond = ExitCond->getInversePredicate();
3162
3163  // Handle common loops like: for (X = "string"; *X; ++X)
3164  if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
3165    if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
3166      const SCEV *ItCnt =
3167        ComputeLoadConstantCompareBackedgeTakenCount(LI, RHS, L, Cond);
3168      if (!isa<SCEVCouldNotCompute>(ItCnt)) {
3169        unsigned BitWidth = getTypeSizeInBits(ItCnt->getType());
3170        return BackedgeTakenInfo(ItCnt,
3171                                 isa<SCEVConstant>(ItCnt) ? ItCnt :
3172                                   getConstant(APInt::getMaxValue(BitWidth)-1));
3173      }
3174    }
3175
3176  const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
3177  const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
3178
3179  // Try to evaluate any dependencies out of the loop.
3180  LHS = getSCEVAtScope(LHS, L);
3181  RHS = getSCEVAtScope(RHS, L);
3182
3183  // At this point, we would like to compute how many iterations of the
3184  // loop the predicate will return true for these inputs.
3185  if (LHS->isLoopInvariant(L) && !RHS->isLoopInvariant(L)) {
3186    // If there is a loop-invariant, force it into the RHS.
3187    std::swap(LHS, RHS);
3188    Cond = ICmpInst::getSwappedPredicate(Cond);
3189  }
3190
3191  // If we have a comparison of a chrec against a constant, try to use value
3192  // ranges to answer this query.
3193  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
3194    if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
3195      if (AddRec->getLoop() == L) {
3196        // Form the constant range.
3197        ConstantRange CompRange(
3198            ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue()));
3199
3200        const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
3201        if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
3202      }
3203
3204  switch (Cond) {
3205  case ICmpInst::ICMP_NE: {                     // while (X != Y)
3206    // Convert to: while (X-Y != 0)
3207    const SCEV *TC = HowFarToZero(getMinusSCEV(LHS, RHS), L);
3208    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
3209    break;
3210  }
3211  case ICmpInst::ICMP_EQ: {
3212    // Convert to: while (X-Y == 0)           // while (X == Y)
3213    const SCEV *TC = HowFarToNonZero(getMinusSCEV(LHS, RHS), L);
3214    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
3215    break;
3216  }
3217  case ICmpInst::ICMP_SLT: {
3218    BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, true);
3219    if (BTI.hasAnyInfo()) return BTI;
3220    break;
3221  }
3222  case ICmpInst::ICMP_SGT: {
3223    BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS),
3224                                             getNotSCEV(RHS), L, true);
3225    if (BTI.hasAnyInfo()) return BTI;
3226    break;
3227  }
3228  case ICmpInst::ICMP_ULT: {
3229    BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, false);
3230    if (BTI.hasAnyInfo()) return BTI;
3231    break;
3232  }
3233  case ICmpInst::ICMP_UGT: {
3234    BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS),
3235                                             getNotSCEV(RHS), L, false);
3236    if (BTI.hasAnyInfo()) return BTI;
3237    break;
3238  }
3239  default:
3240#if 0
3241    errs() << "ComputeBackedgeTakenCount ";
3242    if (ExitCond->getOperand(0)->getType()->isUnsigned())
3243      errs() << "[unsigned] ";
3244    errs() << *LHS << "   "
3245         << Instruction::getOpcodeName(Instruction::ICmp)
3246         << "   " << *RHS << "\n";
3247#endif
3248    break;
3249  }
3250  return
3251    ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB));
3252}
3253
3254static ConstantInt *
3255EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
3256                                ScalarEvolution &SE) {
3257  const SCEV *InVal = SE.getConstant(C);
3258  const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
3259  assert(isa<SCEVConstant>(Val) &&
3260         "Evaluation of SCEV at constant didn't fold correctly?");
3261  return cast<SCEVConstant>(Val)->getValue();
3262}
3263
3264/// GetAddressedElementFromGlobal - Given a global variable with an initializer
3265/// and a GEP expression (missing the pointer index) indexing into it, return
3266/// the addressed element of the initializer or null if the index expression is
3267/// invalid.
3268static Constant *
3269GetAddressedElementFromGlobal(LLVMContext *Context, GlobalVariable *GV,
3270                              const std::vector<ConstantInt*> &Indices) {
3271  Constant *Init = GV->getInitializer();
3272  for (unsigned i = 0, e = Indices.size(); i != e; ++i) {
3273    uint64_t Idx = Indices[i]->getZExtValue();
3274    if (ConstantStruct *CS = dyn_cast<ConstantStruct>(Init)) {
3275      assert(Idx < CS->getNumOperands() && "Bad struct index!");
3276      Init = cast<Constant>(CS->getOperand(Idx));
3277    } else if (ConstantArray *CA = dyn_cast<ConstantArray>(Init)) {
3278      if (Idx >= CA->getNumOperands()) return 0;  // Bogus program
3279      Init = cast<Constant>(CA->getOperand(Idx));
3280    } else if (isa<ConstantAggregateZero>(Init)) {
3281      if (const StructType *STy = dyn_cast<StructType>(Init->getType())) {
3282        assert(Idx < STy->getNumElements() && "Bad struct index!");
3283        Init = Context->getNullValue(STy->getElementType(Idx));
3284      } else if (const ArrayType *ATy = dyn_cast<ArrayType>(Init->getType())) {
3285        if (Idx >= ATy->getNumElements()) return 0;  // Bogus program
3286        Init = Context->getNullValue(ATy->getElementType());
3287      } else {
3288        LLVM_UNREACHABLE("Unknown constant aggregate type!");
3289      }
3290      return 0;
3291    } else {
3292      return 0; // Unknown initializer type
3293    }
3294  }
3295  return Init;
3296}
3297
3298/// ComputeLoadConstantCompareBackedgeTakenCount - Given an exit condition of
3299/// 'icmp op load X, cst', try to see if we can compute the backedge
3300/// execution count.
3301const SCEV *
3302ScalarEvolution::ComputeLoadConstantCompareBackedgeTakenCount(
3303                                                LoadInst *LI,
3304                                                Constant *RHS,
3305                                                const Loop *L,
3306                                                ICmpInst::Predicate predicate) {
3307  if (LI->isVolatile()) return getCouldNotCompute();
3308
3309  // Check to see if the loaded pointer is a getelementptr of a global.
3310  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0));
3311  if (!GEP) return getCouldNotCompute();
3312
3313  // Make sure that it is really a constant global we are gepping, with an
3314  // initializer, and make sure the first IDX is really 0.
3315  GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
3316  if (!GV || !GV->isConstant() || !GV->hasInitializer() ||
3317      GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) ||
3318      !cast<Constant>(GEP->getOperand(1))->isNullValue())
3319    return getCouldNotCompute();
3320
3321  // Okay, we allow one non-constant index into the GEP instruction.
3322  Value *VarIdx = 0;
3323  std::vector<ConstantInt*> Indexes;
3324  unsigned VarIdxNum = 0;
3325  for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i)
3326    if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
3327      Indexes.push_back(CI);
3328    } else if (!isa<ConstantInt>(GEP->getOperand(i))) {
3329      if (VarIdx) return getCouldNotCompute();  // Multiple non-constant idx's.
3330      VarIdx = GEP->getOperand(i);
3331      VarIdxNum = i-2;
3332      Indexes.push_back(0);
3333    }
3334
3335  // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant.
3336  // Check to see if X is a loop variant variable value now.
3337  const SCEV *Idx = getSCEV(VarIdx);
3338  Idx = getSCEVAtScope(Idx, L);
3339
3340  // We can only recognize very limited forms of loop index expressions, in
3341  // particular, only affine AddRec's like {C1,+,C2}.
3342  const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx);
3343  if (!IdxExpr || !IdxExpr->isAffine() || IdxExpr->isLoopInvariant(L) ||
3344      !isa<SCEVConstant>(IdxExpr->getOperand(0)) ||
3345      !isa<SCEVConstant>(IdxExpr->getOperand(1)))
3346    return getCouldNotCompute();
3347
3348  unsigned MaxSteps = MaxBruteForceIterations;
3349  for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
3350    ConstantInt *ItCst =
3351      ConstantInt::get(cast<IntegerType>(IdxExpr->getType()), IterationNum);
3352    ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this);
3353
3354    // Form the GEP offset.
3355    Indexes[VarIdxNum] = Val;
3356
3357    Constant *Result = GetAddressedElementFromGlobal(Context, GV, Indexes);
3358    if (Result == 0) break;  // Cannot compute!
3359
3360    // Evaluate the condition for this iteration.
3361    Result = ConstantExpr::getICmp(predicate, Result, RHS);
3362    if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
3363    if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
3364#if 0
3365      errs() << "\n***\n*** Computed loop count " << *ItCst
3366             << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
3367             << "***\n";
3368#endif
3369      ++NumArrayLenItCounts;
3370      return getConstant(ItCst);   // Found terminating iteration!
3371    }
3372  }
3373  return getCouldNotCompute();
3374}
3375
3376
3377/// CanConstantFold - Return true if we can constant fold an instruction of the
3378/// specified type, assuming that all operands were constants.
3379static bool CanConstantFold(const Instruction *I) {
3380  if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
3381      isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I))
3382    return true;
3383
3384  if (const CallInst *CI = dyn_cast<CallInst>(I))
3385    if (const Function *F = CI->getCalledFunction())
3386      return canConstantFoldCallTo(F);
3387  return false;
3388}
3389
3390/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
3391/// in the loop that V is derived from.  We allow arbitrary operations along the
3392/// way, but the operands of an operation must either be constants or a value
3393/// derived from a constant PHI.  If this expression does not fit with these
3394/// constraints, return null.
3395static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
3396  // If this is not an instruction, or if this is an instruction outside of the
3397  // loop, it can't be derived from a loop PHI.
3398  Instruction *I = dyn_cast<Instruction>(V);
3399  if (I == 0 || !L->contains(I->getParent())) return 0;
3400
3401  if (PHINode *PN = dyn_cast<PHINode>(I)) {
3402    if (L->getHeader() == I->getParent())
3403      return PN;
3404    else
3405      // We don't currently keep track of the control flow needed to evaluate
3406      // PHIs, so we cannot handle PHIs inside of loops.
3407      return 0;
3408  }
3409
3410  // If we won't be able to constant fold this expression even if the operands
3411  // are constants, return early.
3412  if (!CanConstantFold(I)) return 0;
3413
3414  // Otherwise, we can evaluate this instruction if all of its operands are
3415  // constant or derived from a PHI node themselves.
3416  PHINode *PHI = 0;
3417  for (unsigned Op = 0, e = I->getNumOperands(); Op != e; ++Op)
3418    if (!(isa<Constant>(I->getOperand(Op)) ||
3419          isa<GlobalValue>(I->getOperand(Op)))) {
3420      PHINode *P = getConstantEvolvingPHI(I->getOperand(Op), L);
3421      if (P == 0) return 0;  // Not evolving from PHI
3422      if (PHI == 0)
3423        PHI = P;
3424      else if (PHI != P)
3425        return 0;  // Evolving from multiple different PHIs.
3426    }
3427
3428  // This is a expression evolving from a constant PHI!
3429  return PHI;
3430}
3431
3432/// EvaluateExpression - Given an expression that passes the
3433/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
3434/// in the loop has the value PHIVal.  If we can't fold this expression for some
3435/// reason, return null.
3436static Constant *EvaluateExpression(Value *V, Constant *PHIVal) {
3437  if (isa<PHINode>(V)) return PHIVal;
3438  if (Constant *C = dyn_cast<Constant>(V)) return C;
3439  if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) return GV;
3440  Instruction *I = cast<Instruction>(V);
3441  LLVMContext *Context = I->getParent()->getContext();
3442
3443  std::vector<Constant*> Operands;
3444  Operands.resize(I->getNumOperands());
3445
3446  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
3447    Operands[i] = EvaluateExpression(I->getOperand(i), PHIVal);
3448    if (Operands[i] == 0) return 0;
3449  }
3450
3451  if (const CmpInst *CI = dyn_cast<CmpInst>(I))
3452    return ConstantFoldCompareInstOperands(CI->getPredicate(),
3453                                           &Operands[0], Operands.size(),
3454                                           Context);
3455  else
3456    return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
3457                                    &Operands[0], Operands.size(),
3458                                    Context);
3459}
3460
3461/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
3462/// in the header of its containing loop, we know the loop executes a
3463/// constant number of times, and the PHI node is just a recurrence
3464/// involving constants, fold it.
3465Constant *
3466ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
3467                                                   const APInt& BEs,
3468                                                   const Loop *L) {
3469  std::map<PHINode*, Constant*>::iterator I =
3470    ConstantEvolutionLoopExitValue.find(PN);
3471  if (I != ConstantEvolutionLoopExitValue.end())
3472    return I->second;
3473
3474  if (BEs.ugt(APInt(BEs.getBitWidth(),MaxBruteForceIterations)))
3475    return ConstantEvolutionLoopExitValue[PN] = 0;  // Not going to evaluate it.
3476
3477  Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
3478
3479  // Since the loop is canonicalized, the PHI node must have two entries.  One
3480  // entry must be a constant (coming in from outside of the loop), and the
3481  // second must be derived from the same PHI.
3482  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
3483  Constant *StartCST =
3484    dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
3485  if (StartCST == 0)
3486    return RetVal = 0;  // Must be a constant.
3487
3488  Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
3489  PHINode *PN2 = getConstantEvolvingPHI(BEValue, L);
3490  if (PN2 != PN)
3491    return RetVal = 0;  // Not derived from same PHI.
3492
3493  // Execute the loop symbolically to determine the exit value.
3494  if (BEs.getActiveBits() >= 32)
3495    return RetVal = 0; // More than 2^32-1 iterations?? Not doing it!
3496
3497  unsigned NumIterations = BEs.getZExtValue(); // must be in range
3498  unsigned IterationNum = 0;
3499  for (Constant *PHIVal = StartCST; ; ++IterationNum) {
3500    if (IterationNum == NumIterations)
3501      return RetVal = PHIVal;  // Got exit value!
3502
3503    // Compute the value of the PHI node for the next iteration.
3504    Constant *NextPHI = EvaluateExpression(BEValue, PHIVal);
3505    if (NextPHI == PHIVal)
3506      return RetVal = NextPHI;  // Stopped evolving!
3507    if (NextPHI == 0)
3508      return 0;        // Couldn't evaluate!
3509    PHIVal = NextPHI;
3510  }
3511}
3512
3513/// ComputeBackedgeTakenCountExhaustively - If the trip is known to execute a
3514/// constant number of times (the condition evolves only from constants),
3515/// try to evaluate a few iterations of the loop until we get the exit
3516/// condition gets a value of ExitWhen (true or false).  If we cannot
3517/// evaluate the trip count of the loop, return getCouldNotCompute().
3518const SCEV *
3519ScalarEvolution::ComputeBackedgeTakenCountExhaustively(const Loop *L,
3520                                                       Value *Cond,
3521                                                       bool ExitWhen) {
3522  PHINode *PN = getConstantEvolvingPHI(Cond, L);
3523  if (PN == 0) return getCouldNotCompute();
3524
3525  // Since the loop is canonicalized, the PHI node must have two entries.  One
3526  // entry must be a constant (coming in from outside of the loop), and the
3527  // second must be derived from the same PHI.
3528  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
3529  Constant *StartCST =
3530    dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
3531  if (StartCST == 0) return getCouldNotCompute();  // Must be a constant.
3532
3533  Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
3534  PHINode *PN2 = getConstantEvolvingPHI(BEValue, L);
3535  if (PN2 != PN) return getCouldNotCompute();  // Not derived from same PHI.
3536
3537  // Okay, we find a PHI node that defines the trip count of this loop.  Execute
3538  // the loop symbolically to determine when the condition gets a value of
3539  // "ExitWhen".
3540  unsigned IterationNum = 0;
3541  unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis.
3542  for (Constant *PHIVal = StartCST;
3543       IterationNum != MaxIterations; ++IterationNum) {
3544    ConstantInt *CondVal =
3545      dyn_cast_or_null<ConstantInt>(EvaluateExpression(Cond, PHIVal));
3546
3547    // Couldn't symbolically evaluate.
3548    if (!CondVal) return getCouldNotCompute();
3549
3550    if (CondVal->getValue() == uint64_t(ExitWhen)) {
3551      ++NumBruteForceTripCountsComputed;
3552      return getConstant(Type::Int32Ty, IterationNum);
3553    }
3554
3555    // Compute the value of the PHI node for the next iteration.
3556    Constant *NextPHI = EvaluateExpression(BEValue, PHIVal);
3557    if (NextPHI == 0 || NextPHI == PHIVal)
3558      return getCouldNotCompute();// Couldn't evaluate or not making progress...
3559    PHIVal = NextPHI;
3560  }
3561
3562  // Too many iterations were needed to evaluate.
3563  return getCouldNotCompute();
3564}
3565
3566/// getSCEVAtScope - Return a SCEV expression handle for the specified value
3567/// at the specified scope in the program.  The L value specifies a loop
3568/// nest to evaluate the expression at, where null is the top-level or a
3569/// specified loop is immediately inside of the loop.
3570///
3571/// This method can be used to compute the exit value for a variable defined
3572/// in a loop by querying what the value will hold in the parent loop.
3573///
3574/// In the case that a relevant loop exit value cannot be computed, the
3575/// original value V is returned.
3576const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
3577  // FIXME: this should be turned into a virtual method on SCEV!
3578
3579  if (isa<SCEVConstant>(V)) return V;
3580
3581  // If this instruction is evolved from a constant-evolving PHI, compute the
3582  // exit value from the loop without using SCEVs.
3583  if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
3584    if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
3585      const Loop *LI = (*this->LI)[I->getParent()];
3586      if (LI && LI->getParentLoop() == L)  // Looking for loop exit value.
3587        if (PHINode *PN = dyn_cast<PHINode>(I))
3588          if (PN->getParent() == LI->getHeader()) {
3589            // Okay, there is no closed form solution for the PHI node.  Check
3590            // to see if the loop that contains it has a known backedge-taken
3591            // count.  If so, we may be able to force computation of the exit
3592            // value.
3593            const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI);
3594            if (const SCEVConstant *BTCC =
3595                  dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
3596              // Okay, we know how many times the containing loop executes.  If
3597              // this is a constant evolving PHI node, get the final value at
3598              // the specified iteration number.
3599              Constant *RV = getConstantEvolutionLoopExitValue(PN,
3600                                                   BTCC->getValue()->getValue(),
3601                                                               LI);
3602              if (RV) return getSCEV(RV);
3603            }
3604          }
3605
3606      // Okay, this is an expression that we cannot symbolically evaluate
3607      // into a SCEV.  Check to see if it's possible to symbolically evaluate
3608      // the arguments into constants, and if so, try to constant propagate the
3609      // result.  This is particularly useful for computing loop exit values.
3610      if (CanConstantFold(I)) {
3611        // Check to see if we've folded this instruction at this loop before.
3612        std::map<const Loop *, Constant *> &Values = ValuesAtScopes[I];
3613        std::pair<std::map<const Loop *, Constant *>::iterator, bool> Pair =
3614          Values.insert(std::make_pair(L, static_cast<Constant *>(0)));
3615        if (!Pair.second)
3616          return Pair.first->second ? &*getSCEV(Pair.first->second) : V;
3617
3618        std::vector<Constant*> Operands;
3619        Operands.reserve(I->getNumOperands());
3620        for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
3621          Value *Op = I->getOperand(i);
3622          if (Constant *C = dyn_cast<Constant>(Op)) {
3623            Operands.push_back(C);
3624          } else {
3625            // If any of the operands is non-constant and if they are
3626            // non-integer and non-pointer, don't even try to analyze them
3627            // with scev techniques.
3628            if (!isSCEVable(Op->getType()))
3629              return V;
3630
3631            const SCEV *OpV = getSCEVAtScope(getSCEV(Op), L);
3632            if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(OpV)) {
3633              Constant *C = SC->getValue();
3634              if (C->getType() != Op->getType())
3635                C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
3636                                                                  Op->getType(),
3637                                                                  false),
3638                                          C, Op->getType());
3639              Operands.push_back(C);
3640            } else if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(OpV)) {
3641              if (Constant *C = dyn_cast<Constant>(SU->getValue())) {
3642                if (C->getType() != Op->getType())
3643                  C =
3644                    ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
3645                                                                  Op->getType(),
3646                                                                  false),
3647                                          C, Op->getType());
3648                Operands.push_back(C);
3649              } else
3650                return V;
3651            } else {
3652              return V;
3653            }
3654          }
3655        }
3656
3657        Constant *C;
3658        if (const CmpInst *CI = dyn_cast<CmpInst>(I))
3659          C = ConstantFoldCompareInstOperands(CI->getPredicate(),
3660                                              &Operands[0], Operands.size(),
3661                                              Context);
3662        else
3663          C = ConstantFoldInstOperands(I->getOpcode(), I->getType(),
3664                                       &Operands[0], Operands.size(), Context);
3665        Pair.first->second = C;
3666        return getSCEV(C);
3667      }
3668    }
3669
3670    // This is some other type of SCEVUnknown, just return it.
3671    return V;
3672  }
3673
3674  if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) {
3675    // Avoid performing the look-up in the common case where the specified
3676    // expression has no loop-variant portions.
3677    for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
3678      const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
3679      if (OpAtScope != Comm->getOperand(i)) {
3680        // Okay, at least one of these operands is loop variant but might be
3681        // foldable.  Build a new instance of the folded commutative expression.
3682        SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(),
3683                                            Comm->op_begin()+i);
3684        NewOps.push_back(OpAtScope);
3685
3686        for (++i; i != e; ++i) {
3687          OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
3688          NewOps.push_back(OpAtScope);
3689        }
3690        if (isa<SCEVAddExpr>(Comm))
3691          return getAddExpr(NewOps);
3692        if (isa<SCEVMulExpr>(Comm))
3693          return getMulExpr(NewOps);
3694        if (isa<SCEVSMaxExpr>(Comm))
3695          return getSMaxExpr(NewOps);
3696        if (isa<SCEVUMaxExpr>(Comm))
3697          return getUMaxExpr(NewOps);
3698        LLVM_UNREACHABLE("Unknown commutative SCEV type!");
3699      }
3700    }
3701    // If we got here, all operands are loop invariant.
3702    return Comm;
3703  }
3704
3705  if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
3706    const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
3707    const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
3708    if (LHS == Div->getLHS() && RHS == Div->getRHS())
3709      return Div;   // must be loop invariant
3710    return getUDivExpr(LHS, RHS);
3711  }
3712
3713  // If this is a loop recurrence for a loop that does not contain L, then we
3714  // are dealing with the final value computed by the loop.
3715  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
3716    if (!L || !AddRec->getLoop()->contains(L->getHeader())) {
3717      // To evaluate this recurrence, we need to know how many times the AddRec
3718      // loop iterates.  Compute this now.
3719      const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
3720      if (BackedgeTakenCount == getCouldNotCompute()) return AddRec;
3721
3722      // Then, evaluate the AddRec.
3723      return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
3724    }
3725    return AddRec;
3726  }
3727
3728  if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) {
3729    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
3730    if (Op == Cast->getOperand())
3731      return Cast;  // must be loop invariant
3732    return getZeroExtendExpr(Op, Cast->getType());
3733  }
3734
3735  if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) {
3736    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
3737    if (Op == Cast->getOperand())
3738      return Cast;  // must be loop invariant
3739    return getSignExtendExpr(Op, Cast->getType());
3740  }
3741
3742  if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) {
3743    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
3744    if (Op == Cast->getOperand())
3745      return Cast;  // must be loop invariant
3746    return getTruncateExpr(Op, Cast->getType());
3747  }
3748
3749  LLVM_UNREACHABLE("Unknown SCEV type!");
3750  return 0;
3751}
3752
3753/// getSCEVAtScope - This is a convenience function which does
3754/// getSCEVAtScope(getSCEV(V), L).
3755const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
3756  return getSCEVAtScope(getSCEV(V), L);
3757}
3758
3759/// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the
3760/// following equation:
3761///
3762///     A * X = B (mod N)
3763///
3764/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
3765/// A and B isn't important.
3766///
3767/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
3768static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
3769                                               ScalarEvolution &SE) {
3770  uint32_t BW = A.getBitWidth();
3771  assert(BW == B.getBitWidth() && "Bit widths must be the same.");
3772  assert(A != 0 && "A must be non-zero.");
3773
3774  // 1. D = gcd(A, N)
3775  //
3776  // The gcd of A and N may have only one prime factor: 2. The number of
3777  // trailing zeros in A is its multiplicity
3778  uint32_t Mult2 = A.countTrailingZeros();
3779  // D = 2^Mult2
3780
3781  // 2. Check if B is divisible by D.
3782  //
3783  // B is divisible by D if and only if the multiplicity of prime factor 2 for B
3784  // is not less than multiplicity of this prime factor for D.
3785  if (B.countTrailingZeros() < Mult2)
3786    return SE.getCouldNotCompute();
3787
3788  // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
3789  // modulo (N / D).
3790  //
3791  // (N / D) may need BW+1 bits in its representation.  Hence, we'll use this
3792  // bit width during computations.
3793  APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
3794  APInt Mod(BW + 1, 0);
3795  Mod.set(BW - Mult2);  // Mod = N / D
3796  APInt I = AD.multiplicativeInverse(Mod);
3797
3798  // 4. Compute the minimum unsigned root of the equation:
3799  // I * (B / D) mod (N / D)
3800  APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
3801
3802  // The result is guaranteed to be less than 2^BW so we may truncate it to BW
3803  // bits.
3804  return SE.getConstant(Result.trunc(BW));
3805}
3806
3807/// SolveQuadraticEquation - Find the roots of the quadratic equation for the
3808/// given quadratic chrec {L,+,M,+,N}.  This returns either the two roots (which
3809/// might be the same) or two SCEVCouldNotCompute objects.
3810///
3811static std::pair<const SCEV *,const SCEV *>
3812SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
3813  assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
3814  const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
3815  const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
3816  const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
3817
3818  // We currently can only solve this if the coefficients are constants.
3819  if (!LC || !MC || !NC) {
3820    const SCEV *CNC = SE.getCouldNotCompute();
3821    return std::make_pair(CNC, CNC);
3822  }
3823
3824  uint32_t BitWidth = LC->getValue()->getValue().getBitWidth();
3825  const APInt &L = LC->getValue()->getValue();
3826  const APInt &M = MC->getValue()->getValue();
3827  const APInt &N = NC->getValue()->getValue();
3828  APInt Two(BitWidth, 2);
3829  APInt Four(BitWidth, 4);
3830
3831  {
3832    using namespace APIntOps;
3833    const APInt& C = L;
3834    // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
3835    // The B coefficient is M-N/2
3836    APInt B(M);
3837    B -= sdiv(N,Two);
3838
3839    // The A coefficient is N/2
3840    APInt A(N.sdiv(Two));
3841
3842    // Compute the B^2-4ac term.
3843    APInt SqrtTerm(B);
3844    SqrtTerm *= B;
3845    SqrtTerm -= Four * (A * C);
3846
3847    // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
3848    // integer value or else APInt::sqrt() will assert.
3849    APInt SqrtVal(SqrtTerm.sqrt());
3850
3851    // Compute the two solutions for the quadratic formula.
3852    // The divisions must be performed as signed divisions.
3853    APInt NegB(-B);
3854    APInt TwoA( A << 1 );
3855    if (TwoA.isMinValue()) {
3856      const SCEV *CNC = SE.getCouldNotCompute();
3857      return std::make_pair(CNC, CNC);
3858    }
3859
3860    LLVMContext *Context = SE.getContext();
3861
3862    ConstantInt *Solution1 =
3863      Context->getConstantInt((NegB + SqrtVal).sdiv(TwoA));
3864    ConstantInt *Solution2 =
3865      Context->getConstantInt((NegB - SqrtVal).sdiv(TwoA));
3866
3867    return std::make_pair(SE.getConstant(Solution1),
3868                          SE.getConstant(Solution2));
3869    } // end APIntOps namespace
3870}
3871
3872/// HowFarToZero - Return the number of times a backedge comparing the specified
3873/// value to zero will execute.  If not computable, return CouldNotCompute.
3874const SCEV *ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) {
3875  // If the value is a constant
3876  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
3877    // If the value is already zero, the branch will execute zero times.
3878    if (C->getValue()->isZero()) return C;
3879    return getCouldNotCompute();  // Otherwise it will loop infinitely.
3880  }
3881
3882  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V);
3883  if (!AddRec || AddRec->getLoop() != L)
3884    return getCouldNotCompute();
3885
3886  if (AddRec->isAffine()) {
3887    // If this is an affine expression, the execution count of this branch is
3888    // the minimum unsigned root of the following equation:
3889    //
3890    //     Start + Step*N = 0 (mod 2^BW)
3891    //
3892    // equivalent to:
3893    //
3894    //             Step*N = -Start (mod 2^BW)
3895    //
3896    // where BW is the common bit width of Start and Step.
3897
3898    // Get the initial value for the loop.
3899    const SCEV *Start = getSCEVAtScope(AddRec->getStart(),
3900                                       L->getParentLoop());
3901    const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1),
3902                                      L->getParentLoop());
3903
3904    if (const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step)) {
3905      // For now we handle only constant steps.
3906
3907      // First, handle unitary steps.
3908      if (StepC->getValue()->equalsInt(1))      // 1*N = -Start (mod 2^BW), so:
3909        return getNegativeSCEV(Start);       //   N = -Start (as unsigned)
3910      if (StepC->getValue()->isAllOnesValue())  // -1*N = -Start (mod 2^BW), so:
3911        return Start;                           //    N = Start (as unsigned)
3912
3913      // Then, try to solve the above equation provided that Start is constant.
3914      if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
3915        return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
3916                                            -StartC->getValue()->getValue(),
3917                                            *this);
3918    }
3919  } else if (AddRec->isQuadratic() && AddRec->getType()->isInteger()) {
3920    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
3921    // the quadratic equation to solve it.
3922    std::pair<const SCEV *,const SCEV *> Roots = SolveQuadraticEquation(AddRec,
3923                                                                    *this);
3924    const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
3925    const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
3926    if (R1) {
3927#if 0
3928      errs() << "HFTZ: " << *V << " - sol#1: " << *R1
3929             << "  sol#2: " << *R2 << "\n";
3930#endif
3931      // Pick the smallest positive root value.
3932      if (ConstantInt *CB =
3933          dyn_cast<ConstantInt>(Context->getConstantExprICmp(ICmpInst::ICMP_ULT,
3934                                   R1->getValue(), R2->getValue()))) {
3935        if (CB->getZExtValue() == false)
3936          std::swap(R1, R2);   // R1 is the minimum root now.
3937
3938        // We can only use this value if the chrec ends up with an exact zero
3939        // value at this index.  When solving for "X*X != 5", for example, we
3940        // should not accept a root of 2.
3941        const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
3942        if (Val->isZero())
3943          return R1;  // We found a quadratic root!
3944      }
3945    }
3946  }
3947
3948  return getCouldNotCompute();
3949}
3950
3951/// HowFarToNonZero - Return the number of times a backedge checking the
3952/// specified value for nonzero will execute.  If not computable, return
3953/// CouldNotCompute
3954const SCEV *ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) {
3955  // Loops that look like: while (X == 0) are very strange indeed.  We don't
3956  // handle them yet except for the trivial case.  This could be expanded in the
3957  // future as needed.
3958
3959  // If the value is a constant, check to see if it is known to be non-zero
3960  // already.  If so, the backedge will execute zero times.
3961  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
3962    if (!C->getValue()->isNullValue())
3963      return getIntegerSCEV(0, C->getType());
3964    return getCouldNotCompute();  // Otherwise it will loop infinitely.
3965  }
3966
3967  // We could implement others, but I really doubt anyone writes loops like
3968  // this, and if they did, they would already be constant folded.
3969  return getCouldNotCompute();
3970}
3971
3972/// getLoopPredecessor - If the given loop's header has exactly one unique
3973/// predecessor outside the loop, return it. Otherwise return null.
3974///
3975BasicBlock *ScalarEvolution::getLoopPredecessor(const Loop *L) {
3976  BasicBlock *Header = L->getHeader();
3977  BasicBlock *Pred = 0;
3978  for (pred_iterator PI = pred_begin(Header), E = pred_end(Header);
3979       PI != E; ++PI)
3980    if (!L->contains(*PI)) {
3981      if (Pred && Pred != *PI) return 0; // Multiple predecessors.
3982      Pred = *PI;
3983    }
3984  return Pred;
3985}
3986
3987/// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
3988/// (which may not be an immediate predecessor) which has exactly one
3989/// successor from which BB is reachable, or null if no such block is
3990/// found.
3991///
3992BasicBlock *
3993ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
3994  // If the block has a unique predecessor, then there is no path from the
3995  // predecessor to the block that does not go through the direct edge
3996  // from the predecessor to the block.
3997  if (BasicBlock *Pred = BB->getSinglePredecessor())
3998    return Pred;
3999
4000  // A loop's header is defined to be a block that dominates the loop.
4001  // If the header has a unique predecessor outside the loop, it must be
4002  // a block that has exactly one successor that can reach the loop.
4003  if (Loop *L = LI->getLoopFor(BB))
4004    return getLoopPredecessor(L);
4005
4006  return 0;
4007}
4008
4009/// HasSameValue - SCEV structural equivalence is usually sufficient for
4010/// testing whether two expressions are equal, however for the purposes of
4011/// looking for a condition guarding a loop, it can be useful to be a little
4012/// more general, since a front-end may have replicated the controlling
4013/// expression.
4014///
4015static bool HasSameValue(const SCEV *A, const SCEV *B) {
4016  // Quick check to see if they are the same SCEV.
4017  if (A == B) return true;
4018
4019  // Otherwise, if they're both SCEVUnknown, it's possible that they hold
4020  // two different instructions with the same value. Check for this case.
4021  if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
4022    if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
4023      if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
4024        if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
4025          if (AI->isIdenticalTo(BI))
4026            return true;
4027
4028  // Otherwise assume they may have a different value.
4029  return false;
4030}
4031
4032/// isLoopGuardedByCond - Test whether entry to the loop is protected by
4033/// a conditional between LHS and RHS.  This is used to help avoid max
4034/// expressions in loop trip counts.
4035bool ScalarEvolution::isLoopGuardedByCond(const Loop *L,
4036                                          ICmpInst::Predicate Pred,
4037                                          const SCEV *LHS, const SCEV *RHS) {
4038  // Interpret a null as meaning no loop, where there is obviously no guard
4039  // (interprocedural conditions notwithstanding).
4040  if (!L) return false;
4041
4042  BasicBlock *Predecessor = getLoopPredecessor(L);
4043  BasicBlock *PredecessorDest = L->getHeader();
4044
4045  // Starting at the loop predecessor, climb up the predecessor chain, as long
4046  // as there are predecessors that can be found that have unique successors
4047  // leading to the original header.
4048  for (; Predecessor;
4049       PredecessorDest = Predecessor,
4050       Predecessor = getPredecessorWithUniqueSuccessorForBB(Predecessor)) {
4051
4052    BranchInst *LoopEntryPredicate =
4053      dyn_cast<BranchInst>(Predecessor->getTerminator());
4054    if (!LoopEntryPredicate ||
4055        LoopEntryPredicate->isUnconditional())
4056      continue;
4057
4058    if (isNecessaryCond(LoopEntryPredicate->getCondition(), Pred, LHS, RHS,
4059                        LoopEntryPredicate->getSuccessor(0) != PredecessorDest))
4060      return true;
4061  }
4062
4063  return false;
4064}
4065
4066/// isNecessaryCond - Test whether the given CondValue value is a condition
4067/// which is at least as strict as the one described by Pred, LHS, and RHS.
4068bool ScalarEvolution::isNecessaryCond(Value *CondValue,
4069                                      ICmpInst::Predicate Pred,
4070                                      const SCEV *LHS, const SCEV *RHS,
4071                                      bool Inverse) {
4072  // Recursivly handle And and Or conditions.
4073  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(CondValue)) {
4074    if (BO->getOpcode() == Instruction::And) {
4075      if (!Inverse)
4076        return isNecessaryCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) ||
4077               isNecessaryCond(BO->getOperand(1), Pred, LHS, RHS, Inverse);
4078    } else if (BO->getOpcode() == Instruction::Or) {
4079      if (Inverse)
4080        return isNecessaryCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) ||
4081               isNecessaryCond(BO->getOperand(1), Pred, LHS, RHS, Inverse);
4082    }
4083  }
4084
4085  ICmpInst *ICI = dyn_cast<ICmpInst>(CondValue);
4086  if (!ICI) return false;
4087
4088  // Now that we found a conditional branch that dominates the loop, check to
4089  // see if it is the comparison we are looking for.
4090  Value *PreCondLHS = ICI->getOperand(0);
4091  Value *PreCondRHS = ICI->getOperand(1);
4092  ICmpInst::Predicate Cond;
4093  if (Inverse)
4094    Cond = ICI->getInversePredicate();
4095  else
4096    Cond = ICI->getPredicate();
4097
4098  if (Cond == Pred)
4099    ; // An exact match.
4100  else if (!ICmpInst::isTrueWhenEqual(Cond) && Pred == ICmpInst::ICMP_NE)
4101    ; // The actual condition is beyond sufficient.
4102  else
4103    // Check a few special cases.
4104    switch (Cond) {
4105    case ICmpInst::ICMP_UGT:
4106      if (Pred == ICmpInst::ICMP_ULT) {
4107        std::swap(PreCondLHS, PreCondRHS);
4108        Cond = ICmpInst::ICMP_ULT;
4109        break;
4110      }
4111      return false;
4112    case ICmpInst::ICMP_SGT:
4113      if (Pred == ICmpInst::ICMP_SLT) {
4114        std::swap(PreCondLHS, PreCondRHS);
4115        Cond = ICmpInst::ICMP_SLT;
4116        break;
4117      }
4118      return false;
4119    case ICmpInst::ICMP_NE:
4120      // Expressions like (x >u 0) are often canonicalized to (x != 0),
4121      // so check for this case by checking if the NE is comparing against
4122      // a minimum or maximum constant.
4123      if (!ICmpInst::isTrueWhenEqual(Pred))
4124        if (ConstantInt *CI = dyn_cast<ConstantInt>(PreCondRHS)) {
4125          const APInt &A = CI->getValue();
4126          switch (Pred) {
4127          case ICmpInst::ICMP_SLT:
4128            if (A.isMaxSignedValue()) break;
4129            return false;
4130          case ICmpInst::ICMP_SGT:
4131            if (A.isMinSignedValue()) break;
4132            return false;
4133          case ICmpInst::ICMP_ULT:
4134            if (A.isMaxValue()) break;
4135            return false;
4136          case ICmpInst::ICMP_UGT:
4137            if (A.isMinValue()) break;
4138            return false;
4139          default:
4140            return false;
4141          }
4142          Cond = ICmpInst::ICMP_NE;
4143          // NE is symmetric but the original comparison may not be. Swap
4144          // the operands if necessary so that they match below.
4145          if (isa<SCEVConstant>(LHS))
4146            std::swap(PreCondLHS, PreCondRHS);
4147          break;
4148        }
4149      return false;
4150    default:
4151      // We weren't able to reconcile the condition.
4152      return false;
4153    }
4154
4155  if (!PreCondLHS->getType()->isInteger()) return false;
4156
4157  const SCEV *PreCondLHSSCEV = getSCEV(PreCondLHS);
4158  const SCEV *PreCondRHSSCEV = getSCEV(PreCondRHS);
4159  return (HasSameValue(LHS, PreCondLHSSCEV) &&
4160          HasSameValue(RHS, PreCondRHSSCEV)) ||
4161         (HasSameValue(LHS, getNotSCEV(PreCondRHSSCEV)) &&
4162          HasSameValue(RHS, getNotSCEV(PreCondLHSSCEV)));
4163}
4164
4165/// getBECount - Subtract the end and start values and divide by the step,
4166/// rounding up, to get the number of times the backedge is executed. Return
4167/// CouldNotCompute if an intermediate computation overflows.
4168const SCEV *ScalarEvolution::getBECount(const SCEV *Start,
4169                                       const SCEV *End,
4170                                       const SCEV *Step) {
4171  const Type *Ty = Start->getType();
4172  const SCEV *NegOne = getIntegerSCEV(-1, Ty);
4173  const SCEV *Diff = getMinusSCEV(End, Start);
4174  const SCEV *RoundUp = getAddExpr(Step, NegOne);
4175
4176  // Add an adjustment to the difference between End and Start so that
4177  // the division will effectively round up.
4178  const SCEV *Add = getAddExpr(Diff, RoundUp);
4179
4180  // Check Add for unsigned overflow.
4181  // TODO: More sophisticated things could be done here.
4182  const Type *WideTy = Context->getIntegerType(getTypeSizeInBits(Ty) + 1);
4183  const SCEV *OperandExtendedAdd =
4184    getAddExpr(getZeroExtendExpr(Diff, WideTy),
4185               getZeroExtendExpr(RoundUp, WideTy));
4186  if (getZeroExtendExpr(Add, WideTy) != OperandExtendedAdd)
4187    return getCouldNotCompute();
4188
4189  return getUDivExpr(Add, Step);
4190}
4191
4192/// HowManyLessThans - Return the number of times a backedge containing the
4193/// specified less-than comparison will execute.  If not computable, return
4194/// CouldNotCompute.
4195ScalarEvolution::BackedgeTakenInfo
4196ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
4197                                  const Loop *L, bool isSigned) {
4198  // Only handle:  "ADDREC < LoopInvariant".
4199  if (!RHS->isLoopInvariant(L)) return getCouldNotCompute();
4200
4201  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS);
4202  if (!AddRec || AddRec->getLoop() != L)
4203    return getCouldNotCompute();
4204
4205  if (AddRec->isAffine()) {
4206    // FORNOW: We only support unit strides.
4207    unsigned BitWidth = getTypeSizeInBits(AddRec->getType());
4208    const SCEV *Step = AddRec->getStepRecurrence(*this);
4209
4210    // TODO: handle non-constant strides.
4211    const SCEVConstant *CStep = dyn_cast<SCEVConstant>(Step);
4212    if (!CStep || CStep->isZero())
4213      return getCouldNotCompute();
4214    if (CStep->isOne()) {
4215      // With unit stride, the iteration never steps past the limit value.
4216    } else if (CStep->getValue()->getValue().isStrictlyPositive()) {
4217      if (const SCEVConstant *CLimit = dyn_cast<SCEVConstant>(RHS)) {
4218        // Test whether a positive iteration iteration can step past the limit
4219        // value and past the maximum value for its type in a single step.
4220        if (isSigned) {
4221          APInt Max = APInt::getSignedMaxValue(BitWidth);
4222          if ((Max - CStep->getValue()->getValue())
4223                .slt(CLimit->getValue()->getValue()))
4224            return getCouldNotCompute();
4225        } else {
4226          APInt Max = APInt::getMaxValue(BitWidth);
4227          if ((Max - CStep->getValue()->getValue())
4228                .ult(CLimit->getValue()->getValue()))
4229            return getCouldNotCompute();
4230        }
4231      } else
4232        // TODO: handle non-constant limit values below.
4233        return getCouldNotCompute();
4234    } else
4235      // TODO: handle negative strides below.
4236      return getCouldNotCompute();
4237
4238    // We know the LHS is of the form {n,+,s} and the RHS is some loop-invariant
4239    // m.  So, we count the number of iterations in which {n,+,s} < m is true.
4240    // Note that we cannot simply return max(m-n,0)/s because it's not safe to
4241    // treat m-n as signed nor unsigned due to overflow possibility.
4242
4243    // First, we get the value of the LHS in the first iteration: n
4244    const SCEV *Start = AddRec->getOperand(0);
4245
4246    // Determine the minimum constant start value.
4247    const SCEV *MinStart = isa<SCEVConstant>(Start) ? Start :
4248      getConstant(isSigned ? APInt::getSignedMinValue(BitWidth) :
4249                             APInt::getMinValue(BitWidth));
4250
4251    // If we know that the condition is true in order to enter the loop,
4252    // then we know that it will run exactly (m-n)/s times. Otherwise, we
4253    // only know that it will execute (max(m,n)-n)/s times. In both cases,
4254    // the division must round up.
4255    const SCEV *End = RHS;
4256    if (!isLoopGuardedByCond(L,
4257                             isSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT,
4258                             getMinusSCEV(Start, Step), RHS))
4259      End = isSigned ? getSMaxExpr(RHS, Start)
4260                     : getUMaxExpr(RHS, Start);
4261
4262    // Determine the maximum constant end value.
4263    const SCEV *MaxEnd =
4264      isa<SCEVConstant>(End) ? End :
4265      getConstant(isSigned ? APInt::getSignedMaxValue(BitWidth)
4266                               .ashr(GetMinSignBits(End) - 1) :
4267                             APInt::getMaxValue(BitWidth)
4268                               .lshr(GetMinLeadingZeros(End)));
4269
4270    // Finally, we subtract these two values and divide, rounding up, to get
4271    // the number of times the backedge is executed.
4272    const SCEV *BECount = getBECount(Start, End, Step);
4273
4274    // The maximum backedge count is similar, except using the minimum start
4275    // value and the maximum end value.
4276    const SCEV *MaxBECount = getBECount(MinStart, MaxEnd, Step);
4277
4278    return BackedgeTakenInfo(BECount, MaxBECount);
4279  }
4280
4281  return getCouldNotCompute();
4282}
4283
4284/// getNumIterationsInRange - Return the number of iterations of this loop that
4285/// produce values in the specified constant range.  Another way of looking at
4286/// this is that it returns the first iteration number where the value is not in
4287/// the condition, thus computing the exit count. If the iteration count can't
4288/// be computed, an instance of SCEVCouldNotCompute is returned.
4289const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
4290                                                    ScalarEvolution &SE) const {
4291  if (Range.isFullSet())  // Infinite loop.
4292    return SE.getCouldNotCompute();
4293
4294  // If the start is a non-zero constant, shift the range to simplify things.
4295  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
4296    if (!SC->getValue()->isZero()) {
4297      SmallVector<const SCEV *, 4> Operands(op_begin(), op_end());
4298      Operands[0] = SE.getIntegerSCEV(0, SC->getType());
4299      const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop());
4300      if (const SCEVAddRecExpr *ShiftedAddRec =
4301            dyn_cast<SCEVAddRecExpr>(Shifted))
4302        return ShiftedAddRec->getNumIterationsInRange(
4303                           Range.subtract(SC->getValue()->getValue()), SE);
4304      // This is strange and shouldn't happen.
4305      return SE.getCouldNotCompute();
4306    }
4307
4308  // The only time we can solve this is when we have all constant indices.
4309  // Otherwise, we cannot determine the overflow conditions.
4310  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
4311    if (!isa<SCEVConstant>(getOperand(i)))
4312      return SE.getCouldNotCompute();
4313
4314
4315  // Okay at this point we know that all elements of the chrec are constants and
4316  // that the start element is zero.
4317
4318  // First check to see if the range contains zero.  If not, the first
4319  // iteration exits.
4320  unsigned BitWidth = SE.getTypeSizeInBits(getType());
4321  if (!Range.contains(APInt(BitWidth, 0)))
4322    return SE.getIntegerSCEV(0, getType());
4323
4324  if (isAffine()) {
4325    // If this is an affine expression then we have this situation:
4326    //   Solve {0,+,A} in Range  ===  Ax in Range
4327
4328    // We know that zero is in the range.  If A is positive then we know that
4329    // the upper value of the range must be the first possible exit value.
4330    // If A is negative then the lower of the range is the last possible loop
4331    // value.  Also note that we already checked for a full range.
4332    APInt One(BitWidth,1);
4333    APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
4334    APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
4335
4336    // The exit value should be (End+A)/A.
4337    APInt ExitVal = (End + A).udiv(A);
4338    ConstantInt *ExitValue = SE.getContext()->getConstantInt(ExitVal);
4339
4340    // Evaluate at the exit value.  If we really did fall out of the valid
4341    // range, then we computed our trip count, otherwise wrap around or other
4342    // things must have happened.
4343    ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
4344    if (Range.contains(Val->getValue()))
4345      return SE.getCouldNotCompute();  // Something strange happened
4346
4347    // Ensure that the previous value is in the range.  This is a sanity check.
4348    assert(Range.contains(
4349           EvaluateConstantChrecAtConstant(this,
4350           SE.getContext()->getConstantInt(ExitVal - One), SE)->getValue()) &&
4351           "Linear scev computation is off in a bad way!");
4352    return SE.getConstant(ExitValue);
4353  } else if (isQuadratic()) {
4354    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
4355    // quadratic equation to solve it.  To do this, we must frame our problem in
4356    // terms of figuring out when zero is crossed, instead of when
4357    // Range.getUpper() is crossed.
4358    SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end());
4359    NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper()));
4360    const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop());
4361
4362    // Next, solve the constructed addrec
4363    std::pair<const SCEV *,const SCEV *> Roots =
4364      SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
4365    const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
4366    const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
4367    if (R1) {
4368      // Pick the smallest positive root value.
4369      if (ConstantInt *CB =
4370          dyn_cast<ConstantInt>(
4371                       SE.getContext()->getConstantExprICmp(ICmpInst::ICMP_ULT,
4372                         R1->getValue(), R2->getValue()))) {
4373        if (CB->getZExtValue() == false)
4374          std::swap(R1, R2);   // R1 is the minimum root now.
4375
4376        // Make sure the root is not off by one.  The returned iteration should
4377        // not be in the range, but the previous one should be.  When solving
4378        // for "X*X < 5", for example, we should not return a root of 2.
4379        ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this,
4380                                                             R1->getValue(),
4381                                                             SE);
4382        if (Range.contains(R1Val->getValue())) {
4383          // The next iteration must be out of the range...
4384          ConstantInt *NextVal =
4385                 SE.getContext()->getConstantInt(R1->getValue()->getValue()+1);
4386
4387          R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
4388          if (!Range.contains(R1Val->getValue()))
4389            return SE.getConstant(NextVal);
4390          return SE.getCouldNotCompute();  // Something strange happened
4391        }
4392
4393        // If R1 was not in the range, then it is a good return value.  Make
4394        // sure that R1-1 WAS in the range though, just in case.
4395        ConstantInt *NextVal =
4396                 SE.getContext()->getConstantInt(R1->getValue()->getValue()-1);
4397        R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
4398        if (Range.contains(R1Val->getValue()))
4399          return R1;
4400        return SE.getCouldNotCompute();  // Something strange happened
4401      }
4402    }
4403  }
4404
4405  return SE.getCouldNotCompute();
4406}
4407
4408
4409
4410//===----------------------------------------------------------------------===//
4411//                   SCEVCallbackVH Class Implementation
4412//===----------------------------------------------------------------------===//
4413
4414void ScalarEvolution::SCEVCallbackVH::deleted() {
4415  assert(SE && "SCEVCallbackVH called with a non-null ScalarEvolution!");
4416  if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
4417    SE->ConstantEvolutionLoopExitValue.erase(PN);
4418  if (Instruction *I = dyn_cast<Instruction>(getValPtr()))
4419    SE->ValuesAtScopes.erase(I);
4420  SE->Scalars.erase(getValPtr());
4421  // this now dangles!
4422}
4423
4424void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *) {
4425  assert(SE && "SCEVCallbackVH called with a non-null ScalarEvolution!");
4426
4427  // Forget all the expressions associated with users of the old value,
4428  // so that future queries will recompute the expressions using the new
4429  // value.
4430  SmallVector<User *, 16> Worklist;
4431  Value *Old = getValPtr();
4432  bool DeleteOld = false;
4433  for (Value::use_iterator UI = Old->use_begin(), UE = Old->use_end();
4434       UI != UE; ++UI)
4435    Worklist.push_back(*UI);
4436  while (!Worklist.empty()) {
4437    User *U = Worklist.pop_back_val();
4438    // Deleting the Old value will cause this to dangle. Postpone
4439    // that until everything else is done.
4440    if (U == Old) {
4441      DeleteOld = true;
4442      continue;
4443    }
4444    if (PHINode *PN = dyn_cast<PHINode>(U))
4445      SE->ConstantEvolutionLoopExitValue.erase(PN);
4446    if (Instruction *I = dyn_cast<Instruction>(U))
4447      SE->ValuesAtScopes.erase(I);
4448    if (SE->Scalars.erase(U))
4449      for (Value::use_iterator UI = U->use_begin(), UE = U->use_end();
4450           UI != UE; ++UI)
4451        Worklist.push_back(*UI);
4452  }
4453  if (DeleteOld) {
4454    if (PHINode *PN = dyn_cast<PHINode>(Old))
4455      SE->ConstantEvolutionLoopExitValue.erase(PN);
4456    if (Instruction *I = dyn_cast<Instruction>(Old))
4457      SE->ValuesAtScopes.erase(I);
4458    SE->Scalars.erase(Old);
4459    // this now dangles!
4460  }
4461  // this may dangle!
4462}
4463
4464ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
4465  : CallbackVH(V), SE(se) {}
4466
4467//===----------------------------------------------------------------------===//
4468//                   ScalarEvolution Class Implementation
4469//===----------------------------------------------------------------------===//
4470
4471ScalarEvolution::ScalarEvolution()
4472  : FunctionPass(&ID) {
4473}
4474
4475bool ScalarEvolution::runOnFunction(Function &F) {
4476  this->F = &F;
4477  LI = &getAnalysis<LoopInfo>();
4478  TD = getAnalysisIfAvailable<TargetData>();
4479  return false;
4480}
4481
4482void ScalarEvolution::releaseMemory() {
4483  Scalars.clear();
4484  BackedgeTakenCounts.clear();
4485  ConstantEvolutionLoopExitValue.clear();
4486  ValuesAtScopes.clear();
4487  UniqueSCEVs.clear();
4488  SCEVAllocator.Reset();
4489}
4490
4491void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
4492  AU.setPreservesAll();
4493  AU.addRequiredTransitive<LoopInfo>();
4494}
4495
4496bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
4497  return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
4498}
4499
4500static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
4501                          const Loop *L) {
4502  // Print all inner loops first
4503  for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
4504    PrintLoopInfo(OS, SE, *I);
4505
4506  OS << "Loop " << L->getHeader()->getName() << ": ";
4507
4508  SmallVector<BasicBlock*, 8> ExitBlocks;
4509  L->getExitBlocks(ExitBlocks);
4510  if (ExitBlocks.size() != 1)
4511    OS << "<multiple exits> ";
4512
4513  if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
4514    OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L);
4515  } else {
4516    OS << "Unpredictable backedge-taken count. ";
4517  }
4518
4519  OS << "\n";
4520  OS << "Loop " << L->getHeader()->getName() << ": ";
4521
4522  if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) {
4523    OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L);
4524  } else {
4525    OS << "Unpredictable max backedge-taken count. ";
4526  }
4527
4528  OS << "\n";
4529}
4530
4531void ScalarEvolution::print(raw_ostream &OS, const Module* ) const {
4532  // ScalarEvolution's implementaiton of the print method is to print
4533  // out SCEV values of all instructions that are interesting. Doing
4534  // this potentially causes it to create new SCEV objects though,
4535  // which technically conflicts with the const qualifier. This isn't
4536  // observable from outside the class though, so casting away the
4537  // const isn't dangerous.
4538  ScalarEvolution &SE = *const_cast<ScalarEvolution*>(this);
4539
4540  OS << "Classifying expressions for: " << F->getName() << "\n";
4541  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
4542    if (isSCEVable(I->getType())) {
4543      OS << *I;
4544      OS << "  -->  ";
4545      const SCEV *SV = SE.getSCEV(&*I);
4546      SV->print(OS);
4547
4548      const Loop *L = LI->getLoopFor((*I).getParent());
4549
4550      const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
4551      if (AtUse != SV) {
4552        OS << "  -->  ";
4553        AtUse->print(OS);
4554      }
4555
4556      if (L) {
4557        OS << "\t\t" "Exits: ";
4558        const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
4559        if (!ExitValue->isLoopInvariant(L)) {
4560          OS << "<<Unknown>>";
4561        } else {
4562          OS << *ExitValue;
4563        }
4564      }
4565
4566      OS << "\n";
4567    }
4568
4569  OS << "Determining loop execution counts for: " << F->getName() << "\n";
4570  for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I)
4571    PrintLoopInfo(OS, &SE, *I);
4572}
4573
4574void ScalarEvolution::print(std::ostream &o, const Module *M) const {
4575  raw_os_ostream OS(o);
4576  print(OS, M);
4577}
4578