1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/framework/partial_tensor_shape.h"
16#include "tensorflow/core/framework/tensor.h"
17#include "tensorflow/core/kernels/data/dataset.h"
18
19namespace tensorflow {
20
21namespace {
22
23// See documentation in ../ops/dataset_ops.cc for a high-level
24// description of the following op.
25
26class TensorDatasetOp : public DatasetOpKernel {
27 public:
28  explicit TensorDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
29
30  void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
31    // Create a new TensorDatasetOp::Dataset, insert it in the step
32    // container, and return it as the output.
33    OpInputList inputs;
34    OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs));
35    // TODO(mrry): Validate that the shapes of the "components" tensors match
36    // the "shapes" attr.;
37    std::vector<Tensor> components;
38    components.reserve(inputs.size());
39    for (const Tensor& t : inputs) {
40      components.push_back(t);
41    }
42    *output = new Dataset(ctx, std::move(components));
43  }
44
45 private:
46  class Dataset : public GraphDatasetBase {
47   public:
48    Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors)
49        : GraphDatasetBase(ctx), tensors_(std::move(tensors)) {
50      for (const Tensor& t : tensors_) {
51        dtypes_.push_back(t.dtype());
52        shapes_.emplace_back(t.shape().dim_sizes());
53      }
54    }
55
56    std::unique_ptr<IteratorBase> MakeIterator(
57        const string& prefix) const override {
58      return std::unique_ptr<IteratorBase>(
59          new Iterator({this, strings::StrCat(prefix, "::FromTensor")}));
60    }
61
62    const DataTypeVector& output_dtypes() const override { return dtypes_; }
63    const std::vector<PartialTensorShape>& output_shapes() const override {
64      return shapes_;
65    }
66
67    string DebugString() override { return "TensorDatasetOp::Dataset"; }
68
69   protected:
70    Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
71                              Node** output) const override {
72      std::vector<Node*> components;
73      components.reserve(tensors_.size());
74      for (const Tensor& t : tensors_) {
75        Node* node;
76        TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
77        components.emplace_back(node);
78      }
79      AttrValue dtypes;
80      b->BuildAttrValue(dtypes_, &dtypes);
81      TF_RETURN_IF_ERROR(b->AddDataset(this, {}, {{0, components}},
82                                       {{"Toutput_types", dtypes}}, output));
83      return Status::OK();
84    }
85
86   private:
87    class Iterator : public DatasetIterator<Dataset> {
88     public:
89      explicit Iterator(const Params& params)
90          : DatasetIterator<Dataset>(params), produced_(false) {}
91
92      Status GetNextInternal(IteratorContext* ctx,
93                             std::vector<Tensor>* out_tensors,
94                             bool* end_of_sequence) override {
95        mutex_lock l(mu_);
96        if (!produced_) {
97          *out_tensors = dataset()->tensors_;
98          produced_ = true;
99          *end_of_sequence = false;
100          return Status::OK();
101        } else {
102          *end_of_sequence = true;
103          return Status::OK();
104        }
105      }
106
107     protected:
108      Status SaveInternal(IteratorStateWriter* writer) override {
109        mutex_lock l(mu_);
110        if (produced_)
111          TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("produced"), ""));
112        return Status::OK();
113      }
114
115      Status RestoreInternal(IteratorContext* ctx,
116                             IteratorStateReader* reader) override {
117        mutex_lock l(mu_);
118        produced_ = reader->Contains(full_name("produced"));
119        return Status::OK();
120      }
121
122     private:
123      mutex mu_;
124      bool produced_ GUARDED_BY(mu_);
125    };
126
127    const std::vector<Tensor> tensors_;
128    DataTypeVector dtypes_;
129    std::vector<PartialTensorShape> shapes_;
130  };
131};
132
133REGISTER_KERNEL_BUILDER(Name("TensorDataset").Device(DEVICE_CPU),
134                        TensorDatasetOp);
135
136}  // namespace
137
138}  // namespace tensorflow
139