1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/websockets/websocket_deflate_stream.h"
6
7#include <algorithm>
8#include <string>
9
10#include "base/bind.h"
11#include "base/logging.h"
12#include "base/memory/ref_counted.h"
13#include "base/memory/scoped_ptr.h"
14#include "base/memory/scoped_vector.h"
15#include "net/base/completion_callback.h"
16#include "net/base/io_buffer.h"
17#include "net/base/net_errors.h"
18#include "net/websockets/websocket_deflate_predictor.h"
19#include "net/websockets/websocket_deflater.h"
20#include "net/websockets/websocket_errors.h"
21#include "net/websockets/websocket_frame.h"
22#include "net/websockets/websocket_inflater.h"
23#include "net/websockets/websocket_stream.h"
24
25class GURL;
26
27namespace net {
28
29namespace {
30
31const int kWindowBits = 15;
32const size_t kChunkSize = 4 * 1024;
33
34}  // namespace
35
36WebSocketDeflateStream::WebSocketDeflateStream(
37    scoped_ptr<WebSocketStream> stream,
38    WebSocketDeflater::ContextTakeOverMode mode,
39    int client_window_bits,
40    scoped_ptr<WebSocketDeflatePredictor> predictor)
41    : stream_(stream.Pass()),
42      deflater_(mode),
43      inflater_(kChunkSize, kChunkSize),
44      reading_state_(NOT_READING),
45      writing_state_(NOT_WRITING),
46      current_reading_opcode_(WebSocketFrameHeader::kOpCodeText),
47      current_writing_opcode_(WebSocketFrameHeader::kOpCodeText),
48      predictor_(predictor.Pass()) {
49  DCHECK(stream_);
50  DCHECK_GE(client_window_bits, 8);
51  DCHECK_LE(client_window_bits, 15);
52  deflater_.Initialize(client_window_bits);
53  inflater_.Initialize(kWindowBits);
54}
55
56WebSocketDeflateStream::~WebSocketDeflateStream() {}
57
58int WebSocketDeflateStream::ReadFrames(ScopedVector<WebSocketFrame>* frames,
59                                       const CompletionCallback& callback) {
60  int result = stream_->ReadFrames(
61      frames,
62      base::Bind(&WebSocketDeflateStream::OnReadComplete,
63                 base::Unretained(this),
64                 base::Unretained(frames),
65                 callback));
66  if (result < 0)
67    return result;
68  DCHECK_EQ(OK, result);
69  DCHECK(!frames->empty());
70
71  return InflateAndReadIfNecessary(frames, callback);
72}
73
74int WebSocketDeflateStream::WriteFrames(ScopedVector<WebSocketFrame>* frames,
75                                        const CompletionCallback& callback) {
76  int result = Deflate(frames);
77  if (result != OK)
78    return result;
79  if (frames->empty())
80    return OK;
81  return stream_->WriteFrames(frames, callback);
82}
83
84void WebSocketDeflateStream::Close() { stream_->Close(); }
85
86std::string WebSocketDeflateStream::GetSubProtocol() const {
87  return stream_->GetSubProtocol();
88}
89
90std::string WebSocketDeflateStream::GetExtensions() const {
91  return stream_->GetExtensions();
92}
93
94void WebSocketDeflateStream::OnReadComplete(
95    ScopedVector<WebSocketFrame>* frames,
96    const CompletionCallback& callback,
97    int result) {
98  if (result != OK) {
99    frames->clear();
100    callback.Run(result);
101    return;
102  }
103
104  int r = InflateAndReadIfNecessary(frames, callback);
105  if (r != ERR_IO_PENDING)
106    callback.Run(r);
107}
108
109int WebSocketDeflateStream::Deflate(ScopedVector<WebSocketFrame>* frames) {
110  ScopedVector<WebSocketFrame> frames_to_write;
111  // Store frames of the currently processed message if writing_state_ equals to
112  // WRITING_POSSIBLY_COMPRESSED_MESSAGE.
113  ScopedVector<WebSocketFrame> frames_of_message;
114  for (size_t i = 0; i < frames->size(); ++i) {
115    DCHECK(!(*frames)[i]->header.reserved1);
116    if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames)[i]->header.opcode)) {
117      frames_to_write.push_back((*frames)[i]);
118      (*frames)[i] = NULL;
119      continue;
120    }
121    if (writing_state_ == NOT_WRITING)
122      OnMessageStart(*frames, i);
123
124    scoped_ptr<WebSocketFrame> frame((*frames)[i]);
125    (*frames)[i] = NULL;
126    predictor_->RecordInputDataFrame(frame.get());
127
128    if (writing_state_ == WRITING_UNCOMPRESSED_MESSAGE) {
129      if (frame->header.final)
130        writing_state_ = NOT_WRITING;
131      predictor_->RecordWrittenDataFrame(frame.get());
132      frames_to_write.push_back(frame.release());
133      current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
134    } else {
135      if (frame->data.get() &&
136          !deflater_.AddBytes(frame->data->data(),
137                              frame->header.payload_length)) {
138        DVLOG(1) << "WebSocket protocol error. "
139                 << "deflater_.AddBytes() returns an error.";
140        return ERR_WS_PROTOCOL_ERROR;
141      }
142      if (frame->header.final && !deflater_.Finish()) {
143        DVLOG(1) << "WebSocket protocol error. "
144                 << "deflater_.Finish() returns an error.";
145        return ERR_WS_PROTOCOL_ERROR;
146      }
147
148      if (writing_state_ == WRITING_COMPRESSED_MESSAGE) {
149        if (deflater_.CurrentOutputSize() >= kChunkSize ||
150            frame->header.final) {
151          int result = AppendCompressedFrame(frame->header, &frames_to_write);
152          if (result != OK)
153            return result;
154        }
155        if (frame->header.final)
156          writing_state_ = NOT_WRITING;
157      } else {
158        DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
159        bool final = frame->header.final;
160        frames_of_message.push_back(frame.release());
161        if (final) {
162          int result = AppendPossiblyCompressedMessage(&frames_of_message,
163                                                       &frames_to_write);
164          if (result != OK)
165            return result;
166          frames_of_message.clear();
167          writing_state_ = NOT_WRITING;
168        }
169      }
170    }
171  }
172  DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
173  frames->swap(frames_to_write);
174  return OK;
175}
176
177void WebSocketDeflateStream::OnMessageStart(
178    const ScopedVector<WebSocketFrame>& frames, size_t index) {
179  WebSocketFrame* frame = frames[index];
180  current_writing_opcode_ = frame->header.opcode;
181  DCHECK(current_writing_opcode_ == WebSocketFrameHeader::kOpCodeText ||
182         current_writing_opcode_ == WebSocketFrameHeader::kOpCodeBinary);
183  WebSocketDeflatePredictor::Result prediction =
184      predictor_->Predict(frames, index);
185
186  switch (prediction) {
187    case WebSocketDeflatePredictor::DEFLATE:
188      writing_state_ = WRITING_COMPRESSED_MESSAGE;
189      return;
190    case WebSocketDeflatePredictor::DO_NOT_DEFLATE:
191      writing_state_ = WRITING_UNCOMPRESSED_MESSAGE;
192      return;
193    case WebSocketDeflatePredictor::TRY_DEFLATE:
194      writing_state_ = WRITING_POSSIBLY_COMPRESSED_MESSAGE;
195      return;
196  }
197  NOTREACHED();
198}
199
200int WebSocketDeflateStream::AppendCompressedFrame(
201    const WebSocketFrameHeader& header,
202    ScopedVector<WebSocketFrame>* frames_to_write) {
203  const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
204  scoped_refptr<IOBufferWithSize> compressed_payload =
205      deflater_.GetOutput(deflater_.CurrentOutputSize());
206  if (!compressed_payload.get()) {
207    DVLOG(1) << "WebSocket protocol error. "
208             << "deflater_.GetOutput() returns an error.";
209    return ERR_WS_PROTOCOL_ERROR;
210  }
211  scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode));
212  compressed->header.CopyFrom(header);
213  compressed->header.opcode = opcode;
214  compressed->header.final = header.final;
215  compressed->header.reserved1 =
216      (opcode != WebSocketFrameHeader::kOpCodeContinuation);
217  compressed->data = compressed_payload;
218  compressed->header.payload_length = compressed_payload->size();
219
220  current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
221  predictor_->RecordWrittenDataFrame(compressed.get());
222  frames_to_write->push_back(compressed.release());
223  return OK;
224}
225
226int WebSocketDeflateStream::AppendPossiblyCompressedMessage(
227    ScopedVector<WebSocketFrame>* frames,
228    ScopedVector<WebSocketFrame>* frames_to_write) {
229  DCHECK(!frames->empty());
230
231  const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
232  scoped_refptr<IOBufferWithSize> compressed_payload =
233      deflater_.GetOutput(deflater_.CurrentOutputSize());
234  if (!compressed_payload.get()) {
235    DVLOG(1) << "WebSocket protocol error. "
236             << "deflater_.GetOutput() returns an error.";
237    return ERR_WS_PROTOCOL_ERROR;
238  }
239
240  uint64 original_payload_length = 0;
241  for (size_t i = 0; i < frames->size(); ++i) {
242    WebSocketFrame* frame = (*frames)[i];
243    // Asserts checking that frames represent one whole data message.
244    DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode));
245    DCHECK_EQ(i == 0,
246              WebSocketFrameHeader::kOpCodeContinuation !=
247              frame->header.opcode);
248    DCHECK_EQ(i == frames->size() - 1, frame->header.final);
249    original_payload_length += frame->header.payload_length;
250  }
251  if (original_payload_length <=
252      static_cast<uint64>(compressed_payload->size())) {
253    // Compression is not effective. Use the original frames.
254    for (size_t i = 0; i < frames->size(); ++i) {
255      WebSocketFrame* frame = (*frames)[i];
256      frames_to_write->push_back(frame);
257      predictor_->RecordWrittenDataFrame(frame);
258      (*frames)[i] = NULL;
259    }
260    frames->weak_clear();
261    return OK;
262  }
263  scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode));
264  compressed->header.CopyFrom((*frames)[0]->header);
265  compressed->header.opcode = opcode;
266  compressed->header.final = true;
267  compressed->header.reserved1 = true;
268  compressed->data = compressed_payload;
269  compressed->header.payload_length = compressed_payload->size();
270
271  predictor_->RecordWrittenDataFrame(compressed.get());
272  frames_to_write->push_back(compressed.release());
273  return OK;
274}
275
276int WebSocketDeflateStream::Inflate(ScopedVector<WebSocketFrame>* frames) {
277  ScopedVector<WebSocketFrame> frames_to_output;
278  ScopedVector<WebSocketFrame> frames_passed;
279  frames->swap(frames_passed);
280  for (size_t i = 0; i < frames_passed.size(); ++i) {
281    scoped_ptr<WebSocketFrame> frame(frames_passed[i]);
282    frames_passed[i] = NULL;
283    DVLOG(3) << "Input frame: opcode=" << frame->header.opcode
284             << " final=" << frame->header.final
285             << " reserved1=" << frame->header.reserved1
286             << " payload_length=" << frame->header.payload_length;
287
288    if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) {
289      frames_to_output.push_back(frame.release());
290      continue;
291    }
292
293    if (reading_state_ == NOT_READING) {
294      if (frame->header.reserved1)
295        reading_state_ = READING_COMPRESSED_MESSAGE;
296      else
297        reading_state_ = READING_UNCOMPRESSED_MESSAGE;
298      current_reading_opcode_ = frame->header.opcode;
299    } else {
300      if (frame->header.reserved1) {
301        DVLOG(1) << "WebSocket protocol error. "
302                 << "Receiving a non-first frame with RSV1 flag set.";
303        return ERR_WS_PROTOCOL_ERROR;
304      }
305    }
306
307    if (reading_state_ == READING_UNCOMPRESSED_MESSAGE) {
308      if (frame->header.final)
309        reading_state_ = NOT_READING;
310      current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
311      frames_to_output.push_back(frame.release());
312    } else {
313      DCHECK_EQ(reading_state_, READING_COMPRESSED_MESSAGE);
314      if (frame->data.get() &&
315          !inflater_.AddBytes(frame->data->data(),
316                              frame->header.payload_length)) {
317        DVLOG(1) << "WebSocket protocol error. "
318                 << "inflater_.AddBytes() returns an error.";
319        return ERR_WS_PROTOCOL_ERROR;
320      }
321      if (frame->header.final) {
322        if (!inflater_.Finish()) {
323          DVLOG(1) << "WebSocket protocol error. "
324                   << "inflater_.Finish() returns an error.";
325          return ERR_WS_PROTOCOL_ERROR;
326        }
327      }
328      // TODO(yhirano): Many frames can be generated by the inflater and
329      // memory consumption can grow.
330      // We could avoid it, but avoiding it makes this class much more
331      // complicated.
332      while (inflater_.CurrentOutputSize() >= kChunkSize ||
333             frame->header.final) {
334        size_t size = std::min(kChunkSize, inflater_.CurrentOutputSize());
335        scoped_ptr<WebSocketFrame> inflated(
336            new WebSocketFrame(WebSocketFrameHeader::kOpCodeText));
337        scoped_refptr<IOBufferWithSize> data = inflater_.GetOutput(size);
338        bool is_final = !inflater_.CurrentOutputSize() && frame->header.final;
339        if (!data.get()) {
340          DVLOG(1) << "WebSocket protocol error. "
341                   << "inflater_.GetOutput() returns an error.";
342          return ERR_WS_PROTOCOL_ERROR;
343        }
344        inflated->header.CopyFrom(frame->header);
345        inflated->header.opcode = current_reading_opcode_;
346        inflated->header.final = is_final;
347        inflated->header.reserved1 = false;
348        inflated->data = data;
349        inflated->header.payload_length = data->size();
350        DVLOG(3) << "Inflated frame: opcode=" << inflated->header.opcode
351                 << " final=" << inflated->header.final
352                 << " reserved1=" << inflated->header.reserved1
353                 << " payload_length=" << inflated->header.payload_length;
354        frames_to_output.push_back(inflated.release());
355        current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
356        if (is_final)
357          break;
358      }
359      if (frame->header.final)
360        reading_state_ = NOT_READING;
361    }
362  }
363  frames->swap(frames_to_output);
364  return frames->empty() ? ERR_IO_PENDING : OK;
365}
366
367int WebSocketDeflateStream::InflateAndReadIfNecessary(
368    ScopedVector<WebSocketFrame>* frames,
369    const CompletionCallback& callback) {
370  int result = Inflate(frames);
371  while (result == ERR_IO_PENDING) {
372    DCHECK(frames->empty());
373
374    result = stream_->ReadFrames(
375        frames,
376        base::Bind(&WebSocketDeflateStream::OnReadComplete,
377                   base::Unretained(this),
378                   base::Unretained(frames),
379                   callback));
380    if (result < 0)
381      break;
382    DCHECK_EQ(OK, result);
383    DCHECK(!frames->empty());
384
385    result = Inflate(frames);
386  }
387  if (result < 0)
388    frames->clear();
389  return result;
390}
391
392}  // namespace net
393