1// Copyright (c) 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/browser/policy/cloud/test_request_interceptor.h"
6
7#include <limits>
8#include <queue>
9
10#include "base/bind.h"
11#include "base/bind_helpers.h"
12#include "base/memory/scoped_ptr.h"
13#include "base/run_loop.h"
14#include "base/sequenced_task_runner.h"
15#include "content/public/browser/browser_thread.h"
16#include "net/base/net_errors.h"
17#include "net/base/upload_bytes_element_reader.h"
18#include "net/base/upload_data_stream.h"
19#include "net/base/upload_element_reader.h"
20#include "net/test/url_request/url_request_mock_http_job.h"
21#include "net/url_request/url_request_error_job.h"
22#include "net/url_request/url_request_filter.h"
23#include "net/url_request/url_request_interceptor.h"
24#include "net/url_request/url_request_test_job.h"
25#include "url/gurl.h"
26
27namespace em = enterprise_management;
28
29namespace policy {
30
31namespace {
32
33// Helper callback for jobs that should fail with a network |error|.
34net::URLRequestJob* ErrorJobCallback(int error,
35                                     net::URLRequest* request,
36                                     net::NetworkDelegate* network_delegate) {
37  return new net::URLRequestErrorJob(request, network_delegate, error);
38}
39
40// Helper callback for jobs that should fail with a 400 HTTP error.
41net::URLRequestJob* BadRequestJobCallback(
42    net::URLRequest* request,
43    net::NetworkDelegate* network_delegate) {
44  static const char kBadHeaders[] =
45      "HTTP/1.1 400 Bad request\0"
46      "Content-type: application/protobuf\0"
47      "\0";
48  std::string headers(kBadHeaders, arraysize(kBadHeaders));
49  return new net::URLRequestTestJob(
50      request, network_delegate, headers, std::string(), true);
51}
52
53net::URLRequestJob* FileJobCallback(const base::FilePath& file_path,
54                                    net::URLRequest* request,
55                                    net::NetworkDelegate* network_delegate) {
56  return new net::URLRequestMockHTTPJob(
57      request,
58      network_delegate,
59      file_path,
60      content::BrowserThread::GetBlockingPool()
61          ->GetTaskRunnerWithShutdownBehavior(
62              base::SequencedWorkerPool::SKIP_ON_SHUTDOWN));
63}
64
65// Parses the upload data in |request| into |request_msg|, and validates the
66// request. The query string in the URL must contain the |expected_type| for
67// the "request" parameter. Returns true if all checks succeeded, and the
68// request data has been parsed into |request_msg|.
69bool ValidRequest(net::URLRequest* request,
70                  const std::string& expected_type,
71                  em::DeviceManagementRequest* request_msg) {
72  if (request->method() != "POST")
73    return false;
74  std::string spec = request->url().spec();
75  if (spec.find("request=" + expected_type) == std::string::npos)
76    return false;
77
78  // This assumes that the payload data was set from a single string. In that
79  // case the UploadDataStream has a single UploadBytesElementReader with the
80  // data in memory.
81  const net::UploadDataStream* stream = request->get_upload();
82  if (!stream)
83    return false;
84  const ScopedVector<net::UploadElementReader>& readers =
85      stream->element_readers();
86  if (readers.size() != 1u)
87    return false;
88  const net::UploadBytesElementReader* reader = readers[0]->AsBytesReader();
89  if (!reader)
90    return false;
91  std::string data(reader->bytes(), reader->length());
92  if (!request_msg->ParseFromString(data))
93    return false;
94
95  return true;
96}
97
98// Helper callback for register jobs that should suceed. Validates the request
99// parameters and returns an appropriate response job. If |expect_reregister|
100// is true then the reregister flag must be set in the DeviceRegisterRequest
101// protobuf.
102net::URLRequestJob* RegisterJobCallback(
103    em::DeviceRegisterRequest::Type expected_type,
104    bool expect_reregister,
105    net::URLRequest* request,
106    net::NetworkDelegate* network_delegate) {
107  em::DeviceManagementRequest request_msg;
108  if (!ValidRequest(request, "register", &request_msg))
109    return BadRequestJobCallback(request, network_delegate);
110
111  if (!request_msg.has_register_request() ||
112      request_msg.has_unregister_request() ||
113      request_msg.has_policy_request() ||
114      request_msg.has_device_status_report_request() ||
115      request_msg.has_session_status_report_request() ||
116      request_msg.has_auto_enrollment_request()) {
117    return BadRequestJobCallback(request, network_delegate);
118  }
119
120  const em::DeviceRegisterRequest& register_request =
121      request_msg.register_request();
122  if (expect_reregister &&
123      (!register_request.has_reregister() || !register_request.reregister())) {
124    return BadRequestJobCallback(request, network_delegate);
125  } else if (!expect_reregister &&
126             register_request.has_reregister() &&
127             register_request.reregister()) {
128    return BadRequestJobCallback(request, network_delegate);
129  }
130
131  if (!register_request.has_type() || register_request.type() != expected_type)
132    return BadRequestJobCallback(request, network_delegate);
133
134  em::DeviceManagementResponse response;
135  em::DeviceRegisterResponse* register_response =
136      response.mutable_register_response();
137  register_response->set_device_management_token("s3cr3t70k3n");
138  std::string data;
139  response.SerializeToString(&data);
140
141  static const char kGoodHeaders[] =
142      "HTTP/1.1 200 OK\0"
143      "Content-type: application/protobuf\0"
144      "\0";
145  std::string headers(kGoodHeaders, arraysize(kGoodHeaders));
146  return new net::URLRequestTestJob(
147      request, network_delegate, headers, data, true);
148}
149
150void RegisterHttpInterceptor(
151    const std::string& hostname,
152    scoped_ptr<net::URLRequestInterceptor> interceptor) {
153  net::URLRequestFilter::GetInstance()->AddHostnameInterceptor(
154      "http", hostname, interceptor.Pass());
155}
156
157}  // namespace
158
159class TestRequestInterceptor::Delegate : public net::URLRequestInterceptor {
160 public:
161  Delegate(const std::string& hostname,
162           scoped_refptr<base::SequencedTaskRunner> io_task_runner);
163  virtual ~Delegate();
164
165  // net::URLRequestInterceptor implementation:
166  virtual net::URLRequestJob* MaybeInterceptRequest(
167      net::URLRequest* request,
168      net::NetworkDelegate* network_delegate) const OVERRIDE;
169
170  void GetPendingSize(size_t* pending_size) const;
171  void PushJobCallback(const JobCallback& callback);
172
173 private:
174  const std::string hostname_;
175  scoped_refptr<base::SequencedTaskRunner> io_task_runner_;
176
177  // The queue of pending callbacks. 'mutable' because MaybeCreateJob() is a
178  // const method; it can't reenter though, because it runs exclusively on
179  // the IO thread.
180  mutable std::queue<JobCallback> pending_job_callbacks_;
181};
182
183TestRequestInterceptor::Delegate::Delegate(
184    const std::string& hostname,
185    scoped_refptr<base::SequencedTaskRunner> io_task_runner)
186    : hostname_(hostname), io_task_runner_(io_task_runner) {}
187
188TestRequestInterceptor::Delegate::~Delegate() {}
189
190net::URLRequestJob* TestRequestInterceptor::Delegate::MaybeInterceptRequest(
191    net::URLRequest* request,
192    net::NetworkDelegate* network_delegate) const {
193  CHECK(io_task_runner_->RunsTasksOnCurrentThread());
194
195  if (request->url().host() != hostname_) {
196    // Reject requests to other servers.
197    return ErrorJobCallback(
198        net::ERR_CONNECTION_REFUSED, request, network_delegate);
199  }
200
201  if (pending_job_callbacks_.empty()) {
202    // Reject dmserver requests by default.
203    return BadRequestJobCallback(request, network_delegate);
204  }
205
206  JobCallback callback = pending_job_callbacks_.front();
207  pending_job_callbacks_.pop();
208  return callback.Run(request, network_delegate);
209}
210
211void TestRequestInterceptor::Delegate::GetPendingSize(
212    size_t* pending_size) const {
213  CHECK(io_task_runner_->RunsTasksOnCurrentThread());
214  *pending_size = pending_job_callbacks_.size();
215}
216
217void TestRequestInterceptor::Delegate::PushJobCallback(
218    const JobCallback& callback) {
219  CHECK(io_task_runner_->RunsTasksOnCurrentThread());
220  pending_job_callbacks_.push(callback);
221}
222
223TestRequestInterceptor::TestRequestInterceptor(const std::string& hostname,
224    scoped_refptr<base::SequencedTaskRunner> io_task_runner)
225    : hostname_(hostname),
226      io_task_runner_(io_task_runner) {
227  delegate_ = new Delegate(hostname_, io_task_runner_);
228  scoped_ptr<net::URLRequestInterceptor> interceptor(delegate_);
229  PostToIOAndWait(
230      base::Bind(&RegisterHttpInterceptor, hostname_,
231                 base::Passed(&interceptor)));
232}
233
234TestRequestInterceptor::~TestRequestInterceptor() {
235  // RemoveHostnameHandler() destroys the |delegate_|, which is owned by
236  // the URLRequestFilter.
237  delegate_ = NULL;
238  PostToIOAndWait(
239      base::Bind(&net::URLRequestFilter::RemoveHostnameHandler,
240                 base::Unretained(net::URLRequestFilter::GetInstance()),
241                 "http", hostname_));
242}
243
244size_t TestRequestInterceptor::GetPendingSize() {
245  size_t pending_size = std::numeric_limits<size_t>::max();
246  PostToIOAndWait(base::Bind(&Delegate::GetPendingSize,
247                             base::Unretained(delegate_),
248                             &pending_size));
249  return pending_size;
250}
251
252void TestRequestInterceptor::PushJobCallback(const JobCallback& callback) {
253  PostToIOAndWait(base::Bind(&Delegate::PushJobCallback,
254                             base::Unretained(delegate_),
255                             callback));
256}
257
258// static
259TestRequestInterceptor::JobCallback TestRequestInterceptor::ErrorJob(
260    int error) {
261  return base::Bind(&ErrorJobCallback, error);
262}
263
264// static
265TestRequestInterceptor::JobCallback TestRequestInterceptor::BadRequestJob() {
266  return base::Bind(&BadRequestJobCallback);
267}
268
269// static
270TestRequestInterceptor::JobCallback TestRequestInterceptor::RegisterJob(
271    em::DeviceRegisterRequest::Type expected_type,
272    bool expect_reregister) {
273  return base::Bind(&RegisterJobCallback, expected_type, expect_reregister);
274}
275
276// static
277TestRequestInterceptor::JobCallback TestRequestInterceptor::FileJob(
278    const base::FilePath& file_path) {
279  return base::Bind(&FileJobCallback, file_path);
280}
281
282void TestRequestInterceptor::PostToIOAndWait(const base::Closure& task) {
283  io_task_runner_->PostTask(FROM_HERE, task);
284  base::RunLoop run_loop;
285  io_task_runner_->PostTask(
286      FROM_HERE,
287      base::Bind(
288          base::IgnoreResult(&base::MessageLoopProxy::PostTask),
289          base::MessageLoopProxy::current(),
290          FROM_HERE,
291          run_loop.QuitClosure()));
292  run_loop.Run();
293}
294
295}  // namespace policy
296