1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/framework/partial_tensor_shape.h"
16#include "tensorflow/core/framework/tensor.h"
17#include "tensorflow/core/kernels/data/dataset.h"
18
19namespace tensorflow {
20
21namespace {
22
23// See documentation in ../ops/dataset_ops.cc for a high-level
24// description of the following op.
25
26class ZipDatasetOp : public DatasetOpKernel {
27 public:
28  explicit ZipDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
29
30  void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
31    std::vector<DatasetBase*> inputs;
32    for (size_t i = 0; i < ctx->num_inputs(); ++i) {
33      DatasetBase* input;
34      OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
35      inputs.push_back(input);
36    }
37    *output = new Dataset(ctx, inputs);
38  }
39
40 private:
41  class Dataset : public GraphDatasetBase {
42   public:
43    explicit Dataset(OpKernelContext* ctx,
44                     const std::vector<DatasetBase*>& inputs)
45        : GraphDatasetBase(ctx), inputs_(inputs) {
46      for (const auto& input : inputs_) {
47        input->Ref();
48        for (DataType dt : input->output_dtypes()) {
49          output_dtypes_.push_back(dt);
50        }
51        output_shapes_.insert(output_shapes_.end(),
52                              input->output_shapes().begin(),
53                              input->output_shapes().end());
54      }
55    }
56
57    ~Dataset() override {
58      for (const auto& input : inputs_) {
59        input->Unref();
60      }
61    }
62
63    std::unique_ptr<IteratorBase> MakeIterator(
64        const string& prefix) const override {
65      return std::unique_ptr<IteratorBase>(
66          new Iterator({this, strings::StrCat(prefix, "::Zip")}));
67    }
68
69    const DataTypeVector& output_dtypes() const override {
70      return output_dtypes_;
71    }
72
73    const std::vector<PartialTensorShape>& output_shapes() const override {
74      return output_shapes_;
75    }
76
77    string DebugString() override { return "ZipDatasetOp::Dataset"; }
78
79   protected:
80    Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
81                              Node** output) const override {
82      std::vector<Node*> input_graph_nodes;
83      input_graph_nodes.reserve(inputs_.size());
84      for (const auto& input : inputs_) {
85        Node* input_node;
86        TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input, &input_node));
87        input_graph_nodes.emplace_back(input_node);
88      }
89      TF_RETURN_IF_ERROR(b->AddDataset(
90          this, {}, {std::make_pair(0, input_graph_nodes)}, {}, output));
91      return Status::OK();
92    }
93
94   private:
95    class Iterator : public DatasetIterator<Dataset> {
96     public:
97      explicit Iterator(const Params& params)
98          : DatasetIterator<Dataset>(params) {
99        input_impls_.reserve(params.dataset->inputs_.size());
100        size_t idx = 0;
101        for (const auto& input : params.dataset->inputs_) {
102          input_impls_.emplace_back(input->MakeIterator(
103              strings::StrCat(params.prefix, "[", idx++, "]")));
104        }
105      }
106
107      Status GetNextInternal(IteratorContext* ctx,
108                             std::vector<Tensor>* out_tensors,
109                             bool* end_of_sequence) override {
110        mutex_lock l(mu_);
111        if (input_impls_.empty()) {
112          *end_of_sequence = true;
113          return Status::OK();
114        }
115        out_tensors->clear();
116        out_tensors->reserve(dataset()->output_dtypes().size());
117        for (const auto& input_impl : input_impls_) {
118          std::vector<Tensor> input_tensors;
119          TF_RETURN_IF_ERROR(
120              input_impl->GetNext(ctx, &input_tensors, end_of_sequence));
121          if (*end_of_sequence) {
122            break;
123          }
124          out_tensors->insert(out_tensors->end(), input_tensors.begin(),
125                              input_tensors.end());
126        }
127        if (*end_of_sequence) {
128          out_tensors->clear();
129          input_impls_.clear();
130        }
131        return Status::OK();
132      }
133
134     protected:
135      Status SaveInternal(IteratorStateWriter* writer) override {
136        mutex_lock l(mu_);
137        if (input_impls_.empty()) {
138          TF_RETURN_IF_ERROR(
139              writer->WriteScalar(full_name("input_impls_empty"), ""));
140        } else {
141          for (auto& input_impl : input_impls_)
142            TF_RETURN_IF_ERROR(SaveParent(writer, input_impl));
143        }
144        return Status::OK();
145      }
146
147      Status RestoreInternal(IteratorContext* ctx,
148                             IteratorStateReader* reader) override {
149        mutex_lock l(mu_);
150        if (reader->Contains(full_name("input_impls_empty"))) {
151          input_impls_.clear();
152        } else {
153          DCHECK_EQ(input_impls_.size(), dataset()->inputs_.size());
154          for (auto& input_impl : input_impls_)
155            TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl));
156        }
157        return Status::OK();
158      }
159
160     private:
161      mutex mu_;
162      std::vector<std::unique_ptr<IteratorBase>> input_impls_ GUARDED_BY(mu_);
163    };
164
165    const std::vector<DatasetBase*> inputs_;
166    DataTypeVector output_dtypes_;
167    std::vector<PartialTensorShape> output_shapes_;
168  };
169};
170
171REGISTER_KERNEL_BUILDER(Name("ZipDataset").Device(DEVICE_CPU), ZipDatasetOp);
172
173}  // namespace
174
175}  // namespace tensorflow
176