1// Copyright 2013 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "chrome/utility/local_discovery/service_discovery_message_handler.h" 6 7#include <algorithm> 8 9#include "base/lazy_instance.h" 10#include "chrome/common/local_discovery/local_discovery_messages.h" 11#include "chrome/common/local_discovery/service_discovery_client_impl.h" 12#include "content/public/utility/utility_thread.h" 13#include "net/socket/socket_descriptor.h" 14#include "net/udp/datagram_server_socket.h" 15 16namespace local_discovery { 17 18namespace { 19 20void ClosePlatformSocket(net::SocketDescriptor socket); 21 22// Sets socket factory used by |net::CreatePlatformSocket|. Implemetation 23// keeps single socket that will be returned to the first call to 24// |net::CreatePlatformSocket| during object lifetime. 25class ScopedSocketFactory : public net::PlatformSocketFactory { 26 public: 27 explicit ScopedSocketFactory(net::SocketDescriptor socket) : socket_(socket) { 28 net::PlatformSocketFactory::SetInstance(this); 29 } 30 31 virtual ~ScopedSocketFactory() { 32 net::PlatformSocketFactory::SetInstance(NULL); 33 ClosePlatformSocket(socket_); 34 socket_ = net::kInvalidSocket; 35 } 36 37 virtual net::SocketDescriptor CreateSocket(int family, int type, 38 int protocol) OVERRIDE { 39 DCHECK_EQ(type, SOCK_DGRAM); 40 DCHECK(family == AF_INET || family == AF_INET6); 41 net::SocketDescriptor result = net::kInvalidSocket; 42 std::swap(result, socket_); 43 return result; 44 } 45 46 private: 47 net::SocketDescriptor socket_; 48 DISALLOW_COPY_AND_ASSIGN(ScopedSocketFactory); 49}; 50 51struct SocketInfo { 52 SocketInfo(net::SocketDescriptor socket, 53 net::AddressFamily address_family, 54 uint32 interface_index) 55 : socket(socket), 56 address_family(address_family), 57 interface_index(interface_index) { 58 } 59 net::SocketDescriptor socket; 60 net::AddressFamily address_family; 61 uint32 interface_index; 62}; 63 64// Returns list of sockets preallocated before. 65class PreCreatedMDnsSocketFactory : public net::MDnsSocketFactory { 66 public: 67 PreCreatedMDnsSocketFactory() {} 68 virtual ~PreCreatedMDnsSocketFactory() { 69 // Not empty if process exits too fast, before starting mDns code. If 70 // happened, destructors may crash accessing destroyed global objects. 71 sockets_.weak_clear(); 72 } 73 74 // net::MDnsSocketFactory implementation: 75 virtual void CreateSockets( 76 ScopedVector<net::DatagramServerSocket>* sockets) OVERRIDE { 77 sockets->swap(sockets_); 78 Reset(); 79 } 80 81 void AddSocket(const SocketInfo& socket_info) { 82 // Takes ownership of socket_info.socket; 83 ScopedSocketFactory platform_factory(socket_info.socket); 84 scoped_ptr<net::DatagramServerSocket> socket( 85 net::CreateAndBindMDnsSocket(socket_info.address_family, 86 socket_info.interface_index)); 87 if (socket) { 88 socket->DetachFromThread(); 89 sockets_.push_back(socket.release()); 90 } 91 } 92 93 void Reset() { 94 sockets_.clear(); 95 } 96 97 private: 98 ScopedVector<net::DatagramServerSocket> sockets_; 99 100 DISALLOW_COPY_AND_ASSIGN(PreCreatedMDnsSocketFactory); 101}; 102 103base::LazyInstance<PreCreatedMDnsSocketFactory> 104 g_local_discovery_socket_factory = LAZY_INSTANCE_INITIALIZER; 105 106#if defined(OS_WIN) 107 108void ClosePlatformSocket(net::SocketDescriptor socket) { 109 ::closesocket(socket); 110} 111 112void StaticInitializeSocketFactory() { 113 net::InterfaceIndexFamilyList interfaces(net::GetMDnsInterfacesToBind()); 114 for (size_t i = 0; i < interfaces.size(); ++i) { 115 DCHECK(interfaces[i].second == net::ADDRESS_FAMILY_IPV4 || 116 interfaces[i].second == net::ADDRESS_FAMILY_IPV6); 117 net::SocketDescriptor descriptor = 118 net::CreatePlatformSocket( 119 net::ConvertAddressFamily(interfaces[i].second), SOCK_DGRAM, 120 IPPROTO_UDP); 121 g_local_discovery_socket_factory.Get().AddSocket( 122 SocketInfo(descriptor, interfaces[i].second, interfaces[i].first)); 123 } 124} 125 126#else // OS_WIN 127 128void ClosePlatformSocket(net::SocketDescriptor socket) { 129 ::close(socket); 130} 131 132void StaticInitializeSocketFactory() { 133} 134 135#endif // OS_WIN 136 137void SendHostMessageOnUtilityThread(IPC::Message* msg) { 138 content::UtilityThread::Get()->Send(msg); 139} 140 141std::string WatcherUpdateToString(ServiceWatcher::UpdateType update) { 142 switch (update) { 143 case ServiceWatcher::UPDATE_ADDED: 144 return "UPDATE_ADDED"; 145 case ServiceWatcher::UPDATE_CHANGED: 146 return "UPDATE_CHANGED"; 147 case ServiceWatcher::UPDATE_REMOVED: 148 return "UPDATE_REMOVED"; 149 case ServiceWatcher::UPDATE_INVALIDATED: 150 return "UPDATE_INVALIDATED"; 151 } 152 return "Unknown Update"; 153} 154 155std::string ResolverStatusToString(ServiceResolver::RequestStatus status) { 156 switch (status) { 157 case ServiceResolver::STATUS_SUCCESS: 158 return "STATUS_SUCESS"; 159 case ServiceResolver::STATUS_REQUEST_TIMEOUT: 160 return "STATUS_REQUEST_TIMEOUT"; 161 case ServiceResolver::STATUS_KNOWN_NONEXISTENT: 162 return "STATUS_KNOWN_NONEXISTENT"; 163 } 164 return "Unknown Status"; 165} 166 167} // namespace 168 169ServiceDiscoveryMessageHandler::ServiceDiscoveryMessageHandler() { 170} 171 172ServiceDiscoveryMessageHandler::~ServiceDiscoveryMessageHandler() { 173 DCHECK(!discovery_thread_); 174} 175 176void ServiceDiscoveryMessageHandler::PreSandboxStartup() { 177 StaticInitializeSocketFactory(); 178} 179 180void ServiceDiscoveryMessageHandler::InitializeMdns() { 181 if (service_discovery_client_ || mdns_client_) 182 return; 183 184 mdns_client_ = net::MDnsClient::CreateDefault(); 185 bool result = 186 mdns_client_->StartListening(g_local_discovery_socket_factory.Pointer()); 187 // Close unused sockets. 188 g_local_discovery_socket_factory.Get().Reset(); 189 if (!result) { 190 VLOG(1) << "Failed to start MDnsClient"; 191 Send(new LocalDiscoveryHostMsg_Error()); 192 return; 193 } 194 195 service_discovery_client_.reset( 196 new local_discovery::ServiceDiscoveryClientImpl(mdns_client_.get())); 197} 198 199bool ServiceDiscoveryMessageHandler::InitializeThread() { 200 if (discovery_task_runner_.get()) 201 return true; 202 if (discovery_thread_) 203 return false; 204 utility_task_runner_ = base::MessageLoop::current()->message_loop_proxy(); 205 discovery_thread_.reset(new base::Thread("ServiceDiscoveryThread")); 206 base::Thread::Options thread_options(base::MessageLoop::TYPE_IO, 0); 207 if (discovery_thread_->StartWithOptions(thread_options)) { 208 discovery_task_runner_ = discovery_thread_->message_loop_proxy(); 209 discovery_task_runner_->PostTask(FROM_HERE, 210 base::Bind(&ServiceDiscoveryMessageHandler::InitializeMdns, 211 base::Unretained(this))); 212 } 213 return discovery_task_runner_.get() != NULL; 214} 215 216bool ServiceDiscoveryMessageHandler::OnMessageReceived( 217 const IPC::Message& message) { 218 bool handled = true; 219 IPC_BEGIN_MESSAGE_MAP(ServiceDiscoveryMessageHandler, message) 220#if defined(OS_POSIX) 221 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetSockets, OnSetSockets) 222#endif // OS_POSIX 223 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_StartWatcher, OnStartWatcher) 224 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DiscoverServices, OnDiscoverServices) 225 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetActivelyRefreshServices, 226 OnSetActivelyRefreshServices) 227 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyWatcher, OnDestroyWatcher) 228 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveService, OnResolveService) 229 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyResolver, OnDestroyResolver) 230 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveLocalDomain, 231 OnResolveLocalDomain) 232 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyLocalDomainResolver, 233 OnDestroyLocalDomainResolver) 234 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ShutdownLocalDiscovery, 235 ShutdownLocalDiscovery) 236 IPC_MESSAGE_UNHANDLED(handled = false) 237 IPC_END_MESSAGE_MAP() 238 return handled; 239} 240 241void ServiceDiscoveryMessageHandler::PostTask( 242 const tracked_objects::Location& from_here, 243 const base::Closure& task) { 244 if (!InitializeThread()) 245 return; 246 discovery_task_runner_->PostTask(from_here, task); 247} 248 249#if defined(OS_POSIX) 250void ServiceDiscoveryMessageHandler::OnSetSockets( 251 const std::vector<LocalDiscoveryMsg_SocketInfo>& sockets) { 252 for (size_t i = 0; i < sockets.size(); ++i) { 253 g_local_discovery_socket_factory.Get().AddSocket( 254 SocketInfo(sockets[i].descriptor.fd, sockets[i].address_family, 255 sockets[i].interface_index)); 256 } 257} 258#endif // OS_POSIX 259 260void ServiceDiscoveryMessageHandler::OnStartWatcher( 261 uint64 id, 262 const std::string& service_type) { 263 PostTask(FROM_HERE, 264 base::Bind(&ServiceDiscoveryMessageHandler::StartWatcher, 265 base::Unretained(this), id, service_type)); 266} 267 268void ServiceDiscoveryMessageHandler::OnDiscoverServices(uint64 id, 269 bool force_update) { 270 PostTask(FROM_HERE, 271 base::Bind(&ServiceDiscoveryMessageHandler::DiscoverServices, 272 base::Unretained(this), id, force_update)); 273} 274 275void ServiceDiscoveryMessageHandler::OnSetActivelyRefreshServices( 276 uint64 id, bool actively_refresh_services) { 277 PostTask(FROM_HERE, 278 base::Bind( 279 &ServiceDiscoveryMessageHandler::SetActivelyRefreshServices, 280 base::Unretained(this), id, actively_refresh_services)); 281} 282 283void ServiceDiscoveryMessageHandler::OnDestroyWatcher(uint64 id) { 284 PostTask(FROM_HERE, 285 base::Bind(&ServiceDiscoveryMessageHandler::DestroyWatcher, 286 base::Unretained(this), id)); 287} 288 289void ServiceDiscoveryMessageHandler::OnResolveService( 290 uint64 id, 291 const std::string& service_name) { 292 PostTask(FROM_HERE, 293 base::Bind(&ServiceDiscoveryMessageHandler::ResolveService, 294 base::Unretained(this), id, service_name)); 295} 296 297void ServiceDiscoveryMessageHandler::OnDestroyResolver(uint64 id) { 298 PostTask(FROM_HERE, 299 base::Bind(&ServiceDiscoveryMessageHandler::DestroyResolver, 300 base::Unretained(this), id)); 301} 302 303void ServiceDiscoveryMessageHandler::OnResolveLocalDomain( 304 uint64 id, const std::string& domain, 305 net::AddressFamily address_family) { 306 PostTask(FROM_HERE, 307 base::Bind(&ServiceDiscoveryMessageHandler::ResolveLocalDomain, 308 base::Unretained(this), id, domain, address_family)); 309} 310 311void ServiceDiscoveryMessageHandler::OnDestroyLocalDomainResolver(uint64 id) { 312 PostTask(FROM_HERE, 313 base::Bind( 314 &ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver, 315 base::Unretained(this), id)); 316} 317 318void ServiceDiscoveryMessageHandler::StartWatcher( 319 uint64 id, 320 const std::string& service_type) { 321 VLOG(1) << "StartWatcher, id=" << id << ", type=" << service_type; 322 if (!service_discovery_client_) 323 return; 324 DCHECK(!ContainsKey(service_watchers_, id)); 325 scoped_ptr<ServiceWatcher> watcher( 326 service_discovery_client_->CreateServiceWatcher( 327 service_type, 328 base::Bind(&ServiceDiscoveryMessageHandler::OnServiceUpdated, 329 base::Unretained(this), id))); 330 watcher->Start(); 331 service_watchers_[id].reset(watcher.release()); 332} 333 334void ServiceDiscoveryMessageHandler::DiscoverServices(uint64 id, 335 bool force_update) { 336 VLOG(1) << "DiscoverServices, id=" << id; 337 if (!service_discovery_client_) 338 return; 339 DCHECK(ContainsKey(service_watchers_, id)); 340 service_watchers_[id]->DiscoverNewServices(force_update); 341} 342 343void ServiceDiscoveryMessageHandler::SetActivelyRefreshServices( 344 uint64 id, 345 bool actively_refresh_services) { 346 VLOG(1) << "ActivelyRefreshServices, id=" << id; 347 if (!service_discovery_client_) 348 return; 349 DCHECK(ContainsKey(service_watchers_, id)); 350 service_watchers_[id]->SetActivelyRefreshServices(actively_refresh_services); 351} 352 353void ServiceDiscoveryMessageHandler::DestroyWatcher(uint64 id) { 354 VLOG(1) << "DestoryWatcher, id=" << id; 355 if (!service_discovery_client_) 356 return; 357 service_watchers_.erase(id); 358} 359 360void ServiceDiscoveryMessageHandler::ResolveService( 361 uint64 id, 362 const std::string& service_name) { 363 VLOG(1) << "ResolveService, id=" << id << ", name=" << service_name; 364 if (!service_discovery_client_) 365 return; 366 DCHECK(!ContainsKey(service_resolvers_, id)); 367 scoped_ptr<ServiceResolver> resolver( 368 service_discovery_client_->CreateServiceResolver( 369 service_name, 370 base::Bind(&ServiceDiscoveryMessageHandler::OnServiceResolved, 371 base::Unretained(this), id))); 372 resolver->StartResolving(); 373 service_resolvers_[id].reset(resolver.release()); 374} 375 376void ServiceDiscoveryMessageHandler::DestroyResolver(uint64 id) { 377 VLOG(1) << "DestroyResolver, id=" << id; 378 if (!service_discovery_client_) 379 return; 380 service_resolvers_.erase(id); 381} 382 383void ServiceDiscoveryMessageHandler::ResolveLocalDomain( 384 uint64 id, 385 const std::string& domain, 386 net::AddressFamily address_family) { 387 VLOG(1) << "ResolveLocalDomain, id=" << id << ", domain=" << domain; 388 if (!service_discovery_client_) 389 return; 390 DCHECK(!ContainsKey(local_domain_resolvers_, id)); 391 scoped_ptr<LocalDomainResolver> resolver( 392 service_discovery_client_->CreateLocalDomainResolver( 393 domain, address_family, 394 base::Bind(&ServiceDiscoveryMessageHandler::OnLocalDomainResolved, 395 base::Unretained(this), id))); 396 resolver->Start(); 397 local_domain_resolvers_[id].reset(resolver.release()); 398} 399 400void ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver(uint64 id) { 401 VLOG(1) << "DestroyLocalDomainResolver, id=" << id; 402 if (!service_discovery_client_) 403 return; 404 local_domain_resolvers_.erase(id); 405} 406 407void ServiceDiscoveryMessageHandler::ShutdownLocalDiscovery() { 408 if (!discovery_task_runner_.get()) 409 return; 410 411 discovery_task_runner_->PostTask( 412 FROM_HERE, 413 base::Bind(&ServiceDiscoveryMessageHandler::ShutdownOnIOThread, 414 base::Unretained(this))); 415 416 // This will wait for message loop to drain, so ShutdownOnIOThread will 417 // definitely be called. 418 discovery_thread_.reset(); 419} 420 421void ServiceDiscoveryMessageHandler::ShutdownOnIOThread() { 422 VLOG(1) << "ShutdownLocalDiscovery"; 423 service_watchers_.clear(); 424 service_resolvers_.clear(); 425 local_domain_resolvers_.clear(); 426 service_discovery_client_.reset(); 427 mdns_client_.reset(); 428} 429 430void ServiceDiscoveryMessageHandler::OnServiceUpdated( 431 uint64 id, 432 ServiceWatcher::UpdateType update, 433 const std::string& name) { 434 VLOG(1) << "OnServiceUpdated, id=" << id 435 << ", status=" << WatcherUpdateToString(update) << ", name=" << name; 436 DCHECK(service_discovery_client_); 437 438 Send(new LocalDiscoveryHostMsg_WatcherCallback(id, update, name)); 439} 440 441void ServiceDiscoveryMessageHandler::OnServiceResolved( 442 uint64 id, 443 ServiceResolver::RequestStatus status, 444 const ServiceDescription& description) { 445 VLOG(1) << "OnServiceResolved, id=" << id 446 << ", status=" << ResolverStatusToString(status) 447 << ", name=" << description.service_name; 448 449 DCHECK(service_discovery_client_); 450 Send(new LocalDiscoveryHostMsg_ResolverCallback(id, status, description)); 451} 452 453void ServiceDiscoveryMessageHandler::OnLocalDomainResolved( 454 uint64 id, 455 bool success, 456 const net::IPAddressNumber& address_ipv4, 457 const net::IPAddressNumber& address_ipv6) { 458 VLOG(1) << "OnLocalDomainResolved, id=" << id 459 << ", IPv4=" << (address_ipv4.empty() ? "" : 460 net::IPAddressToString(address_ipv4)) 461 << ", IPv6=" << (address_ipv6.empty() ? "" : 462 net::IPAddressToString(address_ipv6)); 463 464 DCHECK(service_discovery_client_); 465 Send(new LocalDiscoveryHostMsg_LocalDomainResolverCallback( 466 id, success, address_ipv4, address_ipv6)); 467} 468 469void ServiceDiscoveryMessageHandler::Send(IPC::Message* msg) { 470 utility_task_runner_->PostTask(FROM_HERE, 471 base::Bind(&SendHostMessageOnUtilityThread, 472 msg)); 473} 474 475} // namespace local_discovery 476