sttable.h revision dfd8b8327b93660601d016cdc6f29f433b45a8d8
1// sttable.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: allauzen@google.com (Cyril Allauzen)
17//
18// \file
19// A generic string-to-type table file format
20//
21// This is not meant as a generalization of SSTable. This is more of
22// a simple replacement for SSTable in order to provide an open-source
23// implementation of the FAR format for the external version of the
24// FST Library.
25
26#ifndef FST_EXTENSIONS_FAR_STTABLE_H_
27#define FST_EXTENSIONS_FAR_STTABLE_H_
28
29#include <algorithm>
30#include <iostream>
31#include <fstream>
32#include <sstream>
33#include <fst/util.h>
34
35namespace fst {
36
37static const int32 kSTTableMagicNumber = 2125656924;
38static const int32 kSTTableFileVersion = 1;
39
40// String-to-type table writing class for object of type 'T' using functor 'W'
41// to write an object of type 'T' from a stream. 'W' must conform to the
42// following interface:
43//
44//   struct Writer {
45//     void operator()(ostream &, const T &) const;
46//   };
47//
48template <class T, class W>
49class STTableWriter {
50 public:
51  typedef T EntryType;
52  typedef W EntryWriter;
53
54  explicit STTableWriter(const string &filename)
55      : stream_(filename.c_str(), ofstream::out | ofstream::binary),
56        error_(false) {
57    WriteType(stream_, kSTTableMagicNumber);
58    WriteType(stream_, kSTTableFileVersion);
59    if (!stream_) {
60      FSTERROR() << "STTableWriter::STTableWriter: error writing to file: "
61                 << filename;
62      error_=true;
63    }
64  }
65
66  static STTableWriter<T, W> *Create(const string &filename) {
67    if (filename.empty()) {
68      LOG(ERROR) << "STTableWriter: writing to standard out unsupported.";
69      return 0;
70    }
71    return new STTableWriter<T, W>(filename);
72  }
73
74  void Add(const string &key, const T &t) {
75    if (key == "") {
76      FSTERROR() << "STTableWriter::Add: key empty: " << key;
77      error_ = true;
78    } else if (key < last_key_) {
79      FSTERROR() << "STTableWriter::Add: key disorder: " << key;
80      error_ = true;
81    }
82    if (error_) return;
83    last_key_ = key;
84    positions_.push_back(stream_.tellp());
85    WriteType(stream_, key);
86    entry_writer_(stream_, t);
87  }
88
89  bool Error() const { return error_; }
90
91  ~STTableWriter() {
92    WriteType(stream_, positions_);
93    WriteType(stream_, static_cast<int64>(positions_.size()));
94  }
95
96 private:
97  EntryWriter entry_writer_;  // Write functor for 'EntryType'
98  ofstream stream_;           // Output stream
99  vector<int64> positions_;   // Position in file of each key-entry pair
100  string last_key_;           // Last key
101  bool error_;
102
103  DISALLOW_COPY_AND_ASSIGN(STTableWriter);
104};
105
106
107// String-to-type table reading class for object of type 'T' using functor 'R'
108// to read an object of type 'T' form a stream. 'R' must conform to the
109// following interface:
110//
111//   struct Reader {
112//     T *operator()(istream &) const;
113//   };
114//
115template <class T, class R>
116class STTableReader {
117 public:
118  typedef T EntryType;
119  typedef R EntryReader;
120
121  explicit STTableReader(const vector<string> &filenames)
122      : sources_(filenames), entry_(0), error_(false) {
123    compare_ = new Compare(&keys_);
124    keys_.resize(filenames.size());
125    streams_.resize(filenames.size(), 0);
126    positions_.resize(filenames.size());
127    for (size_t i = 0; i < filenames.size(); ++i) {
128      streams_[i] = new ifstream(
129          filenames[i].c_str(), ifstream::in | ifstream::binary);
130      int32 magic_number = 0, file_version = 0;
131      ReadType(*streams_[i], &magic_number);
132      ReadType(*streams_[i], &file_version);
133      if (magic_number != kSTTableMagicNumber) {
134        FSTERROR() << "STTableReader::STTableReader: wrong file type: "
135                   << filenames[i];
136        error_ = true;
137        return;
138      }
139      if (file_version != kSTTableFileVersion) {
140        FSTERROR() << "STTableReader::STTableReader: wrong file version: "
141                   << filenames[i];
142        error_ = true;
143        return;
144      }
145      int64 num_entries;
146      streams_[i]->seekg(-static_cast<int>(sizeof(int64)), ios_base::end);
147      ReadType(*streams_[i], &num_entries);
148      streams_[i]->seekg(-static_cast<int>(sizeof(int64)) *
149                         (num_entries + 1), ios_base::end);
150      positions_[i].resize(num_entries);
151      for (size_t j = 0; (j < num_entries) && (*streams_[i]); ++j)
152        ReadType(*streams_[i], &(positions_[i][j]));
153      streams_[i]->seekg(positions_[i][0]);
154      if (!*streams_[i]) {
155        FSTERROR() << "STTableReader::STTableReader: error reading file: "
156                   << filenames[i];
157        error_ = true;
158        return;
159      }
160
161    }
162    MakeHeap();
163  }
164
165  ~STTableReader() {
166    for (size_t i = 0; i < streams_.size(); ++i)
167      delete streams_[i];
168    delete compare_;
169    if (entry_)
170      delete entry_;
171  }
172
173  static STTableReader<T, R> *Open(const string &filename) {
174    if (filename.empty()) {
175      LOG(ERROR) << "STTableReader: reading from standard in not supported";
176      return 0;
177    }
178    vector<string> filenames;
179    filenames.push_back(filename);
180    return new STTableReader<T, R>(filenames);
181  }
182
183  static STTableReader<T, R> *Open(const vector<string> &filenames) {
184    return new STTableReader<T, R>(filenames);
185  }
186
187  void Reset() {
188    if (error_) return;
189    for (size_t i = 0; i < streams_.size(); ++i)
190      streams_[i]->seekg(positions_[i].front());
191    MakeHeap();
192  }
193
194  bool Find(const string &key) {
195    if (error_) return false;
196    for (size_t i = 0; i < streams_.size(); ++i)
197      LowerBound(i, key);
198    MakeHeap();
199    return keys_[current_] == key;
200  }
201
202  bool Done() const { return error_ || heap_.empty(); }
203
204  void Next() {
205    if (error_) return;
206    if (streams_[current_]->tellg() <= positions_[current_].back()) {
207      ReadType(*(streams_[current_]), &(keys_[current_]));
208      if (!*streams_[current_]) {
209        FSTERROR() << "STTableReader: error reading file: "
210                   << sources_[current_];
211        error_ = true;
212        return;
213      }
214      push_heap(heap_.begin(), heap_.end(), *compare_);
215    } else {
216      heap_.pop_back();
217    }
218    if (!heap_.empty())
219      PopHeap();
220  }
221
222  const string &GetKey() const {
223    return keys_[current_];
224  }
225
226  const EntryType &GetEntry() const {
227    return *entry_;
228  }
229
230  bool Error() const { return error_; }
231
232 private:
233  // Comparison functor used to compare stream IDs in the heap
234  struct Compare {
235    Compare(const vector<string> *keys) : keys_(keys) {}
236
237    bool operator()(size_t i, size_t j) const {
238      return (*keys_)[i] > (*keys_)[j];
239    };
240
241   private:
242    const vector<string> *keys_;
243  };
244
245  // Position the stream with ID 'id' at the position corresponding
246  // to the lower bound for key 'find_key'
247  void LowerBound(size_t id, const string &find_key) {
248    ifstream *strm = streams_[id];
249    const vector<int64> &positions = positions_[id];
250    size_t low = 0, high = positions.size() - 1;
251
252    while (low < high) {
253      size_t mid = (low + high)/2;
254      strm->seekg(positions[mid]);
255      string key;
256      ReadType(*strm, &key);
257      if (key > find_key) {
258        high = mid;
259      } else if (key < find_key) {
260        low = mid + 1;
261      } else {
262        for (size_t i = mid; i > low; --i) {
263          strm->seekg(positions[i - 1]);
264          ReadType(*strm, &key);
265          if (key != find_key) {
266            strm->seekg(positions[i]);
267            return;
268          }
269        }
270        strm->seekg(positions[low]);
271        return;
272      }
273    }
274    strm->seekg(positions[low]);
275  }
276
277  // Add all streams to the heap
278  void MakeHeap() {
279    heap_.clear();
280    for (size_t i = 0; i < streams_.size(); ++i) {
281      ReadType(*streams_[i], &(keys_[i]));
282      if (!*streams_[i]) {
283        FSTERROR() << "STTableReader: error reading file: " << sources_[i];
284        error_ = true;
285        return;
286      }
287      heap_.push_back(i);
288    }
289    make_heap(heap_.begin(), heap_.end(), *compare_);
290    PopHeap();
291  }
292
293  // Position the stream with the lowest key at the top
294  // of the heap, set 'current_' to the ID of that stream
295  // and read the current entry from that stream
296  void PopHeap() {
297    pop_heap(heap_.begin(), heap_.end(), *compare_);
298    current_ = heap_.back();
299    if (entry_)
300      delete entry_;
301    entry_ = entry_reader_(*streams_[current_]);
302    if (!entry_)
303      error_ = true;
304    if (!*streams_[current_]) {
305      FSTERROR() << "STTableReader: error reading entry for key: "
306                 << keys_[current_] << ", file: " << sources_[current_];
307      error_ = true;
308    }
309  }
310
311
312  EntryReader entry_reader_;   // Read functor for 'EntryType'
313  vector<ifstream*> streams_;  // Input streams
314  vector<string> sources_;     // and corresponding file names
315  vector<vector<int64> > positions_;  // Index of positions for each stream
316  vector<string> keys_;  // Lowest unread key for each stream
317  vector<int64> heap_;   // Heap containing ID of streams with unread keys
318  int64 current_;        // Id of current stream to be read
319  Compare *compare_;     // Functor comparing stream IDs for the heap
320  mutable EntryType *entry_;  // Pointer to the currently read entry
321  bool error_;
322
323  DISALLOW_COPY_AND_ASSIGN(STTableReader);
324};
325
326
327// String-to-type table header reading function template on the entry header
328// type 'H' having a member function:
329//   Read(istream &strm, const string &filename);
330// Checks that 'filename' is an STTable and call the H::Read() on the last
331// entry in the STTable.
332template <class H>
333bool ReadSTTableHeader(const string &filename, H *header) {
334  ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
335  int32 magic_number = 0, file_version = 0;
336  ReadType(strm, &magic_number);
337  ReadType(strm, &file_version);
338  if (magic_number != kSTTableMagicNumber) {
339    LOG(ERROR) << "ReadSTTableHeader: wrong file type: " << filename;
340    return false;
341  }
342  if (file_version != kSTTableFileVersion) {
343    LOG(ERROR) << "ReadSTTableHeader: wrong file version: " << filename;
344    return false;
345  }
346  int64 i = -1;
347  strm.seekg(-static_cast<int>(sizeof(int64)), ios_base::end);
348  ReadType(strm, &i);  // Read number of entries
349  if (!strm) {
350    LOG(ERROR) << "ReadSTTableHeader: error reading file: " << filename;
351    return false;
352  }
353  if (i == 0) return true;  // No entry header to read
354  strm.seekg(-2 * static_cast<int>(sizeof(int64)), ios_base::end);
355  ReadType(strm, &i);  // Read position for last entry in file
356  strm.seekg(i);
357  string key;
358  ReadType(strm, &key);
359  header->Read(strm, filename + ":" + key);
360  if (!strm) {
361    LOG(ERROR) << "ReadSTTableHeader: error reading file: " << filename;
362    return false;
363  }
364  return true;
365}
366
367bool IsSTTable(const string &filename);
368
369}  // namespace fst
370
371#endif  // FST_EXTENSIONS_FAR_STTABLE_H_
372