1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <utility>
6
7#include "base/logging.h"
8#include "base/memory/singleton.h"
9#include "base/message_loop/message_loop_proxy.h"
10#include "base/stl_util.h"
11#include "chrome/utility/local_discovery/service_discovery_client_impl.h"
12#include "net/dns/dns_protocol.h"
13#include "net/dns/record_rdata.h"
14
15namespace local_discovery {
16
17ServiceDiscoveryClientImpl::ServiceDiscoveryClientImpl(
18    net::MDnsClient* mdns_client) : mdns_client_(mdns_client) {
19}
20
21ServiceDiscoveryClientImpl::~ServiceDiscoveryClientImpl() {
22}
23
24scoped_ptr<ServiceWatcher> ServiceDiscoveryClientImpl::CreateServiceWatcher(
25    const std::string& service_type,
26    const ServiceWatcher::UpdatedCallback& callback) {
27  return scoped_ptr<ServiceWatcher>(new ServiceWatcherImpl(
28      service_type,  callback, mdns_client_));
29}
30
31scoped_ptr<ServiceResolver> ServiceDiscoveryClientImpl::CreateServiceResolver(
32    const std::string& service_name,
33    const ServiceResolver::ResolveCompleteCallback& callback) {
34  return scoped_ptr<ServiceResolver>(new ServiceResolverImpl(
35      service_name, callback, mdns_client_));
36}
37
38scoped_ptr<LocalDomainResolver>
39ServiceDiscoveryClientImpl::CreateLocalDomainResolver(
40      const std::string& domain,
41      net::AddressFamily address_family,
42      const LocalDomainResolver::IPAddressCallback& callback) {
43  return scoped_ptr<LocalDomainResolver>(new LocalDomainResolverImpl(
44      domain, address_family, callback, mdns_client_));
45}
46
47ServiceWatcherImpl::ServiceWatcherImpl(
48    const std::string& service_type,
49    const ServiceWatcher::UpdatedCallback& callback,
50    net::MDnsClient* mdns_client)
51    : service_type_(service_type), callback_(callback), started_(false),
52      mdns_client_(mdns_client) {
53}
54
55void ServiceWatcherImpl::Start() {
56  DCHECK(!started_);
57  listener_ = mdns_client_->CreateListener(
58      net::dns_protocol::kTypePTR, service_type_, this);
59  started_ = listener_->Start();
60  if (started_)
61    ReadCachedServices();
62}
63
64ServiceWatcherImpl::~ServiceWatcherImpl() {
65}
66
67void ServiceWatcherImpl::DiscoverNewServices(bool force_update) {
68  DCHECK(started_);
69  if (force_update)
70    services_.clear();
71  CreateTransaction(true /*network*/, false /*cache*/, force_update,
72                    &transaction_network_);
73}
74
75void ServiceWatcherImpl::ReadCachedServices() {
76  DCHECK(started_);
77  CreateTransaction(false /*network*/, true /*cache*/, false /*force refresh*/,
78                    &transaction_cache_);
79}
80
81bool ServiceWatcherImpl::CreateTransaction(
82    bool network, bool cache, bool force_refresh,
83    scoped_ptr<net::MDnsTransaction>* transaction) {
84  int transaction_flags = 0;
85  if (network)
86    transaction_flags |= net::MDnsTransaction::QUERY_NETWORK;
87
88  if (cache)
89    transaction_flags |= net::MDnsTransaction::QUERY_CACHE;
90
91  // TODO(noamsml): Add flag for force_refresh when supported.
92
93  if (transaction_flags) {
94    *transaction = mdns_client_->CreateTransaction(
95        net::dns_protocol::kTypePTR, service_type_, transaction_flags,
96        base::Bind(&ServiceWatcherImpl::OnTransactionResponse,
97                   base::Unretained(this), transaction));
98    return (*transaction)->Start();
99  }
100
101  return true;
102}
103
104std::string ServiceWatcherImpl::GetServiceType() const {
105  return listener_->GetName();
106}
107
108void ServiceWatcherImpl::OnRecordUpdate(
109    net::MDnsListener::UpdateType update,
110    const net::RecordParsed* record) {
111  DCHECK(started_);
112  if (record->type() == net::dns_protocol::kTypePTR) {
113    DCHECK(record->name() == GetServiceType());
114    const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
115
116    switch (update) {
117      case net::MDnsListener::RECORD_ADDED:
118        AddService(rdata->ptrdomain());
119        break;
120      case net::MDnsListener::RECORD_CHANGED:
121        NOTREACHED();
122        break;
123      case net::MDnsListener::RECORD_REMOVED:
124        RemoveService(rdata->ptrdomain());
125        break;
126    }
127  } else {
128    DCHECK(record->type() == net::dns_protocol::kTypeSRV ||
129           record->type() == net::dns_protocol::kTypeTXT);
130    DCHECK(services_.find(record->name()) != services_.end());
131
132    DeferUpdate(UPDATE_CHANGED, record->name());
133  }
134}
135
136void ServiceWatcherImpl::OnCachePurged() {
137  // Not yet implemented.
138}
139
140void ServiceWatcherImpl::OnTransactionResponse(
141    scoped_ptr<net::MDnsTransaction>* transaction,
142    net::MDnsTransaction::Result result,
143    const net::RecordParsed* record) {
144  DCHECK(started_);
145  if (result == net::MDnsTransaction::RESULT_RECORD) {
146    const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
147    DCHECK(rdata);
148    AddService(rdata->ptrdomain());
149  } else if (result == net::MDnsTransaction::RESULT_DONE) {
150    transaction->reset();
151  }
152
153  // Do nothing for NSEC records. It is an error for hosts to broadcast an NSEC
154  // record for PTR records on any name.
155}
156
157ServiceWatcherImpl::ServiceListeners::ServiceListeners(
158    const std::string& service_name,
159    ServiceWatcherImpl* watcher,
160    net::MDnsClient* mdns_client) : update_pending_(false) {
161  srv_listener_ = mdns_client->CreateListener(
162      net::dns_protocol::kTypeSRV, service_name, watcher);
163  txt_listener_ = mdns_client->CreateListener(
164      net::dns_protocol::kTypeTXT, service_name, watcher);
165}
166
167ServiceWatcherImpl::ServiceListeners::~ServiceListeners() {
168}
169
170bool ServiceWatcherImpl::ServiceListeners::Start() {
171  if (!srv_listener_->Start())
172    return false;
173  return txt_listener_->Start();
174}
175
176void ServiceWatcherImpl::AddService(const std::string& service) {
177  DCHECK(started_);
178  std::pair<ServiceListenersMap::iterator, bool> found = services_.insert(
179      make_pair(service, linked_ptr<ServiceListeners>(NULL)));
180  if (found.second) {  // Newly inserted.
181    found.first->second = linked_ptr<ServiceListeners>(
182        new ServiceListeners(service, this, mdns_client_));
183    bool success = found.first->second->Start();
184
185    DeferUpdate(UPDATE_ADDED, service);
186
187    DCHECK(success);
188  }
189}
190
191void ServiceWatcherImpl::DeferUpdate(ServiceWatcher::UpdateType update_type,
192                                     const std::string& service_name) {
193  ServiceListenersMap::iterator found = services_.find(service_name);
194
195  if (found != services_.end() && !found->second->update_pending()) {
196    found->second->set_update_pending(true);
197    base::MessageLoop::current()->PostTask(
198        FROM_HERE,
199        base::Bind(&ServiceWatcherImpl::DeliverDeferredUpdate, AsWeakPtr(),
200                   update_type, service_name));
201  }
202}
203
204void ServiceWatcherImpl::DeliverDeferredUpdate(
205    ServiceWatcher::UpdateType update_type, const std::string& service_name) {
206  ServiceListenersMap::iterator found = services_.find(service_name);
207
208  if (found != services_.end()) {
209    found->second->set_update_pending(false);
210    if (!callback_.is_null())
211      callback_.Run(update_type, service_name);
212  }
213}
214
215void ServiceWatcherImpl::RemoveService(const std::string& service) {
216  DCHECK(started_);
217  ServiceListenersMap::iterator found = services_.find(service);
218  if (found != services_.end()) {
219    services_.erase(found);
220    if (!callback_.is_null())
221      callback_.Run(UPDATE_REMOVED, service);
222  }
223}
224
225void ServiceWatcherImpl::OnNsecRecord(const std::string& name,
226                                      unsigned rrtype) {
227  // Do nothing. It is an error for hosts to broadcast an NSEC record for PTR
228  // on any name.
229}
230
231ServiceResolverImpl::ServiceResolverImpl(
232    const std::string& service_name,
233    const ResolveCompleteCallback& callback,
234    net::MDnsClient* mdns_client)
235    : service_name_(service_name), callback_(callback),
236      metadata_resolved_(false), address_resolved_(false),
237      mdns_client_(mdns_client) {
238  service_staging_.service_name = service_name_;
239}
240
241void ServiceResolverImpl::StartResolving() {
242  address_resolved_ = false;
243  metadata_resolved_ = false;
244
245  if (!CreateTxtTransaction() || !CreateSrvTransaction()) {
246    ServiceNotFound(ServiceResolver::STATUS_REQUEST_TIMEOUT);
247  }
248}
249
250ServiceResolverImpl::~ServiceResolverImpl() {
251}
252
253bool ServiceResolverImpl::CreateTxtTransaction() {
254  txt_transaction_ = mdns_client_->CreateTransaction(
255      net::dns_protocol::kTypeTXT, service_name_,
256      net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
257      net::MDnsTransaction::QUERY_NETWORK,
258      base::Bind(&ServiceResolverImpl::TxtRecordTransactionResponse,
259                 AsWeakPtr()));
260  return txt_transaction_->Start();
261}
262
263// TODO(noamsml): quick-resolve for AAAA records.  Since A records tend to be in
264void ServiceResolverImpl::CreateATransaction() {
265  a_transaction_ = mdns_client_->CreateTransaction(
266      net::dns_protocol::kTypeA,
267      service_staging_.address.host(),
268      net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE,
269      base::Bind(&ServiceResolverImpl::ARecordTransactionResponse,
270                 AsWeakPtr()));
271  a_transaction_->Start();
272}
273
274bool ServiceResolverImpl::CreateSrvTransaction() {
275  srv_transaction_ = mdns_client_->CreateTransaction(
276      net::dns_protocol::kTypeSRV, service_name_,
277      net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
278      net::MDnsTransaction::QUERY_NETWORK,
279      base::Bind(&ServiceResolverImpl::SrvRecordTransactionResponse,
280                 AsWeakPtr()));
281  return srv_transaction_->Start();
282}
283
284std::string ServiceResolverImpl::GetName() const {
285  return service_name_;
286}
287
288void ServiceResolverImpl::SrvRecordTransactionResponse(
289    net::MDnsTransaction::Result status, const net::RecordParsed* record) {
290  srv_transaction_.reset();
291  if (status == net::MDnsTransaction::RESULT_RECORD) {
292    DCHECK(record);
293    service_staging_.address = RecordToAddress(record);
294    service_staging_.last_seen = record->time_created();
295    CreateATransaction();
296  } else {
297    ServiceNotFound(MDnsStatusToRequestStatus(status));
298  }
299}
300
301void ServiceResolverImpl::TxtRecordTransactionResponse(
302    net::MDnsTransaction::Result status, const net::RecordParsed* record) {
303  txt_transaction_.reset();
304  if (status == net::MDnsTransaction::RESULT_RECORD) {
305    DCHECK(record);
306    service_staging_.metadata = RecordToMetadata(record);
307  } else {
308    service_staging_.metadata = std::vector<std::string>();
309  }
310
311  metadata_resolved_ = true;
312  AlertCallbackIfReady();
313}
314
315void ServiceResolverImpl::ARecordTransactionResponse(
316    net::MDnsTransaction::Result status, const net::RecordParsed* record) {
317  a_transaction_.reset();
318
319  if (status == net::MDnsTransaction::RESULT_RECORD) {
320    DCHECK(record);
321    service_staging_.ip_address = RecordToIPAddress(record);
322  } else {
323    service_staging_.ip_address = net::IPAddressNumber();
324  }
325
326  address_resolved_ = true;
327  AlertCallbackIfReady();
328}
329
330void ServiceResolverImpl::AlertCallbackIfReady() {
331  if (metadata_resolved_ && address_resolved_) {
332    txt_transaction_.reset();
333    srv_transaction_.reset();
334    a_transaction_.reset();
335    if (!callback_.is_null())
336      callback_.Run(STATUS_SUCCESS, service_staging_);
337    service_staging_ = ServiceDescription();
338  }
339}
340
341void ServiceResolverImpl::ServiceNotFound(
342    ServiceResolver::RequestStatus status) {
343  txt_transaction_.reset();
344  srv_transaction_.reset();
345  a_transaction_.reset();
346  if (!callback_.is_null())
347    callback_.Run(status, ServiceDescription());
348}
349
350ServiceResolver::RequestStatus ServiceResolverImpl::MDnsStatusToRequestStatus(
351    net::MDnsTransaction::Result status) const {
352  switch (status) {
353    case net::MDnsTransaction::RESULT_RECORD:
354      return ServiceResolver::STATUS_SUCCESS;
355    case net::MDnsTransaction::RESULT_NO_RESULTS:
356      return ServiceResolver::STATUS_REQUEST_TIMEOUT;
357    case net::MDnsTransaction::RESULT_NSEC:
358      return ServiceResolver::STATUS_KNOWN_NONEXISTENT;
359    case net::MDnsTransaction::RESULT_DONE:  // Pass through.
360    default:
361      NOTREACHED();
362      return ServiceResolver::STATUS_REQUEST_TIMEOUT;
363  }
364}
365
366const std::vector<std::string>& ServiceResolverImpl::RecordToMetadata(
367    const net::RecordParsed* record) const {
368  DCHECK(record->type() == net::dns_protocol::kTypeTXT);
369  const net::TxtRecordRdata* txt_rdata = record->rdata<net::TxtRecordRdata>();
370  DCHECK(txt_rdata);
371  return txt_rdata->texts();
372}
373
374net::HostPortPair ServiceResolverImpl::RecordToAddress(
375    const net::RecordParsed* record) const {
376  DCHECK(record->type() == net::dns_protocol::kTypeSRV);
377  const net::SrvRecordRdata* srv_rdata = record->rdata<net::SrvRecordRdata>();
378  DCHECK(srv_rdata);
379  return net::HostPortPair(srv_rdata->target(), srv_rdata->port());
380}
381
382const net::IPAddressNumber& ServiceResolverImpl::RecordToIPAddress(
383    const net::RecordParsed* record) const {
384  DCHECK(record->type() == net::dns_protocol::kTypeA);
385  const net::ARecordRdata* a_rdata = record->rdata<net::ARecordRdata>();
386  DCHECK(a_rdata);
387  return a_rdata->address();
388}
389
390LocalDomainResolverImpl::LocalDomainResolverImpl(
391    const std::string& domain,
392    net::AddressFamily address_family,
393    const IPAddressCallback& callback,
394    net::MDnsClient* mdns_client)
395    : domain_(domain), address_family_(address_family), callback_(callback),
396      transaction_failures_(0), mdns_client_(mdns_client) {
397}
398
399LocalDomainResolverImpl::~LocalDomainResolverImpl() {
400}
401
402void LocalDomainResolverImpl::Start() {
403  if (address_family_ == net::ADDRESS_FAMILY_IPV4 ||
404      address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
405    transaction_a_ = CreateTransaction(net::dns_protocol::kTypeA);
406    transaction_a_->Start();
407  }
408
409  if (address_family_ == net::ADDRESS_FAMILY_IPV6 ||
410      address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
411    transaction_aaaa_ = CreateTransaction(net::dns_protocol::kTypeAAAA);
412    transaction_aaaa_->Start();
413  }
414}
415
416scoped_ptr<net::MDnsTransaction> LocalDomainResolverImpl::CreateTransaction(
417    uint16 type) {
418  return mdns_client_->CreateTransaction(
419      type, domain_, net::MDnsTransaction::SINGLE_RESULT |
420                     net::MDnsTransaction::QUERY_CACHE |
421                     net::MDnsTransaction::QUERY_NETWORK,
422      base::Bind(&LocalDomainResolverImpl::OnTransactionComplete,
423                 base::Unretained(this)));
424}
425
426void LocalDomainResolverImpl::OnTransactionComplete(
427    net::MDnsTransaction::Result result, const net::RecordParsed* record) {
428  if (result != net::MDnsTransaction::RESULT_RECORD &&
429      address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
430    transaction_failures_++;
431
432    if (transaction_failures_ < 2) {
433      return;
434    }
435  }
436
437  transaction_a_.reset();
438  transaction_aaaa_.reset();
439
440  net::IPAddressNumber address;
441  if (result == net::MDnsTransaction::RESULT_RECORD) {
442    if (record->type() == net::dns_protocol::kTypeA) {
443      const net::ARecordRdata* rdata = record->rdata<net::ARecordRdata>();
444      address = rdata->address();
445    } else {
446      DCHECK_EQ(net::dns_protocol::kTypeAAAA, record->type());
447      const net::AAAARecordRdata* rdata = record->rdata<net::AAAARecordRdata>();
448      address = rdata->address();
449    }
450  }
451
452  callback_.Run(result == net::MDnsTransaction::RESULT_RECORD, address);
453}
454
455}  // namespace local_discovery
456