1// Copyright 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/browser/extensions/blacklist.h"
6
7#include <algorithm>
8#include <iterator>
9
10#include "base/bind.h"
11#include "base/lazy_instance.h"
12#include "base/memory/ref_counted.h"
13#include "base/prefs/pref_service.h"
14#include "base/stl_util.h"
15#include "chrome/browser/browser_process.h"
16#include "chrome/browser/chrome_notification_types.h"
17#include "chrome/browser/extensions/blacklist_state_fetcher.h"
18#include "chrome/browser/safe_browsing/safe_browsing_service.h"
19#include "chrome/browser/safe_browsing/safe_browsing_util.h"
20#include "chrome/common/pref_names.h"
21#include "content/public/browser/notification_details.h"
22#include "content/public/browser/notification_source.h"
23#include "extensions/browser/extension_prefs.h"
24
25using content::BrowserThread;
26
27namespace extensions {
28
29namespace {
30
31// The safe browsing database manager to use. Make this a global/static variable
32// rather than a member of Blacklist because Blacklist accesses the real
33// database manager before it has a chance to get a fake one.
34class LazySafeBrowsingDatabaseManager {
35 public:
36  LazySafeBrowsingDatabaseManager() {
37#if defined(FULL_SAFE_BROWSING) || defined(MOBILE_SAFE_BROWSING)
38    if (g_browser_process && g_browser_process->safe_browsing_service()) {
39      instance_ =
40          g_browser_process->safe_browsing_service()->database_manager();
41    }
42#endif
43  }
44
45  scoped_refptr<SafeBrowsingDatabaseManager> get() {
46    return instance_;
47  }
48
49  void set(scoped_refptr<SafeBrowsingDatabaseManager> instance) {
50    instance_ = instance;
51  }
52
53 private:
54  scoped_refptr<SafeBrowsingDatabaseManager> instance_;
55};
56
57static base::LazyInstance<LazySafeBrowsingDatabaseManager> g_database_manager =
58    LAZY_INSTANCE_INITIALIZER;
59
60// Implementation of SafeBrowsingDatabaseManager::Client, the class which is
61// called back from safebrowsing queries.
62//
63// Constructed on any thread but lives on the IO from then on.
64class SafeBrowsingClientImpl
65    : public SafeBrowsingDatabaseManager::Client,
66      public base::RefCountedThreadSafe<SafeBrowsingClientImpl> {
67 public:
68  typedef base::Callback<void(const std::set<std::string>&)> OnResultCallback;
69
70  // Constructs a client to query the database manager for |extension_ids| and
71  // run |callback| with the IDs of those which have been blacklisted.
72  SafeBrowsingClientImpl(
73      const std::set<std::string>& extension_ids,
74      const OnResultCallback& callback)
75      : callback_message_loop_(base::MessageLoopProxy::current()),
76        callback_(callback) {
77    BrowserThread::PostTask(
78        BrowserThread::IO,
79        FROM_HERE,
80        base::Bind(&SafeBrowsingClientImpl::StartCheck, this,
81                   g_database_manager.Get().get(),
82                   extension_ids));
83  }
84
85 private:
86  friend class base::RefCountedThreadSafe<SafeBrowsingClientImpl>;
87  virtual ~SafeBrowsingClientImpl() {}
88
89  // Pass |database_manager| as a parameter to avoid touching
90  // SafeBrowsingService on the IO thread.
91  void StartCheck(scoped_refptr<SafeBrowsingDatabaseManager> database_manager,
92                  const std::set<std::string>& extension_ids) {
93    DCHECK_CURRENTLY_ON(BrowserThread::IO);
94    if (database_manager->CheckExtensionIDs(extension_ids, this)) {
95      // Definitely not blacklisted. Callback immediately.
96      callback_message_loop_->PostTask(
97          FROM_HERE,
98          base::Bind(callback_, std::set<std::string>()));
99      return;
100    }
101    // Something might be blacklisted, response will come in
102    // OnCheckExtensionsResult.
103    AddRef();  // Balanced in OnCheckExtensionsResult
104  }
105
106  virtual void OnCheckExtensionsResult(
107      const std::set<std::string>& hits) OVERRIDE {
108    DCHECK_CURRENTLY_ON(BrowserThread::IO);
109    callback_message_loop_->PostTask(FROM_HERE, base::Bind(callback_, hits));
110    Release();  // Balanced in StartCheck.
111  }
112
113  scoped_refptr<base::MessageLoopProxy> callback_message_loop_;
114  OnResultCallback callback_;
115
116  DISALLOW_COPY_AND_ASSIGN(SafeBrowsingClientImpl);
117};
118
119void CheckOneExtensionState(
120    const Blacklist::IsBlacklistedCallback& callback,
121    const Blacklist::BlacklistStateMap& state_map) {
122  callback.Run(state_map.empty() ? NOT_BLACKLISTED : state_map.begin()->second);
123}
124
125void GetMalwareFromBlacklistStateMap(
126    const Blacklist::GetMalwareIDsCallback& callback,
127    const Blacklist::BlacklistStateMap& state_map) {
128  std::set<std::string> malware;
129  for (Blacklist::BlacklistStateMap::const_iterator it = state_map.begin();
130       it != state_map.end(); ++it) {
131    // TODO(oleg): UNKNOWN is treated as MALWARE for backwards compatibility.
132    // In future GetMalwareIDs will be removed and the caller will have to
133    // deal with BLACKLISTED_UNKNOWN state returned from GetBlacklistedIDs.
134    if (it->second == BLACKLISTED_MALWARE || it->second == BLACKLISTED_UNKNOWN)
135      malware.insert(it->first);
136  }
137  callback.Run(malware);
138}
139
140}  // namespace
141
142Blacklist::Observer::Observer(Blacklist* blacklist) : blacklist_(blacklist) {
143  blacklist_->AddObserver(this);
144}
145
146Blacklist::Observer::~Observer() {
147  blacklist_->RemoveObserver(this);
148}
149
150Blacklist::ScopedDatabaseManagerForTest::ScopedDatabaseManagerForTest(
151    scoped_refptr<SafeBrowsingDatabaseManager> database_manager)
152    : original_(GetDatabaseManager()) {
153  SetDatabaseManager(database_manager);
154}
155
156Blacklist::ScopedDatabaseManagerForTest::~ScopedDatabaseManagerForTest() {
157  SetDatabaseManager(original_);
158}
159
160Blacklist::Blacklist(ExtensionPrefs* prefs) {
161  scoped_refptr<SafeBrowsingDatabaseManager> database_manager =
162      g_database_manager.Get().get();
163  if (database_manager.get()) {
164    registrar_.Add(
165        this,
166        chrome::NOTIFICATION_SAFE_BROWSING_UPDATE_COMPLETE,
167        content::Source<SafeBrowsingDatabaseManager>(database_manager.get()));
168  }
169
170  // Clear out the old prefs-backed blacklist, stored as empty extension entries
171  // with just a "blacklisted" property.
172  //
173  // TODO(kalman): Delete this block of code, see http://crbug.com/295882.
174  std::set<std::string> blacklisted = prefs->GetBlacklistedExtensions();
175  for (std::set<std::string>::iterator it = blacklisted.begin();
176       it != blacklisted.end(); ++it) {
177    if (!prefs->GetInstalledExtensionInfo(*it))
178      prefs->DeleteExtensionPrefs(*it);
179  }
180}
181
182Blacklist::~Blacklist() {
183}
184
185void Blacklist::GetBlacklistedIDs(const std::set<std::string>& ids,
186                                  const GetBlacklistedIDsCallback& callback) {
187  DCHECK_CURRENTLY_ON(BrowserThread::UI);
188
189  if (ids.empty() || !g_database_manager.Get().get().get()) {
190    base::MessageLoopProxy::current()->PostTask(
191        FROM_HERE, base::Bind(callback, BlacklistStateMap()));
192    return;
193  }
194
195  // Constructing the SafeBrowsingClientImpl begins the process of asking
196  // safebrowsing for the blacklisted extensions. The set of blacklisted
197  // extensions returned by SafeBrowsing will then be passed to
198  // GetBlacklistStateIDs to get the particular BlacklistState for each id.
199  new SafeBrowsingClientImpl(
200      ids, base::Bind(&Blacklist::GetBlacklistStateForIDs, AsWeakPtr(),
201                      callback));
202}
203
204void Blacklist::GetMalwareIDs(const std::set<std::string>& ids,
205                              const GetMalwareIDsCallback& callback) {
206  GetBlacklistedIDs(ids, base::Bind(&GetMalwareFromBlacklistStateMap,
207                                    callback));
208}
209
210
211void Blacklist::IsBlacklisted(const std::string& extension_id,
212                              const IsBlacklistedCallback& callback) {
213  std::set<std::string> check;
214  check.insert(extension_id);
215  GetBlacklistedIDs(check, base::Bind(&CheckOneExtensionState, callback));
216}
217
218void Blacklist::GetBlacklistStateForIDs(
219    const GetBlacklistedIDsCallback& callback,
220    const std::set<std::string>& blacklisted_ids) {
221  DCHECK_CURRENTLY_ON(BrowserThread::UI);
222
223  std::set<std::string> ids_unknown_state;
224  BlacklistStateMap extensions_state;
225  for (std::set<std::string>::const_iterator it = blacklisted_ids.begin();
226       it != blacklisted_ids.end(); ++it) {
227    BlacklistStateMap::const_iterator cache_it =
228        blacklist_state_cache_.find(*it);
229    if (cache_it == blacklist_state_cache_.end() ||
230        cache_it->second == BLACKLISTED_UNKNOWN)  // Do not return UNKNOWN
231                                                  // from cache, retry request.
232      ids_unknown_state.insert(*it);
233    else
234      extensions_state[*it] = cache_it->second;
235  }
236
237  if (ids_unknown_state.empty()) {
238    callback.Run(extensions_state);
239  } else {
240    // After the extension blacklist states have been downloaded, call this
241    // functions again, but prevent infinite cycle in case server is offline
242    // or some other reason prevents us from receiving the blacklist state for
243    // these extensions.
244    RequestExtensionsBlacklistState(
245        ids_unknown_state,
246        base::Bind(&Blacklist::ReturnBlacklistStateMap, AsWeakPtr(),
247                   callback, blacklisted_ids));
248  }
249}
250
251void Blacklist::ReturnBlacklistStateMap(
252    const GetBlacklistedIDsCallback& callback,
253    const std::set<std::string>& blacklisted_ids) {
254  BlacklistStateMap extensions_state;
255  for (std::set<std::string>::const_iterator it = blacklisted_ids.begin();
256       it != blacklisted_ids.end(); ++it) {
257    BlacklistStateMap::const_iterator cache_it =
258        blacklist_state_cache_.find(*it);
259    if (cache_it != blacklist_state_cache_.end())
260      extensions_state[*it] = cache_it->second;
261    // If for some reason we still haven't cached the state of this extension,
262    // we silently skip it.
263  }
264
265  callback.Run(extensions_state);
266}
267
268void Blacklist::RequestExtensionsBlacklistState(
269    const std::set<std::string>& ids, const base::Callback<void()>& callback) {
270  DCHECK_CURRENTLY_ON(BrowserThread::UI);
271  if (!state_fetcher_)
272    state_fetcher_.reset(new BlacklistStateFetcher());
273
274  state_requests_.push_back(
275      make_pair(std::vector<std::string>(ids.begin(), ids.end()), callback));
276  for (std::set<std::string>::const_iterator it = ids.begin();
277       it != ids.end();
278       ++it) {
279    state_fetcher_->Request(
280        *it,
281        base::Bind(&Blacklist::OnBlacklistStateReceived, AsWeakPtr(), *it));
282  }
283}
284
285void Blacklist::OnBlacklistStateReceived(const std::string& id,
286                                         BlacklistState state) {
287  DCHECK_CURRENTLY_ON(BrowserThread::UI);
288  blacklist_state_cache_[id] = state;
289
290  // Go through the opened requests and call the callbacks for those requests
291  // for which we already got all the required blacklist states.
292  StateRequestsList::iterator requests_it = state_requests_.begin();
293  while (requests_it != state_requests_.end()) {
294    const std::vector<std::string>& ids = requests_it->first;
295
296    bool have_all_in_cache = true;
297    for (std::vector<std::string>::const_iterator ids_it = ids.begin();
298         ids_it != ids.end();
299         ++ids_it) {
300      if (!ContainsKey(blacklist_state_cache_, *ids_it)) {
301        have_all_in_cache = false;
302        break;
303      }
304    }
305
306    if (have_all_in_cache) {
307      requests_it->second.Run();
308      requests_it = state_requests_.erase(requests_it); // returns next element
309    } else {
310      ++requests_it;
311    }
312  }
313}
314
315void Blacklist::SetBlacklistStateFetcherForTest(
316    BlacklistStateFetcher* fetcher) {
317  state_fetcher_.reset(fetcher);
318}
319
320BlacklistStateFetcher* Blacklist::ResetBlacklistStateFetcherForTest() {
321  return state_fetcher_.release();
322}
323
324void Blacklist::AddObserver(Observer* observer) {
325  DCHECK_CURRENTLY_ON(BrowserThread::UI);
326  observers_.AddObserver(observer);
327}
328
329void Blacklist::RemoveObserver(Observer* observer) {
330  DCHECK_CURRENTLY_ON(BrowserThread::UI);
331  observers_.RemoveObserver(observer);
332}
333
334// static
335void Blacklist::SetDatabaseManager(
336    scoped_refptr<SafeBrowsingDatabaseManager> database_manager) {
337  g_database_manager.Get().set(database_manager);
338}
339
340// static
341scoped_refptr<SafeBrowsingDatabaseManager> Blacklist::GetDatabaseManager() {
342  return g_database_manager.Get().get();
343}
344
345void Blacklist::Observe(int type,
346                        const content::NotificationSource& source,
347                        const content::NotificationDetails& details) {
348  DCHECK_EQ(chrome::NOTIFICATION_SAFE_BROWSING_UPDATE_COMPLETE, type);
349  FOR_EACH_OBSERVER(Observer, observers_, OnBlacklistUpdated());
350}
351
352}  // namespace extensions
353