1//===- InstCombineAddSub.cpp ----------------------------------------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements the visit functions for add, fadd, sub, and fsub.
11//
12//===----------------------------------------------------------------------===//
13
14#include "InstCombine.h"
15#include "llvm/Analysis/InstructionSimplify.h"
16#include "llvm/IR/DataLayout.h"
17#include "llvm/Support/GetElementPtrTypeIterator.h"
18#include "llvm/Support/PatternMatch.h"
19using namespace llvm;
20using namespace PatternMatch;
21
22namespace {
23
24  /// Class representing coefficient of floating-point addend.
25  /// This class needs to be highly efficient, which is especially true for
26  /// the constructor. As of I write this comment, the cost of the default
27  /// constructor is merely 4-byte-store-zero (Assuming compiler is able to
28  /// perform write-merging).
29  ///
30  class FAddendCoef {
31  public:
32    // The constructor has to initialize a APFloat, which is uncessary for
33    // most addends which have coefficient either 1 or -1. So, the constructor
34    // is expensive. In order to avoid the cost of the constructor, we should
35    // reuse some instances whenever possible. The pre-created instances
36    // FAddCombine::Add[0-5] embodies this idea.
37    //
38    FAddendCoef() : IsFp(false), BufHasFpVal(false), IntVal(0) {}
39    ~FAddendCoef();
40
41    void set(short C) {
42      assert(!insaneIntVal(C) && "Insane coefficient");
43      IsFp = false; IntVal = C;
44    }
45
46    void set(const APFloat& C);
47
48    void negate();
49
50    bool isZero() const { return isInt() ? !IntVal : getFpVal().isZero(); }
51    Value *getValue(Type *) const;
52
53    // If possible, don't define operator+/operator- etc because these
54    // operators inevitably call FAddendCoef's constructor which is not cheap.
55    void operator=(const FAddendCoef &A);
56    void operator+=(const FAddendCoef &A);
57    void operator-=(const FAddendCoef &A);
58    void operator*=(const FAddendCoef &S);
59
60    bool isOne() const { return isInt() && IntVal == 1; }
61    bool isTwo() const { return isInt() && IntVal == 2; }
62    bool isMinusOne() const { return isInt() && IntVal == -1; }
63    bool isMinusTwo() const { return isInt() && IntVal == -2; }
64
65  private:
66    bool insaneIntVal(int V) { return V > 4 || V < -4; }
67    APFloat *getFpValPtr(void)
68      { return reinterpret_cast<APFloat*>(&FpValBuf.buffer[0]); }
69    const APFloat *getFpValPtr(void) const
70      { return reinterpret_cast<const APFloat*>(&FpValBuf.buffer[0]); }
71
72    const APFloat &getFpVal(void) const {
73      assert(IsFp && BufHasFpVal && "Incorret state");
74      return *getFpValPtr();
75    }
76
77    APFloat &getFpVal(void)
78      { assert(IsFp && BufHasFpVal && "Incorret state"); return *getFpValPtr(); }
79
80    bool isInt() const { return !IsFp; }
81
82  private:
83
84    bool IsFp;
85
86    // True iff FpValBuf contains an instance of APFloat.
87    bool BufHasFpVal;
88
89    // The integer coefficient of an individual addend is either 1 or -1,
90    // and we try to simplify at most 4 addends from neighboring at most
91    // two instructions. So the range of <IntVal> falls in [-4, 4]. APInt
92    // is overkill of this end.
93    short IntVal;
94
95    AlignedCharArrayUnion<APFloat> FpValBuf;
96  };
97
98  /// FAddend is used to represent floating-point addend. An addend is
99  /// represented as <C, V>, where the V is a symbolic value, and C is a
100  /// constant coefficient. A constant addend is represented as <C, 0>.
101  ///
102  class FAddend {
103  public:
104    FAddend() { Val = 0; }
105
106    Value *getSymVal (void) const { return Val; }
107    const FAddendCoef &getCoef(void) const { return Coeff; }
108
109    bool isConstant() const { return Val == 0; }
110    bool isZero() const { return Coeff.isZero(); }
111
112    void set(short Coefficient, Value *V) { Coeff.set(Coefficient), Val = V; }
113    void set(const APFloat& Coefficient, Value *V)
114      { Coeff.set(Coefficient); Val = V; }
115    void set(const ConstantFP* Coefficient, Value *V)
116      { Coeff.set(Coefficient->getValueAPF()); Val = V; }
117
118    void negate() { Coeff.negate(); }
119
120    /// Drill down the U-D chain one step to find the definition of V, and
121    /// try to break the definition into one or two addends.
122    static unsigned drillValueDownOneStep(Value* V, FAddend &A0, FAddend &A1);
123
124    /// Similar to FAddend::drillDownOneStep() except that the value being
125    /// splitted is the addend itself.
126    unsigned drillAddendDownOneStep(FAddend &Addend0, FAddend &Addend1) const;
127
128    void operator+=(const FAddend &T) {
129      assert((Val == T.Val) && "Symbolic-values disagree");
130      Coeff += T.Coeff;
131    }
132
133  private:
134    void Scale(const FAddendCoef& ScaleAmt) { Coeff *= ScaleAmt; }
135
136    // This addend has the value of "Coeff * Val".
137    Value *Val;
138    FAddendCoef Coeff;
139  };
140
141  /// FAddCombine is the class for optimizing an unsafe fadd/fsub along
142  /// with its neighboring at most two instructions.
143  ///
144  class FAddCombine {
145  public:
146    FAddCombine(InstCombiner::BuilderTy *B) : Builder(B), Instr(0) {}
147    Value *simplify(Instruction *FAdd);
148
149  private:
150    typedef SmallVector<const FAddend*, 4> AddendVect;
151
152    Value *simplifyFAdd(AddendVect& V, unsigned InstrQuota);
153
154    Value *performFactorization(Instruction *I);
155
156    /// Convert given addend to a Value
157    Value *createAddendVal(const FAddend &A, bool& NeedNeg);
158
159    /// Return the number of instructions needed to emit the N-ary addition.
160    unsigned calcInstrNumber(const AddendVect& Vect);
161    Value *createFSub(Value *Opnd0, Value *Opnd1);
162    Value *createFAdd(Value *Opnd0, Value *Opnd1);
163    Value *createFMul(Value *Opnd0, Value *Opnd1);
164    Value *createFDiv(Value *Opnd0, Value *Opnd1);
165    Value *createFNeg(Value *V);
166    Value *createNaryFAdd(const AddendVect& Opnds, unsigned InstrQuota);
167    void createInstPostProc(Instruction *NewInst);
168
169    InstCombiner::BuilderTy *Builder;
170    Instruction *Instr;
171
172  private:
173     // Debugging stuff are clustered here.
174    #ifndef NDEBUG
175      unsigned CreateInstrNum;
176      void initCreateInstNum() { CreateInstrNum = 0; }
177      void incCreateInstNum() { CreateInstrNum++; }
178    #else
179      void initCreateInstNum() {}
180      void incCreateInstNum() {}
181    #endif
182  };
183}
184
185//===----------------------------------------------------------------------===//
186//
187// Implementation of
188//    {FAddendCoef, FAddend, FAddition, FAddCombine}.
189//
190//===----------------------------------------------------------------------===//
191FAddendCoef::~FAddendCoef() {
192  if (BufHasFpVal)
193    getFpValPtr()->~APFloat();
194}
195
196void FAddendCoef::set(const APFloat& C) {
197  APFloat *P = getFpValPtr();
198
199  if (isInt()) {
200    // As the buffer is meanless byte stream, we cannot call
201    // APFloat::operator=().
202    new(P) APFloat(C);
203  } else
204    *P = C;
205
206  IsFp = BufHasFpVal = true;
207}
208
209void FAddendCoef::operator=(const FAddendCoef& That) {
210  if (That.isInt())
211    set(That.IntVal);
212  else
213    set(That.getFpVal());
214}
215
216void FAddendCoef::operator+=(const FAddendCoef &That) {
217  enum APFloat::roundingMode RndMode = APFloat::rmNearestTiesToEven;
218  if (isInt() == That.isInt()) {
219    if (isInt())
220      IntVal += That.IntVal;
221    else
222      getFpVal().add(That.getFpVal(), RndMode);
223    return;
224  }
225
226  if (isInt()) {
227    const APFloat &T = That.getFpVal();
228    set(T);
229    getFpVal().add(APFloat(T.getSemantics(), IntVal), RndMode);
230    return;
231  }
232
233  APFloat &T = getFpVal();
234  T.add(APFloat(T.getSemantics(), That.IntVal), RndMode);
235}
236
237void FAddendCoef::operator-=(const FAddendCoef &That) {
238  enum APFloat::roundingMode RndMode = APFloat::rmNearestTiesToEven;
239  if (isInt() == That.isInt()) {
240    if (isInt())
241      IntVal -= That.IntVal;
242    else
243      getFpVal().subtract(That.getFpVal(), RndMode);
244    return;
245  }
246
247  if (isInt()) {
248    const APFloat &T = That.getFpVal();
249    set(T);
250    getFpVal().subtract(APFloat(T.getSemantics(), IntVal), RndMode);
251    return;
252  }
253
254  APFloat &T = getFpVal();
255  T.subtract(APFloat(T.getSemantics(), IntVal), RndMode);
256}
257
258void FAddendCoef::operator*=(const FAddendCoef &That) {
259  if (That.isOne())
260    return;
261
262  if (That.isMinusOne()) {
263    negate();
264    return;
265  }
266
267  if (isInt() && That.isInt()) {
268    int Res = IntVal * (int)That.IntVal;
269    assert(!insaneIntVal(Res) && "Insane int value");
270    IntVal = Res;
271    return;
272  }
273
274  const fltSemantics &Semantic =
275    isInt() ? That.getFpVal().getSemantics() : getFpVal().getSemantics();
276
277  if (isInt())
278    set(APFloat(Semantic, IntVal));
279  APFloat &F0 = getFpVal();
280
281  if (That.isInt())
282    F0.multiply(APFloat(Semantic, That.IntVal), APFloat::rmNearestTiesToEven);
283  else
284    F0.multiply(That.getFpVal(), APFloat::rmNearestTiesToEven);
285
286  return;
287}
288
289void FAddendCoef::negate() {
290  if (isInt())
291    IntVal = 0 - IntVal;
292  else
293    getFpVal().changeSign();
294}
295
296Value *FAddendCoef::getValue(Type *Ty) const {
297  return isInt() ?
298    ConstantFP::get(Ty, float(IntVal)) :
299    ConstantFP::get(Ty->getContext(), getFpVal());
300}
301
302// The definition of <Val>     Addends
303// =========================================
304//  A + B                     <1, A>, <1,B>
305//  A - B                     <1, A>, <1,B>
306//  0 - B                     <-1, B>
307//  C * A,                    <C, A>
308//  A + C                     <1, A> <C, NULL>
309//  0 +/- 0                   <0, NULL> (corner case)
310//
311// Legend: A and B are not constant, C is constant
312//
313unsigned FAddend::drillValueDownOneStep
314  (Value *Val, FAddend &Addend0, FAddend &Addend1) {
315  Instruction *I = 0;
316  if (Val == 0 || !(I = dyn_cast<Instruction>(Val)))
317    return 0;
318
319  unsigned Opcode = I->getOpcode();
320
321  if (Opcode == Instruction::FAdd || Opcode == Instruction::FSub) {
322    ConstantFP *C0, *C1;
323    Value *Opnd0 = I->getOperand(0);
324    Value *Opnd1 = I->getOperand(1);
325    if ((C0 = dyn_cast<ConstantFP>(Opnd0)) && C0->isZero())
326      Opnd0 = 0;
327
328    if ((C1 = dyn_cast<ConstantFP>(Opnd1)) && C1->isZero())
329      Opnd1 = 0;
330
331    if (Opnd0) {
332      if (!C0)
333        Addend0.set(1, Opnd0);
334      else
335        Addend0.set(C0, 0);
336    }
337
338    if (Opnd1) {
339      FAddend &Addend = Opnd0 ? Addend1 : Addend0;
340      if (!C1)
341        Addend.set(1, Opnd1);
342      else
343        Addend.set(C1, 0);
344      if (Opcode == Instruction::FSub)
345        Addend.negate();
346    }
347
348    if (Opnd0 || Opnd1)
349      return Opnd0 && Opnd1 ? 2 : 1;
350
351    // Both operands are zero. Weird!
352    Addend0.set(APFloat(C0->getValueAPF().getSemantics()), 0);
353    return 1;
354  }
355
356  if (I->getOpcode() == Instruction::FMul) {
357    Value *V0 = I->getOperand(0);
358    Value *V1 = I->getOperand(1);
359    if (ConstantFP *C = dyn_cast<ConstantFP>(V0)) {
360      Addend0.set(C, V1);
361      return 1;
362    }
363
364    if (ConstantFP *C = dyn_cast<ConstantFP>(V1)) {
365      Addend0.set(C, V0);
366      return 1;
367    }
368  }
369
370  return 0;
371}
372
373// Try to break *this* addend into two addends. e.g. Suppose this addend is
374// <2.3, V>, and V = X + Y, by calling this function, we obtain two addends,
375// i.e. <2.3, X> and <2.3, Y>.
376//
377unsigned FAddend::drillAddendDownOneStep
378  (FAddend &Addend0, FAddend &Addend1) const {
379  if (isConstant())
380    return 0;
381
382  unsigned BreakNum = FAddend::drillValueDownOneStep(Val, Addend0, Addend1);
383  if (!BreakNum || Coeff.isOne())
384    return BreakNum;
385
386  Addend0.Scale(Coeff);
387
388  if (BreakNum == 2)
389    Addend1.Scale(Coeff);
390
391  return BreakNum;
392}
393
394// Try to perform following optimization on the input instruction I. Return the
395// simplified expression if was successful; otherwise, return 0.
396//
397//   Instruction "I" is                Simplified into
398// -------------------------------------------------------
399//   (x * y) +/- (x * z)               x * (y +/- z)
400//   (y / x) +/- (z / x)               (y +/- z) / x
401//
402Value *FAddCombine::performFactorization(Instruction *I) {
403  assert((I->getOpcode() == Instruction::FAdd ||
404          I->getOpcode() == Instruction::FSub) && "Expect add/sub");
405
406  Instruction *I0 = dyn_cast<Instruction>(I->getOperand(0));
407  Instruction *I1 = dyn_cast<Instruction>(I->getOperand(1));
408
409  if (!I0 || !I1 || I0->getOpcode() != I1->getOpcode())
410    return 0;
411
412  bool isMpy = false;
413  if (I0->getOpcode() == Instruction::FMul)
414    isMpy = true;
415  else if (I0->getOpcode() != Instruction::FDiv)
416    return 0;
417
418  Value *Opnd0_0 = I0->getOperand(0);
419  Value *Opnd0_1 = I0->getOperand(1);
420  Value *Opnd1_0 = I1->getOperand(0);
421  Value *Opnd1_1 = I1->getOperand(1);
422
423  //  Input Instr I       Factor   AddSub0  AddSub1
424  //  ----------------------------------------------
425  // (x*y) +/- (x*z)        x        y         z
426  // (y/x) +/- (z/x)        x        y         z
427  //
428  Value *Factor = 0;
429  Value *AddSub0 = 0, *AddSub1 = 0;
430
431  if (isMpy) {
432    if (Opnd0_0 == Opnd1_0 || Opnd0_0 == Opnd1_1)
433      Factor = Opnd0_0;
434    else if (Opnd0_1 == Opnd1_0 || Opnd0_1 == Opnd1_1)
435      Factor = Opnd0_1;
436
437    if (Factor) {
438      AddSub0 = (Factor == Opnd0_0) ? Opnd0_1 : Opnd0_0;
439      AddSub1 = (Factor == Opnd1_0) ? Opnd1_1 : Opnd1_0;
440    }
441  } else if (Opnd0_1 == Opnd1_1) {
442    Factor = Opnd0_1;
443    AddSub0 = Opnd0_0;
444    AddSub1 = Opnd1_0;
445  }
446
447  if (!Factor)
448    return 0;
449
450  // Create expression "NewAddSub = AddSub0 +/- AddsSub1"
451  Value *NewAddSub = (I->getOpcode() == Instruction::FAdd) ?
452                      createFAdd(AddSub0, AddSub1) :
453                      createFSub(AddSub0, AddSub1);
454  if (ConstantFP *CFP = dyn_cast<ConstantFP>(NewAddSub)) {
455    const APFloat &F = CFP->getValueAPF();
456    if (!F.isNormal() || F.isDenormal())
457      return 0;
458  }
459
460  if (isMpy)
461    return createFMul(Factor, NewAddSub);
462
463  return createFDiv(NewAddSub, Factor);
464}
465
466Value *FAddCombine::simplify(Instruction *I) {
467  assert(I->hasUnsafeAlgebra() && "Should be in unsafe mode");
468
469  // Currently we are not able to handle vector type.
470  if (I->getType()->isVectorTy())
471    return 0;
472
473  assert((I->getOpcode() == Instruction::FAdd ||
474          I->getOpcode() == Instruction::FSub) && "Expect add/sub");
475
476  // Save the instruction before calling other member-functions.
477  Instr = I;
478
479  FAddend Opnd0, Opnd1, Opnd0_0, Opnd0_1, Opnd1_0, Opnd1_1;
480
481  unsigned OpndNum = FAddend::drillValueDownOneStep(I, Opnd0, Opnd1);
482
483  // Step 1: Expand the 1st addend into Opnd0_0 and Opnd0_1.
484  unsigned Opnd0_ExpNum = 0;
485  unsigned Opnd1_ExpNum = 0;
486
487  if (!Opnd0.isConstant())
488    Opnd0_ExpNum = Opnd0.drillAddendDownOneStep(Opnd0_0, Opnd0_1);
489
490  // Step 2: Expand the 2nd addend into Opnd1_0 and Opnd1_1.
491  if (OpndNum == 2 && !Opnd1.isConstant())
492    Opnd1_ExpNum = Opnd1.drillAddendDownOneStep(Opnd1_0, Opnd1_1);
493
494  // Step 3: Try to optimize Opnd0_0 + Opnd0_1 + Opnd1_0 + Opnd1_1
495  if (Opnd0_ExpNum && Opnd1_ExpNum) {
496    AddendVect AllOpnds;
497    AllOpnds.push_back(&Opnd0_0);
498    AllOpnds.push_back(&Opnd1_0);
499    if (Opnd0_ExpNum == 2)
500      AllOpnds.push_back(&Opnd0_1);
501    if (Opnd1_ExpNum == 2)
502      AllOpnds.push_back(&Opnd1_1);
503
504    // Compute instruction quota. We should save at least one instruction.
505    unsigned InstQuota = 0;
506
507    Value *V0 = I->getOperand(0);
508    Value *V1 = I->getOperand(1);
509    InstQuota = ((!isa<Constant>(V0) && V0->hasOneUse()) &&
510                 (!isa<Constant>(V1) && V1->hasOneUse())) ? 2 : 1;
511
512    if (Value *R = simplifyFAdd(AllOpnds, InstQuota))
513      return R;
514  }
515
516  if (OpndNum != 2) {
517    // The input instruction is : "I=0.0 +/- V". If the "V" were able to be
518    // splitted into two addends, say "V = X - Y", the instruction would have
519    // been optimized into "I = Y - X" in the previous steps.
520    //
521    const FAddendCoef &CE = Opnd0.getCoef();
522    return CE.isOne() ? Opnd0.getSymVal() : 0;
523  }
524
525  // step 4: Try to optimize Opnd0 + Opnd1_0 [+ Opnd1_1]
526  if (Opnd1_ExpNum) {
527    AddendVect AllOpnds;
528    AllOpnds.push_back(&Opnd0);
529    AllOpnds.push_back(&Opnd1_0);
530    if (Opnd1_ExpNum == 2)
531      AllOpnds.push_back(&Opnd1_1);
532
533    if (Value *R = simplifyFAdd(AllOpnds, 1))
534      return R;
535  }
536
537  // step 5: Try to optimize Opnd1 + Opnd0_0 [+ Opnd0_1]
538  if (Opnd0_ExpNum) {
539    AddendVect AllOpnds;
540    AllOpnds.push_back(&Opnd1);
541    AllOpnds.push_back(&Opnd0_0);
542    if (Opnd0_ExpNum == 2)
543      AllOpnds.push_back(&Opnd0_1);
544
545    if (Value *R = simplifyFAdd(AllOpnds, 1))
546      return R;
547  }
548
549  // step 6: Try factorization as the last resort,
550  return performFactorization(I);
551}
552
553Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) {
554
555  unsigned AddendNum = Addends.size();
556  assert(AddendNum <= 4 && "Too many addends");
557
558  // For saving intermediate results;
559  unsigned NextTmpIdx = 0;
560  FAddend TmpResult[3];
561
562  // Points to the constant addend of the resulting simplified expression.
563  // If the resulting expr has constant-addend, this constant-addend is
564  // desirable to reside at the top of the resulting expression tree. Placing
565  // constant close to supper-expr(s) will potentially reveal some optimization
566  // opportunities in super-expr(s).
567  //
568  const FAddend *ConstAdd = 0;
569
570  // Simplified addends are placed <SimpVect>.
571  AddendVect SimpVect;
572
573  // The outer loop works on one symbolic-value at a time. Suppose the input
574  // addends are : <a1, x>, <b1, y>, <a2, x>, <c1, z>, <b2, y>, ...
575  // The symbolic-values will be processed in this order: x, y, z.
576  //
577  for (unsigned SymIdx = 0; SymIdx < AddendNum; SymIdx++) {
578
579    const FAddend *ThisAddend = Addends[SymIdx];
580    if (!ThisAddend) {
581      // This addend was processed before.
582      continue;
583    }
584
585    Value *Val = ThisAddend->getSymVal();
586    unsigned StartIdx = SimpVect.size();
587    SimpVect.push_back(ThisAddend);
588
589    // The inner loop collects addends sharing same symbolic-value, and these
590    // addends will be later on folded into a single addend. Following above
591    // example, if the symbolic value "y" is being processed, the inner loop
592    // will collect two addends "<b1,y>" and "<b2,Y>". These two addends will
593    // be later on folded into "<b1+b2, y>".
594    //
595    for (unsigned SameSymIdx = SymIdx + 1;
596         SameSymIdx < AddendNum; SameSymIdx++) {
597      const FAddend *T = Addends[SameSymIdx];
598      if (T && T->getSymVal() == Val) {
599        // Set null such that next iteration of the outer loop will not process
600        // this addend again.
601        Addends[SameSymIdx] = 0;
602        SimpVect.push_back(T);
603      }
604    }
605
606    // If multiple addends share same symbolic value, fold them together.
607    if (StartIdx + 1 != SimpVect.size()) {
608      FAddend &R = TmpResult[NextTmpIdx ++];
609      R = *SimpVect[StartIdx];
610      for (unsigned Idx = StartIdx + 1; Idx < SimpVect.size(); Idx++)
611        R += *SimpVect[Idx];
612
613      // Pop all addends being folded and push the resulting folded addend.
614      SimpVect.resize(StartIdx);
615      if (Val != 0) {
616        if (!R.isZero()) {
617          SimpVect.push_back(&R);
618        }
619      } else {
620        // Don't push constant addend at this time. It will be the last element
621        // of <SimpVect>.
622        ConstAdd = &R;
623      }
624    }
625  }
626
627  assert((NextTmpIdx <= sizeof(TmpResult)/sizeof(TmpResult[0]) + 1) &&
628         "out-of-bound access");
629
630  if (ConstAdd)
631    SimpVect.push_back(ConstAdd);
632
633  Value *Result;
634  if (!SimpVect.empty())
635    Result = createNaryFAdd(SimpVect, InstrQuota);
636  else {
637    // The addition is folded to 0.0.
638    Result = ConstantFP::get(Instr->getType(), 0.0);
639  }
640
641  return Result;
642}
643
644Value *FAddCombine::createNaryFAdd
645  (const AddendVect &Opnds, unsigned InstrQuota) {
646  assert(!Opnds.empty() && "Expect at least one addend");
647
648  // Step 1: Check if the # of instructions needed exceeds the quota.
649  //
650  unsigned InstrNeeded = calcInstrNumber(Opnds);
651  if (InstrNeeded > InstrQuota)
652    return 0;
653
654  initCreateInstNum();
655
656  // step 2: Emit the N-ary addition.
657  // Note that at most three instructions are involved in Fadd-InstCombine: the
658  // addition in question, and at most two neighboring instructions.
659  // The resulting optimized addition should have at least one less instruction
660  // than the original addition expression tree. This implies that the resulting
661  // N-ary addition has at most two instructions, and we don't need to worry
662  // about tree-height when constructing the N-ary addition.
663
664  Value *LastVal = 0;
665  bool LastValNeedNeg = false;
666
667  // Iterate the addends, creating fadd/fsub using adjacent two addends.
668  for (AddendVect::const_iterator I = Opnds.begin(), E = Opnds.end();
669       I != E; I++) {
670    bool NeedNeg;
671    Value *V = createAddendVal(**I, NeedNeg);
672    if (!LastVal) {
673      LastVal = V;
674      LastValNeedNeg = NeedNeg;
675      continue;
676    }
677
678    if (LastValNeedNeg == NeedNeg) {
679      LastVal = createFAdd(LastVal, V);
680      continue;
681    }
682
683    if (LastValNeedNeg)
684      LastVal = createFSub(V, LastVal);
685    else
686      LastVal = createFSub(LastVal, V);
687
688    LastValNeedNeg = false;
689  }
690
691  if (LastValNeedNeg) {
692    LastVal = createFNeg(LastVal);
693  }
694
695  #ifndef NDEBUG
696    assert(CreateInstrNum == InstrNeeded &&
697           "Inconsistent in instruction numbers");
698  #endif
699
700  return LastVal;
701}
702
703Value *FAddCombine::createFSub
704  (Value *Opnd0, Value *Opnd1) {
705  Value *V = Builder->CreateFSub(Opnd0, Opnd1);
706  if (Instruction *I = dyn_cast<Instruction>(V))
707    createInstPostProc(I);
708  return V;
709}
710
711Value *FAddCombine::createFNeg(Value *V) {
712  Value *Zero = cast<Value>(ConstantFP::get(V->getType(), 0.0));
713  return createFSub(Zero, V);
714}
715
716Value *FAddCombine::createFAdd
717  (Value *Opnd0, Value *Opnd1) {
718  Value *V = Builder->CreateFAdd(Opnd0, Opnd1);
719  if (Instruction *I = dyn_cast<Instruction>(V))
720    createInstPostProc(I);
721  return V;
722}
723
724Value *FAddCombine::createFMul(Value *Opnd0, Value *Opnd1) {
725  Value *V = Builder->CreateFMul(Opnd0, Opnd1);
726  if (Instruction *I = dyn_cast<Instruction>(V))
727    createInstPostProc(I);
728  return V;
729}
730
731Value *FAddCombine::createFDiv(Value *Opnd0, Value *Opnd1) {
732  Value *V = Builder->CreateFDiv(Opnd0, Opnd1);
733  if (Instruction *I = dyn_cast<Instruction>(V))
734    createInstPostProc(I);
735  return V;
736}
737
738void FAddCombine::createInstPostProc(Instruction *NewInstr) {
739  NewInstr->setDebugLoc(Instr->getDebugLoc());
740
741  // Keep track of the number of instruction created.
742  incCreateInstNum();
743
744  // Propagate fast-math flags
745  NewInstr->setFastMathFlags(Instr->getFastMathFlags());
746}
747
748// Return the number of instruction needed to emit the N-ary addition.
749// NOTE: Keep this function in sync with createAddendVal().
750unsigned FAddCombine::calcInstrNumber(const AddendVect &Opnds) {
751  unsigned OpndNum = Opnds.size();
752  unsigned InstrNeeded = OpndNum - 1;
753
754  // The number of addends in the form of "(-1)*x".
755  unsigned NegOpndNum = 0;
756
757  // Adjust the number of instructions needed to emit the N-ary add.
758  for (AddendVect::const_iterator I = Opnds.begin(), E = Opnds.end();
759       I != E; I++) {
760    const FAddend *Opnd = *I;
761    if (Opnd->isConstant())
762      continue;
763
764    const FAddendCoef &CE = Opnd->getCoef();
765    if (CE.isMinusOne() || CE.isMinusTwo())
766      NegOpndNum++;
767
768    // Let the addend be "c * x". If "c == +/-1", the value of the addend
769    // is immediately available; otherwise, it needs exactly one instruction
770    // to evaluate the value.
771    if (!CE.isMinusOne() && !CE.isOne())
772      InstrNeeded++;
773  }
774  if (NegOpndNum == OpndNum)
775    InstrNeeded++;
776  return InstrNeeded;
777}
778
779// Input Addend        Value           NeedNeg(output)
780// ================================================================
781// Constant C          C               false
782// <+/-1, V>           V               coefficient is -1
783// <2/-2, V>          "fadd V, V"      coefficient is -2
784// <C, V>             "fmul V, C"      false
785//
786// NOTE: Keep this function in sync with FAddCombine::calcInstrNumber.
787Value *FAddCombine::createAddendVal
788  (const FAddend &Opnd, bool &NeedNeg) {
789  const FAddendCoef &Coeff = Opnd.getCoef();
790
791  if (Opnd.isConstant()) {
792    NeedNeg = false;
793    return Coeff.getValue(Instr->getType());
794  }
795
796  Value *OpndVal = Opnd.getSymVal();
797
798  if (Coeff.isMinusOne() || Coeff.isOne()) {
799    NeedNeg = Coeff.isMinusOne();
800    return OpndVal;
801  }
802
803  if (Coeff.isTwo() || Coeff.isMinusTwo()) {
804    NeedNeg = Coeff.isMinusTwo();
805    return createFAdd(OpndVal, OpndVal);
806  }
807
808  NeedNeg = false;
809  return createFMul(OpndVal, Coeff.getValue(Instr->getType()));
810}
811
812/// AddOne - Add one to a ConstantInt.
813static Constant *AddOne(Constant *C) {
814  return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1));
815}
816
817/// SubOne - Subtract one from a ConstantInt.
818static Constant *SubOne(ConstantInt *C) {
819  return ConstantInt::get(C->getContext(), C->getValue()-1);
820}
821
822
823// dyn_castFoldableMul - If this value is a multiply that can be folded into
824// other computations (because it has a constant operand), return the
825// non-constant operand of the multiply, and set CST to point to the multiplier.
826// Otherwise, return null.
827//
828static inline Value *dyn_castFoldableMul(Value *V, ConstantInt *&CST) {
829  if (!V->hasOneUse() || !V->getType()->isIntegerTy())
830    return 0;
831
832  Instruction *I = dyn_cast<Instruction>(V);
833  if (I == 0) return 0;
834
835  if (I->getOpcode() == Instruction::Mul)
836    if ((CST = dyn_cast<ConstantInt>(I->getOperand(1))))
837      return I->getOperand(0);
838  if (I->getOpcode() == Instruction::Shl)
839    if ((CST = dyn_cast<ConstantInt>(I->getOperand(1)))) {
840      // The multiplier is really 1 << CST.
841      uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
842      uint32_t CSTVal = CST->getLimitedValue(BitWidth);
843      CST = ConstantInt::get(V->getType()->getContext(),
844                             APInt(BitWidth, 1).shl(CSTVal));
845      return I->getOperand(0);
846    }
847  return 0;
848}
849
850
851/// WillNotOverflowSignedAdd - Return true if we can prove that:
852///    (sext (add LHS, RHS))  === (add (sext LHS), (sext RHS))
853/// This basically requires proving that the add in the original type would not
854/// overflow to change the sign bit or have a carry out.
855bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) {
856  // There are different heuristics we can use for this.  Here are some simple
857  // ones.
858
859  // Add has the property that adding any two 2's complement numbers can only
860  // have one carry bit which can change a sign.  As such, if LHS and RHS each
861  // have at least two sign bits, we know that the addition of the two values
862  // will sign extend fine.
863  if (ComputeNumSignBits(LHS) > 1 && ComputeNumSignBits(RHS) > 1)
864    return true;
865
866
867  // If one of the operands only has one non-zero bit, and if the other operand
868  // has a known-zero bit in a more significant place than it (not including the
869  // sign bit) the ripple may go up to and fill the zero, but won't change the
870  // sign.  For example, (X & ~4) + 1.
871
872  // TODO: Implement.
873
874  return false;
875}
876
877Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
878  bool Changed = SimplifyAssociativeOrCommutative(I);
879  Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
880
881  if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(),
882                                 I.hasNoUnsignedWrap(), TD))
883    return ReplaceInstUsesWith(I, V);
884
885  // (A*B)+(A*C) -> A*(B+C) etc
886  if (Value *V = SimplifyUsingDistributiveLaws(I))
887    return ReplaceInstUsesWith(I, V);
888
889  if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
890    // X + (signbit) --> X ^ signbit
891    const APInt &Val = CI->getValue();
892    if (Val.isSignBit())
893      return BinaryOperator::CreateXor(LHS, RHS);
894
895    // See if SimplifyDemandedBits can simplify this.  This handles stuff like
896    // (X & 254)+1 -> (X&254)|1
897    if (SimplifyDemandedInstructionBits(I))
898      return &I;
899
900    // zext(bool) + C -> bool ? C + 1 : C
901    if (ZExtInst *ZI = dyn_cast<ZExtInst>(LHS))
902      if (ZI->getSrcTy()->isIntegerTy(1))
903        return SelectInst::Create(ZI->getOperand(0), AddOne(CI), CI);
904
905    Value *XorLHS = 0; ConstantInt *XorRHS = 0;
906    if (match(LHS, m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS)))) {
907      uint32_t TySizeBits = I.getType()->getScalarSizeInBits();
908      const APInt &RHSVal = CI->getValue();
909      unsigned ExtendAmt = 0;
910      // If we have ADD(XOR(AND(X, 0xFF), 0x80), 0xF..F80), it's a sext.
911      // If we have ADD(XOR(AND(X, 0xFF), 0xF..F80), 0x80), it's a sext.
912      if (XorRHS->getValue() == -RHSVal) {
913        if (RHSVal.isPowerOf2())
914          ExtendAmt = TySizeBits - RHSVal.logBase2() - 1;
915        else if (XorRHS->getValue().isPowerOf2())
916          ExtendAmt = TySizeBits - XorRHS->getValue().logBase2() - 1;
917      }
918
919      if (ExtendAmt) {
920        APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt);
921        if (!MaskedValueIsZero(XorLHS, Mask))
922          ExtendAmt = 0;
923      }
924
925      if (ExtendAmt) {
926        Constant *ShAmt = ConstantInt::get(I.getType(), ExtendAmt);
927        Value *NewShl = Builder->CreateShl(XorLHS, ShAmt, "sext");
928        return BinaryOperator::CreateAShr(NewShl, ShAmt);
929      }
930
931      // If this is a xor that was canonicalized from a sub, turn it back into
932      // a sub and fuse this add with it.
933      if (LHS->hasOneUse() && (XorRHS->getValue()+1).isPowerOf2()) {
934        IntegerType *IT = cast<IntegerType>(I.getType());
935        APInt LHSKnownOne(IT->getBitWidth(), 0);
936        APInt LHSKnownZero(IT->getBitWidth(), 0);
937        ComputeMaskedBits(XorLHS, LHSKnownZero, LHSKnownOne);
938        if ((XorRHS->getValue() | LHSKnownZero).isAllOnesValue())
939          return BinaryOperator::CreateSub(ConstantExpr::getAdd(XorRHS, CI),
940                                           XorLHS);
941      }
942    }
943  }
944
945  if (isa<Constant>(RHS) && isa<PHINode>(LHS))
946    if (Instruction *NV = FoldOpIntoPhi(I))
947      return NV;
948
949  if (I.getType()->isIntegerTy(1))
950    return BinaryOperator::CreateXor(LHS, RHS);
951
952  // X + X --> X << 1
953  if (LHS == RHS) {
954    BinaryOperator *New =
955      BinaryOperator::CreateShl(LHS, ConstantInt::get(I.getType(), 1));
956    New->setHasNoSignedWrap(I.hasNoSignedWrap());
957    New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
958    return New;
959  }
960
961  // -A + B  -->  B - A
962  // -A + -B  -->  -(A + B)
963  if (Value *LHSV = dyn_castNegVal(LHS)) {
964    if (!isa<Constant>(RHS))
965      if (Value *RHSV = dyn_castNegVal(RHS)) {
966        Value *NewAdd = Builder->CreateAdd(LHSV, RHSV, "sum");
967        return BinaryOperator::CreateNeg(NewAdd);
968      }
969
970    return BinaryOperator::CreateSub(RHS, LHSV);
971  }
972
973  // A + -B  -->  A - B
974  if (!isa<Constant>(RHS))
975    if (Value *V = dyn_castNegVal(RHS))
976      return BinaryOperator::CreateSub(LHS, V);
977
978
979  ConstantInt *C2;
980  if (Value *X = dyn_castFoldableMul(LHS, C2)) {
981    if (X == RHS)   // X*C + X --> X * (C+1)
982      return BinaryOperator::CreateMul(RHS, AddOne(C2));
983
984    // X*C1 + X*C2 --> X * (C1+C2)
985    ConstantInt *C1;
986    if (X == dyn_castFoldableMul(RHS, C1))
987      return BinaryOperator::CreateMul(X, ConstantExpr::getAdd(C1, C2));
988  }
989
990  // X + X*C --> X * (C+1)
991  if (dyn_castFoldableMul(RHS, C2) == LHS)
992    return BinaryOperator::CreateMul(LHS, AddOne(C2));
993
994  // A+B --> A|B iff A and B have no bits set in common.
995  if (IntegerType *IT = dyn_cast<IntegerType>(I.getType())) {
996    APInt LHSKnownOne(IT->getBitWidth(), 0);
997    APInt LHSKnownZero(IT->getBitWidth(), 0);
998    ComputeMaskedBits(LHS, LHSKnownZero, LHSKnownOne);
999    if (LHSKnownZero != 0) {
1000      APInt RHSKnownOne(IT->getBitWidth(), 0);
1001      APInt RHSKnownZero(IT->getBitWidth(), 0);
1002      ComputeMaskedBits(RHS, RHSKnownZero, RHSKnownOne);
1003
1004      // No bits in common -> bitwise or.
1005      if ((LHSKnownZero|RHSKnownZero).isAllOnesValue())
1006        return BinaryOperator::CreateOr(LHS, RHS);
1007    }
1008  }
1009
1010  // W*X + Y*Z --> W * (X+Z)  iff W == Y
1011  {
1012    Value *W, *X, *Y, *Z;
1013    if (match(LHS, m_Mul(m_Value(W), m_Value(X))) &&
1014        match(RHS, m_Mul(m_Value(Y), m_Value(Z)))) {
1015      if (W != Y) {
1016        if (W == Z) {
1017          std::swap(Y, Z);
1018        } else if (Y == X) {
1019          std::swap(W, X);
1020        } else if (X == Z) {
1021          std::swap(Y, Z);
1022          std::swap(W, X);
1023        }
1024      }
1025
1026      if (W == Y) {
1027        Value *NewAdd = Builder->CreateAdd(X, Z, LHS->getName());
1028        return BinaryOperator::CreateMul(W, NewAdd);
1029      }
1030    }
1031  }
1032
1033  if (ConstantInt *CRHS = dyn_cast<ConstantInt>(RHS)) {
1034    Value *X = 0;
1035    if (match(LHS, m_Not(m_Value(X))))    // ~X + C --> (C-1) - X
1036      return BinaryOperator::CreateSub(SubOne(CRHS), X);
1037
1038    // (X & FF00) + xx00  -> (X+xx00) & FF00
1039    if (LHS->hasOneUse() &&
1040        match(LHS, m_And(m_Value(X), m_ConstantInt(C2))) &&
1041        CRHS->getValue() == (CRHS->getValue() & C2->getValue())) {
1042      // See if all bits from the first bit set in the Add RHS up are included
1043      // in the mask.  First, get the rightmost bit.
1044      const APInt &AddRHSV = CRHS->getValue();
1045
1046      // Form a mask of all bits from the lowest bit added through the top.
1047      APInt AddRHSHighBits(~((AddRHSV & -AddRHSV)-1));
1048
1049      // See if the and mask includes all of these bits.
1050      APInt AddRHSHighBitsAnd(AddRHSHighBits & C2->getValue());
1051
1052      if (AddRHSHighBits == AddRHSHighBitsAnd) {
1053        // Okay, the xform is safe.  Insert the new add pronto.
1054        Value *NewAdd = Builder->CreateAdd(X, CRHS, LHS->getName());
1055        return BinaryOperator::CreateAnd(NewAdd, C2);
1056      }
1057    }
1058
1059    // Try to fold constant add into select arguments.
1060    if (SelectInst *SI = dyn_cast<SelectInst>(LHS))
1061      if (Instruction *R = FoldOpIntoSelect(I, SI))
1062        return R;
1063  }
1064
1065  // add (select X 0 (sub n A)) A  -->  select X A n
1066  {
1067    SelectInst *SI = dyn_cast<SelectInst>(LHS);
1068    Value *A = RHS;
1069    if (!SI) {
1070      SI = dyn_cast<SelectInst>(RHS);
1071      A = LHS;
1072    }
1073    if (SI && SI->hasOneUse()) {
1074      Value *TV = SI->getTrueValue();
1075      Value *FV = SI->getFalseValue();
1076      Value *N;
1077
1078      // Can we fold the add into the argument of the select?
1079      // We check both true and false select arguments for a matching subtract.
1080      if (match(FV, m_Zero()) && match(TV, m_Sub(m_Value(N), m_Specific(A))))
1081        // Fold the add into the true select value.
1082        return SelectInst::Create(SI->getCondition(), N, A);
1083
1084      if (match(TV, m_Zero()) && match(FV, m_Sub(m_Value(N), m_Specific(A))))
1085        // Fold the add into the false select value.
1086        return SelectInst::Create(SI->getCondition(), A, N);
1087    }
1088  }
1089
1090  // Check for (add (sext x), y), see if we can merge this into an
1091  // integer add followed by a sext.
1092  if (SExtInst *LHSConv = dyn_cast<SExtInst>(LHS)) {
1093    // (add (sext x), cst) --> (sext (add x, cst'))
1094    if (ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS)) {
1095      Constant *CI =
1096        ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType());
1097      if (LHSConv->hasOneUse() &&
1098          ConstantExpr::getSExt(CI, I.getType()) == RHSC &&
1099          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) {
1100        // Insert the new, smaller add.
1101        Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
1102                                              CI, "addconv");
1103        return new SExtInst(NewAdd, I.getType());
1104      }
1105    }
1106
1107    // (add (sext x), (sext y)) --> (sext (add int x, y))
1108    if (SExtInst *RHSConv = dyn_cast<SExtInst>(RHS)) {
1109      // Only do this if x/y have the same type, if at last one of them has a
1110      // single use (so we don't increase the number of sexts), and if the
1111      // integer add will not overflow.
1112      if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&&
1113          (LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
1114          WillNotOverflowSignedAdd(LHSConv->getOperand(0),
1115                                   RHSConv->getOperand(0))) {
1116        // Insert the new integer add.
1117        Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
1118                                             RHSConv->getOperand(0), "addconv");
1119        return new SExtInst(NewAdd, I.getType());
1120      }
1121    }
1122  }
1123
1124  // Check for (x & y) + (x ^ y)
1125  {
1126    Value *A = 0, *B = 0;
1127    if (match(RHS, m_Xor(m_Value(A), m_Value(B))) &&
1128        (match(LHS, m_And(m_Specific(A), m_Specific(B))) ||
1129         match(LHS, m_And(m_Specific(B), m_Specific(A)))))
1130      return BinaryOperator::CreateOr(A, B);
1131
1132    if (match(LHS, m_Xor(m_Value(A), m_Value(B))) &&
1133        (match(RHS, m_And(m_Specific(A), m_Specific(B))) ||
1134         match(RHS, m_And(m_Specific(B), m_Specific(A)))))
1135      return BinaryOperator::CreateOr(A, B);
1136  }
1137
1138  return Changed ? &I : 0;
1139}
1140
1141Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
1142  bool Changed = SimplifyAssociativeOrCommutative(I);
1143  Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
1144
1145  if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), TD))
1146    return ReplaceInstUsesWith(I, V);
1147
1148  if (isa<Constant>(RHS) && isa<PHINode>(LHS))
1149    if (Instruction *NV = FoldOpIntoPhi(I))
1150      return NV;
1151
1152  // -A + B  -->  B - A
1153  // -A + -B  -->  -(A + B)
1154  if (Value *LHSV = dyn_castFNegVal(LHS))
1155    return BinaryOperator::CreateFSub(RHS, LHSV);
1156
1157  // A + -B  -->  A - B
1158  if (!isa<Constant>(RHS))
1159    if (Value *V = dyn_castFNegVal(RHS))
1160      return BinaryOperator::CreateFSub(LHS, V);
1161
1162  // Check for (fadd double (sitofp x), y), see if we can merge this into an
1163  // integer add followed by a promotion.
1164  if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) {
1165    // (fadd double (sitofp x), fpcst) --> (sitofp (add int x, intcst))
1166    // ... if the constant fits in the integer value.  This is useful for things
1167    // like (double)(x & 1234) + 4.0 -> (double)((X & 1234)+4) which no longer
1168    // requires a constant pool load, and generally allows the add to be better
1169    // instcombined.
1170    if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS)) {
1171      Constant *CI =
1172      ConstantExpr::getFPToSI(CFP, LHSConv->getOperand(0)->getType());
1173      if (LHSConv->hasOneUse() &&
1174          ConstantExpr::getSIToFP(CI, I.getType()) == CFP &&
1175          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) {
1176        // Insert the new integer add.
1177        Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
1178                                              CI, "addconv");
1179        return new SIToFPInst(NewAdd, I.getType());
1180      }
1181    }
1182
1183    // (fadd double (sitofp x), (sitofp y)) --> (sitofp (add int x, y))
1184    if (SIToFPInst *RHSConv = dyn_cast<SIToFPInst>(RHS)) {
1185      // Only do this if x/y have the same type, if at last one of them has a
1186      // single use (so we don't increase the number of int->fp conversions),
1187      // and if the integer add will not overflow.
1188      if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&&
1189          (LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
1190          WillNotOverflowSignedAdd(LHSConv->getOperand(0),
1191                                   RHSConv->getOperand(0))) {
1192        // Insert the new integer add.
1193        Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
1194                                              RHSConv->getOperand(0),"addconv");
1195        return new SIToFPInst(NewAdd, I.getType());
1196      }
1197    }
1198  }
1199
1200  if (I.hasUnsafeAlgebra()) {
1201    if (Value *V = FAddCombine(Builder).simplify(&I))
1202      return ReplaceInstUsesWith(I, V);
1203  }
1204
1205  return Changed ? &I : 0;
1206}
1207
1208
1209/// Optimize pointer differences into the same array into a size.  Consider:
1210///  &A[10] - &A[0]: we should compile this to "10".  LHS/RHS are the pointer
1211/// operands to the ptrtoint instructions for the LHS/RHS of the subtract.
1212///
1213Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS,
1214                                               Type *Ty) {
1215  assert(TD && "Must have target data info for this");
1216
1217  // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize
1218  // this.
1219  bool Swapped = false;
1220  GEPOperator *GEP1 = 0, *GEP2 = 0;
1221
1222  // For now we require one side to be the base pointer "A" or a constant
1223  // GEP derived from it.
1224  if (GEPOperator *LHSGEP = dyn_cast<GEPOperator>(LHS)) {
1225    // (gep X, ...) - X
1226    if (LHSGEP->getOperand(0) == RHS) {
1227      GEP1 = LHSGEP;
1228      Swapped = false;
1229    } else if (GEPOperator *RHSGEP = dyn_cast<GEPOperator>(RHS)) {
1230      // (gep X, ...) - (gep X, ...)
1231      if (LHSGEP->getOperand(0)->stripPointerCasts() ==
1232            RHSGEP->getOperand(0)->stripPointerCasts()) {
1233        GEP2 = RHSGEP;
1234        GEP1 = LHSGEP;
1235        Swapped = false;
1236      }
1237    }
1238  }
1239
1240  if (GEPOperator *RHSGEP = dyn_cast<GEPOperator>(RHS)) {
1241    // X - (gep X, ...)
1242    if (RHSGEP->getOperand(0) == LHS) {
1243      GEP1 = RHSGEP;
1244      Swapped = true;
1245    } else if (GEPOperator *LHSGEP = dyn_cast<GEPOperator>(LHS)) {
1246      // (gep X, ...) - (gep X, ...)
1247      if (RHSGEP->getOperand(0)->stripPointerCasts() ==
1248            LHSGEP->getOperand(0)->stripPointerCasts()) {
1249        GEP2 = LHSGEP;
1250        GEP1 = RHSGEP;
1251        Swapped = true;
1252      }
1253    }
1254  }
1255
1256  // Avoid duplicating the arithmetic if GEP2 has non-constant indices and
1257  // multiple users.
1258  if (GEP1 == 0 ||
1259      (GEP2 != 0 && !GEP2->hasAllConstantIndices() && !GEP2->hasOneUse()))
1260    return 0;
1261
1262  // Emit the offset of the GEP and an intptr_t.
1263  Value *Result = EmitGEPOffset(GEP1);
1264
1265  // If we had a constant expression GEP on the other side offsetting the
1266  // pointer, subtract it from the offset we have.
1267  if (GEP2) {
1268    Value *Offset = EmitGEPOffset(GEP2);
1269    Result = Builder->CreateSub(Result, Offset);
1270  }
1271
1272  // If we have p - gep(p, ...)  then we have to negate the result.
1273  if (Swapped)
1274    Result = Builder->CreateNeg(Result, "diff.neg");
1275
1276  return Builder->CreateIntCast(Result, Ty, true);
1277}
1278
1279
1280Instruction *InstCombiner::visitSub(BinaryOperator &I) {
1281  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1282
1283  if (Value *V = SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(),
1284                                 I.hasNoUnsignedWrap(), TD))
1285    return ReplaceInstUsesWith(I, V);
1286
1287  // (A*B)-(A*C) -> A*(B-C) etc
1288  if (Value *V = SimplifyUsingDistributiveLaws(I))
1289    return ReplaceInstUsesWith(I, V);
1290
1291  // If this is a 'B = x-(-A)', change to B = x+A.  This preserves NSW/NUW.
1292  if (Value *V = dyn_castNegVal(Op1)) {
1293    BinaryOperator *Res = BinaryOperator::CreateAdd(Op0, V);
1294    Res->setHasNoSignedWrap(I.hasNoSignedWrap());
1295    Res->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
1296    return Res;
1297  }
1298
1299  if (I.getType()->isIntegerTy(1))
1300    return BinaryOperator::CreateXor(Op0, Op1);
1301
1302  // Replace (-1 - A) with (~A).
1303  if (match(Op0, m_AllOnes()))
1304    return BinaryOperator::CreateNot(Op1);
1305
1306  if (ConstantInt *C = dyn_cast<ConstantInt>(Op0)) {
1307    // C - ~X == X + (1+C)
1308    Value *X = 0;
1309    if (match(Op1, m_Not(m_Value(X))))
1310      return BinaryOperator::CreateAdd(X, AddOne(C));
1311
1312    // -(X >>u 31) -> (X >>s 31)
1313    // -(X >>s 31) -> (X >>u 31)
1314    if (C->isZero()) {
1315      Value *X; ConstantInt *CI;
1316      if (match(Op1, m_LShr(m_Value(X), m_ConstantInt(CI))) &&
1317          // Verify we are shifting out everything but the sign bit.
1318          CI->getValue() == I.getType()->getPrimitiveSizeInBits()-1)
1319        return BinaryOperator::CreateAShr(X, CI);
1320
1321      if (match(Op1, m_AShr(m_Value(X), m_ConstantInt(CI))) &&
1322          // Verify we are shifting out everything but the sign bit.
1323          CI->getValue() == I.getType()->getPrimitiveSizeInBits()-1)
1324        return BinaryOperator::CreateLShr(X, CI);
1325    }
1326
1327    // Try to fold constant sub into select arguments.
1328    if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
1329      if (Instruction *R = FoldOpIntoSelect(I, SI))
1330        return R;
1331
1332    // C-(X+C2) --> (C-C2)-X
1333    ConstantInt *C2;
1334    if (match(Op1, m_Add(m_Value(X), m_ConstantInt(C2))))
1335      return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X);
1336
1337    if (SimplifyDemandedInstructionBits(I))
1338      return &I;
1339
1340    // Fold (sub 0, (zext bool to B)) --> (sext bool to B)
1341    if (C->isZero() && match(Op1, m_ZExt(m_Value(X))))
1342      if (X->getType()->isIntegerTy(1))
1343        return CastInst::CreateSExtOrBitCast(X, Op1->getType());
1344
1345    // Fold (sub 0, (sext bool to B)) --> (zext bool to B)
1346    if (C->isZero() && match(Op1, m_SExt(m_Value(X))))
1347      if (X->getType()->isIntegerTy(1))
1348        return CastInst::CreateZExtOrBitCast(X, Op1->getType());
1349  }
1350
1351
1352  { Value *Y;
1353    // X-(X+Y) == -Y    X-(Y+X) == -Y
1354    if (match(Op1, m_Add(m_Specific(Op0), m_Value(Y))) ||
1355        match(Op1, m_Add(m_Value(Y), m_Specific(Op0))))
1356      return BinaryOperator::CreateNeg(Y);
1357
1358    // (X-Y)-X == -Y
1359    if (match(Op0, m_Sub(m_Specific(Op1), m_Value(Y))))
1360      return BinaryOperator::CreateNeg(Y);
1361  }
1362
1363  if (Op1->hasOneUse()) {
1364    Value *X = 0, *Y = 0, *Z = 0;
1365    Constant *C = 0;
1366    ConstantInt *CI = 0;
1367
1368    // (X - (Y - Z))  -->  (X + (Z - Y)).
1369    if (match(Op1, m_Sub(m_Value(Y), m_Value(Z))))
1370      return BinaryOperator::CreateAdd(Op0,
1371                                      Builder->CreateSub(Z, Y, Op1->getName()));
1372
1373    // (X - (X & Y))   -->   (X & ~Y)
1374    //
1375    if (match(Op1, m_And(m_Value(Y), m_Specific(Op0))) ||
1376        match(Op1, m_And(m_Specific(Op0), m_Value(Y))))
1377      return BinaryOperator::CreateAnd(Op0,
1378                                  Builder->CreateNot(Y, Y->getName() + ".not"));
1379
1380    // 0 - (X sdiv C)  -> (X sdiv -C)
1381    if (match(Op1, m_SDiv(m_Value(X), m_Constant(C))) &&
1382        match(Op0, m_Zero()))
1383      return BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(C));
1384
1385    // 0 - (X << Y)  -> (-X << Y)   when X is freely negatable.
1386    if (match(Op1, m_Shl(m_Value(X), m_Value(Y))) && match(Op0, m_Zero()))
1387      if (Value *XNeg = dyn_castNegVal(X))
1388        return BinaryOperator::CreateShl(XNeg, Y);
1389
1390    // X - X*C --> X * (1-C)
1391    if (match(Op1, m_Mul(m_Specific(Op0), m_ConstantInt(CI)))) {
1392      Constant *CP1 = ConstantExpr::getSub(ConstantInt::get(I.getType(),1), CI);
1393      return BinaryOperator::CreateMul(Op0, CP1);
1394    }
1395
1396    // X - X<<C --> X * (1-(1<<C))
1397    if (match(Op1, m_Shl(m_Specific(Op0), m_ConstantInt(CI)))) {
1398      Constant *One = ConstantInt::get(I.getType(), 1);
1399      C = ConstantExpr::getSub(One, ConstantExpr::getShl(One, CI));
1400      return BinaryOperator::CreateMul(Op0, C);
1401    }
1402
1403    // X - A*-B -> X + A*B
1404    // X - -A*B -> X + A*B
1405    Value *A, *B;
1406    if (match(Op1, m_Mul(m_Value(A), m_Neg(m_Value(B)))) ||
1407        match(Op1, m_Mul(m_Neg(m_Value(A)), m_Value(B))))
1408      return BinaryOperator::CreateAdd(Op0, Builder->CreateMul(A, B));
1409
1410    // X - A*CI -> X + A*-CI
1411    // X - CI*A -> X + A*-CI
1412    if (match(Op1, m_Mul(m_Value(A), m_ConstantInt(CI))) ||
1413        match(Op1, m_Mul(m_ConstantInt(CI), m_Value(A)))) {
1414      Value *NewMul = Builder->CreateMul(A, ConstantExpr::getNeg(CI));
1415      return BinaryOperator::CreateAdd(Op0, NewMul);
1416    }
1417  }
1418
1419  ConstantInt *C1;
1420  if (Value *X = dyn_castFoldableMul(Op0, C1)) {
1421    if (X == Op1)  // X*C - X --> X * (C-1)
1422      return BinaryOperator::CreateMul(Op1, SubOne(C1));
1423
1424    ConstantInt *C2;   // X*C1 - X*C2 -> X * (C1-C2)
1425    if (X == dyn_castFoldableMul(Op1, C2))
1426      return BinaryOperator::CreateMul(X, ConstantExpr::getSub(C1, C2));
1427  }
1428
1429  // Optimize pointer differences into the same array into a size.  Consider:
1430  //  &A[10] - &A[0]: we should compile this to "10".
1431  if (TD) {
1432    Value *LHSOp, *RHSOp;
1433    if (match(Op0, m_PtrToInt(m_Value(LHSOp))) &&
1434        match(Op1, m_PtrToInt(m_Value(RHSOp))))
1435      if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType()))
1436        return ReplaceInstUsesWith(I, Res);
1437
1438    // trunc(p)-trunc(q) -> trunc(p-q)
1439    if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) &&
1440        match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp)))))
1441      if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType()))
1442        return ReplaceInstUsesWith(I, Res);
1443  }
1444
1445  return 0;
1446}
1447
1448Instruction *InstCombiner::visitFSub(BinaryOperator &I) {
1449  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1450
1451  if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), TD))
1452    return ReplaceInstUsesWith(I, V);
1453
1454  // If this is a 'B = x-(-A)', change to B = x+A...
1455  if (Value *V = dyn_castFNegVal(Op1))
1456    return BinaryOperator::CreateFAdd(Op0, V);
1457
1458  if (I.hasUnsafeAlgebra()) {
1459    if (Value *V = FAddCombine(Builder).simplify(&I))
1460      return ReplaceInstUsesWith(I, V);
1461  }
1462
1463  return 0;
1464}
1465