1/*
2 * Copyright (C) 2010 The Guava Authors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
5 * in compliance with the License. You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software distributed under the License
10 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
11 * or implied. See the License for the specific language governing permissions and limitations under
12 * the License.
13 */
14
15package com.google.common.collect;
16
17import static com.google.common.base.Preconditions.checkNotNull;
18import static com.google.common.base.Preconditions.checkState;
19
20import com.google.common.base.Equivalence;
21import com.google.common.base.Function;
22import com.google.common.collect.MapMaker.RemovalCause;
23import com.google.common.collect.MapMaker.RemovalListener;
24
25import java.io.IOException;
26import java.io.ObjectInputStream;
27import java.io.ObjectOutputStream;
28import java.lang.ref.ReferenceQueue;
29import java.util.concurrent.ConcurrentMap;
30import java.util.concurrent.ExecutionException;
31import java.util.concurrent.atomic.AtomicReferenceArray;
32
33import javax.annotation.Nullable;
34import javax.annotation.concurrent.GuardedBy;
35
36/**
37 * Adds computing functionality to {@link MapMakerInternalMap}.
38 *
39 * @author Bob Lee
40 * @author Charles Fry
41 */
42class ComputingConcurrentHashMap<K, V> extends MapMakerInternalMap<K, V> {
43  final Function<? super K, ? extends V> computingFunction;
44
45  /**
46   * Creates a new, empty map with the specified strategy, initial capacity, load factor and
47   * concurrency level.
48   */
49  ComputingConcurrentHashMap(MapMaker builder,
50      Function<? super K, ? extends V> computingFunction) {
51    super(builder);
52    this.computingFunction = checkNotNull(computingFunction);
53  }
54
55  @Override
56  Segment<K, V> createSegment(int initialCapacity, int maxSegmentSize) {
57    return new ComputingSegment<K, V>(this, initialCapacity, maxSegmentSize);
58  }
59
60  @Override
61  ComputingSegment<K, V> segmentFor(int hash) {
62    return (ComputingSegment<K, V>) super.segmentFor(hash);
63  }
64
65  V getOrCompute(K key) throws ExecutionException {
66    int hash = hash(checkNotNull(key));
67    return segmentFor(hash).getOrCompute(key, hash, computingFunction);
68  }
69
70  @SuppressWarnings("serial") // This class is never serialized.
71  static final class ComputingSegment<K, V> extends Segment<K, V> {
72    ComputingSegment(MapMakerInternalMap<K, V> map, int initialCapacity, int maxSegmentSize) {
73      super(map, initialCapacity, maxSegmentSize);
74    }
75
76    V getOrCompute(K key, int hash, Function<? super K, ? extends V> computingFunction)
77        throws ExecutionException {
78      try {
79        outer: while (true) {
80          // don't call getLiveEntry, which would ignore computing values
81          ReferenceEntry<K, V> e = getEntry(key, hash);
82          if (e != null) {
83            V value = getLiveValue(e);
84            if (value != null) {
85              recordRead(e);
86              return value;
87            }
88          }
89
90          // at this point e is either null, computing, or expired;
91          // avoid locking if it's already computing
92          if (e == null || !e.getValueReference().isComputingReference()) {
93            boolean createNewEntry = true;
94            ComputingValueReference<K, V> computingValueReference = null;
95            lock();
96            try {
97              preWriteCleanup();
98
99              int newCount = this.count - 1;
100              AtomicReferenceArray<ReferenceEntry<K, V>> table = this.table;
101              int index = hash & (table.length() - 1);
102              ReferenceEntry<K, V> first = table.get(index);
103
104              for (e = first; e != null; e = e.getNext()) {
105                K entryKey = e.getKey();
106                if (e.getHash() == hash && entryKey != null
107                    && map.keyEquivalence.equivalent(key, entryKey)) {
108                  ValueReference<K, V> valueReference = e.getValueReference();
109                  if (valueReference.isComputingReference()) {
110                    createNewEntry = false;
111                  } else {
112                    V value = e.getValueReference().get();
113                    if (value == null) {
114                      enqueueNotification(entryKey, hash, value, RemovalCause.COLLECTED);
115                    } else if (map.expires() && map.isExpired(e)) {
116                      // This is a duplicate check, as preWriteCleanup already purged expired
117                      // entries, but let's accomodate an incorrect expiration queue.
118                      enqueueNotification(entryKey, hash, value, RemovalCause.EXPIRED);
119                    } else {
120                      recordLockedRead(e);
121                      return value;
122                    }
123
124                    // immediately reuse invalid entries
125                    evictionQueue.remove(e);
126                    expirationQueue.remove(e);
127                    this.count = newCount; // write-volatile
128                  }
129                  break;
130                }
131              }
132
133              if (createNewEntry) {
134                computingValueReference = new ComputingValueReference<K, V>(computingFunction);
135
136                if (e == null) {
137                  e = newEntry(key, hash, first);
138                  e.setValueReference(computingValueReference);
139                  table.set(index, e);
140                } else {
141                  e.setValueReference(computingValueReference);
142                }
143              }
144            } finally {
145              unlock();
146              postWriteCleanup();
147            }
148
149            if (createNewEntry) {
150              // This thread solely created the entry.
151              return compute(key, hash, e, computingValueReference);
152            }
153          }
154
155          // The entry already exists. Wait for the computation.
156          checkState(!Thread.holdsLock(e), "Recursive computation");
157          // don't consider expiration as we're concurrent with computation
158          V value = e.getValueReference().waitForValue();
159          if (value != null) {
160            recordRead(e);
161            return value;
162          }
163          // else computing thread will clearValue
164          continue outer;
165        }
166      } finally {
167        postReadCleanup();
168      }
169    }
170
171    V compute(K key, int hash, ReferenceEntry<K, V> e,
172        ComputingValueReference<K, V> computingValueReference)
173        throws ExecutionException {
174      V value = null;
175      long start = System.nanoTime();
176      long end = 0;
177      try {
178        // Synchronizes on the entry to allow failing fast when a recursive computation is
179        // detected. This is not fool-proof since the entry may be copied when the segment
180        // is written to.
181        synchronized (e) {
182          value = computingValueReference.compute(key, hash);
183          end = System.nanoTime();
184        }
185        if (value != null) {
186          // putIfAbsent
187          V oldValue = put(key, hash, value, true);
188          if (oldValue != null) {
189            // the computed value was already clobbered
190            enqueueNotification(key, hash, value, RemovalCause.REPLACED);
191          }
192        }
193        return value;
194      } finally {
195        if (end == 0) {
196          end = System.nanoTime();
197        }
198        if (value == null) {
199          clearValue(key, hash, computingValueReference);
200        }
201      }
202    }
203  }
204
205  /**
206   * Used to provide computation exceptions to other threads.
207   */
208  private static final class ComputationExceptionReference<K, V> implements ValueReference<K, V> {
209    final Throwable t;
210
211    ComputationExceptionReference(Throwable t) {
212      this.t = t;
213    }
214
215    @Override
216    public V get() {
217      return null;
218    }
219
220    @Override
221    public ReferenceEntry<K, V> getEntry() {
222      return null;
223    }
224
225    @Override
226    public ValueReference<K, V> copyFor(
227        ReferenceQueue<V> queue, V value, ReferenceEntry<K, V> entry) {
228      return this;
229    }
230
231    @Override
232    public boolean isComputingReference() {
233      return false;
234    }
235
236    @Override
237    public V waitForValue() throws ExecutionException {
238      throw new ExecutionException(t);
239    }
240
241    @Override
242    public void clear(ValueReference<K, V> newValue) {}
243  }
244
245  /**
246   * Used to provide computation result to other threads.
247   */
248  private static final class ComputedReference<K, V> implements ValueReference<K, V> {
249    final V value;
250
251    ComputedReference(@Nullable V value) {
252      this.value = value;
253    }
254
255    @Override
256    public V get() {
257      return value;
258    }
259
260    @Override
261    public ReferenceEntry<K, V> getEntry() {
262      return null;
263    }
264
265    @Override
266    public ValueReference<K, V> copyFor(
267        ReferenceQueue<V> queue, V value, ReferenceEntry<K, V> entry) {
268      return this;
269    }
270
271    @Override
272    public boolean isComputingReference() {
273      return false;
274    }
275
276    @Override
277    public V waitForValue() {
278      return get();
279    }
280
281    @Override
282    public void clear(ValueReference<K, V> newValue) {}
283  }
284
285  private static final class ComputingValueReference<K, V> implements ValueReference<K, V> {
286    final Function<? super K, ? extends V> computingFunction;
287
288    @GuardedBy("ComputingValueReference.this") // writes
289    volatile ValueReference<K, V> computedReference = unset();
290
291    public ComputingValueReference(Function<? super K, ? extends V> computingFunction) {
292      this.computingFunction = computingFunction;
293    }
294
295    @Override
296    public V get() {
297      // All computation lookups go through waitForValue. This method thus is
298      // only used by put, to whom we always want to appear absent.
299      return null;
300    }
301
302    @Override
303    public ReferenceEntry<K, V> getEntry() {
304      return null;
305    }
306
307    @Override
308    public ValueReference<K, V> copyFor(
309        ReferenceQueue<V> queue, @Nullable V value, ReferenceEntry<K, V> entry) {
310      return this;
311    }
312
313    @Override
314    public boolean isComputingReference() {
315      return true;
316    }
317
318    /**
319     * Waits for a computation to complete. Returns the result of the computation.
320     */
321    @Override
322    public V waitForValue() throws ExecutionException {
323      if (computedReference == UNSET) {
324        boolean interrupted = false;
325        try {
326          synchronized (this) {
327            while (computedReference == UNSET) {
328              try {
329                wait();
330              } catch (InterruptedException ie) {
331                interrupted = true;
332              }
333            }
334          }
335        } finally {
336          if (interrupted) {
337            Thread.currentThread().interrupt();
338          }
339        }
340      }
341      return computedReference.waitForValue();
342    }
343
344    @Override
345    public void clear(ValueReference<K, V> newValue) {
346      // The pending computation was clobbered by a manual write. Unblock all
347      // pending gets, and have them return the new value.
348      setValueReference(newValue);
349
350      // TODO(fry): could also cancel computation if we had a thread handle
351    }
352
353    V compute(K key, int hash) throws ExecutionException {
354      V value;
355      try {
356        value = computingFunction.apply(key);
357      } catch (Throwable t) {
358        setValueReference(new ComputationExceptionReference<K, V>(t));
359        throw new ExecutionException(t);
360      }
361
362      setValueReference(new ComputedReference<K, V>(value));
363      return value;
364    }
365
366    void setValueReference(ValueReference<K, V> valueReference) {
367      synchronized (this) {
368        if (computedReference == UNSET) {
369          computedReference = valueReference;
370          notifyAll();
371        }
372      }
373    }
374  }
375
376  // Serialization Support
377
378  private static final long serialVersionUID = 4;
379
380  @Override
381  Object writeReplace() {
382    return new ComputingSerializationProxy<K, V>(keyStrength, valueStrength, keyEquivalence,
383        valueEquivalence, expireAfterWriteNanos, expireAfterAccessNanos, maximumSize,
384        concurrencyLevel, removalListener, this, computingFunction);
385  }
386
387  static final class ComputingSerializationProxy<K, V> extends AbstractSerializationProxy<K, V> {
388
389    final Function<? super K, ? extends V> computingFunction;
390
391    ComputingSerializationProxy(Strength keyStrength, Strength valueStrength,
392        Equivalence<Object> keyEquivalence, Equivalence<Object> valueEquivalence,
393        long expireAfterWriteNanos, long expireAfterAccessNanos, int maximumSize,
394        int concurrencyLevel, RemovalListener<? super K, ? super V> removalListener,
395        ConcurrentMap<K, V> delegate, Function<? super K, ? extends V> computingFunction) {
396      super(keyStrength, valueStrength, keyEquivalence, valueEquivalence, expireAfterWriteNanos,
397          expireAfterAccessNanos, maximumSize, concurrencyLevel, removalListener, delegate);
398      this.computingFunction = computingFunction;
399    }
400
401    private void writeObject(ObjectOutputStream out) throws IOException {
402      out.defaultWriteObject();
403      writeMapTo(out);
404    }
405
406    @SuppressWarnings("deprecation") // self-use
407    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
408      in.defaultReadObject();
409      MapMaker mapMaker = readMapMaker(in);
410      delegate = mapMaker.makeComputingMap(computingFunction);
411      readEntries(in);
412    }
413
414    Object readResolve() {
415      return delegate;
416    }
417
418    private static final long serialVersionUID = 4;
419  }
420}
421