1// Copyright 2013 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "net/websockets/websocket_deflate_stream.h" 6 7#include <algorithm> 8#include <string> 9 10#include "base/bind.h" 11#include "base/logging.h" 12#include "base/memory/ref_counted.h" 13#include "base/memory/scoped_ptr.h" 14#include "base/memory/scoped_vector.h" 15#include "net/base/completion_callback.h" 16#include "net/base/io_buffer.h" 17#include "net/base/net_errors.h" 18#include "net/websockets/websocket_deflate_predictor.h" 19#include "net/websockets/websocket_deflater.h" 20#include "net/websockets/websocket_errors.h" 21#include "net/websockets/websocket_frame.h" 22#include "net/websockets/websocket_inflater.h" 23#include "net/websockets/websocket_stream.h" 24 25class GURL; 26 27namespace net { 28 29namespace { 30 31const int kWindowBits = 15; 32const size_t kChunkSize = 4 * 1024; 33 34} // namespace 35 36WebSocketDeflateStream::WebSocketDeflateStream( 37 scoped_ptr<WebSocketStream> stream, 38 WebSocketDeflater::ContextTakeOverMode mode, 39 int client_window_bits, 40 scoped_ptr<WebSocketDeflatePredictor> predictor) 41 : stream_(stream.Pass()), 42 deflater_(mode), 43 inflater_(kChunkSize, kChunkSize), 44 reading_state_(NOT_READING), 45 writing_state_(NOT_WRITING), 46 current_reading_opcode_(WebSocketFrameHeader::kOpCodeText), 47 current_writing_opcode_(WebSocketFrameHeader::kOpCodeText), 48 predictor_(predictor.Pass()) { 49 DCHECK(stream_); 50 DCHECK_GE(client_window_bits, 8); 51 DCHECK_LE(client_window_bits, 15); 52 deflater_.Initialize(client_window_bits); 53 inflater_.Initialize(kWindowBits); 54} 55 56WebSocketDeflateStream::~WebSocketDeflateStream() {} 57 58int WebSocketDeflateStream::ReadFrames(ScopedVector<WebSocketFrame>* frames, 59 const CompletionCallback& callback) { 60 int result = stream_->ReadFrames( 61 frames, 62 base::Bind(&WebSocketDeflateStream::OnReadComplete, 63 base::Unretained(this), 64 base::Unretained(frames), 65 callback)); 66 if (result < 0) 67 return result; 68 DCHECK_EQ(OK, result); 69 DCHECK(!frames->empty()); 70 71 return InflateAndReadIfNecessary(frames, callback); 72} 73 74int WebSocketDeflateStream::WriteFrames(ScopedVector<WebSocketFrame>* frames, 75 const CompletionCallback& callback) { 76 int result = Deflate(frames); 77 if (result != OK) 78 return result; 79 if (frames->empty()) 80 return OK; 81 return stream_->WriteFrames(frames, callback); 82} 83 84void WebSocketDeflateStream::Close() { stream_->Close(); } 85 86std::string WebSocketDeflateStream::GetSubProtocol() const { 87 return stream_->GetSubProtocol(); 88} 89 90std::string WebSocketDeflateStream::GetExtensions() const { 91 return stream_->GetExtensions(); 92} 93 94void WebSocketDeflateStream::OnReadComplete( 95 ScopedVector<WebSocketFrame>* frames, 96 const CompletionCallback& callback, 97 int result) { 98 if (result != OK) { 99 frames->clear(); 100 callback.Run(result); 101 return; 102 } 103 104 int r = InflateAndReadIfNecessary(frames, callback); 105 if (r != ERR_IO_PENDING) 106 callback.Run(r); 107} 108 109int WebSocketDeflateStream::Deflate(ScopedVector<WebSocketFrame>* frames) { 110 ScopedVector<WebSocketFrame> frames_to_write; 111 // Store frames of the currently processed message if writing_state_ equals to 112 // WRITING_POSSIBLY_COMPRESSED_MESSAGE. 113 ScopedVector<WebSocketFrame> frames_of_message; 114 for (size_t i = 0; i < frames->size(); ++i) { 115 DCHECK(!(*frames)[i]->header.reserved1); 116 if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames)[i]->header.opcode)) { 117 frames_to_write.push_back((*frames)[i]); 118 (*frames)[i] = NULL; 119 continue; 120 } 121 if (writing_state_ == NOT_WRITING) 122 OnMessageStart(*frames, i); 123 124 scoped_ptr<WebSocketFrame> frame((*frames)[i]); 125 (*frames)[i] = NULL; 126 predictor_->RecordInputDataFrame(frame.get()); 127 128 if (writing_state_ == WRITING_UNCOMPRESSED_MESSAGE) { 129 if (frame->header.final) 130 writing_state_ = NOT_WRITING; 131 predictor_->RecordWrittenDataFrame(frame.get()); 132 frames_to_write.push_back(frame.release()); 133 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; 134 } else { 135 if (frame->data.get() && 136 !deflater_.AddBytes(frame->data->data(), 137 frame->header.payload_length)) { 138 DVLOG(1) << "WebSocket protocol error. " 139 << "deflater_.AddBytes() returns an error."; 140 return ERR_WS_PROTOCOL_ERROR; 141 } 142 if (frame->header.final && !deflater_.Finish()) { 143 DVLOG(1) << "WebSocket protocol error. " 144 << "deflater_.Finish() returns an error."; 145 return ERR_WS_PROTOCOL_ERROR; 146 } 147 148 if (writing_state_ == WRITING_COMPRESSED_MESSAGE) { 149 if (deflater_.CurrentOutputSize() >= kChunkSize || 150 frame->header.final) { 151 int result = AppendCompressedFrame(frame->header, &frames_to_write); 152 if (result != OK) 153 return result; 154 } 155 if (frame->header.final) 156 writing_state_ = NOT_WRITING; 157 } else { 158 DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_); 159 bool final = frame->header.final; 160 frames_of_message.push_back(frame.release()); 161 if (final) { 162 int result = AppendPossiblyCompressedMessage(&frames_of_message, 163 &frames_to_write); 164 if (result != OK) 165 return result; 166 frames_of_message.clear(); 167 writing_state_ = NOT_WRITING; 168 } 169 } 170 } 171 } 172 DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_); 173 frames->swap(frames_to_write); 174 return OK; 175} 176 177void WebSocketDeflateStream::OnMessageStart( 178 const ScopedVector<WebSocketFrame>& frames, size_t index) { 179 WebSocketFrame* frame = frames[index]; 180 current_writing_opcode_ = frame->header.opcode; 181 DCHECK(current_writing_opcode_ == WebSocketFrameHeader::kOpCodeText || 182 current_writing_opcode_ == WebSocketFrameHeader::kOpCodeBinary); 183 WebSocketDeflatePredictor::Result prediction = 184 predictor_->Predict(frames, index); 185 186 switch (prediction) { 187 case WebSocketDeflatePredictor::DEFLATE: 188 writing_state_ = WRITING_COMPRESSED_MESSAGE; 189 return; 190 case WebSocketDeflatePredictor::DO_NOT_DEFLATE: 191 writing_state_ = WRITING_UNCOMPRESSED_MESSAGE; 192 return; 193 case WebSocketDeflatePredictor::TRY_DEFLATE: 194 writing_state_ = WRITING_POSSIBLY_COMPRESSED_MESSAGE; 195 return; 196 } 197 NOTREACHED(); 198} 199 200int WebSocketDeflateStream::AppendCompressedFrame( 201 const WebSocketFrameHeader& header, 202 ScopedVector<WebSocketFrame>* frames_to_write) { 203 const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_; 204 scoped_refptr<IOBufferWithSize> compressed_payload = 205 deflater_.GetOutput(deflater_.CurrentOutputSize()); 206 if (!compressed_payload.get()) { 207 DVLOG(1) << "WebSocket protocol error. " 208 << "deflater_.GetOutput() returns an error."; 209 return ERR_WS_PROTOCOL_ERROR; 210 } 211 scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode)); 212 compressed->header.CopyFrom(header); 213 compressed->header.opcode = opcode; 214 compressed->header.final = header.final; 215 compressed->header.reserved1 = 216 (opcode != WebSocketFrameHeader::kOpCodeContinuation); 217 compressed->data = compressed_payload; 218 compressed->header.payload_length = compressed_payload->size(); 219 220 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; 221 predictor_->RecordWrittenDataFrame(compressed.get()); 222 frames_to_write->push_back(compressed.release()); 223 return OK; 224} 225 226int WebSocketDeflateStream::AppendPossiblyCompressedMessage( 227 ScopedVector<WebSocketFrame>* frames, 228 ScopedVector<WebSocketFrame>* frames_to_write) { 229 DCHECK(!frames->empty()); 230 231 const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_; 232 scoped_refptr<IOBufferWithSize> compressed_payload = 233 deflater_.GetOutput(deflater_.CurrentOutputSize()); 234 if (!compressed_payload.get()) { 235 DVLOG(1) << "WebSocket protocol error. " 236 << "deflater_.GetOutput() returns an error."; 237 return ERR_WS_PROTOCOL_ERROR; 238 } 239 240 uint64 original_payload_length = 0; 241 for (size_t i = 0; i < frames->size(); ++i) { 242 WebSocketFrame* frame = (*frames)[i]; 243 // Asserts checking that frames represent one whole data message. 244 DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)); 245 DCHECK_EQ(i == 0, 246 WebSocketFrameHeader::kOpCodeContinuation != 247 frame->header.opcode); 248 DCHECK_EQ(i == frames->size() - 1, frame->header.final); 249 original_payload_length += frame->header.payload_length; 250 } 251 if (original_payload_length <= 252 static_cast<uint64>(compressed_payload->size())) { 253 // Compression is not effective. Use the original frames. 254 for (size_t i = 0; i < frames->size(); ++i) { 255 WebSocketFrame* frame = (*frames)[i]; 256 frames_to_write->push_back(frame); 257 predictor_->RecordWrittenDataFrame(frame); 258 (*frames)[i] = NULL; 259 } 260 frames->weak_clear(); 261 return OK; 262 } 263 scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode)); 264 compressed->header.CopyFrom((*frames)[0]->header); 265 compressed->header.opcode = opcode; 266 compressed->header.final = true; 267 compressed->header.reserved1 = true; 268 compressed->data = compressed_payload; 269 compressed->header.payload_length = compressed_payload->size(); 270 271 predictor_->RecordWrittenDataFrame(compressed.get()); 272 frames_to_write->push_back(compressed.release()); 273 return OK; 274} 275 276int WebSocketDeflateStream::Inflate(ScopedVector<WebSocketFrame>* frames) { 277 ScopedVector<WebSocketFrame> frames_to_output; 278 ScopedVector<WebSocketFrame> frames_passed; 279 frames->swap(frames_passed); 280 for (size_t i = 0; i < frames_passed.size(); ++i) { 281 scoped_ptr<WebSocketFrame> frame(frames_passed[i]); 282 frames_passed[i] = NULL; 283 DVLOG(3) << "Input frame: opcode=" << frame->header.opcode 284 << " final=" << frame->header.final 285 << " reserved1=" << frame->header.reserved1 286 << " payload_length=" << frame->header.payload_length; 287 288 if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) { 289 frames_to_output.push_back(frame.release()); 290 continue; 291 } 292 293 if (reading_state_ == NOT_READING) { 294 if (frame->header.reserved1) 295 reading_state_ = READING_COMPRESSED_MESSAGE; 296 else 297 reading_state_ = READING_UNCOMPRESSED_MESSAGE; 298 current_reading_opcode_ = frame->header.opcode; 299 } else { 300 if (frame->header.reserved1) { 301 DVLOG(1) << "WebSocket protocol error. " 302 << "Receiving a non-first frame with RSV1 flag set."; 303 return ERR_WS_PROTOCOL_ERROR; 304 } 305 } 306 307 if (reading_state_ == READING_UNCOMPRESSED_MESSAGE) { 308 if (frame->header.final) 309 reading_state_ = NOT_READING; 310 current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; 311 frames_to_output.push_back(frame.release()); 312 } else { 313 DCHECK_EQ(reading_state_, READING_COMPRESSED_MESSAGE); 314 if (frame->data.get() && 315 !inflater_.AddBytes(frame->data->data(), 316 frame->header.payload_length)) { 317 DVLOG(1) << "WebSocket protocol error. " 318 << "inflater_.AddBytes() returns an error."; 319 return ERR_WS_PROTOCOL_ERROR; 320 } 321 if (frame->header.final) { 322 if (!inflater_.Finish()) { 323 DVLOG(1) << "WebSocket protocol error. " 324 << "inflater_.Finish() returns an error."; 325 return ERR_WS_PROTOCOL_ERROR; 326 } 327 } 328 // TODO(yhirano): Many frames can be generated by the inflater and 329 // memory consumption can grow. 330 // We could avoid it, but avoiding it makes this class much more 331 // complicated. 332 while (inflater_.CurrentOutputSize() >= kChunkSize || 333 frame->header.final) { 334 size_t size = std::min(kChunkSize, inflater_.CurrentOutputSize()); 335 scoped_ptr<WebSocketFrame> inflated( 336 new WebSocketFrame(WebSocketFrameHeader::kOpCodeText)); 337 scoped_refptr<IOBufferWithSize> data = inflater_.GetOutput(size); 338 bool is_final = !inflater_.CurrentOutputSize() && frame->header.final; 339 if (!data.get()) { 340 DVLOG(1) << "WebSocket protocol error. " 341 << "inflater_.GetOutput() returns an error."; 342 return ERR_WS_PROTOCOL_ERROR; 343 } 344 inflated->header.CopyFrom(frame->header); 345 inflated->header.opcode = current_reading_opcode_; 346 inflated->header.final = is_final; 347 inflated->header.reserved1 = false; 348 inflated->data = data; 349 inflated->header.payload_length = data->size(); 350 DVLOG(3) << "Inflated frame: opcode=" << inflated->header.opcode 351 << " final=" << inflated->header.final 352 << " reserved1=" << inflated->header.reserved1 353 << " payload_length=" << inflated->header.payload_length; 354 frames_to_output.push_back(inflated.release()); 355 current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; 356 if (is_final) 357 break; 358 } 359 if (frame->header.final) 360 reading_state_ = NOT_READING; 361 } 362 } 363 frames->swap(frames_to_output); 364 return frames->empty() ? ERR_IO_PENDING : OK; 365} 366 367int WebSocketDeflateStream::InflateAndReadIfNecessary( 368 ScopedVector<WebSocketFrame>* frames, 369 const CompletionCallback& callback) { 370 int result = Inflate(frames); 371 while (result == ERR_IO_PENDING) { 372 DCHECK(frames->empty()); 373 374 result = stream_->ReadFrames( 375 frames, 376 base::Bind(&WebSocketDeflateStream::OnReadComplete, 377 base::Unretained(this), 378 base::Unretained(frames), 379 callback)); 380 if (result < 0) 381 break; 382 DCHECK_EQ(OK, result); 383 DCHECK(!frames->empty()); 384 385 result = Inflate(frames); 386 } 387 if (result < 0) 388 frames->clear(); 389 return result; 390} 391 392} // namespace net 393