1/*
2 * Copyright (C) 2014 Square, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16package com.squareup.okhttp.internal.ws;
17
18import java.io.EOFException;
19import java.io.IOException;
20import java.net.ProtocolException;
21import okio.Buffer;
22import okio.BufferedSource;
23import okio.Okio;
24import okio.Source;
25import okio.Timeout;
26
27import static com.squareup.okhttp.ws.WebSocket.PayloadType;
28import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_FIN;
29import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_RSV1;
30import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_RSV2;
31import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_FLAG_RSV3;
32import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B0_MASK_OPCODE;
33import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B1_FLAG_MASK;
34import static com.squareup.okhttp.internal.ws.WebSocketProtocol.B1_MASK_LENGTH;
35import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_BINARY;
36import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTINUATION;
37import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTROL_CLOSE;
38import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTROL_PING;
39import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_CONTROL_PONG;
40import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_FLAG_CONTROL;
41import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_TEXT;
42import static com.squareup.okhttp.internal.ws.WebSocketProtocol.PAYLOAD_LONG;
43import static com.squareup.okhttp.internal.ws.WebSocketProtocol.PAYLOAD_MAX;
44import static com.squareup.okhttp.internal.ws.WebSocketProtocol.PAYLOAD_SHORT;
45import static com.squareup.okhttp.internal.ws.WebSocketProtocol.toggleMask;
46import static java.lang.Integer.toHexString;
47
48/**
49 * An <a href="http://tools.ietf.org/html/rfc6455">RFC 6455</a>-compatible WebSocket frame reader.
50 */
51public final class WebSocketReader {
52  public interface FrameCallback {
53    void onMessage(BufferedSource source, PayloadType type) throws IOException;
54    void onPing(Buffer buffer);
55    void onPong(Buffer buffer);
56    void onClose(int code, String reason);
57  }
58
59  private final boolean isClient;
60  private final BufferedSource source;
61  private final FrameCallback frameCallback;
62
63  private final Source framedMessageSource = new FramedMessageSource();
64
65  private boolean closed;
66  private boolean messageClosed;
67
68  // Stateful data about the current frame.
69  private int opcode;
70  private long frameLength;
71  private long frameBytesRead;
72  private boolean isFinalFrame;
73  private boolean isControlFrame;
74  private boolean isMasked;
75
76  private final byte[] maskKey = new byte[4];
77  private final byte[] maskBuffer = new byte[2048];
78
79  public WebSocketReader(boolean isClient, BufferedSource source, FrameCallback frameCallback) {
80    if (source == null) throw new NullPointerException("source == null");
81    if (frameCallback == null) throw new NullPointerException("frameCallback == null");
82    this.isClient = isClient;
83    this.source = source;
84    this.frameCallback = frameCallback;
85  }
86
87  /**
88   * Process the next protocol frame.
89   * <ul>
90   * <li>If it is a control frame this will result in a single call to {@link FrameCallback}.</li>
91   * <li>If it is a message frame this will result in a single call to {@link
92   * FrameCallback#onMessage}. If the message spans multiple frames, each interleaved control
93   * frame will result in a corresponding call to {@link FrameCallback}.
94   * </ul>
95   */
96  public void processNextFrame() throws IOException {
97    readHeader();
98    if (isControlFrame) {
99      readControlFrame();
100    } else {
101      readMessageFrame();
102    }
103  }
104
105  private void readHeader() throws IOException {
106    if (closed) throw new IOException("closed");
107
108    int b0 = source.readByte() & 0xff;
109
110    opcode = b0 & B0_MASK_OPCODE;
111    isFinalFrame = (b0 & B0_FLAG_FIN) != 0;
112    isControlFrame = (b0 & OPCODE_FLAG_CONTROL) != 0;
113
114    // Control frames must be final frames (cannot contain continuations).
115    if (isControlFrame && !isFinalFrame) {
116      throw new ProtocolException("Control frames must be final.");
117    }
118
119    boolean reservedFlag1 = (b0 & B0_FLAG_RSV1) != 0;
120    boolean reservedFlag2 = (b0 & B0_FLAG_RSV2) != 0;
121    boolean reservedFlag3 = (b0 & B0_FLAG_RSV3) != 0;
122    if (reservedFlag1 || reservedFlag2 || reservedFlag3) {
123      // Reserved flags are for extensions which we currently do not support.
124      throw new ProtocolException("Reserved flags are unsupported.");
125    }
126
127    int b1 = source.readByte() & 0xff;
128
129    isMasked = (b1 & B1_FLAG_MASK) != 0;
130    if (isMasked == isClient) {
131      // Masked payloads must be read on the server. Unmasked payloads must be read on the client.
132      throw new ProtocolException("Client-sent frames must be masked. Server sent must not.");
133    }
134
135    // Get frame length, optionally reading from follow-up bytes if indicated by special values.
136    frameLength = b1 & B1_MASK_LENGTH;
137    if (frameLength == PAYLOAD_SHORT) {
138      frameLength = source.readShort() & 0xffffL; // Value is unsigned.
139    } else if (frameLength == PAYLOAD_LONG) {
140      frameLength = source.readLong();
141      if (frameLength < 0) {
142        throw new ProtocolException(
143            "Frame length 0x" + Long.toHexString(frameLength) + " > 0x7FFFFFFFFFFFFFFF");
144      }
145    }
146    frameBytesRead = 0;
147
148    if (isControlFrame && frameLength > PAYLOAD_MAX) {
149      throw new ProtocolException("Control frame must be less than " + PAYLOAD_MAX + "B.");
150    }
151
152    if (isMasked) {
153      // Read the masking key as bytes so that they can be used directly for unmasking.
154      source.readFully(maskKey);
155    }
156  }
157
158  private void readControlFrame() throws IOException {
159    Buffer buffer = null;
160    if (frameBytesRead < frameLength) {
161      buffer = new Buffer();
162
163      if (isClient) {
164        source.readFully(buffer, frameLength);
165      } else {
166        while (frameBytesRead < frameLength) {
167          int toRead = (int) Math.min(frameLength - frameBytesRead, maskBuffer.length);
168          int read = source.read(maskBuffer, 0, toRead);
169          if (read == -1) throw new EOFException();
170          toggleMask(maskBuffer, read, maskKey, frameBytesRead);
171          buffer.write(maskBuffer, 0, read);
172          frameBytesRead += read;
173        }
174      }
175    }
176
177    switch (opcode) {
178      case OPCODE_CONTROL_PING:
179        frameCallback.onPing(buffer);
180        break;
181      case OPCODE_CONTROL_PONG:
182        frameCallback.onPong(buffer);
183        break;
184      case OPCODE_CONTROL_CLOSE:
185        int code = 0;
186        String reason = "";
187        if (buffer != null) {
188          if (buffer.size() < 2) {
189            throw new ProtocolException("Close payload must be at least two bytes.");
190          }
191          code = buffer.readShort();
192          if (code < 1000 || code >= 5000) {
193            throw new ProtocolException("Code must be in range [1000,5000): " + code);
194          }
195
196          reason = buffer.readUtf8();
197        }
198        frameCallback.onClose(code, reason);
199        closed = true;
200        break;
201      default:
202        throw new ProtocolException("Unknown control opcode: " + toHexString(opcode));
203    }
204  }
205
206  private void readMessageFrame() throws IOException {
207    PayloadType type;
208    switch (opcode) {
209      case OPCODE_TEXT:
210        type = PayloadType.TEXT;
211        break;
212      case OPCODE_BINARY:
213        type = PayloadType.BINARY;
214        break;
215      default:
216        throw new ProtocolException("Unknown opcode: " + toHexString(opcode));
217    }
218
219    messageClosed = false;
220    frameCallback.onMessage(Okio.buffer(framedMessageSource), type);
221    if (!messageClosed) {
222      throw new IllegalStateException("Listener failed to call close on message payload.");
223    }
224  }
225
226  /** Read headers and process any control frames until we reach a non-control frame. */
227  private void readUntilNonControlFrame() throws IOException {
228    while (!closed) {
229      readHeader();
230      if (!isControlFrame) {
231        break;
232      }
233      readControlFrame();
234    }
235  }
236
237  /**
238   * A special source which knows how to read a message body across one or more frames. Control
239   * frames that occur between fragments will be processed. If the message payload is masked this
240   * will unmask as it's being processed.
241   */
242  private final class FramedMessageSource implements Source {
243    @Override public long read(Buffer sink, long byteCount) throws IOException {
244      if (closed) throw new IOException("closed");
245      if (messageClosed) throw new IllegalStateException("closed");
246
247      if (frameBytesRead == frameLength) {
248        if (isFinalFrame) return -1; // We are exhausted and have no continuations.
249
250        readUntilNonControlFrame();
251        if (opcode != OPCODE_CONTINUATION) {
252          throw new ProtocolException("Expected continuation opcode. Got: " + toHexString(opcode));
253        }
254        if (isFinalFrame && frameLength == 0) {
255          return -1; // Fast-path for empty final frame.
256        }
257      }
258
259      long toRead = Math.min(byteCount, frameLength - frameBytesRead);
260
261      long read;
262      if (isMasked) {
263        toRead = Math.min(toRead, maskBuffer.length);
264        read = source.read(maskBuffer, 0, (int) toRead);
265        if (read == -1) throw new EOFException();
266        toggleMask(maskBuffer, read, maskKey, frameBytesRead);
267        sink.write(maskBuffer, 0, (int) read);
268      } else {
269        read = source.read(sink, toRead);
270        if (read == -1) throw new EOFException();
271      }
272
273      frameBytesRead += read;
274      return read;
275    }
276
277    @Override public Timeout timeout() {
278      return source.timeout();
279    }
280
281    @Override public void close() throws IOException {
282      if (messageClosed) return;
283      messageClosed = true;
284      if (closed) return;
285
286      // Exhaust the remainder of the message, if any.
287      source.skip(frameLength - frameBytesRead);
288      while (!isFinalFrame) {
289        readUntilNonControlFrame();
290        source.skip(frameLength);
291      }
292    }
293  }
294}
295