1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/protocol/authenticator_test_base.h"
6
7#include "base/base64.h"
8#include "base/file_util.h"
9#include "base/files/file_path.h"
10#include "base/path_service.h"
11#include "base/test/test_timeouts.h"
12#include "base/timer/timer.h"
13#include "net/base/test_data_directory.h"
14#include "remoting/base/rsa_key_pair.h"
15#include "remoting/protocol/authenticator.h"
16#include "remoting/protocol/channel_authenticator.h"
17#include "remoting/protocol/fake_session.h"
18#include "testing/gtest/include/gtest/gtest.h"
19#include "third_party/libjingle/source/talk/xmllite/xmlelement.h"
20
21using testing::_;
22using testing::SaveArg;
23
24namespace remoting {
25namespace protocol {
26
27namespace {
28
29ACTION_P(QuitThreadOnCounter, counter) {
30  --(*counter);
31  EXPECT_GE(*counter, 0);
32  if (*counter == 0)
33    base::MessageLoop::current()->Quit();
34}
35
36}  // namespace
37
38AuthenticatorTestBase::MockChannelDoneCallback::MockChannelDoneCallback() {}
39
40AuthenticatorTestBase::MockChannelDoneCallback::~MockChannelDoneCallback() {}
41
42AuthenticatorTestBase::AuthenticatorTestBase() {}
43
44AuthenticatorTestBase::~AuthenticatorTestBase() {}
45
46void AuthenticatorTestBase::SetUp() {
47  base::FilePath certs_dir(net::GetTestCertsDirectory());
48
49  base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der");
50  ASSERT_TRUE(base::ReadFileToString(cert_path, &host_cert_));
51
52  base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
53  std::string key_string;
54  ASSERT_TRUE(base::ReadFileToString(key_path, &key_string));
55  std::string key_base64;
56  base::Base64Encode(key_string, &key_base64);
57  key_pair_ = RsaKeyPair::FromString(key_base64);
58  ASSERT_TRUE(key_pair_.get());
59  host_public_key_ = key_pair_->GetPublicKey();
60}
61
62void AuthenticatorTestBase::RunAuthExchange() {
63  ContinueAuthExchangeWith(client_.get(),
64                           host_.get(),
65                           client_->started(),
66                           host_->started());
67}
68
69void AuthenticatorTestBase::RunHostInitiatedAuthExchange() {
70  ContinueAuthExchangeWith(host_.get(),
71                           client_.get(),
72                           host_->started(),
73                           client_->started());
74}
75
76// static
77// This function sends a message from the sender and receiver and recursively
78// calls itself to the send the next message from the receiver to the sender
79// untils the authentication completes.
80void AuthenticatorTestBase::ContinueAuthExchangeWith(Authenticator* sender,
81                                                     Authenticator* receiver,
82                                                     bool sender_started,
83                                                     bool receiver_started) {
84  scoped_ptr<buzz::XmlElement> message;
85  ASSERT_NE(Authenticator::WAITING_MESSAGE, sender->state());
86  if (sender->state() == Authenticator::ACCEPTED ||
87      sender->state() == Authenticator::REJECTED)
88    return;
89
90  // Verify that once the started flag for either party is set to true,
91  // it should always stay true.
92  if (receiver_started) {
93    ASSERT_TRUE(receiver->started());
94  }
95
96  if (sender_started) {
97    ASSERT_TRUE(sender->started());
98  }
99
100  ASSERT_EQ(Authenticator::MESSAGE_READY, sender->state());
101  message = sender->GetNextMessage();
102  ASSERT_TRUE(message.get());
103  ASSERT_NE(Authenticator::MESSAGE_READY, sender->state());
104
105  ASSERT_EQ(Authenticator::WAITING_MESSAGE, receiver->state());
106  receiver->ProcessMessage(message.get(), base::Bind(
107      &AuthenticatorTestBase::ContinueAuthExchangeWith,
108      base::Unretained(receiver), base::Unretained(sender),
109      receiver->started(), sender->started()));
110}
111
112void AuthenticatorTestBase::RunChannelAuth(bool expected_fail) {
113  client_fake_socket_.reset(new FakeSocket());
114  host_fake_socket_.reset(new FakeSocket());
115  client_fake_socket_->PairWith(host_fake_socket_.get());
116
117  client_auth_->SecureAndAuthenticate(
118      client_fake_socket_.PassAs<net::StreamSocket>(),
119      base::Bind(&AuthenticatorTestBase::OnClientConnected,
120                 base::Unretained(this)));
121
122  host_auth_->SecureAndAuthenticate(
123      host_fake_socket_.PassAs<net::StreamSocket>(),
124      base::Bind(&AuthenticatorTestBase::OnHostConnected,
125                 base::Unretained(this)));
126
127  // Expect two callbacks to be called - the client callback and the host
128  // callback.
129  int callback_counter = 2;
130
131  EXPECT_CALL(client_callback_, OnDone(net::OK))
132      .WillOnce(QuitThreadOnCounter(&callback_counter));
133  if (expected_fail) {
134    EXPECT_CALL(host_callback_, OnDone(net::ERR_FAILED))
135         .WillOnce(QuitThreadOnCounter(&callback_counter));
136  } else {
137    EXPECT_CALL(host_callback_, OnDone(net::OK))
138        .WillOnce(QuitThreadOnCounter(&callback_counter));
139  }
140
141  // Ensure that .Run() does not run unbounded if the callbacks are never
142  // called.
143  base::Timer shutdown_timer(false, false);
144  shutdown_timer.Start(FROM_HERE,
145                       TestTimeouts::action_timeout(),
146                       base::MessageLoop::QuitClosure());
147  message_loop_.Run();
148  shutdown_timer.Stop();
149
150  testing::Mock::VerifyAndClearExpectations(&client_callback_);
151  testing::Mock::VerifyAndClearExpectations(&host_callback_);
152
153  if (!expected_fail) {
154    ASSERT_TRUE(client_socket_.get() != NULL);
155    ASSERT_TRUE(host_socket_.get() != NULL);
156  }
157}
158
159void AuthenticatorTestBase::OnHostConnected(
160    net::Error error,
161    scoped_ptr<net::StreamSocket> socket) {
162  host_callback_.OnDone(error);
163  host_socket_ = socket.Pass();
164}
165
166void AuthenticatorTestBase::OnClientConnected(
167    net::Error error,
168    scoped_ptr<net::StreamSocket> socket) {
169  client_callback_.OnDone(error);
170  client_socket_ = socket.Pass();
171}
172
173}  // namespace protocol
174}  // namespace remoting
175