1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "net/websockets/websocket_handshake_handler.h" 6 7#include <limits> 8 9#include "base/base64.h" 10#include "base/md5.h" 11#include "base/sha1.h" 12#include "base/strings/string_number_conversions.h" 13#include "base/strings/string_piece.h" 14#include "base/strings/string_tokenizer.h" 15#include "base/strings/string_util.h" 16#include "base/strings/stringprintf.h" 17#include "net/http/http_response_headers.h" 18#include "net/http/http_util.h" 19#include "url/gurl.h" 20 21namespace { 22 23const size_t kRequestKey3Size = 8U; 24const size_t kResponseKeySize = 16U; 25 26// First version that introduced new WebSocket handshake which does not 27// require sending "key3" or "response key" data after headers. 28const int kMinVersionOfHybiNewHandshake = 4; 29 30// Used when we calculate the value of Sec-WebSocket-Accept. 31const char* const kWebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 32 33void ParseHandshakeHeader( 34 const char* handshake_message, int len, 35 std::string* status_line, 36 std::string* headers) { 37 size_t i = base::StringPiece(handshake_message, len).find_first_of("\r\n"); 38 if (i == base::StringPiece::npos) { 39 *status_line = std::string(handshake_message, len); 40 *headers = ""; 41 return; 42 } 43 // |status_line| includes \r\n. 44 *status_line = std::string(handshake_message, i + 2); 45 46 int header_len = len - (i + 2) - 2; 47 if (header_len > 0) { 48 // |handshake_message| includes trailing \r\n\r\n. 49 // |headers| doesn't include 2nd \r\n. 50 *headers = std::string(handshake_message + i + 2, header_len); 51 } else { 52 *headers = ""; 53 } 54} 55 56void FetchHeaders(const std::string& headers, 57 const char* const headers_to_get[], 58 size_t headers_to_get_len, 59 std::vector<std::string>* values) { 60 net::HttpUtil::HeadersIterator iter(headers.begin(), headers.end(), "\r\n"); 61 while (iter.GetNext()) { 62 for (size_t i = 0; i < headers_to_get_len; i++) { 63 if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(), 64 headers_to_get[i])) { 65 values->push_back(iter.values()); 66 } 67 } 68 } 69} 70 71bool GetHeaderName(std::string::const_iterator line_begin, 72 std::string::const_iterator line_end, 73 std::string::const_iterator* name_begin, 74 std::string::const_iterator* name_end) { 75 std::string::const_iterator colon = std::find(line_begin, line_end, ':'); 76 if (colon == line_end) { 77 return false; 78 } 79 *name_begin = line_begin; 80 *name_end = colon; 81 if (*name_begin == *name_end || net::HttpUtil::IsLWS(**name_begin)) 82 return false; 83 net::HttpUtil::TrimLWS(name_begin, name_end); 84 return true; 85} 86 87// Similar to HttpUtil::StripHeaders, but it preserves malformed headers, that 88// is, lines that are not formatted as "<name>: <value>\r\n". 89std::string FilterHeaders( 90 const std::string& headers, 91 const char* const headers_to_remove[], 92 size_t headers_to_remove_len) { 93 std::string filtered_headers; 94 95 base::StringTokenizer lines(headers.begin(), headers.end(), "\r\n"); 96 while (lines.GetNext()) { 97 std::string::const_iterator line_begin = lines.token_begin(); 98 std::string::const_iterator line_end = lines.token_end(); 99 std::string::const_iterator name_begin; 100 std::string::const_iterator name_end; 101 bool should_remove = false; 102 if (GetHeaderName(line_begin, line_end, &name_begin, &name_end)) { 103 for (size_t i = 0; i < headers_to_remove_len; ++i) { 104 if (LowerCaseEqualsASCII(name_begin, name_end, headers_to_remove[i])) { 105 should_remove = true; 106 break; 107 } 108 } 109 } 110 if (!should_remove) { 111 filtered_headers.append(line_begin, line_end); 112 filtered_headers.append("\r\n"); 113 } 114 } 115 return filtered_headers; 116} 117 118int GetVersionFromRequest(const std::string& request_headers) { 119 std::vector<std::string> values; 120 const char* const headers_to_get[2] = { "sec-websocket-version", 121 "sec-websocket-draft" }; 122 FetchHeaders(request_headers, headers_to_get, 2, &values); 123 DCHECK_LE(values.size(), 1U); 124 if (values.empty()) 125 return 0; 126 int version; 127 bool conversion_success = base::StringToInt(values[0], &version); 128 DCHECK(conversion_success); 129 DCHECK_GE(version, 1); 130 return version; 131} 132 133} // namespace 134 135namespace net { 136 137namespace internal { 138 139void GetKeyNumber(const std::string& key, std::string* challenge) { 140 uint32 key_number = 0; 141 uint32 spaces = 0; 142 for (size_t i = 0; i < key.size(); ++i) { 143 if (isdigit(key[i])) { 144 // key_number should not overflow. (it comes from 145 // WebCore/websockets/WebSocketHandshake.cpp). 146 // Trust, but verify. 147 DCHECK_GE((std::numeric_limits<uint32>::max() - (key[i] - '0')) / 10, 148 key_number) << "Supplied key would overflow"; 149 key_number = key_number * 10 + key[i] - '0'; 150 } else if (key[i] == ' ') { 151 ++spaces; 152 } 153 } 154 DCHECK_NE(0u, spaces) << "Key must contain at least one space"; 155 if (spaces == 0) 156 return; 157 DCHECK_EQ(0u, key_number % spaces) << "Key number must be an integral " 158 << "multiple of the number of spaces"; 159 key_number /= spaces; 160 161 char part[4]; 162 for (int i = 0; i < 4; i++) { 163 part[3 - i] = key_number & 0xFF; 164 key_number >>= 8; 165 } 166 challenge->append(part, 4); 167} 168 169} // namespace internal 170 171WebSocketHandshakeRequestHandler::WebSocketHandshakeRequestHandler() 172 : original_length_(0), 173 raw_length_(0), 174 protocol_version_(-1) {} 175 176bool WebSocketHandshakeRequestHandler::ParseRequest( 177 const char* data, int length) { 178 DCHECK_GT(length, 0); 179 std::string input(data, length); 180 int input_header_length = 181 HttpUtil::LocateEndOfHeaders(input.data(), input.size(), 0); 182 if (input_header_length <= 0) 183 return false; 184 185 ParseHandshakeHeader(input.data(), 186 input_header_length, 187 &status_line_, 188 &headers_); 189 190 // WebSocket protocol drafts hixie-76 (hybi-00), hybi-01, 02 and 03 require 191 // the clients to send key3 after the handshake request header fields. 192 // Hybi-04 and later drafts, on the other hand, no longer have key3 193 // in the handshake format. 194 protocol_version_ = GetVersionFromRequest(headers_); 195 DCHECK_GE(protocol_version_, 0); 196 if (protocol_version_ >= kMinVersionOfHybiNewHandshake) { 197 key3_ = ""; 198 original_length_ = input_header_length; 199 return true; 200 } 201 202 if (input_header_length + kRequestKey3Size > input.size()) 203 return false; 204 205 // Assumes WebKit doesn't send any data after handshake request message 206 // until handshake is finished. 207 // Thus, |key3_| is part of handshake message, and not in part 208 // of WebSocket frame stream. 209 DCHECK_EQ(kRequestKey3Size, input.size() - input_header_length); 210 key3_ = std::string(input.data() + input_header_length, 211 input.size() - input_header_length); 212 original_length_ = input.size(); 213 return true; 214} 215 216size_t WebSocketHandshakeRequestHandler::original_length() const { 217 return original_length_; 218} 219 220void WebSocketHandshakeRequestHandler::AppendHeaderIfMissing( 221 const std::string& name, const std::string& value) { 222 DCHECK(!headers_.empty()); 223 HttpUtil::AppendHeaderIfMissing(name.c_str(), value, &headers_); 224} 225 226void WebSocketHandshakeRequestHandler::RemoveHeaders( 227 const char* const headers_to_remove[], 228 size_t headers_to_remove_len) { 229 DCHECK(!headers_.empty()); 230 headers_ = FilterHeaders( 231 headers_, headers_to_remove, headers_to_remove_len); 232} 233 234HttpRequestInfo WebSocketHandshakeRequestHandler::GetRequestInfo( 235 const GURL& url, std::string* challenge) { 236 HttpRequestInfo request_info; 237 request_info.url = url; 238 size_t method_end = base::StringPiece(status_line_).find_first_of(" "); 239 if (method_end != base::StringPiece::npos) 240 request_info.method = std::string(status_line_.data(), method_end); 241 242 request_info.extra_headers.Clear(); 243 request_info.extra_headers.AddHeadersFromString(headers_); 244 245 request_info.extra_headers.RemoveHeader("Upgrade"); 246 request_info.extra_headers.RemoveHeader("Connection"); 247 248 if (protocol_version_ >= kMinVersionOfHybiNewHandshake) { 249 std::string key; 250 bool header_present = 251 request_info.extra_headers.GetHeader("Sec-WebSocket-Key", &key); 252 DCHECK(header_present); 253 request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key"); 254 *challenge = key; 255 } else { 256 challenge->clear(); 257 std::string key; 258 bool header_present = 259 request_info.extra_headers.GetHeader("Sec-WebSocket-Key1", &key); 260 DCHECK(header_present); 261 request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key1"); 262 internal::GetKeyNumber(key, challenge); 263 264 header_present = 265 request_info.extra_headers.GetHeader("Sec-WebSocket-Key2", &key); 266 DCHECK(header_present); 267 request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key2"); 268 internal::GetKeyNumber(key, challenge); 269 270 challenge->append(key3_); 271 } 272 273 return request_info; 274} 275 276bool WebSocketHandshakeRequestHandler::GetRequestHeaderBlock( 277 const GURL& url, 278 SpdyHeaderBlock* headers, 279 std::string* challenge, 280 int spdy_protocol_version) { 281 // Construct opening handshake request headers as a SPDY header block. 282 // For details, see WebSocket Layering over SPDY/3 Draft 8. 283 if (spdy_protocol_version <= 2) { 284 (*headers)["path"] = url.path(); 285 (*headers)["version"] = 286 base::StringPrintf("%s%d", "WebSocket/", protocol_version_); 287 (*headers)["scheme"] = url.scheme(); 288 } else { 289 (*headers)[":path"] = url.path(); 290 (*headers)[":version"] = 291 base::StringPrintf("%s%d", "WebSocket/", protocol_version_); 292 (*headers)[":scheme"] = url.scheme(); 293 } 294 295 HttpUtil::HeadersIterator iter(headers_.begin(), headers_.end(), "\r\n"); 296 while (iter.GetNext()) { 297 if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(), "upgrade") || 298 LowerCaseEqualsASCII(iter.name_begin(), 299 iter.name_end(), 300 "connection") || 301 LowerCaseEqualsASCII(iter.name_begin(), 302 iter.name_end(), 303 "sec-websocket-version")) { 304 // These headers must be ignored. 305 continue; 306 } else if (LowerCaseEqualsASCII(iter.name_begin(), 307 iter.name_end(), 308 "sec-websocket-key")) { 309 *challenge = iter.values(); 310 // Sec-WebSocket-Key is not sent to the server. 311 continue; 312 } else if (LowerCaseEqualsASCII(iter.name_begin(), 313 iter.name_end(), 314 "host") || 315 LowerCaseEqualsASCII(iter.name_begin(), 316 iter.name_end(), 317 "origin") || 318 LowerCaseEqualsASCII(iter.name_begin(), 319 iter.name_end(), 320 "sec-websocket-protocol") || 321 LowerCaseEqualsASCII(iter.name_begin(), 322 iter.name_end(), 323 "sec-websocket-extensions")) { 324 // TODO(toyoshim): Some WebSocket extensions may not be compatible with 325 // SPDY. We should omit them from a Sec-WebSocket-Extension header. 326 std::string name; 327 if (spdy_protocol_version <= 2) 328 name = StringToLowerASCII(iter.name()); 329 else 330 name = ":" + StringToLowerASCII(iter.name()); 331 (*headers)[name] = iter.values(); 332 continue; 333 } 334 // Others should be sent out to |headers|. 335 std::string name = StringToLowerASCII(iter.name()); 336 SpdyHeaderBlock::iterator found = headers->find(name); 337 if (found == headers->end()) { 338 (*headers)[name] = iter.values(); 339 } else { 340 // For now, websocket doesn't use multiple headers, but follows to http. 341 found->second.append(1, '\0'); // +=() doesn't append 0's 342 found->second.append(iter.values()); 343 } 344 } 345 346 return true; 347} 348 349std::string WebSocketHandshakeRequestHandler::GetRawRequest() { 350 DCHECK(!status_line_.empty()); 351 DCHECK(!headers_.empty()); 352 // The following works on both hybi-04 and older handshake, 353 // because |key3_| is guaranteed to be empty if the handshake was hybi-04's. 354 std::string raw_request = status_line_ + headers_ + "\r\n" + key3_; 355 raw_length_ = raw_request.size(); 356 return raw_request; 357} 358 359size_t WebSocketHandshakeRequestHandler::raw_length() const { 360 DCHECK_GT(raw_length_, 0); 361 return raw_length_; 362} 363 364int WebSocketHandshakeRequestHandler::protocol_version() const { 365 DCHECK_GE(protocol_version_, 0); 366 return protocol_version_; 367} 368 369WebSocketHandshakeResponseHandler::WebSocketHandshakeResponseHandler() 370 : original_header_length_(0), 371 protocol_version_(0) {} 372 373WebSocketHandshakeResponseHandler::~WebSocketHandshakeResponseHandler() {} 374 375int WebSocketHandshakeResponseHandler::protocol_version() const { 376 DCHECK_GE(protocol_version_, 0); 377 return protocol_version_; 378} 379 380void WebSocketHandshakeResponseHandler::set_protocol_version( 381 int protocol_version) { 382 DCHECK_GE(protocol_version, 0); 383 protocol_version_ = protocol_version; 384} 385 386size_t WebSocketHandshakeResponseHandler::ParseRawResponse( 387 const char* data, int length) { 388 DCHECK_GT(length, 0); 389 if (HasResponse()) { 390 DCHECK(!status_line_.empty()); 391 // headers_ might be empty for wrong response from server. 392 return 0; 393 } 394 395 size_t old_original_length = original_.size(); 396 397 original_.append(data, length); 398 // TODO(ukai): fail fast when response gives wrong status code. 399 original_header_length_ = HttpUtil::LocateEndOfHeaders( 400 original_.data(), original_.size(), 0); 401 if (!HasResponse()) 402 return length; 403 404 ParseHandshakeHeader(original_.data(), 405 original_header_length_, 406 &status_line_, 407 &headers_); 408 int header_size = status_line_.size() + headers_.size(); 409 DCHECK_GE(original_header_length_, header_size); 410 header_separator_ = std::string(original_.data() + header_size, 411 original_header_length_ - header_size); 412 key_ = std::string(original_.data() + original_header_length_, 413 GetResponseKeySize()); 414 return original_header_length_ + GetResponseKeySize() - old_original_length; 415} 416 417bool WebSocketHandshakeResponseHandler::HasResponse() const { 418 return original_header_length_ > 0 && 419 original_header_length_ + GetResponseKeySize() <= original_.size(); 420} 421 422bool WebSocketHandshakeResponseHandler::ParseResponseInfo( 423 const HttpResponseInfo& response_info, 424 const std::string& challenge) { 425 if (!response_info.headers.get()) 426 return false; 427 428 std::string response_message; 429 response_message = response_info.headers->GetStatusLine(); 430 response_message += "\r\n"; 431 if (protocol_version_ >= kMinVersionOfHybiNewHandshake) 432 response_message += "Upgrade: websocket\r\n"; 433 else 434 response_message += "Upgrade: WebSocket\r\n"; 435 response_message += "Connection: Upgrade\r\n"; 436 437 if (protocol_version_ >= kMinVersionOfHybiNewHandshake) { 438 std::string hash = base::SHA1HashString(challenge + kWebSocketGuid); 439 std::string websocket_accept; 440 bool encode_success = base::Base64Encode(hash, &websocket_accept); 441 DCHECK(encode_success); 442 response_message += "Sec-WebSocket-Accept: " + websocket_accept + "\r\n"; 443 } 444 445 void* iter = NULL; 446 std::string name; 447 std::string value; 448 while (response_info.headers->EnumerateHeaderLines(&iter, &name, &value)) { 449 response_message += name + ": " + value + "\r\n"; 450 } 451 response_message += "\r\n"; 452 453 if (protocol_version_ < kMinVersionOfHybiNewHandshake) { 454 base::MD5Digest digest; 455 base::MD5Sum(challenge.data(), challenge.size(), &digest); 456 457 const char* digest_data = reinterpret_cast<char*>(digest.a); 458 response_message.append(digest_data, sizeof(digest.a)); 459 } 460 461 return ParseRawResponse(response_message.data(), 462 response_message.size()) == response_message.size(); 463} 464 465bool WebSocketHandshakeResponseHandler::ParseResponseHeaderBlock( 466 const SpdyHeaderBlock& headers, 467 const std::string& challenge, 468 int spdy_protocol_version) { 469 SpdyHeaderBlock::const_iterator status; 470 if (spdy_protocol_version <= 2) 471 status = headers.find("status"); 472 else 473 status = headers.find(":status"); 474 if (status == headers.end()) 475 return false; 476 std::string response_message; 477 response_message = 478 base::StringPrintf("%s%s\r\n", "HTTP/1.1 ", status->second.c_str()); 479 response_message += "Upgrade: websocket\r\n"; 480 response_message += "Connection: Upgrade\r\n"; 481 482 std::string hash = base::SHA1HashString(challenge + kWebSocketGuid); 483 std::string websocket_accept; 484 bool encode_success = base::Base64Encode(hash, &websocket_accept); 485 DCHECK(encode_success); 486 response_message += "Sec-WebSocket-Accept: " + websocket_accept + "\r\n"; 487 488 for (SpdyHeaderBlock::const_iterator iter = headers.begin(); 489 iter != headers.end(); 490 ++iter) { 491 // For each value, if the server sends a NUL-separated list of values, 492 // we separate that back out into individual headers for each value 493 // in the list. 494 if ((spdy_protocol_version <= 2 && 495 LowerCaseEqualsASCII(iter->first, "status")) || 496 (spdy_protocol_version >= 3 && 497 LowerCaseEqualsASCII(iter->first, ":status"))) { 498 // The status value is already handled as the first line of 499 // |response_message|. Just skip here. 500 continue; 501 } 502 const std::string& value = iter->second; 503 size_t start = 0; 504 size_t end = 0; 505 do { 506 end = value.find('\0', start); 507 std::string tval; 508 if (end != std::string::npos) 509 tval = value.substr(start, (end - start)); 510 else 511 tval = value.substr(start); 512 if (spdy_protocol_version >= 3 && 513 (LowerCaseEqualsASCII(iter->first, ":sec-websocket-protocol") || 514 LowerCaseEqualsASCII(iter->first, ":sec-websocket-extensions"))) 515 response_message += iter->first.substr(1) + ": " + tval + "\r\n"; 516 else 517 response_message += iter->first + ": " + tval + "\r\n"; 518 start = end + 1; 519 } while (end != std::string::npos); 520 } 521 response_message += "\r\n"; 522 523 return ParseRawResponse(response_message.data(), 524 response_message.size()) == response_message.size(); 525} 526 527void WebSocketHandshakeResponseHandler::GetHeaders( 528 const char* const headers_to_get[], 529 size_t headers_to_get_len, 530 std::vector<std::string>* values) { 531 DCHECK(HasResponse()); 532 DCHECK(!status_line_.empty()); 533 // headers_ might be empty for wrong response from server. 534 if (headers_.empty()) 535 return; 536 537 FetchHeaders(headers_, headers_to_get, headers_to_get_len, values); 538} 539 540void WebSocketHandshakeResponseHandler::RemoveHeaders( 541 const char* const headers_to_remove[], 542 size_t headers_to_remove_len) { 543 DCHECK(HasResponse()); 544 DCHECK(!status_line_.empty()); 545 // headers_ might be empty for wrong response from server. 546 if (headers_.empty()) 547 return; 548 549 headers_ = FilterHeaders(headers_, headers_to_remove, headers_to_remove_len); 550} 551 552std::string WebSocketHandshakeResponseHandler::GetRawResponse() const { 553 DCHECK(HasResponse()); 554 return std::string(original_.data(), 555 original_header_length_ + GetResponseKeySize()); 556} 557 558std::string WebSocketHandshakeResponseHandler::GetResponse() { 559 DCHECK(HasResponse()); 560 DCHECK(!status_line_.empty()); 561 // headers_ might be empty for wrong response from server. 562 DCHECK_EQ(GetResponseKeySize(), key_.size()); 563 564 return status_line_ + headers_ + header_separator_ + key_; 565} 566 567size_t WebSocketHandshakeResponseHandler::GetResponseKeySize() const { 568 if (protocol_version_ >= kMinVersionOfHybiNewHandshake) 569 return 0; 570 return kResponseKeySize; 571} 572 573} // namespace net 574