1// Copyright (c) 2011 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef BASE_WIN_SCOPED_COMPTR_H_
6#define BASE_WIN_SCOPED_COMPTR_H_
7
8#include <unknwn.h>
9
10#include "base/logging.h"
11#include "base/memory/ref_counted.h"
12
13namespace base {
14namespace win {
15
16// A fairly minimalistic smart class for COM interface pointers.
17// Uses scoped_refptr for the basic smart pointer functionality
18// and adds a few IUnknown specific services.
19template <class Interface, const IID* interface_id = &__uuidof(Interface)>
20class ScopedComPtr : public scoped_refptr<Interface> {
21 public:
22  // Utility template to prevent users of ScopedComPtr from calling AddRef
23  // and/or Release() without going through the ScopedComPtr class.
24  class BlockIUnknownMethods : public Interface {
25   private:
26    STDMETHOD(QueryInterface)(REFIID iid, void** object) = 0;
27    STDMETHOD_(ULONG, AddRef)() = 0;
28    STDMETHOD_(ULONG, Release)() = 0;
29  };
30
31  typedef scoped_refptr<Interface> ParentClass;
32
33  ScopedComPtr() {
34  }
35
36  explicit ScopedComPtr(Interface* p) : ParentClass(p) {
37  }
38
39  ScopedComPtr(const ScopedComPtr<Interface, interface_id>& p)
40      : ParentClass(p) {
41  }
42
43  ~ScopedComPtr() {
44    // We don't want the smart pointer class to be bigger than the pointer
45    // it wraps.
46    COMPILE_ASSERT(sizeof(ScopedComPtr<Interface, interface_id>) ==
47                   sizeof(Interface*), ScopedComPtrSize);
48  }
49
50  // Explicit Release() of the held object.  Useful for reuse of the
51  // ScopedComPtr instance.
52  // Note that this function equates to IUnknown::Release and should not
53  // be confused with e.g. scoped_ptr::release().
54  void Release() {
55    if (ptr_ != NULL) {
56      ptr_->Release();
57      ptr_ = NULL;
58    }
59  }
60
61  // Sets the internal pointer to NULL and returns the held object without
62  // releasing the reference.
63  Interface* Detach() {
64    Interface* p = ptr_;
65    ptr_ = NULL;
66    return p;
67  }
68
69  // Accepts an interface pointer that has already been addref-ed.
70  void Attach(Interface* p) {
71    DCHECK(!ptr_);
72    ptr_ = p;
73  }
74
75  // Retrieves the pointer address.
76  // Used to receive object pointers as out arguments (and take ownership).
77  // The function DCHECKs on the current value being NULL.
78  // Usage: Foo(p.Receive());
79  Interface** Receive() {
80    DCHECK(!ptr_) << "Object leak. Pointer must be NULL";
81    return &ptr_;
82  }
83
84  // A convenience for whenever a void pointer is needed as an out argument.
85  void** ReceiveVoid() {
86    return reinterpret_cast<void**>(Receive());
87  }
88
89  template <class Query>
90  HRESULT QueryInterface(Query** p) {
91    DCHECK(p != NULL);
92    DCHECK(ptr_ != NULL);
93    // IUnknown already has a template version of QueryInterface
94    // so the iid parameter is implicit here. The only thing this
95    // function adds are the DCHECKs.
96    return ptr_->QueryInterface(p);
97  }
98
99  // QI for times when the IID is not associated with the type.
100  HRESULT QueryInterface(const IID& iid, void** obj) {
101    DCHECK(obj != NULL);
102    DCHECK(ptr_ != NULL);
103    return ptr_->QueryInterface(iid, obj);
104  }
105
106  // Queries |other| for the interface this object wraps and returns the
107  // error code from the other->QueryInterface operation.
108  HRESULT QueryFrom(IUnknown* object) {
109    DCHECK(object != NULL);
110    return object->QueryInterface(Receive());
111  }
112
113  // Convenience wrapper around CoCreateInstance
114  HRESULT CreateInstance(const CLSID& clsid, IUnknown* outer = NULL,
115                         DWORD context = CLSCTX_ALL) {
116    DCHECK(!ptr_);
117    HRESULT hr = ::CoCreateInstance(clsid, outer, context, *interface_id,
118                                    reinterpret_cast<void**>(&ptr_));
119    return hr;
120  }
121
122  // Checks if the identity of |other| and this object is the same.
123  bool IsSameObject(IUnknown* other) {
124    if (!other && !ptr_)
125      return true;
126
127    if (!other || !ptr_)
128      return false;
129
130    ScopedComPtr<IUnknown> my_identity;
131    QueryInterface(my_identity.Receive());
132
133    ScopedComPtr<IUnknown> other_identity;
134    other->QueryInterface(other_identity.Receive());
135
136    return static_cast<IUnknown*>(my_identity) ==
137           static_cast<IUnknown*>(other_identity);
138  }
139
140  // Provides direct access to the interface.
141  // Here we use a well known trick to make sure we block access to
142  // IUnknown methods so that something bad like this doesn't happen:
143  //    ScopedComPtr<IUnknown> p(Foo());
144  //    p->Release();
145  //    ... later the destructor runs, which will Release() again.
146  // and to get the benefit of the DCHECKs we add to QueryInterface.
147  // There's still a way to call these methods if you absolutely must
148  // by statically casting the ScopedComPtr instance to the wrapped interface
149  // and then making the call... but generally that shouldn't be necessary.
150  BlockIUnknownMethods* operator->() const {
151    DCHECK(ptr_ != NULL);
152    return reinterpret_cast<BlockIUnknownMethods*>(ptr_);
153  }
154
155  // Pull in operator=() from the parent class.
156  using scoped_refptr<Interface>::operator=;
157
158  // static methods
159
160  static const IID& iid() {
161    return *interface_id;
162  }
163};
164
165}  // namespace win
166}  // namespace base
167
168#endif  // BASE_WIN_SCOPED_COMPTR_H_
169