1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/browser/extensions/blacklist_state_fetcher.h"
6
7#include "base/stl_util.h"
8#include "base/strings/stringprintf.h"
9#include "chrome/browser/browser_process.h"
10#include "chrome/browser/safe_browsing/protocol_manager_helper.h"
11#include "chrome/browser/safe_browsing/safe_browsing_service.h"
12#include "chrome/common/safe_browsing/crx_info.pb.h"
13#include "google_apis/google_api_keys.h"
14#include "net/base/escape.h"
15#include "net/url_request/url_request_context.h"
16#include "net/url_request/url_request_context_getter.h"
17#include "net/url_request/url_request_status.h"
18#include "url/gurl.h"
19
20using content::BrowserThread;
21
22namespace {
23
24class BlacklistRequestContextGetter : public net::URLRequestContextGetter {
25 public:
26  explicit BlacklistRequestContextGetter(
27      net::URLRequestContextGetter* parent_context_getter) :
28          network_task_runner_(
29              BrowserThread::GetMessageLoopProxyForThread(BrowserThread::IO)) {
30    DCHECK_CURRENTLY_ON(BrowserThread::IO);
31    url_request_context_.reset(new net::URLRequestContext());
32    url_request_context_->CopyFrom(
33        parent_context_getter->GetURLRequestContext());
34  }
35
36  static void Create(
37      scoped_refptr<net::URLRequestContextGetter> parent_context_getter,
38      base::Callback<void(scoped_refptr<net::URLRequestContextGetter>)>
39          callback) {
40    DCHECK_CURRENTLY_ON(BrowserThread::IO);
41
42    scoped_refptr<net::URLRequestContextGetter> context_getter =
43        new BlacklistRequestContextGetter(parent_context_getter.get());
44    BrowserThread::PostTask(BrowserThread::UI,
45                            FROM_HERE,
46                            base::Bind(callback, context_getter));
47  }
48
49  virtual net::URLRequestContext* GetURLRequestContext() OVERRIDE {
50    DCHECK_CURRENTLY_ON(BrowserThread::IO);
51    return url_request_context_.get();
52  }
53
54  virtual scoped_refptr<base::SingleThreadTaskRunner> GetNetworkTaskRunner()
55      const OVERRIDE {
56    return network_task_runner_;
57  }
58
59 protected:
60  virtual ~BlacklistRequestContextGetter() {
61    url_request_context_->AssertNoURLRequests();
62  }
63
64 private:
65  scoped_ptr<net::URLRequestContext> url_request_context_;
66  scoped_refptr<base::SingleThreadTaskRunner> network_task_runner_;
67};
68
69}  // namespace
70
71namespace extensions {
72
73BlacklistStateFetcher::BlacklistStateFetcher()
74    : url_fetcher_id_(0),
75      weak_ptr_factory_(this) {}
76
77BlacklistStateFetcher::~BlacklistStateFetcher() {
78  DCHECK_CURRENTLY_ON(BrowserThread::UI);
79  STLDeleteContainerPairFirstPointers(requests_.begin(), requests_.end());
80  requests_.clear();
81}
82
83void BlacklistStateFetcher::Request(const std::string& id,
84                                    const RequestCallback& callback) {
85  DCHECK_CURRENTLY_ON(BrowserThread::UI);
86  if (!safe_browsing_config_) {
87    if (g_browser_process && g_browser_process->safe_browsing_service()) {
88      SetSafeBrowsingConfig(
89          g_browser_process->safe_browsing_service()->GetProtocolConfig());
90    } else {
91      base::MessageLoopProxy::current()->PostTask(
92          FROM_HERE, base::Bind(callback, BLACKLISTED_UNKNOWN));
93      return;
94    }
95  }
96
97  bool request_already_sent = ContainsKey(callbacks_, id);
98  callbacks_.insert(std::make_pair(id, callback));
99  if (request_already_sent)
100    return;
101
102  if (url_request_context_getter_.get() || !g_browser_process ||
103      !g_browser_process->safe_browsing_service()) {
104    SendRequest(id);
105  } else {
106    scoped_refptr<net::URLRequestContextGetter> parent_request_context;
107    if (g_browser_process && g_browser_process->safe_browsing_service()) {
108      parent_request_context = g_browser_process->safe_browsing_service()
109                                                ->url_request_context();
110    } else {
111      parent_request_context = parent_request_context_for_test_;
112    }
113
114    BrowserThread::PostTask(
115        BrowserThread::IO, FROM_HERE,
116        base::Bind(&BlacklistRequestContextGetter::Create,
117                   parent_request_context,
118                   base::Bind(&BlacklistStateFetcher::SaveRequestContext,
119                              weak_ptr_factory_.GetWeakPtr(),
120                              id)));
121  }
122}
123
124void BlacklistStateFetcher::SaveRequestContext(
125    const std::string& id,
126    scoped_refptr<net::URLRequestContextGetter> request_context_getter) {
127  DCHECK_CURRENTLY_ON(BrowserThread::UI);
128  if (!url_request_context_getter_.get())
129    url_request_context_getter_ = request_context_getter;
130  SendRequest(id);
131}
132
133void BlacklistStateFetcher::SendRequest(const std::string& id) {
134  DCHECK_CURRENTLY_ON(BrowserThread::UI);
135
136  ClientCRXListInfoRequest request;
137  request.set_id(id);
138  std::string request_str;
139  request.SerializeToString(&request_str);
140
141  GURL request_url = RequestUrl();
142  net::URLFetcher* fetcher = net::URLFetcher::Create(url_fetcher_id_++,
143                                                     request_url,
144                                                     net::URLFetcher::POST,
145                                                     this);
146  requests_[fetcher] = id;
147  fetcher->SetAutomaticallyRetryOn5xx(false);  // Don't retry on error.
148  fetcher->SetRequestContext(url_request_context_getter_.get());
149  fetcher->SetUploadData("application/octet-stream", request_str);
150  fetcher->Start();
151}
152
153void BlacklistStateFetcher::SetSafeBrowsingConfig(
154    const SafeBrowsingProtocolConfig& config) {
155  safe_browsing_config_.reset(new SafeBrowsingProtocolConfig(config));
156}
157
158void BlacklistStateFetcher::SetURLRequestContextForTest(
159      net::URLRequestContextGetter* parent_request_context) {
160  parent_request_context_for_test_ = parent_request_context;
161}
162
163GURL BlacklistStateFetcher::RequestUrl() const {
164  std::string url = base::StringPrintf(
165      "%s/%s?client=%s&appver=%s&pver=2.2",
166      safe_browsing_config_->url_prefix.c_str(),
167      "clientreport/crx-list-info",
168      safe_browsing_config_->client_name.c_str(),
169      safe_browsing_config_->version.c_str());
170  std::string api_key = google_apis::GetAPIKey();
171  if (!api_key.empty()) {
172    base::StringAppendF(&url, "&key=%s",
173                        net::EscapeQueryParamValue(api_key, true).c_str());
174  }
175  return GURL(url);
176}
177
178void BlacklistStateFetcher::OnURLFetchComplete(const net::URLFetcher* source) {
179  DCHECK_CURRENTLY_ON(BrowserThread::UI);
180
181  std::map<const net::URLFetcher*, std::string>::iterator it =
182     requests_.find(source);
183  if (it == requests_.end()) {
184    NOTREACHED();
185    return;
186  }
187
188  scoped_ptr<const net::URLFetcher> fetcher;
189
190  fetcher.reset(it->first);
191  std::string id = it->second;
192  requests_.erase(it);
193
194  BlacklistState state;
195
196  if (source->GetStatus().is_success() && source->GetResponseCode() == 200) {
197    std::string data;
198    source->GetResponseAsString(&data);
199    ClientCRXListInfoResponse response;
200    if (response.ParseFromString(data)) {
201      state = static_cast<BlacklistState>(response.verdict());
202    } else {
203      state = BLACKLISTED_UNKNOWN;
204    }
205  } else {
206    if (source->GetStatus().status() == net::URLRequestStatus::FAILED) {
207      VLOG(1) << "Blacklist request for: " << id
208              << " failed with error: " << source->GetStatus().error();
209    } else {
210      VLOG(1) << "Blacklist request for: " << id
211              << " failed with error: " << source->GetResponseCode();
212    }
213
214    state = BLACKLISTED_UNKNOWN;
215  }
216
217  std::pair<CallbackMultiMap::iterator, CallbackMultiMap::iterator> range =
218      callbacks_.equal_range(id);
219  for (CallbackMultiMap::const_iterator callback_it = range.first;
220       callback_it != range.second;
221       ++callback_it) {
222    callback_it->second.Run(state);
223  }
224
225  callbacks_.erase(range.first, range.second);
226}
227
228}  // namespace extensions
229
230