1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/host/it2me/it2me_native_messaging_host.h"
6
7#include "base/basictypes.h"
8#include "base/compiler_specific.h"
9#include "base/json/json_reader.h"
10#include "base/json/json_writer.h"
11#include "base/message_loop/message_loop.h"
12#include "base/run_loop.h"
13#include "base/stl_util.h"
14#include "base/strings/stringize_macros.h"
15#include "base/values.h"
16#include "net/base/file_stream.h"
17#include "net/base/net_util.h"
18#include "remoting/base/auto_thread_task_runner.h"
19#include "remoting/host/chromoting_host_context.h"
20#include "remoting/host/native_messaging/pipe_messaging_channel.h"
21#include "remoting/host/setup/test_util.h"
22#include "testing/gtest/include/gtest/gtest.h"
23
24namespace remoting {
25
26namespace {
27
28const char kTestAccessCode[] = "888888";
29const int kTestAccessCodeLifetimeInSeconds = 666;
30const char kTestClientUsername[] = "some_user@gmail.com";
31
32void VerifyId(scoped_ptr<base::DictionaryValue> response, int expected_value) {
33  ASSERT_TRUE(response);
34
35  int value;
36  EXPECT_TRUE(response->GetInteger("id", &value));
37  EXPECT_EQ(expected_value, value);
38}
39
40void VerifyStringProperty(scoped_ptr<base::DictionaryValue> response,
41                          const std::string& name,
42                          const std::string& expected_value) {
43  ASSERT_TRUE(response);
44
45  std::string value;
46  EXPECT_TRUE(response->GetString(name, &value));
47  EXPECT_EQ(expected_value, value);
48}
49
50// Verity the values of the "type" and "id" properties
51void VerifyCommonProperties(scoped_ptr<base::DictionaryValue> response,
52                            const std::string& type,
53                            int id) {
54  ASSERT_TRUE(response);
55
56  std::string string_value;
57  EXPECT_TRUE(response->GetString("type", &string_value));
58  EXPECT_EQ(type, string_value);
59
60  int int_value;
61  EXPECT_TRUE(response->GetInteger("id", &int_value));
62  EXPECT_EQ(id, int_value);
63}
64
65}  // namespace
66
67class MockIt2MeHost : public It2MeHost {
68 public:
69  MockIt2MeHost(ChromotingHostContext* context,
70                scoped_refptr<base::SingleThreadTaskRunner> task_runner,
71                base::WeakPtr<It2MeHost::Observer> observer,
72                const XmppSignalStrategy::XmppServerConfig& xmpp_server_config,
73                const std::string& directory_bot_jid)
74      : It2MeHost(context,
75                  task_runner,
76                  observer,
77                  xmpp_server_config,
78                  directory_bot_jid) {}
79
80  // It2MeHost overrides
81  virtual void Connect() OVERRIDE;
82  virtual void Disconnect() OVERRIDE;
83  virtual void RequestNatPolicy() OVERRIDE;
84
85 private:
86  virtual ~MockIt2MeHost() {}
87
88  void RunSetState(It2MeHostState state);
89
90  DISALLOW_COPY_AND_ASSIGN(MockIt2MeHost);
91};
92
93void MockIt2MeHost::Connect() {
94  if (!host_context()->ui_task_runner()->BelongsToCurrentThread()) {
95    DCHECK(task_runner()->BelongsToCurrentThread());
96    host_context()->ui_task_runner()->PostTask(
97        FROM_HERE, base::Bind(&MockIt2MeHost::Connect, this));
98    return;
99  }
100
101  RunSetState(kStarting);
102  RunSetState(kRequestedAccessCode);
103
104  std::string access_code(kTestAccessCode);
105  base::TimeDelta lifetime =
106      base::TimeDelta::FromSeconds(kTestAccessCodeLifetimeInSeconds);
107  task_runner()->PostTask(FROM_HERE,
108                          base::Bind(&It2MeHost::Observer::OnStoreAccessCode,
109                                     observer(),
110                                     access_code,
111                                     lifetime));
112
113  RunSetState(kReceivedAccessCode);
114
115  std::string client_username(kTestClientUsername);
116  task_runner()->PostTask(
117      FROM_HERE,
118      base::Bind(&It2MeHost::Observer::OnClientAuthenticated,
119                 observer(),
120                 client_username));
121
122  RunSetState(kConnected);
123}
124
125void MockIt2MeHost::Disconnect() {
126  if (!host_context()->network_task_runner()->BelongsToCurrentThread()) {
127    DCHECK(task_runner()->BelongsToCurrentThread());
128    host_context()->network_task_runner()->PostTask(
129        FROM_HERE, base::Bind(&MockIt2MeHost::Disconnect, this));
130    return;
131  }
132
133  RunSetState(kDisconnecting);
134  RunSetState(kDisconnected);
135}
136
137void MockIt2MeHost::RequestNatPolicy() {}
138
139void MockIt2MeHost::RunSetState(It2MeHostState state) {
140  if (!host_context()->network_task_runner()->BelongsToCurrentThread()) {
141    host_context()->network_task_runner()->PostTask(
142        FROM_HERE, base::Bind(&It2MeHost::SetStateForTesting, this, state));
143  } else {
144    SetStateForTesting(state);
145  }
146}
147
148class MockIt2MeHostFactory : public It2MeHostFactory {
149 public:
150  MockIt2MeHostFactory() {}
151  virtual scoped_refptr<It2MeHost> CreateIt2MeHost(
152      ChromotingHostContext* context,
153      scoped_refptr<base::SingleThreadTaskRunner> task_runner,
154      base::WeakPtr<It2MeHost::Observer> observer,
155      const XmppSignalStrategy::XmppServerConfig& xmpp_server_config,
156      const std::string& directory_bot_jid) OVERRIDE {
157    return new MockIt2MeHost(
158        context, task_runner, observer, xmpp_server_config, directory_bot_jid);
159  }
160
161 private:
162  DISALLOW_COPY_AND_ASSIGN(MockIt2MeHostFactory);
163};  // MockIt2MeHostFactory
164
165class It2MeNativeMessagingHostTest : public testing::Test {
166 public:
167  It2MeNativeMessagingHostTest() {}
168  virtual ~It2MeNativeMessagingHostTest() {}
169
170  virtual void SetUp() OVERRIDE;
171  virtual void TearDown() OVERRIDE;
172
173 protected:
174  scoped_ptr<base::DictionaryValue> ReadMessageFromOutputPipe();
175  void WriteMessageToInputPipe(const base::Value& message);
176
177  void VerifyHelloResponse(int request_id);
178  void VerifyErrorResponse();
179  void VerifyConnectResponses(int request_id);
180  void VerifyDisconnectResponses(int request_id);
181
182  // The Host process should shut down when it receives a malformed request.
183  // This is tested by sending a known-good request, followed by |message|,
184  // followed by the known-good request again. The response file should only
185  // contain a single response from the first good request.
186  void TestBadRequest(const base::Value& message, bool expect_error_response);
187  void TestConnect();
188
189 private:
190  void StartHost();
191  void StopHost();
192  void ExitTest();
193
194  // Each test creates two unidirectional pipes: "input" and "output".
195  // It2MeNativeMessagingHost reads from input_read_file and writes to
196  // output_write_file. The unittest supplies data to input_write_handle, and
197  // verifies output from output_read_handle.
198  //
199  // unittest -> [input] -> It2MeNativeMessagingHost -> [output] -> unittest
200  base::File input_write_file_;
201  base::File output_read_file_;
202
203  // Message loop of the test thread.
204  scoped_ptr<base::MessageLoop> test_message_loop_;
205  scoped_ptr<base::RunLoop> test_run_loop_;
206
207  scoped_ptr<base::Thread> host_thread_;
208  scoped_ptr<base::RunLoop> host_run_loop_;
209
210  // Task runner of the host thread.
211  scoped_refptr<AutoThreadTaskRunner> host_task_runner_;
212  scoped_ptr<remoting::It2MeNativeMessagingHost> host_;
213
214  DISALLOW_COPY_AND_ASSIGN(It2MeNativeMessagingHostTest);
215};
216
217void It2MeNativeMessagingHostTest::SetUp() {
218  test_message_loop_.reset(new base::MessageLoop());
219  test_run_loop_.reset(new base::RunLoop());
220
221  // Run the host on a dedicated thread.
222  host_thread_.reset(new base::Thread("host_thread"));
223  host_thread_->Start();
224
225  host_task_runner_ = new AutoThreadTaskRunner(
226      host_thread_->message_loop_proxy(),
227      base::Bind(&It2MeNativeMessagingHostTest::ExitTest,
228                 base::Unretained(this)));
229
230  host_task_runner_->PostTask(
231      FROM_HERE,
232      base::Bind(&It2MeNativeMessagingHostTest::StartHost,
233                 base::Unretained(this)));
234
235  // Wait until the host finishes starting.
236  test_run_loop_->Run();
237}
238
239void It2MeNativeMessagingHostTest::TearDown() {
240  // Closing the write-end of the input will send an EOF to the native
241  // messaging reader. This will trigger a host shutdown.
242  input_write_file_.Close();
243
244  // Start a new RunLoop and Wait until the host finishes shutting down.
245  test_run_loop_.reset(new base::RunLoop());
246  test_run_loop_->Run();
247
248  // Verify there are no more message in the output pipe.
249  scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
250  EXPECT_FALSE(response);
251
252  // The It2MeNativeMessagingHost dtor closes the handles that are passed to it.
253  // So the only handle left to close is |output_read_file_|.
254  output_read_file_.Close();
255}
256
257scoped_ptr<base::DictionaryValue>
258It2MeNativeMessagingHostTest::ReadMessageFromOutputPipe() {
259  uint32 length;
260  int read_result = output_read_file_.ReadAtCurrentPos(
261      reinterpret_cast<char*>(&length), sizeof(length));
262  if (read_result != sizeof(length)) {
263    // The output pipe has been closed, return an empty message.
264    return scoped_ptr<base::DictionaryValue>();
265  }
266
267  std::string message_json(length, '\0');
268  read_result = output_read_file_.ReadAtCurrentPos(
269      string_as_array(&message_json), length);
270  if (read_result != static_cast<int>(length)) {
271    LOG(ERROR) << "Message size (" << read_result
272               << ") doesn't match the header (" << length << ").";
273    return scoped_ptr<base::DictionaryValue>();
274  }
275
276  scoped_ptr<base::Value> message(base::JSONReader::Read(message_json));
277  if (!message || !message->IsType(base::Value::TYPE_DICTIONARY)) {
278    LOG(ERROR) << "Malformed message:" << message_json;
279    return scoped_ptr<base::DictionaryValue>();
280  }
281
282  return scoped_ptr<base::DictionaryValue>(
283      static_cast<base::DictionaryValue*>(message.release()));
284}
285
286void It2MeNativeMessagingHostTest::WriteMessageToInputPipe(
287    const base::Value& message) {
288  std::string message_json;
289  base::JSONWriter::Write(&message, &message_json);
290
291  uint32 length = message_json.length();
292  input_write_file_.WriteAtCurrentPos(reinterpret_cast<char*>(&length),
293                                      sizeof(length));
294  input_write_file_.WriteAtCurrentPos(message_json.data(), length);
295}
296
297void It2MeNativeMessagingHostTest::VerifyHelloResponse(int request_id) {
298  scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
299  VerifyCommonProperties(response.Pass(), "helloResponse", request_id);
300}
301
302void It2MeNativeMessagingHostTest::VerifyErrorResponse() {
303  scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
304  VerifyStringProperty(response.Pass(), "type", "error");
305}
306
307void It2MeNativeMessagingHostTest::VerifyConnectResponses(int request_id) {
308  bool connect_response_received = false;
309  bool starting_received = false;
310  bool requestedAccessCode_received = false;
311  bool receivedAccessCode_received = false;
312  bool connected_received = false;
313
314  // We expect a total of 5 messages: 1 connectResponse and 4 hostStateChanged.
315  for (int i = 0; i < 5; ++i) {
316    scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
317    ASSERT_TRUE(response);
318
319    std::string type;
320    ASSERT_TRUE(response->GetString("type", &type));
321
322    if (type == "connectResponse") {
323      EXPECT_FALSE(connect_response_received);
324      connect_response_received = true;
325      VerifyId(response.Pass(), request_id);
326    } else if (type == "hostStateChanged") {
327      std::string state;
328      ASSERT_TRUE(response->GetString("state", &state));
329
330      std::string value;
331      if (state == It2MeNativeMessagingHost::HostStateToString(kStarting)) {
332        EXPECT_FALSE(starting_received);
333        starting_received = true;
334      } else if (state == It2MeNativeMessagingHost::HostStateToString(
335                              kRequestedAccessCode)) {
336        EXPECT_FALSE(requestedAccessCode_received);
337        requestedAccessCode_received = true;
338      } else if (state == It2MeNativeMessagingHost::HostStateToString(
339                              kReceivedAccessCode)) {
340        EXPECT_FALSE(receivedAccessCode_received);
341        receivedAccessCode_received = true;
342
343        EXPECT_TRUE(response->GetString("accessCode", &value));
344        EXPECT_EQ(kTestAccessCode, value);
345
346        int accessCodeLifetime;
347        EXPECT_TRUE(
348            response->GetInteger("accessCodeLifetime", &accessCodeLifetime));
349        EXPECT_EQ(kTestAccessCodeLifetimeInSeconds, accessCodeLifetime);
350      } else if (state ==
351                 It2MeNativeMessagingHost::HostStateToString(kConnected)) {
352        EXPECT_FALSE(connected_received);
353        connected_received = true;
354
355        EXPECT_TRUE(response->GetString("client", &value));
356        EXPECT_EQ(kTestClientUsername, value);
357      } else {
358        ADD_FAILURE() << "Unexpected host state: " << state;
359      }
360    } else {
361      ADD_FAILURE() << "Unexpected message type: " << type;
362    }
363  }
364}
365
366void It2MeNativeMessagingHostTest::VerifyDisconnectResponses(int request_id) {
367  bool disconnect_response_received = false;
368  bool disconnecting_received = false;
369  bool disconnected_received = false;
370
371  // We expect a total of 3 messages: 1 connectResponse and 2 hostStateChanged.
372  for (int i = 0; i < 3; ++i) {
373    scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
374    ASSERT_TRUE(response);
375
376    std::string type;
377    ASSERT_TRUE(response->GetString("type", &type));
378
379    if (type == "disconnectResponse") {
380      EXPECT_FALSE(disconnect_response_received);
381      disconnect_response_received = true;
382      VerifyId(response.Pass(), request_id);
383    } else if (type == "hostStateChanged") {
384      std::string state;
385      ASSERT_TRUE(response->GetString("state", &state));
386      if (state ==
387          It2MeNativeMessagingHost::HostStateToString(kDisconnecting)) {
388        EXPECT_FALSE(disconnecting_received);
389        disconnecting_received = true;
390      } else if (state ==
391                 It2MeNativeMessagingHost::HostStateToString(kDisconnected)) {
392        EXPECT_FALSE(disconnected_received);
393        disconnected_received = true;
394      } else {
395        ADD_FAILURE() << "Unexpected host state: " << state;
396      }
397    } else {
398      ADD_FAILURE() << "Unexpected message type: " << type;
399    }
400  }
401}
402
403void It2MeNativeMessagingHostTest::TestBadRequest(const base::Value& message,
404                                                  bool expect_error_response) {
405  base::DictionaryValue good_message;
406  good_message.SetString("type", "hello");
407  good_message.SetInteger("id", 1);
408
409  WriteMessageToInputPipe(good_message);
410  WriteMessageToInputPipe(message);
411  WriteMessageToInputPipe(good_message);
412
413  VerifyHelloResponse(1);
414
415  if (expect_error_response)
416    VerifyErrorResponse();
417
418  scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
419  EXPECT_FALSE(response);
420}
421
422void It2MeNativeMessagingHostTest::StartHost() {
423  DCHECK(host_task_runner_->RunsTasksOnCurrentThread());
424
425  base::File input_read_file;
426  base::File output_write_file;
427
428  ASSERT_TRUE(MakePipe(&input_read_file, &input_write_file_));
429  ASSERT_TRUE(MakePipe(&output_read_file_, &output_write_file));
430
431  // Creating a native messaging host with a mock It2MeHostFactory.
432  scoped_ptr<It2MeHostFactory> factory(new MockIt2MeHostFactory());
433
434  scoped_ptr<extensions::NativeMessagingChannel> channel(
435      new PipeMessagingChannel(input_read_file.Pass(),
436                               output_write_file.Pass()));
437
438  host_.reset(new It2MeNativeMessagingHost(
439      host_task_runner_,
440      channel.Pass(),
441      factory.Pass()));
442  host_->Start(base::Bind(&It2MeNativeMessagingHostTest::StopHost,
443                          base::Unretained(this)));
444
445  // Notify the test that the host has finished starting up.
446  test_message_loop_->message_loop_proxy()->PostTask(
447      FROM_HERE, test_run_loop_->QuitClosure());
448}
449
450void It2MeNativeMessagingHostTest::StopHost() {
451  DCHECK(host_task_runner_->RunsTasksOnCurrentThread());
452
453  host_.reset();
454
455  // Wait till all shutdown tasks have completed.
456  base::RunLoop().RunUntilIdle();
457
458  // Trigger a test shutdown via ExitTest().
459  host_task_runner_ = NULL;
460}
461
462void It2MeNativeMessagingHostTest::ExitTest() {
463  if (!test_message_loop_->message_loop_proxy()->RunsTasksOnCurrentThread()) {
464    test_message_loop_->message_loop_proxy()->PostTask(
465        FROM_HERE,
466        base::Bind(&It2MeNativeMessagingHostTest::ExitTest,
467                   base::Unretained(this)));
468    return;
469  }
470  test_run_loop_->Quit();
471}
472
473void It2MeNativeMessagingHostTest::TestConnect() {
474  base::DictionaryValue connect_message;
475  int next_id = 0;
476
477  // Send the "connect" request.
478  connect_message.SetInteger("id", ++next_id);
479  connect_message.SetString("type", "connect");
480  connect_message.SetString("xmppServerAddress", "talk.google.com:5222");
481  connect_message.SetBoolean("xmppServerUseTls", true);
482  connect_message.SetString("directoryBotJid", "remoting@bot.talk.google.com");
483  connect_message.SetString("userName", "chromo.pyauto@gmail.com");
484  connect_message.SetString("authServiceWithToken", "oauth2:sometoken");
485  WriteMessageToInputPipe(connect_message);
486
487  VerifyConnectResponses(next_id);
488
489  base::DictionaryValue disconnect_message;
490  disconnect_message.SetInteger("id", ++next_id);
491  disconnect_message.SetString("type", "disconnect");
492  WriteMessageToInputPipe(disconnect_message);
493
494  VerifyDisconnectResponses(next_id);
495}
496
497// Test hello request.
498TEST_F(It2MeNativeMessagingHostTest, Hello) {
499  int next_id = 0;
500  base::DictionaryValue message;
501  message.SetInteger("id", ++next_id);
502  message.SetString("type", "hello");
503  WriteMessageToInputPipe(message);
504
505  VerifyHelloResponse(next_id);
506}
507
508// Verify that response ID matches request ID.
509TEST_F(It2MeNativeMessagingHostTest, Id) {
510  base::DictionaryValue message;
511  message.SetString("type", "hello");
512  WriteMessageToInputPipe(message);
513  message.SetString("id", "42");
514  WriteMessageToInputPipe(message);
515
516  scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
517  EXPECT_TRUE(response);
518  std::string value;
519  EXPECT_FALSE(response->GetString("id", &value));
520
521  response = ReadMessageFromOutputPipe();
522  EXPECT_TRUE(response);
523  EXPECT_TRUE(response->GetString("id", &value));
524  EXPECT_EQ("42", value);
525}
526
527TEST_F(It2MeNativeMessagingHostTest, Connect) {
528  // A new It2MeHost instance is created for every it2me session. The native
529  // messaging host, on the other hand, is long lived. This test verifies
530  // multiple It2Me host startup and shutdowns.
531  for (int i = 0; i < 3; ++i)
532    TestConnect();
533}
534
535// Verify non-Dictionary requests are rejected.
536TEST_F(It2MeNativeMessagingHostTest, WrongFormat) {
537  base::ListValue message;
538  // No "error" response will be sent for non-Dictionary messages.
539  TestBadRequest(message, false);
540}
541
542// Verify requests with no type are rejected.
543TEST_F(It2MeNativeMessagingHostTest, MissingType) {
544  base::DictionaryValue message;
545  TestBadRequest(message, true);
546}
547
548// Verify rejection if type is unrecognized.
549TEST_F(It2MeNativeMessagingHostTest, InvalidType) {
550  base::DictionaryValue message;
551  message.SetString("type", "xxx");
552  TestBadRequest(message, true);
553}
554
555}  // namespace remoting
556
557