1/*
2 * Copyright (C) 2014 Square, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16package com.squareup.okhttp;
17
18import com.squareup.okhttp.internal.NamedRunnable;
19import com.squareup.okhttp.internal.Util;
20import java.io.IOException;
21import java.net.InetAddress;
22import java.net.InetSocketAddress;
23import java.net.ProtocolException;
24import java.net.Proxy;
25import java.net.ServerSocket;
26import java.net.Socket;
27import java.net.SocketException;
28import java.util.concurrent.ExecutorService;
29import java.util.concurrent.Executors;
30import java.util.concurrent.TimeUnit;
31import java.util.concurrent.atomic.AtomicInteger;
32import java.util.logging.Level;
33import java.util.logging.Logger;
34import okio.Buffer;
35import okio.BufferedSink;
36import okio.BufferedSource;
37import okio.Okio;
38
39/**
40 * A limited implementation of SOCKS Protocol Version 5, intended to be similar to MockWebServer.
41 * See <a href="https://www.ietf.org/rfc/rfc1928.txt">RFC 1928</a>.
42 */
43public final class SocksProxy {
44  public final String HOSTNAME_THAT_ONLY_THE_PROXY_KNOWS = "onlyProxyCanResolveMe.org";
45
46  private static final int VERSION_5 = 5;
47  private static final int METHOD_NONE = 0xff;
48  private static final int METHOD_NO_AUTHENTICATION_REQUIRED = 0;
49  private static final int ADDRESS_TYPE_IPV4 = 1;
50  private static final int ADDRESS_TYPE_DOMAIN_NAME = 3;
51  private static final int COMMAND_CONNECT = 1;
52  private static final int REPLY_SUCCEEDED = 0;
53
54  private static final Logger logger = Logger.getLogger(SocksProxy.class.getName());
55
56  private final ExecutorService executor = Executors.newCachedThreadPool(
57      Util.threadFactory("SocksProxy", false));
58
59  private ServerSocket serverSocket;
60  private AtomicInteger connectionCount = new AtomicInteger();
61
62  public void play() throws IOException {
63    serverSocket = new ServerSocket(0);
64    executor.execute(new NamedRunnable("SocksProxy %s", serverSocket.getLocalPort()) {
65      @Override protected void execute() {
66        try {
67          while (true) {
68            Socket socket = serverSocket.accept();
69            connectionCount.incrementAndGet();
70            service(socket);
71          }
72        } catch (SocketException e) {
73          logger.info(name + " done accepting connections: " + e.getMessage());
74        } catch (IOException e) {
75          logger.log(Level.WARNING, name + " failed unexpectedly", e);
76        }
77      }
78    });
79  }
80
81  public Proxy proxy() {
82    return new Proxy(Proxy.Type.SOCKS, InetSocketAddress.createUnresolved(
83        "localhost", serverSocket.getLocalPort()));
84  }
85
86  public int connectionCount() {
87    return connectionCount.get();
88  }
89
90  public void shutdown() throws Exception {
91    serverSocket.close();
92    executor.shutdown();
93    if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
94      throw new IOException("Gave up waiting for executor to shut down");
95    }
96  }
97
98  private void service(final Socket from) {
99    executor.execute(new NamedRunnable("SocksProxy %s", from.getRemoteSocketAddress()) {
100      @Override protected void execute() {
101        try {
102          BufferedSource fromSource = Okio.buffer(Okio.source(from));
103          BufferedSink fromSink = Okio.buffer(Okio.sink(from));
104          hello(fromSource, fromSink);
105          acceptCommand(from.getInetAddress(), fromSource, fromSink);
106        } catch (IOException e) {
107          logger.log(Level.WARNING, name + " failed", e);
108          Util.closeQuietly(from);
109        }
110      }
111    });
112  }
113
114  private void hello(BufferedSource fromSource, BufferedSink fromSink) throws IOException {
115    int version = fromSource.readByte() & 0xff;
116    int methodCount = fromSource.readByte() & 0xff;
117    int selectedMethod = METHOD_NONE;
118
119    if (version != VERSION_5) {
120      throw new ProtocolException("unsupported version: " + version);
121    }
122
123    for (int i = 0; i < methodCount; i++) {
124      int candidateMethod = fromSource.readByte() & 0xff;
125      if (candidateMethod == METHOD_NO_AUTHENTICATION_REQUIRED) {
126        selectedMethod = candidateMethod;
127      }
128    }
129
130    switch (selectedMethod) {
131      case METHOD_NO_AUTHENTICATION_REQUIRED:
132        fromSink.writeByte(VERSION_5);
133        fromSink.writeByte(selectedMethod);
134        fromSink.emit();
135        break;
136
137      default:
138        throw new ProtocolException("unsupported method: " + selectedMethod);
139    }
140  }
141
142  private void acceptCommand(InetAddress fromAddress, BufferedSource fromSource,
143      BufferedSink fromSink) throws IOException {
144    // Read the command.
145    int version = fromSource.readByte() & 0xff;
146    if (version != VERSION_5) throw new ProtocolException("unexpected version: " + version);
147    int command = fromSource.readByte() & 0xff;
148    int reserved = fromSource.readByte() & 0xff;
149    if (reserved != 0) throw new ProtocolException("unexpected reserved: " + reserved);
150
151    int addressType = fromSource.readByte() & 0xff;
152    InetAddress toAddress;
153    switch (addressType) {
154      case ADDRESS_TYPE_IPV4:
155        toAddress = InetAddress.getByAddress(fromSource.readByteArray(4L));
156        break;
157
158      case ADDRESS_TYPE_DOMAIN_NAME:
159        int domainNameLength = fromSource.readByte() & 0xff;
160        String domainName = fromSource.readUtf8(domainNameLength);
161        // Resolve HOSTNAME_THAT_ONLY_THE_PROXY_KNOWS to localhost.
162        toAddress = domainName.equalsIgnoreCase(HOSTNAME_THAT_ONLY_THE_PROXY_KNOWS)
163            ? InetAddress.getByName("localhost")
164            : InetAddress.getByName(domainName);
165        break;
166
167      default:
168        throw new ProtocolException("unsupported address type: " + addressType);
169    }
170
171    int port = fromSource.readShort() & 0xffff;
172
173    switch (command) {
174      case COMMAND_CONNECT:
175        Socket toSocket = new Socket(toAddress, port);
176        byte[] localAddress = toSocket.getLocalAddress().getAddress();
177        if (localAddress.length != 4) {
178          throw new ProtocolException("unexpected address: " + toSocket.getLocalAddress());
179        }
180
181        // Write the reply.
182        fromSink.writeByte(VERSION_5);
183        fromSink.writeByte(REPLY_SUCCEEDED);
184        fromSink.writeByte(0);
185        fromSink.writeByte(ADDRESS_TYPE_IPV4);
186        fromSink.write(localAddress);
187        fromSink.writeShort(toSocket.getLocalPort());
188        fromSink.emit();
189
190        logger.log(Level.INFO, "SocksProxy connected " + fromAddress + " to " + toAddress);
191
192        // Copy sources to sinks in both directions.
193        BufferedSource toSource = Okio.buffer(Okio.source(toSocket));
194        BufferedSink toSink = Okio.buffer(Okio.sink(toSocket));
195        transfer(fromAddress, toAddress, fromSource, toSink);
196        transfer(fromAddress, toAddress, toSource, fromSink);
197        break;
198
199      default:
200        throw new ProtocolException("unexpected command: " + command);
201    }
202  }
203
204  private void transfer(final InetAddress fromAddress, final InetAddress toAddress,
205      final BufferedSource source, final BufferedSink sink) {
206    executor.execute(new NamedRunnable("SocksProxy %s to %s", fromAddress, toAddress) {
207      @Override protected void execute() {
208        Buffer buffer = new Buffer();
209        try {
210          while (true) {
211            long byteCount = source.read(buffer, 2048L);
212            if (byteCount == -1L) break;
213            sink.write(buffer, byteCount);
214            sink.emit();
215          }
216        } catch (SocketException e) {
217          logger.info(name + " done: " + e.getMessage());
218        } catch (IOException e) {
219          logger.log(Level.WARNING, name + " failed", e);
220        }
221
222        try {
223          source.close();
224        } catch (IOException e) {
225          logger.log(Level.WARNING, name + " failed", e);
226        }
227
228        try {
229          sink.close();
230        } catch (IOException e) {
231          logger.log(Level.WARNING, name + " failed", e);
232        }
233      }
234    });
235  }
236}
237