1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/browser/local_discovery/service_discovery_client_mdns.h"
6
7#include "base/memory/scoped_vector.h"
8#include "base/metrics/histogram.h"
9#include "chrome/common/local_discovery/service_discovery_client_impl.h"
10#include "content/public/browser/browser_thread.h"
11#include "net/dns/mdns_client.h"
12#include "net/udp/datagram_server_socket.h"
13
14namespace local_discovery {
15
16using content::BrowserThread;
17
18// Base class for objects returned by ServiceDiscoveryClient implementation.
19// Handles interaction of client code on UI thread end net code on mdns thread.
20class ServiceDiscoveryClientMdns::Proxy {
21 public:
22  typedef base::WeakPtr<Proxy> WeakPtr;
23
24  explicit Proxy(ServiceDiscoveryClientMdns* client)
25      : client_(client),
26        weak_ptr_factory_(this) {
27    DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
28    client_->proxies_.AddObserver(this);
29  }
30
31  virtual ~Proxy() {
32    DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
33    client_->proxies_.RemoveObserver(this);
34  }
35
36  // Returns true if object is not yet shutdown.
37  virtual bool IsValid() = 0;
38
39  // Notifies proxies that mDNS layer is going to be destroyed.
40  virtual void OnMdnsDestroy() = 0;
41
42  // Notifies proxies that new mDNS instance is ready.
43  virtual void OnNewMdnsReady() {
44    DCHECK(!client_->need_dalay_mdns_tasks_);
45    if (IsValid()) {
46      for (size_t i = 0; i < delayed_tasks_.size(); ++i)
47        client_->mdns_runner_->PostTask(FROM_HERE, delayed_tasks_[i]);
48    }
49    delayed_tasks_.clear();
50  }
51
52  // Runs callback using this method to abort callback if instance of |Proxy|
53  // is deleted.
54  void RunCallback(const base::Closure& callback) {
55    DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
56    callback.Run();
57  }
58
59 protected:
60  void PostToMdnsThread(const base::Closure& task) {
61    DCHECK(IsValid());
62    // The first task on IO thread for each |mdns_| instance must be |InitMdns|.
63    // |OnInterfaceListReady| could be delayed by |GetMDnsInterfacesToBind|
64    // running on FILE thread, so |PostToMdnsThread| could be called to post
65    // task for |mdns_| that is not initialized yet.
66    if (!client_->need_dalay_mdns_tasks_) {
67      client_->mdns_runner_->PostTask(FROM_HERE, task);
68      return;
69    }
70    delayed_tasks_.push_back(task);
71  }
72
73  static bool PostToUIThread(const base::Closure& task) {
74    return BrowserThread::PostTask(BrowserThread::UI, FROM_HERE, task);
75  }
76
77  ServiceDiscoveryClient* client() {
78    return client_->client_.get();
79  }
80
81  WeakPtr GetWeakPtr() {
82    return weak_ptr_factory_.GetWeakPtr();
83  }
84
85  template<class T>
86  void DeleteOnMdnsThread(T* t) {
87    if (!t)
88      return;
89    if (!client_->mdns_runner_->DeleteSoon(FROM_HERE, t))
90      delete t;
91  }
92
93 private:
94  scoped_refptr<ServiceDiscoveryClientMdns> client_;
95  base::WeakPtrFactory<Proxy> weak_ptr_factory_;
96  // Delayed |mdns_runner_| tasks.
97  std::vector<base::Closure> delayed_tasks_;
98  DISALLOW_COPY_AND_ASSIGN(Proxy);
99};
100
101namespace {
102
103const int kMaxRestartAttempts = 10;
104const int kRestartDelayOnNetworkChangeSeconds = 3;
105
106typedef base::Callback<void(bool)> MdnsInitCallback;
107
108class SocketFactory : public net::MDnsSocketFactory {
109 public:
110  explicit SocketFactory(const net::InterfaceIndexFamilyList& interfaces)
111      : interfaces_(interfaces) {}
112
113  // net::MDnsSocketFactory implementation:
114  virtual void CreateSockets(
115      ScopedVector<net::DatagramServerSocket>* sockets) OVERRIDE {
116    for (size_t i = 0; i < interfaces_.size(); ++i) {
117      DCHECK(interfaces_[i].second == net::ADDRESS_FAMILY_IPV4 ||
118             interfaces_[i].second == net::ADDRESS_FAMILY_IPV6);
119      scoped_ptr<net::DatagramServerSocket> socket(
120          CreateAndBindMDnsSocket(interfaces_[i].second, interfaces_[i].first));
121      if (socket)
122        sockets->push_back(socket.release());
123    }
124  }
125
126 private:
127  net::InterfaceIndexFamilyList interfaces_;
128};
129
130void InitMdns(const MdnsInitCallback& on_initialized,
131              const net::InterfaceIndexFamilyList& interfaces,
132              net::MDnsClient* mdns) {
133  SocketFactory socket_factory(interfaces);
134  BrowserThread::PostTask(BrowserThread::UI, FROM_HERE,
135                          base::Bind(on_initialized,
136                                     mdns->StartListening(&socket_factory)));
137}
138
139template<class T>
140class ProxyBase : public ServiceDiscoveryClientMdns::Proxy, public T {
141 public:
142  typedef ProxyBase<T> Base;
143
144  explicit ProxyBase(ServiceDiscoveryClientMdns* client)
145      : Proxy(client) {
146  }
147
148  virtual ~ProxyBase() {
149    DeleteOnMdnsThread(implementation_.release());
150  }
151
152  virtual bool IsValid() OVERRIDE {
153    return !!implementation();
154  }
155
156  virtual void OnMdnsDestroy() OVERRIDE {
157    DeleteOnMdnsThread(implementation_.release());
158  };
159
160 protected:
161  void set_implementation(scoped_ptr<T> implementation) {
162    implementation_ = implementation.Pass();
163  }
164
165  T* implementation()  const {
166    return implementation_.get();
167  }
168
169 private:
170  scoped_ptr<T> implementation_;
171  DISALLOW_COPY_AND_ASSIGN(ProxyBase);
172};
173
174class ServiceWatcherProxy : public ProxyBase<ServiceWatcher> {
175 public:
176  ServiceWatcherProxy(ServiceDiscoveryClientMdns* client_mdns,
177                      const std::string& service_type,
178                      const ServiceWatcher::UpdatedCallback& callback)
179      : ProxyBase(client_mdns),
180        service_type_(service_type),
181        callback_(callback) {
182    // It's safe to call |CreateServiceWatcher| on UI thread, because
183    // |MDnsClient| is not used there. It's simplify implementation.
184    set_implementation(client()->CreateServiceWatcher(
185        service_type,
186        base::Bind(&ServiceWatcherProxy::OnCallback, GetWeakPtr(), callback)));
187  }
188
189  // ServiceWatcher methods.
190  virtual void Start() OVERRIDE {
191    if (implementation()) {
192      PostToMdnsThread(base::Bind(&ServiceWatcher::Start,
193                                  base::Unretained(implementation())));
194    }
195  }
196
197  virtual void DiscoverNewServices(bool force_update) OVERRIDE {
198    if (implementation()) {
199      PostToMdnsThread(base::Bind(&ServiceWatcher::DiscoverNewServices,
200                                  base::Unretained(implementation()),
201                                  force_update));
202    }
203  }
204
205  virtual void SetActivelyRefreshServices(
206      bool actively_refresh_services) OVERRIDE {
207    if (implementation()) {
208      PostToMdnsThread(base::Bind(&ServiceWatcher::SetActivelyRefreshServices,
209                                  base::Unretained(implementation()),
210                                  actively_refresh_services));
211    }
212  }
213
214  virtual std::string GetServiceType() const OVERRIDE {
215    return service_type_;
216  }
217
218  virtual void OnNewMdnsReady() OVERRIDE {
219    ProxyBase<ServiceWatcher>::OnNewMdnsReady();
220    if (!implementation())
221      callback_.Run(ServiceWatcher::UPDATE_INVALIDATED, "");
222  }
223
224 private:
225  static void OnCallback(const WeakPtr& proxy,
226                         const ServiceWatcher::UpdatedCallback& callback,
227                         UpdateType a1,
228                         const std::string& a2) {
229    DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI));
230    PostToUIThread(base::Bind(&Base::RunCallback, proxy,
231                              base::Bind(callback, a1, a2)));
232  }
233  std::string service_type_;
234  ServiceWatcher::UpdatedCallback callback_;
235  DISALLOW_COPY_AND_ASSIGN(ServiceWatcherProxy);
236};
237
238class ServiceResolverProxy : public ProxyBase<ServiceResolver> {
239 public:
240  ServiceResolverProxy(ServiceDiscoveryClientMdns* client_mdns,
241                       const std::string& service_name,
242                       const ServiceResolver::ResolveCompleteCallback& callback)
243      : ProxyBase(client_mdns),
244        service_name_(service_name) {
245    // It's safe to call |CreateServiceResolver| on UI thread, because
246    // |MDnsClient| is not used there. It's simplify implementation.
247    set_implementation(client()->CreateServiceResolver(
248        service_name,
249        base::Bind(&ServiceResolverProxy::OnCallback, GetWeakPtr(), callback)));
250  }
251
252  // ServiceResolver methods.
253  virtual void StartResolving() OVERRIDE {
254    if (implementation()) {
255      PostToMdnsThread(base::Bind(&ServiceResolver::StartResolving,
256                                  base::Unretained(implementation())));
257    }
258  };
259
260  virtual std::string GetName() const OVERRIDE {
261    return service_name_;
262  }
263
264 private:
265  static void OnCallback(
266      const WeakPtr& proxy,
267      const ServiceResolver::ResolveCompleteCallback& callback,
268      RequestStatus a1,
269      const ServiceDescription& a2) {
270    DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI));
271    PostToUIThread(base::Bind(&Base::RunCallback, proxy,
272                              base::Bind(callback, a1, a2)));
273  }
274
275  std::string service_name_;
276  DISALLOW_COPY_AND_ASSIGN(ServiceResolverProxy);
277};
278
279class LocalDomainResolverProxy : public ProxyBase<LocalDomainResolver> {
280 public:
281  LocalDomainResolverProxy(
282      ServiceDiscoveryClientMdns* client_mdns,
283      const std::string& domain,
284      net::AddressFamily address_family,
285      const LocalDomainResolver::IPAddressCallback& callback)
286      : ProxyBase(client_mdns) {
287    // It's safe to call |CreateLocalDomainResolver| on UI thread, because
288    // |MDnsClient| is not used there. It's simplify implementation.
289    set_implementation(client()->CreateLocalDomainResolver(
290        domain,
291        address_family,
292        base::Bind(
293            &LocalDomainResolverProxy::OnCallback, GetWeakPtr(), callback)));
294  }
295
296  // LocalDomainResolver methods.
297  virtual void Start() OVERRIDE {
298    if (implementation()) {
299      PostToMdnsThread(base::Bind(&LocalDomainResolver::Start,
300                                  base::Unretained(implementation())));
301    }
302  };
303
304 private:
305  static void OnCallback(const WeakPtr& proxy,
306                         const LocalDomainResolver::IPAddressCallback& callback,
307                         bool a1,
308                         const net::IPAddressNumber& a2,
309                         const net::IPAddressNumber& a3) {
310    DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI));
311    PostToUIThread(base::Bind(&Base::RunCallback, proxy,
312                              base::Bind(callback, a1, a2, a3)));
313  }
314
315  DISALLOW_COPY_AND_ASSIGN(LocalDomainResolverProxy);
316};
317
318}  // namespace
319
320ServiceDiscoveryClientMdns::ServiceDiscoveryClientMdns()
321    : mdns_runner_(
322          BrowserThread::GetMessageLoopProxyForThread(BrowserThread::IO)),
323      restart_attempts_(0),
324      need_dalay_mdns_tasks_(true),
325      weak_ptr_factory_(this) {
326  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
327  net::NetworkChangeNotifier::AddNetworkChangeObserver(this);
328  StartNewClient();
329}
330
331scoped_ptr<ServiceWatcher> ServiceDiscoveryClientMdns::CreateServiceWatcher(
332    const std::string& service_type,
333    const ServiceWatcher::UpdatedCallback& callback) {
334  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
335  return scoped_ptr<ServiceWatcher>(
336      new ServiceWatcherProxy(this, service_type, callback));
337}
338
339scoped_ptr<ServiceResolver> ServiceDiscoveryClientMdns::CreateServiceResolver(
340    const std::string& service_name,
341    const ServiceResolver::ResolveCompleteCallback& callback) {
342  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
343  return scoped_ptr<ServiceResolver>(
344      new ServiceResolverProxy(this, service_name, callback));
345}
346
347scoped_ptr<LocalDomainResolver>
348ServiceDiscoveryClientMdns::CreateLocalDomainResolver(
349    const std::string& domain,
350    net::AddressFamily address_family,
351    const LocalDomainResolver::IPAddressCallback& callback) {
352  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
353  return scoped_ptr<LocalDomainResolver>(
354      new LocalDomainResolverProxy(this, domain, address_family, callback));
355}
356
357ServiceDiscoveryClientMdns::~ServiceDiscoveryClientMdns() {
358  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
359  net::NetworkChangeNotifier::RemoveNetworkChangeObserver(this);
360  DestroyMdns();
361}
362
363void ServiceDiscoveryClientMdns::OnNetworkChanged(
364    net::NetworkChangeNotifier::ConnectionType type) {
365  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
366  // Only network changes resets counter.
367  restart_attempts_ = 0;
368  ScheduleStartNewClient();
369}
370
371void ServiceDiscoveryClientMdns::ScheduleStartNewClient() {
372  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
373  OnBeforeMdnsDestroy();
374  if (restart_attempts_ < kMaxRestartAttempts) {
375    base::MessageLoop::current()->PostDelayedTask(
376        FROM_HERE,
377        base::Bind(&ServiceDiscoveryClientMdns::StartNewClient,
378                   weak_ptr_factory_.GetWeakPtr()),
379        base::TimeDelta::FromSeconds(
380            kRestartDelayOnNetworkChangeSeconds * (1 << restart_attempts_)));
381  } else {
382    ReportSuccess();
383  }
384}
385
386void ServiceDiscoveryClientMdns::StartNewClient() {
387  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
388  ++restart_attempts_;
389  DestroyMdns();
390  mdns_.reset(net::MDnsClient::CreateDefault().release());
391  client_.reset(new ServiceDiscoveryClientImpl(mdns_.get()));
392  BrowserThread::PostTaskAndReplyWithResult(
393      BrowserThread::FILE,
394      FROM_HERE,
395      base::Bind(&net::GetMDnsInterfacesToBind),
396      base::Bind(&ServiceDiscoveryClientMdns::OnInterfaceListReady,
397                 weak_ptr_factory_.GetWeakPtr()));
398}
399
400void ServiceDiscoveryClientMdns::OnInterfaceListReady(
401    const net::InterfaceIndexFamilyList& interfaces) {
402  mdns_runner_->PostTask(
403      FROM_HERE,
404      base::Bind(&InitMdns,
405                 base::Bind(&ServiceDiscoveryClientMdns::OnMdnsInitialized,
406                            weak_ptr_factory_.GetWeakPtr()),
407                 interfaces,
408                 base::Unretained(mdns_.get())));
409}
410
411void ServiceDiscoveryClientMdns::OnMdnsInitialized(bool success) {
412  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
413  if (!success) {
414    ScheduleStartNewClient();
415    return;
416  }
417  ReportSuccess();
418
419  // Initialization is done, no need to delay tasks.
420  need_dalay_mdns_tasks_ = false;
421  FOR_EACH_OBSERVER(Proxy, proxies_, OnNewMdnsReady());
422}
423
424void ServiceDiscoveryClientMdns::ReportSuccess() {
425  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
426  UMA_HISTOGRAM_COUNTS_100("LocalDiscovery.ClientRestartAttempts",
427                           restart_attempts_);
428}
429
430void ServiceDiscoveryClientMdns::OnBeforeMdnsDestroy() {
431  need_dalay_mdns_tasks_ = true;
432  weak_ptr_factory_.InvalidateWeakPtrs();
433  FOR_EACH_OBSERVER(Proxy, proxies_, OnMdnsDestroy());
434}
435
436void ServiceDiscoveryClientMdns::DestroyMdns() {
437  OnBeforeMdnsDestroy();
438  // After calling |Proxy::OnMdnsDestroy| all references to client_ and mdns_
439  // should be destroyed.
440  if (client_)
441    mdns_runner_->DeleteSoon(FROM_HERE, client_.release());
442  if (mdns_)
443    mdns_runner_->DeleteSoon(FROM_HERE, mdns_.release());
444}
445
446}  // namespace local_discovery
447