ScalarEvolution.cpp revision 59846aced20b23882a97b92da5d653dc3f3e8526
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file contains the implementation of the scalar evolution analysis
11// engine, which is used primarily to analyze expressions involving induction
12// variables in loops.
13//
14// There are several aspects to this library.  First is the representation of
15// scalar expressions, which are represented as subclasses of the SCEV class.
16// These classes are used to represent certain types of subexpressions that we
17// can handle. We only create one SCEV of a particular shape, so
18// pointer-comparisons for equality are legal.
19//
20// One important aspect of the SCEV objects is that they are never cyclic, even
21// if there is a cycle in the dataflow for an expression (ie, a PHI node).  If
22// the PHI node is one of the idioms that we can represent (e.g., a polynomial
23// recurrence) then we represent it directly as a recurrence node, otherwise we
24// represent it as a SCEVUnknown node.
25//
26// In addition to being able to represent expressions of various types, we also
27// have folders that are used to build the *canonical* representation for a
28// particular expression.  These folders are capable of using a variety of
29// rewrite rules to simplify the expressions.
30//
31// Once the folders are defined, we can implement the more interesting
32// higher-level code, such as the code that recognizes PHI nodes of various
33// types, computes the execution count of a loop, etc.
34//
35// TODO: We should use these routines and value representations to implement
36// dependence analysis!
37//
38//===----------------------------------------------------------------------===//
39//
40// There are several good references for the techniques used in this analysis.
41//
42//  Chains of recurrences -- a method to expedite the evaluation
43//  of closed-form functions
44//  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
45//
46//  On computational properties of chains of recurrences
47//  Eugene V. Zima
48//
49//  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
50//  Robert A. van Engelen
51//
52//  Efficient Symbolic Analysis for Optimizing Compilers
53//  Robert A. van Engelen
54//
55//  Using the chains of recurrences algebra for data dependence testing and
56//  induction variable substitution
57//  MS Thesis, Johnie Birch
58//
59//===----------------------------------------------------------------------===//
60
61#define DEBUG_TYPE "scalar-evolution"
62#include "llvm/Analysis/ScalarEvolutionExpressions.h"
63#include "llvm/Constants.h"
64#include "llvm/DerivedTypes.h"
65#include "llvm/GlobalVariable.h"
66#include "llvm/GlobalAlias.h"
67#include "llvm/Instructions.h"
68#include "llvm/LLVMContext.h"
69#include "llvm/Operator.h"
70#include "llvm/Analysis/ConstantFolding.h"
71#include "llvm/Analysis/Dominators.h"
72#include "llvm/Analysis/LoopInfo.h"
73#include "llvm/Analysis/ValueTracking.h"
74#include "llvm/Assembly/Writer.h"
75#include "llvm/Target/TargetData.h"
76#include "llvm/Support/CommandLine.h"
77#include "llvm/Support/ConstantRange.h"
78#include "llvm/Support/Debug.h"
79#include "llvm/Support/ErrorHandling.h"
80#include "llvm/Support/GetElementPtrTypeIterator.h"
81#include "llvm/Support/InstIterator.h"
82#include "llvm/Support/MathExtras.h"
83#include "llvm/Support/raw_ostream.h"
84#include "llvm/ADT/Statistic.h"
85#include "llvm/ADT/STLExtras.h"
86#include "llvm/ADT/SmallPtrSet.h"
87#include <algorithm>
88using namespace llvm;
89
90STATISTIC(NumArrayLenItCounts,
91          "Number of trip counts computed with array length");
92STATISTIC(NumTripCountsComputed,
93          "Number of loops with predictable loop counts");
94STATISTIC(NumTripCountsNotComputed,
95          "Number of loops without predictable loop counts");
96STATISTIC(NumBruteForceTripCountsComputed,
97          "Number of loops with trip counts computed by force");
98
99static cl::opt<unsigned>
100MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
101                        cl::desc("Maximum number of iterations SCEV will "
102                                 "symbolically execute a constant "
103                                 "derived loop"),
104                        cl::init(100));
105
106INITIALIZE_PASS(ScalarEvolution, "scalar-evolution",
107                "Scalar Evolution Analysis", false, true);
108char ScalarEvolution::ID = 0;
109
110//===----------------------------------------------------------------------===//
111//                           SCEV class definitions
112//===----------------------------------------------------------------------===//
113
114//===----------------------------------------------------------------------===//
115// Implementation of the SCEV class.
116//
117
118SCEV::~SCEV() {}
119
120void SCEV::dump() const {
121  print(dbgs());
122  dbgs() << '\n';
123}
124
125bool SCEV::isZero() const {
126  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
127    return SC->getValue()->isZero();
128  return false;
129}
130
131bool SCEV::isOne() const {
132  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
133    return SC->getValue()->isOne();
134  return false;
135}
136
137bool SCEV::isAllOnesValue() const {
138  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
139    return SC->getValue()->isAllOnesValue();
140  return false;
141}
142
143SCEVCouldNotCompute::SCEVCouldNotCompute() :
144  SCEV(FoldingSetNodeIDRef(), scCouldNotCompute) {}
145
146bool SCEVCouldNotCompute::isLoopInvariant(const Loop *L) const {
147  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
148  return false;
149}
150
151const Type *SCEVCouldNotCompute::getType() const {
152  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
153  return 0;
154}
155
156bool SCEVCouldNotCompute::hasComputableLoopEvolution(const Loop *L) const {
157  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
158  return false;
159}
160
161bool SCEVCouldNotCompute::hasOperand(const SCEV *) const {
162  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
163  return false;
164}
165
166void SCEVCouldNotCompute::print(raw_ostream &OS) const {
167  OS << "***COULDNOTCOMPUTE***";
168}
169
170bool SCEVCouldNotCompute::classof(const SCEV *S) {
171  return S->getSCEVType() == scCouldNotCompute;
172}
173
174const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
175  FoldingSetNodeID ID;
176  ID.AddInteger(scConstant);
177  ID.AddPointer(V);
178  void *IP = 0;
179  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
180  SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
181  UniqueSCEVs.InsertNode(S, IP);
182  return S;
183}
184
185const SCEV *ScalarEvolution::getConstant(const APInt& Val) {
186  return getConstant(ConstantInt::get(getContext(), Val));
187}
188
189const SCEV *
190ScalarEvolution::getConstant(const Type *Ty, uint64_t V, bool isSigned) {
191  const IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
192  return getConstant(ConstantInt::get(ITy, V, isSigned));
193}
194
195const Type *SCEVConstant::getType() const { return V->getType(); }
196
197void SCEVConstant::print(raw_ostream &OS) const {
198  WriteAsOperand(OS, V, false);
199}
200
201SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID,
202                           unsigned SCEVTy, const SCEV *op, const Type *ty)
203  : SCEV(ID, SCEVTy), Op(op), Ty(ty) {}
204
205bool SCEVCastExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
206  return Op->dominates(BB, DT);
207}
208
209bool SCEVCastExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const {
210  return Op->properlyDominates(BB, DT);
211}
212
213SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID,
214                                   const SCEV *op, const Type *ty)
215  : SCEVCastExpr(ID, scTruncate, op, ty) {
216  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
217         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
218         "Cannot truncate non-integer value!");
219}
220
221void SCEVTruncateExpr::print(raw_ostream &OS) const {
222  OS << "(trunc " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
223}
224
225SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
226                                       const SCEV *op, const Type *ty)
227  : SCEVCastExpr(ID, scZeroExtend, op, ty) {
228  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
229         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
230         "Cannot zero extend non-integer value!");
231}
232
233void SCEVZeroExtendExpr::print(raw_ostream &OS) const {
234  OS << "(zext " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
235}
236
237SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
238                                       const SCEV *op, const Type *ty)
239  : SCEVCastExpr(ID, scSignExtend, op, ty) {
240  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
241         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
242         "Cannot sign extend non-integer value!");
243}
244
245void SCEVSignExtendExpr::print(raw_ostream &OS) const {
246  OS << "(sext " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
247}
248
249void SCEVCommutativeExpr::print(raw_ostream &OS) const {
250  const char *OpStr = getOperationStr();
251  OS << "(";
252  for (op_iterator I = op_begin(), E = op_end(); I != E; ++I) {
253    OS << **I;
254    if (next(I) != E)
255      OS << OpStr;
256  }
257  OS << ")";
258}
259
260bool SCEVNAryExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
261  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
262    if (!getOperand(i)->dominates(BB, DT))
263      return false;
264  }
265  return true;
266}
267
268bool SCEVNAryExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const {
269  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
270    if (!getOperand(i)->properlyDominates(BB, DT))
271      return false;
272  }
273  return true;
274}
275
276bool SCEVUDivExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
277  return LHS->dominates(BB, DT) && RHS->dominates(BB, DT);
278}
279
280bool SCEVUDivExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const {
281  return LHS->properlyDominates(BB, DT) && RHS->properlyDominates(BB, DT);
282}
283
284void SCEVUDivExpr::print(raw_ostream &OS) const {
285  OS << "(" << *LHS << " /u " << *RHS << ")";
286}
287
288const Type *SCEVUDivExpr::getType() const {
289  // In most cases the types of LHS and RHS will be the same, but in some
290  // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
291  // depend on the type for correctness, but handling types carefully can
292  // avoid extra casts in the SCEVExpander. The LHS is more likely to be
293  // a pointer type than the RHS, so use the RHS' type here.
294  return RHS->getType();
295}
296
297bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const {
298  // Add recurrences are never invariant in the function-body (null loop).
299  if (!QueryLoop)
300    return false;
301
302  // This recurrence is variant w.r.t. QueryLoop if QueryLoop contains L.
303  if (QueryLoop->contains(L))
304    return false;
305
306  // This recurrence is variant w.r.t. QueryLoop if any of its operands
307  // are variant.
308  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
309    if (!getOperand(i)->isLoopInvariant(QueryLoop))
310      return false;
311
312  // Otherwise it's loop-invariant.
313  return true;
314}
315
316bool
317SCEVAddRecExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
318  return DT->dominates(L->getHeader(), BB) &&
319         SCEVNAryExpr::dominates(BB, DT);
320}
321
322bool
323SCEVAddRecExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const {
324  // This uses a "dominates" query instead of "properly dominates" query because
325  // the instruction which produces the addrec's value is a PHI, and a PHI
326  // effectively properly dominates its entire containing block.
327  return DT->dominates(L->getHeader(), BB) &&
328         SCEVNAryExpr::properlyDominates(BB, DT);
329}
330
331void SCEVAddRecExpr::print(raw_ostream &OS) const {
332  OS << "{" << *Operands[0];
333  for (unsigned i = 1, e = NumOperands; i != e; ++i)
334    OS << ",+," << *Operands[i];
335  OS << "}<";
336  WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false);
337  OS << ">";
338}
339
340bool SCEVUnknown::isLoopInvariant(const Loop *L) const {
341  // All non-instruction values are loop invariant.  All instructions are loop
342  // invariant if they are not contained in the specified loop.
343  // Instructions are never considered invariant in the function body
344  // (null loop) because they are defined within the "loop".
345  if (Instruction *I = dyn_cast<Instruction>(V))
346    return L && !L->contains(I);
347  return true;
348}
349
350bool SCEVUnknown::dominates(BasicBlock *BB, DominatorTree *DT) const {
351  if (Instruction *I = dyn_cast<Instruction>(getValue()))
352    return DT->dominates(I->getParent(), BB);
353  return true;
354}
355
356bool SCEVUnknown::properlyDominates(BasicBlock *BB, DominatorTree *DT) const {
357  if (Instruction *I = dyn_cast<Instruction>(getValue()))
358    return DT->properlyDominates(I->getParent(), BB);
359  return true;
360}
361
362const Type *SCEVUnknown::getType() const {
363  return V->getType();
364}
365
366bool SCEVUnknown::isSizeOf(const Type *&AllocTy) const {
367  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(V))
368    if (VCE->getOpcode() == Instruction::PtrToInt)
369      if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
370        if (CE->getOpcode() == Instruction::GetElementPtr &&
371            CE->getOperand(0)->isNullValue() &&
372            CE->getNumOperands() == 2)
373          if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
374            if (CI->isOne()) {
375              AllocTy = cast<PointerType>(CE->getOperand(0)->getType())
376                                 ->getElementType();
377              return true;
378            }
379
380  return false;
381}
382
383bool SCEVUnknown::isAlignOf(const Type *&AllocTy) const {
384  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(V))
385    if (VCE->getOpcode() == Instruction::PtrToInt)
386      if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
387        if (CE->getOpcode() == Instruction::GetElementPtr &&
388            CE->getOperand(0)->isNullValue()) {
389          const Type *Ty =
390            cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
391          if (const StructType *STy = dyn_cast<StructType>(Ty))
392            if (!STy->isPacked() &&
393                CE->getNumOperands() == 3 &&
394                CE->getOperand(1)->isNullValue()) {
395              if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
396                if (CI->isOne() &&
397                    STy->getNumElements() == 2 &&
398                    STy->getElementType(0)->isIntegerTy(1)) {
399                  AllocTy = STy->getElementType(1);
400                  return true;
401                }
402            }
403        }
404
405  return false;
406}
407
408bool SCEVUnknown::isOffsetOf(const Type *&CTy, Constant *&FieldNo) const {
409  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(V))
410    if (VCE->getOpcode() == Instruction::PtrToInt)
411      if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
412        if (CE->getOpcode() == Instruction::GetElementPtr &&
413            CE->getNumOperands() == 3 &&
414            CE->getOperand(0)->isNullValue() &&
415            CE->getOperand(1)->isNullValue()) {
416          const Type *Ty =
417            cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
418          // Ignore vector types here so that ScalarEvolutionExpander doesn't
419          // emit getelementptrs that index into vectors.
420          if (Ty->isStructTy() || Ty->isArrayTy()) {
421            CTy = Ty;
422            FieldNo = CE->getOperand(2);
423            return true;
424          }
425        }
426
427  return false;
428}
429
430void SCEVUnknown::print(raw_ostream &OS) const {
431  const Type *AllocTy;
432  if (isSizeOf(AllocTy)) {
433    OS << "sizeof(" << *AllocTy << ")";
434    return;
435  }
436  if (isAlignOf(AllocTy)) {
437    OS << "alignof(" << *AllocTy << ")";
438    return;
439  }
440
441  const Type *CTy;
442  Constant *FieldNo;
443  if (isOffsetOf(CTy, FieldNo)) {
444    OS << "offsetof(" << *CTy << ", ";
445    WriteAsOperand(OS, FieldNo, false);
446    OS << ")";
447    return;
448  }
449
450  // Otherwise just print it normally.
451  WriteAsOperand(OS, V, false);
452}
453
454//===----------------------------------------------------------------------===//
455//                               SCEV Utilities
456//===----------------------------------------------------------------------===//
457
458static bool CompareTypes(const Type *A, const Type *B) {
459  if (A->getTypeID() != B->getTypeID())
460    return A->getTypeID() < B->getTypeID();
461  if (const IntegerType *AI = dyn_cast<IntegerType>(A)) {
462    const IntegerType *BI = cast<IntegerType>(B);
463    return AI->getBitWidth() < BI->getBitWidth();
464  }
465  if (const PointerType *AI = dyn_cast<PointerType>(A)) {
466    const PointerType *BI = cast<PointerType>(B);
467    return CompareTypes(AI->getElementType(), BI->getElementType());
468  }
469  if (const ArrayType *AI = dyn_cast<ArrayType>(A)) {
470    const ArrayType *BI = cast<ArrayType>(B);
471    if (AI->getNumElements() != BI->getNumElements())
472      return AI->getNumElements() < BI->getNumElements();
473    return CompareTypes(AI->getElementType(), BI->getElementType());
474  }
475  if (const VectorType *AI = dyn_cast<VectorType>(A)) {
476    const VectorType *BI = cast<VectorType>(B);
477    if (AI->getNumElements() != BI->getNumElements())
478      return AI->getNumElements() < BI->getNumElements();
479    return CompareTypes(AI->getElementType(), BI->getElementType());
480  }
481  if (const StructType *AI = dyn_cast<StructType>(A)) {
482    const StructType *BI = cast<StructType>(B);
483    if (AI->getNumElements() != BI->getNumElements())
484      return AI->getNumElements() < BI->getNumElements();
485    for (unsigned i = 0, e = AI->getNumElements(); i != e; ++i)
486      if (CompareTypes(AI->getElementType(i), BI->getElementType(i)) ||
487          CompareTypes(BI->getElementType(i), AI->getElementType(i)))
488        return CompareTypes(AI->getElementType(i), BI->getElementType(i));
489  }
490  return false;
491}
492
493namespace {
494  /// SCEVComplexityCompare - Return true if the complexity of the LHS is less
495  /// than the complexity of the RHS.  This comparator is used to canonicalize
496  /// expressions.
497  class SCEVComplexityCompare {
498    const LoopInfo *LI;
499  public:
500    explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {}
501
502    bool operator()(const SCEV *LHS, const SCEV *RHS) const {
503      // Fast-path: SCEVs are uniqued so we can do a quick equality check.
504      if (LHS == RHS)
505        return false;
506
507      // Primarily, sort the SCEVs by their getSCEVType().
508      unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
509      if (LType != RType)
510        return LType < RType;
511
512      // Aside from the getSCEVType() ordering, the particular ordering
513      // isn't very important except that it's beneficial to be consistent,
514      // so that (a + b) and (b + a) don't end up as different expressions.
515
516      // Sort SCEVUnknown values with some loose heuristics. TODO: This is
517      // not as complete as it could be.
518      if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS)) {
519        const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
520
521        // Order pointer values after integer values. This helps SCEVExpander
522        // form GEPs.
523        bool LIsPointer = LU->getType()->isPointerTy(),
524             RIsPointer = RU->getType()->isPointerTy();
525        if (LIsPointer != RIsPointer)
526          return RIsPointer;
527
528        // Compare getValueID values.
529        unsigned LID = LU->getValue()->getValueID(),
530                 RID = RU->getValue()->getValueID();
531        if (LID != RID)
532          return LID < RID;
533
534        // Sort arguments by their position.
535        if (const Argument *LA = dyn_cast<Argument>(LU->getValue())) {
536          const Argument *RA = cast<Argument>(RU->getValue());
537          return LA->getArgNo() < RA->getArgNo();
538        }
539
540        // For instructions, compare their loop depth, and their opcode.
541        // This is pretty loose.
542        if (const Instruction *LV = dyn_cast<Instruction>(LU->getValue())) {
543          const Instruction *RV = cast<Instruction>(RU->getValue());
544
545          // Compare loop depths.
546          unsigned LDepth = LI->getLoopDepth(LV->getParent()),
547                   RDepth = LI->getLoopDepth(RV->getParent());
548          if (LDepth != RDepth)
549            return LDepth < RDepth;
550
551          // Compare the number of operands.
552          unsigned LNumOps = LV->getNumOperands(),
553                   RNumOps = RV->getNumOperands();
554          if (LNumOps != RNumOps)
555            return LNumOps < RNumOps;
556        }
557
558        return false;
559      }
560
561      // Compare constant values.
562      if (const SCEVConstant *LC = dyn_cast<SCEVConstant>(LHS)) {
563        const SCEVConstant *RC = cast<SCEVConstant>(RHS);
564        const ConstantInt *LCC = LC->getValue();
565        const ConstantInt *RCC = RC->getValue();
566        unsigned LBitWidth = LCC->getBitWidth(), RBitWidth = RCC->getBitWidth();
567        if (LBitWidth != RBitWidth)
568          return LBitWidth < RBitWidth;
569        return LCC->getValue().ult(RCC->getValue());
570      }
571
572      // Compare addrec loop depths.
573      if (const SCEVAddRecExpr *LA = dyn_cast<SCEVAddRecExpr>(LHS)) {
574        const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
575        unsigned LDepth = LA->getLoop()->getLoopDepth(),
576                 RDepth = RA->getLoop()->getLoopDepth();
577        if (LDepth != RDepth)
578          return LDepth < RDepth;
579      }
580
581      // Lexicographically compare n-ary expressions.
582      if (const SCEVNAryExpr *LC = dyn_cast<SCEVNAryExpr>(LHS)) {
583        const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
584        unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
585        for (unsigned i = 0; i != LNumOps; ++i) {
586          if (i >= RNumOps)
587            return false;
588          const SCEV *LOp = LC->getOperand(i), *ROp = RC->getOperand(i);
589          if (operator()(LOp, ROp))
590            return true;
591          if (operator()(ROp, LOp))
592            return false;
593        }
594        return LNumOps < RNumOps;
595      }
596
597      // Lexicographically compare udiv expressions.
598      if (const SCEVUDivExpr *LC = dyn_cast<SCEVUDivExpr>(LHS)) {
599        const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
600        const SCEV *LL = LC->getLHS(), *LR = LC->getRHS(),
601                   *RL = RC->getLHS(), *RR = RC->getRHS();
602        if (operator()(LL, RL))
603          return true;
604        if (operator()(RL, LL))
605          return false;
606        if (operator()(LR, RR))
607          return true;
608        if (operator()(RR, LR))
609          return false;
610        return false;
611      }
612
613      // Compare cast expressions by operand.
614      if (const SCEVCastExpr *LC = dyn_cast<SCEVCastExpr>(LHS)) {
615        const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
616        return operator()(LC->getOperand(), RC->getOperand());
617      }
618
619      llvm_unreachable("Unknown SCEV kind!");
620      return false;
621    }
622  };
623}
624
625/// GroupByComplexity - Given a list of SCEV objects, order them by their
626/// complexity, and group objects of the same complexity together by value.
627/// When this routine is finished, we know that any duplicates in the vector are
628/// consecutive and that complexity is monotonically increasing.
629///
630/// Note that we go take special precautions to ensure that we get deterministic
631/// results from this routine.  In other words, we don't want the results of
632/// this to depend on where the addresses of various SCEV objects happened to
633/// land in memory.
634///
635static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
636                              LoopInfo *LI) {
637  if (Ops.size() < 2) return;  // Noop
638  if (Ops.size() == 2) {
639    // This is the common case, which also happens to be trivially simple.
640    // Special case it.
641    if (SCEVComplexityCompare(LI)(Ops[1], Ops[0]))
642      std::swap(Ops[0], Ops[1]);
643    return;
644  }
645
646  // Do the rough sort by complexity.
647  std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI));
648
649  // Now that we are sorted by complexity, group elements of the same
650  // complexity.  Note that this is, at worst, N^2, but the vector is likely to
651  // be extremely short in practice.  Note that we take this approach because we
652  // do not want to depend on the addresses of the objects we are grouping.
653  for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
654    const SCEV *S = Ops[i];
655    unsigned Complexity = S->getSCEVType();
656
657    // If there are any objects of the same complexity and same value as this
658    // one, group them.
659    for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
660      if (Ops[j] == S) { // Found a duplicate.
661        // Move it to immediately after i'th element.
662        std::swap(Ops[i+1], Ops[j]);
663        ++i;   // no need to rescan it.
664        if (i == e-2) return;  // Done!
665      }
666    }
667  }
668}
669
670
671
672//===----------------------------------------------------------------------===//
673//                      Simple SCEV method implementations
674//===----------------------------------------------------------------------===//
675
676/// BinomialCoefficient - Compute BC(It, K).  The result has width W.
677/// Assume, K > 0.
678static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
679                                       ScalarEvolution &SE,
680                                       const Type* ResultTy) {
681  // Handle the simplest case efficiently.
682  if (K == 1)
683    return SE.getTruncateOrZeroExtend(It, ResultTy);
684
685  // We are using the following formula for BC(It, K):
686  //
687  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
688  //
689  // Suppose, W is the bitwidth of the return value.  We must be prepared for
690  // overflow.  Hence, we must assure that the result of our computation is
691  // equal to the accurate one modulo 2^W.  Unfortunately, division isn't
692  // safe in modular arithmetic.
693  //
694  // However, this code doesn't use exactly that formula; the formula it uses
695  // is something like the following, where T is the number of factors of 2 in
696  // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
697  // exponentiation:
698  //
699  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
700  //
701  // This formula is trivially equivalent to the previous formula.  However,
702  // this formula can be implemented much more efficiently.  The trick is that
703  // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
704  // arithmetic.  To do exact division in modular arithmetic, all we have
705  // to do is multiply by the inverse.  Therefore, this step can be done at
706  // width W.
707  //
708  // The next issue is how to safely do the division by 2^T.  The way this
709  // is done is by doing the multiplication step at a width of at least W + T
710  // bits.  This way, the bottom W+T bits of the product are accurate. Then,
711  // when we perform the division by 2^T (which is equivalent to a right shift
712  // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get
713  // truncated out after the division by 2^T.
714  //
715  // In comparison to just directly using the first formula, this technique
716  // is much more efficient; using the first formula requires W * K bits,
717  // but this formula less than W + K bits. Also, the first formula requires
718  // a division step, whereas this formula only requires multiplies and shifts.
719  //
720  // It doesn't matter whether the subtraction step is done in the calculation
721  // width or the input iteration count's width; if the subtraction overflows,
722  // the result must be zero anyway.  We prefer here to do it in the width of
723  // the induction variable because it helps a lot for certain cases; CodeGen
724  // isn't smart enough to ignore the overflow, which leads to much less
725  // efficient code if the width of the subtraction is wider than the native
726  // register width.
727  //
728  // (It's possible to not widen at all by pulling out factors of 2 before
729  // the multiplication; for example, K=2 can be calculated as
730  // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
731  // extra arithmetic, so it's not an obvious win, and it gets
732  // much more complicated for K > 3.)
733
734  // Protection from insane SCEVs; this bound is conservative,
735  // but it probably doesn't matter.
736  if (K > 1000)
737    return SE.getCouldNotCompute();
738
739  unsigned W = SE.getTypeSizeInBits(ResultTy);
740
741  // Calculate K! / 2^T and T; we divide out the factors of two before
742  // multiplying for calculating K! / 2^T to avoid overflow.
743  // Other overflow doesn't matter because we only care about the bottom
744  // W bits of the result.
745  APInt OddFactorial(W, 1);
746  unsigned T = 1;
747  for (unsigned i = 3; i <= K; ++i) {
748    APInt Mult(W, i);
749    unsigned TwoFactors = Mult.countTrailingZeros();
750    T += TwoFactors;
751    Mult = Mult.lshr(TwoFactors);
752    OddFactorial *= Mult;
753  }
754
755  // We need at least W + T bits for the multiplication step
756  unsigned CalculationBits = W + T;
757
758  // Calculate 2^T, at width T+W.
759  APInt DivFactor = APInt(CalculationBits, 1).shl(T);
760
761  // Calculate the multiplicative inverse of K! / 2^T;
762  // this multiplication factor will perform the exact division by
763  // K! / 2^T.
764  APInt Mod = APInt::getSignedMinValue(W+1);
765  APInt MultiplyFactor = OddFactorial.zext(W+1);
766  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
767  MultiplyFactor = MultiplyFactor.trunc(W);
768
769  // Calculate the product, at width T+W
770  const IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
771                                                      CalculationBits);
772  const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
773  for (unsigned i = 1; i != K; ++i) {
774    const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
775    Dividend = SE.getMulExpr(Dividend,
776                             SE.getTruncateOrZeroExtend(S, CalculationTy));
777  }
778
779  // Divide by 2^T
780  const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
781
782  // Truncate the result, and divide by K! / 2^T.
783
784  return SE.getMulExpr(SE.getConstant(MultiplyFactor),
785                       SE.getTruncateOrZeroExtend(DivResult, ResultTy));
786}
787
788/// evaluateAtIteration - Return the value of this chain of recurrences at
789/// the specified iteration number.  We can evaluate this recurrence by
790/// multiplying each element in the chain by the binomial coefficient
791/// corresponding to it.  In other words, we can evaluate {A,+,B,+,C,+,D} as:
792///
793///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
794///
795/// where BC(It, k) stands for binomial coefficient.
796///
797const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
798                                                ScalarEvolution &SE) const {
799  const SCEV *Result = getStart();
800  for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
801    // The computation is correct in the face of overflow provided that the
802    // multiplication is performed _after_ the evaluation of the binomial
803    // coefficient.
804    const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType());
805    if (isa<SCEVCouldNotCompute>(Coeff))
806      return Coeff;
807
808    Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
809  }
810  return Result;
811}
812
813//===----------------------------------------------------------------------===//
814//                    SCEV Expression folder implementations
815//===----------------------------------------------------------------------===//
816
817const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,
818                                             const Type *Ty) {
819  assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
820         "This is not a truncating conversion!");
821  assert(isSCEVable(Ty) &&
822         "This is not a conversion to a SCEVable type!");
823  Ty = getEffectiveSCEVType(Ty);
824
825  FoldingSetNodeID ID;
826  ID.AddInteger(scTruncate);
827  ID.AddPointer(Op);
828  ID.AddPointer(Ty);
829  void *IP = 0;
830  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
831
832  // Fold if the operand is constant.
833  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
834    return getConstant(
835      cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(),
836                                               getEffectiveSCEVType(Ty))));
837
838  // trunc(trunc(x)) --> trunc(x)
839  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
840    return getTruncateExpr(ST->getOperand(), Ty);
841
842  // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
843  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
844    return getTruncateOrSignExtend(SS->getOperand(), Ty);
845
846  // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
847  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
848    return getTruncateOrZeroExtend(SZ->getOperand(), Ty);
849
850  // If the input value is a chrec scev, truncate the chrec's operands.
851  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
852    SmallVector<const SCEV *, 4> Operands;
853    for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
854      Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty));
855    return getAddRecExpr(Operands, AddRec->getLoop());
856  }
857
858  // As a special case, fold trunc(undef) to undef. We don't want to
859  // know too much about SCEVUnknowns, but this special case is handy
860  // and harmless.
861  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Op))
862    if (isa<UndefValue>(U->getValue()))
863      return getSCEV(UndefValue::get(Ty));
864
865  // The cast wasn't folded; create an explicit cast node. We can reuse
866  // the existing insert position since if we get here, we won't have
867  // made any changes which would invalidate it.
868  SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
869                                                 Op, Ty);
870  UniqueSCEVs.InsertNode(S, IP);
871  return S;
872}
873
874const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
875                                               const Type *Ty) {
876  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
877         "This is not an extending conversion!");
878  assert(isSCEVable(Ty) &&
879         "This is not a conversion to a SCEVable type!");
880  Ty = getEffectiveSCEVType(Ty);
881
882  // Fold if the operand is constant.
883  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
884    return getConstant(
885      cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(),
886                                              getEffectiveSCEVType(Ty))));
887
888  // zext(zext(x)) --> zext(x)
889  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
890    return getZeroExtendExpr(SZ->getOperand(), Ty);
891
892  // Before doing any expensive analysis, check to see if we've already
893  // computed a SCEV for this Op and Ty.
894  FoldingSetNodeID ID;
895  ID.AddInteger(scZeroExtend);
896  ID.AddPointer(Op);
897  ID.AddPointer(Ty);
898  void *IP = 0;
899  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
900
901  // If the input value is a chrec scev, and we can prove that the value
902  // did not overflow the old, smaller, value, we can zero extend all of the
903  // operands (often constants).  This allows analysis of something like
904  // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
905  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
906    if (AR->isAffine()) {
907      const SCEV *Start = AR->getStart();
908      const SCEV *Step = AR->getStepRecurrence(*this);
909      unsigned BitWidth = getTypeSizeInBits(AR->getType());
910      const Loop *L = AR->getLoop();
911
912      // If we have special knowledge that this addrec won't overflow,
913      // we don't need to do any further analysis.
914      if (AR->hasNoUnsignedWrap())
915        return getAddRecExpr(getZeroExtendExpr(Start, Ty),
916                             getZeroExtendExpr(Step, Ty),
917                             L);
918
919      // Check whether the backedge-taken count is SCEVCouldNotCompute.
920      // Note that this serves two purposes: It filters out loops that are
921      // simply not analyzable, and it covers the case where this code is
922      // being called from within backedge-taken count analysis, such that
923      // attempting to ask for the backedge-taken count would likely result
924      // in infinite recursion. In the later case, the analysis code will
925      // cope with a conservative value, and it will take care to purge
926      // that value once it has finished.
927      const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
928      if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
929        // Manually compute the final value for AR, checking for
930        // overflow.
931
932        // Check whether the backedge-taken count can be losslessly casted to
933        // the addrec's type. The count is always unsigned.
934        const SCEV *CastedMaxBECount =
935          getTruncateOrZeroExtend(MaxBECount, Start->getType());
936        const SCEV *RecastedMaxBECount =
937          getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
938        if (MaxBECount == RecastedMaxBECount) {
939          const Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
940          // Check whether Start+Step*MaxBECount has no unsigned overflow.
941          const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step);
942          const SCEV *Add = getAddExpr(Start, ZMul);
943          const SCEV *OperandExtendedAdd =
944            getAddExpr(getZeroExtendExpr(Start, WideTy),
945                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
946                                  getZeroExtendExpr(Step, WideTy)));
947          if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd)
948            // Return the expression with the addrec on the outside.
949            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
950                                 getZeroExtendExpr(Step, Ty),
951                                 L);
952
953          // Similar to above, only this time treat the step value as signed.
954          // This covers loops that count down.
955          const SCEV *SMul = getMulExpr(CastedMaxBECount, Step);
956          Add = getAddExpr(Start, SMul);
957          OperandExtendedAdd =
958            getAddExpr(getZeroExtendExpr(Start, WideTy),
959                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
960                                  getSignExtendExpr(Step, WideTy)));
961          if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd)
962            // Return the expression with the addrec on the outside.
963            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
964                                 getSignExtendExpr(Step, Ty),
965                                 L);
966        }
967
968        // If the backedge is guarded by a comparison with the pre-inc value
969        // the addrec is safe. Also, if the entry is guarded by a comparison
970        // with the start value and the backedge is guarded by a comparison
971        // with the post-inc value, the addrec is safe.
972        if (isKnownPositive(Step)) {
973          const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
974                                      getUnsignedRange(Step).getUnsignedMax());
975          if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
976              (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) &&
977               isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT,
978                                           AR->getPostIncExpr(*this), N)))
979            // Return the expression with the addrec on the outside.
980            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
981                                 getZeroExtendExpr(Step, Ty),
982                                 L);
983        } else if (isKnownNegative(Step)) {
984          const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
985                                      getSignedRange(Step).getSignedMin());
986          if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
987              (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) &&
988               isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT,
989                                           AR->getPostIncExpr(*this), N)))
990            // Return the expression with the addrec on the outside.
991            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
992                                 getSignExtendExpr(Step, Ty),
993                                 L);
994        }
995      }
996    }
997
998  // The cast wasn't folded; create an explicit cast node.
999  // Recompute the insert position, as it may have been invalidated.
1000  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1001  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1002                                                   Op, Ty);
1003  UniqueSCEVs.InsertNode(S, IP);
1004  return S;
1005}
1006
1007const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
1008                                               const Type *Ty) {
1009  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1010         "This is not an extending conversion!");
1011  assert(isSCEVable(Ty) &&
1012         "This is not a conversion to a SCEVable type!");
1013  Ty = getEffectiveSCEVType(Ty);
1014
1015  // Fold if the operand is constant.
1016  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1017    return getConstant(
1018      cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(),
1019                                              getEffectiveSCEVType(Ty))));
1020
1021  // sext(sext(x)) --> sext(x)
1022  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1023    return getSignExtendExpr(SS->getOperand(), Ty);
1024
1025  // Before doing any expensive analysis, check to see if we've already
1026  // computed a SCEV for this Op and Ty.
1027  FoldingSetNodeID ID;
1028  ID.AddInteger(scSignExtend);
1029  ID.AddPointer(Op);
1030  ID.AddPointer(Ty);
1031  void *IP = 0;
1032  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1033
1034  // If the input value is a chrec scev, and we can prove that the value
1035  // did not overflow the old, smaller, value, we can sign extend all of the
1036  // operands (often constants).  This allows analysis of something like
1037  // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
1038  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1039    if (AR->isAffine()) {
1040      const SCEV *Start = AR->getStart();
1041      const SCEV *Step = AR->getStepRecurrence(*this);
1042      unsigned BitWidth = getTypeSizeInBits(AR->getType());
1043      const Loop *L = AR->getLoop();
1044
1045      // If we have special knowledge that this addrec won't overflow,
1046      // we don't need to do any further analysis.
1047      if (AR->hasNoSignedWrap())
1048        return getAddRecExpr(getSignExtendExpr(Start, Ty),
1049                             getSignExtendExpr(Step, Ty),
1050                             L);
1051
1052      // Check whether the backedge-taken count is SCEVCouldNotCompute.
1053      // Note that this serves two purposes: It filters out loops that are
1054      // simply not analyzable, and it covers the case where this code is
1055      // being called from within backedge-taken count analysis, such that
1056      // attempting to ask for the backedge-taken count would likely result
1057      // in infinite recursion. In the later case, the analysis code will
1058      // cope with a conservative value, and it will take care to purge
1059      // that value once it has finished.
1060      const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
1061      if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1062        // Manually compute the final value for AR, checking for
1063        // overflow.
1064
1065        // Check whether the backedge-taken count can be losslessly casted to
1066        // the addrec's type. The count is always unsigned.
1067        const SCEV *CastedMaxBECount =
1068          getTruncateOrZeroExtend(MaxBECount, Start->getType());
1069        const SCEV *RecastedMaxBECount =
1070          getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
1071        if (MaxBECount == RecastedMaxBECount) {
1072          const Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1073          // Check whether Start+Step*MaxBECount has no signed overflow.
1074          const SCEV *SMul = getMulExpr(CastedMaxBECount, Step);
1075          const SCEV *Add = getAddExpr(Start, SMul);
1076          const SCEV *OperandExtendedAdd =
1077            getAddExpr(getSignExtendExpr(Start, WideTy),
1078                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
1079                                  getSignExtendExpr(Step, WideTy)));
1080          if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd)
1081            // Return the expression with the addrec on the outside.
1082            return getAddRecExpr(getSignExtendExpr(Start, Ty),
1083                                 getSignExtendExpr(Step, Ty),
1084                                 L);
1085
1086          // Similar to above, only this time treat the step value as unsigned.
1087          // This covers loops that count up with an unsigned step.
1088          const SCEV *UMul = getMulExpr(CastedMaxBECount, Step);
1089          Add = getAddExpr(Start, UMul);
1090          OperandExtendedAdd =
1091            getAddExpr(getSignExtendExpr(Start, WideTy),
1092                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
1093                                  getZeroExtendExpr(Step, WideTy)));
1094          if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd)
1095            // Return the expression with the addrec on the outside.
1096            return getAddRecExpr(getSignExtendExpr(Start, Ty),
1097                                 getZeroExtendExpr(Step, Ty),
1098                                 L);
1099        }
1100
1101        // If the backedge is guarded by a comparison with the pre-inc value
1102        // the addrec is safe. Also, if the entry is guarded by a comparison
1103        // with the start value and the backedge is guarded by a comparison
1104        // with the post-inc value, the addrec is safe.
1105        if (isKnownPositive(Step)) {
1106          const SCEV *N = getConstant(APInt::getSignedMinValue(BitWidth) -
1107                                      getSignedRange(Step).getSignedMax());
1108          if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SLT, AR, N) ||
1109              (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SLT, Start, N) &&
1110               isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SLT,
1111                                           AR->getPostIncExpr(*this), N)))
1112            // Return the expression with the addrec on the outside.
1113            return getAddRecExpr(getSignExtendExpr(Start, Ty),
1114                                 getSignExtendExpr(Step, Ty),
1115                                 L);
1116        } else if (isKnownNegative(Step)) {
1117          const SCEV *N = getConstant(APInt::getSignedMaxValue(BitWidth) -
1118                                      getSignedRange(Step).getSignedMin());
1119          if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SGT, AR, N) ||
1120              (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SGT, Start, N) &&
1121               isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SGT,
1122                                           AR->getPostIncExpr(*this), N)))
1123            // Return the expression with the addrec on the outside.
1124            return getAddRecExpr(getSignExtendExpr(Start, Ty),
1125                                 getSignExtendExpr(Step, Ty),
1126                                 L);
1127        }
1128      }
1129    }
1130
1131  // The cast wasn't folded; create an explicit cast node.
1132  // Recompute the insert position, as it may have been invalidated.
1133  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1134  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1135                                                   Op, Ty);
1136  UniqueSCEVs.InsertNode(S, IP);
1137  return S;
1138}
1139
1140/// getAnyExtendExpr - Return a SCEV for the given operand extended with
1141/// unspecified bits out to the given type.
1142///
1143const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
1144                                              const Type *Ty) {
1145  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1146         "This is not an extending conversion!");
1147  assert(isSCEVable(Ty) &&
1148         "This is not a conversion to a SCEVable type!");
1149  Ty = getEffectiveSCEVType(Ty);
1150
1151  // Sign-extend negative constants.
1152  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1153    if (SC->getValue()->getValue().isNegative())
1154      return getSignExtendExpr(Op, Ty);
1155
1156  // Peel off a truncate cast.
1157  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
1158    const SCEV *NewOp = T->getOperand();
1159    if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
1160      return getAnyExtendExpr(NewOp, Ty);
1161    return getTruncateOrNoop(NewOp, Ty);
1162  }
1163
1164  // Next try a zext cast. If the cast is folded, use it.
1165  const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
1166  if (!isa<SCEVZeroExtendExpr>(ZExt))
1167    return ZExt;
1168
1169  // Next try a sext cast. If the cast is folded, use it.
1170  const SCEV *SExt = getSignExtendExpr(Op, Ty);
1171  if (!isa<SCEVSignExtendExpr>(SExt))
1172    return SExt;
1173
1174  // Force the cast to be folded into the operands of an addrec.
1175  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
1176    SmallVector<const SCEV *, 4> Ops;
1177    for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
1178         I != E; ++I)
1179      Ops.push_back(getAnyExtendExpr(*I, Ty));
1180    return getAddRecExpr(Ops, AR->getLoop());
1181  }
1182
1183  // As a special case, fold anyext(undef) to undef. We don't want to
1184  // know too much about SCEVUnknowns, but this special case is handy
1185  // and harmless.
1186  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Op))
1187    if (isa<UndefValue>(U->getValue()))
1188      return getSCEV(UndefValue::get(Ty));
1189
1190  // If the expression is obviously signed, use the sext cast value.
1191  if (isa<SCEVSMaxExpr>(Op))
1192    return SExt;
1193
1194  // Absent any other information, use the zext cast value.
1195  return ZExt;
1196}
1197
1198/// CollectAddOperandsWithScales - Process the given Ops list, which is
1199/// a list of operands to be added under the given scale, update the given
1200/// map. This is a helper function for getAddRecExpr. As an example of
1201/// what it does, given a sequence of operands that would form an add
1202/// expression like this:
1203///
1204///    m + n + 13 + (A * (o + p + (B * q + m + 29))) + r + (-1 * r)
1205///
1206/// where A and B are constants, update the map with these values:
1207///
1208///    (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
1209///
1210/// and add 13 + A*B*29 to AccumulatedConstant.
1211/// This will allow getAddRecExpr to produce this:
1212///
1213///    13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
1214///
1215/// This form often exposes folding opportunities that are hidden in
1216/// the original operand list.
1217///
1218/// Return true iff it appears that any interesting folding opportunities
1219/// may be exposed. This helps getAddRecExpr short-circuit extra work in
1220/// the common case where no interesting opportunities are present, and
1221/// is also used as a check to avoid infinite recursion.
1222///
1223static bool
1224CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
1225                             SmallVector<const SCEV *, 8> &NewOps,
1226                             APInt &AccumulatedConstant,
1227                             const SCEV *const *Ops, size_t NumOperands,
1228                             const APInt &Scale,
1229                             ScalarEvolution &SE) {
1230  bool Interesting = false;
1231
1232  // Iterate over the add operands. They are sorted, with constants first.
1233  unsigned i = 0;
1234  while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1235    ++i;
1236    // Pull a buried constant out to the outside.
1237    if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
1238      Interesting = true;
1239    AccumulatedConstant += Scale * C->getValue()->getValue();
1240  }
1241
1242  // Next comes everything else. We're especially interested in multiplies
1243  // here, but they're in the middle, so just visit the rest with one loop.
1244  for (; i != NumOperands; ++i) {
1245    const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
1246    if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
1247      APInt NewScale =
1248        Scale * cast<SCEVConstant>(Mul->getOperand(0))->getValue()->getValue();
1249      if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
1250        // A multiplication of a constant with another add; recurse.
1251        const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
1252        Interesting |=
1253          CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1254                                       Add->op_begin(), Add->getNumOperands(),
1255                                       NewScale, SE);
1256      } else {
1257        // A multiplication of a constant with some other value. Update
1258        // the map.
1259        SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end());
1260        const SCEV *Key = SE.getMulExpr(MulOps);
1261        std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
1262          M.insert(std::make_pair(Key, NewScale));
1263        if (Pair.second) {
1264          NewOps.push_back(Pair.first->first);
1265        } else {
1266          Pair.first->second += NewScale;
1267          // The map already had an entry for this value, which may indicate
1268          // a folding opportunity.
1269          Interesting = true;
1270        }
1271      }
1272    } else {
1273      // An ordinary operand. Update the map.
1274      std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
1275        M.insert(std::make_pair(Ops[i], Scale));
1276      if (Pair.second) {
1277        NewOps.push_back(Pair.first->first);
1278      } else {
1279        Pair.first->second += Scale;
1280        // The map already had an entry for this value, which may indicate
1281        // a folding opportunity.
1282        Interesting = true;
1283      }
1284    }
1285  }
1286
1287  return Interesting;
1288}
1289
1290namespace {
1291  struct APIntCompare {
1292    bool operator()(const APInt &LHS, const APInt &RHS) const {
1293      return LHS.ult(RHS);
1294    }
1295  };
1296}
1297
1298/// getAddExpr - Get a canonical add expression, or something simpler if
1299/// possible.
1300const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
1301                                        bool HasNUW, bool HasNSW) {
1302  assert(!Ops.empty() && "Cannot get empty add!");
1303  if (Ops.size() == 1) return Ops[0];
1304#ifndef NDEBUG
1305  const Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
1306  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1307    assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
1308           "SCEVAddExpr operand types don't match!");
1309#endif
1310
1311  // If HasNSW is true and all the operands are non-negative, infer HasNUW.
1312  if (!HasNUW && HasNSW) {
1313    bool All = true;
1314    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1315      if (!isKnownNonNegative(Ops[i])) {
1316        All = false;
1317        break;
1318      }
1319    if (All) HasNUW = true;
1320  }
1321
1322  // Sort by complexity, this groups all similar expression types together.
1323  GroupByComplexity(Ops, LI);
1324
1325  // If there are any constants, fold them together.
1326  unsigned Idx = 0;
1327  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1328    ++Idx;
1329    assert(Idx < Ops.size());
1330    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1331      // We found two constants, fold them together!
1332      Ops[0] = getConstant(LHSC->getValue()->getValue() +
1333                           RHSC->getValue()->getValue());
1334      if (Ops.size() == 2) return Ops[0];
1335      Ops.erase(Ops.begin()+1);  // Erase the folded element
1336      LHSC = cast<SCEVConstant>(Ops[0]);
1337    }
1338
1339    // If we are left with a constant zero being added, strip it off.
1340    if (LHSC->getValue()->isZero()) {
1341      Ops.erase(Ops.begin());
1342      --Idx;
1343    }
1344
1345    if (Ops.size() == 1) return Ops[0];
1346  }
1347
1348  // Okay, check to see if the same value occurs in the operand list twice.  If
1349  // so, merge them together into an multiply expression.  Since we sorted the
1350  // list, these values are required to be adjacent.
1351  const Type *Ty = Ops[0]->getType();
1352  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1353    if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2
1354      // Found a match, merge the two values into a multiply, and add any
1355      // remaining values to the result.
1356      const SCEV *Two = getConstant(Ty, 2);
1357      const SCEV *Mul = getMulExpr(Ops[i], Two);
1358      if (Ops.size() == 2)
1359        return Mul;
1360      Ops.erase(Ops.begin()+i, Ops.begin()+i+2);
1361      Ops.push_back(Mul);
1362      return getAddExpr(Ops, HasNUW, HasNSW);
1363    }
1364
1365  // Check for truncates. If all the operands are truncated from the same
1366  // type, see if factoring out the truncate would permit the result to be
1367  // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n)
1368  // if the contents of the resulting outer trunc fold to something simple.
1369  for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) {
1370    const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]);
1371    const Type *DstType = Trunc->getType();
1372    const Type *SrcType = Trunc->getOperand()->getType();
1373    SmallVector<const SCEV *, 8> LargeOps;
1374    bool Ok = true;
1375    // Check all the operands to see if they can be represented in the
1376    // source type of the truncate.
1377    for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
1378      if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
1379        if (T->getOperand()->getType() != SrcType) {
1380          Ok = false;
1381          break;
1382        }
1383        LargeOps.push_back(T->getOperand());
1384      } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1385        LargeOps.push_back(getAnyExtendExpr(C, SrcType));
1386      } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
1387        SmallVector<const SCEV *, 8> LargeMulOps;
1388        for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
1389          if (const SCEVTruncateExpr *T =
1390                dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
1391            if (T->getOperand()->getType() != SrcType) {
1392              Ok = false;
1393              break;
1394            }
1395            LargeMulOps.push_back(T->getOperand());
1396          } else if (const SCEVConstant *C =
1397                       dyn_cast<SCEVConstant>(M->getOperand(j))) {
1398            LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
1399          } else {
1400            Ok = false;
1401            break;
1402          }
1403        }
1404        if (Ok)
1405          LargeOps.push_back(getMulExpr(LargeMulOps));
1406      } else {
1407        Ok = false;
1408        break;
1409      }
1410    }
1411    if (Ok) {
1412      // Evaluate the expression in the larger type.
1413      const SCEV *Fold = getAddExpr(LargeOps, HasNUW, HasNSW);
1414      // If it folds to something simple, use it. Otherwise, don't.
1415      if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
1416        return getTruncateExpr(Fold, DstType);
1417    }
1418  }
1419
1420  // Skip past any other cast SCEVs.
1421  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
1422    ++Idx;
1423
1424  // If there are add operands they would be next.
1425  if (Idx < Ops.size()) {
1426    bool DeletedAdd = false;
1427    while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
1428      // If we have an add, expand the add operands onto the end of the operands
1429      // list.
1430      Ops.erase(Ops.begin()+Idx);
1431      Ops.append(Add->op_begin(), Add->op_end());
1432      DeletedAdd = true;
1433    }
1434
1435    // If we deleted at least one add, we added operands to the end of the list,
1436    // and they are not necessarily sorted.  Recurse to resort and resimplify
1437    // any operands we just acquired.
1438    if (DeletedAdd)
1439      return getAddExpr(Ops);
1440  }
1441
1442  // Skip over the add expression until we get to a multiply.
1443  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1444    ++Idx;
1445
1446  // Check to see if there are any folding opportunities present with
1447  // operands multiplied by constant values.
1448  if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
1449    uint64_t BitWidth = getTypeSizeInBits(Ty);
1450    DenseMap<const SCEV *, APInt> M;
1451    SmallVector<const SCEV *, 8> NewOps;
1452    APInt AccumulatedConstant(BitWidth, 0);
1453    if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1454                                     Ops.data(), Ops.size(),
1455                                     APInt(BitWidth, 1), *this)) {
1456      // Some interesting folding opportunity is present, so its worthwhile to
1457      // re-generate the operands list. Group the operands by constant scale,
1458      // to avoid multiplying by the same constant scale multiple times.
1459      std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
1460      for (SmallVector<const SCEV *, 8>::iterator I = NewOps.begin(),
1461           E = NewOps.end(); I != E; ++I)
1462        MulOpLists[M.find(*I)->second].push_back(*I);
1463      // Re-generate the operands list.
1464      Ops.clear();
1465      if (AccumulatedConstant != 0)
1466        Ops.push_back(getConstant(AccumulatedConstant));
1467      for (std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare>::iterator
1468           I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I)
1469        if (I->first != 0)
1470          Ops.push_back(getMulExpr(getConstant(I->first),
1471                                   getAddExpr(I->second)));
1472      if (Ops.empty())
1473        return getConstant(Ty, 0);
1474      if (Ops.size() == 1)
1475        return Ops[0];
1476      return getAddExpr(Ops);
1477    }
1478  }
1479
1480  // If we are adding something to a multiply expression, make sure the
1481  // something is not already an operand of the multiply.  If so, merge it into
1482  // the multiply.
1483  for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
1484    const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
1485    for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
1486      const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
1487      for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
1488        if (MulOpSCEV == Ops[AddOp] && !isa<SCEVConstant>(Ops[AddOp])) {
1489          // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1))
1490          const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
1491          if (Mul->getNumOperands() != 2) {
1492            // If the multiply has more than two operands, we must get the
1493            // Y*Z term.
1494            SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(), Mul->op_end());
1495            MulOps.erase(MulOps.begin()+MulOp);
1496            InnerMul = getMulExpr(MulOps);
1497          }
1498          const SCEV *One = getConstant(Ty, 1);
1499          const SCEV *AddOne = getAddExpr(InnerMul, One);
1500          const SCEV *OuterMul = getMulExpr(AddOne, Ops[AddOp]);
1501          if (Ops.size() == 2) return OuterMul;
1502          if (AddOp < Idx) {
1503            Ops.erase(Ops.begin()+AddOp);
1504            Ops.erase(Ops.begin()+Idx-1);
1505          } else {
1506            Ops.erase(Ops.begin()+Idx);
1507            Ops.erase(Ops.begin()+AddOp-1);
1508          }
1509          Ops.push_back(OuterMul);
1510          return getAddExpr(Ops);
1511        }
1512
1513      // Check this multiply against other multiplies being added together.
1514      for (unsigned OtherMulIdx = Idx+1;
1515           OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
1516           ++OtherMulIdx) {
1517        const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
1518        // If MulOp occurs in OtherMul, we can fold the two multiplies
1519        // together.
1520        for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
1521             OMulOp != e; ++OMulOp)
1522          if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
1523            // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
1524            const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
1525            if (Mul->getNumOperands() != 2) {
1526              SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
1527                                                  Mul->op_end());
1528              MulOps.erase(MulOps.begin()+MulOp);
1529              InnerMul1 = getMulExpr(MulOps);
1530            }
1531            const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
1532            if (OtherMul->getNumOperands() != 2) {
1533              SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
1534                                                  OtherMul->op_end());
1535              MulOps.erase(MulOps.begin()+OMulOp);
1536              InnerMul2 = getMulExpr(MulOps);
1537            }
1538            const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2);
1539            const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum);
1540            if (Ops.size() == 2) return OuterMul;
1541            Ops.erase(Ops.begin()+Idx);
1542            Ops.erase(Ops.begin()+OtherMulIdx-1);
1543            Ops.push_back(OuterMul);
1544            return getAddExpr(Ops);
1545          }
1546      }
1547    }
1548  }
1549
1550  // If there are any add recurrences in the operands list, see if any other
1551  // added values are loop invariant.  If so, we can fold them into the
1552  // recurrence.
1553  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1554    ++Idx;
1555
1556  // Scan over all recurrences, trying to fold loop invariants into them.
1557  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1558    // Scan all of the other operands to this add and add them to the vector if
1559    // they are loop invariant w.r.t. the recurrence.
1560    SmallVector<const SCEV *, 8> LIOps;
1561    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1562    const Loop *AddRecLoop = AddRec->getLoop();
1563    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1564      if (Ops[i]->isLoopInvariant(AddRecLoop)) {
1565        LIOps.push_back(Ops[i]);
1566        Ops.erase(Ops.begin()+i);
1567        --i; --e;
1568      }
1569
1570    // If we found some loop invariants, fold them into the recurrence.
1571    if (!LIOps.empty()) {
1572      //  NLI + LI + {Start,+,Step}  -->  NLI + {LI+Start,+,Step}
1573      LIOps.push_back(AddRec->getStart());
1574
1575      SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
1576                                             AddRec->op_end());
1577      AddRecOps[0] = getAddExpr(LIOps);
1578
1579      // Build the new addrec. Propagate the NUW and NSW flags if both the
1580      // outer add and the inner addrec are guaranteed to have no overflow.
1581      const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop,
1582                                         HasNUW && AddRec->hasNoUnsignedWrap(),
1583                                         HasNSW && AddRec->hasNoSignedWrap());
1584
1585      // If all of the other operands were loop invariant, we are done.
1586      if (Ops.size() == 1) return NewRec;
1587
1588      // Otherwise, add the folded AddRec by the non-liv parts.
1589      for (unsigned i = 0;; ++i)
1590        if (Ops[i] == AddRec) {
1591          Ops[i] = NewRec;
1592          break;
1593        }
1594      return getAddExpr(Ops);
1595    }
1596
1597    // Okay, if there weren't any loop invariants to be folded, check to see if
1598    // there are multiple AddRec's with the same loop induction variable being
1599    // added together.  If so, we can fold them.
1600    for (unsigned OtherIdx = Idx+1;
1601         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
1602      if (OtherIdx != Idx) {
1603        const SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
1604        if (AddRecLoop == OtherAddRec->getLoop()) {
1605          // Other + {A,+,B} + {C,+,D}  -->  Other + {A+C,+,B+D}
1606          SmallVector<const SCEV *, 4> NewOps(AddRec->op_begin(),
1607                                              AddRec->op_end());
1608          for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) {
1609            if (i >= NewOps.size()) {
1610              NewOps.append(OtherAddRec->op_begin()+i,
1611                            OtherAddRec->op_end());
1612              break;
1613            }
1614            NewOps[i] = getAddExpr(NewOps[i], OtherAddRec->getOperand(i));
1615          }
1616          const SCEV *NewAddRec = getAddRecExpr(NewOps, AddRecLoop);
1617
1618          if (Ops.size() == 2) return NewAddRec;
1619
1620          Ops.erase(Ops.begin()+Idx);
1621          Ops.erase(Ops.begin()+OtherIdx-1);
1622          Ops.push_back(NewAddRec);
1623          return getAddExpr(Ops);
1624        }
1625      }
1626
1627    // Otherwise couldn't fold anything into this recurrence.  Move onto the
1628    // next one.
1629  }
1630
1631  // Okay, it looks like we really DO need an add expr.  Check to see if we
1632  // already have one, otherwise create a new one.
1633  FoldingSetNodeID ID;
1634  ID.AddInteger(scAddExpr);
1635  ID.AddInteger(Ops.size());
1636  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1637    ID.AddPointer(Ops[i]);
1638  void *IP = 0;
1639  SCEVAddExpr *S =
1640    static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1641  if (!S) {
1642    const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
1643    std::uninitialized_copy(Ops.begin(), Ops.end(), O);
1644    S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator),
1645                                        O, Ops.size());
1646    UniqueSCEVs.InsertNode(S, IP);
1647  }
1648  if (HasNUW) S->setHasNoUnsignedWrap(true);
1649  if (HasNSW) S->setHasNoSignedWrap(true);
1650  return S;
1651}
1652
1653/// getMulExpr - Get a canonical multiply expression, or something simpler if
1654/// possible.
1655const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
1656                                        bool HasNUW, bool HasNSW) {
1657  assert(!Ops.empty() && "Cannot get empty mul!");
1658  if (Ops.size() == 1) return Ops[0];
1659#ifndef NDEBUG
1660  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1661    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
1662           getEffectiveSCEVType(Ops[0]->getType()) &&
1663           "SCEVMulExpr operand types don't match!");
1664#endif
1665
1666  // If HasNSW is true and all the operands are non-negative, infer HasNUW.
1667  if (!HasNUW && HasNSW) {
1668    bool All = true;
1669    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1670      if (!isKnownNonNegative(Ops[i])) {
1671        All = false;
1672        break;
1673      }
1674    if (All) HasNUW = true;
1675  }
1676
1677  // Sort by complexity, this groups all similar expression types together.
1678  GroupByComplexity(Ops, LI);
1679
1680  // If there are any constants, fold them together.
1681  unsigned Idx = 0;
1682  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1683
1684    // C1*(C2+V) -> C1*C2 + C1*V
1685    if (Ops.size() == 2)
1686      if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
1687        if (Add->getNumOperands() == 2 &&
1688            isa<SCEVConstant>(Add->getOperand(0)))
1689          return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
1690                            getMulExpr(LHSC, Add->getOperand(1)));
1691
1692    ++Idx;
1693    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1694      // We found two constants, fold them together!
1695      ConstantInt *Fold = ConstantInt::get(getContext(),
1696                                           LHSC->getValue()->getValue() *
1697                                           RHSC->getValue()->getValue());
1698      Ops[0] = getConstant(Fold);
1699      Ops.erase(Ops.begin()+1);  // Erase the folded element
1700      if (Ops.size() == 1) return Ops[0];
1701      LHSC = cast<SCEVConstant>(Ops[0]);
1702    }
1703
1704    // If we are left with a constant one being multiplied, strip it off.
1705    if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) {
1706      Ops.erase(Ops.begin());
1707      --Idx;
1708    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
1709      // If we have a multiply of zero, it will always be zero.
1710      return Ops[0];
1711    } else if (Ops[0]->isAllOnesValue()) {
1712      // If we have a mul by -1 of an add, try distributing the -1 among the
1713      // add operands.
1714      if (Ops.size() == 2)
1715        if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
1716          SmallVector<const SCEV *, 4> NewOps;
1717          bool AnyFolded = false;
1718          for (SCEVAddRecExpr::op_iterator I = Add->op_begin(), E = Add->op_end();
1719               I != E; ++I) {
1720            const SCEV *Mul = getMulExpr(Ops[0], *I);
1721            if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
1722            NewOps.push_back(Mul);
1723          }
1724          if (AnyFolded)
1725            return getAddExpr(NewOps);
1726        }
1727    }
1728
1729    if (Ops.size() == 1)
1730      return Ops[0];
1731  }
1732
1733  // Skip over the add expression until we get to a multiply.
1734  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1735    ++Idx;
1736
1737  // If there are mul operands inline them all into this expression.
1738  if (Idx < Ops.size()) {
1739    bool DeletedMul = false;
1740    while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
1741      // If we have an mul, expand the mul operands onto the end of the operands
1742      // list.
1743      Ops.erase(Ops.begin()+Idx);
1744      Ops.append(Mul->op_begin(), Mul->op_end());
1745      DeletedMul = true;
1746    }
1747
1748    // If we deleted at least one mul, we added operands to the end of the list,
1749    // and they are not necessarily sorted.  Recurse to resort and resimplify
1750    // any operands we just acquired.
1751    if (DeletedMul)
1752      return getMulExpr(Ops);
1753  }
1754
1755  // If there are any add recurrences in the operands list, see if any other
1756  // added values are loop invariant.  If so, we can fold them into the
1757  // recurrence.
1758  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1759    ++Idx;
1760
1761  // Scan over all recurrences, trying to fold loop invariants into them.
1762  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1763    // Scan all of the other operands to this mul and add them to the vector if
1764    // they are loop invariant w.r.t. the recurrence.
1765    SmallVector<const SCEV *, 8> LIOps;
1766    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1767    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1768      if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
1769        LIOps.push_back(Ops[i]);
1770        Ops.erase(Ops.begin()+i);
1771        --i; --e;
1772      }
1773
1774    // If we found some loop invariants, fold them into the recurrence.
1775    if (!LIOps.empty()) {
1776      //  NLI * LI * {Start,+,Step}  -->  NLI * {LI*Start,+,LI*Step}
1777      SmallVector<const SCEV *, 4> NewOps;
1778      NewOps.reserve(AddRec->getNumOperands());
1779      const SCEV *Scale = getMulExpr(LIOps);
1780      for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
1781        NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i)));
1782
1783      // Build the new addrec. Propagate the NUW and NSW flags if both the
1784      // outer mul and the inner addrec are guaranteed to have no overflow.
1785      const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(),
1786                                         HasNUW && AddRec->hasNoUnsignedWrap(),
1787                                         HasNSW && AddRec->hasNoSignedWrap());
1788
1789      // If all of the other operands were loop invariant, we are done.
1790      if (Ops.size() == 1) return NewRec;
1791
1792      // Otherwise, multiply the folded AddRec by the non-liv parts.
1793      for (unsigned i = 0;; ++i)
1794        if (Ops[i] == AddRec) {
1795          Ops[i] = NewRec;
1796          break;
1797        }
1798      return getMulExpr(Ops);
1799    }
1800
1801    // Okay, if there weren't any loop invariants to be folded, check to see if
1802    // there are multiple AddRec's with the same loop induction variable being
1803    // multiplied together.  If so, we can fold them.
1804    for (unsigned OtherIdx = Idx+1;
1805         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
1806      if (OtherIdx != Idx) {
1807        const SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
1808        if (AddRec->getLoop() == OtherAddRec->getLoop()) {
1809          // F * G  -->  {A,+,B} * {C,+,D}  -->  {A*C,+,F*D + G*B + B*D}
1810          const SCEVAddRecExpr *F = AddRec, *G = OtherAddRec;
1811          const SCEV *NewStart = getMulExpr(F->getStart(),
1812                                                 G->getStart());
1813          const SCEV *B = F->getStepRecurrence(*this);
1814          const SCEV *D = G->getStepRecurrence(*this);
1815          const SCEV *NewStep = getAddExpr(getMulExpr(F, D),
1816                                          getMulExpr(G, B),
1817                                          getMulExpr(B, D));
1818          const SCEV *NewAddRec = getAddRecExpr(NewStart, NewStep,
1819                                               F->getLoop());
1820          if (Ops.size() == 2) return NewAddRec;
1821
1822          Ops.erase(Ops.begin()+Idx);
1823          Ops.erase(Ops.begin()+OtherIdx-1);
1824          Ops.push_back(NewAddRec);
1825          return getMulExpr(Ops);
1826        }
1827      }
1828
1829    // Otherwise couldn't fold anything into this recurrence.  Move onto the
1830    // next one.
1831  }
1832
1833  // Okay, it looks like we really DO need an mul expr.  Check to see if we
1834  // already have one, otherwise create a new one.
1835  FoldingSetNodeID ID;
1836  ID.AddInteger(scMulExpr);
1837  ID.AddInteger(Ops.size());
1838  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1839    ID.AddPointer(Ops[i]);
1840  void *IP = 0;
1841  SCEVMulExpr *S =
1842    static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1843  if (!S) {
1844    const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
1845    std::uninitialized_copy(Ops.begin(), Ops.end(), O);
1846    S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
1847                                        O, Ops.size());
1848    UniqueSCEVs.InsertNode(S, IP);
1849  }
1850  if (HasNUW) S->setHasNoUnsignedWrap(true);
1851  if (HasNSW) S->setHasNoSignedWrap(true);
1852  return S;
1853}
1854
1855/// getUDivExpr - Get a canonical unsigned division expression, or something
1856/// simpler if possible.
1857const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
1858                                         const SCEV *RHS) {
1859  assert(getEffectiveSCEVType(LHS->getType()) ==
1860         getEffectiveSCEVType(RHS->getType()) &&
1861         "SCEVUDivExpr operand types don't match!");
1862
1863  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
1864    if (RHSC->getValue()->equalsInt(1))
1865      return LHS;                               // X udiv 1 --> x
1866    // If the denominator is zero, the result of the udiv is undefined. Don't
1867    // try to analyze it, because the resolution chosen here may differ from
1868    // the resolution chosen in other parts of the compiler.
1869    if (!RHSC->getValue()->isZero()) {
1870      // Determine if the division can be folded into the operands of
1871      // its operands.
1872      // TODO: Generalize this to non-constants by using known-bits information.
1873      const Type *Ty = LHS->getType();
1874      unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros();
1875      unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ;
1876      // For non-power-of-two values, effectively round the value up to the
1877      // nearest power of two.
1878      if (!RHSC->getValue()->getValue().isPowerOf2())
1879        ++MaxShiftAmt;
1880      const IntegerType *ExtTy =
1881        IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
1882      // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
1883      if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
1884        if (const SCEVConstant *Step =
1885              dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this)))
1886          if (!Step->getValue()->getValue()
1887                .urem(RHSC->getValue()->getValue()) &&
1888              getZeroExtendExpr(AR, ExtTy) ==
1889              getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
1890                            getZeroExtendExpr(Step, ExtTy),
1891                            AR->getLoop())) {
1892            SmallVector<const SCEV *, 4> Operands;
1893            for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i)
1894              Operands.push_back(getUDivExpr(AR->getOperand(i), RHS));
1895            return getAddRecExpr(Operands, AR->getLoop());
1896          }
1897      // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
1898      if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
1899        SmallVector<const SCEV *, 4> Operands;
1900        for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i)
1901          Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy));
1902        if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
1903          // Find an operand that's safely divisible.
1904          for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
1905            const SCEV *Op = M->getOperand(i);
1906            const SCEV *Div = getUDivExpr(Op, RHSC);
1907            if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
1908              Operands = SmallVector<const SCEV *, 4>(M->op_begin(),
1909                                                      M->op_end());
1910              Operands[i] = Div;
1911              return getMulExpr(Operands);
1912            }
1913          }
1914      }
1915      // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
1916      if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(LHS)) {
1917        SmallVector<const SCEV *, 4> Operands;
1918        for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i)
1919          Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy));
1920        if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
1921          Operands.clear();
1922          for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
1923            const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
1924            if (isa<SCEVUDivExpr>(Op) ||
1925                getMulExpr(Op, RHS) != A->getOperand(i))
1926              break;
1927            Operands.push_back(Op);
1928          }
1929          if (Operands.size() == A->getNumOperands())
1930            return getAddExpr(Operands);
1931        }
1932      }
1933
1934      // Fold if both operands are constant.
1935      if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
1936        Constant *LHSCV = LHSC->getValue();
1937        Constant *RHSCV = RHSC->getValue();
1938        return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
1939                                                                   RHSCV)));
1940      }
1941    }
1942  }
1943
1944  FoldingSetNodeID ID;
1945  ID.AddInteger(scUDivExpr);
1946  ID.AddPointer(LHS);
1947  ID.AddPointer(RHS);
1948  void *IP = 0;
1949  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1950  SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
1951                                             LHS, RHS);
1952  UniqueSCEVs.InsertNode(S, IP);
1953  return S;
1954}
1955
1956
1957/// getAddRecExpr - Get an add recurrence expression for the specified loop.
1958/// Simplify the expression as much as possible.
1959const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start,
1960                                           const SCEV *Step, const Loop *L,
1961                                           bool HasNUW, bool HasNSW) {
1962  SmallVector<const SCEV *, 4> Operands;
1963  Operands.push_back(Start);
1964  if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
1965    if (StepChrec->getLoop() == L) {
1966      Operands.append(StepChrec->op_begin(), StepChrec->op_end());
1967      return getAddRecExpr(Operands, L);
1968    }
1969
1970  Operands.push_back(Step);
1971  return getAddRecExpr(Operands, L, HasNUW, HasNSW);
1972}
1973
1974/// getAddRecExpr - Get an add recurrence expression for the specified loop.
1975/// Simplify the expression as much as possible.
1976const SCEV *
1977ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
1978                               const Loop *L,
1979                               bool HasNUW, bool HasNSW) {
1980  if (Operands.size() == 1) return Operands[0];
1981#ifndef NDEBUG
1982  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
1983    assert(getEffectiveSCEVType(Operands[i]->getType()) ==
1984           getEffectiveSCEVType(Operands[0]->getType()) &&
1985           "SCEVAddRecExpr operand types don't match!");
1986#endif
1987
1988  if (Operands.back()->isZero()) {
1989    Operands.pop_back();
1990    return getAddRecExpr(Operands, L, HasNUW, HasNSW); // {X,+,0}  -->  X
1991  }
1992
1993  // It's tempting to want to call getMaxBackedgeTakenCount count here and
1994  // use that information to infer NUW and NSW flags. However, computing a
1995  // BE count requires calling getAddRecExpr, so we may not yet have a
1996  // meaningful BE count at this point (and if we don't, we'd be stuck
1997  // with a SCEVCouldNotCompute as the cached BE count).
1998
1999  // If HasNSW is true and all the operands are non-negative, infer HasNUW.
2000  if (!HasNUW && HasNSW) {
2001    bool All = true;
2002    for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2003      if (!isKnownNonNegative(Operands[i])) {
2004        All = false;
2005        break;
2006      }
2007    if (All) HasNUW = true;
2008  }
2009
2010  // Canonicalize nested AddRecs in by nesting them in order of loop depth.
2011  if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
2012    const Loop *NestedLoop = NestedAR->getLoop();
2013    if (L->contains(NestedLoop->getHeader()) ?
2014        (L->getLoopDepth() < NestedLoop->getLoopDepth()) :
2015        (!NestedLoop->contains(L->getHeader()) &&
2016         DT->dominates(L->getHeader(), NestedLoop->getHeader()))) {
2017      SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(),
2018                                                  NestedAR->op_end());
2019      Operands[0] = NestedAR->getStart();
2020      // AddRecs require their operands be loop-invariant with respect to their
2021      // loops. Don't perform this transformation if it would break this
2022      // requirement.
2023      bool AllInvariant = true;
2024      for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2025        if (!Operands[i]->isLoopInvariant(L)) {
2026          AllInvariant = false;
2027          break;
2028        }
2029      if (AllInvariant) {
2030        NestedOperands[0] = getAddRecExpr(Operands, L);
2031        AllInvariant = true;
2032        for (unsigned i = 0, e = NestedOperands.size(); i != e; ++i)
2033          if (!NestedOperands[i]->isLoopInvariant(NestedLoop)) {
2034            AllInvariant = false;
2035            break;
2036          }
2037        if (AllInvariant)
2038          // Ok, both add recurrences are valid after the transformation.
2039          return getAddRecExpr(NestedOperands, NestedLoop, HasNUW, HasNSW);
2040      }
2041      // Reset Operands to its original state.
2042      Operands[0] = NestedAR;
2043    }
2044  }
2045
2046  // Okay, it looks like we really DO need an addrec expr.  Check to see if we
2047  // already have one, otherwise create a new one.
2048  FoldingSetNodeID ID;
2049  ID.AddInteger(scAddRecExpr);
2050  ID.AddInteger(Operands.size());
2051  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2052    ID.AddPointer(Operands[i]);
2053  ID.AddPointer(L);
2054  void *IP = 0;
2055  SCEVAddRecExpr *S =
2056    static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2057  if (!S) {
2058    const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Operands.size());
2059    std::uninitialized_copy(Operands.begin(), Operands.end(), O);
2060    S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator),
2061                                           O, Operands.size(), L);
2062    UniqueSCEVs.InsertNode(S, IP);
2063  }
2064  if (HasNUW) S->setHasNoUnsignedWrap(true);
2065  if (HasNSW) S->setHasNoSignedWrap(true);
2066  return S;
2067}
2068
2069const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS,
2070                                         const SCEV *RHS) {
2071  SmallVector<const SCEV *, 2> Ops;
2072  Ops.push_back(LHS);
2073  Ops.push_back(RHS);
2074  return getSMaxExpr(Ops);
2075}
2076
2077const SCEV *
2078ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
2079  assert(!Ops.empty() && "Cannot get empty smax!");
2080  if (Ops.size() == 1) return Ops[0];
2081#ifndef NDEBUG
2082  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2083    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
2084           getEffectiveSCEVType(Ops[0]->getType()) &&
2085           "SCEVSMaxExpr operand types don't match!");
2086#endif
2087
2088  // Sort by complexity, this groups all similar expression types together.
2089  GroupByComplexity(Ops, LI);
2090
2091  // If there are any constants, fold them together.
2092  unsigned Idx = 0;
2093  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2094    ++Idx;
2095    assert(Idx < Ops.size());
2096    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2097      // We found two constants, fold them together!
2098      ConstantInt *Fold = ConstantInt::get(getContext(),
2099                              APIntOps::smax(LHSC->getValue()->getValue(),
2100                                             RHSC->getValue()->getValue()));
2101      Ops[0] = getConstant(Fold);
2102      Ops.erase(Ops.begin()+1);  // Erase the folded element
2103      if (Ops.size() == 1) return Ops[0];
2104      LHSC = cast<SCEVConstant>(Ops[0]);
2105    }
2106
2107    // If we are left with a constant minimum-int, strip it off.
2108    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
2109      Ops.erase(Ops.begin());
2110      --Idx;
2111    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(true)) {
2112      // If we have an smax with a constant maximum-int, it will always be
2113      // maximum-int.
2114      return Ops[0];
2115    }
2116
2117    if (Ops.size() == 1) return Ops[0];
2118  }
2119
2120  // Find the first SMax
2121  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
2122    ++Idx;
2123
2124  // Check to see if one of the operands is an SMax. If so, expand its operands
2125  // onto our operand list, and recurse to simplify.
2126  if (Idx < Ops.size()) {
2127    bool DeletedSMax = false;
2128    while (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
2129      Ops.erase(Ops.begin()+Idx);
2130      Ops.append(SMax->op_begin(), SMax->op_end());
2131      DeletedSMax = true;
2132    }
2133
2134    if (DeletedSMax)
2135      return getSMaxExpr(Ops);
2136  }
2137
2138  // Okay, check to see if the same value occurs in the operand list twice.  If
2139  // so, delete one.  Since we sorted the list, these values are required to
2140  // be adjacent.
2141  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
2142    //  X smax Y smax Y  -->  X smax Y
2143    //  X smax Y         -->  X, if X is always greater than Y
2144    if (Ops[i] == Ops[i+1] ||
2145        isKnownPredicate(ICmpInst::ICMP_SGE, Ops[i], Ops[i+1])) {
2146      Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2);
2147      --i; --e;
2148    } else if (isKnownPredicate(ICmpInst::ICMP_SLE, Ops[i], Ops[i+1])) {
2149      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
2150      --i; --e;
2151    }
2152
2153  if (Ops.size() == 1) return Ops[0];
2154
2155  assert(!Ops.empty() && "Reduced smax down to nothing!");
2156
2157  // Okay, it looks like we really DO need an smax expr.  Check to see if we
2158  // already have one, otherwise create a new one.
2159  FoldingSetNodeID ID;
2160  ID.AddInteger(scSMaxExpr);
2161  ID.AddInteger(Ops.size());
2162  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2163    ID.AddPointer(Ops[i]);
2164  void *IP = 0;
2165  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2166  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2167  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2168  SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator),
2169                                             O, Ops.size());
2170  UniqueSCEVs.InsertNode(S, IP);
2171  return S;
2172}
2173
2174const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS,
2175                                         const SCEV *RHS) {
2176  SmallVector<const SCEV *, 2> Ops;
2177  Ops.push_back(LHS);
2178  Ops.push_back(RHS);
2179  return getUMaxExpr(Ops);
2180}
2181
2182const SCEV *
2183ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
2184  assert(!Ops.empty() && "Cannot get empty umax!");
2185  if (Ops.size() == 1) return Ops[0];
2186#ifndef NDEBUG
2187  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2188    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
2189           getEffectiveSCEVType(Ops[0]->getType()) &&
2190           "SCEVUMaxExpr operand types don't match!");
2191#endif
2192
2193  // Sort by complexity, this groups all similar expression types together.
2194  GroupByComplexity(Ops, LI);
2195
2196  // If there are any constants, fold them together.
2197  unsigned Idx = 0;
2198  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2199    ++Idx;
2200    assert(Idx < Ops.size());
2201    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2202      // We found two constants, fold them together!
2203      ConstantInt *Fold = ConstantInt::get(getContext(),
2204                              APIntOps::umax(LHSC->getValue()->getValue(),
2205                                             RHSC->getValue()->getValue()));
2206      Ops[0] = getConstant(Fold);
2207      Ops.erase(Ops.begin()+1);  // Erase the folded element
2208      if (Ops.size() == 1) return Ops[0];
2209      LHSC = cast<SCEVConstant>(Ops[0]);
2210    }
2211
2212    // If we are left with a constant minimum-int, strip it off.
2213    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) {
2214      Ops.erase(Ops.begin());
2215      --Idx;
2216    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(false)) {
2217      // If we have an umax with a constant maximum-int, it will always be
2218      // maximum-int.
2219      return Ops[0];
2220    }
2221
2222    if (Ops.size() == 1) return Ops[0];
2223  }
2224
2225  // Find the first UMax
2226  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr)
2227    ++Idx;
2228
2229  // Check to see if one of the operands is a UMax. If so, expand its operands
2230  // onto our operand list, and recurse to simplify.
2231  if (Idx < Ops.size()) {
2232    bool DeletedUMax = false;
2233    while (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) {
2234      Ops.erase(Ops.begin()+Idx);
2235      Ops.append(UMax->op_begin(), UMax->op_end());
2236      DeletedUMax = true;
2237    }
2238
2239    if (DeletedUMax)
2240      return getUMaxExpr(Ops);
2241  }
2242
2243  // Okay, check to see if the same value occurs in the operand list twice.  If
2244  // so, delete one.  Since we sorted the list, these values are required to
2245  // be adjacent.
2246  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
2247    //  X umax Y umax Y  -->  X umax Y
2248    //  X umax Y         -->  X, if X is always greater than Y
2249    if (Ops[i] == Ops[i+1] ||
2250        isKnownPredicate(ICmpInst::ICMP_UGE, Ops[i], Ops[i+1])) {
2251      Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2);
2252      --i; --e;
2253    } else if (isKnownPredicate(ICmpInst::ICMP_ULE, Ops[i], Ops[i+1])) {
2254      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
2255      --i; --e;
2256    }
2257
2258  if (Ops.size() == 1) return Ops[0];
2259
2260  assert(!Ops.empty() && "Reduced umax down to nothing!");
2261
2262  // Okay, it looks like we really DO need a umax expr.  Check to see if we
2263  // already have one, otherwise create a new one.
2264  FoldingSetNodeID ID;
2265  ID.AddInteger(scUMaxExpr);
2266  ID.AddInteger(Ops.size());
2267  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2268    ID.AddPointer(Ops[i]);
2269  void *IP = 0;
2270  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2271  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2272  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2273  SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator),
2274                                             O, Ops.size());
2275  UniqueSCEVs.InsertNode(S, IP);
2276  return S;
2277}
2278
2279const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
2280                                         const SCEV *RHS) {
2281  // ~smax(~x, ~y) == smin(x, y).
2282  return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
2283}
2284
2285const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS,
2286                                         const SCEV *RHS) {
2287  // ~umax(~x, ~y) == umin(x, y)
2288  return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
2289}
2290
2291const SCEV *ScalarEvolution::getSizeOfExpr(const Type *AllocTy) {
2292  // If we have TargetData, we can bypass creating a target-independent
2293  // constant expression and then folding it back into a ConstantInt.
2294  // This is just a compile-time optimization.
2295  if (TD)
2296    return getConstant(TD->getIntPtrType(getContext()),
2297                       TD->getTypeAllocSize(AllocTy));
2298
2299  Constant *C = ConstantExpr::getSizeOf(AllocTy);
2300  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2301    if (Constant *Folded = ConstantFoldConstantExpression(CE, TD))
2302      C = Folded;
2303  const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy));
2304  return getTruncateOrZeroExtend(getSCEV(C), Ty);
2305}
2306
2307const SCEV *ScalarEvolution::getAlignOfExpr(const Type *AllocTy) {
2308  Constant *C = ConstantExpr::getAlignOf(AllocTy);
2309  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2310    if (Constant *Folded = ConstantFoldConstantExpression(CE, TD))
2311      C = Folded;
2312  const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy));
2313  return getTruncateOrZeroExtend(getSCEV(C), Ty);
2314}
2315
2316const SCEV *ScalarEvolution::getOffsetOfExpr(const StructType *STy,
2317                                             unsigned FieldNo) {
2318  // If we have TargetData, we can bypass creating a target-independent
2319  // constant expression and then folding it back into a ConstantInt.
2320  // This is just a compile-time optimization.
2321  if (TD)
2322    return getConstant(TD->getIntPtrType(getContext()),
2323                       TD->getStructLayout(STy)->getElementOffset(FieldNo));
2324
2325  Constant *C = ConstantExpr::getOffsetOf(STy, FieldNo);
2326  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2327    if (Constant *Folded = ConstantFoldConstantExpression(CE, TD))
2328      C = Folded;
2329  const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(STy));
2330  return getTruncateOrZeroExtend(getSCEV(C), Ty);
2331}
2332
2333const SCEV *ScalarEvolution::getOffsetOfExpr(const Type *CTy,
2334                                             Constant *FieldNo) {
2335  Constant *C = ConstantExpr::getOffsetOf(CTy, FieldNo);
2336  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2337    if (Constant *Folded = ConstantFoldConstantExpression(CE, TD))
2338      C = Folded;
2339  const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(CTy));
2340  return getTruncateOrZeroExtend(getSCEV(C), Ty);
2341}
2342
2343const SCEV *ScalarEvolution::getUnknown(Value *V) {
2344  // Don't attempt to do anything other than create a SCEVUnknown object
2345  // here.  createSCEV only calls getUnknown after checking for all other
2346  // interesting possibilities, and any other code that calls getUnknown
2347  // is doing so in order to hide a value from SCEV canonicalization.
2348
2349  FoldingSetNodeID ID;
2350  ID.AddInteger(scUnknown);
2351  ID.AddPointer(V);
2352  void *IP = 0;
2353  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2354  SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V);
2355  UniqueSCEVs.InsertNode(S, IP);
2356  return S;
2357}
2358
2359//===----------------------------------------------------------------------===//
2360//            Basic SCEV Analysis and PHI Idiom Recognition Code
2361//
2362
2363/// isSCEVable - Test if values of the given type are analyzable within
2364/// the SCEV framework. This primarily includes integer types, and it
2365/// can optionally include pointer types if the ScalarEvolution class
2366/// has access to target-specific information.
2367bool ScalarEvolution::isSCEVable(const Type *Ty) const {
2368  // Integers and pointers are always SCEVable.
2369  return Ty->isIntegerTy() || Ty->isPointerTy();
2370}
2371
2372/// getTypeSizeInBits - Return the size in bits of the specified type,
2373/// for which isSCEVable must return true.
2374uint64_t ScalarEvolution::getTypeSizeInBits(const Type *Ty) const {
2375  assert(isSCEVable(Ty) && "Type is not SCEVable!");
2376
2377  // If we have a TargetData, use it!
2378  if (TD)
2379    return TD->getTypeSizeInBits(Ty);
2380
2381  // Integer types have fixed sizes.
2382  if (Ty->isIntegerTy())
2383    return Ty->getPrimitiveSizeInBits();
2384
2385  // The only other support type is pointer. Without TargetData, conservatively
2386  // assume pointers are 64-bit.
2387  assert(Ty->isPointerTy() && "isSCEVable permitted a non-SCEVable type!");
2388  return 64;
2389}
2390
2391/// getEffectiveSCEVType - Return a type with the same bitwidth as
2392/// the given type and which represents how SCEV will treat the given
2393/// type, for which isSCEVable must return true. For pointer types,
2394/// this is the pointer-sized integer type.
2395const Type *ScalarEvolution::getEffectiveSCEVType(const Type *Ty) const {
2396  assert(isSCEVable(Ty) && "Type is not SCEVable!");
2397
2398  if (Ty->isIntegerTy())
2399    return Ty;
2400
2401  // The only other support type is pointer.
2402  assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
2403  if (TD) return TD->getIntPtrType(getContext());
2404
2405  // Without TargetData, conservatively assume pointers are 64-bit.
2406  return Type::getInt64Ty(getContext());
2407}
2408
2409const SCEV *ScalarEvolution::getCouldNotCompute() {
2410  return &CouldNotCompute;
2411}
2412
2413/// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
2414/// expression and create a new one.
2415const SCEV *ScalarEvolution::getSCEV(Value *V) {
2416  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
2417
2418  std::map<SCEVCallbackVH, const SCEV *>::iterator I = Scalars.find(V);
2419  if (I != Scalars.end()) return I->second;
2420  const SCEV *S = createSCEV(V);
2421  Scalars.insert(std::make_pair(SCEVCallbackVH(V, this), S));
2422  return S;
2423}
2424
2425/// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
2426///
2427const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V) {
2428  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
2429    return getConstant(
2430               cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
2431
2432  const Type *Ty = V->getType();
2433  Ty = getEffectiveSCEVType(Ty);
2434  return getMulExpr(V,
2435                  getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))));
2436}
2437
2438/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
2439const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
2440  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
2441    return getConstant(
2442                cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
2443
2444  const Type *Ty = V->getType();
2445  Ty = getEffectiveSCEVType(Ty);
2446  const SCEV *AllOnes =
2447                   getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty)));
2448  return getMinusSCEV(AllOnes, V);
2449}
2450
2451/// getMinusSCEV - Return a SCEV corresponding to LHS - RHS.
2452///
2453const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS,
2454                                          const SCEV *RHS) {
2455  // Fast path: X - X --> 0.
2456  if (LHS == RHS)
2457    return getConstant(LHS->getType(), 0);
2458
2459  // X - Y --> X + -Y
2460  return getAddExpr(LHS, getNegativeSCEV(RHS));
2461}
2462
2463/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
2464/// input value to the specified type.  If the type must be extended, it is zero
2465/// extended.
2466const SCEV *
2467ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V,
2468                                         const Type *Ty) {
2469  const Type *SrcTy = V->getType();
2470  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2471         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2472         "Cannot truncate or zero extend with non-integer arguments!");
2473  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2474    return V;  // No conversion
2475  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
2476    return getTruncateExpr(V, Ty);
2477  return getZeroExtendExpr(V, Ty);
2478}
2479
2480/// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the
2481/// input value to the specified type.  If the type must be extended, it is sign
2482/// extended.
2483const SCEV *
2484ScalarEvolution::getTruncateOrSignExtend(const SCEV *V,
2485                                         const Type *Ty) {
2486  const Type *SrcTy = V->getType();
2487  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2488         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2489         "Cannot truncate or zero extend with non-integer arguments!");
2490  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2491    return V;  // No conversion
2492  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
2493    return getTruncateExpr(V, Ty);
2494  return getSignExtendExpr(V, Ty);
2495}
2496
2497/// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the
2498/// input value to the specified type.  If the type must be extended, it is zero
2499/// extended.  The conversion must not be narrowing.
2500const SCEV *
2501ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, const Type *Ty) {
2502  const Type *SrcTy = V->getType();
2503  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2504         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2505         "Cannot noop or zero extend with non-integer arguments!");
2506  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2507         "getNoopOrZeroExtend cannot truncate!");
2508  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2509    return V;  // No conversion
2510  return getZeroExtendExpr(V, Ty);
2511}
2512
2513/// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the
2514/// input value to the specified type.  If the type must be extended, it is sign
2515/// extended.  The conversion must not be narrowing.
2516const SCEV *
2517ScalarEvolution::getNoopOrSignExtend(const SCEV *V, const Type *Ty) {
2518  const Type *SrcTy = V->getType();
2519  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2520         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2521         "Cannot noop or sign extend with non-integer arguments!");
2522  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2523         "getNoopOrSignExtend cannot truncate!");
2524  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2525    return V;  // No conversion
2526  return getSignExtendExpr(V, Ty);
2527}
2528
2529/// getNoopOrAnyExtend - Return a SCEV corresponding to a conversion of
2530/// the input value to the specified type. If the type must be extended,
2531/// it is extended with unspecified bits. The conversion must not be
2532/// narrowing.
2533const SCEV *
2534ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, const Type *Ty) {
2535  const Type *SrcTy = V->getType();
2536  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2537         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2538         "Cannot noop or any extend with non-integer arguments!");
2539  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2540         "getNoopOrAnyExtend cannot truncate!");
2541  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2542    return V;  // No conversion
2543  return getAnyExtendExpr(V, Ty);
2544}
2545
2546/// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the
2547/// input value to the specified type.  The conversion must not be widening.
2548const SCEV *
2549ScalarEvolution::getTruncateOrNoop(const SCEV *V, const Type *Ty) {
2550  const Type *SrcTy = V->getType();
2551  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2552         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2553         "Cannot truncate or noop with non-integer arguments!");
2554  assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
2555         "getTruncateOrNoop cannot extend!");
2556  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2557    return V;  // No conversion
2558  return getTruncateExpr(V, Ty);
2559}
2560
2561/// getUMaxFromMismatchedTypes - Promote the operands to the wider of
2562/// the types using zero-extension, and then perform a umax operation
2563/// with them.
2564const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
2565                                                        const SCEV *RHS) {
2566  const SCEV *PromotedLHS = LHS;
2567  const SCEV *PromotedRHS = RHS;
2568
2569  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
2570    PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
2571  else
2572    PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
2573
2574  return getUMaxExpr(PromotedLHS, PromotedRHS);
2575}
2576
2577/// getUMinFromMismatchedTypes - Promote the operands to the wider of
2578/// the types using zero-extension, and then perform a umin operation
2579/// with them.
2580const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
2581                                                        const SCEV *RHS) {
2582  const SCEV *PromotedLHS = LHS;
2583  const SCEV *PromotedRHS = RHS;
2584
2585  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
2586    PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
2587  else
2588    PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
2589
2590  return getUMinExpr(PromotedLHS, PromotedRHS);
2591}
2592
2593/// PushDefUseChildren - Push users of the given Instruction
2594/// onto the given Worklist.
2595static void
2596PushDefUseChildren(Instruction *I,
2597                   SmallVectorImpl<Instruction *> &Worklist) {
2598  // Push the def-use children onto the Worklist stack.
2599  for (Value::use_iterator UI = I->use_begin(), UE = I->use_end();
2600       UI != UE; ++UI)
2601    Worklist.push_back(cast<Instruction>(*UI));
2602}
2603
2604/// ForgetSymbolicValue - This looks up computed SCEV values for all
2605/// instructions that depend on the given instruction and removes them from
2606/// the Scalars map if they reference SymName. This is used during PHI
2607/// resolution.
2608void
2609ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) {
2610  SmallVector<Instruction *, 16> Worklist;
2611  PushDefUseChildren(PN, Worklist);
2612
2613  SmallPtrSet<Instruction *, 8> Visited;
2614  Visited.insert(PN);
2615  while (!Worklist.empty()) {
2616    Instruction *I = Worklist.pop_back_val();
2617    if (!Visited.insert(I)) continue;
2618
2619    std::map<SCEVCallbackVH, const SCEV *>::iterator It =
2620      Scalars.find(static_cast<Value *>(I));
2621    if (It != Scalars.end()) {
2622      // Short-circuit the def-use traversal if the symbolic name
2623      // ceases to appear in expressions.
2624      if (It->second != SymName && !It->second->hasOperand(SymName))
2625        continue;
2626
2627      // SCEVUnknown for a PHI either means that it has an unrecognized
2628      // structure, it's a PHI that's in the progress of being computed
2629      // by createNodeForPHI, or it's a single-value PHI. In the first case,
2630      // additional loop trip count information isn't going to change anything.
2631      // In the second case, createNodeForPHI will perform the necessary
2632      // updates on its own when it gets to that point. In the third, we do
2633      // want to forget the SCEVUnknown.
2634      if (!isa<PHINode>(I) ||
2635          !isa<SCEVUnknown>(It->second) ||
2636          (I != PN && It->second == SymName)) {
2637        ValuesAtScopes.erase(It->second);
2638        Scalars.erase(It);
2639      }
2640    }
2641
2642    PushDefUseChildren(I, Worklist);
2643  }
2644}
2645
2646/// createNodeForPHI - PHI nodes have two cases.  Either the PHI node exists in
2647/// a loop header, making it a potential recurrence, or it doesn't.
2648///
2649const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
2650  if (const Loop *L = LI->getLoopFor(PN->getParent()))
2651    if (L->getHeader() == PN->getParent()) {
2652      // The loop may have multiple entrances or multiple exits; we can analyze
2653      // this phi as an addrec if it has a unique entry value and a unique
2654      // backedge value.
2655      Value *BEValueV = 0, *StartValueV = 0;
2656      for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
2657        Value *V = PN->getIncomingValue(i);
2658        if (L->contains(PN->getIncomingBlock(i))) {
2659          if (!BEValueV) {
2660            BEValueV = V;
2661          } else if (BEValueV != V) {
2662            BEValueV = 0;
2663            break;
2664          }
2665        } else if (!StartValueV) {
2666          StartValueV = V;
2667        } else if (StartValueV != V) {
2668          StartValueV = 0;
2669          break;
2670        }
2671      }
2672      if (BEValueV && StartValueV) {
2673        // While we are analyzing this PHI node, handle its value symbolically.
2674        const SCEV *SymbolicName = getUnknown(PN);
2675        assert(Scalars.find(PN) == Scalars.end() &&
2676               "PHI node already processed?");
2677        Scalars.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName));
2678
2679        // Using this symbolic name for the PHI, analyze the value coming around
2680        // the back-edge.
2681        const SCEV *BEValue = getSCEV(BEValueV);
2682
2683        // NOTE: If BEValue is loop invariant, we know that the PHI node just
2684        // has a special value for the first iteration of the loop.
2685
2686        // If the value coming around the backedge is an add with the symbolic
2687        // value we just inserted, then we found a simple induction variable!
2688        if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
2689          // If there is a single occurrence of the symbolic value, replace it
2690          // with a recurrence.
2691          unsigned FoundIndex = Add->getNumOperands();
2692          for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
2693            if (Add->getOperand(i) == SymbolicName)
2694              if (FoundIndex == e) {
2695                FoundIndex = i;
2696                break;
2697              }
2698
2699          if (FoundIndex != Add->getNumOperands()) {
2700            // Create an add with everything but the specified operand.
2701            SmallVector<const SCEV *, 8> Ops;
2702            for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
2703              if (i != FoundIndex)
2704                Ops.push_back(Add->getOperand(i));
2705            const SCEV *Accum = getAddExpr(Ops);
2706
2707            // This is not a valid addrec if the step amount is varying each
2708            // loop iteration, but is not itself an addrec in this loop.
2709            if (Accum->isLoopInvariant(L) ||
2710                (isa<SCEVAddRecExpr>(Accum) &&
2711                 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
2712              bool HasNUW = false;
2713              bool HasNSW = false;
2714
2715              // If the increment doesn't overflow, then neither the addrec nor
2716              // the post-increment will overflow.
2717              if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV)) {
2718                if (OBO->hasNoUnsignedWrap())
2719                  HasNUW = true;
2720                if (OBO->hasNoSignedWrap())
2721                  HasNSW = true;
2722              }
2723
2724              const SCEV *StartVal = getSCEV(StartValueV);
2725              const SCEV *PHISCEV =
2726                getAddRecExpr(StartVal, Accum, L, HasNUW, HasNSW);
2727
2728              // Since the no-wrap flags are on the increment, they apply to the
2729              // post-incremented value as well.
2730              if (Accum->isLoopInvariant(L))
2731                (void)getAddRecExpr(getAddExpr(StartVal, Accum),
2732                                    Accum, L, HasNUW, HasNSW);
2733
2734              // Okay, for the entire analysis of this edge we assumed the PHI
2735              // to be symbolic.  We now need to go back and purge all of the
2736              // entries for the scalars that use the symbolic expression.
2737              ForgetSymbolicName(PN, SymbolicName);
2738              Scalars[SCEVCallbackVH(PN, this)] = PHISCEV;
2739              return PHISCEV;
2740            }
2741          }
2742        } else if (const SCEVAddRecExpr *AddRec =
2743                     dyn_cast<SCEVAddRecExpr>(BEValue)) {
2744          // Otherwise, this could be a loop like this:
2745          //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
2746          // In this case, j = {1,+,1}  and BEValue is j.
2747          // Because the other in-value of i (0) fits the evolution of BEValue
2748          // i really is an addrec evolution.
2749          if (AddRec->getLoop() == L && AddRec->isAffine()) {
2750            const SCEV *StartVal = getSCEV(StartValueV);
2751
2752            // If StartVal = j.start - j.stride, we can use StartVal as the
2753            // initial step of the addrec evolution.
2754            if (StartVal == getMinusSCEV(AddRec->getOperand(0),
2755                                         AddRec->getOperand(1))) {
2756              const SCEV *PHISCEV =
2757                 getAddRecExpr(StartVal, AddRec->getOperand(1), L);
2758
2759              // Okay, for the entire analysis of this edge we assumed the PHI
2760              // to be symbolic.  We now need to go back and purge all of the
2761              // entries for the scalars that use the symbolic expression.
2762              ForgetSymbolicName(PN, SymbolicName);
2763              Scalars[SCEVCallbackVH(PN, this)] = PHISCEV;
2764              return PHISCEV;
2765            }
2766          }
2767        }
2768      }
2769    }
2770
2771  // If the PHI has a single incoming value, follow that value, unless the
2772  // PHI's incoming blocks are in a different loop, in which case doing so
2773  // risks breaking LCSSA form. Instcombine would normally zap these, but
2774  // it doesn't have DominatorTree information, so it may miss cases.
2775  if (Value *V = PN->hasConstantValue(DT)) {
2776    bool AllSameLoop = true;
2777    Loop *PNLoop = LI->getLoopFor(PN->getParent());
2778    for (size_t i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
2779      if (LI->getLoopFor(PN->getIncomingBlock(i)) != PNLoop) {
2780        AllSameLoop = false;
2781        break;
2782      }
2783    if (AllSameLoop)
2784      return getSCEV(V);
2785  }
2786
2787  // If it's not a loop phi, we can't handle it yet.
2788  return getUnknown(PN);
2789}
2790
2791/// createNodeForGEP - Expand GEP instructions into add and multiply
2792/// operations. This allows them to be analyzed by regular SCEV code.
2793///
2794const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
2795
2796  // Don't blindly transfer the inbounds flag from the GEP instruction to the
2797  // Add expression, because the Instruction may be guarded by control flow
2798  // and the no-overflow bits may not be valid for the expression in any
2799  // context.
2800
2801  const Type *IntPtrTy = getEffectiveSCEVType(GEP->getType());
2802  Value *Base = GEP->getOperand(0);
2803  // Don't attempt to analyze GEPs over unsized objects.
2804  if (!cast<PointerType>(Base->getType())->getElementType()->isSized())
2805    return getUnknown(GEP);
2806  const SCEV *TotalOffset = getConstant(IntPtrTy, 0);
2807  gep_type_iterator GTI = gep_type_begin(GEP);
2808  for (GetElementPtrInst::op_iterator I = next(GEP->op_begin()),
2809                                      E = GEP->op_end();
2810       I != E; ++I) {
2811    Value *Index = *I;
2812    // Compute the (potentially symbolic) offset in bytes for this index.
2813    if (const StructType *STy = dyn_cast<StructType>(*GTI++)) {
2814      // For a struct, add the member offset.
2815      unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue();
2816      const SCEV *FieldOffset = getOffsetOfExpr(STy, FieldNo);
2817
2818      // Add the field offset to the running total offset.
2819      TotalOffset = getAddExpr(TotalOffset, FieldOffset);
2820    } else {
2821      // For an array, add the element offset, explicitly scaled.
2822      const SCEV *ElementSize = getSizeOfExpr(*GTI);
2823      const SCEV *IndexS = getSCEV(Index);
2824      // Getelementptr indices are signed.
2825      IndexS = getTruncateOrSignExtend(IndexS, IntPtrTy);
2826
2827      // Multiply the index by the element size to compute the element offset.
2828      const SCEV *LocalOffset = getMulExpr(IndexS, ElementSize);
2829
2830      // Add the element offset to the running total offset.
2831      TotalOffset = getAddExpr(TotalOffset, LocalOffset);
2832    }
2833  }
2834
2835  // Get the SCEV for the GEP base.
2836  const SCEV *BaseS = getSCEV(Base);
2837
2838  // Add the total offset from all the GEP indices to the base.
2839  return getAddExpr(BaseS, TotalOffset);
2840}
2841
2842/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
2843/// guaranteed to end in (at every loop iteration).  It is, at the same time,
2844/// the minimum number of times S is divisible by 2.  For example, given {4,+,8}
2845/// it returns 2.  If S is guaranteed to be 0, it returns the bitwidth of S.
2846uint32_t
2847ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
2848  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
2849    return C->getValue()->getValue().countTrailingZeros();
2850
2851  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
2852    return std::min(GetMinTrailingZeros(T->getOperand()),
2853                    (uint32_t)getTypeSizeInBits(T->getType()));
2854
2855  if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
2856    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
2857    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
2858             getTypeSizeInBits(E->getType()) : OpRes;
2859  }
2860
2861  if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
2862    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
2863    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
2864             getTypeSizeInBits(E->getType()) : OpRes;
2865  }
2866
2867  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
2868    // The result is the min of all operands results.
2869    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
2870    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
2871      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
2872    return MinOpRes;
2873  }
2874
2875  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
2876    // The result is the sum of all operands results.
2877    uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
2878    uint32_t BitWidth = getTypeSizeInBits(M->getType());
2879    for (unsigned i = 1, e = M->getNumOperands();
2880         SumOpRes != BitWidth && i != e; ++i)
2881      SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
2882                          BitWidth);
2883    return SumOpRes;
2884  }
2885
2886  if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
2887    // The result is the min of all operands results.
2888    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
2889    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
2890      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
2891    return MinOpRes;
2892  }
2893
2894  if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
2895    // The result is the min of all operands results.
2896    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
2897    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
2898      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
2899    return MinOpRes;
2900  }
2901
2902  if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
2903    // The result is the min of all operands results.
2904    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
2905    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
2906      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
2907    return MinOpRes;
2908  }
2909
2910  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
2911    // For a SCEVUnknown, ask ValueTracking.
2912    unsigned BitWidth = getTypeSizeInBits(U->getType());
2913    APInt Mask = APInt::getAllOnesValue(BitWidth);
2914    APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
2915    ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones);
2916    return Zeros.countTrailingOnes();
2917  }
2918
2919  // SCEVUDivExpr
2920  return 0;
2921}
2922
2923/// getUnsignedRange - Determine the unsigned range for a particular SCEV.
2924///
2925ConstantRange
2926ScalarEvolution::getUnsignedRange(const SCEV *S) {
2927
2928  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
2929    return ConstantRange(C->getValue()->getValue());
2930
2931  unsigned BitWidth = getTypeSizeInBits(S->getType());
2932  ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
2933
2934  // If the value has known zeros, the maximum unsigned value will have those
2935  // known zeros as well.
2936  uint32_t TZ = GetMinTrailingZeros(S);
2937  if (TZ != 0)
2938    ConservativeResult =
2939      ConstantRange(APInt::getMinValue(BitWidth),
2940                    APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
2941
2942  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
2943    ConstantRange X = getUnsignedRange(Add->getOperand(0));
2944    for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
2945      X = X.add(getUnsignedRange(Add->getOperand(i)));
2946    return ConservativeResult.intersectWith(X);
2947  }
2948
2949  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
2950    ConstantRange X = getUnsignedRange(Mul->getOperand(0));
2951    for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
2952      X = X.multiply(getUnsignedRange(Mul->getOperand(i)));
2953    return ConservativeResult.intersectWith(X);
2954  }
2955
2956  if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
2957    ConstantRange X = getUnsignedRange(SMax->getOperand(0));
2958    for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
2959      X = X.smax(getUnsignedRange(SMax->getOperand(i)));
2960    return ConservativeResult.intersectWith(X);
2961  }
2962
2963  if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
2964    ConstantRange X = getUnsignedRange(UMax->getOperand(0));
2965    for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
2966      X = X.umax(getUnsignedRange(UMax->getOperand(i)));
2967    return ConservativeResult.intersectWith(X);
2968  }
2969
2970  if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
2971    ConstantRange X = getUnsignedRange(UDiv->getLHS());
2972    ConstantRange Y = getUnsignedRange(UDiv->getRHS());
2973    return ConservativeResult.intersectWith(X.udiv(Y));
2974  }
2975
2976  if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
2977    ConstantRange X = getUnsignedRange(ZExt->getOperand());
2978    return ConservativeResult.intersectWith(X.zeroExtend(BitWidth));
2979  }
2980
2981  if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
2982    ConstantRange X = getUnsignedRange(SExt->getOperand());
2983    return ConservativeResult.intersectWith(X.signExtend(BitWidth));
2984  }
2985
2986  if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
2987    ConstantRange X = getUnsignedRange(Trunc->getOperand());
2988    return ConservativeResult.intersectWith(X.truncate(BitWidth));
2989  }
2990
2991  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
2992    // If there's no unsigned wrap, the value will never be less than its
2993    // initial value.
2994    if (AddRec->hasNoUnsignedWrap())
2995      if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart()))
2996        if (!C->getValue()->isZero())
2997          ConservativeResult =
2998            ConservativeResult.intersectWith(
2999              ConstantRange(C->getValue()->getValue(), APInt(BitWidth, 0)));
3000
3001    // TODO: non-affine addrec
3002    if (AddRec->isAffine()) {
3003      const Type *Ty = AddRec->getType();
3004      const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
3005      if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
3006          getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
3007        MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
3008
3009        const SCEV *Start = AddRec->getStart();
3010        const SCEV *Step = AddRec->getStepRecurrence(*this);
3011
3012        ConstantRange StartRange = getUnsignedRange(Start);
3013        ConstantRange StepRange = getSignedRange(Step);
3014        ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
3015        ConstantRange EndRange =
3016          StartRange.add(MaxBECountRange.multiply(StepRange));
3017
3018        // Check for overflow. This must be done with ConstantRange arithmetic
3019        // because we could be called from within the ScalarEvolution overflow
3020        // checking code.
3021        ConstantRange ExtStartRange = StartRange.zextOrTrunc(BitWidth*2+1);
3022        ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1);
3023        ConstantRange ExtMaxBECountRange =
3024          MaxBECountRange.zextOrTrunc(BitWidth*2+1);
3025        ConstantRange ExtEndRange = EndRange.zextOrTrunc(BitWidth*2+1);
3026        if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) !=
3027            ExtEndRange)
3028          return ConservativeResult;
3029
3030        APInt Min = APIntOps::umin(StartRange.getUnsignedMin(),
3031                                   EndRange.getUnsignedMin());
3032        APInt Max = APIntOps::umax(StartRange.getUnsignedMax(),
3033                                   EndRange.getUnsignedMax());
3034        if (Min.isMinValue() && Max.isMaxValue())
3035          return ConservativeResult;
3036        return ConservativeResult.intersectWith(ConstantRange(Min, Max+1));
3037      }
3038    }
3039
3040    return ConservativeResult;
3041  }
3042
3043  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
3044    // For a SCEVUnknown, ask ValueTracking.
3045    APInt Mask = APInt::getAllOnesValue(BitWidth);
3046    APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
3047    ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones, TD);
3048    if (Ones == ~Zeros + 1)
3049      return ConservativeResult;
3050    return ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1));
3051  }
3052
3053  return ConservativeResult;
3054}
3055
3056/// getSignedRange - Determine the signed range for a particular SCEV.
3057///
3058ConstantRange
3059ScalarEvolution::getSignedRange(const SCEV *S) {
3060
3061  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
3062    return ConstantRange(C->getValue()->getValue());
3063
3064  unsigned BitWidth = getTypeSizeInBits(S->getType());
3065  ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
3066
3067  // If the value has known zeros, the maximum signed value will have those
3068  // known zeros as well.
3069  uint32_t TZ = GetMinTrailingZeros(S);
3070  if (TZ != 0)
3071    ConservativeResult =
3072      ConstantRange(APInt::getSignedMinValue(BitWidth),
3073                    APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
3074
3075  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
3076    ConstantRange X = getSignedRange(Add->getOperand(0));
3077    for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
3078      X = X.add(getSignedRange(Add->getOperand(i)));
3079    return ConservativeResult.intersectWith(X);
3080  }
3081
3082  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
3083    ConstantRange X = getSignedRange(Mul->getOperand(0));
3084    for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
3085      X = X.multiply(getSignedRange(Mul->getOperand(i)));
3086    return ConservativeResult.intersectWith(X);
3087  }
3088
3089  if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
3090    ConstantRange X = getSignedRange(SMax->getOperand(0));
3091    for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
3092      X = X.smax(getSignedRange(SMax->getOperand(i)));
3093    return ConservativeResult.intersectWith(X);
3094  }
3095
3096  if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
3097    ConstantRange X = getSignedRange(UMax->getOperand(0));
3098    for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
3099      X = X.umax(getSignedRange(UMax->getOperand(i)));
3100    return ConservativeResult.intersectWith(X);
3101  }
3102
3103  if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
3104    ConstantRange X = getSignedRange(UDiv->getLHS());
3105    ConstantRange Y = getSignedRange(UDiv->getRHS());
3106    return ConservativeResult.intersectWith(X.udiv(Y));
3107  }
3108
3109  if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
3110    ConstantRange X = getSignedRange(ZExt->getOperand());
3111    return ConservativeResult.intersectWith(X.zeroExtend(BitWidth));
3112  }
3113
3114  if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
3115    ConstantRange X = getSignedRange(SExt->getOperand());
3116    return ConservativeResult.intersectWith(X.signExtend(BitWidth));
3117  }
3118
3119  if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
3120    ConstantRange X = getSignedRange(Trunc->getOperand());
3121    return ConservativeResult.intersectWith(X.truncate(BitWidth));
3122  }
3123
3124  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
3125    // If there's no signed wrap, and all the operands have the same sign or
3126    // zero, the value won't ever change sign.
3127    if (AddRec->hasNoSignedWrap()) {
3128      bool AllNonNeg = true;
3129      bool AllNonPos = true;
3130      for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3131        if (!isKnownNonNegative(AddRec->getOperand(i))) AllNonNeg = false;
3132        if (!isKnownNonPositive(AddRec->getOperand(i))) AllNonPos = false;
3133      }
3134      if (AllNonNeg)
3135        ConservativeResult = ConservativeResult.intersectWith(
3136          ConstantRange(APInt(BitWidth, 0),
3137                        APInt::getSignedMinValue(BitWidth)));
3138      else if (AllNonPos)
3139        ConservativeResult = ConservativeResult.intersectWith(
3140          ConstantRange(APInt::getSignedMinValue(BitWidth),
3141                        APInt(BitWidth, 1)));
3142    }
3143
3144    // TODO: non-affine addrec
3145    if (AddRec->isAffine()) {
3146      const Type *Ty = AddRec->getType();
3147      const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
3148      if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
3149          getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
3150        MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
3151
3152        const SCEV *Start = AddRec->getStart();
3153        const SCEV *Step = AddRec->getStepRecurrence(*this);
3154
3155        ConstantRange StartRange = getSignedRange(Start);
3156        ConstantRange StepRange = getSignedRange(Step);
3157        ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
3158        ConstantRange EndRange =
3159          StartRange.add(MaxBECountRange.multiply(StepRange));
3160
3161        // Check for overflow. This must be done with ConstantRange arithmetic
3162        // because we could be called from within the ScalarEvolution overflow
3163        // checking code.
3164        ConstantRange ExtStartRange = StartRange.sextOrTrunc(BitWidth*2+1);
3165        ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1);
3166        ConstantRange ExtMaxBECountRange =
3167          MaxBECountRange.zextOrTrunc(BitWidth*2+1);
3168        ConstantRange ExtEndRange = EndRange.sextOrTrunc(BitWidth*2+1);
3169        if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) !=
3170            ExtEndRange)
3171          return ConservativeResult;
3172
3173        APInt Min = APIntOps::smin(StartRange.getSignedMin(),
3174                                   EndRange.getSignedMin());
3175        APInt Max = APIntOps::smax(StartRange.getSignedMax(),
3176                                   EndRange.getSignedMax());
3177        if (Min.isMinSignedValue() && Max.isMaxSignedValue())
3178          return ConservativeResult;
3179        return ConservativeResult.intersectWith(ConstantRange(Min, Max+1));
3180      }
3181    }
3182
3183    return ConservativeResult;
3184  }
3185
3186  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
3187    // For a SCEVUnknown, ask ValueTracking.
3188    if (!U->getValue()->getType()->isIntegerTy() && !TD)
3189      return ConservativeResult;
3190    unsigned NS = ComputeNumSignBits(U->getValue(), TD);
3191    if (NS == 1)
3192      return ConservativeResult;
3193    return ConservativeResult.intersectWith(
3194      ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
3195                    APInt::getSignedMaxValue(BitWidth).ashr(NS - 1)+1));
3196  }
3197
3198  return ConservativeResult;
3199}
3200
3201/// createSCEV - We know that there is no SCEV for the specified value.
3202/// Analyze the expression.
3203///
3204const SCEV *ScalarEvolution::createSCEV(Value *V) {
3205  if (!isSCEVable(V->getType()))
3206    return getUnknown(V);
3207
3208  unsigned Opcode = Instruction::UserOp1;
3209  if (Instruction *I = dyn_cast<Instruction>(V)) {
3210    Opcode = I->getOpcode();
3211
3212    // Don't attempt to analyze instructions in blocks that aren't
3213    // reachable. Such instructions don't matter, and they aren't required
3214    // to obey basic rules for definitions dominating uses which this
3215    // analysis depends on.
3216    if (!DT->isReachableFromEntry(I->getParent()))
3217      return getUnknown(V);
3218  } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
3219    Opcode = CE->getOpcode();
3220  else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
3221    return getConstant(CI);
3222  else if (isa<ConstantPointerNull>(V))
3223    return getConstant(V->getType(), 0);
3224  else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
3225    return GA->mayBeOverridden() ? getUnknown(V) : getSCEV(GA->getAliasee());
3226  else
3227    return getUnknown(V);
3228
3229  Operator *U = cast<Operator>(V);
3230  switch (Opcode) {
3231  case Instruction::Add:
3232    return getAddExpr(getSCEV(U->getOperand(0)),
3233                      getSCEV(U->getOperand(1)));
3234  case Instruction::Mul:
3235    return getMulExpr(getSCEV(U->getOperand(0)),
3236                      getSCEV(U->getOperand(1)));
3237  case Instruction::UDiv:
3238    return getUDivExpr(getSCEV(U->getOperand(0)),
3239                       getSCEV(U->getOperand(1)));
3240  case Instruction::Sub:
3241    return getMinusSCEV(getSCEV(U->getOperand(0)),
3242                        getSCEV(U->getOperand(1)));
3243  case Instruction::And:
3244    // For an expression like x&255 that merely masks off the high bits,
3245    // use zext(trunc(x)) as the SCEV expression.
3246    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3247      if (CI->isNullValue())
3248        return getSCEV(U->getOperand(1));
3249      if (CI->isAllOnesValue())
3250        return getSCEV(U->getOperand(0));
3251      const APInt &A = CI->getValue();
3252
3253      // Instcombine's ShrinkDemandedConstant may strip bits out of
3254      // constants, obscuring what would otherwise be a low-bits mask.
3255      // Use ComputeMaskedBits to compute what ShrinkDemandedConstant
3256      // knew about to reconstruct a low-bits mask value.
3257      unsigned LZ = A.countLeadingZeros();
3258      unsigned BitWidth = A.getBitWidth();
3259      APInt AllOnes = APInt::getAllOnesValue(BitWidth);
3260      APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
3261      ComputeMaskedBits(U->getOperand(0), AllOnes, KnownZero, KnownOne, TD);
3262
3263      APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ);
3264
3265      if (LZ != 0 && !((~A & ~KnownZero) & EffectiveMask))
3266        return
3267          getZeroExtendExpr(getTruncateExpr(getSCEV(U->getOperand(0)),
3268                                IntegerType::get(getContext(), BitWidth - LZ)),
3269                            U->getType());
3270    }
3271    break;
3272
3273  case Instruction::Or:
3274    // If the RHS of the Or is a constant, we may have something like:
3275    // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
3276    // optimizations will transparently handle this case.
3277    //
3278    // In order for this transformation to be safe, the LHS must be of the
3279    // form X*(2^n) and the Or constant must be less than 2^n.
3280    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3281      const SCEV *LHS = getSCEV(U->getOperand(0));
3282      const APInt &CIVal = CI->getValue();
3283      if (GetMinTrailingZeros(LHS) >=
3284          (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
3285        // Build a plain add SCEV.
3286        const SCEV *S = getAddExpr(LHS, getSCEV(CI));
3287        // If the LHS of the add was an addrec and it has no-wrap flags,
3288        // transfer the no-wrap flags, since an or won't introduce a wrap.
3289        if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) {
3290          const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS);
3291          if (OldAR->hasNoUnsignedWrap())
3292            const_cast<SCEVAddRecExpr *>(NewAR)->setHasNoUnsignedWrap(true);
3293          if (OldAR->hasNoSignedWrap())
3294            const_cast<SCEVAddRecExpr *>(NewAR)->setHasNoSignedWrap(true);
3295        }
3296        return S;
3297      }
3298    }
3299    break;
3300  case Instruction::Xor:
3301    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3302      // If the RHS of the xor is a signbit, then this is just an add.
3303      // Instcombine turns add of signbit into xor as a strength reduction step.
3304      if (CI->getValue().isSignBit())
3305        return getAddExpr(getSCEV(U->getOperand(0)),
3306                          getSCEV(U->getOperand(1)));
3307
3308      // If the RHS of xor is -1, then this is a not operation.
3309      if (CI->isAllOnesValue())
3310        return getNotSCEV(getSCEV(U->getOperand(0)));
3311
3312      // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
3313      // This is a variant of the check for xor with -1, and it handles
3314      // the case where instcombine has trimmed non-demanded bits out
3315      // of an xor with -1.
3316      if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0)))
3317        if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1)))
3318          if (BO->getOpcode() == Instruction::And &&
3319              LCI->getValue() == CI->getValue())
3320            if (const SCEVZeroExtendExpr *Z =
3321                  dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) {
3322              const Type *UTy = U->getType();
3323              const SCEV *Z0 = Z->getOperand();
3324              const Type *Z0Ty = Z0->getType();
3325              unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
3326
3327              // If C is a low-bits mask, the zero extend is serving to
3328              // mask off the high bits. Complement the operand and
3329              // re-apply the zext.
3330              if (APIntOps::isMask(Z0TySize, CI->getValue()))
3331                return getZeroExtendExpr(getNotSCEV(Z0), UTy);
3332
3333              // If C is a single bit, it may be in the sign-bit position
3334              // before the zero-extend. In this case, represent the xor
3335              // using an add, which is equivalent, and re-apply the zext.
3336              APInt Trunc = APInt(CI->getValue()).trunc(Z0TySize);
3337              if (APInt(Trunc).zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
3338                  Trunc.isSignBit())
3339                return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
3340                                         UTy);
3341            }
3342    }
3343    break;
3344
3345  case Instruction::Shl:
3346    // Turn shift left of a constant amount into a multiply.
3347    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
3348      uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
3349
3350      // If the shift count is not less than the bitwidth, the result of
3351      // the shift is undefined. Don't try to analyze it, because the
3352      // resolution chosen here may differ from the resolution chosen in
3353      // other parts of the compiler.
3354      if (SA->getValue().uge(BitWidth))
3355        break;
3356
3357      Constant *X = ConstantInt::get(getContext(),
3358        APInt(BitWidth, 1).shl(SA->getZExtValue()));
3359      return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
3360    }
3361    break;
3362
3363  case Instruction::LShr:
3364    // Turn logical shift right of a constant into a unsigned divide.
3365    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
3366      uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
3367
3368      // If the shift count is not less than the bitwidth, the result of
3369      // the shift is undefined. Don't try to analyze it, because the
3370      // resolution chosen here may differ from the resolution chosen in
3371      // other parts of the compiler.
3372      if (SA->getValue().uge(BitWidth))
3373        break;
3374
3375      Constant *X = ConstantInt::get(getContext(),
3376        APInt(BitWidth, 1).shl(SA->getZExtValue()));
3377      return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
3378    }
3379    break;
3380
3381  case Instruction::AShr:
3382    // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
3383    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1)))
3384      if (Operator *L = dyn_cast<Operator>(U->getOperand(0)))
3385        if (L->getOpcode() == Instruction::Shl &&
3386            L->getOperand(1) == U->getOperand(1)) {
3387          uint64_t BitWidth = getTypeSizeInBits(U->getType());
3388
3389          // If the shift count is not less than the bitwidth, the result of
3390          // the shift is undefined. Don't try to analyze it, because the
3391          // resolution chosen here may differ from the resolution chosen in
3392          // other parts of the compiler.
3393          if (CI->getValue().uge(BitWidth))
3394            break;
3395
3396          uint64_t Amt = BitWidth - CI->getZExtValue();
3397          if (Amt == BitWidth)
3398            return getSCEV(L->getOperand(0));       // shift by zero --> noop
3399          return
3400            getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)),
3401                                              IntegerType::get(getContext(),
3402                                                               Amt)),
3403                              U->getType());
3404        }
3405    break;
3406
3407  case Instruction::Trunc:
3408    return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
3409
3410  case Instruction::ZExt:
3411    return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
3412
3413  case Instruction::SExt:
3414    return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
3415
3416  case Instruction::BitCast:
3417    // BitCasts are no-op casts so we just eliminate the cast.
3418    if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
3419      return getSCEV(U->getOperand(0));
3420    break;
3421
3422  // It's tempting to handle inttoptr and ptrtoint as no-ops, however this can
3423  // lead to pointer expressions which cannot safely be expanded to GEPs,
3424  // because ScalarEvolution doesn't respect the GEP aliasing rules when
3425  // simplifying integer expressions.
3426
3427  case Instruction::GetElementPtr:
3428    return createNodeForGEP(cast<GEPOperator>(U));
3429
3430  case Instruction::PHI:
3431    return createNodeForPHI(cast<PHINode>(U));
3432
3433  case Instruction::Select:
3434    // This could be a smax or umax that was lowered earlier.
3435    // Try to recover it.
3436    if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
3437      Value *LHS = ICI->getOperand(0);
3438      Value *RHS = ICI->getOperand(1);
3439      switch (ICI->getPredicate()) {
3440      case ICmpInst::ICMP_SLT:
3441      case ICmpInst::ICMP_SLE:
3442        std::swap(LHS, RHS);
3443        // fall through
3444      case ICmpInst::ICMP_SGT:
3445      case ICmpInst::ICMP_SGE:
3446        // a >s b ? a+x : b+x  ->  smax(a, b)+x
3447        // a >s b ? b+x : a+x  ->  smin(a, b)+x
3448        if (LHS->getType() == U->getType()) {
3449          const SCEV *LS = getSCEV(LHS);
3450          const SCEV *RS = getSCEV(RHS);
3451          const SCEV *LA = getSCEV(U->getOperand(1));
3452          const SCEV *RA = getSCEV(U->getOperand(2));
3453          const SCEV *LDiff = getMinusSCEV(LA, LS);
3454          const SCEV *RDiff = getMinusSCEV(RA, RS);
3455          if (LDiff == RDiff)
3456            return getAddExpr(getSMaxExpr(LS, RS), LDiff);
3457          LDiff = getMinusSCEV(LA, RS);
3458          RDiff = getMinusSCEV(RA, LS);
3459          if (LDiff == RDiff)
3460            return getAddExpr(getSMinExpr(LS, RS), LDiff);
3461        }
3462        break;
3463      case ICmpInst::ICMP_ULT:
3464      case ICmpInst::ICMP_ULE:
3465        std::swap(LHS, RHS);
3466        // fall through
3467      case ICmpInst::ICMP_UGT:
3468      case ICmpInst::ICMP_UGE:
3469        // a >u b ? a+x : b+x  ->  umax(a, b)+x
3470        // a >u b ? b+x : a+x  ->  umin(a, b)+x
3471        if (LHS->getType() == U->getType()) {
3472          const SCEV *LS = getSCEV(LHS);
3473          const SCEV *RS = getSCEV(RHS);
3474          const SCEV *LA = getSCEV(U->getOperand(1));
3475          const SCEV *RA = getSCEV(U->getOperand(2));
3476          const SCEV *LDiff = getMinusSCEV(LA, LS);
3477          const SCEV *RDiff = getMinusSCEV(RA, RS);
3478          if (LDiff == RDiff)
3479            return getAddExpr(getUMaxExpr(LS, RS), LDiff);
3480          LDiff = getMinusSCEV(LA, RS);
3481          RDiff = getMinusSCEV(RA, LS);
3482          if (LDiff == RDiff)
3483            return getAddExpr(getUMinExpr(LS, RS), LDiff);
3484        }
3485        break;
3486      case ICmpInst::ICMP_NE:
3487        // n != 0 ? n+x : 1+x  ->  umax(n, 1)+x
3488        if (LHS->getType() == U->getType() &&
3489            isa<ConstantInt>(RHS) &&
3490            cast<ConstantInt>(RHS)->isZero()) {
3491          const SCEV *One = getConstant(LHS->getType(), 1);
3492          const SCEV *LS = getSCEV(LHS);
3493          const SCEV *LA = getSCEV(U->getOperand(1));
3494          const SCEV *RA = getSCEV(U->getOperand(2));
3495          const SCEV *LDiff = getMinusSCEV(LA, LS);
3496          const SCEV *RDiff = getMinusSCEV(RA, One);
3497          if (LDiff == RDiff)
3498            return getAddExpr(getUMaxExpr(LS, One), LDiff);
3499        }
3500        break;
3501      case ICmpInst::ICMP_EQ:
3502        // n == 0 ? 1+x : n+x  ->  umax(n, 1)+x
3503        if (LHS->getType() == U->getType() &&
3504            isa<ConstantInt>(RHS) &&
3505            cast<ConstantInt>(RHS)->isZero()) {
3506          const SCEV *One = getConstant(LHS->getType(), 1);
3507          const SCEV *LS = getSCEV(LHS);
3508          const SCEV *LA = getSCEV(U->getOperand(1));
3509          const SCEV *RA = getSCEV(U->getOperand(2));
3510          const SCEV *LDiff = getMinusSCEV(LA, One);
3511          const SCEV *RDiff = getMinusSCEV(RA, LS);
3512          if (LDiff == RDiff)
3513            return getAddExpr(getUMaxExpr(LS, One), LDiff);
3514        }
3515        break;
3516      default:
3517        break;
3518      }
3519    }
3520
3521  default: // We cannot analyze this expression.
3522    break;
3523  }
3524
3525  return getUnknown(V);
3526}
3527
3528
3529
3530//===----------------------------------------------------------------------===//
3531//                   Iteration Count Computation Code
3532//
3533
3534/// getBackedgeTakenCount - If the specified loop has a predictable
3535/// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute
3536/// object. The backedge-taken count is the number of times the loop header
3537/// will be branched to from within the loop. This is one less than the
3538/// trip count of the loop, since it doesn't count the first iteration,
3539/// when the header is branched to from outside the loop.
3540///
3541/// Note that it is not valid to call this method on a loop without a
3542/// loop-invariant backedge-taken count (see
3543/// hasLoopInvariantBackedgeTakenCount).
3544///
3545const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) {
3546  return getBackedgeTakenInfo(L).Exact;
3547}
3548
3549/// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except
3550/// return the least SCEV value that is known never to be less than the
3551/// actual backedge taken count.
3552const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {
3553  return getBackedgeTakenInfo(L).Max;
3554}
3555
3556/// PushLoopPHIs - Push PHI nodes in the header of the given loop
3557/// onto the given Worklist.
3558static void
3559PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
3560  BasicBlock *Header = L->getHeader();
3561
3562  // Push all Loop-header PHIs onto the Worklist stack.
3563  for (BasicBlock::iterator I = Header->begin();
3564       PHINode *PN = dyn_cast<PHINode>(I); ++I)
3565    Worklist.push_back(PN);
3566}
3567
3568const ScalarEvolution::BackedgeTakenInfo &
3569ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
3570  // Initially insert a CouldNotCompute for this loop. If the insertion
3571  // succeeds, proceed to actually compute a backedge-taken count and
3572  // update the value. The temporary CouldNotCompute value tells SCEV
3573  // code elsewhere that it shouldn't attempt to request a new
3574  // backedge-taken count, which could result in infinite recursion.
3575  std::pair<std::map<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
3576    BackedgeTakenCounts.insert(std::make_pair(L, getCouldNotCompute()));
3577  if (Pair.second) {
3578    BackedgeTakenInfo BECount = ComputeBackedgeTakenCount(L);
3579    if (BECount.Exact != getCouldNotCompute()) {
3580      assert(BECount.Exact->isLoopInvariant(L) &&
3581             BECount.Max->isLoopInvariant(L) &&
3582             "Computed backedge-taken count isn't loop invariant for loop!");
3583      ++NumTripCountsComputed;
3584
3585      // Update the value in the map.
3586      Pair.first->second = BECount;
3587    } else {
3588      if (BECount.Max != getCouldNotCompute())
3589        // Update the value in the map.
3590        Pair.first->second = BECount;
3591      if (isa<PHINode>(L->getHeader()->begin()))
3592        // Only count loops that have phi nodes as not being computable.
3593        ++NumTripCountsNotComputed;
3594    }
3595
3596    // Now that we know more about the trip count for this loop, forget any
3597    // existing SCEV values for PHI nodes in this loop since they are only
3598    // conservative estimates made without the benefit of trip count
3599    // information. This is similar to the code in forgetLoop, except that
3600    // it handles SCEVUnknown PHI nodes specially.
3601    if (BECount.hasAnyInfo()) {
3602      SmallVector<Instruction *, 16> Worklist;
3603      PushLoopPHIs(L, Worklist);
3604
3605      SmallPtrSet<Instruction *, 8> Visited;
3606      while (!Worklist.empty()) {
3607        Instruction *I = Worklist.pop_back_val();
3608        if (!Visited.insert(I)) continue;
3609
3610        std::map<SCEVCallbackVH, const SCEV *>::iterator It =
3611          Scalars.find(static_cast<Value *>(I));
3612        if (It != Scalars.end()) {
3613          // SCEVUnknown for a PHI either means that it has an unrecognized
3614          // structure, or it's a PHI that's in the progress of being computed
3615          // by createNodeForPHI.  In the former case, additional loop trip
3616          // count information isn't going to change anything. In the later
3617          // case, createNodeForPHI will perform the necessary updates on its
3618          // own when it gets to that point.
3619          if (!isa<PHINode>(I) || !isa<SCEVUnknown>(It->second)) {
3620            ValuesAtScopes.erase(It->second);
3621            Scalars.erase(It);
3622          }
3623          if (PHINode *PN = dyn_cast<PHINode>(I))
3624            ConstantEvolutionLoopExitValue.erase(PN);
3625        }
3626
3627        PushDefUseChildren(I, Worklist);
3628      }
3629    }
3630  }
3631  return Pair.first->second;
3632}
3633
3634/// forgetLoop - This method should be called by the client when it has
3635/// changed a loop in a way that may effect ScalarEvolution's ability to
3636/// compute a trip count, or if the loop is deleted.
3637void ScalarEvolution::forgetLoop(const Loop *L) {
3638  // Drop any stored trip count value.
3639  BackedgeTakenCounts.erase(L);
3640
3641  // Drop information about expressions based on loop-header PHIs.
3642  SmallVector<Instruction *, 16> Worklist;
3643  PushLoopPHIs(L, Worklist);
3644
3645  SmallPtrSet<Instruction *, 8> Visited;
3646  while (!Worklist.empty()) {
3647    Instruction *I = Worklist.pop_back_val();
3648    if (!Visited.insert(I)) continue;
3649
3650    std::map<SCEVCallbackVH, const SCEV *>::iterator It =
3651      Scalars.find(static_cast<Value *>(I));
3652    if (It != Scalars.end()) {
3653      ValuesAtScopes.erase(It->second);
3654      Scalars.erase(It);
3655      if (PHINode *PN = dyn_cast<PHINode>(I))
3656        ConstantEvolutionLoopExitValue.erase(PN);
3657    }
3658
3659    PushDefUseChildren(I, Worklist);
3660  }
3661}
3662
3663/// forgetValue - This method should be called by the client when it has
3664/// changed a value in a way that may effect its value, or which may
3665/// disconnect it from a def-use chain linking it to a loop.
3666void ScalarEvolution::forgetValue(Value *V) {
3667  Instruction *I = dyn_cast<Instruction>(V);
3668  if (!I) return;
3669
3670  // Drop information about expressions based on loop-header PHIs.
3671  SmallVector<Instruction *, 16> Worklist;
3672  Worklist.push_back(I);
3673
3674  SmallPtrSet<Instruction *, 8> Visited;
3675  while (!Worklist.empty()) {
3676    I = Worklist.pop_back_val();
3677    if (!Visited.insert(I)) continue;
3678
3679    std::map<SCEVCallbackVH, const SCEV *>::iterator It =
3680      Scalars.find(static_cast<Value *>(I));
3681    if (It != Scalars.end()) {
3682      ValuesAtScopes.erase(It->second);
3683      Scalars.erase(It);
3684      if (PHINode *PN = dyn_cast<PHINode>(I))
3685        ConstantEvolutionLoopExitValue.erase(PN);
3686    }
3687
3688    // If there's a SCEVUnknown tying this value into the SCEV
3689    // space, remove it from the folding set map. The SCEVUnknown
3690    // object and any other SCEV objects which reference it
3691    // (transitively) remain allocated, effectively leaked until
3692    // the underlying BumpPtrAllocator is freed.
3693    //
3694    // This permits SCEV pointers to be used as keys in maps
3695    // such as the ValuesAtScopes map.
3696    FoldingSetNodeID ID;
3697    ID.AddInteger(scUnknown);
3698    ID.AddPointer(I);
3699    void *IP;
3700    if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
3701      UniqueSCEVs.RemoveNode(S);
3702
3703      // This isn't necessary, but we might as well remove the
3704      // value from the ValuesAtScopes map too.
3705      ValuesAtScopes.erase(S);
3706    }
3707
3708    PushDefUseChildren(I, Worklist);
3709  }
3710}
3711
3712/// ComputeBackedgeTakenCount - Compute the number of times the backedge
3713/// of the specified loop will execute.
3714ScalarEvolution::BackedgeTakenInfo
3715ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
3716  SmallVector<BasicBlock *, 8> ExitingBlocks;
3717  L->getExitingBlocks(ExitingBlocks);
3718
3719  // Examine all exits and pick the most conservative values.
3720  const SCEV *BECount = getCouldNotCompute();
3721  const SCEV *MaxBECount = getCouldNotCompute();
3722  bool CouldNotComputeBECount = false;
3723  for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
3724    BackedgeTakenInfo NewBTI =
3725      ComputeBackedgeTakenCountFromExit(L, ExitingBlocks[i]);
3726
3727    if (NewBTI.Exact == getCouldNotCompute()) {
3728      // We couldn't compute an exact value for this exit, so
3729      // we won't be able to compute an exact value for the loop.
3730      CouldNotComputeBECount = true;
3731      BECount = getCouldNotCompute();
3732    } else if (!CouldNotComputeBECount) {
3733      if (BECount == getCouldNotCompute())
3734        BECount = NewBTI.Exact;
3735      else
3736        BECount = getUMinFromMismatchedTypes(BECount, NewBTI.Exact);
3737    }
3738    if (MaxBECount == getCouldNotCompute())
3739      MaxBECount = NewBTI.Max;
3740    else if (NewBTI.Max != getCouldNotCompute())
3741      MaxBECount = getUMinFromMismatchedTypes(MaxBECount, NewBTI.Max);
3742  }
3743
3744  return BackedgeTakenInfo(BECount, MaxBECount);
3745}
3746
3747/// ComputeBackedgeTakenCountFromExit - Compute the number of times the backedge
3748/// of the specified loop will execute if it exits via the specified block.
3749ScalarEvolution::BackedgeTakenInfo
3750ScalarEvolution::ComputeBackedgeTakenCountFromExit(const Loop *L,
3751                                                   BasicBlock *ExitingBlock) {
3752
3753  // Okay, we've chosen an exiting block.  See what condition causes us to
3754  // exit at this block.
3755  //
3756  // FIXME: we should be able to handle switch instructions (with a single exit)
3757  BranchInst *ExitBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
3758  if (ExitBr == 0) return getCouldNotCompute();
3759  assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!");
3760
3761  // At this point, we know we have a conditional branch that determines whether
3762  // the loop is exited.  However, we don't know if the branch is executed each
3763  // time through the loop.  If not, then the execution count of the branch will
3764  // not be equal to the trip count of the loop.
3765  //
3766  // Currently we check for this by checking to see if the Exit branch goes to
3767  // the loop header.  If so, we know it will always execute the same number of
3768  // times as the loop.  We also handle the case where the exit block *is* the
3769  // loop header.  This is common for un-rotated loops.
3770  //
3771  // If both of those tests fail, walk up the unique predecessor chain to the
3772  // header, stopping if there is an edge that doesn't exit the loop. If the
3773  // header is reached, the execution count of the branch will be equal to the
3774  // trip count of the loop.
3775  //
3776  //  More extensive analysis could be done to handle more cases here.
3777  //
3778  if (ExitBr->getSuccessor(0) != L->getHeader() &&
3779      ExitBr->getSuccessor(1) != L->getHeader() &&
3780      ExitBr->getParent() != L->getHeader()) {
3781    // The simple checks failed, try climbing the unique predecessor chain
3782    // up to the header.
3783    bool Ok = false;
3784    for (BasicBlock *BB = ExitBr->getParent(); BB; ) {
3785      BasicBlock *Pred = BB->getUniquePredecessor();
3786      if (!Pred)
3787        return getCouldNotCompute();
3788      TerminatorInst *PredTerm = Pred->getTerminator();
3789      for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) {
3790        BasicBlock *PredSucc = PredTerm->getSuccessor(i);
3791        if (PredSucc == BB)
3792          continue;
3793        // If the predecessor has a successor that isn't BB and isn't
3794        // outside the loop, assume the worst.
3795        if (L->contains(PredSucc))
3796          return getCouldNotCompute();
3797      }
3798      if (Pred == L->getHeader()) {
3799        Ok = true;
3800        break;
3801      }
3802      BB = Pred;
3803    }
3804    if (!Ok)
3805      return getCouldNotCompute();
3806  }
3807
3808  // Proceed to the next level to examine the exit condition expression.
3809  return ComputeBackedgeTakenCountFromExitCond(L, ExitBr->getCondition(),
3810                                               ExitBr->getSuccessor(0),
3811                                               ExitBr->getSuccessor(1));
3812}
3813
3814/// ComputeBackedgeTakenCountFromExitCond - Compute the number of times the
3815/// backedge of the specified loop will execute if its exit condition
3816/// were a conditional branch of ExitCond, TBB, and FBB.
3817ScalarEvolution::BackedgeTakenInfo
3818ScalarEvolution::ComputeBackedgeTakenCountFromExitCond(const Loop *L,
3819                                                       Value *ExitCond,
3820                                                       BasicBlock *TBB,
3821                                                       BasicBlock *FBB) {
3822  // Check if the controlling expression for this loop is an And or Or.
3823  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
3824    if (BO->getOpcode() == Instruction::And) {
3825      // Recurse on the operands of the and.
3826      BackedgeTakenInfo BTI0 =
3827        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB);
3828      BackedgeTakenInfo BTI1 =
3829        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB);
3830      const SCEV *BECount = getCouldNotCompute();
3831      const SCEV *MaxBECount = getCouldNotCompute();
3832      if (L->contains(TBB)) {
3833        // Both conditions must be true for the loop to continue executing.
3834        // Choose the less conservative count.
3835        if (BTI0.Exact == getCouldNotCompute() ||
3836            BTI1.Exact == getCouldNotCompute())
3837          BECount = getCouldNotCompute();
3838        else
3839          BECount = getUMinFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
3840        if (BTI0.Max == getCouldNotCompute())
3841          MaxBECount = BTI1.Max;
3842        else if (BTI1.Max == getCouldNotCompute())
3843          MaxBECount = BTI0.Max;
3844        else
3845          MaxBECount = getUMinFromMismatchedTypes(BTI0.Max, BTI1.Max);
3846      } else {
3847        // Both conditions must be true for the loop to exit.
3848        assert(L->contains(FBB) && "Loop block has no successor in loop!");
3849        if (BTI0.Exact != getCouldNotCompute() &&
3850            BTI1.Exact != getCouldNotCompute())
3851          BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
3852        if (BTI0.Max != getCouldNotCompute() &&
3853            BTI1.Max != getCouldNotCompute())
3854          MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max);
3855      }
3856
3857      return BackedgeTakenInfo(BECount, MaxBECount);
3858    }
3859    if (BO->getOpcode() == Instruction::Or) {
3860      // Recurse on the operands of the or.
3861      BackedgeTakenInfo BTI0 =
3862        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB);
3863      BackedgeTakenInfo BTI1 =
3864        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB);
3865      const SCEV *BECount = getCouldNotCompute();
3866      const SCEV *MaxBECount = getCouldNotCompute();
3867      if (L->contains(FBB)) {
3868        // Both conditions must be false for the loop to continue executing.
3869        // Choose the less conservative count.
3870        if (BTI0.Exact == getCouldNotCompute() ||
3871            BTI1.Exact == getCouldNotCompute())
3872          BECount = getCouldNotCompute();
3873        else
3874          BECount = getUMinFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
3875        if (BTI0.Max == getCouldNotCompute())
3876          MaxBECount = BTI1.Max;
3877        else if (BTI1.Max == getCouldNotCompute())
3878          MaxBECount = BTI0.Max;
3879        else
3880          MaxBECount = getUMinFromMismatchedTypes(BTI0.Max, BTI1.Max);
3881      } else {
3882        // Both conditions must be false for the loop to exit.
3883        assert(L->contains(TBB) && "Loop block has no successor in loop!");
3884        if (BTI0.Exact != getCouldNotCompute() &&
3885            BTI1.Exact != getCouldNotCompute())
3886          BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
3887        if (BTI0.Max != getCouldNotCompute() &&
3888            BTI1.Max != getCouldNotCompute())
3889          MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max);
3890      }
3891
3892      return BackedgeTakenInfo(BECount, MaxBECount);
3893    }
3894  }
3895
3896  // With an icmp, it may be feasible to compute an exact backedge-taken count.
3897  // Proceed to the next level to examine the icmp.
3898  if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond))
3899    return ComputeBackedgeTakenCountFromExitCondICmp(L, ExitCondICmp, TBB, FBB);
3900
3901  // Check for a constant condition. These are normally stripped out by
3902  // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
3903  // preserve the CFG and is temporarily leaving constant conditions
3904  // in place.
3905  if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
3906    if (L->contains(FBB) == !CI->getZExtValue())
3907      // The backedge is always taken.
3908      return getCouldNotCompute();
3909    else
3910      // The backedge is never taken.
3911      return getConstant(CI->getType(), 0);
3912  }
3913
3914  // If it's not an integer or pointer comparison then compute it the hard way.
3915  return ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB));
3916}
3917
3918/// ComputeBackedgeTakenCountFromExitCondICmp - Compute the number of times the
3919/// backedge of the specified loop will execute if its exit condition
3920/// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB.
3921ScalarEvolution::BackedgeTakenInfo
3922ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L,
3923                                                           ICmpInst *ExitCond,
3924                                                           BasicBlock *TBB,
3925                                                           BasicBlock *FBB) {
3926
3927  // If the condition was exit on true, convert the condition to exit on false
3928  ICmpInst::Predicate Cond;
3929  if (!L->contains(FBB))
3930    Cond = ExitCond->getPredicate();
3931  else
3932    Cond = ExitCond->getInversePredicate();
3933
3934  // Handle common loops like: for (X = "string"; *X; ++X)
3935  if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
3936    if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
3937      BackedgeTakenInfo ItCnt =
3938        ComputeLoadConstantCompareBackedgeTakenCount(LI, RHS, L, Cond);
3939      if (ItCnt.hasAnyInfo())
3940        return ItCnt;
3941    }
3942
3943  const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
3944  const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
3945
3946  // Try to evaluate any dependencies out of the loop.
3947  LHS = getSCEVAtScope(LHS, L);
3948  RHS = getSCEVAtScope(RHS, L);
3949
3950  // At this point, we would like to compute how many iterations of the
3951  // loop the predicate will return true for these inputs.
3952  if (LHS->isLoopInvariant(L) && !RHS->isLoopInvariant(L)) {
3953    // If there is a loop-invariant, force it into the RHS.
3954    std::swap(LHS, RHS);
3955    Cond = ICmpInst::getSwappedPredicate(Cond);
3956  }
3957
3958  // Simplify the operands before analyzing them.
3959  (void)SimplifyICmpOperands(Cond, LHS, RHS);
3960
3961  // If we have a comparison of a chrec against a constant, try to use value
3962  // ranges to answer this query.
3963  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
3964    if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
3965      if (AddRec->getLoop() == L) {
3966        // Form the constant range.
3967        ConstantRange CompRange(
3968            ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue()));
3969
3970        const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
3971        if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
3972      }
3973
3974  switch (Cond) {
3975  case ICmpInst::ICMP_NE: {                     // while (X != Y)
3976    // Convert to: while (X-Y != 0)
3977    BackedgeTakenInfo BTI = HowFarToZero(getMinusSCEV(LHS, RHS), L);
3978    if (BTI.hasAnyInfo()) return BTI;
3979    break;
3980  }
3981  case ICmpInst::ICMP_EQ: {                     // while (X == Y)
3982    // Convert to: while (X-Y == 0)
3983    BackedgeTakenInfo BTI = HowFarToNonZero(getMinusSCEV(LHS, RHS), L);
3984    if (BTI.hasAnyInfo()) return BTI;
3985    break;
3986  }
3987  case ICmpInst::ICMP_SLT: {
3988    BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, true);
3989    if (BTI.hasAnyInfo()) return BTI;
3990    break;
3991  }
3992  case ICmpInst::ICMP_SGT: {
3993    BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS),
3994                                             getNotSCEV(RHS), L, true);
3995    if (BTI.hasAnyInfo()) return BTI;
3996    break;
3997  }
3998  case ICmpInst::ICMP_ULT: {
3999    BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, false);
4000    if (BTI.hasAnyInfo()) return BTI;
4001    break;
4002  }
4003  case ICmpInst::ICMP_UGT: {
4004    BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS),
4005                                             getNotSCEV(RHS), L, false);
4006    if (BTI.hasAnyInfo()) return BTI;
4007    break;
4008  }
4009  default:
4010#if 0
4011    dbgs() << "ComputeBackedgeTakenCount ";
4012    if (ExitCond->getOperand(0)->getType()->isUnsigned())
4013      dbgs() << "[unsigned] ";
4014    dbgs() << *LHS << "   "
4015         << Instruction::getOpcodeName(Instruction::ICmp)
4016         << "   " << *RHS << "\n";
4017#endif
4018    break;
4019  }
4020  return
4021    ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB));
4022}
4023
4024static ConstantInt *
4025EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
4026                                ScalarEvolution &SE) {
4027  const SCEV *InVal = SE.getConstant(C);
4028  const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
4029  assert(isa<SCEVConstant>(Val) &&
4030         "Evaluation of SCEV at constant didn't fold correctly?");
4031  return cast<SCEVConstant>(Val)->getValue();
4032}
4033
4034/// GetAddressedElementFromGlobal - Given a global variable with an initializer
4035/// and a GEP expression (missing the pointer index) indexing into it, return
4036/// the addressed element of the initializer or null if the index expression is
4037/// invalid.
4038static Constant *
4039GetAddressedElementFromGlobal(GlobalVariable *GV,
4040                              const std::vector<ConstantInt*> &Indices) {
4041  Constant *Init = GV->getInitializer();
4042  for (unsigned i = 0, e = Indices.size(); i != e; ++i) {
4043    uint64_t Idx = Indices[i]->getZExtValue();
4044    if (ConstantStruct *CS = dyn_cast<ConstantStruct>(Init)) {
4045      assert(Idx < CS->getNumOperands() && "Bad struct index!");
4046      Init = cast<Constant>(CS->getOperand(Idx));
4047    } else if (ConstantArray *CA = dyn_cast<ConstantArray>(Init)) {
4048      if (Idx >= CA->getNumOperands()) return 0;  // Bogus program
4049      Init = cast<Constant>(CA->getOperand(Idx));
4050    } else if (isa<ConstantAggregateZero>(Init)) {
4051      if (const StructType *STy = dyn_cast<StructType>(Init->getType())) {
4052        assert(Idx < STy->getNumElements() && "Bad struct index!");
4053        Init = Constant::getNullValue(STy->getElementType(Idx));
4054      } else if (const ArrayType *ATy = dyn_cast<ArrayType>(Init->getType())) {
4055        if (Idx >= ATy->getNumElements()) return 0;  // Bogus program
4056        Init = Constant::getNullValue(ATy->getElementType());
4057      } else {
4058        llvm_unreachable("Unknown constant aggregate type!");
4059      }
4060      return 0;
4061    } else {
4062      return 0; // Unknown initializer type
4063    }
4064  }
4065  return Init;
4066}
4067
4068/// ComputeLoadConstantCompareBackedgeTakenCount - Given an exit condition of
4069/// 'icmp op load X, cst', try to see if we can compute the backedge
4070/// execution count.
4071ScalarEvolution::BackedgeTakenInfo
4072ScalarEvolution::ComputeLoadConstantCompareBackedgeTakenCount(
4073                                                LoadInst *LI,
4074                                                Constant *RHS,
4075                                                const Loop *L,
4076                                                ICmpInst::Predicate predicate) {
4077  if (LI->isVolatile()) return getCouldNotCompute();
4078
4079  // Check to see if the loaded pointer is a getelementptr of a global.
4080  // TODO: Use SCEV instead of manually grubbing with GEPs.
4081  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0));
4082  if (!GEP) return getCouldNotCompute();
4083
4084  // Make sure that it is really a constant global we are gepping, with an
4085  // initializer, and make sure the first IDX is really 0.
4086  GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
4087  if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() ||
4088      GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) ||
4089      !cast<Constant>(GEP->getOperand(1))->isNullValue())
4090    return getCouldNotCompute();
4091
4092  // Okay, we allow one non-constant index into the GEP instruction.
4093  Value *VarIdx = 0;
4094  std::vector<ConstantInt*> Indexes;
4095  unsigned VarIdxNum = 0;
4096  for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i)
4097    if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
4098      Indexes.push_back(CI);
4099    } else if (!isa<ConstantInt>(GEP->getOperand(i))) {
4100      if (VarIdx) return getCouldNotCompute();  // Multiple non-constant idx's.
4101      VarIdx = GEP->getOperand(i);
4102      VarIdxNum = i-2;
4103      Indexes.push_back(0);
4104    }
4105
4106  // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant.
4107  // Check to see if X is a loop variant variable value now.
4108  const SCEV *Idx = getSCEV(VarIdx);
4109  Idx = getSCEVAtScope(Idx, L);
4110
4111  // We can only recognize very limited forms of loop index expressions, in
4112  // particular, only affine AddRec's like {C1,+,C2}.
4113  const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx);
4114  if (!IdxExpr || !IdxExpr->isAffine() || IdxExpr->isLoopInvariant(L) ||
4115      !isa<SCEVConstant>(IdxExpr->getOperand(0)) ||
4116      !isa<SCEVConstant>(IdxExpr->getOperand(1)))
4117    return getCouldNotCompute();
4118
4119  unsigned MaxSteps = MaxBruteForceIterations;
4120  for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
4121    ConstantInt *ItCst = ConstantInt::get(
4122                           cast<IntegerType>(IdxExpr->getType()), IterationNum);
4123    ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this);
4124
4125    // Form the GEP offset.
4126    Indexes[VarIdxNum] = Val;
4127
4128    Constant *Result = GetAddressedElementFromGlobal(GV, Indexes);
4129    if (Result == 0) break;  // Cannot compute!
4130
4131    // Evaluate the condition for this iteration.
4132    Result = ConstantExpr::getICmp(predicate, Result, RHS);
4133    if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
4134    if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
4135#if 0
4136      dbgs() << "\n***\n*** Computed loop count " << *ItCst
4137             << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
4138             << "***\n";
4139#endif
4140      ++NumArrayLenItCounts;
4141      return getConstant(ItCst);   // Found terminating iteration!
4142    }
4143  }
4144  return getCouldNotCompute();
4145}
4146
4147
4148/// CanConstantFold - Return true if we can constant fold an instruction of the
4149/// specified type, assuming that all operands were constants.
4150static bool CanConstantFold(const Instruction *I) {
4151  if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
4152      isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I))
4153    return true;
4154
4155  if (const CallInst *CI = dyn_cast<CallInst>(I))
4156    if (const Function *F = CI->getCalledFunction())
4157      return canConstantFoldCallTo(F);
4158  return false;
4159}
4160
4161/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
4162/// in the loop that V is derived from.  We allow arbitrary operations along the
4163/// way, but the operands of an operation must either be constants or a value
4164/// derived from a constant PHI.  If this expression does not fit with these
4165/// constraints, return null.
4166static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
4167  // If this is not an instruction, or if this is an instruction outside of the
4168  // loop, it can't be derived from a loop PHI.
4169  Instruction *I = dyn_cast<Instruction>(V);
4170  if (I == 0 || !L->contains(I)) return 0;
4171
4172  if (PHINode *PN = dyn_cast<PHINode>(I)) {
4173    if (L->getHeader() == I->getParent())
4174      return PN;
4175    else
4176      // We don't currently keep track of the control flow needed to evaluate
4177      // PHIs, so we cannot handle PHIs inside of loops.
4178      return 0;
4179  }
4180
4181  // If we won't be able to constant fold this expression even if the operands
4182  // are constants, return early.
4183  if (!CanConstantFold(I)) return 0;
4184
4185  // Otherwise, we can evaluate this instruction if all of its operands are
4186  // constant or derived from a PHI node themselves.
4187  PHINode *PHI = 0;
4188  for (unsigned Op = 0, e = I->getNumOperands(); Op != e; ++Op)
4189    if (!isa<Constant>(I->getOperand(Op))) {
4190      PHINode *P = getConstantEvolvingPHI(I->getOperand(Op), L);
4191      if (P == 0) return 0;  // Not evolving from PHI
4192      if (PHI == 0)
4193        PHI = P;
4194      else if (PHI != P)
4195        return 0;  // Evolving from multiple different PHIs.
4196    }
4197
4198  // This is a expression evolving from a constant PHI!
4199  return PHI;
4200}
4201
4202/// EvaluateExpression - Given an expression that passes the
4203/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
4204/// in the loop has the value PHIVal.  If we can't fold this expression for some
4205/// reason, return null.
4206static Constant *EvaluateExpression(Value *V, Constant *PHIVal,
4207                                    const TargetData *TD) {
4208  if (isa<PHINode>(V)) return PHIVal;
4209  if (Constant *C = dyn_cast<Constant>(V)) return C;
4210  Instruction *I = cast<Instruction>(V);
4211
4212  std::vector<Constant*> Operands(I->getNumOperands());
4213
4214  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
4215    Operands[i] = EvaluateExpression(I->getOperand(i), PHIVal, TD);
4216    if (Operands[i] == 0) return 0;
4217  }
4218
4219  if (const CmpInst *CI = dyn_cast<CmpInst>(I))
4220    return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
4221                                           Operands[1], TD);
4222  return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
4223                                  &Operands[0], Operands.size(), TD);
4224}
4225
4226/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
4227/// in the header of its containing loop, we know the loop executes a
4228/// constant number of times, and the PHI node is just a recurrence
4229/// involving constants, fold it.
4230Constant *
4231ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
4232                                                   const APInt &BEs,
4233                                                   const Loop *L) {
4234  std::map<PHINode*, Constant*>::iterator I =
4235    ConstantEvolutionLoopExitValue.find(PN);
4236  if (I != ConstantEvolutionLoopExitValue.end())
4237    return I->second;
4238
4239  if (BEs.ugt(MaxBruteForceIterations))
4240    return ConstantEvolutionLoopExitValue[PN] = 0;  // Not going to evaluate it.
4241
4242  Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
4243
4244  // Since the loop is canonicalized, the PHI node must have two entries.  One
4245  // entry must be a constant (coming in from outside of the loop), and the
4246  // second must be derived from the same PHI.
4247  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
4248  Constant *StartCST =
4249    dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
4250  if (StartCST == 0)
4251    return RetVal = 0;  // Must be a constant.
4252
4253  Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
4254  if (getConstantEvolvingPHI(BEValue, L) != PN &&
4255      !isa<Constant>(BEValue))
4256    return RetVal = 0;  // Not derived from same PHI.
4257
4258  // Execute the loop symbolically to determine the exit value.
4259  if (BEs.getActiveBits() >= 32)
4260    return RetVal = 0; // More than 2^32-1 iterations?? Not doing it!
4261
4262  unsigned NumIterations = BEs.getZExtValue(); // must be in range
4263  unsigned IterationNum = 0;
4264  for (Constant *PHIVal = StartCST; ; ++IterationNum) {
4265    if (IterationNum == NumIterations)
4266      return RetVal = PHIVal;  // Got exit value!
4267
4268    // Compute the value of the PHI node for the next iteration.
4269    Constant *NextPHI = EvaluateExpression(BEValue, PHIVal, TD);
4270    if (NextPHI == PHIVal)
4271      return RetVal = NextPHI;  // Stopped evolving!
4272    if (NextPHI == 0)
4273      return 0;        // Couldn't evaluate!
4274    PHIVal = NextPHI;
4275  }
4276}
4277
4278/// ComputeBackedgeTakenCountExhaustively - If the loop is known to execute a
4279/// constant number of times (the condition evolves only from constants),
4280/// try to evaluate a few iterations of the loop until we get the exit
4281/// condition gets a value of ExitWhen (true or false).  If we cannot
4282/// evaluate the trip count of the loop, return getCouldNotCompute().
4283const SCEV *
4284ScalarEvolution::ComputeBackedgeTakenCountExhaustively(const Loop *L,
4285                                                       Value *Cond,
4286                                                       bool ExitWhen) {
4287  PHINode *PN = getConstantEvolvingPHI(Cond, L);
4288  if (PN == 0) return getCouldNotCompute();
4289
4290  // If the loop is canonicalized, the PHI will have exactly two entries.
4291  // That's the only form we support here.
4292  if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
4293
4294  // One entry must be a constant (coming in from outside of the loop), and the
4295  // second must be derived from the same PHI.
4296  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
4297  Constant *StartCST =
4298    dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
4299  if (StartCST == 0) return getCouldNotCompute();  // Must be a constant.
4300
4301  Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
4302  if (getConstantEvolvingPHI(BEValue, L) != PN &&
4303      !isa<Constant>(BEValue))
4304    return getCouldNotCompute();  // Not derived from same PHI.
4305
4306  // Okay, we find a PHI node that defines the trip count of this loop.  Execute
4307  // the loop symbolically to determine when the condition gets a value of
4308  // "ExitWhen".
4309  unsigned IterationNum = 0;
4310  unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis.
4311  for (Constant *PHIVal = StartCST;
4312       IterationNum != MaxIterations; ++IterationNum) {
4313    ConstantInt *CondVal =
4314      dyn_cast_or_null<ConstantInt>(EvaluateExpression(Cond, PHIVal, TD));
4315
4316    // Couldn't symbolically evaluate.
4317    if (!CondVal) return getCouldNotCompute();
4318
4319    if (CondVal->getValue() == uint64_t(ExitWhen)) {
4320      ++NumBruteForceTripCountsComputed;
4321      return getConstant(Type::getInt32Ty(getContext()), IterationNum);
4322    }
4323
4324    // Compute the value of the PHI node for the next iteration.
4325    Constant *NextPHI = EvaluateExpression(BEValue, PHIVal, TD);
4326    if (NextPHI == 0 || NextPHI == PHIVal)
4327      return getCouldNotCompute();// Couldn't evaluate or not making progress...
4328    PHIVal = NextPHI;
4329  }
4330
4331  // Too many iterations were needed to evaluate.
4332  return getCouldNotCompute();
4333}
4334
4335/// getSCEVAtScope - Return a SCEV expression for the specified value
4336/// at the specified scope in the program.  The L value specifies a loop
4337/// nest to evaluate the expression at, where null is the top-level or a
4338/// specified loop is immediately inside of the loop.
4339///
4340/// This method can be used to compute the exit value for a variable defined
4341/// in a loop by querying what the value will hold in the parent loop.
4342///
4343/// In the case that a relevant loop exit value cannot be computed, the
4344/// original value V is returned.
4345const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
4346  // Check to see if we've folded this expression at this loop before.
4347  std::map<const Loop *, const SCEV *> &Values = ValuesAtScopes[V];
4348  std::pair<std::map<const Loop *, const SCEV *>::iterator, bool> Pair =
4349    Values.insert(std::make_pair(L, static_cast<const SCEV *>(0)));
4350  if (!Pair.second)
4351    return Pair.first->second ? Pair.first->second : V;
4352
4353  // Otherwise compute it.
4354  const SCEV *C = computeSCEVAtScope(V, L);
4355  ValuesAtScopes[V][L] = C;
4356  return C;
4357}
4358
4359const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
4360  if (isa<SCEVConstant>(V)) return V;
4361
4362  // If this instruction is evolved from a constant-evolving PHI, compute the
4363  // exit value from the loop without using SCEVs.
4364  if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
4365    if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
4366      const Loop *LI = (*this->LI)[I->getParent()];
4367      if (LI && LI->getParentLoop() == L)  // Looking for loop exit value.
4368        if (PHINode *PN = dyn_cast<PHINode>(I))
4369          if (PN->getParent() == LI->getHeader()) {
4370            // Okay, there is no closed form solution for the PHI node.  Check
4371            // to see if the loop that contains it has a known backedge-taken
4372            // count.  If so, we may be able to force computation of the exit
4373            // value.
4374            const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI);
4375            if (const SCEVConstant *BTCC =
4376                  dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
4377              // Okay, we know how many times the containing loop executes.  If
4378              // this is a constant evolving PHI node, get the final value at
4379              // the specified iteration number.
4380              Constant *RV = getConstantEvolutionLoopExitValue(PN,
4381                                                   BTCC->getValue()->getValue(),
4382                                                               LI);
4383              if (RV) return getSCEV(RV);
4384            }
4385          }
4386
4387      // Okay, this is an expression that we cannot symbolically evaluate
4388      // into a SCEV.  Check to see if it's possible to symbolically evaluate
4389      // the arguments into constants, and if so, try to constant propagate the
4390      // result.  This is particularly useful for computing loop exit values.
4391      if (CanConstantFold(I)) {
4392        SmallVector<Constant *, 4> Operands;
4393        bool MadeImprovement = false;
4394        for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
4395          Value *Op = I->getOperand(i);
4396          if (Constant *C = dyn_cast<Constant>(Op)) {
4397            Operands.push_back(C);
4398            continue;
4399          }
4400
4401          // If any of the operands is non-constant and if they are
4402          // non-integer and non-pointer, don't even try to analyze them
4403          // with scev techniques.
4404          if (!isSCEVable(Op->getType()))
4405            return V;
4406
4407          const SCEV *OrigV = getSCEV(Op);
4408          const SCEV *OpV = getSCEVAtScope(OrigV, L);
4409          MadeImprovement |= OrigV != OpV;
4410
4411          Constant *C = 0;
4412          if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(OpV))
4413            C = SC->getValue();
4414          if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(OpV))
4415            C = dyn_cast<Constant>(SU->getValue());
4416          if (!C) return V;
4417          if (C->getType() != Op->getType())
4418            C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
4419                                                              Op->getType(),
4420                                                              false),
4421                                      C, Op->getType());
4422          Operands.push_back(C);
4423        }
4424
4425        // Check to see if getSCEVAtScope actually made an improvement.
4426        if (MadeImprovement) {
4427          Constant *C = 0;
4428          if (const CmpInst *CI = dyn_cast<CmpInst>(I))
4429            C = ConstantFoldCompareInstOperands(CI->getPredicate(),
4430                                                Operands[0], Operands[1], TD);
4431          else
4432            C = ConstantFoldInstOperands(I->getOpcode(), I->getType(),
4433                                         &Operands[0], Operands.size(), TD);
4434          if (!C) return V;
4435          return getSCEV(C);
4436        }
4437      }
4438    }
4439
4440    // This is some other type of SCEVUnknown, just return it.
4441    return V;
4442  }
4443
4444  if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) {
4445    // Avoid performing the look-up in the common case where the specified
4446    // expression has no loop-variant portions.
4447    for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
4448      const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
4449      if (OpAtScope != Comm->getOperand(i)) {
4450        // Okay, at least one of these operands is loop variant but might be
4451        // foldable.  Build a new instance of the folded commutative expression.
4452        SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(),
4453                                            Comm->op_begin()+i);
4454        NewOps.push_back(OpAtScope);
4455
4456        for (++i; i != e; ++i) {
4457          OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
4458          NewOps.push_back(OpAtScope);
4459        }
4460        if (isa<SCEVAddExpr>(Comm))
4461          return getAddExpr(NewOps);
4462        if (isa<SCEVMulExpr>(Comm))
4463          return getMulExpr(NewOps);
4464        if (isa<SCEVSMaxExpr>(Comm))
4465          return getSMaxExpr(NewOps);
4466        if (isa<SCEVUMaxExpr>(Comm))
4467          return getUMaxExpr(NewOps);
4468        llvm_unreachable("Unknown commutative SCEV type!");
4469      }
4470    }
4471    // If we got here, all operands are loop invariant.
4472    return Comm;
4473  }
4474
4475  if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
4476    const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
4477    const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
4478    if (LHS == Div->getLHS() && RHS == Div->getRHS())
4479      return Div;   // must be loop invariant
4480    return getUDivExpr(LHS, RHS);
4481  }
4482
4483  // If this is a loop recurrence for a loop that does not contain L, then we
4484  // are dealing with the final value computed by the loop.
4485  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4486    // First, attempt to evaluate each operand.
4487    // Avoid performing the look-up in the common case where the specified
4488    // expression has no loop-variant portions.
4489    for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
4490      const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
4491      if (OpAtScope == AddRec->getOperand(i))
4492        continue;
4493
4494      // Okay, at least one of these operands is loop variant but might be
4495      // foldable.  Build a new instance of the folded commutative expression.
4496      SmallVector<const SCEV *, 8> NewOps(AddRec->op_begin(),
4497                                          AddRec->op_begin()+i);
4498      NewOps.push_back(OpAtScope);
4499      for (++i; i != e; ++i)
4500        NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
4501
4502      AddRec = cast<SCEVAddRecExpr>(getAddRecExpr(NewOps, AddRec->getLoop()));
4503      break;
4504    }
4505
4506    // If the scope is outside the addrec's loop, evaluate it by using the
4507    // loop exit value of the addrec.
4508    if (!AddRec->getLoop()->contains(L)) {
4509      // To evaluate this recurrence, we need to know how many times the AddRec
4510      // loop iterates.  Compute this now.
4511      const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
4512      if (BackedgeTakenCount == getCouldNotCompute()) return AddRec;
4513
4514      // Then, evaluate the AddRec.
4515      return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
4516    }
4517
4518    return AddRec;
4519  }
4520
4521  if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) {
4522    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
4523    if (Op == Cast->getOperand())
4524      return Cast;  // must be loop invariant
4525    return getZeroExtendExpr(Op, Cast->getType());
4526  }
4527
4528  if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) {
4529    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
4530    if (Op == Cast->getOperand())
4531      return Cast;  // must be loop invariant
4532    return getSignExtendExpr(Op, Cast->getType());
4533  }
4534
4535  if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) {
4536    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
4537    if (Op == Cast->getOperand())
4538      return Cast;  // must be loop invariant
4539    return getTruncateExpr(Op, Cast->getType());
4540  }
4541
4542  llvm_unreachable("Unknown SCEV type!");
4543  return 0;
4544}
4545
4546/// getSCEVAtScope - This is a convenience function which does
4547/// getSCEVAtScope(getSCEV(V), L).
4548const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
4549  return getSCEVAtScope(getSCEV(V), L);
4550}
4551
4552/// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the
4553/// following equation:
4554///
4555///     A * X = B (mod N)
4556///
4557/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
4558/// A and B isn't important.
4559///
4560/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
4561static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
4562                                               ScalarEvolution &SE) {
4563  uint32_t BW = A.getBitWidth();
4564  assert(BW == B.getBitWidth() && "Bit widths must be the same.");
4565  assert(A != 0 && "A must be non-zero.");
4566
4567  // 1. D = gcd(A, N)
4568  //
4569  // The gcd of A and N may have only one prime factor: 2. The number of
4570  // trailing zeros in A is its multiplicity
4571  uint32_t Mult2 = A.countTrailingZeros();
4572  // D = 2^Mult2
4573
4574  // 2. Check if B is divisible by D.
4575  //
4576  // B is divisible by D if and only if the multiplicity of prime factor 2 for B
4577  // is not less than multiplicity of this prime factor for D.
4578  if (B.countTrailingZeros() < Mult2)
4579    return SE.getCouldNotCompute();
4580
4581  // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
4582  // modulo (N / D).
4583  //
4584  // (N / D) may need BW+1 bits in its representation.  Hence, we'll use this
4585  // bit width during computations.
4586  APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
4587  APInt Mod(BW + 1, 0);
4588  Mod.set(BW - Mult2);  // Mod = N / D
4589  APInt I = AD.multiplicativeInverse(Mod);
4590
4591  // 4. Compute the minimum unsigned root of the equation:
4592  // I * (B / D) mod (N / D)
4593  APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
4594
4595  // The result is guaranteed to be less than 2^BW so we may truncate it to BW
4596  // bits.
4597  return SE.getConstant(Result.trunc(BW));
4598}
4599
4600/// SolveQuadraticEquation - Find the roots of the quadratic equation for the
4601/// given quadratic chrec {L,+,M,+,N}.  This returns either the two roots (which
4602/// might be the same) or two SCEVCouldNotCompute objects.
4603///
4604static std::pair<const SCEV *,const SCEV *>
4605SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
4606  assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
4607  const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
4608  const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
4609  const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
4610
4611  // We currently can only solve this if the coefficients are constants.
4612  if (!LC || !MC || !NC) {
4613    const SCEV *CNC = SE.getCouldNotCompute();
4614    return std::make_pair(CNC, CNC);
4615  }
4616
4617  uint32_t BitWidth = LC->getValue()->getValue().getBitWidth();
4618  const APInt &L = LC->getValue()->getValue();
4619  const APInt &M = MC->getValue()->getValue();
4620  const APInt &N = NC->getValue()->getValue();
4621  APInt Two(BitWidth, 2);
4622  APInt Four(BitWidth, 4);
4623
4624  {
4625    using namespace APIntOps;
4626    const APInt& C = L;
4627    // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
4628    // The B coefficient is M-N/2
4629    APInt B(M);
4630    B -= sdiv(N,Two);
4631
4632    // The A coefficient is N/2
4633    APInt A(N.sdiv(Two));
4634
4635    // Compute the B^2-4ac term.
4636    APInt SqrtTerm(B);
4637    SqrtTerm *= B;
4638    SqrtTerm -= Four * (A * C);
4639
4640    // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
4641    // integer value or else APInt::sqrt() will assert.
4642    APInt SqrtVal(SqrtTerm.sqrt());
4643
4644    // Compute the two solutions for the quadratic formula.
4645    // The divisions must be performed as signed divisions.
4646    APInt NegB(-B);
4647    APInt TwoA( A << 1 );
4648    if (TwoA.isMinValue()) {
4649      const SCEV *CNC = SE.getCouldNotCompute();
4650      return std::make_pair(CNC, CNC);
4651    }
4652
4653    LLVMContext &Context = SE.getContext();
4654
4655    ConstantInt *Solution1 =
4656      ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA));
4657    ConstantInt *Solution2 =
4658      ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA));
4659
4660    return std::make_pair(SE.getConstant(Solution1),
4661                          SE.getConstant(Solution2));
4662    } // end APIntOps namespace
4663}
4664
4665/// HowFarToZero - Return the number of times a backedge comparing the specified
4666/// value to zero will execute.  If not computable, return CouldNotCompute.
4667ScalarEvolution::BackedgeTakenInfo
4668ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) {
4669  // If the value is a constant
4670  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
4671    // If the value is already zero, the branch will execute zero times.
4672    if (C->getValue()->isZero()) return C;
4673    return getCouldNotCompute();  // Otherwise it will loop infinitely.
4674  }
4675
4676  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V);
4677  if (!AddRec || AddRec->getLoop() != L)
4678    return getCouldNotCompute();
4679
4680  if (AddRec->isAffine()) {
4681    // If this is an affine expression, the execution count of this branch is
4682    // the minimum unsigned root of the following equation:
4683    //
4684    //     Start + Step*N = 0 (mod 2^BW)
4685    //
4686    // equivalent to:
4687    //
4688    //             Step*N = -Start (mod 2^BW)
4689    //
4690    // where BW is the common bit width of Start and Step.
4691
4692    // Get the initial value for the loop.
4693    const SCEV *Start = getSCEVAtScope(AddRec->getStart(),
4694                                       L->getParentLoop());
4695    const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1),
4696                                      L->getParentLoop());
4697
4698    if (const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step)) {
4699      // For now we handle only constant steps.
4700
4701      // First, handle unitary steps.
4702      if (StepC->getValue()->equalsInt(1))      // 1*N = -Start (mod 2^BW), so:
4703        return getNegativeSCEV(Start);          //   N = -Start (as unsigned)
4704      if (StepC->getValue()->isAllOnesValue())  // -1*N = -Start (mod 2^BW), so:
4705        return Start;                           //    N = Start (as unsigned)
4706
4707      // Then, try to solve the above equation provided that Start is constant.
4708      if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
4709        return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
4710                                            -StartC->getValue()->getValue(),
4711                                            *this);
4712    }
4713  } else if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
4714    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
4715    // the quadratic equation to solve it.
4716    std::pair<const SCEV *,const SCEV *> Roots = SolveQuadraticEquation(AddRec,
4717                                                                    *this);
4718    const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
4719    const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
4720    if (R1) {
4721#if 0
4722      dbgs() << "HFTZ: " << *V << " - sol#1: " << *R1
4723             << "  sol#2: " << *R2 << "\n";
4724#endif
4725      // Pick the smallest positive root value.
4726      if (ConstantInt *CB =
4727          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
4728                                   R1->getValue(), R2->getValue()))) {
4729        if (CB->getZExtValue() == false)
4730          std::swap(R1, R2);   // R1 is the minimum root now.
4731
4732        // We can only use this value if the chrec ends up with an exact zero
4733        // value at this index.  When solving for "X*X != 5", for example, we
4734        // should not accept a root of 2.
4735        const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
4736        if (Val->isZero())
4737          return R1;  // We found a quadratic root!
4738      }
4739    }
4740  }
4741
4742  return getCouldNotCompute();
4743}
4744
4745/// HowFarToNonZero - Return the number of times a backedge checking the
4746/// specified value for nonzero will execute.  If not computable, return
4747/// CouldNotCompute
4748ScalarEvolution::BackedgeTakenInfo
4749ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) {
4750  // Loops that look like: while (X == 0) are very strange indeed.  We don't
4751  // handle them yet except for the trivial case.  This could be expanded in the
4752  // future as needed.
4753
4754  // If the value is a constant, check to see if it is known to be non-zero
4755  // already.  If so, the backedge will execute zero times.
4756  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
4757    if (!C->getValue()->isNullValue())
4758      return getConstant(C->getType(), 0);
4759    return getCouldNotCompute();  // Otherwise it will loop infinitely.
4760  }
4761
4762  // We could implement others, but I really doubt anyone writes loops like
4763  // this, and if they did, they would already be constant folded.
4764  return getCouldNotCompute();
4765}
4766
4767/// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
4768/// (which may not be an immediate predecessor) which has exactly one
4769/// successor from which BB is reachable, or null if no such block is
4770/// found.
4771///
4772std::pair<BasicBlock *, BasicBlock *>
4773ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
4774  // If the block has a unique predecessor, then there is no path from the
4775  // predecessor to the block that does not go through the direct edge
4776  // from the predecessor to the block.
4777  if (BasicBlock *Pred = BB->getSinglePredecessor())
4778    return std::make_pair(Pred, BB);
4779
4780  // A loop's header is defined to be a block that dominates the loop.
4781  // If the header has a unique predecessor outside the loop, it must be
4782  // a block that has exactly one successor that can reach the loop.
4783  if (Loop *L = LI->getLoopFor(BB))
4784    return std::make_pair(L->getLoopPredecessor(), L->getHeader());
4785
4786  return std::pair<BasicBlock *, BasicBlock *>();
4787}
4788
4789/// HasSameValue - SCEV structural equivalence is usually sufficient for
4790/// testing whether two expressions are equal, however for the purposes of
4791/// looking for a condition guarding a loop, it can be useful to be a little
4792/// more general, since a front-end may have replicated the controlling
4793/// expression.
4794///
4795static bool HasSameValue(const SCEV *A, const SCEV *B) {
4796  // Quick check to see if they are the same SCEV.
4797  if (A == B) return true;
4798
4799  // Otherwise, if they're both SCEVUnknown, it's possible that they hold
4800  // two different instructions with the same value. Check for this case.
4801  if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
4802    if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
4803      if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
4804        if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
4805          if (AI->isIdenticalTo(BI) && !AI->mayReadFromMemory())
4806            return true;
4807
4808  // Otherwise assume they may have a different value.
4809  return false;
4810}
4811
4812/// SimplifyICmpOperands - Simplify LHS and RHS in a comparison with
4813/// predicate Pred. Return true iff any changes were made.
4814///
4815bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
4816                                           const SCEV *&LHS, const SCEV *&RHS) {
4817  bool Changed = false;
4818
4819  // Canonicalize a constant to the right side.
4820  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
4821    // Check for both operands constant.
4822    if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
4823      if (ConstantExpr::getICmp(Pred,
4824                                LHSC->getValue(),
4825                                RHSC->getValue())->isNullValue())
4826        goto trivially_false;
4827      else
4828        goto trivially_true;
4829    }
4830    // Otherwise swap the operands to put the constant on the right.
4831    std::swap(LHS, RHS);
4832    Pred = ICmpInst::getSwappedPredicate(Pred);
4833    Changed = true;
4834  }
4835
4836  // If we're comparing an addrec with a value which is loop-invariant in the
4837  // addrec's loop, put the addrec on the left. Also make a dominance check,
4838  // as both operands could be addrecs loop-invariant in each other's loop.
4839  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
4840    const Loop *L = AR->getLoop();
4841    if (LHS->isLoopInvariant(L) && LHS->properlyDominates(L->getHeader(), DT)) {
4842      std::swap(LHS, RHS);
4843      Pred = ICmpInst::getSwappedPredicate(Pred);
4844      Changed = true;
4845    }
4846  }
4847
4848  // If there's a constant operand, canonicalize comparisons with boundary
4849  // cases, and canonicalize *-or-equal comparisons to regular comparisons.
4850  if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
4851    const APInt &RA = RC->getValue()->getValue();
4852    switch (Pred) {
4853    default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
4854    case ICmpInst::ICMP_EQ:
4855    case ICmpInst::ICMP_NE:
4856      break;
4857    case ICmpInst::ICMP_UGE:
4858      if ((RA - 1).isMinValue()) {
4859        Pred = ICmpInst::ICMP_NE;
4860        RHS = getConstant(RA - 1);
4861        Changed = true;
4862        break;
4863      }
4864      if (RA.isMaxValue()) {
4865        Pred = ICmpInst::ICMP_EQ;
4866        Changed = true;
4867        break;
4868      }
4869      if (RA.isMinValue()) goto trivially_true;
4870
4871      Pred = ICmpInst::ICMP_UGT;
4872      RHS = getConstant(RA - 1);
4873      Changed = true;
4874      break;
4875    case ICmpInst::ICMP_ULE:
4876      if ((RA + 1).isMaxValue()) {
4877        Pred = ICmpInst::ICMP_NE;
4878        RHS = getConstant(RA + 1);
4879        Changed = true;
4880        break;
4881      }
4882      if (RA.isMinValue()) {
4883        Pred = ICmpInst::ICMP_EQ;
4884        Changed = true;
4885        break;
4886      }
4887      if (RA.isMaxValue()) goto trivially_true;
4888
4889      Pred = ICmpInst::ICMP_ULT;
4890      RHS = getConstant(RA + 1);
4891      Changed = true;
4892      break;
4893    case ICmpInst::ICMP_SGE:
4894      if ((RA - 1).isMinSignedValue()) {
4895        Pred = ICmpInst::ICMP_NE;
4896        RHS = getConstant(RA - 1);
4897        Changed = true;
4898        break;
4899      }
4900      if (RA.isMaxSignedValue()) {
4901        Pred = ICmpInst::ICMP_EQ;
4902        Changed = true;
4903        break;
4904      }
4905      if (RA.isMinSignedValue()) goto trivially_true;
4906
4907      Pred = ICmpInst::ICMP_SGT;
4908      RHS = getConstant(RA - 1);
4909      Changed = true;
4910      break;
4911    case ICmpInst::ICMP_SLE:
4912      if ((RA + 1).isMaxSignedValue()) {
4913        Pred = ICmpInst::ICMP_NE;
4914        RHS = getConstant(RA + 1);
4915        Changed = true;
4916        break;
4917      }
4918      if (RA.isMinSignedValue()) {
4919        Pred = ICmpInst::ICMP_EQ;
4920        Changed = true;
4921        break;
4922      }
4923      if (RA.isMaxSignedValue()) goto trivially_true;
4924
4925      Pred = ICmpInst::ICMP_SLT;
4926      RHS = getConstant(RA + 1);
4927      Changed = true;
4928      break;
4929    case ICmpInst::ICMP_UGT:
4930      if (RA.isMinValue()) {
4931        Pred = ICmpInst::ICMP_NE;
4932        Changed = true;
4933        break;
4934      }
4935      if ((RA + 1).isMaxValue()) {
4936        Pred = ICmpInst::ICMP_EQ;
4937        RHS = getConstant(RA + 1);
4938        Changed = true;
4939        break;
4940      }
4941      if (RA.isMaxValue()) goto trivially_false;
4942      break;
4943    case ICmpInst::ICMP_ULT:
4944      if (RA.isMaxValue()) {
4945        Pred = ICmpInst::ICMP_NE;
4946        Changed = true;
4947        break;
4948      }
4949      if ((RA - 1).isMinValue()) {
4950        Pred = ICmpInst::ICMP_EQ;
4951        RHS = getConstant(RA - 1);
4952        Changed = true;
4953        break;
4954      }
4955      if (RA.isMinValue()) goto trivially_false;
4956      break;
4957    case ICmpInst::ICMP_SGT:
4958      if (RA.isMinSignedValue()) {
4959        Pred = ICmpInst::ICMP_NE;
4960        Changed = true;
4961        break;
4962      }
4963      if ((RA + 1).isMaxSignedValue()) {
4964        Pred = ICmpInst::ICMP_EQ;
4965        RHS = getConstant(RA + 1);
4966        Changed = true;
4967        break;
4968      }
4969      if (RA.isMaxSignedValue()) goto trivially_false;
4970      break;
4971    case ICmpInst::ICMP_SLT:
4972      if (RA.isMaxSignedValue()) {
4973        Pred = ICmpInst::ICMP_NE;
4974        Changed = true;
4975        break;
4976      }
4977      if ((RA - 1).isMinSignedValue()) {
4978       Pred = ICmpInst::ICMP_EQ;
4979       RHS = getConstant(RA - 1);
4980        Changed = true;
4981       break;
4982      }
4983      if (RA.isMinSignedValue()) goto trivially_false;
4984      break;
4985    }
4986  }
4987
4988  // Check for obvious equality.
4989  if (HasSameValue(LHS, RHS)) {
4990    if (ICmpInst::isTrueWhenEqual(Pred))
4991      goto trivially_true;
4992    if (ICmpInst::isFalseWhenEqual(Pred))
4993      goto trivially_false;
4994  }
4995
4996  // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
4997  // adding or subtracting 1 from one of the operands.
4998  switch (Pred) {
4999  case ICmpInst::ICMP_SLE:
5000    if (!getSignedRange(RHS).getSignedMax().isMaxSignedValue()) {
5001      RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
5002                       /*HasNUW=*/false, /*HasNSW=*/true);
5003      Pred = ICmpInst::ICMP_SLT;
5004      Changed = true;
5005    } else if (!getSignedRange(LHS).getSignedMin().isMinSignedValue()) {
5006      LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
5007                       /*HasNUW=*/false, /*HasNSW=*/true);
5008      Pred = ICmpInst::ICMP_SLT;
5009      Changed = true;
5010    }
5011    break;
5012  case ICmpInst::ICMP_SGE:
5013    if (!getSignedRange(RHS).getSignedMin().isMinSignedValue()) {
5014      RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
5015                       /*HasNUW=*/false, /*HasNSW=*/true);
5016      Pred = ICmpInst::ICMP_SGT;
5017      Changed = true;
5018    } else if (!getSignedRange(LHS).getSignedMax().isMaxSignedValue()) {
5019      LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
5020                       /*HasNUW=*/false, /*HasNSW=*/true);
5021      Pred = ICmpInst::ICMP_SGT;
5022      Changed = true;
5023    }
5024    break;
5025  case ICmpInst::ICMP_ULE:
5026    if (!getUnsignedRange(RHS).getUnsignedMax().isMaxValue()) {
5027      RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
5028                       /*HasNUW=*/true, /*HasNSW=*/false);
5029      Pred = ICmpInst::ICMP_ULT;
5030      Changed = true;
5031    } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) {
5032      LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
5033                       /*HasNUW=*/true, /*HasNSW=*/false);
5034      Pred = ICmpInst::ICMP_ULT;
5035      Changed = true;
5036    }
5037    break;
5038  case ICmpInst::ICMP_UGE:
5039    if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) {
5040      RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
5041                       /*HasNUW=*/true, /*HasNSW=*/false);
5042      Pred = ICmpInst::ICMP_UGT;
5043      Changed = true;
5044    } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) {
5045      LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
5046                       /*HasNUW=*/true, /*HasNSW=*/false);
5047      Pred = ICmpInst::ICMP_UGT;
5048      Changed = true;
5049    }
5050    break;
5051  default:
5052    break;
5053  }
5054
5055  // TODO: More simplifications are possible here.
5056
5057  return Changed;
5058
5059trivially_true:
5060  // Return 0 == 0.
5061  LHS = RHS = getConstant(Type::getInt1Ty(getContext()), 0);
5062  Pred = ICmpInst::ICMP_EQ;
5063  return true;
5064
5065trivially_false:
5066  // Return 0 != 0.
5067  LHS = RHS = getConstant(Type::getInt1Ty(getContext()), 0);
5068  Pred = ICmpInst::ICMP_NE;
5069  return true;
5070}
5071
5072bool ScalarEvolution::isKnownNegative(const SCEV *S) {
5073  return getSignedRange(S).getSignedMax().isNegative();
5074}
5075
5076bool ScalarEvolution::isKnownPositive(const SCEV *S) {
5077  return getSignedRange(S).getSignedMin().isStrictlyPositive();
5078}
5079
5080bool ScalarEvolution::isKnownNonNegative(const SCEV *S) {
5081  return !getSignedRange(S).getSignedMin().isNegative();
5082}
5083
5084bool ScalarEvolution::isKnownNonPositive(const SCEV *S) {
5085  return !getSignedRange(S).getSignedMax().isStrictlyPositive();
5086}
5087
5088bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
5089  return isKnownNegative(S) || isKnownPositive(S);
5090}
5091
5092bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
5093                                       const SCEV *LHS, const SCEV *RHS) {
5094  // Canonicalize the inputs first.
5095  (void)SimplifyICmpOperands(Pred, LHS, RHS);
5096
5097  // If LHS or RHS is an addrec, check to see if the condition is true in
5098  // every iteration of the loop.
5099  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
5100    if (isLoopEntryGuardedByCond(
5101          AR->getLoop(), Pred, AR->getStart(), RHS) &&
5102        isLoopBackedgeGuardedByCond(
5103          AR->getLoop(), Pred, AR->getPostIncExpr(*this), RHS))
5104      return true;
5105  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS))
5106    if (isLoopEntryGuardedByCond(
5107          AR->getLoop(), Pred, LHS, AR->getStart()) &&
5108        isLoopBackedgeGuardedByCond(
5109          AR->getLoop(), Pred, LHS, AR->getPostIncExpr(*this)))
5110      return true;
5111
5112  // Otherwise see what can be done with known constant ranges.
5113  return isKnownPredicateWithRanges(Pred, LHS, RHS);
5114}
5115
5116bool
5117ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred,
5118                                            const SCEV *LHS, const SCEV *RHS) {
5119  if (HasSameValue(LHS, RHS))
5120    return ICmpInst::isTrueWhenEqual(Pred);
5121
5122  // This code is split out from isKnownPredicate because it is called from
5123  // within isLoopEntryGuardedByCond.
5124  switch (Pred) {
5125  default:
5126    llvm_unreachable("Unexpected ICmpInst::Predicate value!");
5127    break;
5128  case ICmpInst::ICMP_SGT:
5129    Pred = ICmpInst::ICMP_SLT;
5130    std::swap(LHS, RHS);
5131  case ICmpInst::ICMP_SLT: {
5132    ConstantRange LHSRange = getSignedRange(LHS);
5133    ConstantRange RHSRange = getSignedRange(RHS);
5134    if (LHSRange.getSignedMax().slt(RHSRange.getSignedMin()))
5135      return true;
5136    if (LHSRange.getSignedMin().sge(RHSRange.getSignedMax()))
5137      return false;
5138    break;
5139  }
5140  case ICmpInst::ICMP_SGE:
5141    Pred = ICmpInst::ICMP_SLE;
5142    std::swap(LHS, RHS);
5143  case ICmpInst::ICMP_SLE: {
5144    ConstantRange LHSRange = getSignedRange(LHS);
5145    ConstantRange RHSRange = getSignedRange(RHS);
5146    if (LHSRange.getSignedMax().sle(RHSRange.getSignedMin()))
5147      return true;
5148    if (LHSRange.getSignedMin().sgt(RHSRange.getSignedMax()))
5149      return false;
5150    break;
5151  }
5152  case ICmpInst::ICMP_UGT:
5153    Pred = ICmpInst::ICMP_ULT;
5154    std::swap(LHS, RHS);
5155  case ICmpInst::ICMP_ULT: {
5156    ConstantRange LHSRange = getUnsignedRange(LHS);
5157    ConstantRange RHSRange = getUnsignedRange(RHS);
5158    if (LHSRange.getUnsignedMax().ult(RHSRange.getUnsignedMin()))
5159      return true;
5160    if (LHSRange.getUnsignedMin().uge(RHSRange.getUnsignedMax()))
5161      return false;
5162    break;
5163  }
5164  case ICmpInst::ICMP_UGE:
5165    Pred = ICmpInst::ICMP_ULE;
5166    std::swap(LHS, RHS);
5167  case ICmpInst::ICMP_ULE: {
5168    ConstantRange LHSRange = getUnsignedRange(LHS);
5169    ConstantRange RHSRange = getUnsignedRange(RHS);
5170    if (LHSRange.getUnsignedMax().ule(RHSRange.getUnsignedMin()))
5171      return true;
5172    if (LHSRange.getUnsignedMin().ugt(RHSRange.getUnsignedMax()))
5173      return false;
5174    break;
5175  }
5176  case ICmpInst::ICMP_NE: {
5177    if (getUnsignedRange(LHS).intersectWith(getUnsignedRange(RHS)).isEmptySet())
5178      return true;
5179    if (getSignedRange(LHS).intersectWith(getSignedRange(RHS)).isEmptySet())
5180      return true;
5181
5182    const SCEV *Diff = getMinusSCEV(LHS, RHS);
5183    if (isKnownNonZero(Diff))
5184      return true;
5185    break;
5186  }
5187  case ICmpInst::ICMP_EQ:
5188    // The check at the top of the function catches the case where
5189    // the values are known to be equal.
5190    break;
5191  }
5192  return false;
5193}
5194
5195/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
5196/// protected by a conditional between LHS and RHS.  This is used to
5197/// to eliminate casts.
5198bool
5199ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
5200                                             ICmpInst::Predicate Pred,
5201                                             const SCEV *LHS, const SCEV *RHS) {
5202  // Interpret a null as meaning no loop, where there is obviously no guard
5203  // (interprocedural conditions notwithstanding).
5204  if (!L) return true;
5205
5206  BasicBlock *Latch = L->getLoopLatch();
5207  if (!Latch)
5208    return false;
5209
5210  BranchInst *LoopContinuePredicate =
5211    dyn_cast<BranchInst>(Latch->getTerminator());
5212  if (!LoopContinuePredicate ||
5213      LoopContinuePredicate->isUnconditional())
5214    return false;
5215
5216  return isImpliedCond(LoopContinuePredicate->getCondition(), Pred, LHS, RHS,
5217                       LoopContinuePredicate->getSuccessor(0) != L->getHeader());
5218}
5219
5220/// isLoopEntryGuardedByCond - Test whether entry to the loop is protected
5221/// by a conditional between LHS and RHS.  This is used to help avoid max
5222/// expressions in loop trip counts, and to eliminate casts.
5223bool
5224ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
5225                                          ICmpInst::Predicate Pred,
5226                                          const SCEV *LHS, const SCEV *RHS) {
5227  // Interpret a null as meaning no loop, where there is obviously no guard
5228  // (interprocedural conditions notwithstanding).
5229  if (!L) return false;
5230
5231  // Starting at the loop predecessor, climb up the predecessor chain, as long
5232  // as there are predecessors that can be found that have unique successors
5233  // leading to the original header.
5234  for (std::pair<BasicBlock *, BasicBlock *>
5235         Pair(L->getLoopPredecessor(), L->getHeader());
5236       Pair.first;
5237       Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
5238
5239    BranchInst *LoopEntryPredicate =
5240      dyn_cast<BranchInst>(Pair.first->getTerminator());
5241    if (!LoopEntryPredicate ||
5242        LoopEntryPredicate->isUnconditional())
5243      continue;
5244
5245    if (isImpliedCond(LoopEntryPredicate->getCondition(), Pred, LHS, RHS,
5246                      LoopEntryPredicate->getSuccessor(0) != Pair.second))
5247      return true;
5248  }
5249
5250  return false;
5251}
5252
5253/// isImpliedCond - Test whether the condition described by Pred, LHS,
5254/// and RHS is true whenever the given Cond value evaluates to true.
5255bool ScalarEvolution::isImpliedCond(Value *CondValue,
5256                                    ICmpInst::Predicate Pred,
5257                                    const SCEV *LHS, const SCEV *RHS,
5258                                    bool Inverse) {
5259  // Recursively handle And and Or conditions.
5260  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(CondValue)) {
5261    if (BO->getOpcode() == Instruction::And) {
5262      if (!Inverse)
5263        return isImpliedCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) ||
5264               isImpliedCond(BO->getOperand(1), Pred, LHS, RHS, Inverse);
5265    } else if (BO->getOpcode() == Instruction::Or) {
5266      if (Inverse)
5267        return isImpliedCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) ||
5268               isImpliedCond(BO->getOperand(1), Pred, LHS, RHS, Inverse);
5269    }
5270  }
5271
5272  ICmpInst *ICI = dyn_cast<ICmpInst>(CondValue);
5273  if (!ICI) return false;
5274
5275  // Bail if the ICmp's operands' types are wider than the needed type
5276  // before attempting to call getSCEV on them. This avoids infinite
5277  // recursion, since the analysis of widening casts can require loop
5278  // exit condition information for overflow checking, which would
5279  // lead back here.
5280  if (getTypeSizeInBits(LHS->getType()) <
5281      getTypeSizeInBits(ICI->getOperand(0)->getType()))
5282    return false;
5283
5284  // Now that we found a conditional branch that dominates the loop, check to
5285  // see if it is the comparison we are looking for.
5286  ICmpInst::Predicate FoundPred;
5287  if (Inverse)
5288    FoundPred = ICI->getInversePredicate();
5289  else
5290    FoundPred = ICI->getPredicate();
5291
5292  const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
5293  const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
5294
5295  // Balance the types. The case where FoundLHS' type is wider than
5296  // LHS' type is checked for above.
5297  if (getTypeSizeInBits(LHS->getType()) >
5298      getTypeSizeInBits(FoundLHS->getType())) {
5299    if (CmpInst::isSigned(Pred)) {
5300      FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
5301      FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
5302    } else {
5303      FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
5304      FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
5305    }
5306  }
5307
5308  // Canonicalize the query to match the way instcombine will have
5309  // canonicalized the comparison.
5310  if (SimplifyICmpOperands(Pred, LHS, RHS))
5311    if (LHS == RHS)
5312      return CmpInst::isTrueWhenEqual(Pred);
5313  if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
5314    if (FoundLHS == FoundRHS)
5315      return CmpInst::isFalseWhenEqual(Pred);
5316
5317  // Check to see if we can make the LHS or RHS match.
5318  if (LHS == FoundRHS || RHS == FoundLHS) {
5319    if (isa<SCEVConstant>(RHS)) {
5320      std::swap(FoundLHS, FoundRHS);
5321      FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
5322    } else {
5323      std::swap(LHS, RHS);
5324      Pred = ICmpInst::getSwappedPredicate(Pred);
5325    }
5326  }
5327
5328  // Check whether the found predicate is the same as the desired predicate.
5329  if (FoundPred == Pred)
5330    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
5331
5332  // Check whether swapping the found predicate makes it the same as the
5333  // desired predicate.
5334  if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
5335    if (isa<SCEVConstant>(RHS))
5336      return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS);
5337    else
5338      return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred),
5339                                   RHS, LHS, FoundLHS, FoundRHS);
5340  }
5341
5342  // Check whether the actual condition is beyond sufficient.
5343  if (FoundPred == ICmpInst::ICMP_EQ)
5344    if (ICmpInst::isTrueWhenEqual(Pred))
5345      if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS))
5346        return true;
5347  if (Pred == ICmpInst::ICMP_NE)
5348    if (!ICmpInst::isTrueWhenEqual(FoundPred))
5349      if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS))
5350        return true;
5351
5352  // Otherwise assume the worst.
5353  return false;
5354}
5355
5356/// isImpliedCondOperands - Test whether the condition described by Pred,
5357/// LHS, and RHS is true whenever the condition described by Pred, FoundLHS,
5358/// and FoundRHS is true.
5359bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
5360                                            const SCEV *LHS, const SCEV *RHS,
5361                                            const SCEV *FoundLHS,
5362                                            const SCEV *FoundRHS) {
5363  return isImpliedCondOperandsHelper(Pred, LHS, RHS,
5364                                     FoundLHS, FoundRHS) ||
5365         // ~x < ~y --> x > y
5366         isImpliedCondOperandsHelper(Pred, LHS, RHS,
5367                                     getNotSCEV(FoundRHS),
5368                                     getNotSCEV(FoundLHS));
5369}
5370
5371/// isImpliedCondOperandsHelper - Test whether the condition described by
5372/// Pred, LHS, and RHS is true whenever the condition described by Pred,
5373/// FoundLHS, and FoundRHS is true.
5374bool
5375ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
5376                                             const SCEV *LHS, const SCEV *RHS,
5377                                             const SCEV *FoundLHS,
5378                                             const SCEV *FoundRHS) {
5379  switch (Pred) {
5380  default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
5381  case ICmpInst::ICMP_EQ:
5382  case ICmpInst::ICMP_NE:
5383    if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
5384      return true;
5385    break;
5386  case ICmpInst::ICMP_SLT:
5387  case ICmpInst::ICMP_SLE:
5388    if (isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
5389        isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, RHS, FoundRHS))
5390      return true;
5391    break;
5392  case ICmpInst::ICMP_SGT:
5393  case ICmpInst::ICMP_SGE:
5394    if (isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
5395        isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, RHS, FoundRHS))
5396      return true;
5397    break;
5398  case ICmpInst::ICMP_ULT:
5399  case ICmpInst::ICMP_ULE:
5400    if (isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
5401        isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, RHS, FoundRHS))
5402      return true;
5403    break;
5404  case ICmpInst::ICMP_UGT:
5405  case ICmpInst::ICMP_UGE:
5406    if (isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
5407        isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, RHS, FoundRHS))
5408      return true;
5409    break;
5410  }
5411
5412  return false;
5413}
5414
5415/// getBECount - Subtract the end and start values and divide by the step,
5416/// rounding up, to get the number of times the backedge is executed. Return
5417/// CouldNotCompute if an intermediate computation overflows.
5418const SCEV *ScalarEvolution::getBECount(const SCEV *Start,
5419                                        const SCEV *End,
5420                                        const SCEV *Step,
5421                                        bool NoWrap) {
5422  assert(!isKnownNegative(Step) &&
5423         "This code doesn't handle negative strides yet!");
5424
5425  const Type *Ty = Start->getType();
5426  const SCEV *NegOne = getConstant(Ty, (uint64_t)-1);
5427  const SCEV *Diff = getMinusSCEV(End, Start);
5428  const SCEV *RoundUp = getAddExpr(Step, NegOne);
5429
5430  // Add an adjustment to the difference between End and Start so that
5431  // the division will effectively round up.
5432  const SCEV *Add = getAddExpr(Diff, RoundUp);
5433
5434  if (!NoWrap) {
5435    // Check Add for unsigned overflow.
5436    // TODO: More sophisticated things could be done here.
5437    const Type *WideTy = IntegerType::get(getContext(),
5438                                          getTypeSizeInBits(Ty) + 1);
5439    const SCEV *EDiff = getZeroExtendExpr(Diff, WideTy);
5440    const SCEV *ERoundUp = getZeroExtendExpr(RoundUp, WideTy);
5441    const SCEV *OperandExtendedAdd = getAddExpr(EDiff, ERoundUp);
5442    if (getZeroExtendExpr(Add, WideTy) != OperandExtendedAdd)
5443      return getCouldNotCompute();
5444  }
5445
5446  return getUDivExpr(Add, Step);
5447}
5448
5449/// HowManyLessThans - Return the number of times a backedge containing the
5450/// specified less-than comparison will execute.  If not computable, return
5451/// CouldNotCompute.
5452ScalarEvolution::BackedgeTakenInfo
5453ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
5454                                  const Loop *L, bool isSigned) {
5455  // Only handle:  "ADDREC < LoopInvariant".
5456  if (!RHS->isLoopInvariant(L)) return getCouldNotCompute();
5457
5458  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS);
5459  if (!AddRec || AddRec->getLoop() != L)
5460    return getCouldNotCompute();
5461
5462  // Check to see if we have a flag which makes analysis easy.
5463  bool NoWrap = isSigned ? AddRec->hasNoSignedWrap() :
5464                           AddRec->hasNoUnsignedWrap();
5465
5466  if (AddRec->isAffine()) {
5467    unsigned BitWidth = getTypeSizeInBits(AddRec->getType());
5468    const SCEV *Step = AddRec->getStepRecurrence(*this);
5469
5470    if (Step->isZero())
5471      return getCouldNotCompute();
5472    if (Step->isOne()) {
5473      // With unit stride, the iteration never steps past the limit value.
5474    } else if (isKnownPositive(Step)) {
5475      // Test whether a positive iteration can step past the limit
5476      // value and past the maximum value for its type in a single step.
5477      // Note that it's not sufficient to check NoWrap here, because even
5478      // though the value after a wrap is undefined, it's not undefined
5479      // behavior, so if wrap does occur, the loop could either terminate or
5480      // loop infinitely, but in either case, the loop is guaranteed to
5481      // iterate at least until the iteration where the wrapping occurs.
5482      const SCEV *One = getConstant(Step->getType(), 1);
5483      if (isSigned) {
5484        APInt Max = APInt::getSignedMaxValue(BitWidth);
5485        if ((Max - getSignedRange(getMinusSCEV(Step, One)).getSignedMax())
5486              .slt(getSignedRange(RHS).getSignedMax()))
5487          return getCouldNotCompute();
5488      } else {
5489        APInt Max = APInt::getMaxValue(BitWidth);
5490        if ((Max - getUnsignedRange(getMinusSCEV(Step, One)).getUnsignedMax())
5491              .ult(getUnsignedRange(RHS).getUnsignedMax()))
5492          return getCouldNotCompute();
5493      }
5494    } else
5495      // TODO: Handle negative strides here and below.
5496      return getCouldNotCompute();
5497
5498    // We know the LHS is of the form {n,+,s} and the RHS is some loop-invariant
5499    // m.  So, we count the number of iterations in which {n,+,s} < m is true.
5500    // Note that we cannot simply return max(m-n,0)/s because it's not safe to
5501    // treat m-n as signed nor unsigned due to overflow possibility.
5502
5503    // First, we get the value of the LHS in the first iteration: n
5504    const SCEV *Start = AddRec->getOperand(0);
5505
5506    // Determine the minimum constant start value.
5507    const SCEV *MinStart = getConstant(isSigned ?
5508      getSignedRange(Start).getSignedMin() :
5509      getUnsignedRange(Start).getUnsignedMin());
5510
5511    // If we know that the condition is true in order to enter the loop,
5512    // then we know that it will run exactly (m-n)/s times. Otherwise, we
5513    // only know that it will execute (max(m,n)-n)/s times. In both cases,
5514    // the division must round up.
5515    const SCEV *End = RHS;
5516    if (!isLoopEntryGuardedByCond(L,
5517                                  isSigned ? ICmpInst::ICMP_SLT :
5518                                             ICmpInst::ICMP_ULT,
5519                                  getMinusSCEV(Start, Step), RHS))
5520      End = isSigned ? getSMaxExpr(RHS, Start)
5521                     : getUMaxExpr(RHS, Start);
5522
5523    // Determine the maximum constant end value.
5524    const SCEV *MaxEnd = getConstant(isSigned ?
5525      getSignedRange(End).getSignedMax() :
5526      getUnsignedRange(End).getUnsignedMax());
5527
5528    // If MaxEnd is within a step of the maximum integer value in its type,
5529    // adjust it down to the minimum value which would produce the same effect.
5530    // This allows the subsequent ceiling division of (N+(step-1))/step to
5531    // compute the correct value.
5532    const SCEV *StepMinusOne = getMinusSCEV(Step,
5533                                            getConstant(Step->getType(), 1));
5534    MaxEnd = isSigned ?
5535      getSMinExpr(MaxEnd,
5536                  getMinusSCEV(getConstant(APInt::getSignedMaxValue(BitWidth)),
5537                               StepMinusOne)) :
5538      getUMinExpr(MaxEnd,
5539                  getMinusSCEV(getConstant(APInt::getMaxValue(BitWidth)),
5540                               StepMinusOne));
5541
5542    // Finally, we subtract these two values and divide, rounding up, to get
5543    // the number of times the backedge is executed.
5544    const SCEV *BECount = getBECount(Start, End, Step, NoWrap);
5545
5546    // The maximum backedge count is similar, except using the minimum start
5547    // value and the maximum end value.
5548    const SCEV *MaxBECount = getBECount(MinStart, MaxEnd, Step, NoWrap);
5549
5550    return BackedgeTakenInfo(BECount, MaxBECount);
5551  }
5552
5553  return getCouldNotCompute();
5554}
5555
5556/// getNumIterationsInRange - Return the number of iterations of this loop that
5557/// produce values in the specified constant range.  Another way of looking at
5558/// this is that it returns the first iteration number where the value is not in
5559/// the condition, thus computing the exit count. If the iteration count can't
5560/// be computed, an instance of SCEVCouldNotCompute is returned.
5561const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
5562                                                    ScalarEvolution &SE) const {
5563  if (Range.isFullSet())  // Infinite loop.
5564    return SE.getCouldNotCompute();
5565
5566  // If the start is a non-zero constant, shift the range to simplify things.
5567  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
5568    if (!SC->getValue()->isZero()) {
5569      SmallVector<const SCEV *, 4> Operands(op_begin(), op_end());
5570      Operands[0] = SE.getConstant(SC->getType(), 0);
5571      const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop());
5572      if (const SCEVAddRecExpr *ShiftedAddRec =
5573            dyn_cast<SCEVAddRecExpr>(Shifted))
5574        return ShiftedAddRec->getNumIterationsInRange(
5575                           Range.subtract(SC->getValue()->getValue()), SE);
5576      // This is strange and shouldn't happen.
5577      return SE.getCouldNotCompute();
5578    }
5579
5580  // The only time we can solve this is when we have all constant indices.
5581  // Otherwise, we cannot determine the overflow conditions.
5582  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
5583    if (!isa<SCEVConstant>(getOperand(i)))
5584      return SE.getCouldNotCompute();
5585
5586
5587  // Okay at this point we know that all elements of the chrec are constants and
5588  // that the start element is zero.
5589
5590  // First check to see if the range contains zero.  If not, the first
5591  // iteration exits.
5592  unsigned BitWidth = SE.getTypeSizeInBits(getType());
5593  if (!Range.contains(APInt(BitWidth, 0)))
5594    return SE.getConstant(getType(), 0);
5595
5596  if (isAffine()) {
5597    // If this is an affine expression then we have this situation:
5598    //   Solve {0,+,A} in Range  ===  Ax in Range
5599
5600    // We know that zero is in the range.  If A is positive then we know that
5601    // the upper value of the range must be the first possible exit value.
5602    // If A is negative then the lower of the range is the last possible loop
5603    // value.  Also note that we already checked for a full range.
5604    APInt One(BitWidth,1);
5605    APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
5606    APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
5607
5608    // The exit value should be (End+A)/A.
5609    APInt ExitVal = (End + A).udiv(A);
5610    ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
5611
5612    // Evaluate at the exit value.  If we really did fall out of the valid
5613    // range, then we computed our trip count, otherwise wrap around or other
5614    // things must have happened.
5615    ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
5616    if (Range.contains(Val->getValue()))
5617      return SE.getCouldNotCompute();  // Something strange happened
5618
5619    // Ensure that the previous value is in the range.  This is a sanity check.
5620    assert(Range.contains(
5621           EvaluateConstantChrecAtConstant(this,
5622           ConstantInt::get(SE.getContext(), ExitVal - One), SE)->getValue()) &&
5623           "Linear scev computation is off in a bad way!");
5624    return SE.getConstant(ExitValue);
5625  } else if (isQuadratic()) {
5626    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
5627    // quadratic equation to solve it.  To do this, we must frame our problem in
5628    // terms of figuring out when zero is crossed, instead of when
5629    // Range.getUpper() is crossed.
5630    SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end());
5631    NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper()));
5632    const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop());
5633
5634    // Next, solve the constructed addrec
5635    std::pair<const SCEV *,const SCEV *> Roots =
5636      SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
5637    const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
5638    const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
5639    if (R1) {
5640      // Pick the smallest positive root value.
5641      if (ConstantInt *CB =
5642          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
5643                         R1->getValue(), R2->getValue()))) {
5644        if (CB->getZExtValue() == false)
5645          std::swap(R1, R2);   // R1 is the minimum root now.
5646
5647        // Make sure the root is not off by one.  The returned iteration should
5648        // not be in the range, but the previous one should be.  When solving
5649        // for "X*X < 5", for example, we should not return a root of 2.
5650        ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this,
5651                                                             R1->getValue(),
5652                                                             SE);
5653        if (Range.contains(R1Val->getValue())) {
5654          // The next iteration must be out of the range...
5655          ConstantInt *NextVal =
5656                ConstantInt::get(SE.getContext(), R1->getValue()->getValue()+1);
5657
5658          R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
5659          if (!Range.contains(R1Val->getValue()))
5660            return SE.getConstant(NextVal);
5661          return SE.getCouldNotCompute();  // Something strange happened
5662        }
5663
5664        // If R1 was not in the range, then it is a good return value.  Make
5665        // sure that R1-1 WAS in the range though, just in case.
5666        ConstantInt *NextVal =
5667               ConstantInt::get(SE.getContext(), R1->getValue()->getValue()-1);
5668        R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
5669        if (Range.contains(R1Val->getValue()))
5670          return R1;
5671        return SE.getCouldNotCompute();  // Something strange happened
5672      }
5673    }
5674  }
5675
5676  return SE.getCouldNotCompute();
5677}
5678
5679
5680
5681//===----------------------------------------------------------------------===//
5682//                   SCEVCallbackVH Class Implementation
5683//===----------------------------------------------------------------------===//
5684
5685void ScalarEvolution::SCEVCallbackVH::deleted() {
5686  assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
5687  if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
5688    SE->ConstantEvolutionLoopExitValue.erase(PN);
5689  SE->Scalars.erase(getValPtr());
5690  // this now dangles!
5691}
5692
5693void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *) {
5694  assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
5695
5696  // Forget all the expressions associated with users of the old value,
5697  // so that future queries will recompute the expressions using the new
5698  // value.
5699  SmallVector<User *, 16> Worklist;
5700  SmallPtrSet<User *, 8> Visited;
5701  Value *Old = getValPtr();
5702  for (Value::use_iterator UI = Old->use_begin(), UE = Old->use_end();
5703       UI != UE; ++UI)
5704    Worklist.push_back(*UI);
5705  while (!Worklist.empty()) {
5706    User *U = Worklist.pop_back_val();
5707    // Deleting the Old value will cause this to dangle. Postpone
5708    // that until everything else is done.
5709    if (U == Old)
5710      continue;
5711    if (!Visited.insert(U))
5712      continue;
5713    if (PHINode *PN = dyn_cast<PHINode>(U))
5714      SE->ConstantEvolutionLoopExitValue.erase(PN);
5715    SE->Scalars.erase(U);
5716    for (Value::use_iterator UI = U->use_begin(), UE = U->use_end();
5717         UI != UE; ++UI)
5718      Worklist.push_back(*UI);
5719  }
5720  // Delete the Old value.
5721  if (PHINode *PN = dyn_cast<PHINode>(Old))
5722    SE->ConstantEvolutionLoopExitValue.erase(PN);
5723  SE->Scalars.erase(Old);
5724  // this now dangles!
5725}
5726
5727ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
5728  : CallbackVH(V), SE(se) {}
5729
5730//===----------------------------------------------------------------------===//
5731//                   ScalarEvolution Class Implementation
5732//===----------------------------------------------------------------------===//
5733
5734ScalarEvolution::ScalarEvolution()
5735  : FunctionPass(&ID) {
5736}
5737
5738bool ScalarEvolution::runOnFunction(Function &F) {
5739  this->F = &F;
5740  LI = &getAnalysis<LoopInfo>();
5741  TD = getAnalysisIfAvailable<TargetData>();
5742  DT = &getAnalysis<DominatorTree>();
5743  return false;
5744}
5745
5746void ScalarEvolution::releaseMemory() {
5747  Scalars.clear();
5748  BackedgeTakenCounts.clear();
5749  ConstantEvolutionLoopExitValue.clear();
5750  ValuesAtScopes.clear();
5751  UniqueSCEVs.clear();
5752  SCEVAllocator.Reset();
5753}
5754
5755void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
5756  AU.setPreservesAll();
5757  AU.addRequiredTransitive<LoopInfo>();
5758  AU.addRequiredTransitive<DominatorTree>();
5759}
5760
5761bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
5762  return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
5763}
5764
5765static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
5766                          const Loop *L) {
5767  // Print all inner loops first
5768  for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
5769    PrintLoopInfo(OS, SE, *I);
5770
5771  OS << "Loop ";
5772  WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false);
5773  OS << ": ";
5774
5775  SmallVector<BasicBlock *, 8> ExitBlocks;
5776  L->getExitBlocks(ExitBlocks);
5777  if (ExitBlocks.size() != 1)
5778    OS << "<multiple exits> ";
5779
5780  if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
5781    OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L);
5782  } else {
5783    OS << "Unpredictable backedge-taken count. ";
5784  }
5785
5786  OS << "\n"
5787        "Loop ";
5788  WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false);
5789  OS << ": ";
5790
5791  if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) {
5792    OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L);
5793  } else {
5794    OS << "Unpredictable max backedge-taken count. ";
5795  }
5796
5797  OS << "\n";
5798}
5799
5800void ScalarEvolution::print(raw_ostream &OS, const Module *) const {
5801  // ScalarEvolution's implementation of the print method is to print
5802  // out SCEV values of all instructions that are interesting. Doing
5803  // this potentially causes it to create new SCEV objects though,
5804  // which technically conflicts with the const qualifier. This isn't
5805  // observable from outside the class though, so casting away the
5806  // const isn't dangerous.
5807  ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
5808
5809  OS << "Classifying expressions for: ";
5810  WriteAsOperand(OS, F, /*PrintType=*/false);
5811  OS << "\n";
5812  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
5813    if (isSCEVable(I->getType()) && !isa<CmpInst>(*I)) {
5814      OS << *I << '\n';
5815      OS << "  -->  ";
5816      const SCEV *SV = SE.getSCEV(&*I);
5817      SV->print(OS);
5818
5819      const Loop *L = LI->getLoopFor((*I).getParent());
5820
5821      const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
5822      if (AtUse != SV) {
5823        OS << "  -->  ";
5824        AtUse->print(OS);
5825      }
5826
5827      if (L) {
5828        OS << "\t\t" "Exits: ";
5829        const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
5830        if (!ExitValue->isLoopInvariant(L)) {
5831          OS << "<<Unknown>>";
5832        } else {
5833          OS << *ExitValue;
5834        }
5835      }
5836
5837      OS << "\n";
5838    }
5839
5840  OS << "Determining loop execution counts for: ";
5841  WriteAsOperand(OS, F, /*PrintType=*/false);
5842  OS << "\n";
5843  for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I)
5844    PrintLoopInfo(OS, &SE, *I);
5845}
5846
5847