1/*
2 * Copyright (C) 2014 Square, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16package com.squareup.okhttp.internal.ws;
17
18import java.io.EOFException;
19import java.io.IOException;
20import java.util.Random;
21import okio.Buffer;
22import okio.BufferedSink;
23import okio.ByteString;
24import okio.Okio;
25import okio.Sink;
26import org.junit.Rule;
27import org.junit.Test;
28import org.junit.rules.TestRule;
29import org.junit.runner.Description;
30import org.junit.runners.model.Statement;
31
32import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_BINARY;
33import static com.squareup.okhttp.internal.ws.WebSocketProtocol.OPCODE_TEXT;
34import static com.squareup.okhttp.internal.ws.WebSocketProtocol.PAYLOAD_BYTE_MAX;
35import static com.squareup.okhttp.internal.ws.WebSocketProtocol.PAYLOAD_SHORT_MAX;
36import static com.squareup.okhttp.internal.ws.WebSocketProtocol.toggleMask;
37import static org.junit.Assert.assertEquals;
38import static org.junit.Assert.fail;
39
40public final class WebSocketWriterTest {
41  private final Buffer data = new Buffer();
42  private final Random random = new Random(0);
43
44  /**
45   * Check all data as verified inside of the test. We do this in a rule instead of @After so that
46   * exceptions thrown from the test do not cause this check to fail.
47   */
48  @Rule public final TestRule noDataLeftBehind = new TestRule() {
49    @Override public Statement apply(final Statement base, Description description) {
50      return new Statement() {
51        @Override public void evaluate() throws Throwable {
52          base.evaluate();
53          assertEquals("Data not empty", "", data.readByteString().hex());
54        }
55      };
56    }
57  };
58
59  // Mutually exclusive. Use the one corresponding to the peer whose behavior you wish to test.
60  private final WebSocketWriter serverWriter = new WebSocketWriter(false, data, random);
61  private final WebSocketWriter clientWriter = new WebSocketWriter(true, data, random);
62
63  @Test public void serverTextMessage() throws IOException {
64    BufferedSink sink = Okio.buffer(serverWriter.newMessageSink(OPCODE_TEXT));
65
66    sink.writeUtf8("Hel").flush();
67    assertData("010348656c");
68
69    sink.writeUtf8("lo").flush();
70    assertData("00026c6f");
71
72    sink.close();
73    assertData("8000");
74  }
75
76  @Test public void closeFlushes() throws IOException {
77    BufferedSink sink = Okio.buffer(serverWriter.newMessageSink(OPCODE_TEXT));
78
79    sink.writeUtf8("Hel").flush();
80    assertData("010348656c");
81
82    sink.writeUtf8("lo").close();
83    assertData("80026c6f");
84  }
85
86  @Test public void noWritesAfterClose() throws IOException {
87    Sink sink = serverWriter.newMessageSink(OPCODE_TEXT);
88
89    sink.close();
90    assertData("8100");
91
92    Buffer payload = new Buffer().writeUtf8("Hello");
93    try {
94      // Write to the unbuffered sink as BufferedSink keeps its own closed state.
95      sink.write(payload, payload.size());
96      fail();
97    } catch (IOException e) {
98      assertEquals("closed", e.getMessage());
99    }
100  }
101
102  @Test public void clientTextMessage() throws IOException {
103    BufferedSink sink = Okio.buffer(clientWriter.newMessageSink(OPCODE_TEXT));
104
105    sink.writeUtf8("Hel").flush();
106    assertData("018360b420bb28d14c");
107
108    sink.writeUtf8("lo").flush();
109    assertData("00823851d9d4543e");
110
111    sink.close();
112    assertData("80807acb933d");
113  }
114
115  @Test public void serverBinaryMessage() throws IOException {
116    BufferedSink sink = Okio.buffer(serverWriter.newMessageSink(OPCODE_BINARY));
117
118    sink.write(binaryData(50)).flush();
119    assertData("0232");
120    assertData(binaryData(50));
121
122    sink.write(binaryData(50)).flush();
123    assertData("0032");
124    assertData(binaryData(50));
125
126    sink.close();
127    assertData("8000");
128  }
129
130  @Test public void serverMessageLengthShort() throws IOException {
131    Sink sink = serverWriter.newMessageSink(OPCODE_BINARY);
132
133    // Create a payload which will overflow the normal payload byte size.
134    Buffer payload = new Buffer();
135    while (payload.completeSegmentByteCount() <= PAYLOAD_BYTE_MAX) {
136      payload.writeByte('0');
137    }
138    long byteCount = payload.completeSegmentByteCount();
139
140    // Write directly to the unbuffered sink. This ensures it will become single frame.
141    sink.write(payload.clone(), byteCount);
142    assertData("027e"); // 'e' == 4-byte follow-up length.
143    assertData(String.format("%04X", payload.completeSegmentByteCount()));
144    assertData(payload.readByteArray());
145
146    sink.close();
147    assertData("8000");
148  }
149
150  @Test public void serverMessageLengthLong() throws IOException {
151    Sink sink = serverWriter.newMessageSink(OPCODE_BINARY);
152
153    // Create a payload which will overflow the normal and short payload byte size.
154    Buffer payload = new Buffer();
155    while (payload.completeSegmentByteCount() <= PAYLOAD_SHORT_MAX) {
156      payload.writeByte('0');
157    }
158    long byteCount = payload.completeSegmentByteCount();
159
160    // Write directly to the unbuffered sink. This ensures it will become single frame.
161    sink.write(payload.clone(), byteCount);
162    assertData("027f"); // 'f' == 16-byte follow-up length.
163    assertData(String.format("%016X", byteCount));
164    assertData(payload.readByteArray(byteCount));
165
166    sink.close();
167    assertData("8000");
168  }
169
170  @Test public void clientBinary() throws IOException {
171    byte[] maskKey1 = new byte[4];
172    random.nextBytes(maskKey1);
173    byte[] maskKey2 = new byte[4];
174    random.nextBytes(maskKey2);
175
176    random.setSeed(0); // Reset the seed so real data matches.
177
178    BufferedSink sink = Okio.buffer(clientWriter.newMessageSink(OPCODE_BINARY));
179
180    byte[] part1 = binaryData(50);
181    sink.write(part1).flush();
182    toggleMask(part1, 50, maskKey1, 0);
183    assertData("02b2");
184    assertData(maskKey1);
185    assertData(part1);
186
187    byte[] part2 = binaryData(50);
188    sink.write(part2).close();
189    toggleMask(part2, 50, maskKey2, 0);
190    assertData("80b2");
191    assertData(maskKey2);
192    assertData(part2);
193  }
194
195  @Test public void serverEmptyClose() throws IOException {
196    serverWriter.writeClose(0, null);
197    assertData("8800");
198  }
199
200  @Test public void serverCloseWithCode() throws IOException {
201    serverWriter.writeClose(1005, null);
202    assertData("880203ed");
203  }
204
205  @Test public void serverCloseWithCodeAndReason() throws IOException {
206    serverWriter.writeClose(1005, "Hello");
207    assertData("880703ed48656c6c6f");
208  }
209
210  @Test public void clientEmptyClose() throws IOException {
211    clientWriter.writeClose(0, null);
212    assertData("888060b420bb");
213  }
214
215  @Test public void clientCloseWithCode() throws IOException {
216    clientWriter.writeClose(1005, null);
217    assertData("888260b420bb6359");
218  }
219
220  @Test public void clientCloseWithCodeAndReason() throws IOException {
221    clientWriter.writeClose(1005, "Hello");
222    assertData("888760b420bb635968de0cd84f");
223  }
224
225  @Test public void closeWithOnlyReasonThrows() throws IOException {
226    clientWriter.writeClose(0, "Hello");
227    assertData("888760b420bb60b468de0cd84f");
228  }
229
230  @Test public void closeCodeOutOfRangeThrows() throws IOException {
231    try {
232      clientWriter.writeClose(98724976, "Hello");
233      fail();
234    } catch (IllegalArgumentException e) {
235      assertEquals("Code must be in range [1000,5000).", e.getMessage());
236    }
237  }
238
239  @Test public void serverEmptyPing() throws IOException {
240    serverWriter.writePing(null);
241    assertData("8900");
242  }
243
244  @Test public void clientEmptyPing() throws IOException {
245    clientWriter.writePing(null);
246    assertData("898060b420bb");
247  }
248
249  @Test public void serverPingWithPayload() throws IOException {
250    serverWriter.writePing(new Buffer().writeUtf8("Hello"));
251    assertData("890548656c6c6f");
252  }
253
254  @Test public void clientPingWithPayload() throws IOException {
255    clientWriter.writePing(new Buffer().writeUtf8("Hello"));
256    assertData("898560b420bb28d14cd70f");
257  }
258
259  @Test public void serverEmptyPong() throws IOException {
260    serverWriter.writePong(null);
261    assertData("8a00");
262  }
263
264  @Test public void clientEmptyPong() throws IOException {
265    clientWriter.writePong(null);
266    assertData("8a8060b420bb");
267  }
268
269  @Test public void serverPongWithPayload() throws IOException {
270    serverWriter.writePong(new Buffer().writeUtf8("Hello"));
271    assertData("8a0548656c6c6f");
272  }
273
274  @Test public void clientPongWithPayload() throws IOException {
275    clientWriter.writePong(new Buffer().writeUtf8("Hello"));
276    assertData("8a8560b420bb28d14cd70f");
277  }
278
279  @Test public void pingTooLongThrows() throws IOException {
280    try {
281      serverWriter.writePing(new Buffer().write(binaryData(1000)));
282      fail();
283    } catch (IllegalArgumentException e) {
284      assertEquals("Payload size must be less than or equal to 125", e.getMessage());
285    }
286  }
287
288  @Test public void pongTooLongThrows() throws IOException {
289    try {
290      serverWriter.writePong(new Buffer().write(binaryData(1000)));
291      fail();
292    } catch (IllegalArgumentException e) {
293      assertEquals("Payload size must be less than or equal to 125", e.getMessage());
294    }
295  }
296
297  @Test public void closeTooLongThrows() throws IOException {
298    try {
299      String longString = ByteString.of(binaryData(75)).hex();
300      serverWriter.writeClose(1000, longString);
301      fail();
302    } catch (IllegalArgumentException e) {
303      assertEquals("Payload size must be less than or equal to 125", e.getMessage());
304    }
305  }
306
307  @Test public void twoMessageSinksThrows() {
308    clientWriter.newMessageSink(OPCODE_TEXT);
309    try {
310      clientWriter.newMessageSink(OPCODE_TEXT);
311      fail();
312    } catch (IllegalStateException e) {
313      assertEquals("Another message writer is active. Did you call close()?", e.getMessage());
314    }
315  }
316
317  private void assertData(String hex) throws EOFException {
318    ByteString expected = ByteString.decodeHex(hex);
319    ByteString actual = data.readByteString(expected.size());
320    assertEquals(expected, actual);
321  }
322
323  private void assertData(byte[] data) throws IOException {
324    int byteCount = 16;
325    for (int i = 0; i < data.length; i += byteCount) {
326      int count = Math.min(byteCount, data.length - i);
327      Buffer expectedChunk = new Buffer();
328      expectedChunk.write(data, i, count);
329      assertEquals("At " + i, expectedChunk.readByteString(), this.data.readByteString(count));
330    }
331  }
332
333  private static byte[] binaryData(int length) {
334    byte[] junk = new byte[length];
335    new Random(0).nextBytes(junk);
336    return junk;
337  }
338}
339