1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/test/spawned_test_server/spawner_communicator.h"
6
7#include "base/json/json_reader.h"
8#include "base/logging.h"
9#include "base/strings/stringprintf.h"
10#include "base/supports_user_data.h"
11#include "base/test/test_timeouts.h"
12#include "base/time/time.h"
13#include "base/values.h"
14#include "build/build_config.h"
15#include "net/base/net_util.h"
16#include "net/base/request_priority.h"
17#include "net/base/upload_bytes_element_reader.h"
18#include "net/base/upload_data_stream.h"
19#include "net/http/http_response_headers.h"
20#include "net/url_request/url_request_test_util.h"
21#include "url/gurl.h"
22
23namespace net {
24
25namespace {
26
27GURL GenerateSpawnerCommandURL(const std::string& command, uint16 port) {
28  // Always performs HTTP request for sending command to the spawner server.
29  return GURL(base::StringPrintf("%s:%u/%s", "http://127.0.0.1", port,
30                                 command.c_str()));
31}
32
33int kBufferSize = 2048;
34
35// A class to hold all data needed to send a command to spawner server.
36class SpawnerRequestData : public base::SupportsUserData::Data {
37 public:
38  SpawnerRequestData(int id, int* result_code, std::string* data_received)
39      : request_id_(id),
40        buf_(new IOBuffer(kBufferSize)),
41        result_code_(result_code),
42        data_received_(data_received),
43        response_started_count_(0) {
44    DCHECK(result_code);
45    *result_code_ = OK;
46    DCHECK(data_received);
47    data_received_->clear();
48  }
49
50  virtual ~SpawnerRequestData() {}
51
52  bool DoesRequestIdMatch(int request_id) const {
53    return request_id_ == request_id;
54  }
55
56  IOBuffer* buf() const { return buf_.get(); }
57
58  bool IsResultOK() const { return *result_code_ == OK; }
59
60  void ClearReceivedData() { data_received_->clear(); }
61
62  void SetResultCode(int result_code) { *result_code_ = result_code; }
63
64  void IncreaseResponseStartedCount() { response_started_count_++; }
65
66  int response_started_count() const { return response_started_count_; }
67
68  // Write data read from URLRequest::Read() to |data_received_|. Returns true
69  // if |num_bytes| is great than 0. |num_bytes| is 0 for EOF, < 0 on errors.
70  bool ConsumeBytesRead(int num_bytes) {
71    // Error while reading, or EOF.
72    if (num_bytes <= 0)
73      return false;
74
75    data_received_->append(buf_->data(), num_bytes);
76    return true;
77  }
78
79 private:
80  // Unique ID for the current request.
81  int request_id_;
82
83  // Buffer that URLRequest writes into.
84  scoped_refptr<IOBuffer> buf_;
85
86  // Holds the error condition that was hit on the current request, or OK.
87  int* result_code_;
88
89  // Data received from server;
90  std::string* data_received_;
91
92  // Used to track how many times the OnResponseStarted get called after
93  // sending a command to spawner server.
94  int response_started_count_;
95
96  DISALLOW_COPY_AND_ASSIGN(SpawnerRequestData);
97};
98
99}  // namespace
100
101SpawnerCommunicator::SpawnerCommunicator(uint16 port)
102    : io_thread_("spawner_communicator"),
103      event_(false, false),
104      port_(port),
105      next_id_(0),
106      is_running_(false),
107      weak_factory_(this) {}
108
109SpawnerCommunicator::~SpawnerCommunicator() {
110  DCHECK(!is_running_);
111}
112
113void SpawnerCommunicator::WaitForResponse() {
114  DCHECK_NE(base::MessageLoop::current(), io_thread_.message_loop());
115  event_.Wait();
116  event_.Reset();
117}
118
119void SpawnerCommunicator::StartIOThread() {
120  DCHECK_NE(base::MessageLoop::current(), io_thread_.message_loop());
121  if (is_running_)
122    return;
123
124  allowed_port_.reset(new ScopedPortException(port_));
125  base::Thread::Options options;
126  options.message_loop_type = base::MessageLoop::TYPE_IO;
127  is_running_ = io_thread_.StartWithOptions(options);
128  DCHECK(is_running_);
129}
130
131void SpawnerCommunicator::Shutdown() {
132  DCHECK_NE(base::MessageLoop::current(), io_thread_.message_loop());
133  DCHECK(is_running_);
134  // The request and its context should be created and destroyed only on the
135  // IO thread.
136  DCHECK(!cur_request_.get());
137  DCHECK(!context_.get());
138  is_running_ = false;
139  io_thread_.Stop();
140  allowed_port_.reset();
141}
142
143void SpawnerCommunicator::SendCommandAndWaitForResult(
144    const std::string& command,
145    const std::string& post_data,
146    int* result_code,
147    std::string* data_received) {
148  if (!result_code || !data_received)
149    return;
150  // Start the communicator thread to talk to test server spawner.
151  StartIOThread();
152  DCHECK(io_thread_.message_loop());
153
154  // Since the method will be blocked until SpawnerCommunicator gets result
155  // from the spawner server or timed-out. It's safe to use base::Unretained
156  // when using base::Bind.
157  io_thread_.message_loop()->PostTask(FROM_HERE, base::Bind(
158      &SpawnerCommunicator::SendCommandAndWaitForResultOnIOThread,
159      base::Unretained(this), command, post_data, result_code, data_received));
160  WaitForResponse();
161}
162
163void SpawnerCommunicator::SendCommandAndWaitForResultOnIOThread(
164    const std::string& command,
165    const std::string& post_data,
166    int* result_code,
167    std::string* data_received) {
168  base::MessageLoop* loop = io_thread_.message_loop();
169  DCHECK(loop);
170  DCHECK_EQ(base::MessageLoop::current(), loop);
171
172  // Prepare the URLRequest for sending the command.
173  DCHECK(!cur_request_.get());
174  context_.reset(new TestURLRequestContext);
175  cur_request_ = context_->CreateRequest(
176      GenerateSpawnerCommandURL(command, port_), DEFAULT_PRIORITY, this, NULL);
177  DCHECK(cur_request_);
178  int current_request_id = ++next_id_;
179  SpawnerRequestData* data = new SpawnerRequestData(current_request_id,
180                                                    result_code,
181                                                    data_received);
182  DCHECK(data);
183  cur_request_->SetUserData(this, data);
184
185  if (post_data.empty()) {
186    cur_request_->set_method("GET");
187  } else {
188    cur_request_->set_method("POST");
189    scoped_ptr<UploadElementReader> reader(
190        UploadOwnedBytesElementReader::CreateWithString(post_data));
191    cur_request_->set_upload(make_scoped_ptr(
192        UploadDataStream::CreateWithReader(reader.Pass(), 0)));
193    net::HttpRequestHeaders headers;
194    headers.SetHeader(net::HttpRequestHeaders::kContentType,
195                      "application/json");
196    cur_request_->SetExtraRequestHeaders(headers);
197  }
198
199  // Post a task to timeout this request if it takes too long.
200  base::MessageLoop::current()->PostDelayedTask(
201      FROM_HERE,
202      base::Bind(&SpawnerCommunicator::OnTimeout,
203                 weak_factory_.GetWeakPtr(),
204                 current_request_id),
205      TestTimeouts::action_max_timeout());
206
207  // Start the request.
208  cur_request_->Start();
209}
210
211void SpawnerCommunicator::OnTimeout(int id) {
212  // Timeout tasks may outlive the URLRequest they reference. Make sure it
213  // is still applicable.
214  if (!cur_request_.get())
215    return;
216  SpawnerRequestData* data =
217      static_cast<SpawnerRequestData*>(cur_request_->GetUserData(this));
218  DCHECK(data);
219
220  if (!data->DoesRequestIdMatch(id))
221    return;
222  // Set the result code and cancel the timed-out task.
223  data->SetResultCode(ERR_TIMED_OUT);
224  cur_request_->Cancel();
225  OnSpawnerCommandCompleted(cur_request_.get());
226}
227
228void SpawnerCommunicator::OnSpawnerCommandCompleted(URLRequest* request) {
229  if (!cur_request_.get())
230    return;
231  DCHECK_EQ(request, cur_request_.get());
232  SpawnerRequestData* data =
233      static_cast<SpawnerRequestData*>(cur_request_->GetUserData(this));
234  DCHECK(data);
235
236  // If request is faild,return the error code.
237  if (!cur_request_->status().is_success())
238    data->SetResultCode(cur_request_->status().error());
239
240  if (!data->IsResultOK()) {
241    LOG(ERROR) << "request failed, status: "
242               << static_cast<int>(request->status().status())
243               << ", error: " << request->status().error();
244    // Clear the buffer of received data if any net error happened.
245    data->ClearReceivedData();
246  } else {
247    DCHECK_EQ(1, data->response_started_count());
248  }
249
250  // Clear current request to indicate the completion of sending a command
251  // to spawner server and getting the result.
252  cur_request_.reset();
253  context_.reset();
254  // Invalidate the weak pointers on the IO thread.
255  weak_factory_.InvalidateWeakPtrs();
256
257  // Wakeup the caller in user thread.
258  event_.Signal();
259}
260
261void SpawnerCommunicator::ReadResult(URLRequest* request) {
262  DCHECK_EQ(request, cur_request_.get());
263  SpawnerRequestData* data =
264      static_cast<SpawnerRequestData*>(cur_request_->GetUserData(this));
265  DCHECK(data);
266
267  IOBuffer* buf = data->buf();
268  // Read as many bytes as are available synchronously.
269  while (true) {
270    int num_bytes;
271    if (!request->Read(buf, kBufferSize, &num_bytes)) {
272      // Check whether the read failed synchronously.
273      if (!request->status().is_io_pending())
274        OnSpawnerCommandCompleted(request);
275      return;
276    }
277    if (!data->ConsumeBytesRead(num_bytes)) {
278      OnSpawnerCommandCompleted(request);
279      return;
280    }
281  }
282}
283
284void SpawnerCommunicator::OnResponseStarted(URLRequest* request) {
285  DCHECK_EQ(request, cur_request_.get());
286  SpawnerRequestData* data =
287      static_cast<SpawnerRequestData*>(cur_request_->GetUserData(this));
288  DCHECK(data);
289
290  data->IncreaseResponseStartedCount();
291
292  if (!request->status().is_success()) {
293    OnSpawnerCommandCompleted(request);
294    return;
295  }
296
297  // Require HTTP responses to have a success status code.
298  if (request->GetResponseCode() != 200) {
299    LOG(ERROR) << "Spawner server returned bad status: "
300               << request->response_headers()->GetStatusLine();
301    data->SetResultCode(ERR_FAILED);
302    request->Cancel();
303    OnSpawnerCommandCompleted(request);
304    return;
305  }
306
307  ReadResult(request);
308}
309
310void SpawnerCommunicator::OnReadCompleted(URLRequest* request, int num_bytes) {
311  if (!cur_request_.get())
312    return;
313  DCHECK_EQ(request, cur_request_.get());
314  SpawnerRequestData* data =
315      static_cast<SpawnerRequestData*>(cur_request_->GetUserData(this));
316  DCHECK(data);
317
318  if (data->ConsumeBytesRead(num_bytes)) {
319    // Keep reading.
320    ReadResult(request);
321  } else {
322    OnSpawnerCommandCompleted(request);
323  }
324}
325
326bool SpawnerCommunicator::StartServer(const std::string& arguments,
327                                      uint16* port) {
328  *port = 0;
329  // Send the start command to spawner server to start the Python test server
330  // on remote machine.
331  std::string server_return_data;
332  int result_code;
333  SendCommandAndWaitForResult("start", arguments, &result_code,
334                              &server_return_data);
335  if (OK != result_code || server_return_data.empty())
336    return false;
337
338  // Check whether the data returned from spawner server is JSON-formatted.
339  scoped_ptr<base::Value> value(base::JSONReader::Read(server_return_data));
340  if (!value.get() || !value->IsType(base::Value::TYPE_DICTIONARY)) {
341    LOG(ERROR) << "Invalid server data: " << server_return_data.c_str();
342    return false;
343  }
344
345  // Check whether spawner server returns valid data.
346  base::DictionaryValue* server_data =
347      static_cast<base::DictionaryValue*>(value.get());
348  std::string message;
349  if (!server_data->GetString("message", &message) || message != "started") {
350    LOG(ERROR) << "Invalid message in server data: ";
351    return false;
352  }
353  int int_port;
354  if (!server_data->GetInteger("port", &int_port) || int_port <= 0 ||
355      int_port > kuint16max) {
356    LOG(ERROR) << "Invalid port value: " << int_port;
357    return false;
358  }
359  *port = static_cast<uint16>(int_port);
360  return true;
361}
362
363bool SpawnerCommunicator::StopServer() {
364  // It's OK to stop the SpawnerCommunicator without starting it. Some tests
365  // have test server on their test fixture but do not actually use it.
366  if (!is_running_)
367    return true;
368
369  // When the test is done, ask the test server spawner to kill the test server
370  // on the remote machine.
371  std::string server_return_data;
372  int result_code;
373  SendCommandAndWaitForResult("kill", "", &result_code, &server_return_data);
374  Shutdown();
375  if (OK != result_code || server_return_data != "killed")
376    return false;
377  return true;
378}
379
380}  // namespace net
381