1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include "tensorflow/core/framework/partial_tensor_shape.h" 16#include "tensorflow/core/framework/tensor.h" 17#include "tensorflow/core/kernels/data/dataset.h" 18 19namespace tensorflow { 20 21namespace { 22 23// See documentation in ../ops/dataset_ops.cc for a high-level 24// description of the following op. 25 26class TensorDatasetOp : public DatasetOpKernel { 27 public: 28 explicit TensorDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} 29 30 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { 31 // Create a new TensorDatasetOp::Dataset, insert it in the step 32 // container, and return it as the output. 33 OpInputList inputs; 34 OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs)); 35 // TODO(mrry): Validate that the shapes of the "components" tensors match 36 // the "shapes" attr.; 37 std::vector<Tensor> components; 38 components.reserve(inputs.size()); 39 for (const Tensor& t : inputs) { 40 components.push_back(t); 41 } 42 *output = new Dataset(ctx, std::move(components)); 43 } 44 45 private: 46 class Dataset : public GraphDatasetBase { 47 public: 48 Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors) 49 : GraphDatasetBase(ctx), tensors_(std::move(tensors)) { 50 for (const Tensor& t : tensors_) { 51 dtypes_.push_back(t.dtype()); 52 shapes_.emplace_back(t.shape().dim_sizes()); 53 } 54 } 55 56 std::unique_ptr<IteratorBase> MakeIterator( 57 const string& prefix) const override { 58 return std::unique_ptr<IteratorBase>( 59 new Iterator({this, strings::StrCat(prefix, "::FromTensor")})); 60 } 61 62 const DataTypeVector& output_dtypes() const override { return dtypes_; } 63 const std::vector<PartialTensorShape>& output_shapes() const override { 64 return shapes_; 65 } 66 67 string DebugString() override { return "TensorDatasetOp::Dataset"; } 68 69 protected: 70 Status AsGraphDefInternal(DatasetGraphDefBuilder* b, 71 Node** output) const override { 72 std::vector<Node*> components; 73 components.reserve(tensors_.size()); 74 for (const Tensor& t : tensors_) { 75 Node* node; 76 TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); 77 components.emplace_back(node); 78 } 79 AttrValue dtypes; 80 b->BuildAttrValue(dtypes_, &dtypes); 81 TF_RETURN_IF_ERROR(b->AddDataset(this, {}, {{0, components}}, 82 {{"Toutput_types", dtypes}}, output)); 83 return Status::OK(); 84 } 85 86 private: 87 class Iterator : public DatasetIterator<Dataset> { 88 public: 89 explicit Iterator(const Params& params) 90 : DatasetIterator<Dataset>(params), produced_(false) {} 91 92 Status GetNextInternal(IteratorContext* ctx, 93 std::vector<Tensor>* out_tensors, 94 bool* end_of_sequence) override { 95 mutex_lock l(mu_); 96 if (!produced_) { 97 *out_tensors = dataset()->tensors_; 98 produced_ = true; 99 *end_of_sequence = false; 100 return Status::OK(); 101 } else { 102 *end_of_sequence = true; 103 return Status::OK(); 104 } 105 } 106 107 protected: 108 Status SaveInternal(IteratorStateWriter* writer) override { 109 mutex_lock l(mu_); 110 if (produced_) 111 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("produced"), "")); 112 return Status::OK(); 113 } 114 115 Status RestoreInternal(IteratorContext* ctx, 116 IteratorStateReader* reader) override { 117 mutex_lock l(mu_); 118 produced_ = reader->Contains(full_name("produced")); 119 return Status::OK(); 120 } 121 122 private: 123 mutex mu_; 124 bool produced_ GUARDED_BY(mu_); 125 }; 126 127 const std::vector<Tensor> tensors_; 128 DataTypeVector dtypes_; 129 std::vector<PartialTensorShape> shapes_; 130 }; 131}; 132 133REGISTER_KERNEL_BUILDER(Name("TensorDataset").Device(DEVICE_CPU), 134 TensorDatasetOp); 135 136} // namespace 137 138} // namespace tensorflow 139