1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/host/heartbeat_sender.h"
6
7#include <set>
8
9#include "base/memory/ref_counted.h"
10#include "base/message_loop/message_loop.h"
11#include "base/message_loop/message_loop_proxy.h"
12#include "base/run_loop.h"
13#include "base/strings/string_number_conversions.h"
14#include "remoting/base/constants.h"
15#include "remoting/base/rsa_key_pair.h"
16#include "remoting/base/test_rsa_key_pair.h"
17#include "remoting/signaling/iq_sender.h"
18#include "remoting/signaling/mock_signal_strategy.h"
19#include "testing/gmock/include/gmock/gmock.h"
20#include "testing/gtest/include/gtest/gtest.h"
21#include "third_party/libjingle/source/talk/xmpp/constants.h"
22#include "third_party/webrtc/libjingle/xmllite/xmlelement.h"
23
24using buzz::QName;
25using buzz::XmlElement;
26
27using testing::_;
28using testing::DeleteArg;
29using testing::DoAll;
30using testing::Invoke;
31using testing::NotNull;
32using testing::Return;
33using testing::SaveArg;
34
35namespace remoting {
36
37namespace {
38
39const char kTestBotJid[] = "remotingunittest@bot.talk.google.com";
40const char kHostId[] = "0";
41const char kTestJid[] = "user@gmail.com/chromoting123";
42const char kStanzaId[] = "123";
43
44class MockListener : public HeartbeatSender::Listener {
45 public:
46  // Overridden from HeartbeatSender::Listener
47  virtual void OnUnknownHostIdError() OVERRIDE {
48    NOTREACHED();
49  }
50
51  // Overridden from HeartbeatSender::Listener
52  MOCK_METHOD0(OnHeartbeatSuccessful, void());
53};
54
55}  // namespace
56
57ACTION_P(AddListener, list) {
58  list->insert(arg0);
59}
60ACTION_P(RemoveListener, list) {
61  EXPECT_TRUE(list->find(arg0) != list->end());
62  list->erase(arg0);
63}
64
65class HeartbeatSenderTest
66    : public testing::Test {
67 protected:
68  virtual void SetUp() OVERRIDE {
69    key_pair_ = RsaKeyPair::FromString(kTestRsaKeyPair);
70    ASSERT_TRUE(key_pair_.get());
71
72    EXPECT_CALL(signal_strategy_, GetState())
73        .WillOnce(Return(SignalStrategy::DISCONNECTED));
74    EXPECT_CALL(signal_strategy_, AddListener(NotNull()))
75        .WillRepeatedly(AddListener(&signal_strategy_listeners_));
76    EXPECT_CALL(signal_strategy_, RemoveListener(NotNull()))
77        .WillRepeatedly(RemoveListener(&signal_strategy_listeners_));
78    EXPECT_CALL(signal_strategy_, GetLocalJid())
79        .WillRepeatedly(Return(kTestJid));
80
81    heartbeat_sender_.reset(new HeartbeatSender(
82        &mock_listener_, kHostId, &signal_strategy_, key_pair_, kTestBotJid));
83  }
84
85  virtual void TearDown() OVERRIDE {
86    heartbeat_sender_.reset();
87    EXPECT_TRUE(signal_strategy_listeners_.empty());
88  }
89
90  void ValidateHeartbeatStanza(XmlElement* stanza,
91                               const char* expectedSequenceId);
92
93  base::MessageLoop message_loop_;
94  MockSignalStrategy signal_strategy_;
95  MockListener mock_listener_;
96  std::set<SignalStrategy::Listener*> signal_strategy_listeners_;
97  scoped_refptr<RsaKeyPair> key_pair_;
98  scoped_ptr<HeartbeatSender> heartbeat_sender_;
99};
100
101// Call Start() followed by Stop(), and make sure a valid heartbeat is sent.
102TEST_F(HeartbeatSenderTest, DoSendStanza) {
103  XmlElement* sent_iq = NULL;
104  EXPECT_CALL(signal_strategy_, GetLocalJid())
105      .WillRepeatedly(Return(kTestJid));
106  EXPECT_CALL(signal_strategy_, GetNextId())
107      .WillOnce(Return(kStanzaId));
108  EXPECT_CALL(signal_strategy_, SendStanzaPtr(NotNull()))
109      .WillOnce(DoAll(SaveArg<0>(&sent_iq), Return(true)));
110
111  heartbeat_sender_->OnSignalStrategyStateChange(SignalStrategy::CONNECTED);
112  base::RunLoop().RunUntilIdle();
113
114  scoped_ptr<XmlElement> stanza(sent_iq);
115  ASSERT_TRUE(stanza != NULL);
116  ValidateHeartbeatStanza(stanza.get(), "0");
117
118  heartbeat_sender_->OnSignalStrategyStateChange(SignalStrategy::DISCONNECTED);
119  base::RunLoop().RunUntilIdle();
120}
121
122// Call Start() followed by Stop(), twice, and make sure two valid heartbeats
123// are sent, with the correct sequence IDs.
124TEST_F(HeartbeatSenderTest, DoSendStanzaTwice) {
125  XmlElement* sent_iq = NULL;
126  EXPECT_CALL(signal_strategy_, GetLocalJid())
127      .WillRepeatedly(Return(kTestJid));
128  EXPECT_CALL(signal_strategy_, GetNextId())
129      .WillOnce(Return(kStanzaId));
130  EXPECT_CALL(signal_strategy_, SendStanzaPtr(NotNull()))
131      .WillOnce(DoAll(SaveArg<0>(&sent_iq), Return(true)));
132
133  heartbeat_sender_->OnSignalStrategyStateChange(SignalStrategy::CONNECTED);
134  base::RunLoop().RunUntilIdle();
135
136  scoped_ptr<XmlElement> stanza(sent_iq);
137  ASSERT_TRUE(stanza != NULL);
138  ValidateHeartbeatStanza(stanza.get(), "0");
139
140  heartbeat_sender_->OnSignalStrategyStateChange(SignalStrategy::DISCONNECTED);
141  base::RunLoop().RunUntilIdle();
142
143  EXPECT_CALL(signal_strategy_, GetLocalJid())
144      .WillRepeatedly(Return(kTestJid));
145  EXPECT_CALL(signal_strategy_, GetNextId())
146      .WillOnce(Return(kStanzaId + 1));
147  EXPECT_CALL(signal_strategy_, SendStanzaPtr(NotNull()))
148      .WillOnce(DoAll(SaveArg<0>(&sent_iq), Return(true)));
149
150  heartbeat_sender_->OnSignalStrategyStateChange(SignalStrategy::CONNECTED);
151  base::RunLoop().RunUntilIdle();
152
153  scoped_ptr<XmlElement> stanza2(sent_iq);
154  ValidateHeartbeatStanza(stanza2.get(), "1");
155
156  heartbeat_sender_->OnSignalStrategyStateChange(SignalStrategy::DISCONNECTED);
157  base::RunLoop().RunUntilIdle();
158}
159
160// Call Start() followed by Stop(), make sure a valid Iq stanza is sent,
161// reply with an expected sequence ID, and make sure two valid heartbeats
162// are sent, with the correct sequence IDs.
163TEST_F(HeartbeatSenderTest, DoSendStanzaWithExpectedSequenceId) {
164  XmlElement* sent_iq = NULL;
165  EXPECT_CALL(signal_strategy_, GetLocalJid())
166      .WillRepeatedly(Return(kTestJid));
167  EXPECT_CALL(signal_strategy_, GetNextId())
168      .WillOnce(Return(kStanzaId));
169  EXPECT_CALL(signal_strategy_, SendStanzaPtr(NotNull()))
170      .WillOnce(DoAll(SaveArg<0>(&sent_iq), Return(true)));
171
172  heartbeat_sender_->OnSignalStrategyStateChange(SignalStrategy::CONNECTED);
173  base::RunLoop().RunUntilIdle();
174
175  scoped_ptr<XmlElement> stanza(sent_iq);
176  ASSERT_TRUE(stanza != NULL);
177  ValidateHeartbeatStanza(stanza.get(), "0");
178
179  XmlElement* sent_iq2 = NULL;
180  EXPECT_CALL(signal_strategy_, GetLocalJid())
181      .WillRepeatedly(Return(kTestJid));
182  EXPECT_CALL(signal_strategy_, GetNextId())
183      .WillOnce(Return(kStanzaId + 1));
184  EXPECT_CALL(signal_strategy_, SendStanzaPtr(NotNull()))
185      .WillOnce(DoAll(SaveArg<0>(&sent_iq2), Return(true)));
186  EXPECT_CALL(mock_listener_, OnHeartbeatSuccessful());
187
188  scoped_ptr<XmlElement> response(new XmlElement(buzz::QN_IQ));
189  response->AddAttr(QName(std::string(), "type"), "result");
190  XmlElement* result =
191      new XmlElement(QName(kChromotingXmlNamespace, "heartbeat-result"));
192  response->AddElement(result);
193  XmlElement* expected_sequence_id = new XmlElement(
194      QName(kChromotingXmlNamespace, "expected-sequence-id"));
195  result->AddElement(expected_sequence_id);
196  const int kExpectedSequenceId = 456;
197  expected_sequence_id->AddText(base::IntToString(kExpectedSequenceId));
198  heartbeat_sender_->ProcessResponse(NULL, response.get());
199  base::RunLoop().RunUntilIdle();
200
201  scoped_ptr<XmlElement> stanza2(sent_iq2);
202  ASSERT_TRUE(stanza2 != NULL);
203  ValidateHeartbeatStanza(stanza2.get(),
204                          base::IntToString(kExpectedSequenceId).c_str());
205
206  heartbeat_sender_->OnSignalStrategyStateChange(SignalStrategy::DISCONNECTED);
207  base::RunLoop().RunUntilIdle();
208}
209
210// Verify that ProcessResponse parses set-interval result.
211TEST_F(HeartbeatSenderTest, ProcessResponseSetInterval) {
212  EXPECT_CALL(mock_listener_, OnHeartbeatSuccessful());
213
214  scoped_ptr<XmlElement> response(new XmlElement(buzz::QN_IQ));
215  response->AddAttr(QName(std::string(), "type"), "result");
216
217  XmlElement* result = new XmlElement(
218      QName(kChromotingXmlNamespace, "heartbeat-result"));
219  response->AddElement(result);
220
221  XmlElement* set_interval = new XmlElement(
222      QName(kChromotingXmlNamespace, "set-interval"));
223  result->AddElement(set_interval);
224
225  const int kTestInterval = 123;
226  set_interval->AddText(base::IntToString(kTestInterval));
227
228  heartbeat_sender_->ProcessResponse(NULL, response.get());
229
230  EXPECT_EQ(kTestInterval * 1000, heartbeat_sender_->interval_ms_);
231}
232
233// Validate a heartbeat stanza.
234void HeartbeatSenderTest::ValidateHeartbeatStanza(
235    XmlElement* stanza, const char* expectedSequenceId) {
236  EXPECT_EQ(stanza->Attr(buzz::QName(std::string(), "to")),
237            std::string(kTestBotJid));
238  EXPECT_EQ(stanza->Attr(buzz::QName(std::string(), "type")), "set");
239  XmlElement* heartbeat_stanza =
240      stanza->FirstNamed(QName(kChromotingXmlNamespace, "heartbeat"));
241  ASSERT_TRUE(heartbeat_stanza != NULL);
242  EXPECT_EQ(expectedSequenceId, heartbeat_stanza->Attr(
243      buzz::QName(kChromotingXmlNamespace, "sequence-id")));
244  EXPECT_EQ(std::string(kHostId),
245            heartbeat_stanza->Attr(QName(kChromotingXmlNamespace, "hostid")));
246
247  QName signature_tag(kChromotingXmlNamespace, "signature");
248  XmlElement* signature = heartbeat_stanza->FirstNamed(signature_tag);
249  ASSERT_TRUE(signature != NULL);
250  EXPECT_TRUE(heartbeat_stanza->NextNamed(signature_tag) == NULL);
251
252  scoped_refptr<RsaKeyPair> key_pair = RsaKeyPair::FromString(kTestRsaKeyPair);
253  ASSERT_TRUE(key_pair.get());
254  std::string expected_signature =
255      key_pair->SignMessage(std::string(kTestJid) + ' ' + expectedSequenceId);
256  EXPECT_EQ(expected_signature, signature->BodyText());
257}
258
259}  // namespace remoting
260