1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/shape_refiner.h"
16
17#include <deque>
18#include <memory>
19#include <unordered_set>
20#include <vector>
21
22#include "tensorflow/core/framework/common_shape_fns.h"
23#include "tensorflow/core/framework/node_def.pb.h"
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/tensor.pb.h"
26#include "tensorflow/core/framework/versions.pb.h"
27#include "tensorflow/core/graph/algorithm.h"
28#include "tensorflow/core/graph/graph_constructor.h"
29#include "tensorflow/core/kernels/bounds_check.h"
30#include "tensorflow/core/lib/core/errors.h"
31#include "tensorflow/core/lib/gtl/stl_util.h"
32#include "tensorflow/core/public/session.h"
33
34namespace tensorflow {
35
36using shape_inference::DimensionHandle;
37using shape_inference::InferenceContext;
38using shape_inference::ShapeAndType;
39using shape_inference::ShapeHandle;
40
41ShapeRefiner::ShapeRefiner(int graph_def_version,
42                           const OpRegistryInterface* ops)
43    : graph_def_version_(graph_def_version),
44      ops_registry_(ops),
45      graph_runner_(Env::Default()) {}
46
47ShapeRefiner::ShapeRefiner(const VersionDef& versions,
48                           const OpRegistryInterface* ops)
49    : ShapeRefiner(versions.producer(), ops) {}
50
51ShapeRefiner::~ShapeRefiner() {
52  // The lifetime of the tensors are bound to the GraphRunner, so the tensors
53  // should be deleted before it.
54  const_tensor_map_.clear();
55}
56
57namespace {
58
59constexpr char kArgOp[] = "_Arg";
60constexpr char kRetvalOp[] = "_Retval";
61
62// Runs shape inference for the given node using the given ShapeRefiner.
63// The node must be a sub-node of a function node and the outer_context is
64// the inference context of that function node in the outer graph.
65Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner,
66                                     InferenceContext* outer_context) {
67  TF_RETURN_IF_ERROR(refiner->AddNode(node));
68  InferenceContext* node_context = CHECK_NOTNULL(refiner->GetContext(node));
69
70  if (StringPiece(node->type_string()) == kArgOp) {
71    // Handle special node: function input.
72    // Shapes for these nodes are provided in the outer inference
73    // context.
74
75    int index;
76    TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
77
78    if (index < 0 || outer_context->num_inputs() <= index) {
79      return errors::Internal(
80          "Function instantiation included invalid input index: ", index,
81          " not in [0, ", outer_context->num_inputs(), ").");
82    }
83
84    node_context->set_output(0, outer_context->input(index));
85
86    auto* resource = outer_context->input_handle_shapes_and_types(index);
87    if (resource) {
88      node_context->set_output_handle_shapes_and_types(0, *resource);
89    }
90  } else if (StringPiece(node->type_string()) == kRetvalOp) {
91    // Handle special node: function output.
92    // Shapes inferred for these nodes go into the outer inference
93    // context.
94
95    int index;
96    TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
97
98    if (index < 0 || outer_context->num_outputs() <= index) {
99      return errors::Internal(
100          "Function instantiation included invalid output index: ", index,
101          " not in [0, ", outer_context->num_outputs(), ").");
102    }
103
104    // outer_context outlives node_context, therefore we need to create
105    // a new shape handle owned by outer_context instead.
106    ShapeHandle handle;
107    TensorShapeProto proto;
108    node_context->ShapeHandleToProto(node_context->input(0), &proto);
109    TF_RETURN_IF_ERROR(outer_context->MakeShapeFromShapeProto(proto, &handle));
110    outer_context->set_output(index, handle);
111
112    auto* resource = node_context->input_handle_shapes_and_types(0);
113    if (resource) {
114      outer_context->set_output_handle_shapes_and_types(index, *resource);
115    }
116  }
117
118  return Status::OK();
119}
120
121}  // namespace
122
123// TODO(cwhipkey): When an inference context inside function has
124// requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i)
125// set when input(i) is an _Arg op, then this request should propagate to
126// context, and vice versa.
127//
128// NOTE: Recursive user-defined functions are not supported.
129// Maybe we won't support recursive functions at all in TF, because of
130// other maintainability issues.
131Status ShapeRefiner::InferShapesForFunction(
132    const tensorflow::FunctionDef* function_def, bool keep_nested_shapes,
133    ExtendedInferenceContext* outer_context) {
134  const Graph* graph;
135  auto it = functions_.find(function_def);
136  if (it != functions_.end()) {
137    graph = it->second.get();
138  } else {
139    InstantiationResult result;
140    TF_RETURN_IF_ERROR(InstantiateFunction(
141        *function_def, outer_context->get_context()->attrs(),
142        [this](const string& op, const OpDef** sig) {
143          return this->function_library_->LookUpOpDef(op, sig);
144        },
145        &result));
146
147    Graph* new_graph = new Graph(function_library_);
148    GraphConstructorOptions options;
149    options.allow_internal_ops = true;
150    TF_RETURN_IF_ERROR(
151        ConvertNodeDefsToGraph(options, result.nodes, new_graph));
152    functions_[function_def].reset(new_graph);
153    graph = new_graph;
154  }
155
156  std::unordered_set<const Node*> function_nodes;
157  Status inference_status = Status::OK();
158  {
159    auto node_shape_inference_lambda = [this, &outer_context, &function_nodes,
160                                        &inference_status](const Node* node) {
161      if (!inference_status.ok()) return;
162      inference_status = InferShapesForFunctionSubNode(
163          node, this, outer_context->get_context());
164      function_nodes.insert(node);
165    };
166
167    // Calls inference lambda for each node after visiting all predecessors.
168    // Ensures that we are adding nodes to ShapeRefiner in the topological
169    // order.
170    ReverseDFS(*graph, {}, node_shape_inference_lambda);
171  }
172
173  if (keep_nested_shapes && inference_status.ok()) {
174    // Fill the nested inferences map.
175    //
176    // The materialized function graph has extra nodes for arguments and
177    // return values, which are not explicitly listed in the FunctionDef,
178    // we filter out these special nodes here to not expose the implementation
179    // details and keep only inferences for the nodes listed in the FunctionDef.
180    std::unordered_map<string, const NodeDef*> user_defined_nodes;
181    for (const auto& node_def : function_def->node_def()) {
182      user_defined_nodes[node_def.name()] = &node_def;
183    }
184
185    std::unordered_map<string, std::unique_ptr<ExtendedInferenceContext>>
186        nested_inferences;
187    for (const Node* node : function_nodes) {
188      const string& node_name = node->name();
189      if (user_defined_nodes.find(node_name) != user_defined_nodes.end()) {
190        nested_inferences[node_name] = std::move(node_to_context_[node]);
191        node_to_context_.erase(node);
192        // By default InferenceContext refers to a NodeDef from Graph.
193        // Change it to the publicly accessible NodeDef of the function
194        // definition.
195        nested_inferences[node_name]->get_context()->node_def_ =
196            user_defined_nodes[node_name];
197      }
198    }
199    outer_context->set_nested_inferences(std::move(nested_inferences));
200  } else {
201    // Delete the contexts created for the functions nodes to save memory.
202    for (const Node* node : function_nodes) {
203      node_to_context_.erase(node);
204    }
205  }
206
207  return inference_status;
208}
209
210Status ShapeRefiner::AddNode(const Node* node) {
211  // For each 'input' of this node, fetch the corresponding shape
212  // from 'input's InferenceContext, and store into a vector
213  // indexed by 'node's input.
214  std::vector<Node*> input_nodes(node->num_inputs());
215  std::vector<ShapeHandle> input_shapes(node->num_inputs());
216  std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
217      input_handle_shapes_and_types(node->num_inputs());
218  for (const Edge* e : node->in_edges()) {
219    if (e->IsControlEdge()) continue;
220
221    Node* input = e->src();
222    auto it = node_to_context_.find(input);
223    if (it == node_to_context_.end()) {
224      return errors::FailedPrecondition(
225          "Input ", e->dst_input(), " ('", input->name(), "') for '",
226          node->name(), "' was not previously added to ShapeRefiner.");
227    }
228
229    InferenceContext* c = it->second->get_context();
230    DCHECK_GE(e->dst_input(), 0);
231    input_nodes[e->dst_input()] = input;
232    input_shapes[e->dst_input()] = c->output(e->src_output());
233
234    // Only propagate handle data of edges which are carrying resource handles.
235    if (e->src()->output_type(e->src_output()) == DT_RESOURCE) {
236      const auto* in_v = c->output_handle_shapes_and_types(e->src_output());
237      if (in_v != nullptr) {
238        input_handle_shapes_and_types[e->dst_input()].reset(
239            new std::vector<ShapeAndType>(*in_v));
240      }
241    }
242  }
243
244  // Get the shape function for this node
245  const OpRegistrationData* op_reg_data;
246  TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
247  if (op_reg_data->shape_inference_fn == nullptr &&
248      require_shape_inference_fns_) {
249    return errors::InvalidArgument(
250        "No shape inference function exists for op '", node->type_string(),
251        "', did you forget to define it?");
252  }
253
254  // This needs to be filled in with real data in a second pass.
255  std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
256  std::vector<ShapeHandle> input_tensors_as_shapes;
257
258  // Create the inference context for this node with the existing input shapes.
259  std::unique_ptr<InferenceContext> c(
260      new InferenceContext(graph_def_version_, &node->def(), node->op_def(),
261                           input_shapes, input_tensors, input_tensors_as_shapes,
262                           std::move(input_handle_shapes_and_types)));
263  if (!c->construction_status().ok()) {
264    return c->construction_status();
265  }
266
267  std::unique_ptr<ExtendedInferenceContext> ec(
268      new ExtendedInferenceContext(std::move(c), node));
269
270  // Run the shape inference function, and return if there was an error.
271  TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get()));
272
273  // Store the resulting context object in the map.
274  node_to_context_[node].swap(ec);
275
276  return Status::OK();
277}
278
279Status ShapeRefiner::SetShape(const Node* node, int output_port,
280                              ShapeHandle shape) {
281  auto c = GetContext(node);
282  if (c == nullptr) {
283    return errors::Internal("Could not find context for ", node->name());
284  }
285
286  if (output_port < 0 || output_port >= node->num_outputs()) {
287    return errors::InvalidArgument(
288        "output_port '", output_port, "' is out of range, ", "node '",
289        node->name(), "' has ", node->num_outputs(), " outputs");
290  }
291
292  // Check compatibility, and merge the shapes.
293  ShapeHandle existing_shape = c->output(output_port);
294  TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &shape));
295  c->set_output(output_port, shape);
296
297  // TODO(vrv): Do we need to propagate the new shape through all
298  // consumers that change their outputs?  At the moment, python
299  // does not do this, but this seems like a nice feature.
300
301  // TODO(vrv): We might need to keep track of the fact that the
302  // existing shape is invalidated, in case we need to propagate
303  // this information to remote workers.
304  return Status::OK();
305}
306
307Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
308  auto it = node_to_context_.find(node);
309  if (it == node_to_context_.end()) {
310    *refined = true;
311    return AddNode(node);
312  }
313  ExtendedInferenceContext* node_ext_context = it->second.get();
314  InferenceContext* node_context = node_ext_context->get_context();
315
316  // Give up if the context wasn't successfully built by the AddNode() method.
317  TF_RETURN_IF_ERROR(node_context->construction_status());
318
319  // Check if the shapes of the nodes in the fan-in of this node have changed,
320  // and if they have update the node input shapes.
321  for (const Edge* e : node->in_edges()) {
322    if (e->IsControlEdge()) continue;
323
324    int dst_input = e->dst_input();
325    int src_output = e->src_output();
326
327    Node* input = e->src();
328    auto iter = node_to_context_.find(input);
329    if (iter == node_to_context_.end()) {
330      return errors::FailedPrecondition(
331          "Input ", dst_input, " ('", input->name(), "') for '", node->name(),
332          "' was not previously added to ShapeRefiner.");
333    }
334
335    InferenceContext* c = iter->second->get_context();
336    DCHECK_GE(dst_input, 0);
337    ShapeHandle existing_input = node_context->input(dst_input);
338    if (!relax) {
339      if (node_context->MergeInput(dst_input, c->output(src_output))) {
340        if (!SameDefinedShape(node_context, node_context->input(dst_input),
341                              existing_input)) {
342          *refined = true;
343        }
344      }
345    } else {
346      if (node_context->RelaxInput(dst_input, c->output(src_output))) {
347        if (!SameDefinedShape(node_context, node_context->input(dst_input),
348                              existing_input)) {
349          *refined = true;
350        }
351      }
352    }
353
354    // Also propagate handle shape and dtype of edges which are carrying
355    // resource handles.
356    if (e->src()->output_type(src_output) == DT_RESOURCE) {
357      auto* outputs = c->output_handle_shapes_and_types(src_output);
358      if (!outputs) continue;
359
360      if (!relax &&
361          node_context->MergeInputHandleShapesAndTypes(dst_input, *outputs)) {
362        *refined = true;
363      } else if (relax) {
364        std::vector<ShapeAndType> existing_inputs;
365        const std::vector<ShapeAndType>* inputs =
366            node_context->input_handle_shapes_and_types(dst_input);
367        if (inputs) {
368          existing_inputs = *inputs;
369        }
370        if (node_context->RelaxInputHandleShapesAndMergeTypes(dst_input,
371                                                              *outputs)) {
372          if (IsUpdatedShapesOrTypes(
373                  node_context, existing_inputs,
374                  *node_context->input_handle_shapes_and_types(dst_input))) {
375            *refined = true;
376          }
377        }
378      }
379    }
380  }
381
382  if (!*refined) {
383    // No input shape has changed, we're done
384    return Status::OK();
385  }
386
387  // Get and run the shape function for this node to update the shapes of the
388  // outputs.
389  const OpRegistrationData* op_reg_data;
390  TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
391  if (op_reg_data->shape_inference_fn == nullptr &&
392      require_shape_inference_fns_) {
393    return errors::InvalidArgument(
394        "No shape inference function exists for op '", node->type_string(),
395        "', did you forget to define it?");
396  }
397
398  if (!op_reg_data->shape_inference_fn) {
399    // There is nothing more we can infer
400    return Status::OK();
401  }
402
403  return RunShapeFn(node, op_reg_data, node_ext_context);
404}
405
406Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
407                                                   int dst_idx, bool* evaluated,
408                                                   Tensor* result) {
409  *evaluated = false;
410
411  const Edge* input_edge;
412  TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
413
414  // Simple case: the source node is a constant
415  const Node* src = input_edge->src();
416  if (src->IsConstant()) {
417    if (result->FromProto(src->def().attr().at("value").tensor())) {
418      *evaluated = true;
419      return Status::OK();
420    }
421  }
422
423  if (disable_constant_propagation_) {
424    return Status::OK();
425  }
426
427  bool is_constant_graph = false;
428  Graph subgraph(ops_registry_);
429  auto versions = subgraph.versions();
430  versions.set_producer(graph_def_version_);
431  subgraph.set_versions(versions);
432
433  // We identify the possibly constant subgraph to evaluate by
434  // recursively iterating backwards through the inputs to 'node'
435  // until we either 1) find an already existing input to our subgraph
436  // (filled in `const_inputs`), 2) Discover our graph is not constant,
437  // or 3) Hit a root node.
438  std::vector<std::pair<string, Tensor>> const_inputs;
439  TF_RETURN_IF_ERROR(ExtractConstantSubgraph(
440      input_edge->src(), &subgraph, &is_constant_graph, &const_inputs));
441  if (!is_constant_graph) {
442    return Status::OK();
443  }
444  const string output_tensor_name =
445      strings::StrCat(input_edge->src()->name(), ":", input_edge->src_output());
446  std::vector<Tensor> outputs;
447
448  // NOTE; we should pass in a function library runtime if we want
449  // to support constant-expression evaluation on functions.
450  Status s = graph_runner_.Run(&subgraph, nullptr /* function_library */,
451                               const_inputs, {output_tensor_name}, &outputs);
452
453  // If all kernels in the constant graph are not registered
454  // in the process, GraphRunner::Run may fail, in which case
455  // we cannot propagate constants, so this is best-effort.
456  if (s.ok()) {
457    *result = outputs[0];
458    *evaluated = true;
459
460    // We memoize (small) constants evaluated so far, so
461    // ExtractConstantSubgraph can avoid extracting the full
462    // subgraph.  As we build up large graphs, this avoids
463    // repeated computation of the early parts of a constant
464    // graph.
465    if (outputs[0].TotalBytes() <= kMaxTensorSize) {
466      const_tensor_map_[output_tensor_name] = outputs[0];
467    }
468  }
469  return Status::OK();
470}
471
472Status ShapeRefiner::TryToInferTensorOutputFromInputShapes(const Edge* edge,
473                                                           Tensor* output,
474                                                           bool* success) {
475  *success = false;
476  const Node* node = edge->src();
477  auto it = node_to_context_.find(node);
478  if (it == node_to_context_.end()) {
479    return errors::FailedPrecondition("Node does not have context.");
480  }
481  InferenceContext* c = it->second->get_context();
482
483  if (node->type_string() == "Shape") {
484    // If input shapes to the shape op are fully defined,
485    // we can infer the shape op's output tensor.
486    bool fully_defined_inputs = c->FullyDefined(c->input(0));
487    if (fully_defined_inputs) {
488      int input_rank = c->Rank(c->input(0));
489      Tensor t(node->output_type(0), TensorShape({input_rank}));
490      if (node->output_type(0) == DT_INT32) {
491        auto flat = t.flat<int>();
492        for (int i = 0; i < input_rank; i++) {
493          int64 dimension = c->Value(c->Dim(c->input(0), i));
494          if (!FastBoundsCheck(dimension, std::numeric_limits<int32>::max())) {
495            return errors::FailedPrecondition(
496                "Shape has output type int32, but dimension exceeds maximum "
497                "int32 value");
498          }
499          flat(i) = static_cast<int32>(dimension);
500        }
501      } else if (node->output_type(0) == DT_INT64) {
502        auto flat = t.flat<int64>();
503        for (int i = 0; i < input_rank; i++) {
504          flat(i) = c->Value(c->Dim(c->input(0), i));
505        }
506      } else {
507        return errors::FailedPrecondition(
508            "Shape has output type that is not int32 or int64");
509      }
510      *output = t;
511      *success = true;
512    }
513  } else if (node->type_string() == "Rank") {
514    bool rank_known = c->RankKnown(c->input(0));
515    if (rank_known) {
516      int32 input_rank = c->Rank(c->input(0));
517      Tensor t(node->output_type(0), TensorShape({}));
518      t.flat<int32>()(0) = input_rank;
519      *output = t;
520      *success = true;
521    }
522  } else if (node->type_string() == "Size") {
523    bool fully_defined_inputs = c->FullyDefined(c->input(0));
524    if (fully_defined_inputs) {
525      int32 rank = c->Rank(c->input(0));
526      Tensor t(node->output_type(0), TensorShape({}));
527      int64 size = 1;
528      for (int i = 0; i < rank; i++) {
529        size *= c->Value(c->Dim(c->input(0), i));
530      }
531      if (node->output_type(0) == DT_INT32) {
532        if (!FastBoundsCheck(size, std::numeric_limits<int32>::max())) {
533          return errors::FailedPrecondition(
534              "Size has output type int32, but size exceeds maximum int32 "
535              "value");
536        }
537        t.flat<int32>()(0) = static_cast<int32>(size);
538      } else if (node->output_type(0) == DT_INT64) {
539        t.flat<int64>()(0) = size;
540      } else {
541        return errors::FailedPrecondition(
542            "Size has output type that is not int32 or int64");
543      }
544      *output = t;
545      *success = true;
546    }
547  }
548  return Status::OK();
549}
550
551Status ShapeRefiner::ExtractConstantSubgraph(
552    Node* target_node, Graph* out_graph, bool* is_constant_graph,
553    std::vector<std::pair<string, Tensor>>* const_inputs) {
554  *is_constant_graph = false;
555  std::unordered_set<string> const_inputs_added;
556
557  if (target_node->op_def().is_stateful()) {
558    return Status::OK();
559  }
560
561  if (target_node->type_string() == "PlaceholderWithDefault") {
562    return Status::OK();
563  }
564
565  // TODO(skyewm): more of the filtering applied in input nodes below should be
566  // applied to target_node here
567
568  struct NodeAndRecursed {
569    Node* new_node = nullptr;
570    bool recursed = false;
571  };
572
573  std::map<Node*, NodeAndRecursed> old_to_new_and_recursed;
574  Node* target_node_copy = out_graph->CopyNode(target_node);
575  old_to_new_and_recursed[target_node].new_node = target_node_copy;
576  old_to_new_and_recursed[target_node].recursed = true;
577
578  // Add the target node's inputs to seed the recursion.
579  std::deque<const Edge*> edges_to_visit;
580  for (const Edge* e : target_node->in_edges()) {
581    // TODO(vrv): What do we do about control edges?  Based on our
582    // definition of a constant graph, we should be free to ignore
583    // control edges since the order in which a constant graph is
584    // executed should be the same regardless of when nodes run: we
585    // should only need to recurse down data edges.
586    if (e->IsControlEdge()) continue;
587    edges_to_visit.push_back(e);
588  }
589
590  *is_constant_graph = true;
591
592  // Iterate over the set of edges to visit (backwards).
593  while (!edges_to_visit.empty()) {
594    const Edge* current_edge = edges_to_visit.front();
595    edges_to_visit.pop_front();
596    Node* current_node = current_edge->src();
597
598    // If the node is stateful, assume the graph is not constant.
599    if (current_node->op_def().is_stateful()) {
600      *is_constant_graph = false;
601      return Status::OK();
602    }
603
604    // During construction or import from GraphConstructor, back edges may not
605    // be filled in.  Don't constant fold through merges at all for now.
606    if (IsMerge(current_node)) {
607      *is_constant_graph = false;
608      return Status::OK();
609    }
610
611    // Don't constant fold enter/exit currently either, as it's easy to end
612    // up with a partial frame.
613    if (IsEnter(current_node) || IsExit(current_node)) {
614      *is_constant_graph = false;
615      return Status::OK();
616    }
617
618    // Placeholders should never be constant folded because their outputs are
619    // fed by the user. Note that "Placeholder" nodes have no inputs so are
620    // handled below.
621    if (current_node->type_string() == "PlaceholderWithDefault") {
622      *is_constant_graph = false;
623      return Status::OK();
624    }
625
626    // If there is nothing more to recurse down, see if
627    // the generator node is a constant.
628    if (current_node->num_inputs() == 0) {
629      if (!current_node->IsConstant()) {
630        // Generator node is not a constant, so subgraph is not
631        // constant.
632        *is_constant_graph = false;
633        return Status::OK();
634      }
635    }
636
637    // Either the node is a constant, or the node is a potential
638    // intermediate node on the path from a constant.
639    //
640    // Add a copy of its node and a new edge to the new subgraph.
641
642    // Get or create the version of 'current_node' in the new graph.
643    Node* current_node_copy;
644    // This gets or creates the NodeAndRecursed entry for current_node.
645    NodeAndRecursed* node_and_recursed = &old_to_new_and_recursed[current_node];
646    if (node_and_recursed->new_node == nullptr) {
647      // First time processing this node.
648      current_node_copy = out_graph->CopyNode(current_node);
649      // Track the mapping from the original node to the new one.
650      node_and_recursed->new_node = current_node_copy;
651    } else {
652      current_node_copy = node_and_recursed->new_node;
653    }
654
655    // Add the edge to the destination node.
656    {
657      auto it = old_to_new_and_recursed.find(current_edge->dst());
658      if (it == old_to_new_and_recursed.end()) {
659        return errors::Internal(
660            "Could not find mapping from old to new copy of destination node: ",
661            current_edge->dst()->name());
662      }
663      Node* dst_copy = it->second.new_node;
664
665      out_graph->AddEdge(current_node_copy, current_edge->src_output(),
666                         dst_copy, current_edge->dst_input());
667    }
668
669    const string& output_tensor_name =
670        strings::StrCat(current_node->name(), ":", current_edge->src_output());
671
672    // Some tensor values can be inferred. For example, a shape op
673    // with input shapes fully defined can have its output tensor inferred.
674    Tensor tensor_inferred;
675    bool successfully_inferred_tensor = false;
676    TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes(
677        current_edge, &tensor_inferred, &successfully_inferred_tensor));
678    if (successfully_inferred_tensor) {
679      const_inputs->emplace_back(output_tensor_name, tensor_inferred);
680      const_inputs_added.insert(output_tensor_name);
681      continue;
682    }
683
684    // If we have a copy of the input tensor materialized already,
685    // then add to the list of inputs to feed and do not recurse further.
686    auto it = const_tensor_map_.find(output_tensor_name);
687    if (it != const_tensor_map_.end() &&
688        const_inputs_added.count(output_tensor_name) == 0) {
689      const_inputs->emplace_back(output_tensor_name, it->second);
690      const_inputs_added.insert(output_tensor_name);
691      continue;
692    }
693
694    // If this node's inputs have not been processed already, do so now.
695    if (!node_and_recursed->recursed) {
696      node_and_recursed->recursed = true;
697      for (const Edge* e : current_node->in_edges()) {
698        if (e->IsControlEdge()) continue;
699        edges_to_visit.push_back(e);
700      }
701    }
702  }
703
704  return Status::OK();
705}
706
707Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
708                                          const Node* node, int dst_idx,
709                                          ShapeHandle* result) {
710  const Edge* input_edge;
711  TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
712
713  InferenceContext* src_context = GetContext(input_edge->src());
714  if (src_context == nullptr) return errors::Internal("Missing src context");
715  ShapeHandle src_shape = src_context->output(input_edge->src_output());
716  TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
717
718  const string& src_op = input_edge->src()->type_string();
719  if (src_context->Value(src_context->Dim(src_shape, 0)) == 0) {
720    // Source tensor is a vector of length 0, so the shape it
721    // represents is as scalar.
722    *result = target_context->Scalar();
723  } else if (src_op == "Shape") {
724    *result = src_context->input(0);
725  } else if (src_op == "ShapeN") {
726    *result = src_context->input(input_edge->src_output());
727  } else if (src_op == "Pack") {
728    std::vector<DimensionHandle> dims;
729    // Pack is concatenating its input scalars to form the shape tensor vector.
730    for (int i = 0; i < src_context->num_inputs(); ++i) {
731      Tensor scalar;
732      bool evaluated = false;
733      TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(input_edge->src(), i,
734                                                       &evaluated, &scalar));
735      if (evaluated) {
736        int64 size;
737        if (scalar.dtype() == DT_INT32) {
738          size = scalar.scalar<int32>()();
739        } else if (scalar.dtype() == DT_INT64) {
740          size = scalar.scalar<int64>()();
741        } else {
742          return errors::InvalidArgument("Pack input must be int32 or int64");
743        }
744        dims.push_back(size < 0 ? target_context->UnknownDim()
745                                : target_context->MakeDim(size));
746      } else {
747        dims.push_back(target_context->UnknownDim());
748      }
749    }
750    *result = target_context->MakeShape(dims);
751  } else if (src_op == "Concat" || src_op == "ConcatV2") {
752    *result = target_context->Scalar();
753    // For Concat, input 0 is concat dim; for V2 it is the last input.
754    const int concat_dim =
755        src_op == "Concat" ? 0 : src_context->num_inputs() - 1;
756    // Concat is concatenating its input shape vectors.
757    for (int i = 0; i < src_context->num_inputs(); ++i) {
758      // Concat dim is ignored (and will always be a scalar).
759      if (i == concat_dim) continue;
760      ShapeHandle sub_result;
761      TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
762                                              i, &sub_result));
763      if (!target_context->RankKnown(sub_result)) {
764        // Failed to evaluate. Treat the output as completely unknown.
765        // TODO(cwhipkey): we could rely on all inputs being the same rank, so
766        // figure that rank out and append the right number of unknown dims.
767        *result = target_context->UnknownShape();
768        return Status::OK();
769      }
770      TF_RETURN_IF_ERROR(
771          target_context->Concatenate(*result, sub_result, result));
772    }
773  } else {
774    Tensor t;
775    bool evaluated = false;
776    TF_RETURN_IF_ERROR(
777        EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
778    TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
779        evaluated ? &t : nullptr, src_shape, result));
780  }
781  return Status::OK();
782}
783
784Status ShapeRefiner::RunShapeFn(const Node* node,
785                                const OpRegistrationData* op_reg_data,
786                                ExtendedInferenceContext* ec) {
787  // This will be filled in with real data in a second pass.
788  std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
789  std::vector<Tensor> real_tensors(node->num_inputs());
790  std::vector<bool> attempted_materialization(node->num_inputs());
791  std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
792  std::vector<ShapeHandle> input_tensors_as_shapes;
793
794  auto* c = ec->get_context();
795
796  c->set_input_tensors(input_tensors);
797  c->set_input_tensors_as_shapes(input_tensors_as_shapes);
798
799  // Run the shape inference function, and return if there was an error.
800  // Capture as lambda, because we might need to re-run inference later on.
801  auto run_inference_lambda = [&]() {
802    if (function_library_ && op_reg_data->is_function_op) {
803      // Special inference logic for user-defined functions.
804
805      auto* func_def = function_library_->Find(op_reg_data->op_def.name());
806      if (func_def) {
807        return InferShapesForFunction(func_def, keep_nested_shape_inferences_,
808                                      ec);
809      }
810    }
811
812    if (op_reg_data->shape_inference_fn) {
813      TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
814    } else {
815      TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
816    }
817    return Status::OK();
818  };
819  TF_RETURN_IF_ERROR(run_inference_lambda());
820
821  // We must run the shape function repeatedly, in case users write
822  // shape functions where they only conditionally call input_tensor()
823  // based on the values of another input tensor.
824  bool rerun_shape_fn;
825  do {
826    // If the result of running shape inference would have benefitted
827    // from knowing the values of input tensors, try to materialize
828    // the results of those tensors, and then run the shape inference
829    // function again using those known tensors.
830    rerun_shape_fn = false;
831
832    // NOTE: It is possible to batch the extraction and
833    // materialization of inputs, instead of materializing one input
834    // at a time like we do below.  If input-at-a-time computation
835    // becomes a bottleneck, we could separate ExtractConstantSubgraph
836    // into two functions: one that returns true if an input is
837    // derivable from constants, and another function that extracts
838    // the subgraph for multiple target nodes and executes the whole
839    // subgraph once.
840
841    for (int i = 0; i < c->num_inputs(); ++i) {
842      if (!c->requested_input_tensor(i)) {
843        continue;
844      }
845      // Check if we have not already filled in the requested input,
846      // and if not, try to materialize the tensors.
847      if (!attempted_materialization[i]) {
848        attempted_materialization[i] = true;
849
850        Tensor result;
851        bool evaluated = false;
852        TF_RETURN_IF_ERROR(
853            EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
854        if (evaluated) {
855          real_tensors[i] = result;
856          input_tensors[i] = &real_tensors[i];
857          // We have more concrete information about a shape,
858          // so re-run shape inference.
859          rerun_shape_fn = true;
860        }
861      }
862      if (c->requested_input_tensor_as_partial_shape(i) &&
863          !attempted_tensor_as_shape_conversion[i]) {
864        attempted_tensor_as_shape_conversion[i] = true;
865        if (i >= input_tensors_as_shapes.size()) {
866          input_tensors_as_shapes.resize(i + 1);
867        }
868        ShapeHandle s;
869        TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s));
870        input_tensors_as_shapes[i] = s;
871        rerun_shape_fn = true;
872      }
873    }
874
875    if (rerun_shape_fn) {
876      // We have more information about the shapes on this pass,
877      // so re-run shape inference.
878      c->set_input_tensors(input_tensors);
879      c->set_input_tensors_as_shapes(input_tensors_as_shapes);
880      TF_RETURN_IF_ERROR(run_inference_lambda());
881    }
882  } while (rerun_shape_fn);
883
884  return Status::OK();
885}
886
887bool ShapeRefiner::SameDefinedShape(InferenceContext* c, ShapeHandle s0,
888                                    ShapeHandle s1) {
889  if (s0.SameHandle(s1)) {
890    return true;
891  }
892  if (c->Rank(s0) != c->Rank(s1)) {
893    return false;
894  }
895  if (!c->RankKnown(s0) && !c->RankKnown(s1)) {
896    return false;
897  }
898  for (int i = 0; i < c->Rank(s0); ++i) {
899    if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
900      int64 val0 = c->Value(c->Dim(s0, i));
901      int64 val1 = c->Value(c->Dim(s1, i));
902      if (val0 < 0 || val1 < 0 || val0 != val1) {
903        return false;
904      }
905    }
906  }
907
908  return true;
909}
910
911bool ShapeRefiner::IsUpdatedShapesOrTypes(
912    InferenceContext* c, const std::vector<ShapeAndType>& existing,
913    const std::vector<ShapeAndType>& updated) {
914  if (existing.size() != updated.size()) {
915    return true;
916  }
917  for (int i = 0; i < existing.size(); i++) {
918    if (!SameDefinedShape(c, existing[i].shape, updated[i].shape) ||
919        existing[i].dtype != updated[i].dtype) {
920      return true;
921    }
922  }
923  return false;
924}
925
926}  // namespace tensorflow
927