1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/browser/extensions/api/dial/dial_service.h"
6
7#include <algorithm>
8#include <set>
9#include <utility>
10
11#include "base/basictypes.h"
12#include "base/callback.h"
13#include "base/logging.h"
14#include "base/rand_util.h"
15#include "base/strings/string_number_conversions.h"
16#include "base/strings/stringprintf.h"
17#include "base/time/time.h"
18#include "chrome/browser/extensions/api/dial/dial_device_data.h"
19#include "chrome/common/chrome_version_info.h"
20#include "content/public/browser/browser_thread.h"
21#include "net/base/completion_callback.h"
22#include "net/base/io_buffer.h"
23#include "net/base/ip_endpoint.h"
24#include "net/base/net_errors.h"
25#include "net/base/net_util.h"
26#include "net/http/http_response_headers.h"
27#include "net/http/http_util.h"
28#include "url/gurl.h"
29#if defined(OS_CHROMEOS)
30#include "chromeos/network/network_state.h"
31#include "chromeos/network/network_state_handler.h"
32#include "third_party/cros_system_api/dbus/service_constants.h"
33#endif
34
35using base::Time;
36using base::TimeDelta;
37using content::BrowserThread;
38using net::HttpResponseHeaders;
39using net::HttpUtil;
40using net::IOBufferWithSize;
41using net::IPAddressNumber;
42using net::IPEndPoint;
43using net::NetworkInterface;
44using net::NetworkInterfaceList;
45using net::StringIOBuffer;
46using net::UDPSocket;
47
48namespace extensions {
49
50namespace {
51
52// The total number of requests to make per discovery cycle.
53const int kDialMaxRequests = 4;
54
55// The interval to wait between successive requests.
56const int kDialRequestIntervalMillis = 1000;
57
58// The maximum delay a device may wait before responding (MX).
59const int kDialMaxResponseDelaySecs = 1;
60
61// The maximum time a response is expected after a M-SEARCH request.
62const int kDialResponseTimeoutSecs = 2;
63
64// The multicast IP address for discovery.
65const char kDialRequestAddress[] = "239.255.255.250";
66
67// The UDP port number for discovery.
68const int kDialRequestPort = 1900;
69
70// The DIAL service type as part of the search request.
71const char kDialSearchType[] = "urn:dial-multiscreen-org:service:dial:1";
72
73// SSDP headers parsed from the response.
74const char kSsdpLocationHeader[] = "LOCATION";
75const char kSsdpCacheControlHeader[] = "CACHE-CONTROL";
76const char kSsdpConfigIdHeader[] = "CONFIGID.UPNP.ORG";
77const char kSsdpUsnHeader[] = "USN";
78
79// The receive buffer size, in bytes.
80const int kDialRecvBufferSize = 1500;
81
82// Gets a specific header from |headers| and puts it in |value|.
83bool GetHeader(HttpResponseHeaders* headers, const char* name,
84               std::string* value) {
85  return headers->EnumerateHeader(NULL, std::string(name), value);
86}
87
88// Returns the request string.
89std::string BuildRequest() {
90  // Extra line at the end to make UPnP lib happy.
91  chrome::VersionInfo version;
92  std::string request(base::StringPrintf(
93      "M-SEARCH * HTTP/1.1\r\n"
94      "HOST: %s:%i\r\n"
95      "MAN: \"ssdp:discover\"\r\n"
96      "MX: %d\r\n"
97      "ST: %s\r\n"
98      "USER-AGENT: %s/%s %s\r\n"
99      "\r\n",
100      kDialRequestAddress,
101      kDialRequestPort,
102      kDialMaxResponseDelaySecs,
103      kDialSearchType,
104      version.Name().c_str(),
105      version.Version().c_str(),
106      version.OSType().c_str()));
107  // 1500 is a good MTU value for most Ethernet LANs.
108  DCHECK(request.size() <= 1500);
109  return request;
110}
111
112#if !defined(OS_CHROMEOS)
113void GetNetworkListOnFileThread(
114    const scoped_refptr<base::MessageLoopProxy>& loop,
115    const base::Callback<void(const NetworkInterfaceList& networks)>& cb) {
116  NetworkInterfaceList list;
117  bool success = net::GetNetworkList(
118      &list, net::INCLUDE_HOST_SCOPE_VIRTUAL_INTERFACES);
119  if (!success)
120    VLOG(1) << "Could not retrieve network list!";
121
122  loop->PostTask(FROM_HERE, base::Bind(cb, list));
123}
124
125#else
126
127// Finds the IP address of the preferred interface of network type |type|
128// to bind the socket and inserts the address into |bind_address_list|. This
129// ChromeOS version can prioritize wifi and ethernet interfaces.
130void InsertBestBindAddressChromeOS(
131    const chromeos::NetworkTypePattern& type,
132    std::vector<IPAddressNumber>* bind_address_list) {
133  const chromeos::NetworkState* state = chromeos::NetworkHandler::Get()
134      ->network_state_handler()->ConnectedNetworkByType(type);
135  IPAddressNumber bind_ip_address;
136  if (state
137      && net::ParseIPLiteralToNumber(state->ip_address(), &bind_ip_address)
138      && bind_ip_address.size() == net::kIPv4AddressSize) {
139    VLOG(2) << "Found " << state->type() << ", " << state->name() << ": "
140            << state->ip_address();
141    bind_address_list->push_back(bind_ip_address);
142  }
143}
144#endif  // !defined(OS_CHROMEOS)
145
146}  // namespace
147
148DialServiceImpl::DialSocket::DialSocket(
149    const base::Closure& discovery_request_cb,
150    const base::Callback<void(const DialDeviceData&)>& device_discovered_cb,
151    const base::Closure& on_error_cb)
152    : discovery_request_cb_(discovery_request_cb),
153      device_discovered_cb_(device_discovered_cb),
154      on_error_cb_(on_error_cb),
155      is_writing_(false),
156      is_reading_(false) {
157}
158
159DialServiceImpl::DialSocket::~DialSocket() {
160  DCHECK(thread_checker_.CalledOnValidThread());
161}
162
163bool DialServiceImpl::DialSocket::CreateAndBindSocket(
164    const IPAddressNumber& bind_ip_address,
165    net::NetLog* net_log,
166    net::NetLog::Source net_log_source) {
167  DCHECK(thread_checker_.CalledOnValidThread());
168  DCHECK(!socket_.get());
169  DCHECK(bind_ip_address.size() == net::kIPv4AddressSize);
170
171  net::RandIntCallback rand_cb = base::Bind(&base::RandInt);
172  socket_.reset(new UDPSocket(net::DatagramSocket::RANDOM_BIND,
173                              rand_cb,
174                              net_log,
175                              net_log_source));
176  socket_->AllowBroadcast();
177
178  // 0 means bind a random port
179  IPEndPoint address(bind_ip_address, 0);
180
181  if (!CheckResult("Bind", socket_->Bind(address)))
182    return false;
183
184  DCHECK(socket_.get());
185
186  recv_buffer_ = new IOBufferWithSize(kDialRecvBufferSize);
187  return ReadSocket();
188}
189
190void DialServiceImpl::DialSocket::SendOneRequest(
191    const net::IPEndPoint& send_address,
192    const scoped_refptr<net::StringIOBuffer>& send_buffer) {
193  if (!socket_.get()) {
194    VLOG(1) << "Socket not connected.";
195    return;
196  }
197
198  if (is_writing_) {
199    VLOG(1) << "Already writing.";
200    return;
201  }
202
203  is_writing_ = true;
204  int result = socket_->SendTo(
205      send_buffer.get(), send_buffer->size(), send_address,
206      base::Bind(&DialServiceImpl::DialSocket::OnSocketWrite,
207                 base::Unretained(this),
208                 send_buffer->size()));
209  bool result_ok = CheckResult("SendTo", result);
210  if (result_ok && result > 0) {
211    // Synchronous write.
212    OnSocketWrite(send_buffer->size(), result);
213  }
214}
215
216bool DialServiceImpl::DialSocket::IsClosed() {
217  DCHECK(thread_checker_.CalledOnValidThread());
218  return !socket_.get();
219}
220
221bool DialServiceImpl::DialSocket::CheckResult(const char* operation,
222                                              int result) {
223  DCHECK(thread_checker_.CalledOnValidThread());
224  VLOG(2) << "Operation " << operation << " result " << result;
225  if (result < net::OK && result != net::ERR_IO_PENDING) {
226    Close();
227    std::string error_str(net::ErrorToString(result));
228    VLOG(1) << "dial socket error: " << error_str;
229    on_error_cb_.Run();
230    return false;
231  }
232  return true;
233}
234
235void DialServiceImpl::DialSocket::Close() {
236  DCHECK(thread_checker_.CalledOnValidThread());
237  is_reading_ = false;
238  is_writing_ = false;
239  socket_.reset();
240}
241
242void DialServiceImpl::DialSocket::OnSocketWrite(int send_buffer_size,
243                                                int result) {
244  DCHECK(thread_checker_.CalledOnValidThread());
245  is_writing_ = false;
246  if (!CheckResult("OnSocketWrite", result))
247    return;
248  if (result != send_buffer_size) {
249    VLOG(1) << "Sent " << result << " chars, expected "
250            << send_buffer_size << " chars";
251  }
252  discovery_request_cb_.Run();
253}
254
255bool DialServiceImpl::DialSocket::ReadSocket() {
256  DCHECK(thread_checker_.CalledOnValidThread());
257  if (!socket_.get()) {
258    VLOG(1) << "Socket not connected.";
259    return false;
260  }
261
262  if (is_reading_) {
263    VLOG(1) << "Already reading.";
264    return false;
265  }
266
267  int result = net::OK;
268  bool result_ok = true;
269  do {
270    is_reading_ = true;
271    result = socket_->RecvFrom(
272        recv_buffer_.get(),
273        kDialRecvBufferSize, &recv_address_,
274        base::Bind(&DialServiceImpl::DialSocket::OnSocketRead,
275                   base::Unretained(this)));
276    result_ok = CheckResult("RecvFrom", result);
277    if (result != net::ERR_IO_PENDING)
278      is_reading_ = false;
279    if (result_ok && result > 0) {
280      // Synchronous read.
281      HandleResponse(result);
282    }
283  } while (result_ok && result != net::OK && result != net::ERR_IO_PENDING);
284  return result_ok;
285}
286
287void DialServiceImpl::DialSocket::OnSocketRead(int result) {
288  DCHECK(thread_checker_.CalledOnValidThread());
289  is_reading_ = false;
290  if (!CheckResult("OnSocketRead", result))
291    return;
292  if (result > 0)
293    HandleResponse(result);
294
295  // Await next response.
296  ReadSocket();
297}
298
299void DialServiceImpl::DialSocket::HandleResponse(int bytes_read) {
300  DCHECK(thread_checker_.CalledOnValidThread());
301  DCHECK_GT(bytes_read, 0);
302  if (bytes_read > kDialRecvBufferSize) {
303    VLOG(1) << bytes_read << " > " << kDialRecvBufferSize << "!?";
304    return;
305  }
306  VLOG(2) << "Read " << bytes_read << " bytes from "
307          << recv_address_.ToString();
308
309  std::string response(recv_buffer_->data(), bytes_read);
310  Time response_time = Time::Now();
311
312  // Attempt to parse response, notify observers if successful.
313  DialDeviceData parsed_device;
314  if (ParseResponse(response, response_time, &parsed_device))
315    device_discovered_cb_.Run(parsed_device);
316}
317
318// static
319bool DialServiceImpl::DialSocket::ParseResponse(
320    const std::string& response,
321    const base::Time& response_time,
322    DialDeviceData* device) {
323  int headers_end = HttpUtil::LocateEndOfHeaders(response.c_str(),
324                                                 response.size());
325  if (headers_end < 1) {
326    VLOG(1) << "Headers invalid or empty, ignoring: " << response;
327    return false;
328  }
329  std::string raw_headers =
330      HttpUtil::AssembleRawHeaders(response.c_str(), headers_end);
331  VLOG(3) << "raw_headers: " << raw_headers << "\n";
332  scoped_refptr<HttpResponseHeaders> headers =
333      new HttpResponseHeaders(raw_headers);
334
335  std::string device_url_str;
336  if (!GetHeader(headers.get(), kSsdpLocationHeader, &device_url_str) ||
337      device_url_str.empty()) {
338    VLOG(1) << "No LOCATION header found.";
339    return false;
340  }
341
342  GURL device_url(device_url_str);
343  if (!DialDeviceData::IsDeviceDescriptionUrl(device_url)) {
344    VLOG(1) << "URL " << device_url_str << " not valid.";
345    return false;
346  }
347
348  std::string device_id;
349  if (!GetHeader(headers.get(), kSsdpUsnHeader, &device_id) ||
350      device_id.empty()) {
351    VLOG(1) << "No USN header found.";
352    return false;
353  }
354
355  device->set_device_id(device_id);
356  device->set_device_description_url(device_url);
357  device->set_response_time(response_time);
358
359  // TODO(mfoltz): Parse the max-age value from the cache control header.
360  // http://crbug.com/165289
361  std::string cache_control;
362  GetHeader(headers.get(), kSsdpCacheControlHeader, &cache_control);
363
364  std::string config_id;
365  int config_id_int;
366  if (GetHeader(headers.get(), kSsdpConfigIdHeader, &config_id) &&
367      base::StringToInt(config_id, &config_id_int)) {
368    device->set_config_id(config_id_int);
369  } else {
370    VLOG(1) << "Malformed or missing " << kSsdpConfigIdHeader << ": "
371            << config_id;
372  }
373
374  return true;
375}
376
377DialServiceImpl::DialServiceImpl(net::NetLog* net_log)
378    : discovery_active_(false),
379      num_requests_sent_(0),
380      max_requests_(kDialMaxRequests),
381      finish_delay_(TimeDelta::FromMilliseconds((kDialMaxRequests - 1) *
382                                                kDialRequestIntervalMillis) +
383                    TimeDelta::FromSeconds(kDialResponseTimeoutSecs)),
384      request_interval_(
385          TimeDelta::FromMilliseconds(kDialRequestIntervalMillis)) {
386  IPAddressNumber address;
387  bool success = net::ParseIPLiteralToNumber(kDialRequestAddress, &address);
388  DCHECK(success);
389  send_address_ = IPEndPoint(address, kDialRequestPort);
390  send_buffer_ = new StringIOBuffer(BuildRequest());
391  net_log_ = net_log;
392  net_log_source_.type = net::NetLog::SOURCE_UDP_SOCKET;
393  net_log_source_.id = net_log_->NextID();
394}
395
396DialServiceImpl::~DialServiceImpl() {
397  DCHECK(thread_checker_.CalledOnValidThread());
398}
399
400void DialServiceImpl::AddObserver(Observer* observer) {
401  DCHECK(thread_checker_.CalledOnValidThread());
402  observer_list_.AddObserver(observer);
403}
404
405void DialServiceImpl::RemoveObserver(Observer* observer) {
406  DCHECK(thread_checker_.CalledOnValidThread());
407  observer_list_.RemoveObserver(observer);
408}
409
410bool DialServiceImpl::HasObserver(Observer* observer) {
411  DCHECK(thread_checker_.CalledOnValidThread());
412  return observer_list_.HasObserver(observer);
413}
414
415bool DialServiceImpl::Discover() {
416  DCHECK(thread_checker_.CalledOnValidThread());
417  if (discovery_active_) {
418    VLOG(2) << "Discovery is already active - returning.";
419    return false;
420  }
421  discovery_active_ = true;
422
423  VLOG(2) << "Discovery started.";
424
425  StartDiscovery();
426  return true;
427}
428
429void DialServiceImpl::StartDiscovery() {
430  DCHECK(thread_checker_.CalledOnValidThread());
431  DCHECK(discovery_active_);
432  if (HasOpenSockets()) {
433    VLOG(2) << "Calling StartDiscovery() with open sockets. Returning.";
434    return;
435  }
436
437#if defined(OS_CHROMEOS)
438  // The ChromeOS specific version of getting network interfaces does not
439  // require trampolining to another thread, and contains additional interface
440  // information such as interface types (i.e. wifi vs cellular).
441  std::vector<IPAddressNumber> chrome_os_address_list;
442  InsertBestBindAddressChromeOS(chromeos::NetworkTypePattern::Ethernet(),
443                                &chrome_os_address_list);
444  InsertBestBindAddressChromeOS(chromeos::NetworkTypePattern::WiFi(),
445                                &chrome_os_address_list);
446  DiscoverOnAddresses(chrome_os_address_list);
447
448#else
449  BrowserThread::PostTask(BrowserThread::FILE, FROM_HERE, base::Bind(
450      &GetNetworkListOnFileThread,
451      base::MessageLoopProxy::current(), base::Bind(
452          &DialServiceImpl::SendNetworkList, AsWeakPtr())));
453#endif
454}
455
456void DialServiceImpl::SendNetworkList(const NetworkInterfaceList& networks) {
457  DCHECK(thread_checker_.CalledOnValidThread());
458  typedef std::pair<uint32, net::AddressFamily> InterfaceIndexAddressFamily;
459  std::set<InterfaceIndexAddressFamily> interface_index_addr_family_seen;
460  std::vector<IPAddressNumber> ip_addresses;
461
462  // Binds a socket to each IPv4 network interface found. Note that
463  // there may be duplicates in |networks|, so address family + interface index
464  // is used to identify unique interfaces.
465  // TODO(mfoltz): Support IPV6 multicast.  http://crbug.com/165286
466  for (NetworkInterfaceList::const_iterator iter = networks.begin();
467       iter != networks.end(); ++iter) {
468    net::AddressFamily addr_family = net::GetAddressFamily(iter->address);
469    VLOG(2) << "Found " << iter->name << ", "
470            << net::IPAddressToString(iter->address)
471            << ", address family: " << addr_family;
472    if (addr_family == net::ADDRESS_FAMILY_IPV4) {
473      InterfaceIndexAddressFamily interface_index_addr_family =
474          std::make_pair(iter->interface_index, addr_family);
475      bool inserted = interface_index_addr_family_seen
476          .insert(interface_index_addr_family)
477          .second;
478      // We have not seen this interface before, so add its IP address to the
479      // discovery list.
480      if (inserted) {
481        VLOG(2) << "Encountered "
482                << "interface index: " << iter->interface_index << ", "
483                << "address family: " << addr_family << " for the first time, "
484                << "adding IP address " << net::IPAddressToString(iter->address)
485                << " to list.";
486        ip_addresses.push_back(iter->address);
487      } else {
488        VLOG(2) << "Already encountered "
489                << "interface index: " << iter->interface_index << ", "
490                << "address family: " << addr_family << " before, not adding.";
491      }
492    }
493  }
494
495  DiscoverOnAddresses(ip_addresses);
496}
497
498void DialServiceImpl::DiscoverOnAddresses(
499    const std::vector<IPAddressNumber>& ip_addresses) {
500  if (ip_addresses.empty()) {
501    VLOG(1) << "Could not find a valid interface to bind. Finishing discovery";
502    FinishDiscovery();
503    return;
504  }
505
506  // Schedule a timer to finish the discovery process (and close the sockets).
507  if (finish_delay_ > TimeDelta::FromSeconds(0)) {
508    VLOG(2) << "Starting timer to finish discovery.";
509    finish_timer_.Start(FROM_HERE,
510                        finish_delay_,
511                        this,
512                        &DialServiceImpl::FinishDiscovery);
513  }
514
515  for (std::vector<IPAddressNumber>::const_iterator iter = ip_addresses.begin();
516       iter != ip_addresses.end();
517       ++iter)
518    BindAndAddSocket(*iter);
519
520  SendOneRequest();
521}
522
523void DialServiceImpl::BindAndAddSocket(const IPAddressNumber& bind_ip_address) {
524  scoped_ptr<DialServiceImpl::DialSocket> dial_socket(CreateDialSocket());
525  if (dial_socket->CreateAndBindSocket(bind_ip_address, net_log_,
526                                       net_log_source_))
527    dial_sockets_.push_back(dial_socket.release());
528}
529
530scoped_ptr<DialServiceImpl::DialSocket> DialServiceImpl::CreateDialSocket() {
531  scoped_ptr<DialServiceImpl::DialSocket> dial_socket(
532      new DialServiceImpl::DialSocket(
533          base::Bind(&DialServiceImpl::NotifyOnDiscoveryRequest, AsWeakPtr()),
534          base::Bind(&DialServiceImpl::NotifyOnDeviceDiscovered, AsWeakPtr()),
535          base::Bind(&DialServiceImpl::NotifyOnError, AsWeakPtr())));
536  return dial_socket.Pass();
537}
538
539void DialServiceImpl::SendOneRequest() {
540  DCHECK(thread_checker_.CalledOnValidThread());
541  if (num_requests_sent_ == max_requests_) {
542    VLOG(2) << "Reached max requests; stopping request timer.";
543    request_timer_.Stop();
544    return;
545  }
546  num_requests_sent_++;
547  VLOG(2) << "Sending request " << num_requests_sent_ << "/"
548          << max_requests_;
549  for (ScopedVector<DialServiceImpl::DialSocket>::iterator iter =
550           dial_sockets_.begin();
551       iter != dial_sockets_.end();
552       ++iter) {
553    if (!((*iter)->IsClosed()))
554      (*iter)->SendOneRequest(send_address_, send_buffer_);
555  }
556}
557
558void DialServiceImpl::NotifyOnDiscoveryRequest() {
559  DCHECK(thread_checker_.CalledOnValidThread());
560  // If discovery is inactive, no reason to notify observers.
561  if (!discovery_active_) {
562    VLOG(2) << "Request sent after discovery finished.  Ignoring.";
563    return;
564  }
565
566  VLOG(2) << "Notifying observers of discovery request";
567  FOR_EACH_OBSERVER(Observer, observer_list_, OnDiscoveryRequest(this));
568  // If we need to send additional requests, schedule a timer to do so.
569  if (num_requests_sent_ < max_requests_ && num_requests_sent_ == 1) {
570    VLOG(2) << "Scheduling timer to send additional requests";
571    // TODO(imcheng): Move this to SendOneRequest() once the implications are
572    // understood.
573    request_timer_.Start(FROM_HERE,
574                         request_interval_,
575                         this,
576                         &DialServiceImpl::SendOneRequest);
577  }
578}
579
580void DialServiceImpl::NotifyOnDeviceDiscovered(
581    const DialDeviceData& device_data) {
582  DCHECK(thread_checker_.CalledOnValidThread());
583  if (!discovery_active_) {
584    VLOG(2) << "Got response after discovery finished.  Ignoring.";
585    return;
586  }
587  FOR_EACH_OBSERVER(Observer, observer_list_,
588                    OnDeviceDiscovered(this, device_data));
589}
590
591void DialServiceImpl::NotifyOnError() {
592  DCHECK(thread_checker_.CalledOnValidThread());
593  // TODO(imcheng): Modify upstream so that the device list is not cleared
594  // when it could still potentially discover devices on other sockets.
595  FOR_EACH_OBSERVER(Observer, observer_list_,
596                    OnError(this,
597                            HasOpenSockets() ? DIAL_SERVICE_SOCKET_ERROR
598                                             : DIAL_SERVICE_NO_INTERFACES));
599}
600
601void DialServiceImpl::FinishDiscovery() {
602  DCHECK(thread_checker_.CalledOnValidThread());
603  DCHECK(discovery_active_);
604  VLOG(2) << "Discovery finished.";
605  // Close all open sockets.
606  dial_sockets_.clear();
607  finish_timer_.Stop();
608  request_timer_.Stop();
609  discovery_active_ = false;
610  num_requests_sent_ = 0;
611  FOR_EACH_OBSERVER(Observer, observer_list_, OnDiscoveryFinished(this));
612}
613
614bool DialServiceImpl::HasOpenSockets() {
615  for (ScopedVector<DialSocket>::const_iterator iter = dial_sockets_.begin();
616       iter != dial_sockets_.end();
617       ++iter) {
618    if (!((*iter)->IsClosed()))
619      return true;
620  }
621  return false;
622}
623
624}  // namespace extensions
625