1/*
2 * Copyright (C) 2010 The Guava Authors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.google.common.collect.testing;
18
19import java.io.Serializable;
20import java.util.Collection;
21import java.util.Comparator;
22import java.util.Iterator;
23import java.util.NavigableSet;
24import java.util.SortedSet;
25import java.util.TreeSet;
26
27/**
28 * A wrapper around {@code TreeSet} that aggressively checks to see if elements
29 * are mutually comparable. This implementation passes the navigable set test
30 * suites.
31 *
32 * @author Louis Wasserman
33 */
34public final class SafeTreeSet<E> implements Serializable, NavigableSet<E> {
35  @SuppressWarnings("unchecked")
36  private static final Comparator NATURAL_ORDER = new Comparator<Comparable>() {
37    @Override public int compare(Comparable o1, Comparable o2) {
38      return o1.compareTo(o2);
39    }
40  };
41  private final NavigableSet<E> delegate;
42
43  public SafeTreeSet() {
44    this(new TreeSet<E>());
45  }
46
47  public SafeTreeSet(Collection<? extends E> collection) {
48    this(new TreeSet<E>(collection));
49  }
50
51  public SafeTreeSet(Comparator<? super E> comparator) {
52    this(new TreeSet<E>(comparator));
53  }
54
55  public SafeTreeSet(SortedSet<E> set) {
56    this(new TreeSet<E>(set));
57  }
58
59  private SafeTreeSet(NavigableSet<E> delegate) {
60    this.delegate = delegate;
61    for (E e : this) {
62      checkValid(e);
63    }
64  }
65
66  @Override public boolean add(E element) {
67    return delegate.add(checkValid(element));
68  }
69
70  @Override public boolean addAll(Collection<? extends E> collection) {
71    for (E e : collection) {
72      checkValid(e);
73    }
74    return delegate.addAll(collection);
75  }
76
77  @Override public E ceiling(E e) {
78    return delegate.ceiling(checkValid(e));
79  }
80
81  @Override public void clear() {
82    delegate.clear();
83  }
84
85  @SuppressWarnings("unchecked")
86  @Override public Comparator<? super E> comparator() {
87    Comparator<? super E> comparator = delegate.comparator();
88    if (comparator == null) {
89      comparator = NATURAL_ORDER;
90    }
91    return comparator;
92  }
93
94  @Override public boolean contains(Object object) {
95    return delegate.contains(checkValid(object));
96  }
97
98  @Override public boolean containsAll(Collection<?> c) {
99    return delegate.containsAll(c);
100  }
101
102  @Override public Iterator<E> descendingIterator() {
103    return delegate.descendingIterator();
104  }
105
106  @Override public NavigableSet<E> descendingSet() {
107    return new SafeTreeSet<E>(delegate.descendingSet());
108  }
109
110  @Override public E first() {
111    return delegate.first();
112  }
113
114  @Override public E floor(E e) {
115    return delegate.floor(checkValid(e));
116  }
117
118  @Override public SortedSet<E> headSet(E toElement) {
119    return headSet(toElement, false);
120  }
121
122  @Override public NavigableSet<E> headSet(E toElement, boolean inclusive) {
123    return new SafeTreeSet<E>(
124        delegate.headSet(checkValid(toElement), inclusive));
125  }
126
127  @Override public E higher(E e) {
128    return delegate.higher(checkValid(e));
129  }
130
131  @Override public boolean isEmpty() {
132    return delegate.isEmpty();
133  }
134
135  @Override public Iterator<E> iterator() {
136    return delegate.iterator();
137  }
138
139  @Override public E last() {
140    return delegate.last();
141  }
142
143  @Override public E lower(E e) {
144    return delegate.lower(checkValid(e));
145  }
146
147  @Override public E pollFirst() {
148    return delegate.pollFirst();
149  }
150
151  @Override public E pollLast() {
152    return delegate.pollLast();
153  }
154
155  @Override public boolean remove(Object object) {
156    return delegate.remove(checkValid(object));
157  }
158
159  @Override public boolean removeAll(Collection<?> c) {
160    return delegate.removeAll(c);
161  }
162
163  @Override public boolean retainAll(Collection<?> c) {
164    return delegate.retainAll(c);
165  }
166
167  @Override public int size() {
168    return delegate.size();
169  }
170
171  @Override public NavigableSet<E> subSet(
172      E fromElement, boolean fromInclusive, E toElement, boolean toInclusive) {
173    return new SafeTreeSet<E>(
174        delegate.subSet(checkValid(fromElement), fromInclusive,
175            checkValid(toElement), toInclusive));
176  }
177
178  @Override public SortedSet<E> subSet(E fromElement, E toElement) {
179    return subSet(fromElement, true, toElement, false);
180  }
181
182  @Override public SortedSet<E> tailSet(E fromElement) {
183    return tailSet(fromElement, true);
184  }
185
186  @Override public NavigableSet<E> tailSet(E fromElement, boolean inclusive) {
187    return delegate.tailSet(checkValid(fromElement), inclusive);
188  }
189
190  @Override public Object[] toArray() {
191    return delegate.toArray();
192  }
193
194  @Override public <T> T[] toArray(T[] a) {
195    return delegate.toArray(a);
196  }
197
198  private <T> T checkValid(T t) {
199    // a ClassCastException is what's supposed to happen!
200    @SuppressWarnings("unchecked")
201    E e = (E) t;
202    comparator().compare(e, e);
203    return t;
204  }
205
206  @Override public boolean equals(Object obj) {
207    return delegate.equals(obj);
208  }
209
210  @Override public int hashCode() {
211    return delegate.hashCode();
212  }
213
214  @Override public String toString() {
215    return delegate.toString();
216  }
217
218  private static final long serialVersionUID = 0L;
219}
220