/* * Copyright (C) 2014 Square, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.squareup.okhttp.ws; import com.squareup.okhttp.OkHttpClient; import com.squareup.okhttp.Request; import com.squareup.okhttp.Response; import com.squareup.okhttp.internal.SslContextBuilder; import com.squareup.okhttp.mockwebserver.MockResponse; import com.squareup.okhttp.mockwebserver.MockWebServer; import com.squareup.okhttp.testing.RecordingHostnameVerifier; import java.io.IOException; import java.net.ProtocolException; import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import javax.net.ssl.SSLContext; import okio.Buffer; import okio.BufferedSink; import okio.BufferedSource; import org.junit.After; import org.junit.Rule; import org.junit.Test; import static com.squareup.okhttp.ws.WebSocket.PayloadType.TEXT; public final class WebSocketCallTest { @Rule public final MockWebServer server = new MockWebServer(); private final SSLContext sslContext = SslContextBuilder.localhost(); private final WebSocketRecorder listener = new WebSocketRecorder(); private final OkHttpClient client = new OkHttpClient(); private final Random random = new Random(0); @After public void tearDown() { listener.assertExhausted(); } @Test public void clientPingPong() throws IOException { WebSocketListener serverListener = new EmptyWebSocketListener(); server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); WebSocket webSocket = awaitWebSocket(); webSocket.sendPing(new Buffer().writeUtf8("Hello, WebSockets!")); listener.assertPong(new Buffer().writeUtf8("Hello, WebSockets!")); } @Test public void clientMessage() throws IOException { WebSocketRecorder serverListener = new WebSocketRecorder(); server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); WebSocket webSocket = awaitWebSocket(); webSocket.sendMessage(TEXT, new Buffer().writeUtf8("Hello, WebSockets!")); serverListener.assertTextMessage("Hello, WebSockets!"); } @Test public void serverMessage() throws IOException { WebSocketListener serverListener = new EmptyWebSocketListener() { @Override public void onOpen(final WebSocket webSocket, Response response) { new Thread() { @Override public void run() { try { webSocket.sendMessage(TEXT, new Buffer().writeUtf8("Hello, WebSockets!")); } catch (IOException e) { throw new AssertionError(e); } } }.start(); } }; server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); awaitWebSocket(); listener.assertTextMessage("Hello, WebSockets!"); } @Test public void clientStreamingMessage() throws IOException { WebSocketRecorder serverListener = new WebSocketRecorder(); server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); WebSocket webSocket = awaitWebSocket(); BufferedSink sink = webSocket.newMessageSink(TEXT); sink.writeUtf8("Hello, ").flush(); sink.writeUtf8("WebSockets!").flush(); sink.close(); serverListener.assertTextMessage("Hello, WebSockets!"); } @Test public void serverStreamingMessage() throws IOException { WebSocketListener serverListener = new EmptyWebSocketListener() { @Override public void onOpen(final WebSocket webSocket, Response response) { new Thread() { @Override public void run() { try { BufferedSink sink = webSocket.newMessageSink(TEXT); sink.writeUtf8("Hello, ").flush(); sink.writeUtf8("WebSockets!").flush(); sink.close(); } catch (IOException e) { throw new AssertionError(e); } } }.start(); } }; server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); awaitWebSocket(); listener.assertTextMessage("Hello, WebSockets!"); } @Test public void okButNotOk() { server.enqueue(new MockResponse()); awaitWebSocket(); listener.assertFailure(ProtocolException.class, "Expected HTTP 101 response but was '200 OK'"); } @Test public void notFound() { server.enqueue(new MockResponse().setStatus("HTTP/1.1 404 Not Found")); awaitWebSocket(); listener.assertFailure(ProtocolException.class, "Expected HTTP 101 response but was '404 Not Found'"); } @Test public void missingConnectionHeader() { server.enqueue(new MockResponse() .setResponseCode(101) .setHeader("Upgrade", "websocket") .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")); awaitWebSocket(); listener.assertFailure(ProtocolException.class, "Expected 'Connection' header value 'Upgrade' but was 'null'"); } @Test public void wrongConnectionHeader() { server.enqueue(new MockResponse() .setResponseCode(101) .setHeader("Upgrade", "websocket") .setHeader("Connection", "Downgrade") .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")); awaitWebSocket(); listener.assertFailure(ProtocolException.class, "Expected 'Connection' header value 'Upgrade' but was 'Downgrade'"); } @Test public void missingUpgradeHeader() { server.enqueue(new MockResponse() .setResponseCode(101) .setHeader("Connection", "Upgrade") .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")); awaitWebSocket(); listener.assertFailure(ProtocolException.class, "Expected 'Upgrade' header value 'websocket' but was 'null'"); } @Test public void wrongUpgradeHeader() { server.enqueue(new MockResponse() .setResponseCode(101) .setHeader("Connection", "Upgrade") .setHeader("Upgrade", "Pepsi") .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")); awaitWebSocket(); listener.assertFailure(ProtocolException.class, "Expected 'Upgrade' header value 'websocket' but was 'Pepsi'"); } @Test public void missingMagicHeader() { server.enqueue(new MockResponse() .setResponseCode(101) .setHeader("Connection", "Upgrade") .setHeader("Upgrade", "websocket")); awaitWebSocket(); listener.assertFailure(ProtocolException.class, "Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'null'"); } @Test public void wrongMagicHeader() { server.enqueue(new MockResponse() .setResponseCode(101) .setHeader("Connection", "Upgrade") .setHeader("Upgrade", "websocket") .setHeader("Sec-WebSocket-Accept", "magic")); awaitWebSocket(); listener.assertFailure(ProtocolException.class, "Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'magic'"); } @Test public void wsScheme() throws IOException { websocketScheme("ws"); } @Test public void wsUppercaseScheme() throws IOException { websocketScheme("WS"); } @Test public void wssScheme() throws IOException { server.useHttps(sslContext.getSocketFactory(), false); client.setSslSocketFactory(sslContext.getSocketFactory()); client.setHostnameVerifier(new RecordingHostnameVerifier()); websocketScheme("wss"); } @Test public void httpsScheme() throws IOException { server.useHttps(sslContext.getSocketFactory(), false); client.setSslSocketFactory(sslContext.getSocketFactory()); client.setHostnameVerifier(new RecordingHostnameVerifier()); websocketScheme("https"); } private void websocketScheme(String scheme) throws IOException { WebSocketRecorder serverListener = new WebSocketRecorder(); server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); Request request1 = new Request.Builder() .url(scheme + "://" + server.getHostName() + ":" + server.getPort() + "/") .build(); WebSocket webSocket = awaitWebSocket(request1); webSocket.sendMessage(TEXT, new Buffer().writeUtf8("abc")); serverListener.assertTextMessage("abc"); } private WebSocket awaitWebSocket() { return awaitWebSocket(new Request.Builder().get().url(server.url("/")).build()); } private WebSocket awaitWebSocket(Request request) { WebSocketCall call = new WebSocketCall(client, request, random); final AtomicReference responseRef = new AtomicReference<>(); final AtomicReference webSocketRef = new AtomicReference<>(); final AtomicReference failureRef = new AtomicReference<>(); final CountDownLatch latch = new CountDownLatch(1); call.enqueue(new WebSocketListener() { @Override public void onOpen(WebSocket webSocket, Response response) { webSocketRef.set(webSocket); responseRef.set(response); latch.countDown(); } @Override public void onMessage(BufferedSource payload, WebSocket.PayloadType type) throws IOException { listener.onMessage(payload, type); } @Override public void onPong(Buffer payload) { listener.onPong(payload); } @Override public void onClose(int code, String reason) { listener.onClose(code, reason); } @Override public void onFailure(IOException e, Response response) { listener.onFailure(e, null); failureRef.set(e); latch.countDown(); } }); try { if (!latch.await(10, TimeUnit.SECONDS)) { throw new AssertionError("Timed out."); } } catch (InterruptedException e) { throw new AssertionError(e); } return webSocketRef.get(); } private static class EmptyWebSocketListener implements WebSocketListener { @Override public void onOpen(WebSocket webSocket, Response response) { } @Override public void onMessage(BufferedSource payload, WebSocket.PayloadType type) throws IOException { } @Override public void onPong(Buffer payload) { } @Override public void onClose(int code, String reason) { } @Override public void onFailure(IOException e, Response response) { } } }