1/*
2 * Copyright (C) 2011 The Guava Authors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.google.common.math;
18
19import static com.google.common.base.Preconditions.checkArgument;
20import static com.google.common.base.Preconditions.checkNotNull;
21import static com.google.common.math.MathPreconditions.checkNoOverflow;
22import static com.google.common.math.MathPreconditions.checkNonNegative;
23import static com.google.common.math.MathPreconditions.checkPositive;
24import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
25import static java.lang.Math.abs;
26import static java.lang.Math.min;
27import static java.math.RoundingMode.HALF_EVEN;
28import static java.math.RoundingMode.HALF_UP;
29
30import com.google.common.annotations.Beta;
31import com.google.common.annotations.VisibleForTesting;
32
33import java.math.BigInteger;
34import java.math.RoundingMode;
35
36/**
37 * A class for arithmetic on values of type {@code long}. Where possible, methods are defined and
38 * named analogously to their {@code BigInteger} counterparts.
39 *
40 * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
41 * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
42 *
43 * <p>Similar functionality for {@code int} and for {@link BigInteger} can be found in
44 * {@link IntMath} and {@link BigIntegerMath} respectively.  For other common operations on
45 * {@code long} values, see {@link com.google.common.primitives.Longs}.
46 *
47 * @author Louis Wasserman
48 * @since 11.0
49 */
50@Beta
51public final class LongMath {
52  // NOTE: Whenever both tests are cheap and functional, it's faster to use &, | instead of &&, ||
53
54  /**
55   * Returns {@code true} if {@code x} represents a power of two.
56   *
57   * <p>This differs from {@code Long.bitCount(x) == 1}, because
58   * {@code Long.bitCount(Long.MIN_VALUE) == 1}, but {@link Long#MIN_VALUE} is not a power of two.
59   */
60  public static boolean isPowerOfTwo(long x) {
61    return x > 0 & (x & (x - 1)) == 0;
62  }
63
64  /**
65   * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
66   *
67   * @throws IllegalArgumentException if {@code x <= 0}
68   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
69   *         is not a power of two
70   */
71  @SuppressWarnings("fallthrough")
72  public static int log2(long x, RoundingMode mode) {
73    checkPositive("x", x);
74    switch (mode) {
75      case UNNECESSARY:
76        checkRoundingUnnecessary(isPowerOfTwo(x));
77        // fall through
78      case DOWN:
79      case FLOOR:
80        return (Long.SIZE - 1) - Long.numberOfLeadingZeros(x);
81
82      case UP:
83      case CEILING:
84        return Long.SIZE - Long.numberOfLeadingZeros(x - 1);
85
86      case HALF_DOWN:
87      case HALF_UP:
88      case HALF_EVEN:
89        // Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
90        int leadingZeros = Long.numberOfLeadingZeros(x);
91        long cmp = MAX_POWER_OF_SQRT2_UNSIGNED >>> leadingZeros;
92        // floor(2^(logFloor + 0.5))
93        int logFloor = (Long.SIZE - 1) - leadingZeros;
94        return (x <= cmp) ? logFloor : logFloor + 1;
95
96      default:
97        throw new AssertionError("impossible");
98    }
99  }
100
101  /** The biggest half power of two that fits into an unsigned long */
102  @VisibleForTesting static final long MAX_POWER_OF_SQRT2_UNSIGNED = 0xB504F333F9DE6484L;
103
104  /**
105   * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
106   *
107   * @throws IllegalArgumentException if {@code x <= 0}
108   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
109   *         is not a power of ten
110   */
111  @SuppressWarnings("fallthrough")
112  public static int log10(long x, RoundingMode mode) {
113    checkPositive("x", x);
114    if (fitsInInt(x)) {
115      return IntMath.log10((int) x, mode);
116    }
117    int logFloor = log10Floor(x);
118    long floorPow = POWERS_OF_10[logFloor];
119    switch (mode) {
120      case UNNECESSARY:
121        checkRoundingUnnecessary(x == floorPow);
122        // fall through
123      case FLOOR:
124      case DOWN:
125        return logFloor;
126      case CEILING:
127      case UP:
128        return (x == floorPow) ? logFloor : logFloor + 1;
129      case HALF_DOWN:
130      case HALF_UP:
131      case HALF_EVEN:
132        // sqrt(10) is irrational, so log10(x)-logFloor is never exactly 0.5
133        return (x <= HALF_POWERS_OF_10[logFloor]) ? logFloor : logFloor + 1;
134      default:
135        throw new AssertionError();
136    }
137  }
138
139  static int log10Floor(long x) {
140    for (int i = 1; i < POWERS_OF_10.length; i++) {
141      if (x < POWERS_OF_10[i]) {
142        return i - 1;
143      }
144    }
145    return POWERS_OF_10.length - 1;
146  }
147
148  @VisibleForTesting
149  static final long[] POWERS_OF_10 = {
150    1L,
151    10L,
152    100L,
153    1000L,
154    10000L,
155    100000L,
156    1000000L,
157    10000000L,
158    100000000L,
159    1000000000L,
160    10000000000L,
161    100000000000L,
162    1000000000000L,
163    10000000000000L,
164    100000000000000L,
165    1000000000000000L,
166    10000000000000000L,
167    100000000000000000L,
168    1000000000000000000L
169  };
170
171  // HALF_POWERS_OF_10[i] = largest long less than 10^(i + 0.5)
172  @VisibleForTesting
173  static final long[] HALF_POWERS_OF_10 = {
174    3L,
175    31L,
176    316L,
177    3162L,
178    31622L,
179    316227L,
180    3162277L,
181    31622776L,
182    316227766L,
183    3162277660L,
184    31622776601L,
185    316227766016L,
186    3162277660168L,
187    31622776601683L,
188    316227766016837L,
189    3162277660168379L,
190    31622776601683793L,
191    316227766016837933L,
192    3162277660168379331L
193  };
194
195  /**
196   * Returns {@code b} to the {@code k}th power. Even if the result overflows, it will be equal to
197   * {@code BigInteger.valueOf(b).pow(k).longValue()}. This implementation runs in {@code O(log k)}
198   * time.
199   *
200   * @throws IllegalArgumentException if {@code k < 0}
201   */
202  public static long pow(long b, int k) {
203    checkNonNegative("exponent", k);
204    if (-2 <= b && b <= 2) {
205      switch ((int) b) {
206        case 0:
207          return (k == 0) ? 1 : 0;
208        case 1:
209          return 1;
210        case (-1):
211          return ((k & 1) == 0) ? 1 : -1;
212        case 2:
213          return (k < Long.SIZE) ? 1L << k : 0;
214        case (-2):
215          if (k < Long.SIZE) {
216            return ((k & 1) == 0) ? 1L << k : -(1L << k);
217          } else {
218            return 0;
219          }
220      }
221    }
222    for (long accum = 1;; k >>= 1) {
223      switch (k) {
224        case 0:
225          return accum;
226        case 1:
227          return accum * b;
228        default:
229          accum *= ((k & 1) == 0) ? 1 : b;
230          b *= b;
231      }
232    }
233  }
234
235  /**
236   * Returns the square root of {@code x}, rounded with the specified rounding mode.
237   *
238   * @throws IllegalArgumentException if {@code x < 0}
239   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and
240   *         {@code sqrt(x)} is not an integer
241   */
242  @SuppressWarnings("fallthrough")
243  public static long sqrt(long x, RoundingMode mode) {
244    checkNonNegative("x", x);
245    if (fitsInInt(x)) {
246      return IntMath.sqrt((int) x, mode);
247    }
248    long sqrtFloor = sqrtFloor(x);
249    switch (mode) {
250      case UNNECESSARY:
251        checkRoundingUnnecessary(sqrtFloor * sqrtFloor == x); // fall through
252      case FLOOR:
253      case DOWN:
254        return sqrtFloor;
255      case CEILING:
256      case UP:
257        return (sqrtFloor * sqrtFloor == x) ? sqrtFloor : sqrtFloor + 1;
258      case HALF_DOWN:
259      case HALF_UP:
260      case HALF_EVEN:
261        long halfSquare = sqrtFloor * sqrtFloor + sqrtFloor;
262        /*
263         * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both
264         * x and halfSquare are integers, this is equivalent to testing whether or not x <=
265         * halfSquare. (We have to deal with overflow, though.)
266         */
267        return (halfSquare >= x | halfSquare < 0) ? sqrtFloor : sqrtFloor + 1;
268      default:
269        throw new AssertionError();
270    }
271  }
272
273  private static long sqrtFloor(long x) {
274    // Hackers's Delight, Figure 11-1
275    long sqrt0 = (long) Math.sqrt(x);
276    // Precision can be lost in the cast to double, so we use this as a starting estimate.
277    long sqrt1 = (sqrt0 + (x / sqrt0)) >> 1;
278    if (sqrt1 == sqrt0) {
279      return sqrt0;
280    }
281    do {
282      sqrt0 = sqrt1;
283      sqrt1 = (sqrt0 + (x / sqrt0)) >> 1;
284    } while (sqrt1 < sqrt0);
285    return sqrt0;
286  }
287
288  /**
289   * Returns the result of dividing {@code p} by {@code q}, rounding using the specified
290   * {@code RoundingMode}.
291   *
292   * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
293   *         is not an integer multiple of {@code b}
294   */
295  @SuppressWarnings("fallthrough")
296  public static long divide(long p, long q, RoundingMode mode) {
297    checkNotNull(mode);
298    long div = p / q; // throws if q == 0
299    long rem = p - q * div; // equals p % q
300
301    if (rem == 0) {
302      return div;
303    }
304
305    /*
306     * Normal Java division rounds towards 0, consistently with RoundingMode.DOWN. We just have to
307     * deal with the cases where rounding towards 0 is wrong, which typically depends on the sign of
308     * p / q.
309     *
310     * signum is 1 if p and q are both nonnegative or both negative, and -1 otherwise.
311     */
312    int signum = 1 | (int) ((p ^ q) >> (Long.SIZE - 1));
313    boolean increment;
314    switch (mode) {
315      case UNNECESSARY:
316        checkRoundingUnnecessary(rem == 0);
317        // fall through
318      case DOWN:
319        increment = false;
320        break;
321      case UP:
322        increment = true;
323        break;
324      case CEILING:
325        increment = signum > 0;
326        break;
327      case FLOOR:
328        increment = signum < 0;
329        break;
330      case HALF_EVEN:
331      case HALF_DOWN:
332      case HALF_UP:
333        long absRem = abs(rem);
334        long cmpRemToHalfDivisor = absRem - (abs(q) - absRem);
335        // subtracting two nonnegative longs can't overflow
336        // cmpRemToHalfDivisor has the same sign as compare(abs(rem), abs(q) / 2).
337        if (cmpRemToHalfDivisor == 0) { // exactly on the half mark
338          increment = (mode == HALF_UP | (mode == HALF_EVEN & (div & 1) != 0));
339        } else {
340          increment = cmpRemToHalfDivisor > 0; // closer to the UP value
341        }
342        break;
343      default:
344        throw new AssertionError();
345    }
346    return increment ? div + signum : div;
347  }
348
349  /**
350   * Returns {@code x mod m}. This differs from {@code x % m} in that it always returns a
351   * non-negative result.
352   *
353   * <p>For example:
354   *
355   * <pre> {@code
356   *
357   * mod(7, 4) == 3
358   * mod(-7, 4) == 1
359   * mod(-1, 4) == 3
360   * mod(-8, 4) == 0
361   * mod(8, 4) == 0}</pre>
362   *
363   * @throws ArithmeticException if {@code m <= 0}
364   */
365  public static int mod(long x, int m) {
366    // Cast is safe because the result is guaranteed in the range [0, m)
367    return (int) mod(x, (long) m);
368  }
369
370  /**
371   * Returns {@code x mod m}. This differs from {@code x % m} in that it always returns a
372   * non-negative result.
373   *
374   * <p>For example:
375   *
376   * <pre> {@code
377   *
378   * mod(7, 4) == 3
379   * mod(-7, 4) == 1
380   * mod(-1, 4) == 3
381   * mod(-8, 4) == 0
382   * mod(8, 4) == 0}</pre>
383   *
384   * @throws ArithmeticException if {@code m <= 0}
385   */
386  public static long mod(long x, long m) {
387    if (m <= 0) {
388      throw new ArithmeticException("Modulus " + m + " must be > 0");
389    }
390    long result = x % m;
391    return (result >= 0) ? result : result + m;
392  }
393
394  /**
395   * Returns the greatest common divisor of {@code a, b}. Returns {@code 0} if
396   * {@code a == 0 && b == 0}.
397   *
398   * @throws IllegalArgumentException if {@code a < 0} or {@code b < 0}
399   */
400  public static long gcd(long a, long b) {
401    /*
402     * The reason we require both arguments to be >= 0 is because otherwise, what do you return on
403     * gcd(0, Long.MIN_VALUE)? BigInteger.gcd would return positive 2^63, but positive 2^63 isn't
404     * an int.
405     */
406    checkNonNegative("a", a);
407    checkNonNegative("b", b);
408    if (a == 0 | b == 0) {
409      return a | b;
410    }
411    /*
412     * Uses the binary GCD algorithm; see http://en.wikipedia.org/wiki/Binary_GCD_algorithm.
413     * This is over 40% faster than the Euclidean algorithm in benchmarks.
414     */
415    int aTwos = Long.numberOfTrailingZeros(a);
416    a >>= aTwos; // divide out all 2s
417    int bTwos = Long.numberOfTrailingZeros(b);
418    b >>= bTwos; // divide out all 2s
419    while (a != b) { // both a, b are odd
420      if (a < b) { // swap a, b
421        long t = b;
422        b = a;
423        a = t;
424      }
425      a -= b; // a is now positive and even
426      a >>= Long.numberOfTrailingZeros(a); // divide out all 2s, since 2 doesn't divide b
427    }
428    return a << min(aTwos, bTwos);
429  }
430
431  /**
432   * Returns the sum of {@code a} and {@code b}, provided it does not overflow.
433   *
434   * @throws ArithmeticException if {@code a + b} overflows in signed {@code long} arithmetic
435   */
436  public static long checkedAdd(long a, long b) {
437    long result = a + b;
438    checkNoOverflow((a ^ b) < 0 | (a ^ result) >= 0);
439    return result;
440  }
441
442  /**
443   * Returns the difference of {@code a} and {@code b}, provided it does not overflow.
444   *
445   * @throws ArithmeticException if {@code a - b} overflows in signed {@code long} arithmetic
446   */
447  public static long checkedSubtract(long a, long b) {
448    long result = a - b;
449    checkNoOverflow((a ^ b) >= 0 | (a ^ result) >= 0);
450    return result;
451  }
452
453  /**
454   * Returns the product of {@code a} and {@code b}, provided it does not overflow.
455   *
456   * @throws ArithmeticException if {@code a * b} overflows in signed {@code long} arithmetic
457   */
458  public static long checkedMultiply(long a, long b) {
459    // Hacker's Delight, Section 2-12
460    int leadingZeros = Long.numberOfLeadingZeros(a) + Long.numberOfLeadingZeros(~a)
461        + Long.numberOfLeadingZeros(b) + Long.numberOfLeadingZeros(~b);
462    /*
463     * If leadingZeros > Long.SIZE + 1 it's definitely fine, if it's < Long.SIZE it's definitely
464     * bad. We do the leadingZeros check to avoid the division below if at all possible.
465     *
466     * Otherwise, if b == Long.MIN_VALUE, then the only allowed values of a are 0 and 1. We take
467     * care of all a < 0 with their own check, because in particular, the case a == -1 will
468     * incorrectly pass the division check below.
469     *
470     * In all other cases, we check that either a is 0 or the result is consistent with division.
471     */
472    if (leadingZeros > Long.SIZE + 1) {
473      return a * b;
474    }
475    checkNoOverflow(leadingZeros >= Long.SIZE);
476    checkNoOverflow(a >= 0 | b != Long.MIN_VALUE);
477    long result = a * b;
478    checkNoOverflow(a == 0 || result / a == b);
479    return result;
480  }
481
482  /**
483   * Returns the {@code b} to the {@code k}th power, provided it does not overflow.
484   *
485   * @throws ArithmeticException if {@code b} to the {@code k}th power overflows in signed
486   *         {@code long} arithmetic
487   */
488  public static long checkedPow(long b, int k) {
489    checkNonNegative("exponent", k);
490    if (b >= -2 & b <= 2) {
491      switch ((int) b) {
492        case 0:
493          return (k == 0) ? 1 : 0;
494        case 1:
495          return 1;
496        case (-1):
497          return ((k & 1) == 0) ? 1 : -1;
498        case 2:
499          checkNoOverflow(k < Long.SIZE - 1);
500          return 1L << k;
501        case (-2):
502          checkNoOverflow(k < Long.SIZE);
503          return ((k & 1) == 0) ? (1L << k) : (-1L << k);
504      }
505    }
506    long accum = 1;
507    while (true) {
508      switch (k) {
509        case 0:
510          return accum;
511        case 1:
512          return checkedMultiply(accum, b);
513        default:
514          if ((k & 1) != 0) {
515            accum = checkedMultiply(accum, b);
516          }
517          k >>= 1;
518          if (k > 0) {
519            checkNoOverflow(b <= FLOOR_SQRT_MAX_LONG);
520            b *= b;
521          }
522      }
523    }
524  }
525
526  @VisibleForTesting static final long FLOOR_SQRT_MAX_LONG = 3037000499L;
527
528  /**
529   * Returns {@code n!}, that is, the product of the first {@code n} positive
530   * integers, {@code 1} if {@code n == 0}, or {@link Long#MAX_VALUE} if the
531   * result does not fit in a {@code long}.
532   *
533   * @throws IllegalArgumentException if {@code n < 0}
534   */
535  public static long factorial(int n) {
536    checkNonNegative("n", n);
537    return (n < FACTORIALS.length) ? FACTORIALS[n] : Long.MAX_VALUE;
538  }
539
540  static final long[] FACTORIALS = {
541      1L,
542      1L,
543      1L * 2,
544      1L * 2 * 3,
545      1L * 2 * 3 * 4,
546      1L * 2 * 3 * 4 * 5,
547      1L * 2 * 3 * 4 * 5 * 6,
548      1L * 2 * 3 * 4 * 5 * 6 * 7,
549      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8,
550      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9,
551      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10,
552      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11,
553      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12,
554      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13,
555      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14,
556      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15,
557      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16,
558      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17,
559      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17 * 18,
560      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17 * 18 * 19,
561      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17 * 18 * 19 * 20
562  };
563
564  /**
565   * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
566   * {@code k}, or {@link Long#MAX_VALUE} if the result does not fit in a {@code long}.
567   *
568   * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
569   */
570  public static long binomial(int n, int k) {
571    checkNonNegative("n", n);
572    checkNonNegative("k", k);
573    checkArgument(k <= n, "k (%s) > n (%s)", k, n);
574    if (k > (n >> 1)) {
575      k = n - k;
576    }
577    if (k >= BIGGEST_BINOMIALS.length || n > BIGGEST_BINOMIALS[k]) {
578      return Long.MAX_VALUE;
579    }
580    long result = 1;
581    if (k < BIGGEST_SIMPLE_BINOMIALS.length && n <= BIGGEST_SIMPLE_BINOMIALS[k]) {
582      // guaranteed not to overflow
583      for (int i = 0; i < k; i++) {
584        result *= n - i;
585        result /= i + 1;
586      }
587    } else {
588      // We want to do this in long math for speed, but want to avoid overflow.
589      // Dividing by the GCD suffices to avoid overflow in all the remaining cases.
590      for (int i = 1; i <= k; i++, n--) {
591        int d = IntMath.gcd(n, i);
592        result /= i / d; // (i/d) is guaranteed to divide result
593        result *= n / d;
594      }
595    }
596    return result;
597  }
598
599  /*
600   * binomial(BIGGEST_BINOMIALS[k], k) fits in a long, but not
601   * binomial(BIGGEST_BINOMIALS[k] + 1, k).
602   */
603  static final int[] BIGGEST_BINOMIALS =
604      {Integer.MAX_VALUE, Integer.MAX_VALUE, Integer.MAX_VALUE, 3810779, 121977, 16175, 4337, 1733,
605          887, 534, 361, 265, 206, 169, 143, 125, 111, 101, 94, 88, 83, 79, 76, 74, 72, 70, 69, 68,
606          67, 67, 66, 66, 66, 66};
607
608  /*
609   * binomial(BIGGEST_SIMPLE_BINOMIALS[k], k) doesn't need to use the slower GCD-based impl,
610   * but binomial(BIGGEST_SIMPLE_BINOMIALS[k] + 1, k) does.
611   */
612  @VisibleForTesting static final int[] BIGGEST_SIMPLE_BINOMIALS =
613      {Integer.MAX_VALUE, Integer.MAX_VALUE, Integer.MAX_VALUE, 2642246, 86251, 11724, 3218, 1313,
614          684, 419, 287, 214, 169, 139, 119, 105, 95, 87, 81, 76, 73, 70, 68, 66, 64, 63, 62, 62,
615          61, 61, 61};
616  // These values were generated by using checkedMultiply to see when the simple multiply/divide
617  // algorithm would lead to an overflow.
618
619  static boolean fitsInInt(long x) {
620    return (int) x == x;
621  }
622
623  private LongMath() {}
624}
625