1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/browser/local_discovery/service_discovery_client_mdns.h"
6
7#include "base/memory/scoped_vector.h"
8#include "base/metrics/histogram.h"
9#include "chrome/common/local_discovery/service_discovery_client_impl.h"
10#include "content/public/browser/browser_thread.h"
11#include "net/dns/mdns_client.h"
12#include "net/udp/datagram_server_socket.h"
13
14namespace local_discovery {
15
16using content::BrowserThread;
17
18// Base class for objects returned by ServiceDiscoveryClient implementation.
19// Handles interaction of client code on UI thread end net code on mdns thread.
20class ServiceDiscoveryClientMdns::Proxy {
21 public:
22  typedef base::WeakPtr<Proxy> WeakPtr;
23
24  explicit Proxy(ServiceDiscoveryClientMdns* client)
25      : client_(client),
26        weak_ptr_factory_(this) {
27    DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
28    client_->proxies_.AddObserver(this);
29  }
30
31  virtual ~Proxy() {
32    DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
33    client_->proxies_.RemoveObserver(this);
34  }
35
36  // Notify proxies that mDNS layer is going to be destroyed.
37  virtual void OnMdnsDestroy() = 0;
38
39  // Notify proxies that new mDNS instance is ready.
40  virtual void OnNewMdnsReady() {}
41
42  // Run callback using this method to abort callback if instance of |Proxy|
43  // is deleted.
44  void RunCallback(const base::Closure& callback) {
45    DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
46    callback.Run();
47  }
48
49 protected:
50  bool PostToMdnsThread(const base::Closure& task) {
51    return client_->PostToMdnsThread(task);
52  }
53
54  static bool PostToUIThread(const base::Closure& task) {
55    return BrowserThread::PostTask(BrowserThread::UI, FROM_HERE, task);
56  }
57
58  ServiceDiscoveryClient* client() {
59    return client_->client_.get();
60  }
61
62  WeakPtr GetWeakPtr() {
63    return weak_ptr_factory_.GetWeakPtr();
64  }
65
66  template<class T>
67  void DeleteOnMdnsThread(T* t) {
68    if (!t)
69      return;
70    if (!client_->mdns_runner_->DeleteSoon(FROM_HERE, t))
71      delete t;
72  }
73
74 private:
75  scoped_refptr<ServiceDiscoveryClientMdns> client_;
76  base::WeakPtrFactory<Proxy> weak_ptr_factory_;
77
78  DISALLOW_COPY_AND_ASSIGN(Proxy);
79};
80
81namespace {
82
83const size_t kMaxDelayedTasks = 10000;
84const int kMaxRestartAttempts = 10;
85const int kRestartDelayOnNetworkChangeSeconds = 3;
86
87typedef base::Callback<void(bool)> MdnsInitCallback;
88
89class SocketFactory : public net::MDnsSocketFactory {
90 public:
91  explicit SocketFactory(const net::InterfaceIndexFamilyList& interfaces)
92      : interfaces_(interfaces) {}
93
94  // net::MDnsSocketFactory implementation:
95  virtual void CreateSockets(
96      ScopedVector<net::DatagramServerSocket>* sockets) OVERRIDE {
97    for (size_t i = 0; i < interfaces_.size(); ++i) {
98      DCHECK(interfaces_[i].second == net::ADDRESS_FAMILY_IPV4 ||
99             interfaces_[i].second == net::ADDRESS_FAMILY_IPV6);
100      scoped_ptr<net::DatagramServerSocket> socket(
101          CreateAndBindMDnsSocket(interfaces_[i].second, interfaces_[i].first));
102      if (socket)
103        sockets->push_back(socket.release());
104    }
105  }
106
107 private:
108  net::InterfaceIndexFamilyList interfaces_;
109};
110
111void InitMdns(const MdnsInitCallback& on_initialized,
112              const net::InterfaceIndexFamilyList& interfaces,
113              net::MDnsClient* mdns) {
114  SocketFactory socket_factory(interfaces);
115  BrowserThread::PostTask(BrowserThread::UI, FROM_HERE,
116                          base::Bind(on_initialized,
117                                     mdns->StartListening(&socket_factory)));
118}
119
120template<class T>
121class ProxyBase : public ServiceDiscoveryClientMdns::Proxy, public T {
122 public:
123  typedef base::WeakPtr<Proxy> WeakPtr;
124  typedef ProxyBase<T> Base;
125
126  explicit ProxyBase(ServiceDiscoveryClientMdns* client)
127      : Proxy(client) {
128  }
129
130  virtual ~ProxyBase() {
131    DeleteOnMdnsThread(implementation_.release());
132  }
133
134  virtual void OnMdnsDestroy() OVERRIDE {
135    DeleteOnMdnsThread(implementation_.release());
136  };
137
138 protected:
139  void set_implementation(scoped_ptr<T> implementation) {
140    implementation_ = implementation.Pass();
141  }
142
143  T* implementation()  const {
144    return implementation_.get();
145  }
146
147 private:
148  scoped_ptr<T> implementation_;
149  DISALLOW_COPY_AND_ASSIGN(ProxyBase);
150};
151
152class ServiceWatcherProxy : public ProxyBase<ServiceWatcher> {
153 public:
154  ServiceWatcherProxy(ServiceDiscoveryClientMdns* client_mdns,
155                      const std::string& service_type,
156                      const ServiceWatcher::UpdatedCallback& callback)
157      : ProxyBase(client_mdns),
158        service_type_(service_type),
159        callback_(callback) {
160    // It's safe to call |CreateServiceWatcher| on UI thread, because
161    // |MDnsClient| is not used there. It's simplify implementation.
162    set_implementation(client()->CreateServiceWatcher(
163        service_type,
164        base::Bind(&ServiceWatcherProxy::OnCallback, GetWeakPtr(), callback)));
165  }
166
167  // ServiceWatcher methods.
168  virtual void Start() OVERRIDE {
169    if (implementation())
170      PostToMdnsThread(base::Bind(&ServiceWatcher::Start,
171                                  base::Unretained(implementation())));
172  }
173
174  virtual void DiscoverNewServices(bool force_update) OVERRIDE {
175    if (implementation())
176      PostToMdnsThread(base::Bind(&ServiceWatcher::DiscoverNewServices,
177                                  base::Unretained(implementation()),
178                                  force_update));
179  }
180
181  virtual void SetActivelyRefreshServices(
182      bool actively_refresh_services) OVERRIDE {
183    if (implementation())
184      PostToMdnsThread(base::Bind(&ServiceWatcher::SetActivelyRefreshServices,
185                                  base::Unretained(implementation()),
186                                  actively_refresh_services));
187  }
188
189  virtual std::string GetServiceType() const OVERRIDE {
190    return service_type_;
191  }
192
193  virtual void OnNewMdnsReady() OVERRIDE {
194    if (!implementation())
195      callback_.Run(ServiceWatcher::UPDATE_INVALIDATED, "");
196  }
197
198 private:
199  static void OnCallback(const WeakPtr& proxy,
200                         const ServiceWatcher::UpdatedCallback& callback,
201                         UpdateType a1,
202                         const std::string& a2) {
203    DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI));
204    PostToUIThread(base::Bind(&Base::RunCallback, proxy,
205                              base::Bind(callback, a1, a2)));
206  }
207  std::string service_type_;
208  ServiceWatcher::UpdatedCallback callback_;
209  DISALLOW_COPY_AND_ASSIGN(ServiceWatcherProxy);
210};
211
212class ServiceResolverProxy : public ProxyBase<ServiceResolver> {
213 public:
214  ServiceResolverProxy(ServiceDiscoveryClientMdns* client_mdns,
215                       const std::string& service_name,
216                       const ServiceResolver::ResolveCompleteCallback& callback)
217      : ProxyBase(client_mdns),
218        service_name_(service_name) {
219    // It's safe to call |CreateServiceResolver| on UI thread, because
220    // |MDnsClient| is not used there. It's simplify implementation.
221    set_implementation(client()->CreateServiceResolver(
222        service_name,
223        base::Bind(&ServiceResolverProxy::OnCallback, GetWeakPtr(), callback)));
224  }
225
226  // ServiceResolver methods.
227  virtual void StartResolving() OVERRIDE {
228    if (implementation())
229      PostToMdnsThread(base::Bind(&ServiceResolver::StartResolving,
230                                  base::Unretained(implementation())));
231  };
232
233  virtual std::string GetName() const OVERRIDE {
234    return service_name_;
235  }
236
237 private:
238  static void OnCallback(
239      const WeakPtr& proxy,
240      const ServiceResolver::ResolveCompleteCallback& callback,
241      RequestStatus a1,
242      const ServiceDescription& a2) {
243    DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI));
244    PostToUIThread(base::Bind(&Base::RunCallback, proxy,
245                              base::Bind(callback, a1, a2)));
246  }
247
248  std::string service_name_;
249  DISALLOW_COPY_AND_ASSIGN(ServiceResolverProxy);
250};
251
252class LocalDomainResolverProxy : public ProxyBase<LocalDomainResolver> {
253 public:
254  LocalDomainResolverProxy(
255      ServiceDiscoveryClientMdns* client_mdns,
256      const std::string& domain,
257      net::AddressFamily address_family,
258      const LocalDomainResolver::IPAddressCallback& callback)
259      : ProxyBase(client_mdns) {
260    // It's safe to call |CreateLocalDomainResolver| on UI thread, because
261    // |MDnsClient| is not used there. It's simplify implementation.
262    set_implementation(client()->CreateLocalDomainResolver(
263        domain,
264        address_family,
265        base::Bind(
266            &LocalDomainResolverProxy::OnCallback, GetWeakPtr(), callback)));
267  }
268
269  // LocalDomainResolver methods.
270  virtual void Start() OVERRIDE {
271    if (implementation())
272      PostToMdnsThread(base::Bind(&LocalDomainResolver::Start,
273                                  base::Unretained(implementation())));
274  };
275
276 private:
277  static void OnCallback(const WeakPtr& proxy,
278                         const LocalDomainResolver::IPAddressCallback& callback,
279                         bool a1,
280                         const net::IPAddressNumber& a2,
281                         const net::IPAddressNumber& a3) {
282    DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI));
283    PostToUIThread(base::Bind(&Base::RunCallback, proxy,
284                              base::Bind(callback, a1, a2, a3)));
285  }
286
287  DISALLOW_COPY_AND_ASSIGN(LocalDomainResolverProxy);
288};
289
290}  // namespace
291
292ServiceDiscoveryClientMdns::ServiceDiscoveryClientMdns()
293    : mdns_runner_(
294          BrowserThread::GetMessageLoopProxyForThread(BrowserThread::IO)),
295      restart_attempts_(0),
296      need_dalay_mdns_tasks_(true),
297      weak_ptr_factory_(this) {
298  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
299  net::NetworkChangeNotifier::AddNetworkChangeObserver(this);
300  StartNewClient();
301}
302
303scoped_ptr<ServiceWatcher> ServiceDiscoveryClientMdns::CreateServiceWatcher(
304    const std::string& service_type,
305    const ServiceWatcher::UpdatedCallback& callback) {
306  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
307  return scoped_ptr<ServiceWatcher>(
308      new ServiceWatcherProxy(this, service_type, callback));
309}
310
311scoped_ptr<ServiceResolver> ServiceDiscoveryClientMdns::CreateServiceResolver(
312    const std::string& service_name,
313    const ServiceResolver::ResolveCompleteCallback& callback) {
314  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
315  return scoped_ptr<ServiceResolver>(
316      new ServiceResolverProxy(this, service_name, callback));
317}
318
319scoped_ptr<LocalDomainResolver>
320ServiceDiscoveryClientMdns::CreateLocalDomainResolver(
321    const std::string& domain,
322    net::AddressFamily address_family,
323    const LocalDomainResolver::IPAddressCallback& callback) {
324  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
325  return scoped_ptr<LocalDomainResolver>(
326      new LocalDomainResolverProxy(this, domain, address_family, callback));
327}
328
329ServiceDiscoveryClientMdns::~ServiceDiscoveryClientMdns() {
330  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
331  net::NetworkChangeNotifier::RemoveNetworkChangeObserver(this);
332  DestroyMdns();
333}
334
335void ServiceDiscoveryClientMdns::OnNetworkChanged(
336    net::NetworkChangeNotifier::ConnectionType type) {
337  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
338  // Only network changes resets counter.
339  restart_attempts_ = 0;
340  ScheduleStartNewClient();
341}
342
343void ServiceDiscoveryClientMdns::ScheduleStartNewClient() {
344  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
345  OnBeforeMdnsDestroy();
346  if (restart_attempts_ < kMaxRestartAttempts) {
347    base::MessageLoop::current()->PostDelayedTask(
348        FROM_HERE,
349        base::Bind(&ServiceDiscoveryClientMdns::StartNewClient,
350                   weak_ptr_factory_.GetWeakPtr()),
351        base::TimeDelta::FromSeconds(
352            kRestartDelayOnNetworkChangeSeconds * (1 << restart_attempts_)));
353  } else {
354    ReportSuccess();
355  }
356}
357
358void ServiceDiscoveryClientMdns::StartNewClient() {
359  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
360  ++restart_attempts_;
361  DestroyMdns();
362  mdns_.reset(net::MDnsClient::CreateDefault().release());
363  client_.reset(new ServiceDiscoveryClientImpl(mdns_.get()));
364  BrowserThread::PostTaskAndReplyWithResult(
365      BrowserThread::FILE,
366      FROM_HERE,
367      base::Bind(&net::GetMDnsInterfacesToBind),
368      base::Bind(&ServiceDiscoveryClientMdns::OnInterfaceListReady,
369                 weak_ptr_factory_.GetWeakPtr()));
370}
371
372void ServiceDiscoveryClientMdns::OnInterfaceListReady(
373    const net::InterfaceIndexFamilyList& interfaces) {
374  mdns_runner_->PostTask(
375      FROM_HERE,
376      base::Bind(&InitMdns,
377                 base::Bind(&ServiceDiscoveryClientMdns::OnMdnsInitialized,
378                            weak_ptr_factory_.GetWeakPtr()),
379                 interfaces,
380                 base::Unretained(mdns_.get())));
381}
382
383void ServiceDiscoveryClientMdns::OnMdnsInitialized(bool success) {
384  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
385  if (!success) {
386    ScheduleStartNewClient();
387    return;
388  }
389  ReportSuccess();
390
391  // Initialization is done, no need to delay tasks.
392  need_dalay_mdns_tasks_ = false;
393  for (size_t i = 0; i < delayed_tasks_.size(); ++i)
394    mdns_runner_->PostTask(FROM_HERE, delayed_tasks_[i]);
395  delayed_tasks_.clear();
396
397  FOR_EACH_OBSERVER(Proxy, proxies_, OnNewMdnsReady());
398}
399
400void ServiceDiscoveryClientMdns::ReportSuccess() {
401  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
402  UMA_HISTOGRAM_COUNTS_100("LocalDiscovery.ClientRestartAttempts",
403                           restart_attempts_);
404}
405
406void ServiceDiscoveryClientMdns::OnBeforeMdnsDestroy() {
407  need_dalay_mdns_tasks_ = true;
408  delayed_tasks_.clear();
409  weak_ptr_factory_.InvalidateWeakPtrs();
410  FOR_EACH_OBSERVER(Proxy, proxies_, OnMdnsDestroy());
411}
412
413void ServiceDiscoveryClientMdns::DestroyMdns() {
414  OnBeforeMdnsDestroy();
415  // After calling |Proxy::OnMdnsDestroy| all references to client_ and mdns_
416  // should be destroyed.
417  if (client_)
418    mdns_runner_->DeleteSoon(FROM_HERE, client_.release());
419  if (mdns_)
420    mdns_runner_->DeleteSoon(FROM_HERE, mdns_.release());
421}
422
423bool ServiceDiscoveryClientMdns::PostToMdnsThread(const base::Closure& task) {
424  // The first task on IO thread for each |mdns_| instance must be |InitMdns|.
425  // |OnInterfaceListReady| could be delayed by |GetMDnsInterfacesToBind|
426  // running on FILE thread, so |PostToMdnsThread| could be called to post
427  // task for |mdns_| that is not initialized yet.
428  if (!need_dalay_mdns_tasks_)
429    return mdns_runner_->PostTask(FROM_HERE, task);
430  if (kMaxDelayedTasks > delayed_tasks_.size())
431    delayed_tasks_.push_back(task);
432  return true;
433}
434
435}  // namespace local_discovery
436