websocket_basic_stream.cc revision 4e180b6a0b4720a9b8e9e959a882386f690f08ff
1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/websockets/websocket_basic_stream.h"
6
7#include <algorithm>
8#include <limits>
9#include <string>
10#include <vector>
11
12#include "base/basictypes.h"
13#include "base/bind.h"
14#include "base/logging.h"
15#include "base/safe_numerics.h"
16#include "net/base/io_buffer.h"
17#include "net/base/net_errors.h"
18#include "net/socket/client_socket_handle.h"
19#include "net/websockets/websocket_errors.h"
20#include "net/websockets/websocket_frame.h"
21#include "net/websockets/websocket_frame_parser.h"
22
23namespace net {
24
25namespace {
26
27// This uses type uint64 to match the definition of
28// WebSocketFrameHeader::payload_length in websocket_frame.h.
29const uint64 kMaxControlFramePayload = 125;
30
31// The number of bytes to attempt to read at a time.
32// TODO(ricea): See if there is a better number or algorithm to fulfill our
33// requirements:
34//  1. We would like to use minimal memory on low-bandwidth or idle connections
35//  2. We would like to read as close to line speed as possible on
36//     high-bandwidth connections
37//  3. We can't afford to cause jank on the IO thread by copying large buffers
38//     around
39//  4. We would like to hit any sweet-spots that might exist in terms of network
40//     packet sizes / encryption block sizes / IPC alignment issues, etc.
41const int kReadBufferSize = 32 * 1024;
42
43typedef ScopedVector<WebSocketFrame>::const_iterator WebSocketFrameIterator;
44
45// Returns the total serialized size of |frames|. This function assumes that
46// |frames| will be serialized with mask field. This function forces the
47// masked bit of the frames on.
48int CalculateSerializedSizeAndTurnOnMaskBit(
49    ScopedVector<WebSocketFrame>* frames) {
50  const int kMaximumTotalSize = std::numeric_limits<int>::max();
51
52  int total_size = 0;
53  for (WebSocketFrameIterator it = frames->begin();
54       it != frames->end(); ++it) {
55    WebSocketFrame* frame = *it;
56    // Force the masked bit on.
57    frame->header.masked = true;
58    // We enforce flow control so the renderer should never be able to force us
59    // to cache anywhere near 2GB of frames.
60    int frame_size = frame->header.payload_length +
61                     GetWebSocketFrameHeaderSize(frame->header);
62    CHECK_GE(kMaximumTotalSize - total_size, frame_size)
63        << "Aborting to prevent overflow";
64    total_size += frame_size;
65  }
66  return total_size;
67}
68
69}  // namespace
70
71WebSocketBasicStream::WebSocketBasicStream(
72    scoped_ptr<ClientSocketHandle> connection)
73    : read_buffer_(new IOBufferWithSize(kReadBufferSize)),
74      connection_(connection.Pass()),
75      generate_websocket_masking_key_(&GenerateWebSocketMaskingKey) {
76  DCHECK(connection_->is_initialized());
77}
78
79WebSocketBasicStream::~WebSocketBasicStream() { Close(); }
80
81int WebSocketBasicStream::ReadFrames(ScopedVector<WebSocketFrame>* frames,
82                                     const CompletionCallback& callback) {
83  DCHECK(frames->empty());
84  // If there is data left over after parsing the HTTP headers, attempt to parse
85  // it as WebSocket frames.
86  if (http_read_buffer_) {
87    DCHECK_GE(http_read_buffer_->offset(), 0);
88    // We cannot simply copy the data into read_buffer_, as it might be too
89    // large.
90    scoped_refptr<GrowableIOBuffer> buffered_data;
91    buffered_data.swap(http_read_buffer_);
92    DCHECK(http_read_buffer_.get() == NULL);
93    ScopedVector<WebSocketFrameChunk> frame_chunks;
94    if (!parser_.Decode(buffered_data->StartOfBuffer(),
95                        buffered_data->offset(),
96                        &frame_chunks))
97      return WebSocketErrorToNetError(parser_.websocket_error());
98    if (!frame_chunks.empty()) {
99      int result = ConvertChunksToFrames(&frame_chunks, frames);
100      if (result != ERR_IO_PENDING)
101        return result;
102    }
103  }
104
105  // Run until socket stops giving us data or we get some frames.
106  while (true) {
107    // base::Unretained(this) here is safe because net::Socket guarantees not to
108    // call any callbacks after Disconnect(), which we call from the
109    // destructor. The caller of ReadFrames() is required to keep |frames|
110    // valid.
111    int result = connection_->socket()->Read(
112        read_buffer_.get(),
113        read_buffer_->size(),
114        base::Bind(&WebSocketBasicStream::OnReadComplete,
115                   base::Unretained(this),
116                   base::Unretained(frames),
117                   callback));
118    if (result == ERR_IO_PENDING)
119      return result;
120    result = HandleReadResult(result, frames);
121    if (result != ERR_IO_PENDING)
122      return result;
123    DCHECK(frames->empty());
124  }
125}
126
127int WebSocketBasicStream::WriteFrames(ScopedVector<WebSocketFrame>* frames,
128                                      const CompletionCallback& callback) {
129  // This function always concatenates all frames into a single buffer.
130  // TODO(ricea): Investigate whether it would be better in some cases to
131  // perform multiple writes with smaller buffers.
132  //
133  // First calculate the size of the buffer we need to allocate.
134  int total_size = CalculateSerializedSizeAndTurnOnMaskBit(frames);
135  scoped_refptr<IOBufferWithSize> combined_buffer(
136      new IOBufferWithSize(total_size));
137
138  char* dest = combined_buffer->data();
139  int remaining_size = total_size;
140  for (WebSocketFrameIterator it = frames->begin();
141       it != frames->end(); ++it) {
142    WebSocketFrame* frame = *it;
143    WebSocketMaskingKey mask = generate_websocket_masking_key_();
144    int result =
145        WriteWebSocketFrameHeader(frame->header, &mask, dest, remaining_size);
146    DCHECK_NE(ERR_INVALID_ARGUMENT, result)
147        << "WriteWebSocketFrameHeader() says that " << remaining_size
148        << " is not enough to write the header in. This should not happen.";
149    CHECK_GE(result, 0) << "Potentially security-critical check failed";
150    dest += result;
151    remaining_size -= result;
152
153    const char* const frame_data = frame->data->data();
154    const int frame_size = frame->header.payload_length;
155    CHECK_GE(remaining_size, frame_size);
156    std::copy(frame_data, frame_data + frame_size, dest);
157    MaskWebSocketFramePayload(mask, 0, dest, frame_size);
158    dest += frame_size;
159    remaining_size -= frame_size;
160  }
161  DCHECK_EQ(0, remaining_size) << "Buffer size calculation was wrong; "
162                               << remaining_size << " bytes left over.";
163  scoped_refptr<DrainableIOBuffer> drainable_buffer(
164      new DrainableIOBuffer(combined_buffer, total_size));
165  return WriteEverything(drainable_buffer, callback);
166}
167
168void WebSocketBasicStream::Close() { connection_->socket()->Disconnect(); }
169
170std::string WebSocketBasicStream::GetSubProtocol() const {
171  return sub_protocol_;
172}
173
174std::string WebSocketBasicStream::GetExtensions() const { return extensions_; }
175
176int WebSocketBasicStream::SendHandshakeRequest(
177    const GURL& url,
178    const HttpRequestHeaders& headers,
179    HttpResponseInfo* response_info,
180    const CompletionCallback& callback) {
181  // TODO(ricea): Implement handshake-related functionality.
182  NOTREACHED();
183  return ERR_NOT_IMPLEMENTED;
184}
185
186int WebSocketBasicStream::ReadHandshakeResponse(
187    const CompletionCallback& callback) {
188  NOTREACHED();
189  return ERR_NOT_IMPLEMENTED;
190}
191
192/*static*/
193scoped_ptr<WebSocketBasicStream>
194WebSocketBasicStream::CreateWebSocketBasicStreamForTesting(
195    scoped_ptr<ClientSocketHandle> connection,
196    const scoped_refptr<GrowableIOBuffer>& http_read_buffer,
197    const std::string& sub_protocol,
198    const std::string& extensions,
199    WebSocketMaskingKeyGeneratorFunction key_generator_function) {
200  scoped_ptr<WebSocketBasicStream> stream(
201      new WebSocketBasicStream(connection.Pass()));
202  if (http_read_buffer) {
203    stream->http_read_buffer_ = http_read_buffer;
204  }
205  stream->sub_protocol_ = sub_protocol;
206  stream->extensions_ = extensions;
207  stream->generate_websocket_masking_key_ = key_generator_function;
208  return stream.Pass();
209}
210
211int WebSocketBasicStream::WriteEverything(
212    const scoped_refptr<DrainableIOBuffer>& buffer,
213    const CompletionCallback& callback) {
214  while (buffer->BytesRemaining() > 0) {
215    // The use of base::Unretained() here is safe because on destruction we
216    // disconnect the socket, preventing any further callbacks.
217    int result = connection_->socket()->Write(
218        buffer.get(),
219        buffer->BytesRemaining(),
220        base::Bind(&WebSocketBasicStream::OnWriteComplete,
221                   base::Unretained(this),
222                   buffer,
223                   callback));
224    if (result > 0) {
225      buffer->DidConsume(result);
226    } else {
227      return result;
228    }
229  }
230  return OK;
231}
232
233void WebSocketBasicStream::OnWriteComplete(
234    const scoped_refptr<DrainableIOBuffer>& buffer,
235    const CompletionCallback& callback,
236    int result) {
237  if (result < 0) {
238    DCHECK_NE(ERR_IO_PENDING, result);
239    callback.Run(result);
240    return;
241  }
242
243  DCHECK_NE(0, result);
244  buffer->DidConsume(result);
245  result = WriteEverything(buffer, callback);
246  if (result != ERR_IO_PENDING)
247    callback.Run(result);
248}
249
250int WebSocketBasicStream::HandleReadResult(
251    int result,
252    ScopedVector<WebSocketFrame>* frames) {
253  DCHECK_NE(ERR_IO_PENDING, result);
254  DCHECK(frames->empty());
255  if (result < 0)
256    return result;
257  if (result == 0)
258    return ERR_CONNECTION_CLOSED;
259  ScopedVector<WebSocketFrameChunk> frame_chunks;
260  if (!parser_.Decode(read_buffer_->data(), result, &frame_chunks))
261    return WebSocketErrorToNetError(parser_.websocket_error());
262  if (frame_chunks.empty())
263    return ERR_IO_PENDING;
264  return ConvertChunksToFrames(&frame_chunks, frames);
265}
266
267int WebSocketBasicStream::ConvertChunksToFrames(
268    ScopedVector<WebSocketFrameChunk>* frame_chunks,
269    ScopedVector<WebSocketFrame>* frames) {
270  for (size_t i = 0; i < frame_chunks->size(); ++i) {
271    scoped_ptr<WebSocketFrame> frame;
272    int result = ConvertChunkToFrame(
273        scoped_ptr<WebSocketFrameChunk>((*frame_chunks)[i]), &frame);
274    (*frame_chunks)[i] = NULL;
275    if (result != OK)
276      return result;
277    if (frame)
278      frames->push_back(frame.release());
279  }
280  // All the elements of |frame_chunks| are now NULL, so there is no point in
281  // calling delete on them all.
282  frame_chunks->weak_clear();
283  if (frames->empty())
284    return ERR_IO_PENDING;
285  return OK;
286}
287
288int WebSocketBasicStream::ConvertChunkToFrame(
289    scoped_ptr<WebSocketFrameChunk> chunk,
290    scoped_ptr<WebSocketFrame>* frame) {
291  DCHECK(frame->get() == NULL);
292  bool is_first_chunk = false;
293  if (chunk->header) {
294    DCHECK(current_frame_header_ == NULL)
295        << "Received the header for a new frame without notification that "
296        << "the previous frame was complete (bug in WebSocketFrameParser?)";
297    is_first_chunk = true;
298    current_frame_header_.swap(chunk->header);
299  }
300  const int chunk_size = chunk->data ? chunk->data->size() : 0;
301  DCHECK(current_frame_header_) << "Unexpected header-less chunk received "
302                                << "(final_chunk = " << chunk->final_chunk
303                                << ", data size = " << chunk_size
304                                << ") (bug in WebSocketFrameParser?)";
305  scoped_refptr<IOBufferWithSize> data_buffer;
306  data_buffer.swap(chunk->data);
307  const bool is_final_chunk = chunk->final_chunk;
308  const WebSocketFrameHeader::OpCode opcode = current_frame_header_->opcode;
309  if (WebSocketFrameHeader::IsKnownControlOpCode(opcode)) {
310    bool protocol_error = false;
311    if (!current_frame_header_->final) {
312      DVLOG(1) << "WebSocket protocol error. Control frame, opcode=" << opcode
313               << " received with FIN bit unset.";
314      protocol_error = true;
315    }
316    if (current_frame_header_->payload_length > kMaxControlFramePayload) {
317      DVLOG(1) << "WebSocket protocol error. Control frame, opcode=" << opcode
318               << ", payload_length=" << current_frame_header_->payload_length
319               << " exceeds maximum payload length for a control message.";
320      protocol_error = true;
321    }
322    if (protocol_error) {
323      current_frame_header_.reset();
324      return ERR_WS_PROTOCOL_ERROR;
325    }
326    if (!is_final_chunk) {
327      DVLOG(2) << "Encountered a split control frame, opcode " << opcode;
328      if (incomplete_control_frame_body_) {
329        DVLOG(3) << "Appending to an existing split control frame.";
330        AddToIncompleteControlFrameBody(data_buffer);
331      } else {
332        DVLOG(3) << "Creating new storage for an incomplete control frame.";
333        incomplete_control_frame_body_ = new GrowableIOBuffer();
334        // This method checks for oversize control frames above, so as long as
335        // the frame parser is working correctly, this won't overflow. If a bug
336        // does cause it to overflow, it will CHECK() in
337        // AddToIncompleteControlFrameBody() without writing outside the buffer.
338        incomplete_control_frame_body_->SetCapacity(kMaxControlFramePayload);
339        AddToIncompleteControlFrameBody(data_buffer);
340      }
341      return OK;
342    }
343    if (incomplete_control_frame_body_) {
344      DVLOG(2) << "Rejoining a split control frame, opcode " << opcode;
345      AddToIncompleteControlFrameBody(data_buffer);
346      const int body_size = incomplete_control_frame_body_->offset();
347      DCHECK_EQ(body_size,
348                static_cast<int>(current_frame_header_->payload_length));
349      scoped_refptr<IOBufferWithSize> body = new IOBufferWithSize(body_size);
350      memcpy(body->data(),
351             incomplete_control_frame_body_->StartOfBuffer(),
352             body_size);
353      incomplete_control_frame_body_ = NULL;  // Frame now complete.
354      DCHECK(is_final_chunk);
355      *frame = CreateFrame(is_final_chunk, body);
356      return OK;
357    }
358  }
359
360  // Apply basic sanity checks to the |payload_length| field from the frame
361  // header. A check for exact equality can only be used when the whole frame
362  // arrives in one chunk.
363  DCHECK_GE(current_frame_header_->payload_length,
364            base::checked_numeric_cast<uint64>(chunk_size));
365  DCHECK(!is_first_chunk || !is_final_chunk ||
366         current_frame_header_->payload_length ==
367             base::checked_numeric_cast<uint64>(chunk_size));
368
369  // Convert the chunk to a complete frame.
370  *frame = CreateFrame(is_final_chunk, data_buffer);
371  return OK;
372}
373
374scoped_ptr<WebSocketFrame> WebSocketBasicStream::CreateFrame(
375    bool is_final_chunk,
376    const scoped_refptr<IOBufferWithSize>& data) {
377  scoped_ptr<WebSocketFrame> result_frame;
378  const bool is_final_chunk_in_message =
379      is_final_chunk && current_frame_header_->final;
380  const int data_size = data ? data->size() : 0;
381  const WebSocketFrameHeader::OpCode opcode = current_frame_header_->opcode;
382  // Empty frames convey no useful information unless they are the first frame
383  // (containing the type and flags) or have the "final" bit set.
384  if (is_final_chunk_in_message || data_size > 0 ||
385      current_frame_header_->opcode !=
386          WebSocketFrameHeader::kOpCodeContinuation) {
387    result_frame.reset(new WebSocketFrame(opcode));
388    result_frame->header.CopyFrom(*current_frame_header_);
389    result_frame->header.final = is_final_chunk_in_message;
390    result_frame->header.payload_length = data_size;
391    result_frame->data = data;
392    // Ensure that opcodes Text and Binary are only used for the first frame in
393    // the message.
394    if (WebSocketFrameHeader::IsKnownDataOpCode(opcode))
395      current_frame_header_->opcode = WebSocketFrameHeader::kOpCodeContinuation;
396  }
397  // Make sure that a frame header is not applied to any chunks that do not
398  // belong to it.
399  if (is_final_chunk)
400    current_frame_header_.reset();
401  return result_frame.Pass();
402}
403
404void WebSocketBasicStream::AddToIncompleteControlFrameBody(
405    const scoped_refptr<IOBufferWithSize>& data_buffer) {
406  if (!data_buffer)
407    return;
408  const int new_offset =
409      incomplete_control_frame_body_->offset() + data_buffer->size();
410  CHECK_GE(incomplete_control_frame_body_->capacity(), new_offset)
411      << "Control frame body larger than frame header indicates; frame parser "
412         "bug?";
413  memcpy(incomplete_control_frame_body_->data(),
414         data_buffer->data(),
415         data_buffer->size());
416  incomplete_control_frame_body_->set_offset(new_offset);
417}
418
419void WebSocketBasicStream::OnReadComplete(ScopedVector<WebSocketFrame>* frames,
420                                          const CompletionCallback& callback,
421                                          int result) {
422  result = HandleReadResult(result, frames);
423  if (result == ERR_IO_PENDING)
424    result = ReadFrames(frames, callback);
425  if (result != ERR_IO_PENDING)
426    callback.Run(result);
427}
428
429}  // namespace net
430