1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// A tensor bundle is a set of immutable persistent files storing a set of named
17// tensors.  It is designed for checkpointing TensorFlow tensors.
18//
19// The paths of the managed files share a common prefix; e.g., with the prefix:
20//   /fs/model/train/ckpt-step/ckpt
21//
22// the bundle may contain a metadata file, and sharded data files:
23//   /fs/model/train/ckpt-step/
24//       ckpt.index
25//       ckpt.data-00000-of-00020
26//       ckpt.data-00001-of-00020
27//       ...
28//       ckpt.data-00019-of-00020
29//
30// The ".index" file is a string-string immutable table
31// (tensorflow::table::Table).  Each key is a name of a tensor and its value is
32// a serialized BundleEntryProto.  Each BundleEntryProto describes the metadata
33// of a tensor: which of the "data" files contains the content of a tensor, the
34// offset into that file, checksum, some auxiliary data, etc.
35//
36// A tensor bundle can be accessed randomly using a BundleReader.  Usage:
37//
38//   BundleReader reader(env, "/fs/model/train/ckpt-step/ckpt");
39//   reader.Lookup("name", &tensor);
40//
41// A tensor bundle can be built using BundleWriter.  Each BundleWriter builds a
42// single data file bundle.  Multiple bundles can then be merged by
43// MergeBundles() without reading and writing large chunk of data: it reads the
44// metadata files and outputs a single merged metadata.  Typical usage:
45//
46//   worker 0:
47//     BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker0-step");
48//     writer.Add(...);  // Adds the tensors on this worker.
49//     writer.Finish();  // Flushes.
50//   worker 1:
51//     BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker1-step");
52//     writer.Add(...);
53//     writer.Finish();
54//   worker 2:
55//     MergeBundles(env,
56//       {"/fs/model/train/ckpt-step/tmp/worker0-step",
57//        "/fs/model/train/ckpt-step/tmp/worker1-step"},
58//       "/fs/model/train/ckpt-step/ckpt" /* merged prefix */);
59//
60
61#ifndef TENSORFLOW_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
62#define TENSORFLOW_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
63
64#include "tensorflow/core/protobuf/tensor_bundle.pb.h"
65
66#include <map>
67#include <string>
68#include <unordered_map>
69
70#include "tensorflow/core/framework/tensor.h"
71#include "tensorflow/core/framework/tensor_shape.h"
72#include "tensorflow/core/framework/tensor_slice.h"
73#include "tensorflow/core/lib/core/status.h"
74#include "tensorflow/core/lib/gtl/array_slice.h"
75#include "tensorflow/core/lib/io/inputbuffer.h"
76#include "tensorflow/core/lib/io/table.h"
77#include "tensorflow/core/platform/env.h"
78#include "tensorflow/core/platform/file_system.h"
79#include "tensorflow/core/platform/macros.h"
80#include "tensorflow/core/platform/types.h"
81#include "tensorflow/core/util/tensor_bundle/naming.h"
82#include "tensorflow/core/util/tensor_slice_set.h"
83
84namespace tensorflow {
85
86class FileOutputBuffer;
87
88// Versioning of the tensor bundle format.
89// Follows the same rules as 3p/tf/core/public/version.h.
90//
91// History:
92// 0. Any tensor bundles produced before this field was added.
93// 1. Added this field (2016-09-14).
94extern const int kTensorBundleMinProducer;
95extern const int kTensorBundleMinConsumer;
96extern const int kTensorBundleVersion;
97
98// The empty string, hence always the first key in the metadata table.  Its
99// corresponding value is a BundleHeaderProto.
100extern const char* const kHeaderEntryKey;
101
102// Builds a string-string table of tensor names to BundleEntryProto (metadata).
103//
104// On construction, attempts to create a directory given by the dirname of
105// "prefix", so "status()" must be checked before calling any member functions.
106//
107// All threads accessing the same BundleWriter must synchronize.
108class BundleWriter {
109 public:
110  struct Options {
111    Options() {}
112    // Alignment, in bytes, for tensor data.
113    // Must be >= 1. The default size of 1 densely packs tensors.
114    int data_alignment{1};
115  };
116  BundleWriter(Env* env, StringPiece prefix,
117               const Options& options = Options());
118
119  // Adds the tensor "val" under key "key".
120  // Across calls "key" must be unique but can be added in any order.
121  Status Add(StringPiece key, const Tensor& val);
122
123  // Partitioned variables support.
124  // A slice of a full tensor is stored in two entries in the metadata table:
125  //
126  //   full_tensor_key   -> BundleEntryProto, describing all stored slices
127  //                        of this full tensor.  Does not append to the data
128  //                        file.
129  //   encoded slice key -> BundleEntryProto, describing one particular slice.
130  //                        Appends values of this slice to the data file.
131  //
132  // Slices of a full tensor can be added in any order.
133  //
134  // If a full tensor has slices placed on N devices and N BundleWriter's are
135  // concurrently used, the caller must use MergeBundles() to ensure that a
136  // consistent entry for "full_tensor_key" is produced.
137  //
138  // Returns an error if the same slice is added the second time.
139  Status AddSlice(StringPiece full_tensor_key,
140                  const TensorShape& full_tensor_shape,
141                  const TensorSlice& slice_spec, const Tensor& slice_tensor);
142
143  // Finishes the writer and flushes.
144  Status Finish() TF_MUST_USE_RESULT;
145
146  Status status() const { return status_; }
147
148 private:
149  Env* const env_;  // Not owned.
150  const Options options_;
151  const string prefix_;
152  const string tmp_metadata_path_;
153  const string tmp_data_path_;
154  std::unique_ptr<FileOutputBuffer> out_;
155  int64 size_;  // Number of bytes written into out_.
156  std::map<string, BundleEntryProto> entries_;
157  Status status_;
158
159  TF_DISALLOW_COPY_AND_ASSIGN(BundleWriter);
160};
161
162// Merges a set of bundles (given their prefixes) into a single bundle with the
163// given "merged_prefix".  The merged metadata is guaranteed to be consistent.
164//
165// If there are N bundles in "prefixes", during the merge the data files will be
166// renamed to contain a proper sharded file spec, with num_shards set to the sum
167// of num_shards across the N input bundles.
168//
169// The caller should only rely on the metadata file of the merged bundle to
170// query information about a tensor.  In particular, this function does not
171// guarantee not to re-order the input data files.
172//
173// Once merged, makes a best effort to delete the old metadata files.
174// Returns OK iff all bundles are successfully merged.
175Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes,
176                    StringPiece merged_prefix);
177
178// On construction, silently attempts to read the metadata associated with
179// "prefix".  If caller intends to call any function afterwards, "status()"
180// must be checked.
181// All threads accessing the same BundleReader must synchronize.
182class BundleReader {
183 public:
184  BundleReader(Env* const env, StringPiece prefix);
185  ~BundleReader();
186
187  // Is ok() iff the reader construction is successful (completed the read of
188  // the metadata).
189  Status status() const { return status_; }
190
191  // Queries whether the bundle contains an entry keyed by "key".  Calls Seek()
192  // internally, so this call invalidates the reader's current position.
193  // REQUIRES: status().ok()
194  bool Contains(StringPiece key);
195
196  // Looks up the dtype and the shape of the tensor keyed by "key".
197  // REQUIRES: status().ok()
198  Status LookupDtypeAndShape(StringPiece key, DataType* dtype,
199                             TensorShape* shape) TF_MUST_USE_RESULT;
200
201  // Looks up the shape of the tensor keyed by "key".
202  // Clears "shape" if not found.
203  // REQUIRES: status().ok()
204  Status LookupTensorShape(StringPiece key,
205                           TensorShape* shape) TF_MUST_USE_RESULT;
206
207  // Looks up the tensor keyed by "key".  If "key" refers to a partitioned
208  // tensor, attempts to look up the full contents using all stored slices.
209  //
210  // Caller must make sure "val" has the same shape and dtype as the
211  // corresponding contents, so that its buffer can be filled without needing
212  // extra allocation.  These can be queried via "LookupDtypeAndShape()".
213  //
214  // On error, "val" may contain nonsense data.  Returns a NotFound error if
215  // tensor keyed by "key" does not exist in this bundle.
216  //
217  // Validates the stored crc32c checksum against the restored bytes.
218  // REQUIRES: status().ok()
219  Status Lookup(StringPiece key, Tensor* val) TF_MUST_USE_RESULT;
220
221  // Looks up the tensor pointed to by the internal iterator.
222  //
223  // On error, "val" may contain nonsense data.
224  //
225  // Validates the stored crc32c checksum against the restored bytes.
226  // REQUIRES: status().ok() && Valid()
227  Status ReadCurrent(Tensor* val) TF_MUST_USE_RESULT;
228
229  // Looks up the slices of the tensor keyed by "key".  On OK, "slices"
230  // is non-empty if and only if the tensor is a partitioned tensor.
231  //
232  // Warning - there is no guaranteed ordering for the returned slices, so
233  // a slice with a larger start index in some dimension could come before
234  // another slice with a smaller start index in the same dimension.
235  // REQUIRES: status().ok()
236  Status LookupTensorSlices(StringPiece key, std::vector<TensorSlice>* slices)
237      TF_MUST_USE_RESULT;
238
239  // Looks up a specific slice of a partitioned tensor.
240  // It is only required that the stored slices cover the requested slice,
241  // namely "slice_spec" is a subset of the union of the stored slices.
242  // REQUIRES: status().ok()
243  Status LookupSlice(StringPiece full_tensor_key, const TensorSlice& slice_spec,
244                     Tensor* val) TF_MUST_USE_RESULT;
245
246  // Seeks to the first position in the bundle whose key is no less than "key".
247  // REQUIRES: status().ok()
248  void Seek(StringPiece key) { return iter_->Seek(key); }
249  // Moves to the next position in the bundle.
250  // REQUIRES: status().ok()
251  void Next() const { iter_->Next(); }
252  // Returns true iff the reader is positioned to a key/val pair.
253  // REQUIRES: status().ok()
254  bool Valid() const { return iter_->Valid(); }
255
256  // Returns the key at the current position.
257  // REQUIRES: status().ok() && Valid()
258  StringPiece key() const { return iter_->key(); }
259  // Returns the raw value at the current position.
260  // REQUIRES: status().ok() && Valid()
261  StringPiece value() const { return iter_->value(); }
262
263  string DebugString();
264
265 private:
266  // Seeks for "key" and reads the metadata proto.
267  // On non-OK return, clears "entry" for the caller.
268  // REQUIRES: status().ok()
269  Status GetBundleEntryProto(StringPiece key,
270                             BundleEntryProto* entry) TF_MUST_USE_RESULT;
271
272  // Reads the tensor value described by the metadata proto "entry".
273  // Usage for "val" follows the comment of "Lookup()".
274  Status GetValue(const BundleEntryProto& entry,
275                  Tensor* val) TF_MUST_USE_RESULT;
276
277  // Reads the slice described by "slice_spec".  The corresponding full tensor
278  // has key "ful_tensor_key" and metadata proto "full_tensor_entry".
279  // REQUIRES: full_tensor_entry.slices_size() > 0
280  Status GetSliceValue(StringPiece full_tensor_key,
281                       const BundleEntryProto& full_tensor_entry,
282                       const TensorSlice& slice_spec,
283                       Tensor* val) TF_MUST_USE_RESULT;
284
285  Env* env_;  // Not owned.
286  const string prefix_;
287
288  Status status_;
289  RandomAccessFile* metadata_;  // Owned.
290  table::Table* table_;
291  table::Iterator* iter_;
292  // Owned the InputBuffer objects and their underlying RandomAccessFile's.
293  std::unordered_map<int32, io::InputBuffer*> data_;
294
295  // Maps each partitioned tensor's key to its stored slices (represented in a
296  // TensorSliceSet).  Populated on-demand.
297  std::unordered_map<string, checkpoint::TensorSliceSet*> tensor_slices_;
298
299  // Expected number of data file shards in the bundle.  Extracted by reading
300  // the header entry in the metadata table.
301  int num_shards_;
302
303  friend class TensorBundleAlignmentTest;  // For testing data alignment.
304
305  TF_DISALLOW_COPY_AND_ASSIGN(BundleReader);
306};
307
308// A buffering wrapper for a WritableFile.  Useful if the caller wishes to issue
309// small writes to a file (e.g. writing out a list of small varints).
310// External synchronization must be used in the presence of concurrent callers.
311class FileOutputBuffer {
312 public:
313  FileOutputBuffer(WritableFile* file, size_t buffer_size)
314      : file_(file), position_(0), buffer_size_(buffer_size) {
315    DCHECK_GT(buffer_size, 0);
316    buffer_.resize(buffer_size);
317  }
318  ~FileOutputBuffer();
319
320  // Buffered append.
321  Status Append(StringPiece data);
322
323  // Returns the running crc32c checksum of all currently appended bytes.
324  uint32 crc32c() { return crc32c_; }
325  // Clears the running crc32c checksum.
326  void clear_crc32c() { crc32c_ = 0; }
327
328  // Appends the buffered data, then closes the underlying file.
329  Status Close();
330
331 private:
332  // Appends the buffered data to the underlying file. Does NOT flush the file.
333  Status FlushBuffer();
334
335  WritableFile* file_;  // Owned.
336
337  // buffer_[0, position_) holds the buffered data not yet appended to the
338  // underlying file.
339  size_t position_;
340  const size_t buffer_size_;
341  std::vector<char> buffer_;
342
343  // Checksum of all appended bytes since construction or last clear_crc32c().
344  uint32 crc32c_ = 0;
345};
346
347}  // namespace tensorflow
348
349#endif  // TENSORFLOW_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
350