1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/protocol/authenticator_test_base.h"
6
7#include "base/base64.h"
8#include "base/file_util.h"
9#include "base/files/file_path.h"
10#include "base/path_service.h"
11#include "base/test/test_timeouts.h"
12#include "base/timer/timer.h"
13#include "net/base/test_data_directory.h"
14#include "remoting/base/rsa_key_pair.h"
15#include "remoting/protocol/authenticator.h"
16#include "remoting/protocol/channel_authenticator.h"
17#include "remoting/protocol/fake_session.h"
18#include "testing/gtest/include/gtest/gtest.h"
19#include "third_party/libjingle/source/talk/xmllite/xmlelement.h"
20
21using testing::_;
22using testing::SaveArg;
23
24namespace remoting {
25namespace protocol {
26
27namespace {
28
29ACTION_P(QuitThreadOnCounter, counter) {
30  --(*counter);
31  EXPECT_GE(*counter, 0);
32  if (*counter == 0)
33    base::MessageLoop::current()->Quit();
34}
35
36}  // namespace
37
38AuthenticatorTestBase::MockChannelDoneCallback::MockChannelDoneCallback() {}
39
40AuthenticatorTestBase::MockChannelDoneCallback::~MockChannelDoneCallback() {}
41
42AuthenticatorTestBase::AuthenticatorTestBase() {}
43
44AuthenticatorTestBase::~AuthenticatorTestBase() {}
45
46void AuthenticatorTestBase::SetUp() {
47  base::FilePath certs_dir(net::GetTestCertsDirectory());
48
49  base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der");
50  ASSERT_TRUE(base::ReadFileToString(cert_path, &host_cert_));
51
52  base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
53  std::string key_string;
54  ASSERT_TRUE(base::ReadFileToString(key_path, &key_string));
55  std::string key_base64;
56  base::Base64Encode(key_string, &key_base64);
57  key_pair_ = RsaKeyPair::FromString(key_base64);
58  ASSERT_TRUE(key_pair_.get());
59  host_public_key_ = key_pair_->GetPublicKey();
60}
61
62void AuthenticatorTestBase::RunAuthExchange() {
63  ContinueAuthExchangeWith(client_.get(), host_.get());
64}
65
66void AuthenticatorTestBase::RunHostInitiatedAuthExchange() {
67  ContinueAuthExchangeWith(host_.get(), client_.get());
68}
69
70// static
71void AuthenticatorTestBase::ContinueAuthExchangeWith(Authenticator* sender,
72                                                     Authenticator* receiver) {
73  scoped_ptr<buzz::XmlElement> message;
74  ASSERT_NE(Authenticator::WAITING_MESSAGE, sender->state());
75  if (sender->state() == Authenticator::ACCEPTED ||
76      sender->state() == Authenticator::REJECTED)
77    return;
78  // Pass message from client to host.
79  ASSERT_EQ(Authenticator::MESSAGE_READY, sender->state());
80  message = sender->GetNextMessage();
81  ASSERT_TRUE(message.get());
82  ASSERT_NE(Authenticator::MESSAGE_READY, sender->state());
83
84  ASSERT_EQ(Authenticator::WAITING_MESSAGE, receiver->state());
85  receiver->ProcessMessage(message.get(), base::Bind(
86      &AuthenticatorTestBase::ContinueAuthExchangeWith,
87      base::Unretained(receiver), base::Unretained(sender)));
88}
89
90void AuthenticatorTestBase::RunChannelAuth(bool expected_fail) {
91  client_fake_socket_.reset(new FakeSocket());
92  host_fake_socket_.reset(new FakeSocket());
93  client_fake_socket_->PairWith(host_fake_socket_.get());
94
95  client_auth_->SecureAndAuthenticate(
96      client_fake_socket_.PassAs<net::StreamSocket>(),
97      base::Bind(&AuthenticatorTestBase::OnClientConnected,
98                 base::Unretained(this)));
99
100  host_auth_->SecureAndAuthenticate(
101      host_fake_socket_.PassAs<net::StreamSocket>(),
102      base::Bind(&AuthenticatorTestBase::OnHostConnected,
103                 base::Unretained(this)));
104
105  // Expect two callbacks to be called - the client callback and the host
106  // callback.
107  int callback_counter = 2;
108
109  EXPECT_CALL(client_callback_, OnDone(net::OK))
110      .WillOnce(QuitThreadOnCounter(&callback_counter));
111  if (expected_fail) {
112    EXPECT_CALL(host_callback_, OnDone(net::ERR_FAILED))
113         .WillOnce(QuitThreadOnCounter(&callback_counter));
114  } else {
115    EXPECT_CALL(host_callback_, OnDone(net::OK))
116        .WillOnce(QuitThreadOnCounter(&callback_counter));
117  }
118
119  // Ensure that .Run() does not run unbounded if the callbacks are never
120  // called.
121  base::Timer shutdown_timer(false, false);
122  shutdown_timer.Start(FROM_HERE,
123                       TestTimeouts::action_timeout(),
124                       base::MessageLoop::QuitClosure());
125  message_loop_.Run();
126  shutdown_timer.Stop();
127
128  testing::Mock::VerifyAndClearExpectations(&client_callback_);
129  testing::Mock::VerifyAndClearExpectations(&host_callback_);
130
131  if (!expected_fail) {
132    ASSERT_TRUE(client_socket_.get() != NULL);
133    ASSERT_TRUE(host_socket_.get() != NULL);
134  }
135}
136
137void AuthenticatorTestBase::OnHostConnected(
138    net::Error error,
139    scoped_ptr<net::StreamSocket> socket) {
140  host_callback_.OnDone(error);
141  host_socket_ = socket.Pass();
142}
143
144void AuthenticatorTestBase::OnClientConnected(
145    net::Error error,
146    scoped_ptr<net::StreamSocket> socket) {
147  client_callback_.OnDone(error);
148  client_socket_ = socket.Pass();
149}
150
151}  // namespace protocol
152}  // namespace remoting
153