1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/browser/local_discovery/service_discovery_host_client.h"
6
7#include "chrome/common/local_discovery/local_discovery_messages.h"
8#include "content/public/browser/browser_thread.h"
9#include "content/public/browser/utility_process_host.h"
10#include "net/dns/mdns_client.h"
11#include "net/socket/socket_descriptor.h"
12
13#if defined(OS_POSIX)
14#include <netinet/in.h>
15#include "base/file_descriptor_posix.h"
16#endif  // OS_POSIX
17
18namespace local_discovery {
19
20using content::BrowserThread;
21using content::UtilityProcessHost;
22
23namespace {
24
25#if defined(OS_POSIX)
26SocketInfoList GetSocketsOnFileThread() {
27  net::InterfaceIndexFamilyList interfaces(net::GetMDnsInterfacesToBind());
28  SocketInfoList sockets;
29  for (size_t i = 0; i < interfaces.size(); ++i) {
30    DCHECK(interfaces[i].second == net::ADDRESS_FAMILY_IPV4 ||
31           interfaces[i].second == net::ADDRESS_FAMILY_IPV6);
32    base::FileDescriptor socket_descriptor(
33        net::CreatePlatformSocket(
34            net::ConvertAddressFamily(interfaces[i].second), SOCK_DGRAM,
35                                      IPPROTO_UDP),
36        true);
37    LOG_IF(ERROR, socket_descriptor.fd == net::kInvalidSocket)
38        << "Can't create socket, family=" << interfaces[i].second;
39    if (socket_descriptor.fd != net::kInvalidSocket) {
40      LocalDiscoveryMsg_SocketInfo socket;
41      socket.descriptor = socket_descriptor;
42      socket.interface_index = interfaces[i].first;
43      socket.address_family = interfaces[i].second;
44      sockets.push_back(socket);
45    }
46  }
47
48  return sockets;
49}
50#endif  // OS_POSIX
51
52}  // namespace
53
54class ServiceDiscoveryHostClient::ServiceWatcherProxy : public ServiceWatcher {
55 public:
56  ServiceWatcherProxy(ServiceDiscoveryHostClient* host,
57                      const std::string& service_type,
58                      const ServiceWatcher::UpdatedCallback& callback)
59      : host_(host),
60        service_type_(service_type),
61        id_(host_->RegisterWatcherCallback(callback)),
62        started_(false) {
63  }
64
65  virtual ~ServiceWatcherProxy() {
66    DVLOG(1) << "~ServiceWatcherProxy with id " << id_;
67    host_->UnregisterWatcherCallback(id_);
68    if (started_)
69      host_->Send(new LocalDiscoveryMsg_DestroyWatcher(id_));
70  }
71
72  virtual void Start() OVERRIDE {
73    DVLOG(1) << "ServiceWatcher::Start with id " << id_;
74    DCHECK(!started_);
75    host_->Send(new LocalDiscoveryMsg_StartWatcher(id_, service_type_));
76    started_ = true;
77  }
78
79  virtual void DiscoverNewServices(bool force_update) OVERRIDE {
80    DVLOG(1) << "ServiceWatcher::DiscoverNewServices with id " << id_;
81    DCHECK(started_);
82    host_->Send(new LocalDiscoveryMsg_DiscoverServices(id_, force_update));
83  }
84
85  virtual void SetActivelyRefreshServices(
86      bool actively_refresh_services) OVERRIDE {
87    DVLOG(1) << "ServiceWatcher::SetActivelyRefreshServices with id " << id_;
88    DCHECK(started_);
89    host_->Send(new LocalDiscoveryMsg_SetActivelyRefreshServices(
90        id_, actively_refresh_services));
91  }
92
93  virtual std::string GetServiceType() const OVERRIDE {
94    return service_type_;
95  }
96
97 private:
98  scoped_refptr<ServiceDiscoveryHostClient> host_;
99  const std::string service_type_;
100  const uint64 id_;
101  bool started_;
102};
103
104class ServiceDiscoveryHostClient::ServiceResolverProxy
105    : public ServiceResolver {
106 public:
107  ServiceResolverProxy(ServiceDiscoveryHostClient* host,
108                       const std::string& service_name,
109                       const ServiceResolver::ResolveCompleteCallback& callback)
110      : host_(host),
111        service_name_(service_name),
112        id_(host->RegisterResolverCallback(callback)),
113        started_(false) {
114  }
115
116  virtual ~ServiceResolverProxy() {
117    DVLOG(1) << "~ServiceResolverProxy with id " << id_;
118    host_->UnregisterResolverCallback(id_);
119    if (started_)
120      host_->Send(new LocalDiscoveryMsg_DestroyResolver(id_));
121  }
122
123  virtual void StartResolving() OVERRIDE {
124    DVLOG(1) << "ServiceResolverProxy::StartResolving with id " << id_;
125    DCHECK(!started_);
126    host_->Send(new LocalDiscoveryMsg_ResolveService(id_, service_name_));
127    started_ = true;
128  }
129
130  virtual std::string GetName() const OVERRIDE {
131    return service_name_;
132  }
133
134 private:
135  scoped_refptr<ServiceDiscoveryHostClient> host_;
136  const std::string service_name_;
137  const uint64 id_;
138  bool started_;
139};
140
141class ServiceDiscoveryHostClient::LocalDomainResolverProxy
142    : public LocalDomainResolver {
143 public:
144  LocalDomainResolverProxy(ServiceDiscoveryHostClient* host,
145                       const std::string& domain,
146                       net::AddressFamily address_family,
147                       const LocalDomainResolver::IPAddressCallback& callback)
148      : host_(host),
149        domain_(domain),
150        address_family_(address_family),
151        id_(host->RegisterLocalDomainResolverCallback(callback)),
152        started_(false) {
153  }
154
155  virtual ~LocalDomainResolverProxy() {
156    DVLOG(1) << "~LocalDomainResolverProxy with id " << id_;
157    host_->UnregisterLocalDomainResolverCallback(id_);
158    if (started_)
159      host_->Send(new LocalDiscoveryMsg_DestroyLocalDomainResolver(id_));
160  }
161
162  virtual void Start() OVERRIDE {
163    DVLOG(1) << "LocalDomainResolverProxy::Start with id " << id_;
164    DCHECK(!started_);
165    host_->Send(new LocalDiscoveryMsg_ResolveLocalDomain(id_, domain_,
166                                                         address_family_));
167    started_ = true;
168  }
169
170 private:
171  scoped_refptr<ServiceDiscoveryHostClient> host_;
172  std::string domain_;
173  net::AddressFamily address_family_;
174  const uint64 id_;
175  bool started_;
176};
177
178ServiceDiscoveryHostClient::ServiceDiscoveryHostClient() : current_id_(0) {
179  callback_runner_ = base::MessageLoop::current()->message_loop_proxy();
180  io_runner_ = BrowserThread::GetMessageLoopProxyForThread(BrowserThread::IO);
181}
182
183ServiceDiscoveryHostClient::~ServiceDiscoveryHostClient() {
184  DCHECK(service_watcher_callbacks_.empty());
185  DCHECK(service_resolver_callbacks_.empty());
186  DCHECK(domain_resolver_callbacks_.empty());
187}
188
189scoped_ptr<ServiceWatcher> ServiceDiscoveryHostClient::CreateServiceWatcher(
190    const std::string& service_type,
191    const ServiceWatcher::UpdatedCallback& callback) {
192  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
193  return scoped_ptr<ServiceWatcher>(
194      new ServiceWatcherProxy(this, service_type, callback));
195}
196
197scoped_ptr<ServiceResolver> ServiceDiscoveryHostClient::CreateServiceResolver(
198    const std::string& service_name,
199    const ServiceResolver::ResolveCompleteCallback& callback) {
200  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
201  return scoped_ptr<ServiceResolver>(
202      new ServiceResolverProxy(this, service_name, callback));
203}
204
205scoped_ptr<LocalDomainResolver>
206ServiceDiscoveryHostClient::CreateLocalDomainResolver(
207    const std::string& domain,
208    net::AddressFamily address_family,
209    const LocalDomainResolver::IPAddressCallback& callback) {
210  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
211  return scoped_ptr<LocalDomainResolver>(new LocalDomainResolverProxy(
212      this, domain, address_family, callback));
213}
214
215uint64 ServiceDiscoveryHostClient::RegisterWatcherCallback(
216    const ServiceWatcher::UpdatedCallback& callback) {
217  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
218  DCHECK(!ContainsKey(service_watcher_callbacks_, current_id_ + 1));
219  service_watcher_callbacks_[++current_id_] = callback;
220  return current_id_;
221}
222
223uint64 ServiceDiscoveryHostClient::RegisterResolverCallback(
224    const ServiceResolver::ResolveCompleteCallback& callback) {
225  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
226  DCHECK(!ContainsKey(service_resolver_callbacks_, current_id_ + 1));
227  service_resolver_callbacks_[++current_id_] = callback;
228  return current_id_;
229}
230
231uint64 ServiceDiscoveryHostClient::RegisterLocalDomainResolverCallback(
232    const LocalDomainResolver::IPAddressCallback& callback) {
233  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
234  DCHECK(!ContainsKey(domain_resolver_callbacks_, current_id_ + 1));
235  domain_resolver_callbacks_[++current_id_] = callback;
236  return current_id_;
237}
238
239void ServiceDiscoveryHostClient::UnregisterWatcherCallback(uint64 id) {
240  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
241  service_watcher_callbacks_.erase(id);
242}
243
244void ServiceDiscoveryHostClient::UnregisterResolverCallback(uint64 id) {
245  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
246  service_resolver_callbacks_.erase(id);
247}
248
249void ServiceDiscoveryHostClient::UnregisterLocalDomainResolverCallback(
250    uint64 id) {
251  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
252  domain_resolver_callbacks_.erase(id);
253}
254
255void ServiceDiscoveryHostClient::Start(
256    const base::Closure& error_callback) {
257  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
258  DCHECK(!utility_host_);
259  DCHECK(error_callback_.is_null());
260  error_callback_ = error_callback;
261  io_runner_->PostTask(
262      FROM_HERE,
263      base::Bind(&ServiceDiscoveryHostClient::StartOnIOThread, this));
264}
265
266void ServiceDiscoveryHostClient::Shutdown() {
267  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
268  io_runner_->PostTask(
269      FROM_HERE,
270      base::Bind(&ServiceDiscoveryHostClient::ShutdownOnIOThread, this));
271}
272
273#if defined(OS_POSIX)
274
275void ServiceDiscoveryHostClient::StartOnIOThread() {
276  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
277  DCHECK(!utility_host_);
278  BrowserThread::PostTaskAndReplyWithResult(
279      BrowserThread::FILE,
280      FROM_HERE,
281      base::Bind(&GetSocketsOnFileThread),
282      base::Bind(&ServiceDiscoveryHostClient::OnSocketsReady, this));
283}
284
285void ServiceDiscoveryHostClient::OnSocketsReady(const SocketInfoList& sockets) {
286  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
287  DCHECK(!utility_host_);
288  utility_host_ = UtilityProcessHost::Create(
289      this, base::MessageLoopProxy::current().get())->AsWeakPtr();
290  if (!utility_host_)
291    return;
292  utility_host_->EnableMDns();
293  utility_host_->StartBatchMode();
294  if (sockets.empty()) {
295    ShutdownOnIOThread();
296    return;
297  }
298  utility_host_->Send(new LocalDiscoveryMsg_SetSockets(sockets));
299  // Send messages for requests made during network enumeration.
300  for (size_t i = 0; i < delayed_messages_.size(); ++i)
301    utility_host_->Send(delayed_messages_[i]);
302  delayed_messages_.weak_clear();
303}
304
305#else  // OS_POSIX
306
307void ServiceDiscoveryHostClient::StartOnIOThread() {
308  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
309  DCHECK(!utility_host_);
310  utility_host_ = UtilityProcessHost::Create(
311      this, base::MessageLoopProxy::current().get())->AsWeakPtr();
312  if (!utility_host_)
313    return;
314  utility_host_->EnableMDns();
315  utility_host_->StartBatchMode();
316  // Windows does not enumerate networks here.
317  DCHECK(delayed_messages_.empty());
318}
319
320#endif  // OS_POSIX
321
322void ServiceDiscoveryHostClient::ShutdownOnIOThread() {
323  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
324  if (utility_host_) {
325    utility_host_->Send(new LocalDiscoveryMsg_ShutdownLocalDiscovery);
326    utility_host_->EndBatchMode();
327    utility_host_.reset();
328  }
329  error_callback_ = base::Closure();
330}
331
332void ServiceDiscoveryHostClient::Send(IPC::Message* msg) {
333  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
334  io_runner_->PostTask(
335      FROM_HERE,
336      base::Bind(&ServiceDiscoveryHostClient::SendOnIOThread, this, msg));
337}
338
339void ServiceDiscoveryHostClient::SendOnIOThread(IPC::Message* msg) {
340  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
341  if (utility_host_) {
342    utility_host_->Send(msg);
343  } else {
344    delayed_messages_.push_back(msg);
345  }
346}
347
348void ServiceDiscoveryHostClient::OnProcessCrashed(int exit_code) {
349  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
350  DCHECK(!utility_host_);
351  OnError();
352}
353
354bool ServiceDiscoveryHostClient::OnMessageReceived(
355    const IPC::Message& message) {
356  bool handled = true;
357  IPC_BEGIN_MESSAGE_MAP(ServiceDiscoveryHostClient, message)
358    IPC_MESSAGE_HANDLER(LocalDiscoveryHostMsg_Error, OnError)
359    IPC_MESSAGE_HANDLER(LocalDiscoveryHostMsg_WatcherCallback,
360                        OnWatcherCallback)
361    IPC_MESSAGE_HANDLER(LocalDiscoveryHostMsg_ResolverCallback,
362                        OnResolverCallback)
363    IPC_MESSAGE_HANDLER(LocalDiscoveryHostMsg_LocalDomainResolverCallback,
364                        OnLocalDomainResolverCallback)
365    IPC_MESSAGE_UNHANDLED(handled = false)
366  IPC_END_MESSAGE_MAP()
367  return handled;
368}
369
370void ServiceDiscoveryHostClient::InvalidateWatchers() {
371  WatcherCallbacks service_watcher_callbacks;
372  service_watcher_callbacks_.swap(service_watcher_callbacks);
373  service_resolver_callbacks_.clear();
374  domain_resolver_callbacks_.clear();
375
376  for (WatcherCallbacks::iterator i = service_watcher_callbacks.begin();
377       i != service_watcher_callbacks.end(); i++) {
378    if (!i->second.is_null()) {
379      i->second.Run(ServiceWatcher::UPDATE_INVALIDATED, "");
380    }
381  }
382}
383
384void ServiceDiscoveryHostClient::OnError() {
385  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
386  if (!error_callback_.is_null())
387    callback_runner_->PostTask(FROM_HERE, error_callback_);
388}
389
390void ServiceDiscoveryHostClient::OnWatcherCallback(
391    uint64 id,
392    ServiceWatcher::UpdateType update,
393    const std::string& service_name) {
394  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
395  callback_runner_->PostTask(
396      FROM_HERE,
397      base::Bind(&ServiceDiscoveryHostClient::RunWatcherCallback, this, id,
398                 update, service_name));
399}
400
401void ServiceDiscoveryHostClient::OnResolverCallback(
402    uint64 id,
403    ServiceResolver::RequestStatus status,
404    const ServiceDescription& description) {
405  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
406  callback_runner_->PostTask(
407      FROM_HERE,
408      base::Bind(&ServiceDiscoveryHostClient::RunResolverCallback, this, id,
409                 status, description));
410}
411
412void ServiceDiscoveryHostClient::OnLocalDomainResolverCallback(
413    uint64 id,
414    bool success,
415    const net::IPAddressNumber& ip_address_ipv4,
416    const net::IPAddressNumber& ip_address_ipv6) {
417  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
418  callback_runner_->PostTask(
419      FROM_HERE,
420      base::Bind(&ServiceDiscoveryHostClient::RunLocalDomainResolverCallback,
421                 this, id, success, ip_address_ipv4, ip_address_ipv6));
422}
423
424void ServiceDiscoveryHostClient::RunWatcherCallback(
425    uint64 id,
426    ServiceWatcher::UpdateType update,
427    const std::string& service_name) {
428  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
429  WatcherCallbacks::iterator it = service_watcher_callbacks_.find(id);
430  if (it != service_watcher_callbacks_.end() && !it->second.is_null())
431    it->second.Run(update, service_name);
432}
433
434void ServiceDiscoveryHostClient::RunResolverCallback(
435    uint64 id,
436    ServiceResolver::RequestStatus status,
437    const ServiceDescription& description) {
438  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
439  ResolverCallbacks::iterator it = service_resolver_callbacks_.find(id);
440  if (it != service_resolver_callbacks_.end() && !it->second.is_null())
441    it->second.Run(status, description);
442}
443
444void ServiceDiscoveryHostClient::RunLocalDomainResolverCallback(
445    uint64 id,
446    bool success,
447    const net::IPAddressNumber& ip_address_ipv4,
448    const net::IPAddressNumber& ip_address_ipv6) {
449  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
450  DomainResolverCallbacks::iterator it = domain_resolver_callbacks_.find(id);
451  if (it != domain_resolver_callbacks_.end() && !it->second.is_null())
452    it->second.Run(success, ip_address_ipv4, ip_address_ipv6);
453}
454
455}  // namespace local_discovery
456