1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <algorithm>
16#include <iterator>
17#include <memory>
18#include <string>
19#include <unordered_map>
20#include <vector>
21
22#include "absl/strings/str_join.h"
23#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
24#include "tensorflow/contrib/lite/toco/model.h"
25#include "tensorflow/contrib/lite/toco/tooling_util.h"
26#include "tensorflow/core/platform/logging.h"
27
28namespace toco {
29
30namespace {
31
32void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth,
33                      int kheight, int stride_width, int stride_height,
34                      PaddingType padding_type, Shape* output_shape,
35                      FixedPadding* fixed_padding) {
36  const int input_width = input_shape.dims(2);
37  const int input_height = input_shape.dims(1);
38  const int batch = input_shape.dims(0);
39
40  int output_height = 0;
41  int output_width = 0;
42  if (padding_type == PaddingType::kValid) {
43    output_height = (input_height + stride_height - kheight) / stride_height;
44    output_width = (input_width + stride_width - kwidth) / stride_width;
45  } else if (padding_type == PaddingType::kSame) {
46    output_height = (input_height + stride_height - 1) / stride_height;
47    output_width = (input_width + stride_width - 1) / stride_width;
48  } else {
49    LOG(FATAL) << "Only supporting SAME or VALID padding";
50  }
51
52  fixed_padding->height = std::max(
53      0, ((output_height - 1) * stride_height + kheight - input_height) / 2);
54  fixed_padding->width = std::max(
55      0, ((output_width - 1) * stride_width + kwidth - input_width) / 2);
56
57  // Actually had to debug a situation where those were negative due to bad
58  // propagation of placeholder -1 sizes in TensorFlowReshape.
59  CHECK_GT(output_width, 0);
60  CHECK_GT(output_height, 0);
61  output_shape->ReplaceDims({batch, output_height, output_width, output_depth});
62}
63
64void ComputeBinaryOperatorOutputSize(const Shape& input_shape_x,
65                                     const Shape& input_shape_y,
66                                     Array* output_array) {
67  // This matches the code in BroadcastBinaryOpShapeFn from tensorflow.
68  // It zips together the two input shapes and pads with 1 to make them the
69  // same length. For each dimension we broadcast if either dimension is 1 and
70  // otherwise expect them to match.
71  int rank_x = input_shape_x.dimensions_count();
72  int rank_y = input_shape_y.dimensions_count();
73  int rank_out = std::max(rank_x, rank_y);
74  std::vector<int>* dims_out = output_array->mutable_shape()->mutable_dims();
75  dims_out->clear();
76  dims_out->reserve(rank_out);
77  for (int i = 0; i < rank_out; ++i) {
78    int dim_x = i < (rank_out - rank_x)
79                    ? 1
80                    : input_shape_x.dims(i - (rank_out - rank_x));
81    bool dim_y_is_one = i < (rank_out - rank_y);
82    int dim_y = dim_y_is_one ? 1 : input_shape_y.dims(i - (rank_out - rank_y));
83    if (dim_x == -1 || dim_y == -1) {
84      // One or both dimensions is unknown.
85      QCHECK(false) << "Shapes must be specified";
86    } else if (dim_x == 1 || dim_y == 1) {
87      // Broadcast one dimension to the other that is 1.
88      if (dim_x == 1 && !dim_y_is_one) {
89        // Broadcast dim_y to dim_x (1).
90        dims_out->push_back(dim_y);
91      } else {
92        // Broadcast dim_x to dim_y (1).
93        DCHECK_EQ(dim_y, 1);
94        dims_out->push_back(dim_x);
95      }
96    } else {
97      // Expect the dimensions to match.
98      CHECK_EQ(dim_x, dim_y) << "Dimensions must match";
99      dims_out->push_back(dim_x);
100    }
101  }
102  CHECK(output_array->has_shape());
103}
104
105int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
106  const string& weights_name = op.inputs[1];
107  const auto& weights_shape = model.GetArray(weights_name).shape();
108  if (op.type == OperatorType::kConv ||
109      op.type == OperatorType::kFullyConnected) {
110    return weights_shape.dims(0);
111  } else if (op.type == OperatorType::kDepthwiseConv) {
112    return weights_shape.dims(3);
113  } else {
114    LOG(FATAL) << "Unhandled operator type";
115  }
116}
117
118bool EnsureBiasVectorShape(Model* model, Operator* op) {
119  const string& weights_name = op->inputs[1];
120  const auto& weights_array = model->GetArray(weights_name);
121  // Yield until weights shape has been resolved.
122  if (!weights_array.has_shape()) {
123    return false;
124  }
125
126  if (op->inputs.size() < 3) {
127    return false;
128  }
129  auto& bias_array = model->GetArray(op->inputs[2]);
130  if (bias_array.has_shape()) {
131    return true;
132  }
133
134  const int output_depth = GetOutputDepthFromWeights(*model, *op);
135  bias_array.copy_shape(Shape({output_depth}));
136
137  auto& float_buffer = bias_array.GetMutableBuffer<ArrayDataType::kFloat>();
138  float_buffer.data.resize(output_depth, 0);
139
140  return true;
141}
142
143void ProcessConvOperator(Model* model, ConvOperator* op) {
144  if (!EnsureBiasVectorShape(model, op)) {
145    return;
146  }
147
148  const auto& input_array = model->GetArray(op->inputs[0]);
149  // Yield until input dims have been resolved.
150  if (!input_array.has_shape()) {
151    return;
152  }
153  const auto& input_shape = input_array.shape();
154  CHECK_EQ(input_shape.dimensions_count(), 4);
155
156  const auto& weights_array = model->GetArray(op->inputs[1]);
157  // Yield until weights dims have been resolved.
158  if (!weights_array.has_shape()) {
159    return;
160  }
161  const auto& weights_shape = weights_array.shape();
162  CHECK_EQ(weights_shape.dimensions_count(), 4);
163
164  auto& output_array = model->GetArray(op->outputs[0]);
165  const int output_depth = weights_shape.dims(0);
166  const int kheight = weights_shape.dims(1);
167  const int kwidth = weights_shape.dims(2);
168  ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
169                   op->stride_height, op->padding.type,
170                   output_array.mutable_shape(),
171                   &op->padding.GetOrCreateFixedPadding());
172  CHECK_EQ(output_array.shape().dimensions_count(), 4);
173
174  // Set im2col array dimensions if there is one.
175  if (op->outputs.size() == 2) {
176    const auto& output_shape = output_array.shape();
177    const int input_depth = weights_shape.dims(3);
178    auto& im2col_array = model->GetArray(op->outputs[1]);
179    im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1),
180                                  output_shape.dims(2),
181                                  input_depth * kheight * kwidth});
182  }
183}
184
185void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
186  if (!EnsureBiasVectorShape(model, op)) {
187    return;
188  }
189
190  const auto& input_array = model->GetArray(op->inputs[0]);
191  // Yield until input dims have been resolved.
192  if (!input_array.has_shape()) {
193    return;
194  }
195  const auto& input_shape = input_array.shape();
196  CHECK_EQ(input_shape.dimensions_count(), 4);
197
198  const auto& weights_array = model->GetArray(op->inputs[1]);
199  // Yield until weights dims have been resolved.
200  if (!weights_array.has_shape()) {
201    return;
202  }
203  const auto& weights_shape = weights_array.shape();
204  CHECK_EQ(weights_shape.dimensions_count(), 4);
205
206  const string& output_name = op->outputs[0];
207  const int input_depth = input_shape.dims(3);
208  const int output_depth = weights_shape.dims(3);
209  // TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops,
210  // instead it has to be inferred from the weights dims. However, once we are
211  // here, weights dims have already been converted to our own internal format,
212  // where the multiplier is no longer readily apparent. So instead we get it
213  // as the quotient of output and input depths. We only want to do that when
214  // depth_multiplier had the zero value: any other value should be checked
215  // as done by the next if() below.
216  if (!op->depth_multiplier) {
217    op->depth_multiplier = output_depth / input_depth;
218  }
219  QCHECK_EQ(output_depth, input_depth * op->depth_multiplier)
220      << "input/output depths and depth_multiplier don't match";
221
222  const int kheight = weights_shape.dims(1);
223  const int kwidth = weights_shape.dims(2);
224  ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
225                   op->stride_height, op->padding.type,
226                   model->GetArray(output_name).mutable_shape(),
227                   &op->padding.GetOrCreateFixedPadding());
228}
229
230void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
231  const auto& input_array = model->GetArray(op->inputs[0]);
232  // Yield until input dims have been resolved.
233  if (!input_array.has_shape()) {
234    return;
235  }
236  const auto& input_shape = input_array.shape();
237  CHECK_EQ(input_shape.dimensions_count(), 4);
238
239  const string& output_name = op->outputs[0];
240  const int block_size = op->block_size;
241  CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
242  const int batch = input_shape.dims(0);
243  const int height = input_shape.dims(1);
244  const int width = input_shape.dims(2);
245  const int depth = input_shape.dims(3);
246  QCHECK_EQ(depth % (block_size * block_size), 0);
247
248  model->GetArray(output_name)
249      .copy_shape(Shape({batch, height * block_size, width * block_size,
250                         depth / block_size / block_size}));
251}
252
253void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
254  const auto& input_array = model->GetArray(op->inputs[0]);
255  // Yield until input dims have been resolved.
256  if (!input_array.has_shape()) {
257    return;
258  }
259  const auto& input_shape = input_array.shape();
260  CHECK_EQ(input_shape.dimensions_count(), 4);
261
262  const string& output_name = op->outputs[0];
263  const int block_size = op->block_size;
264  CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
265  const int batch = input_shape.dims(0);
266  const int height = input_shape.dims(1);
267  const int width = input_shape.dims(2);
268  const int depth = input_shape.dims(3);
269  QCHECK_EQ(width % block_size, 0);
270  QCHECK_EQ(height % block_size, 0);
271
272  model->GetArray(output_name)
273      .copy_shape(Shape({batch, height / block_size, width / block_size,
274                         depth * block_size * block_size}));
275}
276
277void ProcessFillOperator(Model* model, FillOperator* op) {
278  CHECK_EQ(op->inputs.size(), 2);
279  CHECK_EQ(op->outputs.size(), 1);
280  auto& output_array = model->GetArray(op->outputs[0]);
281  if (output_array.has_shape()) {
282    // We have already run
283    return;
284  }
285
286  auto& dims_array = model->GetArray(op->inputs[0]);
287  if (!dims_array.has_shape()) {
288    // Yield until dims shape been resolved.
289    return;
290  }
291  if (!dims_array.buffer) {
292    // Yield until the dims are constant
293    return;
294  }
295  CHECK(dims_array.data_type == ArrayDataType::kInt32) << "dims must be int32";
296  CHECK_LE(RequiredBufferSizeForShape(dims_array.shape()), 4)
297      << "dims vector can be no larger than 4 values";
298
299  std::vector<int32> const& dims =
300      dims_array.GetBuffer<ArrayDataType::kInt32>().data;
301  *(output_array.mutable_shape()->mutable_dims()) = dims;
302}
303
304void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
305  if (!EnsureBiasVectorShape(model, op)) {
306    return;
307  }
308
309  const auto& input_array = model->GetArray(op->inputs[0]);
310  // Yield until input dims have been resolved.
311  if (!input_array.has_shape()) {
312    return;
313  }
314  const auto& input_shape = input_array.shape();
315  CHECK_GE(input_shape.dimensions_count(), 1);
316
317  const auto& weights_array = model->GetArray(op->inputs[1]);
318  // Yield until weights dims have been resolved.
319  if (!weights_array.has_shape()) {
320    return;
321  }
322  const auto& weights_shape = weights_array.shape();
323
324  const int weights_output_depth = weights_shape.dims(0);
325  CHECK_EQ(weights_shape.dimensions_count(), 2);
326
327  const int input_overall_size = RequiredBufferSizeForShape(input_shape);
328  const int matmul_repeats = input_overall_size / weights_shape.dims(1);
329  CHECK_EQ(matmul_repeats * weights_shape.dims(1), input_overall_size);
330
331  auto& output_array = model->GetArray(op->outputs[0]);
332  output_array.copy_shape(Shape({matmul_repeats, weights_output_depth}));
333}
334
335void ProcessTensorFlowReshapeOperator(Model* model,
336                                      TensorFlowReshapeOperator* op) {
337  auto& output_array = model->GetArray(op->outputs[0]);
338  if (output_array.has_shape()) {
339    // We have already run
340    return;
341  }
342
343  const auto& input_array = model->GetArray(op->inputs[0]);
344  if (!input_array.has_shape()) {
345    // Yield until input dims have been resolved.
346    return;
347  }
348  const auto& input_shape = input_array.shape();
349
350  auto& shape_array = model->GetArray(op->inputs[1]);
351  if (!shape_array.has_shape()) {
352    // Yield until target_shape shape been resolved.
353    return;
354  }
355  if (!shape_array.buffer) {
356    // Yield until the target_shape is constant
357    return;
358  }
359  CHECK(shape_array.data_type == ArrayDataType::kInt32)
360      << "Reshape dims must be int32";
361
362  // shape_data is the raw array of ints describing the shape
363  // in the TensorFlow node. We intentionally make a copy here, rather than
364  // modify wildcards in-place below, because in some graphs, the same shape
365  // array with a wildcard may be referenced from multiple Reshape nodes, where
366  // the wildcard needs to resolved to distinct values.
367  std::vector<int32> shape_data =
368      shape_array.GetBuffer<ArrayDataType::kInt32>().data;
369  // The Reshape shape may have a wildcard dim, encoded as -1.
370  bool has_wildcard = false;
371  int wildcard_index = 0;
372  int product_non_wildcard_dims = 1;
373  for (int i = 0; i < shape_data.size(); i++) {
374    if (shape_data[i] == -1) {
375      CHECK(!has_wildcard);
376      has_wildcard = true;
377      wildcard_index = i;
378    } else {
379      product_non_wildcard_dims *= shape_data[i];
380    }
381  }
382  const int input_flat_size = RequiredBufferSizeForShape(input_shape);
383  if (has_wildcard) {
384    CHECK_GE(input_flat_size, product_non_wildcard_dims)
385        << "Array not large enough to fill the requested dimensions for "
386           "Reshape op with output \""
387        << op->outputs[0] << "\". Are your input shapes correct?";
388    shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims;
389  }
390  auto& output_shape = *output_array.mutable_shape();
391  *output_shape.mutable_dims() = shape_data;
392  CHECK_EQ(input_flat_size, RequiredBufferSizeForShape(output_shape))
393      << "Input cannot be reshaped to requested dimensions for Reshape op with "
394         "output \""
395      << op->outputs[0] << "\". Are your input shapes correct?";
396}
397
398void ProcessSimpleOperator(Model* model, Operator* op) {
399  const auto& input_array = model->GetArray(op->inputs[0]);
400  // Yield until input dims have been resolved.
401  if (!input_array.has_shape()) {
402    return;
403  }
404
405  const string& output_name = op->outputs[0];
406  auto& output_array = model->GetArray(output_name);
407  if (output_array.has_shape()) {
408    return;
409  }
410
411  output_array.copy_shape(input_array.shape());
412}
413
414void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
415  CHECK_EQ(op->inputs.size(), 2);
416  const auto& input0_array = model->GetArray(op->inputs[0]);
417  const auto& input1_array = model->GetArray(op->inputs[1]);
418  // Yield until input dims have been resolved.
419  if (!input0_array.has_shape() || !input1_array.has_shape()) {
420    return;
421  }
422  const string& output_name = op->outputs[0];
423  auto& output_array = model->GetArray(output_name);
424  ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
425                                  &output_array);
426}
427
428void ProcessAddNOperator(Model* model, Operator* op) {
429  // Yield until all input dims have been resolved.
430  //
431  // TODO(myenik): Since AddN does not support broadcasting, maybe we could
432  // actually use this to improve shape propagation by propagating the shape of
433  // one input to all other inputs once it is resolved instead of just the
434  // output, since all inputs must be the same size and shape for a well-formed
435  // graph.
436  for (const auto& input : op->inputs) {
437    const auto& input_array = model->GetArray(input);
438    if (!input_array.has_shape()) {
439      return;
440    }
441  }
442
443  // AddN does not support broadcasting, all inputs must be the same shape, so
444  // we just take the first input shape and apply it to the output.
445  const auto& input0_array = model->GetArray(op->inputs[0]);
446  auto& output_array = model->GetArray(op->outputs[0]);
447  output_array.copy_shape(input0_array.shape());
448}
449
450bool KeepDims(const Operator& op) {
451  switch (op.type) {
452    case OperatorType::kTensorFlowMin:
453      return static_cast<const TensorFlowMinOperator&>(op).keep_dims;
454    case OperatorType::kTensorFlowMax:
455      return static_cast<const TensorFlowMaxOperator&>(op).keep_dims;
456    case OperatorType::kTensorFlowSum:
457      return static_cast<const TensorFlowSumOperator&>(op).keep_dims;
458    case OperatorType::kMean:
459      return static_cast<const MeanOperator&>(op).keep_dims;
460    default:
461      LOG(FATAL) << "Not a reduction operator!";
462      return false;
463  }
464}
465
466void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
467  CHECK_LE(op->inputs.size(), 2);
468  auto& output_array = model->GetArray(op->outputs[0]);
469  if (output_array.has_shape()) {
470    return;
471  }
472  const auto& input_array = model->GetArray(op->inputs[0]);
473  if (!input_array.has_shape()) {
474    return;
475  }
476  const auto& input_shape = input_array.shape();
477  const bool keep_dims = KeepDims(*op);
478  if (op->inputs.size() == 2) {
479    // There is a reduction_indices input.
480    const auto& reduction_array = model->GetArray(op->inputs[1]);
481    if (!reduction_array.buffer) {
482      return;
483    }
484    CHECK(reduction_array.buffer->type == ArrayDataType::kInt32);
485    const auto& reduction_array_vals =
486        reduction_array.GetBuffer<ArrayDataType::kInt32>().data;
487    auto& output_dims = *output_array.mutable_shape()->mutable_dims();
488    output_dims.clear();
489    for (int i = 0; i < input_shape.dimensions_count(); i++) {
490      bool is_reduction_dim = false;
491      for (int r : reduction_array_vals) {
492        if (i == r) {
493          is_reduction_dim = true;
494        }
495      }
496      if (!is_reduction_dim) {
497        output_dims.push_back(input_shape.dims(i));
498      } else if (keep_dims) {
499        output_dims.push_back(1);
500      }
501    }
502  } else {
503    // No reduction_indices means complete reduction to a single scalar.
504    if (keep_dims) {
505      output_array.copy_shape(input_shape);
506    } else {
507      output_array.copy_shape(Shape({}));
508    }
509  }
510}
511
512void ProcessSliceOperator(Model* model, SliceOperator* op) {
513  CHECK_EQ(op->inputs.size(), 3);
514  CHECK_EQ(op->outputs.size(), 1);
515
516  // Yield until the Slice params have been resolved.
517  if (op->begin.empty()) return;
518
519  // Yield until input dims have been resolved.
520  const auto& input_array = model->GetArray(op->inputs[0]);
521  if (!input_array.has_shape()) return;
522  const Shape& input_shape = input_array.shape();
523
524  auto& output_array = model->GetArray(op->outputs[0]);
525  if (output_array.has_shape()) return;
526
527  CHECK_EQ(input_shape.dims().size(), op->size.size());
528  CHECK_EQ(op->begin.size(), op->size.size());
529
530  std::vector<int> output_dims;
531  for (int i = 0; i < op->begin.size(); ++i) {
532    int size = op->size[i];
533    if (size == -1) {
534      size = input_array.shape().dims(i) - op->begin[i];
535    }
536    output_dims.push_back(size);
537  }
538
539  *output_array.mutable_shape()->mutable_dims() = output_dims;
540}
541
542void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
543  const string& input_name = op->inputs[0];
544  const auto& input_array = model->GetArray(input_name);
545  // Yield until input dims have been resolved.
546  if (!input_array.has_shape()) {
547    return;
548  }
549  const auto& input_shape = input_array.shape();
550  const string& output_name = op->outputs[0];
551  Shape* output_shape = model->GetArray(output_name).mutable_shape();
552  ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order,
553              output_shape);
554}
555
556void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
557  // Yield until input dims have been resolved.
558  for (const auto& input_name : op->inputs) {
559    auto& input_array = model->GetArray(input_name);
560    if (!input_array.has_shape()) {
561      return;
562    }
563  }
564  auto& output_array = model->GetArray(op->outputs[0]);
565  // Use 0 input as basis for output dimensions.
566  const auto& first_input_array = model->GetArray(op->inputs[0]);
567  output_array.copy_shape(first_input_array.shape());
568  // Negative axis means the count starts at the back of the dims().
569  int axis = op->axis;
570  if (axis < 0) axis += first_input_array.shape().dims().size();
571  // Determine the concat size, and enfore that all inputs have
572  // the same dimensions count.
573  int concat_size = 0;
574  for (const auto& input_name : op->inputs) {
575    auto& input_array = model->GetArray(input_name);
576    CHECK(input_array.has_shape());
577    if (input_array.shape().dimensions_count() == 0) {
578      continue;
579    }
580    CHECK_EQ(input_array.shape().dimensions_count(),
581             output_array.shape().dimensions_count());
582    const std::vector<int>& input_dims = input_array.shape().dims();
583    CHECK_LT(axis, input_dims.size());
584    concat_size += input_dims[axis];
585  }
586  // Write out the concat_size on the output array shape.
587  auto& output_shape = *output_array.mutable_shape();
588  auto& output_dims = *output_shape.mutable_dims();
589  CHECK_LT(axis, output_shape.dimensions_count());
590  output_dims[axis] = concat_size;
591}
592
593void ProcessRangeOperator(Model* model, RangeOperator* op) {
594  CHECK_EQ(op->inputs.size(), 3);
595  const auto& start_array = model->GetArray(op->inputs[0]);
596  if (!start_array.has_shape()) {
597    // Yield until input dims have been resolved.
598    return;
599  }
600  const auto& limit_array = model->GetArray(op->inputs[1]);
601  if (!limit_array.has_shape()) {
602    return;
603  }
604  const auto& delta_array = model->GetArray(op->inputs[2]);
605  if (!delta_array.has_shape()) {
606    return;
607  }
608
609  if (!IsConstantParameterArray(*model, op->inputs[0])) {
610    // Yield until inputs are constant.
611    return;
612  }
613  if (!IsConstantParameterArray(*model, op->inputs[1])) {
614    return;
615  }
616  if (!IsConstantParameterArray(*model, op->inputs[2])) {
617    return;
618  }
619
620  CHECK(start_array.data_type == ArrayDataType::kInt32)
621      << "Range op inputs must be int32.";
622  CHECK(limit_array.data_type == ArrayDataType::kInt32)
623      << "Range op inputs must be int32.";
624  CHECK(delta_array.data_type == ArrayDataType::kInt32)
625      << "Range op inputs must be int32.";
626  CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1)
627      << "Range op inputs must be scalar.";
628  CHECK_EQ(RequiredBufferSizeForShape(limit_array.shape()), 1)
629      << "Range op inputs must be scalar.";
630  CHECK_EQ(RequiredBufferSizeForShape(delta_array.shape()), 1)
631      << "Range op inputs must be scalar.";
632  int size = floor((limit_array.GetBuffer<ArrayDataType::kInt32>().data[0] -
633                    start_array.GetBuffer<ArrayDataType::kInt32>().data[0]) /
634                   delta_array.GetBuffer<ArrayDataType::kInt32>().data[0]);
635
636  // Only set the output shape. Contents are set by ResolveConstantRange.
637  CHECK_EQ(op->outputs.size(), 1);
638  auto& output_array = model->GetArray(op->outputs[0]);
639  Shape* output_shape = output_array.mutable_shape();
640  output_shape->ReplaceDims({size});
641}
642
643void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
644  CHECK_EQ(op->inputs.size(), 2);
645  const string& input_name = op->inputs[1];
646  const auto& input_array = model->GetArray(input_name);
647  // Yield until input dims have been resolved.
648  if (!input_array.has_shape()) {
649    return;
650  }
651  const Shape& input_shape = input_array.shape();
652
653  // Yield until axis is constant.
654  if (!IsConstantParameterArray(*model, op->inputs[0])) {
655    return;
656  }
657
658  const auto& axis_array = model->GetArray(op->inputs[0]);
659
660  // Yield until axis dims have been resolved.
661  if (!axis_array.has_shape()) {
662    return;
663  }
664
665  CHECK(axis_array.data_type == ArrayDataType::kInt32)
666      << "Axis array must be int32.";
667  CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
668      << "Axis array must be scalar.";
669
670  int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
671  if (axis < 0) {
672    axis += input_shape.dimensions_count();
673  }
674
675  const int split_dim = input_shape.dims(axis);
676  CHECK_EQ(split_dim % op->num_split, 0);
677  const int split_depth = split_dim / op->num_split;
678
679  Shape output_shape = input_shape;
680  (*output_shape.mutable_dims())[axis] = split_depth;
681
682  CHECK_EQ(op->outputs.size(), op->num_split);
683  for (const auto& output : op->outputs) {
684    model->GetArray(output).copy_shape(output_shape);
685  }
686}
687
688void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
689  const string& input_name = op->inputs[0];
690  const auto& input_array = model->GetArray(input_name);
691  // Yield until input dims have been resolved.
692  if (!input_array.has_shape()) {
693    return;
694  }
695  const auto& input_shape = input_array.shape();
696  CHECK_EQ(input_shape.dimensions_count(), 4);
697  const string& output_name = op->outputs[0];
698  const int output_depth = input_shape.dims(3);
699  ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
700                   op->stride_width, op->stride_height, op->padding.type,
701                   model->GetArray(output_name).mutable_shape(),
702                   &op->padding.GetOrCreateFixedPadding());
703}
704
705void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
706  const string& input_name = op->inputs[0];
707  const auto& input_array = model->GetArray(input_name);
708  // Yield until input dims have been resolved.
709  if (!input_array.has_shape()) {
710    return;
711  }
712  const auto& input_shape = input_array.shape();
713  CHECK_EQ(input_shape.dimensions_count(), 4);
714  const string& output_name = op->outputs[0];
715  const int output_depth = input_shape.dims(3);
716  ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
717                   op->stride_width, op->stride_height, op->padding.type,
718                   model->GetArray(output_name).mutable_shape(),
719                   &op->padding.GetOrCreateFixedPadding());
720}
721
722void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
723  const string& input_name = op->inputs[0];
724  const auto& input_array = model->GetArray(input_name);
725  // Yield until input dims have been resolved.
726  if (!input_array.has_shape()) {
727    return;
728  }
729  const auto& input_shape = input_array.shape();
730  if (input_shape.dimensions_count() < 4) {
731    LOG(FATAL) << "missing dimensions for " << input_name;
732  }
733  const string& output_name = op->outputs[0];
734  const int output_depth = input_shape.dims(3);
735  ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
736                   op->stride_width, op->stride_height, op->padding.type,
737                   model->GetArray(output_name).mutable_shape(),
738                   &op->padding.GetOrCreateFixedPadding());
739}
740
741void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
742  CHECK_EQ(op->inputs.size(), 2);
743  CHECK_EQ(op->outputs.size(), 1);
744
745  if (!model->GetArray(op->inputs[0]).has_shape() ||
746      !model->GetArray(op->inputs[1]).has_shape()) {
747    return;
748  }
749  const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
750
751  const string& output_size_name = op->inputs[1];
752  const auto& output_size_array = model->GetArray(output_size_name);
753  CHECK(output_size_array.data_type == ArrayDataType::kInt32);
754  CHECK(output_size_array.has_shape());
755  const auto& output_size_shape = output_size_array.shape();
756  CHECK_EQ(output_size_shape.dimensions_count(), 1);
757  CHECK_EQ(output_size_shape.dims(0), 2);
758  if (!output_size_array.buffer) {
759    return;
760  }
761  std::vector<int32> output_shape =
762      output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
763  model->GetArray(op->outputs[0])
764      .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
765                         output_shape[1], input_data_shape.dims(3)}));
766}
767
768void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
769  // Only required for compact LstmCell with default NUM_INPUTS of inputs.
770  if (op->inputs.size() != LstmCellOperator::NUM_INPUTS) return;
771
772  const auto& input_array =
773      model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]);
774  // Yield until all input dims have been resolved.
775  if (!input_array.has_shape()) {
776    return;
777  }
778  const auto& input_shape = input_array.shape();
779  CHECK_GE(input_shape.dimensions_count(), 2);
780
781  const auto& prev_activ_array =
782      model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]);
783  // Yield until all input dims have been resolved.
784  if (!prev_activ_array.has_shape()) {
785    return;
786  }
787  const auto& prev_activ_shape = prev_activ_array.shape();
788  CHECK_GE(prev_activ_shape.dimensions_count(), 2);
789
790  const auto& weights_array =
791      model->GetArray(op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
792  // Yield until weights dims have been resolved.
793  if (!weights_array.has_shape()) {
794    return;
795  }
796  const auto& weights_shape = weights_array.shape();
797  CHECK_EQ(weights_shape.dimensions_count(), 2);
798
799  const auto& bias_array =
800      model->GetArray(op->inputs[LstmCellOperator::BIASES_INPUT]);
801  // Yield until bias dims have been resolved.
802  if (!bias_array.has_shape()) {
803    return;
804  }
805  const auto& bias_shape = bias_array.shape();
806  CHECK_GE(bias_shape.dimensions_count(), 1);
807
808  const auto& prev_state_array =
809      model->GetArray(op->inputs[LstmCellOperator::PREV_STATE_INPUT]);
810  // Yield until all input dims have been resolved.
811  if (!prev_state_array.has_shape()) {
812    return;
813  }
814  const auto& prev_state_shape = prev_state_array.shape();
815  CHECK_GE(prev_state_shape.dimensions_count(), 2);
816
817  const int fc_output_depth = weights_shape.dims(0);
818  CHECK_EQ(fc_output_depth, bias_shape.dims(0));
819  CHECK_EQ(fc_output_depth % 4, 0);
820  const int depth = fc_output_depth / 4;
821
822  const int input_depth = input_shape.dims(input_shape.dimensions_count() - 1);
823  const int fc_input_depth = weights_shape.dims(1);
824  CHECK_EQ(input_depth + depth, fc_input_depth);
825  Shape output_shape(input_shape);
826  (*output_shape.mutable_dims())[output_shape.dimensions_count() - 1] = depth;
827
828  // Set output dimensions
829  model->GetArray(op->outputs[LstmCellOperator::STATE_OUTPUT])
830      .copy_shape(output_shape);
831  model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT])
832      .copy_shape(output_shape);
833
834  Shape concat_temp_shape(input_shape);
835  (*concat_temp_shape
836        .mutable_dims())[concat_temp_shape.dimensions_count() - 1] =
837      fc_input_depth;
838  model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP])
839      .copy_shape(concat_temp_shape);
840
841  Shape activ_temp_shape(input_shape);
842  (*activ_temp_shape.mutable_dims())[activ_temp_shape.dimensions_count() - 1] =
843      fc_output_depth;
844  model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP])
845      .copy_shape(activ_temp_shape);
846}
847
848void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
849  const auto& input_array = model->GetArray(op->inputs[0]);
850  // Yield until input dims have been resolved.
851  if (!input_array.has_shape()) {
852    return;
853  }
854  const auto& input_shape = input_array.shape();
855  // This method only handles input dimensions of 4.
856  if (input_shape.dimensions_count() != 4) {
857    return;
858  }
859  const auto input_height = input_shape.dims(1);
860  const auto input_width = input_shape.dims(2);
861
862  const auto& block_shape_array = model->GetArray(op->inputs[1]);
863  const auto& paddings_array = model->GetArray(op->inputs[2]);
864  const auto& block_shape_array_shape = block_shape_array.shape();
865  const auto& paddings_array_shape = paddings_array.shape();
866  QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
867  QCHECK_EQ(paddings_array_shape.dimensions_count(), 2);
868
869  // We only support two dimensions.
870  QCHECK_EQ(block_shape_array_shape.dims(0), 2);
871  if (!block_shape_array.buffer) {
872    return;
873  }
874  QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
875  const auto& block_shape_data =
876      block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
877  auto block_height = block_shape_data[0];
878  auto block_width = block_shape_data[1];
879
880  QCHECK_EQ(paddings_array_shape.dims(0), 2);  // Number of block dimensions
881  QCHECK_EQ(paddings_array_shape.dims(1), 2);  // Two parameters per dimension.
882  if (!paddings_array.buffer) {
883    return;
884  }
885  QCHECK(paddings_array.data_type == ArrayDataType::kInt32);
886  const auto& paddings_data =
887      paddings_array.GetBuffer<ArrayDataType::kInt32>().data;
888  int height_with_paddings = input_height + paddings_data[0] + paddings_data[1];
889  int width_with_paddings = input_width + paddings_data[2] + paddings_data[3];
890  QCHECK_EQ(height_with_paddings % block_height, 0);
891  QCHECK_EQ(width_with_paddings % block_width, 0);
892  int output_height = height_with_paddings / block_height;
893  int output_width = width_with_paddings / block_width;
894
895  model->GetArray(op->outputs[0])
896      .copy_shape(Shape({input_shape.dims(0) * block_height * block_width,
897                         output_height, output_width, input_shape.dims(3)}));
898}
899
900void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
901  const auto& input_array = model->GetArray(op->inputs[0]);
902  // Yield until input dims have been resolved.
903  if (!input_array.has_shape()) {
904    return;
905  }
906  const auto& input_shape = input_array.shape();
907  CHECK_EQ(input_shape.dimensions_count(), 4);
908  const auto input_height = input_shape.dims(1);
909  const auto input_width = input_shape.dims(2);
910
911  const auto& block_shape_array = model->GetArray(op->inputs[1]);
912  const auto& crops_array = model->GetArray(op->inputs[2]);
913  const auto& block_shape_array_shape = block_shape_array.shape();
914  const auto& crops_array_shape = crops_array.shape();
915  QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
916  QCHECK_EQ(crops_array_shape.dimensions_count(), 2);
917
918  // We only support two dimensions.
919  QCHECK_EQ(block_shape_array_shape.dims(0), 2);
920  if (!block_shape_array.buffer) {
921    return;
922  }
923  QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
924  const auto& block_shape_data =
925      block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
926  auto block_height = block_shape_data[0];
927  auto block_width = block_shape_data[1];
928
929  QCHECK_EQ(crops_array_shape.dims(0), 2);  // Number of block dimensions
930  QCHECK_EQ(crops_array_shape.dims(1), 2);  // Two parameters per dimension.
931  if (!crops_array.buffer) {
932    return;
933  }
934  QCHECK(crops_array.data_type == ArrayDataType::kInt32);
935  const auto& crops_data = crops_array.GetBuffer<ArrayDataType::kInt32>().data;
936  // We don't support crops now.
937  QCHECK_EQ(crops_data[0], 0);
938  QCHECK_EQ(crops_data[1], 0);
939  QCHECK_EQ(crops_data[2], 0);
940  QCHECK_EQ(crops_data[3], 0);
941
942  QCHECK_EQ(input_shape.dims(0) % (block_height * block_width), 0);
943
944  int output_height = input_height * block_height;
945  int output_width = input_width * block_width;
946
947  model->GetArray(op->outputs[0])
948      .copy_shape(Shape({input_shape.dims(0) / (block_height * block_width),
949                         output_height, output_width, input_shape.dims(3)}));
950}
951
952void ProcessGatherOperator(Model* model, GatherOperator* op) {
953  const auto& input_array = model->GetArray(op->inputs[0]);
954  const auto& indices_array = model->GetArray(op->inputs[1]);
955  auto& output_array = model->GetArray(op->outputs[0]);
956
957  // Bail if we already know the output shape.
958  if (output_array.has_shape()) {
959    return;
960  }
961
962  // Yield until input dims have been resolved.
963  if (!input_array.has_shape() || !indices_array.has_shape()) {
964    return;
965  }
966
967  const auto& input_shape = input_array.shape();
968  const auto& indices_shape = indices_array.shape();
969  QCHECK_GE(input_shape.dimensions_count(), 1);
970  op->input_rank = input_shape.dimensions_count();
971
972  // We only support 1-D indices.
973  QCHECK_EQ(indices_shape.dimensions_count(), 1);
974
975  // Copy the input dimensions to the output except for dimension 0,
976  // where the dimension of indices_shape is used.
977  // TODO(mgubin): if axis != 0 this is not true, change when it's supported.
978  auto output_dims = output_array.mutable_shape()->mutable_dims();
979  output_dims->push_back(indices_shape.dims(0));
980  for (int dim = 1; dim < input_shape.dimensions_count(); dim++) {
981    output_dims->push_back(input_shape.dims(dim));
982  }
983}
984
985void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) {
986  const auto& input_values = model->GetArray(op->inputs[0]);
987  const auto& input_k = model->GetArray(op->inputs[1]);
988  auto& output_indexes = model->GetArray(op->outputs[0]);
989  auto& output_values = model->GetArray(op->outputs[1]);
990
991  // Bail if we already know the output shape.
992  if (output_indexes.has_shape()) {
993    QCHECK(output_values.has_shape());
994    return;
995  }
996
997  // Yield until input dims have been resolved.
998  if (!input_values.has_shape()) {
999    return;
1000  }
1001
1002  const auto& input_values_shape = input_values.shape();
1003  auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims();
1004  auto output_values_dims = output_values.mutable_shape()->mutable_dims();
1005  for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) {
1006    output_indexes_dims->push_back(input_values_shape.dims(dim));
1007    output_values_dims->push_back(input_values_shape.dims(dim));
1008  }
1009  // If the value is initialized, we can specify the last dimension, otherwise
1010  // unknown.
1011  if (input_k.buffer) {
1012    const int32_t k_value = input_k.GetBuffer<ArrayDataType::kInt32>().data[0];
1013    output_indexes_dims->push_back(k_value);
1014    output_values_dims->push_back(k_value);
1015
1016  } else {
1017    output_indexes_dims->push_back(0);
1018    output_values_dims->push_back(0);
1019  }
1020}
1021
1022void ProcessPadOperator(Model* model, PadOperator* op) {
1023  CHECK_EQ(op->inputs.size(), 2);
1024  CHECK_EQ(op->outputs.size(), 1);
1025
1026  const auto& input_array = model->GetArray(op->inputs[0]);
1027
1028  // Yield until input dims have been resolved.
1029  if (!input_array.has_shape()) return;
1030
1031  if (op->left_padding.empty()) return;
1032  CHECK_EQ(op->left_padding.size(), op->right_padding.size());
1033
1034  auto& output_array = model->GetArray(op->outputs[0]);
1035  if (output_array.has_shape()) return;
1036
1037  Shape output_shape = input_array.shape();
1038  std::vector<int>& dims = *output_shape.mutable_dims();
1039  CHECK_EQ(op->left_padding.size(), dims.size());
1040
1041  for (int i = 0; i < op->left_padding.size(); ++i) {
1042    dims[i] += op->left_padding[i] + op->right_padding[i];
1043  }
1044
1045  output_array.copy_shape(output_shape);
1046}
1047
1048void ProcessRankOperator(Model* model, RankOperator* op) {
1049  CHECK_GE(op->inputs.size(), 1);
1050  CHECK_EQ(op->outputs.size(), 1);
1051  auto& output_array = model->GetArray(op->outputs[0]);
1052  if (output_array.has_shape()) {
1053    // Shape already propagated
1054    return;
1055  }
1056
1057  const auto& input_array = model->GetArray(op->inputs[0]);
1058  if (!input_array.has_shape()) {
1059    // Yield until input dims have been resolved.
1060    return;
1061  }
1062
1063  // Only set the output shape. Array contents are set by
1064  // ResolveConstantShapeOrRank.
1065  Shape* output_shape = output_array.mutable_shape();
1066  output_shape->ReplaceDims({});
1067}
1068
1069void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
1070  CHECK_GE(op->inputs.size(), 1);
1071  CHECK_EQ(op->outputs.size(), 1);
1072  auto& output_array = model->GetArray(op->outputs[0]);
1073  if (output_array.has_shape()) {
1074    // Shape already propagated
1075    return;
1076  }
1077
1078  const auto& input_array = model->GetArray(op->inputs[0]);
1079  if (!input_array.has_shape()) {
1080    // Yield until input dims have been resolved.
1081    return;
1082  }
1083
1084  // Only set the output shape. Array contents are set by
1085  // ResolveConstantShapeOrRank.
1086  Shape* output_shape = output_array.mutable_shape();
1087  output_shape->ReplaceDims({input_array.shape().dimensions_count()});
1088}
1089
1090void ProcessStackOperator(Model* model, StackOperator* op) {
1091  CHECK_GE(op->inputs.size(), 1);
1092  CHECK_EQ(op->outputs.size(), 1);
1093  auto& output_array = model->GetArray(op->outputs[0]);
1094  if (output_array.has_shape()) {
1095    // Shape already propagated
1096    return;
1097  }
1098
1099  std::unique_ptr<Shape> stacked_shape;
1100  for (const auto& input : op->inputs) {
1101    const auto& input_array = model->GetArray(input);
1102    if (!input_array.has_shape()) {
1103      // Yield until all input dims have been resolved.
1104      return;
1105    }
1106
1107    Shape shape = input_array.shape();
1108    if (shape.dimensions_count() == 0) {
1109      // Convert 0D scalars to 1D scalars of shape {1}.
1110      shape.mutable_dims()->push_back(1);
1111    }
1112    if (!stacked_shape) {
1113      stacked_shape.reset(new Shape(shape));
1114    } else {
1115      CHECK(*stacked_shape == shape) << "All input arrays to Stack operators "
1116                                        "must have the same shape. Input \""
1117                                     << input << "\" is different.";
1118    }
1119  }
1120
1121  int axis = op->axis;
1122  if (axis < 0) {
1123    // Handle negative axis
1124    axis += stacked_shape->dims().size() + 1;
1125  }
1126  stacked_shape->mutable_dims()->insert(
1127      stacked_shape->mutable_dims()->begin() + axis, op->inputs.size());
1128  output_array.copy_shape(*stacked_shape);
1129}
1130
1131void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
1132  CHECK_GE(op->inputs.size(), 1);
1133  CHECK_EQ(op->outputs.size(), 1);
1134  auto& output_array = model->GetArray(op->outputs[0]);
1135  if (output_array.has_shape()) {
1136    // Shape already propagated
1137    return;
1138  }
1139
1140  if (op->start_indices.empty() || op->stop_indices.empty() ||
1141      op->strides.empty()) {
1142    // ResolveStridedSliceAttributes has not run yet.
1143    return;
1144  }
1145
1146  const auto& input_array = model->GetArray(op->inputs[0]);
1147  if (!input_array.has_shape()) {
1148    // Yield until input dims have been resolved.
1149    return;
1150  }
1151
1152  if (op->ellipsis_mask != 0) {
1153    // Something like LOG_FIRST_N(WARNING, 10) would be prefferable to reduce
1154    // log noise. However, the TensorFlow logging library does not appear to
1155    // support this.
1156    LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
1157                 << "\". ellipsis_mask is not supported (mask="
1158                 << op->ellipsis_mask << ")";
1159    return;
1160  }
1161  if (op->new_axis_mask != 0) {
1162    LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
1163                 << "\". new_axis_mask is not supported (mask="
1164                 << op->new_axis_mask << ")";
1165    return;
1166  }
1167
1168  int dim_count = input_array.shape().dimensions_count();
1169  CHECK(op->start_indices.size() == dim_count)
1170      << ": Incorrect number of start indices supplied to StridedSlice op with "
1171         "output \""
1172      << op->outputs[0] << "\". Op requires " << dim_count << " start indices";
1173  CHECK(op->stop_indices.size() == dim_count)
1174      << ": Incorrect number of stop indices supplied to StridedSlice op with "
1175         "output \""
1176      << op->outputs[0] << "\". Op requires " << dim_count << " stop indices";
1177  CHECK(op->strides.size() == dim_count)
1178      << ": Incorrect number of strides supplied to StridedSlice op with "
1179         " output \""
1180      << op->outputs[0] << "\". Op requires " << dim_count << " strides";
1181
1182  // Create output shape
1183  std::vector<int>* dims = output_array.mutable_shape()->mutable_dims();
1184
1185  // Compute output shape
1186  for (int i = 0; i < dim_count; ++i) {
1187    const int mask = 1 << i;
1188    int start = (op->begin_mask & mask) ? 0 : op->start_indices[i];
1189    if (start < 0) {
1190      // handle negative indices
1191      start += input_array.shape().dims(i);
1192    }
1193    int stop = (op->end_mask & mask) ? input_array.shape().dims(i)
1194                                     : op->stop_indices[i];
1195    if (stop < 0) {
1196      // handle negative indices
1197      stop += input_array.shape().dims(i);
1198    }
1199
1200    int dim_size = ceil((stop - start) / static_cast<float>(op->strides[i]));
1201    dim_size = dim_size < 0 ? 0 : dim_size;
1202    if (op->shrink_axis_mask & mask) {
1203      CHECK_EQ(dim_size, 1) << "Output size for an axis must compute to 1 when "
1204                               "shrinking that axis";
1205    } else {
1206      dims->push_back(dim_size);
1207    }
1208  }
1209}
1210
1211void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
1212  CHECK_EQ(op->inputs.size(), 1);
1213  CHECK_EQ(op->outputs.size(), 1);
1214
1215  const auto& input_array = model->GetArray(op->inputs[0]);
1216
1217  // Yield until input dims have been resolved.
1218  if (!input_array.has_shape()) return;
1219
1220  auto& output_array = model->GetArray(op->outputs[0]);
1221  if (output_array.has_shape()) return;
1222
1223  const std::vector<int>& input_dims = input_array.shape().dims();
1224  std::vector<int> output_dims;
1225
1226  for (int i = 0; i < input_dims.size(); ++i) {
1227    if (input_dims[i] != 1 ||
1228        (!op->squeeze_dims.empty() &&
1229         std::find(op->squeeze_dims.begin(), op->squeeze_dims.end(), i) ==
1230             op->squeeze_dims.end())) {
1231      output_dims.push_back(input_dims[i]);
1232    }
1233  }
1234  *output_array.mutable_shape()->mutable_dims() = output_dims;
1235}
1236
1237void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
1238  CHECK(op->inputs.size() == 3 || op->inputs.size() == 4);
1239  const auto& input_array = model->GetArray(op->inputs[0]);
1240  if (!input_array.has_shape()) return;
1241
1242  auto& weights_feature_array = model->GetArray(op->inputs[1]);
1243  if (!weights_feature_array.has_shape()) return;
1244
1245  const auto& weights_time_array = model->GetArray(op->inputs[2]);
1246  if (!weights_time_array.has_shape()) return;
1247
1248  const bool has_bias = (op->inputs.size() == 4);
1249  if (has_bias) {
1250    const auto& bias_array = model->GetArray(op->inputs[3]);
1251    if (!bias_array.has_shape()) return;
1252  }
1253
1254  const int batch_size = input_array.shape().dims()[0];
1255  const int num_units = weights_feature_array.shape().dims()[0];
1256  const int memory_size = weights_time_array.shape().dims()[1];
1257
1258  auto& state_array = model->GetArray(op->outputs[0]);
1259  state_array.mutable_shape()->ReplaceDims(
1260      {batch_size, memory_size * num_units});
1261
1262  auto& output_array = model->GetArray(op->outputs[1]);
1263  output_array.mutable_shape()->ReplaceDims({batch_size, num_units});
1264}
1265
1266void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
1267  auto& output_array = model->GetArray(op->outputs[0]);
1268  if (output_array.has_shape()) {
1269    // We have already run
1270    return;
1271  }
1272
1273  const auto& input_array = model->GetArray(op->inputs[0]);
1274  if (!input_array.has_shape()) {
1275    // Yield until input dims have been resolved.
1276    return;
1277  }
1278  const auto& input_shape = input_array.shape();
1279
1280  auto& perm_array = model->GetArray(op->inputs[1]);
1281  if (!perm_array.has_shape()) {
1282    // Yield until permutation shape been resolved.
1283    return;
1284  }
1285  if (!perm_array.buffer) {
1286    // Yield until the permutation is constant
1287    return;
1288  }
1289  CHECK(perm_array.data_type == ArrayDataType::kInt32)
1290      << "Transpose permutation input must be int32";
1291
1292  std::vector<int32> const& perm =
1293      perm_array.GetBuffer<ArrayDataType::kInt32>().data;
1294  CHECK_EQ(perm.size(), input_shape.dimensions_count())
1295      << "Transpose permutation input " << op->inputs[0]
1296      << " must be same length as input dimensions";
1297  std::vector<int>* output_dims = output_array.mutable_shape()->mutable_dims();
1298  for (int i = 0; i < perm.size(); i++) {
1299    int axis = perm[i];
1300    CHECK_GE(axis, 0);
1301    CHECK_LT(axis, input_shape.dimensions_count());
1302    output_dims->push_back(input_shape.dims(axis));
1303  }
1304}
1305
1306void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
1307  CHECK_EQ(op->inputs.size(), 2);
1308  const auto& input_array = model->GetArray(op->inputs[0]);
1309  // Yield until input dims have been resolved.
1310  if (!input_array.has_shape()) {
1311    return;
1312  }
1313
1314  // The current ArgMax implementation only supports 4-dimensional inputs with
1315  // the last dimension as the axis to perform ArgMax for.
1316  const std::vector<int>& input_dims = input_array.shape().dims();
1317  CHECK_EQ(input_dims.size(), 4);
1318  std::vector<int> output_dims;
1319
1320  output_dims.reserve(input_dims.size() - 1);
1321  for (int i = 0; i < input_dims.size() - 1; ++i) {
1322    output_dims.push_back(input_dims[i]);
1323  }
1324  output_dims.push_back(1);
1325  const string& output_name = op->outputs[0];
1326  auto& output_array = model->GetArray(output_name);
1327  if (output_array.has_shape()) {
1328    return;
1329  }
1330  *output_array.mutable_shape()->mutable_dims() = output_dims;
1331}
1332
1333}  // namespace
1334
1335bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
1336  auto it = model->operators.begin() + op_index;
1337  auto* op = it->get();
1338  std::unordered_map<string, std::vector<int>> old_output_dims;
1339  for (const auto& output : op->outputs) {
1340    if (model->GetArray(output).has_shape()) {
1341      old_output_dims[output] = model->GetArray(output).shape().dims();
1342    }
1343  }
1344
1345  switch (op->type) {
1346    case OperatorType::kBatchNormalization:
1347    case OperatorType::kL2Normalization:
1348    case OperatorType::kDequantize:
1349    case OperatorType::kRelu:
1350    case OperatorType::kRelu1:
1351    case OperatorType::kRelu6:
1352    case OperatorType::kSoftmax:
1353    case OperatorType::kLogSoftmax:
1354    case OperatorType::kLogistic:
1355    case OperatorType::kTanh:
1356    case OperatorType::kLocalResponseNormalization:
1357    case OperatorType::kTensorFlowIdentity:
1358    case OperatorType::kFakeQuant:
1359    case OperatorType::kNeg:
1360    case OperatorType::kTensorFlowRsqrt:
1361    case OperatorType::kTensorFlowSqrt:
1362    case OperatorType::kTensorFlowSquare:
1363    case OperatorType::kTensorFlowAll:
1364    case OperatorType::kTensorFlowAssert:
1365    case OperatorType::kCast:
1366    case OperatorType::kFloor:
1367    case OperatorType::kExp:
1368      ProcessSimpleOperator(model, op);
1369      break;
1370    case OperatorType::kGather:
1371      ProcessGatherOperator(model, static_cast<GatherOperator*>(op));
1372      break;
1373    case OperatorType::kTopK_V2:
1374      ProcessTopkV2Operator(model, static_cast<TopKV2Operator*>(op));
1375      break;
1376    case OperatorType::kAdd:
1377    case OperatorType::kSub:
1378    case OperatorType::kMul:
1379    case OperatorType::kDiv:
1380    case OperatorType::kFloorDiv:
1381    case OperatorType::kFloorMod:
1382    case OperatorType::kTensorFlowLess:
1383    case OperatorType::kTensorFlowLessEqual:
1384    case OperatorType::kTensorFlowGreater:
1385    case OperatorType::kTensorFlowMaximum:
1386    case OperatorType::kTensorFlowMinimum:
1387    case OperatorType::kTensorFlowGreaterEqual:
1388      ProcessSimpleBinaryOperator(model, op);
1389      break;
1390    case OperatorType::kAddN:
1391      ProcessAddNOperator(model, op);
1392      break;
1393    case OperatorType::kConv:
1394      ProcessConvOperator(model, static_cast<ConvOperator*>(op));
1395      break;
1396    case OperatorType::kTransposeConv:
1397      // Unimplemented, hopefully another graph transformation will drop it or
1398      // rewrite it.
1399      break;
1400    case OperatorType::kDepthwiseConv:
1401      ProcessDepthwiseConvOperator(model,
1402                                   static_cast<DepthwiseConvOperator*>(op));
1403      break;
1404    case OperatorType::kDepthToSpace:
1405      ProcessDepthToSpaceOperator(model,
1406                                  static_cast<DepthToSpaceOperator*>(op));
1407      break;
1408    case OperatorType::kSpaceToDepth:
1409      ProcessSpaceToDepthOperator(model,
1410                                  static_cast<SpaceToDepthOperator*>(op));
1411      break;
1412    case OperatorType::kFill:
1413      ProcessFillOperator(model, static_cast<FillOperator*>(op));
1414      break;
1415    case OperatorType::kFullyConnected:
1416      ProcessFullyConnectedOperator(model,
1417                                    static_cast<FullyConnectedOperator*>(op));
1418      break;
1419    case OperatorType::kTensorFlowReshape:
1420      ProcessTensorFlowReshapeOperator(
1421          model, static_cast<TensorFlowReshapeOperator*>(op));
1422      break;
1423    case OperatorType::kAveragePool:
1424      ProcessAveragePoolOperator(model, static_cast<AveragePoolOperator*>(op));
1425      break;
1426    case OperatorType::kMaxPool:
1427      ProcessMaxPoolOperator(model, static_cast<MaxPoolOperator*>(op));
1428      break;
1429    case OperatorType::kL2Pool:
1430      ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op));
1431      break;
1432    case OperatorType::kTensorFlowMin:
1433    case OperatorType::kTensorFlowMax:
1434    case OperatorType::kTensorFlowSum:
1435    case OperatorType::kMean:
1436      ProcessTensorFlowReductionOperator(model, op);
1437      break;
1438
1439    case OperatorType::kSlice:
1440      ProcessSliceOperator(model, static_cast<SliceOperator*>(op));
1441      break;
1442
1443    case OperatorType::kTensorFlowTile:
1444      // We don't currently implement the propagation of fixed sizes through
1445      // a TensorFlow Tile.
1446      //
1447      // Fortunately, we don't need to: so far, we have only dealt with Tile
1448      // or Slice ops in subgraphs that are identified as L2Normalization.
1449      // See IdentifyL2Normalization.
1450      break;
1451    case OperatorType::kTensorFlowSwitch:
1452      // We can't know the sizes of the outputs until we have resolved the
1453      // predicate, and once we have resolved the predicate, the whole
1454      // Switch node will get resolved away.
1455      // See ResolveTensorFlowSwitch.
1456      break;
1457    case OperatorType::kTensorFlowMerge:
1458      // No need to bother resolving TensorFlow Merge ops: other graph
1459      // transformations will remove them anyway.
1460      // See ResolveTensorFlowMerge.
1461      break;
1462    case OperatorType::kTensorFlowSplit:
1463      ProcessTensorFlowSplitOperator(model,
1464                                     static_cast<TensorFlowSplitOperator*>(op));
1465      break;
1466    case OperatorType::kSqueeze:
1467      ProcessSqueezeOperator(model, static_cast<SqueezeOperator*>(op));
1468      break;
1469    case OperatorType::kTensorFlowConcat:
1470    case OperatorType::kTensorFlowConcatV2:
1471      // Unimplemented, hopefully another graph transformation will
1472      // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat
1473      // will resolve this node to a DepthConcatenation, or else we have
1474      // a more general non-depth concatenation that will hopefully be dropped,
1475      // or else at the moment we will abort.
1476      break;
1477    case OperatorType::kExpandDims:
1478      // Yield until ExpandDims is converted to Reshape
1479      break;
1480    case OperatorType::kRange:
1481      ProcessRangeOperator(model, static_cast<RangeOperator*>(op));
1482      break;
1483    case OperatorType::kRank:
1484      ProcessRankOperator(model, static_cast<RankOperator*>(op));
1485      break;
1486    case OperatorType::kTensorFlowShape:
1487      ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op));
1488      break;
1489    case OperatorType::kStack:
1490      ProcessStackOperator(model, static_cast<StackOperator*>(op));
1491      break;
1492    case OperatorType::kReorderAxes:
1493      ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op));
1494      break;
1495    case OperatorType::kConcatenation:
1496      ProcessConcatenationOperator(model,
1497                                   static_cast<ConcatenationOperator*>(op));
1498      break;
1499    case OperatorType::kResizeBilinear:
1500      ProcessResizeBilinearOperator(model,
1501                                    static_cast<ResizeBilinearOperator*>(op));
1502      break;
1503    case OperatorType::kLstmCell:
1504      ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
1505      break;
1506    case OperatorType::kBatchMatMul:
1507    case OperatorType::kTensorFlowMatMul:
1508      // MatMul operators are converted to FullyConnected, after which their
1509      // shapes are propagated.
1510      break;
1511    case OperatorType::kSpaceToBatchND:
1512      ProcessSpaceToBatchNDOperator(model,
1513                                    static_cast<SpaceToBatchNDOperator*>(op));
1514      break;
1515    case OperatorType::kBatchToSpaceND:
1516      ProcessBatchToSpaceNDOperator(model,
1517                                    static_cast<BatchToSpaceNDOperator*>(op));
1518      break;
1519    case OperatorType::kPad:
1520      ProcessPadOperator(model, static_cast<PadOperator*>(op));
1521      break;
1522    case OperatorType::kStridedSlice:
1523      ProcessStridedSliceOperator(model,
1524                                  static_cast<StridedSliceOperator*>(op));
1525      break;
1526    case OperatorType::kArgMax:
1527      ProcessArgMaxOperator(model, static_cast<ArgMaxOperator*>(op));
1528      break;
1529    case OperatorType::kTensorFlowUnsupported:
1530      break;
1531    case OperatorType::kSvdf:
1532      ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op));
1533      break;
1534    case OperatorType::kTranspose:
1535      ProcessTransposeOperator(model, static_cast<TransposeOperator*>(op));
1536      break;
1537    default:
1538      // Unimplemented, another graph transformation should drop it.
1539      LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
1540  }
1541
1542  // Return true if any output dim changed, false if none changed.
1543  // Assumption: no transformation clears an output shape, they only add shapes.
1544  for (const auto& output : op->outputs) {
1545    if (model->GetArray(output).has_shape() &&
1546        (old_output_dims[output] != model->GetArray(output).shape().dims())) {
1547      AddMessageF("Set shape of %s to [%s]", output,
1548                  absl::StrJoin(model->GetArray(output).shape().dims(), ","));
1549      return true;
1550    }
1551  }
1552  return false;
1553}
1554
1555}  // namespace toco
1556