1/*
2 * Copyright (C) 2007 The Guava Authors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.google.common.collect;
18
19import static com.google.common.base.Preconditions.checkArgument;
20import static com.google.common.base.Preconditions.checkNotNull;
21import static com.google.common.base.Preconditions.checkState;
22import static com.google.common.collect.BstSide.LEFT;
23import static com.google.common.collect.BstSide.RIGHT;
24
25import java.io.IOException;
26import java.io.ObjectInputStream;
27import java.io.ObjectOutputStream;
28import java.io.Serializable;
29import java.util.Comparator;
30import java.util.ConcurrentModificationException;
31import java.util.Iterator;
32
33import javax.annotation.Nullable;
34
35import com.google.common.annotations.GwtCompatible;
36import com.google.common.annotations.GwtIncompatible;
37import com.google.common.primitives.Ints;
38
39/**
40 * A multiset which maintains the ordering of its elements, according to either
41 * their natural order or an explicit {@link Comparator}. In all cases, this
42 * implementation uses {@link Comparable#compareTo} or {@link
43 * Comparator#compare} instead of {@link Object#equals} to determine
44 * equivalence of instances.
45 *
46 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as
47 * explained by the {@link Comparable} class specification. Otherwise, the
48 * resulting multiset will violate the {@link java.util.Collection} contract,
49 * which is specified in terms of {@link Object#equals}.
50 *
51 * @author Louis Wasserman
52 * @author Jared Levy
53 * @since 2.0 (imported from Google Collections Library)
54 */
55@GwtCompatible(emulated = true)
56public final class TreeMultiset<E> extends AbstractSortedMultiset<E>
57    implements Serializable {
58
59  /**
60   * Creates a new, empty multiset, sorted according to the elements' natural
61   * order. All elements inserted into the multiset must implement the
62   * {@code Comparable} interface. Furthermore, all such elements must be
63   * <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
64   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in
65   * the multiset. If the user attempts to add an element to the multiset that
66   * violates this constraint (for example, the user attempts to add a string
67   * element to a set whose elements are integers), the {@code add(Object)}
68   * call will throw a {@code ClassCastException}.
69   *
70   * <p>The type specification is {@code <E extends Comparable>}, instead of the
71   * more specific {@code <E extends Comparable<? super E>>}, to support
72   * classes defined without generics.
73   */
74  public static <E extends Comparable> TreeMultiset<E> create() {
75    return new TreeMultiset<E>(Ordering.natural());
76  }
77
78  /**
79   * Creates a new, empty multiset, sorted according to the specified
80   * comparator. All elements inserted into the multiset must be <i>mutually
81   * comparable</i> by the specified comparator: {@code comparator.compare(e1,
82   * e2)} must not throw a {@code ClassCastException} for any elements {@code
83   * e1} and {@code e2} in the multiset. If the user attempts to add an element
84   * to the multiset that violates this constraint, the {@code add(Object)} call
85   * will throw a {@code ClassCastException}.
86   *
87   * @param comparator the comparator that will be used to sort this multiset. A
88   *     null value indicates that the elements' <i>natural ordering</i> should
89   *     be used.
90   */
91  @SuppressWarnings("unchecked")
92  public static <E> TreeMultiset<E> create(
93      @Nullable Comparator<? super E> comparator) {
94    return (comparator == null)
95           ? new TreeMultiset<E>((Comparator) Ordering.natural())
96           : new TreeMultiset<E>(comparator);
97  }
98
99  /**
100   * Creates an empty multiset containing the given initial elements, sorted
101   * according to the elements' natural order.
102   *
103   * <p>This implementation is highly efficient when {@code elements} is itself
104   * a {@link Multiset}.
105   *
106   * <p>The type specification is {@code <E extends Comparable>}, instead of the
107   * more specific {@code <E extends Comparable<? super E>>}, to support
108   * classes defined without generics.
109   */
110  public static <E extends Comparable> TreeMultiset<E> create(
111      Iterable<? extends E> elements) {
112    TreeMultiset<E> multiset = create();
113    Iterables.addAll(multiset, elements);
114    return multiset;
115  }
116
117  /**
118   * Returns an iterator over the elements contained in this collection.
119   */
120  @Override
121  public Iterator<E> iterator() {
122    // Needed to avoid Javadoc bug.
123    return super.iterator();
124  }
125
126  private TreeMultiset(Comparator<? super E> comparator) {
127    super(comparator);
128    this.range = GeneralRange.all(comparator);
129    this.rootReference = new Reference<Node<E>>();
130  }
131
132  private TreeMultiset(GeneralRange<E> range, Reference<Node<E>> root) {
133    super(range.comparator());
134    this.range = range;
135    this.rootReference = root;
136  }
137
138  @SuppressWarnings("unchecked")
139  E checkElement(Object o) {
140    return (E) o;
141  }
142
143  private transient final GeneralRange<E> range;
144
145  private transient final Reference<Node<E>> rootReference;
146
147  static final class Reference<T> {
148    T value;
149
150    public Reference() {}
151
152    public T get() {
153      return value;
154    }
155
156    public boolean compareAndSet(T expected, T newValue) {
157      if (value == expected) {
158        value = newValue;
159        return true;
160      }
161      return false;
162    }
163  }
164
165  @Override
166  int distinctElements() {
167    Node<E> root = rootReference.get();
168    return Ints.checkedCast(BstRangeOps.totalInRange(distinctAggregate(), range, root));
169  }
170
171  @Override
172  public int size() {
173    Node<E> root = rootReference.get();
174    return Ints.saturatedCast(BstRangeOps.totalInRange(sizeAggregate(), range, root));
175  }
176
177  @Override
178  public int count(@Nullable Object element) {
179    try {
180      E e = checkElement(element);
181      if (range.contains(e)) {
182        Node<E> node = BstOperations.seek(comparator(), rootReference.get(), e);
183        return countOrZero(node);
184      }
185      return 0;
186    } catch (ClassCastException e) {
187      return 0;
188    } catch (NullPointerException e) {
189      return 0;
190    }
191  }
192
193  private int mutate(@Nullable E e, MultisetModifier modifier) {
194    BstMutationRule<E, Node<E>> mutationRule = BstMutationRule.createRule(
195        modifier,
196        BstCountBasedBalancePolicies.
197          <E, Node<E>>singleRebalancePolicy(distinctAggregate()),
198        nodeFactory());
199    BstMutationResult<E, Node<E>> mutationResult =
200        BstOperations.mutate(comparator(), mutationRule, rootReference.get(), e);
201    if (!rootReference.compareAndSet(
202        mutationResult.getOriginalRoot(), mutationResult.getChangedRoot())) {
203      throw new ConcurrentModificationException();
204    }
205    Node<E> original = mutationResult.getOriginalTarget();
206    return countOrZero(original);
207  }
208
209  @Override
210  public int add(E element, int occurrences) {
211    checkElement(element);
212    if (occurrences == 0) {
213      return count(element);
214    }
215    checkArgument(range.contains(element));
216    return mutate(element, new AddModifier(occurrences));
217  }
218
219  @Override
220  public int remove(@Nullable Object element, int occurrences) {
221    if (element == null) {
222      return 0;
223    } else if (occurrences == 0) {
224      return count(element);
225    }
226    try {
227      E e = checkElement(element);
228      return range.contains(e) ? mutate(e, new RemoveModifier(occurrences)) : 0;
229    } catch (ClassCastException e) {
230      return 0;
231    }
232  }
233
234  @Override
235  public boolean setCount(E element, int oldCount, int newCount) {
236    checkElement(element);
237    checkArgument(range.contains(element));
238    return mutate(element, new ConditionalSetCountModifier(oldCount, newCount))
239        == oldCount;
240  }
241
242  @Override
243  public int setCount(E element, int count) {
244    checkElement(element);
245    checkArgument(range.contains(element));
246    return mutate(element, new SetCountModifier(count));
247  }
248
249  private BstPathFactory<Node<E>, BstInOrderPath<Node<E>>> pathFactory() {
250    return BstInOrderPath.inOrderFactory();
251  }
252
253  @Override
254  Iterator<Entry<E>> entryIterator() {
255    Node<E> root = rootReference.get();
256    final BstInOrderPath<Node<E>> startingPath =
257        BstRangeOps.furthestPath(range, LEFT, pathFactory(), root);
258    return iteratorInDirection(startingPath, RIGHT);
259  }
260
261  @Override
262  Iterator<Entry<E>> descendingEntryIterator() {
263    Node<E> root = rootReference.get();
264    final BstInOrderPath<Node<E>> startingPath =
265        BstRangeOps.furthestPath(range, RIGHT, pathFactory(), root);
266    return iteratorInDirection(startingPath, LEFT);
267  }
268
269  private Iterator<Entry<E>> iteratorInDirection(
270      @Nullable BstInOrderPath<Node<E>> start, final BstSide direction) {
271    final Iterator<BstInOrderPath<Node<E>>> pathIterator =
272        new AbstractLinkedIterator<BstInOrderPath<Node<E>>>(start) {
273          @Override
274          protected BstInOrderPath<Node<E>> computeNext(BstInOrderPath<Node<E>> previous) {
275            if (!previous.hasNext(direction)) {
276              return null;
277            }
278            BstInOrderPath<Node<E>> next = previous.next(direction);
279            // TODO(user): only check against one side
280            return range.contains(next.getTip().getKey()) ? next : null;
281          }
282        };
283    return new Iterator<Entry<E>>() {
284      E toRemove = null;
285
286      @Override
287      public boolean hasNext() {
288        return pathIterator.hasNext();
289      }
290
291      @Override
292      public Entry<E> next() {
293        BstInOrderPath<Node<E>> path = pathIterator.next();
294        return new LiveEntry(
295            toRemove = path.getTip().getKey(), path.getTip().elemCount());
296      }
297
298      @Override
299      public void remove() {
300        checkState(toRemove != null);
301        setCount(toRemove, 0);
302        toRemove = null;
303      }
304    };
305  }
306
307  class LiveEntry extends Multisets.AbstractEntry<E> {
308    private Node<E> expectedRoot;
309    private final E element;
310    private int count;
311
312    private LiveEntry(E element, int count) {
313      this.expectedRoot = rootReference.get();
314      this.element = element;
315      this.count = count;
316    }
317
318    @Override
319    public E getElement() {
320      return element;
321    }
322
323    @Override
324    public int getCount() {
325      if (rootReference.get() == expectedRoot) {
326        return count;
327      } else {
328        // check for updates
329        expectedRoot = rootReference.get();
330        return count = TreeMultiset.this.count(element);
331      }
332    }
333  }
334
335  @Override
336  public void clear() {
337    Node<E> root = rootReference.get();
338    Node<E> cleared = BstRangeOps.minusRange(range,
339        BstCountBasedBalancePolicies.<E, Node<E>>fullRebalancePolicy(distinctAggregate()),
340        nodeFactory(), root);
341    if (!rootReference.compareAndSet(root, cleared)) {
342      throw new ConcurrentModificationException();
343    }
344  }
345
346  @Override
347  public SortedMultiset<E> headMultiset(E upperBound, BoundType boundType) {
348    checkNotNull(upperBound);
349    return new TreeMultiset<E>(
350        range.intersect(GeneralRange.upTo(comparator, upperBound, boundType)), rootReference);
351  }
352
353  @Override
354  public SortedMultiset<E> tailMultiset(E lowerBound, BoundType boundType) {
355    checkNotNull(lowerBound);
356    return new TreeMultiset<E>(
357        range.intersect(GeneralRange.downTo(comparator, lowerBound, boundType)), rootReference);
358  }
359
360  /**
361   * {@inheritDoc}
362   *
363   * @since 11.0
364   */
365  @Override
366  public Comparator<? super E> comparator() {
367    return super.comparator();
368  }
369
370  private static final class Node<E> extends BstNode<E, Node<E>> implements Serializable {
371    private final long size;
372    private final int distinct;
373
374    private Node(E key, int elemCount, @Nullable Node<E> left,
375        @Nullable Node<E> right) {
376      super(key, left, right);
377      checkArgument(elemCount > 0);
378      this.size = (long) elemCount + sizeOrZero(left) + sizeOrZero(right);
379      this.distinct = 1 + distinctOrZero(left) + distinctOrZero(right);
380    }
381
382    int elemCount() {
383      long result = size - sizeOrZero(childOrNull(LEFT))
384          - sizeOrZero(childOrNull(RIGHT));
385      return Ints.checkedCast(result);
386    }
387
388    private Node(E key, int elemCount) {
389      this(key, elemCount, null, null);
390    }
391
392    private static final long serialVersionUID = 0;
393  }
394
395  private static long sizeOrZero(@Nullable Node<?> node) {
396    return (node == null) ? 0 : node.size;
397  }
398
399  private static int distinctOrZero(@Nullable Node<?> node) {
400    return (node == null) ? 0 : node.distinct;
401  }
402
403  private static int countOrZero(@Nullable Node<?> entry) {
404    return (entry == null) ? 0 : entry.elemCount();
405  }
406
407  @SuppressWarnings("unchecked")
408  private BstAggregate<Node<E>> distinctAggregate() {
409    return (BstAggregate) DISTINCT_AGGREGATE;
410  }
411
412  private static final BstAggregate<Node<Object>> DISTINCT_AGGREGATE =
413      new BstAggregate<Node<Object>>() {
414    @Override
415    public int entryValue(Node<Object> entry) {
416      return 1;
417    }
418
419    @Override
420    public long treeValue(@Nullable Node<Object> tree) {
421      return distinctOrZero(tree);
422    }
423  };
424
425  @SuppressWarnings("unchecked")
426  private BstAggregate<Node<E>> sizeAggregate() {
427    return (BstAggregate) SIZE_AGGREGATE;
428  }
429
430  private static final BstAggregate<Node<Object>> SIZE_AGGREGATE =
431      new BstAggregate<Node<Object>>() {
432        @Override
433        public int entryValue(Node<Object> entry) {
434          return entry.elemCount();
435        }
436
437        @Override
438        public long treeValue(@Nullable Node<Object> tree) {
439          return sizeOrZero(tree);
440        }
441      };
442
443  @SuppressWarnings("unchecked")
444  private BstNodeFactory<Node<E>> nodeFactory() {
445    return (BstNodeFactory) NODE_FACTORY;
446  }
447
448  private static final BstNodeFactory<Node<Object>> NODE_FACTORY =
449      new BstNodeFactory<Node<Object>>() {
450        @Override
451        public Node<Object> createNode(Node<Object> source, @Nullable Node<Object> left,
452            @Nullable Node<Object> right) {
453          return new Node<Object>(source.getKey(), source.elemCount(), left, right);
454        }
455      };
456
457  private abstract class MultisetModifier implements BstModifier<E, Node<E>> {
458    abstract int newCount(int oldCount);
459
460    @Nullable
461    @Override
462    public BstModificationResult<Node<E>> modify(E key, @Nullable Node<E> originalEntry) {
463      int oldCount = countOrZero(originalEntry);
464      int newCount = newCount(oldCount);
465      if (oldCount == newCount) {
466        return BstModificationResult.identity(originalEntry);
467      } else if (newCount == 0) {
468        return BstModificationResult.rebalancingChange(originalEntry, null);
469      } else if (oldCount == 0) {
470        return BstModificationResult.rebalancingChange(null, new Node<E>(key, newCount));
471      } else {
472        return BstModificationResult.rebuildingChange(originalEntry,
473            new Node<E>(originalEntry.getKey(), newCount));
474      }
475    }
476  }
477
478  private final class AddModifier extends MultisetModifier {
479    private final int countToAdd;
480
481    private AddModifier(int countToAdd) {
482      checkArgument(countToAdd > 0);
483      this.countToAdd = countToAdd;
484    }
485
486    @Override
487    int newCount(int oldCount) {
488      checkArgument(countToAdd <= Integer.MAX_VALUE - oldCount, "Cannot add this many elements");
489      return oldCount + countToAdd;
490    }
491  }
492
493  private final class RemoveModifier extends MultisetModifier {
494    private final int countToRemove;
495
496    private RemoveModifier(int countToRemove) {
497      checkArgument(countToRemove > 0);
498      this.countToRemove = countToRemove;
499    }
500
501    @Override
502    int newCount(int oldCount) {
503      return Math.max(0, oldCount - countToRemove);
504    }
505  }
506
507  private final class SetCountModifier extends MultisetModifier {
508    private final int countToSet;
509
510    private SetCountModifier(int countToSet) {
511      checkArgument(countToSet >= 0);
512      this.countToSet = countToSet;
513    }
514
515    @Override
516    int newCount(int oldCount) {
517      return countToSet;
518    }
519  }
520
521  private final class ConditionalSetCountModifier extends MultisetModifier {
522    private final int expectedCount;
523    private final int setCount;
524
525    private ConditionalSetCountModifier(int expectedCount, int setCount) {
526      checkArgument(setCount >= 0 & expectedCount >= 0);
527      this.expectedCount = expectedCount;
528      this.setCount = setCount;
529    }
530
531    @Override
532    int newCount(int oldCount) {
533      return (oldCount == expectedCount) ? setCount : oldCount;
534    }
535  }
536
537  /*
538   * TODO(jlevy): Decide whether entrySet() should return entries with an
539   * equals() method that calls the comparator to compare the two keys. If that
540   * change is made, AbstractMultiset.equals() can simply check whether two
541   * multisets have equal entry sets.
542   */
543
544  /**
545   * @serialData the comparator, the number of distinct elements, the first
546   *     element, its count, the second element, its count, and so on
547   */
548  @GwtIncompatible("java.io.ObjectOutputStream")
549  private void writeObject(ObjectOutputStream stream) throws IOException {
550    stream.defaultWriteObject();
551    stream.writeObject(elementSet().comparator());
552    Serialization.writeMultiset(this, stream);
553  }
554
555  @GwtIncompatible("java.io.ObjectInputStream")
556  private void readObject(ObjectInputStream stream)
557      throws IOException, ClassNotFoundException {
558    stream.defaultReadObject();
559    @SuppressWarnings("unchecked") // reading data stored by writeObject
560    Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
561    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
562    Serialization.getFieldSetter(TreeMultiset.class, "range").set(this,
563        GeneralRange.all(comparator));
564    Serialization.getFieldSetter(TreeMultiset.class, "rootReference").set(this,
565        new Reference<Node<E>>());
566    Serialization.populateMultiset(this, stream);
567  }
568
569  @GwtIncompatible("not needed in emulated source")
570  private static final long serialVersionUID = 1;
571}
572