1/**
2 * Copyright (C) 2008 Google Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.google.inject.util;
18
19import com.google.common.collect.ImmutableList;
20import com.google.common.collect.ImmutableSet;
21import com.google.common.collect.Iterables;
22import com.google.common.collect.Lists;
23import com.google.common.collect.Maps;
24import com.google.common.collect.Sets;
25import com.google.inject.AbstractModule;
26import com.google.inject.Binder;
27import com.google.inject.Binding;
28import com.google.inject.Key;
29import com.google.inject.Module;
30import com.google.inject.PrivateBinder;
31import com.google.inject.PrivateModule;
32import com.google.inject.Scope;
33import com.google.inject.internal.Errors;
34import com.google.inject.spi.DefaultBindingScopingVisitor;
35import com.google.inject.spi.DefaultElementVisitor;
36import com.google.inject.spi.Element;
37import com.google.inject.spi.ElementVisitor;
38import com.google.inject.spi.Elements;
39import com.google.inject.spi.ModuleAnnotatedMethodScannerBinding;
40import com.google.inject.spi.PrivateElements;
41import com.google.inject.spi.ScopeBinding;
42
43import java.lang.annotation.Annotation;
44import java.util.Arrays;
45import java.util.LinkedHashSet;
46import java.util.List;
47import java.util.Map;
48import java.util.Set;
49
50/**
51 * Static utility methods for creating and working with instances of {@link Module}.
52 *
53 * @author jessewilson@google.com (Jesse Wilson)
54 * @since 2.0
55 */
56public final class Modules {
57  private Modules() {}
58
59  public static final Module EMPTY_MODULE = new EmptyModule();
60  private static class EmptyModule implements Module {
61    public void configure(Binder binder) {}
62  }
63
64  /**
65   * Returns a builder that creates a module that overlays override modules over the given
66   * modules. If a key is bound in both sets of modules, only the binding from the override modules
67   * is kept. If a single {@link PrivateModule} is supplied or all elements are from
68   * a single {@link PrivateBinder}, then this will overwrite the private bindings.
69   * Otherwise, private bindings will not be overwritten unless they are exposed.
70   * This can be used to replace the bindings of a production module with test bindings:
71   * <pre>
72   * Module functionalTestModule
73   *     = Modules.override(new ProductionModule()).with(new TestModule());
74   * </pre>
75   *
76   * <p>Prefer to write smaller modules that can be reused and tested without overrides.
77   *
78   * @param modules the modules whose bindings are open to be overridden
79   */
80  public static OverriddenModuleBuilder override(Module... modules) {
81    return new RealOverriddenModuleBuilder(Arrays.asList(modules));
82  }
83
84  /**
85   * Returns a builder that creates a module that overlays override modules over the given
86   * modules. If a key is bound in both sets of modules, only the binding from the override modules
87   * is kept. If a single {@link PrivateModule} is supplied or all elements are from
88   * a single {@link PrivateBinder}, then this will overwrite the private bindings.
89   * Otherwise, private bindings will not be overwritten unless they are exposed.
90   * This can be used to replace the bindings of a production module with test bindings:
91   * <pre>
92   * Module functionalTestModule
93   *     = Modules.override(getProductionModules()).with(getTestModules());
94   * </pre>
95   *
96   * <p>Prefer to write smaller modules that can be reused and tested without overrides.
97   *
98   * @param modules the modules whose bindings are open to be overridden
99   */
100  public static OverriddenModuleBuilder override(Iterable<? extends Module> modules) {
101    return new RealOverriddenModuleBuilder(modules);
102  }
103
104  /**
105   * Returns a new module that installs all of {@code modules}.
106   */
107  public static Module combine(Module... modules) {
108    return combine(ImmutableSet.copyOf(modules));
109  }
110
111  /**
112   * Returns a new module that installs all of {@code modules}.
113   */
114  public static Module combine(Iterable<? extends Module> modules) {
115    return new CombinedModule(modules);
116  }
117
118  private static class CombinedModule implements Module {
119    final Set<Module> modulesSet;
120
121    CombinedModule(Iterable<? extends Module> modules) {
122      this.modulesSet = ImmutableSet.copyOf(modules);
123    }
124
125    public void configure(Binder binder) {
126      binder = binder.skipSources(getClass());
127      for (Module module : modulesSet) {
128        binder.install(module);
129      }
130    }
131  }
132
133  /**
134   * See the EDSL example at {@link Modules#override(Module[]) override()}.
135   */
136  public interface OverriddenModuleBuilder {
137
138    /**
139     * See the EDSL example at {@link Modules#override(Module[]) override()}.
140     */
141    Module with(Module... overrides);
142
143    /**
144     * See the EDSL example at {@link Modules#override(Module[]) override()}.
145     */
146    Module with(Iterable<? extends Module> overrides);
147  }
148
149  private static final class RealOverriddenModuleBuilder implements OverriddenModuleBuilder {
150    private final ImmutableSet<Module> baseModules;
151
152    private RealOverriddenModuleBuilder(Iterable<? extends Module> baseModules) {
153      this.baseModules = ImmutableSet.copyOf(baseModules);
154    }
155
156    public Module with(Module... overrides) {
157      return with(Arrays.asList(overrides));
158    }
159
160    public Module with(Iterable<? extends Module> overrides) {
161      return new OverrideModule(overrides, baseModules);
162    }
163  }
164
165  static class OverrideModule extends AbstractModule {
166    private final ImmutableSet<Module> overrides;
167    private final ImmutableSet<Module> baseModules;
168
169    OverrideModule(Iterable<? extends Module> overrides, ImmutableSet<Module> baseModules) {
170      this.overrides = ImmutableSet.copyOf(overrides);
171      this.baseModules = baseModules;
172    }
173
174    @Override
175    public void configure() {
176      Binder baseBinder = binder();
177      List<Element> baseElements = Elements.getElements(currentStage(), baseModules);
178
179      // If the sole element was a PrivateElements, we want to override
180      // the private elements within that -- so refocus our elements
181      // and binder.
182      if(baseElements.size() == 1) {
183        Element element = Iterables.getOnlyElement(baseElements);
184        if(element instanceof PrivateElements) {
185          PrivateElements privateElements = (PrivateElements)element;
186          PrivateBinder privateBinder = baseBinder.newPrivateBinder().withSource(privateElements.getSource());
187          for(Key exposed : privateElements.getExposedKeys()) {
188            privateBinder.withSource(privateElements.getExposedSource(exposed)).expose(exposed);
189          }
190          baseBinder = privateBinder;
191          baseElements = privateElements.getElements();
192        }
193      }
194
195      final Binder binder = baseBinder.skipSources(this.getClass());
196      final LinkedHashSet<Element> elements = new LinkedHashSet<Element>(baseElements);
197      final Module scannersModule = extractScanners(elements);
198      final List<Element> overrideElements = Elements.getElements(currentStage(),
199          ImmutableList.<Module>builder().addAll(overrides).add(scannersModule).build());
200
201      final Set<Key<?>> overriddenKeys = Sets.newHashSet();
202      final Map<Class<? extends Annotation>, ScopeBinding> overridesScopeAnnotations =
203          Maps.newHashMap();
204
205      // execute the overrides module, keeping track of which keys and scopes are bound
206      new ModuleWriter(binder) {
207        @Override public <T> Void visit(Binding<T> binding) {
208          overriddenKeys.add(binding.getKey());
209          return super.visit(binding);
210        }
211
212        @Override public Void visit(ScopeBinding scopeBinding) {
213          overridesScopeAnnotations.put(scopeBinding.getAnnotationType(), scopeBinding);
214          return super.visit(scopeBinding);
215        }
216
217        @Override public Void visit(PrivateElements privateElements) {
218          overriddenKeys.addAll(privateElements.getExposedKeys());
219          return super.visit(privateElements);
220        }
221      }.writeAll(overrideElements);
222
223      // execute the original module, skipping all scopes and overridden keys. We only skip each
224      // overridden binding once so things still blow up if the module binds the same thing
225      // multiple times.
226      final Map<Scope, List<Object>> scopeInstancesInUse = Maps.newHashMap();
227      final List<ScopeBinding> scopeBindings = Lists.newArrayList();
228      new ModuleWriter(binder) {
229        @Override public <T> Void visit(Binding<T> binding) {
230          if (!overriddenKeys.remove(binding.getKey())) {
231            super.visit(binding);
232
233            // Record when a scope instance is used in a binding
234            Scope scope = getScopeInstanceOrNull(binding);
235            if (scope != null) {
236              List<Object> existing = scopeInstancesInUse.get(scope);
237              if (existing == null) {
238                existing = Lists.newArrayList();
239                scopeInstancesInUse.put(scope, existing);
240              }
241              existing.add(binding.getSource());
242            }
243          }
244
245          return null;
246        }
247
248        void rewrite(Binder binder, PrivateElements privateElements, Set<Key<?>> keysToSkip) {
249          PrivateBinder privateBinder = binder.withSource(privateElements.getSource())
250              .newPrivateBinder();
251
252          Set<Key<?>> skippedExposes = Sets.newHashSet();
253
254          for (Key<?> key : privateElements.getExposedKeys()) {
255            if (keysToSkip.remove(key)) {
256              skippedExposes.add(key);
257            } else {
258              privateBinder.withSource(privateElements.getExposedSource(key)).expose(key);
259            }
260          }
261
262          for (Element element : privateElements.getElements()) {
263            if (element instanceof Binding
264                && skippedExposes.remove(((Binding) element).getKey())) {
265              continue;
266            }
267            if (element instanceof PrivateElements) {
268              rewrite(privateBinder, (PrivateElements) element, skippedExposes);
269              continue;
270            }
271            element.applyTo(privateBinder);
272          }
273        }
274
275        @Override public Void visit(PrivateElements privateElements) {
276          rewrite(binder, privateElements, overriddenKeys);
277          return null;
278        }
279
280        @Override public Void visit(ScopeBinding scopeBinding) {
281          scopeBindings.add(scopeBinding);
282          return null;
283        }
284      }.writeAll(elements);
285
286      // execute the scope bindings, skipping scopes that have been overridden. Any scope that
287      // is overridden and in active use will prompt an error
288      new ModuleWriter(binder) {
289        @Override public Void visit(ScopeBinding scopeBinding) {
290          ScopeBinding overideBinding =
291              overridesScopeAnnotations.remove(scopeBinding.getAnnotationType());
292          if (overideBinding == null) {
293            super.visit(scopeBinding);
294          } else {
295            List<Object> usedSources = scopeInstancesInUse.get(scopeBinding.getScope());
296            if (usedSources != null) {
297              StringBuilder sb = new StringBuilder(
298                  "The scope for @%s is bound directly and cannot be overridden.");
299              sb.append("%n     original binding at " + Errors.convert(scopeBinding.getSource()));
300              for (Object usedSource : usedSources) {
301                sb.append("%n     bound directly at " + Errors.convert(usedSource) + "");
302              }
303              binder.withSource(overideBinding.getSource())
304                  .addError(sb.toString(), scopeBinding.getAnnotationType().getSimpleName());
305            }
306          }
307          return null;
308        }
309      }.writeAll(scopeBindings);
310    }
311
312    private Scope getScopeInstanceOrNull(Binding<?> binding) {
313      return binding.acceptScopingVisitor(new DefaultBindingScopingVisitor<Scope>() {
314        @Override public Scope visitScope(Scope scope) {
315          return scope;
316        }
317      });
318    }
319  }
320
321  private static class ModuleWriter extends DefaultElementVisitor<Void> {
322    protected final Binder binder;
323
324    ModuleWriter(Binder binder) {
325      this.binder = binder.skipSources(this.getClass());
326    }
327
328    @Override protected Void visitOther(Element element) {
329      element.applyTo(binder);
330      return null;
331    }
332
333    void writeAll(Iterable<? extends Element> elements) {
334      for (Element element : elements) {
335        element.acceptVisitor(this);
336      }
337    }
338  }
339
340  private static Module extractScanners(Iterable<Element> elements) {
341    final List<ModuleAnnotatedMethodScannerBinding> scanners = Lists.newArrayList();
342    ElementVisitor<Void> visitor = new DefaultElementVisitor<Void>() {
343      @Override public Void visit(ModuleAnnotatedMethodScannerBinding binding) {
344        scanners.add(binding);
345        return null;
346      }
347    };
348    for (Element element : elements) {
349      element.acceptVisitor(visitor);
350    }
351    return new AbstractModule() {
352      @Override protected void configure() {
353        for (ModuleAnnotatedMethodScannerBinding scanner : scanners) {
354          scanner.applyTo(binder());
355        }
356      }
357    };
358  }
359}
360