websocket_basic_handshake_stream.cc revision cedac228d2dd51db4b79ea1e72c7f249408ee061
1// Copyright 2013 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "net/websockets/websocket_basic_handshake_stream.h" 6 7#include <algorithm> 8#include <iterator> 9#include <set> 10#include <string> 11#include <vector> 12 13#include "base/base64.h" 14#include "base/basictypes.h" 15#include "base/bind.h" 16#include "base/containers/hash_tables.h" 17#include "base/metrics/histogram.h" 18#include "base/stl_util.h" 19#include "base/strings/string_number_conversions.h" 20#include "base/strings/string_piece.h" 21#include "base/strings/string_util.h" 22#include "base/strings/stringprintf.h" 23#include "base/time/time.h" 24#include "crypto/random.h" 25#include "net/http/http_request_headers.h" 26#include "net/http/http_request_info.h" 27#include "net/http/http_response_body_drainer.h" 28#include "net/http/http_response_headers.h" 29#include "net/http/http_status_code.h" 30#include "net/http/http_stream_parser.h" 31#include "net/socket/client_socket_handle.h" 32#include "net/websockets/websocket_basic_stream.h" 33#include "net/websockets/websocket_deflate_predictor.h" 34#include "net/websockets/websocket_deflate_predictor_impl.h" 35#include "net/websockets/websocket_deflate_stream.h" 36#include "net/websockets/websocket_deflater.h" 37#include "net/websockets/websocket_extension_parser.h" 38#include "net/websockets/websocket_handshake_constants.h" 39#include "net/websockets/websocket_handshake_handler.h" 40#include "net/websockets/websocket_handshake_request_info.h" 41#include "net/websockets/websocket_handshake_response_info.h" 42#include "net/websockets/websocket_stream.h" 43 44namespace net { 45 46// TODO(ricea): If more extensions are added, replace this with a more general 47// mechanism. 48struct WebSocketExtensionParams { 49 WebSocketExtensionParams() 50 : deflate_enabled(false), 51 client_window_bits(15), 52 deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {} 53 54 bool deflate_enabled; 55 int client_window_bits; 56 WebSocketDeflater::ContextTakeOverMode deflate_mode; 57}; 58 59namespace { 60 61enum GetHeaderResult { 62 GET_HEADER_OK, 63 GET_HEADER_MISSING, 64 GET_HEADER_MULTIPLE, 65}; 66 67std::string MissingHeaderMessage(const std::string& header_name) { 68 return std::string("'") + header_name + "' header is missing"; 69} 70 71std::string MultipleHeaderValuesMessage(const std::string& header_name) { 72 return 73 std::string("'") + 74 header_name + 75 "' header must not appear more than once in a response"; 76} 77 78std::string GenerateHandshakeChallenge() { 79 std::string raw_challenge(websockets::kRawChallengeLength, '\0'); 80 crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length()); 81 std::string encoded_challenge; 82 base::Base64Encode(raw_challenge, &encoded_challenge); 83 return encoded_challenge; 84} 85 86void AddVectorHeaderIfNonEmpty(const char* name, 87 const std::vector<std::string>& value, 88 HttpRequestHeaders* headers) { 89 if (value.empty()) 90 return; 91 headers->SetHeader(name, JoinString(value, ", ")); 92} 93 94GetHeaderResult GetSingleHeaderValue(const HttpResponseHeaders* headers, 95 const base::StringPiece& name, 96 std::string* value) { 97 void* state = NULL; 98 size_t num_values = 0; 99 std::string temp_value; 100 while (headers->EnumerateHeader(&state, name, &temp_value)) { 101 if (++num_values > 1) 102 return GET_HEADER_MULTIPLE; 103 *value = temp_value; 104 } 105 return num_values > 0 ? GET_HEADER_OK : GET_HEADER_MISSING; 106} 107 108bool ValidateHeaderHasSingleValue(GetHeaderResult result, 109 const std::string& header_name, 110 std::string* failure_message) { 111 if (result == GET_HEADER_MISSING) { 112 *failure_message = MissingHeaderMessage(header_name); 113 return false; 114 } 115 if (result == GET_HEADER_MULTIPLE) { 116 *failure_message = MultipleHeaderValuesMessage(header_name); 117 return false; 118 } 119 DCHECK_EQ(result, GET_HEADER_OK); 120 return true; 121} 122 123bool ValidateUpgrade(const HttpResponseHeaders* headers, 124 std::string* failure_message) { 125 std::string value; 126 GetHeaderResult result = 127 GetSingleHeaderValue(headers, websockets::kUpgrade, &value); 128 if (!ValidateHeaderHasSingleValue(result, 129 websockets::kUpgrade, 130 failure_message)) { 131 return false; 132 } 133 134 if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) { 135 *failure_message = 136 "'Upgrade' header value is not 'WebSocket': " + value; 137 return false; 138 } 139 return true; 140} 141 142bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers, 143 const std::string& expected, 144 std::string* failure_message) { 145 std::string actual; 146 GetHeaderResult result = 147 GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual); 148 if (!ValidateHeaderHasSingleValue(result, 149 websockets::kSecWebSocketAccept, 150 failure_message)) { 151 return false; 152 } 153 154 if (expected != actual) { 155 *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value"; 156 return false; 157 } 158 return true; 159} 160 161bool ValidateConnection(const HttpResponseHeaders* headers, 162 std::string* failure_message) { 163 // Connection header is permitted to contain other tokens. 164 if (!headers->HasHeader(HttpRequestHeaders::kConnection)) { 165 *failure_message = MissingHeaderMessage(HttpRequestHeaders::kConnection); 166 return false; 167 } 168 if (!headers->HasHeaderValue(HttpRequestHeaders::kConnection, 169 websockets::kUpgrade)) { 170 *failure_message = "'Connection' header value must contain 'Upgrade'"; 171 return false; 172 } 173 return true; 174} 175 176bool ValidateSubProtocol( 177 const HttpResponseHeaders* headers, 178 const std::vector<std::string>& requested_sub_protocols, 179 std::string* sub_protocol, 180 std::string* failure_message) { 181 void* state = NULL; 182 std::string value; 183 base::hash_set<std::string> requested_set(requested_sub_protocols.begin(), 184 requested_sub_protocols.end()); 185 int count = 0; 186 bool has_multiple_protocols = false; 187 bool has_invalid_protocol = false; 188 189 while (!has_invalid_protocol || !has_multiple_protocols) { 190 std::string temp_value; 191 if (!headers->EnumerateHeader( 192 &state, websockets::kSecWebSocketProtocol, &temp_value)) 193 break; 194 value = temp_value; 195 if (requested_set.count(value) == 0) 196 has_invalid_protocol = true; 197 if (++count > 1) 198 has_multiple_protocols = true; 199 } 200 201 if (has_multiple_protocols) { 202 *failure_message = 203 MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol); 204 return false; 205 } else if (count > 0 && requested_sub_protocols.size() == 0) { 206 *failure_message = 207 std::string("Response must not include 'Sec-WebSocket-Protocol' " 208 "header if not present in request: ") 209 + value; 210 return false; 211 } else if (has_invalid_protocol) { 212 *failure_message = 213 "'Sec-WebSocket-Protocol' header value '" + 214 value + 215 "' in response does not match any of sent values"; 216 return false; 217 } else if (requested_sub_protocols.size() > 0 && count == 0) { 218 *failure_message = 219 "Sent non-empty 'Sec-WebSocket-Protocol' header " 220 "but no response was received"; 221 return false; 222 } 223 *sub_protocol = value; 224 return true; 225} 226 227bool DeflateError(std::string* message, const base::StringPiece& piece) { 228 *message = "Error in permessage-deflate: "; 229 piece.AppendToString(message); 230 return false; 231} 232 233bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension, 234 std::string* failure_message, 235 WebSocketExtensionParams* params) { 236 static const char kClientPrefix[] = "client_"; 237 static const char kServerPrefix[] = "server_"; 238 static const char kNoContextTakeover[] = "no_context_takeover"; 239 static const char kMaxWindowBits[] = "max_window_bits"; 240 const size_t kPrefixLen = arraysize(kClientPrefix) - 1; 241 COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1, 242 the_strings_server_and_client_must_be_the_same_length); 243 typedef std::vector<WebSocketExtension::Parameter> ParameterVector; 244 245 DCHECK_EQ("permessage-deflate", extension.name()); 246 const ParameterVector& parameters = extension.parameters(); 247 std::set<std::string> seen_names; 248 for (ParameterVector::const_iterator it = parameters.begin(); 249 it != parameters.end(); ++it) { 250 const std::string& name = it->name(); 251 if (seen_names.count(name) != 0) { 252 return DeflateError( 253 failure_message, 254 "Received duplicate permessage-deflate extension parameter " + name); 255 } 256 seen_names.insert(name); 257 const std::string client_or_server(name, 0, kPrefixLen); 258 const bool is_client = (client_or_server == kClientPrefix); 259 if (!is_client && client_or_server != kServerPrefix) { 260 return DeflateError( 261 failure_message, 262 "Received an unexpected permessage-deflate extension parameter"); 263 } 264 const std::string rest(name, kPrefixLen); 265 if (rest == kNoContextTakeover) { 266 if (it->HasValue()) { 267 return DeflateError(failure_message, 268 "Received invalid " + name + " parameter"); 269 } 270 if (is_client) 271 params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT; 272 } else if (rest == kMaxWindowBits) { 273 if (!it->HasValue()) 274 return DeflateError(failure_message, name + " must have value"); 275 int bits = 0; 276 if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 || 277 it->value()[0] == '0' || 278 it->value().find_first_not_of("0123456789") != std::string::npos) { 279 return DeflateError(failure_message, 280 "Received invalid " + name + " parameter"); 281 } 282 if (is_client) 283 params->client_window_bits = bits; 284 } else { 285 return DeflateError( 286 failure_message, 287 "Received an unexpected permessage-deflate extension parameter"); 288 } 289 } 290 params->deflate_enabled = true; 291 return true; 292} 293 294bool ValidateExtensions(const HttpResponseHeaders* headers, 295 const std::vector<std::string>& requested_extensions, 296 std::string* extensions, 297 std::string* failure_message, 298 WebSocketExtensionParams* params) { 299 void* state = NULL; 300 std::string value; 301 std::vector<std::string> accepted_extensions; 302 // TODO(ricea): If adding support for additional extensions, generalise this 303 // code. 304 bool seen_permessage_deflate = false; 305 while (headers->EnumerateHeader( 306 &state, websockets::kSecWebSocketExtensions, &value)) { 307 WebSocketExtensionParser parser; 308 parser.Parse(value); 309 if (parser.has_error()) { 310 // TODO(yhirano) Set appropriate failure message. 311 *failure_message = 312 "'Sec-WebSocket-Extensions' header value is " 313 "rejected by the parser: " + 314 value; 315 return false; 316 } 317 if (parser.extension().name() == "permessage-deflate") { 318 if (seen_permessage_deflate) { 319 *failure_message = "Received duplicate permessage-deflate response"; 320 return false; 321 } 322 seen_permessage_deflate = true; 323 if (!ValidatePerMessageDeflateExtension( 324 parser.extension(), failure_message, params)) 325 return false; 326 } else { 327 *failure_message = 328 "Found an unsupported extension '" + 329 parser.extension().name() + 330 "' in 'Sec-WebSocket-Extensions' header"; 331 return false; 332 } 333 accepted_extensions.push_back(value); 334 } 335 *extensions = JoinString(accepted_extensions, ", "); 336 return true; 337} 338 339} // namespace 340 341WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( 342 scoped_ptr<ClientSocketHandle> connection, 343 WebSocketStream::ConnectDelegate* connect_delegate, 344 bool using_proxy, 345 std::vector<std::string> requested_sub_protocols, 346 std::vector<std::string> requested_extensions) 347 : state_(connection.release(), using_proxy), 348 connect_delegate_(connect_delegate), 349 http_response_info_(NULL), 350 requested_sub_protocols_(requested_sub_protocols), 351 requested_extensions_(requested_extensions) {} 352 353WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {} 354 355int WebSocketBasicHandshakeStream::InitializeStream( 356 const HttpRequestInfo* request_info, 357 RequestPriority priority, 358 const BoundNetLog& net_log, 359 const CompletionCallback& callback) { 360 url_ = request_info->url; 361 state_.Initialize(request_info, priority, net_log, callback); 362 return OK; 363} 364 365int WebSocketBasicHandshakeStream::SendRequest( 366 const HttpRequestHeaders& headers, 367 HttpResponseInfo* response, 368 const CompletionCallback& callback) { 369 DCHECK(!headers.HasHeader(websockets::kSecWebSocketKey)); 370 DCHECK(!headers.HasHeader(websockets::kSecWebSocketProtocol)); 371 DCHECK(!headers.HasHeader(websockets::kSecWebSocketExtensions)); 372 DCHECK(headers.HasHeader(HttpRequestHeaders::kOrigin)); 373 DCHECK(headers.HasHeader(websockets::kUpgrade)); 374 DCHECK(headers.HasHeader(HttpRequestHeaders::kConnection)); 375 DCHECK(headers.HasHeader(websockets::kSecWebSocketVersion)); 376 DCHECK(parser()); 377 378 http_response_info_ = response; 379 380 // Create a copy of the headers object, so that we can add the 381 // Sec-WebSockey-Key header. 382 HttpRequestHeaders enriched_headers; 383 enriched_headers.CopyFrom(headers); 384 std::string handshake_challenge; 385 if (handshake_challenge_for_testing_) { 386 handshake_challenge = *handshake_challenge_for_testing_; 387 handshake_challenge_for_testing_.reset(); 388 } else { 389 handshake_challenge = GenerateHandshakeChallenge(); 390 } 391 enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge); 392 393 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions, 394 requested_extensions_, 395 &enriched_headers); 396 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol, 397 requested_sub_protocols_, 398 &enriched_headers); 399 400 ComputeSecWebSocketAccept(handshake_challenge, 401 &handshake_challenge_response_); 402 403 DCHECK(connect_delegate_); 404 scoped_ptr<WebSocketHandshakeRequestInfo> request( 405 new WebSocketHandshakeRequestInfo(url_, base::Time::Now())); 406 request->headers.CopyFrom(enriched_headers); 407 connect_delegate_->OnStartOpeningHandshake(request.Pass()); 408 409 return parser()->SendRequest( 410 state_.GenerateRequestLine(), enriched_headers, response, callback); 411} 412 413int WebSocketBasicHandshakeStream::ReadResponseHeaders( 414 const CompletionCallback& callback) { 415 // HttpStreamParser uses a weak pointer when reading from the 416 // socket, so it won't be called back after being destroyed. The 417 // HttpStreamParser is owned by HttpBasicState which is owned by this object, 418 // so this use of base::Unretained() is safe. 419 int rv = parser()->ReadResponseHeaders( 420 base::Bind(&WebSocketBasicHandshakeStream::ReadResponseHeadersCallback, 421 base::Unretained(this), 422 callback)); 423 if (rv == ERR_IO_PENDING) 424 return rv; 425 return ValidateResponse(rv); 426} 427 428int WebSocketBasicHandshakeStream::ReadResponseBody( 429 IOBuffer* buf, 430 int buf_len, 431 const CompletionCallback& callback) { 432 return parser()->ReadResponseBody(buf, buf_len, callback); 433} 434 435void WebSocketBasicHandshakeStream::Close(bool not_reusable) { 436 // This class ignores the value of |not_reusable| and never lets the socket be 437 // re-used. 438 if (parser()) 439 parser()->Close(true); 440} 441 442bool WebSocketBasicHandshakeStream::IsResponseBodyComplete() const { 443 return parser()->IsResponseBodyComplete(); 444} 445 446bool WebSocketBasicHandshakeStream::CanFindEndOfResponse() const { 447 return parser() && parser()->CanFindEndOfResponse(); 448} 449 450bool WebSocketBasicHandshakeStream::IsConnectionReused() const { 451 return parser()->IsConnectionReused(); 452} 453 454void WebSocketBasicHandshakeStream::SetConnectionReused() { 455 parser()->SetConnectionReused(); 456} 457 458bool WebSocketBasicHandshakeStream::IsConnectionReusable() const { 459 return false; 460} 461 462int64 WebSocketBasicHandshakeStream::GetTotalReceivedBytes() const { 463 return 0; 464} 465 466bool WebSocketBasicHandshakeStream::GetLoadTimingInfo( 467 LoadTimingInfo* load_timing_info) const { 468 return state_.connection()->GetLoadTimingInfo(IsConnectionReused(), 469 load_timing_info); 470} 471 472void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo* ssl_info) { 473 parser()->GetSSLInfo(ssl_info); 474} 475 476void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo( 477 SSLCertRequestInfo* cert_request_info) { 478 parser()->GetSSLCertRequestInfo(cert_request_info); 479} 480 481bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { return false; } 482 483void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession* session) { 484 HttpResponseBodyDrainer* drainer = new HttpResponseBodyDrainer(this); 485 drainer->Start(session); 486 // |drainer| will delete itself. 487} 488 489void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) { 490 // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is 491 // gone, then copy whatever has happened there over here. 492} 493 494scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() { 495 // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make 496 // sure it does not touch it again before it is destroyed. 497 state_.DeleteParser(); 498 scoped_ptr<WebSocketStream> basic_stream( 499 new WebSocketBasicStream(state_.ReleaseConnection(), 500 state_.read_buf(), 501 sub_protocol_, 502 extensions_)); 503 DCHECK(extension_params_.get()); 504 if (extension_params_->deflate_enabled) { 505 UMA_HISTOGRAM_ENUMERATION( 506 "Net.WebSocket.DeflateMode", 507 extension_params_->deflate_mode, 508 WebSocketDeflater::NUM_CONTEXT_TAKEOVER_MODE_TYPES); 509 510 return scoped_ptr<WebSocketStream>( 511 new WebSocketDeflateStream(basic_stream.Pass(), 512 extension_params_->deflate_mode, 513 extension_params_->client_window_bits, 514 scoped_ptr<WebSocketDeflatePredictor>( 515 new WebSocketDeflatePredictorImpl))); 516 } else { 517 return basic_stream.Pass(); 518 } 519} 520 521void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting( 522 const std::string& key) { 523 handshake_challenge_for_testing_.reset(new std::string(key)); 524} 525 526std::string WebSocketBasicHandshakeStream::GetFailureMessage() const { 527 return failure_message_; 528} 529 530void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback( 531 const CompletionCallback& callback, 532 int result) { 533 callback.Run(ValidateResponse(result)); 534} 535 536void WebSocketBasicHandshakeStream::OnFinishOpeningHandshake() { 537 DCHECK(connect_delegate_); 538 DCHECK(http_response_info_); 539 scoped_refptr<HttpResponseHeaders> headers = http_response_info_->headers; 540 // If the headers are too large, HttpStreamParser will just not parse them at 541 // all. 542 if (headers) { 543 scoped_ptr<WebSocketHandshakeResponseInfo> response( 544 new WebSocketHandshakeResponseInfo(url_, 545 headers->response_code(), 546 headers->GetStatusText(), 547 headers, 548 http_response_info_->response_time)); 549 connect_delegate_->OnFinishOpeningHandshake(response.Pass()); 550 } 551} 552 553int WebSocketBasicHandshakeStream::ValidateResponse(int rv) { 554 DCHECK(http_response_info_); 555 const HttpResponseHeaders* headers = http_response_info_->headers.get(); 556 if (rv >= 0) { 557 switch (headers->response_code()) { 558 case HTTP_SWITCHING_PROTOCOLS: 559 OnFinishOpeningHandshake(); 560 return ValidateUpgradeResponse(headers); 561 562 // We need to pass these through for authentication to work. 563 case HTTP_UNAUTHORIZED: 564 case HTTP_PROXY_AUTHENTICATION_REQUIRED: 565 return OK; 566 567 // Other status codes are potentially risky (see the warnings in the 568 // WHATWG WebSocket API spec) and so are dropped by default. 569 default: 570 // A WebSocket server cannot be using HTTP/0.9, so if we see version 571 // 0.9, it means the response was garbage. 572 // Reporting "Unexpected response code: 200" in this case is not 573 // helpful, so use a different error message. 574 if (headers->GetHttpVersion() == HttpVersion(0, 9)) { 575 failure_message_ = 576 "Error during WebSocket handshake: Invalid status line"; 577 } else { 578 failure_message_ = base::StringPrintf( 579 "Error during WebSocket handshake: Unexpected response code: %d", 580 headers->response_code()); 581 } 582 OnFinishOpeningHandshake(); 583 return ERR_INVALID_RESPONSE; 584 } 585 } else { 586 if (rv == ERR_EMPTY_RESPONSE) { 587 failure_message_ = 588 "Connection closed before receiving a handshake response"; 589 return rv; 590 } 591 failure_message_ = 592 std::string("Error during WebSocket handshake: ") + ErrorToString(rv); 593 OnFinishOpeningHandshake(); 594 return rv; 595 } 596} 597 598int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( 599 const HttpResponseHeaders* headers) { 600 extension_params_.reset(new WebSocketExtensionParams); 601 if (ValidateUpgrade(headers, &failure_message_) && 602 ValidateSecWebSocketAccept(headers, 603 handshake_challenge_response_, 604 &failure_message_) && 605 ValidateConnection(headers, &failure_message_) && 606 ValidateSubProtocol(headers, 607 requested_sub_protocols_, 608 &sub_protocol_, 609 &failure_message_) && 610 ValidateExtensions(headers, 611 requested_extensions_, 612 &extensions_, 613 &failure_message_, 614 extension_params_.get())) { 615 return OK; 616 } 617 failure_message_ = "Error during WebSocket handshake: " + failure_message_; 618 return ERR_INVALID_RESPONSE; 619} 620 621} // namespace net 622