SocketTunnelClient.java revision 1320f92c476a1ad9d19dba2a48c72b75566198e9
1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5package org.chromium.components.devtools_bridge;
6
7import android.net.LocalServerSocket;
8import android.net.LocalSocket;
9import android.util.Log;
10
11import java.io.IOException;
12import java.util.HashMap;
13import java.util.Map;
14import java.util.concurrent.ConcurrentHashMap;
15import java.util.concurrent.ConcurrentMap;
16import java.util.concurrent.ExecutorService;
17import java.util.concurrent.Executors;
18import java.util.concurrent.atomic.AtomicReference;
19
20/**
21 * Listens LocalServerSocket and tunnels all connections to the SocketTunnelServer.
22 */
23public class SocketTunnelClient extends SocketTunnelBase {
24    private static final String TAG = "SocketTunnelClient";
25
26    private enum State {
27        INITIAL, RUNNING, STOPPED
28    }
29
30    private final AtomicReference<State> mState = new AtomicReference<State>(State.INITIAL);
31
32    private final LocalServerSocket mSocket;
33    private final ExecutorService mThreadPool = Executors.newCachedThreadPool();
34
35    // Connections with opened server to client stream. Always accesses on signaling thread.
36    private final Map<Integer, Connection> mServerConnections =
37            new HashMap<Integer, Connection>();
38
39    // Accepted connections are kept here until server returns SERVER_OPEN_ACK or SERVER_CLOSE.
40    // New connections are added in the listening loop, checked and removed on signaling thread.
41    // So add/read/remove synchronized through message round trip.
42    private final ConcurrentMap<Integer, Connection> mPendingConnections =
43            new ConcurrentHashMap<Integer, Connection>();
44
45    private final IdRegistry mIdRegistry = new IdRegistry(MIN_CONNECTION_ID, MAX_CONNECTION_ID, 2);
46
47    /**
48     * This class responsible for generating valid connection IDs. It count usage of connection:
49     * one user for client to server stream and one for server to client one. When both are closed
50     * it's safe to reuse ID.
51     */
52    private static final class IdRegistry {
53        private final int[] mLocks;
54        private final int mMin;
55        private final int mMax;
56        private final int mMaxLocks;
57        private final Object mLock = new Object();
58
59        public IdRegistry(int minId, int maxId, int maxLocks) {
60            assert minId < maxId;
61            assert maxLocks > 0;
62
63            mMin = minId;
64            mMax = maxId;
65            mMaxLocks = maxLocks;
66            mLocks = new int[maxId - minId + 1];
67        }
68
69        public void lock(int id) {
70            synchronized (mLock) {
71                int index = toIndex(id);
72                if (mLocks[index] == 0 || mLocks[index] == mMaxLocks) {
73                    throw new RuntimeException();
74                }
75                mLocks[index]++;
76            }
77        }
78
79        public void release(int id) {
80            synchronized (mLock) {
81                int index = toIndex(id);
82                if (mLocks[index] == 0) {
83                    throw new RuntimeException("Releasing unlocked id " + Integer.toString(id));
84                }
85                mLocks[index]--;
86            }
87        }
88
89        public boolean isLocked(int id) {
90            synchronized (mLock) {
91                return mLocks[toIndex(id)] > 0;
92            }
93        }
94
95        public int generate() throws NoIdAvailableException {
96            synchronized (mLock) {
97                for (int id = mMin; id != mMax; id++) {
98                    int index = toIndex(id);
99                    if (mLocks[index] == 0) {
100                        mLocks[index] = 1;
101                        return id;
102                    }
103                }
104            }
105            throw new NoIdAvailableException();
106        }
107
108        private int toIndex(int id) {
109            if (id < mMin || id > mMax) {
110                throw new RuntimeException();
111            }
112            return id - mMin;
113        }
114    }
115
116    private static class NoIdAvailableException extends Exception {}
117
118    public SocketTunnelClient(String socketName) throws IOException {
119        mSocket = new LocalServerSocket(socketName);
120    }
121
122    public boolean hasConnections() {
123        return mServerConnections.size() + mPendingConnections.size() > 0;
124    }
125
126    @Override
127    public AbstractDataChannel unbind() {
128        AbstractDataChannel dataChannel = super.unbind();
129        close();
130        return dataChannel;
131    }
132
133    public void close() {
134        if (mState.get() != State.STOPPED) closeSocket();
135    }
136
137    @Override
138    protected void onReceivedDataPacket(int connectionId, byte[] data) throws ProtocolError {
139        checkCalledOnSignalingThread();
140
141        if (!mServerConnections.containsKey(connectionId))
142            throw new ProtocolError("Unknows connection id");
143
144        mServerConnections.get(connectionId).onReceivedDataPacket(data);
145    }
146
147    @Override
148    protected void onReceivedControlPacket(int connectionId, byte opCode) throws ProtocolError {
149        switch (opCode) {
150            case SERVER_OPEN_ACK:
151                onServerOpenAck(connectionId);
152                break;
153
154            case SERVER_CLOSE:
155                onServerClose(connectionId);
156                break;
157
158            default:
159                throw new ProtocolError("Invalid opCode");
160        }
161    }
162
163    private void onServerOpenAck(int connectionId) throws ProtocolError {
164        checkCalledOnSignalingThread();
165
166        if (mServerConnections.containsKey(connectionId)) {
167            throw new ProtocolError("Connection already acknowledged");
168        }
169
170        if (!mPendingConnections.containsKey(connectionId)) {
171            throw new ProtocolError("Unknow connection id");
172        }
173
174        // Check/get is safe since it can be only removed on this thread.
175        Connection connection = mPendingConnections.get(connectionId);
176        mPendingConnections.remove(connectionId);
177
178        mServerConnections.put(connectionId, connection);
179
180        // Lock for client to server stream.
181        mIdRegistry.lock(connectionId);
182        mThreadPool.execute(connection);
183    }
184
185    private void onServerClose(int connectionId) throws ProtocolError {
186        checkCalledOnSignalingThread();
187
188        if (mServerConnections.containsKey(connectionId)) {
189            Connection connection = mServerConnections.get(connectionId);
190            mServerConnections.remove(connectionId);
191            mIdRegistry.release(connectionId); // Release sever to client stream.
192            connection.closedByServer();
193        } else if (mPendingConnections.containsKey(connectionId)) {
194            Connection connection = mPendingConnections.get(connectionId);
195            mPendingConnections.remove(connectionId);
196            connection.closedByServer();
197            sendToDataChannel(buildControlPacket(connectionId, CLIENT_CLOSE));
198            mIdRegistry.release(connectionId); // Release sever to client stream.
199        } else {
200            throw new ProtocolError("Closing unknown connection");
201        }
202    }
203
204    @Override
205    protected void onDataChannelOpened() {
206        if (!mState.compareAndSet(State.INITIAL, State.RUNNING)) {
207            throw new InvalidStateException();
208        }
209
210        mThreadPool.execute(new Runnable() {
211            @Override
212            public void run() {
213                runListenLoop();
214            }
215        });
216    }
217
218    @Override
219    protected void onDataChannelClosed() {
220        // All new connections will be rejected.
221        if (!mState.compareAndSet(State.RUNNING, State.STOPPED)) {
222            throw new InvalidStateException();
223        }
224
225        for (Connection connection : mServerConnections.values()) {
226            connection.terminate();
227        }
228
229        for (Connection connection : mPendingConnections.values()) {
230            connection.terminate();
231        }
232
233        closeSocket();
234
235        mThreadPool.shutdown();
236    }
237
238    private void closeSocket() {
239        try {
240            mSocket.close();
241        } catch (IOException e) {
242            Log.d(TAG, "Failed to close socket: " + e);
243            onSocketException(e, -1);
244        }
245    }
246
247    private void runListenLoop() {
248        try {
249            while (true) {
250                LocalSocket socket = mSocket.accept();
251                State state = mState.get();
252                if (mState.get() == State.RUNNING) {
253                    // Make sure no socket processed when stopped.
254                    clientOpenConnection(socket);
255                } else {
256                    socket.close();
257                }
258            }
259        } catch (IOException e) {
260            if (mState.get() != State.RUNNING) {
261                onSocketException(e, -1);
262            }
263            // Else exception expected (socket closed).
264        }
265    }
266
267    private void clientOpenConnection(LocalSocket socket) throws IOException {
268        try {
269            int id = mIdRegistry.generate();  // id generated locked for server to client stream.
270            Connection connection = new Connection(id, socket);
271            mPendingConnections.put(id, connection);
272            sendToDataChannel(buildControlPacket(id, CLIENT_OPEN));
273        } catch (NoIdAvailableException e) {
274            socket.close();
275        }
276    }
277
278    private final class Connection extends ConnectionBase implements Runnable {
279        public Connection(int id, LocalSocket socket) {
280            super(id, socket);
281        }
282
283        public void closedByServer() {
284            shutdownOutput();
285        }
286
287        @Override
288        public void run() {
289            assert mIdRegistry.isLocked(mId);
290
291            runReadingLoop();
292
293            shutdownInput();
294            sendToDataChannel(buildControlPacket(mId, CLIENT_CLOSE));
295            mIdRegistry.release(mId);  // Unlock for client to server stream.
296        }
297    }
298
299    /**
300     * Method called in inappropriate state.
301     */
302    public static class InvalidStateException extends RuntimeException {}
303}
304