1// Copyright (c) 2011 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/websockets/websocket_handshake.h"
6
7#include <algorithm>
8#include <vector>
9
10#include "base/logging.h"
11#include "base/md5.h"
12#include "base/memory/ref_counted.h"
13#include "base/rand_util.h"
14#include "base/string_number_conversions.h"
15#include "base/string_util.h"
16#include "base/stringprintf.h"
17#include "net/http/http_response_headers.h"
18#include "net/http/http_util.h"
19
20namespace net {
21
22const int WebSocketHandshake::kWebSocketPort = 80;
23const int WebSocketHandshake::kSecureWebSocketPort = 443;
24
25WebSocketHandshake::WebSocketHandshake(
26    const GURL& url,
27    const std::string& origin,
28    const std::string& location,
29    const std::string& protocol)
30    : url_(url),
31      origin_(origin),
32      location_(location),
33      protocol_(protocol),
34      mode_(MODE_INCOMPLETE) {
35}
36
37WebSocketHandshake::~WebSocketHandshake() {
38}
39
40bool WebSocketHandshake::is_secure() const {
41  return url_.SchemeIs("wss");
42}
43
44std::string WebSocketHandshake::CreateClientHandshakeMessage() {
45  if (!parameter_.get()) {
46    parameter_.reset(new Parameter);
47    parameter_->GenerateKeys();
48  }
49  std::string msg;
50
51  // WebSocket protocol 4.1 Opening handshake.
52
53  msg = "GET ";
54  msg += GetResourceName();
55  msg += " HTTP/1.1\r\n";
56
57  std::vector<std::string> fields;
58
59  fields.push_back("Upgrade: WebSocket");
60  fields.push_back("Connection: Upgrade");
61
62  fields.push_back("Host: " + GetHostFieldValue());
63
64  fields.push_back("Origin: " + GetOriginFieldValue());
65
66  if (!protocol_.empty())
67    fields.push_back("Sec-WebSocket-Protocol: " + protocol_);
68
69  // TODO(ukai): Add cookie if necessary.
70
71  fields.push_back("Sec-WebSocket-Key1: " + parameter_->GetSecWebSocketKey1());
72  fields.push_back("Sec-WebSocket-Key2: " + parameter_->GetSecWebSocketKey2());
73
74  std::random_shuffle(fields.begin(), fields.end(), base::RandGenerator);
75
76  for (size_t i = 0; i < fields.size(); i++) {
77    msg += fields[i] + "\r\n";
78  }
79  msg += "\r\n";
80
81  msg.append(parameter_->GetKey3());
82  return msg;
83}
84
85int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) {
86  mode_ = MODE_INCOMPLETE;
87  int eoh = HttpUtil::LocateEndOfHeaders(data, len);
88  if (eoh < 0)
89    return -1;
90
91  scoped_refptr<HttpResponseHeaders> headers(
92      new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh)));
93
94  if (headers->response_code() != 101) {
95    mode_ = MODE_FAILED;
96    DVLOG(1) << "Bad response code: " << headers->response_code();
97    return eoh;
98  }
99  mode_ = MODE_NORMAL;
100  if (!ProcessHeaders(*headers) || !CheckResponseHeaders()) {
101    DVLOG(1) << "Process Headers failed: " << std::string(data, eoh);
102    mode_ = MODE_FAILED;
103    return eoh;
104  }
105  if (len < static_cast<size_t>(eoh + Parameter::kExpectedResponseSize)) {
106    mode_ = MODE_INCOMPLETE;
107    return -1;
108  }
109  uint8 expected[Parameter::kExpectedResponseSize];
110  parameter_->GetExpectedResponse(expected);
111  if (memcmp(&data[eoh], expected, Parameter::kExpectedResponseSize)) {
112    mode_ = MODE_FAILED;
113    return eoh + Parameter::kExpectedResponseSize;
114  }
115  mode_ = MODE_CONNECTED;
116  return eoh + Parameter::kExpectedResponseSize;
117}
118
119std::string WebSocketHandshake::GetResourceName() const {
120  std::string resource_name = url_.path();
121  if (url_.has_query()) {
122    resource_name += "?";
123    resource_name += url_.query();
124  }
125  return resource_name;
126}
127
128std::string WebSocketHandshake::GetHostFieldValue() const {
129  // url_.host() is expected to be encoded in punnycode here.
130  std::string host = StringToLowerASCII(url_.host());
131  if (url_.has_port()) {
132    bool secure = is_secure();
133    int port = url_.EffectiveIntPort();
134    if ((!secure &&
135         port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) ||
136        (secure &&
137         port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) {
138      host += ":";
139      host += base::IntToString(port);
140    }
141  }
142  return host;
143}
144
145std::string WebSocketHandshake::GetOriginFieldValue() const {
146  // It's OK to lowercase the origin as the Origin header does not contain
147  // the path or query portions, as per
148  // http://tools.ietf.org/html/draft-abarth-origin-00.
149  //
150  // TODO(satorux): Should we trim the port portion here if it's 80 for
151  // http:// or 443 for https:// ? Or can we assume it's done by the
152  // client of the library?
153  return StringToLowerASCII(origin_);
154}
155
156/* static */
157bool WebSocketHandshake::GetSingleHeader(const HttpResponseHeaders& headers,
158                                         const std::string& name,
159                                         std::string* value) {
160  std::string first_value;
161  void* iter = NULL;
162  if (!headers.EnumerateHeader(&iter, name, &first_value))
163    return false;
164
165  // Checks no more |name| found in |headers|.
166  // Second call of EnumerateHeader() must return false.
167  std::string second_value;
168  if (headers.EnumerateHeader(&iter, name, &second_value))
169    return false;
170  *value = first_value;
171  return true;
172}
173
174bool WebSocketHandshake::ProcessHeaders(const HttpResponseHeaders& headers) {
175  std::string value;
176  if (!GetSingleHeader(headers, "upgrade", &value) ||
177      value != "WebSocket")
178    return false;
179
180  if (!GetSingleHeader(headers, "connection", &value) ||
181      !LowerCaseEqualsASCII(value, "upgrade"))
182    return false;
183
184  if (!GetSingleHeader(headers, "sec-websocket-origin", &ws_origin_))
185    return false;
186
187  if (!GetSingleHeader(headers, "sec-websocket-location", &ws_location_))
188    return false;
189
190  // If |protocol_| is not specified by client, we don't care if there's
191  // protocol field or not as specified in the spec.
192  if (!protocol_.empty()
193      && !GetSingleHeader(headers, "sec-websocket-protocol", &ws_protocol_))
194    return false;
195  return true;
196}
197
198bool WebSocketHandshake::CheckResponseHeaders() const {
199  DCHECK(mode_ == MODE_NORMAL);
200  if (!LowerCaseEqualsASCII(origin_, ws_origin_.c_str()))
201    return false;
202  if (location_ != ws_location_)
203    return false;
204  if (!protocol_.empty() && protocol_ != ws_protocol_)
205    return false;
206  return true;
207}
208
209namespace {
210
211// unsigned int version of base::RandInt().
212// we can't use base::RandInt(), because max would be negative if it is
213// represented as int, so DCHECK(min <= max) fails.
214uint32 RandUint32(uint32 min, uint32 max) {
215  DCHECK(min <= max);
216
217  uint64 range = static_cast<int64>(max) - min + 1;
218  uint64 number = base::RandUint64();
219  // TODO(ukai): fix to be uniform.
220  // the distribution of the result of modulo will be biased.
221  uint32 result = min + static_cast<uint32>(number % range);
222  DCHECK(result >= min && result <= max);
223  return result;
224}
225
226}
227
228uint32 (*WebSocketHandshake::Parameter::rand_)(uint32 min, uint32 max) =
229    RandUint32;
230uint8 randomCharacterInSecWebSocketKey[0x2F - 0x20 + 0x7E - 0x39];
231
232WebSocketHandshake::Parameter::Parameter()
233    : number_1_(0), number_2_(0) {
234  if (randomCharacterInSecWebSocketKey[0] == '\0') {
235    int i = 0;
236    for (int ch = 0x21; ch <= 0x2F; ch++, i++)
237      randomCharacterInSecWebSocketKey[i] = ch;
238    for (int ch = 0x3A; ch <= 0x7E; ch++, i++)
239      randomCharacterInSecWebSocketKey[i] = ch;
240  }
241}
242
243WebSocketHandshake::Parameter::~Parameter() {}
244
245void WebSocketHandshake::Parameter::GenerateKeys() {
246  GenerateSecWebSocketKey(&number_1_, &key_1_);
247  GenerateSecWebSocketKey(&number_2_, &key_2_);
248  GenerateKey3();
249}
250
251static void SetChallengeNumber(uint8* buf, uint32 number) {
252  uint8* p = buf + 3;
253  for (int i = 0; i < 4; i++) {
254    *p = (uint8)(number & 0xFF);
255    --p;
256    number >>= 8;
257  }
258}
259
260void WebSocketHandshake::Parameter::GetExpectedResponse(uint8 *expected) const {
261  uint8 challenge[kExpectedResponseSize];
262  SetChallengeNumber(&challenge[0], number_1_);
263  SetChallengeNumber(&challenge[4], number_2_);
264  memcpy(&challenge[8], key_3_.data(), kKey3Size);
265  MD5Digest digest;
266  MD5Sum(challenge, kExpectedResponseSize, &digest);
267  memcpy(expected, digest.a, kExpectedResponseSize);
268}
269
270/* static */
271void WebSocketHandshake::Parameter::SetRandomNumberGenerator(
272    uint32 (*rand)(uint32 min, uint32 max)) {
273  rand_ = rand;
274}
275
276void WebSocketHandshake::Parameter::GenerateSecWebSocketKey(
277    uint32* number, std::string* key) {
278  uint32 space = rand_(1, 12);
279  uint32 max = 4294967295U / space;
280  *number = rand_(0, max);
281  uint32 product = *number * space;
282
283  std::string s = base::StringPrintf("%u", product);
284  int n = rand_(1, 12);
285  for (int i = 0; i < n; i++) {
286    int pos = rand_(0, s.length());
287    int chpos = rand_(0, sizeof(randomCharacterInSecWebSocketKey) - 1);
288    s = s.substr(0, pos).append(1, randomCharacterInSecWebSocketKey[chpos]) +
289        s.substr(pos);
290  }
291  for (uint32 i = 0; i < space; i++) {
292    int pos = rand_(1, s.length() - 1);
293    s = s.substr(0, pos) + " " + s.substr(pos);
294  }
295  *key = s;
296}
297
298void WebSocketHandshake::Parameter::GenerateKey3() {
299  key_3_.clear();
300  for (int i = 0; i < 8; i++) {
301    key_3_.append(1, rand_(0, 255));
302  }
303}
304
305}  // namespace net
306