18dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
28dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
38dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License");
48dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFloweryou may not use this file except in compliance with the License.
58dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerYou may obtain a copy of the License at
68dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
78dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    http://www.apache.org/licenses/LICENSE-2.0
88dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
98dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software
108dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS,
118dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
128dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerSee the License for the specific language governing permissions and
138dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerlimitations under the License.
148dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower==============================================================================*/
158dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower#include <utility>
168dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
178dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower#include "tensorflow/core/framework/partial_tensor_shape.h"
188dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower#include "tensorflow/core/framework/tensor.h"
19a5b2a0c9a3335d10c4dd3dfdff96149f74a4d120Jiri Simsa#include "tensorflow/core/kernels/data/dataset.h"
20a5b2a0c9a3335d10c4dd3dfdff96149f74a4d120Jiri Simsa#include "tensorflow/core/kernels/data/sql/driver_manager.h"
21a5b2a0c9a3335d10c4dd3dfdff96149f74a4d120Jiri Simsa#include "tensorflow/core/kernels/data/sql/query_connection.h"
228dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower#include "tensorflow/core/lib/io/inputbuffer.h"
238dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower#include "tensorflow/core/lib/io/record_reader.h"
248dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower#include "tensorflow/core/lib/strings/stringprintf.h"
258dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
268dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowernamespace tensorflow {
278dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
288dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowernamespace {
298dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower// See documentation in ../ops/dataset_ops.cc for a high-level
308dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower// description of the following ops.
318dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
328dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerclass SqlDatasetOp : public DatasetOpKernel {
338dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower public:
348dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower  explicit SqlDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
358dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
368dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
37a38d387323befd74671764f3187711332afdba5cA. Unique TensorFlower    for (const DataType& dt : output_types_) {
38be1916ce7e2ec5bb2a72b392aedb8d16accf1983Daniel Grazian      OP_REQUIRES(ctx,
39be1916ce7e2ec5bb2a72b392aedb8d16accf1983Daniel Grazian                  dt == DT_STRING || dt == DT_INT8 || dt == DT_INT16 ||
40be1916ce7e2ec5bb2a72b392aedb8d16accf1983Daniel Grazian                      dt == DT_INT32 || dt == DT_INT64 || dt == DT_UINT8 ||
41be1916ce7e2ec5bb2a72b392aedb8d16accf1983Daniel Grazian                      dt == DT_UINT16 || dt == DT_BOOL || dt == DT_DOUBLE,
42be1916ce7e2ec5bb2a72b392aedb8d16accf1983Daniel Grazian                  errors::InvalidArgument(
43be1916ce7e2ec5bb2a72b392aedb8d16accf1983Daniel Grazian                      "Each element of `output_types_` must be one of: "
44be1916ce7e2ec5bb2a72b392aedb8d16accf1983Daniel Grazian                      "DT_STRING, DT_INT8, DT_INT16, DT_INT32, DT_INT64, "
45be1916ce7e2ec5bb2a72b392aedb8d16accf1983Daniel Grazian                      "DT_UINT8, DT_UINT16, DT_BOOL, DT_DOUBLE "));
46a38d387323befd74671764f3187711332afdba5cA. Unique TensorFlower    }
47a38d387323befd74671764f3187711332afdba5cA. Unique TensorFlower    for (const PartialTensorShape& pts : output_shapes_) {
48a38d387323befd74671764f3187711332afdba5cA. Unique TensorFlower      OP_REQUIRES(ctx, pts.dims() == 0,
49a38d387323befd74671764f3187711332afdba5cA. Unique TensorFlower                  errors::InvalidArgument(
50a38d387323befd74671764f3187711332afdba5cA. Unique TensorFlower                      "Each element of `output_shapes_` must be a scalar."));
51a38d387323befd74671764f3187711332afdba5cA. Unique TensorFlower    }
528dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower  }
538dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower  void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
548dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    string driver_name;
558dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    OP_REQUIRES_OK(
568dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower        ctx, ParseScalarArgument<string>(ctx, "driver_name", &driver_name));
578dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
588dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    string data_source_name;
598dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "data_source_name",
608dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower                                                    &data_source_name));
618dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
628dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    string query;
638dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "query", &query));
648dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
658dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    // TODO(b/64276826) Change this check when we add support for other
668dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    // databases.
678dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    OP_REQUIRES(ctx, driver_name == "sqlite",
688dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower                errors::InvalidArgument(tensorflow::strings::Printf(
698dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower                    "The database type, %s, is not supported by SqlDataset. "
708dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower                    "The set of supported databases is: {'sqlite'}.",
718dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower                    driver_name.c_str())));
728dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
738dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    *output = new Dataset(driver_name, data_source_name, query, output_types_,
748dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower                          output_shapes_);
758dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower  }
768dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
778dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower private:
788dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower  class Dataset : public DatasetBase {
798dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower   public:
808dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    Dataset(const string& driver_name, const string& data_source_name,
818dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower            const string& query, const DataTypeVector& output_types,
828dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower            const std::vector<PartialTensorShape>& output_shapes)
838dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower        : driver_name_(driver_name),
848dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower          data_source_name_(data_source_name),
858dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower          query_(query),
868dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower          output_types_(output_types),
878dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower          output_shapes_(output_shapes) {}
888dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
898dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    std::unique_ptr<IteratorBase> MakeIterator(
908dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower        const string& prefix) const override {
918dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower      return std::unique_ptr<IteratorBase>(
928dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower          new Iterator({this, strings::StrCat(prefix, "::Sql")}));
938dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    }
948dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
958dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    const DataTypeVector& output_dtypes() const override {
968dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower      return output_types_;
978dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    }
988dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
998dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    const std::vector<PartialTensorShape>& output_shapes() const override {
1008dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower      return output_shapes_;
1018dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    }
1028dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
1038dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    string DebugString() override { return "SqlDatasetOp::Dataset"; }
1048dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
1058dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower   private:
1068dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    class Iterator : public DatasetIterator<Dataset> {
1078dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower     public:
1088dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower      explicit Iterator(const Params& params)
1098dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower          : DatasetIterator<Dataset>(params) {}
1108dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower      ~Iterator() override {
1118dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower        if (query_connection_initialized_) {
1128dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower          Status s = query_connection_->Close();
1138dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower          if (!s.ok()) {
1148dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower            LOG(WARNING) << "Failed to close query connection: " << s;
1158dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower          }
1168dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower        }
1178dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower      }
1188dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
119548df15375488fc06ff663670f88734f3ece4814Derek Murray      Status GetNextInternal(IteratorContext* ctx,
1208dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower                             std::vector<Tensor>* out_tensors,
1218dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower                             bool* end_of_sequence) override {
1228dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower        mutex_lock l(mu_);
1238dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower        if (!query_connection_initialized_) {
1248dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower          query_connection_initialized_ = true;
1258dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower          query_connection_ = sql::DriverManager::CreateQueryConnection(
1268dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower              dataset()->driver_name_);
1278dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower          Status s = query_connection_->Open(dataset()->data_source_name_,
1288dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower                                             dataset()->query_,
1298dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower                                             dataset()->output_types_);
1308dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower          if (!s.ok()) {
1318dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower            LOG(WARNING) << "Failed to connect to database: " << s;
1328dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower            return s;
1338dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower          }
1348dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower        }
135548df15375488fc06ff663670f88734f3ece4814Derek Murray        return query_connection_->GetNext(ctx, out_tensors, end_of_sequence);
1368dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower      }
1378dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
1388dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower     private:
1398dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower      mutex mu_;
1408dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower      std::unique_ptr<sql::QueryConnection> query_connection_ GUARDED_BY(mu_);
1418dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower      bool query_connection_initialized_ GUARDED_BY(mu_) = false;
1428dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    };
1438dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    const string driver_name_;
1448dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    const string data_source_name_;
1458dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    const string query_;
1468dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    const DataTypeVector output_types_;
1478dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower    const std::vector<PartialTensorShape> output_shapes_;
1488dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower  };
1498dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower  DataTypeVector output_types_;
1508dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower  std::vector<PartialTensorShape> output_shapes_;
1518dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower};
1528dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
1538dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlowerREGISTER_KERNEL_BUILDER(Name("SqlDataset").Device(DEVICE_CPU), SqlDatasetOp);
1548dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
1558dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower}  // namespace
1568dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower
1578dac0eb067662ba69a7eb72e640d8ff8d146e5c9A. Unique TensorFlower}  // namespace tensorflow
158