1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/protocol/jingle_session.h"
6
7#include "base/bind.h"
8#include "base/message_loop/message_loop.h"
9#include "base/test/test_timeouts.h"
10#include "base/time/time.h"
11#include "net/socket/socket.h"
12#include "net/socket/stream_socket.h"
13#include "net/url_request/url_request_context_getter.h"
14#include "remoting/base/constants.h"
15#include "remoting/jingle_glue/chromium_port_allocator.h"
16#include "remoting/jingle_glue/fake_signal_strategy.h"
17#include "remoting/jingle_glue/network_settings.h"
18#include "remoting/protocol/authenticator.h"
19#include "remoting/protocol/channel_authenticator.h"
20#include "remoting/protocol/connection_tester.h"
21#include "remoting/protocol/fake_authenticator.h"
22#include "remoting/protocol/jingle_session_manager.h"
23#include "remoting/protocol/libjingle_transport_factory.h"
24#include "testing/gmock/include/gmock/gmock.h"
25#include "testing/gtest/include/gtest/gtest.h"
26
27using testing::_;
28using testing::AtLeast;
29using testing::AtMost;
30using testing::DeleteArg;
31using testing::DoAll;
32using testing::InSequence;
33using testing::Invoke;
34using testing::InvokeWithoutArgs;
35using testing::Return;
36using testing::SaveArg;
37using testing::SetArgumentPointee;
38using testing::WithArg;
39
40namespace remoting {
41namespace protocol {
42
43namespace {
44
45const char kHostJid[] = "host1@gmail.com/123";
46const char kClientJid[] = "host2@gmail.com/321";
47
48// Send 100 messages 1024 bytes each. UDP messages are sent with 10ms delay
49// between messages (about 1 second for 100 messages).
50const int kMessageSize = 1024;
51const int kMessages = 100;
52const char kChannelName[] = "test_channel";
53
54void QuitCurrentThread() {
55  base::MessageLoop::current()->PostTask(FROM_HERE,
56                                         base::MessageLoop::QuitClosure());
57}
58
59ACTION(QuitThread) {
60  QuitCurrentThread();
61}
62
63ACTION_P(QuitThreadOnCounter, counter) {
64  --(*counter);
65  EXPECT_GE(*counter, 0);
66  if (*counter == 0)
67    QuitCurrentThread();
68}
69
70class MockSessionManagerListener : public SessionManager::Listener {
71 public:
72  MOCK_METHOD0(OnSessionManagerReady, void());
73  MOCK_METHOD2(OnIncomingSession,
74               void(Session*,
75                    SessionManager::IncomingSessionResponse*));
76};
77
78class MockSessionEventHandler : public Session::EventHandler {
79 public:
80  MOCK_METHOD1(OnSessionStateChange, void(Session::State));
81  MOCK_METHOD2(OnSessionRouteChange, void(const std::string& channel_name,
82                                          const TransportRoute& route));
83};
84
85class MockStreamChannelCallback {
86 public:
87  MOCK_METHOD1(OnDone, void(net::StreamSocket* socket));
88};
89
90}  // namespace
91
92class JingleSessionTest : public testing::Test {
93 public:
94  JingleSessionTest() {
95    message_loop_.reset(new base::MessageLoopForIO());
96  }
97
98  // Helper method that handles OnIncomingSession().
99  void SetHostSession(Session* session) {
100    DCHECK(session);
101    host_session_.reset(session);
102    host_session_->SetEventHandler(&host_session_event_handler_);
103
104    session->set_config(SessionConfig::ForTest());
105  }
106
107  void OnClientChannelCreated(scoped_ptr<net::StreamSocket> socket) {
108    client_channel_callback_.OnDone(socket.get());
109    client_socket_ = socket.Pass();
110  }
111
112  void OnHostChannelCreated(scoped_ptr<net::StreamSocket> socket) {
113    host_channel_callback_.OnDone(socket.get());
114    host_socket_ = socket.Pass();
115  }
116
117 protected:
118  virtual void SetUp() {
119  }
120
121  virtual void TearDown() {
122    CloseSessions();
123    CloseSessionManager();
124    message_loop_->RunUntilIdle();
125  }
126
127  void CloseSessions() {
128    host_socket_.reset();
129    host_session_.reset();
130    client_socket_.reset();
131    client_session_.reset();
132  }
133
134  void CreateSessionManagers(int auth_round_trips,
135                        FakeAuthenticator::Action auth_action) {
136    host_signal_strategy_.reset(new FakeSignalStrategy(kHostJid));
137    client_signal_strategy_.reset(new FakeSignalStrategy(kClientJid));
138    FakeSignalStrategy::Connect(host_signal_strategy_.get(),
139                                client_signal_strategy_.get());
140
141    EXPECT_CALL(host_server_listener_, OnSessionManagerReady())
142        .Times(1);
143
144    NetworkSettings network_settings(NetworkSettings::NAT_TRAVERSAL_OUTGOING);
145
146    scoped_ptr<TransportFactory> host_transport(new LibjingleTransportFactory(
147        NULL,
148        ChromiumPortAllocator::Create(NULL, network_settings)
149            .PassAs<cricket::HttpPortAllocatorBase>(),
150        network_settings));
151    host_server_.reset(new JingleSessionManager(host_transport.Pass()));
152    host_server_->Init(host_signal_strategy_.get(), &host_server_listener_);
153
154    scoped_ptr<AuthenticatorFactory> factory(
155        new FakeHostAuthenticatorFactory(auth_round_trips, auth_action, true));
156    host_server_->set_authenticator_factory(factory.Pass());
157
158    EXPECT_CALL(client_server_listener_, OnSessionManagerReady())
159        .Times(1);
160    scoped_ptr<TransportFactory> client_transport(new LibjingleTransportFactory(
161        NULL,
162        ChromiumPortAllocator::Create(NULL, network_settings)
163            .PassAs<cricket::HttpPortAllocatorBase>(),
164        network_settings));
165    client_server_.reset(
166        new JingleSessionManager(client_transport.Pass()));
167    client_server_->Init(client_signal_strategy_.get(),
168                         &client_server_listener_);
169  }
170
171  void CloseSessionManager() {
172    if (host_server_.get()) {
173      host_server_->Close();
174      host_server_.reset();
175    }
176    if (client_server_.get()) {
177      client_server_->Close();
178      client_server_.reset();
179    }
180    host_signal_strategy_.reset();
181    client_signal_strategy_.reset();
182  }
183
184  void InitiateConnection(int auth_round_trips,
185                          FakeAuthenticator::Action auth_action,
186                          bool expect_fail) {
187    EXPECT_CALL(host_server_listener_, OnIncomingSession(_, _))
188        .WillOnce(DoAll(
189            WithArg<0>(Invoke(this, &JingleSessionTest::SetHostSession)),
190            SetArgumentPointee<1>(protocol::SessionManager::ACCEPT)));
191
192    {
193      InSequence dummy;
194
195      EXPECT_CALL(host_session_event_handler_,
196                  OnSessionStateChange(Session::CONNECTED))
197          .Times(AtMost(1));
198      if (expect_fail) {
199        EXPECT_CALL(host_session_event_handler_,
200                    OnSessionStateChange(Session::FAILED))
201            .Times(1);
202      } else {
203        EXPECT_CALL(host_session_event_handler_,
204                    OnSessionStateChange(Session::AUTHENTICATED))
205            .Times(1);
206        // Expect that the connection will be closed eventually.
207        EXPECT_CALL(host_session_event_handler_,
208                    OnSessionStateChange(Session::CLOSED))
209            .Times(AtMost(1));
210      }
211    }
212
213    {
214      InSequence dummy;
215
216      EXPECT_CALL(client_session_event_handler_,
217                  OnSessionStateChange(Session::CONNECTED))
218          .Times(AtMost(1));
219      if (expect_fail) {
220        EXPECT_CALL(client_session_event_handler_,
221                    OnSessionStateChange(Session::FAILED))
222            .Times(1);
223      } else {
224        EXPECT_CALL(client_session_event_handler_,
225                    OnSessionStateChange(Session::AUTHENTICATED))
226            .Times(1);
227        // Expect that the connection will be closed eventually.
228        EXPECT_CALL(client_session_event_handler_,
229                    OnSessionStateChange(Session::CLOSED))
230            .Times(AtMost(1));
231      }
232    }
233
234    scoped_ptr<Authenticator> authenticator(new FakeAuthenticator(
235        FakeAuthenticator::CLIENT, auth_round_trips, auth_action, true));
236
237    client_session_ = client_server_->Connect(
238        kHostJid, authenticator.Pass(),
239        CandidateSessionConfig::CreateDefault());
240    client_session_->SetEventHandler(&client_session_event_handler_);
241
242    message_loop_->RunUntilIdle();
243  }
244
245  void CreateChannel() {
246    client_session_->GetTransportChannelFactory()->CreateStreamChannel(
247        kChannelName, base::Bind(&JingleSessionTest::OnClientChannelCreated,
248                                 base::Unretained(this)));
249    host_session_->GetTransportChannelFactory()->CreateStreamChannel(
250        kChannelName, base::Bind(&JingleSessionTest::OnHostChannelCreated,
251                                 base::Unretained(this)));
252
253    int counter = 2;
254    ExpectRouteChange(kChannelName);
255    EXPECT_CALL(client_channel_callback_, OnDone(_))
256        .WillOnce(QuitThreadOnCounter(&counter));
257    EXPECT_CALL(host_channel_callback_, OnDone(_))
258        .WillOnce(QuitThreadOnCounter(&counter));
259    message_loop_->Run();
260
261    EXPECT_TRUE(client_socket_.get());
262    EXPECT_TRUE(host_socket_.get());
263  }
264
265  void ExpectRouteChange(const std::string& channel_name) {
266    EXPECT_CALL(host_session_event_handler_,
267                OnSessionRouteChange(channel_name, _))
268        .Times(AtLeast(1));
269    EXPECT_CALL(client_session_event_handler_,
270                OnSessionRouteChange(channel_name, _))
271        .Times(AtLeast(1));
272  }
273
274  scoped_ptr<base::MessageLoopForIO> message_loop_;
275
276  scoped_ptr<FakeSignalStrategy> host_signal_strategy_;
277  scoped_ptr<FakeSignalStrategy> client_signal_strategy_;
278
279  scoped_ptr<JingleSessionManager> host_server_;
280  MockSessionManagerListener host_server_listener_;
281  scoped_ptr<JingleSessionManager> client_server_;
282  MockSessionManagerListener client_server_listener_;
283
284  scoped_ptr<Session> host_session_;
285  MockSessionEventHandler host_session_event_handler_;
286  scoped_ptr<Session> client_session_;
287  MockSessionEventHandler client_session_event_handler_;
288
289  MockStreamChannelCallback client_channel_callback_;
290  MockStreamChannelCallback host_channel_callback_;
291
292  scoped_ptr<net::StreamSocket> client_socket_;
293  scoped_ptr<net::StreamSocket> host_socket_;
294};
295
296
297// Verify that we can create and destroy session managers without a
298// connection.
299TEST_F(JingleSessionTest, CreateAndDestoy) {
300  CreateSessionManagers(1, FakeAuthenticator::ACCEPT);
301}
302
303// Verify that an incoming session can be rejected, and that the
304// status of the connection is set to FAILED in this case.
305TEST_F(JingleSessionTest, RejectConnection) {
306  CreateSessionManagers(1, FakeAuthenticator::ACCEPT);
307
308  // Reject incoming session.
309  EXPECT_CALL(host_server_listener_, OnIncomingSession(_, _))
310      .WillOnce(SetArgumentPointee<1>(protocol::SessionManager::DECLINE));
311
312  {
313    InSequence dummy;
314    EXPECT_CALL(client_session_event_handler_,
315                OnSessionStateChange(Session::FAILED))
316        .Times(1);
317  }
318
319  scoped_ptr<Authenticator> authenticator(new FakeAuthenticator(
320      FakeAuthenticator::CLIENT, 1, FakeAuthenticator::ACCEPT, true));
321  client_session_ = client_server_->Connect(
322      kHostJid, authenticator.Pass(), CandidateSessionConfig::CreateDefault());
323  client_session_->SetEventHandler(&client_session_event_handler_);
324
325  message_loop_->RunUntilIdle();
326}
327
328// Verify that we can connect two endpoints with single-step authentication.
329TEST_F(JingleSessionTest, Connect) {
330  CreateSessionManagers(1, FakeAuthenticator::ACCEPT);
331  InitiateConnection(1, FakeAuthenticator::ACCEPT, false);
332
333  // Verify that the client specified correct initiator value.
334  ASSERT_GT(host_signal_strategy_->received_messages().size(), 0U);
335  const buzz::XmlElement* initiate_xml =
336      host_signal_strategy_->received_messages().front();
337  const buzz::XmlElement* jingle_element =
338      initiate_xml->FirstNamed(buzz::QName(kJingleNamespace, "jingle"));
339  ASSERT_TRUE(jingle_element);
340  ASSERT_EQ(kClientJid,
341            jingle_element->Attr(buzz::QName(std::string(), "initiator")));
342}
343
344// Verify that we can connect two endpoints with multi-step authentication.
345TEST_F(JingleSessionTest, ConnectWithMultistep) {
346  CreateSessionManagers(3, FakeAuthenticator::ACCEPT);
347  InitiateConnection(3, FakeAuthenticator::ACCEPT, false);
348}
349
350// Verify that connection is terminated when single-step auth fails.
351TEST_F(JingleSessionTest, ConnectWithBadAuth) {
352  CreateSessionManagers(1, FakeAuthenticator::REJECT);
353  InitiateConnection(1, FakeAuthenticator::ACCEPT, true);
354}
355
356// Verify that connection is terminated when multi-step auth fails.
357TEST_F(JingleSessionTest, ConnectWithBadMultistepAuth) {
358  CreateSessionManagers(3, FakeAuthenticator::REJECT);
359  InitiateConnection(3, FakeAuthenticator::ACCEPT, true);
360}
361
362// Verify that data can be sent over stream channel.
363TEST_F(JingleSessionTest, TestStreamChannel) {
364  CreateSessionManagers(1, FakeAuthenticator::ACCEPT);
365  ASSERT_NO_FATAL_FAILURE(
366      InitiateConnection(1, FakeAuthenticator::ACCEPT, false));
367
368  ASSERT_NO_FATAL_FAILURE(CreateChannel());
369
370  StreamConnectionTester tester(host_socket_.get(), client_socket_.get(),
371                                kMessageSize, kMessages);
372  tester.Start();
373  message_loop_->Run();
374  tester.CheckResults();
375}
376
377// Verify that data can be sent over a multiplexed channel.
378TEST_F(JingleSessionTest, TestMuxStreamChannel) {
379  CreateSessionManagers(1, FakeAuthenticator::ACCEPT);
380  ASSERT_NO_FATAL_FAILURE(
381      InitiateConnection(1, FakeAuthenticator::ACCEPT, false));
382
383  client_session_->GetMultiplexedChannelFactory()->CreateStreamChannel(
384      kChannelName, base::Bind(&JingleSessionTest::OnClientChannelCreated,
385                               base::Unretained(this)));
386  host_session_->GetMultiplexedChannelFactory()->CreateStreamChannel(
387      kChannelName, base::Bind(&JingleSessionTest::OnHostChannelCreated,
388                               base::Unretained(this)));
389
390  int counter = 2;
391  ExpectRouteChange("mux");
392  EXPECT_CALL(client_channel_callback_, OnDone(_))
393      .WillOnce(QuitThreadOnCounter(&counter));
394  EXPECT_CALL(host_channel_callback_, OnDone(_))
395      .WillOnce(QuitThreadOnCounter(&counter));
396  message_loop_->Run();
397
398  EXPECT_TRUE(client_socket_.get());
399  EXPECT_TRUE(host_socket_.get());
400
401  StreamConnectionTester tester(host_socket_.get(), client_socket_.get(),
402                                kMessageSize, kMessages);
403  tester.Start();
404  message_loop_->Run();
405  tester.CheckResults();
406}
407
408// Verify that we can connect channels with multistep auth.
409TEST_F(JingleSessionTest, TestMultistepAuthStreamChannel) {
410  CreateSessionManagers(3, FakeAuthenticator::ACCEPT);
411  ASSERT_NO_FATAL_FAILURE(
412      InitiateConnection(3, FakeAuthenticator::ACCEPT, false));
413
414  ASSERT_NO_FATAL_FAILURE(CreateChannel());
415
416  StreamConnectionTester tester(host_socket_.get(), client_socket_.get(),
417                                kMessageSize, kMessages);
418  tester.Start();
419  message_loop_->Run();
420  tester.CheckResults();
421}
422
423// Verify that we shutdown properly when channel authentication fails.
424TEST_F(JingleSessionTest, TestFailedChannelAuth) {
425  CreateSessionManagers(1, FakeAuthenticator::REJECT_CHANNEL);
426  ASSERT_NO_FATAL_FAILURE(
427      InitiateConnection(1, FakeAuthenticator::ACCEPT, false));
428
429  client_session_->GetTransportChannelFactory()->CreateStreamChannel(
430      kChannelName, base::Bind(&JingleSessionTest::OnClientChannelCreated,
431                               base::Unretained(this)));
432  host_session_->GetTransportChannelFactory()->CreateStreamChannel(
433      kChannelName, base::Bind(&JingleSessionTest::OnHostChannelCreated,
434                               base::Unretained(this)));
435
436  // Terminate the message loop when we get rejection notification
437  // from the host.
438  EXPECT_CALL(host_channel_callback_, OnDone(NULL))
439      .WillOnce(QuitThread());
440  EXPECT_CALL(client_channel_callback_, OnDone(_))
441      .Times(AtMost(1));
442  ExpectRouteChange(kChannelName);
443
444  message_loop_->Run();
445
446  EXPECT_TRUE(!host_socket_.get());
447}
448
449}  // namespace protocol
450}  // namespace remoting
451