1/*
2 * Copyright (C) 2014 Square, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16package com.squareup.okhttp.internal.ws;
17
18import com.squareup.okhttp.MediaType;
19import com.squareup.okhttp.ResponseBody;
20import com.squareup.okhttp.ws.WebSocket;
21import java.io.EOFException;
22import java.io.IOException;
23import java.net.ProtocolException;
24import okio.Buffer;
25import okio.BufferedSource;
26import okio.Okio;
27import okio.Source;
28import okio.Timeout;
29
30import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_FIN;
31import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_RSV1;
32import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_RSV2;
33import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_RSV3;
34import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_MASK_OPCODE;
35import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B1_FLAG_MASK;
36import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B1_MASK_LENGTH;
37import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_BINARY;
38import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTINUATION;
39import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTROL_CLOSE;
40import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTROL_PING;
41import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTROL_PONG;
42import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_FLAG_CONTROL;
43import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_TEXT;
44import static com.squareup.okhttp.internal.ws.WebSocketProtocol.PAYLOAD_LONG;
45import static com.squareup.okhttp.internal.ws.WebSocketProtocol.PAYLOAD_BYTE_MAX;
46import static com.squareup.okhttp.internal.ws.WebSocketProtocol.PAYLOAD_SHORT;
47import static com.squareup.okhttp.internal.ws.WebSocketProtocol.toggleMask;
48import static java.lang.Integer.toHexString;
49
50/**
51 * An <a href="http://tools.ietf.org/html/rfc6455">RFC 6455</a>-compatible WebSocket frame reader.
52 */
53public final class WebSocketReader {
54  public interface FrameCallback {
55    void onMessage(ResponseBody body) throws IOException;
56    void onPing(Buffer buffer);
57    void onPong(Buffer buffer);
58    void onClose(int code, String reason);
59  }
60
61  private final boolean isClient;
62  private final BufferedSource source;
63  private final FrameCallback frameCallback;
64
65  private final Source framedMessageSource = new FramedMessageSource();
66
67  private boolean closed;
68  private boolean messageClosed;
69
70  // Stateful data about the current frame.
71  private int opcode;
72  private long frameLength;
73  private long frameBytesRead;
74  private boolean isFinalFrame;
75  private boolean isControlFrame;
76  private boolean isMasked;
77
78  private final byte[] maskKey = new byte[4];
79  private final byte[] maskBuffer = new byte[2048];
80
81  public WebSocketReader(boolean isClient, BufferedSource source, FrameCallback frameCallback) {
82    if (source == null) throw new NullPointerException("source == null");
83    if (frameCallback == null) throw new NullPointerException("frameCallback == null");
84    this.isClient = isClient;
85    this.source = source;
86    this.frameCallback = frameCallback;
87  }
88
89  /**
90   * Process the next protocol frame.
91   * <ul>
92   * <li>If it is a control frame this will result in a single call to {@link FrameCallback}.</li>
93   * <li>If it is a message frame this will result in a single call to {@link
94   * FrameCallback#onMessage}. If the message spans multiple frames, each interleaved control
95   * frame will result in a corresponding call to {@link FrameCallback}.
96   * </ul>
97   */
98  public void processNextFrame() throws IOException {
99    readHeader();
100    if (isControlFrame) {
101      readControlFrame();
102    } else {
103      readMessageFrame();
104    }
105  }
106
107  private void readHeader() throws IOException {
108    if (closed) throw new IOException("closed");
109
110    int b0 = source.readByte() & 0xff;
111
112    opcode = b0 & B0_MASK_OPCODE;
113    isFinalFrame = (b0 & B0_FLAG_FIN) != 0;
114    isControlFrame = (b0 & OPCODE_FLAG_CONTROL) != 0;
115
116    // Control frames must be final frames (cannot contain continuations).
117    if (isControlFrame && !isFinalFrame) {
118      throw new ProtocolException("Control frames must be final.");
119    }
120
121    boolean reservedFlag1 = (b0 & B0_FLAG_RSV1) != 0;
122    boolean reservedFlag2 = (b0 & B0_FLAG_RSV2) != 0;
123    boolean reservedFlag3 = (b0 & B0_FLAG_RSV3) != 0;
124    if (reservedFlag1 || reservedFlag2 || reservedFlag3) {
125      // Reserved flags are for extensions which we currently do not support.
126      throw new ProtocolException("Reserved flags are unsupported.");
127    }
128
129    int b1 = source.readByte() & 0xff;
130
131    isMasked = (b1 & B1_FLAG_MASK) != 0;
132    if (isMasked == isClient) {
133      // Masked payloads must be read on the server. Unmasked payloads must be read on the client.
134      throw new ProtocolException("Client-sent frames must be masked. Server sent must not.");
135    }
136
137    // Get frame length, optionally reading from follow-up bytes if indicated by special values.
138    frameLength = b1 & B1_MASK_LENGTH;
139    if (frameLength == PAYLOAD_SHORT) {
140      frameLength = source.readShort() & 0xffffL; // Value is unsigned.
141    } else if (frameLength == PAYLOAD_LONG) {
142      frameLength = source.readLong();
143      if (frameLength < 0) {
144        throw new ProtocolException(
145            "Frame length 0x" + Long.toHexString(frameLength) + " > 0x7FFFFFFFFFFFFFFF");
146      }
147    }
148    frameBytesRead = 0;
149
150    if (isControlFrame && frameLength > PAYLOAD_BYTE_MAX) {
151      throw new ProtocolException("Control frame must be less than " + PAYLOAD_BYTE_MAX + "B.");
152    }
153
154    if (isMasked) {
155      // Read the masking key as bytes so that they can be used directly for unmasking.
156      source.readFully(maskKey);
157    }
158  }
159
160  private void readControlFrame() throws IOException {
161    Buffer buffer = null;
162    if (frameBytesRead < frameLength) {
163      buffer = new Buffer();
164
165      if (isClient) {
166        source.readFully(buffer, frameLength);
167      } else {
168        while (frameBytesRead < frameLength) {
169          int toRead = (int) Math.min(frameLength - frameBytesRead, maskBuffer.length);
170          int read = source.read(maskBuffer, 0, toRead);
171          if (read == -1) throw new EOFException();
172          toggleMask(maskBuffer, read, maskKey, frameBytesRead);
173          buffer.write(maskBuffer, 0, read);
174          frameBytesRead += read;
175        }
176      }
177    }
178
179    switch (opcode) {
180      case OPCODE_CONTROL_PING:
181        frameCallback.onPing(buffer);
182        break;
183      case OPCODE_CONTROL_PONG:
184        frameCallback.onPong(buffer);
185        break;
186      case OPCODE_CONTROL_CLOSE:
187        int code = 1000;
188        String reason = "";
189        if (buffer != null) {
190          long bufferSize = buffer.size();
191          if (bufferSize == 1) {
192            throw new ProtocolException("Malformed close payload length of 1.");
193          } else if (bufferSize != 0) {
194            code = buffer.readShort();
195            if (code < 1000 || code >= 5000) {
196              throw new ProtocolException("Code must be in range [1000,5000): " + code);
197            }
198            if ((code >= 1004 && code <= 1006) || (code >= 1012 && code <= 2999)) {
199              throw new ProtocolException("Code " + code + " is reserved and may not be used.");
200            }
201
202            reason = buffer.readUtf8();
203          }
204        }
205        frameCallback.onClose(code, reason);
206        closed = true;
207        break;
208      default:
209        throw new ProtocolException("Unknown control opcode: " + toHexString(opcode));
210    }
211  }
212
213  private void readMessageFrame() throws IOException {
214    final MediaType type;
215    switch (opcode) {
216      case OPCODE_TEXT:
217        type = WebSocket.TEXT;
218        break;
219      case OPCODE_BINARY:
220        type = WebSocket.BINARY;
221        break;
222      default:
223        throw new ProtocolException("Unknown opcode: " + toHexString(opcode));
224    }
225
226    final BufferedSource source = Okio.buffer(framedMessageSource);
227    ResponseBody body = new ResponseBody() {
228      @Override public MediaType contentType() {
229        return type;
230      }
231
232      @Override public long contentLength() throws IOException {
233        return -1;
234      }
235
236      @Override public BufferedSource source() throws IOException {
237        return source;
238      }
239    };
240
241    messageClosed = false;
242    frameCallback.onMessage(body);
243    if (!messageClosed) {
244      throw new IllegalStateException("Listener failed to call close on message payload.");
245    }
246  }
247
248  /** Read headers and process any control frames until we reach a non-control frame. */
249  private void readUntilNonControlFrame() throws IOException {
250    while (!closed) {
251      readHeader();
252      if (!isControlFrame) {
253        break;
254      }
255      readControlFrame();
256    }
257  }
258
259  /**
260   * A special source which knows how to read a message body across one or more frames. Control
261   * frames that occur between fragments will be processed. If the message payload is masked this
262   * will unmask as it's being processed.
263   */
264  private final class FramedMessageSource implements Source {
265    @Override public long read(Buffer sink, long byteCount) throws IOException {
266      if (closed) throw new IOException("closed");
267      if (messageClosed) throw new IllegalStateException("closed");
268
269      if (frameBytesRead == frameLength) {
270        if (isFinalFrame) return -1; // We are exhausted and have no continuations.
271
272        readUntilNonControlFrame();
273        if (opcode != OPCODE_CONTINUATION) {
274          throw new ProtocolException("Expected continuation opcode. Got: " + toHexString(opcode));
275        }
276        if (isFinalFrame && frameLength == 0) {
277          return -1; // Fast-path for empty final frame.
278        }
279      }
280
281      long toRead = Math.min(byteCount, frameLength - frameBytesRead);
282
283      long read;
284      if (isMasked) {
285        toRead = Math.min(toRead, maskBuffer.length);
286        read = source.read(maskBuffer, 0, (int) toRead);
287        if (read == -1) throw new EOFException();
288        toggleMask(maskBuffer, read, maskKey, frameBytesRead);
289        sink.write(maskBuffer, 0, (int) read);
290      } else {
291        read = source.read(sink, toRead);
292        if (read == -1) throw new EOFException();
293      }
294
295      frameBytesRead += read;
296      return read;
297    }
298
299    @Override public Timeout timeout() {
300      return source.timeout();
301    }
302
303    @Override public void close() throws IOException {
304      if (messageClosed) return;
305      messageClosed = true;
306      if (closed) return;
307
308      // Exhaust the remainder of the message, if any.
309      source.skip(frameLength - frameBytesRead);
310      while (!isFinalFrame) {
311        readUntilNonControlFrame();
312        source.skip(frameLength);
313      }
314    }
315  }
316}
317