1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "components/copresence/rpc/rpc_handler.h"
6
7#include <map>
8#include <string>
9#include <vector>
10
11#include "base/bind.h"
12#include "base/bind_helpers.h"
13#include "base/message_loop/message_loop.h"
14#include "components/copresence/handlers/directive_handler.h"
15#include "components/copresence/proto/data.pb.h"
16#include "components/copresence/proto/enums.pb.h"
17#include "components/copresence/proto/rpcs.pb.h"
18#include "net/http/http_status_code.h"
19#include "testing/gtest/include/gtest/gtest.h"
20
21using google::protobuf::MessageLite;
22using google::protobuf::RepeatedPtrField;
23
24namespace copresence {
25
26namespace {
27
28const char kChromeVersion[] = "Chrome Version String";
29
30void CreateSubscribedMessage(const std::vector<std::string>& subscription_ids,
31                             const std::string& message_string,
32                             SubscribedMessage* message_proto) {
33  message_proto->mutable_published_message()->set_payload(message_string);
34  for (std::vector<std::string>::const_iterator subscription_id =
35           subscription_ids.begin();
36       subscription_id != subscription_ids.end();
37       ++subscription_id) {
38    message_proto->add_subscription_id(*subscription_id);
39  }
40}
41
42// TODO(ckehoe): Make DirectiveHandler an interface.
43class FakeDirectiveHandler : public DirectiveHandler {
44 public:
45  FakeDirectiveHandler() {}
46  virtual ~FakeDirectiveHandler() {}
47
48  const std::vector<Directive>& added_directives() const {
49    return added_directives_;
50  }
51
52  virtual void Initialize(
53      const AudioRecorder::DecodeSamplesCallback& decode_cb,
54      const AudioDirectiveHandler::EncodeTokenCallback& encode_cb) OVERRIDE {}
55
56  virtual void AddDirective(const Directive& directive) OVERRIDE {
57    added_directives_.push_back(directive);
58  }
59
60  virtual void RemoveDirectives(const std::string& op_id) OVERRIDE {
61    // TODO(ckehoe): Add a parallel implementation when prod has one.
62  }
63
64 private:
65  std::vector<Directive> added_directives_;
66
67  DISALLOW_COPY_AND_ASSIGN(FakeDirectiveHandler);
68};
69
70}  // namespace
71
72class RpcHandlerTest : public testing::Test, public CopresenceDelegate {
73 public:
74  RpcHandlerTest() : rpc_handler_(this), status_(SUCCESS), api_key_("API key") {
75    rpc_handler_.server_post_callback_ =
76        base::Bind(&RpcHandlerTest::CaptureHttpPost, base::Unretained(this));
77    rpc_handler_.device_id_ = "Device ID";
78  }
79
80  void CaptureHttpPost(
81      net::URLRequestContextGetter* url_context_getter,
82      const std::string& rpc_name,
83      scoped_ptr<MessageLite> request_proto,
84      const RpcHandler::PostCleanupCallback& response_callback) {
85    rpc_name_ = rpc_name;
86    request_proto_ = request_proto.Pass();
87  }
88
89  void CaptureStatus(CopresenceStatus status) {
90    status_ = status;
91  }
92
93  inline const ReportRequest* GetReportSent() {
94    return static_cast<ReportRequest*>(request_proto_.get());
95  }
96
97  const TokenTechnology& GetTokenTechnologyFromReport() {
98    return GetReportSent()->update_signals_request().state().capabilities()
99        .token_technology(0);
100  }
101
102  const RepeatedPtrField<PublishedMessage>& GetMessagesPublished() {
103    return GetReportSent()->manage_messages_request().message_to_publish();
104  }
105
106  const RepeatedPtrField<Subscription>& GetSubscriptionsSent() {
107    return GetReportSent()->manage_subscriptions_request().subscription();
108  }
109
110  void SetDeviceId(const std::string& device_id) {
111    rpc_handler_.device_id_ = device_id;
112  }
113
114  const std::string& GetDeviceId() {
115    return rpc_handler_.device_id_;
116  }
117
118  void AddInvalidToken(const std::string& token) {
119    rpc_handler_.invalid_audio_token_cache_.Add(token, true);
120  }
121
122  bool TokenIsInvalid(const std::string& token) {
123    return rpc_handler_.invalid_audio_token_cache_.HasKey(token);
124  }
125
126  FakeDirectiveHandler* InstallFakeDirectiveHandler() {
127    FakeDirectiveHandler* handler = new FakeDirectiveHandler;
128    rpc_handler_.directive_handler_.reset(handler);
129    return handler;
130  }
131
132  void InvokeReportResponseHandler(int status_code,
133                                   const std::string& response) {
134    rpc_handler_.ReportResponseHandler(
135        base::Bind(&RpcHandlerTest::CaptureStatus, base::Unretained(this)),
136        NULL,
137        status_code,
138        response);
139  }
140
141  // CopresenceDelegate implementation
142
143  virtual void HandleMessages(
144      const std::string& app_id,
145      const std::string& subscription_id,
146      const std::vector<Message>& messages) OVERRIDE {
147    // app_id is unused for now, pending a server fix.
148    messages_by_subscription_[subscription_id] = messages;
149  }
150
151  virtual net::URLRequestContextGetter* GetRequestContext() const OVERRIDE {
152    return NULL;
153  }
154
155  virtual const std::string GetPlatformVersionString() const OVERRIDE {
156    return kChromeVersion;
157  }
158
159  virtual const std::string GetAPIKey() const OVERRIDE {
160    return api_key_;
161  }
162
163  virtual WhispernetClient* GetWhispernetClient() OVERRIDE {
164    return NULL;
165  }
166
167 protected:
168  // For rpc_handler_.invalid_audio_token_cache_
169  base::MessageLoop message_loop_;
170
171  RpcHandler rpc_handler_;
172  CopresenceStatus status_;
173  std::string api_key_;
174
175  std::string rpc_name_;
176  scoped_ptr<MessageLite> request_proto_;
177  std::map<std::string, std::vector<Message>> messages_by_subscription_;
178};
179
180TEST_F(RpcHandlerTest, Initialize) {
181  SetDeviceId("");
182  rpc_handler_.Initialize(RpcHandler::SuccessCallback());
183  RegisterDeviceRequest* registration =
184      static_cast<RegisterDeviceRequest*>(request_proto_.get());
185  Identity identity = registration->device_identifiers().registrant();
186  EXPECT_EQ(CHROME, identity.type());
187  EXPECT_FALSE(identity.chrome_id().empty());
188}
189
190TEST_F(RpcHandlerTest, CreateRequestHeader) {
191  SetDeviceId("CreateRequestHeader Device ID");
192  rpc_handler_.SendReportRequest(make_scoped_ptr(new ReportRequest),
193                                 "CreateRequestHeader App ID",
194                                 StatusCallback());
195  EXPECT_EQ(RpcHandler::kReportRequestRpcName, rpc_name_);
196  ReportRequest* report = static_cast<ReportRequest*>(request_proto_.get());
197  EXPECT_EQ(kChromeVersion,
198            report->header().framework_version().version_name());
199  EXPECT_EQ("CreateRequestHeader App ID",
200            report->header().client_version().client());
201  EXPECT_EQ("CreateRequestHeader Device ID",
202            report->header().registered_device_id());
203  EXPECT_EQ(CHROME_PLATFORM_TYPE,
204            report->header().device_fingerprint().type());
205}
206
207TEST_F(RpcHandlerTest, ReportTokens) {
208  std::vector<AudioToken> test_tokens;
209  test_tokens.push_back(AudioToken("token 1", false));
210  test_tokens.push_back(AudioToken("token 2", true));
211  test_tokens.push_back(AudioToken("token 3", false));
212  AddInvalidToken("token 2");
213
214  rpc_handler_.ReportTokens(test_tokens);
215  EXPECT_EQ(RpcHandler::kReportRequestRpcName, rpc_name_);
216  ReportRequest* report = static_cast<ReportRequest*>(request_proto_.get());
217  google::protobuf::RepeatedPtrField<TokenObservation> tokens_sent =
218      report->update_signals_request().token_observation();
219  ASSERT_EQ(2, tokens_sent.size());
220  EXPECT_EQ("token 1", tokens_sent.Get(0).token_id());
221  EXPECT_EQ("token 3", tokens_sent.Get(1).token_id());
222}
223
224TEST_F(RpcHandlerTest, ReportResponseHandler) {
225  // Fail on HTTP status != 200.
226  ReportResponse empty_response;
227  empty_response.mutable_header()->mutable_status()->set_code(OK);
228  std::string serialized_empty_response;
229  ASSERT_TRUE(empty_response.SerializeToString(&serialized_empty_response));
230  status_ = SUCCESS;
231  InvokeReportResponseHandler(net::HTTP_BAD_REQUEST, serialized_empty_response);
232  EXPECT_EQ(FAIL, status_);
233
234  std::vector<std::string> subscription_1(1, "Subscription 1");
235  std::vector<std::string> subscription_2(1, "Subscription 2");
236  std::vector<std::string> both_subscriptions;
237  both_subscriptions.push_back("Subscription 1");
238  both_subscriptions.push_back("Subscription 2");
239
240  ReportResponse test_response;
241  test_response.mutable_header()->mutable_status()->set_code(OK);
242  UpdateSignalsResponse* update_response =
243      test_response.mutable_update_signals_response();
244  update_response->set_status(util::error::OK);
245  Token* invalid_token = update_response->add_token();
246  invalid_token->set_id("bad token");
247  invalid_token->set_status(INVALID);
248  CreateSubscribedMessage(
249      subscription_1, "Message A", update_response->add_message());
250  CreateSubscribedMessage(
251      subscription_2, "Message B", update_response->add_message());
252  CreateSubscribedMessage(
253      both_subscriptions, "Message C", update_response->add_message());
254  update_response->add_directive()->set_subscription_id("Subscription 1");
255  update_response->add_directive()->set_subscription_id("Subscription 2");
256
257  messages_by_subscription_.clear();
258  FakeDirectiveHandler* directive_handler = InstallFakeDirectiveHandler();
259  std::string serialized_proto;
260  ASSERT_TRUE(test_response.SerializeToString(&serialized_proto));
261  status_ = FAIL;
262  InvokeReportResponseHandler(net::HTTP_OK, serialized_proto);
263
264  EXPECT_EQ(SUCCESS, status_);
265  EXPECT_TRUE(TokenIsInvalid("bad token"));
266  ASSERT_EQ(2U, messages_by_subscription_.size());
267  ASSERT_EQ(2U, messages_by_subscription_["Subscription 1"].size());
268  ASSERT_EQ(2U, messages_by_subscription_["Subscription 2"].size());
269  EXPECT_EQ("Message A",
270            messages_by_subscription_["Subscription 1"][0].payload());
271  EXPECT_EQ("Message B",
272            messages_by_subscription_["Subscription 2"][0].payload());
273  EXPECT_EQ("Message C",
274            messages_by_subscription_["Subscription 1"][1].payload());
275  EXPECT_EQ("Message C",
276            messages_by_subscription_["Subscription 2"][1].payload());
277
278  ASSERT_EQ(2U, directive_handler->added_directives().size());
279  EXPECT_EQ("Subscription 1",
280            directive_handler->added_directives()[0].subscription_id());
281  EXPECT_EQ("Subscription 2",
282            directive_handler->added_directives()[1].subscription_id());
283}
284
285}  // namespace copresence
286