service_discovery_message_handler.cc revision 3551c9c881056c480085172ff9840cab31610854
1// Copyright 2013 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "chrome/utility/local_discovery/service_discovery_message_handler.h" 6 7#include <algorithm> 8 9#include "base/command_line.h" 10#include "chrome/common/local_discovery/local_discovery_messages.h" 11#include "chrome/utility/local_discovery/service_discovery_client_impl.h" 12#include "content/public/common/content_switches.h" 13#include "content/public/utility/utility_thread.h" 14 15#if defined(OS_WIN) 16 17#include "base/lazy_instance.h" 18#include "net/base/winsock_init.h" 19#include "net/base/winsock_util.h" 20 21#endif // OS_WIN 22 23namespace local_discovery { 24 25namespace { 26 27bool NeedsSockets() { 28 return !CommandLine::ForCurrentProcess()->HasSwitch(switches::kNoSandbox) && 29 CommandLine::ForCurrentProcess()->HasSwitch( 30 switches::kUtilityProcessEnableMDns); 31} 32 33#if defined(OS_WIN) 34 35class SocketFactory : public net::PlatformSocketFactory { 36 public: 37 SocketFactory() 38 : socket_v4_(NULL), 39 socket_v6_(NULL) { 40 net::EnsureWinsockInit(); 41 socket_v4_ = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, NULL, 0, 42 WSA_FLAG_OVERLAPPED); 43 socket_v6_ = WSASocket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP, NULL, 0, 44 WSA_FLAG_OVERLAPPED); 45 } 46 47 void Reset() { 48 if (socket_v4_ != INVALID_SOCKET) { 49 closesocket(socket_v4_); 50 socket_v4_ = INVALID_SOCKET; 51 } 52 if (socket_v6_ != INVALID_SOCKET) { 53 closesocket(socket_v6_); 54 socket_v6_ = INVALID_SOCKET; 55 } 56 } 57 58 virtual ~SocketFactory() { 59 Reset(); 60 } 61 62 virtual SOCKET CreateSocket(int family, int type, int protocol) OVERRIDE { 63 SOCKET result = INVALID_SOCKET; 64 if (type != SOCK_DGRAM && protocol != IPPROTO_UDP) { 65 NOTREACHED(); 66 } else if (family == AF_INET) { 67 std::swap(result, socket_v4_); 68 } else if (family == AF_INET6) { 69 std::swap(result, socket_v6_); 70 } 71 return result; 72 } 73 74 SOCKET socket_v4_; 75 SOCKET socket_v6_; 76 77 private: 78 DISALLOW_COPY_AND_ASSIGN(SocketFactory); 79}; 80 81base::LazyInstance<SocketFactory> 82 g_local_discovery_socket_factory = LAZY_INSTANCE_INITIALIZER; 83 84class ScopedSocketFactorySetter { 85 public: 86 ScopedSocketFactorySetter() { 87 if (NeedsSockets()) { 88 net::PlatformSocketFactory::SetInstance( 89 &g_local_discovery_socket_factory.Get()); 90 } 91 } 92 93 ~ScopedSocketFactorySetter() { 94 if (NeedsSockets()) { 95 net::PlatformSocketFactory::SetInstance(NULL); 96 g_local_discovery_socket_factory.Get().Reset(); 97 } 98 } 99 100 static void Initialize() { 101 if (NeedsSockets()) { 102 g_local_discovery_socket_factory.Get(); 103 } 104 } 105 106 private: 107 DISALLOW_COPY_AND_ASSIGN(ScopedSocketFactorySetter); 108}; 109 110#else // OS_WIN 111 112class ScopedSocketFactorySetter { 113 public: 114 ScopedSocketFactorySetter() {} 115 116 static void Initialize() { 117 // TODO(vitalybuka) : implement socket access from sandbox for other 118 // platforms. 119 DCHECK(!NeedsSockets()); 120 } 121}; 122 123#endif // OS_WIN 124 125void SendServiceResolved(uint64 id, ServiceResolver::RequestStatus status, 126 const ServiceDescription& description) { 127 content::UtilityThread::Get()->Send( 128 new LocalDiscoveryHostMsg_ResolverCallback(id, status, description)); 129} 130 131void SendServiceUpdated(uint64 id, ServiceWatcher::UpdateType update, 132 const std::string& name) { 133 content::UtilityThread::Get()->Send( 134 new LocalDiscoveryHostMsg_WatcherCallback(id, update, name)); 135} 136 137void SendLocalDomainResolved(uint64 id, bool success, 138 const net::IPAddressNumber& address) { 139 content::UtilityThread::Get()->Send( 140 new LocalDiscoveryHostMsg_LocalDomainResolverCallback( 141 id, success, address)); 142} 143 144} // namespace 145 146ServiceDiscoveryMessageHandler::ServiceDiscoveryMessageHandler() { 147} 148 149ServiceDiscoveryMessageHandler::~ServiceDiscoveryMessageHandler() { 150 DCHECK(!discovery_thread_); 151} 152 153void ServiceDiscoveryMessageHandler::PreSandboxStartup() { 154 ScopedSocketFactorySetter::Initialize(); 155} 156 157void ServiceDiscoveryMessageHandler::InitializeMdns() { 158 if (service_discovery_client_ || mdns_client_) 159 return; 160 161 mdns_client_ = net::MDnsClient::CreateDefault(); 162 { 163 // Temporarily redirect network code to use pre-created sockets. 164 ScopedSocketFactorySetter socket_factory_setter; 165 if (!mdns_client_->StartListening()) 166 return; 167 } 168 169 service_discovery_client_.reset( 170 new local_discovery::ServiceDiscoveryClientImpl(mdns_client_.get())); 171} 172 173bool ServiceDiscoveryMessageHandler::InitializeThread() { 174 if (discovery_task_runner_) 175 return true; 176 if (discovery_thread_) 177 return false; 178 utility_task_runner_ = base::MessageLoop::current()->message_loop_proxy(); 179 discovery_thread_.reset(new base::Thread("ServiceDiscoveryThread")); 180 base::Thread::Options thread_options(base::MessageLoop::TYPE_IO, 0); 181 if (discovery_thread_->StartWithOptions(thread_options)) { 182 discovery_task_runner_ = discovery_thread_->message_loop_proxy(); 183 discovery_task_runner_->PostTask(FROM_HERE, 184 base::Bind(&ServiceDiscoveryMessageHandler::InitializeMdns, 185 base::Unretained(this))); 186 } 187 return discovery_task_runner_ != NULL; 188} 189 190bool ServiceDiscoveryMessageHandler::OnMessageReceived( 191 const IPC::Message& message) { 192 bool handled = true; 193 IPC_BEGIN_MESSAGE_MAP(ServiceDiscoveryMessageHandler, message) 194 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_StartWatcher, OnStartWatcher) 195 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DiscoverServices, OnDiscoverServices) 196 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyWatcher, OnDestroyWatcher) 197 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveService, OnResolveService) 198 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyResolver, OnDestroyResolver) 199 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveLocalDomain, 200 OnResolveLocalDomain) 201 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyLocalDomainResolver, 202 OnDestroyLocalDomainResolver) 203 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ShutdownLocalDiscovery, 204 ShutdownLocalDiscovery) 205 IPC_MESSAGE_UNHANDLED(handled = false) 206 IPC_END_MESSAGE_MAP() 207 return handled; 208} 209 210void ServiceDiscoveryMessageHandler::PostTask( 211 const tracked_objects::Location& from_here, 212 const base::Closure& task) { 213 if (!InitializeThread()) 214 return; 215 discovery_task_runner_->PostTask(from_here, task); 216} 217 218void ServiceDiscoveryMessageHandler::OnStartWatcher( 219 uint64 id, 220 const std::string& service_type) { 221 PostTask(FROM_HERE, 222 base::Bind(&ServiceDiscoveryMessageHandler::StartWatcher, 223 base::Unretained(this), id, service_type)); 224} 225 226void ServiceDiscoveryMessageHandler::OnDiscoverServices(uint64 id, 227 bool force_update) { 228 PostTask(FROM_HERE, 229 base::Bind(&ServiceDiscoveryMessageHandler::DiscoverServices, 230 base::Unretained(this), id, force_update)); 231} 232 233void ServiceDiscoveryMessageHandler::OnDestroyWatcher(uint64 id) { 234 PostTask(FROM_HERE, 235 base::Bind(&ServiceDiscoveryMessageHandler::DestroyWatcher, 236 base::Unretained(this), id)); 237} 238 239void ServiceDiscoveryMessageHandler::OnResolveService( 240 uint64 id, 241 const std::string& service_name) { 242 PostTask(FROM_HERE, 243 base::Bind(&ServiceDiscoveryMessageHandler::ResolveService, 244 base::Unretained(this), id, service_name)); 245} 246 247void ServiceDiscoveryMessageHandler::OnDestroyResolver(uint64 id) { 248 PostTask(FROM_HERE, 249 base::Bind(&ServiceDiscoveryMessageHandler::DestroyResolver, 250 base::Unretained(this), id)); 251} 252 253void ServiceDiscoveryMessageHandler::OnResolveLocalDomain( 254 uint64 id, const std::string& domain, 255 net::AddressFamily address_family) { 256 PostTask(FROM_HERE, 257 base::Bind(&ServiceDiscoveryMessageHandler::ResolveLocalDomain, 258 base::Unretained(this), id, domain, address_family)); 259} 260 261void ServiceDiscoveryMessageHandler::OnDestroyLocalDomainResolver(uint64 id) { 262 PostTask(FROM_HERE, 263 base::Bind( 264 &ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver, 265 base::Unretained(this), id)); 266} 267 268void ServiceDiscoveryMessageHandler::StartWatcher( 269 uint64 id, 270 const std::string& service_type) { 271 if (!service_discovery_client_) 272 return; 273 DCHECK(!ContainsKey(service_watchers_, id)); 274 scoped_ptr<ServiceWatcher> watcher( 275 service_discovery_client_->CreateServiceWatcher( 276 service_type, 277 base::Bind(&ServiceDiscoveryMessageHandler::OnServiceUpdated, 278 base::Unretained(this), id))); 279 watcher->Start(); 280 service_watchers_[id].reset(watcher.release()); 281} 282 283void ServiceDiscoveryMessageHandler::DiscoverServices(uint64 id, 284 bool force_update) { 285 if (!service_discovery_client_) 286 return; 287 DCHECK(ContainsKey(service_watchers_, id)); 288 service_watchers_[id]->DiscoverNewServices(force_update); 289} 290 291void ServiceDiscoveryMessageHandler::DestroyWatcher(uint64 id) { 292 if (!service_discovery_client_) 293 return; 294 DCHECK(ContainsKey(service_watchers_, id)); 295 service_watchers_.erase(id); 296} 297 298void ServiceDiscoveryMessageHandler::ResolveService( 299 uint64 id, 300 const std::string& service_name) { 301 if (!service_discovery_client_) 302 return; 303 DCHECK(!ContainsKey(service_resolvers_, id)); 304 scoped_ptr<ServiceResolver> resolver( 305 service_discovery_client_->CreateServiceResolver( 306 service_name, 307 base::Bind(&ServiceDiscoveryMessageHandler::OnServiceResolved, 308 base::Unretained(this), id))); 309 resolver->StartResolving(); 310 service_resolvers_[id].reset(resolver.release()); 311} 312 313void ServiceDiscoveryMessageHandler::DestroyResolver(uint64 id) { 314 if (!service_discovery_client_) 315 return; 316 DCHECK(ContainsKey(service_resolvers_, id)); 317 service_resolvers_.erase(id); 318} 319 320void ServiceDiscoveryMessageHandler::ResolveLocalDomain( 321 uint64 id, 322 const std::string& domain, 323 net::AddressFamily address_family) { 324 if (!service_discovery_client_) 325 return; 326 DCHECK(!ContainsKey(local_domain_resolvers_, id)); 327 scoped_ptr<LocalDomainResolver> resolver( 328 service_discovery_client_->CreateLocalDomainResolver( 329 domain, address_family, 330 base::Bind(&ServiceDiscoveryMessageHandler::OnLocalDomainResolved, 331 base::Unretained(this), id))); 332 resolver->Start(); 333 local_domain_resolvers_[id].reset(resolver.release()); 334} 335 336void ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver(uint64 id) { 337 if (!service_discovery_client_) 338 return; 339 DCHECK(ContainsKey(local_domain_resolvers_, id)); 340 local_domain_resolvers_.erase(id); 341} 342 343void ServiceDiscoveryMessageHandler::ShutdownLocalDiscovery() { 344 discovery_task_runner_->PostTask( 345 FROM_HERE, 346 base::Bind(&ServiceDiscoveryMessageHandler::ShutdownOnIOThread, 347 base::Unretained(this))); 348 349 // This will wait for message loop to drain, so ShutdownOnIOThread will 350 // definitely be called. 351 discovery_thread_.reset(); 352} 353 354void ServiceDiscoveryMessageHandler::ShutdownOnIOThread() { 355 service_watchers_.clear(); 356 service_resolvers_.clear(); 357 local_domain_resolvers_.clear(); 358 359 service_discovery_client_.reset(); 360 mdns_client_.reset(); 361} 362 363void ServiceDiscoveryMessageHandler::OnServiceUpdated( 364 uint64 id, 365 ServiceWatcher::UpdateType update, 366 const std::string& name) { 367 DCHECK(service_discovery_client_); 368 utility_task_runner_->PostTask(FROM_HERE, 369 base::Bind(&SendServiceUpdated, id, update, name)); 370} 371 372void ServiceDiscoveryMessageHandler::OnServiceResolved( 373 uint64 id, 374 ServiceResolver::RequestStatus status, 375 const ServiceDescription& description) { 376 DCHECK(service_discovery_client_); 377 utility_task_runner_->PostTask(FROM_HERE, 378 base::Bind(&SendServiceResolved, id, status, description)); 379} 380 381void ServiceDiscoveryMessageHandler::OnLocalDomainResolved( 382 uint64 id, 383 bool success, 384 const net::IPAddressNumber& address) { 385 DCHECK(service_discovery_client_); 386 utility_task_runner_->PostTask(FROM_HERE, base::Bind(&SendLocalDomainResolved, 387 id, success, address)); 388} 389 390 391} // namespace local_discovery 392