1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 22966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore 32966fcdae0cf7d75f4eca027f3003e355089b0edSherry MooreLicensed under the Apache License, Version 2.0 (the "License"); 42966fcdae0cf7d75f4eca027f3003e355089b0edSherry Mooreyou may not use this file except in compliance with the License. 52966fcdae0cf7d75f4eca027f3003e355089b0edSherry MooreYou may obtain a copy of the License at 62966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore 72966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore http://www.apache.org/licenses/LICENSE-2.0 82966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore 92966fcdae0cf7d75f4eca027f3003e355089b0edSherry MooreUnless required by applicable law or agreed to in writing, software 102966fcdae0cf7d75f4eca027f3003e355089b0edSherry Mooredistributed under the License is distributed on an "AS IS" BASIS, 112966fcdae0cf7d75f4eca027f3003e355089b0edSherry MooreWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 122966fcdae0cf7d75f4eca027f3003e355089b0edSherry MooreSee the License for the specific language governing permissions and 132966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moorelimitations under the License. 142966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore==============================================================================*/ 152966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore 16788f359b7218ad46696c15459c89688ffe70955eA. Unique TensorFlower#include "tensorflow/c/checkpoint_reader.h" 171e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang 181e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang#include <unordered_set> 19a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower#include <utility> 201e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang 212966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore#include "tensorflow/core/lib/core/status.h" 222966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore#include "tensorflow/core/lib/core/stringpiece.h" 232966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore#include "tensorflow/core/platform/env.h" 242966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore#include "tensorflow/core/platform/types.h" 251e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang#include "tensorflow/core/util/saved_tensor_slice_util.h" 262966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore 272966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moorenamespace tensorflow { 282966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moorenamespace checkpoint { 292966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore 302966fcdae0cf7d75f4eca027f3003e355089b0edSherry Mooreclass TensorSliceReader; 312966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore 322966fcdae0cf7d75f4eca027f3003e355089b0edSherry MooreCheckpointReader::CheckpointReader(const string& filename, 3394d27b6852b3e331fd9d64a0533f0fc27af05bfdA. Unique TensorFlower TF_Status* out_status) 34a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower : reader_(nullptr), 35a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower v2_reader_(nullptr), 36a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower var_to_shape_map_(nullptr), 37a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower var_to_data_type_map_(nullptr) { 38ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang // Depending on whether this is a V2 ckpt, initializes "reader_" or 39ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang // "v2_reader_". 40ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang std::vector<string> v2_path; 41ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang if (Env::Default()->GetMatchingPaths(MetaFilename(filename), &v2_path).ok() && 42ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang !v2_path.empty()) { 43220515bffdf1df5379a7f8921f5a12deb2e0dee7A. Unique TensorFlower v2_reader_.reset( 44220515bffdf1df5379a7f8921f5a12deb2e0dee7A. Unique TensorFlower new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */)); 45ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang if (!v2_reader_->status().ok()) { 46ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang Set_TF_Status_from_Status(out_status, v2_reader_->status()); 47ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang return; 48ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang } 49a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower auto result = BuildV2VarMaps(); 50a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower var_to_shape_map_.swap(result.first); 51a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower var_to_data_type_map_.swap(result.second); 522966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore } else { 53220515bffdf1df5379a7f8921f5a12deb2e0dee7A. Unique TensorFlower reader_.reset(new TensorSliceReader(filename)); 54ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang if (!reader_->status().ok()) { 55ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang Set_TF_Status_from_Status(out_status, reader_->status()); 56ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang return; 57ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang } 58a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower var_to_shape_map_.reset( 59220515bffdf1df5379a7f8921f5a12deb2e0dee7A. Unique TensorFlower new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap())); 60a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower var_to_data_type_map_.reset(new TensorSliceReader::VarToDataTypeMap( 61a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower reader_->GetVariableToDataTypeMap())); 622966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore } 632966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore} 642966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore 652966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moorebool CheckpointReader::HasTensor(const string& name) const { 66ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang if (reader_ != nullptr) { 67ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang return reader_->HasTensor(name, nullptr, nullptr); 68ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang } 69ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang return v2_reader_->Contains(name); 702966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore} 712966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore 722966fcdae0cf7d75f4eca027f3003e355089b0edSherry Mooreconst TensorSliceReader::VarToShapeMap& 732966fcdae0cf7d75f4eca027f3003e355089b0edSherry MooreCheckpointReader::GetVariableToShapeMap() const { 74a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower CHECK(var_to_shape_map_); 75a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower return *var_to_shape_map_; 76a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower} 77a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower 78a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlowerconst TensorSliceReader::VarToDataTypeMap& 79a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlowerCheckpointReader::GetVariableToDataTypeMap() const { 80a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower CHECK(var_to_data_type_map_); 81a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower return *var_to_data_type_map_; 822966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore} 832966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore 842966fcdae0cf7d75f4eca027f3003e355089b0edSherry Mooreconst string CheckpointReader::DebugString() const { 85ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang if (reader_ != nullptr) return reader_->DebugString(); 86ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang return v2_reader_->DebugString(); 87ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang} 88ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang 89ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yangvoid CheckpointReader::GetTensor( 90ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor, 91ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang TF_Status* out_status) const { 92ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang Status status; 93ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang if (reader_ != nullptr) { 94ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang status = reader_->GetTensor(name, out_tensor); 95ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang } else { 96cae3713cd4a2a191f10012a8efab19c721d41742Zongheng Yang tensorflow::DataType dtype; 97cae3713cd4a2a191f10012a8efab19c721d41742Zongheng Yang tensorflow::TensorShape shape; 98cae3713cd4a2a191f10012a8efab19c721d41742Zongheng Yang status = v2_reader_->LookupDtypeAndShape(name, &dtype, &shape); 99cae3713cd4a2a191f10012a8efab19c721d41742Zongheng Yang if (status.ok()) { 100cae3713cd4a2a191f10012a8efab19c721d41742Zongheng Yang out_tensor->reset(new Tensor(dtype, shape)); 101cae3713cd4a2a191f10012a8efab19c721d41742Zongheng Yang status = v2_reader_->Lookup(name, out_tensor->get()); 102cae3713cd4a2a191f10012a8efab19c721d41742Zongheng Yang if (!status.ok()) out_tensor->reset(); 103cae3713cd4a2a191f10012a8efab19c721d41742Zongheng Yang } 104ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang } 105ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang if (!status.ok()) { 106ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang Set_TF_Status_from_Status(out_status, status); 107ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang } 108ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang} 109ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang 110a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlowerstd::pair<std::unique_ptr<TensorSliceReader::VarToShapeMap>, 111a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower std::unique_ptr<TensorSliceReader::VarToDataTypeMap>> 112a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlowerCheckpointReader::BuildV2VarMaps() { 113ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang CHECK(v2_reader_ != nullptr); 114ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang CHECK(v2_reader_->status().ok()); 1151e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang 1161e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang // First pass: filters out the entries of the slices. 1171e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang std::unordered_set<string> filtered_keys; 1181e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang BundleEntryProto entry; 119ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang v2_reader_->Seek(kHeaderEntryKey); 1201e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { 1211e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang CHECK(entry.ParseFromArray(v2_reader_->value().data(), 1221e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang v2_reader_->value().size())) 1231e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang << entry.InitializationErrorString(); 1241e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang for (int i = 0; i < entry.slices_size(); ++i) { 1251e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang const auto& slice_proto = entry.slices(i); 1261e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang CHECK(filtered_keys 1271e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang .insert(EncodeTensorNameSlice( 1281e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang v2_reader_->key().ToString() /* full var's name */, 1291e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang TensorSlice(slice_proto))) 1301e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang .second); 1311e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang } 1321e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang } 133ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang 1341e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang // Second pass: adds the entries, ignoring the filtered keys. 135220515bffdf1df5379a7f8921f5a12deb2e0dee7A. Unique TensorFlower std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map( 136220515bffdf1df5379a7f8921f5a12deb2e0dee7A. Unique TensorFlower new TensorSliceReader::VarToShapeMap); 137a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower std::unique_ptr<TensorSliceReader::VarToDataTypeMap> var_to_data_type_map( 138a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower new TensorSliceReader::VarToDataTypeMap); 1391e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang v2_reader_->Seek(kHeaderEntryKey); 140ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { 1411e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue; 142ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang CHECK(entry.ParseFromArray(v2_reader_->value().data(), 1431e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang v2_reader_->value().size())) 1441e324a0f2a67cfa651677bd381bf1bf2adc3e2f8Zongheng Yang << entry.InitializationErrorString(); 145a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower string key = v2_reader_->key().ToString(); 146a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower (*var_to_shape_map)[key] = TensorShape(entry.shape()); 147a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower (*var_to_data_type_map)[key] = DataType(entry.dtype()); 148ecdf0b202a2bfcff7985e62da727397bd8c67a91Zongheng Yang } 149a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower // The returned pointers are owned by the caller. 150a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower return std::make_pair(std::move(var_to_shape_map), 151a8c5d5fe011e796593d20c74d8b927c014a27c89A. Unique TensorFlower std::move(var_to_data_type_map)); 1522966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore} 1532966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore 1542966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore} // namespace checkpoint 1552966fcdae0cf7d75f4eca027f3003e355089b0edSherry Moore} // namespace tensorflow 156