1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "device/test/usb_test_gadget.h"
6
7#include <string>
8#include <vector>
9
10#include "base/command_line.h"
11#include "base/compiler_specific.h"
12#include "base/files/file.h"
13#include "base/files/file_path.h"
14#include "base/logging.h"
15#include "base/macros.h"
16#include "base/memory/ref_counted.h"
17#include "base/memory/scoped_ptr.h"
18#include "base/path_service.h"
19#include "base/process/process.h"
20#include "base/run_loop.h"
21#include "base/strings/string_number_conversions.h"
22#include "base/strings/stringprintf.h"
23#include "base/strings/utf_string_conversions.h"
24#include "base/threading/platform_thread.h"
25#include "base/time/time.h"
26#include "device/usb/usb_device.h"
27#include "device/usb/usb_device_handle.h"
28#include "device/usb/usb_service.h"
29#include "net/proxy/proxy_service.h"
30#include "net/url_request/url_fetcher.h"
31#include "net/url_request/url_fetcher_delegate.h"
32#include "net/url_request/url_request_context.h"
33#include "net/url_request/url_request_context_builder.h"
34#include "net/url_request/url_request_context_getter.h"
35#include "url/gurl.h"
36
37using ::base::PlatformThread;
38using ::base::TimeDelta;
39
40namespace device {
41
42namespace {
43
44static const char kCommandLineSwitch[] = "enable-gadget-tests";
45static const int kClaimRetries = 100;  // 5 seconds
46static const int kDisconnectRetries = 100;  // 5 seconds
47static const int kRetryPeriod = 50;  // 0.05 seconds
48static const int kReconnectRetries = 100;  // 5 seconds
49static const int kUpdateRetries = 100;  // 5 seconds
50
51struct UsbTestGadgetConfiguration {
52  UsbTestGadget::Type type;
53  const char* http_resource;
54  uint16 product_id;
55};
56
57static const struct UsbTestGadgetConfiguration kConfigurations[] = {
58    {UsbTestGadget::DEFAULT, "/unconfigure", 0x58F0},
59    {UsbTestGadget::KEYBOARD, "/keyboard/configure", 0x58F1},
60    {UsbTestGadget::MOUSE, "/mouse/configure", 0x58F2},
61    {UsbTestGadget::HID_ECHO, "/hid_echo/configure", 0x58F3},
62    {UsbTestGadget::ECHO, "/echo/configure", 0x58F4},
63};
64
65class UsbTestGadgetImpl : public UsbTestGadget {
66 public:
67  virtual ~UsbTestGadgetImpl();
68
69  virtual bool Unclaim() OVERRIDE;
70  virtual bool Disconnect() OVERRIDE;
71  virtual bool Reconnect() OVERRIDE;
72  virtual bool SetType(Type type) OVERRIDE;
73  virtual UsbDevice* GetDevice() const OVERRIDE;
74  virtual std::string GetSerialNumber() const OVERRIDE;
75
76 protected:
77  UsbTestGadgetImpl();
78
79 private:
80  scoped_ptr<net::URLFetcher> CreateURLFetcher(
81      const GURL& url,
82      net::URLFetcher::RequestType request_type,
83      net::URLFetcherDelegate* delegate);
84  int SimplePOSTRequest(const GURL& url, const std::string& form_data);
85  bool FindUnclaimed();
86  bool GetVersion(std::string* version);
87  bool Update();
88  bool FindClaimed();
89  bool ReadLocalVersion(std::string* version);
90  bool ReadLocalPackage(std::string* package);
91  bool ReadFile(const base::FilePath& file_path, std::string* content);
92
93  class Delegate : public net::URLFetcherDelegate {
94   public:
95    Delegate() {}
96    virtual ~Delegate() {}
97
98    void WaitForCompletion() {
99      run_loop_.Run();
100    }
101
102    virtual void OnURLFetchComplete(const net::URLFetcher* source) OVERRIDE {
103      run_loop_.Quit();
104    }
105
106   private:
107    base::RunLoop run_loop_;
108
109    DISALLOW_COPY_AND_ASSIGN(Delegate);
110  };
111
112  scoped_refptr<UsbDevice> device_;
113  std::string device_address_;
114  scoped_ptr<net::URLRequestContext> request_context_;
115  std::string session_id_;
116  UsbService* usb_service_;
117
118  friend class UsbTestGadget;
119
120  DISALLOW_COPY_AND_ASSIGN(UsbTestGadgetImpl);
121};
122
123}  // namespace
124
125bool UsbTestGadget::IsTestEnabled() {
126  base::CommandLine* command_line = CommandLine::ForCurrentProcess();
127  return command_line->HasSwitch(kCommandLineSwitch);
128}
129
130scoped_ptr<UsbTestGadget> UsbTestGadget::Claim() {
131  scoped_ptr<UsbTestGadgetImpl> gadget(new UsbTestGadgetImpl);
132
133  int retries = kClaimRetries;
134  while (!gadget->FindUnclaimed()) {
135    if (--retries == 0) {
136      LOG(ERROR) << "Failed to find an unclaimed device.";
137      return scoped_ptr<UsbTestGadget>();
138    }
139    PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod));
140  }
141  VLOG(1) << "It took " << (kClaimRetries - retries)
142          << " retries to find an unclaimed device.";
143
144  return gadget.PassAs<UsbTestGadget>();
145}
146
147UsbTestGadgetImpl::UsbTestGadgetImpl() {
148  net::URLRequestContextBuilder context_builder;
149  context_builder.set_proxy_service(net::ProxyService::CreateDirect());
150  request_context_.reset(context_builder.Build());
151
152  base::ProcessId process_id = base::Process::Current().pid();
153  session_id_ = base::StringPrintf(
154      "%s:%p", base::HexEncode(&process_id, sizeof(process_id)).c_str(), this);
155
156  usb_service_ = UsbService::GetInstance(NULL);
157}
158
159UsbTestGadgetImpl::~UsbTestGadgetImpl() {
160  if (!device_address_.empty()) {
161    Unclaim();
162  }
163}
164
165UsbDevice* UsbTestGadgetImpl::GetDevice() const {
166  return device_.get();
167}
168
169std::string UsbTestGadgetImpl::GetSerialNumber() const {
170  return device_address_;
171}
172
173scoped_ptr<net::URLFetcher> UsbTestGadgetImpl::CreateURLFetcher(
174    const GURL& url, net::URLFetcher::RequestType request_type,
175    net::URLFetcherDelegate* delegate) {
176  scoped_ptr<net::URLFetcher> url_fetcher(
177      net::URLFetcher::Create(url, request_type, delegate));
178
179  url_fetcher->SetRequestContext(
180      new net::TrivialURLRequestContextGetter(
181          request_context_.get(),
182          base::MessageLoop::current()->message_loop_proxy()));
183
184  return url_fetcher.PassAs<net::URLFetcher>();
185}
186
187int UsbTestGadgetImpl::SimplePOSTRequest(const GURL& url,
188                                         const std::string& form_data) {
189  Delegate delegate;
190  scoped_ptr<net::URLFetcher> url_fetcher =
191    CreateURLFetcher(url, net::URLFetcher::POST, &delegate);
192
193  url_fetcher->SetUploadData("application/x-www-form-urlencoded", form_data);
194  url_fetcher->Start();
195  delegate.WaitForCompletion();
196
197  return url_fetcher->GetResponseCode();
198}
199
200bool UsbTestGadgetImpl::FindUnclaimed() {
201  std::vector<scoped_refptr<UsbDevice> > devices;
202  usb_service_->GetDevices(&devices);
203
204  for (std::vector<scoped_refptr<UsbDevice> >::const_iterator iter =
205         devices.begin(); iter != devices.end(); ++iter) {
206    const scoped_refptr<UsbDevice> &device = *iter;
207    if (device->vendor_id() == 0x18D1 && device->product_id() == 0x58F0) {
208      base::string16 serial_utf16;
209      if (!device->GetSerialNumber(&serial_utf16)) {
210        continue;
211      }
212
213      const std::string serial = base::UTF16ToUTF8(serial_utf16);
214      const GURL url("http://" + serial + "/claim");
215      const std::string form_data = base::StringPrintf(
216          "session_id=%s",
217          net::EscapeUrlEncodedData(session_id_, true).c_str());
218      const int response_code = SimplePOSTRequest(url, form_data);
219
220      if (response_code == 200) {
221        device_address_ = serial;
222        device_ = device;
223        break;
224      }
225
226      // The device is probably claimed by another process.
227      if (response_code != 403) {
228        LOG(WARNING) << "Unexpected HTTP " << response_code << " from /claim.";
229      }
230    }
231  }
232
233  std::string local_version;
234  std::string version;
235  if (!ReadLocalVersion(&local_version) ||
236      !GetVersion(&version)) {
237    return false;
238  }
239
240  if (version == local_version) {
241    return true;
242  }
243
244  return Update();
245}
246
247bool UsbTestGadgetImpl::GetVersion(std::string* version) {
248  Delegate delegate;
249  const GURL url("http://" + device_address_ + "/version");
250  scoped_ptr<net::URLFetcher> url_fetcher =
251      CreateURLFetcher(url, net::URLFetcher::GET, &delegate);
252
253  url_fetcher->Start();
254  delegate.WaitForCompletion();
255
256  const int response_code = url_fetcher->GetResponseCode();
257  if (response_code != 200) {
258    VLOG(2) << "Unexpected HTTP " << response_code << " from /version.";
259    return false;
260  }
261
262  STLClearObject(version);
263  if (!url_fetcher->GetResponseAsString(version)) {
264    VLOG(2) << "Failed to read body from /version.";
265    return false;
266  }
267  return true;
268}
269
270bool UsbTestGadgetImpl::Update() {
271  std::string version;
272  if (!ReadLocalVersion(&version)) {
273    return false;
274  }
275  LOG(INFO) << "Updating " << device_address_ << " to " << version << "...";
276
277  Delegate delegate;
278  const GURL url("http://" + device_address_ + "/update");
279  scoped_ptr<net::URLFetcher> url_fetcher =
280      CreateURLFetcher(url, net::URLFetcher::POST, &delegate);
281
282  const std::string mime_header =
283      base::StringPrintf(
284      "--foo\r\n"
285      "Content-Disposition: form-data; name=\"file\"; "
286          "filename=\"usb_gadget-%s.zip\"\r\n"
287      "Content-Type: application/octet-stream\r\n"
288      "\r\n", version.c_str());
289  const std::string mime_footer("\r\n--foo--\r\n");
290
291  std::string package;
292  if (!ReadLocalPackage(&package)) {
293    return false;
294  }
295
296  url_fetcher->SetUploadData("multipart/form-data; boundary=foo",
297                             mime_header + package + mime_footer);
298  url_fetcher->Start();
299  delegate.WaitForCompletion();
300
301  const int response_code = url_fetcher->GetResponseCode();
302  if (response_code != 200) {
303    LOG(ERROR) << "Unexpected HTTP " << response_code << " from /update.";
304    return false;
305  }
306
307  int retries = kUpdateRetries;
308  std::string new_version;
309  while (!GetVersion(&new_version) || new_version != version) {
310    if (--retries == 0) {
311      LOG(ERROR) << "Device not responding with new version.";
312      return false;
313    }
314    PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod));
315  }
316  VLOG(1) << "It took " << (kUpdateRetries - retries)
317          << " retries to see the new version.";
318
319  // Release the old reference to the device and try to open a new one.
320  device_ = NULL;
321  retries = kReconnectRetries;
322  while (!FindClaimed()) {
323    if (--retries == 0) {
324      LOG(ERROR) << "Failed to find updated device.";
325      return false;
326    }
327    PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod));
328  }
329  VLOG(1) << "It took " << (kReconnectRetries - retries)
330          << " retries to find the updated device.";
331
332  return true;
333}
334
335bool UsbTestGadgetImpl::FindClaimed() {
336  CHECK(!device_.get());
337
338  std::string expected_serial = GetSerialNumber();
339
340  std::vector<scoped_refptr<UsbDevice> > devices;
341  usb_service_->GetDevices(&devices);
342
343  for (std::vector<scoped_refptr<UsbDevice> >::iterator iter =
344         devices.begin(); iter != devices.end(); ++iter) {
345    scoped_refptr<UsbDevice> &device = *iter;
346
347    if (device->vendor_id() == 0x18D1) {
348      const uint16 product_id = device->product_id();
349      bool found = false;
350      for (size_t i = 0; i < arraysize(kConfigurations); ++i) {
351        if (product_id == kConfigurations[i].product_id) {
352          found = true;
353          break;
354        }
355      }
356      if (!found) {
357        continue;
358      }
359
360      base::string16 serial_utf16;
361      if (!device->GetSerialNumber(&serial_utf16)) {
362        continue;
363      }
364
365      std::string serial = base::UTF16ToUTF8(serial_utf16);
366      if (serial != expected_serial) {
367        continue;
368      }
369
370      device_ = device;
371      return true;
372    }
373  }
374
375  return false;
376}
377
378bool UsbTestGadgetImpl::ReadLocalVersion(std::string* version) {
379  base::FilePath file_path;
380  CHECK(PathService::Get(base::DIR_EXE, &file_path));
381  file_path = file_path.AppendASCII("usb_gadget.zip.md5");
382
383  return ReadFile(file_path, version);
384}
385
386bool UsbTestGadgetImpl::ReadLocalPackage(std::string* package) {
387  base::FilePath file_path;
388  CHECK(PathService::Get(base::DIR_EXE, &file_path));
389  file_path = file_path.AppendASCII("usb_gadget.zip");
390
391  return ReadFile(file_path, package);
392}
393
394bool UsbTestGadgetImpl::ReadFile(const base::FilePath& file_path,
395                                 std::string* content) {
396  base::File file(file_path, base::File::FLAG_OPEN | base::File::FLAG_READ);
397  if (!file.IsValid()) {
398    LOG(ERROR) << "Cannot open " << file_path.MaybeAsASCII() << ": "
399               << base::File::ErrorToString(file.error_details());
400    return false;
401  }
402
403  STLClearObject(content);
404  int rv;
405  do {
406    char buf[4096];
407    rv = file.ReadAtCurrentPos(buf, sizeof buf);
408    if (rv == -1) {
409      LOG(ERROR) << "Cannot read " << file_path.MaybeAsASCII() << ": "
410                 << base::File::ErrorToString(file.error_details());
411      return false;
412    }
413    content->append(buf, rv);
414  } while (rv > 0);
415
416  return true;
417}
418
419bool UsbTestGadgetImpl::Unclaim() {
420  VLOG(1) << "Releasing the device at " << device_address_ << ".";
421
422  const GURL url("http://" + device_address_ + "/unclaim");
423  const int response_code = SimplePOSTRequest(url, "");
424
425  if (response_code != 200) {
426    LOG(ERROR) << "Unexpected HTTP " << response_code << " from /unclaim.";
427    return false;
428  }
429  return true;
430}
431
432bool UsbTestGadgetImpl::SetType(Type type) {
433  const struct UsbTestGadgetConfiguration* config = NULL;
434  for (size_t i = 0; i < arraysize(kConfigurations); ++i) {
435    if (kConfigurations[i].type == type) {
436      config = &kConfigurations[i];
437    }
438  }
439  CHECK(config);
440
441  const GURL url("http://" + device_address_ + config->http_resource);
442  const int response_code = SimplePOSTRequest(url, "");
443
444  if (response_code != 200) {
445    LOG(ERROR) << "Unexpected HTTP " << response_code
446               << " from " << config->http_resource << ".";
447    return false;
448  }
449
450  // Release the old reference to the device and try to open a new one.
451  int retries = kReconnectRetries;
452  while (true) {
453    device_ = NULL;
454    if (FindClaimed() && device_->product_id() == config->product_id) {
455      break;
456    }
457    if (--retries == 0) {
458      LOG(ERROR) << "Failed to find updated device.";
459      return false;
460    }
461    PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod));
462  }
463  VLOG(1) << "It took " << (kReconnectRetries - retries)
464          << " retries to find the updated device.";
465
466  return true;
467}
468
469bool UsbTestGadgetImpl::Disconnect() {
470  const GURL url("http://" + device_address_ + "/disconnect");
471  const int response_code = SimplePOSTRequest(url, "");
472
473  if (response_code != 200) {
474    LOG(ERROR) << "Unexpected HTTP " << response_code << " from /disconnect.";
475    return false;
476  }
477
478  // Release the old reference to the device and wait until it can't be found.
479  int retries = kDisconnectRetries;
480  while (true) {
481    device_ = NULL;
482    if (!FindClaimed()) {
483      break;
484    }
485    if (--retries == 0) {
486      LOG(ERROR) << "Device did not disconnect.";
487      return false;
488    }
489    PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod));
490  }
491  VLOG(1) << "It took " << (kDisconnectRetries - retries)
492          << " retries for the device to disconnect.";
493
494  return true;
495}
496
497bool UsbTestGadgetImpl::Reconnect() {
498  const GURL url("http://" + device_address_ + "/reconnect");
499  const int response_code = SimplePOSTRequest(url, "");
500
501  if (response_code != 200) {
502    LOG(ERROR) << "Unexpected HTTP " << response_code << " from /reconnect.";
503    return false;
504  }
505
506  int retries = kDisconnectRetries;
507  while (true) {
508    if (FindClaimed()) {
509      break;
510    }
511    if (--retries == 0) {
512      LOG(ERROR) << "Device did not reconnect.";
513      return false;
514    }
515    PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod));
516  }
517  VLOG(1) << "It took " << (kDisconnectRetries - retries)
518          << " retries for the device to reconnect.";
519
520  return true;
521}
522
523}  // namespace device
524