service_discovery_message_handler.cc revision 3551c9c881056c480085172ff9840cab31610854
1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/utility/local_discovery/service_discovery_message_handler.h"
6
7#include <algorithm>
8
9#include "base/command_line.h"
10#include "chrome/common/local_discovery/local_discovery_messages.h"
11#include "chrome/utility/local_discovery/service_discovery_client_impl.h"
12#include "content/public/common/content_switches.h"
13#include "content/public/utility/utility_thread.h"
14
15#if defined(OS_WIN)
16
17#include "base/lazy_instance.h"
18#include "net/base/winsock_init.h"
19#include "net/base/winsock_util.h"
20
21#endif  // OS_WIN
22
23namespace local_discovery {
24
25namespace {
26
27bool NeedsSockets() {
28  return !CommandLine::ForCurrentProcess()->HasSwitch(switches::kNoSandbox) &&
29         CommandLine::ForCurrentProcess()->HasSwitch(
30             switches::kUtilityProcessEnableMDns);
31}
32
33#if defined(OS_WIN)
34
35class SocketFactory : public net::PlatformSocketFactory {
36 public:
37  SocketFactory()
38      : socket_v4_(NULL),
39        socket_v6_(NULL) {
40    net::EnsureWinsockInit();
41    socket_v4_ = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, NULL, 0,
42                           WSA_FLAG_OVERLAPPED);
43    socket_v6_ = WSASocket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP, NULL, 0,
44                           WSA_FLAG_OVERLAPPED);
45  }
46
47  void Reset() {
48    if (socket_v4_ != INVALID_SOCKET) {
49      closesocket(socket_v4_);
50      socket_v4_ = INVALID_SOCKET;
51    }
52    if (socket_v6_ != INVALID_SOCKET) {
53      closesocket(socket_v6_);
54      socket_v6_ = INVALID_SOCKET;
55    }
56  }
57
58  virtual ~SocketFactory() {
59    Reset();
60  }
61
62  virtual SOCKET CreateSocket(int family, int type, int protocol) OVERRIDE {
63    SOCKET result = INVALID_SOCKET;
64    if (type != SOCK_DGRAM && protocol != IPPROTO_UDP) {
65      NOTREACHED();
66    } else if (family == AF_INET) {
67      std::swap(result, socket_v4_);
68    } else if (family == AF_INET6) {
69      std::swap(result, socket_v6_);
70    }
71    return result;
72  }
73
74  SOCKET socket_v4_;
75  SOCKET socket_v6_;
76
77 private:
78  DISALLOW_COPY_AND_ASSIGN(SocketFactory);
79};
80
81base::LazyInstance<SocketFactory>
82    g_local_discovery_socket_factory = LAZY_INSTANCE_INITIALIZER;
83
84class ScopedSocketFactorySetter {
85 public:
86  ScopedSocketFactorySetter() {
87    if (NeedsSockets()) {
88      net::PlatformSocketFactory::SetInstance(
89          &g_local_discovery_socket_factory.Get());
90    }
91  }
92
93  ~ScopedSocketFactorySetter() {
94    if (NeedsSockets()) {
95      net::PlatformSocketFactory::SetInstance(NULL);
96      g_local_discovery_socket_factory.Get().Reset();
97    }
98  }
99
100  static void Initialize() {
101    if (NeedsSockets()) {
102      g_local_discovery_socket_factory.Get();
103    }
104  }
105
106 private:
107  DISALLOW_COPY_AND_ASSIGN(ScopedSocketFactorySetter);
108};
109
110#else  // OS_WIN
111
112class ScopedSocketFactorySetter {
113 public:
114  ScopedSocketFactorySetter() {}
115
116  static void Initialize() {
117    // TODO(vitalybuka) : implement socket access from sandbox for other
118    // platforms.
119    DCHECK(!NeedsSockets());
120  }
121};
122
123#endif  // OS_WIN
124
125void SendServiceResolved(uint64 id, ServiceResolver::RequestStatus status,
126                         const ServiceDescription& description) {
127  content::UtilityThread::Get()->Send(
128      new LocalDiscoveryHostMsg_ResolverCallback(id, status, description));
129}
130
131void SendServiceUpdated(uint64 id, ServiceWatcher::UpdateType update,
132                        const std::string& name) {
133  content::UtilityThread::Get()->Send(
134      new LocalDiscoveryHostMsg_WatcherCallback(id, update, name));
135}
136
137void SendLocalDomainResolved(uint64 id, bool success,
138                             const net::IPAddressNumber& address) {
139  content::UtilityThread::Get()->Send(
140      new LocalDiscoveryHostMsg_LocalDomainResolverCallback(
141          id, success, address));
142}
143
144}  // namespace
145
146ServiceDiscoveryMessageHandler::ServiceDiscoveryMessageHandler() {
147}
148
149ServiceDiscoveryMessageHandler::~ServiceDiscoveryMessageHandler() {
150  DCHECK(!discovery_thread_);
151}
152
153void ServiceDiscoveryMessageHandler::PreSandboxStartup() {
154  ScopedSocketFactorySetter::Initialize();
155}
156
157void ServiceDiscoveryMessageHandler::InitializeMdns() {
158  if (service_discovery_client_ || mdns_client_)
159    return;
160
161  mdns_client_ = net::MDnsClient::CreateDefault();
162  {
163    // Temporarily redirect network code to use pre-created sockets.
164    ScopedSocketFactorySetter socket_factory_setter;
165    if (!mdns_client_->StartListening())
166      return;
167  }
168
169  service_discovery_client_.reset(
170      new local_discovery::ServiceDiscoveryClientImpl(mdns_client_.get()));
171}
172
173bool ServiceDiscoveryMessageHandler::InitializeThread() {
174  if (discovery_task_runner_)
175    return true;
176  if (discovery_thread_)
177    return false;
178  utility_task_runner_ = base::MessageLoop::current()->message_loop_proxy();
179  discovery_thread_.reset(new base::Thread("ServiceDiscoveryThread"));
180  base::Thread::Options thread_options(base::MessageLoop::TYPE_IO, 0);
181  if (discovery_thread_->StartWithOptions(thread_options)) {
182    discovery_task_runner_ = discovery_thread_->message_loop_proxy();
183    discovery_task_runner_->PostTask(FROM_HERE,
184        base::Bind(&ServiceDiscoveryMessageHandler::InitializeMdns,
185                    base::Unretained(this)));
186  }
187  return discovery_task_runner_ != NULL;
188}
189
190bool ServiceDiscoveryMessageHandler::OnMessageReceived(
191    const IPC::Message& message) {
192  bool handled = true;
193  IPC_BEGIN_MESSAGE_MAP(ServiceDiscoveryMessageHandler, message)
194    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_StartWatcher, OnStartWatcher)
195    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DiscoverServices, OnDiscoverServices)
196    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyWatcher, OnDestroyWatcher)
197    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveService, OnResolveService)
198    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyResolver, OnDestroyResolver)
199    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveLocalDomain,
200                        OnResolveLocalDomain)
201    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyLocalDomainResolver,
202                        OnDestroyLocalDomainResolver)
203    IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ShutdownLocalDiscovery,
204                        ShutdownLocalDiscovery)
205    IPC_MESSAGE_UNHANDLED(handled = false)
206  IPC_END_MESSAGE_MAP()
207  return handled;
208}
209
210void ServiceDiscoveryMessageHandler::PostTask(
211    const tracked_objects::Location& from_here,
212    const base::Closure& task) {
213  if (!InitializeThread())
214    return;
215  discovery_task_runner_->PostTask(from_here, task);
216}
217
218void ServiceDiscoveryMessageHandler::OnStartWatcher(
219    uint64 id,
220    const std::string& service_type) {
221  PostTask(FROM_HERE,
222           base::Bind(&ServiceDiscoveryMessageHandler::StartWatcher,
223                      base::Unretained(this), id, service_type));
224}
225
226void ServiceDiscoveryMessageHandler::OnDiscoverServices(uint64 id,
227                                                        bool force_update) {
228  PostTask(FROM_HERE,
229           base::Bind(&ServiceDiscoveryMessageHandler::DiscoverServices,
230                      base::Unretained(this), id, force_update));
231}
232
233void ServiceDiscoveryMessageHandler::OnDestroyWatcher(uint64 id) {
234  PostTask(FROM_HERE,
235           base::Bind(&ServiceDiscoveryMessageHandler::DestroyWatcher,
236                      base::Unretained(this), id));
237}
238
239void ServiceDiscoveryMessageHandler::OnResolveService(
240    uint64 id,
241    const std::string& service_name) {
242  PostTask(FROM_HERE,
243           base::Bind(&ServiceDiscoveryMessageHandler::ResolveService,
244                      base::Unretained(this), id, service_name));
245}
246
247void ServiceDiscoveryMessageHandler::OnDestroyResolver(uint64 id) {
248  PostTask(FROM_HERE,
249           base::Bind(&ServiceDiscoveryMessageHandler::DestroyResolver,
250                      base::Unretained(this), id));
251}
252
253void ServiceDiscoveryMessageHandler::OnResolveLocalDomain(
254    uint64 id, const std::string& domain,
255    net::AddressFamily address_family) {
256    PostTask(FROM_HERE,
257           base::Bind(&ServiceDiscoveryMessageHandler::ResolveLocalDomain,
258                      base::Unretained(this), id, domain, address_family));
259}
260
261void ServiceDiscoveryMessageHandler::OnDestroyLocalDomainResolver(uint64 id) {
262  PostTask(FROM_HERE,
263           base::Bind(
264               &ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver,
265               base::Unretained(this), id));
266}
267
268void ServiceDiscoveryMessageHandler::StartWatcher(
269    uint64 id,
270    const std::string& service_type) {
271  if (!service_discovery_client_)
272    return;
273  DCHECK(!ContainsKey(service_watchers_, id));
274  scoped_ptr<ServiceWatcher> watcher(
275      service_discovery_client_->CreateServiceWatcher(
276          service_type,
277          base::Bind(&ServiceDiscoveryMessageHandler::OnServiceUpdated,
278                     base::Unretained(this), id)));
279  watcher->Start();
280  service_watchers_[id].reset(watcher.release());
281}
282
283void ServiceDiscoveryMessageHandler::DiscoverServices(uint64 id,
284                                                      bool force_update) {
285  if (!service_discovery_client_)
286    return;
287  DCHECK(ContainsKey(service_watchers_, id));
288  service_watchers_[id]->DiscoverNewServices(force_update);
289}
290
291void ServiceDiscoveryMessageHandler::DestroyWatcher(uint64 id) {
292  if (!service_discovery_client_)
293    return;
294  DCHECK(ContainsKey(service_watchers_, id));
295  service_watchers_.erase(id);
296}
297
298void ServiceDiscoveryMessageHandler::ResolveService(
299    uint64 id,
300    const std::string& service_name) {
301  if (!service_discovery_client_)
302    return;
303  DCHECK(!ContainsKey(service_resolvers_, id));
304  scoped_ptr<ServiceResolver> resolver(
305      service_discovery_client_->CreateServiceResolver(
306          service_name,
307          base::Bind(&ServiceDiscoveryMessageHandler::OnServiceResolved,
308                     base::Unretained(this), id)));
309  resolver->StartResolving();
310  service_resolvers_[id].reset(resolver.release());
311}
312
313void ServiceDiscoveryMessageHandler::DestroyResolver(uint64 id) {
314  if (!service_discovery_client_)
315    return;
316  DCHECK(ContainsKey(service_resolvers_, id));
317  service_resolvers_.erase(id);
318}
319
320void ServiceDiscoveryMessageHandler::ResolveLocalDomain(
321    uint64 id,
322    const std::string& domain,
323    net::AddressFamily address_family) {
324  if (!service_discovery_client_)
325    return;
326  DCHECK(!ContainsKey(local_domain_resolvers_, id));
327  scoped_ptr<LocalDomainResolver> resolver(
328      service_discovery_client_->CreateLocalDomainResolver(
329          domain, address_family,
330          base::Bind(&ServiceDiscoveryMessageHandler::OnLocalDomainResolved,
331                     base::Unretained(this), id)));
332  resolver->Start();
333  local_domain_resolvers_[id].reset(resolver.release());
334}
335
336void ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver(uint64 id) {
337  if (!service_discovery_client_)
338    return;
339  DCHECK(ContainsKey(local_domain_resolvers_, id));
340  local_domain_resolvers_.erase(id);
341}
342
343void ServiceDiscoveryMessageHandler::ShutdownLocalDiscovery() {
344  discovery_task_runner_->PostTask(
345      FROM_HERE,
346      base::Bind(&ServiceDiscoveryMessageHandler::ShutdownOnIOThread,
347                 base::Unretained(this)));
348
349  // This will wait for message loop to drain, so ShutdownOnIOThread will
350  // definitely be called.
351  discovery_thread_.reset();
352}
353
354void ServiceDiscoveryMessageHandler::ShutdownOnIOThread() {
355  service_watchers_.clear();
356  service_resolvers_.clear();
357  local_domain_resolvers_.clear();
358
359  service_discovery_client_.reset();
360  mdns_client_.reset();
361}
362
363void ServiceDiscoveryMessageHandler::OnServiceUpdated(
364    uint64 id,
365    ServiceWatcher::UpdateType update,
366    const std::string& name) {
367  DCHECK(service_discovery_client_);
368  utility_task_runner_->PostTask(FROM_HERE,
369      base::Bind(&SendServiceUpdated, id, update, name));
370}
371
372void ServiceDiscoveryMessageHandler::OnServiceResolved(
373    uint64 id,
374    ServiceResolver::RequestStatus status,
375    const ServiceDescription& description) {
376  DCHECK(service_discovery_client_);
377  utility_task_runner_->PostTask(FROM_HERE,
378      base::Bind(&SendServiceResolved, id, status, description));
379}
380
381void ServiceDiscoveryMessageHandler::OnLocalDomainResolved(
382    uint64 id,
383    bool success,
384    const net::IPAddressNumber& address) {
385  DCHECK(service_discovery_client_);
386  utility_task_runner_->PostTask(FROM_HERE, base::Bind(&SendLocalDomainResolved,
387                                                       id, success, address));
388}
389
390
391}  // namespace local_discovery
392