sttable.h revision f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2
1// sttable.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: allauzen@google.com (Cyril Allauzen)
17//
18// \file
19// A generic string-to-type table file format
20//
21// This is not meant as a generalization of SSTable. This is more of
22// a simple replacement for SSTable in order to provide an open-source
23// implementation of the FAR format for the external version of the
24// FST Library.
25
26#ifndef FST_EXTENSIONS_FAR_STTABLE_H_
27#define FST_EXTENSIONS_FAR_STTABLE_H_
28
29#include <algorithm>
30#include <iostream>
31#include <fstream>
32#include <fst/util.h>
33
34namespace fst {
35
36static const int32 kSTTableMagicNumber = 2125656924;
37static const int32 kSTTableFileVersion = 1;
38
39// String-to-type table writing class for object of type 'T' using functor 'W'
40// to write an object of type 'T' from a stream. 'W' must conform to the
41// following interface:
42//
43//   struct Writer {
44//     void operator()(ostream &, const T &) const;
45//   };
46//
47template <class T, class W>
48class STTableWriter {
49 public:
50  typedef T EntryType;
51  typedef W EntryWriter;
52
53  explicit STTableWriter(const string &filename)
54      : stream_(filename.c_str(), ofstream::out | ofstream::binary),
55        error_(false) {
56    WriteType(stream_, kSTTableMagicNumber);
57    WriteType(stream_, kSTTableFileVersion);
58    if (!stream_) {
59      FSTERROR() << "STTableWriter::STTableWriter: error writing to file: "
60                 << filename;
61      error_=true;
62    }
63  }
64
65  static STTableWriter<T, W> *Create(const string &filename) {
66    if (filename.empty()) {
67      LOG(ERROR) << "STTableWriter: writing to standard out unsupported.";
68      return 0;
69    }
70    return new STTableWriter<T, W>(filename);
71  }
72
73  void Add(const string &key, const T &t) {
74    if (key == "") {
75      FSTERROR() << "STTableWriter::Add: key empty: " << key;
76      error_ = true;
77    } else if (key < last_key_) {
78      FSTERROR() << "STTableWriter::Add: key disorder: " << key;
79      error_ = true;
80    }
81    if (error_) return;
82    last_key_ = key;
83    positions_.push_back(stream_.tellp());
84    WriteType(stream_, key);
85    entry_writer_(stream_, t);
86  }
87
88  bool Error() const { return error_; }
89
90  ~STTableWriter() {
91    WriteType(stream_, positions_);
92    WriteType(stream_, static_cast<int64>(positions_.size()));
93  }
94
95 private:
96  EntryWriter entry_writer_;  // Write functor for 'EntryType'
97  ofstream stream_;           // Output stream
98  vector<int64> positions_;   // Position in file of each key-entry pair
99  string last_key_;           // Last key
100  bool error_;
101
102  DISALLOW_COPY_AND_ASSIGN(STTableWriter);
103};
104
105
106// String-to-type table reading class for object of type 'T' using functor 'R'
107// to read an object of type 'T' form a stream. 'R' must conform to the
108// following interface:
109//
110//   struct Reader {
111//     T *operator()(istream &) const;
112//   };
113//
114template <class T, class R>
115class STTableReader {
116 public:
117  typedef T EntryType;
118  typedef R EntryReader;
119
120  explicit STTableReader(const vector<string> &filenames)
121      : sources_(filenames), entry_(0), error_(false) {
122    compare_ = new Compare(&keys_);
123    keys_.resize(filenames.size());
124    streams_.resize(filenames.size(), 0);
125    positions_.resize(filenames.size());
126    for (size_t i = 0; i < filenames.size(); ++i) {
127      streams_[i] = new ifstream(
128          filenames[i].c_str(), ifstream::in | ifstream::binary);
129      int32 magic_number = 0, file_version = 0;
130      ReadType(*streams_[i], &magic_number);
131      ReadType(*streams_[i], &file_version);
132      if (magic_number != kSTTableMagicNumber) {
133        FSTERROR() << "STTableReader::STTableReader: wrong file type: "
134                   << filenames[i];
135        error_ = true;
136        return;
137      }
138      if (file_version != kSTTableFileVersion) {
139        FSTERROR() << "STTableReader::STTableReader: wrong file version: "
140                   << filenames[i];
141        error_ = true;
142        return;
143      }
144      int64 num_entries;
145      streams_[i]->seekg(-static_cast<int>(sizeof(int64)), ios_base::end);
146      ReadType(*streams_[i], &num_entries);
147      streams_[i]->seekg(-static_cast<int>(sizeof(int64)) *
148                         (num_entries + 1), ios_base::end);
149      positions_[i].resize(num_entries);
150      for (size_t j = 0; (j < num_entries) && (*streams_[i]); ++j)
151        ReadType(*streams_[i], &(positions_[i][j]));
152      streams_[i]->seekg(positions_[i][0]);
153      if (!*streams_[i]) {
154        FSTERROR() << "STTableReader::STTableReader: error reading file: "
155                   << filenames[i];
156        error_ = true;
157        return;
158      }
159
160    }
161    MakeHeap();
162  }
163
164  ~STTableReader() {
165    for (size_t i = 0; i < streams_.size(); ++i)
166      delete streams_[i];
167    delete compare_;
168    if (entry_)
169      delete entry_;
170  }
171
172  static STTableReader<T, R> *Open(const string &filename) {
173    if (filename.empty()) {
174      LOG(ERROR) << "STTableReader: reading from standard in not supported";
175      return 0;
176    }
177    vector<string> filenames;
178    filenames.push_back(filename);
179    return new STTableReader<T, R>(filenames);
180  }
181
182  static STTableReader<T, R> *Open(const vector<string> &filenames) {
183    return new STTableReader<T, R>(filenames);
184  }
185
186  void Reset() {
187    if (error_) return;
188    for (size_t i = 0; i < streams_.size(); ++i)
189      streams_[i]->seekg(positions_[i].front());
190    MakeHeap();
191  }
192
193  bool Find(const string &key) {
194    if (error_) return false;
195    for (size_t i = 0; i < streams_.size(); ++i)
196      LowerBound(i, key);
197    MakeHeap();
198    return keys_[current_] == key;
199  }
200
201  bool Done() const { return error_ || heap_.empty(); }
202
203  void Next() {
204    if (error_) return;
205    if (streams_[current_]->tellg() <= positions_[current_].back()) {
206      ReadType(*(streams_[current_]), &(keys_[current_]));
207      if (!*streams_[current_]) {
208        FSTERROR() << "STTableReader: error reading file: "
209                   << sources_[current_];
210        error_ = true;
211        return;
212      }
213      push_heap(heap_.begin(), heap_.end(), *compare_);
214    } else {
215      heap_.pop_back();
216    }
217    if (!heap_.empty())
218      PopHeap();
219  }
220
221  const string &GetKey() const {
222    return keys_[current_];
223  }
224
225  const EntryType &GetEntry() const {
226    return *entry_;
227  }
228
229  bool Error() const { return error_; }
230
231 private:
232  // Comparison functor used to compare stream IDs in the heap
233  struct Compare {
234    Compare(const vector<string> *keys) : keys_(keys) {}
235
236    bool operator()(size_t i, size_t j) const {
237      return (*keys_)[i] > (*keys_)[j];
238    };
239
240   private:
241    const vector<string> *keys_;
242  };
243
244  // Position the stream with ID 'id' at the position corresponding
245  // to the lower bound for key 'find_key'
246  void LowerBound(size_t id, const string &find_key) {
247    ifstream *strm = streams_[id];
248    const vector<int64> &positions = positions_[id];
249    size_t low = 0, high = positions.size() - 1;
250
251    while (low < high) {
252      size_t mid = (low + high)/2;
253      strm->seekg(positions[mid]);
254      string key;
255      ReadType(*strm, &key);
256      if (key > find_key) {
257        high = mid;
258      } else if (key < find_key) {
259        low = mid + 1;
260      } else {
261        for (size_t i = mid; i > low; --i) {
262          strm->seekg(positions[i - 1]);
263          ReadType(*strm, &key);
264          if (key != find_key) {
265            strm->seekg(positions[i]);
266            return;
267          }
268        }
269        strm->seekg(positions[low]);
270        return;
271      }
272    }
273    strm->seekg(positions[low]);
274  }
275
276  // Add all streams to the heap
277  void MakeHeap() {
278    heap_.clear();
279    for (size_t i = 0; i < streams_.size(); ++i) {
280      ReadType(*streams_[i], &(keys_[i]));
281      if (!*streams_[i]) {
282        FSTERROR() << "STTableReader: error reading file: " << sources_[i];
283        error_ = true;
284        return;
285      }
286      heap_.push_back(i);
287    }
288    make_heap(heap_.begin(), heap_.end(), *compare_);
289    PopHeap();
290  }
291
292  // Position the stream with the lowest key at the top
293  // of the heap, set 'current_' to the ID of that stream
294  // and read the current entry from that stream
295  void PopHeap() {
296    pop_heap(heap_.begin(), heap_.end(), *compare_);
297    current_ = heap_.back();
298    if (entry_)
299      delete entry_;
300    entry_ = entry_reader_(*streams_[current_]);
301    if (!entry_)
302      error_ = true;
303    if (!*streams_[current_]) {
304      FSTERROR() << "STTableReader: error reading entry for key: "
305                 << keys_[current_] << ", file: " << sources_[current_];
306      error_ = true;
307    }
308  }
309
310
311  EntryReader entry_reader_;   // Read functor for 'EntryType'
312  vector<ifstream*> streams_;  // Input streams
313  vector<string> sources_;     // and corresponding file names
314  vector<vector<int64> > positions_;  // Index of positions for each stream
315  vector<string> keys_;  // Lowest unread key for each stream
316  vector<int64> heap_;   // Heap containing ID of streams with unread keys
317  int64 current_;        // Id of current stream to be read
318  Compare *compare_;     // Functor comparing stream IDs for the heap
319  mutable EntryType *entry_;  // Pointer to the currently read entry
320  bool error_;
321
322  DISALLOW_COPY_AND_ASSIGN(STTableReader);
323};
324
325
326// String-to-type table header reading function template on the entry header
327// type 'H' having a member function:
328//   Read(istream &strm, const string &filename);
329// Checks that 'filename' is an STTable and call the H::Read() on the last
330// entry in the STTable.
331template <class H>
332bool ReadSTTableHeader(const string &filename, H *header) {
333  ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
334  int32 magic_number = 0, file_version = 0;
335  ReadType(strm, &magic_number);
336  ReadType(strm, &file_version);
337  if (magic_number != kSTTableMagicNumber) {
338    LOG(ERROR) << "ReadSTTableHeader: wrong file type: " << filename;
339    return false;
340  }
341  if (file_version != kSTTableFileVersion) {
342    LOG(ERROR) << "ReadSTTableHeader: wrong file version: " << filename;
343    return false;
344  }
345  int64 i = -1;
346  strm.seekg(-static_cast<int>(sizeof(int64)), ios_base::end);
347  ReadType(strm, &i);  // Read number of entries
348  if (!strm) {
349    LOG(ERROR) << "ReadSTTableHeader: error reading file: " << filename;
350    return false;
351  }
352  if (i == 0) return true;  // No entry header to read
353  strm.seekg(-2 * static_cast<int>(sizeof(int64)), ios_base::end);
354  ReadType(strm, &i);  // Read position for last entry in file
355  strm.seekg(i);
356  string key;
357  ReadType(strm, &key);
358  header->Read(strm, filename + ":" + key);
359  if (!strm) {
360    LOG(ERROR) << "ReadSTTableHeader: error reading file: " << filename;
361    return false;
362  }
363  return true;
364}
365
366bool IsSTTable(const string &filename);
367
368}  // namespace fst
369
370#endif  // FST_EXTENSIONS_FAR_STTABLE_H_
371