1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/api_def/update_api_def.h"
16
17#include <ctype.h>
18#include <algorithm>
19#include <string>
20#include <vector>
21
22#include "tensorflow/core/api_def/excluded_ops.h"
23#include "tensorflow/core/framework/api_def.pb.h"
24#include "tensorflow/core/framework/op.h"
25#include "tensorflow/core/framework/op_def_builder.h"
26#include "tensorflow/core/framework/op_gen_lib.h"
27#include "tensorflow/core/lib/core/status.h"
28#include "tensorflow/core/lib/io/path.h"
29#include "tensorflow/core/lib/strings/stringprintf.h"
30#include "tensorflow/core/platform/env.h"
31
32namespace tensorflow {
33
34namespace {
35constexpr char kApiDefFileFormat[] = "api_def_%s.pbtxt";
36// TODO(annarev): look into supporting other prefixes, not just 'doc'.
37constexpr char kDocStart[] = ".Doc(R\"doc(";
38constexpr char kDocEnd[] = ")doc\")";
39
40// Updates api_def based on the given op.
41void FillBaseApiDef(ApiDef* api_def, const OpDef& op) {
42  api_def->set_graph_op_name(op.name());
43  // Add arg docs
44  for (auto& input_arg : op.input_arg()) {
45    if (!input_arg.description().empty()) {
46      auto* api_def_in_arg = api_def->add_in_arg();
47      api_def_in_arg->set_name(input_arg.name());
48      api_def_in_arg->set_description(input_arg.description());
49    }
50  }
51  for (auto& output_arg : op.output_arg()) {
52    if (!output_arg.description().empty()) {
53      auto* api_def_out_arg = api_def->add_out_arg();
54      api_def_out_arg->set_name(output_arg.name());
55      api_def_out_arg->set_description(output_arg.description());
56    }
57  }
58  // Add attr docs
59  for (auto& attr : op.attr()) {
60    if (!attr.description().empty()) {
61      auto* api_def_attr = api_def->add_attr();
62      api_def_attr->set_name(attr.name());
63      api_def_attr->set_description(attr.description());
64    }
65  }
66  // Add docs
67  api_def->set_summary(op.summary());
68  api_def->set_description(op.description());
69}
70
71// Returns true if op has any description or summary.
72bool OpHasDocs(const OpDef& op) {
73  if (!op.summary().empty() || !op.description().empty()) {
74    return true;
75  }
76  for (const auto& arg : op.input_arg()) {
77    if (!arg.description().empty()) {
78      return true;
79    }
80  }
81  for (const auto& arg : op.output_arg()) {
82    if (!arg.description().empty()) {
83      return true;
84    }
85  }
86  for (const auto& attr : op.attr()) {
87    if (!attr.description().empty()) {
88      return true;
89    }
90  }
91  return false;
92}
93
94// Returns true if summary and all descriptions are the same in op1
95// and op2.
96bool CheckDocsMatch(const OpDef& op1, const OpDef& op2) {
97  if (op1.summary() != op2.summary() ||
98      op1.description() != op2.description() ||
99      op1.input_arg_size() != op2.input_arg_size() ||
100      op1.output_arg_size() != op2.output_arg_size() ||
101      op1.attr_size() != op2.attr_size()) {
102    return false;
103  }
104  // Iterate over args and attrs to compare their docs.
105  for (int i = 0; i < op1.input_arg_size(); ++i) {
106    if (op1.input_arg(i).description() != op2.input_arg(i).description()) {
107      return false;
108    }
109  }
110  for (int i = 0; i < op1.output_arg_size(); ++i) {
111    if (op1.output_arg(i).description() != op2.output_arg(i).description()) {
112      return false;
113    }
114  }
115  for (int i = 0; i < op1.attr_size(); ++i) {
116    if (op1.attr(i).description() != op2.attr(i).description()) {
117      return false;
118    }
119  }
120  return true;
121}
122
123// Returns true if descriptions and summaries in op match a
124// given single doc-string.
125bool ValidateOpDocs(const OpDef& op, const string& doc) {
126  OpDefBuilder b(op.name());
127  // We don't really care about type we use for arguments and
128  // attributes. We just want to make sure attribute and argument names
129  // are added so that descriptions can be assigned to them when parsing
130  // documentation.
131  for (const auto& arg : op.input_arg()) {
132    b.Input(arg.name() + ":string");
133  }
134  for (const auto& arg : op.output_arg()) {
135    b.Output(arg.name() + ":string");
136  }
137  for (const auto& attr : op.attr()) {
138    b.Attr(attr.name() + ":string");
139  }
140  b.Doc(doc);
141  OpRegistrationData op_reg_data;
142  TF_CHECK_OK(b.Finalize(&op_reg_data));
143  return CheckDocsMatch(op, op_reg_data.op_def);
144}
145}  // namespace
146
147string RemoveDoc(const OpDef& op, const string& file_contents,
148                 size_t start_location) {
149  // Look for a line starting with .Doc( after the REGISTER_OP.
150  const auto doc_start_location = file_contents.find(kDocStart, start_location);
151  const string format_error = strings::Printf(
152      "Could not find %s doc for removal. Make sure the doc is defined with "
153      "'%s' prefix and '%s' suffix or remove the doc manually.",
154      op.name().c_str(), kDocStart, kDocEnd);
155  if (doc_start_location == string::npos) {
156    std::cerr << format_error << std::endl;
157    LOG(ERROR) << "Didn't find doc start";
158    return file_contents;
159  }
160  const auto doc_end_location = file_contents.find(kDocEnd, doc_start_location);
161  if (doc_end_location == string::npos) {
162    LOG(ERROR) << "Didn't find doc start";
163    std::cerr << format_error << std::endl;
164    return file_contents;
165  }
166
167  const auto doc_start_size = sizeof(kDocStart) - 1;
168  string doc_text = file_contents.substr(
169      doc_start_location + doc_start_size,
170      doc_end_location - doc_start_location - doc_start_size);
171
172  // Make sure the doc text we found actually matches OpDef docs to
173  // avoid removing incorrect text.
174  if (!ValidateOpDocs(op, doc_text)) {
175    LOG(ERROR) << "Invalid doc: " << doc_text;
176    std::cerr << format_error << std::endl;
177    return file_contents;
178  }
179  // Remove .Doc call.
180  auto before_doc = file_contents.substr(0, doc_start_location);
181  str_util::StripTrailingWhitespace(&before_doc);
182  return before_doc +
183         file_contents.substr(doc_end_location + sizeof(kDocEnd) - 1);
184}
185
186namespace {
187// Remove .Doc calls that follow REGISTER_OP calls for the given ops.
188// We search for REGISTER_OP calls in the given op_files list.
189void RemoveDocs(const std::vector<const OpDef*>& ops,
190                const std::vector<string>& op_files) {
191  // Set of ops that we already found REGISTER_OP calls for.
192  std::set<string> processed_ops;
193
194  for (const auto& file : op_files) {
195    string file_contents;
196    bool file_contents_updated = false;
197    TF_CHECK_OK(ReadFileToString(Env::Default(), file, &file_contents));
198
199    for (auto op : ops) {
200      if (processed_ops.find(op->name()) != processed_ops.end()) {
201        // We already found REGISTER_OP call for this op in another file.
202        continue;
203      }
204      string register_call =
205          strings::Printf("REGISTER_OP(\"%s\")", op->name().c_str());
206      const auto register_call_location = file_contents.find(register_call);
207      // Find REGISTER_OP(OpName) call.
208      if (register_call_location == string::npos) {
209        continue;
210      }
211      std::cout << "Removing .Doc call for " << op->name() << " from " << file
212                << "." << std::endl;
213      file_contents = RemoveDoc(*op, file_contents, register_call_location);
214      file_contents_updated = true;
215
216      processed_ops.insert(op->name());
217    }
218    if (file_contents_updated) {
219      TF_CHECK_OK(WriteStringToFile(Env::Default(), file, file_contents))
220          << "Could not remove .Doc calls in " << file
221          << ". Make sure the file is writable.";
222    }
223  }
224}
225}  // namespace
226
227// Returns ApiDefs text representation in multi-line format
228// constructed based on the given op.
229string CreateApiDef(const OpDef& op) {
230  ApiDefs api_defs;
231  FillBaseApiDef(api_defs.add_op(), op);
232
233  const std::vector<string> multi_line_fields = {"description"};
234  string new_api_defs_str = api_defs.DebugString();
235  return PBTxtToMultiline(new_api_defs_str, multi_line_fields);
236}
237
238// Creates ApiDef files for any new ops.
239// If op_file_pattern is not empty, then also removes .Doc calls from
240// new op registrations in these files.
241void CreateApiDefs(const OpList& ops, const string& api_def_dir,
242                   const string& op_file_pattern) {
243  auto* excluded_ops = GetExcludedOps();
244  std::vector<const OpDef*> new_ops_with_docs;
245
246  for (const auto& op : ops.op()) {
247    if (excluded_ops->find(op.name()) != excluded_ops->end()) {
248      continue;
249    }
250    // Form the expected ApiDef path.
251    string file_path =
252        io::JoinPath(tensorflow::string(api_def_dir), kApiDefFileFormat);
253    file_path = strings::Printf(file_path.c_str(), op.name().c_str());
254
255    // Create ApiDef if it doesn't exist.
256    if (!Env::Default()->FileExists(file_path).ok()) {
257      std::cout << "Creating ApiDef file " << file_path << std::endl;
258      const auto& api_def_text = CreateApiDef(op);
259      TF_CHECK_OK(WriteStringToFile(Env::Default(), file_path, api_def_text));
260
261      if (OpHasDocs(op)) {
262        new_ops_with_docs.push_back(&op);
263      }
264    }
265  }
266  if (!op_file_pattern.empty()) {
267    std::vector<string> op_files;
268    TF_CHECK_OK(Env::Default()->GetMatchingPaths(op_file_pattern, &op_files));
269    RemoveDocs(new_ops_with_docs, op_files);
270  }
271}
272}  // namespace tensorflow
273