1/*
2 * Copyright 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import java.io.EOFException;
20import java.io.IOException;
21import java.io.InputStream;
22import java.io.OutputStream;
23import java.net.InetAddress;
24import java.net.SocketException;
25import java.nio.channels.ClosedChannelException;
26import javax.net.ssl.SSLException;
27import javax.net.ssl.SSLSocket;
28import javax.net.ssl.SSLSocketFactory;
29
30/**
31 * Client-side endpoint. Provides basic services for sending/receiving messages from the client
32 * socket.
33 */
34final class ClientEndpoint {
35    private final SSLSocket socket;
36    private InputStream input;
37    private OutputStream output;
38
39    ClientEndpoint(SSLSocketFactory socketFactory, ChannelType channelType, int port,
40            String[] protocols, String[] ciphers) throws IOException {
41        socket = channelType.newClientSocket(socketFactory, InetAddress.getLoopbackAddress(), port);
42        socket.setEnabledProtocols(protocols);
43        socket.setEnabledCipherSuites(ciphers);
44    }
45
46    void start() {
47        try {
48            socket.startHandshake();
49            input = socket.getInputStream();
50            output = socket.getOutputStream();
51        } catch (IOException e) {
52            e.printStackTrace();
53            throw new RuntimeException(e);
54        }
55    }
56
57    void stop() {
58        try {
59            socket.close();
60        } catch (IOException e) {
61            throw new RuntimeException(e);
62        }
63    }
64
65    int readMessage(byte[] buffer) {
66        try {
67            int totalBytesRead = 0;
68            while (totalBytesRead < buffer.length) {
69                int remaining = buffer.length - totalBytesRead;
70                int bytesRead = input.read(buffer, totalBytesRead, remaining);
71                if (bytesRead == -1) {
72                    break;
73                }
74                totalBytesRead += bytesRead;
75            }
76            return totalBytesRead;
77        } catch (SSLException e) {
78            if (e.getCause() instanceof EOFException) {
79                return -1;
80            }
81            throw new RuntimeException(e);
82        } catch (ClosedChannelException e) {
83            // Thrown for channel-based sockets. Just treat like EOF.
84            return -1;
85        }  catch (SocketException e) {
86            // The socket was broken. Just treat like EOF.
87            return -1;
88        } catch (IOException e) {
89            throw new RuntimeException(e);
90        }
91    }
92
93    void sendMessage(byte[] data) {
94        try {
95            output.write(data);
96        } catch (IOException e) {
97            throw new RuntimeException(e);
98        }
99    }
100
101    void flush() {
102        try {
103            output.flush();
104        } catch (IOException e) {
105            throw new RuntimeException(e);
106        }
107    }
108}
109