service_discovery_host_client.cc revision d0247b1b59f9c528cb6df88b4f2b9afaf80d181e
1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/browser/local_discovery/service_discovery_host_client.h"
6
7#if defined(OS_POSIX)
8#include "base/file_descriptor_posix.h"
9#endif  // OS_POSIX
10
11#include "chrome/common/local_discovery/local_discovery_messages.h"
12#include "content/public/browser/browser_thread.h"
13#include "content/public/browser/utility_process_host.h"
14#include "net/socket/socket_descriptor.h"
15
16namespace local_discovery {
17
18namespace {
19ServiceDiscoverySharedClient* g_service_discovery_client = NULL;
20}  // namespace
21
22using content::BrowserThread;
23using content::UtilityProcessHost;
24
25class ServiceDiscoveryHostClient::ServiceWatcherProxy : public ServiceWatcher {
26 public:
27  ServiceWatcherProxy(ServiceDiscoveryHostClient* host,
28                      const std::string& service_type,
29                      const ServiceWatcher::UpdatedCallback& callback)
30      : host_(host),
31        service_type_(service_type),
32        id_(host_->RegisterWatcherCallback(callback)),
33        started_(false) {
34  }
35
36  virtual ~ServiceWatcherProxy() {
37    DVLOG(1) << "~ServiceWatcherProxy with id " << id_;
38    host_->UnregisterWatcherCallback(id_);
39    if (started_)
40      host_->Send(new LocalDiscoveryMsg_DestroyWatcher(id_));
41  }
42
43  virtual void Start() OVERRIDE {
44    DVLOG(1) << "ServiceWatcher::Start with id " << id_;
45    DCHECK(!started_);
46    host_->Send(new LocalDiscoveryMsg_StartWatcher(id_, service_type_));
47    started_ = true;
48  }
49
50  virtual void DiscoverNewServices(bool force_update) OVERRIDE {
51    DVLOG(1) << "ServiceWatcher::DiscoverNewServices with id " << id_;
52    DCHECK(started_);
53    host_->Send(new LocalDiscoveryMsg_DiscoverServices(id_, force_update));
54  }
55
56  virtual std::string GetServiceType() const OVERRIDE {
57    return service_type_;
58  }
59
60 private:
61  scoped_refptr<ServiceDiscoveryHostClient> host_;
62  const std::string service_type_;
63  const uint64 id_;
64  bool started_;
65};
66
67class ServiceDiscoveryHostClient::ServiceResolverProxy
68    : public ServiceResolver {
69 public:
70  ServiceResolverProxy(ServiceDiscoveryHostClient* host,
71                       const std::string& service_name,
72                       const ServiceResolver::ResolveCompleteCallback& callback)
73      : host_(host),
74        service_name_(service_name),
75        id_(host->RegisterResolverCallback(callback)),
76        started_(false) {
77  }
78
79  virtual ~ServiceResolverProxy() {
80    DVLOG(1) << "~ServiceResolverProxy with id " << id_;
81    host_->UnregisterResolverCallback(id_);
82    if (started_)
83      host_->Send(new LocalDiscoveryMsg_DestroyResolver(id_));
84  }
85
86  virtual void StartResolving() OVERRIDE {
87    DVLOG(1) << "ServiceResolverProxy::StartResolving with id " << id_;
88    DCHECK(!started_);
89    host_->Send(new LocalDiscoveryMsg_ResolveService(id_, service_name_));
90    started_ = true;
91  }
92
93  virtual std::string GetName() const OVERRIDE {
94    return service_name_;
95  }
96
97 private:
98  scoped_refptr<ServiceDiscoveryHostClient> host_;
99  const std::string service_name_;
100  const uint64 id_;
101  bool started_;
102};
103
104class ServiceDiscoveryHostClient::LocalDomainResolverProxy
105    : public LocalDomainResolver {
106 public:
107  LocalDomainResolverProxy(ServiceDiscoveryHostClient* host,
108                       const std::string& domain,
109                       net::AddressFamily address_family,
110                       const LocalDomainResolver::IPAddressCallback& callback)
111      : host_(host),
112        domain_(domain),
113        address_family_(address_family),
114        id_(host->RegisterLocalDomainResolverCallback(callback)),
115        started_(false) {
116  }
117
118  virtual ~LocalDomainResolverProxy() {
119    DVLOG(1) << "~LocalDomainResolverProxy with id " << id_;
120    host_->UnregisterLocalDomainResolverCallback(id_);
121    if (started_)
122      host_->Send(new LocalDiscoveryMsg_DestroyLocalDomainResolver(id_));
123  }
124
125  virtual void Start() OVERRIDE {
126    DVLOG(1) << "LocalDomainResolverProxy::Start with id " << id_;
127    DCHECK(!started_);
128    host_->Send(new LocalDiscoveryMsg_ResolveLocalDomain(id_, domain_,
129                                                         address_family_));
130    started_ = true;
131  }
132
133 private:
134  scoped_refptr<ServiceDiscoveryHostClient> host_;
135  std::string domain_;
136  net::AddressFamily address_family_;
137  const uint64 id_;
138  bool started_;
139};
140
141ServiceDiscoveryHostClient::ServiceDiscoveryHostClient() : current_id_(0) {
142  callback_runner_ = base::MessageLoop::current()->message_loop_proxy();
143  io_runner_ = BrowserThread::GetMessageLoopProxyForThread(BrowserThread::IO);
144}
145
146ServiceDiscoveryHostClient::~ServiceDiscoveryHostClient() {
147  // The ServiceDiscoveryHostClient may be destroyed from the IO thread or the
148  // owning thread.
149  DetachFromThread();
150  DCHECK(service_watcher_callbacks_.empty());
151  DCHECK(service_resolver_callbacks_.empty());
152  DCHECK(domain_resolver_callbacks_.empty());
153}
154
155scoped_ptr<ServiceWatcher> ServiceDiscoveryHostClient::CreateServiceWatcher(
156    const std::string& service_type,
157    const ServiceWatcher::UpdatedCallback& callback) {
158  DCHECK(CalledOnValidThread());
159  return scoped_ptr<ServiceWatcher>(
160      new ServiceWatcherProxy(this, service_type, callback));
161}
162
163scoped_ptr<ServiceResolver> ServiceDiscoveryHostClient::CreateServiceResolver(
164    const std::string& service_name,
165    const ServiceResolver::ResolveCompleteCallback& callback) {
166  DCHECK(CalledOnValidThread());
167  return scoped_ptr<ServiceResolver>(
168      new ServiceResolverProxy(this, service_name, callback));
169}
170
171scoped_ptr<LocalDomainResolver>
172ServiceDiscoveryHostClient::CreateLocalDomainResolver(
173    const std::string& domain,
174    net::AddressFamily address_family,
175    const LocalDomainResolver::IPAddressCallback& callback) {
176  DCHECK(CalledOnValidThread());
177  return scoped_ptr<LocalDomainResolver>(new LocalDomainResolverProxy(
178      this, domain, address_family, callback));
179}
180
181uint64 ServiceDiscoveryHostClient::RegisterWatcherCallback(
182    const ServiceWatcher::UpdatedCallback& callback) {
183  DCHECK(CalledOnValidThread());
184  DCHECK(!ContainsKey(service_watcher_callbacks_, current_id_ + 1));
185  service_watcher_callbacks_[++current_id_] = callback;
186  return current_id_;
187}
188
189uint64 ServiceDiscoveryHostClient::RegisterResolverCallback(
190    const ServiceResolver::ResolveCompleteCallback& callback) {
191  DCHECK(CalledOnValidThread());
192  DCHECK(!ContainsKey(service_resolver_callbacks_, current_id_ + 1));
193  service_resolver_callbacks_[++current_id_] = callback;
194  return current_id_;
195}
196
197uint64 ServiceDiscoveryHostClient::RegisterLocalDomainResolverCallback(
198    const LocalDomainResolver::IPAddressCallback& callback) {
199  DCHECK(CalledOnValidThread());
200  DCHECK(!ContainsKey(domain_resolver_callbacks_, current_id_ + 1));
201  domain_resolver_callbacks_[++current_id_] = callback;
202  return current_id_;
203}
204
205void ServiceDiscoveryHostClient::UnregisterWatcherCallback(uint64 id) {
206  DCHECK(CalledOnValidThread());
207  service_watcher_callbacks_.erase(id);
208}
209
210void ServiceDiscoveryHostClient::UnregisterResolverCallback(uint64 id) {
211  DCHECK(CalledOnValidThread());
212  service_resolver_callbacks_.erase(id);
213}
214
215void ServiceDiscoveryHostClient::UnregisterLocalDomainResolverCallback(
216    uint64 id) {
217  DCHECK(CalledOnValidThread());
218  domain_resolver_callbacks_.erase(id);
219}
220
221void ServiceDiscoveryHostClient::Start() {
222  DCHECK(CalledOnValidThread());
223  io_runner_->PostTask(
224      FROM_HERE,
225      base::Bind(&ServiceDiscoveryHostClient::StartOnIOThread, this));
226}
227
228void ServiceDiscoveryHostClient::Shutdown() {
229  DCHECK(CalledOnValidThread());
230  io_runner_->PostTask(
231      FROM_HERE,
232      base::Bind(&ServiceDiscoveryHostClient::ShutdownOnIOThread, this));
233}
234
235void ServiceDiscoveryHostClient::StartOnIOThread() {
236  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
237  DCHECK(!utility_host_);
238  utility_host_ = UtilityProcessHost::Create(
239      this, base::MessageLoopProxy::current().get())->AsWeakPtr();
240  if (utility_host_) {
241    utility_host_->EnableZygote();
242    utility_host_->EnableMDns();
243    utility_host_->StartBatchMode();
244
245#if defined(OS_POSIX)
246    base::FileDescriptor v4(net::CreatePlatformSocket(AF_INET, SOCK_DGRAM, 0),
247                            true);
248    base::FileDescriptor v6(net::CreatePlatformSocket(AF_INET6, SOCK_DGRAM, 0),
249                            true);
250    LOG_IF(ERROR, v4.fd == net::kInvalidSocket) << "Can't create IPv4 socket.";
251    LOG_IF(ERROR, v6.fd == net::kInvalidSocket) << "Can't create IPv6 socket.";
252    if (v4.fd == net::kInvalidSocket &&
253        v6.fd == net::kInvalidSocket) {
254      ShutdownOnIOThread();
255    } else {
256      utility_host_->Send(new LocalDiscoveryMsg_SetSockets(v4, v6));
257    }
258#endif  // OS_POSIX
259  }
260}
261
262void ServiceDiscoveryHostClient::ShutdownOnIOThread() {
263  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
264  if (utility_host_) {
265    utility_host_->Send(new LocalDiscoveryMsg_ShutdownLocalDiscovery);
266    utility_host_->EndBatchMode();
267  }
268}
269
270void ServiceDiscoveryHostClient::Send(IPC::Message* msg) {
271  DCHECK(CalledOnValidThread());
272  io_runner_->PostTask(
273      FROM_HERE,
274      base::Bind(&ServiceDiscoveryHostClient::SendOnIOThread, this, msg));
275}
276
277void ServiceDiscoveryHostClient::SendOnIOThread(IPC::Message* msg) {
278  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
279  if (utility_host_)
280    utility_host_->Send(msg);
281}
282
283bool ServiceDiscoveryHostClient::OnMessageReceived(
284    const IPC::Message& message) {
285  bool handled = true;
286  IPC_BEGIN_MESSAGE_MAP(ServiceDiscoveryHostClient, message)
287    IPC_MESSAGE_HANDLER(LocalDiscoveryHostMsg_WatcherCallback,
288                        OnWatcherCallback)
289    IPC_MESSAGE_HANDLER(LocalDiscoveryHostMsg_ResolverCallback,
290                        OnResolverCallback)
291    IPC_MESSAGE_HANDLER(LocalDiscoveryHostMsg_LocalDomainResolverCallback,
292                        OnLocalDomainResolverCallback)
293    IPC_MESSAGE_UNHANDLED(handled = false)
294  IPC_END_MESSAGE_MAP()
295  return handled;
296}
297
298void ServiceDiscoveryHostClient::InvalidateWatchers() {
299  WatcherCallbacks service_watcher_callbacks;
300  service_watcher_callbacks_.swap(service_watcher_callbacks);
301  service_resolver_callbacks_.clear();
302  domain_resolver_callbacks_.clear();
303
304  for (WatcherCallbacks::iterator i = service_watcher_callbacks.begin();
305       i != service_watcher_callbacks.end(); i++) {
306    if (!i->second.is_null()) {
307      i->second.Run(ServiceWatcher::UPDATE_INVALIDATED, "");
308    }
309  }
310}
311
312void ServiceDiscoveryHostClient::OnWatcherCallback(
313    uint64 id,
314    ServiceWatcher::UpdateType update,
315    const std::string& service_name) {
316  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
317  callback_runner_->PostTask(
318      FROM_HERE,
319      base::Bind(&ServiceDiscoveryHostClient::RunWatcherCallback, this, id,
320                 update, service_name));
321}
322
323void ServiceDiscoveryHostClient::OnResolverCallback(
324    uint64 id,
325    ServiceResolver::RequestStatus status,
326    const ServiceDescription& description) {
327  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
328  callback_runner_->PostTask(
329      FROM_HERE,
330      base::Bind(&ServiceDiscoveryHostClient::RunResolverCallback, this, id,
331                 status, description));
332}
333
334void ServiceDiscoveryHostClient::OnLocalDomainResolverCallback(
335    uint64 id,
336    bool success,
337    const net::IPAddressNumber& ip_address_ipv4,
338    const net::IPAddressNumber& ip_address_ipv6) {
339  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
340  callback_runner_->PostTask(
341      FROM_HERE,
342      base::Bind(&ServiceDiscoveryHostClient::RunLocalDomainResolverCallback,
343                 this, id, success, ip_address_ipv4, ip_address_ipv6));
344}
345
346void ServiceDiscoveryHostClient::RunWatcherCallback(
347    uint64 id,
348    ServiceWatcher::UpdateType update,
349    const std::string& service_name) {
350  DCHECK(CalledOnValidThread());
351  WatcherCallbacks::iterator it = service_watcher_callbacks_.find(id);
352  if (it != service_watcher_callbacks_.end() && !it->second.is_null())
353    it->second.Run(update, service_name);
354}
355
356void ServiceDiscoveryHostClient::RunResolverCallback(
357    uint64 id,
358    ServiceResolver::RequestStatus status,
359    const ServiceDescription& description) {
360  DCHECK(CalledOnValidThread());
361  ResolverCallbacks::iterator it = service_resolver_callbacks_.find(id);
362  if (it != service_resolver_callbacks_.end() && !it->second.is_null())
363    it->second.Run(status, description);
364}
365
366void ServiceDiscoveryHostClient::RunLocalDomainResolverCallback(
367    uint64 id,
368    bool success,
369    const net::IPAddressNumber& ip_address_ipv4,
370    const net::IPAddressNumber& ip_address_ipv6) {
371  DCHECK(CalledOnValidThread());
372  DomainResolverCallbacks::iterator it = domain_resolver_callbacks_.find(id);
373  if (it != domain_resolver_callbacks_.end() && !it->second.is_null())
374    it->second.Run(success, ip_address_ipv4, ip_address_ipv6);
375}
376
377scoped_ptr<ServiceWatcher> ServiceDiscoverySharedClient::CreateServiceWatcher(
378    const std::string& service_type,
379    const ServiceWatcher::UpdatedCallback& callback) {
380  DCHECK(CalledOnValidThread());
381  return host_client_->CreateServiceWatcher(service_type, callback);
382}
383
384scoped_ptr<ServiceResolver> ServiceDiscoverySharedClient::CreateServiceResolver(
385    const std::string& service_name,
386    const ServiceResolver::ResolveCompleteCallback& callback) {
387  DCHECK(CalledOnValidThread());
388  return host_client_->CreateServiceResolver(service_name, callback);
389}
390
391scoped_ptr<LocalDomainResolver>
392ServiceDiscoverySharedClient::CreateLocalDomainResolver(
393    const std::string& domain,
394    net::AddressFamily address_family,
395    const LocalDomainResolver::IPAddressCallback& callback) {
396  DCHECK(CalledOnValidThread());
397  return host_client_->CreateLocalDomainResolver(domain, address_family,
398                                                 callback);
399}
400
401ServiceDiscoverySharedClient::ServiceDiscoverySharedClient() {
402  net::NetworkChangeNotifier::AddNetworkChangeObserver(this);
403  DCHECK(!g_service_discovery_client);
404  g_service_discovery_client = this;
405  host_client_ = new ServiceDiscoveryHostClient();
406  host_client_->Start();
407}
408
409ServiceDiscoverySharedClient::~ServiceDiscoverySharedClient() {
410  net::NetworkChangeNotifier::RemoveNetworkChangeObserver(this);
411  DCHECK_EQ(g_service_discovery_client, this);
412  g_service_discovery_client = NULL;
413  host_client_->Shutdown();
414}
415
416
417
418void ServiceDiscoverySharedClient::OnNetworkChanged(
419    net::NetworkChangeNotifier::ConnectionType type) {
420  DCHECK(CalledOnValidThread());
421  host_client_->Shutdown();
422  network_change_callback_.Reset(
423      base::Bind(&ServiceDiscoverySharedClient::StartNewClient,
424                 base::Unretained(this)));  // Unretained to avoid ref cycle.
425  base::MessageLoop::current()->PostDelayedTask(
426      FROM_HERE,
427      network_change_callback_.callback(),
428      base::TimeDelta::FromSeconds(3));
429}
430
431void ServiceDiscoverySharedClient::StartNewClient() {
432  DCHECK(CalledOnValidThread());
433  scoped_refptr<ServiceDiscoveryHostClient> old_client = host_client_;
434  host_client_ = new ServiceDiscoveryHostClient();
435  host_client_->Start();
436  // Run when host_client_ is created. Callbacks created by InvalidateWatchers
437  // may create new watchers.
438  old_client->InvalidateWatchers();
439}
440
441scoped_refptr<ServiceDiscoverySharedClient>
442    ServiceDiscoverySharedClient::GetInstance() {
443  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
444
445  if (g_service_discovery_client)
446    return g_service_discovery_client;
447
448  return new ServiceDiscoverySharedClient();
449}
450
451}  // namespace local_discovery
452