1// Copyright (c) 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/test/chromedriver/net/adb_client_socket.h"
6
7#include "base/bind.h"
8#include "base/compiler_specific.h"
9#include "base/strings/string_number_conversions.h"
10#include "base/strings/string_util.h"
11#include "base/strings/stringprintf.h"
12#include "net/base/address_list.h"
13#include "net/base/completion_callback.h"
14#include "net/base/net_errors.h"
15#include "net/base/net_util.h"
16#include "net/socket/tcp_client_socket.h"
17
18namespace {
19
20const int kBufferSize = 16 * 1024;
21const char kOkayResponse[] = "OKAY";
22const char kHostTransportCommand[] = "host:transport:%s";
23const char kLocalAbstractCommand[] = "localabstract:%s";
24const char kLocalhost[] = "127.0.0.1";
25
26typedef base::Callback<void(int, const std::string&)> CommandCallback;
27typedef base::Callback<void(int, net::StreamSocket*)> SocketCallback;
28
29std::string EncodeMessage(const std::string& message) {
30  static const char kHexChars[] = "0123456789ABCDEF";
31
32  size_t length = message.length();
33  std::string result(4, '\0');
34  char b = reinterpret_cast<const char*>(&length)[1];
35  result[0] = kHexChars[(b >> 4) & 0xf];
36  result[1] = kHexChars[b & 0xf];
37  b = reinterpret_cast<const char*>(&length)[0];
38  result[2] = kHexChars[(b >> 4) & 0xf];
39  result[3] = kHexChars[b & 0xf];
40  return result + message;
41}
42
43class AdbTransportSocket : public AdbClientSocket {
44 public:
45  AdbTransportSocket(int port,
46                     const std::string& serial,
47                     const std::string& socket_name,
48                     const SocketCallback& callback)
49    : AdbClientSocket(port),
50      serial_(serial),
51      socket_name_(socket_name),
52      callback_(callback) {
53    Connect(base::Bind(&AdbTransportSocket::OnConnected,
54                       base::Unretained(this)));
55  }
56
57 private:
58  ~AdbTransportSocket() {}
59
60  void OnConnected(int result) {
61    if (!CheckNetResultOrDie(result))
62      return;
63    SendCommand(base::StringPrintf(kHostTransportCommand, serial_.c_str()),
64        true, base::Bind(&AdbTransportSocket::SendLocalAbstract,
65                         base::Unretained(this)));
66  }
67
68  void SendLocalAbstract(int result, const std::string& response) {
69    if (!CheckNetResultOrDie(result))
70      return;
71    SendCommand(base::StringPrintf(kLocalAbstractCommand, socket_name_.c_str()),
72        true, base::Bind(&AdbTransportSocket::OnSocketAvailable,
73                         base::Unretained(this)));
74  }
75
76  void OnSocketAvailable(int result, const std::string& response) {
77    if (!CheckNetResultOrDie(result))
78      return;
79    callback_.Run(net::OK, socket_.release());
80    delete this;
81  }
82
83  bool CheckNetResultOrDie(int result) {
84    if (result >= 0)
85      return true;
86    callback_.Run(result, NULL);
87    delete this;
88    return false;
89  }
90
91  std::string serial_;
92  std::string socket_name_;
93  SocketCallback callback_;
94};
95
96class HttpOverAdbSocket {
97 public:
98  HttpOverAdbSocket(int port,
99                    const std::string& serial,
100                    const std::string& socket_name,
101                    const std::string& request,
102                    const CommandCallback& callback)
103    : request_(request),
104      command_callback_(callback),
105      body_pos_(0) {
106    Connect(port, serial, socket_name);
107  }
108
109  HttpOverAdbSocket(int port,
110                    const std::string& serial,
111                    const std::string& socket_name,
112                    const std::string& request,
113                    const SocketCallback& callback)
114    : request_(request),
115      socket_callback_(callback),
116      body_pos_(0) {
117    Connect(port, serial, socket_name);
118  }
119
120 private:
121  ~HttpOverAdbSocket() {
122  }
123
124  void Connect(int port,
125               const std::string& serial,
126               const std::string& socket_name) {
127    AdbClientSocket::TransportQuery(
128        port, serial, socket_name,
129        base::Bind(&HttpOverAdbSocket::OnSocketAvailable,
130                   base::Unretained(this)));
131  }
132
133  void OnSocketAvailable(int result,
134                         net::StreamSocket* socket) {
135    if (!CheckNetResultOrDie(result))
136      return;
137
138    socket_.reset(socket);
139
140    scoped_refptr<net::StringIOBuffer> request_buffer =
141        new net::StringIOBuffer(request_);
142
143    result = socket_->Write(
144        request_buffer.get(),
145        request_buffer->size(),
146        base::Bind(&HttpOverAdbSocket::ReadResponse, base::Unretained(this)));
147    if (result != net::ERR_IO_PENDING)
148      ReadResponse(result);
149  }
150
151  void ReadResponse(int result) {
152    if (!CheckNetResultOrDie(result))
153      return;
154
155    scoped_refptr<net::IOBuffer> response_buffer =
156        new net::IOBuffer(kBufferSize);
157
158    result = socket_->Read(response_buffer.get(),
159                           kBufferSize,
160                           base::Bind(&HttpOverAdbSocket::OnResponseData,
161                                      base::Unretained(this),
162                                      response_buffer,
163                                      -1));
164    if (result != net::ERR_IO_PENDING)
165      OnResponseData(response_buffer, -1, result);
166  }
167
168  void OnResponseData(scoped_refptr<net::IOBuffer> response_buffer,
169                      int bytes_total,
170                      int result) {
171    if (!CheckNetResultOrDie(result))
172      return;
173    if (result == 0) {
174      CheckNetResultOrDie(net::ERR_CONNECTION_CLOSED);
175      return;
176    }
177
178    response_ += std::string(response_buffer->data(), result);
179    int expected_length = 0;
180    if (bytes_total < 0) {
181      size_t content_pos = response_.find("Content-Length:");
182      if (content_pos != std::string::npos) {
183        size_t endline_pos = response_.find("\n", content_pos);
184        if (endline_pos != std::string::npos) {
185          std::string len = response_.substr(content_pos + 15,
186                                             endline_pos - content_pos - 15);
187          base::TrimWhitespace(len, base::TRIM_ALL, &len);
188          if (!base::StringToInt(len, &expected_length)) {
189            CheckNetResultOrDie(net::ERR_FAILED);
190            return;
191          }
192        }
193      }
194
195      body_pos_ = response_.find("\r\n\r\n");
196      if (body_pos_ != std::string::npos) {
197        body_pos_ += 4;
198        bytes_total = body_pos_ + expected_length;
199      }
200    }
201
202    if (bytes_total == static_cast<int>(response_.length())) {
203      if (!command_callback_.is_null())
204        command_callback_.Run(body_pos_, response_);
205      else
206        socket_callback_.Run(net::OK, socket_.release());
207      delete this;
208      return;
209    }
210
211    result = socket_->Read(response_buffer.get(),
212                           kBufferSize,
213                           base::Bind(&HttpOverAdbSocket::OnResponseData,
214                                      base::Unretained(this),
215                                      response_buffer,
216                                      bytes_total));
217    if (result != net::ERR_IO_PENDING)
218      OnResponseData(response_buffer, bytes_total, result);
219  }
220
221  bool CheckNetResultOrDie(int result) {
222    if (result >= 0)
223      return true;
224    if (!command_callback_.is_null())
225      command_callback_.Run(result, std::string());
226    else
227      socket_callback_.Run(result, NULL);
228    delete this;
229    return false;
230  }
231
232  scoped_ptr<net::StreamSocket> socket_;
233  std::string request_;
234  std::string response_;
235  CommandCallback command_callback_;
236  SocketCallback socket_callback_;
237  size_t body_pos_;
238};
239
240class AdbQuerySocket : AdbClientSocket {
241 public:
242  AdbQuerySocket(int port,
243                 const std::string& query,
244                 const CommandCallback& callback)
245      : AdbClientSocket(port),
246        current_query_(0),
247        callback_(callback) {
248    if (Tokenize(query, "|", &queries_) == 0) {
249      CheckNetResultOrDie(net::ERR_INVALID_ARGUMENT);
250      return;
251    }
252    Connect(base::Bind(&AdbQuerySocket::SendNextQuery,
253                       base::Unretained(this)));
254  }
255
256 private:
257  ~AdbQuerySocket() {
258  }
259
260  void SendNextQuery(int result) {
261    if (!CheckNetResultOrDie(result))
262      return;
263    std::string query = queries_[current_query_];
264    if (query.length() > 0xFFFF) {
265      CheckNetResultOrDie(net::ERR_MSG_TOO_BIG);
266      return;
267    }
268    bool is_void = current_query_ < queries_.size() - 1;
269    SendCommand(query, is_void,
270        base::Bind(&AdbQuerySocket::OnResponse, base::Unretained(this)));
271  }
272
273  void OnResponse(int result, const std::string& response) {
274    if (++current_query_ < queries_.size()) {
275      SendNextQuery(net::OK);
276    } else {
277      callback_.Run(result, response);
278      delete this;
279    }
280  }
281
282  bool CheckNetResultOrDie(int result) {
283    if (result >= 0)
284      return true;
285    callback_.Run(result, std::string());
286    delete this;
287    return false;
288  }
289
290  std::vector<std::string> queries_;
291  size_t current_query_;
292  CommandCallback callback_;
293};
294
295}  // namespace
296
297// static
298void AdbClientSocket::AdbQuery(int port,
299                               const std::string& query,
300                               const CommandCallback& callback) {
301  new AdbQuerySocket(port, query, callback);
302}
303
304#if defined(DEBUG_DEVTOOLS)
305static void UseTransportQueryForDesktop(const SocketCallback& callback,
306                                        net::StreamSocket* socket,
307                                        int result) {
308  callback.Run(result, socket);
309}
310#endif  // defined(DEBUG_DEVTOOLS)
311
312// static
313void AdbClientSocket::TransportQuery(int port,
314                                     const std::string& serial,
315                                     const std::string& socket_name,
316                                     const SocketCallback& callback) {
317#if defined(DEBUG_DEVTOOLS)
318  if (serial.empty()) {
319    // Use plain socket for remote debugging on Desktop (debugging purposes).
320    net::IPAddressNumber ip_number;
321    net::ParseIPLiteralToNumber(kLocalhost, &ip_number);
322
323    int tcp_port = 0;
324    if (!base::StringToInt(socket_name, &tcp_port))
325      tcp_port = 9222;
326
327    net::AddressList address_list =
328        net::AddressList::CreateFromIPAddress(ip_number, tcp_port);
329    net::TCPClientSocket* socket = new net::TCPClientSocket(
330        address_list, NULL, net::NetLog::Source());
331    socket->Connect(base::Bind(&UseTransportQueryForDesktop, callback, socket));
332    return;
333  }
334#endif  // defined(DEBUG_DEVTOOLS)
335  new AdbTransportSocket(port, serial, socket_name, callback);
336}
337
338// static
339void AdbClientSocket::HttpQuery(int port,
340                                const std::string& serial,
341                                const std::string& socket_name,
342                                const std::string& request_path,
343                                const CommandCallback& callback) {
344  new HttpOverAdbSocket(port, serial, socket_name, request_path,
345      callback);
346}
347
348// static
349void AdbClientSocket::HttpQuery(int port,
350                                const std::string& serial,
351                                const std::string& socket_name,
352                                const std::string& request_path,
353                                const SocketCallback& callback) {
354  new HttpOverAdbSocket(port, serial, socket_name, request_path,
355      callback);
356}
357
358AdbClientSocket::AdbClientSocket(int port)
359    : host_(kLocalhost), port_(port) {
360}
361
362AdbClientSocket::~AdbClientSocket() {
363}
364
365void AdbClientSocket::Connect(const net::CompletionCallback& callback) {
366  net::IPAddressNumber ip_number;
367  if (!net::ParseIPLiteralToNumber(host_, &ip_number)) {
368    callback.Run(net::ERR_FAILED);
369    return;
370  }
371
372  net::AddressList address_list =
373      net::AddressList::CreateFromIPAddress(ip_number, port_);
374  socket_.reset(new net::TCPClientSocket(address_list, NULL,
375                                         net::NetLog::Source()));
376  int result = socket_->Connect(callback);
377  if (result != net::ERR_IO_PENDING)
378    callback.Run(result);
379}
380
381void AdbClientSocket::SendCommand(const std::string& command,
382                                  bool is_void,
383                                  const CommandCallback& callback) {
384  scoped_refptr<net::StringIOBuffer> request_buffer =
385      new net::StringIOBuffer(EncodeMessage(command));
386  int result = socket_->Write(request_buffer.get(),
387                              request_buffer->size(),
388                              base::Bind(&AdbClientSocket::ReadResponse,
389                                         base::Unretained(this),
390                                         callback,
391                                         is_void));
392  if (result != net::ERR_IO_PENDING)
393    ReadResponse(callback, is_void, result);
394}
395
396void AdbClientSocket::ReadResponse(const CommandCallback& callback,
397                                   bool is_void,
398                                   int result) {
399  if (result < 0) {
400    callback.Run(result, "IO error");
401    return;
402  }
403  scoped_refptr<net::IOBuffer> response_buffer =
404      new net::IOBuffer(kBufferSize);
405  result = socket_->Read(response_buffer.get(),
406                         kBufferSize,
407                         base::Bind(&AdbClientSocket::OnResponseHeader,
408                                    base::Unretained(this),
409                                    callback,
410                                    is_void,
411                                    response_buffer));
412  if (result != net::ERR_IO_PENDING)
413    OnResponseHeader(callback, is_void, response_buffer, result);
414}
415
416void AdbClientSocket::OnResponseHeader(
417    const CommandCallback& callback,
418    bool is_void,
419    scoped_refptr<net::IOBuffer> response_buffer,
420    int result) {
421  if (result <= 0) {
422    callback.Run(result == 0 ? net::ERR_CONNECTION_CLOSED : result,
423                 "IO error");
424    return;
425  }
426
427  std::string data = std::string(response_buffer->data(), result);
428  if (result < 4) {
429    callback.Run(net::ERR_FAILED, "Response is too short: " + data);
430    return;
431  }
432
433  std::string status = data.substr(0, 4);
434  if (status != kOkayResponse) {
435    callback.Run(net::ERR_FAILED, data);
436    return;
437  }
438
439  data = data.substr(4);
440
441  if (!is_void) {
442    int payload_length = 0;
443    int bytes_left = -1;
444    if (data.length() >= 4 &&
445        base::HexStringToInt(data.substr(0, 4), &payload_length)) {
446      data = data.substr(4);
447      bytes_left = payload_length - result + 8;
448    } else {
449      bytes_left = -1;
450    }
451    OnResponseData(callback, data, response_buffer, bytes_left, 0);
452  } else {
453    callback.Run(net::OK, data);
454  }
455}
456
457void AdbClientSocket::OnResponseData(
458    const CommandCallback& callback,
459    const std::string& response,
460    scoped_refptr<net::IOBuffer> response_buffer,
461    int bytes_left,
462    int result) {
463  if (result < 0) {
464    callback.Run(result, "IO error");
465    return;
466  }
467
468  bytes_left -= result;
469  std::string new_response =
470      response + std::string(response_buffer->data(), result);
471  if (bytes_left == 0) {
472    callback.Run(net::OK, new_response);
473    return;
474  }
475
476  // Read tail
477  result = socket_->Read(response_buffer.get(),
478                         kBufferSize,
479                         base::Bind(&AdbClientSocket::OnResponseData,
480                                    base::Unretained(this),
481                                    callback,
482                                    new_response,
483                                    response_buffer,
484                                    bytes_left));
485  if (result > 0)
486    OnResponseData(callback, new_response, response_buffer, bytes_left, result);
487  else if (result != net::ERR_IO_PENDING)
488    callback.Run(net::OK, new_response);
489}
490