1/*
2 * Copyright (C) 2009 Google Inc.  All rights reserved.
3 * Copyright (C) Research In Motion Limited 2011. All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are
7 * met:
8 *
9 *     * Redistributions of source code must retain the above copyright
10 * notice, this list of conditions and the following disclaimer.
11 *     * Redistributions in binary form must reproduce the above
12 * copyright notice, this list of conditions and the following disclaimer
13 * in the documentation and/or other materials provided with the
14 * distribution.
15 *     * Neither the name of Google Inc. nor the names of its
16 * contributors may be used to endorse or promote products derived from
17 * this software without specific prior written permission.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32#include "config.h"
33
34#if ENABLE(WEB_SOCKETS)
35
36#include "WebSocketHandshake.h"
37
38#include "Cookie.h"
39#include "CookieJar.h"
40#include "Document.h"
41#include "HTTPHeaderMap.h"
42#include "KURL.h"
43#include "Logging.h"
44#include "ScriptCallStack.h"
45#include "ScriptExecutionContext.h"
46#include "SecurityOrigin.h"
47#include <wtf/CryptographicallyRandomNumber.h>
48#include <wtf/MD5.h>
49#include <wtf/StdLibExtras.h>
50#include <wtf/StringExtras.h>
51#include <wtf/Vector.h>
52#include <wtf/text/AtomicString.h>
53#include <wtf/text/CString.h>
54#include <wtf/text/StringBuilder.h>
55#include <wtf/text/StringConcatenate.h>
56#include <wtf/unicode/CharacterNames.h>
57
58namespace WebCore {
59
60static const char randomCharacterInSecWebSocketKey[] = "!\"#$%&'()*+,-./:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~";
61
62static String resourceName(const KURL& url)
63{
64    String name = url.path();
65    if (name.isEmpty())
66        name = "/";
67    if (!url.query().isNull())
68        name += "?" + url.query();
69    ASSERT(!name.isEmpty());
70    ASSERT(!name.contains(' '));
71    return name;
72}
73
74static String hostName(const KURL& url, bool secure)
75{
76    ASSERT(url.protocolIs("wss") == secure);
77    StringBuilder builder;
78    builder.append(url.host().lower());
79    if (url.port() && ((!secure && url.port() != 80) || (secure && url.port() != 443))) {
80        builder.append(':');
81        builder.append(String::number(url.port()));
82    }
83    return builder.toString();
84}
85
86static const size_t maxConsoleMessageSize = 128;
87static String trimConsoleMessage(const char* p, size_t len)
88{
89    String s = String(p, std::min<size_t>(len, maxConsoleMessageSize));
90    if (len > maxConsoleMessageSize)
91        s.append(horizontalEllipsis);
92    return s;
93}
94
95static uint32_t randomNumberLessThan(uint32_t n)
96{
97    if (!n)
98        return 0;
99    if (n == std::numeric_limits<uint32_t>::max())
100        return cryptographicallyRandomNumber();
101    uint32_t max = std::numeric_limits<uint32_t>::max() - (std::numeric_limits<uint32_t>::max() % n);
102    ASSERT(!(max % n));
103    uint32_t v;
104    do {
105        v = cryptographicallyRandomNumber();
106    } while (v >= max);
107    return v % n;
108}
109
110static void generateSecWebSocketKey(uint32_t& number, String& key)
111{
112    uint32_t space = randomNumberLessThan(12) + 1;
113    uint32_t max = 4294967295U / space;
114    number = randomNumberLessThan(max);
115    uint32_t product = number * space;
116
117    String s = String::number(product);
118    int n = randomNumberLessThan(12) + 1;
119    DEFINE_STATIC_LOCAL(String, randomChars, (randomCharacterInSecWebSocketKey));
120    for (int i = 0; i < n; i++) {
121        int pos = randomNumberLessThan(s.length() + 1);
122        int chpos = randomNumberLessThan(randomChars.length());
123        s.insert(randomChars.substring(chpos, 1), pos);
124    }
125    DEFINE_STATIC_LOCAL(String, spaceChar, (" "));
126    for (uint32_t i = 0; i < space; i++) {
127        int pos = randomNumberLessThan(s.length() - 1) + 1;
128        s.insert(spaceChar, pos);
129    }
130    ASSERT(s[0] != ' ');
131    ASSERT(s[s.length() - 1] != ' ');
132    key = s;
133}
134
135static void generateKey3(unsigned char key3[8])
136{
137    cryptographicallyRandomValues(key3, 8);
138}
139
140static void setChallengeNumber(unsigned char* buf, uint32_t number)
141{
142    unsigned char* p = buf + 3;
143    for (int i = 0; i < 4; i++) {
144        *p = number & 0xFF;
145        --p;
146        number >>= 8;
147    }
148}
149
150static void generateExpectedChallengeResponse(uint32_t number1, uint32_t number2, unsigned char key3[8], unsigned char expectedChallenge[16])
151{
152    unsigned char challenge[16];
153    setChallengeNumber(&challenge[0], number1);
154    setChallengeNumber(&challenge[4], number2);
155    memcpy(&challenge[8], key3, 8);
156    MD5 md5;
157    md5.addBytes(challenge, sizeof(challenge));
158    Vector<uint8_t, 16> digest;
159    md5.checksum(digest);
160    memcpy(expectedChallenge, digest.data(), 16);
161}
162
163WebSocketHandshake::WebSocketHandshake(const KURL& url, const String& protocol, ScriptExecutionContext* context)
164    : m_url(url)
165    , m_clientProtocol(protocol)
166    , m_secure(m_url.protocolIs("wss"))
167    , m_context(context)
168    , m_mode(Incomplete)
169{
170    uint32_t number1;
171    uint32_t number2;
172    generateSecWebSocketKey(number1, m_secWebSocketKey1);
173    generateSecWebSocketKey(number2, m_secWebSocketKey2);
174    generateKey3(m_key3);
175    generateExpectedChallengeResponse(number1, number2, m_key3, m_expectedChallengeResponse);
176}
177
178WebSocketHandshake::~WebSocketHandshake()
179{
180}
181
182const KURL& WebSocketHandshake::url() const
183{
184    return m_url;
185}
186
187void WebSocketHandshake::setURL(const KURL& url)
188{
189    m_url = url.copy();
190}
191
192const String WebSocketHandshake::host() const
193{
194    return m_url.host().lower();
195}
196
197const String& WebSocketHandshake::clientProtocol() const
198{
199    return m_clientProtocol;
200}
201
202void WebSocketHandshake::setClientProtocol(const String& protocol)
203{
204    m_clientProtocol = protocol;
205}
206
207bool WebSocketHandshake::secure() const
208{
209    return m_secure;
210}
211
212String WebSocketHandshake::clientOrigin() const
213{
214    return m_context->securityOrigin()->toString();
215}
216
217String WebSocketHandshake::clientLocation() const
218{
219    StringBuilder builder;
220    builder.append(m_secure ? "wss" : "ws");
221    builder.append("://");
222    builder.append(hostName(m_url, m_secure));
223    builder.append(resourceName(m_url));
224    return builder.toString();
225}
226
227CString WebSocketHandshake::clientHandshakeMessage() const
228{
229    // Keep the following consistent with clientHandshakeRequest().
230    StringBuilder builder;
231
232    builder.append("GET ");
233    builder.append(resourceName(m_url));
234    builder.append(" HTTP/1.1\r\n");
235
236    Vector<String> fields;
237    fields.append("Upgrade: WebSocket");
238    fields.append("Connection: Upgrade");
239    fields.append("Host: " + hostName(m_url, m_secure));
240    fields.append("Origin: " + clientOrigin());
241    if (!m_clientProtocol.isEmpty())
242        fields.append("Sec-WebSocket-Protocol: " + m_clientProtocol);
243
244    KURL url = httpURLForAuthenticationAndCookies();
245    if (m_context->isDocument()) {
246        Document* document = static_cast<Document*>(m_context);
247        String cookie = cookieRequestHeaderFieldValue(document, url);
248        if (!cookie.isEmpty())
249            fields.append("Cookie: " + cookie);
250        // Set "Cookie2: <cookie>" if cookies 2 exists for url?
251    }
252
253    fields.append("Sec-WebSocket-Key1: " + m_secWebSocketKey1);
254    fields.append("Sec-WebSocket-Key2: " + m_secWebSocketKey2);
255
256    // Fields in the handshake are sent by the client in a random order; the
257    // order is not meaningful.  Thus, it's ok to send the order we constructed
258    // the fields.
259
260    for (size_t i = 0; i < fields.size(); i++) {
261        builder.append(fields[i]);
262        builder.append("\r\n");
263    }
264
265    builder.append("\r\n");
266
267    CString handshakeHeader = builder.toString().utf8();
268    char* characterBuffer = 0;
269    CString msg = CString::newUninitialized(handshakeHeader.length() + sizeof(m_key3), characterBuffer);
270    memcpy(characterBuffer, handshakeHeader.data(), handshakeHeader.length());
271    memcpy(characterBuffer + handshakeHeader.length(), m_key3, sizeof(m_key3));
272    return msg;
273}
274
275WebSocketHandshakeRequest WebSocketHandshake::clientHandshakeRequest() const
276{
277    // Keep the following consistent with clientHandshakeMessage().
278    // FIXME: do we need to store m_secWebSocketKey1, m_secWebSocketKey2 and
279    // m_key3 in WebSocketHandshakeRequest?
280    WebSocketHandshakeRequest request("GET", m_url);
281    request.addHeaderField("Upgrade", "WebSocket");
282    request.addHeaderField("Connection", "Upgrade");
283    request.addHeaderField("Host", hostName(m_url, m_secure));
284    request.addHeaderField("Origin", clientOrigin());
285    if (!m_clientProtocol.isEmpty())
286        request.addHeaderField("Sec-WebSocket-Protocol:", m_clientProtocol);
287
288    KURL url = httpURLForAuthenticationAndCookies();
289    if (m_context->isDocument()) {
290        Document* document = static_cast<Document*>(m_context);
291        String cookie = cookieRequestHeaderFieldValue(document, url);
292        if (!cookie.isEmpty())
293            request.addHeaderField("Cookie", cookie);
294        // Set "Cookie2: <cookie>" if cookies 2 exists for url?
295    }
296
297    request.addHeaderField("Sec-WebSocket-Key1", m_secWebSocketKey1);
298    request.addHeaderField("Sec-WebSocket-Key2", m_secWebSocketKey2);
299    request.setKey3(m_key3);
300
301    return request;
302}
303
304void WebSocketHandshake::reset()
305{
306    m_mode = Incomplete;
307}
308
309void WebSocketHandshake::clearScriptExecutionContext()
310{
311    m_context = 0;
312}
313
314int WebSocketHandshake::readServerHandshake(const char* header, size_t len)
315{
316    m_mode = Incomplete;
317    int statusCode;
318    String statusText;
319    int lineLength = readStatusLine(header, len, statusCode, statusText);
320    if (lineLength == -1)
321        return -1;
322    if (statusCode == -1) {
323        m_mode = Failed;
324        return len;
325    }
326    LOG(Network, "response code: %d", statusCode);
327    m_response.setStatusCode(statusCode);
328    m_response.setStatusText(statusText);
329    if (statusCode != 101) {
330        m_mode = Failed;
331        m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, makeString("Unexpected response code: ", String::number(statusCode)), 0, clientOrigin(), 0);
332        return len;
333    }
334    m_mode = Normal;
335    if (!strnstr(header, "\r\n\r\n", len)) {
336        // Just hasn't been received fully yet.
337        m_mode = Incomplete;
338        return -1;
339    }
340    const char* p = readHTTPHeaders(header + lineLength, header + len);
341    if (!p) {
342        LOG(Network, "readHTTPHeaders failed");
343        m_mode = Failed;
344        return len;
345    }
346    if (!checkResponseHeaders()) {
347        LOG(Network, "header process failed");
348        m_mode = Failed;
349        return p - header;
350    }
351    if (len < static_cast<size_t>(p - header + sizeof(m_expectedChallengeResponse))) {
352        // Just hasn't been received /expected/ yet.
353        m_mode = Incomplete;
354        return -1;
355    }
356    m_response.setChallengeResponse(static_cast<const unsigned char*>(static_cast<const void*>(p)));
357    if (memcmp(p, m_expectedChallengeResponse, sizeof(m_expectedChallengeResponse))) {
358        m_mode = Failed;
359        return (p - header) + sizeof(m_expectedChallengeResponse);
360    }
361    m_mode = Connected;
362    return (p - header) + sizeof(m_expectedChallengeResponse);
363}
364
365WebSocketHandshake::Mode WebSocketHandshake::mode() const
366{
367    return m_mode;
368}
369
370String WebSocketHandshake::serverWebSocketOrigin() const
371{
372    return m_response.headerFields().get("sec-websocket-origin");
373}
374
375String WebSocketHandshake::serverWebSocketLocation() const
376{
377    return m_response.headerFields().get("sec-websocket-location");
378}
379
380String WebSocketHandshake::serverWebSocketProtocol() const
381{
382    return m_response.headerFields().get("sec-websocket-protocol");
383}
384
385String WebSocketHandshake::serverSetCookie() const
386{
387    return m_response.headerFields().get("set-cookie");
388}
389
390String WebSocketHandshake::serverSetCookie2() const
391{
392    return m_response.headerFields().get("set-cookie2");
393}
394
395String WebSocketHandshake::serverUpgrade() const
396{
397    return m_response.headerFields().get("upgrade");
398}
399
400String WebSocketHandshake::serverConnection() const
401{
402    return m_response.headerFields().get("connection");
403}
404
405const WebSocketHandshakeResponse& WebSocketHandshake::serverHandshakeResponse() const
406{
407    return m_response;
408}
409
410KURL WebSocketHandshake::httpURLForAuthenticationAndCookies() const
411{
412    KURL url = m_url.copy();
413    bool couldSetProtocol = url.setProtocol(m_secure ? "https" : "http");
414    ASSERT_UNUSED(couldSetProtocol, couldSetProtocol);
415    return url;
416}
417
418// Returns the header length (including "\r\n"), or -1 if we have not received enough data yet.
419// If the line is malformed or the status code is not a 3-digit number,
420// statusCode and statusText will be set to -1 and a null string, respectively.
421int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, int& statusCode, String& statusText)
422{
423    // Arbitrary size limit to prevent the server from sending an unbounded
424    // amount of data with no newlines and forcing us to buffer it all.
425    static const int maximumLength = 1024;
426
427    statusCode = -1;
428    statusText = String();
429
430    const char* space1 = 0;
431    const char* space2 = 0;
432    const char* p;
433    size_t consumedLength;
434
435    for (p = header, consumedLength = 0; consumedLength < headerLength; p++, consumedLength++) {
436        if (*p == ' ') {
437            if (!space1)
438                space1 = p;
439            else if (!space2)
440                space2 = p;
441        } else if (*p == '\0') {
442            // The caller isn't prepared to deal with null bytes in status
443            // line. WebSockets specification doesn't prohibit this, but HTTP
444            // does, so we'll just treat this as an error.
445            m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Status line contains embedded null", 0, clientOrigin(), 0);
446            return p + 1 - header;
447        } else if (*p == '\n')
448            break;
449    }
450    if (consumedLength == headerLength)
451        return -1; // We have not received '\n' yet.
452
453    const char* end = p + 1;
454    if (end - header > maximumLength) {
455        m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Status line is too long", 0, clientOrigin(), 0);
456        return maximumLength;
457    }
458    int lineLength = end - header;
459
460    if (!space1 || !space2) {
461        m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "No response code found: " + trimConsoleMessage(header, lineLength - 1), 0, clientOrigin(), 0);
462        return lineLength;
463    }
464
465    // The line must end with "\r\n".
466    if (*(end - 2) != '\r') {
467        m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Status line does not end with CRLF", 0, clientOrigin(), 0);
468        return lineLength;
469    }
470
471    String statusCodeString(space1 + 1, space2 - space1 - 1);
472    if (statusCodeString.length() != 3) // Status code must consist of three digits.
473        return lineLength;
474    for (int i = 0; i < 3; ++i)
475        if (statusCodeString[i] < '0' || statusCodeString[i] > '9') {
476            m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Invalid status code: " + statusCodeString, 0, clientOrigin(), 0);
477            return lineLength;
478        }
479
480    bool ok = false;
481    statusCode = statusCodeString.toInt(&ok);
482    ASSERT(ok);
483
484    statusText = String(space2 + 1, end - space2 - 3); // Exclude "\r\n".
485    return lineLength;
486}
487
488const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* end)
489{
490    m_response.clearHeaderFields();
491
492    Vector<char> name;
493    Vector<char> value;
494    for (const char* p = start; p < end; p++) {
495        name.clear();
496        value.clear();
497
498        for (; p < end; p++) {
499            switch (*p) {
500            case '\r':
501                if (name.isEmpty()) {
502                    if (p + 1 < end && *(p + 1) == '\n')
503                        return p + 2;
504                    m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "CR doesn't follow LF at " + trimConsoleMessage(p, end - p), 0, clientOrigin(), 0);
505                    return 0;
506                }
507                m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected CR in name at " + trimConsoleMessage(name.data(), name.size()), 0, clientOrigin(), 0);
508                return 0;
509            case '\n':
510                m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in name at " + trimConsoleMessage(name.data(), name.size()), 0, clientOrigin(), 0);
511                return 0;
512            case ':':
513                break;
514            default:
515                name.append(*p);
516                continue;
517            }
518            if (*p == ':') {
519                ++p;
520                break;
521            }
522        }
523
524        for (; p < end && *p == 0x20; p++) { }
525
526        for (; p < end; p++) {
527            switch (*p) {
528            case '\r':
529                break;
530            case '\n':
531                m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in value at " + trimConsoleMessage(value.data(), value.size()), 0, clientOrigin(), 0);
532                return 0;
533            default:
534                value.append(*p);
535            }
536            if (*p == '\r') {
537                ++p;
538                break;
539            }
540        }
541        if (p >= end || *p != '\n') {
542            m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "CR doesn't follow LF after value at " + trimConsoleMessage(p, end - p), 0, clientOrigin(), 0);
543            return 0;
544        }
545        AtomicString nameStr = AtomicString::fromUTF8(name.data(), name.size());
546        String valueStr = String::fromUTF8(value.data(), value.size());
547        if (nameStr.isNull()) {
548            m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "invalid UTF-8 sequence in header name", 0, clientOrigin(), 0);
549            return 0;
550        }
551        if (valueStr.isNull()) {
552            m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "invalid UTF-8 sequence in header value", 0, clientOrigin(), 0);
553            return 0;
554        }
555        LOG(Network, "name=%s value=%s", nameStr.string().utf8().data(), valueStr.utf8().data());
556        m_response.addHeaderField(nameStr, valueStr);
557    }
558    ASSERT_NOT_REACHED();
559    return 0;
560}
561
562bool WebSocketHandshake::checkResponseHeaders()
563{
564    const String& serverWebSocketLocation = this->serverWebSocketLocation();
565    const String& serverWebSocketOrigin = this->serverWebSocketOrigin();
566    const String& serverWebSocketProtocol = this->serverWebSocketProtocol();
567    const String& serverUpgrade = this->serverUpgrade();
568    const String& serverConnection = this->serverConnection();
569
570    if (serverUpgrade.isNull()) {
571        m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'Upgrade' header is missing", 0, clientOrigin(), 0);
572        return false;
573    }
574    if (serverConnection.isNull()) {
575        m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'Connection' header is missing", 0, clientOrigin(), 0);
576        return false;
577    }
578    if (serverWebSocketOrigin.isNull()) {
579        m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'Sec-WebSocket-Origin' header is missing", 0, clientOrigin(), 0);
580        return false;
581    }
582    if (serverWebSocketLocation.isNull()) {
583        m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'Sec-WebSocket-Location' header is missing", 0, clientOrigin(), 0);
584        return false;
585    }
586
587    if (!equalIgnoringCase(serverUpgrade, "websocket")) {
588        m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'Upgrade' header value is not 'WebSocket'", 0, clientOrigin(), 0);
589        return false;
590    }
591    if (!equalIgnoringCase(serverConnection, "upgrade")) {
592        m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'Connection' header value is not 'Upgrade'", 0, clientOrigin(), 0);
593        return false;
594    }
595
596    if (clientOrigin() != serverWebSocketOrigin) {
597        m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: origin mismatch: " + clientOrigin() + " != " + serverWebSocketOrigin, 0, clientOrigin(), 0);
598        return false;
599    }
600    if (clientLocation() != serverWebSocketLocation) {
601        m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: location mismatch: " + clientLocation() + " != " + serverWebSocketLocation, 0, clientOrigin(), 0);
602        return false;
603    }
604    if (!m_clientProtocol.isEmpty() && m_clientProtocol != serverWebSocketProtocol) {
605        m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: protocol mismatch: " + m_clientProtocol + " != " + serverWebSocketProtocol, 0, clientOrigin(), 0);
606        return false;
607    }
608    return true;
609}
610
611} // namespace WebCore
612
613#endif // ENABLE(WEB_SOCKETS)
614