18dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 28dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 38dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License"); 48dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFloweryou may not use this file except in compliance with the License. 58dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerYou may obtain a copy of the License at 68dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 78dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower http://www.apache.org/licenses/LICENSE-2.0 88dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 98dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software 108dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS, 118dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 128dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerSee the License for the specific language governing permissions and 138dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerlimitations under the License. 148dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower==============================================================================*/ 158dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower#include <utility> 168dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 178dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower#include "tensorflow/core/framework/partial_tensor_shape.h" 188dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower#include "tensorflow/core/framework/tensor.h" 19a5b2a0c9a3335d10c4dd3dfdff96149f74a4d120Jiri Simsa#include "tensorflow/core/kernels/data/dataset.h" 20a5b2a0c9a3335d10c4dd3dfdff96149f74a4d120Jiri Simsa#include "tensorflow/core/kernels/data/sql/driver_manager.h" 21a5b2a0c9a3335d10c4dd3dfdff96149f74a4d120Jiri Simsa#include "tensorflow/core/kernels/data/sql/query_connection.h" 228dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower#include "tensorflow/core/lib/io/inputbuffer.h" 238dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower#include "tensorflow/core/lib/io/record_reader.h" 248dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower#include "tensorflow/core/lib/strings/stringprintf.h" 258dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 268dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowernamespace tensorflow { 278dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 288dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowernamespace { 298dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower// See documentation in ../ops/dataset_ops.cc for a high-level 308dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower// description of the following ops. 318dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 328dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerclass SqlDatasetOp : public DatasetOpKernel { 338dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower public: 348dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower explicit SqlDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { 358dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); 368dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); 37a38d387323befd74671764f3187711332afdba5cA. Unique TensorFlower for (const DataType& dt : output_types_) { 38be1916ce7e2ec5bb2a72b392aedb8d16accf1983Daniel Grazian OP_REQUIRES(ctx, 39be1916ce7e2ec5bb2a72b392aedb8d16accf1983Daniel Grazian dt == DT_STRING || dt == DT_INT8 || dt == DT_INT16 || 40be1916ce7e2ec5bb2a72b392aedb8d16accf1983Daniel Grazian dt == DT_INT32 || dt == DT_INT64 || dt == DT_UINT8 || 41be1916ce7e2ec5bb2a72b392aedb8d16accf1983Daniel Grazian dt == DT_UINT16 || dt == DT_BOOL || dt == DT_DOUBLE, 42be1916ce7e2ec5bb2a72b392aedb8d16accf1983Daniel Grazian errors::InvalidArgument( 43be1916ce7e2ec5bb2a72b392aedb8d16accf1983Daniel Grazian "Each element of `output_types_` must be one of: " 44be1916ce7e2ec5bb2a72b392aedb8d16accf1983Daniel Grazian "DT_STRING, DT_INT8, DT_INT16, DT_INT32, DT_INT64, " 45be1916ce7e2ec5bb2a72b392aedb8d16accf1983Daniel Grazian "DT_UINT8, DT_UINT16, DT_BOOL, DT_DOUBLE ")); 46a38d387323befd74671764f3187711332afdba5cA. Unique TensorFlower } 47a38d387323befd74671764f3187711332afdba5cA. Unique TensorFlower for (const PartialTensorShape& pts : output_shapes_) { 48a38d387323befd74671764f3187711332afdba5cA. Unique TensorFlower OP_REQUIRES(ctx, pts.dims() == 0, 49a38d387323befd74671764f3187711332afdba5cA. Unique TensorFlower errors::InvalidArgument( 50a38d387323befd74671764f3187711332afdba5cA. Unique TensorFlower "Each element of `output_shapes_` must be a scalar.")); 51a38d387323befd74671764f3187711332afdba5cA. Unique TensorFlower } 528dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower } 538dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { 548dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower string driver_name; 558dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower OP_REQUIRES_OK( 568dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower ctx, ParseScalarArgument<string>(ctx, "driver_name", &driver_name)); 578dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 588dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower string data_source_name; 598dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "data_source_name", 608dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower &data_source_name)); 618dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 628dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower string query; 638dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "query", &query)); 648dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 658dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower // TODO(b/64276826) Change this check when we add support for other 668dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower // databases. 678dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower OP_REQUIRES(ctx, driver_name == "sqlite", 688dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower errors::InvalidArgument(tensorflow::strings::Printf( 698dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower "The database type, %s, is not supported by SqlDataset. " 708dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower "The set of supported databases is: {'sqlite'}.", 718dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower driver_name.c_str()))); 728dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 738dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower *output = new Dataset(driver_name, data_source_name, query, output_types_, 748dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower output_shapes_); 758dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower } 768dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 778dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower private: 788dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower class Dataset : public DatasetBase { 798dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower public: 808dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower Dataset(const string& driver_name, const string& data_source_name, 818dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower const string& query, const DataTypeVector& output_types, 828dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower const std::vector<PartialTensorShape>& output_shapes) 838dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower : driver_name_(driver_name), 848dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower data_source_name_(data_source_name), 858dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower query_(query), 868dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower output_types_(output_types), 878dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower output_shapes_(output_shapes) {} 888dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 898dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower std::unique_ptr<IteratorBase> MakeIterator( 908dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower const string& prefix) const override { 918dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower return std::unique_ptr<IteratorBase>( 928dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower new Iterator({this, strings::StrCat(prefix, "::Sql")})); 938dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower } 948dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 958dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower const DataTypeVector& output_dtypes() const override { 968dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower return output_types_; 978dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower } 988dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 998dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower const std::vector<PartialTensorShape>& output_shapes() const override { 1008dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower return output_shapes_; 1018dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower } 1028dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 1038dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower string DebugString() override { return "SqlDatasetOp::Dataset"; } 1048dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 1058dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower private: 1068dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower class Iterator : public DatasetIterator<Dataset> { 1078dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower public: 1088dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower explicit Iterator(const Params& params) 1098dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower : DatasetIterator<Dataset>(params) {} 1108dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower ~Iterator() override { 1118dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower if (query_connection_initialized_) { 1128dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower Status s = query_connection_->Close(); 1138dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower if (!s.ok()) { 1148dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower LOG(WARNING) << "Failed to close query connection: " << s; 1158dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower } 1168dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower } 1178dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower } 1188dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 119548df15375488fc06ff663670f88734f3ece4814Derek Murray Status GetNextInternal(IteratorContext* ctx, 1208dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower std::vector<Tensor>* out_tensors, 1218dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower bool* end_of_sequence) override { 1228dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower mutex_lock l(mu_); 1238dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower if (!query_connection_initialized_) { 1248dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower query_connection_initialized_ = true; 1258dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower query_connection_ = sql::DriverManager::CreateQueryConnection( 1268dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower dataset()->driver_name_); 1278dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower Status s = query_connection_->Open(dataset()->data_source_name_, 1288dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower dataset()->query_, 1298dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower dataset()->output_types_); 1308dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower if (!s.ok()) { 1318dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower LOG(WARNING) << "Failed to connect to database: " << s; 1328dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower return s; 1338dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower } 1348dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower } 135548df15375488fc06ff663670f88734f3ece4814Derek Murray return query_connection_->GetNext(ctx, out_tensors, end_of_sequence); 1368dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower } 1378dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 1388dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower private: 1398dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower mutex mu_; 1408dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower std::unique_ptr<sql::QueryConnection> query_connection_ GUARDED_BY(mu_); 1418dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower bool query_connection_initialized_ GUARDED_BY(mu_) = false; 1428dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower }; 1438dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower const string driver_name_; 1448dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower const string data_source_name_; 1458dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower const string query_; 1468dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower const DataTypeVector output_types_; 1478dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower const std::vector<PartialTensorShape> output_shapes_; 1488dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower }; 1498dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower DataTypeVector output_types_; 1508dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower std::vector<PartialTensorShape> output_shapes_; 1518dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower}; 1528dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 1538dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerREGISTER_KERNEL_BUILDER(Name("SqlDataset").Device(DEVICE_CPU), SqlDatasetOp); 1548dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 1558dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower} // namespace 1568dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower 1578dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower} // namespace tensorflow 158