1// Copyright 2014 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "chrome/browser/local_discovery/service_discovery_client_mdns.h" 6 7#include "base/memory/scoped_vector.h" 8#include "base/metrics/histogram.h" 9#include "chrome/common/local_discovery/service_discovery_client_impl.h" 10#include "content/public/browser/browser_thread.h" 11#include "net/dns/mdns_client.h" 12#include "net/udp/datagram_server_socket.h" 13 14namespace local_discovery { 15 16using content::BrowserThread; 17 18// Base class for objects returned by ServiceDiscoveryClient implementation. 19// Handles interaction of client code on UI thread end net code on mdns thread. 20class ServiceDiscoveryClientMdns::Proxy { 21 public: 22 typedef base::WeakPtr<Proxy> WeakPtr; 23 24 explicit Proxy(ServiceDiscoveryClientMdns* client) 25 : client_(client), 26 weak_ptr_factory_(this) { 27 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 28 client_->proxies_.AddObserver(this); 29 } 30 31 virtual ~Proxy() { 32 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 33 client_->proxies_.RemoveObserver(this); 34 } 35 36 // Notify proxies that mDNS layer is going to be destroyed. 37 virtual void OnMdnsDestroy() = 0; 38 39 // Notify proxies that new mDNS instance is ready. 40 virtual void OnNewMdnsReady() {} 41 42 // Run callback using this method to abort callback if instance of |Proxy| 43 // is deleted. 44 void RunCallback(const base::Closure& callback) { 45 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 46 callback.Run(); 47 } 48 49 protected: 50 bool PostToMdnsThread(const base::Closure& task) { 51 return client_->PostToMdnsThread(task); 52 } 53 54 static bool PostToUIThread(const base::Closure& task) { 55 return BrowserThread::PostTask(BrowserThread::UI, FROM_HERE, task); 56 } 57 58 ServiceDiscoveryClient* client() { 59 return client_->client_.get(); 60 } 61 62 WeakPtr GetWeakPtr() { 63 return weak_ptr_factory_.GetWeakPtr(); 64 } 65 66 template<class T> 67 void DeleteOnMdnsThread(T* t) { 68 if (!t) 69 return; 70 if (!client_->mdns_runner_->DeleteSoon(FROM_HERE, t)) 71 delete t; 72 } 73 74 private: 75 scoped_refptr<ServiceDiscoveryClientMdns> client_; 76 base::WeakPtrFactory<Proxy> weak_ptr_factory_; 77 78 DISALLOW_COPY_AND_ASSIGN(Proxy); 79}; 80 81namespace { 82 83const size_t kMaxDelayedTasks = 10000; 84const int kMaxRestartAttempts = 10; 85const int kRestartDelayOnNetworkChangeSeconds = 3; 86 87typedef base::Callback<void(bool)> MdnsInitCallback; 88 89class SocketFactory : public net::MDnsSocketFactory { 90 public: 91 explicit SocketFactory(const net::InterfaceIndexFamilyList& interfaces) 92 : interfaces_(interfaces) {} 93 94 // net::MDnsSocketFactory implementation: 95 virtual void CreateSockets( 96 ScopedVector<net::DatagramServerSocket>* sockets) OVERRIDE { 97 for (size_t i = 0; i < interfaces_.size(); ++i) { 98 DCHECK(interfaces_[i].second == net::ADDRESS_FAMILY_IPV4 || 99 interfaces_[i].second == net::ADDRESS_FAMILY_IPV6); 100 scoped_ptr<net::DatagramServerSocket> socket( 101 CreateAndBindMDnsSocket(interfaces_[i].second, interfaces_[i].first)); 102 if (socket) 103 sockets->push_back(socket.release()); 104 } 105 } 106 107 private: 108 net::InterfaceIndexFamilyList interfaces_; 109}; 110 111void InitMdns(const MdnsInitCallback& on_initialized, 112 const net::InterfaceIndexFamilyList& interfaces, 113 net::MDnsClient* mdns) { 114 SocketFactory socket_factory(interfaces); 115 BrowserThread::PostTask(BrowserThread::UI, FROM_HERE, 116 base::Bind(on_initialized, 117 mdns->StartListening(&socket_factory))); 118} 119 120template<class T> 121class ProxyBase : public ServiceDiscoveryClientMdns::Proxy, public T { 122 public: 123 typedef base::WeakPtr<Proxy> WeakPtr; 124 typedef ProxyBase<T> Base; 125 126 explicit ProxyBase(ServiceDiscoveryClientMdns* client) 127 : Proxy(client) { 128 } 129 130 virtual ~ProxyBase() { 131 DeleteOnMdnsThread(implementation_.release()); 132 } 133 134 virtual void OnMdnsDestroy() OVERRIDE { 135 DeleteOnMdnsThread(implementation_.release()); 136 }; 137 138 protected: 139 void set_implementation(scoped_ptr<T> implementation) { 140 implementation_ = implementation.Pass(); 141 } 142 143 T* implementation() const { 144 return implementation_.get(); 145 } 146 147 private: 148 scoped_ptr<T> implementation_; 149 DISALLOW_COPY_AND_ASSIGN(ProxyBase); 150}; 151 152class ServiceWatcherProxy : public ProxyBase<ServiceWatcher> { 153 public: 154 ServiceWatcherProxy(ServiceDiscoveryClientMdns* client_mdns, 155 const std::string& service_type, 156 const ServiceWatcher::UpdatedCallback& callback) 157 : ProxyBase(client_mdns), 158 service_type_(service_type), 159 callback_(callback) { 160 // It's safe to call |CreateServiceWatcher| on UI thread, because 161 // |MDnsClient| is not used there. It's simplify implementation. 162 set_implementation(client()->CreateServiceWatcher( 163 service_type, 164 base::Bind(&ServiceWatcherProxy::OnCallback, GetWeakPtr(), callback))); 165 } 166 167 // ServiceWatcher methods. 168 virtual void Start() OVERRIDE { 169 if (implementation()) 170 PostToMdnsThread(base::Bind(&ServiceWatcher::Start, 171 base::Unretained(implementation()))); 172 } 173 174 virtual void DiscoverNewServices(bool force_update) OVERRIDE { 175 if (implementation()) 176 PostToMdnsThread(base::Bind(&ServiceWatcher::DiscoverNewServices, 177 base::Unretained(implementation()), 178 force_update)); 179 } 180 181 virtual void SetActivelyRefreshServices( 182 bool actively_refresh_services) OVERRIDE { 183 if (implementation()) 184 PostToMdnsThread(base::Bind(&ServiceWatcher::SetActivelyRefreshServices, 185 base::Unretained(implementation()), 186 actively_refresh_services)); 187 } 188 189 virtual std::string GetServiceType() const OVERRIDE { 190 return service_type_; 191 } 192 193 virtual void OnNewMdnsReady() OVERRIDE { 194 if (!implementation()) 195 callback_.Run(ServiceWatcher::UPDATE_INVALIDATED, ""); 196 } 197 198 private: 199 static void OnCallback(const WeakPtr& proxy, 200 const ServiceWatcher::UpdatedCallback& callback, 201 UpdateType a1, 202 const std::string& a2) { 203 DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI)); 204 PostToUIThread(base::Bind(&Base::RunCallback, proxy, 205 base::Bind(callback, a1, a2))); 206 } 207 std::string service_type_; 208 ServiceWatcher::UpdatedCallback callback_; 209 DISALLOW_COPY_AND_ASSIGN(ServiceWatcherProxy); 210}; 211 212class ServiceResolverProxy : public ProxyBase<ServiceResolver> { 213 public: 214 ServiceResolverProxy(ServiceDiscoveryClientMdns* client_mdns, 215 const std::string& service_name, 216 const ServiceResolver::ResolveCompleteCallback& callback) 217 : ProxyBase(client_mdns), 218 service_name_(service_name) { 219 // It's safe to call |CreateServiceResolver| on UI thread, because 220 // |MDnsClient| is not used there. It's simplify implementation. 221 set_implementation(client()->CreateServiceResolver( 222 service_name, 223 base::Bind(&ServiceResolverProxy::OnCallback, GetWeakPtr(), callback))); 224 } 225 226 // ServiceResolver methods. 227 virtual void StartResolving() OVERRIDE { 228 if (implementation()) 229 PostToMdnsThread(base::Bind(&ServiceResolver::StartResolving, 230 base::Unretained(implementation()))); 231 }; 232 233 virtual std::string GetName() const OVERRIDE { 234 return service_name_; 235 } 236 237 private: 238 static void OnCallback( 239 const WeakPtr& proxy, 240 const ServiceResolver::ResolveCompleteCallback& callback, 241 RequestStatus a1, 242 const ServiceDescription& a2) { 243 DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI)); 244 PostToUIThread(base::Bind(&Base::RunCallback, proxy, 245 base::Bind(callback, a1, a2))); 246 } 247 248 std::string service_name_; 249 DISALLOW_COPY_AND_ASSIGN(ServiceResolverProxy); 250}; 251 252class LocalDomainResolverProxy : public ProxyBase<LocalDomainResolver> { 253 public: 254 LocalDomainResolverProxy( 255 ServiceDiscoveryClientMdns* client_mdns, 256 const std::string& domain, 257 net::AddressFamily address_family, 258 const LocalDomainResolver::IPAddressCallback& callback) 259 : ProxyBase(client_mdns) { 260 // It's safe to call |CreateLocalDomainResolver| on UI thread, because 261 // |MDnsClient| is not used there. It's simplify implementation. 262 set_implementation(client()->CreateLocalDomainResolver( 263 domain, 264 address_family, 265 base::Bind( 266 &LocalDomainResolverProxy::OnCallback, GetWeakPtr(), callback))); 267 } 268 269 // LocalDomainResolver methods. 270 virtual void Start() OVERRIDE { 271 if (implementation()) 272 PostToMdnsThread(base::Bind(&LocalDomainResolver::Start, 273 base::Unretained(implementation()))); 274 }; 275 276 private: 277 static void OnCallback(const WeakPtr& proxy, 278 const LocalDomainResolver::IPAddressCallback& callback, 279 bool a1, 280 const net::IPAddressNumber& a2, 281 const net::IPAddressNumber& a3) { 282 DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI)); 283 PostToUIThread(base::Bind(&Base::RunCallback, proxy, 284 base::Bind(callback, a1, a2, a3))); 285 } 286 287 DISALLOW_COPY_AND_ASSIGN(LocalDomainResolverProxy); 288}; 289 290} // namespace 291 292ServiceDiscoveryClientMdns::ServiceDiscoveryClientMdns() 293 : mdns_runner_( 294 BrowserThread::GetMessageLoopProxyForThread(BrowserThread::IO)), 295 restart_attempts_(0), 296 need_dalay_mdns_tasks_(true), 297 weak_ptr_factory_(this) { 298 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 299 net::NetworkChangeNotifier::AddNetworkChangeObserver(this); 300 StartNewClient(); 301} 302 303scoped_ptr<ServiceWatcher> ServiceDiscoveryClientMdns::CreateServiceWatcher( 304 const std::string& service_type, 305 const ServiceWatcher::UpdatedCallback& callback) { 306 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 307 return scoped_ptr<ServiceWatcher>( 308 new ServiceWatcherProxy(this, service_type, callback)); 309} 310 311scoped_ptr<ServiceResolver> ServiceDiscoveryClientMdns::CreateServiceResolver( 312 const std::string& service_name, 313 const ServiceResolver::ResolveCompleteCallback& callback) { 314 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 315 return scoped_ptr<ServiceResolver>( 316 new ServiceResolverProxy(this, service_name, callback)); 317} 318 319scoped_ptr<LocalDomainResolver> 320ServiceDiscoveryClientMdns::CreateLocalDomainResolver( 321 const std::string& domain, 322 net::AddressFamily address_family, 323 const LocalDomainResolver::IPAddressCallback& callback) { 324 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 325 return scoped_ptr<LocalDomainResolver>( 326 new LocalDomainResolverProxy(this, domain, address_family, callback)); 327} 328 329ServiceDiscoveryClientMdns::~ServiceDiscoveryClientMdns() { 330 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 331 net::NetworkChangeNotifier::RemoveNetworkChangeObserver(this); 332 DestroyMdns(); 333} 334 335void ServiceDiscoveryClientMdns::OnNetworkChanged( 336 net::NetworkChangeNotifier::ConnectionType type) { 337 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 338 // Only network changes resets counter. 339 restart_attempts_ = 0; 340 ScheduleStartNewClient(); 341} 342 343void ServiceDiscoveryClientMdns::ScheduleStartNewClient() { 344 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 345 OnBeforeMdnsDestroy(); 346 if (restart_attempts_ < kMaxRestartAttempts) { 347 base::MessageLoop::current()->PostDelayedTask( 348 FROM_HERE, 349 base::Bind(&ServiceDiscoveryClientMdns::StartNewClient, 350 weak_ptr_factory_.GetWeakPtr()), 351 base::TimeDelta::FromSeconds( 352 kRestartDelayOnNetworkChangeSeconds * (1 << restart_attempts_))); 353 } else { 354 ReportSuccess(); 355 } 356} 357 358void ServiceDiscoveryClientMdns::StartNewClient() { 359 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 360 ++restart_attempts_; 361 DestroyMdns(); 362 mdns_.reset(net::MDnsClient::CreateDefault().release()); 363 client_.reset(new ServiceDiscoveryClientImpl(mdns_.get())); 364 BrowserThread::PostTaskAndReplyWithResult( 365 BrowserThread::FILE, 366 FROM_HERE, 367 base::Bind(&net::GetMDnsInterfacesToBind), 368 base::Bind(&ServiceDiscoveryClientMdns::OnInterfaceListReady, 369 weak_ptr_factory_.GetWeakPtr())); 370} 371 372void ServiceDiscoveryClientMdns::OnInterfaceListReady( 373 const net::InterfaceIndexFamilyList& interfaces) { 374 mdns_runner_->PostTask( 375 FROM_HERE, 376 base::Bind(&InitMdns, 377 base::Bind(&ServiceDiscoveryClientMdns::OnMdnsInitialized, 378 weak_ptr_factory_.GetWeakPtr()), 379 interfaces, 380 base::Unretained(mdns_.get()))); 381} 382 383void ServiceDiscoveryClientMdns::OnMdnsInitialized(bool success) { 384 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 385 if (!success) { 386 ScheduleStartNewClient(); 387 return; 388 } 389 ReportSuccess(); 390 391 // Initialization is done, no need to delay tasks. 392 need_dalay_mdns_tasks_ = false; 393 for (size_t i = 0; i < delayed_tasks_.size(); ++i) 394 mdns_runner_->PostTask(FROM_HERE, delayed_tasks_[i]); 395 delayed_tasks_.clear(); 396 397 FOR_EACH_OBSERVER(Proxy, proxies_, OnNewMdnsReady()); 398} 399 400void ServiceDiscoveryClientMdns::ReportSuccess() { 401 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 402 UMA_HISTOGRAM_COUNTS_100("LocalDiscovery.ClientRestartAttempts", 403 restart_attempts_); 404} 405 406void ServiceDiscoveryClientMdns::OnBeforeMdnsDestroy() { 407 need_dalay_mdns_tasks_ = true; 408 delayed_tasks_.clear(); 409 weak_ptr_factory_.InvalidateWeakPtrs(); 410 FOR_EACH_OBSERVER(Proxy, proxies_, OnMdnsDestroy()); 411} 412 413void ServiceDiscoveryClientMdns::DestroyMdns() { 414 OnBeforeMdnsDestroy(); 415 // After calling |Proxy::OnMdnsDestroy| all references to client_ and mdns_ 416 // should be destroyed. 417 if (client_) 418 mdns_runner_->DeleteSoon(FROM_HERE, client_.release()); 419 if (mdns_) 420 mdns_runner_->DeleteSoon(FROM_HERE, mdns_.release()); 421} 422 423bool ServiceDiscoveryClientMdns::PostToMdnsThread(const base::Closure& task) { 424 // The first task on IO thread for each |mdns_| instance must be |InitMdns|. 425 // |OnInterfaceListReady| could be delayed by |GetMDnsInterfacesToBind| 426 // running on FILE thread, so |PostToMdnsThread| could be called to post 427 // task for |mdns_| that is not initialized yet. 428 if (!need_dalay_mdns_tasks_) 429 return mdns_runner_->PostTask(FROM_HERE, task); 430 if (kMaxDelayedTasks > delayed_tasks_.size()) 431 delayed_tasks_.push_back(task); 432 return true; 433} 434 435} // namespace local_discovery 436