1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/spdy/hpack_huffman_table.h"
6
7#include <algorithm>
8#include <cmath>
9
10#include "base/logging.h"
11#include "net/spdy/hpack_input_stream.h"
12#include "net/spdy/hpack_output_stream.h"
13
14namespace net {
15
16using base::StringPiece;
17using std::string;
18
19namespace {
20
21// How many bits to index in the root decode table.
22const uint8 kDecodeTableRootBits = 9;
23// Maximum number of bits to index in successive decode tables.
24const uint8 kDecodeTableBranchBits = 6;
25
26bool SymbolLengthAndIdCompare(const HpackHuffmanSymbol& a,
27                              const HpackHuffmanSymbol& b) {
28  if (a.length == b.length) {
29    return a.id < b.id;
30  }
31  return a.length < b.length;
32}
33bool SymbolIdCompare(const HpackHuffmanSymbol& a,
34                     const HpackHuffmanSymbol& b) {
35  return a.id < b.id;
36}
37
38}  // namespace
39
40HpackHuffmanTable::DecodeEntry::DecodeEntry()
41  : next_table_index(0), length(0), symbol_id(0) {
42}
43HpackHuffmanTable::DecodeEntry::DecodeEntry(uint8 next_table_index,
44                                            uint8 length,
45                                            uint16 symbol_id)
46  : next_table_index(next_table_index), length(length), symbol_id(symbol_id) {
47}
48size_t HpackHuffmanTable::DecodeTable::size() const {
49  return size_t(1) << indexed_length;
50}
51
52HpackHuffmanTable::HpackHuffmanTable() {}
53
54HpackHuffmanTable::~HpackHuffmanTable() {}
55
56bool HpackHuffmanTable::Initialize(const HpackHuffmanSymbol* input_symbols,
57                                   size_t symbol_count) {
58  CHECK(!IsInitialized());
59
60  std::vector<Symbol> symbols(symbol_count);
61  // Validate symbol id sequence, and copy into |symbols|.
62  for (size_t i = 0; i != symbol_count; i++) {
63    if (i != input_symbols[i].id) {
64      failed_symbol_id_ = i;
65      return false;
66    }
67    symbols[i] = input_symbols[i];
68  }
69  // Order on length and ID ascending, to verify symbol codes are canonical.
70  std::sort(symbols.begin(), symbols.end(), SymbolLengthAndIdCompare);
71  if (symbols[0].code != 0) {
72    failed_symbol_id_ = 0;
73    return false;
74  }
75  for (size_t i = 1; i != symbols.size(); i++) {
76    unsigned code_shift = 32 - symbols[i-1].length;
77    uint32 code = symbols[i-1].code + (1 << code_shift);
78
79    if (code != symbols[i].code) {
80      failed_symbol_id_ = symbols[i].id;
81      return false;
82    }
83    if (code < symbols[i-1].code) {
84      // An integer overflow occurred. This implies the input
85      // lengths do not represent a valid Huffman code.
86      failed_symbol_id_ = symbols[i].id;
87      return false;
88    }
89  }
90  if (symbols.back().length < 8) {
91    // At least one code (such as an EOS symbol) must be 8 bits or longer.
92    // Without this, some inputs will not be encodable in a whole number
93    // of bytes.
94    return false;
95  }
96  pad_bits_ = static_cast<uint8>(symbols.back().code >> 24);
97
98  BuildDecodeTables(symbols);
99  // Order on symbol ID ascending.
100  std::sort(symbols.begin(), symbols.end(), SymbolIdCompare);
101  BuildEncodeTable(symbols);
102  return true;
103}
104
105void HpackHuffmanTable::BuildEncodeTable(const std::vector<Symbol>& symbols) {
106  for (size_t i = 0; i != symbols.size(); i++) {
107    const Symbol& symbol = symbols[i];
108    CHECK_EQ(i, symbol.id);
109    code_by_id_.push_back(symbol.code);
110    length_by_id_.push_back(symbol.length);
111  }
112}
113
114void HpackHuffmanTable::BuildDecodeTables(const std::vector<Symbol>& symbols) {
115  AddDecodeTable(0, kDecodeTableRootBits);
116  // We wish to maximize the flatness of the DecodeTable hierarchy (subject to
117  // the |kDecodeTableBranchBits| constraint), and to minimize the size of
118  // child tables. To achieve this, we iterate in order of descending code
119  // length. This ensures that child tables are visited with their longest
120  // entry first, and that the child can therefore be minimally sized to hold
121  // that entry without fear of introducing unneccesary branches later.
122  for (std::vector<Symbol>::const_reverse_iterator it = symbols.rbegin();
123       it != symbols.rend(); ++it) {
124    uint8 table_index = 0;
125    while (true) {
126      const DecodeTable table = decode_tables_[table_index];
127
128      // Mask and shift the portion of the code being indexed into low bits.
129      uint32 index = (it->code << table.prefix_length);
130      index = index >> (32 - table.indexed_length);
131
132      CHECK_LT(index, table.size());
133      DecodeEntry entry = Entry(table, index);
134
135      uint8 total_indexed = table.prefix_length + table.indexed_length;
136      if (total_indexed >= it->length) {
137        // We're writing a terminal entry.
138        entry.length = it->length;
139        entry.symbol_id = it->id;
140        entry.next_table_index = table_index;
141        SetEntry(table, index, entry);
142        break;
143      }
144
145      if (entry.length == 0) {
146        // First visit to this placeholder. We need to create a new table.
147        CHECK_EQ(entry.next_table_index, 0);
148        entry.length = it->length;
149        entry.next_table_index = AddDecodeTable(
150            total_indexed,  // Becomes the new table prefix.
151            std::min<uint8>(kDecodeTableBranchBits,
152                            entry.length - total_indexed));
153        SetEntry(table, index, entry);
154      }
155      CHECK_NE(entry.next_table_index, table_index);
156      table_index = entry.next_table_index;
157    }
158  }
159  // Fill shorter table entries into the additional entry spots they map to.
160  for (size_t i = 0; i != decode_tables_.size(); i++) {
161    const DecodeTable& table = decode_tables_[i];
162    uint8 total_indexed = table.prefix_length + table.indexed_length;
163
164    size_t j = 0;
165    while (j != table.size()) {
166      const DecodeEntry& entry = Entry(table, j);
167      if (entry.length != 0 && entry.length < total_indexed) {
168        // The difference between entry & table bit counts tells us how
169        // many additional entries map to this one.
170        size_t fill_count = 1 << (total_indexed - entry.length);
171        CHECK_LE(j + fill_count, table.size());
172
173        for (size_t k = 1; k != fill_count; k++) {
174          CHECK_EQ(Entry(table, j + k).length, 0);
175          SetEntry(table, j + k, entry);
176        }
177        j += fill_count;
178      } else {
179        j++;
180      }
181    }
182  }
183}
184
185uint8 HpackHuffmanTable::AddDecodeTable(uint8 prefix, uint8 indexed) {
186  CHECK_LT(decode_tables_.size(), 255u);
187  {
188    DecodeTable table;
189    table.prefix_length = prefix;
190    table.indexed_length = indexed;
191    table.entries_offset = decode_entries_.size();
192    decode_tables_.push_back(table);
193  }
194  decode_entries_.resize(decode_entries_.size() + (size_t(1) << indexed));
195  return static_cast<uint8>(decode_tables_.size() - 1);
196}
197
198const HpackHuffmanTable::DecodeEntry& HpackHuffmanTable::Entry(
199    const DecodeTable& table,
200    uint32 index) const {
201  DCHECK_LT(index, table.size());
202  DCHECK_LT(table.entries_offset + index, decode_entries_.size());
203  return decode_entries_[table.entries_offset + index];
204}
205
206void HpackHuffmanTable::SetEntry(const DecodeTable& table,
207                                 uint32 index,
208                                 const DecodeEntry& entry) {
209  CHECK_LT(index, table.size());
210  CHECK_LT(table.entries_offset + index, decode_entries_.size());
211  decode_entries_[table.entries_offset + index] = entry;
212}
213
214bool HpackHuffmanTable::IsInitialized() const {
215  return !code_by_id_.empty();
216}
217
218void HpackHuffmanTable::EncodeString(StringPiece in,
219                                     HpackOutputStream* out) const {
220  size_t bit_remnant = 0;
221  for (size_t i = 0; i != in.size(); i++) {
222    uint16 symbol_id = static_cast<uint8>(in[i]);
223    CHECK_GT(code_by_id_.size(), symbol_id);
224
225    // Load, and shift code to low bits.
226    unsigned length = length_by_id_[symbol_id];
227    uint32 code = code_by_id_[symbol_id] >> (32 - length);
228
229    bit_remnant = (bit_remnant + length) % 8;
230
231    if (length > 24) {
232      out->AppendBits(static_cast<uint8>(code >> 24), length - 24);
233      length = 24;
234    }
235    if (length > 16) {
236      out->AppendBits(static_cast<uint8>(code >> 16), length - 16);
237      length = 16;
238    }
239    if (length > 8) {
240      out->AppendBits(static_cast<uint8>(code >> 8), length - 8);
241      length = 8;
242    }
243    out->AppendBits(static_cast<uint8>(code), length);
244  }
245  if (bit_remnant != 0) {
246    // Pad current byte as required.
247    out->AppendBits(pad_bits_ >> bit_remnant, 8 - bit_remnant);
248  }
249}
250
251size_t HpackHuffmanTable::EncodedSize(StringPiece in) const {
252  size_t bit_count = 0;
253  for (size_t i = 0; i != in.size(); i++) {
254    uint16 symbol_id = static_cast<uint8>(in[i]);
255    CHECK_GT(code_by_id_.size(), symbol_id);
256
257    bit_count += length_by_id_[symbol_id];
258  }
259  if (bit_count % 8 != 0) {
260    bit_count += 8 - bit_count % 8;
261  }
262  return bit_count / 8;
263}
264
265bool HpackHuffmanTable::DecodeString(HpackInputStream* in,
266                                     size_t out_capacity,
267                                     string* out) const {
268  // Number of decode iterations required for a 32-bit code.
269  const int kDecodeIterations = static_cast<int>(
270      std::ceil((32.f - kDecodeTableRootBits) / kDecodeTableBranchBits));
271
272  out->clear();
273
274  // Current input, stored in the high |bits_available| bits of |bits|.
275  uint32 bits = 0;
276  size_t bits_available = 0;
277  bool peeked_success = in->PeekBits(&bits_available, &bits);
278
279  while (true) {
280    const DecodeTable* table = &decode_tables_[0];
281    uint32 index = bits >> (32 - kDecodeTableRootBits);
282
283    for (int i = 0; i != kDecodeIterations; i++) {
284      DCHECK_LT(index, table->size());
285      DCHECK_LT(Entry(*table, index).next_table_index, decode_tables_.size());
286
287      table = &decode_tables_[Entry(*table, index).next_table_index];
288      // Mask and shift the portion of the code being indexed into low bits.
289      index = (bits << table->prefix_length) >> (32 - table->indexed_length);
290    }
291    const DecodeEntry& entry = Entry(*table, index);
292
293    if (entry.length > bits_available) {
294      if (!peeked_success) {
295        // Unable to read enough input for a match. If only a portion of
296        // the last byte remains, this is a successful EOF condition.
297        in->ConsumeByteRemainder();
298        return !in->HasMoreData();
299      }
300    } else if (entry.length == 0) {
301      // The input is an invalid prefix, larger than any prefix in the table.
302      return false;
303    } else {
304      if (out->size() == out_capacity) {
305        // This code would cause us to overflow |out_capacity|.
306        return false;
307      }
308      if (entry.symbol_id < 256) {
309        // Assume symbols >= 256 are used for padding.
310        out->push_back(static_cast<char>(entry.symbol_id));
311      }
312
313      in->ConsumeBits(entry.length);
314      bits = bits << entry.length;
315      bits_available -= entry.length;
316    }
317    peeked_success = in->PeekBits(&bits_available, &bits);
318  }
319  NOTREACHED();
320  return false;
321}
322
323}  // namespace net
324