1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <utility>
6
7#include "base/logging.h"
8#include "base/memory/singleton.h"
9#include "base/message_loop/message_loop_proxy.h"
10#include "base/stl_util.h"
11#include "chrome/common/local_discovery/service_discovery_client_impl.h"
12#include "net/dns/dns_protocol.h"
13#include "net/dns/record_rdata.h"
14
15namespace local_discovery {
16
17namespace {
18// TODO(noamsml): Make this configurable through the LocalDomainResolver
19// interface.
20const int kLocalDomainSecondAddressTimeoutMs = 100;
21
22const int kInitialRequeryTimeSeconds = 1;
23const int kMaxRequeryTimeSeconds = 2; // Time for last requery
24}
25
26ServiceDiscoveryClientImpl::ServiceDiscoveryClientImpl(
27    net::MDnsClient* mdns_client) : mdns_client_(mdns_client) {
28}
29
30ServiceDiscoveryClientImpl::~ServiceDiscoveryClientImpl() {
31}
32
33scoped_ptr<ServiceWatcher> ServiceDiscoveryClientImpl::CreateServiceWatcher(
34    const std::string& service_type,
35    const ServiceWatcher::UpdatedCallback& callback) {
36  return scoped_ptr<ServiceWatcher>(new ServiceWatcherImpl(
37      service_type, callback, mdns_client_));
38}
39
40scoped_ptr<ServiceResolver> ServiceDiscoveryClientImpl::CreateServiceResolver(
41    const std::string& service_name,
42    const ServiceResolver::ResolveCompleteCallback& callback) {
43  return scoped_ptr<ServiceResolver>(new ServiceResolverImpl(
44      service_name, callback, mdns_client_));
45}
46
47scoped_ptr<LocalDomainResolver>
48ServiceDiscoveryClientImpl::CreateLocalDomainResolver(
49      const std::string& domain,
50      net::AddressFamily address_family,
51      const LocalDomainResolver::IPAddressCallback& callback) {
52  return scoped_ptr<LocalDomainResolver>(new LocalDomainResolverImpl(
53      domain, address_family, callback, mdns_client_));
54}
55
56ServiceWatcherImpl::ServiceWatcherImpl(
57    const std::string& service_type,
58    const ServiceWatcher::UpdatedCallback& callback,
59    net::MDnsClient* mdns_client)
60    : service_type_(service_type), callback_(callback), started_(false),
61      actively_refresh_services_(false), mdns_client_(mdns_client) {
62}
63
64void ServiceWatcherImpl::Start() {
65  DCHECK(!started_);
66  listener_ = mdns_client_->CreateListener(
67      net::dns_protocol::kTypePTR, service_type_, this);
68  started_ = listener_->Start();
69  if (started_)
70    ReadCachedServices();
71}
72
73ServiceWatcherImpl::~ServiceWatcherImpl() {
74}
75
76void ServiceWatcherImpl::DiscoverNewServices(bool force_update) {
77  DCHECK(started_);
78  if (force_update)
79    services_.clear();
80  SendQuery(kInitialRequeryTimeSeconds, force_update);
81}
82
83void ServiceWatcherImpl::SetActivelyRefreshServices(
84    bool actively_refresh_services) {
85  DCHECK(started_);
86  actively_refresh_services_ = actively_refresh_services;
87
88  for (ServiceListenersMap::iterator i = services_.begin();
89       i != services_.end(); i++) {
90    i->second->SetActiveRefresh(actively_refresh_services);
91  }
92}
93
94void ServiceWatcherImpl::ReadCachedServices() {
95  DCHECK(started_);
96  CreateTransaction(false /*network*/, true /*cache*/, false /*force refresh*/,
97                    &transaction_cache_);
98}
99
100bool ServiceWatcherImpl::CreateTransaction(
101    bool network, bool cache, bool force_refresh,
102    scoped_ptr<net::MDnsTransaction>* transaction) {
103  int transaction_flags = 0;
104  if (network)
105    transaction_flags |= net::MDnsTransaction::QUERY_NETWORK;
106
107  if (cache)
108    transaction_flags |= net::MDnsTransaction::QUERY_CACHE;
109
110  // TODO(noamsml): Add flag for force_refresh when supported.
111
112  if (transaction_flags) {
113    *transaction = mdns_client_->CreateTransaction(
114        net::dns_protocol::kTypePTR, service_type_, transaction_flags,
115        base::Bind(&ServiceWatcherImpl::OnTransactionResponse,
116                   base::Unretained(this), transaction));
117    return (*transaction)->Start();
118  }
119
120  return true;
121}
122
123std::string ServiceWatcherImpl::GetServiceType() const {
124  return listener_->GetName();
125}
126
127void ServiceWatcherImpl::OnRecordUpdate(
128    net::MDnsListener::UpdateType update,
129    const net::RecordParsed* record) {
130  DCHECK(started_);
131  if (record->type() == net::dns_protocol::kTypePTR) {
132    DCHECK(record->name() == GetServiceType());
133    const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
134
135    switch (update) {
136      case net::MDnsListener::RECORD_ADDED:
137        AddService(rdata->ptrdomain());
138        break;
139      case net::MDnsListener::RECORD_CHANGED:
140        NOTREACHED();
141        break;
142      case net::MDnsListener::RECORD_REMOVED:
143        RemovePTR(rdata->ptrdomain());
144        break;
145    }
146  } else {
147    DCHECK(record->type() == net::dns_protocol::kTypeSRV ||
148           record->type() == net::dns_protocol::kTypeTXT);
149    DCHECK(services_.find(record->name()) != services_.end());
150
151    if (record->type() == net::dns_protocol::kTypeSRV) {
152      if (update == net::MDnsListener::RECORD_REMOVED) {
153        RemoveSRV(record->name());
154      } else if (update == net::MDnsListener::RECORD_ADDED) {
155        AddSRV(record->name());
156      }
157    }
158
159    // If this is the first time we see an SRV record, do not send
160    // an UPDATE_CHANGED.
161    if (record->type() != net::dns_protocol::kTypeSRV ||
162        update != net::MDnsListener::RECORD_ADDED) {
163      DeferUpdate(UPDATE_CHANGED, record->name());
164    }
165  }
166}
167
168void ServiceWatcherImpl::OnCachePurged() {
169  // Not yet implemented.
170}
171
172void ServiceWatcherImpl::OnTransactionResponse(
173    scoped_ptr<net::MDnsTransaction>* transaction,
174    net::MDnsTransaction::Result result,
175    const net::RecordParsed* record) {
176  DCHECK(started_);
177  if (result == net::MDnsTransaction::RESULT_RECORD) {
178    const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
179    DCHECK(rdata);
180    AddService(rdata->ptrdomain());
181  } else if (result == net::MDnsTransaction::RESULT_DONE) {
182    transaction->reset();
183  }
184
185  // Do nothing for NSEC records. It is an error for hosts to broadcast an NSEC
186  // record for PTR records on any name.
187}
188
189ServiceWatcherImpl::ServiceListeners::ServiceListeners(
190    const std::string& service_name,
191    ServiceWatcherImpl* watcher,
192    net::MDnsClient* mdns_client)
193    : service_name_(service_name), mdns_client_(mdns_client),
194      update_pending_(false), has_ptr_(true), has_srv_(false) {
195  srv_listener_ = mdns_client->CreateListener(
196      net::dns_protocol::kTypeSRV, service_name, watcher);
197  txt_listener_ = mdns_client->CreateListener(
198      net::dns_protocol::kTypeTXT, service_name, watcher);
199}
200
201ServiceWatcherImpl::ServiceListeners::~ServiceListeners() {
202}
203
204bool ServiceWatcherImpl::ServiceListeners::Start() {
205  if (!srv_listener_->Start())
206    return false;
207  return txt_listener_->Start();
208}
209
210void ServiceWatcherImpl::ServiceListeners::SetActiveRefresh(
211    bool active_refresh) {
212  srv_listener_->SetActiveRefresh(active_refresh);
213
214  if (active_refresh && !has_srv_) {
215    DCHECK(has_ptr_);
216    srv_transaction_ = mdns_client_->CreateTransaction(
217        net::dns_protocol::kTypeSRV, service_name_,
218        net::MDnsTransaction::SINGLE_RESULT |
219        net::MDnsTransaction::QUERY_CACHE | net::MDnsTransaction::QUERY_NETWORK,
220        base::Bind(&ServiceWatcherImpl::ServiceListeners::OnSRVRecord,
221                   base::Unretained(this)));
222    srv_transaction_->Start();
223  } else if (!active_refresh) {
224    srv_transaction_.reset();
225  }
226}
227
228void ServiceWatcherImpl::ServiceListeners::OnSRVRecord(
229    net::MDnsTransaction::Result result,
230    const net::RecordParsed* record) {
231  set_has_srv(record != NULL);
232}
233
234void ServiceWatcherImpl::ServiceListeners::set_has_srv(bool has_srv) {
235  has_srv_ = has_srv;
236
237  srv_transaction_.reset();
238}
239
240void ServiceWatcherImpl::AddService(const std::string& service) {
241  DCHECK(started_);
242  std::pair<ServiceListenersMap::iterator, bool> found = services_.insert(
243      make_pair(service, linked_ptr<ServiceListeners>(NULL)));
244
245  if (found.second) {  // Newly inserted.
246    found.first->second = linked_ptr<ServiceListeners>(
247        new ServiceListeners(service, this, mdns_client_));
248    bool success = found.first->second->Start();
249    found.first->second->SetActiveRefresh(actively_refresh_services_);
250    DeferUpdate(UPDATE_ADDED, service);
251
252    DCHECK(success);
253  }
254
255  found.first->second->set_has_ptr(true);
256}
257
258void ServiceWatcherImpl::AddSRV(const std::string& service) {
259  DCHECK(started_);
260
261  ServiceListenersMap::iterator found = services_.find(service);
262  if (found != services_.end()) {
263    found->second->set_has_srv(true);
264  }
265}
266
267void ServiceWatcherImpl::DeferUpdate(ServiceWatcher::UpdateType update_type,
268                                     const std::string& service_name) {
269  ServiceListenersMap::iterator found = services_.find(service_name);
270
271  if (found != services_.end() && !found->second->update_pending()) {
272    found->second->set_update_pending(true);
273    base::MessageLoop::current()->PostTask(
274        FROM_HERE,
275        base::Bind(&ServiceWatcherImpl::DeliverDeferredUpdate, AsWeakPtr(),
276                   update_type, service_name));
277  }
278}
279
280void ServiceWatcherImpl::DeliverDeferredUpdate(
281    ServiceWatcher::UpdateType update_type, const std::string& service_name) {
282  ServiceListenersMap::iterator found = services_.find(service_name);
283
284  if (found != services_.end()) {
285    found->second->set_update_pending(false);
286    if (!callback_.is_null())
287      callback_.Run(update_type, service_name);
288  }
289}
290
291void ServiceWatcherImpl::RemovePTR(const std::string& service) {
292  DCHECK(started_);
293
294  ServiceListenersMap::iterator found = services_.find(service);
295  if (found != services_.end()) {
296    found->second->set_has_ptr(false);
297
298    if (!found->second->has_ptr_or_srv()) {
299      services_.erase(found);
300      if (!callback_.is_null())
301        callback_.Run(UPDATE_REMOVED, service);
302    }
303  }
304}
305
306void ServiceWatcherImpl::RemoveSRV(const std::string& service) {
307  DCHECK(started_);
308
309  ServiceListenersMap::iterator found = services_.find(service);
310  if (found != services_.end()) {
311    found->second->set_has_srv(false);
312
313    if (!found->second->has_ptr_or_srv()) {
314      services_.erase(found);
315      if (!callback_.is_null())
316        callback_.Run(UPDATE_REMOVED, service);
317    }
318  }
319}
320
321void ServiceWatcherImpl::OnNsecRecord(const std::string& name,
322                                      unsigned rrtype) {
323  // Do nothing. It is an error for hosts to broadcast an NSEC record for PTR
324  // on any name.
325}
326
327void ServiceWatcherImpl::ScheduleQuery(int timeout_seconds) {
328  if (timeout_seconds <= kMaxRequeryTimeSeconds) {
329    base::MessageLoop::current()->PostDelayedTask(
330        FROM_HERE,
331        base::Bind(&ServiceWatcherImpl::SendQuery,
332                   AsWeakPtr(),
333                   timeout_seconds * 2 /*next_timeout_seconds*/,
334                   false /*force_update*/),
335        base::TimeDelta::FromSeconds(timeout_seconds));
336  }
337}
338
339void ServiceWatcherImpl::SendQuery(int next_timeout_seconds,
340                                   bool force_update) {
341  CreateTransaction(true /*network*/, false /*cache*/, force_update,
342                    &transaction_network_);
343  ScheduleQuery(next_timeout_seconds);
344}
345
346ServiceResolverImpl::ServiceResolverImpl(
347    const std::string& service_name,
348    const ResolveCompleteCallback& callback,
349    net::MDnsClient* mdns_client)
350    : service_name_(service_name), callback_(callback),
351      metadata_resolved_(false), address_resolved_(false),
352      mdns_client_(mdns_client) {
353}
354
355void ServiceResolverImpl::StartResolving() {
356  address_resolved_ = false;
357  metadata_resolved_ = false;
358  service_staging_ = ServiceDescription();
359  service_staging_.service_name = service_name_;
360
361  if (!CreateTxtTransaction() || !CreateSrvTransaction()) {
362    ServiceNotFound(ServiceResolver::STATUS_REQUEST_TIMEOUT);
363  }
364}
365
366ServiceResolverImpl::~ServiceResolverImpl() {
367}
368
369bool ServiceResolverImpl::CreateTxtTransaction() {
370  txt_transaction_ = mdns_client_->CreateTransaction(
371      net::dns_protocol::kTypeTXT, service_name_,
372      net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
373      net::MDnsTransaction::QUERY_NETWORK,
374      base::Bind(&ServiceResolverImpl::TxtRecordTransactionResponse,
375                 AsWeakPtr()));
376  return txt_transaction_->Start();
377}
378
379// TODO(noamsml): quick-resolve for AAAA records.  Since A records tend to be in
380void ServiceResolverImpl::CreateATransaction() {
381  a_transaction_ = mdns_client_->CreateTransaction(
382      net::dns_protocol::kTypeA,
383      service_staging_.address.host(),
384      net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE,
385      base::Bind(&ServiceResolverImpl::ARecordTransactionResponse,
386                 AsWeakPtr()));
387  a_transaction_->Start();
388}
389
390bool ServiceResolverImpl::CreateSrvTransaction() {
391  srv_transaction_ = mdns_client_->CreateTransaction(
392      net::dns_protocol::kTypeSRV, service_name_,
393      net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
394      net::MDnsTransaction::QUERY_NETWORK,
395      base::Bind(&ServiceResolverImpl::SrvRecordTransactionResponse,
396                 AsWeakPtr()));
397  return srv_transaction_->Start();
398}
399
400std::string ServiceResolverImpl::GetName() const {
401  return service_name_;
402}
403
404void ServiceResolverImpl::SrvRecordTransactionResponse(
405    net::MDnsTransaction::Result status, const net::RecordParsed* record) {
406  srv_transaction_.reset();
407  if (status == net::MDnsTransaction::RESULT_RECORD) {
408    DCHECK(record);
409    service_staging_.address = RecordToAddress(record);
410    service_staging_.last_seen = record->time_created();
411    CreateATransaction();
412  } else {
413    ServiceNotFound(MDnsStatusToRequestStatus(status));
414  }
415}
416
417void ServiceResolverImpl::TxtRecordTransactionResponse(
418    net::MDnsTransaction::Result status, const net::RecordParsed* record) {
419  txt_transaction_.reset();
420  if (status == net::MDnsTransaction::RESULT_RECORD) {
421    DCHECK(record);
422    service_staging_.metadata = RecordToMetadata(record);
423  } else {
424    service_staging_.metadata = std::vector<std::string>();
425  }
426
427  metadata_resolved_ = true;
428  AlertCallbackIfReady();
429}
430
431void ServiceResolverImpl::ARecordTransactionResponse(
432    net::MDnsTransaction::Result status, const net::RecordParsed* record) {
433  a_transaction_.reset();
434
435  if (status == net::MDnsTransaction::RESULT_RECORD) {
436    DCHECK(record);
437    service_staging_.ip_address = RecordToIPAddress(record);
438  } else {
439    service_staging_.ip_address = net::IPAddressNumber();
440  }
441
442  address_resolved_ = true;
443  AlertCallbackIfReady();
444}
445
446void ServiceResolverImpl::AlertCallbackIfReady() {
447  if (metadata_resolved_ && address_resolved_) {
448    txt_transaction_.reset();
449    srv_transaction_.reset();
450    a_transaction_.reset();
451    if (!callback_.is_null())
452      callback_.Run(STATUS_SUCCESS, service_staging_);
453  }
454}
455
456void ServiceResolverImpl::ServiceNotFound(
457    ServiceResolver::RequestStatus status) {
458  txt_transaction_.reset();
459  srv_transaction_.reset();
460  a_transaction_.reset();
461  if (!callback_.is_null())
462    callback_.Run(status, ServiceDescription());
463}
464
465ServiceResolver::RequestStatus ServiceResolverImpl::MDnsStatusToRequestStatus(
466    net::MDnsTransaction::Result status) const {
467  switch (status) {
468    case net::MDnsTransaction::RESULT_RECORD:
469      return ServiceResolver::STATUS_SUCCESS;
470    case net::MDnsTransaction::RESULT_NO_RESULTS:
471      return ServiceResolver::STATUS_REQUEST_TIMEOUT;
472    case net::MDnsTransaction::RESULT_NSEC:
473      return ServiceResolver::STATUS_KNOWN_NONEXISTENT;
474    case net::MDnsTransaction::RESULT_DONE:  // Pass through.
475    default:
476      NOTREACHED();
477      return ServiceResolver::STATUS_REQUEST_TIMEOUT;
478  }
479}
480
481const std::vector<std::string>& ServiceResolverImpl::RecordToMetadata(
482    const net::RecordParsed* record) const {
483  DCHECK(record->type() == net::dns_protocol::kTypeTXT);
484  const net::TxtRecordRdata* txt_rdata = record->rdata<net::TxtRecordRdata>();
485  DCHECK(txt_rdata);
486  return txt_rdata->texts();
487}
488
489net::HostPortPair ServiceResolverImpl::RecordToAddress(
490    const net::RecordParsed* record) const {
491  DCHECK(record->type() == net::dns_protocol::kTypeSRV);
492  const net::SrvRecordRdata* srv_rdata = record->rdata<net::SrvRecordRdata>();
493  DCHECK(srv_rdata);
494  return net::HostPortPair(srv_rdata->target(), srv_rdata->port());
495}
496
497const net::IPAddressNumber& ServiceResolverImpl::RecordToIPAddress(
498    const net::RecordParsed* record) const {
499  DCHECK(record->type() == net::dns_protocol::kTypeA);
500  const net::ARecordRdata* a_rdata = record->rdata<net::ARecordRdata>();
501  DCHECK(a_rdata);
502  return a_rdata->address();
503}
504
505LocalDomainResolverImpl::LocalDomainResolverImpl(
506    const std::string& domain,
507    net::AddressFamily address_family,
508    const IPAddressCallback& callback,
509    net::MDnsClient* mdns_client)
510    : domain_(domain), address_family_(address_family), callback_(callback),
511      transactions_finished_(0), mdns_client_(mdns_client) {
512}
513
514LocalDomainResolverImpl::~LocalDomainResolverImpl() {
515  timeout_callback_.Cancel();
516}
517
518void LocalDomainResolverImpl::Start() {
519  if (address_family_ == net::ADDRESS_FAMILY_IPV4 ||
520      address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
521    transaction_a_ = CreateTransaction(net::dns_protocol::kTypeA);
522    transaction_a_->Start();
523  }
524
525  if (address_family_ == net::ADDRESS_FAMILY_IPV6 ||
526      address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
527    transaction_aaaa_ = CreateTransaction(net::dns_protocol::kTypeAAAA);
528    transaction_aaaa_->Start();
529  }
530}
531
532scoped_ptr<net::MDnsTransaction> LocalDomainResolverImpl::CreateTransaction(
533    uint16 type) {
534  return mdns_client_->CreateTransaction(
535      type, domain_, net::MDnsTransaction::SINGLE_RESULT |
536                     net::MDnsTransaction::QUERY_CACHE |
537                     net::MDnsTransaction::QUERY_NETWORK,
538      base::Bind(&LocalDomainResolverImpl::OnTransactionComplete,
539                 base::Unretained(this)));
540}
541
542void LocalDomainResolverImpl::OnTransactionComplete(
543    net::MDnsTransaction::Result result, const net::RecordParsed* record) {
544  transactions_finished_++;
545
546  if (result == net::MDnsTransaction::RESULT_RECORD) {
547    if (record->type() == net::dns_protocol::kTypeA) {
548      const net::ARecordRdata* rdata = record->rdata<net::ARecordRdata>();
549      address_ipv4_ = rdata->address();
550    } else {
551      DCHECK_EQ(net::dns_protocol::kTypeAAAA, record->type());
552      const net::AAAARecordRdata* rdata = record->rdata<net::AAAARecordRdata>();
553      address_ipv6_ = rdata->address();
554    }
555  }
556
557  if (transactions_finished_ == 1 &&
558      address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
559    timeout_callback_.Reset(base::Bind(
560        &LocalDomainResolverImpl::SendResolvedAddresses,
561        base::Unretained(this)));
562
563    base::MessageLoop::current()->PostDelayedTask(
564        FROM_HERE,
565        timeout_callback_.callback(),
566        base::TimeDelta::FromMilliseconds(kLocalDomainSecondAddressTimeoutMs));
567  } else if (transactions_finished_ == 2
568      || address_family_ != net::ADDRESS_FAMILY_UNSPECIFIED) {
569    SendResolvedAddresses();
570  }
571}
572
573bool LocalDomainResolverImpl::IsSuccess() {
574  return !address_ipv4_.empty() || !address_ipv6_.empty();
575}
576
577void LocalDomainResolverImpl::SendResolvedAddresses() {
578  transaction_a_.reset();
579  transaction_aaaa_.reset();
580  timeout_callback_.Cancel();
581  callback_.Run(IsSuccess(), address_ipv4_, address_ipv6_);
582}
583
584}  // namespace local_discovery
585