client_side_detection_service_unittest.cc revision 72a454cd3513ac24fbdd0e0cb9ad70b86a99b801
1// Copyright (c) 2010 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <map>
6#include <queue>
7#include <string>
8
9#include "base/callback.h"
10#include "base/file_path.h"
11#include "base/file_util.h"
12#include "base/file_util_proxy.h"
13#include "base/logging.h"
14#include "base/message_loop.h"
15#include "base/platform_file.h"
16#include "base/scoped_ptr.h"
17#include "base/scoped_temp_dir.h"
18#include "base/task.h"
19#include "base/time.h"
20#include "testing/gtest/include/gtest/gtest.h"
21#include "chrome/browser/browser_thread.h"
22#include "chrome/browser/renderer_host/test/test_render_view_host.h"
23#include "chrome/browser/safe_browsing/client_side_detection_service.h"
24#include "chrome/browser/safe_browsing/csd.pb.h"
25#include "chrome/common/render_messages.h"
26#include "chrome/common/net/test_url_fetcher_factory.h"
27#include "chrome/common/net/url_fetcher.h"
28#include "googleurl/src/gurl.h"
29#include "ipc/ipc_channel.h"
30#include "ipc/ipc_test_sink.h"
31#include "net/url_request/url_request_status.h"
32
33namespace safe_browsing {
34
35class ClientSideDetectionServiceTest : public testing::Test {
36 protected:
37  virtual void SetUp() {
38    file_thread_.reset(new BrowserThread(BrowserThread::FILE, &msg_loop_));
39
40    factory_.reset(new FakeURLFetcherFactory());
41    URLFetcher::set_factory(factory_.get());
42
43    browser_thread_.reset(new BrowserThread(BrowserThread::UI, &msg_loop_));
44  }
45
46  virtual void TearDown() {
47    msg_loop_.RunAllPending();
48    csd_service_.reset();
49    URLFetcher::set_factory(NULL);
50    file_thread_.reset();
51    browser_thread_.reset();
52  }
53
54  base::PlatformFile GetModelFile() {
55    model_file_ = base::kInvalidPlatformFileValue;
56    csd_service_->GetModelFile(NewCallback(
57        this, &ClientSideDetectionServiceTest::GetModelFileDone));
58    // This method will block this thread until GetModelFileDone is called.
59    msg_loop_.Run();
60    return model_file_;
61  }
62
63  std::string ReadModelFile(base::PlatformFile model_file) {
64    char buf[1024];
65    int n = base::ReadPlatformFile(model_file, 0, buf, 1024);
66    EXPECT_LE(0, n);
67    return (n < 0) ? "" : std::string(buf, n);
68  }
69
70  bool SendClientReportPhishingRequest(const GURL& phishing_url,
71                                       double score) {
72    csd_service_->SendClientReportPhishingRequest(
73        phishing_url,
74        score,
75        NewCallback(this, &ClientSideDetectionServiceTest::SendRequestDone));
76    phishing_url_ = phishing_url;
77    msg_loop_.Run();  // Waits until callback is called.
78    return is_phishing_;
79  }
80
81  void SetModelFetchResponse(std::string response_data, bool success) {
82    factory_->SetFakeResponse(ClientSideDetectionService::kClientModelUrl,
83                              response_data, success);
84  }
85
86  void SetClientReportPhishingResponse(std::string response_data,
87                                       bool success) {
88    factory_->SetFakeResponse(
89        ClientSideDetectionService::kClientReportPhishingUrl,
90        response_data, success);
91  }
92
93  int GetNumReports() {
94    return csd_service_->GetNumReports();
95  }
96
97  std::queue<base::Time>& GetPhishingReportTimes() {
98    return csd_service_->phishing_report_times_;
99  }
100
101  void SetCache(const GURL& gurl, bool is_phishing, base::Time time) {
102    csd_service_->cache_[gurl] =
103        make_linked_ptr(new ClientSideDetectionService::CacheState(is_phishing,
104                                                                   time));
105  }
106
107  void TestCache() {
108    ClientSideDetectionService::PhishingCache& cache = csd_service_->cache_;
109    base::Time now = base::Time::Now();
110    base::Time time = now - ClientSideDetectionService::kNegativeCacheInterval +
111        base::TimeDelta::FromMinutes(5);
112    cache[GURL("http://first.url.com/")] =
113        make_linked_ptr(new ClientSideDetectionService::CacheState(false,
114                                                                   time));
115
116    time = now - ClientSideDetectionService::kNegativeCacheInterval -
117        base::TimeDelta::FromHours(1);
118    cache[GURL("http://second.url.com/")] =
119        make_linked_ptr(new ClientSideDetectionService::CacheState(false,
120                                                                   time));
121
122    time = now - ClientSideDetectionService::kPositiveCacheInterval -
123        base::TimeDelta::FromMinutes(5);
124    cache[GURL("http://third.url.com/")] =
125        make_linked_ptr(new ClientSideDetectionService::CacheState(true, time));
126
127    time = now - ClientSideDetectionService::kPositiveCacheInterval +
128        base::TimeDelta::FromMinutes(5);
129    cache[GURL("http://fourth.url.com/")] =
130        make_linked_ptr(new ClientSideDetectionService::CacheState(true, time));
131
132    csd_service_->UpdateCache();
133
134    // 3 elements should be in the cache, the first, third, and fourth.
135    EXPECT_EQ(3U, cache.size());
136    EXPECT_TRUE(cache.find(GURL("http://first.url.com/")) != cache.end());
137    EXPECT_TRUE(cache.find(GURL("http://third.url.com/")) != cache.end());
138    EXPECT_TRUE(cache.find(GURL("http://fourth.url.com/")) != cache.end());
139
140    // While 3 elements remain, only the first and the fourth are actually
141    // valid.
142    bool is_phishing;
143    EXPECT_TRUE(csd_service_->GetCachedResult(GURL("http://first.url.com"),
144                                              &is_phishing));
145    EXPECT_FALSE(is_phishing);
146    EXPECT_FALSE(csd_service_->GetCachedResult(GURL("http://third.url.com"),
147                                               &is_phishing));
148    EXPECT_TRUE(csd_service_->GetCachedResult(GURL("http://fourth.url.com"),
149                                              &is_phishing));
150    EXPECT_TRUE(is_phishing);
151  }
152
153 protected:
154  scoped_ptr<ClientSideDetectionService> csd_service_;
155  scoped_ptr<FakeURLFetcherFactory> factory_;
156  MessageLoop msg_loop_;
157
158 private:
159  void GetModelFileDone(base::PlatformFile model_file) {
160    model_file_ = model_file;
161    msg_loop_.Quit();
162  }
163
164  void SendRequestDone(GURL phishing_url, bool is_phishing) {
165    ASSERT_EQ(phishing_url, phishing_url_);
166    is_phishing_ = is_phishing;
167    msg_loop_.Quit();
168  }
169
170  scoped_ptr<BrowserThread> browser_thread_;
171  base::PlatformFile model_file_;
172  scoped_ptr<BrowserThread> file_thread_;
173
174  GURL phishing_url_;
175  bool is_phishing_;
176};
177
178TEST_F(ClientSideDetectionServiceTest, TestFetchingModel) {
179  ScopedTempDir tmp_dir;
180  ASSERT_TRUE(tmp_dir.CreateUniqueTempDir());
181  FilePath model_path = tmp_dir.path().AppendASCII("model");
182
183  // The first time we create the csd service the model file does not exist so
184  // we expect there to be a fetch.
185  SetModelFetchResponse("BOGUS MODEL", true);
186  csd_service_.reset(ClientSideDetectionService::Create(model_path, NULL));
187  base::PlatformFile model_file = GetModelFile();
188  EXPECT_NE(model_file, base::kInvalidPlatformFileValue);
189  EXPECT_EQ(ReadModelFile(model_file), "BOGUS MODEL");
190
191  // If you call GetModelFile() multiple times you always get the same platform
192  // file back.  We don't re-open the file.
193  EXPECT_EQ(GetModelFile(), model_file);
194
195  // The second time the model already exists on disk.  In this case there
196  // should not be any fetch.  To ensure that we clear the factory.
197  factory_->ClearFakeReponses();
198  csd_service_.reset(ClientSideDetectionService::Create(model_path, NULL));
199  model_file = GetModelFile();
200  EXPECT_NE(model_file, base::kInvalidPlatformFileValue);
201  EXPECT_EQ(ReadModelFile(model_file), "BOGUS MODEL");
202
203  // If the model does not exist and the fetch fails we should get an error.
204  model_path = tmp_dir.path().AppendASCII("another_model");
205  SetModelFetchResponse("", false /* success */);
206  csd_service_.reset(ClientSideDetectionService::Create(model_path, NULL));
207  EXPECT_EQ(GetModelFile(), base::kInvalidPlatformFileValue);
208}
209
210TEST_F(ClientSideDetectionServiceTest, ServiceObjectDeletedBeforeCallbackDone) {
211  SetModelFetchResponse("bogus model", true /* success */);
212  ScopedTempDir tmp_dir;
213  ASSERT_TRUE(tmp_dir.CreateUniqueTempDir());
214  csd_service_.reset(ClientSideDetectionService::Create(
215      tmp_dir.path().AppendASCII("model"), NULL));
216  EXPECT_TRUE(csd_service_.get() != NULL);
217  // We delete the client-side detection service class even though the callbacks
218  // haven't run yet.
219  csd_service_.reset();
220  // Waiting for the callbacks to run should not crash even if the service
221  // object is gone.
222  msg_loop_.RunAllPending();
223}
224
225TEST_F(ClientSideDetectionServiceTest, SendClientReportPhishingRequest) {
226  SetModelFetchResponse("bogus model", true /* success */);
227  ScopedTempDir tmp_dir;
228  ASSERT_TRUE(tmp_dir.CreateUniqueTempDir());
229  csd_service_.reset(ClientSideDetectionService::Create(
230      tmp_dir.path().AppendASCII("model"), NULL));
231
232  GURL url("http://a.com/");
233  double score = 0.4;  // Some random client score.
234
235  base::Time before = base::Time::Now();
236
237  // Invalid response body from the server.
238  SetClientReportPhishingResponse("invalid proto response", true /* success */);
239  EXPECT_FALSE(SendClientReportPhishingRequest(url, score));
240
241  // Normal behavior.
242  ClientPhishingResponse response;
243  response.set_phishy(true);
244  SetClientReportPhishingResponse(response.SerializeAsString(),
245                                  true /* success */);
246  EXPECT_TRUE(SendClientReportPhishingRequest(url, score));
247
248  // Caching causes this to still count as phishy.
249  response.set_phishy(false);
250  SetClientReportPhishingResponse(response.SerializeAsString(),
251                                  true /* success */);
252  EXPECT_TRUE(SendClientReportPhishingRequest(url, score));
253
254  // This request will fail and should not be cached.
255  GURL second_url("http://b.com/");
256  response.set_phishy(false);
257  SetClientReportPhishingResponse(response.SerializeAsString(),
258                                  false /* success*/);
259  EXPECT_FALSE(SendClientReportPhishingRequest(second_url, score));
260
261  // Verify that the previous request was not cached.
262  response.set_phishy(true);
263  SetClientReportPhishingResponse(response.SerializeAsString(),
264                                  true /* success */);
265  EXPECT_TRUE(SendClientReportPhishingRequest(second_url, score));
266
267  // This request is blocked because it's not in the cache and we have more
268  // than 3 requests.
269  GURL third_url("http://c.com");
270  response.set_phishy(true);
271  SetClientReportPhishingResponse(response.SerializeAsString(),
272                                  true /* success */);
273  EXPECT_FALSE(SendClientReportPhishingRequest(third_url, score));
274
275  // Verify that caching still works even when new requests are blocked.
276  response.set_phishy(true);
277  SetClientReportPhishingResponse(response.SerializeAsString(),
278                                  true /* success */);
279  EXPECT_TRUE(SendClientReportPhishingRequest(url, score));
280
281  // Verify that we allow cache refreshing even when requests are blocked.
282  base::Time cache_time = base::Time::Now() - base::TimeDelta::FromHours(1);
283  SetCache(second_url, true, cache_time);
284
285  // Even though this element is in the cache, it's not currently valid so
286  // we make request and return that value instead.
287  response.set_phishy(false);
288  SetClientReportPhishingResponse(response.SerializeAsString(),
289                                  true /* success */);
290  EXPECT_FALSE(SendClientReportPhishingRequest(second_url, score));
291
292  base::Time after = base::Time::Now();
293
294  // Check that we have recorded 5 requests, all within the correct time range.
295  // The blocked request and the cached requests should not be present.
296  std::queue<base::Time>& report_times = GetPhishingReportTimes();
297  EXPECT_EQ(5U, report_times.size());
298  while (!report_times.empty()) {
299    base::Time time = report_times.back();
300    report_times.pop();
301    EXPECT_LE(before, time);
302    EXPECT_GE(after, time);
303  }
304}
305
306TEST_F(ClientSideDetectionServiceTest, GetNumReportTest) {
307  SetModelFetchResponse("bogus model", true /* success */);
308  ScopedTempDir tmp_dir;
309  ASSERT_TRUE(tmp_dir.CreateUniqueTempDir());
310  csd_service_.reset(ClientSideDetectionService::Create(
311      tmp_dir.path().AppendASCII("model"), NULL));
312
313  std::queue<base::Time>& report_times = GetPhishingReportTimes();
314  base::Time now = base::Time::Now();
315  base::TimeDelta twenty_five_hours = base::TimeDelta::FromHours(25);
316  report_times.push(now - twenty_five_hours);
317  report_times.push(now - twenty_five_hours);
318  report_times.push(now);
319  report_times.push(now);
320
321  EXPECT_EQ(2, GetNumReports());
322}
323
324TEST_F(ClientSideDetectionServiceTest, CacheTest) {
325  SetModelFetchResponse("bogus model", true /* success */);
326  ScopedTempDir tmp_dir;
327  ASSERT_TRUE(tmp_dir.CreateUniqueTempDir());
328  csd_service_.reset(ClientSideDetectionService::Create(
329      tmp_dir.path().AppendASCII("model"), NULL));
330
331  TestCache();
332}
333
334// We use a separate test fixture for testing the ClientSideDetectionService's
335// handling of load notifications from TabContents.  This uses
336// RenderViewHostTestHarness to set up a fake TabContents and related objects.
337class ClientSideDetectionServiceHooksTest : public RenderViewHostTestHarness,
338                                            public IPC::Channel::Listener {
339 public:
340  // IPC::Channel::Listener
341  virtual bool OnMessageReceived(const IPC::Message& msg) {
342    if (msg.type() == ViewMsg_StartPhishingDetection::ID) {
343      received_msg_ = msg;
344      did_receive_msg_ = true;
345      return true;
346    }
347    return false;
348  }
349
350 protected:
351  virtual void SetUp() {
352    RenderViewHostTestHarness::SetUp();
353    file_thread_.reset(new BrowserThread(BrowserThread::FILE, &message_loop_));
354    ui_thread_.reset(new BrowserThread(BrowserThread::UI, &message_loop_));
355
356    // We're not exercising model fetching here, so just set up a canned
357    // success response.
358    factory_.reset(new FakeURLFetcherFactory());
359    factory_->SetFakeResponse(ClientSideDetectionService::kClientModelUrl,
360                              "dummy model data", true);
361    URLFetcher::set_factory(factory_.get());
362
363    process()->sink().AddFilter(this);
364  }
365
366  virtual void TearDown() {
367    process()->sink().RemoveFilter(this);
368    URLFetcher::set_factory(NULL);
369    file_thread_.reset();
370    ui_thread_.reset();
371    RenderViewHostTestHarness::TearDown();
372  }
373
374  scoped_ptr<FakeURLFetcherFactory> factory_;
375  scoped_ptr<BrowserThread> ui_thread_;
376  scoped_ptr<BrowserThread> file_thread_;
377  IPC::Message received_msg_;
378  bool did_receive_msg_;
379};
380
381TEST_F(ClientSideDetectionServiceHooksTest, ShouldClassifyUrl) {
382  ScopedTempDir tmp_dir;
383  ASSERT_TRUE(tmp_dir.CreateUniqueTempDir());
384  FilePath model_path = tmp_dir.path().AppendASCII("model");
385
386  scoped_ptr<ClientSideDetectionService> csd_service(
387      ClientSideDetectionService::Create(model_path, NULL));
388
389  // Navigate the tab to a page.  We should see a StartPhishingDetection IPC.
390  did_receive_msg_ = false;
391  NavigateAndCommit(GURL("http://host.com/"));
392  // The IPC is sent asynchronously, so run the message loop to wait for
393  // the message.
394  MessageLoop::current()->RunAllPending();
395  ASSERT_TRUE(did_receive_msg_);
396
397  Tuple1<GURL> url;
398  ViewMsg_StartPhishingDetection::Read(&received_msg_, &url);
399  EXPECT_EQ(GURL("http://host.com/"), url.a);
400  EXPECT_EQ(rvh()->routing_id(), received_msg_.routing_id());
401
402  // Now try an in-page navigation.  This should not trigger an IPC.
403  did_receive_msg_ = false;
404  NavigateAndCommit(GURL("http://host.com/#foo"));
405  MessageLoop::current()->RunAllPending();
406  ASSERT_FALSE(did_receive_msg_);
407}
408
409}  // namespace safe_browsing
410