1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include "tensorflow/core/framework/partial_tensor_shape.h" 16#include "tensorflow/core/framework/tensor.h" 17#include "tensorflow/core/kernels/data/dataset.h" 18 19namespace tensorflow { 20 21namespace { 22 23// See documentation in ../ops/dataset_ops.cc for a high-level 24// description of the following op. 25 26class ZipDatasetOp : public DatasetOpKernel { 27 public: 28 explicit ZipDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} 29 30 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { 31 std::vector<DatasetBase*> inputs; 32 for (size_t i = 0; i < ctx->num_inputs(); ++i) { 33 DatasetBase* input; 34 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input)); 35 inputs.push_back(input); 36 } 37 *output = new Dataset(ctx, inputs); 38 } 39 40 private: 41 class Dataset : public GraphDatasetBase { 42 public: 43 explicit Dataset(OpKernelContext* ctx, 44 const std::vector<DatasetBase*>& inputs) 45 : GraphDatasetBase(ctx), inputs_(inputs) { 46 for (const auto& input : inputs_) { 47 input->Ref(); 48 for (DataType dt : input->output_dtypes()) { 49 output_dtypes_.push_back(dt); 50 } 51 output_shapes_.insert(output_shapes_.end(), 52 input->output_shapes().begin(), 53 input->output_shapes().end()); 54 } 55 } 56 57 ~Dataset() override { 58 for (const auto& input : inputs_) { 59 input->Unref(); 60 } 61 } 62 63 std::unique_ptr<IteratorBase> MakeIterator( 64 const string& prefix) const override { 65 return std::unique_ptr<IteratorBase>( 66 new Iterator({this, strings::StrCat(prefix, "::Zip")})); 67 } 68 69 const DataTypeVector& output_dtypes() const override { 70 return output_dtypes_; 71 } 72 73 const std::vector<PartialTensorShape>& output_shapes() const override { 74 return output_shapes_; 75 } 76 77 string DebugString() override { return "ZipDatasetOp::Dataset"; } 78 79 protected: 80 Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, 81 Node** output) const override { 82 std::vector<Node*> input_graph_nodes; 83 input_graph_nodes.reserve(inputs_.size()); 84 for (const auto& input : inputs_) { 85 Node* input_node; 86 TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input, &input_node)); 87 input_graph_nodes.emplace_back(input_node); 88 } 89 TF_RETURN_IF_ERROR(b->AddDataset( 90 this, {}, {std::make_pair(0, input_graph_nodes)}, {}, output)); 91 return Status::OK(); 92 } 93 94 private: 95 class Iterator : public DatasetIterator<Dataset> { 96 public: 97 explicit Iterator(const Params& params) 98 : DatasetIterator<Dataset>(params) { 99 input_impls_.reserve(params.dataset->inputs_.size()); 100 size_t idx = 0; 101 for (const auto& input : params.dataset->inputs_) { 102 input_impls_.emplace_back(input->MakeIterator( 103 strings::StrCat(params.prefix, "[", idx++, "]"))); 104 } 105 } 106 107 Status GetNextInternal(IteratorContext* ctx, 108 std::vector<Tensor>* out_tensors, 109 bool* end_of_sequence) override { 110 mutex_lock l(mu_); 111 if (input_impls_.empty()) { 112 *end_of_sequence = true; 113 return Status::OK(); 114 } 115 out_tensors->clear(); 116 out_tensors->reserve(dataset()->output_dtypes().size()); 117 for (const auto& input_impl : input_impls_) { 118 std::vector<Tensor> input_tensors; 119 TF_RETURN_IF_ERROR( 120 input_impl->GetNext(ctx, &input_tensors, end_of_sequence)); 121 if (*end_of_sequence) { 122 break; 123 } 124 out_tensors->insert(out_tensors->end(), input_tensors.begin(), 125 input_tensors.end()); 126 } 127 if (*end_of_sequence) { 128 out_tensors->clear(); 129 input_impls_.clear(); 130 } 131 return Status::OK(); 132 } 133 134 protected: 135 Status SaveInternal(IteratorStateWriter* writer) override { 136 mutex_lock l(mu_); 137 if (input_impls_.empty()) { 138 TF_RETURN_IF_ERROR( 139 writer->WriteScalar(full_name("input_impls_empty"), "")); 140 } else { 141 for (auto& input_impl : input_impls_) 142 TF_RETURN_IF_ERROR(SaveParent(writer, input_impl)); 143 } 144 return Status::OK(); 145 } 146 147 Status RestoreInternal(IteratorContext* ctx, 148 IteratorStateReader* reader) override { 149 mutex_lock l(mu_); 150 if (reader->Contains(full_name("input_impls_empty"))) { 151 input_impls_.clear(); 152 } else { 153 DCHECK_EQ(input_impls_.size(), dataset()->inputs_.size()); 154 for (auto& input_impl : input_impls_) 155 TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl)); 156 } 157 return Status::OK(); 158 } 159 160 private: 161 mutex mu_; 162 std::vector<std::unique_ptr<IteratorBase>> input_impls_ GUARDED_BY(mu_); 163 }; 164 165 const std::vector<DatasetBase*> inputs_; 166 DataTypeVector output_dtypes_; 167 std::vector<PartialTensorShape> output_shapes_; 168 }; 169}; 170 171REGISTER_KERNEL_BUILDER(Name("ZipDataset").Device(DEVICE_CPU), ZipDatasetOp); 172 173} // namespace 174 175} // namespace tensorflow 176