1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "remoting/protocol/authenticator_test_base.h" 6 7#include "base/base64.h" 8#include "base/file_util.h" 9#include "base/files/file_path.h" 10#include "base/path_service.h" 11#include "base/test/test_timeouts.h" 12#include "base/timer/timer.h" 13#include "net/base/test_data_directory.h" 14#include "remoting/base/rsa_key_pair.h" 15#include "remoting/protocol/authenticator.h" 16#include "remoting/protocol/channel_authenticator.h" 17#include "remoting/protocol/fake_session.h" 18#include "testing/gtest/include/gtest/gtest.h" 19#include "third_party/libjingle/source/talk/xmllite/xmlelement.h" 20 21using testing::_; 22using testing::SaveArg; 23 24namespace remoting { 25namespace protocol { 26 27namespace { 28 29ACTION_P(QuitThreadOnCounter, counter) { 30 --(*counter); 31 EXPECT_GE(*counter, 0); 32 if (*counter == 0) 33 base::MessageLoop::current()->Quit(); 34} 35 36} // namespace 37 38AuthenticatorTestBase::MockChannelDoneCallback::MockChannelDoneCallback() {} 39 40AuthenticatorTestBase::MockChannelDoneCallback::~MockChannelDoneCallback() {} 41 42AuthenticatorTestBase::AuthenticatorTestBase() {} 43 44AuthenticatorTestBase::~AuthenticatorTestBase() {} 45 46void AuthenticatorTestBase::SetUp() { 47 base::FilePath certs_dir(net::GetTestCertsDirectory()); 48 49 base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); 50 ASSERT_TRUE(base::ReadFileToString(cert_path, &host_cert_)); 51 52 base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); 53 std::string key_string; 54 ASSERT_TRUE(base::ReadFileToString(key_path, &key_string)); 55 std::string key_base64; 56 base::Base64Encode(key_string, &key_base64); 57 key_pair_ = RsaKeyPair::FromString(key_base64); 58 ASSERT_TRUE(key_pair_.get()); 59 host_public_key_ = key_pair_->GetPublicKey(); 60} 61 62void AuthenticatorTestBase::RunAuthExchange() { 63 ContinueAuthExchangeWith(client_.get(), 64 host_.get(), 65 client_->started(), 66 host_->started()); 67} 68 69void AuthenticatorTestBase::RunHostInitiatedAuthExchange() { 70 ContinueAuthExchangeWith(host_.get(), 71 client_.get(), 72 host_->started(), 73 client_->started()); 74} 75 76// static 77// This function sends a message from the sender and receiver and recursively 78// calls itself to the send the next message from the receiver to the sender 79// untils the authentication completes. 80void AuthenticatorTestBase::ContinueAuthExchangeWith(Authenticator* sender, 81 Authenticator* receiver, 82 bool sender_started, 83 bool receiver_started) { 84 scoped_ptr<buzz::XmlElement> message; 85 ASSERT_NE(Authenticator::WAITING_MESSAGE, sender->state()); 86 if (sender->state() == Authenticator::ACCEPTED || 87 sender->state() == Authenticator::REJECTED) 88 return; 89 90 // Verify that once the started flag for either party is set to true, 91 // it should always stay true. 92 if (receiver_started) { 93 ASSERT_TRUE(receiver->started()); 94 } 95 96 if (sender_started) { 97 ASSERT_TRUE(sender->started()); 98 } 99 100 ASSERT_EQ(Authenticator::MESSAGE_READY, sender->state()); 101 message = sender->GetNextMessage(); 102 ASSERT_TRUE(message.get()); 103 ASSERT_NE(Authenticator::MESSAGE_READY, sender->state()); 104 105 ASSERT_EQ(Authenticator::WAITING_MESSAGE, receiver->state()); 106 receiver->ProcessMessage(message.get(), base::Bind( 107 &AuthenticatorTestBase::ContinueAuthExchangeWith, 108 base::Unretained(receiver), base::Unretained(sender), 109 receiver->started(), sender->started())); 110} 111 112void AuthenticatorTestBase::RunChannelAuth(bool expected_fail) { 113 client_fake_socket_.reset(new FakeSocket()); 114 host_fake_socket_.reset(new FakeSocket()); 115 client_fake_socket_->PairWith(host_fake_socket_.get()); 116 117 client_auth_->SecureAndAuthenticate( 118 client_fake_socket_.PassAs<net::StreamSocket>(), 119 base::Bind(&AuthenticatorTestBase::OnClientConnected, 120 base::Unretained(this))); 121 122 host_auth_->SecureAndAuthenticate( 123 host_fake_socket_.PassAs<net::StreamSocket>(), 124 base::Bind(&AuthenticatorTestBase::OnHostConnected, 125 base::Unretained(this))); 126 127 // Expect two callbacks to be called - the client callback and the host 128 // callback. 129 int callback_counter = 2; 130 131 EXPECT_CALL(client_callback_, OnDone(net::OK)) 132 .WillOnce(QuitThreadOnCounter(&callback_counter)); 133 if (expected_fail) { 134 EXPECT_CALL(host_callback_, OnDone(net::ERR_FAILED)) 135 .WillOnce(QuitThreadOnCounter(&callback_counter)); 136 } else { 137 EXPECT_CALL(host_callback_, OnDone(net::OK)) 138 .WillOnce(QuitThreadOnCounter(&callback_counter)); 139 } 140 141 // Ensure that .Run() does not run unbounded if the callbacks are never 142 // called. 143 base::Timer shutdown_timer(false, false); 144 shutdown_timer.Start(FROM_HERE, 145 TestTimeouts::action_timeout(), 146 base::MessageLoop::QuitClosure()); 147 message_loop_.Run(); 148 shutdown_timer.Stop(); 149 150 testing::Mock::VerifyAndClearExpectations(&client_callback_); 151 testing::Mock::VerifyAndClearExpectations(&host_callback_); 152 153 if (!expected_fail) { 154 ASSERT_TRUE(client_socket_.get() != NULL); 155 ASSERT_TRUE(host_socket_.get() != NULL); 156 } 157} 158 159void AuthenticatorTestBase::OnHostConnected( 160 net::Error error, 161 scoped_ptr<net::StreamSocket> socket) { 162 host_callback_.OnDone(error); 163 host_socket_ = socket.Pass(); 164} 165 166void AuthenticatorTestBase::OnClientConnected( 167 net::Error error, 168 scoped_ptr<net::StreamSocket> socket) { 169 client_callback_.OnDone(error); 170 client_socket_ = socket.Pass(); 171} 172 173} // namespace protocol 174} // namespace remoting 175