1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
22966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore
32966fcdae0cf7d75f4eca027f3003e355089b0edSherry MooreLicensed under the Apache License, Version 2.0 (the "License");
42966fcdae0cf7d75f4eca027f3003e355089b0edSherry Mooreyou may not use this file except in compliance with the License.
52966fcdae0cf7d75f4eca027f3003e355089b0edSherry MooreYou may obtain a copy of the License at
62966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore
72966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore    http://www.apache.org/licenses/LICENSE-2.0
82966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore
92966fcdae0cf7d75f4eca027f3003e355089b0edSherry MooreUnless required by applicable law or agreed to in writing, software
102966fcdae0cf7d75f4eca027f3003e355089b0edSherry Mooredistributed under the License is distributed on an "AS IS" BASIS,
112966fcdae0cf7d75f4eca027f3003e355089b0edSherry MooreWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122966fcdae0cf7d75f4eca027f3003e355089b0edSherry MooreSee the License for the specific language governing permissions and
132966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moorelimitations under the License.
142966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore==============================================================================*/
152966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore
16788f359b7218ad46696c15459c89688ffe70955eA. Unique TensorFlower#include "tensorflow/c/checkpoint_reader.h"
171e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang
181e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang#include <unordered_set>
19a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower#include <utility>
201e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang
212966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore#include "tensorflow/core/lib/core/status.h"
222966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore#include "tensorflow/core/lib/core/stringpiece.h"
232966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore#include "tensorflow/core/platform/env.h"
242966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore#include "tensorflow/core/platform/types.h"
251e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang#include "tensorflow/core/util/saved_tensor_slice_util.h"
262966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore
272966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moorenamespace tensorflow {
282966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moorenamespace checkpoint {
292966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore
302966fcdae0cf7d75f4eca027f3003e355089b0edSherry Mooreclass TensorSliceReader;
312966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore
322966fcdae0cf7d75f4eca027f3003e355089b0edSherry MooreCheckpointReader::CheckpointReader(const string& filename,
3394d27b6852b3e331fd9d64a0533f0fc27af05bfdA. Unique TensorFlower                                   TF_Status* out_status)
34a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower    : reader_(nullptr),
35a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower      v2_reader_(nullptr),
36a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower      var_to_shape_map_(nullptr),
37a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower      var_to_data_type_map_(nullptr) {
38ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  // Depending on whether this is a V2 ckpt, initializes "reader_" or
39ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  // "v2_reader_".
40ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  std::vector<string> v2_path;
41ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  if (Env::Default()->GetMatchingPaths(MetaFilename(filename), &v2_path).ok() &&
42ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang      !v2_path.empty()) {
43220515bffdf1df5379a7f8921f5a12deb2e0dee7A. Unique TensorFlower    v2_reader_.reset(
44220515bffdf1df5379a7f8921f5a12deb2e0dee7A. Unique TensorFlower        new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */));
45ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang    if (!v2_reader_->status().ok()) {
46ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang      Set_TF_Status_from_Status(out_status, v2_reader_->status());
47ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang      return;
48ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang    }
49a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower    auto result = BuildV2VarMaps();
50a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower    var_to_shape_map_.swap(result.first);
51a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower    var_to_data_type_map_.swap(result.second);
522966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore  } else {
53220515bffdf1df5379a7f8921f5a12deb2e0dee7A. Unique TensorFlower    reader_.reset(new TensorSliceReader(filename));
54ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang    if (!reader_->status().ok()) {
55ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang      Set_TF_Status_from_Status(out_status, reader_->status());
56ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang      return;
57ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang    }
58a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower    var_to_shape_map_.reset(
59220515bffdf1df5379a7f8921f5a12deb2e0dee7A. Unique TensorFlower        new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap()));
60a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower    var_to_data_type_map_.reset(new TensorSliceReader::VarToDataTypeMap(
61a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower        reader_->GetVariableToDataTypeMap()));
622966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore  }
632966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore}
642966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore
652966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moorebool CheckpointReader::HasTensor(const string& name) const {
66ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  if (reader_ != nullptr) {
67ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang    return reader_->HasTensor(name, nullptr, nullptr);
68ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  }
69ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  return v2_reader_->Contains(name);
702966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore}
712966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore
722966fcdae0cf7d75f4eca027f3003e355089b0edSherry Mooreconst TensorSliceReader::VarToShapeMap&
732966fcdae0cf7d75f4eca027f3003e355089b0edSherry MooreCheckpointReader::GetVariableToShapeMap() const {
74a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower  CHECK(var_to_shape_map_);
75a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower  return *var_to_shape_map_;
76a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower}
77a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower
78a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlowerconst TensorSliceReader::VarToDataTypeMap&
79a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlowerCheckpointReader::GetVariableToDataTypeMap() const {
80a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower  CHECK(var_to_data_type_map_);
81a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower  return *var_to_data_type_map_;
822966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore}
832966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore
842966fcdae0cf7d75f4eca027f3003e355089b0edSherry Mooreconst string CheckpointReader::DebugString() const {
85ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  if (reader_ != nullptr) return reader_->DebugString();
86ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  return v2_reader_->DebugString();
87ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang}
88ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang
89ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yangvoid CheckpointReader::GetTensor(
90ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang    const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor,
91ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang    TF_Status* out_status) const {
92ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  Status status;
93ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  if (reader_ != nullptr) {
94ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang    status = reader_->GetTensor(name, out_tensor);
95ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  } else {
96cae3713cd4a2a191f10012a8efab19c721d41742Zongheng Yang    tensorflow::DataType dtype;
97cae3713cd4a2a191f10012a8efab19c721d41742Zongheng Yang    tensorflow::TensorShape shape;
98cae3713cd4a2a191f10012a8efab19c721d41742Zongheng Yang    status = v2_reader_->LookupDtypeAndShape(name, &dtype, &shape);
99cae3713cd4a2a191f10012a8efab19c721d41742Zongheng Yang    if (status.ok()) {
100cae3713cd4a2a191f10012a8efab19c721d41742Zongheng Yang      out_tensor->reset(new Tensor(dtype, shape));
101cae3713cd4a2a191f10012a8efab19c721d41742Zongheng Yang      status = v2_reader_->Lookup(name, out_tensor->get());
102cae3713cd4a2a191f10012a8efab19c721d41742Zongheng Yang      if (!status.ok()) out_tensor->reset();
103cae3713cd4a2a191f10012a8efab19c721d41742Zongheng Yang    }
104ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  }
105ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  if (!status.ok()) {
106ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang    Set_TF_Status_from_Status(out_status, status);
107ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  }
108ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang}
109ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang
110a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlowerstd::pair<std::unique_ptr<TensorSliceReader::VarToShapeMap>,
111a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower          std::unique_ptr<TensorSliceReader::VarToDataTypeMap>>
112a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlowerCheckpointReader::BuildV2VarMaps() {
113ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  CHECK(v2_reader_ != nullptr);
114ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  CHECK(v2_reader_->status().ok());
1151e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang
1161e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang  // First pass: filters out the entries of the slices.
1171e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang  std::unordered_set<string> filtered_keys;
1181e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang  BundleEntryProto entry;
119ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  v2_reader_->Seek(kHeaderEntryKey);
1201e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang  for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
1211e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang    CHECK(entry.ParseFromArray(v2_reader_->value().data(),
1221e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang                               v2_reader_->value().size()))
1231e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang        << entry.InitializationErrorString();
1241e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang    for (int i = 0; i < entry.slices_size(); ++i) {
1251e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang      const auto& slice_proto = entry.slices(i);
1261e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang      CHECK(filtered_keys
1271e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang                .insert(EncodeTensorNameSlice(
1281e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang                    v2_reader_->key().ToString() /* full var's name */,
1291e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang                    TensorSlice(slice_proto)))
1301e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang                .second);
1311e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang    }
1321e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang  }
133ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang
1341e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang  // Second pass: adds the entries, ignoring the filtered keys.
135220515bffdf1df5379a7f8921f5a12deb2e0dee7A. Unique TensorFlower  std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map(
136220515bffdf1df5379a7f8921f5a12deb2e0dee7A. Unique TensorFlower      new TensorSliceReader::VarToShapeMap);
137a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower  std::unique_ptr<TensorSliceReader::VarToDataTypeMap> var_to_data_type_map(
138a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower      new TensorSliceReader::VarToDataTypeMap);
1391e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang  v2_reader_->Seek(kHeaderEntryKey);
140ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
1411e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang    if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue;
142ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang    CHECK(entry.ParseFromArray(v2_reader_->value().data(),
1431e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang                               v2_reader_->value().size()))
1441e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang        << entry.InitializationErrorString();
145a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower    string key = v2_reader_->key().ToString();
146a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower    (*var_to_shape_map)[key] = TensorShape(entry.shape());
147a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower    (*var_to_data_type_map)[key] = DataType(entry.dtype());
148ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang  }
149a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower  // The returned pointers are owned by the caller.
150a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower  return std::make_pair(std::move(var_to_shape_map),
151a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower                        std::move(var_to_data_type_map));
1522966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore}
1532966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore
1542966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore}  // namespace checkpoint
1552966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore}  // namespace tensorflow
156