1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/utility/local_discovery/service_discovery_message_handler.h"
6
7#include <algorithm>
8
9#include "base/lazy_instance.h"
10#include "chrome/common/local_discovery/local_discovery_messages.h"
11#include "chrome/common/local_discovery/service_discovery_client_impl.h"
12#include "content/public/utility/utility_thread.h"
13#include "net/socket/socket_descriptor.h"
14#include "net/udp/datagram_server_socket.h"
15
16namespace local_discovery {
17
18namespace {
19
20void ClosePlatformSocket(net::SocketDescriptor socket);
21
22// Sets socket factory used by |net::CreatePlatformSocket|. Implemetation
23// keeps single socket that will be returned to the first call to
24// |net::CreatePlatformSocket| during object lifetime.
25class ScopedSocketFactory : public net::PlatformSocketFactory {
26 public:
27  explicit ScopedSocketFactory(net::SocketDescriptor socket) : socket_(socket) {
28    net::PlatformSocketFactory::SetInstance(this);
29  }
30
31  virtual ~ScopedSocketFactory() {
32    net::PlatformSocketFactory::SetInstance(NULL);
33    ClosePlatformSocket(socket_);
34    socket_ = net::kInvalidSocket;
35  }
36
37  virtual net::SocketDescriptor CreateSocket(int family, int type,
38                                             int protocol) OVERRIDE {
39    DCHECK_EQ(type, SOCK_DGRAM);
40    DCHECK(family == AF_INET || family == AF_INET6);
41    net::SocketDescriptor result = net::kInvalidSocket;
42    std::swap(result, socket_);
43    return result;
44  }
45
46 private:
47  net::SocketDescriptor socket_;
48  DISALLOW_COPY_AND_ASSIGN(ScopedSocketFactory);
49};
50
51struct SocketInfo {
52  SocketInfo(net::SocketDescriptor socket,
53             net::AddressFamily address_family,
54             uint32 interface_index)
55      : socket(socket),
56        address_family(address_family),
57        interface_index(interface_index) {
58  }
59  net::SocketDescriptor socket;
60  net::AddressFamily address_family;
61  uint32 interface_index;
62};
63
64// Returns list of sockets preallocated before.
65class PreCreatedMDnsSocketFactory : public net::MDnsSocketFactory {
66 public:
67  PreCreatedMDnsSocketFactory() {}
68  virtual ~PreCreatedMDnsSocketFactory() {
69    // Not empty if process exits too fast, before starting mDns code. If
70    // happened, destructors may crash accessing destroyed global objects.
71    sockets_.weak_clear();
72  }
73
74  // net::MDnsSocketFactory implementation:
75  virtual void CreateSockets(
76      ScopedVector<net::DatagramServerSocket>* sockets) OVERRIDE {
77    sockets->swap(sockets_);
78    Reset();
79  }
80
81  void AddSocket(const SocketInfo& socket_info) {
82    // Takes ownership of socket_info.socket;
83    ScopedSocketFactory platform_factory(socket_info.socket);
84    scoped_ptr<net::DatagramServerSocket> socket(
85        net::CreateAndBindMDnsSocket(socket_info.address_family,
86                                     socket_info.interface_index));
87    if (socket) {
88      socket->DetachFromThread();
89      sockets_.push_back(socket.release());
90    }
91  }
92
93  void Reset() {
94    sockets_.clear();
95  }
96
97 private:
98  ScopedVector<net::DatagramServerSocket> sockets_;
99
100  DISALLOW_COPY_AND_ASSIGN(PreCreatedMDnsSocketFactory);
101};
102
103base::LazyInstance<PreCreatedMDnsSocketFactory>
104    g_local_discovery_socket_factory = LAZY_INSTANCE_INITIALIZER;
105
106#if defined(OS_WIN)
107
108void ClosePlatformSocket(net::SocketDescriptor socket) {
109  ::closesocket(socket);
110}
111
112void StaticInitializeSocketFactory() {
113  net::InterfaceIndexFamilyList interfaces(net::GetMDnsInterfacesToBind());
114  for (size_t i = 0; i < interfaces.size(); ++i) {
115    DCHECK(interfaces[i].second == net::ADDRESS_FAMILY_IPV4 ||
116           interfaces[i].second == net::ADDRESS_FAMILY_IPV6);
117    net::SocketDescriptor descriptor =
118        net::CreatePlatformSocket(
119            net::ConvertAddressFamily(interfaces[i].second), SOCK_DGRAM,
120                                      IPPROTO_UDP);
121    g_local_discovery_socket_factory.Get().AddSocket(
122        SocketInfo(descriptor, interfaces[i].second, interfaces[i].first));
123  }
124}
125
126#else  // OS_WIN
127
128void ClosePlatformSocket(net::SocketDescriptor socket) {
129  ::close(socket);
130}
131
132void StaticInitializeSocketFactory() {
133}
134
135#endif  // OS_WIN
136
137void SendHostMessageOnUtilityThread(IPC::Message* msg) {
138  content::UtilityThread::Get()->Send(msg);
139}
140
141std::string WatcherUpdateToString(ServiceWatcher::UpdateType update) {
142  switch (update) {
143    case ServiceWatcher::UPDATE_ADDED:
144      return "UPDATE_ADDED";
145    case ServiceWatcher::UPDATE_CHANGED:
146      return "UPDATE_CHANGED";
147    case ServiceWatcher::UPDATE_REMOVED:
148      return "UPDATE_REMOVED";
149    case ServiceWatcher::UPDATE_INVALIDATED:
150      return "UPDATE_INVALIDATED";
151  }
152  return "Unknown Update";
153}
154
155std::string ResolverStatusToString(ServiceResolver::RequestStatus status) {
156  switch (status) {
157    case ServiceResolver::STATUS_SUCCESS:
158      return "STATUS_SUCESS";
159    case ServiceResolver::STATUS_REQUEST_TIMEOUT:
160      return "STATUS_REQUEST_TIMEOUT";
161    case ServiceResolver::STATUS_KNOWN_NONEXISTENT:
162      return "STATUS_KNOWN_NONEXISTENT";
163  }
164  return "Unknown Status";
165}
166
167}  // namespace
168
169ServiceDiscoveryMessageHandler::ServiceDiscoveryMessageHandler() {
170}
171
172ServiceDiscoveryMessageHandler::~ServiceDiscoveryMessageHandler() {
173  DCHECK(!discovery_thread_);
174}
175
176void ServiceDiscoveryMessageHandler::PreSandboxStartup() {
177  StaticInitializeSocketFactory();
178}
179
180void ServiceDiscoveryMessageHandler::InitializeMdns() {
181  if (service_discovery_client_ || mdns_client_)
182    return;
183
184  mdns_client_ = net::MDnsClient::CreateDefault();
185  bool result =
186      mdns_client_->StartListening(g_local_discovery_socket_factory.Pointer());
187  // Close unused sockets.
188  g_local_discovery_socket_factory.Get().Reset();
189  if (!result) {
190    VLOG(1) << "Failed to start MDnsClient";
191    Send(new LocalDiscoveryHostMsg_Error());
192    return;
193  }
194
195  service_discovery_client_.reset(
196      new local_discovery::ServiceDiscoveryClientImpl(mdns_client_.get()));
197}
198
199bool ServiceDiscoveryMessageHandler::InitializeThread() {
200  if (discovery_task_runner_.get())
201    return true;
202  if (discovery_thread_)
203    return false;
204  utility_task_runner_ = base::MessageLoop::current()->message_loop_proxy();
205  discovery_thread_.reset(new base::Thread("ServiceDiscoveryThread"));
206  base::Thread::Options thread_options(base::MessageLoop::TYPE_IO, 0);
207  if (discovery_thread_->StartWithOptions(thread_options)) {
208    discovery_task_runner_ = discovery_thread_->message_loop_proxy();
209    discovery_task_runner_->PostTask(FROM_HERE,
210        base::Bind(&ServiceDiscoveryMessageHandler::InitializeMdns,
211                    base::Unretained(this)));
212  }
213  return discovery_task_runner_.get() != NULL;
214}
215
216bool ServiceDiscoveryMessageHandler::OnMessageReceived(
217    const IPC::Message& message) {
218  bool handled = true;
219  IPC_BEGIN_MESSAGE_MAP(ServiceDiscoveryMessageHandler, message)
220#if defined(OS_POSIX)
221    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetSockets, OnSetSockets)
222#endif  // OS_POSIX
223    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_StartWatcher, OnStartWatcher)
224    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DiscoverServices, OnDiscoverServices)
225    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetActivelyRefreshServices,
226                        OnSetActivelyRefreshServices)
227    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyWatcher, OnDestroyWatcher)
228    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveService, OnResolveService)
229    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyResolver, OnDestroyResolver)
230    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveLocalDomain,
231                        OnResolveLocalDomain)
232    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyLocalDomainResolver,
233                        OnDestroyLocalDomainResolver)
234    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ShutdownLocalDiscovery,
235                        ShutdownLocalDiscovery)
236    IPC_MESSAGE_UNHANDLED(handled = false)
237  IPC_END_MESSAGE_MAP()
238  return handled;
239}
240
241void ServiceDiscoveryMessageHandler::PostTask(
242    const tracked_objects::Location& from_here,
243    const base::Closure& task) {
244  if (!InitializeThread())
245    return;
246  discovery_task_runner_->PostTask(from_here, task);
247}
248
249#if defined(OS_POSIX)
250void ServiceDiscoveryMessageHandler::OnSetSockets(
251    const std::vector<LocalDiscoveryMsg_SocketInfo>& sockets) {
252  for (size_t i = 0; i < sockets.size(); ++i) {
253    g_local_discovery_socket_factory.Get().AddSocket(
254        SocketInfo(sockets[i].descriptor.fd, sockets[i].address_family,
255                   sockets[i].interface_index));
256  }
257}
258#endif  // OS_POSIX
259
260void ServiceDiscoveryMessageHandler::OnStartWatcher(
261    uint64 id,
262    const std::string& service_type) {
263  PostTask(FROM_HERE,
264           base::Bind(&ServiceDiscoveryMessageHandler::StartWatcher,
265                      base::Unretained(this), id, service_type));
266}
267
268void ServiceDiscoveryMessageHandler::OnDiscoverServices(uint64 id,
269                                                        bool force_update) {
270  PostTask(FROM_HERE,
271           base::Bind(&ServiceDiscoveryMessageHandler::DiscoverServices,
272                      base::Unretained(this), id, force_update));
273}
274
275void ServiceDiscoveryMessageHandler::OnSetActivelyRefreshServices(
276    uint64 id, bool actively_refresh_services) {
277  PostTask(FROM_HERE,
278           base::Bind(
279               &ServiceDiscoveryMessageHandler::SetActivelyRefreshServices,
280               base::Unretained(this), id, actively_refresh_services));
281}
282
283void ServiceDiscoveryMessageHandler::OnDestroyWatcher(uint64 id) {
284  PostTask(FROM_HERE,
285           base::Bind(&ServiceDiscoveryMessageHandler::DestroyWatcher,
286                      base::Unretained(this), id));
287}
288
289void ServiceDiscoveryMessageHandler::OnResolveService(
290    uint64 id,
291    const std::string& service_name) {
292  PostTask(FROM_HERE,
293           base::Bind(&ServiceDiscoveryMessageHandler::ResolveService,
294                      base::Unretained(this), id, service_name));
295}
296
297void ServiceDiscoveryMessageHandler::OnDestroyResolver(uint64 id) {
298  PostTask(FROM_HERE,
299           base::Bind(&ServiceDiscoveryMessageHandler::DestroyResolver,
300                      base::Unretained(this), id));
301}
302
303void ServiceDiscoveryMessageHandler::OnResolveLocalDomain(
304    uint64 id, const std::string& domain,
305    net::AddressFamily address_family) {
306    PostTask(FROM_HERE,
307           base::Bind(&ServiceDiscoveryMessageHandler::ResolveLocalDomain,
308                      base::Unretained(this), id, domain, address_family));
309}
310
311void ServiceDiscoveryMessageHandler::OnDestroyLocalDomainResolver(uint64 id) {
312  PostTask(FROM_HERE,
313           base::Bind(
314               &ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver,
315               base::Unretained(this), id));
316}
317
318void ServiceDiscoveryMessageHandler::StartWatcher(
319    uint64 id,
320    const std::string& service_type) {
321  VLOG(1) << "StartWatcher, id=" << id << ", type=" << service_type;
322  if (!service_discovery_client_)
323    return;
324  DCHECK(!ContainsKey(service_watchers_, id));
325  scoped_ptr<ServiceWatcher> watcher(
326      service_discovery_client_->CreateServiceWatcher(
327          service_type,
328          base::Bind(&ServiceDiscoveryMessageHandler::OnServiceUpdated,
329                     base::Unretained(this), id)));
330  watcher->Start();
331  service_watchers_[id].reset(watcher.release());
332}
333
334void ServiceDiscoveryMessageHandler::DiscoverServices(uint64 id,
335                                                      bool force_update) {
336  VLOG(1) << "DiscoverServices, id=" << id;
337  if (!service_discovery_client_)
338    return;
339  DCHECK(ContainsKey(service_watchers_, id));
340  service_watchers_[id]->DiscoverNewServices(force_update);
341}
342
343void ServiceDiscoveryMessageHandler::SetActivelyRefreshServices(
344    uint64 id,
345    bool actively_refresh_services) {
346  VLOG(1) << "ActivelyRefreshServices, id=" << id;
347  if (!service_discovery_client_)
348    return;
349  DCHECK(ContainsKey(service_watchers_, id));
350  service_watchers_[id]->SetActivelyRefreshServices(actively_refresh_services);
351}
352
353void ServiceDiscoveryMessageHandler::DestroyWatcher(uint64 id) {
354  VLOG(1) << "DestoryWatcher, id=" << id;
355  if (!service_discovery_client_)
356    return;
357  service_watchers_.erase(id);
358}
359
360void ServiceDiscoveryMessageHandler::ResolveService(
361    uint64 id,
362    const std::string& service_name) {
363  VLOG(1) << "ResolveService, id=" << id << ", name=" << service_name;
364  if (!service_discovery_client_)
365    return;
366  DCHECK(!ContainsKey(service_resolvers_, id));
367  scoped_ptr<ServiceResolver> resolver(
368      service_discovery_client_->CreateServiceResolver(
369          service_name,
370          base::Bind(&ServiceDiscoveryMessageHandler::OnServiceResolved,
371                     base::Unretained(this), id)));
372  resolver->StartResolving();
373  service_resolvers_[id].reset(resolver.release());
374}
375
376void ServiceDiscoveryMessageHandler::DestroyResolver(uint64 id) {
377  VLOG(1) << "DestroyResolver, id=" << id;
378  if (!service_discovery_client_)
379    return;
380  service_resolvers_.erase(id);
381}
382
383void ServiceDiscoveryMessageHandler::ResolveLocalDomain(
384    uint64 id,
385    const std::string& domain,
386    net::AddressFamily address_family) {
387  VLOG(1) << "ResolveLocalDomain, id=" << id << ", domain=" << domain;
388  if (!service_discovery_client_)
389    return;
390  DCHECK(!ContainsKey(local_domain_resolvers_, id));
391  scoped_ptr<LocalDomainResolver> resolver(
392      service_discovery_client_->CreateLocalDomainResolver(
393          domain, address_family,
394          base::Bind(&ServiceDiscoveryMessageHandler::OnLocalDomainResolved,
395                     base::Unretained(this), id)));
396  resolver->Start();
397  local_domain_resolvers_[id].reset(resolver.release());
398}
399
400void ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver(uint64 id) {
401  VLOG(1) << "DestroyLocalDomainResolver, id=" << id;
402  if (!service_discovery_client_)
403    return;
404  local_domain_resolvers_.erase(id);
405}
406
407void ServiceDiscoveryMessageHandler::ShutdownLocalDiscovery() {
408  if (!discovery_task_runner_.get())
409    return;
410
411  discovery_task_runner_->PostTask(
412      FROM_HERE,
413      base::Bind(&ServiceDiscoveryMessageHandler::ShutdownOnIOThread,
414                 base::Unretained(this)));
415
416  // This will wait for message loop to drain, so ShutdownOnIOThread will
417  // definitely be called.
418  discovery_thread_.reset();
419}
420
421void ServiceDiscoveryMessageHandler::ShutdownOnIOThread() {
422  VLOG(1) << "ShutdownLocalDiscovery";
423  service_watchers_.clear();
424  service_resolvers_.clear();
425  local_domain_resolvers_.clear();
426  service_discovery_client_.reset();
427  mdns_client_.reset();
428}
429
430void ServiceDiscoveryMessageHandler::OnServiceUpdated(
431    uint64 id,
432    ServiceWatcher::UpdateType update,
433    const std::string& name) {
434  VLOG(1) << "OnServiceUpdated, id=" << id
435          << ", status=" << WatcherUpdateToString(update) << ", name=" << name;
436  DCHECK(service_discovery_client_);
437
438  Send(new LocalDiscoveryHostMsg_WatcherCallback(id, update, name));
439}
440
441void ServiceDiscoveryMessageHandler::OnServiceResolved(
442    uint64 id,
443    ServiceResolver::RequestStatus status,
444    const ServiceDescription& description) {
445  VLOG(1) << "OnServiceResolved, id=" << id
446          << ", status=" << ResolverStatusToString(status)
447          << ", name=" << description.service_name;
448
449  DCHECK(service_discovery_client_);
450  Send(new LocalDiscoveryHostMsg_ResolverCallback(id, status, description));
451}
452
453void ServiceDiscoveryMessageHandler::OnLocalDomainResolved(
454    uint64 id,
455    bool success,
456    const net::IPAddressNumber& address_ipv4,
457    const net::IPAddressNumber& address_ipv6) {
458  VLOG(1) << "OnLocalDomainResolved, id=" << id
459          << ", IPv4=" << (address_ipv4.empty() ? "" :
460                           net::IPAddressToString(address_ipv4))
461          << ", IPv6=" << (address_ipv6.empty() ? "" :
462                           net::IPAddressToString(address_ipv6));
463
464  DCHECK(service_discovery_client_);
465  Send(new LocalDiscoveryHostMsg_LocalDomainResolverCallback(
466          id, success, address_ipv4, address_ipv6));
467}
468
469void ServiceDiscoveryMessageHandler::Send(IPC::Message* msg) {
470  utility_task_runner_->PostTask(FROM_HERE,
471                                 base::Bind(&SendHostMessageOnUtilityThread,
472                                            msg));
473}
474
475}  // namespace local_discovery
476