1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/c/checkpoint_reader.h"
17
18#include <unordered_set>
19#include <utility>
20
21#include "tensorflow/core/lib/core/status.h"
22#include "tensorflow/core/lib/core/stringpiece.h"
23#include "tensorflow/core/platform/env.h"
24#include "tensorflow/core/platform/types.h"
25#include "tensorflow/core/util/saved_tensor_slice_util.h"
26
27namespace tensorflow {
28namespace checkpoint {
29
30class TensorSliceReader;
31
32CheckpointReader::CheckpointReader(const string& filename,
33                                   TF_Status* out_status)
34    : reader_(nullptr),
35      v2_reader_(nullptr),
36      var_to_shape_map_(nullptr),
37      var_to_data_type_map_(nullptr) {
38  // Depending on whether this is a V2 ckpt, initializes "reader_" or
39  // "v2_reader_".
40  std::vector<string> v2_path;
41  if (Env::Default()->GetMatchingPaths(MetaFilename(filename), &v2_path).ok() &&
42      !v2_path.empty()) {
43    v2_reader_.reset(
44        new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */));
45    if (!v2_reader_->status().ok()) {
46      Set_TF_Status_from_Status(out_status, v2_reader_->status());
47      return;
48    }
49    auto result = BuildV2VarMaps();
50    var_to_shape_map_.swap(result.first);
51    var_to_data_type_map_.swap(result.second);
52  } else {
53    reader_.reset(new TensorSliceReader(filename));
54    if (!reader_->status().ok()) {
55      Set_TF_Status_from_Status(out_status, reader_->status());
56      return;
57    }
58    var_to_shape_map_.reset(
59        new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap()));
60    var_to_data_type_map_.reset(new TensorSliceReader::VarToDataTypeMap(
61        reader_->GetVariableToDataTypeMap()));
62  }
63}
64
65bool CheckpointReader::HasTensor(const string& name) const {
66  if (reader_ != nullptr) {
67    return reader_->HasTensor(name, nullptr, nullptr);
68  }
69  return v2_reader_->Contains(name);
70}
71
72const TensorSliceReader::VarToShapeMap&
73CheckpointReader::GetVariableToShapeMap() const {
74  CHECK(var_to_shape_map_);
75  return *var_to_shape_map_;
76}
77
78const TensorSliceReader::VarToDataTypeMap&
79CheckpointReader::GetVariableToDataTypeMap() const {
80  CHECK(var_to_data_type_map_);
81  return *var_to_data_type_map_;
82}
83
84const string CheckpointReader::DebugString() const {
85  if (reader_ != nullptr) return reader_->DebugString();
86  return v2_reader_->DebugString();
87}
88
89void CheckpointReader::GetTensor(
90    const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor,
91    TF_Status* out_status) const {
92  Status status;
93  if (reader_ != nullptr) {
94    status = reader_->GetTensor(name, out_tensor);
95  } else {
96    tensorflow::DataType dtype;
97    tensorflow::TensorShape shape;
98    status = v2_reader_->LookupDtypeAndShape(name, &dtype, &shape);
99    if (status.ok()) {
100      out_tensor->reset(new Tensor(dtype, shape));
101      status = v2_reader_->Lookup(name, out_tensor->get());
102      if (!status.ok()) out_tensor->reset();
103    }
104  }
105  if (!status.ok()) {
106    Set_TF_Status_from_Status(out_status, status);
107  }
108}
109
110std::pair<std::unique_ptr<TensorSliceReader::VarToShapeMap>,
111          std::unique_ptr<TensorSliceReader::VarToDataTypeMap>>
112CheckpointReader::BuildV2VarMaps() {
113  CHECK(v2_reader_ != nullptr);
114  CHECK(v2_reader_->status().ok());
115
116  // First pass: filters out the entries of the slices.
117  std::unordered_set<string> filtered_keys;
118  BundleEntryProto entry;
119  v2_reader_->Seek(kHeaderEntryKey);
120  for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
121    CHECK(entry.ParseFromArray(v2_reader_->value().data(),
122                               v2_reader_->value().size()))
123        << entry.InitializationErrorString();
124    for (int i = 0; i < entry.slices_size(); ++i) {
125      const auto& slice_proto = entry.slices(i);
126      CHECK(filtered_keys
127                .insert(EncodeTensorNameSlice(
128                    v2_reader_->key().ToString() /* full var's name */,
129                    TensorSlice(slice_proto)))
130                .second);
131    }
132  }
133
134  // Second pass: adds the entries, ignoring the filtered keys.
135  std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map(
136      new TensorSliceReader::VarToShapeMap);
137  std::unique_ptr<TensorSliceReader::VarToDataTypeMap> var_to_data_type_map(
138      new TensorSliceReader::VarToDataTypeMap);
139  v2_reader_->Seek(kHeaderEntryKey);
140  for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
141    if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue;
142    CHECK(entry.ParseFromArray(v2_reader_->value().data(),
143                               v2_reader_->value().size()))
144        << entry.InitializationErrorString();
145    string key = v2_reader_->key().ToString();
146    (*var_to_shape_map)[key] = TensorShape(entry.shape());
147    (*var_to_data_type_map)[key] = DataType(entry.dtype());
148  }
149  // The returned pointers are owned by the caller.
150  return std::make_pair(std::move(var_to_shape_map),
151                        std::move(var_to_data_type_map));
152}
153
154}  // namespace checkpoint
155}  // namespace tensorflow
156