it2me_native_messaging_host_unittest.cc revision f2477e01787aa58f445919b809d89e252beef54f
1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/host/it2me/it2me_native_messaging_host.h"
6
7#include "base/basictypes.h"
8#include "base/compiler_specific.h"
9#include "base/json/json_reader.h"
10#include "base/json/json_writer.h"
11#include "base/message_loop/message_loop.h"
12#include "base/run_loop.h"
13#include "base/stl_util.h"
14#include "base/strings/stringize_macros.h"
15#include "base/values.h"
16#include "net/base/file_stream.h"
17#include "net/base/net_util.h"
18#include "remoting/base/auto_thread_task_runner.h"
19#include "remoting/host/chromoting_host_context.h"
20#include "remoting/host/native_messaging/native_messaging_channel.h"
21#include "remoting/host/setup/test_util.h"
22#include "testing/gtest/include/gtest/gtest.h"
23
24namespace {
25
26const char kTestAccessCode[] = "888888";
27const int kTestAccessCodeLifetimeInSeconds = 666;
28const char kTestClientUsername[] = "some_user@gmail.com";
29
30void VerifyId(scoped_ptr<base::DictionaryValue> response, int expected_value) {
31  ASSERT_TRUE(response);
32
33  int value;
34  EXPECT_TRUE(response->GetInteger("id", &value));
35  EXPECT_EQ(expected_value, value);
36}
37
38void VerifyStringProperty(scoped_ptr<base::DictionaryValue> response,
39                          const std::string& name,
40                          const std::string& expected_value) {
41  ASSERT_TRUE(response);
42
43  std::string value;
44  EXPECT_TRUE(response->GetString(name, &value));
45  EXPECT_EQ(expected_value, value);
46}
47
48// Verity the values of the "type" and "id" properties
49void VerifyCommonProperties(scoped_ptr<base::DictionaryValue> response,
50                            const std::string& type,
51                            int id) {
52  ASSERT_TRUE(response);
53
54  std::string string_value;
55  EXPECT_TRUE(response->GetString("type", &string_value));
56  EXPECT_EQ(type, string_value);
57
58  int int_value;
59  EXPECT_TRUE(response->GetInteger("id", &int_value));
60  EXPECT_EQ(id, int_value);
61}
62
63}  // namespace
64
65namespace remoting {
66
67class MockIt2MeHost : public It2MeHost {
68 public:
69  MockIt2MeHost(ChromotingHostContext* context,
70                scoped_refptr<base::SingleThreadTaskRunner> task_runner,
71                base::WeakPtr<It2MeHost::Observer> observer,
72                const XmppSignalStrategy::XmppServerConfig& xmpp_server_config,
73                const std::string& directory_bot_jid)
74      : It2MeHost(context,
75                  task_runner,
76                  observer,
77                  xmpp_server_config,
78                  directory_bot_jid) {}
79
80  // It2MeHost overrides
81  virtual void Connect() OVERRIDE;
82  virtual void Disconnect() OVERRIDE;
83  virtual void RequestNatPolicy() OVERRIDE;
84
85 private:
86  virtual ~MockIt2MeHost() {}
87
88  void RunSetState(It2MeHostState state);
89
90  DISALLOW_COPY_AND_ASSIGN(MockIt2MeHost);
91};
92
93void MockIt2MeHost::Connect() {
94  if (!host_context()->ui_task_runner()->BelongsToCurrentThread()) {
95    DCHECK(task_runner()->BelongsToCurrentThread());
96    host_context()->ui_task_runner()->PostTask(
97        FROM_HERE, base::Bind(&MockIt2MeHost::Connect, this));
98    return;
99  }
100
101  RunSetState(kStarting);
102  RunSetState(kRequestedAccessCode);
103
104  std::string access_code(kTestAccessCode);
105  base::TimeDelta lifetime =
106      base::TimeDelta::FromSeconds(kTestAccessCodeLifetimeInSeconds);
107  task_runner()->PostTask(FROM_HERE,
108                          base::Bind(&It2MeHost::Observer::OnStoreAccessCode,
109                                     observer(),
110                                     access_code,
111                                     lifetime));
112
113  RunSetState(kReceivedAccessCode);
114
115  std::string client_username(kTestClientUsername);
116  task_runner()->PostTask(
117      FROM_HERE,
118      base::Bind(&It2MeHost::Observer::OnClientAuthenticated,
119                 observer(),
120                 client_username));
121
122  RunSetState(kConnected);
123}
124
125void MockIt2MeHost::Disconnect() {
126  if (!host_context()->network_task_runner()->BelongsToCurrentThread()) {
127    DCHECK(task_runner()->BelongsToCurrentThread());
128    host_context()->network_task_runner()->PostTask(
129        FROM_HERE, base::Bind(&MockIt2MeHost::Disconnect, this));
130    return;
131  }
132
133  RunSetState(kDisconnecting);
134  RunSetState(kDisconnected);
135}
136
137void MockIt2MeHost::RequestNatPolicy() {}
138
139void MockIt2MeHost::RunSetState(It2MeHostState state) {
140  if (!host_context()->network_task_runner()->BelongsToCurrentThread()) {
141    host_context()->network_task_runner()->PostTask(
142        FROM_HERE, base::Bind(&It2MeHost::SetStateForTesting, this, state));
143  } else {
144    SetStateForTesting(state);
145  }
146}
147
148class MockIt2MeHostFactory : public It2MeHostFactory {
149 public:
150  MockIt2MeHostFactory() {}
151  virtual scoped_refptr<It2MeHost> CreateIt2MeHost(
152      ChromotingHostContext* context,
153      scoped_refptr<base::SingleThreadTaskRunner> task_runner,
154      base::WeakPtr<It2MeHost::Observer> observer,
155      const XmppSignalStrategy::XmppServerConfig& xmpp_server_config,
156      const std::string& directory_bot_jid) OVERRIDE {
157    return new MockIt2MeHost(
158        context, task_runner, observer, xmpp_server_config, directory_bot_jid);
159  }
160
161 private:
162  DISALLOW_COPY_AND_ASSIGN(MockIt2MeHostFactory);
163};  // MockIt2MeHostFactory
164
165class It2MeNativeMessagingHostTest : public testing::Test {
166 public:
167  It2MeNativeMessagingHostTest() {}
168  virtual ~It2MeNativeMessagingHostTest() {}
169
170  virtual void SetUp() OVERRIDE;
171  virtual void TearDown() OVERRIDE;
172
173 protected:
174  scoped_ptr<base::DictionaryValue> ReadMessageFromOutputPipe();
175  void WriteMessageToInputPipe(const base::Value& message);
176
177  void VerifyHelloResponse(int request_id);
178  void VerifyErrorResponse();
179  void VerifyConnectResponses(int request_id);
180  void VerifyDisconnectResponses(int request_id);
181
182  // The Host process should shut down when it receives a malformed request.
183  // This is tested by sending a known-good request, followed by |message|,
184  // followed by the known-good request again. The response file should only
185  // contain a single response from the first good request.
186  void TestBadRequest(const base::Value& message, bool expect_error_response);
187  void TestConnect();
188
189 private:
190  void StartHost(base::PlatformFile input, base::PlatformFile output);
191  void StopHost();
192  void ExitTest();
193
194  // Each test creates two unidirectional pipes: "input" and "output".
195  // NativeMessagingHost reads from input_read_handle and writes to
196  // output_write_handle. The unittest supplies data to input_write_handle, and
197  // verifies output from output_read_handle.
198  //
199  // unittest -> [input] -> NativeMessagingHost -> [output] -> unittest
200  base::PlatformFile input_write_handle_;
201  base::PlatformFile output_read_handle_;
202
203  // Message loop of the test thread.
204  scoped_ptr<base::MessageLoop> test_message_loop_;
205  scoped_ptr<base::RunLoop> test_run_loop_;
206
207  scoped_ptr<base::Thread> host_thread_;
208  scoped_ptr<base::RunLoop> host_run_loop_;
209
210  // Task runner of the host thread.
211  scoped_refptr<AutoThreadTaskRunner> host_task_runner_;
212  scoped_ptr<remoting::NativeMessagingChannel> channel_;
213
214  DISALLOW_COPY_AND_ASSIGN(It2MeNativeMessagingHostTest);
215};
216
217void It2MeNativeMessagingHostTest::SetUp() {
218  base::PlatformFile input_read_handle;
219  base::PlatformFile output_write_handle;
220
221  ASSERT_TRUE(MakePipe(&input_read_handle, &input_write_handle_));
222  ASSERT_TRUE(MakePipe(&output_read_handle_, &output_write_handle));
223
224  test_message_loop_.reset(new base::MessageLoop());
225  test_run_loop_.reset(new base::RunLoop());
226
227  // Run the host on a dedicated thread.
228  host_thread_.reset(new base::Thread("host_thread"));
229  host_thread_->Start();
230
231  host_task_runner_ = new AutoThreadTaskRunner(
232      host_thread_->message_loop_proxy(),
233      base::Bind(&It2MeNativeMessagingHostTest::ExitTest,
234                 base::Unretained(this)));
235
236  host_task_runner_->PostTask(
237      FROM_HERE,
238      base::Bind(&It2MeNativeMessagingHostTest::StartHost,
239                 base::Unretained(this),
240                 input_read_handle,
241                 output_write_handle));
242
243  // Wait until the host finishes starting.
244  test_run_loop_->Run();
245}
246
247void It2MeNativeMessagingHostTest::TearDown() {
248  // Closing the write-end of the input will send an EOF to the native
249  // messaging reader. This will trigger a host shutdown.
250  base::ClosePlatformFile(input_write_handle_);
251
252  // Start a new RunLoop and Wait until the host finishes shutting down.
253  test_run_loop_.reset(new base::RunLoop());
254  test_run_loop_->Run();
255
256  // Verify there are no more message in the output pipe.
257  scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
258  EXPECT_FALSE(response);
259
260  // The It2MeNativeMessagingHost dtor closes the handles that are passed to it.
261  // So the only handle left to close is |output_read_handle_|.
262  base::ClosePlatformFile(output_read_handle_);
263}
264
265scoped_ptr<base::DictionaryValue>
266It2MeNativeMessagingHostTest::ReadMessageFromOutputPipe() {
267  uint32 length;
268  int read_result = base::ReadPlatformFileAtCurrentPos(
269      output_read_handle_, reinterpret_cast<char*>(&length), sizeof(length));
270  if (read_result != sizeof(length)) {
271    LOG(ERROR) << "Invalid message header.";
272    return scoped_ptr<base::DictionaryValue>();
273  }
274
275  std::string message_json(length, '\0');
276  read_result = base::ReadPlatformFileAtCurrentPos(
277      output_read_handle_, string_as_array(&message_json), length);
278  if (read_result != static_cast<int>(length)) {
279    LOG(ERROR) << "Message size (" << read_result
280               << ") doesn't match the header (" << length << ").";
281    return scoped_ptr<base::DictionaryValue>();
282  }
283
284  scoped_ptr<base::Value> message(base::JSONReader::Read(message_json));
285  if (!message || !message->IsType(base::Value::TYPE_DICTIONARY)) {
286    LOG(ERROR) << "Malformed message:" << message_json;
287    return scoped_ptr<base::DictionaryValue>();
288  }
289
290  return scoped_ptr<base::DictionaryValue>(
291      static_cast<base::DictionaryValue*>(message.release()));
292}
293
294void It2MeNativeMessagingHostTest::WriteMessageToInputPipe(
295    const base::Value& message) {
296  std::string message_json;
297  base::JSONWriter::Write(&message, &message_json);
298
299  uint32 length = message_json.length();
300  base::WritePlatformFileAtCurrentPos(
301      input_write_handle_, reinterpret_cast<char*>(&length), sizeof(length));
302  base::WritePlatformFileAtCurrentPos(
303      input_write_handle_, message_json.data(), length);
304}
305
306void It2MeNativeMessagingHostTest::VerifyHelloResponse(int request_id) {
307  scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
308  VerifyCommonProperties(response.Pass(), "helloResponse", request_id);
309}
310
311void It2MeNativeMessagingHostTest::VerifyErrorResponse() {
312  scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
313  VerifyStringProperty(response.Pass(), "type", "error");
314}
315
316void It2MeNativeMessagingHostTest::VerifyConnectResponses(int request_id) {
317  bool connect_response_received = false;
318  bool starting_received = false;
319  bool requestedAccessCode_received = false;
320  bool receivedAccessCode_received = false;
321  bool connected_received = false;
322
323  // We expect a total of 5 messages: 1 connectResponse and 4 hostStateChanged.
324  for (int i = 0; i < 5; ++i) {
325    scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
326    ASSERT_TRUE(response);
327
328    std::string type;
329    ASSERT_TRUE(response->GetString("type", &type));
330
331    if (type == "connectResponse") {
332      EXPECT_FALSE(connect_response_received);
333      connect_response_received = true;
334      VerifyId(response.Pass(), request_id);
335    } else if (type == "hostStateChanged") {
336      std::string state;
337      ASSERT_TRUE(response->GetString("state", &state));
338
339      std::string value;
340      if (state == It2MeNativeMessagingHost::HostStateToString(kStarting)) {
341        EXPECT_FALSE(starting_received);
342        starting_received = true;
343      } else if (state == It2MeNativeMessagingHost::HostStateToString(
344                              kRequestedAccessCode)) {
345        EXPECT_FALSE(requestedAccessCode_received);
346        requestedAccessCode_received = true;
347      } else if (state == It2MeNativeMessagingHost::HostStateToString(
348                              kReceivedAccessCode)) {
349        EXPECT_FALSE(receivedAccessCode_received);
350        receivedAccessCode_received = true;
351
352        EXPECT_TRUE(response->GetString("accessCode", &value));
353        EXPECT_EQ(kTestAccessCode, value);
354
355        int accessCodeLifetime;
356        EXPECT_TRUE(
357            response->GetInteger("accessCodeLifetime", &accessCodeLifetime));
358        EXPECT_EQ(kTestAccessCodeLifetimeInSeconds, accessCodeLifetime);
359      } else if (state ==
360                 It2MeNativeMessagingHost::HostStateToString(kConnected)) {
361        EXPECT_FALSE(connected_received);
362        connected_received = true;
363
364        EXPECT_TRUE(response->GetString("client", &value));
365        EXPECT_EQ(kTestClientUsername, value);
366      } else {
367        ADD_FAILURE() << "Unexpected host state: " << state;
368      }
369    } else {
370      ADD_FAILURE() << "Unexpected message type: " << type;
371    }
372  }
373}
374
375void It2MeNativeMessagingHostTest::VerifyDisconnectResponses(int request_id) {
376  bool disconnect_response_received = false;
377  bool disconnecting_received = false;
378  bool disconnected_received = false;
379
380  // We expect a total of 3 messages: 1 connectResponse and 2 hostStateChanged.
381  for (int i = 0; i < 3; ++i) {
382    scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
383    ASSERT_TRUE(response);
384
385    std::string type;
386    ASSERT_TRUE(response->GetString("type", &type));
387
388    if (type == "disconnectResponse") {
389      EXPECT_FALSE(disconnect_response_received);
390      disconnect_response_received = true;
391      VerifyId(response.Pass(), request_id);
392    } else if (type == "hostStateChanged") {
393      std::string state;
394      ASSERT_TRUE(response->GetString("state", &state));
395      if (state ==
396          It2MeNativeMessagingHost::HostStateToString(kDisconnecting)) {
397        EXPECT_FALSE(disconnecting_received);
398        disconnecting_received = true;
399      } else if (state ==
400                 It2MeNativeMessagingHost::HostStateToString(kDisconnected)) {
401        EXPECT_FALSE(disconnected_received);
402        disconnected_received = true;
403      } else {
404        ADD_FAILURE() << "Unexpected host state: " << state;
405      }
406    } else {
407      ADD_FAILURE() << "Unexpected message type: " << type;
408    }
409  }
410}
411
412void It2MeNativeMessagingHostTest::TestBadRequest(const base::Value& message,
413                                                  bool expect_error_response) {
414  base::DictionaryValue good_message;
415  good_message.SetString("type", "hello");
416  good_message.SetInteger("id", 1);
417
418  WriteMessageToInputPipe(good_message);
419  WriteMessageToInputPipe(message);
420  WriteMessageToInputPipe(good_message);
421
422  VerifyHelloResponse(1);
423
424  if (expect_error_response)
425    VerifyErrorResponse();
426
427  scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
428  EXPECT_FALSE(response);
429}
430
431void It2MeNativeMessagingHostTest::StartHost(base::PlatformFile input,
432                                             base::PlatformFile output) {
433  DCHECK(host_task_runner_->RunsTasksOnCurrentThread());
434
435  // Creating a native messaging host with a mock It2MeHostFactory.
436  scoped_ptr<It2MeHostFactory> factory(new MockIt2MeHostFactory());
437  scoped_ptr<NativeMessagingChannel::Delegate> host(
438      new It2MeNativeMessagingHost(host_task_runner_, factory.Pass()));
439
440  // Set up and start the native messaging channel.
441  channel_.reset(new NativeMessagingChannel(host.Pass(), input, output));
442  channel_->Start(base::Bind(&It2MeNativeMessagingHostTest::StopHost,
443                             base::Unretained(this)));
444
445  // Notify the test that the host has finished starting up.
446  test_message_loop_->message_loop_proxy()->PostTask(
447      FROM_HERE, test_run_loop_->QuitClosure());
448}
449
450void It2MeNativeMessagingHostTest::StopHost() {
451  DCHECK(host_task_runner_->RunsTasksOnCurrentThread());
452
453  // The NativeMessagingChannel dtor will destroy the reader, the writer,
454  // and the delegate (the native messaging host).
455  channel_.reset();
456
457  // Wait till all shutdown tasks have completed.
458  base::MessageLoop::current()->RunUntilIdle();
459
460  // Trigger a test shutdown via ExitTest().
461  host_task_runner_ = NULL;
462}
463
464void It2MeNativeMessagingHostTest::ExitTest() {
465  if (!test_message_loop_->message_loop_proxy()->RunsTasksOnCurrentThread()) {
466    test_message_loop_->message_loop_proxy()->PostTask(
467        FROM_HERE,
468        base::Bind(&It2MeNativeMessagingHostTest::ExitTest,
469                   base::Unretained(this)));
470    return;
471  }
472  test_run_loop_->Quit();
473}
474
475void It2MeNativeMessagingHostTest::TestConnect() {
476  base::DictionaryValue connect_message;
477  int next_id = 0;
478
479  // Send the "connect" request.
480  connect_message.SetInteger("id", ++next_id);
481  connect_message.SetString("type", "connect");
482  connect_message.SetString("xmppServerAddress", "talk.google.com:5222");
483  connect_message.SetBoolean("xmppServerUseTls", true);
484  connect_message.SetString("directoryBotJid", "remoting@bot.talk.google.com");
485  connect_message.SetString("userName", "chromo.pyauto@gmail.com");
486  connect_message.SetString("authServiceWithToken", "oauth2:sometoken");
487  WriteMessageToInputPipe(connect_message);
488
489  VerifyConnectResponses(next_id);
490
491  base::DictionaryValue disconnect_message;
492  disconnect_message.SetInteger("id", ++next_id);
493  disconnect_message.SetString("type", "disconnect");
494  WriteMessageToInputPipe(disconnect_message);
495
496  VerifyDisconnectResponses(next_id);
497}
498
499// Test hello request.
500TEST_F(It2MeNativeMessagingHostTest, Hello) {
501  int next_id = 0;
502  base::DictionaryValue message;
503  message.SetInteger("id", ++next_id);
504  message.SetString("type", "hello");
505  WriteMessageToInputPipe(message);
506
507  VerifyHelloResponse(next_id);
508}
509
510// Verify that response ID matches request ID.
511TEST_F(It2MeNativeMessagingHostTest, Id) {
512  base::DictionaryValue message;
513  message.SetString("type", "hello");
514  WriteMessageToInputPipe(message);
515  message.SetString("id", "42");
516  WriteMessageToInputPipe(message);
517
518  scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
519  EXPECT_TRUE(response);
520  std::string value;
521  EXPECT_FALSE(response->GetString("id", &value));
522
523  response = ReadMessageFromOutputPipe();
524  EXPECT_TRUE(response);
525  EXPECT_TRUE(response->GetString("id", &value));
526  EXPECT_EQ("42", value);
527}
528
529TEST_F(It2MeNativeMessagingHostTest, Connect) {
530  // A new It2MeHost instance is created for every it2me session. The native
531  // messaging host, on the other hand, is long lived. This test verifies
532  // multiple It2Me host startup and shutdowns.
533  for (int i = 0; i < 3; ++i)
534    TestConnect();
535}
536
537// Verify non-Dictionary requests are rejected.
538TEST_F(It2MeNativeMessagingHostTest, WrongFormat) {
539  base::ListValue message;
540  // No "error" response will be sent for non-Dictionary messages.
541  TestBadRequest(message, false);
542}
543
544// Verify requests with no type are rejected.
545TEST_F(It2MeNativeMessagingHostTest, MissingType) {
546  base::DictionaryValue message;
547  TestBadRequest(message, true);
548}
549
550// Verify rejection if type is unrecognized.
551TEST_F(It2MeNativeMessagingHostTest, InvalidType) {
552  base::DictionaryValue message;
553  message.SetString("type", "xxx");
554  TestBadRequest(message, true);
555}
556
557}  // namespace remoting
558