1/*
2 * Copyright 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import static javax.net.ssl.SSLEngineResult.Status.OK;
20
21import java.io.EOFException;
22import java.io.FileDescriptor;
23import java.io.IOException;
24import java.io.InputStream;
25import java.io.OutputStream;
26import java.net.Socket;
27import java.net.SocketException;
28import java.nio.ByteBuffer;
29import java.nio.channels.SocketChannel;
30import java.security.PrivateKey;
31import java.security.cert.CertificateException;
32import javax.crypto.SecretKey;
33import javax.net.ssl.SSLEngineResult;
34import javax.net.ssl.SSLException;
35import javax.net.ssl.SSLSession;
36import javax.net.ssl.X509KeyManager;
37import javax.security.auth.x500.X500Principal;
38
39/**
40 * Implements crypto handling by delegating to OpenSSLEngine. Used for socket implementations
41 * that are not backed by a real OS socket.
42 *
43 * @hide
44 */
45final class OpenSSLEngineSocketImpl extends OpenSSLSocketImplWrapper {
46    private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0);
47
48    private final OpenSSLEngineImpl engine;
49    private final Socket socket;
50    private final OutputStreamWrapper outputStreamWrapper;
51    private final InputStreamWrapper inputStreamWrapper;
52    private boolean handshakeComplete;
53
54    OpenSSLEngineSocketImpl(Socket socket, String hostname, int port, boolean autoClose,
55            SSLParametersImpl sslParameters) throws IOException {
56        super(socket, hostname, port, autoClose, sslParameters);
57        this.socket = socket;
58        engine = new OpenSSLEngineImpl(hostname, port, sslParameters);
59
60        // When the handshake completes, notify any listeners.
61        engine.setHandshakeListener(new HandshakeListener() {
62            @Override
63            public void onHandshakeFinished() {
64                if (!handshakeComplete) {
65                    handshakeComplete = true;
66                    OpenSSLEngineSocketImpl.this.notifyHandshakeCompletedListeners();
67                }
68            }
69        });
70        outputStreamWrapper = new OutputStreamWrapper();
71        inputStreamWrapper = new InputStreamWrapper();
72        engine.setUseClientMode(sslParameters.getUseClientMode());
73    }
74
75    @Override
76    public void startHandshake() throws IOException {
77        // Trigger the handshake
78        boolean beginHandshakeCalled = false;
79        while (!handshakeComplete) {
80            switch (engine.getHandshakeStatus()) {
81                case NOT_HANDSHAKING: {
82                    if (!beginHandshakeCalled) {
83                        beginHandshakeCalled = true;
84                        engine.beginHandshake();
85                        break;
86                    }
87                    break;
88                }
89                case FINISHED: {
90                    return;
91                }
92                case NEED_WRAP: {
93                    outputStreamWrapper.write(EMPTY_BUFFER);
94                    break;
95                }
96                case NEED_UNWRAP: {
97                    if (inputStreamWrapper.read(EmptyArray.BYTE) == -1) {
98                        // Can't complete the handshake due to EOF.
99                        throw new EOFException();
100                    }
101                    break;
102                }
103                case NEED_TASK: {
104                    throw new IllegalStateException("OpenSSLEngineImpl returned NEED_TASK");
105                }
106                default: { break; }
107            }
108        }
109    }
110
111    @Override
112    public void onSSLStateChange(int type, int val) {
113        throw new AssertionError("Should be handled by engine");
114    }
115
116    @Override
117    public void verifyCertificateChain(long[] certRefs, String authMethod)
118            throws CertificateException {
119        throw new AssertionError("Should be handled by engine");
120    }
121
122    @Override
123    public InputStream getInputStream() throws IOException {
124        return inputStreamWrapper;
125    }
126
127    @Override
128    public OutputStream getOutputStream() throws IOException {
129        return outputStreamWrapper;
130    }
131
132    @Override
133    public SSLSession getSession() {
134        return engine.getSession();
135    }
136
137    @Override
138    public boolean getEnableSessionCreation() {
139        return super.getEnableSessionCreation();
140    }
141
142    @Override
143    public void setEnableSessionCreation(boolean flag) {
144        super.setEnableSessionCreation(flag);
145    }
146
147    @Override
148    public String[] getSupportedCipherSuites() {
149        return super.getSupportedCipherSuites();
150    }
151
152    @Override
153    public String[] getEnabledCipherSuites() {
154        return super.getEnabledCipherSuites();
155    }
156
157    @Override
158    public void setEnabledCipherSuites(String[] suites) {
159        super.setEnabledCipherSuites(suites);
160    }
161
162    @Override
163    public String[] getSupportedProtocols() {
164        return super.getSupportedProtocols();
165    }
166
167    @Override
168    public String[] getEnabledProtocols() {
169        return super.getEnabledProtocols();
170    }
171
172    @Override
173    public void setEnabledProtocols(String[] protocols) {
174        super.setEnabledProtocols(protocols);
175    }
176
177    @Override
178    public void setUseSessionTickets(boolean useSessionTickets) {
179        super.setUseSessionTickets(useSessionTickets);
180    }
181
182    @Override
183    public void setHostname(String hostname) {
184        super.setHostname(hostname);
185    }
186
187    @Override
188    public void setChannelIdEnabled(boolean enabled) {
189        super.setChannelIdEnabled(enabled);
190    }
191
192    @Override
193    public byte[] getChannelId() throws SSLException {
194        return super.getChannelId();
195    }
196
197    @Override
198    public void setChannelIdPrivateKey(PrivateKey privateKey) {
199        super.setChannelIdPrivateKey(privateKey);
200    }
201
202    @Override
203    public boolean getUseClientMode() {
204        return super.getUseClientMode();
205    }
206
207    @Override
208    public void setUseClientMode(boolean mode) {
209        engine.setUseClientMode(mode);
210    }
211
212    @Override
213    public boolean getWantClientAuth() {
214        return super.getWantClientAuth();
215    }
216
217    @Override
218    public boolean getNeedClientAuth() {
219        return super.getNeedClientAuth();
220    }
221
222    @Override
223    public void setNeedClientAuth(boolean need) {
224        super.setNeedClientAuth(need);
225    }
226
227    @Override
228    public void setWantClientAuth(boolean want) {
229        super.setWantClientAuth(want);
230    }
231
232    @Override
233    public void sendUrgentData(int data) throws IOException {
234        super.sendUrgentData(data);
235    }
236
237    @Override
238    public void setOOBInline(boolean on) throws SocketException {
239        super.setOOBInline(on);
240    }
241
242    @Override
243    public void setSoWriteTimeout(int writeTimeoutMilliseconds) throws SocketException {
244        throw new UnsupportedOperationException("Not supported");
245    }
246
247    @Override
248    public int getSoWriteTimeout() throws SocketException {
249        return 0;
250    }
251
252    @Override
253    public void setHandshakeTimeout(int handshakeTimeoutMilliseconds) throws SocketException {
254        throw new UnsupportedOperationException("Not supported");
255    }
256
257    @Override
258    public synchronized void close() throws IOException {
259        // Closing Socket.
260        engine.closeInbound();
261        engine.closeOutbound();
262        socket.close();
263    }
264
265    @Override
266    protected void finalize() throws Throwable {
267        super.finalize();
268    }
269
270    @Override
271    public SocketChannel getChannel() {
272        return super.getChannel();
273    }
274
275    @Override
276    public FileDescriptor getFileDescriptor$() {
277        throw new UnsupportedOperationException("Not supported");
278    }
279
280    @Override
281    public byte[] getNpnSelectedProtocol() {
282        return null;
283    }
284
285    @Override
286    public byte[] getAlpnSelectedProtocol() {
287        return engine.getAlpnSelectedProtocol();
288    }
289
290    @Override
291    public void setNpnProtocols(byte[] npnProtocols) {
292        super.setNpnProtocols(npnProtocols);
293    }
294
295    @Override
296    public void setAlpnProtocols(byte[] alpnProtocols) {
297        super.setAlpnProtocols(alpnProtocols);
298    }
299
300    @Override
301    public String chooseServerAlias(X509KeyManager keyManager, String keyType) {
302        return engine.chooseServerAlias(keyManager, keyType);
303    }
304
305    @Override
306    public String chooseClientAlias(
307            X509KeyManager keyManager, X500Principal[] issuers, String[] keyTypes) {
308        return engine.chooseClientAlias(keyManager, issuers, keyTypes);
309    }
310
311    @Override
312    @SuppressWarnings("deprecation") // PSKKeyManager is deprecated, but in our own package
313    public String chooseServerPSKIdentityHint(PSKKeyManager keyManager) {
314        return engine.chooseServerPSKIdentityHint(keyManager);
315    }
316
317    @Override
318    @SuppressWarnings("deprecation") // PSKKeyManager is deprecated, but in our own package
319    public String chooseClientPSKIdentity(PSKKeyManager keyManager, String identityHint) {
320        return engine.chooseClientPSKIdentity(keyManager, identityHint);
321    }
322
323    @Override
324    @SuppressWarnings("deprecation") // PSKKeyManager is deprecated, but in our own package
325    public SecretKey getPSKKey(PSKKeyManager keyManager, String identityHint, String identity) {
326        return engine.getPSKKey(keyManager, identityHint, identity);
327    }
328
329    /**
330     * Wrap bytes written to the underlying socket.
331     */
332    private final class OutputStreamWrapper extends OutputStream {
333        private final Object stateLock = new Object();
334        private ByteBuffer target;
335        private OutputStream socketOutputStream;
336        private SocketChannel socketChannel;
337
338        OutputStreamWrapper() {}
339
340        @Override
341        public void write(int b) throws IOException {
342            write(new byte[] {(byte) b});
343        }
344
345        @Override
346        public void write(byte[] b) throws IOException {
347            write(ByteBuffer.wrap(b));
348        }
349
350        @Override
351        public void write(byte[] b, int off, int len) throws IOException {
352            write(ByteBuffer.wrap(b, off, len));
353        }
354
355        private void write(ByteBuffer buffer) throws IOException {
356            synchronized (stateLock) {
357                try {
358                    init();
359
360                    // Need to loop through at least once to enable handshaking where no application
361                    // bytes are
362                    // processed.
363                    int len = buffer.remaining();
364                    SSLEngineResult engineResult;
365                    do {
366                        target.clear();
367                        engineResult = engine.wrap(buffer, target);
368                        if (engineResult.getStatus() != OK) {
369                            throw new SSLException(
370                                    "Unexpected engine result " + engineResult.getStatus());
371                        }
372                        if (target.position() != engineResult.bytesProduced()) {
373                            throw new SSLException("Engine bytesProduced "
374                                    + engineResult.bytesProduced()
375                                    + " does not match bytes written " + target.position());
376                        }
377                        len -= engineResult.bytesConsumed();
378                        if (len != buffer.remaining()) {
379                            throw new SSLException(
380                                    "Engine did not read the correct number of bytes");
381                        }
382
383                        target.flip();
384
385                        // Write the data to the socket.
386                        if (socketChannel != null) {
387                            // Loop until all of the data is written to the channel. Typically,
388                            // SocketChannel writes will return only after all bytes are written,
389                            // so we won't really loop here.
390                            while (target.hasRemaining()) {
391                                socketChannel.write(target);
392                            }
393                        } else {
394                            // Target is a heap buffer.
395                            socketOutputStream.write(target.array(), 0, target.limit());
396                        }
397                    } while (len > 0);
398                } catch (IOException e) {
399                    e.printStackTrace();
400                    throw e;
401                } catch (RuntimeException e) {
402                    e.printStackTrace();
403                    throw e;
404                }
405            }
406        }
407
408        @Override
409        public void flush() throws IOException {
410            synchronized (stateLock) {
411                init();
412                socketOutputStream.flush();
413            }
414        }
415
416        @Override
417        public void close() throws IOException {
418            socket.close();
419        }
420
421        private void init() throws IOException {
422            if (socketOutputStream == null) {
423                socketOutputStream = socket.getOutputStream();
424                socketChannel = socket.getChannel();
425                if (socketChannel != null) {
426                    // Optimization. Using direct buffers wherever possible to avoid passing
427                    // arrays to JNI.
428                    target = ByteBuffer.allocateDirect(engine.getSession().getPacketBufferSize());
429                } else {
430                    target = ByteBuffer.allocate(engine.getSession().getPacketBufferSize());
431                }
432            }
433        }
434    }
435
436    /**
437     * Unwrap bytes read from the underlying socket.
438     */
439    private final class InputStreamWrapper extends InputStream {
440        private final Object stateLock = new Object();
441        private final byte[] singleByte = new byte[1];
442        private final ByteBuffer fromEngine;
443        private ByteBuffer fromSocket;
444        private InputStream socketInputStream;
445        private SocketChannel socketChannel;
446
447        InputStreamWrapper() {
448            fromEngine = ByteBuffer.allocateDirect(engine.getSession().getApplicationBufferSize());
449            // Initially fromEngine.remaining() == 0.
450            fromEngine.flip();
451        }
452
453        @Override
454        public int read() throws IOException {
455            synchronized (stateLock) {
456                // Handle returning of -1 if EOF is reached.
457                int count = read(singleByte, 0, 1);
458                if (count == -1) {
459                    // Handle EOF.
460                    return -1;
461                }
462                if (count != 1) {
463                    throw new SSLException("read incorrect number of bytes " + count);
464                }
465                return (int) singleByte[0];
466            }
467        }
468
469        @Override
470        public int read(byte[] b) throws IOException {
471            return read(b, 0, b.length);
472        }
473
474        @Override
475        public int read(byte[] b, int off, int len) throws IOException {
476            synchronized (stateLock) {
477                try {
478                    // Make sure the input stream has been created.
479                    init();
480
481                    for (;;) {
482                        // Serve any remaining data from the engine first.
483                        if (fromEngine.remaining() > 0) {
484                            int readFromEngine = Math.min(fromEngine.remaining(), len);
485                            fromEngine.get(b, off, readFromEngine);
486                            return readFromEngine;
487                        }
488
489                        // Try to unwrap any data already in the socket buffer.
490                        boolean needMoreData = true;
491                        if (fromSocket.position() > 0) {
492                            // Unwrap the unencrypted bytes into the engine buffer.
493                            fromSocket.flip();
494                            fromEngine.clear();
495                            SSLEngineResult engineResult = engine.unwrap(fromSocket, fromEngine);
496
497                            // Shift any remaining data to the beginning of the buffer so that
498                            // we can accommodate the next full packet. After this is called,
499                            // limit will be restored to capacity and position will point just
500                            // past the end of the data.
501                            fromSocket.compact();
502                            fromEngine.flip();
503
504                            switch (engineResult.getStatus()) {
505                                case BUFFER_UNDERFLOW: {
506                                    if (engineResult.bytesProduced() == 0) {
507                                        // Need to read more data from the socket.
508                                        break;
509                                    }
510                                    // Also serve the data that was produced.
511                                    needMoreData = false;
512                                    break;
513                                }
514                                case OK: {
515                                    // We processed the entire packet successfully.
516                                    needMoreData = false;
517                                    break;
518                                }
519                                case CLOSED: {
520                                    // EOF
521                                    return -1;
522                                }
523                                default: {
524                                    // Anything else is an error.
525                                    throw new SSLException(
526                                            "Unexpected engine result " + engineResult.getStatus());
527                                }
528                            }
529
530                            if (!needMoreData && engineResult.bytesProduced() == 0) {
531                                // Read successfully, but produced no data. Possibly part of a
532                                // handshake.
533                                return 0;
534                            }
535                        }
536
537                        // Read more data from the socket.
538                        if (needMoreData && readFromSocket() == -1) {
539                            // Failed to read the next encrypted packet before reaching EOF.
540                            return -1;
541                        }
542
543                        // Continue the loop and return the data from the engine buffer.
544                    }
545                } catch (IOException e) {
546                    e.printStackTrace();
547                    throw e;
548                } catch (RuntimeException e) {
549                    e.printStackTrace();
550                    throw e;
551                }
552            }
553        }
554
555        private void init() throws IOException {
556            if (socketInputStream == null) {
557                socketInputStream = socket.getInputStream();
558                socketChannel = socket.getChannel();
559                if (socketChannel != null) {
560                    fromSocket =
561                            ByteBuffer.allocateDirect(engine.getSession().getPacketBufferSize());
562                } else {
563                    fromSocket = ByteBuffer.allocate(engine.getSession().getPacketBufferSize());
564                }
565            }
566        }
567
568        private int readFromSocket() throws IOException {
569            if (socketChannel != null) {
570                return socketChannel.read(fromSocket);
571            }
572            // Read directly to the underlying array and increment the buffer position if
573            // appropriate.
574            int read = socketInputStream.read(
575                    fromSocket.array(), fromSocket.position(), fromSocket.remaining());
576            if (read > 0) {
577                fromSocket.position(fromSocket.position() + read);
578            }
579            return read;
580        }
581    }
582}
583