1/* 2 * Copyright (C) 2014 Square, Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16package com.squareup.okhttp.internal.ws; 17 18import java.io.EOFException; 19import java.io.IOException; 20import java.net.ProtocolException; 21import okio.Buffer; 22import okio.BufferedSource; 23import okio.Okio; 24import okio.Source; 25import okio.Timeout; 26 27import static com.squareup.okhttp.ws.WebSocket.PayloadType; 28import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_FIN; 29import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_RSV1; 30import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_RSV2; 31import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_RSV3; 32import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_MASK_OPCODE; 33import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B1_FLAG_MASK; 34import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B1_MASK_LENGTH; 35import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_BINARY; 36import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTINUATION; 37import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTROL_CLOSE; 38import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTROL_PING; 39import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTROL_PONG; 40import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_FLAG_CONTROL; 41import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_TEXT; 42import static com.squareup.okhttp.internal.ws.WebSocketProtocol.PAYLOAD_LONG; 43import static com.squareup.okhttp.internal.ws.WebSocketProtocol.PAYLOAD_MAX; 44import static com.squareup.okhttp.internal.ws.WebSocketProtocol.PAYLOAD_SHORT; 45import static com.squareup.okhttp.internal.ws.WebSocketProtocol.toggleMask; 46import static java.lang.Integer.toHexString; 47 48/** 49 * An <a href="http://tools.ietf.org/html/rfc6455">RFC 6455</a>-compatible WebSocket frame reader. 50 */ 51public final class WebSocketReader { 52 public interface FrameCallback { 53 void onMessage(BufferedSource source, PayloadType type) throws IOException; 54 void onPing(Buffer buffer); 55 void onPong(Buffer buffer); 56 void onClose(int code, String reason); 57 } 58 59 private final boolean isClient; 60 private final BufferedSource source; 61 private final FrameCallback frameCallback; 62 63 private final Source framedMessageSource = new FramedMessageSource(); 64 65 private boolean closed; 66 private boolean messageClosed; 67 68 // Stateful data about the current frame. 69 private int opcode; 70 private long frameLength; 71 private long frameBytesRead; 72 private boolean isFinalFrame; 73 private boolean isControlFrame; 74 private boolean isMasked; 75 76 private final byte[] maskKey = new byte[4]; 77 private final byte[] maskBuffer = new byte[2048]; 78 79 public WebSocketReader(boolean isClient, BufferedSource source, FrameCallback frameCallback) { 80 if (source == null) throw new NullPointerException("source == null"); 81 if (frameCallback == null) throw new NullPointerException("frameCallback == null"); 82 this.isClient = isClient; 83 this.source = source; 84 this.frameCallback = frameCallback; 85 } 86 87 /** 88 * Process the next protocol frame. 89 * <ul> 90 * <li>If it is a control frame this will result in a single call to {@link FrameCallback}.</li> 91 * <li>If it is a message frame this will result in a single call to {@link 92 * FrameCallback#onMessage}. If the message spans multiple frames, each interleaved control 93 * frame will result in a corresponding call to {@link FrameCallback}. 94 * </ul> 95 */ 96 public void processNextFrame() throws IOException { 97 readHeader(); 98 if (isControlFrame) { 99 readControlFrame(); 100 } else { 101 readMessageFrame(); 102 } 103 } 104 105 private void readHeader() throws IOException { 106 if (closed) throw new IOException("closed"); 107 108 int b0 = source.readByte() & 0xff; 109 110 opcode = b0 & B0_MASK_OPCODE; 111 isFinalFrame = (b0 & B0_FLAG_FIN) != 0; 112 isControlFrame = (b0 & OPCODE_FLAG_CONTROL) != 0; 113 114 // Control frames must be final frames (cannot contain continuations). 115 if (isControlFrame && !isFinalFrame) { 116 throw new ProtocolException("Control frames must be final."); 117 } 118 119 boolean reservedFlag1 = (b0 & B0_FLAG_RSV1) != 0; 120 boolean reservedFlag2 = (b0 & B0_FLAG_RSV2) != 0; 121 boolean reservedFlag3 = (b0 & B0_FLAG_RSV3) != 0; 122 if (reservedFlag1 || reservedFlag2 || reservedFlag3) { 123 // Reserved flags are for extensions which we currently do not support. 124 throw new ProtocolException("Reserved flags are unsupported."); 125 } 126 127 int b1 = source.readByte() & 0xff; 128 129 isMasked = (b1 & B1_FLAG_MASK) != 0; 130 if (isMasked == isClient) { 131 // Masked payloads must be read on the server. Unmasked payloads must be read on the client. 132 throw new ProtocolException("Client-sent frames must be masked. Server sent must not."); 133 } 134 135 // Get frame length, optionally reading from follow-up bytes if indicated by special values. 136 frameLength = b1 & B1_MASK_LENGTH; 137 if (frameLength == PAYLOAD_SHORT) { 138 frameLength = source.readShort() & 0xffffL; // Value is unsigned. 139 } else if (frameLength == PAYLOAD_LONG) { 140 frameLength = source.readLong(); 141 if (frameLength < 0) { 142 throw new ProtocolException( 143 "Frame length 0x" + Long.toHexString(frameLength) + " > 0x7FFFFFFFFFFFFFFF"); 144 } 145 } 146 frameBytesRead = 0; 147 148 if (isControlFrame && frameLength > PAYLOAD_MAX) { 149 throw new ProtocolException("Control frame must be less than " + PAYLOAD_MAX + "B."); 150 } 151 152 if (isMasked) { 153 // Read the masking key as bytes so that they can be used directly for unmasking. 154 source.readFully(maskKey); 155 } 156 } 157 158 private void readControlFrame() throws IOException { 159 Buffer buffer = null; 160 if (frameBytesRead < frameLength) { 161 buffer = new Buffer(); 162 163 if (isClient) { 164 source.readFully(buffer, frameLength); 165 } else { 166 while (frameBytesRead < frameLength) { 167 int toRead = (int) Math.min(frameLength - frameBytesRead, maskBuffer.length); 168 int read = source.read(maskBuffer, 0, toRead); 169 if (read == -1) throw new EOFException(); 170 toggleMask(maskBuffer, read, maskKey, frameBytesRead); 171 buffer.write(maskBuffer, 0, read); 172 frameBytesRead += read; 173 } 174 } 175 } 176 177 switch (opcode) { 178 case OPCODE_CONTROL_PING: 179 frameCallback.onPing(buffer); 180 break; 181 case OPCODE_CONTROL_PONG: 182 frameCallback.onPong(buffer); 183 break; 184 case OPCODE_CONTROL_CLOSE: 185 int code = 0; 186 String reason = ""; 187 if (buffer != null) { 188 if (buffer.size() < 2) { 189 throw new ProtocolException("Close payload must be at least two bytes."); 190 } 191 code = buffer.readShort(); 192 if (code < 1000 || code >= 5000) { 193 throw new ProtocolException("Code must be in range [1000,5000): " + code); 194 } 195 196 reason = buffer.readUtf8(); 197 } 198 frameCallback.onClose(code, reason); 199 closed = true; 200 break; 201 default: 202 throw new ProtocolException("Unknown control opcode: " + toHexString(opcode)); 203 } 204 } 205 206 private void readMessageFrame() throws IOException { 207 PayloadType type; 208 switch (opcode) { 209 case OPCODE_TEXT: 210 type = PayloadType.TEXT; 211 break; 212 case OPCODE_BINARY: 213 type = PayloadType.BINARY; 214 break; 215 default: 216 throw new ProtocolException("Unknown opcode: " + toHexString(opcode)); 217 } 218 219 messageClosed = false; 220 frameCallback.onMessage(Okio.buffer(framedMessageSource), type); 221 if (!messageClosed) { 222 throw new IllegalStateException("Listener failed to call close on message payload."); 223 } 224 } 225 226 /** Read headers and process any control frames until we reach a non-control frame. */ 227 private void readUntilNonControlFrame() throws IOException { 228 while (!closed) { 229 readHeader(); 230 if (!isControlFrame) { 231 break; 232 } 233 readControlFrame(); 234 } 235 } 236 237 /** 238 * A special source which knows how to read a message body across one or more frames. Control 239 * frames that occur between fragments will be processed. If the message payload is masked this 240 * will unmask as it's being processed. 241 */ 242 private final class FramedMessageSource implements Source { 243 @Override public long read(Buffer sink, long byteCount) throws IOException { 244 if (closed) throw new IOException("closed"); 245 if (messageClosed) throw new IllegalStateException("closed"); 246 247 if (frameBytesRead == frameLength) { 248 if (isFinalFrame) return -1; // We are exhausted and have no continuations. 249 250 readUntilNonControlFrame(); 251 if (opcode != OPCODE_CONTINUATION) { 252 throw new ProtocolException("Expected continuation opcode. Got: " + toHexString(opcode)); 253 } 254 if (isFinalFrame && frameLength == 0) { 255 return -1; // Fast-path for empty final frame. 256 } 257 } 258 259 long toRead = Math.min(byteCount, frameLength - frameBytesRead); 260 261 long read; 262 if (isMasked) { 263 toRead = Math.min(toRead, maskBuffer.length); 264 read = source.read(maskBuffer, 0, (int) toRead); 265 if (read == -1) throw new EOFException(); 266 toggleMask(maskBuffer, read, maskKey, frameBytesRead); 267 sink.write(maskBuffer, 0, (int) read); 268 } else { 269 read = source.read(sink, toRead); 270 if (read == -1) throw new EOFException(); 271 } 272 273 frameBytesRead += read; 274 return read; 275 } 276 277 @Override public Timeout timeout() { 278 return source.timeout(); 279 } 280 281 @Override public void close() throws IOException { 282 if (messageClosed) return; 283 messageClosed = true; 284 if (closed) return; 285 286 // Exhaust the remainder of the message, if any. 287 source.skip(frameLength - frameBytesRead); 288 while (!isFinalFrame) { 289 readUntilNonControlFrame(); 290 source.skip(frameLength); 291 } 292 } 293 } 294} 295