1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/python/eager/python_eager_op_gen.h"
17
18#include <memory>
19#include <string>
20#include <unordered_set>
21#include <vector>
22
23#include "tensorflow/core/framework/op.h"
24#include "tensorflow/core/framework/op_def.pb.h"
25#include "tensorflow/core/framework/op_gen_lib.h"
26#include "tensorflow/core/lib/core/errors.h"
27#include "tensorflow/core/lib/io/inputbuffer.h"
28#include "tensorflow/core/lib/io/path.h"
29#include "tensorflow/core/lib/strings/scanner.h"
30#include "tensorflow/core/platform/env.h"
31#include "tensorflow/core/platform/init_main.h"
32#include "tensorflow/core/platform/logging.h"
33
34namespace tensorflow {
35namespace {
36
37Status ReadOpListFromFile(const string& filename,
38                          std::vector<string>* op_list) {
39  std::unique_ptr<RandomAccessFile> file;
40  TF_CHECK_OK(Env::Default()->NewRandomAccessFile(filename, &file));
41  std::unique_ptr<io::InputBuffer> input_buffer(
42      new io::InputBuffer(file.get(), 256 << 10));
43  string line_contents;
44  Status s = input_buffer->ReadLine(&line_contents);
45  while (s.ok()) {
46    // The parser assumes that the op name is the first string on each
47    // line with no preceding whitespace, and ignores lines that do
48    // not start with an op name as a comment.
49    strings::Scanner scanner{StringPiece(line_contents)};
50    StringPiece op_name;
51    if (scanner.One(strings::Scanner::LETTER_DIGIT_DOT)
52            .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
53            .GetResult(nullptr, &op_name)) {
54      op_list->emplace_back(op_name.ToString());
55    }
56    s = input_buffer->ReadLine(&line_contents);
57  }
58  if (!errors::IsOutOfRange(s)) return s;
59  return Status::OK();
60}
61
62// The argument parsing is deliberately simplistic to support our only
63// known use cases:
64//
65// 1. Read all op names from a file.
66// 2. Read all op names from the arg as a comma-delimited list.
67//
68// Expected command-line argument syntax:
69// ARG ::= '@' FILENAME
70//       |  OP_NAME [',' OP_NAME]*
71//       |  ''
72Status ParseOpListCommandLine(const char* arg, std::vector<string>* op_list) {
73  std::vector<string> op_names = str_util::Split(arg, ',');
74  if (op_names.size() == 1 && op_names[0].empty()) {
75    return Status::OK();
76  } else if (op_names.size() == 1 && op_names[0].substr(0, 1) == "@") {
77    const string filename = op_names[0].substr(1);
78    return tensorflow::ReadOpListFromFile(filename, op_list);
79  } else {
80    *op_list = std::move(op_names);
81  }
82  return Status::OK();
83}
84
85// Use the name of the current executable to infer the C++ source file
86// where the REGISTER_OP() call for the operator can be found.
87// Returns the name of the file.
88// Returns an empty string if the current executable's name does not
89// follow a known pattern.
90string InferSourceFileName(const char* argv_zero) {
91  StringPiece command_str = io::Basename(argv_zero);
92
93  // For built-in ops, the Bazel build creates a separate executable
94  // with the name gen_<op type>_ops_py_wrappers_cc containing the
95  // operators defined in <op type>_ops.cc
96  const char* kExecPrefix = "gen_";
97  const char* kExecSuffix = "_py_wrappers_cc";
98  if (command_str.Consume(kExecPrefix) && command_str.ends_with(kExecSuffix)) {
99    command_str.remove_suffix(strlen(kExecSuffix));
100    return strings::StrCat(command_str, ".cc");
101  } else {
102    return string("");
103  }
104}
105
106void PrintAllPythonOps(const std::vector<string>& op_list,
107                       const std::vector<string>& api_def_dirs,
108                       const string& source_file_name, bool require_shapes,
109                       bool op_list_is_whitelist) {
110  OpList ops;
111  OpRegistry::Global()->Export(false, &ops);
112
113  ApiDefMap api_def_map(ops);
114  if (!api_def_dirs.empty()) {
115    Env* env = Env::Default();
116
117    for (const auto& api_def_dir : api_def_dirs) {
118      std::vector<string> api_files;
119      TF_CHECK_OK(env->GetMatchingPaths(io::JoinPath(api_def_dir, "*.pbtxt"),
120                                        &api_files));
121      TF_CHECK_OK(api_def_map.LoadFileList(env, api_files));
122    }
123    api_def_map.UpdateDocs();
124  }
125
126  if (op_list_is_whitelist) {
127    std::unordered_set<string> whitelist(op_list.begin(), op_list.end());
128    OpList pruned_ops;
129    for (const auto& op_def : ops.op()) {
130      if (whitelist.find(op_def.name()) != whitelist.end()) {
131        *pruned_ops.mutable_op()->Add() = op_def;
132      }
133    }
134    PrintEagerPythonOps(pruned_ops, api_def_map, {}, require_shapes,
135                        source_file_name);
136  } else {
137    PrintEagerPythonOps(ops, api_def_map, op_list, require_shapes,
138                        source_file_name);
139  }
140}
141
142}  // namespace
143}  // namespace tensorflow
144
145int main(int argc, char* argv[]) {
146  tensorflow::port::InitMain(argv[0], &argc, &argv);
147
148  tensorflow::string source_file_name =
149      tensorflow::InferSourceFileName(argv[0]);
150
151  // Usage:
152  //   gen_main api_def_dir1,api_def_dir2,...
153  //       [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1]
154  if (argc < 3) {
155    return -1;
156  }
157  std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
158      argv[1], ",", tensorflow::str_util::SkipEmpty());
159
160  if (argc == 3) {
161    tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name,
162                                  tensorflow::string(argv[2]) == "1",
163                                  false /* op_list_is_whitelist */);
164  } else if (argc == 4) {
165    std::vector<tensorflow::string> hidden_ops;
166    TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops));
167    tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name,
168                                  tensorflow::string(argv[3]) == "1",
169                                  false /* op_list_is_whitelist */);
170  } else if (argc == 5) {
171    std::vector<tensorflow::string> op_list;
172    TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &op_list));
173    tensorflow::PrintAllPythonOps(op_list, api_def_dirs, source_file_name,
174                                  tensorflow::string(argv[3]) == "1",
175                                  tensorflow::string(argv[4]) == "1");
176  } else {
177    return -1;
178  }
179  return 0;
180}
181