1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/protocol/pairing_registry.h"
6
7#include <stdlib.h>
8
9#include <algorithm>
10
11#include "base/bind.h"
12#include "base/compiler_specific.h"
13#include "base/memory/scoped_ptr.h"
14#include "base/message_loop/message_loop.h"
15#include "base/run_loop.h"
16#include "base/thread_task_runner_handle.h"
17#include "base/values.h"
18#include "remoting/protocol/protocol_mock_objects.h"
19#include "testing/gmock/include/gmock/gmock.h"
20#include "testing/gtest/include/gtest/gtest.h"
21
22using testing::Sequence;
23
24namespace {
25
26using remoting::protocol::PairingRegistry;
27
28class MockPairingRegistryCallbacks {
29 public:
30  MockPairingRegistryCallbacks() {}
31  virtual ~MockPairingRegistryCallbacks() {}
32
33  MOCK_METHOD1(DoneCallback, void(bool));
34  MOCK_METHOD1(GetAllPairingsCallbackPtr, void(base::ListValue*));
35  MOCK_METHOD1(GetPairingCallback, void(PairingRegistry::Pairing));
36
37  void GetAllPairingsCallback(scoped_ptr<base::ListValue> pairings) {
38    GetAllPairingsCallbackPtr(pairings.get());
39  }
40
41 private:
42  DISALLOW_COPY_AND_ASSIGN(MockPairingRegistryCallbacks);
43};
44
45// Verify that a pairing Dictionary has correct entries, but doesn't include
46// any shared secret.
47void VerifyPairing(PairingRegistry::Pairing expected,
48                   const base::DictionaryValue& actual) {
49  std::string value;
50  EXPECT_TRUE(actual.GetString(PairingRegistry::kClientNameKey, &value));
51  EXPECT_EQ(expected.client_name(), value);
52  EXPECT_TRUE(actual.GetString(PairingRegistry::kClientIdKey, &value));
53  EXPECT_EQ(expected.client_id(), value);
54
55  EXPECT_FALSE(actual.HasKey(PairingRegistry::kSharedSecretKey));
56}
57
58}  // namespace
59
60namespace remoting {
61namespace protocol {
62
63class PairingRegistryTest : public testing::Test {
64 public:
65  virtual void SetUp() OVERRIDE {
66    callback_count_ = 0;
67  }
68
69  void set_pairings(scoped_ptr<base::ListValue> pairings) {
70    pairings_ = pairings.Pass();
71  }
72
73  void ExpectSecret(const std::string& expected,
74                    PairingRegistry::Pairing actual) {
75    EXPECT_EQ(expected, actual.shared_secret());
76    ++callback_count_;
77  }
78
79  void ExpectSaveSuccess(bool success) {
80    EXPECT_TRUE(success);
81    ++callback_count_;
82  }
83
84 protected:
85  base::MessageLoop message_loop_;
86  base::RunLoop run_loop_;
87
88  int callback_count_;
89  scoped_ptr<base::ListValue> pairings_;
90};
91
92TEST_F(PairingRegistryTest, CreateAndGetPairings) {
93  scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry(
94      scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
95  PairingRegistry::Pairing pairing_1 = registry->CreatePairing("my_client");
96  PairingRegistry::Pairing pairing_2 = registry->CreatePairing("my_client");
97
98  EXPECT_NE(pairing_1.shared_secret(), pairing_2.shared_secret());
99
100  registry->GetPairing(pairing_1.client_id(),
101                       base::Bind(&PairingRegistryTest::ExpectSecret,
102                                  base::Unretained(this),
103                                  pairing_1.shared_secret()));
104  EXPECT_EQ(1, callback_count_);
105
106  // Check that the second client is paired with a different shared secret.
107  registry->GetPairing(pairing_2.client_id(),
108                       base::Bind(&PairingRegistryTest::ExpectSecret,
109                                  base::Unretained(this),
110                                  pairing_2.shared_secret()));
111  EXPECT_EQ(2, callback_count_);
112}
113
114TEST_F(PairingRegistryTest, GetAllPairings) {
115  scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry(
116      scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
117  PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1");
118  PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2");
119
120  registry->GetAllPairings(
121      base::Bind(&PairingRegistryTest::set_pairings,
122                 base::Unretained(this)));
123
124  ASSERT_EQ(2u, pairings_->GetSize());
125  const base::DictionaryValue* actual_pairing_1;
126  const base::DictionaryValue* actual_pairing_2;
127  ASSERT_TRUE(pairings_->GetDictionary(0, &actual_pairing_1));
128  ASSERT_TRUE(pairings_->GetDictionary(1, &actual_pairing_2));
129
130  // Ordering is not guaranteed, so swap if necessary.
131  std::string actual_client_id;
132  ASSERT_TRUE(actual_pairing_1->GetString(PairingRegistry::kClientIdKey,
133                                          &actual_client_id));
134  if (actual_client_id != pairing_1.client_id()) {
135    std::swap(actual_pairing_1, actual_pairing_2);
136  }
137
138  VerifyPairing(pairing_1, *actual_pairing_1);
139  VerifyPairing(pairing_2, *actual_pairing_2);
140}
141
142TEST_F(PairingRegistryTest, DeletePairing) {
143  scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry(
144      scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
145  PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1");
146  PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2");
147
148  registry->DeletePairing(
149      pairing_1.client_id(),
150      base::Bind(&PairingRegistryTest::ExpectSaveSuccess,
151                 base::Unretained(this)));
152
153  // Re-read the list, and verify it only has the pairing_2 client.
154  registry->GetAllPairings(
155      base::Bind(&PairingRegistryTest::set_pairings,
156                 base::Unretained(this)));
157
158  ASSERT_EQ(1u, pairings_->GetSize());
159  const base::DictionaryValue* actual_pairing_2;
160  ASSERT_TRUE(pairings_->GetDictionary(0, &actual_pairing_2));
161  std::string actual_client_id;
162  ASSERT_TRUE(actual_pairing_2->GetString(PairingRegistry::kClientIdKey,
163                                          &actual_client_id));
164  EXPECT_EQ(pairing_2.client_id(), actual_client_id);
165}
166
167TEST_F(PairingRegistryTest, ClearAllPairings) {
168  scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry(
169      scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
170  PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1");
171  PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2");
172
173  registry->ClearAllPairings(
174      base::Bind(&PairingRegistryTest::ExpectSaveSuccess,
175                 base::Unretained(this)));
176
177  // Re-read the list, and verify it is empty.
178  registry->GetAllPairings(
179      base::Bind(&PairingRegistryTest::set_pairings,
180                 base::Unretained(this)));
181
182  EXPECT_TRUE(pairings_->empty());
183}
184
185ACTION_P(QuitMessageLoop, callback) {
186  callback.Run();
187}
188
189MATCHER_P(EqualsClientName, client_name, "") {
190  return arg.client_name() == client_name;
191}
192
193MATCHER(NoPairings, "") {
194  return arg->empty();
195}
196
197TEST_F(PairingRegistryTest, SerializedRequests) {
198  MockPairingRegistryCallbacks callbacks;
199  Sequence s;
200  EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client1")))
201      .InSequence(s);
202  EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client2")))
203      .InSequence(s);
204  EXPECT_CALL(callbacks, DoneCallback(true))
205      .InSequence(s);
206  EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client1")))
207      .InSequence(s);
208  EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("")))
209      .InSequence(s);
210  EXPECT_CALL(callbacks, DoneCallback(true))
211      .InSequence(s);
212  EXPECT_CALL(callbacks, GetAllPairingsCallbackPtr(NoPairings()))
213      .InSequence(s);
214  EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client3")))
215      .InSequence(s)
216      .WillOnce(QuitMessageLoop(run_loop_.QuitClosure()));
217
218  scoped_refptr<PairingRegistry> registry = new PairingRegistry(
219      base::ThreadTaskRunnerHandle::Get(),
220      scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
221  PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1");
222  PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2");
223  registry->GetPairing(
224      pairing_1.client_id(),
225      base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
226                 base::Unretained(&callbacks)));
227  registry->GetPairing(
228      pairing_2.client_id(),
229      base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
230                 base::Unretained(&callbacks)));
231  registry->DeletePairing(
232      pairing_2.client_id(),
233      base::Bind(&MockPairingRegistryCallbacks::DoneCallback,
234                 base::Unretained(&callbacks)));
235  registry->GetPairing(
236      pairing_1.client_id(),
237      base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
238                 base::Unretained(&callbacks)));
239  registry->GetPairing(
240      pairing_2.client_id(),
241      base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
242                 base::Unretained(&callbacks)));
243  registry->ClearAllPairings(
244      base::Bind(&MockPairingRegistryCallbacks::DoneCallback,
245                 base::Unretained(&callbacks)));
246  registry->GetAllPairings(
247      base::Bind(&MockPairingRegistryCallbacks::GetAllPairingsCallback,
248                 base::Unretained(&callbacks)));
249  PairingRegistry::Pairing pairing_3 = registry->CreatePairing("client3");
250  registry->GetPairing(
251      pairing_3.client_id(),
252      base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
253                 base::Unretained(&callbacks)));
254
255  run_loop_.Run();
256}
257
258}  // namespace protocol
259}  // namespace remoting
260