1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/websockets/websocket_basic_handshake_stream.h"
6
7#include <algorithm>
8#include <iterator>
9#include <set>
10#include <string>
11#include <vector>
12
13#include "base/base64.h"
14#include "base/basictypes.h"
15#include "base/bind.h"
16#include "base/containers/hash_tables.h"
17#include "base/logging.h"
18#include "base/metrics/histogram.h"
19#include "base/metrics/sparse_histogram.h"
20#include "base/stl_util.h"
21#include "base/strings/string_number_conversions.h"
22#include "base/strings/string_piece.h"
23#include "base/strings/string_util.h"
24#include "base/strings/stringprintf.h"
25#include "base/time/time.h"
26#include "crypto/random.h"
27#include "net/http/http_request_headers.h"
28#include "net/http/http_request_info.h"
29#include "net/http/http_response_body_drainer.h"
30#include "net/http/http_response_headers.h"
31#include "net/http/http_status_code.h"
32#include "net/http/http_stream_parser.h"
33#include "net/socket/client_socket_handle.h"
34#include "net/socket/websocket_transport_client_socket_pool.h"
35#include "net/websockets/websocket_basic_stream.h"
36#include "net/websockets/websocket_deflate_predictor.h"
37#include "net/websockets/websocket_deflate_predictor_impl.h"
38#include "net/websockets/websocket_deflate_stream.h"
39#include "net/websockets/websocket_deflater.h"
40#include "net/websockets/websocket_extension_parser.h"
41#include "net/websockets/websocket_handshake_constants.h"
42#include "net/websockets/websocket_handshake_handler.h"
43#include "net/websockets/websocket_handshake_request_info.h"
44#include "net/websockets/websocket_handshake_response_info.h"
45#include "net/websockets/websocket_stream.h"
46
47namespace net {
48
49// TODO(ricea): If more extensions are added, replace this with a more general
50// mechanism.
51struct WebSocketExtensionParams {
52  WebSocketExtensionParams()
53      : deflate_enabled(false),
54        client_window_bits(15),
55        deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {}
56
57  bool deflate_enabled;
58  int client_window_bits;
59  WebSocketDeflater::ContextTakeOverMode deflate_mode;
60};
61
62namespace {
63
64enum GetHeaderResult {
65  GET_HEADER_OK,
66  GET_HEADER_MISSING,
67  GET_HEADER_MULTIPLE,
68};
69
70std::string MissingHeaderMessage(const std::string& header_name) {
71  return std::string("'") + header_name + "' header is missing";
72}
73
74std::string MultipleHeaderValuesMessage(const std::string& header_name) {
75  return
76      std::string("'") +
77      header_name +
78      "' header must not appear more than once in a response";
79}
80
81std::string GenerateHandshakeChallenge() {
82  std::string raw_challenge(websockets::kRawChallengeLength, '\0');
83  crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length());
84  std::string encoded_challenge;
85  base::Base64Encode(raw_challenge, &encoded_challenge);
86  return encoded_challenge;
87}
88
89void AddVectorHeaderIfNonEmpty(const char* name,
90                               const std::vector<std::string>& value,
91                               HttpRequestHeaders* headers) {
92  if (value.empty())
93    return;
94  headers->SetHeader(name, JoinString(value, ", "));
95}
96
97GetHeaderResult GetSingleHeaderValue(const HttpResponseHeaders* headers,
98                                     const base::StringPiece& name,
99                                     std::string* value) {
100  void* state = NULL;
101  size_t num_values = 0;
102  std::string temp_value;
103  while (headers->EnumerateHeader(&state, name, &temp_value)) {
104    if (++num_values > 1)
105      return GET_HEADER_MULTIPLE;
106    *value = temp_value;
107  }
108  return num_values > 0 ? GET_HEADER_OK : GET_HEADER_MISSING;
109}
110
111bool ValidateHeaderHasSingleValue(GetHeaderResult result,
112                                  const std::string& header_name,
113                                  std::string* failure_message) {
114  if (result == GET_HEADER_MISSING) {
115    *failure_message = MissingHeaderMessage(header_name);
116    return false;
117  }
118  if (result == GET_HEADER_MULTIPLE) {
119    *failure_message = MultipleHeaderValuesMessage(header_name);
120    return false;
121  }
122  DCHECK_EQ(result, GET_HEADER_OK);
123  return true;
124}
125
126bool ValidateUpgrade(const HttpResponseHeaders* headers,
127                     std::string* failure_message) {
128  std::string value;
129  GetHeaderResult result =
130      GetSingleHeaderValue(headers, websockets::kUpgrade, &value);
131  if (!ValidateHeaderHasSingleValue(result,
132                                    websockets::kUpgrade,
133                                    failure_message)) {
134    return false;
135  }
136
137  if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) {
138    *failure_message =
139        "'Upgrade' header value is not 'WebSocket': " + value;
140    return false;
141  }
142  return true;
143}
144
145bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers,
146                                const std::string& expected,
147                                std::string* failure_message) {
148  std::string actual;
149  GetHeaderResult result =
150      GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual);
151  if (!ValidateHeaderHasSingleValue(result,
152                                    websockets::kSecWebSocketAccept,
153                                    failure_message)) {
154    return false;
155  }
156
157  if (expected != actual) {
158    *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value";
159    return false;
160  }
161  return true;
162}
163
164bool ValidateConnection(const HttpResponseHeaders* headers,
165                        std::string* failure_message) {
166  // Connection header is permitted to contain other tokens.
167  if (!headers->HasHeader(HttpRequestHeaders::kConnection)) {
168    *failure_message = MissingHeaderMessage(HttpRequestHeaders::kConnection);
169    return false;
170  }
171  if (!headers->HasHeaderValue(HttpRequestHeaders::kConnection,
172                               websockets::kUpgrade)) {
173    *failure_message = "'Connection' header value must contain 'Upgrade'";
174    return false;
175  }
176  return true;
177}
178
179bool ValidateSubProtocol(
180    const HttpResponseHeaders* headers,
181    const std::vector<std::string>& requested_sub_protocols,
182    std::string* sub_protocol,
183    std::string* failure_message) {
184  void* state = NULL;
185  std::string value;
186  base::hash_set<std::string> requested_set(requested_sub_protocols.begin(),
187                                            requested_sub_protocols.end());
188  int count = 0;
189  bool has_multiple_protocols = false;
190  bool has_invalid_protocol = false;
191
192  while (!has_invalid_protocol || !has_multiple_protocols) {
193    std::string temp_value;
194    if (!headers->EnumerateHeader(
195            &state, websockets::kSecWebSocketProtocol, &temp_value))
196      break;
197    value = temp_value;
198    if (requested_set.count(value) == 0)
199      has_invalid_protocol = true;
200    if (++count > 1)
201      has_multiple_protocols = true;
202  }
203
204  if (has_multiple_protocols) {
205    *failure_message =
206        MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol);
207    return false;
208  } else if (count > 0 && requested_sub_protocols.size() == 0) {
209    *failure_message =
210        std::string("Response must not include 'Sec-WebSocket-Protocol' "
211                    "header if not present in request: ")
212        + value;
213    return false;
214  } else if (has_invalid_protocol) {
215    *failure_message =
216        "'Sec-WebSocket-Protocol' header value '" +
217        value +
218        "' in response does not match any of sent values";
219    return false;
220  } else if (requested_sub_protocols.size() > 0 && count == 0) {
221    *failure_message =
222        "Sent non-empty 'Sec-WebSocket-Protocol' header "
223        "but no response was received";
224    return false;
225  }
226  *sub_protocol = value;
227  return true;
228}
229
230bool DeflateError(std::string* message, const base::StringPiece& piece) {
231  *message = "Error in permessage-deflate: ";
232  piece.AppendToString(message);
233  return false;
234}
235
236bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension,
237                                        std::string* failure_message,
238                                        WebSocketExtensionParams* params) {
239  static const char kClientPrefix[] = "client_";
240  static const char kServerPrefix[] = "server_";
241  static const char kNoContextTakeover[] = "no_context_takeover";
242  static const char kMaxWindowBits[] = "max_window_bits";
243  const size_t kPrefixLen = arraysize(kClientPrefix) - 1;
244  COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1,
245                 the_strings_server_and_client_must_be_the_same_length);
246  typedef std::vector<WebSocketExtension::Parameter> ParameterVector;
247
248  DCHECK_EQ("permessage-deflate", extension.name());
249  const ParameterVector& parameters = extension.parameters();
250  std::set<std::string> seen_names;
251  for (ParameterVector::const_iterator it = parameters.begin();
252       it != parameters.end(); ++it) {
253    const std::string& name = it->name();
254    if (seen_names.count(name) != 0) {
255      return DeflateError(
256          failure_message,
257          "Received duplicate permessage-deflate extension parameter " + name);
258    }
259    seen_names.insert(name);
260    const std::string client_or_server(name, 0, kPrefixLen);
261    const bool is_client = (client_or_server == kClientPrefix);
262    if (!is_client && client_or_server != kServerPrefix) {
263      return DeflateError(
264          failure_message,
265          "Received an unexpected permessage-deflate extension parameter");
266    }
267    const std::string rest(name, kPrefixLen);
268    if (rest == kNoContextTakeover) {
269      if (it->HasValue()) {
270        return DeflateError(failure_message,
271                            "Received invalid " + name + " parameter");
272      }
273      if (is_client)
274        params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT;
275    } else if (rest == kMaxWindowBits) {
276      if (!it->HasValue())
277        return DeflateError(failure_message, name + " must have value");
278      int bits = 0;
279      if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 ||
280          it->value()[0] == '0' ||
281          it->value().find_first_not_of("0123456789") != std::string::npos) {
282        return DeflateError(failure_message,
283                            "Received invalid " + name + " parameter");
284      }
285      if (is_client)
286        params->client_window_bits = bits;
287    } else {
288      return DeflateError(
289          failure_message,
290          "Received an unexpected permessage-deflate extension parameter");
291    }
292  }
293  params->deflate_enabled = true;
294  return true;
295}
296
297bool ValidateExtensions(const HttpResponseHeaders* headers,
298                        const std::vector<std::string>& requested_extensions,
299                        std::string* extensions,
300                        std::string* failure_message,
301                        WebSocketExtensionParams* params) {
302  void* state = NULL;
303  std::string value;
304  std::vector<std::string> accepted_extensions;
305  // TODO(ricea): If adding support for additional extensions, generalise this
306  // code.
307  bool seen_permessage_deflate = false;
308  while (headers->EnumerateHeader(
309             &state, websockets::kSecWebSocketExtensions, &value)) {
310    WebSocketExtensionParser parser;
311    parser.Parse(value);
312    if (parser.has_error()) {
313      // TODO(yhirano) Set appropriate failure message.
314      *failure_message =
315          "'Sec-WebSocket-Extensions' header value is "
316          "rejected by the parser: " +
317          value;
318      return false;
319    }
320    if (parser.extension().name() == "permessage-deflate") {
321      if (seen_permessage_deflate) {
322        *failure_message = "Received duplicate permessage-deflate response";
323        return false;
324      }
325      seen_permessage_deflate = true;
326      if (!ValidatePerMessageDeflateExtension(
327              parser.extension(), failure_message, params))
328        return false;
329    } else {
330      *failure_message =
331          "Found an unsupported extension '" +
332          parser.extension().name() +
333          "' in 'Sec-WebSocket-Extensions' header";
334      return false;
335    }
336    accepted_extensions.push_back(value);
337  }
338  *extensions = JoinString(accepted_extensions, ", ");
339  return true;
340}
341
342}  // namespace
343
344WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream(
345    scoped_ptr<ClientSocketHandle> connection,
346    WebSocketStream::ConnectDelegate* connect_delegate,
347    bool using_proxy,
348    std::vector<std::string> requested_sub_protocols,
349    std::vector<std::string> requested_extensions,
350    std::string* failure_message)
351    : state_(connection.release(), using_proxy),
352      connect_delegate_(connect_delegate),
353      http_response_info_(NULL),
354      requested_sub_protocols_(requested_sub_protocols),
355      requested_extensions_(requested_extensions),
356      failure_message_(failure_message) {
357  DCHECK(connect_delegate);
358  DCHECK(failure_message);
359}
360
361WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {}
362
363int WebSocketBasicHandshakeStream::InitializeStream(
364    const HttpRequestInfo* request_info,
365    RequestPriority priority,
366    const BoundNetLog& net_log,
367    const CompletionCallback& callback) {
368  url_ = request_info->url;
369  state_.Initialize(request_info, priority, net_log, callback);
370  return OK;
371}
372
373int WebSocketBasicHandshakeStream::SendRequest(
374    const HttpRequestHeaders& headers,
375    HttpResponseInfo* response,
376    const CompletionCallback& callback) {
377  DCHECK(!headers.HasHeader(websockets::kSecWebSocketKey));
378  DCHECK(!headers.HasHeader(websockets::kSecWebSocketProtocol));
379  DCHECK(!headers.HasHeader(websockets::kSecWebSocketExtensions));
380  DCHECK(headers.HasHeader(HttpRequestHeaders::kOrigin));
381  DCHECK(headers.HasHeader(websockets::kUpgrade));
382  DCHECK(headers.HasHeader(HttpRequestHeaders::kConnection));
383  DCHECK(headers.HasHeader(websockets::kSecWebSocketVersion));
384  DCHECK(parser());
385
386  http_response_info_ = response;
387
388  // Create a copy of the headers object, so that we can add the
389  // Sec-WebSockey-Key header.
390  HttpRequestHeaders enriched_headers;
391  enriched_headers.CopyFrom(headers);
392  std::string handshake_challenge;
393  if (handshake_challenge_for_testing_) {
394    handshake_challenge = *handshake_challenge_for_testing_;
395    handshake_challenge_for_testing_.reset();
396  } else {
397    handshake_challenge = GenerateHandshakeChallenge();
398  }
399  enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge);
400
401  AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
402                            requested_extensions_,
403                            &enriched_headers);
404  AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
405                            requested_sub_protocols_,
406                            &enriched_headers);
407
408  ComputeSecWebSocketAccept(handshake_challenge,
409                            &handshake_challenge_response_);
410
411  DCHECK(connect_delegate_);
412  scoped_ptr<WebSocketHandshakeRequestInfo> request(
413      new WebSocketHandshakeRequestInfo(url_, base::Time::Now()));
414  request->headers.CopyFrom(enriched_headers);
415  connect_delegate_->OnStartOpeningHandshake(request.Pass());
416
417  return parser()->SendRequest(
418      state_.GenerateRequestLine(), enriched_headers, response, callback);
419}
420
421int WebSocketBasicHandshakeStream::ReadResponseHeaders(
422    const CompletionCallback& callback) {
423  // HttpStreamParser uses a weak pointer when reading from the
424  // socket, so it won't be called back after being destroyed. The
425  // HttpStreamParser is owned by HttpBasicState which is owned by this object,
426  // so this use of base::Unretained() is safe.
427  int rv = parser()->ReadResponseHeaders(
428      base::Bind(&WebSocketBasicHandshakeStream::ReadResponseHeadersCallback,
429                 base::Unretained(this),
430                 callback));
431  if (rv == ERR_IO_PENDING)
432    return rv;
433  return ValidateResponse(rv);
434}
435
436int WebSocketBasicHandshakeStream::ReadResponseBody(
437    IOBuffer* buf,
438    int buf_len,
439    const CompletionCallback& callback) {
440  return parser()->ReadResponseBody(buf, buf_len, callback);
441}
442
443void WebSocketBasicHandshakeStream::Close(bool not_reusable) {
444  // This class ignores the value of |not_reusable| and never lets the socket be
445  // re-used.
446  if (parser())
447    parser()->Close(true);
448}
449
450bool WebSocketBasicHandshakeStream::IsResponseBodyComplete() const {
451  return parser()->IsResponseBodyComplete();
452}
453
454bool WebSocketBasicHandshakeStream::CanFindEndOfResponse() const {
455  return parser() && parser()->CanFindEndOfResponse();
456}
457
458bool WebSocketBasicHandshakeStream::IsConnectionReused() const {
459  return parser()->IsConnectionReused();
460}
461
462void WebSocketBasicHandshakeStream::SetConnectionReused() {
463  parser()->SetConnectionReused();
464}
465
466bool WebSocketBasicHandshakeStream::IsConnectionReusable() const {
467  return false;
468}
469
470int64 WebSocketBasicHandshakeStream::GetTotalReceivedBytes() const {
471  return 0;
472}
473
474bool WebSocketBasicHandshakeStream::GetLoadTimingInfo(
475    LoadTimingInfo* load_timing_info) const {
476  return state_.connection()->GetLoadTimingInfo(IsConnectionReused(),
477                                                load_timing_info);
478}
479
480void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {
481  parser()->GetSSLInfo(ssl_info);
482}
483
484void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo(
485    SSLCertRequestInfo* cert_request_info) {
486  parser()->GetSSLCertRequestInfo(cert_request_info);
487}
488
489bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { return false; }
490
491void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession* session) {
492  HttpResponseBodyDrainer* drainer = new HttpResponseBodyDrainer(this);
493  drainer->Start(session);
494  // |drainer| will delete itself.
495}
496
497void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) {
498  // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is
499  // gone, then copy whatever has happened there over here.
500}
501
502scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() {
503  // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make
504  // sure it does not touch it again before it is destroyed.
505  state_.DeleteParser();
506  WebSocketTransportClientSocketPool::UnlockEndpoint(state_.connection());
507  scoped_ptr<WebSocketStream> basic_stream(
508      new WebSocketBasicStream(state_.ReleaseConnection(),
509                               state_.read_buf(),
510                               sub_protocol_,
511                               extensions_));
512  DCHECK(extension_params_.get());
513  if (extension_params_->deflate_enabled) {
514    UMA_HISTOGRAM_ENUMERATION(
515        "Net.WebSocket.DeflateMode",
516        extension_params_->deflate_mode,
517        WebSocketDeflater::NUM_CONTEXT_TAKEOVER_MODE_TYPES);
518
519    return scoped_ptr<WebSocketStream>(
520        new WebSocketDeflateStream(basic_stream.Pass(),
521                                   extension_params_->deflate_mode,
522                                   extension_params_->client_window_bits,
523                                   scoped_ptr<WebSocketDeflatePredictor>(
524                                       new WebSocketDeflatePredictorImpl)));
525  } else {
526    return basic_stream.Pass();
527  }
528}
529
530void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
531    const std::string& key) {
532  handshake_challenge_for_testing_.reset(new std::string(key));
533}
534
535void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback(
536    const CompletionCallback& callback,
537    int result) {
538  callback.Run(ValidateResponse(result));
539}
540
541void WebSocketBasicHandshakeStream::OnFinishOpeningHandshake() {
542  DCHECK(http_response_info_);
543  WebSocketDispatchOnFinishOpeningHandshake(connect_delegate_,
544                                            url_,
545                                            http_response_info_->headers,
546                                            http_response_info_->response_time);
547}
548
549int WebSocketBasicHandshakeStream::ValidateResponse(int rv) {
550  DCHECK(http_response_info_);
551  // Most net errors happen during connection, so they are not seen by this
552  // method. The histogram for error codes is created in
553  // Delegate::OnResponseStarted in websocket_stream.cc instead.
554  if (rv >= 0) {
555    const HttpResponseHeaders* headers = http_response_info_->headers.get();
556    const int response_code = headers->response_code();
557    UMA_HISTOGRAM_SPARSE_SLOWLY("Net.WebSocket.ResponseCode", response_code);
558    switch (response_code) {
559      case HTTP_SWITCHING_PROTOCOLS:
560        OnFinishOpeningHandshake();
561        return ValidateUpgradeResponse(headers);
562
563      // We need to pass these through for authentication to work.
564      case HTTP_UNAUTHORIZED:
565      case HTTP_PROXY_AUTHENTICATION_REQUIRED:
566        return OK;
567
568      // Other status codes are potentially risky (see the warnings in the
569      // WHATWG WebSocket API spec) and so are dropped by default.
570      default:
571        // A WebSocket server cannot be using HTTP/0.9, so if we see version
572        // 0.9, it means the response was garbage.
573        // Reporting "Unexpected response code: 200" in this case is not
574        // helpful, so use a different error message.
575        if (headers->GetHttpVersion() == HttpVersion(0, 9)) {
576          set_failure_message(
577              "Error during WebSocket handshake: Invalid status line");
578        } else {
579          set_failure_message(base::StringPrintf(
580              "Error during WebSocket handshake: Unexpected response code: %d",
581              headers->response_code()));
582        }
583        OnFinishOpeningHandshake();
584        return ERR_INVALID_RESPONSE;
585    }
586  } else {
587    if (rv == ERR_EMPTY_RESPONSE) {
588      set_failure_message(
589          "Connection closed before receiving a handshake response");
590      return rv;
591    }
592    set_failure_message(std::string("Error during WebSocket handshake: ") +
593                        ErrorToString(rv));
594    OnFinishOpeningHandshake();
595    return rv;
596  }
597}
598
599int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
600    const HttpResponseHeaders* headers) {
601  extension_params_.reset(new WebSocketExtensionParams);
602  std::string failure_message;
603  if (ValidateUpgrade(headers, &failure_message) &&
604      ValidateSecWebSocketAccept(
605          headers, handshake_challenge_response_, &failure_message) &&
606      ValidateConnection(headers, &failure_message) &&
607      ValidateSubProtocol(headers,
608                          requested_sub_protocols_,
609                          &sub_protocol_,
610                          &failure_message) &&
611      ValidateExtensions(headers,
612                         requested_extensions_,
613                         &extensions_,
614                         &failure_message,
615                         extension_params_.get())) {
616    return OK;
617  }
618  set_failure_message("Error during WebSocket handshake: " + failure_message);
619  return ERR_INVALID_RESPONSE;
620}
621
622void WebSocketBasicHandshakeStream::set_failure_message(
623    const std::string& failure_message) {
624  *failure_message_ = failure_message;
625}
626
627}  // namespace net
628