1// Copyright 2013 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "net/dns/mdns_client_impl.h" 6 7#include "base/bind.h" 8#include "base/message_loop/message_loop_proxy.h" 9#include "base/stl_util.h" 10#include "base/time/default_clock.h" 11#include "net/base/dns_util.h" 12#include "net/base/net_errors.h" 13#include "net/base/net_log.h" 14#include "net/base/rand_callback.h" 15#include "net/dns/dns_protocol.h" 16#include "net/dns/record_rdata.h" 17#include "net/udp/datagram_socket.h" 18 19// TODO(gene): Remove this temporary method of disabling NSEC support once it 20// becomes clear whether this feature should be 21// supported. http://crbug.com/255232 22#define ENABLE_NSEC 23 24namespace net { 25 26namespace { 27const char kMDnsMulticastGroupIPv4[] = "224.0.0.251"; 28const char kMDnsMulticastGroupIPv6[] = "FF02::FB"; 29const unsigned MDnsTransactionTimeoutSeconds = 3; 30} 31 32MDnsConnection::SocketHandler::SocketHandler( 33 MDnsConnection* connection, const IPEndPoint& multicast_addr, 34 MDnsConnection::SocketFactory* socket_factory) 35 : socket_(socket_factory->CreateSocket()), connection_(connection), 36 response_(new DnsResponse(dns_protocol::kMaxMulticastSize)), 37 multicast_addr_(multicast_addr) { 38} 39 40MDnsConnection::SocketHandler::~SocketHandler() { 41} 42 43int MDnsConnection::SocketHandler::Start() { 44 int rv = BindSocket(); 45 if (rv != OK) { 46 return rv; 47 } 48 49 return DoLoop(0); 50} 51 52int MDnsConnection::SocketHandler::DoLoop(int rv) { 53 do { 54 if (rv > 0) 55 connection_->OnDatagramReceived(response_.get(), recv_addr_, rv); 56 57 rv = socket_->RecvFrom( 58 response_->io_buffer(), 59 response_->io_buffer()->size(), 60 &recv_addr_, 61 base::Bind(&MDnsConnection::SocketHandler::OnDatagramReceived, 62 base::Unretained(this))); 63 } while (rv > 0); 64 65 if (rv != ERR_IO_PENDING) 66 return rv; 67 68 return OK; 69} 70 71void MDnsConnection::SocketHandler::OnDatagramReceived(int rv) { 72 if (rv >= OK) 73 rv = DoLoop(rv); 74 75 if (rv != OK) 76 connection_->OnError(this, rv); 77} 78 79int MDnsConnection::SocketHandler::Send(IOBuffer* buffer, unsigned size) { 80 return socket_->SendTo( 81 buffer, size, multicast_addr_, 82 base::Bind(&MDnsConnection::SocketHandler::SendDone, 83 base::Unretained(this) )); 84} 85 86void MDnsConnection::SocketHandler::SendDone(int rv) { 87 // TODO(noamsml): Retry logic. 88} 89 90int MDnsConnection::SocketHandler::BindSocket() { 91 IPAddressNumber address_any(multicast_addr_.address().size()); 92 93 IPEndPoint bind_endpoint(address_any, multicast_addr_.port()); 94 95 socket_->AllowAddressReuse(); 96 int rv = socket_->Listen(bind_endpoint); 97 98 if (rv < OK) return rv; 99 100 socket_->SetMulticastLoopbackMode(false); 101 102 return socket_->JoinGroup(multicast_addr_.address()); 103} 104 105MDnsConnection::MDnsConnection(MDnsConnection::SocketFactory* socket_factory, 106 MDnsConnection::Delegate* delegate) 107 : socket_handler_ipv4_(this, 108 GetMDnsIPEndPoint(kMDnsMulticastGroupIPv4), 109 socket_factory), 110 socket_handler_ipv6_(this, 111 GetMDnsIPEndPoint(kMDnsMulticastGroupIPv6), 112 socket_factory), 113 delegate_(delegate) { 114} 115 116MDnsConnection::~MDnsConnection() { 117} 118 119int MDnsConnection::Init() { 120 int rv; 121 122 rv = socket_handler_ipv4_.Start(); 123 if (rv != OK) return rv; 124 rv = socket_handler_ipv6_.Start(); 125 if (rv != OK) return rv; 126 127 return OK; 128} 129 130int MDnsConnection::Send(IOBuffer* buffer, unsigned size) { 131 int rv; 132 133 rv = socket_handler_ipv4_.Send(buffer, size); 134 if (rv < OK && rv != ERR_IO_PENDING) return rv; 135 136 rv = socket_handler_ipv6_.Send(buffer, size); 137 if (rv < OK && rv != ERR_IO_PENDING) return rv; 138 139 return OK; 140} 141 142void MDnsConnection::OnError(SocketHandler* loop, 143 int error) { 144 // TODO(noamsml): Specific handling of intermittent errors that can be handled 145 // in the connection. 146 delegate_->OnConnectionError(error); 147} 148 149IPEndPoint MDnsConnection::GetMDnsIPEndPoint(const char* address) { 150 IPAddressNumber multicast_group_number; 151 bool success = ParseIPLiteralToNumber(address, 152 &multicast_group_number); 153 DCHECK(success); 154 return IPEndPoint(multicast_group_number, 155 dns_protocol::kDefaultPortMulticast); 156} 157 158void MDnsConnection::OnDatagramReceived( 159 DnsResponse* response, 160 const IPEndPoint& recv_addr, 161 int bytes_read) { 162 // TODO(noamsml): More sophisticated error handling. 163 DCHECK_GT(bytes_read, 0); 164 delegate_->HandlePacket(response, bytes_read); 165} 166 167class MDnsConnectionSocketFactoryImpl 168 : public MDnsConnection::SocketFactory { 169 public: 170 MDnsConnectionSocketFactoryImpl(); 171 virtual ~MDnsConnectionSocketFactoryImpl(); 172 173 virtual scoped_ptr<DatagramServerSocket> CreateSocket() OVERRIDE; 174}; 175 176MDnsConnectionSocketFactoryImpl::MDnsConnectionSocketFactoryImpl() { 177} 178 179MDnsConnectionSocketFactoryImpl::~MDnsConnectionSocketFactoryImpl() { 180} 181 182scoped_ptr<DatagramServerSocket> 183MDnsConnectionSocketFactoryImpl::CreateSocket() { 184 return scoped_ptr<DatagramServerSocket>(new UDPServerSocket( 185 NULL, NetLog::Source())); 186} 187 188// static 189scoped_ptr<MDnsConnection::SocketFactory> 190MDnsConnection::SocketFactory::CreateDefault() { 191 return scoped_ptr<MDnsConnection::SocketFactory>( 192 new MDnsConnectionSocketFactoryImpl); 193} 194 195MDnsClientImpl::Core::Core(MDnsClientImpl* client, 196 MDnsConnection::SocketFactory* socket_factory) 197 : client_(client), connection_(new MDnsConnection(socket_factory, this)) { 198} 199 200MDnsClientImpl::Core::~Core() { 201 STLDeleteValues(&listeners_); 202} 203 204bool MDnsClientImpl::Core::Init() { 205 return connection_->Init() == OK; 206} 207 208bool MDnsClientImpl::Core::SendQuery(uint16 rrtype, std::string name) { 209 std::string name_dns; 210 if (!DNSDomainFromDot(name, &name_dns)) 211 return false; 212 213 DnsQuery query(0, name_dns, rrtype); 214 query.set_flags(0); // Remove the RD flag from the query. It is unneeded. 215 216 return connection_->Send(query.io_buffer(), query.io_buffer()->size()) == OK; 217} 218 219void MDnsClientImpl::Core::HandlePacket(DnsResponse* response, 220 int bytes_read) { 221 unsigned offset; 222 // Note: We store cache keys rather than record pointers to avoid 223 // erroneous behavior in case a packet contains multiple exclusive 224 // records with the same type and name. 225 std::map<MDnsCache::Key, MDnsListener::UpdateType> update_keys; 226 227 if (!response->InitParseWithoutQuery(bytes_read)) { 228 LOG(WARNING) << "Could not understand an mDNS packet."; 229 return; // Message is unreadable. 230 } 231 232 // TODO(noamsml): duplicate query suppression. 233 if (!(response->flags() & dns_protocol::kFlagResponse)) 234 return; // Message is a query. ignore it. 235 236 DnsRecordParser parser = response->Parser(); 237 unsigned answer_count = response->answer_count() + 238 response->additional_answer_count(); 239 240 for (unsigned i = 0; i < answer_count; i++) { 241 offset = parser.GetOffset(); 242 scoped_ptr<const RecordParsed> record = RecordParsed::CreateFrom( 243 &parser, base::Time::Now()); 244 245 if (!record) { 246 LOG(WARNING) << "Could not understand an mDNS record."; 247 248 if (offset == parser.GetOffset()) { 249 LOG(WARNING) << "Abandoned parsing the rest of the packet."; 250 return; // The parser did not advance, abort reading the packet. 251 } else { 252 continue; // We may be able to extract other records from the packet. 253 } 254 } 255 256 if ((record->klass() & dns_protocol::kMDnsClassMask) != 257 dns_protocol::kClassIN) { 258 LOG(WARNING) << "Received an mDNS record with non-IN class. Ignoring."; 259 continue; // Ignore all records not in the IN class. 260 } 261 262 MDnsCache::Key update_key = MDnsCache::Key::CreateFor(record.get()); 263 MDnsCache::UpdateType update = cache_.UpdateDnsRecord(record.Pass()); 264 265 // Cleanup time may have changed. 266 ScheduleCleanup(cache_.next_expiration()); 267 268 if (update != MDnsCache::NoChange) { 269 MDnsListener::UpdateType update_external; 270 271 switch (update) { 272 case MDnsCache::RecordAdded: 273 update_external = MDnsListener::RECORD_ADDED; 274 break; 275 case MDnsCache::RecordChanged: 276 update_external = MDnsListener::RECORD_CHANGED; 277 break; 278 case MDnsCache::NoChange: 279 default: 280 NOTREACHED(); 281 // Dummy assignment to suppress compiler warning. 282 update_external = MDnsListener::RECORD_CHANGED; 283 break; 284 } 285 286 update_keys.insert(std::make_pair(update_key, update_external)); 287 } 288 } 289 290 for (std::map<MDnsCache::Key, MDnsListener::UpdateType>::iterator i = 291 update_keys.begin(); i != update_keys.end(); i++) { 292 const RecordParsed* record = cache_.LookupKey(i->first); 293 if (!record) 294 continue; 295 296 if (record->type() == dns_protocol::kTypeNSEC) { 297#if defined(ENABLE_NSEC) 298 NotifyNsecRecord(record); 299#endif 300 } else { 301 AlertListeners(i->second, ListenerKey(record->name(), record->type()), 302 record); 303 } 304 } 305} 306 307void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed* record) { 308 DCHECK_EQ(dns_protocol::kTypeNSEC, record->type()); 309 const NsecRecordRdata* rdata = record->rdata<NsecRecordRdata>(); 310 DCHECK(rdata); 311 312 // Remove all cached records matching the nonexistent RR types. 313 std::vector<const RecordParsed*> records_to_remove; 314 315 cache_.FindDnsRecords(0, record->name(), &records_to_remove, 316 base::Time::Now()); 317 318 for (std::vector<const RecordParsed*>::iterator i = records_to_remove.begin(); 319 i != records_to_remove.end(); i++) { 320 if ((*i)->type() == dns_protocol::kTypeNSEC) 321 continue; 322 if (!rdata->GetBit((*i)->type())) { 323 scoped_ptr<const RecordParsed> record_removed = cache_.RemoveRecord((*i)); 324 DCHECK(record_removed); 325 OnRecordRemoved(record_removed.get()); 326 } 327 } 328 329 // Alert all listeners waiting for the nonexistent RR types. 330 ListenerMap::iterator i = 331 listeners_.upper_bound(ListenerKey(record->name(), 0)); 332 for (; i != listeners_.end() && i->first.first == record->name(); i++) { 333 if (!rdata->GetBit(i->first.second)) { 334 FOR_EACH_OBSERVER(MDnsListenerImpl, *i->second, AlertNsecRecord()); 335 } 336 } 337} 338 339void MDnsClientImpl::Core::OnConnectionError(int error) { 340 // TODO(noamsml): On connection error, recreate connection and flush cache. 341} 342 343void MDnsClientImpl::Core::AlertListeners( 344 MDnsListener::UpdateType update_type, 345 const ListenerKey& key, 346 const RecordParsed* record) { 347 ListenerMap::iterator listener_map_iterator = listeners_.find(key); 348 if (listener_map_iterator == listeners_.end()) return; 349 350 FOR_EACH_OBSERVER(MDnsListenerImpl, *listener_map_iterator->second, 351 AlertDelegate(update_type, record)); 352} 353 354void MDnsClientImpl::Core::AddListener( 355 MDnsListenerImpl* listener) { 356 ListenerKey key(listener->GetName(), listener->GetType()); 357 std::pair<ListenerMap::iterator, bool> observer_insert_result = 358 listeners_.insert( 359 make_pair(key, static_cast<ObserverList<MDnsListenerImpl>*>(NULL))); 360 361 // If an equivalent key does not exist, actually create the observer list. 362 if (observer_insert_result.second) 363 observer_insert_result.first->second = new ObserverList<MDnsListenerImpl>(); 364 365 ObserverList<MDnsListenerImpl>* observer_list = 366 observer_insert_result.first->second; 367 368 observer_list->AddObserver(listener); 369} 370 371void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl* listener) { 372 ListenerKey key(listener->GetName(), listener->GetType()); 373 ListenerMap::iterator observer_list_iterator = listeners_.find(key); 374 375 DCHECK(observer_list_iterator != listeners_.end()); 376 DCHECK(observer_list_iterator->second->HasObserver(listener)); 377 378 observer_list_iterator->second->RemoveObserver(listener); 379 380 // Remove the observer list from the map if it is empty 381 if (observer_list_iterator->second->size() == 0) { 382 // Schedule the actual removal for later in case the listener removal 383 // happens while iterating over the observer list. 384 base::MessageLoop::current()->PostTask( 385 FROM_HERE, base::Bind( 386 &MDnsClientImpl::Core::CleanupObserverList, AsWeakPtr(), key)); 387 } 388} 389 390void MDnsClientImpl::Core::CleanupObserverList(const ListenerKey& key) { 391 ListenerMap::iterator found = listeners_.find(key); 392 if (found != listeners_.end() && found->second->size() == 0) { 393 delete found->second; 394 listeners_.erase(found); 395 } 396} 397 398void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup) { 399 // Cleanup is already scheduled, no need to do anything. 400 if (cleanup == scheduled_cleanup_) return; 401 scheduled_cleanup_ = cleanup; 402 403 // This cancels the previously scheduled cleanup. 404 cleanup_callback_.Reset(base::Bind( 405 &MDnsClientImpl::Core::DoCleanup, base::Unretained(this))); 406 407 // If |cleanup| is empty, then no cleanup necessary. 408 if (cleanup != base::Time()) { 409 base::MessageLoop::current()->PostDelayedTask( 410 FROM_HERE, 411 cleanup_callback_.callback(), 412 cleanup - base::Time::Now()); 413 } 414} 415 416void MDnsClientImpl::Core::DoCleanup() { 417 cache_.CleanupRecords(base::Time::Now(), base::Bind( 418 &MDnsClientImpl::Core::OnRecordRemoved, base::Unretained(this))); 419 420 ScheduleCleanup(cache_.next_expiration()); 421} 422 423void MDnsClientImpl::Core::OnRecordRemoved( 424 const RecordParsed* record) { 425 AlertListeners(MDnsListener::RECORD_REMOVED, 426 ListenerKey(record->name(), record->type()), record); 427} 428 429void MDnsClientImpl::Core::QueryCache( 430 uint16 rrtype, const std::string& name, 431 std::vector<const RecordParsed*>* records) const { 432 cache_.FindDnsRecords(rrtype, name, records, base::Time::Now()); 433} 434 435MDnsClientImpl::MDnsClientImpl( 436 scoped_ptr<MDnsConnection::SocketFactory> socket_factory) 437 : socket_factory_(socket_factory.Pass()) { 438} 439 440MDnsClientImpl::~MDnsClientImpl() { 441} 442 443bool MDnsClientImpl::StartListening() { 444 DCHECK(!core_.get()); 445 core_.reset(new Core(this, socket_factory_.get())); 446 if (!core_->Init()) { 447 core_.reset(); 448 return false; 449 } 450 return true; 451} 452 453void MDnsClientImpl::StopListening() { 454 core_.reset(); 455} 456 457bool MDnsClientImpl::IsListening() const { 458 return core_.get() != NULL; 459} 460 461scoped_ptr<MDnsListener> MDnsClientImpl::CreateListener( 462 uint16 rrtype, 463 const std::string& name, 464 MDnsListener::Delegate* delegate) { 465 return scoped_ptr<net::MDnsListener>( 466 new MDnsListenerImpl(rrtype, name, delegate, this)); 467} 468 469scoped_ptr<MDnsTransaction> MDnsClientImpl::CreateTransaction( 470 uint16 rrtype, 471 const std::string& name, 472 int flags, 473 const MDnsTransaction::ResultCallback& callback) { 474 return scoped_ptr<MDnsTransaction>( 475 new MDnsTransactionImpl(rrtype, name, flags, callback, this)); 476} 477 478MDnsListenerImpl::MDnsListenerImpl( 479 uint16 rrtype, 480 const std::string& name, 481 MDnsListener::Delegate* delegate, 482 MDnsClientImpl* client) 483 : rrtype_(rrtype), name_(name), client_(client), delegate_(delegate), 484 started_(false) { 485} 486 487bool MDnsListenerImpl::Start() { 488 DCHECK(!started_); 489 490 started_ = true; 491 492 DCHECK(client_->core()); 493 client_->core()->AddListener(this); 494 495 return true; 496} 497 498MDnsListenerImpl::~MDnsListenerImpl() { 499 if (started_) { 500 DCHECK(client_->core()); 501 client_->core()->RemoveListener(this); 502 } 503} 504 505const std::string& MDnsListenerImpl::GetName() const { 506 return name_; 507} 508 509uint16 MDnsListenerImpl::GetType() const { 510 return rrtype_; 511} 512 513void MDnsListenerImpl::AlertDelegate(MDnsListener::UpdateType update_type, 514 const RecordParsed* record) { 515 DCHECK(started_); 516 delegate_->OnRecordUpdate(update_type, record); 517} 518 519void MDnsListenerImpl::AlertNsecRecord() { 520 DCHECK(started_); 521 delegate_->OnNsecRecord(name_, rrtype_); 522} 523 524MDnsTransactionImpl::MDnsTransactionImpl( 525 uint16 rrtype, 526 const std::string& name, 527 int flags, 528 const MDnsTransaction::ResultCallback& callback, 529 MDnsClientImpl* client) 530 : rrtype_(rrtype), name_(name), callback_(callback), client_(client), 531 started_(false), flags_(flags) { 532 DCHECK((flags_ & MDnsTransaction::FLAG_MASK) == flags_); 533 DCHECK(flags_ & MDnsTransaction::QUERY_CACHE || 534 flags_ & MDnsTransaction::QUERY_NETWORK); 535} 536 537MDnsTransactionImpl::~MDnsTransactionImpl() { 538 timeout_.Cancel(); 539} 540 541bool MDnsTransactionImpl::Start() { 542 DCHECK(!started_); 543 started_ = true; 544 545 base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr(); 546 if (flags_ & MDnsTransaction::QUERY_CACHE) { 547 ServeRecordsFromCache(); 548 549 if (!weak_this || !is_active()) return true; 550 } 551 552 if (flags_ & MDnsTransaction::QUERY_NETWORK) { 553 return QueryAndListen(); 554 } 555 556 // If this is a cache only query, signal that the transaction is over 557 // immediately. 558 SignalTransactionOver(); 559 return true; 560} 561 562const std::string& MDnsTransactionImpl::GetName() const { 563 return name_; 564} 565 566uint16 MDnsTransactionImpl::GetType() const { 567 return rrtype_; 568} 569 570void MDnsTransactionImpl::CacheRecordFound(const RecordParsed* record) { 571 DCHECK(started_); 572 OnRecordUpdate(MDnsListener::RECORD_ADDED, record); 573} 574 575void MDnsTransactionImpl::TriggerCallback(MDnsTransaction::Result result, 576 const RecordParsed* record) { 577 DCHECK(started_); 578 if (!is_active()) return; 579 580 // Ensure callback is run after touching all class state, so that 581 // the callback can delete the transaction. 582 MDnsTransaction::ResultCallback callback = callback_; 583 584 // Reset the transaction if it expects a single result, or if the result 585 // is a final one (everything except for a record). 586 if (flags_ & MDnsTransaction::SINGLE_RESULT || 587 result != MDnsTransaction::RESULT_RECORD) { 588 Reset(); 589 } 590 591 callback.Run(result, record); 592} 593 594void MDnsTransactionImpl::Reset() { 595 callback_.Reset(); 596 listener_.reset(); 597 timeout_.Cancel(); 598} 599 600void MDnsTransactionImpl::OnRecordUpdate(MDnsListener::UpdateType update, 601 const RecordParsed* record) { 602 DCHECK(started_); 603 if (update == MDnsListener::RECORD_ADDED || 604 update == MDnsListener::RECORD_CHANGED) 605 TriggerCallback(MDnsTransaction::RESULT_RECORD, record); 606} 607 608void MDnsTransactionImpl::SignalTransactionOver() { 609 DCHECK(started_); 610 if (flags_ & MDnsTransaction::SINGLE_RESULT) { 611 TriggerCallback(MDnsTransaction::RESULT_NO_RESULTS, NULL); 612 } else { 613 TriggerCallback(MDnsTransaction::RESULT_DONE, NULL); 614 } 615} 616 617void MDnsTransactionImpl::ServeRecordsFromCache() { 618 std::vector<const RecordParsed*> records; 619 base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr(); 620 621 if (client_->core()) { 622 client_->core()->QueryCache(rrtype_, name_, &records); 623 for (std::vector<const RecordParsed*>::iterator i = records.begin(); 624 i != records.end() && weak_this; ++i) { 625 weak_this->TriggerCallback(MDnsTransaction::RESULT_RECORD, *i); 626 } 627 628#if defined(ENABLE_NSEC) 629 if (records.empty()) { 630 DCHECK(weak_this); 631 client_->core()->QueryCache(dns_protocol::kTypeNSEC, name_, &records); 632 if (!records.empty()) { 633 const NsecRecordRdata* rdata = 634 records.front()->rdata<NsecRecordRdata>(); 635 DCHECK(rdata); 636 if (!rdata->GetBit(rrtype_)) 637 weak_this->TriggerCallback(MDnsTransaction::RESULT_NSEC, NULL); 638 } 639 } 640#endif 641 } 642} 643 644bool MDnsTransactionImpl::QueryAndListen() { 645 listener_ = client_->CreateListener(rrtype_, name_, this); 646 if (!listener_->Start()) 647 return false; 648 649 DCHECK(client_->core()); 650 if (!client_->core()->SendQuery(rrtype_, name_)) 651 return false; 652 653 timeout_.Reset(base::Bind(&MDnsTransactionImpl::SignalTransactionOver, 654 AsWeakPtr())); 655 base::MessageLoop::current()->PostDelayedTask( 656 FROM_HERE, 657 timeout_.callback(), 658 base::TimeDelta::FromSeconds(MDnsTransactionTimeoutSeconds)); 659 660 return true; 661} 662 663void MDnsTransactionImpl::OnNsecRecord(const std::string& name, unsigned type) { 664 TriggerCallback(RESULT_NSEC, NULL); 665} 666 667void MDnsTransactionImpl::OnCachePurged() { 668 // TODO(noamsml): Cache purge situations not yet implemented 669} 670 671} // namespace net 672