ScalarEvolution.cpp revision ebf78f18df84a63295c748148f54f091234ed099
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file contains the implementation of the scalar evolution analysis
11// engine, which is used primarily to analyze expressions involving induction
12// variables in loops.
13//
14// There are several aspects to this library.  First is the representation of
15// scalar expressions, which are represented as subclasses of the SCEV class.
16// These classes are used to represent certain types of subexpressions that we
17// can handle. We only create one SCEV of a particular shape, so
18// pointer-comparisons for equality are legal.
19//
20// One important aspect of the SCEV objects is that they are never cyclic, even
21// if there is a cycle in the dataflow for an expression (ie, a PHI node).  If
22// the PHI node is one of the idioms that we can represent (e.g., a polynomial
23// recurrence) then we represent it directly as a recurrence node, otherwise we
24// represent it as a SCEVUnknown node.
25//
26// In addition to being able to represent expressions of various types, we also
27// have folders that are used to build the *canonical* representation for a
28// particular expression.  These folders are capable of using a variety of
29// rewrite rules to simplify the expressions.
30//
31// Once the folders are defined, we can implement the more interesting
32// higher-level code, such as the code that recognizes PHI nodes of various
33// types, computes the execution count of a loop, etc.
34//
35// TODO: We should use these routines and value representations to implement
36// dependence analysis!
37//
38//===----------------------------------------------------------------------===//
39//
40// There are several good references for the techniques used in this analysis.
41//
42//  Chains of recurrences -- a method to expedite the evaluation
43//  of closed-form functions
44//  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
45//
46//  On computational properties of chains of recurrences
47//  Eugene V. Zima
48//
49//  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
50//  Robert A. van Engelen
51//
52//  Efficient Symbolic Analysis for Optimizing Compilers
53//  Robert A. van Engelen
54//
55//  Using the chains of recurrences algebra for data dependence testing and
56//  induction variable substitution
57//  MS Thesis, Johnie Birch
58//
59//===----------------------------------------------------------------------===//
60
61#define DEBUG_TYPE "scalar-evolution"
62#include "llvm/Analysis/ScalarEvolutionExpressions.h"
63#include "llvm/Constants.h"
64#include "llvm/DerivedTypes.h"
65#include "llvm/GlobalVariable.h"
66#include "llvm/GlobalAlias.h"
67#include "llvm/Instructions.h"
68#include "llvm/LLVMContext.h"
69#include "llvm/Operator.h"
70#include "llvm/Analysis/ConstantFolding.h"
71#include "llvm/Analysis/Dominators.h"
72#include "llvm/Analysis/LoopInfo.h"
73#include "llvm/Analysis/ValueTracking.h"
74#include "llvm/Assembly/Writer.h"
75#include "llvm/Target/TargetData.h"
76#include "llvm/Support/CommandLine.h"
77#include "llvm/Support/ConstantRange.h"
78#include "llvm/Support/Debug.h"
79#include "llvm/Support/ErrorHandling.h"
80#include "llvm/Support/GetElementPtrTypeIterator.h"
81#include "llvm/Support/InstIterator.h"
82#include "llvm/Support/MathExtras.h"
83#include "llvm/Support/raw_ostream.h"
84#include "llvm/ADT/Statistic.h"
85#include "llvm/ADT/STLExtras.h"
86#include "llvm/ADT/SmallPtrSet.h"
87#include <algorithm>
88using namespace llvm;
89
90STATISTIC(NumArrayLenItCounts,
91          "Number of trip counts computed with array length");
92STATISTIC(NumTripCountsComputed,
93          "Number of loops with predictable loop counts");
94STATISTIC(NumTripCountsNotComputed,
95          "Number of loops without predictable loop counts");
96STATISTIC(NumBruteForceTripCountsComputed,
97          "Number of loops with trip counts computed by force");
98
99static cl::opt<unsigned>
100MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
101                        cl::desc("Maximum number of iterations SCEV will "
102                                 "symbolically execute a constant "
103                                 "derived loop"),
104                        cl::init(100));
105
106static RegisterPass<ScalarEvolution>
107R("scalar-evolution", "Scalar Evolution Analysis", false, true);
108char ScalarEvolution::ID = 0;
109
110//===----------------------------------------------------------------------===//
111//                           SCEV class definitions
112//===----------------------------------------------------------------------===//
113
114//===----------------------------------------------------------------------===//
115// Implementation of the SCEV class.
116//
117
118SCEV::~SCEV() {}
119
120void SCEV::dump() const {
121  print(dbgs());
122  dbgs() << '\n';
123}
124
125bool SCEV::isZero() const {
126  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
127    return SC->getValue()->isZero();
128  return false;
129}
130
131bool SCEV::isOne() const {
132  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
133    return SC->getValue()->isOne();
134  return false;
135}
136
137bool SCEV::isAllOnesValue() const {
138  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
139    return SC->getValue()->isAllOnesValue();
140  return false;
141}
142
143SCEVCouldNotCompute::SCEVCouldNotCompute() :
144  SCEV(FoldingSetNodeID(), scCouldNotCompute) {}
145
146bool SCEVCouldNotCompute::isLoopInvariant(const Loop *L) const {
147  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
148  return false;
149}
150
151const Type *SCEVCouldNotCompute::getType() const {
152  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
153  return 0;
154}
155
156bool SCEVCouldNotCompute::hasComputableLoopEvolution(const Loop *L) const {
157  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
158  return false;
159}
160
161bool SCEVCouldNotCompute::hasOperand(const SCEV *) const {
162  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
163  return false;
164}
165
166void SCEVCouldNotCompute::print(raw_ostream &OS) const {
167  OS << "***COULDNOTCOMPUTE***";
168}
169
170bool SCEVCouldNotCompute::classof(const SCEV *S) {
171  return S->getSCEVType() == scCouldNotCompute;
172}
173
174const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
175  FoldingSetNodeID ID;
176  ID.AddInteger(scConstant);
177  ID.AddPointer(V);
178  void *IP = 0;
179  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
180  SCEV *S = SCEVAllocator.Allocate<SCEVConstant>();
181  new (S) SCEVConstant(ID, V);
182  UniqueSCEVs.InsertNode(S, IP);
183  return S;
184}
185
186const SCEV *ScalarEvolution::getConstant(const APInt& Val) {
187  return getConstant(ConstantInt::get(getContext(), Val));
188}
189
190const SCEV *
191ScalarEvolution::getConstant(const Type *Ty, uint64_t V, bool isSigned) {
192  return getConstant(
193    ConstantInt::get(cast<IntegerType>(Ty), V, isSigned));
194}
195
196const Type *SCEVConstant::getType() const { return V->getType(); }
197
198void SCEVConstant::print(raw_ostream &OS) const {
199  WriteAsOperand(OS, V, false);
200}
201
202SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeID &ID,
203                           unsigned SCEVTy, const SCEV *op, const Type *ty)
204  : SCEV(ID, SCEVTy), Op(op), Ty(ty) {}
205
206bool SCEVCastExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
207  return Op->dominates(BB, DT);
208}
209
210bool SCEVCastExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const {
211  return Op->properlyDominates(BB, DT);
212}
213
214SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeID &ID,
215                                   const SCEV *op, const Type *ty)
216  : SCEVCastExpr(ID, scTruncate, op, ty) {
217  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
218         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
219         "Cannot truncate non-integer value!");
220}
221
222void SCEVTruncateExpr::print(raw_ostream &OS) const {
223  OS << "(trunc " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
224}
225
226SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeID &ID,
227                                       const SCEV *op, const Type *ty)
228  : SCEVCastExpr(ID, scZeroExtend, op, ty) {
229  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
230         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
231         "Cannot zero extend non-integer value!");
232}
233
234void SCEVZeroExtendExpr::print(raw_ostream &OS) const {
235  OS << "(zext " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
236}
237
238SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeID &ID,
239                                       const SCEV *op, const Type *ty)
240  : SCEVCastExpr(ID, scSignExtend, op, ty) {
241  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
242         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
243         "Cannot sign extend non-integer value!");
244}
245
246void SCEVSignExtendExpr::print(raw_ostream &OS) const {
247  OS << "(sext " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
248}
249
250void SCEVCommutativeExpr::print(raw_ostream &OS) const {
251  assert(Operands.size() > 1 && "This plus expr shouldn't exist!");
252  const char *OpStr = getOperationStr();
253  OS << "(" << *Operands[0];
254  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
255    OS << OpStr << *Operands[i];
256  OS << ")";
257}
258
259bool SCEVNAryExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
260  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
261    if (!getOperand(i)->dominates(BB, DT))
262      return false;
263  }
264  return true;
265}
266
267bool SCEVNAryExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const {
268  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
269    if (!getOperand(i)->properlyDominates(BB, DT))
270      return false;
271  }
272  return true;
273}
274
275bool SCEVUDivExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
276  return LHS->dominates(BB, DT) && RHS->dominates(BB, DT);
277}
278
279bool SCEVUDivExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const {
280  return LHS->properlyDominates(BB, DT) && RHS->properlyDominates(BB, DT);
281}
282
283void SCEVUDivExpr::print(raw_ostream &OS) const {
284  OS << "(" << *LHS << " /u " << *RHS << ")";
285}
286
287const Type *SCEVUDivExpr::getType() const {
288  // In most cases the types of LHS and RHS will be the same, but in some
289  // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
290  // depend on the type for correctness, but handling types carefully can
291  // avoid extra casts in the SCEVExpander. The LHS is more likely to be
292  // a pointer type than the RHS, so use the RHS' type here.
293  return RHS->getType();
294}
295
296bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const {
297  // Add recurrences are never invariant in the function-body (null loop).
298  if (!QueryLoop)
299    return false;
300
301  // This recurrence is variant w.r.t. QueryLoop if QueryLoop contains L.
302  if (QueryLoop->contains(L))
303    return false;
304
305  // This recurrence is variant w.r.t. QueryLoop if any of its operands
306  // are variant.
307  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
308    if (!getOperand(i)->isLoopInvariant(QueryLoop))
309      return false;
310
311  // Otherwise it's loop-invariant.
312  return true;
313}
314
315bool
316SCEVAddRecExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
317  return DT->dominates(L->getHeader(), BB) &&
318         SCEVNAryExpr::dominates(BB, DT);
319}
320
321bool
322SCEVAddRecExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const {
323  // This uses a "dominates" query instead of "properly dominates" query because
324  // the instruction which produces the addrec's value is a PHI, and a PHI
325  // effectively properly dominates its entire containing block.
326  return DT->dominates(L->getHeader(), BB) &&
327         SCEVNAryExpr::properlyDominates(BB, DT);
328}
329
330void SCEVAddRecExpr::print(raw_ostream &OS) const {
331  OS << "{" << *Operands[0];
332  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
333    OS << ",+," << *Operands[i];
334  OS << "}<";
335  WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false);
336  OS << ">";
337}
338
339bool SCEVUnknown::isLoopInvariant(const Loop *L) const {
340  // All non-instruction values are loop invariant.  All instructions are loop
341  // invariant if they are not contained in the specified loop.
342  // Instructions are never considered invariant in the function body
343  // (null loop) because they are defined within the "loop".
344  if (Instruction *I = dyn_cast<Instruction>(V))
345    return L && !L->contains(I);
346  return true;
347}
348
349bool SCEVUnknown::dominates(BasicBlock *BB, DominatorTree *DT) const {
350  if (Instruction *I = dyn_cast<Instruction>(getValue()))
351    return DT->dominates(I->getParent(), BB);
352  return true;
353}
354
355bool SCEVUnknown::properlyDominates(BasicBlock *BB, DominatorTree *DT) const {
356  if (Instruction *I = dyn_cast<Instruction>(getValue()))
357    return DT->properlyDominates(I->getParent(), BB);
358  return true;
359}
360
361const Type *SCEVUnknown::getType() const {
362  return V->getType();
363}
364
365bool SCEVUnknown::isSizeOf(const Type *&AllocTy) const {
366  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(V))
367    if (VCE->getOpcode() == Instruction::PtrToInt)
368      if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
369        if (CE->getOpcode() == Instruction::GetElementPtr &&
370            CE->getOperand(0)->isNullValue() &&
371            CE->getNumOperands() == 2)
372          if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
373            if (CI->isOne()) {
374              AllocTy = cast<PointerType>(CE->getOperand(0)->getType())
375                                 ->getElementType();
376              return true;
377            }
378
379  return false;
380}
381
382bool SCEVUnknown::isAlignOf(const Type *&AllocTy) const {
383  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(V))
384    if (VCE->getOpcode() == Instruction::PtrToInt)
385      if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
386        if (CE->getOpcode() == Instruction::GetElementPtr &&
387            CE->getOperand(0)->isNullValue()) {
388          const Type *Ty =
389            cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
390          if (const StructType *STy = dyn_cast<StructType>(Ty))
391            if (!STy->isPacked() &&
392                CE->getNumOperands() == 3 &&
393                CE->getOperand(1)->isNullValue()) {
394              if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
395                if (CI->isOne() &&
396                    STy->getNumElements() == 2 &&
397                    STy->getElementType(0)->isIntegerTy(1)) {
398                  AllocTy = STy->getElementType(1);
399                  return true;
400                }
401            }
402        }
403
404  return false;
405}
406
407bool SCEVUnknown::isOffsetOf(const Type *&CTy, Constant *&FieldNo) const {
408  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(V))
409    if (VCE->getOpcode() == Instruction::PtrToInt)
410      if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
411        if (CE->getOpcode() == Instruction::GetElementPtr &&
412            CE->getNumOperands() == 3 &&
413            CE->getOperand(0)->isNullValue() &&
414            CE->getOperand(1)->isNullValue()) {
415          const Type *Ty =
416            cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
417          // Ignore vector types here so that ScalarEvolutionExpander doesn't
418          // emit getelementptrs that index into vectors.
419          if (Ty->isStructTy() || Ty->isArrayTy()) {
420            CTy = Ty;
421            FieldNo = CE->getOperand(2);
422            return true;
423          }
424        }
425
426  return false;
427}
428
429void SCEVUnknown::print(raw_ostream &OS) const {
430  const Type *AllocTy;
431  if (isSizeOf(AllocTy)) {
432    OS << "sizeof(" << *AllocTy << ")";
433    return;
434  }
435  if (isAlignOf(AllocTy)) {
436    OS << "alignof(" << *AllocTy << ")";
437    return;
438  }
439
440  const Type *CTy;
441  Constant *FieldNo;
442  if (isOffsetOf(CTy, FieldNo)) {
443    OS << "offsetof(" << *CTy << ", ";
444    WriteAsOperand(OS, FieldNo, false);
445    OS << ")";
446    return;
447  }
448
449  // Otherwise just print it normally.
450  WriteAsOperand(OS, V, false);
451}
452
453//===----------------------------------------------------------------------===//
454//                               SCEV Utilities
455//===----------------------------------------------------------------------===//
456
457static bool CompareTypes(const Type *A, const Type *B) {
458  if (A->getTypeID() != B->getTypeID())
459    return A->getTypeID() < B->getTypeID();
460  if (const IntegerType *AI = dyn_cast<IntegerType>(A)) {
461    const IntegerType *BI = cast<IntegerType>(B);
462    return AI->getBitWidth() < BI->getBitWidth();
463  }
464  if (const PointerType *AI = dyn_cast<PointerType>(A)) {
465    const PointerType *BI = cast<PointerType>(B);
466    return CompareTypes(AI->getElementType(), BI->getElementType());
467  }
468  if (const ArrayType *AI = dyn_cast<ArrayType>(A)) {
469    const ArrayType *BI = cast<ArrayType>(B);
470    if (AI->getNumElements() != BI->getNumElements())
471      return AI->getNumElements() < BI->getNumElements();
472    return CompareTypes(AI->getElementType(), BI->getElementType());
473  }
474  if (const VectorType *AI = dyn_cast<VectorType>(A)) {
475    const VectorType *BI = cast<VectorType>(B);
476    if (AI->getNumElements() != BI->getNumElements())
477      return AI->getNumElements() < BI->getNumElements();
478    return CompareTypes(AI->getElementType(), BI->getElementType());
479  }
480  if (const StructType *AI = dyn_cast<StructType>(A)) {
481    const StructType *BI = cast<StructType>(B);
482    if (AI->getNumElements() != BI->getNumElements())
483      return AI->getNumElements() < BI->getNumElements();
484    for (unsigned i = 0, e = AI->getNumElements(); i != e; ++i)
485      if (CompareTypes(AI->getElementType(i), BI->getElementType(i)) ||
486          CompareTypes(BI->getElementType(i), AI->getElementType(i)))
487        return CompareTypes(AI->getElementType(i), BI->getElementType(i));
488  }
489  return false;
490}
491
492namespace {
493  /// SCEVComplexityCompare - Return true if the complexity of the LHS is less
494  /// than the complexity of the RHS.  This comparator is used to canonicalize
495  /// expressions.
496  class SCEVComplexityCompare {
497    LoopInfo *LI;
498  public:
499    explicit SCEVComplexityCompare(LoopInfo *li) : LI(li) {}
500
501    bool operator()(const SCEV *LHS, const SCEV *RHS) const {
502      // Fast-path: SCEVs are uniqued so we can do a quick equality check.
503      if (LHS == RHS)
504        return false;
505
506      // Primarily, sort the SCEVs by their getSCEVType().
507      if (LHS->getSCEVType() != RHS->getSCEVType())
508        return LHS->getSCEVType() < RHS->getSCEVType();
509
510      // Aside from the getSCEVType() ordering, the particular ordering
511      // isn't very important except that it's beneficial to be consistent,
512      // so that (a + b) and (b + a) don't end up as different expressions.
513
514      // Sort SCEVUnknown values with some loose heuristics. TODO: This is
515      // not as complete as it could be.
516      if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS)) {
517        const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
518
519        // Order pointer values after integer values. This helps SCEVExpander
520        // form GEPs.
521        if (LU->getType()->isPointerTy() && !RU->getType()->isPointerTy())
522          return false;
523        if (RU->getType()->isPointerTy() && !LU->getType()->isPointerTy())
524          return true;
525
526        // Compare getValueID values.
527        if (LU->getValue()->getValueID() != RU->getValue()->getValueID())
528          return LU->getValue()->getValueID() < RU->getValue()->getValueID();
529
530        // Sort arguments by their position.
531        if (const Argument *LA = dyn_cast<Argument>(LU->getValue())) {
532          const Argument *RA = cast<Argument>(RU->getValue());
533          return LA->getArgNo() < RA->getArgNo();
534        }
535
536        // For instructions, compare their loop depth, and their opcode.
537        // This is pretty loose.
538        if (Instruction *LV = dyn_cast<Instruction>(LU->getValue())) {
539          Instruction *RV = cast<Instruction>(RU->getValue());
540
541          // Compare loop depths.
542          if (LI->getLoopDepth(LV->getParent()) !=
543              LI->getLoopDepth(RV->getParent()))
544            return LI->getLoopDepth(LV->getParent()) <
545                   LI->getLoopDepth(RV->getParent());
546
547          // Compare opcodes.
548          if (LV->getOpcode() != RV->getOpcode())
549            return LV->getOpcode() < RV->getOpcode();
550
551          // Compare the number of operands.
552          if (LV->getNumOperands() != RV->getNumOperands())
553            return LV->getNumOperands() < RV->getNumOperands();
554        }
555
556        return false;
557      }
558
559      // Compare constant values.
560      if (const SCEVConstant *LC = dyn_cast<SCEVConstant>(LHS)) {
561        const SCEVConstant *RC = cast<SCEVConstant>(RHS);
562        if (LC->getValue()->getBitWidth() != RC->getValue()->getBitWidth())
563          return LC->getValue()->getBitWidth() < RC->getValue()->getBitWidth();
564        return LC->getValue()->getValue().ult(RC->getValue()->getValue());
565      }
566
567      // Compare addrec loop depths.
568      if (const SCEVAddRecExpr *LA = dyn_cast<SCEVAddRecExpr>(LHS)) {
569        const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
570        if (LA->getLoop()->getLoopDepth() != RA->getLoop()->getLoopDepth())
571          return LA->getLoop()->getLoopDepth() < RA->getLoop()->getLoopDepth();
572      }
573
574      // Lexicographically compare n-ary expressions.
575      if (const SCEVNAryExpr *LC = dyn_cast<SCEVNAryExpr>(LHS)) {
576        const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
577        for (unsigned i = 0, e = LC->getNumOperands(); i != e; ++i) {
578          if (i >= RC->getNumOperands())
579            return false;
580          if (operator()(LC->getOperand(i), RC->getOperand(i)))
581            return true;
582          if (operator()(RC->getOperand(i), LC->getOperand(i)))
583            return false;
584        }
585        return LC->getNumOperands() < RC->getNumOperands();
586      }
587
588      // Lexicographically compare udiv expressions.
589      if (const SCEVUDivExpr *LC = dyn_cast<SCEVUDivExpr>(LHS)) {
590        const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
591        if (operator()(LC->getLHS(), RC->getLHS()))
592          return true;
593        if (operator()(RC->getLHS(), LC->getLHS()))
594          return false;
595        if (operator()(LC->getRHS(), RC->getRHS()))
596          return true;
597        if (operator()(RC->getRHS(), LC->getRHS()))
598          return false;
599        return false;
600      }
601
602      // Compare cast expressions by operand.
603      if (const SCEVCastExpr *LC = dyn_cast<SCEVCastExpr>(LHS)) {
604        const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
605        return operator()(LC->getOperand(), RC->getOperand());
606      }
607
608      llvm_unreachable("Unknown SCEV kind!");
609      return false;
610    }
611  };
612}
613
614/// GroupByComplexity - Given a list of SCEV objects, order them by their
615/// complexity, and group objects of the same complexity together by value.
616/// When this routine is finished, we know that any duplicates in the vector are
617/// consecutive and that complexity is monotonically increasing.
618///
619/// Note that we go take special precautions to ensure that we get deterministic
620/// results from this routine.  In other words, we don't want the results of
621/// this to depend on where the addresses of various SCEV objects happened to
622/// land in memory.
623///
624static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
625                              LoopInfo *LI) {
626  if (Ops.size() < 2) return;  // Noop
627  if (Ops.size() == 2) {
628    // This is the common case, which also happens to be trivially simple.
629    // Special case it.
630    if (SCEVComplexityCompare(LI)(Ops[1], Ops[0]))
631      std::swap(Ops[0], Ops[1]);
632    return;
633  }
634
635  // Do the rough sort by complexity.
636  std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI));
637
638  // Now that we are sorted by complexity, group elements of the same
639  // complexity.  Note that this is, at worst, N^2, but the vector is likely to
640  // be extremely short in practice.  Note that we take this approach because we
641  // do not want to depend on the addresses of the objects we are grouping.
642  for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
643    const SCEV *S = Ops[i];
644    unsigned Complexity = S->getSCEVType();
645
646    // If there are any objects of the same complexity and same value as this
647    // one, group them.
648    for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
649      if (Ops[j] == S) { // Found a duplicate.
650        // Move it to immediately after i'th element.
651        std::swap(Ops[i+1], Ops[j]);
652        ++i;   // no need to rescan it.
653        if (i == e-2) return;  // Done!
654      }
655    }
656  }
657}
658
659
660
661//===----------------------------------------------------------------------===//
662//                      Simple SCEV method implementations
663//===----------------------------------------------------------------------===//
664
665/// BinomialCoefficient - Compute BC(It, K).  The result has width W.
666/// Assume, K > 0.
667static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
668                                       ScalarEvolution &SE,
669                                       const Type* ResultTy) {
670  // Handle the simplest case efficiently.
671  if (K == 1)
672    return SE.getTruncateOrZeroExtend(It, ResultTy);
673
674  // We are using the following formula for BC(It, K):
675  //
676  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
677  //
678  // Suppose, W is the bitwidth of the return value.  We must be prepared for
679  // overflow.  Hence, we must assure that the result of our computation is
680  // equal to the accurate one modulo 2^W.  Unfortunately, division isn't
681  // safe in modular arithmetic.
682  //
683  // However, this code doesn't use exactly that formula; the formula it uses
684  // is something like the following, where T is the number of factors of 2 in
685  // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
686  // exponentiation:
687  //
688  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
689  //
690  // This formula is trivially equivalent to the previous formula.  However,
691  // this formula can be implemented much more efficiently.  The trick is that
692  // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
693  // arithmetic.  To do exact division in modular arithmetic, all we have
694  // to do is multiply by the inverse.  Therefore, this step can be done at
695  // width W.
696  //
697  // The next issue is how to safely do the division by 2^T.  The way this
698  // is done is by doing the multiplication step at a width of at least W + T
699  // bits.  This way, the bottom W+T bits of the product are accurate. Then,
700  // when we perform the division by 2^T (which is equivalent to a right shift
701  // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get
702  // truncated out after the division by 2^T.
703  //
704  // In comparison to just directly using the first formula, this technique
705  // is much more efficient; using the first formula requires W * K bits,
706  // but this formula less than W + K bits. Also, the first formula requires
707  // a division step, whereas this formula only requires multiplies and shifts.
708  //
709  // It doesn't matter whether the subtraction step is done in the calculation
710  // width or the input iteration count's width; if the subtraction overflows,
711  // the result must be zero anyway.  We prefer here to do it in the width of
712  // the induction variable because it helps a lot for certain cases; CodeGen
713  // isn't smart enough to ignore the overflow, which leads to much less
714  // efficient code if the width of the subtraction is wider than the native
715  // register width.
716  //
717  // (It's possible to not widen at all by pulling out factors of 2 before
718  // the multiplication; for example, K=2 can be calculated as
719  // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
720  // extra arithmetic, so it's not an obvious win, and it gets
721  // much more complicated for K > 3.)
722
723  // Protection from insane SCEVs; this bound is conservative,
724  // but it probably doesn't matter.
725  if (K > 1000)
726    return SE.getCouldNotCompute();
727
728  unsigned W = SE.getTypeSizeInBits(ResultTy);
729
730  // Calculate K! / 2^T and T; we divide out the factors of two before
731  // multiplying for calculating K! / 2^T to avoid overflow.
732  // Other overflow doesn't matter because we only care about the bottom
733  // W bits of the result.
734  APInt OddFactorial(W, 1);
735  unsigned T = 1;
736  for (unsigned i = 3; i <= K; ++i) {
737    APInt Mult(W, i);
738    unsigned TwoFactors = Mult.countTrailingZeros();
739    T += TwoFactors;
740    Mult = Mult.lshr(TwoFactors);
741    OddFactorial *= Mult;
742  }
743
744  // We need at least W + T bits for the multiplication step
745  unsigned CalculationBits = W + T;
746
747  // Calculate 2^T, at width T+W.
748  APInt DivFactor = APInt(CalculationBits, 1).shl(T);
749
750  // Calculate the multiplicative inverse of K! / 2^T;
751  // this multiplication factor will perform the exact division by
752  // K! / 2^T.
753  APInt Mod = APInt::getSignedMinValue(W+1);
754  APInt MultiplyFactor = OddFactorial.zext(W+1);
755  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
756  MultiplyFactor = MultiplyFactor.trunc(W);
757
758  // Calculate the product, at width T+W
759  const IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
760                                                      CalculationBits);
761  const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
762  for (unsigned i = 1; i != K; ++i) {
763    const SCEV *S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType()));
764    Dividend = SE.getMulExpr(Dividend,
765                             SE.getTruncateOrZeroExtend(S, CalculationTy));
766  }
767
768  // Divide by 2^T
769  const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
770
771  // Truncate the result, and divide by K! / 2^T.
772
773  return SE.getMulExpr(SE.getConstant(MultiplyFactor),
774                       SE.getTruncateOrZeroExtend(DivResult, ResultTy));
775}
776
777/// evaluateAtIteration - Return the value of this chain of recurrences at
778/// the specified iteration number.  We can evaluate this recurrence by
779/// multiplying each element in the chain by the binomial coefficient
780/// corresponding to it.  In other words, we can evaluate {A,+,B,+,C,+,D} as:
781///
782///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
783///
784/// where BC(It, k) stands for binomial coefficient.
785///
786const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
787                                                ScalarEvolution &SE) const {
788  const SCEV *Result = getStart();
789  for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
790    // The computation is correct in the face of overflow provided that the
791    // multiplication is performed _after_ the evaluation of the binomial
792    // coefficient.
793    const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType());
794    if (isa<SCEVCouldNotCompute>(Coeff))
795      return Coeff;
796
797    Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
798  }
799  return Result;
800}
801
802//===----------------------------------------------------------------------===//
803//                    SCEV Expression folder implementations
804//===----------------------------------------------------------------------===//
805
806const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,
807                                             const Type *Ty) {
808  assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
809         "This is not a truncating conversion!");
810  assert(isSCEVable(Ty) &&
811         "This is not a conversion to a SCEVable type!");
812  Ty = getEffectiveSCEVType(Ty);
813
814  FoldingSetNodeID ID;
815  ID.AddInteger(scTruncate);
816  ID.AddPointer(Op);
817  ID.AddPointer(Ty);
818  void *IP = 0;
819  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
820
821  // Fold if the operand is constant.
822  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
823    return getConstant(
824      cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
825
826  // trunc(trunc(x)) --> trunc(x)
827  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
828    return getTruncateExpr(ST->getOperand(), Ty);
829
830  // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
831  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
832    return getTruncateOrSignExtend(SS->getOperand(), Ty);
833
834  // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
835  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
836    return getTruncateOrZeroExtend(SZ->getOperand(), Ty);
837
838  // If the input value is a chrec scev, truncate the chrec's operands.
839  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
840    SmallVector<const SCEV *, 4> Operands;
841    for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
842      Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty));
843    return getAddRecExpr(Operands, AddRec->getLoop());
844  }
845
846  // The cast wasn't folded; create an explicit cast node.
847  // Recompute the insert position, as it may have been invalidated.
848  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
849  SCEV *S = SCEVAllocator.Allocate<SCEVTruncateExpr>();
850  new (S) SCEVTruncateExpr(ID, Op, Ty);
851  UniqueSCEVs.InsertNode(S, IP);
852  return S;
853}
854
855const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
856                                               const Type *Ty) {
857  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
858         "This is not an extending conversion!");
859  assert(isSCEVable(Ty) &&
860         "This is not a conversion to a SCEVable type!");
861  Ty = getEffectiveSCEVType(Ty);
862
863  // Fold if the operand is constant.
864  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) {
865    const Type *IntTy = getEffectiveSCEVType(Ty);
866    Constant *C = ConstantExpr::getZExt(SC->getValue(), IntTy);
867    if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty);
868    return getConstant(cast<ConstantInt>(C));
869  }
870
871  // zext(zext(x)) --> zext(x)
872  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
873    return getZeroExtendExpr(SZ->getOperand(), Ty);
874
875  // Before doing any expensive analysis, check to see if we've already
876  // computed a SCEV for this Op and Ty.
877  FoldingSetNodeID ID;
878  ID.AddInteger(scZeroExtend);
879  ID.AddPointer(Op);
880  ID.AddPointer(Ty);
881  void *IP = 0;
882  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
883
884  // If the input value is a chrec scev, and we can prove that the value
885  // did not overflow the old, smaller, value, we can zero extend all of the
886  // operands (often constants).  This allows analysis of something like
887  // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
888  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
889    if (AR->isAffine()) {
890      const SCEV *Start = AR->getStart();
891      const SCEV *Step = AR->getStepRecurrence(*this);
892      unsigned BitWidth = getTypeSizeInBits(AR->getType());
893      const Loop *L = AR->getLoop();
894
895      // If we have special knowledge that this addrec won't overflow,
896      // we don't need to do any further analysis.
897      if (AR->hasNoUnsignedWrap())
898        return getAddRecExpr(getZeroExtendExpr(Start, Ty),
899                             getZeroExtendExpr(Step, Ty),
900                             L);
901
902      // Check whether the backedge-taken count is SCEVCouldNotCompute.
903      // Note that this serves two purposes: It filters out loops that are
904      // simply not analyzable, and it covers the case where this code is
905      // being called from within backedge-taken count analysis, such that
906      // attempting to ask for the backedge-taken count would likely result
907      // in infinite recursion. In the later case, the analysis code will
908      // cope with a conservative value, and it will take care to purge
909      // that value once it has finished.
910      const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
911      if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
912        // Manually compute the final value for AR, checking for
913        // overflow.
914
915        // Check whether the backedge-taken count can be losslessly casted to
916        // the addrec's type. The count is always unsigned.
917        const SCEV *CastedMaxBECount =
918          getTruncateOrZeroExtend(MaxBECount, Start->getType());
919        const SCEV *RecastedMaxBECount =
920          getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
921        if (MaxBECount == RecastedMaxBECount) {
922          const Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
923          // Check whether Start+Step*MaxBECount has no unsigned overflow.
924          const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step);
925          const SCEV *Add = getAddExpr(Start, ZMul);
926          const SCEV *OperandExtendedAdd =
927            getAddExpr(getZeroExtendExpr(Start, WideTy),
928                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
929                                  getZeroExtendExpr(Step, WideTy)));
930          if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd)
931            // Return the expression with the addrec on the outside.
932            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
933                                 getZeroExtendExpr(Step, Ty),
934                                 L);
935
936          // Similar to above, only this time treat the step value as signed.
937          // This covers loops that count down.
938          const SCEV *SMul = getMulExpr(CastedMaxBECount, Step);
939          Add = getAddExpr(Start, SMul);
940          OperandExtendedAdd =
941            getAddExpr(getZeroExtendExpr(Start, WideTy),
942                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
943                                  getSignExtendExpr(Step, WideTy)));
944          if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd)
945            // Return the expression with the addrec on the outside.
946            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
947                                 getSignExtendExpr(Step, Ty),
948                                 L);
949        }
950
951        // If the backedge is guarded by a comparison with the pre-inc value
952        // the addrec is safe. Also, if the entry is guarded by a comparison
953        // with the start value and the backedge is guarded by a comparison
954        // with the post-inc value, the addrec is safe.
955        if (isKnownPositive(Step)) {
956          const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
957                                      getUnsignedRange(Step).getUnsignedMax());
958          if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
959              (isLoopGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) &&
960               isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT,
961                                           AR->getPostIncExpr(*this), N)))
962            // Return the expression with the addrec on the outside.
963            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
964                                 getZeroExtendExpr(Step, Ty),
965                                 L);
966        } else if (isKnownNegative(Step)) {
967          const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
968                                      getSignedRange(Step).getSignedMin());
969          if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) &&
970              (isLoopGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) ||
971               isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT,
972                                           AR->getPostIncExpr(*this), N)))
973            // Return the expression with the addrec on the outside.
974            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
975                                 getSignExtendExpr(Step, Ty),
976                                 L);
977        }
978      }
979    }
980
981  // The cast wasn't folded; create an explicit cast node.
982  // Recompute the insert position, as it may have been invalidated.
983  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
984  SCEV *S = SCEVAllocator.Allocate<SCEVZeroExtendExpr>();
985  new (S) SCEVZeroExtendExpr(ID, Op, Ty);
986  UniqueSCEVs.InsertNode(S, IP);
987  return S;
988}
989
990const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
991                                               const Type *Ty) {
992  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
993         "This is not an extending conversion!");
994  assert(isSCEVable(Ty) &&
995         "This is not a conversion to a SCEVable type!");
996  Ty = getEffectiveSCEVType(Ty);
997
998  // Fold if the operand is constant.
999  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) {
1000    const Type *IntTy = getEffectiveSCEVType(Ty);
1001    Constant *C = ConstantExpr::getSExt(SC->getValue(), IntTy);
1002    if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty);
1003    return getConstant(cast<ConstantInt>(C));
1004  }
1005
1006  // sext(sext(x)) --> sext(x)
1007  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1008    return getSignExtendExpr(SS->getOperand(), Ty);
1009
1010  // Before doing any expensive analysis, check to see if we've already
1011  // computed a SCEV for this Op and Ty.
1012  FoldingSetNodeID ID;
1013  ID.AddInteger(scSignExtend);
1014  ID.AddPointer(Op);
1015  ID.AddPointer(Ty);
1016  void *IP = 0;
1017  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1018
1019  // If the input value is a chrec scev, and we can prove that the value
1020  // did not overflow the old, smaller, value, we can sign extend all of the
1021  // operands (often constants).  This allows analysis of something like
1022  // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
1023  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1024    if (AR->isAffine()) {
1025      const SCEV *Start = AR->getStart();
1026      const SCEV *Step = AR->getStepRecurrence(*this);
1027      unsigned BitWidth = getTypeSizeInBits(AR->getType());
1028      const Loop *L = AR->getLoop();
1029
1030      // If we have special knowledge that this addrec won't overflow,
1031      // we don't need to do any further analysis.
1032      if (AR->hasNoSignedWrap())
1033        return getAddRecExpr(getSignExtendExpr(Start, Ty),
1034                             getSignExtendExpr(Step, Ty),
1035                             L);
1036
1037      // Check whether the backedge-taken count is SCEVCouldNotCompute.
1038      // Note that this serves two purposes: It filters out loops that are
1039      // simply not analyzable, and it covers the case where this code is
1040      // being called from within backedge-taken count analysis, such that
1041      // attempting to ask for the backedge-taken count would likely result
1042      // in infinite recursion. In the later case, the analysis code will
1043      // cope with a conservative value, and it will take care to purge
1044      // that value once it has finished.
1045      const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
1046      if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1047        // Manually compute the final value for AR, checking for
1048        // overflow.
1049
1050        // Check whether the backedge-taken count can be losslessly casted to
1051        // the addrec's type. The count is always unsigned.
1052        const SCEV *CastedMaxBECount =
1053          getTruncateOrZeroExtend(MaxBECount, Start->getType());
1054        const SCEV *RecastedMaxBECount =
1055          getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
1056        if (MaxBECount == RecastedMaxBECount) {
1057          const Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1058          // Check whether Start+Step*MaxBECount has no signed overflow.
1059          const SCEV *SMul = getMulExpr(CastedMaxBECount, Step);
1060          const SCEV *Add = getAddExpr(Start, SMul);
1061          const SCEV *OperandExtendedAdd =
1062            getAddExpr(getSignExtendExpr(Start, WideTy),
1063                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
1064                                  getSignExtendExpr(Step, WideTy)));
1065          if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd)
1066            // Return the expression with the addrec on the outside.
1067            return getAddRecExpr(getSignExtendExpr(Start, Ty),
1068                                 getSignExtendExpr(Step, Ty),
1069                                 L);
1070
1071          // Similar to above, only this time treat the step value as unsigned.
1072          // This covers loops that count up with an unsigned step.
1073          const SCEV *UMul = getMulExpr(CastedMaxBECount, Step);
1074          Add = getAddExpr(Start, UMul);
1075          OperandExtendedAdd =
1076            getAddExpr(getSignExtendExpr(Start, WideTy),
1077                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
1078                                  getZeroExtendExpr(Step, WideTy)));
1079          if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd)
1080            // Return the expression with the addrec on the outside.
1081            return getAddRecExpr(getSignExtendExpr(Start, Ty),
1082                                 getZeroExtendExpr(Step, Ty),
1083                                 L);
1084        }
1085
1086        // If the backedge is guarded by a comparison with the pre-inc value
1087        // the addrec is safe. Also, if the entry is guarded by a comparison
1088        // with the start value and the backedge is guarded by a comparison
1089        // with the post-inc value, the addrec is safe.
1090        if (isKnownPositive(Step)) {
1091          const SCEV *N = getConstant(APInt::getSignedMinValue(BitWidth) -
1092                                      getSignedRange(Step).getSignedMax());
1093          if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SLT, AR, N) ||
1094              (isLoopGuardedByCond(L, ICmpInst::ICMP_SLT, Start, N) &&
1095               isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SLT,
1096                                           AR->getPostIncExpr(*this), N)))
1097            // Return the expression with the addrec on the outside.
1098            return getAddRecExpr(getSignExtendExpr(Start, Ty),
1099                                 getSignExtendExpr(Step, Ty),
1100                                 L);
1101        } else if (isKnownNegative(Step)) {
1102          const SCEV *N = getConstant(APInt::getSignedMaxValue(BitWidth) -
1103                                      getSignedRange(Step).getSignedMin());
1104          if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SGT, AR, N) ||
1105              (isLoopGuardedByCond(L, ICmpInst::ICMP_SGT, Start, N) &&
1106               isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SGT,
1107                                           AR->getPostIncExpr(*this), N)))
1108            // Return the expression with the addrec on the outside.
1109            return getAddRecExpr(getSignExtendExpr(Start, Ty),
1110                                 getSignExtendExpr(Step, Ty),
1111                                 L);
1112        }
1113      }
1114    }
1115
1116  // The cast wasn't folded; create an explicit cast node.
1117  // Recompute the insert position, as it may have been invalidated.
1118  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1119  SCEV *S = SCEVAllocator.Allocate<SCEVSignExtendExpr>();
1120  new (S) SCEVSignExtendExpr(ID, Op, Ty);
1121  UniqueSCEVs.InsertNode(S, IP);
1122  return S;
1123}
1124
1125/// getAnyExtendExpr - Return a SCEV for the given operand extended with
1126/// unspecified bits out to the given type.
1127///
1128const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
1129                                              const Type *Ty) {
1130  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1131         "This is not an extending conversion!");
1132  assert(isSCEVable(Ty) &&
1133         "This is not a conversion to a SCEVable type!");
1134  Ty = getEffectiveSCEVType(Ty);
1135
1136  // Sign-extend negative constants.
1137  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1138    if (SC->getValue()->getValue().isNegative())
1139      return getSignExtendExpr(Op, Ty);
1140
1141  // Peel off a truncate cast.
1142  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
1143    const SCEV *NewOp = T->getOperand();
1144    if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
1145      return getAnyExtendExpr(NewOp, Ty);
1146    return getTruncateOrNoop(NewOp, Ty);
1147  }
1148
1149  // Next try a zext cast. If the cast is folded, use it.
1150  const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
1151  if (!isa<SCEVZeroExtendExpr>(ZExt))
1152    return ZExt;
1153
1154  // Next try a sext cast. If the cast is folded, use it.
1155  const SCEV *SExt = getSignExtendExpr(Op, Ty);
1156  if (!isa<SCEVSignExtendExpr>(SExt))
1157    return SExt;
1158
1159  // Force the cast to be folded into the operands of an addrec.
1160  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
1161    SmallVector<const SCEV *, 4> Ops;
1162    for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
1163         I != E; ++I)
1164      Ops.push_back(getAnyExtendExpr(*I, Ty));
1165    return getAddRecExpr(Ops, AR->getLoop());
1166  }
1167
1168  // If the expression is obviously signed, use the sext cast value.
1169  if (isa<SCEVSMaxExpr>(Op))
1170    return SExt;
1171
1172  // Absent any other information, use the zext cast value.
1173  return ZExt;
1174}
1175
1176/// CollectAddOperandsWithScales - Process the given Ops list, which is
1177/// a list of operands to be added under the given scale, update the given
1178/// map. This is a helper function for getAddRecExpr. As an example of
1179/// what it does, given a sequence of operands that would form an add
1180/// expression like this:
1181///
1182///    m + n + 13 + (A * (o + p + (B * q + m + 29))) + r + (-1 * r)
1183///
1184/// where A and B are constants, update the map with these values:
1185///
1186///    (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
1187///
1188/// and add 13 + A*B*29 to AccumulatedConstant.
1189/// This will allow getAddRecExpr to produce this:
1190///
1191///    13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
1192///
1193/// This form often exposes folding opportunities that are hidden in
1194/// the original operand list.
1195///
1196/// Return true iff it appears that any interesting folding opportunities
1197/// may be exposed. This helps getAddRecExpr short-circuit extra work in
1198/// the common case where no interesting opportunities are present, and
1199/// is also used as a check to avoid infinite recursion.
1200///
1201static bool
1202CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
1203                             SmallVector<const SCEV *, 8> &NewOps,
1204                             APInt &AccumulatedConstant,
1205                             const SmallVectorImpl<const SCEV *> &Ops,
1206                             const APInt &Scale,
1207                             ScalarEvolution &SE) {
1208  bool Interesting = false;
1209
1210  // Iterate over the add operands.
1211  for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
1212    const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
1213    if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
1214      APInt NewScale =
1215        Scale * cast<SCEVConstant>(Mul->getOperand(0))->getValue()->getValue();
1216      if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
1217        // A multiplication of a constant with another add; recurse.
1218        Interesting |=
1219          CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1220                                       cast<SCEVAddExpr>(Mul->getOperand(1))
1221                                         ->getOperands(),
1222                                       NewScale, SE);
1223      } else {
1224        // A multiplication of a constant with some other value. Update
1225        // the map.
1226        SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end());
1227        const SCEV *Key = SE.getMulExpr(MulOps);
1228        std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
1229          M.insert(std::make_pair(Key, NewScale));
1230        if (Pair.second) {
1231          NewOps.push_back(Pair.first->first);
1232        } else {
1233          Pair.first->second += NewScale;
1234          // The map already had an entry for this value, which may indicate
1235          // a folding opportunity.
1236          Interesting = true;
1237        }
1238      }
1239    } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1240      // Pull a buried constant out to the outside.
1241      if (Scale != 1 || AccumulatedConstant != 0 || C->isZero())
1242        Interesting = true;
1243      AccumulatedConstant += Scale * C->getValue()->getValue();
1244    } else {
1245      // An ordinary operand. Update the map.
1246      std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
1247        M.insert(std::make_pair(Ops[i], Scale));
1248      if (Pair.second) {
1249        NewOps.push_back(Pair.first->first);
1250      } else {
1251        Pair.first->second += Scale;
1252        // The map already had an entry for this value, which may indicate
1253        // a folding opportunity.
1254        Interesting = true;
1255      }
1256    }
1257  }
1258
1259  return Interesting;
1260}
1261
1262namespace {
1263  struct APIntCompare {
1264    bool operator()(const APInt &LHS, const APInt &RHS) const {
1265      return LHS.ult(RHS);
1266    }
1267  };
1268}
1269
1270/// getAddExpr - Get a canonical add expression, or something simpler if
1271/// possible.
1272const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
1273                                        bool HasNUW, bool HasNSW) {
1274  assert(!Ops.empty() && "Cannot get empty add!");
1275  if (Ops.size() == 1) return Ops[0];
1276#ifndef NDEBUG
1277  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1278    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
1279           getEffectiveSCEVType(Ops[0]->getType()) &&
1280           "SCEVAddExpr operand types don't match!");
1281#endif
1282
1283  // If HasNSW is true and all the operands are non-negative, infer HasNUW.
1284  if (!HasNUW && HasNSW) {
1285    bool All = true;
1286    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1287      if (!isKnownNonNegative(Ops[i])) {
1288        All = false;
1289        break;
1290      }
1291    if (All) HasNUW = true;
1292  }
1293
1294  // Sort by complexity, this groups all similar expression types together.
1295  GroupByComplexity(Ops, LI);
1296
1297  // If there are any constants, fold them together.
1298  unsigned Idx = 0;
1299  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1300    ++Idx;
1301    assert(Idx < Ops.size());
1302    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1303      // We found two constants, fold them together!
1304      Ops[0] = getConstant(LHSC->getValue()->getValue() +
1305                           RHSC->getValue()->getValue());
1306      if (Ops.size() == 2) return Ops[0];
1307      Ops.erase(Ops.begin()+1);  // Erase the folded element
1308      LHSC = cast<SCEVConstant>(Ops[0]);
1309    }
1310
1311    // If we are left with a constant zero being added, strip it off.
1312    if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
1313      Ops.erase(Ops.begin());
1314      --Idx;
1315    }
1316  }
1317
1318  if (Ops.size() == 1) return Ops[0];
1319
1320  // Okay, check to see if the same value occurs in the operand list twice.  If
1321  // so, merge them together into an multiply expression.  Since we sorted the
1322  // list, these values are required to be adjacent.
1323  const Type *Ty = Ops[0]->getType();
1324  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1325    if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2
1326      // Found a match, merge the two values into a multiply, and add any
1327      // remaining values to the result.
1328      const SCEV *Two = getIntegerSCEV(2, Ty);
1329      const SCEV *Mul = getMulExpr(Ops[i], Two);
1330      if (Ops.size() == 2)
1331        return Mul;
1332      Ops.erase(Ops.begin()+i, Ops.begin()+i+2);
1333      Ops.push_back(Mul);
1334      return getAddExpr(Ops, HasNUW, HasNSW);
1335    }
1336
1337  // Check for truncates. If all the operands are truncated from the same
1338  // type, see if factoring out the truncate would permit the result to be
1339  // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n)
1340  // if the contents of the resulting outer trunc fold to something simple.
1341  for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) {
1342    const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]);
1343    const Type *DstType = Trunc->getType();
1344    const Type *SrcType = Trunc->getOperand()->getType();
1345    SmallVector<const SCEV *, 8> LargeOps;
1346    bool Ok = true;
1347    // Check all the operands to see if they can be represented in the
1348    // source type of the truncate.
1349    for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
1350      if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
1351        if (T->getOperand()->getType() != SrcType) {
1352          Ok = false;
1353          break;
1354        }
1355        LargeOps.push_back(T->getOperand());
1356      } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1357        // This could be either sign or zero extension, but sign extension
1358        // is much more likely to be foldable here.
1359        LargeOps.push_back(getSignExtendExpr(C, SrcType));
1360      } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
1361        SmallVector<const SCEV *, 8> LargeMulOps;
1362        for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
1363          if (const SCEVTruncateExpr *T =
1364                dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
1365            if (T->getOperand()->getType() != SrcType) {
1366              Ok = false;
1367              break;
1368            }
1369            LargeMulOps.push_back(T->getOperand());
1370          } else if (const SCEVConstant *C =
1371                       dyn_cast<SCEVConstant>(M->getOperand(j))) {
1372            // This could be either sign or zero extension, but sign extension
1373            // is much more likely to be foldable here.
1374            LargeMulOps.push_back(getSignExtendExpr(C, SrcType));
1375          } else {
1376            Ok = false;
1377            break;
1378          }
1379        }
1380        if (Ok)
1381          LargeOps.push_back(getMulExpr(LargeMulOps));
1382      } else {
1383        Ok = false;
1384        break;
1385      }
1386    }
1387    if (Ok) {
1388      // Evaluate the expression in the larger type.
1389      const SCEV *Fold = getAddExpr(LargeOps, HasNUW, HasNSW);
1390      // If it folds to something simple, use it. Otherwise, don't.
1391      if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
1392        return getTruncateExpr(Fold, DstType);
1393    }
1394  }
1395
1396  // Skip past any other cast SCEVs.
1397  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
1398    ++Idx;
1399
1400  // If there are add operands they would be next.
1401  if (Idx < Ops.size()) {
1402    bool DeletedAdd = false;
1403    while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
1404      // If we have an add, expand the add operands onto the end of the operands
1405      // list.
1406      Ops.insert(Ops.end(), Add->op_begin(), Add->op_end());
1407      Ops.erase(Ops.begin()+Idx);
1408      DeletedAdd = true;
1409    }
1410
1411    // If we deleted at least one add, we added operands to the end of the list,
1412    // and they are not necessarily sorted.  Recurse to resort and resimplify
1413    // any operands we just acquired.
1414    if (DeletedAdd)
1415      return getAddExpr(Ops);
1416  }
1417
1418  // Skip over the add expression until we get to a multiply.
1419  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1420    ++Idx;
1421
1422  // Check to see if there are any folding opportunities present with
1423  // operands multiplied by constant values.
1424  if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
1425    uint64_t BitWidth = getTypeSizeInBits(Ty);
1426    DenseMap<const SCEV *, APInt> M;
1427    SmallVector<const SCEV *, 8> NewOps;
1428    APInt AccumulatedConstant(BitWidth, 0);
1429    if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1430                                     Ops, APInt(BitWidth, 1), *this)) {
1431      // Some interesting folding opportunity is present, so its worthwhile to
1432      // re-generate the operands list. Group the operands by constant scale,
1433      // to avoid multiplying by the same constant scale multiple times.
1434      std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
1435      for (SmallVector<const SCEV *, 8>::iterator I = NewOps.begin(),
1436           E = NewOps.end(); I != E; ++I)
1437        MulOpLists[M.find(*I)->second].push_back(*I);
1438      // Re-generate the operands list.
1439      Ops.clear();
1440      if (AccumulatedConstant != 0)
1441        Ops.push_back(getConstant(AccumulatedConstant));
1442      for (std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare>::iterator
1443           I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I)
1444        if (I->first != 0)
1445          Ops.push_back(getMulExpr(getConstant(I->first),
1446                                   getAddExpr(I->second)));
1447      if (Ops.empty())
1448        return getIntegerSCEV(0, Ty);
1449      if (Ops.size() == 1)
1450        return Ops[0];
1451      return getAddExpr(Ops);
1452    }
1453  }
1454
1455  // If we are adding something to a multiply expression, make sure the
1456  // something is not already an operand of the multiply.  If so, merge it into
1457  // the multiply.
1458  for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
1459    const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
1460    for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
1461      const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
1462      for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
1463        if (MulOpSCEV == Ops[AddOp] && !isa<SCEVConstant>(Ops[AddOp])) {
1464          // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1))
1465          const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
1466          if (Mul->getNumOperands() != 2) {
1467            // If the multiply has more than two operands, we must get the
1468            // Y*Z term.
1469            SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(), Mul->op_end());
1470            MulOps.erase(MulOps.begin()+MulOp);
1471            InnerMul = getMulExpr(MulOps);
1472          }
1473          const SCEV *One = getIntegerSCEV(1, Ty);
1474          const SCEV *AddOne = getAddExpr(InnerMul, One);
1475          const SCEV *OuterMul = getMulExpr(AddOne, Ops[AddOp]);
1476          if (Ops.size() == 2) return OuterMul;
1477          if (AddOp < Idx) {
1478            Ops.erase(Ops.begin()+AddOp);
1479            Ops.erase(Ops.begin()+Idx-1);
1480          } else {
1481            Ops.erase(Ops.begin()+Idx);
1482            Ops.erase(Ops.begin()+AddOp-1);
1483          }
1484          Ops.push_back(OuterMul);
1485          return getAddExpr(Ops);
1486        }
1487
1488      // Check this multiply against other multiplies being added together.
1489      for (unsigned OtherMulIdx = Idx+1;
1490           OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
1491           ++OtherMulIdx) {
1492        const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
1493        // If MulOp occurs in OtherMul, we can fold the two multiplies
1494        // together.
1495        for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
1496             OMulOp != e; ++OMulOp)
1497          if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
1498            // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
1499            const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
1500            if (Mul->getNumOperands() != 2) {
1501              SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
1502                                                  Mul->op_end());
1503              MulOps.erase(MulOps.begin()+MulOp);
1504              InnerMul1 = getMulExpr(MulOps);
1505            }
1506            const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
1507            if (OtherMul->getNumOperands() != 2) {
1508              SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
1509                                                  OtherMul->op_end());
1510              MulOps.erase(MulOps.begin()+OMulOp);
1511              InnerMul2 = getMulExpr(MulOps);
1512            }
1513            const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2);
1514            const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum);
1515            if (Ops.size() == 2) return OuterMul;
1516            Ops.erase(Ops.begin()+Idx);
1517            Ops.erase(Ops.begin()+OtherMulIdx-1);
1518            Ops.push_back(OuterMul);
1519            return getAddExpr(Ops);
1520          }
1521      }
1522    }
1523  }
1524
1525  // If there are any add recurrences in the operands list, see if any other
1526  // added values are loop invariant.  If so, we can fold them into the
1527  // recurrence.
1528  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1529    ++Idx;
1530
1531  // Scan over all recurrences, trying to fold loop invariants into them.
1532  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1533    // Scan all of the other operands to this add and add them to the vector if
1534    // they are loop invariant w.r.t. the recurrence.
1535    SmallVector<const SCEV *, 8> LIOps;
1536    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1537    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1538      if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
1539        LIOps.push_back(Ops[i]);
1540        Ops.erase(Ops.begin()+i);
1541        --i; --e;
1542      }
1543
1544    // If we found some loop invariants, fold them into the recurrence.
1545    if (!LIOps.empty()) {
1546      //  NLI + LI + {Start,+,Step}  -->  NLI + {LI+Start,+,Step}
1547      LIOps.push_back(AddRec->getStart());
1548
1549      SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
1550                                             AddRec->op_end());
1551      AddRecOps[0] = getAddExpr(LIOps);
1552
1553      // It's tempting to propagate NUW/NSW flags here, but nuw/nsw addition
1554      // is not associative so this isn't necessarily safe.
1555      const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRec->getLoop());
1556
1557      // If all of the other operands were loop invariant, we are done.
1558      if (Ops.size() == 1) return NewRec;
1559
1560      // Otherwise, add the folded AddRec by the non-liv parts.
1561      for (unsigned i = 0;; ++i)
1562        if (Ops[i] == AddRec) {
1563          Ops[i] = NewRec;
1564          break;
1565        }
1566      return getAddExpr(Ops);
1567    }
1568
1569    // Okay, if there weren't any loop invariants to be folded, check to see if
1570    // there are multiple AddRec's with the same loop induction variable being
1571    // added together.  If so, we can fold them.
1572    for (unsigned OtherIdx = Idx+1;
1573         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
1574      if (OtherIdx != Idx) {
1575        const SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
1576        if (AddRec->getLoop() == OtherAddRec->getLoop()) {
1577          // Other + {A,+,B} + {C,+,D}  -->  Other + {A+C,+,B+D}
1578          SmallVector<const SCEV *, 4> NewOps(AddRec->op_begin(),
1579                                              AddRec->op_end());
1580          for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) {
1581            if (i >= NewOps.size()) {
1582              NewOps.insert(NewOps.end(), OtherAddRec->op_begin()+i,
1583                            OtherAddRec->op_end());
1584              break;
1585            }
1586            NewOps[i] = getAddExpr(NewOps[i], OtherAddRec->getOperand(i));
1587          }
1588          const SCEV *NewAddRec = getAddRecExpr(NewOps, AddRec->getLoop());
1589
1590          if (Ops.size() == 2) return NewAddRec;
1591
1592          Ops.erase(Ops.begin()+Idx);
1593          Ops.erase(Ops.begin()+OtherIdx-1);
1594          Ops.push_back(NewAddRec);
1595          return getAddExpr(Ops);
1596        }
1597      }
1598
1599    // Otherwise couldn't fold anything into this recurrence.  Move onto the
1600    // next one.
1601  }
1602
1603  // Okay, it looks like we really DO need an add expr.  Check to see if we
1604  // already have one, otherwise create a new one.
1605  FoldingSetNodeID ID;
1606  ID.AddInteger(scAddExpr);
1607  ID.AddInteger(Ops.size());
1608  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1609    ID.AddPointer(Ops[i]);
1610  void *IP = 0;
1611  SCEVAddExpr *S =
1612    static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1613  if (!S) {
1614    S = SCEVAllocator.Allocate<SCEVAddExpr>();
1615    new (S) SCEVAddExpr(ID, Ops);
1616    UniqueSCEVs.InsertNode(S, IP);
1617  }
1618  if (HasNUW) S->setHasNoUnsignedWrap(true);
1619  if (HasNSW) S->setHasNoSignedWrap(true);
1620  return S;
1621}
1622
1623/// getMulExpr - Get a canonical multiply expression, or something simpler if
1624/// possible.
1625const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
1626                                        bool HasNUW, bool HasNSW) {
1627  assert(!Ops.empty() && "Cannot get empty mul!");
1628  if (Ops.size() == 1) return Ops[0];
1629#ifndef NDEBUG
1630  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1631    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
1632           getEffectiveSCEVType(Ops[0]->getType()) &&
1633           "SCEVMulExpr operand types don't match!");
1634#endif
1635
1636  // If HasNSW is true and all the operands are non-negative, infer HasNUW.
1637  if (!HasNUW && HasNSW) {
1638    bool All = true;
1639    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1640      if (!isKnownNonNegative(Ops[i])) {
1641        All = false;
1642        break;
1643      }
1644    if (All) HasNUW = true;
1645  }
1646
1647  // Sort by complexity, this groups all similar expression types together.
1648  GroupByComplexity(Ops, LI);
1649
1650  // If there are any constants, fold them together.
1651  unsigned Idx = 0;
1652  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1653
1654    // C1*(C2+V) -> C1*C2 + C1*V
1655    if (Ops.size() == 2)
1656      if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
1657        if (Add->getNumOperands() == 2 &&
1658            isa<SCEVConstant>(Add->getOperand(0)))
1659          return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
1660                            getMulExpr(LHSC, Add->getOperand(1)));
1661
1662    ++Idx;
1663    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1664      // We found two constants, fold them together!
1665      ConstantInt *Fold = ConstantInt::get(getContext(),
1666                                           LHSC->getValue()->getValue() *
1667                                           RHSC->getValue()->getValue());
1668      Ops[0] = getConstant(Fold);
1669      Ops.erase(Ops.begin()+1);  // Erase the folded element
1670      if (Ops.size() == 1) return Ops[0];
1671      LHSC = cast<SCEVConstant>(Ops[0]);
1672    }
1673
1674    // If we are left with a constant one being multiplied, strip it off.
1675    if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) {
1676      Ops.erase(Ops.begin());
1677      --Idx;
1678    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
1679      // If we have a multiply of zero, it will always be zero.
1680      return Ops[0];
1681    } else if (Ops[0]->isAllOnesValue()) {
1682      // If we have a mul by -1 of an add, try distributing the -1 among the
1683      // add operands.
1684      if (Ops.size() == 2)
1685        if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
1686          SmallVector<const SCEV *, 4> NewOps;
1687          bool AnyFolded = false;
1688          for (SCEVAddRecExpr::op_iterator I = Add->op_begin(), E = Add->op_end();
1689               I != E; ++I) {
1690            const SCEV *Mul = getMulExpr(Ops[0], *I);
1691            if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
1692            NewOps.push_back(Mul);
1693          }
1694          if (AnyFolded)
1695            return getAddExpr(NewOps);
1696        }
1697    }
1698  }
1699
1700  // Skip over the add expression until we get to a multiply.
1701  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1702    ++Idx;
1703
1704  if (Ops.size() == 1)
1705    return Ops[0];
1706
1707  // If there are mul operands inline them all into this expression.
1708  if (Idx < Ops.size()) {
1709    bool DeletedMul = false;
1710    while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
1711      // If we have an mul, expand the mul operands onto the end of the operands
1712      // list.
1713      Ops.insert(Ops.end(), Mul->op_begin(), Mul->op_end());
1714      Ops.erase(Ops.begin()+Idx);
1715      DeletedMul = true;
1716    }
1717
1718    // If we deleted at least one mul, we added operands to the end of the list,
1719    // and they are not necessarily sorted.  Recurse to resort and resimplify
1720    // any operands we just acquired.
1721    if (DeletedMul)
1722      return getMulExpr(Ops);
1723  }
1724
1725  // If there are any add recurrences in the operands list, see if any other
1726  // added values are loop invariant.  If so, we can fold them into the
1727  // recurrence.
1728  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1729    ++Idx;
1730
1731  // Scan over all recurrences, trying to fold loop invariants into them.
1732  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1733    // Scan all of the other operands to this mul and add them to the vector if
1734    // they are loop invariant w.r.t. the recurrence.
1735    SmallVector<const SCEV *, 8> LIOps;
1736    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1737    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1738      if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
1739        LIOps.push_back(Ops[i]);
1740        Ops.erase(Ops.begin()+i);
1741        --i; --e;
1742      }
1743
1744    // If we found some loop invariants, fold them into the recurrence.
1745    if (!LIOps.empty()) {
1746      //  NLI * LI * {Start,+,Step}  -->  NLI * {LI*Start,+,LI*Step}
1747      SmallVector<const SCEV *, 4> NewOps;
1748      NewOps.reserve(AddRec->getNumOperands());
1749      if (LIOps.size() == 1) {
1750        const SCEV *Scale = LIOps[0];
1751        for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
1752          NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i)));
1753      } else {
1754        for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
1755          SmallVector<const SCEV *, 4> MulOps(LIOps.begin(), LIOps.end());
1756          MulOps.push_back(AddRec->getOperand(i));
1757          NewOps.push_back(getMulExpr(MulOps));
1758        }
1759      }
1760
1761      // It's tempting to propagate the NSW flag here, but nsw multiplication
1762      // is not associative so this isn't necessarily safe.
1763      const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(),
1764                                         HasNUW && AddRec->hasNoUnsignedWrap(),
1765                                         /*HasNSW=*/false);
1766
1767      // If all of the other operands were loop invariant, we are done.
1768      if (Ops.size() == 1) return NewRec;
1769
1770      // Otherwise, multiply the folded AddRec by the non-liv parts.
1771      for (unsigned i = 0;; ++i)
1772        if (Ops[i] == AddRec) {
1773          Ops[i] = NewRec;
1774          break;
1775        }
1776      return getMulExpr(Ops);
1777    }
1778
1779    // Okay, if there weren't any loop invariants to be folded, check to see if
1780    // there are multiple AddRec's with the same loop induction variable being
1781    // multiplied together.  If so, we can fold them.
1782    for (unsigned OtherIdx = Idx+1;
1783         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
1784      if (OtherIdx != Idx) {
1785        const SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
1786        if (AddRec->getLoop() == OtherAddRec->getLoop()) {
1787          // F * G  -->  {A,+,B} * {C,+,D}  -->  {A*C,+,F*D + G*B + B*D}
1788          const SCEVAddRecExpr *F = AddRec, *G = OtherAddRec;
1789          const SCEV *NewStart = getMulExpr(F->getStart(),
1790                                                 G->getStart());
1791          const SCEV *B = F->getStepRecurrence(*this);
1792          const SCEV *D = G->getStepRecurrence(*this);
1793          const SCEV *NewStep = getAddExpr(getMulExpr(F, D),
1794                                          getMulExpr(G, B),
1795                                          getMulExpr(B, D));
1796          const SCEV *NewAddRec = getAddRecExpr(NewStart, NewStep,
1797                                               F->getLoop());
1798          if (Ops.size() == 2) return NewAddRec;
1799
1800          Ops.erase(Ops.begin()+Idx);
1801          Ops.erase(Ops.begin()+OtherIdx-1);
1802          Ops.push_back(NewAddRec);
1803          return getMulExpr(Ops);
1804        }
1805      }
1806
1807    // Otherwise couldn't fold anything into this recurrence.  Move onto the
1808    // next one.
1809  }
1810
1811  // Okay, it looks like we really DO need an mul expr.  Check to see if we
1812  // already have one, otherwise create a new one.
1813  FoldingSetNodeID ID;
1814  ID.AddInteger(scMulExpr);
1815  ID.AddInteger(Ops.size());
1816  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1817    ID.AddPointer(Ops[i]);
1818  void *IP = 0;
1819  SCEVMulExpr *S =
1820    static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1821  if (!S) {
1822    S = SCEVAllocator.Allocate<SCEVMulExpr>();
1823    new (S) SCEVMulExpr(ID, Ops);
1824    UniqueSCEVs.InsertNode(S, IP);
1825  }
1826  if (HasNUW) S->setHasNoUnsignedWrap(true);
1827  if (HasNSW) S->setHasNoSignedWrap(true);
1828  return S;
1829}
1830
1831/// getUDivExpr - Get a canonical unsigned division expression, or something
1832/// simpler if possible.
1833const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
1834                                         const SCEV *RHS) {
1835  assert(getEffectiveSCEVType(LHS->getType()) ==
1836         getEffectiveSCEVType(RHS->getType()) &&
1837         "SCEVUDivExpr operand types don't match!");
1838
1839  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
1840    if (RHSC->getValue()->equalsInt(1))
1841      return LHS;                               // X udiv 1 --> x
1842    if (RHSC->isZero())
1843      return getIntegerSCEV(0, LHS->getType()); // value is undefined
1844
1845    // Determine if the division can be folded into the operands of
1846    // its operands.
1847    // TODO: Generalize this to non-constants by using known-bits information.
1848    const Type *Ty = LHS->getType();
1849    unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros();
1850    unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ;
1851    // For non-power-of-two values, effectively round the value up to the
1852    // nearest power of two.
1853    if (!RHSC->getValue()->getValue().isPowerOf2())
1854      ++MaxShiftAmt;
1855    const IntegerType *ExtTy =
1856      IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
1857    // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
1858    if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
1859      if (const SCEVConstant *Step =
1860            dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this)))
1861        if (!Step->getValue()->getValue()
1862              .urem(RHSC->getValue()->getValue()) &&
1863            getZeroExtendExpr(AR, ExtTy) ==
1864            getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
1865                          getZeroExtendExpr(Step, ExtTy),
1866                          AR->getLoop())) {
1867          SmallVector<const SCEV *, 4> Operands;
1868          for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i)
1869            Operands.push_back(getUDivExpr(AR->getOperand(i), RHS));
1870          return getAddRecExpr(Operands, AR->getLoop());
1871        }
1872    // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
1873    if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
1874      SmallVector<const SCEV *, 4> Operands;
1875      for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i)
1876        Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy));
1877      if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
1878        // Find an operand that's safely divisible.
1879        for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
1880          const SCEV *Op = M->getOperand(i);
1881          const SCEV *Div = getUDivExpr(Op, RHSC);
1882          if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
1883            const SmallVectorImpl<const SCEV *> &MOperands = M->getOperands();
1884            Operands = SmallVector<const SCEV *, 4>(MOperands.begin(),
1885                                                  MOperands.end());
1886            Operands[i] = Div;
1887            return getMulExpr(Operands);
1888          }
1889        }
1890    }
1891    // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
1892    if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(LHS)) {
1893      SmallVector<const SCEV *, 4> Operands;
1894      for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i)
1895        Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy));
1896      if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
1897        Operands.clear();
1898        for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
1899          const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
1900          if (isa<SCEVUDivExpr>(Op) || getMulExpr(Op, RHS) != A->getOperand(i))
1901            break;
1902          Operands.push_back(Op);
1903        }
1904        if (Operands.size() == A->getNumOperands())
1905          return getAddExpr(Operands);
1906      }
1907    }
1908
1909    // Fold if both operands are constant.
1910    if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
1911      Constant *LHSCV = LHSC->getValue();
1912      Constant *RHSCV = RHSC->getValue();
1913      return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
1914                                                                 RHSCV)));
1915    }
1916  }
1917
1918  FoldingSetNodeID ID;
1919  ID.AddInteger(scUDivExpr);
1920  ID.AddPointer(LHS);
1921  ID.AddPointer(RHS);
1922  void *IP = 0;
1923  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1924  SCEV *S = SCEVAllocator.Allocate<SCEVUDivExpr>();
1925  new (S) SCEVUDivExpr(ID, LHS, RHS);
1926  UniqueSCEVs.InsertNode(S, IP);
1927  return S;
1928}
1929
1930
1931/// getAddRecExpr - Get an add recurrence expression for the specified loop.
1932/// Simplify the expression as much as possible.
1933const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start,
1934                                           const SCEV *Step, const Loop *L,
1935                                           bool HasNUW, bool HasNSW) {
1936  SmallVector<const SCEV *, 4> Operands;
1937  Operands.push_back(Start);
1938  if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
1939    if (StepChrec->getLoop() == L) {
1940      Operands.insert(Operands.end(), StepChrec->op_begin(),
1941                      StepChrec->op_end());
1942      return getAddRecExpr(Operands, L);
1943    }
1944
1945  Operands.push_back(Step);
1946  return getAddRecExpr(Operands, L, HasNUW, HasNSW);
1947}
1948
1949/// getAddRecExpr - Get an add recurrence expression for the specified loop.
1950/// Simplify the expression as much as possible.
1951const SCEV *
1952ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
1953                               const Loop *L,
1954                               bool HasNUW, bool HasNSW) {
1955  if (Operands.size() == 1) return Operands[0];
1956#ifndef NDEBUG
1957  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
1958    assert(getEffectiveSCEVType(Operands[i]->getType()) ==
1959           getEffectiveSCEVType(Operands[0]->getType()) &&
1960           "SCEVAddRecExpr operand types don't match!");
1961#endif
1962
1963  if (Operands.back()->isZero()) {
1964    Operands.pop_back();
1965    return getAddRecExpr(Operands, L, HasNUW, HasNSW); // {X,+,0}  -->  X
1966  }
1967
1968  // It's tempting to want to call getMaxBackedgeTakenCount count here and
1969  // use that information to infer NUW and NSW flags. However, computing a
1970  // BE count requires calling getAddRecExpr, so we may not yet have a
1971  // meaningful BE count at this point (and if we don't, we'd be stuck
1972  // with a SCEVCouldNotCompute as the cached BE count).
1973
1974  // If HasNSW is true and all the operands are non-negative, infer HasNUW.
1975  if (!HasNUW && HasNSW) {
1976    bool All = true;
1977    for (unsigned i = 0, e = Operands.size(); i != e; ++i)
1978      if (!isKnownNonNegative(Operands[i])) {
1979        All = false;
1980        break;
1981      }
1982    if (All) HasNUW = true;
1983  }
1984
1985  // Canonicalize nested AddRecs in by nesting them in order of loop depth.
1986  if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
1987    const Loop *NestedLoop = NestedAR->getLoop();
1988    if (L->contains(NestedLoop->getHeader()) ?
1989        (L->getLoopDepth() < NestedLoop->getLoopDepth()) :
1990        (!NestedLoop->contains(L->getHeader()) &&
1991         DT->dominates(L->getHeader(), NestedLoop->getHeader()))) {
1992      SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(),
1993                                                  NestedAR->op_end());
1994      Operands[0] = NestedAR->getStart();
1995      // AddRecs require their operands be loop-invariant with respect to their
1996      // loops. Don't perform this transformation if it would break this
1997      // requirement.
1998      bool AllInvariant = true;
1999      for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2000        if (!Operands[i]->isLoopInvariant(L)) {
2001          AllInvariant = false;
2002          break;
2003        }
2004      if (AllInvariant) {
2005        NestedOperands[0] = getAddRecExpr(Operands, L);
2006        AllInvariant = true;
2007        for (unsigned i = 0, e = NestedOperands.size(); i != e; ++i)
2008          if (!NestedOperands[i]->isLoopInvariant(NestedLoop)) {
2009            AllInvariant = false;
2010            break;
2011          }
2012        if (AllInvariant)
2013          // Ok, both add recurrences are valid after the transformation.
2014          return getAddRecExpr(NestedOperands, NestedLoop, HasNUW, HasNSW);
2015      }
2016      // Reset Operands to its original state.
2017      Operands[0] = NestedAR;
2018    }
2019  }
2020
2021  // Okay, it looks like we really DO need an addrec expr.  Check to see if we
2022  // already have one, otherwise create a new one.
2023  FoldingSetNodeID ID;
2024  ID.AddInteger(scAddRecExpr);
2025  ID.AddInteger(Operands.size());
2026  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2027    ID.AddPointer(Operands[i]);
2028  ID.AddPointer(L);
2029  void *IP = 0;
2030  SCEVAddRecExpr *S =
2031    static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2032  if (!S) {
2033    S = SCEVAllocator.Allocate<SCEVAddRecExpr>();
2034    new (S) SCEVAddRecExpr(ID, Operands, L);
2035    UniqueSCEVs.InsertNode(S, IP);
2036  }
2037  if (HasNUW) S->setHasNoUnsignedWrap(true);
2038  if (HasNSW) S->setHasNoSignedWrap(true);
2039  return S;
2040}
2041
2042const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS,
2043                                         const SCEV *RHS) {
2044  SmallVector<const SCEV *, 2> Ops;
2045  Ops.push_back(LHS);
2046  Ops.push_back(RHS);
2047  return getSMaxExpr(Ops);
2048}
2049
2050const SCEV *
2051ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
2052  assert(!Ops.empty() && "Cannot get empty smax!");
2053  if (Ops.size() == 1) return Ops[0];
2054#ifndef NDEBUG
2055  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2056    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
2057           getEffectiveSCEVType(Ops[0]->getType()) &&
2058           "SCEVSMaxExpr operand types don't match!");
2059#endif
2060
2061  // Sort by complexity, this groups all similar expression types together.
2062  GroupByComplexity(Ops, LI);
2063
2064  // If there are any constants, fold them together.
2065  unsigned Idx = 0;
2066  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2067    ++Idx;
2068    assert(Idx < Ops.size());
2069    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2070      // We found two constants, fold them together!
2071      ConstantInt *Fold = ConstantInt::get(getContext(),
2072                              APIntOps::smax(LHSC->getValue()->getValue(),
2073                                             RHSC->getValue()->getValue()));
2074      Ops[0] = getConstant(Fold);
2075      Ops.erase(Ops.begin()+1);  // Erase the folded element
2076      if (Ops.size() == 1) return Ops[0];
2077      LHSC = cast<SCEVConstant>(Ops[0]);
2078    }
2079
2080    // If we are left with a constant minimum-int, strip it off.
2081    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
2082      Ops.erase(Ops.begin());
2083      --Idx;
2084    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(true)) {
2085      // If we have an smax with a constant maximum-int, it will always be
2086      // maximum-int.
2087      return Ops[0];
2088    }
2089  }
2090
2091  if (Ops.size() == 1) return Ops[0];
2092
2093  // Find the first SMax
2094  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
2095    ++Idx;
2096
2097  // Check to see if one of the operands is an SMax. If so, expand its operands
2098  // onto our operand list, and recurse to simplify.
2099  if (Idx < Ops.size()) {
2100    bool DeletedSMax = false;
2101    while (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
2102      Ops.insert(Ops.end(), SMax->op_begin(), SMax->op_end());
2103      Ops.erase(Ops.begin()+Idx);
2104      DeletedSMax = true;
2105    }
2106
2107    if (DeletedSMax)
2108      return getSMaxExpr(Ops);
2109  }
2110
2111  // Okay, check to see if the same value occurs in the operand list twice.  If
2112  // so, delete one.  Since we sorted the list, these values are required to
2113  // be adjacent.
2114  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
2115    if (Ops[i] == Ops[i+1]) {      //  X smax Y smax Y  -->  X smax Y
2116      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
2117      --i; --e;
2118    }
2119
2120  if (Ops.size() == 1) return Ops[0];
2121
2122  assert(!Ops.empty() && "Reduced smax down to nothing!");
2123
2124  // Okay, it looks like we really DO need an smax expr.  Check to see if we
2125  // already have one, otherwise create a new one.
2126  FoldingSetNodeID ID;
2127  ID.AddInteger(scSMaxExpr);
2128  ID.AddInteger(Ops.size());
2129  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2130    ID.AddPointer(Ops[i]);
2131  void *IP = 0;
2132  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2133  SCEV *S = SCEVAllocator.Allocate<SCEVSMaxExpr>();
2134  new (S) SCEVSMaxExpr(ID, Ops);
2135  UniqueSCEVs.InsertNode(S, IP);
2136  return S;
2137}
2138
2139const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS,
2140                                         const SCEV *RHS) {
2141  SmallVector<const SCEV *, 2> Ops;
2142  Ops.push_back(LHS);
2143  Ops.push_back(RHS);
2144  return getUMaxExpr(Ops);
2145}
2146
2147const SCEV *
2148ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
2149  assert(!Ops.empty() && "Cannot get empty umax!");
2150  if (Ops.size() == 1) return Ops[0];
2151#ifndef NDEBUG
2152  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2153    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
2154           getEffectiveSCEVType(Ops[0]->getType()) &&
2155           "SCEVUMaxExpr operand types don't match!");
2156#endif
2157
2158  // Sort by complexity, this groups all similar expression types together.
2159  GroupByComplexity(Ops, LI);
2160
2161  // If there are any constants, fold them together.
2162  unsigned Idx = 0;
2163  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2164    ++Idx;
2165    assert(Idx < Ops.size());
2166    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2167      // We found two constants, fold them together!
2168      ConstantInt *Fold = ConstantInt::get(getContext(),
2169                              APIntOps::umax(LHSC->getValue()->getValue(),
2170                                             RHSC->getValue()->getValue()));
2171      Ops[0] = getConstant(Fold);
2172      Ops.erase(Ops.begin()+1);  // Erase the folded element
2173      if (Ops.size() == 1) return Ops[0];
2174      LHSC = cast<SCEVConstant>(Ops[0]);
2175    }
2176
2177    // If we are left with a constant minimum-int, strip it off.
2178    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) {
2179      Ops.erase(Ops.begin());
2180      --Idx;
2181    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(false)) {
2182      // If we have an umax with a constant maximum-int, it will always be
2183      // maximum-int.
2184      return Ops[0];
2185    }
2186  }
2187
2188  if (Ops.size() == 1) return Ops[0];
2189
2190  // Find the first UMax
2191  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr)
2192    ++Idx;
2193
2194  // Check to see if one of the operands is a UMax. If so, expand its operands
2195  // onto our operand list, and recurse to simplify.
2196  if (Idx < Ops.size()) {
2197    bool DeletedUMax = false;
2198    while (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) {
2199      Ops.insert(Ops.end(), UMax->op_begin(), UMax->op_end());
2200      Ops.erase(Ops.begin()+Idx);
2201      DeletedUMax = true;
2202    }
2203
2204    if (DeletedUMax)
2205      return getUMaxExpr(Ops);
2206  }
2207
2208  // Okay, check to see if the same value occurs in the operand list twice.  If
2209  // so, delete one.  Since we sorted the list, these values are required to
2210  // be adjacent.
2211  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
2212    if (Ops[i] == Ops[i+1]) {      //  X umax Y umax Y  -->  X umax Y
2213      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
2214      --i; --e;
2215    }
2216
2217  if (Ops.size() == 1) return Ops[0];
2218
2219  assert(!Ops.empty() && "Reduced umax down to nothing!");
2220
2221  // Okay, it looks like we really DO need a umax expr.  Check to see if we
2222  // already have one, otherwise create a new one.
2223  FoldingSetNodeID ID;
2224  ID.AddInteger(scUMaxExpr);
2225  ID.AddInteger(Ops.size());
2226  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2227    ID.AddPointer(Ops[i]);
2228  void *IP = 0;
2229  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2230  SCEV *S = SCEVAllocator.Allocate<SCEVUMaxExpr>();
2231  new (S) SCEVUMaxExpr(ID, Ops);
2232  UniqueSCEVs.InsertNode(S, IP);
2233  return S;
2234}
2235
2236const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
2237                                         const SCEV *RHS) {
2238  // ~smax(~x, ~y) == smin(x, y).
2239  return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
2240}
2241
2242const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS,
2243                                         const SCEV *RHS) {
2244  // ~umax(~x, ~y) == umin(x, y)
2245  return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
2246}
2247
2248const SCEV *ScalarEvolution::getSizeOfExpr(const Type *AllocTy) {
2249  Constant *C = ConstantExpr::getSizeOf(AllocTy);
2250  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2251    C = ConstantFoldConstantExpression(CE, TD);
2252  const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy));
2253  return getTruncateOrZeroExtend(getSCEV(C), Ty);
2254}
2255
2256const SCEV *ScalarEvolution::getAlignOfExpr(const Type *AllocTy) {
2257  Constant *C = ConstantExpr::getAlignOf(AllocTy);
2258  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2259    C = ConstantFoldConstantExpression(CE, TD);
2260  const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy));
2261  return getTruncateOrZeroExtend(getSCEV(C), Ty);
2262}
2263
2264const SCEV *ScalarEvolution::getOffsetOfExpr(const StructType *STy,
2265                                             unsigned FieldNo) {
2266  Constant *C = ConstantExpr::getOffsetOf(STy, FieldNo);
2267  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2268    C = ConstantFoldConstantExpression(CE, TD);
2269  const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(STy));
2270  return getTruncateOrZeroExtend(getSCEV(C), Ty);
2271}
2272
2273const SCEV *ScalarEvolution::getOffsetOfExpr(const Type *CTy,
2274                                             Constant *FieldNo) {
2275  Constant *C = ConstantExpr::getOffsetOf(CTy, FieldNo);
2276  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2277    C = ConstantFoldConstantExpression(CE, TD);
2278  const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(CTy));
2279  return getTruncateOrZeroExtend(getSCEV(C), Ty);
2280}
2281
2282const SCEV *ScalarEvolution::getUnknown(Value *V) {
2283  // Don't attempt to do anything other than create a SCEVUnknown object
2284  // here.  createSCEV only calls getUnknown after checking for all other
2285  // interesting possibilities, and any other code that calls getUnknown
2286  // is doing so in order to hide a value from SCEV canonicalization.
2287
2288  FoldingSetNodeID ID;
2289  ID.AddInteger(scUnknown);
2290  ID.AddPointer(V);
2291  void *IP = 0;
2292  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2293  SCEV *S = SCEVAllocator.Allocate<SCEVUnknown>();
2294  new (S) SCEVUnknown(ID, V);
2295  UniqueSCEVs.InsertNode(S, IP);
2296  return S;
2297}
2298
2299//===----------------------------------------------------------------------===//
2300//            Basic SCEV Analysis and PHI Idiom Recognition Code
2301//
2302
2303/// isSCEVable - Test if values of the given type are analyzable within
2304/// the SCEV framework. This primarily includes integer types, and it
2305/// can optionally include pointer types if the ScalarEvolution class
2306/// has access to target-specific information.
2307bool ScalarEvolution::isSCEVable(const Type *Ty) const {
2308  // Integers and pointers are always SCEVable.
2309  return Ty->isIntegerTy() || Ty->isPointerTy();
2310}
2311
2312/// getTypeSizeInBits - Return the size in bits of the specified type,
2313/// for which isSCEVable must return true.
2314uint64_t ScalarEvolution::getTypeSizeInBits(const Type *Ty) const {
2315  assert(isSCEVable(Ty) && "Type is not SCEVable!");
2316
2317  // If we have a TargetData, use it!
2318  if (TD)
2319    return TD->getTypeSizeInBits(Ty);
2320
2321  // Integer types have fixed sizes.
2322  if (Ty->isIntegerTy())
2323    return Ty->getPrimitiveSizeInBits();
2324
2325  // The only other support type is pointer. Without TargetData, conservatively
2326  // assume pointers are 64-bit.
2327  assert(Ty->isPointerTy() && "isSCEVable permitted a non-SCEVable type!");
2328  return 64;
2329}
2330
2331/// getEffectiveSCEVType - Return a type with the same bitwidth as
2332/// the given type and which represents how SCEV will treat the given
2333/// type, for which isSCEVable must return true. For pointer types,
2334/// this is the pointer-sized integer type.
2335const Type *ScalarEvolution::getEffectiveSCEVType(const Type *Ty) const {
2336  assert(isSCEVable(Ty) && "Type is not SCEVable!");
2337
2338  if (Ty->isIntegerTy())
2339    return Ty;
2340
2341  // The only other support type is pointer.
2342  assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
2343  if (TD) return TD->getIntPtrType(getContext());
2344
2345  // Without TargetData, conservatively assume pointers are 64-bit.
2346  return Type::getInt64Ty(getContext());
2347}
2348
2349const SCEV *ScalarEvolution::getCouldNotCompute() {
2350  return &CouldNotCompute;
2351}
2352
2353/// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
2354/// expression and create a new one.
2355const SCEV *ScalarEvolution::getSCEV(Value *V) {
2356  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
2357
2358  std::map<SCEVCallbackVH, const SCEV *>::iterator I = Scalars.find(V);
2359  if (I != Scalars.end()) return I->second;
2360  const SCEV *S = createSCEV(V);
2361  Scalars.insert(std::make_pair(SCEVCallbackVH(V, this), S));
2362  return S;
2363}
2364
2365/// getIntegerSCEV - Given a SCEVable type, create a constant for the
2366/// specified signed integer value and return a SCEV for the constant.
2367const SCEV *ScalarEvolution::getIntegerSCEV(int64_t Val, const Type *Ty) {
2368  const IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
2369  return getConstant(ConstantInt::get(ITy, Val));
2370}
2371
2372/// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
2373///
2374const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V) {
2375  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
2376    return getConstant(
2377               cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
2378
2379  const Type *Ty = V->getType();
2380  Ty = getEffectiveSCEVType(Ty);
2381  return getMulExpr(V,
2382                  getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))));
2383}
2384
2385/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
2386const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
2387  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
2388    return getConstant(
2389                cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
2390
2391  const Type *Ty = V->getType();
2392  Ty = getEffectiveSCEVType(Ty);
2393  const SCEV *AllOnes =
2394                   getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty)));
2395  return getMinusSCEV(AllOnes, V);
2396}
2397
2398/// getMinusSCEV - Return a SCEV corresponding to LHS - RHS.
2399///
2400const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS,
2401                                          const SCEV *RHS) {
2402  // X - Y --> X + -Y
2403  return getAddExpr(LHS, getNegativeSCEV(RHS));
2404}
2405
2406/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
2407/// input value to the specified type.  If the type must be extended, it is zero
2408/// extended.
2409const SCEV *
2410ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V,
2411                                         const Type *Ty) {
2412  const Type *SrcTy = V->getType();
2413  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2414         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2415         "Cannot truncate or zero extend with non-integer arguments!");
2416  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2417    return V;  // No conversion
2418  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
2419    return getTruncateExpr(V, Ty);
2420  return getZeroExtendExpr(V, Ty);
2421}
2422
2423/// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the
2424/// input value to the specified type.  If the type must be extended, it is sign
2425/// extended.
2426const SCEV *
2427ScalarEvolution::getTruncateOrSignExtend(const SCEV *V,
2428                                         const Type *Ty) {
2429  const Type *SrcTy = V->getType();
2430  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2431         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2432         "Cannot truncate or zero extend with non-integer arguments!");
2433  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2434    return V;  // No conversion
2435  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
2436    return getTruncateExpr(V, Ty);
2437  return getSignExtendExpr(V, Ty);
2438}
2439
2440/// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the
2441/// input value to the specified type.  If the type must be extended, it is zero
2442/// extended.  The conversion must not be narrowing.
2443const SCEV *
2444ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, const Type *Ty) {
2445  const Type *SrcTy = V->getType();
2446  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2447         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2448         "Cannot noop or zero extend with non-integer arguments!");
2449  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2450         "getNoopOrZeroExtend cannot truncate!");
2451  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2452    return V;  // No conversion
2453  return getZeroExtendExpr(V, Ty);
2454}
2455
2456/// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the
2457/// input value to the specified type.  If the type must be extended, it is sign
2458/// extended.  The conversion must not be narrowing.
2459const SCEV *
2460ScalarEvolution::getNoopOrSignExtend(const SCEV *V, const Type *Ty) {
2461  const Type *SrcTy = V->getType();
2462  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2463         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2464         "Cannot noop or sign extend with non-integer arguments!");
2465  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2466         "getNoopOrSignExtend cannot truncate!");
2467  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2468    return V;  // No conversion
2469  return getSignExtendExpr(V, Ty);
2470}
2471
2472/// getNoopOrAnyExtend - Return a SCEV corresponding to a conversion of
2473/// the input value to the specified type. If the type must be extended,
2474/// it is extended with unspecified bits. The conversion must not be
2475/// narrowing.
2476const SCEV *
2477ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, const Type *Ty) {
2478  const Type *SrcTy = V->getType();
2479  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2480         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2481         "Cannot noop or any extend with non-integer arguments!");
2482  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2483         "getNoopOrAnyExtend cannot truncate!");
2484  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2485    return V;  // No conversion
2486  return getAnyExtendExpr(V, Ty);
2487}
2488
2489/// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the
2490/// input value to the specified type.  The conversion must not be widening.
2491const SCEV *
2492ScalarEvolution::getTruncateOrNoop(const SCEV *V, const Type *Ty) {
2493  const Type *SrcTy = V->getType();
2494  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2495         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2496         "Cannot truncate or noop with non-integer arguments!");
2497  assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
2498         "getTruncateOrNoop cannot extend!");
2499  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2500    return V;  // No conversion
2501  return getTruncateExpr(V, Ty);
2502}
2503
2504/// getUMaxFromMismatchedTypes - Promote the operands to the wider of
2505/// the types using zero-extension, and then perform a umax operation
2506/// with them.
2507const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
2508                                                        const SCEV *RHS) {
2509  const SCEV *PromotedLHS = LHS;
2510  const SCEV *PromotedRHS = RHS;
2511
2512  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
2513    PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
2514  else
2515    PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
2516
2517  return getUMaxExpr(PromotedLHS, PromotedRHS);
2518}
2519
2520/// getUMinFromMismatchedTypes - Promote the operands to the wider of
2521/// the types using zero-extension, and then perform a umin operation
2522/// with them.
2523const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
2524                                                        const SCEV *RHS) {
2525  const SCEV *PromotedLHS = LHS;
2526  const SCEV *PromotedRHS = RHS;
2527
2528  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
2529    PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
2530  else
2531    PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
2532
2533  return getUMinExpr(PromotedLHS, PromotedRHS);
2534}
2535
2536/// PushDefUseChildren - Push users of the given Instruction
2537/// onto the given Worklist.
2538static void
2539PushDefUseChildren(Instruction *I,
2540                   SmallVectorImpl<Instruction *> &Worklist) {
2541  // Push the def-use children onto the Worklist stack.
2542  for (Value::use_iterator UI = I->use_begin(), UE = I->use_end();
2543       UI != UE; ++UI)
2544    Worklist.push_back(cast<Instruction>(UI));
2545}
2546
2547/// ForgetSymbolicValue - This looks up computed SCEV values for all
2548/// instructions that depend on the given instruction and removes them from
2549/// the Scalars map if they reference SymName. This is used during PHI
2550/// resolution.
2551void
2552ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) {
2553  SmallVector<Instruction *, 16> Worklist;
2554  PushDefUseChildren(PN, Worklist);
2555
2556  SmallPtrSet<Instruction *, 8> Visited;
2557  Visited.insert(PN);
2558  while (!Worklist.empty()) {
2559    Instruction *I = Worklist.pop_back_val();
2560    if (!Visited.insert(I)) continue;
2561
2562    std::map<SCEVCallbackVH, const SCEV *>::iterator It =
2563      Scalars.find(static_cast<Value *>(I));
2564    if (It != Scalars.end()) {
2565      // Short-circuit the def-use traversal if the symbolic name
2566      // ceases to appear in expressions.
2567      if (It->second != SymName && !It->second->hasOperand(SymName))
2568        continue;
2569
2570      // SCEVUnknown for a PHI either means that it has an unrecognized
2571      // structure, it's a PHI that's in the progress of being computed
2572      // by createNodeForPHI, or it's a single-value PHI. In the first case,
2573      // additional loop trip count information isn't going to change anything.
2574      // In the second case, createNodeForPHI will perform the necessary
2575      // updates on its own when it gets to that point. In the third, we do
2576      // want to forget the SCEVUnknown.
2577      if (!isa<PHINode>(I) ||
2578          !isa<SCEVUnknown>(It->second) ||
2579          (I != PN && It->second == SymName)) {
2580        ValuesAtScopes.erase(It->second);
2581        Scalars.erase(It);
2582      }
2583    }
2584
2585    PushDefUseChildren(I, Worklist);
2586  }
2587}
2588
2589/// createNodeForPHI - PHI nodes have two cases.  Either the PHI node exists in
2590/// a loop header, making it a potential recurrence, or it doesn't.
2591///
2592const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
2593  if (PN->getNumIncomingValues() == 2)  // The loops have been canonicalized.
2594    if (const Loop *L = LI->getLoopFor(PN->getParent()))
2595      if (L->getHeader() == PN->getParent()) {
2596        // If it lives in the loop header, it has two incoming values, one
2597        // from outside the loop, and one from inside.
2598        unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0));
2599        unsigned BackEdge     = IncomingEdge^1;
2600
2601        // While we are analyzing this PHI node, handle its value symbolically.
2602        const SCEV *SymbolicName = getUnknown(PN);
2603        assert(Scalars.find(PN) == Scalars.end() &&
2604               "PHI node already processed?");
2605        Scalars.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName));
2606
2607        // Using this symbolic name for the PHI, analyze the value coming around
2608        // the back-edge.
2609        Value *BEValueV = PN->getIncomingValue(BackEdge);
2610        const SCEV *BEValue = getSCEV(BEValueV);
2611
2612        // NOTE: If BEValue is loop invariant, we know that the PHI node just
2613        // has a special value for the first iteration of the loop.
2614
2615        // If the value coming around the backedge is an add with the symbolic
2616        // value we just inserted, then we found a simple induction variable!
2617        if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
2618          // If there is a single occurrence of the symbolic value, replace it
2619          // with a recurrence.
2620          unsigned FoundIndex = Add->getNumOperands();
2621          for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
2622            if (Add->getOperand(i) == SymbolicName)
2623              if (FoundIndex == e) {
2624                FoundIndex = i;
2625                break;
2626              }
2627
2628          if (FoundIndex != Add->getNumOperands()) {
2629            // Create an add with everything but the specified operand.
2630            SmallVector<const SCEV *, 8> Ops;
2631            for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
2632              if (i != FoundIndex)
2633                Ops.push_back(Add->getOperand(i));
2634            const SCEV *Accum = getAddExpr(Ops);
2635
2636            // This is not a valid addrec if the step amount is varying each
2637            // loop iteration, but is not itself an addrec in this loop.
2638            if (Accum->isLoopInvariant(L) ||
2639                (isa<SCEVAddRecExpr>(Accum) &&
2640                 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
2641              bool HasNUW = false;
2642              bool HasNSW = false;
2643
2644              // If the increment doesn't overflow, then neither the addrec nor
2645              // the post-increment will overflow.
2646              if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV)) {
2647                if (OBO->hasNoUnsignedWrap())
2648                  HasNUW = true;
2649                if (OBO->hasNoSignedWrap())
2650                  HasNSW = true;
2651              }
2652
2653              const SCEV *StartVal =
2654                getSCEV(PN->getIncomingValue(IncomingEdge));
2655              const SCEV *PHISCEV =
2656                getAddRecExpr(StartVal, Accum, L, HasNUW, HasNSW);
2657
2658              // Since the no-wrap flags are on the increment, they apply to the
2659              // post-incremented value as well.
2660              if (Accum->isLoopInvariant(L))
2661                (void)getAddRecExpr(getAddExpr(StartVal, Accum),
2662                                    Accum, L, HasNUW, HasNSW);
2663
2664              // Okay, for the entire analysis of this edge we assumed the PHI
2665              // to be symbolic.  We now need to go back and purge all of the
2666              // entries for the scalars that use the symbolic expression.
2667              ForgetSymbolicName(PN, SymbolicName);
2668              Scalars[SCEVCallbackVH(PN, this)] = PHISCEV;
2669              return PHISCEV;
2670            }
2671          }
2672        } else if (const SCEVAddRecExpr *AddRec =
2673                     dyn_cast<SCEVAddRecExpr>(BEValue)) {
2674          // Otherwise, this could be a loop like this:
2675          //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
2676          // In this case, j = {1,+,1}  and BEValue is j.
2677          // Because the other in-value of i (0) fits the evolution of BEValue
2678          // i really is an addrec evolution.
2679          if (AddRec->getLoop() == L && AddRec->isAffine()) {
2680            const SCEV *StartVal = getSCEV(PN->getIncomingValue(IncomingEdge));
2681
2682            // If StartVal = j.start - j.stride, we can use StartVal as the
2683            // initial step of the addrec evolution.
2684            if (StartVal == getMinusSCEV(AddRec->getOperand(0),
2685                                            AddRec->getOperand(1))) {
2686              const SCEV *PHISCEV =
2687                 getAddRecExpr(StartVal, AddRec->getOperand(1), L);
2688
2689              // Okay, for the entire analysis of this edge we assumed the PHI
2690              // to be symbolic.  We now need to go back and purge all of the
2691              // entries for the scalars that use the symbolic expression.
2692              ForgetSymbolicName(PN, SymbolicName);
2693              Scalars[SCEVCallbackVH(PN, this)] = PHISCEV;
2694              return PHISCEV;
2695            }
2696          }
2697        }
2698
2699        return SymbolicName;
2700      }
2701
2702  // If the PHI has a single incoming value, follow that value, unless the
2703  // PHI's incoming blocks are in a different loop, in which case doing so
2704  // risks breaking LCSSA form. Instcombine would normally zap these, but
2705  // it doesn't have DominatorTree information, so it may miss cases.
2706  if (Value *V = PN->hasConstantValue(DT)) {
2707    bool AllSameLoop = true;
2708    Loop *PNLoop = LI->getLoopFor(PN->getParent());
2709    for (size_t i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
2710      if (LI->getLoopFor(PN->getIncomingBlock(i)) != PNLoop) {
2711        AllSameLoop = false;
2712        break;
2713      }
2714    if (AllSameLoop)
2715      return getSCEV(V);
2716  }
2717
2718  // If it's not a loop phi, we can't handle it yet.
2719  return getUnknown(PN);
2720}
2721
2722/// createNodeForGEP - Expand GEP instructions into add and multiply
2723/// operations. This allows them to be analyzed by regular SCEV code.
2724///
2725const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
2726
2727  bool InBounds = GEP->isInBounds();
2728  const Type *IntPtrTy = getEffectiveSCEVType(GEP->getType());
2729  Value *Base = GEP->getOperand(0);
2730  // Don't attempt to analyze GEPs over unsized objects.
2731  if (!cast<PointerType>(Base->getType())->getElementType()->isSized())
2732    return getUnknown(GEP);
2733  const SCEV *TotalOffset = getIntegerSCEV(0, IntPtrTy);
2734  gep_type_iterator GTI = gep_type_begin(GEP);
2735  for (GetElementPtrInst::op_iterator I = next(GEP->op_begin()),
2736                                      E = GEP->op_end();
2737       I != E; ++I) {
2738    Value *Index = *I;
2739    // Compute the (potentially symbolic) offset in bytes for this index.
2740    if (const StructType *STy = dyn_cast<StructType>(*GTI++)) {
2741      // For a struct, add the member offset.
2742      unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue();
2743      TotalOffset = getAddExpr(TotalOffset,
2744                               getOffsetOfExpr(STy, FieldNo),
2745                               /*HasNUW=*/false, /*HasNSW=*/InBounds);
2746    } else {
2747      // For an array, add the element offset, explicitly scaled.
2748      const SCEV *LocalOffset = getSCEV(Index);
2749      // Getelementptr indices are signed.
2750      LocalOffset = getTruncateOrSignExtend(LocalOffset, IntPtrTy);
2751      // Lower "inbounds" GEPs to NSW arithmetic.
2752      LocalOffset = getMulExpr(LocalOffset, getSizeOfExpr(*GTI),
2753                               /*HasNUW=*/false, /*HasNSW=*/InBounds);
2754      TotalOffset = getAddExpr(TotalOffset, LocalOffset,
2755                               /*HasNUW=*/false, /*HasNSW=*/InBounds);
2756    }
2757  }
2758  return getAddExpr(getSCEV(Base), TotalOffset,
2759                    /*HasNUW=*/false, /*HasNSW=*/InBounds);
2760}
2761
2762/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
2763/// guaranteed to end in (at every loop iteration).  It is, at the same time,
2764/// the minimum number of times S is divisible by 2.  For example, given {4,+,8}
2765/// it returns 2.  If S is guaranteed to be 0, it returns the bitwidth of S.
2766uint32_t
2767ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
2768  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
2769    return C->getValue()->getValue().countTrailingZeros();
2770
2771  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
2772    return std::min(GetMinTrailingZeros(T->getOperand()),
2773                    (uint32_t)getTypeSizeInBits(T->getType()));
2774
2775  if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
2776    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
2777    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
2778             getTypeSizeInBits(E->getType()) : OpRes;
2779  }
2780
2781  if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
2782    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
2783    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
2784             getTypeSizeInBits(E->getType()) : OpRes;
2785  }
2786
2787  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
2788    // The result is the min of all operands results.
2789    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
2790    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
2791      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
2792    return MinOpRes;
2793  }
2794
2795  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
2796    // The result is the sum of all operands results.
2797    uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
2798    uint32_t BitWidth = getTypeSizeInBits(M->getType());
2799    for (unsigned i = 1, e = M->getNumOperands();
2800         SumOpRes != BitWidth && i != e; ++i)
2801      SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
2802                          BitWidth);
2803    return SumOpRes;
2804  }
2805
2806  if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
2807    // The result is the min of all operands results.
2808    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
2809    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
2810      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
2811    return MinOpRes;
2812  }
2813
2814  if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
2815    // The result is the min of all operands results.
2816    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
2817    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
2818      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
2819    return MinOpRes;
2820  }
2821
2822  if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
2823    // The result is the min of all operands results.
2824    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
2825    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
2826      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
2827    return MinOpRes;
2828  }
2829
2830  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
2831    // For a SCEVUnknown, ask ValueTracking.
2832    unsigned BitWidth = getTypeSizeInBits(U->getType());
2833    APInt Mask = APInt::getAllOnesValue(BitWidth);
2834    APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
2835    ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones);
2836    return Zeros.countTrailingOnes();
2837  }
2838
2839  // SCEVUDivExpr
2840  return 0;
2841}
2842
2843/// getUnsignedRange - Determine the unsigned range for a particular SCEV.
2844///
2845ConstantRange
2846ScalarEvolution::getUnsignedRange(const SCEV *S) {
2847
2848  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
2849    return ConstantRange(C->getValue()->getValue());
2850
2851  unsigned BitWidth = getTypeSizeInBits(S->getType());
2852  ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
2853
2854  // If the value has known zeros, the maximum unsigned value will have those
2855  // known zeros as well.
2856  uint32_t TZ = GetMinTrailingZeros(S);
2857  if (TZ != 0)
2858    ConservativeResult =
2859      ConstantRange(APInt::getMinValue(BitWidth),
2860                    APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
2861
2862  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
2863    ConstantRange X = getUnsignedRange(Add->getOperand(0));
2864    for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
2865      X = X.add(getUnsignedRange(Add->getOperand(i)));
2866    return ConservativeResult.intersectWith(X);
2867  }
2868
2869  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
2870    ConstantRange X = getUnsignedRange(Mul->getOperand(0));
2871    for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
2872      X = X.multiply(getUnsignedRange(Mul->getOperand(i)));
2873    return ConservativeResult.intersectWith(X);
2874  }
2875
2876  if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
2877    ConstantRange X = getUnsignedRange(SMax->getOperand(0));
2878    for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
2879      X = X.smax(getUnsignedRange(SMax->getOperand(i)));
2880    return ConservativeResult.intersectWith(X);
2881  }
2882
2883  if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
2884    ConstantRange X = getUnsignedRange(UMax->getOperand(0));
2885    for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
2886      X = X.umax(getUnsignedRange(UMax->getOperand(i)));
2887    return ConservativeResult.intersectWith(X);
2888  }
2889
2890  if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
2891    ConstantRange X = getUnsignedRange(UDiv->getLHS());
2892    ConstantRange Y = getUnsignedRange(UDiv->getRHS());
2893    return ConservativeResult.intersectWith(X.udiv(Y));
2894  }
2895
2896  if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
2897    ConstantRange X = getUnsignedRange(ZExt->getOperand());
2898    return ConservativeResult.intersectWith(X.zeroExtend(BitWidth));
2899  }
2900
2901  if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
2902    ConstantRange X = getUnsignedRange(SExt->getOperand());
2903    return ConservativeResult.intersectWith(X.signExtend(BitWidth));
2904  }
2905
2906  if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
2907    ConstantRange X = getUnsignedRange(Trunc->getOperand());
2908    return ConservativeResult.intersectWith(X.truncate(BitWidth));
2909  }
2910
2911  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
2912    // If there's no unsigned wrap, the value will never be less than its
2913    // initial value.
2914    if (AddRec->hasNoUnsignedWrap())
2915      if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart()))
2916        ConservativeResult =
2917          ConstantRange(C->getValue()->getValue(),
2918                        APInt(getTypeSizeInBits(C->getType()), 0));
2919
2920    // TODO: non-affine addrec
2921    if (AddRec->isAffine()) {
2922      const Type *Ty = AddRec->getType();
2923      const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
2924      if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
2925          getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
2926        MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
2927
2928        const SCEV *Start = AddRec->getStart();
2929        const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
2930
2931        // Check for overflow.
2932        if (!AddRec->hasNoUnsignedWrap())
2933          return ConservativeResult;
2934
2935        ConstantRange StartRange = getUnsignedRange(Start);
2936        ConstantRange EndRange = getUnsignedRange(End);
2937        APInt Min = APIntOps::umin(StartRange.getUnsignedMin(),
2938                                   EndRange.getUnsignedMin());
2939        APInt Max = APIntOps::umax(StartRange.getUnsignedMax(),
2940                                   EndRange.getUnsignedMax());
2941        if (Min.isMinValue() && Max.isMaxValue())
2942          return ConservativeResult;
2943        return ConservativeResult.intersectWith(ConstantRange(Min, Max+1));
2944      }
2945    }
2946
2947    return ConservativeResult;
2948  }
2949
2950  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
2951    // For a SCEVUnknown, ask ValueTracking.
2952    APInt Mask = APInt::getAllOnesValue(BitWidth);
2953    APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
2954    ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones, TD);
2955    if (Ones == ~Zeros + 1)
2956      return ConservativeResult;
2957    return ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1));
2958  }
2959
2960  return ConservativeResult;
2961}
2962
2963/// getSignedRange - Determine the signed range for a particular SCEV.
2964///
2965ConstantRange
2966ScalarEvolution::getSignedRange(const SCEV *S) {
2967
2968  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
2969    return ConstantRange(C->getValue()->getValue());
2970
2971  unsigned BitWidth = getTypeSizeInBits(S->getType());
2972  ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
2973
2974  // If the value has known zeros, the maximum signed value will have those
2975  // known zeros as well.
2976  uint32_t TZ = GetMinTrailingZeros(S);
2977  if (TZ != 0)
2978    ConservativeResult =
2979      ConstantRange(APInt::getSignedMinValue(BitWidth),
2980                    APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
2981
2982  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
2983    ConstantRange X = getSignedRange(Add->getOperand(0));
2984    for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
2985      X = X.add(getSignedRange(Add->getOperand(i)));
2986    return ConservativeResult.intersectWith(X);
2987  }
2988
2989  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
2990    ConstantRange X = getSignedRange(Mul->getOperand(0));
2991    for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
2992      X = X.multiply(getSignedRange(Mul->getOperand(i)));
2993    return ConservativeResult.intersectWith(X);
2994  }
2995
2996  if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
2997    ConstantRange X = getSignedRange(SMax->getOperand(0));
2998    for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
2999      X = X.smax(getSignedRange(SMax->getOperand(i)));
3000    return ConservativeResult.intersectWith(X);
3001  }
3002
3003  if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
3004    ConstantRange X = getSignedRange(UMax->getOperand(0));
3005    for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
3006      X = X.umax(getSignedRange(UMax->getOperand(i)));
3007    return ConservativeResult.intersectWith(X);
3008  }
3009
3010  if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
3011    ConstantRange X = getSignedRange(UDiv->getLHS());
3012    ConstantRange Y = getSignedRange(UDiv->getRHS());
3013    return ConservativeResult.intersectWith(X.udiv(Y));
3014  }
3015
3016  if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
3017    ConstantRange X = getSignedRange(ZExt->getOperand());
3018    return ConservativeResult.intersectWith(X.zeroExtend(BitWidth));
3019  }
3020
3021  if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
3022    ConstantRange X = getSignedRange(SExt->getOperand());
3023    return ConservativeResult.intersectWith(X.signExtend(BitWidth));
3024  }
3025
3026  if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
3027    ConstantRange X = getSignedRange(Trunc->getOperand());
3028    return ConservativeResult.intersectWith(X.truncate(BitWidth));
3029  }
3030
3031  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
3032    // If there's no signed wrap, and all the operands have the same sign or
3033    // zero, the value won't ever change sign.
3034    if (AddRec->hasNoSignedWrap()) {
3035      bool AllNonNeg = true;
3036      bool AllNonPos = true;
3037      for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3038        if (!isKnownNonNegative(AddRec->getOperand(i))) AllNonNeg = false;
3039        if (!isKnownNonPositive(AddRec->getOperand(i))) AllNonPos = false;
3040      }
3041      if (AllNonNeg)
3042        ConservativeResult = ConservativeResult.intersectWith(
3043          ConstantRange(APInt(BitWidth, 0),
3044                        APInt::getSignedMinValue(BitWidth)));
3045      else if (AllNonPos)
3046        ConservativeResult = ConservativeResult.intersectWith(
3047          ConstantRange(APInt::getSignedMinValue(BitWidth),
3048                        APInt(BitWidth, 1)));
3049    }
3050
3051    // TODO: non-affine addrec
3052    if (AddRec->isAffine()) {
3053      const Type *Ty = AddRec->getType();
3054      const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
3055      if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
3056          getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
3057        MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
3058
3059        const SCEV *Start = AddRec->getStart();
3060        const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
3061
3062        // Check for overflow.
3063        if (!AddRec->hasNoSignedWrap())
3064          return ConservativeResult;
3065
3066        ConstantRange StartRange = getSignedRange(Start);
3067        ConstantRange EndRange = getSignedRange(End);
3068        APInt Min = APIntOps::smin(StartRange.getSignedMin(),
3069                                   EndRange.getSignedMin());
3070        APInt Max = APIntOps::smax(StartRange.getSignedMax(),
3071                                   EndRange.getSignedMax());
3072        if (Min.isMinSignedValue() && Max.isMaxSignedValue())
3073          return ConservativeResult;
3074        return ConservativeResult.intersectWith(ConstantRange(Min, Max+1));
3075      }
3076    }
3077
3078    return ConservativeResult;
3079  }
3080
3081  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
3082    // For a SCEVUnknown, ask ValueTracking.
3083    if (!U->getValue()->getType()->isIntegerTy() && !TD)
3084      return ConservativeResult;
3085    unsigned NS = ComputeNumSignBits(U->getValue(), TD);
3086    if (NS == 1)
3087      return ConservativeResult;
3088    return ConservativeResult.intersectWith(
3089      ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
3090                    APInt::getSignedMaxValue(BitWidth).ashr(NS - 1)+1));
3091  }
3092
3093  return ConservativeResult;
3094}
3095
3096/// createSCEV - We know that there is no SCEV for the specified value.
3097/// Analyze the expression.
3098///
3099const SCEV *ScalarEvolution::createSCEV(Value *V) {
3100  if (!isSCEVable(V->getType()))
3101    return getUnknown(V);
3102
3103  unsigned Opcode = Instruction::UserOp1;
3104  if (Instruction *I = dyn_cast<Instruction>(V)) {
3105    Opcode = I->getOpcode();
3106
3107    // Don't attempt to analyze instructions in blocks that aren't
3108    // reachable. Such instructions don't matter, and they aren't required
3109    // to obey basic rules for definitions dominating uses which this
3110    // analysis depends on.
3111    if (!DT->isReachableFromEntry(I->getParent()))
3112      return getUnknown(V);
3113  } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
3114    Opcode = CE->getOpcode();
3115  else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
3116    return getConstant(CI);
3117  else if (isa<ConstantPointerNull>(V))
3118    return getIntegerSCEV(0, V->getType());
3119  else if (isa<UndefValue>(V))
3120    return getIntegerSCEV(0, V->getType());
3121  else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
3122    return GA->mayBeOverridden() ? getUnknown(V) : getSCEV(GA->getAliasee());
3123  else
3124    return getUnknown(V);
3125
3126  Operator *U = cast<Operator>(V);
3127  switch (Opcode) {
3128  case Instruction::Add:
3129    // Don't transfer the NSW and NUW bits from the Add instruction to the
3130    // Add expression, because the Instruction may be guarded by control
3131    // flow and the no-overflow bits may not be valid for the expression in
3132    // any context.
3133    return getAddExpr(getSCEV(U->getOperand(0)),
3134                      getSCEV(U->getOperand(1)));
3135  case Instruction::Mul:
3136    // Don't transfer the NSW and NUW bits from the Mul instruction to the
3137    // Mul expression, as with Add.
3138    return getMulExpr(getSCEV(U->getOperand(0)),
3139                      getSCEV(U->getOperand(1)));
3140  case Instruction::UDiv:
3141    return getUDivExpr(getSCEV(U->getOperand(0)),
3142                       getSCEV(U->getOperand(1)));
3143  case Instruction::Sub:
3144    return getMinusSCEV(getSCEV(U->getOperand(0)),
3145                        getSCEV(U->getOperand(1)));
3146  case Instruction::And:
3147    // For an expression like x&255 that merely masks off the high bits,
3148    // use zext(trunc(x)) as the SCEV expression.
3149    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3150      if (CI->isNullValue())
3151        return getSCEV(U->getOperand(1));
3152      if (CI->isAllOnesValue())
3153        return getSCEV(U->getOperand(0));
3154      const APInt &A = CI->getValue();
3155
3156      // Instcombine's ShrinkDemandedConstant may strip bits out of
3157      // constants, obscuring what would otherwise be a low-bits mask.
3158      // Use ComputeMaskedBits to compute what ShrinkDemandedConstant
3159      // knew about to reconstruct a low-bits mask value.
3160      unsigned LZ = A.countLeadingZeros();
3161      unsigned BitWidth = A.getBitWidth();
3162      APInt AllOnes = APInt::getAllOnesValue(BitWidth);
3163      APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
3164      ComputeMaskedBits(U->getOperand(0), AllOnes, KnownZero, KnownOne, TD);
3165
3166      APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ);
3167
3168      if (LZ != 0 && !((~A & ~KnownZero) & EffectiveMask))
3169        return
3170          getZeroExtendExpr(getTruncateExpr(getSCEV(U->getOperand(0)),
3171                                IntegerType::get(getContext(), BitWidth - LZ)),
3172                            U->getType());
3173    }
3174    break;
3175
3176  case Instruction::Or:
3177    // If the RHS of the Or is a constant, we may have something like:
3178    // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
3179    // optimizations will transparently handle this case.
3180    //
3181    // In order for this transformation to be safe, the LHS must be of the
3182    // form X*(2^n) and the Or constant must be less than 2^n.
3183    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3184      const SCEV *LHS = getSCEV(U->getOperand(0));
3185      const APInt &CIVal = CI->getValue();
3186      if (GetMinTrailingZeros(LHS) >=
3187          (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
3188        // Build a plain add SCEV.
3189        const SCEV *S = getAddExpr(LHS, getSCEV(CI));
3190        // If the LHS of the add was an addrec and it has no-wrap flags,
3191        // transfer the no-wrap flags, since an or won't introduce a wrap.
3192        if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) {
3193          const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS);
3194          if (OldAR->hasNoUnsignedWrap())
3195            const_cast<SCEVAddRecExpr *>(NewAR)->setHasNoUnsignedWrap(true);
3196          if (OldAR->hasNoSignedWrap())
3197            const_cast<SCEVAddRecExpr *>(NewAR)->setHasNoSignedWrap(true);
3198        }
3199        return S;
3200      }
3201    }
3202    break;
3203  case Instruction::Xor:
3204    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3205      // If the RHS of the xor is a signbit, then this is just an add.
3206      // Instcombine turns add of signbit into xor as a strength reduction step.
3207      if (CI->getValue().isSignBit())
3208        return getAddExpr(getSCEV(U->getOperand(0)),
3209                          getSCEV(U->getOperand(1)));
3210
3211      // If the RHS of xor is -1, then this is a not operation.
3212      if (CI->isAllOnesValue())
3213        return getNotSCEV(getSCEV(U->getOperand(0)));
3214
3215      // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
3216      // This is a variant of the check for xor with -1, and it handles
3217      // the case where instcombine has trimmed non-demanded bits out
3218      // of an xor with -1.
3219      if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0)))
3220        if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1)))
3221          if (BO->getOpcode() == Instruction::And &&
3222              LCI->getValue() == CI->getValue())
3223            if (const SCEVZeroExtendExpr *Z =
3224                  dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) {
3225              const Type *UTy = U->getType();
3226              const SCEV *Z0 = Z->getOperand();
3227              const Type *Z0Ty = Z0->getType();
3228              unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
3229
3230              // If C is a low-bits mask, the zero extend is serving to
3231              // mask off the high bits. Complement the operand and
3232              // re-apply the zext.
3233              if (APIntOps::isMask(Z0TySize, CI->getValue()))
3234                return getZeroExtendExpr(getNotSCEV(Z0), UTy);
3235
3236              // If C is a single bit, it may be in the sign-bit position
3237              // before the zero-extend. In this case, represent the xor
3238              // using an add, which is equivalent, and re-apply the zext.
3239              APInt Trunc = APInt(CI->getValue()).trunc(Z0TySize);
3240              if (APInt(Trunc).zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
3241                  Trunc.isSignBit())
3242                return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
3243                                         UTy);
3244            }
3245    }
3246    break;
3247
3248  case Instruction::Shl:
3249    // Turn shift left of a constant amount into a multiply.
3250    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
3251      uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
3252      Constant *X = ConstantInt::get(getContext(),
3253        APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
3254      return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
3255    }
3256    break;
3257
3258  case Instruction::LShr:
3259    // Turn logical shift right of a constant into a unsigned divide.
3260    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
3261      uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
3262      Constant *X = ConstantInt::get(getContext(),
3263        APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
3264      return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
3265    }
3266    break;
3267
3268  case Instruction::AShr:
3269    // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
3270    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1)))
3271      if (Instruction *L = dyn_cast<Instruction>(U->getOperand(0)))
3272        if (L->getOpcode() == Instruction::Shl &&
3273            L->getOperand(1) == U->getOperand(1)) {
3274          unsigned BitWidth = getTypeSizeInBits(U->getType());
3275          uint64_t Amt = BitWidth - CI->getZExtValue();
3276          if (Amt == BitWidth)
3277            return getSCEV(L->getOperand(0));       // shift by zero --> noop
3278          if (Amt > BitWidth)
3279            return getIntegerSCEV(0, U->getType()); // value is undefined
3280          return
3281            getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)),
3282                                           IntegerType::get(getContext(), Amt)),
3283                                 U->getType());
3284        }
3285    break;
3286
3287  case Instruction::Trunc:
3288    return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
3289
3290  case Instruction::ZExt:
3291    return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
3292
3293  case Instruction::SExt:
3294    return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
3295
3296  case Instruction::BitCast:
3297    // BitCasts are no-op casts so we just eliminate the cast.
3298    if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
3299      return getSCEV(U->getOperand(0));
3300    break;
3301
3302  // It's tempting to handle inttoptr and ptrtoint as no-ops, however this can
3303  // lead to pointer expressions which cannot safely be expanded to GEPs,
3304  // because ScalarEvolution doesn't respect the GEP aliasing rules when
3305  // simplifying integer expressions.
3306
3307  case Instruction::GetElementPtr:
3308    return createNodeForGEP(cast<GEPOperator>(U));
3309
3310  case Instruction::PHI:
3311    return createNodeForPHI(cast<PHINode>(U));
3312
3313  case Instruction::Select:
3314    // This could be a smax or umax that was lowered earlier.
3315    // Try to recover it.
3316    if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
3317      Value *LHS = ICI->getOperand(0);
3318      Value *RHS = ICI->getOperand(1);
3319      switch (ICI->getPredicate()) {
3320      case ICmpInst::ICMP_SLT:
3321      case ICmpInst::ICMP_SLE:
3322        std::swap(LHS, RHS);
3323        // fall through
3324      case ICmpInst::ICMP_SGT:
3325      case ICmpInst::ICMP_SGE:
3326        if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
3327          return getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
3328        else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
3329          return getSMinExpr(getSCEV(LHS), getSCEV(RHS));
3330        break;
3331      case ICmpInst::ICMP_ULT:
3332      case ICmpInst::ICMP_ULE:
3333        std::swap(LHS, RHS);
3334        // fall through
3335      case ICmpInst::ICMP_UGT:
3336      case ICmpInst::ICMP_UGE:
3337        if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
3338          return getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
3339        else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
3340          return getUMinExpr(getSCEV(LHS), getSCEV(RHS));
3341        break;
3342      case ICmpInst::ICMP_NE:
3343        // n != 0 ? n : 1  ->  umax(n, 1)
3344        if (LHS == U->getOperand(1) &&
3345            isa<ConstantInt>(U->getOperand(2)) &&
3346            cast<ConstantInt>(U->getOperand(2))->isOne() &&
3347            isa<ConstantInt>(RHS) &&
3348            cast<ConstantInt>(RHS)->isZero())
3349          return getUMaxExpr(getSCEV(LHS), getSCEV(U->getOperand(2)));
3350        break;
3351      case ICmpInst::ICMP_EQ:
3352        // n == 0 ? 1 : n  ->  umax(n, 1)
3353        if (LHS == U->getOperand(2) &&
3354            isa<ConstantInt>(U->getOperand(1)) &&
3355            cast<ConstantInt>(U->getOperand(1))->isOne() &&
3356            isa<ConstantInt>(RHS) &&
3357            cast<ConstantInt>(RHS)->isZero())
3358          return getUMaxExpr(getSCEV(LHS), getSCEV(U->getOperand(1)));
3359        break;
3360      default:
3361        break;
3362      }
3363    }
3364
3365  default: // We cannot analyze this expression.
3366    break;
3367  }
3368
3369  return getUnknown(V);
3370}
3371
3372
3373
3374//===----------------------------------------------------------------------===//
3375//                   Iteration Count Computation Code
3376//
3377
3378/// getBackedgeTakenCount - If the specified loop has a predictable
3379/// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute
3380/// object. The backedge-taken count is the number of times the loop header
3381/// will be branched to from within the loop. This is one less than the
3382/// trip count of the loop, since it doesn't count the first iteration,
3383/// when the header is branched to from outside the loop.
3384///
3385/// Note that it is not valid to call this method on a loop without a
3386/// loop-invariant backedge-taken count (see
3387/// hasLoopInvariantBackedgeTakenCount).
3388///
3389const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) {
3390  return getBackedgeTakenInfo(L).Exact;
3391}
3392
3393/// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except
3394/// return the least SCEV value that is known never to be less than the
3395/// actual backedge taken count.
3396const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {
3397  return getBackedgeTakenInfo(L).Max;
3398}
3399
3400/// PushLoopPHIs - Push PHI nodes in the header of the given loop
3401/// onto the given Worklist.
3402static void
3403PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
3404  BasicBlock *Header = L->getHeader();
3405
3406  // Push all Loop-header PHIs onto the Worklist stack.
3407  for (BasicBlock::iterator I = Header->begin();
3408       PHINode *PN = dyn_cast<PHINode>(I); ++I)
3409    Worklist.push_back(PN);
3410}
3411
3412const ScalarEvolution::BackedgeTakenInfo &
3413ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
3414  // Initially insert a CouldNotCompute for this loop. If the insertion
3415  // succeeds, proceed to actually compute a backedge-taken count and
3416  // update the value. The temporary CouldNotCompute value tells SCEV
3417  // code elsewhere that it shouldn't attempt to request a new
3418  // backedge-taken count, which could result in infinite recursion.
3419  std::pair<std::map<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
3420    BackedgeTakenCounts.insert(std::make_pair(L, getCouldNotCompute()));
3421  if (Pair.second) {
3422    BackedgeTakenInfo BECount = ComputeBackedgeTakenCount(L);
3423    if (BECount.Exact != getCouldNotCompute()) {
3424      assert(BECount.Exact->isLoopInvariant(L) &&
3425             BECount.Max->isLoopInvariant(L) &&
3426             "Computed backedge-taken count isn't loop invariant for loop!");
3427      ++NumTripCountsComputed;
3428
3429      // Update the value in the map.
3430      Pair.first->second = BECount;
3431    } else {
3432      if (BECount.Max != getCouldNotCompute())
3433        // Update the value in the map.
3434        Pair.first->second = BECount;
3435      if (isa<PHINode>(L->getHeader()->begin()))
3436        // Only count loops that have phi nodes as not being computable.
3437        ++NumTripCountsNotComputed;
3438    }
3439
3440    // Now that we know more about the trip count for this loop, forget any
3441    // existing SCEV values for PHI nodes in this loop since they are only
3442    // conservative estimates made without the benefit of trip count
3443    // information. This is similar to the code in forgetLoop, except that
3444    // it handles SCEVUnknown PHI nodes specially.
3445    if (BECount.hasAnyInfo()) {
3446      SmallVector<Instruction *, 16> Worklist;
3447      PushLoopPHIs(L, Worklist);
3448
3449      SmallPtrSet<Instruction *, 8> Visited;
3450      while (!Worklist.empty()) {
3451        Instruction *I = Worklist.pop_back_val();
3452        if (!Visited.insert(I)) continue;
3453
3454        std::map<SCEVCallbackVH, const SCEV *>::iterator It =
3455          Scalars.find(static_cast<Value *>(I));
3456        if (It != Scalars.end()) {
3457          // SCEVUnknown for a PHI either means that it has an unrecognized
3458          // structure, or it's a PHI that's in the progress of being computed
3459          // by createNodeForPHI.  In the former case, additional loop trip
3460          // count information isn't going to change anything. In the later
3461          // case, createNodeForPHI will perform the necessary updates on its
3462          // own when it gets to that point.
3463          if (!isa<PHINode>(I) || !isa<SCEVUnknown>(It->second)) {
3464            ValuesAtScopes.erase(It->second);
3465            Scalars.erase(It);
3466          }
3467          if (PHINode *PN = dyn_cast<PHINode>(I))
3468            ConstantEvolutionLoopExitValue.erase(PN);
3469        }
3470
3471        PushDefUseChildren(I, Worklist);
3472      }
3473    }
3474  }
3475  return Pair.first->second;
3476}
3477
3478/// forgetLoop - This method should be called by the client when it has
3479/// changed a loop in a way that may effect ScalarEvolution's ability to
3480/// compute a trip count, or if the loop is deleted.
3481void ScalarEvolution::forgetLoop(const Loop *L) {
3482  // Drop any stored trip count value.
3483  BackedgeTakenCounts.erase(L);
3484
3485  // Drop information about expressions based on loop-header PHIs.
3486  SmallVector<Instruction *, 16> Worklist;
3487  PushLoopPHIs(L, Worklist);
3488
3489  SmallPtrSet<Instruction *, 8> Visited;
3490  while (!Worklist.empty()) {
3491    Instruction *I = Worklist.pop_back_val();
3492    if (!Visited.insert(I)) continue;
3493
3494    std::map<SCEVCallbackVH, const SCEV *>::iterator It =
3495      Scalars.find(static_cast<Value *>(I));
3496    if (It != Scalars.end()) {
3497      ValuesAtScopes.erase(It->second);
3498      Scalars.erase(It);
3499      if (PHINode *PN = dyn_cast<PHINode>(I))
3500        ConstantEvolutionLoopExitValue.erase(PN);
3501    }
3502
3503    PushDefUseChildren(I, Worklist);
3504  }
3505}
3506
3507/// forgetValue - This method should be called by the client when it has
3508/// changed a value in a way that may effect its value, or which may
3509/// disconnect it from a def-use chain linking it to a loop.
3510void ScalarEvolution::forgetValue(Value *V) {
3511  Instruction *I = dyn_cast<Instruction>(V);
3512  if (!I) return;
3513
3514  // Drop information about expressions based on loop-header PHIs.
3515  SmallVector<Instruction *, 16> Worklist;
3516  Worklist.push_back(I);
3517
3518  SmallPtrSet<Instruction *, 8> Visited;
3519  while (!Worklist.empty()) {
3520    I = Worklist.pop_back_val();
3521    if (!Visited.insert(I)) continue;
3522
3523    std::map<SCEVCallbackVH, const SCEV *>::iterator It =
3524      Scalars.find(static_cast<Value *>(I));
3525    if (It != Scalars.end()) {
3526      ValuesAtScopes.erase(It->second);
3527      Scalars.erase(It);
3528      if (PHINode *PN = dyn_cast<PHINode>(I))
3529        ConstantEvolutionLoopExitValue.erase(PN);
3530    }
3531
3532    PushDefUseChildren(I, Worklist);
3533  }
3534}
3535
3536/// ComputeBackedgeTakenCount - Compute the number of times the backedge
3537/// of the specified loop will execute.
3538ScalarEvolution::BackedgeTakenInfo
3539ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
3540  SmallVector<BasicBlock *, 8> ExitingBlocks;
3541  L->getExitingBlocks(ExitingBlocks);
3542
3543  // Examine all exits and pick the most conservative values.
3544  const SCEV *BECount = getCouldNotCompute();
3545  const SCEV *MaxBECount = getCouldNotCompute();
3546  bool CouldNotComputeBECount = false;
3547  for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
3548    BackedgeTakenInfo NewBTI =
3549      ComputeBackedgeTakenCountFromExit(L, ExitingBlocks[i]);
3550
3551    if (NewBTI.Exact == getCouldNotCompute()) {
3552      // We couldn't compute an exact value for this exit, so
3553      // we won't be able to compute an exact value for the loop.
3554      CouldNotComputeBECount = true;
3555      BECount = getCouldNotCompute();
3556    } else if (!CouldNotComputeBECount) {
3557      if (BECount == getCouldNotCompute())
3558        BECount = NewBTI.Exact;
3559      else
3560        BECount = getUMinFromMismatchedTypes(BECount, NewBTI.Exact);
3561    }
3562    if (MaxBECount == getCouldNotCompute())
3563      MaxBECount = NewBTI.Max;
3564    else if (NewBTI.Max != getCouldNotCompute())
3565      MaxBECount = getUMinFromMismatchedTypes(MaxBECount, NewBTI.Max);
3566  }
3567
3568  return BackedgeTakenInfo(BECount, MaxBECount);
3569}
3570
3571/// ComputeBackedgeTakenCountFromExit - Compute the number of times the backedge
3572/// of the specified loop will execute if it exits via the specified block.
3573ScalarEvolution::BackedgeTakenInfo
3574ScalarEvolution::ComputeBackedgeTakenCountFromExit(const Loop *L,
3575                                                   BasicBlock *ExitingBlock) {
3576
3577  // Okay, we've chosen an exiting block.  See what condition causes us to
3578  // exit at this block.
3579  //
3580  // FIXME: we should be able to handle switch instructions (with a single exit)
3581  BranchInst *ExitBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
3582  if (ExitBr == 0) return getCouldNotCompute();
3583  assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!");
3584
3585  // At this point, we know we have a conditional branch that determines whether
3586  // the loop is exited.  However, we don't know if the branch is executed each
3587  // time through the loop.  If not, then the execution count of the branch will
3588  // not be equal to the trip count of the loop.
3589  //
3590  // Currently we check for this by checking to see if the Exit branch goes to
3591  // the loop header.  If so, we know it will always execute the same number of
3592  // times as the loop.  We also handle the case where the exit block *is* the
3593  // loop header.  This is common for un-rotated loops.
3594  //
3595  // If both of those tests fail, walk up the unique predecessor chain to the
3596  // header, stopping if there is an edge that doesn't exit the loop. If the
3597  // header is reached, the execution count of the branch will be equal to the
3598  // trip count of the loop.
3599  //
3600  //  More extensive analysis could be done to handle more cases here.
3601  //
3602  if (ExitBr->getSuccessor(0) != L->getHeader() &&
3603      ExitBr->getSuccessor(1) != L->getHeader() &&
3604      ExitBr->getParent() != L->getHeader()) {
3605    // The simple checks failed, try climbing the unique predecessor chain
3606    // up to the header.
3607    bool Ok = false;
3608    for (BasicBlock *BB = ExitBr->getParent(); BB; ) {
3609      BasicBlock *Pred = BB->getUniquePredecessor();
3610      if (!Pred)
3611        return getCouldNotCompute();
3612      TerminatorInst *PredTerm = Pred->getTerminator();
3613      for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) {
3614        BasicBlock *PredSucc = PredTerm->getSuccessor(i);
3615        if (PredSucc == BB)
3616          continue;
3617        // If the predecessor has a successor that isn't BB and isn't
3618        // outside the loop, assume the worst.
3619        if (L->contains(PredSucc))
3620          return getCouldNotCompute();
3621      }
3622      if (Pred == L->getHeader()) {
3623        Ok = true;
3624        break;
3625      }
3626      BB = Pred;
3627    }
3628    if (!Ok)
3629      return getCouldNotCompute();
3630  }
3631
3632  // Proceed to the next level to examine the exit condition expression.
3633  return ComputeBackedgeTakenCountFromExitCond(L, ExitBr->getCondition(),
3634                                               ExitBr->getSuccessor(0),
3635                                               ExitBr->getSuccessor(1));
3636}
3637
3638/// ComputeBackedgeTakenCountFromExitCond - Compute the number of times the
3639/// backedge of the specified loop will execute if its exit condition
3640/// were a conditional branch of ExitCond, TBB, and FBB.
3641ScalarEvolution::BackedgeTakenInfo
3642ScalarEvolution::ComputeBackedgeTakenCountFromExitCond(const Loop *L,
3643                                                       Value *ExitCond,
3644                                                       BasicBlock *TBB,
3645                                                       BasicBlock *FBB) {
3646  // Check if the controlling expression for this loop is an And or Or.
3647  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
3648    if (BO->getOpcode() == Instruction::And) {
3649      // Recurse on the operands of the and.
3650      BackedgeTakenInfo BTI0 =
3651        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB);
3652      BackedgeTakenInfo BTI1 =
3653        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB);
3654      const SCEV *BECount = getCouldNotCompute();
3655      const SCEV *MaxBECount = getCouldNotCompute();
3656      if (L->contains(TBB)) {
3657        // Both conditions must be true for the loop to continue executing.
3658        // Choose the less conservative count.
3659        if (BTI0.Exact == getCouldNotCompute() ||
3660            BTI1.Exact == getCouldNotCompute())
3661          BECount = getCouldNotCompute();
3662        else
3663          BECount = getUMinFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
3664        if (BTI0.Max == getCouldNotCompute())
3665          MaxBECount = BTI1.Max;
3666        else if (BTI1.Max == getCouldNotCompute())
3667          MaxBECount = BTI0.Max;
3668        else
3669          MaxBECount = getUMinFromMismatchedTypes(BTI0.Max, BTI1.Max);
3670      } else {
3671        // Both conditions must be true for the loop to exit.
3672        assert(L->contains(FBB) && "Loop block has no successor in loop!");
3673        if (BTI0.Exact != getCouldNotCompute() &&
3674            BTI1.Exact != getCouldNotCompute())
3675          BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
3676        if (BTI0.Max != getCouldNotCompute() &&
3677            BTI1.Max != getCouldNotCompute())
3678          MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max);
3679      }
3680
3681      return BackedgeTakenInfo(BECount, MaxBECount);
3682    }
3683    if (BO->getOpcode() == Instruction::Or) {
3684      // Recurse on the operands of the or.
3685      BackedgeTakenInfo BTI0 =
3686        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB);
3687      BackedgeTakenInfo BTI1 =
3688        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB);
3689      const SCEV *BECount = getCouldNotCompute();
3690      const SCEV *MaxBECount = getCouldNotCompute();
3691      if (L->contains(FBB)) {
3692        // Both conditions must be false for the loop to continue executing.
3693        // Choose the less conservative count.
3694        if (BTI0.Exact == getCouldNotCompute() ||
3695            BTI1.Exact == getCouldNotCompute())
3696          BECount = getCouldNotCompute();
3697        else
3698          BECount = getUMinFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
3699        if (BTI0.Max == getCouldNotCompute())
3700          MaxBECount = BTI1.Max;
3701        else if (BTI1.Max == getCouldNotCompute())
3702          MaxBECount = BTI0.Max;
3703        else
3704          MaxBECount = getUMinFromMismatchedTypes(BTI0.Max, BTI1.Max);
3705      } else {
3706        // Both conditions must be false for the loop to exit.
3707        assert(L->contains(TBB) && "Loop block has no successor in loop!");
3708        if (BTI0.Exact != getCouldNotCompute() &&
3709            BTI1.Exact != getCouldNotCompute())
3710          BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
3711        if (BTI0.Max != getCouldNotCompute() &&
3712            BTI1.Max != getCouldNotCompute())
3713          MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max);
3714      }
3715
3716      return BackedgeTakenInfo(BECount, MaxBECount);
3717    }
3718  }
3719
3720  // With an icmp, it may be feasible to compute an exact backedge-taken count.
3721  // Proceed to the next level to examine the icmp.
3722  if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond))
3723    return ComputeBackedgeTakenCountFromExitCondICmp(L, ExitCondICmp, TBB, FBB);
3724
3725  // Check for a constant condition. These are normally stripped out by
3726  // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
3727  // preserve the CFG and is temporarily leaving constant conditions
3728  // in place.
3729  if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
3730    if (L->contains(FBB) == !CI->getZExtValue())
3731      // The backedge is always taken.
3732      return getCouldNotCompute();
3733    else
3734      // The backedge is never taken.
3735      return getIntegerSCEV(0, CI->getType());
3736  }
3737
3738  // If it's not an integer or pointer comparison then compute it the hard way.
3739  return ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB));
3740}
3741
3742/// ComputeBackedgeTakenCountFromExitCondICmp - Compute the number of times the
3743/// backedge of the specified loop will execute if its exit condition
3744/// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB.
3745ScalarEvolution::BackedgeTakenInfo
3746ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L,
3747                                                           ICmpInst *ExitCond,
3748                                                           BasicBlock *TBB,
3749                                                           BasicBlock *FBB) {
3750
3751  // If the condition was exit on true, convert the condition to exit on false
3752  ICmpInst::Predicate Cond;
3753  if (!L->contains(FBB))
3754    Cond = ExitCond->getPredicate();
3755  else
3756    Cond = ExitCond->getInversePredicate();
3757
3758  // Handle common loops like: for (X = "string"; *X; ++X)
3759  if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
3760    if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
3761      BackedgeTakenInfo ItCnt =
3762        ComputeLoadConstantCompareBackedgeTakenCount(LI, RHS, L, Cond);
3763      if (ItCnt.hasAnyInfo())
3764        return ItCnt;
3765    }
3766
3767  const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
3768  const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
3769
3770  // Try to evaluate any dependencies out of the loop.
3771  LHS = getSCEVAtScope(LHS, L);
3772  RHS = getSCEVAtScope(RHS, L);
3773
3774  // At this point, we would like to compute how many iterations of the
3775  // loop the predicate will return true for these inputs.
3776  if (LHS->isLoopInvariant(L) && !RHS->isLoopInvariant(L)) {
3777    // If there is a loop-invariant, force it into the RHS.
3778    std::swap(LHS, RHS);
3779    Cond = ICmpInst::getSwappedPredicate(Cond);
3780  }
3781
3782  // If we have a comparison of a chrec against a constant, try to use value
3783  // ranges to answer this query.
3784  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
3785    if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
3786      if (AddRec->getLoop() == L) {
3787        // Form the constant range.
3788        ConstantRange CompRange(
3789            ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue()));
3790
3791        const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
3792        if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
3793      }
3794
3795  switch (Cond) {
3796  case ICmpInst::ICMP_NE: {                     // while (X != Y)
3797    // Convert to: while (X-Y != 0)
3798    BackedgeTakenInfo BTI = HowFarToZero(getMinusSCEV(LHS, RHS), L);
3799    if (BTI.hasAnyInfo()) return BTI;
3800    break;
3801  }
3802  case ICmpInst::ICMP_EQ: {                     // while (X == Y)
3803    // Convert to: while (X-Y == 0)
3804    BackedgeTakenInfo BTI = HowFarToNonZero(getMinusSCEV(LHS, RHS), L);
3805    if (BTI.hasAnyInfo()) return BTI;
3806    break;
3807  }
3808  case ICmpInst::ICMP_SLT: {
3809    BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, true);
3810    if (BTI.hasAnyInfo()) return BTI;
3811    break;
3812  }
3813  case ICmpInst::ICMP_SGT: {
3814    BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS),
3815                                             getNotSCEV(RHS), L, true);
3816    if (BTI.hasAnyInfo()) return BTI;
3817    break;
3818  }
3819  case ICmpInst::ICMP_ULT: {
3820    BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, false);
3821    if (BTI.hasAnyInfo()) return BTI;
3822    break;
3823  }
3824  case ICmpInst::ICMP_UGT: {
3825    BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS),
3826                                             getNotSCEV(RHS), L, false);
3827    if (BTI.hasAnyInfo()) return BTI;
3828    break;
3829  }
3830  default:
3831#if 0
3832    dbgs() << "ComputeBackedgeTakenCount ";
3833    if (ExitCond->getOperand(0)->getType()->isUnsigned())
3834      dbgs() << "[unsigned] ";
3835    dbgs() << *LHS << "   "
3836         << Instruction::getOpcodeName(Instruction::ICmp)
3837         << "   " << *RHS << "\n";
3838#endif
3839    break;
3840  }
3841  return
3842    ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB));
3843}
3844
3845static ConstantInt *
3846EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
3847                                ScalarEvolution &SE) {
3848  const SCEV *InVal = SE.getConstant(C);
3849  const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
3850  assert(isa<SCEVConstant>(Val) &&
3851         "Evaluation of SCEV at constant didn't fold correctly?");
3852  return cast<SCEVConstant>(Val)->getValue();
3853}
3854
3855/// GetAddressedElementFromGlobal - Given a global variable with an initializer
3856/// and a GEP expression (missing the pointer index) indexing into it, return
3857/// the addressed element of the initializer or null if the index expression is
3858/// invalid.
3859static Constant *
3860GetAddressedElementFromGlobal(GlobalVariable *GV,
3861                              const std::vector<ConstantInt*> &Indices) {
3862  Constant *Init = GV->getInitializer();
3863  for (unsigned i = 0, e = Indices.size(); i != e; ++i) {
3864    uint64_t Idx = Indices[i]->getZExtValue();
3865    if (ConstantStruct *CS = dyn_cast<ConstantStruct>(Init)) {
3866      assert(Idx < CS->getNumOperands() && "Bad struct index!");
3867      Init = cast<Constant>(CS->getOperand(Idx));
3868    } else if (ConstantArray *CA = dyn_cast<ConstantArray>(Init)) {
3869      if (Idx >= CA->getNumOperands()) return 0;  // Bogus program
3870      Init = cast<Constant>(CA->getOperand(Idx));
3871    } else if (isa<ConstantAggregateZero>(Init)) {
3872      if (const StructType *STy = dyn_cast<StructType>(Init->getType())) {
3873        assert(Idx < STy->getNumElements() && "Bad struct index!");
3874        Init = Constant::getNullValue(STy->getElementType(Idx));
3875      } else if (const ArrayType *ATy = dyn_cast<ArrayType>(Init->getType())) {
3876        if (Idx >= ATy->getNumElements()) return 0;  // Bogus program
3877        Init = Constant::getNullValue(ATy->getElementType());
3878      } else {
3879        llvm_unreachable("Unknown constant aggregate type!");
3880      }
3881      return 0;
3882    } else {
3883      return 0; // Unknown initializer type
3884    }
3885  }
3886  return Init;
3887}
3888
3889/// ComputeLoadConstantCompareBackedgeTakenCount - Given an exit condition of
3890/// 'icmp op load X, cst', try to see if we can compute the backedge
3891/// execution count.
3892ScalarEvolution::BackedgeTakenInfo
3893ScalarEvolution::ComputeLoadConstantCompareBackedgeTakenCount(
3894                                                LoadInst *LI,
3895                                                Constant *RHS,
3896                                                const Loop *L,
3897                                                ICmpInst::Predicate predicate) {
3898  if (LI->isVolatile()) return getCouldNotCompute();
3899
3900  // Check to see if the loaded pointer is a getelementptr of a global.
3901  // TODO: Use SCEV instead of manually grubbing with GEPs.
3902  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0));
3903  if (!GEP) return getCouldNotCompute();
3904
3905  // Make sure that it is really a constant global we are gepping, with an
3906  // initializer, and make sure the first IDX is really 0.
3907  GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
3908  if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() ||
3909      GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) ||
3910      !cast<Constant>(GEP->getOperand(1))->isNullValue())
3911    return getCouldNotCompute();
3912
3913  // Okay, we allow one non-constant index into the GEP instruction.
3914  Value *VarIdx = 0;
3915  std::vector<ConstantInt*> Indexes;
3916  unsigned VarIdxNum = 0;
3917  for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i)
3918    if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
3919      Indexes.push_back(CI);
3920    } else if (!isa<ConstantInt>(GEP->getOperand(i))) {
3921      if (VarIdx) return getCouldNotCompute();  // Multiple non-constant idx's.
3922      VarIdx = GEP->getOperand(i);
3923      VarIdxNum = i-2;
3924      Indexes.push_back(0);
3925    }
3926
3927  // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant.
3928  // Check to see if X is a loop variant variable value now.
3929  const SCEV *Idx = getSCEV(VarIdx);
3930  Idx = getSCEVAtScope(Idx, L);
3931
3932  // We can only recognize very limited forms of loop index expressions, in
3933  // particular, only affine AddRec's like {C1,+,C2}.
3934  const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx);
3935  if (!IdxExpr || !IdxExpr->isAffine() || IdxExpr->isLoopInvariant(L) ||
3936      !isa<SCEVConstant>(IdxExpr->getOperand(0)) ||
3937      !isa<SCEVConstant>(IdxExpr->getOperand(1)))
3938    return getCouldNotCompute();
3939
3940  unsigned MaxSteps = MaxBruteForceIterations;
3941  for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
3942    ConstantInt *ItCst = ConstantInt::get(
3943                           cast<IntegerType>(IdxExpr->getType()), IterationNum);
3944    ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this);
3945
3946    // Form the GEP offset.
3947    Indexes[VarIdxNum] = Val;
3948
3949    Constant *Result = GetAddressedElementFromGlobal(GV, Indexes);
3950    if (Result == 0) break;  // Cannot compute!
3951
3952    // Evaluate the condition for this iteration.
3953    Result = ConstantExpr::getICmp(predicate, Result, RHS);
3954    if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
3955    if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
3956#if 0
3957      dbgs() << "\n***\n*** Computed loop count " << *ItCst
3958             << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
3959             << "***\n";
3960#endif
3961      ++NumArrayLenItCounts;
3962      return getConstant(ItCst);   // Found terminating iteration!
3963    }
3964  }
3965  return getCouldNotCompute();
3966}
3967
3968
3969/// CanConstantFold - Return true if we can constant fold an instruction of the
3970/// specified type, assuming that all operands were constants.
3971static bool CanConstantFold(const Instruction *I) {
3972  if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
3973      isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I))
3974    return true;
3975
3976  if (const CallInst *CI = dyn_cast<CallInst>(I))
3977    if (const Function *F = CI->getCalledFunction())
3978      return canConstantFoldCallTo(F);
3979  return false;
3980}
3981
3982/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
3983/// in the loop that V is derived from.  We allow arbitrary operations along the
3984/// way, but the operands of an operation must either be constants or a value
3985/// derived from a constant PHI.  If this expression does not fit with these
3986/// constraints, return null.
3987static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
3988  // If this is not an instruction, or if this is an instruction outside of the
3989  // loop, it can't be derived from a loop PHI.
3990  Instruction *I = dyn_cast<Instruction>(V);
3991  if (I == 0 || !L->contains(I)) return 0;
3992
3993  if (PHINode *PN = dyn_cast<PHINode>(I)) {
3994    if (L->getHeader() == I->getParent())
3995      return PN;
3996    else
3997      // We don't currently keep track of the control flow needed to evaluate
3998      // PHIs, so we cannot handle PHIs inside of loops.
3999      return 0;
4000  }
4001
4002  // If we won't be able to constant fold this expression even if the operands
4003  // are constants, return early.
4004  if (!CanConstantFold(I)) return 0;
4005
4006  // Otherwise, we can evaluate this instruction if all of its operands are
4007  // constant or derived from a PHI node themselves.
4008  PHINode *PHI = 0;
4009  for (unsigned Op = 0, e = I->getNumOperands(); Op != e; ++Op)
4010    if (!(isa<Constant>(I->getOperand(Op)) ||
4011          isa<GlobalValue>(I->getOperand(Op)))) {
4012      PHINode *P = getConstantEvolvingPHI(I->getOperand(Op), L);
4013      if (P == 0) return 0;  // Not evolving from PHI
4014      if (PHI == 0)
4015        PHI = P;
4016      else if (PHI != P)
4017        return 0;  // Evolving from multiple different PHIs.
4018    }
4019
4020  // This is a expression evolving from a constant PHI!
4021  return PHI;
4022}
4023
4024/// EvaluateExpression - Given an expression that passes the
4025/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
4026/// in the loop has the value PHIVal.  If we can't fold this expression for some
4027/// reason, return null.
4028static Constant *EvaluateExpression(Value *V, Constant *PHIVal,
4029                                    const TargetData *TD) {
4030  if (isa<PHINode>(V)) return PHIVal;
4031  if (Constant *C = dyn_cast<Constant>(V)) return C;
4032  if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) return GV;
4033  Instruction *I = cast<Instruction>(V);
4034
4035  std::vector<Constant*> Operands;
4036  Operands.resize(I->getNumOperands());
4037
4038  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
4039    Operands[i] = EvaluateExpression(I->getOperand(i), PHIVal, TD);
4040    if (Operands[i] == 0) return 0;
4041  }
4042
4043  if (const CmpInst *CI = dyn_cast<CmpInst>(I))
4044    return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
4045                                           Operands[1], TD);
4046  return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
4047                                  &Operands[0], Operands.size(), TD);
4048}
4049
4050/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
4051/// in the header of its containing loop, we know the loop executes a
4052/// constant number of times, and the PHI node is just a recurrence
4053/// involving constants, fold it.
4054Constant *
4055ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
4056                                                   const APInt &BEs,
4057                                                   const Loop *L) {
4058  std::map<PHINode*, Constant*>::iterator I =
4059    ConstantEvolutionLoopExitValue.find(PN);
4060  if (I != ConstantEvolutionLoopExitValue.end())
4061    return I->second;
4062
4063  if (BEs.ugt(APInt(BEs.getBitWidth(),MaxBruteForceIterations)))
4064    return ConstantEvolutionLoopExitValue[PN] = 0;  // Not going to evaluate it.
4065
4066  Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
4067
4068  // Since the loop is canonicalized, the PHI node must have two entries.  One
4069  // entry must be a constant (coming in from outside of the loop), and the
4070  // second must be derived from the same PHI.
4071  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
4072  Constant *StartCST =
4073    dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
4074  if (StartCST == 0)
4075    return RetVal = 0;  // Must be a constant.
4076
4077  Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
4078  PHINode *PN2 = getConstantEvolvingPHI(BEValue, L);
4079  if (PN2 != PN)
4080    return RetVal = 0;  // Not derived from same PHI.
4081
4082  // Execute the loop symbolically to determine the exit value.
4083  if (BEs.getActiveBits() >= 32)
4084    return RetVal = 0; // More than 2^32-1 iterations?? Not doing it!
4085
4086  unsigned NumIterations = BEs.getZExtValue(); // must be in range
4087  unsigned IterationNum = 0;
4088  for (Constant *PHIVal = StartCST; ; ++IterationNum) {
4089    if (IterationNum == NumIterations)
4090      return RetVal = PHIVal;  // Got exit value!
4091
4092    // Compute the value of the PHI node for the next iteration.
4093    Constant *NextPHI = EvaluateExpression(BEValue, PHIVal, TD);
4094    if (NextPHI == PHIVal)
4095      return RetVal = NextPHI;  // Stopped evolving!
4096    if (NextPHI == 0)
4097      return 0;        // Couldn't evaluate!
4098    PHIVal = NextPHI;
4099  }
4100}
4101
4102/// ComputeBackedgeTakenCountExhaustively - If the loop is known to execute a
4103/// constant number of times (the condition evolves only from constants),
4104/// try to evaluate a few iterations of the loop until we get the exit
4105/// condition gets a value of ExitWhen (true or false).  If we cannot
4106/// evaluate the trip count of the loop, return getCouldNotCompute().
4107const SCEV *
4108ScalarEvolution::ComputeBackedgeTakenCountExhaustively(const Loop *L,
4109                                                       Value *Cond,
4110                                                       bool ExitWhen) {
4111  PHINode *PN = getConstantEvolvingPHI(Cond, L);
4112  if (PN == 0) return getCouldNotCompute();
4113
4114  // Since the loop is canonicalized, the PHI node must have two entries.  One
4115  // entry must be a constant (coming in from outside of the loop), and the
4116  // second must be derived from the same PHI.
4117  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
4118  Constant *StartCST =
4119    dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
4120  if (StartCST == 0) return getCouldNotCompute();  // Must be a constant.
4121
4122  Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
4123  PHINode *PN2 = getConstantEvolvingPHI(BEValue, L);
4124  if (PN2 != PN) return getCouldNotCompute();  // Not derived from same PHI.
4125
4126  // Okay, we find a PHI node that defines the trip count of this loop.  Execute
4127  // the loop symbolically to determine when the condition gets a value of
4128  // "ExitWhen".
4129  unsigned IterationNum = 0;
4130  unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis.
4131  for (Constant *PHIVal = StartCST;
4132       IterationNum != MaxIterations; ++IterationNum) {
4133    ConstantInt *CondVal =
4134      dyn_cast_or_null<ConstantInt>(EvaluateExpression(Cond, PHIVal, TD));
4135
4136    // Couldn't symbolically evaluate.
4137    if (!CondVal) return getCouldNotCompute();
4138
4139    if (CondVal->getValue() == uint64_t(ExitWhen)) {
4140      ++NumBruteForceTripCountsComputed;
4141      return getConstant(Type::getInt32Ty(getContext()), IterationNum);
4142    }
4143
4144    // Compute the value of the PHI node for the next iteration.
4145    Constant *NextPHI = EvaluateExpression(BEValue, PHIVal, TD);
4146    if (NextPHI == 0 || NextPHI == PHIVal)
4147      return getCouldNotCompute();// Couldn't evaluate or not making progress...
4148    PHIVal = NextPHI;
4149  }
4150
4151  // Too many iterations were needed to evaluate.
4152  return getCouldNotCompute();
4153}
4154
4155/// getSCEVAtScope - Return a SCEV expression for the specified value
4156/// at the specified scope in the program.  The L value specifies a loop
4157/// nest to evaluate the expression at, where null is the top-level or a
4158/// specified loop is immediately inside of the loop.
4159///
4160/// This method can be used to compute the exit value for a variable defined
4161/// in a loop by querying what the value will hold in the parent loop.
4162///
4163/// In the case that a relevant loop exit value cannot be computed, the
4164/// original value V is returned.
4165const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
4166  // Check to see if we've folded this expression at this loop before.
4167  std::map<const Loop *, const SCEV *> &Values = ValuesAtScopes[V];
4168  std::pair<std::map<const Loop *, const SCEV *>::iterator, bool> Pair =
4169    Values.insert(std::make_pair(L, static_cast<const SCEV *>(0)));
4170  if (!Pair.second)
4171    return Pair.first->second ? Pair.first->second : V;
4172
4173  // Otherwise compute it.
4174  const SCEV *C = computeSCEVAtScope(V, L);
4175  ValuesAtScopes[V][L] = C;
4176  return C;
4177}
4178
4179const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
4180  if (isa<SCEVConstant>(V)) return V;
4181
4182  // If this instruction is evolved from a constant-evolving PHI, compute the
4183  // exit value from the loop without using SCEVs.
4184  if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
4185    if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
4186      const Loop *LI = (*this->LI)[I->getParent()];
4187      if (LI && LI->getParentLoop() == L)  // Looking for loop exit value.
4188        if (PHINode *PN = dyn_cast<PHINode>(I))
4189          if (PN->getParent() == LI->getHeader()) {
4190            // Okay, there is no closed form solution for the PHI node.  Check
4191            // to see if the loop that contains it has a known backedge-taken
4192            // count.  If so, we may be able to force computation of the exit
4193            // value.
4194            const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI);
4195            if (const SCEVConstant *BTCC =
4196                  dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
4197              // Okay, we know how many times the containing loop executes.  If
4198              // this is a constant evolving PHI node, get the final value at
4199              // the specified iteration number.
4200              Constant *RV = getConstantEvolutionLoopExitValue(PN,
4201                                                   BTCC->getValue()->getValue(),
4202                                                               LI);
4203              if (RV) return getSCEV(RV);
4204            }
4205          }
4206
4207      // Okay, this is an expression that we cannot symbolically evaluate
4208      // into a SCEV.  Check to see if it's possible to symbolically evaluate
4209      // the arguments into constants, and if so, try to constant propagate the
4210      // result.  This is particularly useful for computing loop exit values.
4211      if (CanConstantFold(I)) {
4212        std::vector<Constant*> Operands;
4213        Operands.reserve(I->getNumOperands());
4214        for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
4215          Value *Op = I->getOperand(i);
4216          if (Constant *C = dyn_cast<Constant>(Op)) {
4217            Operands.push_back(C);
4218          } else {
4219            // If any of the operands is non-constant and if they are
4220            // non-integer and non-pointer, don't even try to analyze them
4221            // with scev techniques.
4222            if (!isSCEVable(Op->getType()))
4223              return V;
4224
4225            const SCEV *OpV = getSCEVAtScope(Op, L);
4226            if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(OpV)) {
4227              Constant *C = SC->getValue();
4228              if (C->getType() != Op->getType())
4229                C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
4230                                                                  Op->getType(),
4231                                                                  false),
4232                                          C, Op->getType());
4233              Operands.push_back(C);
4234            } else if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(OpV)) {
4235              if (Constant *C = dyn_cast<Constant>(SU->getValue())) {
4236                if (C->getType() != Op->getType())
4237                  C =
4238                    ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
4239                                                                  Op->getType(),
4240                                                                  false),
4241                                          C, Op->getType());
4242                Operands.push_back(C);
4243              } else
4244                return V;
4245            } else {
4246              return V;
4247            }
4248          }
4249        }
4250
4251        Constant *C = 0;
4252        if (const CmpInst *CI = dyn_cast<CmpInst>(I))
4253          C = ConstantFoldCompareInstOperands(CI->getPredicate(),
4254                                              Operands[0], Operands[1], TD);
4255        else
4256          C = ConstantFoldInstOperands(I->getOpcode(), I->getType(),
4257                                       &Operands[0], Operands.size(), TD);
4258        if (C)
4259          return getSCEV(C);
4260      }
4261    }
4262
4263    // This is some other type of SCEVUnknown, just return it.
4264    return V;
4265  }
4266
4267  if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) {
4268    // Avoid performing the look-up in the common case where the specified
4269    // expression has no loop-variant portions.
4270    for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
4271      const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
4272      if (OpAtScope != Comm->getOperand(i)) {
4273        // Okay, at least one of these operands is loop variant but might be
4274        // foldable.  Build a new instance of the folded commutative expression.
4275        SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(),
4276                                            Comm->op_begin()+i);
4277        NewOps.push_back(OpAtScope);
4278
4279        for (++i; i != e; ++i) {
4280          OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
4281          NewOps.push_back(OpAtScope);
4282        }
4283        if (isa<SCEVAddExpr>(Comm))
4284          return getAddExpr(NewOps);
4285        if (isa<SCEVMulExpr>(Comm))
4286          return getMulExpr(NewOps);
4287        if (isa<SCEVSMaxExpr>(Comm))
4288          return getSMaxExpr(NewOps);
4289        if (isa<SCEVUMaxExpr>(Comm))
4290          return getUMaxExpr(NewOps);
4291        llvm_unreachable("Unknown commutative SCEV type!");
4292      }
4293    }
4294    // If we got here, all operands are loop invariant.
4295    return Comm;
4296  }
4297
4298  if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
4299    const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
4300    const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
4301    if (LHS == Div->getLHS() && RHS == Div->getRHS())
4302      return Div;   // must be loop invariant
4303    return getUDivExpr(LHS, RHS);
4304  }
4305
4306  // If this is a loop recurrence for a loop that does not contain L, then we
4307  // are dealing with the final value computed by the loop.
4308  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4309    if (!L || !AddRec->getLoop()->contains(L)) {
4310      // To evaluate this recurrence, we need to know how many times the AddRec
4311      // loop iterates.  Compute this now.
4312      const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
4313      if (BackedgeTakenCount == getCouldNotCompute()) return AddRec;
4314
4315      // Then, evaluate the AddRec.
4316      return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
4317    }
4318    return AddRec;
4319  }
4320
4321  if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) {
4322    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
4323    if (Op == Cast->getOperand())
4324      return Cast;  // must be loop invariant
4325    return getZeroExtendExpr(Op, Cast->getType());
4326  }
4327
4328  if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) {
4329    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
4330    if (Op == Cast->getOperand())
4331      return Cast;  // must be loop invariant
4332    return getSignExtendExpr(Op, Cast->getType());
4333  }
4334
4335  if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) {
4336    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
4337    if (Op == Cast->getOperand())
4338      return Cast;  // must be loop invariant
4339    return getTruncateExpr(Op, Cast->getType());
4340  }
4341
4342  llvm_unreachable("Unknown SCEV type!");
4343  return 0;
4344}
4345
4346/// getSCEVAtScope - This is a convenience function which does
4347/// getSCEVAtScope(getSCEV(V), L).
4348const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
4349  return getSCEVAtScope(getSCEV(V), L);
4350}
4351
4352/// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the
4353/// following equation:
4354///
4355///     A * X = B (mod N)
4356///
4357/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
4358/// A and B isn't important.
4359///
4360/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
4361static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
4362                                               ScalarEvolution &SE) {
4363  uint32_t BW = A.getBitWidth();
4364  assert(BW == B.getBitWidth() && "Bit widths must be the same.");
4365  assert(A != 0 && "A must be non-zero.");
4366
4367  // 1. D = gcd(A, N)
4368  //
4369  // The gcd of A and N may have only one prime factor: 2. The number of
4370  // trailing zeros in A is its multiplicity
4371  uint32_t Mult2 = A.countTrailingZeros();
4372  // D = 2^Mult2
4373
4374  // 2. Check if B is divisible by D.
4375  //
4376  // B is divisible by D if and only if the multiplicity of prime factor 2 for B
4377  // is not less than multiplicity of this prime factor for D.
4378  if (B.countTrailingZeros() < Mult2)
4379    return SE.getCouldNotCompute();
4380
4381  // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
4382  // modulo (N / D).
4383  //
4384  // (N / D) may need BW+1 bits in its representation.  Hence, we'll use this
4385  // bit width during computations.
4386  APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
4387  APInt Mod(BW + 1, 0);
4388  Mod.set(BW - Mult2);  // Mod = N / D
4389  APInt I = AD.multiplicativeInverse(Mod);
4390
4391  // 4. Compute the minimum unsigned root of the equation:
4392  // I * (B / D) mod (N / D)
4393  APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
4394
4395  // The result is guaranteed to be less than 2^BW so we may truncate it to BW
4396  // bits.
4397  return SE.getConstant(Result.trunc(BW));
4398}
4399
4400/// SolveQuadraticEquation - Find the roots of the quadratic equation for the
4401/// given quadratic chrec {L,+,M,+,N}.  This returns either the two roots (which
4402/// might be the same) or two SCEVCouldNotCompute objects.
4403///
4404static std::pair<const SCEV *,const SCEV *>
4405SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
4406  assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
4407  const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
4408  const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
4409  const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
4410
4411  // We currently can only solve this if the coefficients are constants.
4412  if (!LC || !MC || !NC) {
4413    const SCEV *CNC = SE.getCouldNotCompute();
4414    return std::make_pair(CNC, CNC);
4415  }
4416
4417  uint32_t BitWidth = LC->getValue()->getValue().getBitWidth();
4418  const APInt &L = LC->getValue()->getValue();
4419  const APInt &M = MC->getValue()->getValue();
4420  const APInt &N = NC->getValue()->getValue();
4421  APInt Two(BitWidth, 2);
4422  APInt Four(BitWidth, 4);
4423
4424  {
4425    using namespace APIntOps;
4426    const APInt& C = L;
4427    // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
4428    // The B coefficient is M-N/2
4429    APInt B(M);
4430    B -= sdiv(N,Two);
4431
4432    // The A coefficient is N/2
4433    APInt A(N.sdiv(Two));
4434
4435    // Compute the B^2-4ac term.
4436    APInt SqrtTerm(B);
4437    SqrtTerm *= B;
4438    SqrtTerm -= Four * (A * C);
4439
4440    // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
4441    // integer value or else APInt::sqrt() will assert.
4442    APInt SqrtVal(SqrtTerm.sqrt());
4443
4444    // Compute the two solutions for the quadratic formula.
4445    // The divisions must be performed as signed divisions.
4446    APInt NegB(-B);
4447    APInt TwoA( A << 1 );
4448    if (TwoA.isMinValue()) {
4449      const SCEV *CNC = SE.getCouldNotCompute();
4450      return std::make_pair(CNC, CNC);
4451    }
4452
4453    LLVMContext &Context = SE.getContext();
4454
4455    ConstantInt *Solution1 =
4456      ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA));
4457    ConstantInt *Solution2 =
4458      ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA));
4459
4460    return std::make_pair(SE.getConstant(Solution1),
4461                          SE.getConstant(Solution2));
4462    } // end APIntOps namespace
4463}
4464
4465/// HowFarToZero - Return the number of times a backedge comparing the specified
4466/// value to zero will execute.  If not computable, return CouldNotCompute.
4467ScalarEvolution::BackedgeTakenInfo
4468ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) {
4469  // If the value is a constant
4470  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
4471    // If the value is already zero, the branch will execute zero times.
4472    if (C->getValue()->isZero()) return C;
4473    return getCouldNotCompute();  // Otherwise it will loop infinitely.
4474  }
4475
4476  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V);
4477  if (!AddRec || AddRec->getLoop() != L)
4478    return getCouldNotCompute();
4479
4480  if (AddRec->isAffine()) {
4481    // If this is an affine expression, the execution count of this branch is
4482    // the minimum unsigned root of the following equation:
4483    //
4484    //     Start + Step*N = 0 (mod 2^BW)
4485    //
4486    // equivalent to:
4487    //
4488    //             Step*N = -Start (mod 2^BW)
4489    //
4490    // where BW is the common bit width of Start and Step.
4491
4492    // Get the initial value for the loop.
4493    const SCEV *Start = getSCEVAtScope(AddRec->getStart(),
4494                                       L->getParentLoop());
4495    const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1),
4496                                      L->getParentLoop());
4497
4498    if (const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step)) {
4499      // For now we handle only constant steps.
4500
4501      // First, handle unitary steps.
4502      if (StepC->getValue()->equalsInt(1))      // 1*N = -Start (mod 2^BW), so:
4503        return getNegativeSCEV(Start);          //   N = -Start (as unsigned)
4504      if (StepC->getValue()->isAllOnesValue())  // -1*N = -Start (mod 2^BW), so:
4505        return Start;                           //    N = Start (as unsigned)
4506
4507      // Then, try to solve the above equation provided that Start is constant.
4508      if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
4509        return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
4510                                            -StartC->getValue()->getValue(),
4511                                            *this);
4512    }
4513  } else if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
4514    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
4515    // the quadratic equation to solve it.
4516    std::pair<const SCEV *,const SCEV *> Roots = SolveQuadraticEquation(AddRec,
4517                                                                    *this);
4518    const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
4519    const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
4520    if (R1) {
4521#if 0
4522      dbgs() << "HFTZ: " << *V << " - sol#1: " << *R1
4523             << "  sol#2: " << *R2 << "\n";
4524#endif
4525      // Pick the smallest positive root value.
4526      if (ConstantInt *CB =
4527          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
4528                                   R1->getValue(), R2->getValue()))) {
4529        if (CB->getZExtValue() == false)
4530          std::swap(R1, R2);   // R1 is the minimum root now.
4531
4532        // We can only use this value if the chrec ends up with an exact zero
4533        // value at this index.  When solving for "X*X != 5", for example, we
4534        // should not accept a root of 2.
4535        const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
4536        if (Val->isZero())
4537          return R1;  // We found a quadratic root!
4538      }
4539    }
4540  }
4541
4542  return getCouldNotCompute();
4543}
4544
4545/// HowFarToNonZero - Return the number of times a backedge checking the
4546/// specified value for nonzero will execute.  If not computable, return
4547/// CouldNotCompute
4548ScalarEvolution::BackedgeTakenInfo
4549ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) {
4550  // Loops that look like: while (X == 0) are very strange indeed.  We don't
4551  // handle them yet except for the trivial case.  This could be expanded in the
4552  // future as needed.
4553
4554  // If the value is a constant, check to see if it is known to be non-zero
4555  // already.  If so, the backedge will execute zero times.
4556  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
4557    if (!C->getValue()->isNullValue())
4558      return getIntegerSCEV(0, C->getType());
4559    return getCouldNotCompute();  // Otherwise it will loop infinitely.
4560  }
4561
4562  // We could implement others, but I really doubt anyone writes loops like
4563  // this, and if they did, they would already be constant folded.
4564  return getCouldNotCompute();
4565}
4566
4567/// getLoopPredecessor - If the given loop's header has exactly one unique
4568/// predecessor outside the loop, return it. Otherwise return null.
4569///
4570BasicBlock *ScalarEvolution::getLoopPredecessor(const Loop *L) {
4571  BasicBlock *Header = L->getHeader();
4572  BasicBlock *Pred = 0;
4573  for (pred_iterator PI = pred_begin(Header), E = pred_end(Header);
4574       PI != E; ++PI)
4575    if (!L->contains(*PI)) {
4576      if (Pred && Pred != *PI) return 0; // Multiple predecessors.
4577      Pred = *PI;
4578    }
4579  return Pred;
4580}
4581
4582/// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
4583/// (which may not be an immediate predecessor) which has exactly one
4584/// successor from which BB is reachable, or null if no such block is
4585/// found.
4586///
4587BasicBlock *
4588ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
4589  // If the block has a unique predecessor, then there is no path from the
4590  // predecessor to the block that does not go through the direct edge
4591  // from the predecessor to the block.
4592  if (BasicBlock *Pred = BB->getSinglePredecessor())
4593    return Pred;
4594
4595  // A loop's header is defined to be a block that dominates the loop.
4596  // If the header has a unique predecessor outside the loop, it must be
4597  // a block that has exactly one successor that can reach the loop.
4598  if (Loop *L = LI->getLoopFor(BB))
4599    return getLoopPredecessor(L);
4600
4601  return 0;
4602}
4603
4604/// HasSameValue - SCEV structural equivalence is usually sufficient for
4605/// testing whether two expressions are equal, however for the purposes of
4606/// looking for a condition guarding a loop, it can be useful to be a little
4607/// more general, since a front-end may have replicated the controlling
4608/// expression.
4609///
4610static bool HasSameValue(const SCEV *A, const SCEV *B) {
4611  // Quick check to see if they are the same SCEV.
4612  if (A == B) return true;
4613
4614  // Otherwise, if they're both SCEVUnknown, it's possible that they hold
4615  // two different instructions with the same value. Check for this case.
4616  if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
4617    if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
4618      if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
4619        if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
4620          if (AI->isIdenticalTo(BI) && !AI->mayReadFromMemory())
4621            return true;
4622
4623  // Otherwise assume they may have a different value.
4624  return false;
4625}
4626
4627bool ScalarEvolution::isKnownNegative(const SCEV *S) {
4628  return getSignedRange(S).getSignedMax().isNegative();
4629}
4630
4631bool ScalarEvolution::isKnownPositive(const SCEV *S) {
4632  return getSignedRange(S).getSignedMin().isStrictlyPositive();
4633}
4634
4635bool ScalarEvolution::isKnownNonNegative(const SCEV *S) {
4636  return !getSignedRange(S).getSignedMin().isNegative();
4637}
4638
4639bool ScalarEvolution::isKnownNonPositive(const SCEV *S) {
4640  return !getSignedRange(S).getSignedMax().isStrictlyPositive();
4641}
4642
4643bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
4644  return isKnownNegative(S) || isKnownPositive(S);
4645}
4646
4647bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
4648                                       const SCEV *LHS, const SCEV *RHS) {
4649
4650  if (HasSameValue(LHS, RHS))
4651    return ICmpInst::isTrueWhenEqual(Pred);
4652
4653  switch (Pred) {
4654  default:
4655    llvm_unreachable("Unexpected ICmpInst::Predicate value!");
4656    break;
4657  case ICmpInst::ICMP_SGT:
4658    Pred = ICmpInst::ICMP_SLT;
4659    std::swap(LHS, RHS);
4660  case ICmpInst::ICMP_SLT: {
4661    ConstantRange LHSRange = getSignedRange(LHS);
4662    ConstantRange RHSRange = getSignedRange(RHS);
4663    if (LHSRange.getSignedMax().slt(RHSRange.getSignedMin()))
4664      return true;
4665    if (LHSRange.getSignedMin().sge(RHSRange.getSignedMax()))
4666      return false;
4667    break;
4668  }
4669  case ICmpInst::ICMP_SGE:
4670    Pred = ICmpInst::ICMP_SLE;
4671    std::swap(LHS, RHS);
4672  case ICmpInst::ICMP_SLE: {
4673    ConstantRange LHSRange = getSignedRange(LHS);
4674    ConstantRange RHSRange = getSignedRange(RHS);
4675    if (LHSRange.getSignedMax().sle(RHSRange.getSignedMin()))
4676      return true;
4677    if (LHSRange.getSignedMin().sgt(RHSRange.getSignedMax()))
4678      return false;
4679    break;
4680  }
4681  case ICmpInst::ICMP_UGT:
4682    Pred = ICmpInst::ICMP_ULT;
4683    std::swap(LHS, RHS);
4684  case ICmpInst::ICMP_ULT: {
4685    ConstantRange LHSRange = getUnsignedRange(LHS);
4686    ConstantRange RHSRange = getUnsignedRange(RHS);
4687    if (LHSRange.getUnsignedMax().ult(RHSRange.getUnsignedMin()))
4688      return true;
4689    if (LHSRange.getUnsignedMin().uge(RHSRange.getUnsignedMax()))
4690      return false;
4691    break;
4692  }
4693  case ICmpInst::ICMP_UGE:
4694    Pred = ICmpInst::ICMP_ULE;
4695    std::swap(LHS, RHS);
4696  case ICmpInst::ICMP_ULE: {
4697    ConstantRange LHSRange = getUnsignedRange(LHS);
4698    ConstantRange RHSRange = getUnsignedRange(RHS);
4699    if (LHSRange.getUnsignedMax().ule(RHSRange.getUnsignedMin()))
4700      return true;
4701    if (LHSRange.getUnsignedMin().ugt(RHSRange.getUnsignedMax()))
4702      return false;
4703    break;
4704  }
4705  case ICmpInst::ICMP_NE: {
4706    if (getUnsignedRange(LHS).intersectWith(getUnsignedRange(RHS)).isEmptySet())
4707      return true;
4708    if (getSignedRange(LHS).intersectWith(getSignedRange(RHS)).isEmptySet())
4709      return true;
4710
4711    const SCEV *Diff = getMinusSCEV(LHS, RHS);
4712    if (isKnownNonZero(Diff))
4713      return true;
4714    break;
4715  }
4716  case ICmpInst::ICMP_EQ:
4717    // The check at the top of the function catches the case where
4718    // the values are known to be equal.
4719    break;
4720  }
4721  return false;
4722}
4723
4724/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
4725/// protected by a conditional between LHS and RHS.  This is used to
4726/// to eliminate casts.
4727bool
4728ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
4729                                             ICmpInst::Predicate Pred,
4730                                             const SCEV *LHS, const SCEV *RHS) {
4731  // Interpret a null as meaning no loop, where there is obviously no guard
4732  // (interprocedural conditions notwithstanding).
4733  if (!L) return true;
4734
4735  BasicBlock *Latch = L->getLoopLatch();
4736  if (!Latch)
4737    return false;
4738
4739  BranchInst *LoopContinuePredicate =
4740    dyn_cast<BranchInst>(Latch->getTerminator());
4741  if (!LoopContinuePredicate ||
4742      LoopContinuePredicate->isUnconditional())
4743    return false;
4744
4745  return isImpliedCond(LoopContinuePredicate->getCondition(), Pred, LHS, RHS,
4746                       LoopContinuePredicate->getSuccessor(0) != L->getHeader());
4747}
4748
4749/// isLoopGuardedByCond - Test whether entry to the loop is protected
4750/// by a conditional between LHS and RHS.  This is used to help avoid max
4751/// expressions in loop trip counts, and to eliminate casts.
4752bool
4753ScalarEvolution::isLoopGuardedByCond(const Loop *L,
4754                                     ICmpInst::Predicate Pred,
4755                                     const SCEV *LHS, const SCEV *RHS) {
4756  // Interpret a null as meaning no loop, where there is obviously no guard
4757  // (interprocedural conditions notwithstanding).
4758  if (!L) return false;
4759
4760  BasicBlock *Predecessor = getLoopPredecessor(L);
4761  BasicBlock *PredecessorDest = L->getHeader();
4762
4763  // Starting at the loop predecessor, climb up the predecessor chain, as long
4764  // as there are predecessors that can be found that have unique successors
4765  // leading to the original header.
4766  for (; Predecessor;
4767       PredecessorDest = Predecessor,
4768       Predecessor = getPredecessorWithUniqueSuccessorForBB(Predecessor)) {
4769
4770    BranchInst *LoopEntryPredicate =
4771      dyn_cast<BranchInst>(Predecessor->getTerminator());
4772    if (!LoopEntryPredicate ||
4773        LoopEntryPredicate->isUnconditional())
4774      continue;
4775
4776    if (isImpliedCond(LoopEntryPredicate->getCondition(), Pred, LHS, RHS,
4777                      LoopEntryPredicate->getSuccessor(0) != PredecessorDest))
4778      return true;
4779  }
4780
4781  return false;
4782}
4783
4784/// isImpliedCond - Test whether the condition described by Pred, LHS,
4785/// and RHS is true whenever the given Cond value evaluates to true.
4786bool ScalarEvolution::isImpliedCond(Value *CondValue,
4787                                    ICmpInst::Predicate Pred,
4788                                    const SCEV *LHS, const SCEV *RHS,
4789                                    bool Inverse) {
4790  // Recursively handle And and Or conditions.
4791  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(CondValue)) {
4792    if (BO->getOpcode() == Instruction::And) {
4793      if (!Inverse)
4794        return isImpliedCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) ||
4795               isImpliedCond(BO->getOperand(1), Pred, LHS, RHS, Inverse);
4796    } else if (BO->getOpcode() == Instruction::Or) {
4797      if (Inverse)
4798        return isImpliedCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) ||
4799               isImpliedCond(BO->getOperand(1), Pred, LHS, RHS, Inverse);
4800    }
4801  }
4802
4803  ICmpInst *ICI = dyn_cast<ICmpInst>(CondValue);
4804  if (!ICI) return false;
4805
4806  // Bail if the ICmp's operands' types are wider than the needed type
4807  // before attempting to call getSCEV on them. This avoids infinite
4808  // recursion, since the analysis of widening casts can require loop
4809  // exit condition information for overflow checking, which would
4810  // lead back here.
4811  if (getTypeSizeInBits(LHS->getType()) <
4812      getTypeSizeInBits(ICI->getOperand(0)->getType()))
4813    return false;
4814
4815  // Now that we found a conditional branch that dominates the loop, check to
4816  // see if it is the comparison we are looking for.
4817  ICmpInst::Predicate FoundPred;
4818  if (Inverse)
4819    FoundPred = ICI->getInversePredicate();
4820  else
4821    FoundPred = ICI->getPredicate();
4822
4823  const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
4824  const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
4825
4826  // Balance the types. The case where FoundLHS' type is wider than
4827  // LHS' type is checked for above.
4828  if (getTypeSizeInBits(LHS->getType()) >
4829      getTypeSizeInBits(FoundLHS->getType())) {
4830    if (CmpInst::isSigned(Pred)) {
4831      FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
4832      FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
4833    } else {
4834      FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
4835      FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
4836    }
4837  }
4838
4839  // Canonicalize the query to match the way instcombine will have
4840  // canonicalized the comparison.
4841  // First, put a constant operand on the right.
4842  if (isa<SCEVConstant>(LHS)) {
4843    std::swap(LHS, RHS);
4844    Pred = ICmpInst::getSwappedPredicate(Pred);
4845  }
4846  // Then, canonicalize comparisons with boundary cases.
4847  if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
4848    const APInt &RA = RC->getValue()->getValue();
4849    switch (Pred) {
4850    default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
4851    case ICmpInst::ICMP_EQ:
4852    case ICmpInst::ICMP_NE:
4853      break;
4854    case ICmpInst::ICMP_UGE:
4855      if ((RA - 1).isMinValue()) {
4856        Pred = ICmpInst::ICMP_NE;
4857        RHS = getConstant(RA - 1);
4858        break;
4859      }
4860      if (RA.isMaxValue()) {
4861        Pred = ICmpInst::ICMP_EQ;
4862        break;
4863      }
4864      if (RA.isMinValue()) return true;
4865      break;
4866    case ICmpInst::ICMP_ULE:
4867      if ((RA + 1).isMaxValue()) {
4868        Pred = ICmpInst::ICMP_NE;
4869        RHS = getConstant(RA + 1);
4870        break;
4871      }
4872      if (RA.isMinValue()) {
4873        Pred = ICmpInst::ICMP_EQ;
4874        break;
4875      }
4876      if (RA.isMaxValue()) return true;
4877      break;
4878    case ICmpInst::ICMP_SGE:
4879      if ((RA - 1).isMinSignedValue()) {
4880        Pred = ICmpInst::ICMP_NE;
4881        RHS = getConstant(RA - 1);
4882        break;
4883      }
4884      if (RA.isMaxSignedValue()) {
4885        Pred = ICmpInst::ICMP_EQ;
4886        break;
4887      }
4888      if (RA.isMinSignedValue()) return true;
4889      break;
4890    case ICmpInst::ICMP_SLE:
4891      if ((RA + 1).isMaxSignedValue()) {
4892        Pred = ICmpInst::ICMP_NE;
4893        RHS = getConstant(RA + 1);
4894        break;
4895      }
4896      if (RA.isMinSignedValue()) {
4897        Pred = ICmpInst::ICMP_EQ;
4898        break;
4899      }
4900      if (RA.isMaxSignedValue()) return true;
4901      break;
4902    case ICmpInst::ICMP_UGT:
4903      if (RA.isMinValue()) {
4904        Pred = ICmpInst::ICMP_NE;
4905        break;
4906      }
4907      if ((RA + 1).isMaxValue()) {
4908        Pred = ICmpInst::ICMP_EQ;
4909        RHS = getConstant(RA + 1);
4910        break;
4911      }
4912      if (RA.isMaxValue()) return false;
4913      break;
4914    case ICmpInst::ICMP_ULT:
4915      if (RA.isMaxValue()) {
4916        Pred = ICmpInst::ICMP_NE;
4917        break;
4918      }
4919      if ((RA - 1).isMinValue()) {
4920        Pred = ICmpInst::ICMP_EQ;
4921        RHS = getConstant(RA - 1);
4922        break;
4923      }
4924      if (RA.isMinValue()) return false;
4925      break;
4926    case ICmpInst::ICMP_SGT:
4927      if (RA.isMinSignedValue()) {
4928        Pred = ICmpInst::ICMP_NE;
4929        break;
4930      }
4931      if ((RA + 1).isMaxSignedValue()) {
4932        Pred = ICmpInst::ICMP_EQ;
4933        RHS = getConstant(RA + 1);
4934        break;
4935      }
4936      if (RA.isMaxSignedValue()) return false;
4937      break;
4938    case ICmpInst::ICMP_SLT:
4939      if (RA.isMaxSignedValue()) {
4940        Pred = ICmpInst::ICMP_NE;
4941        break;
4942      }
4943      if ((RA - 1).isMinSignedValue()) {
4944       Pred = ICmpInst::ICMP_EQ;
4945       RHS = getConstant(RA - 1);
4946       break;
4947      }
4948      if (RA.isMinSignedValue()) return false;
4949      break;
4950    }
4951  }
4952
4953  // Check to see if we can make the LHS or RHS match.
4954  if (LHS == FoundRHS || RHS == FoundLHS) {
4955    if (isa<SCEVConstant>(RHS)) {
4956      std::swap(FoundLHS, FoundRHS);
4957      FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
4958    } else {
4959      std::swap(LHS, RHS);
4960      Pred = ICmpInst::getSwappedPredicate(Pred);
4961    }
4962  }
4963
4964  // Check whether the found predicate is the same as the desired predicate.
4965  if (FoundPred == Pred)
4966    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
4967
4968  // Check whether swapping the found predicate makes it the same as the
4969  // desired predicate.
4970  if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
4971    if (isa<SCEVConstant>(RHS))
4972      return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS);
4973    else
4974      return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred),
4975                                   RHS, LHS, FoundLHS, FoundRHS);
4976  }
4977
4978  // Check whether the actual condition is beyond sufficient.
4979  if (FoundPred == ICmpInst::ICMP_EQ)
4980    if (ICmpInst::isTrueWhenEqual(Pred))
4981      if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS))
4982        return true;
4983  if (Pred == ICmpInst::ICMP_NE)
4984    if (!ICmpInst::isTrueWhenEqual(FoundPred))
4985      if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS))
4986        return true;
4987
4988  // Otherwise assume the worst.
4989  return false;
4990}
4991
4992/// isImpliedCondOperands - Test whether the condition described by Pred,
4993/// LHS, and RHS is true whenever the condition described by Pred, FoundLHS,
4994/// and FoundRHS is true.
4995bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
4996                                            const SCEV *LHS, const SCEV *RHS,
4997                                            const SCEV *FoundLHS,
4998                                            const SCEV *FoundRHS) {
4999  return isImpliedCondOperandsHelper(Pred, LHS, RHS,
5000                                     FoundLHS, FoundRHS) ||
5001         // ~x < ~y --> x > y
5002         isImpliedCondOperandsHelper(Pred, LHS, RHS,
5003                                     getNotSCEV(FoundRHS),
5004                                     getNotSCEV(FoundLHS));
5005}
5006
5007/// isImpliedCondOperandsHelper - Test whether the condition described by
5008/// Pred, LHS, and RHS is true whenever the condition described by Pred,
5009/// FoundLHS, and FoundRHS is true.
5010bool
5011ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
5012                                             const SCEV *LHS, const SCEV *RHS,
5013                                             const SCEV *FoundLHS,
5014                                             const SCEV *FoundRHS) {
5015  switch (Pred) {
5016  default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
5017  case ICmpInst::ICMP_EQ:
5018  case ICmpInst::ICMP_NE:
5019    if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
5020      return true;
5021    break;
5022  case ICmpInst::ICMP_SLT:
5023  case ICmpInst::ICMP_SLE:
5024    if (isKnownPredicate(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
5025        isKnownPredicate(ICmpInst::ICMP_SGE, RHS, FoundRHS))
5026      return true;
5027    break;
5028  case ICmpInst::ICMP_SGT:
5029  case ICmpInst::ICMP_SGE:
5030    if (isKnownPredicate(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
5031        isKnownPredicate(ICmpInst::ICMP_SLE, RHS, FoundRHS))
5032      return true;
5033    break;
5034  case ICmpInst::ICMP_ULT:
5035  case ICmpInst::ICMP_ULE:
5036    if (isKnownPredicate(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
5037        isKnownPredicate(ICmpInst::ICMP_UGE, RHS, FoundRHS))
5038      return true;
5039    break;
5040  case ICmpInst::ICMP_UGT:
5041  case ICmpInst::ICMP_UGE:
5042    if (isKnownPredicate(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
5043        isKnownPredicate(ICmpInst::ICMP_ULE, RHS, FoundRHS))
5044      return true;
5045    break;
5046  }
5047
5048  return false;
5049}
5050
5051/// getBECount - Subtract the end and start values and divide by the step,
5052/// rounding up, to get the number of times the backedge is executed. Return
5053/// CouldNotCompute if an intermediate computation overflows.
5054const SCEV *ScalarEvolution::getBECount(const SCEV *Start,
5055                                        const SCEV *End,
5056                                        const SCEV *Step,
5057                                        bool NoWrap) {
5058  assert(!isKnownNegative(Step) &&
5059         "This code doesn't handle negative strides yet!");
5060
5061  const Type *Ty = Start->getType();
5062  const SCEV *NegOne = getIntegerSCEV(-1, Ty);
5063  const SCEV *Diff = getMinusSCEV(End, Start);
5064  const SCEV *RoundUp = getAddExpr(Step, NegOne);
5065
5066  // Add an adjustment to the difference between End and Start so that
5067  // the division will effectively round up.
5068  const SCEV *Add = getAddExpr(Diff, RoundUp);
5069
5070  if (!NoWrap) {
5071    // Check Add for unsigned overflow.
5072    // TODO: More sophisticated things could be done here.
5073    const Type *WideTy = IntegerType::get(getContext(),
5074                                          getTypeSizeInBits(Ty) + 1);
5075    const SCEV *EDiff = getZeroExtendExpr(Diff, WideTy);
5076    const SCEV *ERoundUp = getZeroExtendExpr(RoundUp, WideTy);
5077    const SCEV *OperandExtendedAdd = getAddExpr(EDiff, ERoundUp);
5078    if (getZeroExtendExpr(Add, WideTy) != OperandExtendedAdd)
5079      return getCouldNotCompute();
5080  }
5081
5082  return getUDivExpr(Add, Step);
5083}
5084
5085/// HowManyLessThans - Return the number of times a backedge containing the
5086/// specified less-than comparison will execute.  If not computable, return
5087/// CouldNotCompute.
5088ScalarEvolution::BackedgeTakenInfo
5089ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
5090                                  const Loop *L, bool isSigned) {
5091  // Only handle:  "ADDREC < LoopInvariant".
5092  if (!RHS->isLoopInvariant(L)) return getCouldNotCompute();
5093
5094  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS);
5095  if (!AddRec || AddRec->getLoop() != L)
5096    return getCouldNotCompute();
5097
5098  // Check to see if we have a flag which makes analysis easy.
5099  bool NoWrap = isSigned ? AddRec->hasNoSignedWrap() :
5100                           AddRec->hasNoUnsignedWrap();
5101
5102  if (AddRec->isAffine()) {
5103    unsigned BitWidth = getTypeSizeInBits(AddRec->getType());
5104    const SCEV *Step = AddRec->getStepRecurrence(*this);
5105
5106    if (Step->isZero())
5107      return getCouldNotCompute();
5108    if (Step->isOne()) {
5109      // With unit stride, the iteration never steps past the limit value.
5110    } else if (isKnownPositive(Step)) {
5111      // Test whether a positive iteration can step past the limit
5112      // value and past the maximum value for its type in a single step.
5113      // Note that it's not sufficient to check NoWrap here, because even
5114      // though the value after a wrap is undefined, it's not undefined
5115      // behavior, so if wrap does occur, the loop could either terminate or
5116      // loop infinitely, but in either case, the loop is guaranteed to
5117      // iterate at least until the iteration where the wrapping occurs.
5118      const SCEV *One = getIntegerSCEV(1, Step->getType());
5119      if (isSigned) {
5120        APInt Max = APInt::getSignedMaxValue(BitWidth);
5121        if ((Max - getSignedRange(getMinusSCEV(Step, One)).getSignedMax())
5122              .slt(getSignedRange(RHS).getSignedMax()))
5123          return getCouldNotCompute();
5124      } else {
5125        APInt Max = APInt::getMaxValue(BitWidth);
5126        if ((Max - getUnsignedRange(getMinusSCEV(Step, One)).getUnsignedMax())
5127              .ult(getUnsignedRange(RHS).getUnsignedMax()))
5128          return getCouldNotCompute();
5129      }
5130    } else
5131      // TODO: Handle negative strides here and below.
5132      return getCouldNotCompute();
5133
5134    // We know the LHS is of the form {n,+,s} and the RHS is some loop-invariant
5135    // m.  So, we count the number of iterations in which {n,+,s} < m is true.
5136    // Note that we cannot simply return max(m-n,0)/s because it's not safe to
5137    // treat m-n as signed nor unsigned due to overflow possibility.
5138
5139    // First, we get the value of the LHS in the first iteration: n
5140    const SCEV *Start = AddRec->getOperand(0);
5141
5142    // Determine the minimum constant start value.
5143    const SCEV *MinStart = getConstant(isSigned ?
5144      getSignedRange(Start).getSignedMin() :
5145      getUnsignedRange(Start).getUnsignedMin());
5146
5147    // If we know that the condition is true in order to enter the loop,
5148    // then we know that it will run exactly (m-n)/s times. Otherwise, we
5149    // only know that it will execute (max(m,n)-n)/s times. In both cases,
5150    // the division must round up.
5151    const SCEV *End = RHS;
5152    if (!isLoopGuardedByCond(L,
5153                             isSigned ? ICmpInst::ICMP_SLT :
5154                                        ICmpInst::ICMP_ULT,
5155                             getMinusSCEV(Start, Step), RHS))
5156      End = isSigned ? getSMaxExpr(RHS, Start)
5157                     : getUMaxExpr(RHS, Start);
5158
5159    // Determine the maximum constant end value.
5160    const SCEV *MaxEnd = getConstant(isSigned ?
5161      getSignedRange(End).getSignedMax() :
5162      getUnsignedRange(End).getUnsignedMax());
5163
5164    // If MaxEnd is within a step of the maximum integer value in its type,
5165    // adjust it down to the minimum value which would produce the same effect.
5166    // This allows the subsequent ceiling division of (N+(step-1))/step to
5167    // compute the correct value.
5168    const SCEV *StepMinusOne = getMinusSCEV(Step,
5169                                            getIntegerSCEV(1, Step->getType()));
5170    MaxEnd = isSigned ?
5171      getSMinExpr(MaxEnd,
5172                  getMinusSCEV(getConstant(APInt::getSignedMaxValue(BitWidth)),
5173                               StepMinusOne)) :
5174      getUMinExpr(MaxEnd,
5175                  getMinusSCEV(getConstant(APInt::getMaxValue(BitWidth)),
5176                               StepMinusOne));
5177
5178    // Finally, we subtract these two values and divide, rounding up, to get
5179    // the number of times the backedge is executed.
5180    const SCEV *BECount = getBECount(Start, End, Step, NoWrap);
5181
5182    // The maximum backedge count is similar, except using the minimum start
5183    // value and the maximum end value.
5184    const SCEV *MaxBECount = getBECount(MinStart, MaxEnd, Step, NoWrap);
5185
5186    return BackedgeTakenInfo(BECount, MaxBECount);
5187  }
5188
5189  return getCouldNotCompute();
5190}
5191
5192/// getNumIterationsInRange - Return the number of iterations of this loop that
5193/// produce values in the specified constant range.  Another way of looking at
5194/// this is that it returns the first iteration number where the value is not in
5195/// the condition, thus computing the exit count. If the iteration count can't
5196/// be computed, an instance of SCEVCouldNotCompute is returned.
5197const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
5198                                                    ScalarEvolution &SE) const {
5199  if (Range.isFullSet())  // Infinite loop.
5200    return SE.getCouldNotCompute();
5201
5202  // If the start is a non-zero constant, shift the range to simplify things.
5203  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
5204    if (!SC->getValue()->isZero()) {
5205      SmallVector<const SCEV *, 4> Operands(op_begin(), op_end());
5206      Operands[0] = SE.getIntegerSCEV(0, SC->getType());
5207      const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop());
5208      if (const SCEVAddRecExpr *ShiftedAddRec =
5209            dyn_cast<SCEVAddRecExpr>(Shifted))
5210        return ShiftedAddRec->getNumIterationsInRange(
5211                           Range.subtract(SC->getValue()->getValue()), SE);
5212      // This is strange and shouldn't happen.
5213      return SE.getCouldNotCompute();
5214    }
5215
5216  // The only time we can solve this is when we have all constant indices.
5217  // Otherwise, we cannot determine the overflow conditions.
5218  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
5219    if (!isa<SCEVConstant>(getOperand(i)))
5220      return SE.getCouldNotCompute();
5221
5222
5223  // Okay at this point we know that all elements of the chrec are constants and
5224  // that the start element is zero.
5225
5226  // First check to see if the range contains zero.  If not, the first
5227  // iteration exits.
5228  unsigned BitWidth = SE.getTypeSizeInBits(getType());
5229  if (!Range.contains(APInt(BitWidth, 0)))
5230    return SE.getIntegerSCEV(0, getType());
5231
5232  if (isAffine()) {
5233    // If this is an affine expression then we have this situation:
5234    //   Solve {0,+,A} in Range  ===  Ax in Range
5235
5236    // We know that zero is in the range.  If A is positive then we know that
5237    // the upper value of the range must be the first possible exit value.
5238    // If A is negative then the lower of the range is the last possible loop
5239    // value.  Also note that we already checked for a full range.
5240    APInt One(BitWidth,1);
5241    APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
5242    APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
5243
5244    // The exit value should be (End+A)/A.
5245    APInt ExitVal = (End + A).udiv(A);
5246    ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
5247
5248    // Evaluate at the exit value.  If we really did fall out of the valid
5249    // range, then we computed our trip count, otherwise wrap around or other
5250    // things must have happened.
5251    ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
5252    if (Range.contains(Val->getValue()))
5253      return SE.getCouldNotCompute();  // Something strange happened
5254
5255    // Ensure that the previous value is in the range.  This is a sanity check.
5256    assert(Range.contains(
5257           EvaluateConstantChrecAtConstant(this,
5258           ConstantInt::get(SE.getContext(), ExitVal - One), SE)->getValue()) &&
5259           "Linear scev computation is off in a bad way!");
5260    return SE.getConstant(ExitValue);
5261  } else if (isQuadratic()) {
5262    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
5263    // quadratic equation to solve it.  To do this, we must frame our problem in
5264    // terms of figuring out when zero is crossed, instead of when
5265    // Range.getUpper() is crossed.
5266    SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end());
5267    NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper()));
5268    const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop());
5269
5270    // Next, solve the constructed addrec
5271    std::pair<const SCEV *,const SCEV *> Roots =
5272      SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
5273    const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
5274    const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
5275    if (R1) {
5276      // Pick the smallest positive root value.
5277      if (ConstantInt *CB =
5278          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
5279                         R1->getValue(), R2->getValue()))) {
5280        if (CB->getZExtValue() == false)
5281          std::swap(R1, R2);   // R1 is the minimum root now.
5282
5283        // Make sure the root is not off by one.  The returned iteration should
5284        // not be in the range, but the previous one should be.  When solving
5285        // for "X*X < 5", for example, we should not return a root of 2.
5286        ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this,
5287                                                             R1->getValue(),
5288                                                             SE);
5289        if (Range.contains(R1Val->getValue())) {
5290          // The next iteration must be out of the range...
5291          ConstantInt *NextVal =
5292                ConstantInt::get(SE.getContext(), R1->getValue()->getValue()+1);
5293
5294          R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
5295          if (!Range.contains(R1Val->getValue()))
5296            return SE.getConstant(NextVal);
5297          return SE.getCouldNotCompute();  // Something strange happened
5298        }
5299
5300        // If R1 was not in the range, then it is a good return value.  Make
5301        // sure that R1-1 WAS in the range though, just in case.
5302        ConstantInt *NextVal =
5303               ConstantInt::get(SE.getContext(), R1->getValue()->getValue()-1);
5304        R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
5305        if (Range.contains(R1Val->getValue()))
5306          return R1;
5307        return SE.getCouldNotCompute();  // Something strange happened
5308      }
5309    }
5310  }
5311
5312  return SE.getCouldNotCompute();
5313}
5314
5315
5316
5317//===----------------------------------------------------------------------===//
5318//                   SCEVCallbackVH Class Implementation
5319//===----------------------------------------------------------------------===//
5320
5321void ScalarEvolution::SCEVCallbackVH::deleted() {
5322  assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
5323  if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
5324    SE->ConstantEvolutionLoopExitValue.erase(PN);
5325  SE->Scalars.erase(getValPtr());
5326  // this now dangles!
5327}
5328
5329void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *) {
5330  assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
5331
5332  // Forget all the expressions associated with users of the old value,
5333  // so that future queries will recompute the expressions using the new
5334  // value.
5335  SmallVector<User *, 16> Worklist;
5336  SmallPtrSet<User *, 8> Visited;
5337  Value *Old = getValPtr();
5338  bool DeleteOld = false;
5339  for (Value::use_iterator UI = Old->use_begin(), UE = Old->use_end();
5340       UI != UE; ++UI)
5341    Worklist.push_back(*UI);
5342  while (!Worklist.empty()) {
5343    User *U = Worklist.pop_back_val();
5344    // Deleting the Old value will cause this to dangle. Postpone
5345    // that until everything else is done.
5346    if (U == Old) {
5347      DeleteOld = true;
5348      continue;
5349    }
5350    if (!Visited.insert(U))
5351      continue;
5352    if (PHINode *PN = dyn_cast<PHINode>(U))
5353      SE->ConstantEvolutionLoopExitValue.erase(PN);
5354    SE->Scalars.erase(U);
5355    for (Value::use_iterator UI = U->use_begin(), UE = U->use_end();
5356         UI != UE; ++UI)
5357      Worklist.push_back(*UI);
5358  }
5359  // Delete the Old value if it (indirectly) references itself.
5360  if (DeleteOld) {
5361    if (PHINode *PN = dyn_cast<PHINode>(Old))
5362      SE->ConstantEvolutionLoopExitValue.erase(PN);
5363    SE->Scalars.erase(Old);
5364    // this now dangles!
5365  }
5366  // this may dangle!
5367}
5368
5369ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
5370  : CallbackVH(V), SE(se) {}
5371
5372//===----------------------------------------------------------------------===//
5373//                   ScalarEvolution Class Implementation
5374//===----------------------------------------------------------------------===//
5375
5376ScalarEvolution::ScalarEvolution()
5377  : FunctionPass(&ID) {
5378}
5379
5380bool ScalarEvolution::runOnFunction(Function &F) {
5381  this->F = &F;
5382  LI = &getAnalysis<LoopInfo>();
5383  TD = getAnalysisIfAvailable<TargetData>();
5384  DT = &getAnalysis<DominatorTree>();
5385  return false;
5386}
5387
5388void ScalarEvolution::releaseMemory() {
5389  Scalars.clear();
5390  BackedgeTakenCounts.clear();
5391  ConstantEvolutionLoopExitValue.clear();
5392  ValuesAtScopes.clear();
5393  UniqueSCEVs.clear();
5394  SCEVAllocator.Reset();
5395}
5396
5397void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
5398  AU.setPreservesAll();
5399  AU.addRequiredTransitive<LoopInfo>();
5400  AU.addRequiredTransitive<DominatorTree>();
5401}
5402
5403bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
5404  return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
5405}
5406
5407static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
5408                          const Loop *L) {
5409  // Print all inner loops first
5410  for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
5411    PrintLoopInfo(OS, SE, *I);
5412
5413  OS << "Loop ";
5414  WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false);
5415  OS << ": ";
5416
5417  SmallVector<BasicBlock *, 8> ExitBlocks;
5418  L->getExitBlocks(ExitBlocks);
5419  if (ExitBlocks.size() != 1)
5420    OS << "<multiple exits> ";
5421
5422  if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
5423    OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L);
5424  } else {
5425    OS << "Unpredictable backedge-taken count. ";
5426  }
5427
5428  OS << "\n"
5429        "Loop ";
5430  WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false);
5431  OS << ": ";
5432
5433  if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) {
5434    OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L);
5435  } else {
5436    OS << "Unpredictable max backedge-taken count. ";
5437  }
5438
5439  OS << "\n";
5440}
5441
5442void ScalarEvolution::print(raw_ostream &OS, const Module *) const {
5443  // ScalarEvolution's implementation of the print method is to print
5444  // out SCEV values of all instructions that are interesting. Doing
5445  // this potentially causes it to create new SCEV objects though,
5446  // which technically conflicts with the const qualifier. This isn't
5447  // observable from outside the class though, so casting away the
5448  // const isn't dangerous.
5449  ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
5450
5451  OS << "Classifying expressions for: ";
5452  WriteAsOperand(OS, F, /*PrintType=*/false);
5453  OS << "\n";
5454  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
5455    if (isSCEVable(I->getType())) {
5456      OS << *I << '\n';
5457      OS << "  -->  ";
5458      const SCEV *SV = SE.getSCEV(&*I);
5459      SV->print(OS);
5460
5461      const Loop *L = LI->getLoopFor((*I).getParent());
5462
5463      const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
5464      if (AtUse != SV) {
5465        OS << "  -->  ";
5466        AtUse->print(OS);
5467      }
5468
5469      if (L) {
5470        OS << "\t\t" "Exits: ";
5471        const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
5472        if (!ExitValue->isLoopInvariant(L)) {
5473          OS << "<<Unknown>>";
5474        } else {
5475          OS << *ExitValue;
5476        }
5477      }
5478
5479      OS << "\n";
5480    }
5481
5482  OS << "Determining loop execution counts for: ";
5483  WriteAsOperand(OS, F, /*PrintType=*/false);
5484  OS << "\n";
5485  for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I)
5486    PrintLoopInfo(OS, &SE, *I);
5487}
5488
5489