1/*
2 * Copyright (C) 2010 The Guava Authors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
5 * in compliance with the License. You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software distributed under the License
10 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
11 * or implied. See the License for the specific language governing permissions and limitations under
12 * the License.
13 */
14
15package com.google.common.collect;
16
17import static com.google.common.base.Preconditions.checkNotNull;
18import static com.google.common.base.Preconditions.checkState;
19
20import com.google.common.base.Equivalence;
21import com.google.common.base.Function;
22import com.google.common.collect.MapMaker.RemovalCause;
23import com.google.common.collect.MapMaker.RemovalListener;
24
25import java.io.IOException;
26import java.io.ObjectInputStream;
27import java.io.ObjectOutputStream;
28import java.lang.ref.ReferenceQueue;
29import java.util.concurrent.ConcurrentMap;
30import java.util.concurrent.ExecutionException;
31import java.util.concurrent.atomic.AtomicReferenceArray;
32
33import javax.annotation.Nullable;
34import javax.annotation.concurrent.GuardedBy;
35
36/**
37 * Adds computing functionality to {@link MapMakerInternalMap}.
38 *
39 * @author Bob Lee
40 * @author Charles Fry
41 */
42class ComputingConcurrentHashMap<K, V> extends MapMakerInternalMap<K, V> {
43  final Function<? super K, ? extends V> computingFunction;
44
45  /**
46   * Creates a new, empty map with the specified strategy, initial capacity, load factor and
47   * concurrency level.
48   */
49  ComputingConcurrentHashMap(MapMaker builder,
50      Function<? super K, ? extends V> computingFunction) {
51    super(builder);
52    this.computingFunction = checkNotNull(computingFunction);
53  }
54
55  @Override
56  Segment<K, V> createSegment(int initialCapacity, int maxSegmentSize) {
57    return new ComputingSegment<K, V>(this, initialCapacity, maxSegmentSize);
58  }
59
60  @Override
61  ComputingSegment<K, V> segmentFor(int hash) {
62    return (ComputingSegment<K, V>) super.segmentFor(hash);
63  }
64
65  V getOrCompute(K key) throws ExecutionException {
66    int hash = hash(checkNotNull(key));
67    return segmentFor(hash).getOrCompute(key, hash, computingFunction);
68  }
69
70  @SuppressWarnings("serial") // This class is never serialized.
71  static final class ComputingSegment<K, V> extends Segment<K, V> {
72    ComputingSegment(MapMakerInternalMap<K, V> map, int initialCapacity, int maxSegmentSize) {
73      super(map, initialCapacity, maxSegmentSize);
74    }
75
76    V getOrCompute(K key, int hash, Function<? super K, ? extends V> computingFunction)
77        throws ExecutionException {
78      try {
79        outer: while (true) {
80          // don't call getLiveEntry, which would ignore computing values
81          ReferenceEntry<K, V> e = getEntry(key, hash);
82          if (e != null) {
83            V value = getLiveValue(e);
84            if (value != null) {
85              recordRead(e);
86              return value;
87            }
88          }
89
90          // at this point e is either null, computing, or expired;
91          // avoid locking if it's already computing
92          if (e == null || !e.getValueReference().isComputingReference()) {
93            boolean createNewEntry = true;
94            ComputingValueReference<K, V> computingValueReference = null;
95            lock();
96            try {
97              preWriteCleanup();
98
99              int newCount = this.count - 1;
100              AtomicReferenceArray<ReferenceEntry<K, V>> table = this.table;
101              int index = hash & (table.length() - 1);
102              ReferenceEntry<K, V> first = table.get(index);
103
104              for (e = first; e != null; e = e.getNext()) {
105                K entryKey = e.getKey();
106                if (e.getHash() == hash && entryKey != null
107                    && map.keyEquivalence.equivalent(key, entryKey)) {
108                  ValueReference<K, V> valueReference = e.getValueReference();
109                  if (valueReference.isComputingReference()) {
110                    createNewEntry = false;
111                  } else {
112                    V value = e.getValueReference().get();
113                    if (value == null) {
114                      enqueueNotification(entryKey, hash, value, RemovalCause.COLLECTED);
115                    } else if (map.expires() && map.isExpired(e)) {
116                      // This is a duplicate check, as preWriteCleanup already purged expired
117                      // entries, but let's accomodate an incorrect expiration queue.
118                      enqueueNotification(entryKey, hash, value, RemovalCause.EXPIRED);
119                    } else {
120                      recordLockedRead(e);
121                      return value;
122                    }
123
124                    // immediately reuse invalid entries
125                    evictionQueue.remove(e);
126                    expirationQueue.remove(e);
127                    this.count = newCount; // write-volatile
128                  }
129                  break;
130                }
131              }
132
133              if (createNewEntry) {
134                computingValueReference = new ComputingValueReference<K, V>(computingFunction);
135
136                if (e == null) {
137                  e = newEntry(key, hash, first);
138                  e.setValueReference(computingValueReference);
139                  table.set(index, e);
140                } else {
141                  e.setValueReference(computingValueReference);
142                }
143              }
144            } finally {
145              unlock();
146              postWriteCleanup();
147            }
148
149            if (createNewEntry) {
150              // This thread solely created the entry.
151              return compute(key, hash, e, computingValueReference);
152            }
153          }
154
155          // The entry already exists. Wait for the computation.
156          checkState(!Thread.holdsLock(e), "Recursive computation");
157          // don't consider expiration as we're concurrent with computation
158          V value = e.getValueReference().waitForValue();
159          if (value != null) {
160            recordRead(e);
161            return value;
162          }
163          // else computing thread will clearValue
164          continue outer;
165        }
166      } finally {
167        postReadCleanup();
168      }
169    }
170
171    V compute(K key, int hash, ReferenceEntry<K, V> e,
172        ComputingValueReference<K, V> computingValueReference)
173        throws ExecutionException {
174      V value = null;
175      long start = System.nanoTime();
176      long end = 0;
177      try {
178        // Synchronizes on the entry to allow failing fast when a recursive computation is
179        // detected. This is not fool-proof since the entry may be copied when the segment
180        // is written to.
181        synchronized (e) {
182          value = computingValueReference.compute(key, hash);
183          end = System.nanoTime();
184        }
185        if (value != null) {
186          // putIfAbsent
187          V oldValue = put(key, hash, value, true);
188          if (oldValue != null) {
189            // the computed value was already clobbered
190            enqueueNotification(key, hash, value, RemovalCause.REPLACED);
191          }
192        }
193        return value;
194      } finally {
195        if (end == 0) {
196          end = System.nanoTime();
197        }
198        if (value == null) {
199          clearValue(key, hash, computingValueReference);
200        }
201      }
202    }
203  }
204
205  /**
206   * Used to provide computation exceptions to other threads.
207   */
208  private static final class ComputationExceptionReference<K, V> implements ValueReference<K, V> {
209    final Throwable t;
210
211    ComputationExceptionReference(Throwable t) {
212      this.t = t;
213    }
214
215    @Override
216    public V get() {
217      return null;
218    }
219
220    @Override
221    public ReferenceEntry<K, V> getEntry() {
222      return null;
223    }
224
225    @Override
226    public ValueReference<K, V> copyFor(ReferenceQueue<V> queue, ReferenceEntry<K, V> entry) {
227      return this;
228    }
229
230    @Override
231    public boolean isComputingReference() {
232      return false;
233    }
234
235    @Override
236    public V waitForValue() throws ExecutionException {
237      throw new ExecutionException(t);
238    }
239
240    @Override
241    public void clear(ValueReference<K, V> newValue) {}
242  }
243
244  /**
245   * Used to provide computation result to other threads.
246   */
247  private static final class ComputedReference<K, V> implements ValueReference<K, V> {
248    final V value;
249
250    ComputedReference(@Nullable V value) {
251      this.value = value;
252    }
253
254    @Override
255    public V get() {
256      return value;
257    }
258
259    @Override
260    public ReferenceEntry<K, V> getEntry() {
261      return null;
262    }
263
264    @Override
265    public ValueReference<K, V> copyFor(ReferenceQueue<V> queue, ReferenceEntry<K, V> entry) {
266      return this;
267    }
268
269    @Override
270    public boolean isComputingReference() {
271      return false;
272    }
273
274    @Override
275    public V waitForValue() {
276      return get();
277    }
278
279    @Override
280    public void clear(ValueReference<K, V> newValue) {}
281  }
282
283  private static final class ComputingValueReference<K, V> implements ValueReference<K, V> {
284    final Function<? super K, ? extends V> computingFunction;
285
286    @GuardedBy("ComputingValueReference.this") // writes
287    volatile ValueReference<K, V> computedReference = unset();
288
289    public ComputingValueReference(Function<? super K, ? extends V> computingFunction) {
290      this.computingFunction = computingFunction;
291    }
292
293    @Override
294    public V get() {
295      // All computation lookups go through waitForValue. This method thus is
296      // only used by put, to whom we always want to appear absent.
297      return null;
298    }
299
300    @Override
301    public ReferenceEntry<K, V> getEntry() {
302      return null;
303    }
304
305    @Override
306    public ValueReference<K, V> copyFor(ReferenceQueue<V> queue, ReferenceEntry<K, V> entry) {
307      return this;
308    }
309
310    @Override
311    public boolean isComputingReference() {
312      return true;
313    }
314
315    /**
316     * Waits for a computation to complete. Returns the result of the computation.
317     */
318    @Override
319    public V waitForValue() throws ExecutionException {
320      if (computedReference == UNSET) {
321        boolean interrupted = false;
322        try {
323          synchronized (this) {
324            while (computedReference == UNSET) {
325              try {
326                wait();
327              } catch (InterruptedException ie) {
328                interrupted = true;
329              }
330            }
331          }
332        } finally {
333          if (interrupted) {
334            Thread.currentThread().interrupt();
335          }
336        }
337      }
338      return computedReference.waitForValue();
339    }
340
341    @Override
342    public void clear(ValueReference<K, V> newValue) {
343      // The pending computation was clobbered by a manual write. Unblock all
344      // pending gets, and have them return the new value.
345      setValueReference(newValue);
346
347      // TODO(fry): could also cancel computation if we had a thread handle
348    }
349
350    V compute(K key, int hash) throws ExecutionException {
351      V value;
352      try {
353        value = computingFunction.apply(key);
354      } catch (Throwable t) {
355        setValueReference(new ComputationExceptionReference<K, V>(t));
356        throw new ExecutionException(t);
357      }
358
359      setValueReference(new ComputedReference<K, V>(value));
360      return value;
361    }
362
363    void setValueReference(ValueReference<K, V> valueReference) {
364      synchronized (this) {
365        if (computedReference == UNSET) {
366          computedReference = valueReference;
367          notifyAll();
368        }
369      }
370    }
371  }
372
373  // Serialization Support
374
375  private static final long serialVersionUID = 4;
376
377  @Override
378  Object writeReplace() {
379    return new ComputingSerializationProxy<K, V>(keyStrength, valueStrength, keyEquivalence,
380        valueEquivalence, expireAfterWriteNanos, expireAfterAccessNanos, maximumSize,
381        concurrencyLevel, removalListener, this, computingFunction);
382  }
383
384  static final class ComputingSerializationProxy<K, V> extends AbstractSerializationProxy<K, V> {
385
386    final Function<? super K, ? extends V> computingFunction;
387
388    ComputingSerializationProxy(Strength keyStrength, Strength valueStrength,
389        Equivalence<Object> keyEquivalence, Equivalence<Object> valueEquivalence,
390        long expireAfterWriteNanos, long expireAfterAccessNanos, int maximumSize,
391        int concurrencyLevel, RemovalListener<? super K, ? super V> removalListener,
392        ConcurrentMap<K, V> delegate, Function<? super K, ? extends V> computingFunction) {
393      super(keyStrength, valueStrength, keyEquivalence, valueEquivalence, expireAfterWriteNanos,
394          expireAfterAccessNanos, maximumSize, concurrencyLevel, removalListener, delegate);
395      this.computingFunction = computingFunction;
396    }
397
398    private void writeObject(ObjectOutputStream out) throws IOException {
399      out.defaultWriteObject();
400      writeMapTo(out);
401    }
402
403    @SuppressWarnings("deprecation") // self-use
404    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
405      in.defaultReadObject();
406      MapMaker mapMaker = readMapMaker(in);
407      delegate = mapMaker.makeComputingMap(computingFunction);
408      readEntries(in);
409    }
410
411    Object readResolve() {
412      return delegate;
413    }
414
415    private static final long serialVersionUID = 4;
416  }
417}
418