1/*
2 * Copyright (C) 2007 The Guava Authors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.google.common.collect;
18
19import static com.google.common.base.Preconditions.checkArgument;
20import static com.google.common.base.Preconditions.checkNotNull;
21import static com.google.common.base.Preconditions.checkState;
22import static com.google.common.collect.BstSide.LEFT;
23import static com.google.common.collect.BstSide.RIGHT;
24
25import java.io.Serializable;
26import java.util.Comparator;
27import java.util.ConcurrentModificationException;
28import java.util.Iterator;
29
30import javax.annotation.Nullable;
31
32import com.google.common.annotations.GwtCompatible;
33import com.google.common.primitives.Ints;
34
35/**
36 * A multiset which maintains the ordering of its elements, according to either
37 * their natural order or an explicit {@link Comparator}. In all cases, this
38 * implementation uses {@link Comparable#compareTo} or {@link
39 * Comparator#compare} instead of {@link Object#equals} to determine
40 * equivalence of instances.
41 *
42 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as
43 * explained by the {@link Comparable} class specification. Otherwise, the
44 * resulting multiset will violate the {@link java.util.Collection} contract,
45 * which is specified in terms of {@link Object#equals}.
46 *
47 * @author Louis Wasserman
48 * @author Jared Levy
49 * @since 2.0 (imported from Google Collections Library)
50 */
51@GwtCompatible(emulated = true)
52public final class TreeMultiset<E> extends AbstractSortedMultiset<E>
53    implements Serializable {
54
55  /**
56   * Creates a new, empty multiset, sorted according to the elements' natural
57   * order. All elements inserted into the multiset must implement the
58   * {@code Comparable} interface. Furthermore, all such elements must be
59   * <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
60   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in
61   * the multiset. If the user attempts to add an element to the multiset that
62   * violates this constraint (for example, the user attempts to add a string
63   * element to a set whose elements are integers), the {@code add(Object)}
64   * call will throw a {@code ClassCastException}.
65   *
66   * <p>The type specification is {@code <E extends Comparable>}, instead of the
67   * more specific {@code <E extends Comparable<? super E>>}, to support
68   * classes defined without generics.
69   */
70  public static <E extends Comparable> TreeMultiset<E> create() {
71    return new TreeMultiset<E>(Ordering.natural());
72  }
73
74  /**
75   * Creates a new, empty multiset, sorted according to the specified
76   * comparator. All elements inserted into the multiset must be <i>mutually
77   * comparable</i> by the specified comparator: {@code comparator.compare(e1,
78   * e2)} must not throw a {@code ClassCastException} for any elements {@code
79   * e1} and {@code e2} in the multiset. If the user attempts to add an element
80   * to the multiset that violates this constraint, the {@code add(Object)} call
81   * will throw a {@code ClassCastException}.
82   *
83   * @param comparator the comparator that will be used to sort this multiset. A
84   *     null value indicates that the elements' <i>natural ordering</i> should
85   *     be used.
86   */
87  @SuppressWarnings("unchecked")
88  public static <E> TreeMultiset<E> create(
89      @Nullable Comparator<? super E> comparator) {
90    return (comparator == null)
91           ? new TreeMultiset<E>((Comparator) Ordering.natural())
92           : new TreeMultiset<E>(comparator);
93  }
94
95  /**
96   * Creates an empty multiset containing the given initial elements, sorted
97   * according to the elements' natural order.
98   *
99   * <p>This implementation is highly efficient when {@code elements} is itself
100   * a {@link Multiset}.
101   *
102   * <p>The type specification is {@code <E extends Comparable>}, instead of the
103   * more specific {@code <E extends Comparable<? super E>>}, to support
104   * classes defined without generics.
105   */
106  public static <E extends Comparable> TreeMultiset<E> create(
107      Iterable<? extends E> elements) {
108    TreeMultiset<E> multiset = create();
109    Iterables.addAll(multiset, elements);
110    return multiset;
111  }
112
113  /**
114   * Returns an iterator over the elements contained in this collection.
115   */
116  @Override
117  public Iterator<E> iterator() {
118    // Needed to avoid Javadoc bug.
119    return super.iterator();
120  }
121
122  private TreeMultiset(Comparator<? super E> comparator) {
123    super(comparator);
124    this.range = GeneralRange.all(comparator);
125    this.rootReference = new Reference<Node<E>>();
126  }
127
128  private TreeMultiset(GeneralRange<E> range, Reference<Node<E>> root) {
129    super(range.comparator());
130    this.range = range;
131    this.rootReference = root;
132  }
133
134  @SuppressWarnings("unchecked")
135  E checkElement(Object o) {
136    return (E) o;
137  }
138
139  private transient final GeneralRange<E> range;
140
141  private transient final Reference<Node<E>> rootReference;
142
143  static final class Reference<T> {
144    T value;
145
146    public Reference() {}
147
148    public T get() {
149      return value;
150    }
151
152    public boolean compareAndSet(T expected, T newValue) {
153      if (value == expected) {
154        value = newValue;
155        return true;
156      }
157      return false;
158    }
159  }
160
161  @Override
162  int distinctElements() {
163    Node<E> root = rootReference.get();
164    return Ints.checkedCast(BstRangeOps.totalInRange(distinctAggregate(), range, root));
165  }
166
167  @Override
168  public int size() {
169    Node<E> root = rootReference.get();
170    return Ints.saturatedCast(BstRangeOps.totalInRange(sizeAggregate(), range, root));
171  }
172
173  @Override
174  public int count(@Nullable Object element) {
175    try {
176      E e = checkElement(element);
177      if (range.contains(e)) {
178        Node<E> node = BstOperations.seek(comparator(), rootReference.get(), e);
179        return countOrZero(node);
180      }
181      return 0;
182    } catch (ClassCastException e) {
183      return 0;
184    } catch (NullPointerException e) {
185      return 0;
186    }
187  }
188
189  private int mutate(@Nullable E e, MultisetModifier modifier) {
190    BstMutationRule<E, Node<E>> mutationRule = BstMutationRule.createRule(
191        modifier,
192        BstCountBasedBalancePolicies.
193          <E, Node<E>>singleRebalancePolicy(distinctAggregate()),
194        nodeFactory());
195    BstMutationResult<E, Node<E>> mutationResult =
196        BstOperations.mutate(comparator(), mutationRule, rootReference.get(), e);
197    if (!rootReference.compareAndSet(
198        mutationResult.getOriginalRoot(), mutationResult.getChangedRoot())) {
199      throw new ConcurrentModificationException();
200    }
201    Node<E> original = mutationResult.getOriginalTarget();
202    return countOrZero(original);
203  }
204
205  @Override
206  public int add(E element, int occurrences) {
207    checkElement(element);
208    if (occurrences == 0) {
209      return count(element);
210    }
211    checkArgument(range.contains(element));
212    return mutate(element, new AddModifier(occurrences));
213  }
214
215  @Override
216  public int remove(@Nullable Object element, int occurrences) {
217    if (element == null) {
218      return 0;
219    } else if (occurrences == 0) {
220      return count(element);
221    }
222    try {
223      E e = checkElement(element);
224      return range.contains(e) ? mutate(e, new RemoveModifier(occurrences)) : 0;
225    } catch (ClassCastException e) {
226      return 0;
227    }
228  }
229
230  @Override
231  public boolean setCount(E element, int oldCount, int newCount) {
232    checkElement(element);
233    checkArgument(range.contains(element));
234    return mutate(element, new ConditionalSetCountModifier(oldCount, newCount))
235        == oldCount;
236  }
237
238  @Override
239  public int setCount(E element, int count) {
240    checkElement(element);
241    checkArgument(range.contains(element));
242    return mutate(element, new SetCountModifier(count));
243  }
244
245  private BstPathFactory<Node<E>, BstInOrderPath<Node<E>>> pathFactory() {
246    return BstInOrderPath.inOrderFactory();
247  }
248
249  @Override
250  Iterator<Entry<E>> entryIterator() {
251    Node<E> root = rootReference.get();
252    final BstInOrderPath<Node<E>> startingPath =
253        BstRangeOps.furthestPath(range, LEFT, pathFactory(), root);
254    return iteratorInDirection(startingPath, RIGHT);
255  }
256
257  @Override
258  Iterator<Entry<E>> descendingEntryIterator() {
259    Node<E> root = rootReference.get();
260    final BstInOrderPath<Node<E>> startingPath =
261        BstRangeOps.furthestPath(range, RIGHT, pathFactory(), root);
262    return iteratorInDirection(startingPath, LEFT);
263  }
264
265  private Iterator<Entry<E>> iteratorInDirection(
266      @Nullable BstInOrderPath<Node<E>> start, final BstSide direction) {
267    final Iterator<BstInOrderPath<Node<E>>> pathIterator =
268        new AbstractLinkedIterator<BstInOrderPath<Node<E>>>(start) {
269          @Override
270          protected BstInOrderPath<Node<E>> computeNext(BstInOrderPath<Node<E>> previous) {
271            if (!previous.hasNext(direction)) {
272              return null;
273            }
274            BstInOrderPath<Node<E>> next = previous.next(direction);
275            // TODO(user): only check against one side
276            return range.contains(next.getTip().getKey()) ? next : null;
277          }
278        };
279    return new Iterator<Entry<E>>() {
280      E toRemove = null;
281
282      @Override
283      public boolean hasNext() {
284        return pathIterator.hasNext();
285      }
286
287      @Override
288      public Entry<E> next() {
289        BstInOrderPath<Node<E>> path = pathIterator.next();
290        return new LiveEntry(
291            toRemove = path.getTip().getKey(), path.getTip().elemCount());
292      }
293
294      @Override
295      public void remove() {
296        checkState(toRemove != null);
297        setCount(toRemove, 0);
298        toRemove = null;
299      }
300    };
301  }
302
303  class LiveEntry extends Multisets.AbstractEntry<E> {
304    private Node<E> expectedRoot;
305    private final E element;
306    private int count;
307
308    private LiveEntry(E element, int count) {
309      this.expectedRoot = rootReference.get();
310      this.element = element;
311      this.count = count;
312    }
313
314    @Override
315    public E getElement() {
316      return element;
317    }
318
319    @Override
320    public int getCount() {
321      if (rootReference.get() == expectedRoot) {
322        return count;
323      } else {
324        // check for updates
325        expectedRoot = rootReference.get();
326        return count = TreeMultiset.this.count(element);
327      }
328    }
329  }
330
331  @Override
332  public void clear() {
333    Node<E> root = rootReference.get();
334    Node<E> cleared = BstRangeOps.minusRange(range,
335        BstCountBasedBalancePolicies.<E, Node<E>>fullRebalancePolicy(distinctAggregate()),
336        nodeFactory(), root);
337    if (!rootReference.compareAndSet(root, cleared)) {
338      throw new ConcurrentModificationException();
339    }
340  }
341
342  @Override
343  public SortedMultiset<E> headMultiset(E upperBound, BoundType boundType) {
344    checkNotNull(upperBound);
345    return new TreeMultiset<E>(
346        range.intersect(GeneralRange.upTo(comparator, upperBound, boundType)), rootReference);
347  }
348
349  @Override
350  public SortedMultiset<E> tailMultiset(E lowerBound, BoundType boundType) {
351    checkNotNull(lowerBound);
352    return new TreeMultiset<E>(
353        range.intersect(GeneralRange.downTo(comparator, lowerBound, boundType)), rootReference);
354  }
355
356  /**
357   * {@inheritDoc}
358   *
359   * @since 11.0
360   */
361  @Override
362  public Comparator<? super E> comparator() {
363    return super.comparator();
364  }
365
366  private static final class Node<E> extends BstNode<E, Node<E>> implements Serializable {
367    private final long size;
368    private final int distinct;
369
370    private Node(E key, int elemCount, @Nullable Node<E> left,
371        @Nullable Node<E> right) {
372      super(key, left, right);
373      checkArgument(elemCount > 0);
374      this.size = (long) elemCount + sizeOrZero(left) + sizeOrZero(right);
375      this.distinct = 1 + distinctOrZero(left) + distinctOrZero(right);
376    }
377
378    int elemCount() {
379      long result = size - sizeOrZero(childOrNull(LEFT))
380          - sizeOrZero(childOrNull(RIGHT));
381      return Ints.checkedCast(result);
382    }
383
384    private Node(E key, int elemCount) {
385      this(key, elemCount, null, null);
386    }
387
388    private static final long serialVersionUID = 0;
389  }
390
391  private static long sizeOrZero(@Nullable Node<?> node) {
392    return (node == null) ? 0 : node.size;
393  }
394
395  private static int distinctOrZero(@Nullable Node<?> node) {
396    return (node == null) ? 0 : node.distinct;
397  }
398
399  private static int countOrZero(@Nullable Node<?> entry) {
400    return (entry == null) ? 0 : entry.elemCount();
401  }
402
403  @SuppressWarnings("unchecked")
404  private BstAggregate<Node<E>> distinctAggregate() {
405    return (BstAggregate) DISTINCT_AGGREGATE;
406  }
407
408  private static final BstAggregate<Node<Object>> DISTINCT_AGGREGATE =
409      new BstAggregate<Node<Object>>() {
410    @Override
411    public int entryValue(Node<Object> entry) {
412      return 1;
413    }
414
415    @Override
416    public long treeValue(@Nullable Node<Object> tree) {
417      return distinctOrZero(tree);
418    }
419  };
420
421  @SuppressWarnings("unchecked")
422  private BstAggregate<Node<E>> sizeAggregate() {
423    return (BstAggregate) SIZE_AGGREGATE;
424  }
425
426  private static final BstAggregate<Node<Object>> SIZE_AGGREGATE =
427      new BstAggregate<Node<Object>>() {
428        @Override
429        public int entryValue(Node<Object> entry) {
430          return entry.elemCount();
431        }
432
433        @Override
434        public long treeValue(@Nullable Node<Object> tree) {
435          return sizeOrZero(tree);
436        }
437      };
438
439  @SuppressWarnings("unchecked")
440  private BstNodeFactory<Node<E>> nodeFactory() {
441    return (BstNodeFactory) NODE_FACTORY;
442  }
443
444  private static final BstNodeFactory<Node<Object>> NODE_FACTORY =
445      new BstNodeFactory<Node<Object>>() {
446        @Override
447        public Node<Object> createNode(Node<Object> source, @Nullable Node<Object> left,
448            @Nullable Node<Object> right) {
449          return new Node<Object>(source.getKey(), source.elemCount(), left, right);
450        }
451      };
452
453  private abstract class MultisetModifier implements BstModifier<E, Node<E>> {
454    abstract int newCount(int oldCount);
455
456    @Nullable
457    @Override
458    public BstModificationResult<Node<E>> modify(E key, @Nullable Node<E> originalEntry) {
459      int oldCount = countOrZero(originalEntry);
460      int newCount = newCount(oldCount);
461      if (oldCount == newCount) {
462        return BstModificationResult.identity(originalEntry);
463      } else if (newCount == 0) {
464        return BstModificationResult.rebalancingChange(originalEntry, null);
465      } else if (oldCount == 0) {
466        return BstModificationResult.rebalancingChange(null, new Node<E>(key, newCount));
467      } else {
468        return BstModificationResult.rebuildingChange(originalEntry,
469            new Node<E>(originalEntry.getKey(), newCount));
470      }
471    }
472  }
473
474  private final class AddModifier extends MultisetModifier {
475    private final int countToAdd;
476
477    private AddModifier(int countToAdd) {
478      checkArgument(countToAdd > 0);
479      this.countToAdd = countToAdd;
480    }
481
482    @Override
483    int newCount(int oldCount) {
484      checkArgument(countToAdd <= Integer.MAX_VALUE - oldCount, "Cannot add this many elements");
485      return oldCount + countToAdd;
486    }
487  }
488
489  private final class RemoveModifier extends MultisetModifier {
490    private final int countToRemove;
491
492    private RemoveModifier(int countToRemove) {
493      checkArgument(countToRemove > 0);
494      this.countToRemove = countToRemove;
495    }
496
497    @Override
498    int newCount(int oldCount) {
499      return Math.max(0, oldCount - countToRemove);
500    }
501  }
502
503  private final class SetCountModifier extends MultisetModifier {
504    private final int countToSet;
505
506    private SetCountModifier(int countToSet) {
507      checkArgument(countToSet >= 0);
508      this.countToSet = countToSet;
509    }
510
511    @Override
512    int newCount(int oldCount) {
513      return countToSet;
514    }
515  }
516
517  private final class ConditionalSetCountModifier extends MultisetModifier {
518    private final int expectedCount;
519    private final int setCount;
520
521    private ConditionalSetCountModifier(int expectedCount, int setCount) {
522      checkArgument(setCount >= 0 & expectedCount >= 0);
523      this.expectedCount = expectedCount;
524      this.setCount = setCount;
525    }
526
527    @Override
528    int newCount(int oldCount) {
529      return (oldCount == expectedCount) ? setCount : oldCount;
530    }
531  }
532
533  /*
534   * TODO(jlevy): Decide whether entrySet() should return entries with an
535   * equals() method that calls the comparator to compare the two keys. If that
536   * change is made, AbstractMultiset.equals() can simply check whether two
537   * multisets have equal entry sets.
538   */
539}
540