1/* 2 * Copyright (C) 2014 Square, Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16package com.squareup.okhttp.internal.ws; 17 18import com.squareup.okhttp.MediaType; 19import com.squareup.okhttp.ResponseBody; 20import com.squareup.okhttp.ws.WebSocket; 21import java.io.EOFException; 22import java.io.IOException; 23import java.net.ProtocolException; 24import okio.Buffer; 25import okio.BufferedSource; 26import okio.Okio; 27import okio.Source; 28import okio.Timeout; 29 30import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_FIN; 31import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_RSV1; 32import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_RSV2; 33import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_RSV3; 34import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_MASK_OPCODE; 35import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B1_FLAG_MASK; 36import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B1_MASK_LENGTH; 37import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_BINARY; 38import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTINUATION; 39import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTROL_CLOSE; 40import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTROL_PING; 41import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTROL_PONG; 42import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_FLAG_CONTROL; 43import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_TEXT; 44import static com.squareup.okhttp.internal.ws.WebSocketProtocol.PAYLOAD_LONG; 45import static com.squareup.okhttp.internal.ws.WebSocketProtocol.PAYLOAD_BYTE_MAX; 46import static com.squareup.okhttp.internal.ws.WebSocketProtocol.PAYLOAD_SHORT; 47import static com.squareup.okhttp.internal.ws.WebSocketProtocol.toggleMask; 48import static java.lang.Integer.toHexString; 49 50/** 51 * An <a href="http://tools.ietf.org/html/rfc6455">RFC 6455</a>-compatible WebSocket frame reader. 52 */ 53public final class WebSocketReader { 54 public interface FrameCallback { 55 void onMessage(ResponseBody body) throws IOException; 56 void onPing(Buffer buffer); 57 void onPong(Buffer buffer); 58 void onClose(int code, String reason); 59 } 60 61 private final boolean isClient; 62 private final BufferedSource source; 63 private final FrameCallback frameCallback; 64 65 private final Source framedMessageSource = new FramedMessageSource(); 66 67 private boolean closed; 68 private boolean messageClosed; 69 70 // Stateful data about the current frame. 71 private int opcode; 72 private long frameLength; 73 private long frameBytesRead; 74 private boolean isFinalFrame; 75 private boolean isControlFrame; 76 private boolean isMasked; 77 78 private final byte[] maskKey = new byte[4]; 79 private final byte[] maskBuffer = new byte[2048]; 80 81 public WebSocketReader(boolean isClient, BufferedSource source, FrameCallback frameCallback) { 82 if (source == null) throw new NullPointerException("source == null"); 83 if (frameCallback == null) throw new NullPointerException("frameCallback == null"); 84 this.isClient = isClient; 85 this.source = source; 86 this.frameCallback = frameCallback; 87 } 88 89 /** 90 * Process the next protocol frame. 91 * <ul> 92 * <li>If it is a control frame this will result in a single call to {@link FrameCallback}.</li> 93 * <li>If it is a message frame this will result in a single call to {@link 94 * FrameCallback#onMessage}. If the message spans multiple frames, each interleaved control 95 * frame will result in a corresponding call to {@link FrameCallback}. 96 * </ul> 97 */ 98 public void processNextFrame() throws IOException { 99 readHeader(); 100 if (isControlFrame) { 101 readControlFrame(); 102 } else { 103 readMessageFrame(); 104 } 105 } 106 107 private void readHeader() throws IOException { 108 if (closed) throw new IOException("closed"); 109 110 int b0 = source.readByte() & 0xff; 111 112 opcode = b0 & B0_MASK_OPCODE; 113 isFinalFrame = (b0 & B0_FLAG_FIN) != 0; 114 isControlFrame = (b0 & OPCODE_FLAG_CONTROL) != 0; 115 116 // Control frames must be final frames (cannot contain continuations). 117 if (isControlFrame && !isFinalFrame) { 118 throw new ProtocolException("Control frames must be final."); 119 } 120 121 boolean reservedFlag1 = (b0 & B0_FLAG_RSV1) != 0; 122 boolean reservedFlag2 = (b0 & B0_FLAG_RSV2) != 0; 123 boolean reservedFlag3 = (b0 & B0_FLAG_RSV3) != 0; 124 if (reservedFlag1 || reservedFlag2 || reservedFlag3) { 125 // Reserved flags are for extensions which we currently do not support. 126 throw new ProtocolException("Reserved flags are unsupported."); 127 } 128 129 int b1 = source.readByte() & 0xff; 130 131 isMasked = (b1 & B1_FLAG_MASK) != 0; 132 if (isMasked == isClient) { 133 // Masked payloads must be read on the server. Unmasked payloads must be read on the client. 134 throw new ProtocolException("Client-sent frames must be masked. Server sent must not."); 135 } 136 137 // Get frame length, optionally reading from follow-up bytes if indicated by special values. 138 frameLength = b1 & B1_MASK_LENGTH; 139 if (frameLength == PAYLOAD_SHORT) { 140 frameLength = source.readShort() & 0xffffL; // Value is unsigned. 141 } else if (frameLength == PAYLOAD_LONG) { 142 frameLength = source.readLong(); 143 if (frameLength < 0) { 144 throw new ProtocolException( 145 "Frame length 0x" + Long.toHexString(frameLength) + " > 0x7FFFFFFFFFFFFFFF"); 146 } 147 } 148 frameBytesRead = 0; 149 150 if (isControlFrame && frameLength > PAYLOAD_BYTE_MAX) { 151 throw new ProtocolException("Control frame must be less than " + PAYLOAD_BYTE_MAX + "B."); 152 } 153 154 if (isMasked) { 155 // Read the masking key as bytes so that they can be used directly for unmasking. 156 source.readFully(maskKey); 157 } 158 } 159 160 private void readControlFrame() throws IOException { 161 Buffer buffer = null; 162 if (frameBytesRead < frameLength) { 163 buffer = new Buffer(); 164 165 if (isClient) { 166 source.readFully(buffer, frameLength); 167 } else { 168 while (frameBytesRead < frameLength) { 169 int toRead = (int) Math.min(frameLength - frameBytesRead, maskBuffer.length); 170 int read = source.read(maskBuffer, 0, toRead); 171 if (read == -1) throw new EOFException(); 172 toggleMask(maskBuffer, read, maskKey, frameBytesRead); 173 buffer.write(maskBuffer, 0, read); 174 frameBytesRead += read; 175 } 176 } 177 } 178 179 switch (opcode) { 180 case OPCODE_CONTROL_PING: 181 frameCallback.onPing(buffer); 182 break; 183 case OPCODE_CONTROL_PONG: 184 frameCallback.onPong(buffer); 185 break; 186 case OPCODE_CONTROL_CLOSE: 187 int code = 1000; 188 String reason = ""; 189 if (buffer != null) { 190 long bufferSize = buffer.size(); 191 if (bufferSize == 1) { 192 throw new ProtocolException("Malformed close payload length of 1."); 193 } else if (bufferSize != 0) { 194 code = buffer.readShort(); 195 if (code < 1000 || code >= 5000) { 196 throw new ProtocolException("Code must be in range [1000,5000): " + code); 197 } 198 if ((code >= 1004 && code <= 1006) || (code >= 1012 && code <= 2999)) { 199 throw new ProtocolException("Code " + code + " is reserved and may not be used."); 200 } 201 202 reason = buffer.readUtf8(); 203 } 204 } 205 frameCallback.onClose(code, reason); 206 closed = true; 207 break; 208 default: 209 throw new ProtocolException("Unknown control opcode: " + toHexString(opcode)); 210 } 211 } 212 213 private void readMessageFrame() throws IOException { 214 final MediaType type; 215 switch (opcode) { 216 case OPCODE_TEXT: 217 type = WebSocket.TEXT; 218 break; 219 case OPCODE_BINARY: 220 type = WebSocket.BINARY; 221 break; 222 default: 223 throw new ProtocolException("Unknown opcode: " + toHexString(opcode)); 224 } 225 226 final BufferedSource source = Okio.buffer(framedMessageSource); 227 ResponseBody body = new ResponseBody() { 228 @Override public MediaType contentType() { 229 return type; 230 } 231 232 @Override public long contentLength() throws IOException { 233 return -1; 234 } 235 236 @Override public BufferedSource source() throws IOException { 237 return source; 238 } 239 }; 240 241 messageClosed = false; 242 frameCallback.onMessage(body); 243 if (!messageClosed) { 244 throw new IllegalStateException("Listener failed to call close on message payload."); 245 } 246 } 247 248 /** Read headers and process any control frames until we reach a non-control frame. */ 249 private void readUntilNonControlFrame() throws IOException { 250 while (!closed) { 251 readHeader(); 252 if (!isControlFrame) { 253 break; 254 } 255 readControlFrame(); 256 } 257 } 258 259 /** 260 * A special source which knows how to read a message body across one or more frames. Control 261 * frames that occur between fragments will be processed. If the message payload is masked this 262 * will unmask as it's being processed. 263 */ 264 private final class FramedMessageSource implements Source { 265 @Override public long read(Buffer sink, long byteCount) throws IOException { 266 if (closed) throw new IOException("closed"); 267 if (messageClosed) throw new IllegalStateException("closed"); 268 269 if (frameBytesRead == frameLength) { 270 if (isFinalFrame) return -1; // We are exhausted and have no continuations. 271 272 readUntilNonControlFrame(); 273 if (opcode != OPCODE_CONTINUATION) { 274 throw new ProtocolException("Expected continuation opcode. Got: " + toHexString(opcode)); 275 } 276 if (isFinalFrame && frameLength == 0) { 277 return -1; // Fast-path for empty final frame. 278 } 279 } 280 281 long toRead = Math.min(byteCount, frameLength - frameBytesRead); 282 283 long read; 284 if (isMasked) { 285 toRead = Math.min(toRead, maskBuffer.length); 286 read = source.read(maskBuffer, 0, (int) toRead); 287 if (read == -1) throw new EOFException(); 288 toggleMask(maskBuffer, read, maskKey, frameBytesRead); 289 sink.write(maskBuffer, 0, (int) read); 290 } else { 291 read = source.read(sink, toRead); 292 if (read == -1) throw new EOFException(); 293 } 294 295 frameBytesRead += read; 296 return read; 297 } 298 299 @Override public Timeout timeout() { 300 return source.timeout(); 301 } 302 303 @Override public void close() throws IOException { 304 if (messageClosed) return; 305 messageClosed = true; 306 if (closed) return; 307 308 // Exhaust the remainder of the message, if any. 309 source.skip(frameLength - frameBytesRead); 310 while (!isFinalFrame) { 311 readUntilNonControlFrame(); 312 source.skip(frameLength); 313 } 314 } 315 } 316} 317