1// Copyright (c) 2010 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/browser/safe_browsing/bloom_filter.h"
6
7#include "base/metrics/histogram.h"
8#include "base/rand_util.h"
9#include "net/base/file_stream.h"
10#include "net/base/net_errors.h"
11
12namespace {
13
14// The Jenkins 96 bit mix function:
15// http://www.concentric.net/~Ttwang/tech/inthash.htm
16uint32 HashMix(BloomFilter::HashKey hash_key, uint32 c) {
17  uint32 a = static_cast<uint32>(hash_key)       & 0xFFFFFFFF;
18  uint32 b = static_cast<uint32>(hash_key >> 32) & 0xFFFFFFFF;
19
20  a -= (b + c);  a ^= (c >> 13);
21  b -= (c + a);  b ^= (a << 8);
22  c -= (a + b);  c ^= (b >> 13);
23  a -= (b + c);  a ^= (c >> 12);
24  b -= (c + a);  b ^= (a << 16);
25  c -= (a + b);  c ^= (b >> 5);
26  a -= (b + c);  a ^= (c >> 3);
27  b -= (c + a);  b ^= (a << 10);
28  c -= (a + b);  c ^= (b >> 15);
29
30  return c;
31}
32
33}  // namespace
34
35// static
36int BloomFilter::FilterSizeForKeyCount(int key_count) {
37  const int default_min = BloomFilter::kBloomFilterMinSize;
38  const int number_of_keys = std::max(key_count, default_min);
39  return std::min(number_of_keys * BloomFilter::kBloomFilterSizeRatio,
40                  BloomFilter::kBloomFilterMaxSize * 8);
41}
42
43// static
44void BloomFilter::RecordFailure(FailureType failure_type) {
45  UMA_HISTOGRAM_ENUMERATION("SB2.BloomFailure", failure_type,
46                            FAILURE_FILTER_MAX);
47}
48
49BloomFilter::BloomFilter(int bit_size) {
50  for (int i = 0; i < kNumHashKeys; ++i)
51    hash_keys_.push_back(base::RandUint64());
52
53  // Round up to the next boundary which fits bit_size.
54  byte_size_ = (bit_size + 7) / 8;
55  bit_size_ = byte_size_ * 8;
56  DCHECK_LE(bit_size, bit_size_);  // strictly more bits.
57  data_.reset(new char[byte_size_]);
58  memset(data_.get(), 0, byte_size_);
59}
60
61BloomFilter::BloomFilter(char* data, int size, const HashKeys& keys)
62    : hash_keys_(keys) {
63  byte_size_ = size;
64  bit_size_ = byte_size_ * 8;
65  data_.reset(data);
66}
67
68BloomFilter::~BloomFilter() {
69}
70
71void BloomFilter::Insert(SBPrefix hash) {
72  uint32 hash_uint32 = static_cast<uint32>(hash);
73  for (size_t i = 0; i < hash_keys_.size(); ++i) {
74    uint32 index = HashMix(hash_keys_[i], hash_uint32) % bit_size_;
75    data_[index / 8] |= 1 << (index % 8);
76  }
77}
78
79bool BloomFilter::Exists(SBPrefix hash) const {
80  uint32 hash_uint32 = static_cast<uint32>(hash);
81  for (size_t i = 0; i < hash_keys_.size(); ++i) {
82    uint32 index = HashMix(hash_keys_[i], hash_uint32) % bit_size_;
83    if (!(data_[index / 8] & (1 << (index % 8))))
84      return false;
85  }
86  return true;
87}
88
89// static.
90BloomFilter* BloomFilter::LoadFile(const FilePath& filter_name) {
91  net::FileStream filter;
92
93  if (filter.Open(filter_name,
94                  base::PLATFORM_FILE_OPEN |
95                  base::PLATFORM_FILE_READ) != net::OK) {
96    RecordFailure(FAILURE_FILTER_READ_OPEN);
97    return NULL;
98  }
99
100  // Make sure we have a file version that we can understand.
101  int file_version;
102  int bytes_read = filter.Read(reinterpret_cast<char*>(&file_version),
103                               sizeof(file_version), NULL);
104  if (bytes_read != sizeof(file_version) || file_version != kFileVersion) {
105    RecordFailure(FAILURE_FILTER_READ_VERSION);
106    return NULL;
107  }
108
109  // Get all the random hash keys.
110  int num_keys;
111  bytes_read = filter.Read(reinterpret_cast<char*>(&num_keys),
112                           sizeof(num_keys), NULL);
113  if (bytes_read != sizeof(num_keys) ||
114      num_keys < 1 || num_keys > kNumHashKeys) {
115    RecordFailure(FAILURE_FILTER_READ_NUM_KEYS);
116    return NULL;
117  }
118
119  HashKeys hash_keys;
120  for (int i = 0; i < num_keys; ++i) {
121    HashKey key;
122    bytes_read = filter.Read(reinterpret_cast<char*>(&key), sizeof(key), NULL);
123    if (bytes_read != sizeof(key)) {
124      RecordFailure(FAILURE_FILTER_READ_KEY);
125      return NULL;
126    }
127    hash_keys.push_back(key);
128  }
129
130  // Read in the filter data, with sanity checks on min and max sizes.
131  int64 remaining64 = filter.Available();
132  if (remaining64 < kBloomFilterMinSize) {
133    RecordFailure(FAILURE_FILTER_READ_DATA_MINSIZE);
134    return NULL;
135  } else if (remaining64 > kBloomFilterMaxSize) {
136    RecordFailure(FAILURE_FILTER_READ_DATA_MAXSIZE);
137    return NULL;
138  }
139
140  int byte_size = static_cast<int>(remaining64);
141  scoped_array<char> data(new char[byte_size]);
142  bytes_read = filter.Read(data.get(), byte_size, NULL);
143  if (bytes_read < byte_size) {
144    RecordFailure(FAILURE_FILTER_READ_DATA_SHORT);
145    return NULL;
146  } else if (bytes_read != byte_size) {
147    RecordFailure(FAILURE_FILTER_READ_DATA);
148    return NULL;
149  }
150
151  // We've read everything okay, commit the data.
152  return new BloomFilter(data.release(), byte_size, hash_keys);
153}
154
155bool BloomFilter::WriteFile(const FilePath& filter_name) const {
156  net::FileStream filter;
157
158  if (filter.Open(filter_name,
159                  base::PLATFORM_FILE_WRITE |
160                  base::PLATFORM_FILE_CREATE_ALWAYS) != net::OK)
161    return false;
162
163  // Write the version information.
164  int version = kFileVersion;
165  int bytes_written = filter.Write(reinterpret_cast<char*>(&version),
166                                   sizeof(version), NULL);
167  if (bytes_written != sizeof(version))
168    return false;
169
170  // Write the number of random hash keys.
171  int num_keys = static_cast<int>(hash_keys_.size());
172  bytes_written = filter.Write(reinterpret_cast<char*>(&num_keys),
173                               sizeof(num_keys), NULL);
174  if (bytes_written != sizeof(num_keys))
175    return false;
176
177  for (int i = 0; i < num_keys; ++i) {
178    bytes_written = filter.Write(reinterpret_cast<const char*>(&hash_keys_[i]),
179                                 sizeof(hash_keys_[i]), NULL);
180    if (bytes_written != sizeof(hash_keys_[i]))
181      return false;
182  }
183
184  // Write the filter data.
185  bytes_written = filter.Write(data_.get(), byte_size_, NULL);
186  if (bytes_written != byte_size_)
187    return false;
188
189  return true;
190}
191