1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <memory>
16#include <string>
17#include <unordered_map>
18#include <vector>
19
20#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
21#include "tensorflow/contrib/lite/toco/model.h"
22#include "tensorflow/contrib/lite/toco/tooling_util.h"
23#include "tensorflow/core/platform/logging.h"
24
25namespace toco {
26
27bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
28  auto conv_it = model->operators.begin() + op_index;
29  if (conv_it->get()->type != OperatorType::kConv) {
30    return false;
31  }
32  const auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
33  if (conv_op->stride_width != conv_op->stride_height) {
34    return false;
35  }
36  auto& weights_array = model->GetArray(conv_op->inputs[1]);
37  if (!weights_array.buffer) {
38    // Yield until the weights are resolved as a constant array.
39    return false;
40  }
41  if (weights_array.data_type != ArrayDataType::kFloat) {
42    return false;
43  }
44  if (weights_array.shape().dims(3) != 1) {
45    // Not a pure convolution: Conv does accumulation across the depth
46    // dimension.
47    return false;
48  }
49  // At this point we know we have a pure conv. Rewrite it as DepthwiseConv.
50  AddMessageF(
51      "%s is purely convolutional (input/weights depth is 1), replacing it by "
52      "a DepthwiseConv.",
53      LogName(*conv_op));
54  auto* depthwiseconv_op = new DepthwiseConvOperator;
55  // Conv and DepthwiseConv take the same inputs
56  depthwiseconv_op->inputs = conv_op->inputs;
57  // Conv may have a 2nd output for im2col
58  depthwiseconv_op->outputs = {conv_op->outputs[0]};
59  if (conv_op->outputs.size() > 1) {
60    // delete the im2col array.
61    model->EraseArray(conv_op->outputs[1]);
62  }
63  depthwiseconv_op->fused_activation_function =
64      conv_op->fused_activation_function;
65  // Let PropagateFixedSizes recompute fixed padding, just in case some day it
66  // may be different for Conv vs DepthwiseConv.
67  depthwiseconv_op->padding.type = conv_op->padding.type;
68  depthwiseconv_op->stride_height = conv_op->stride_height;
69  depthwiseconv_op->stride_width = conv_op->stride_width;
70  depthwiseconv_op->depth_multiplier = weights_array.shape().dims(0);
71  // Replace the operator in the graph.
72  const auto depthwiseconv_it =
73      model->operators.emplace(conv_it, depthwiseconv_op);
74  conv_it = depthwiseconv_it + 1;
75  CHECK_EQ(conv_it->get(), conv_op);
76  model->operators.erase(conv_it);
77  // Shuffle the weights.
78  const auto& weights_shape = weights_array.shape();
79  auto& weights_buffer =
80      weights_array.GetMutableBuffer<ArrayDataType::kFloat>();
81  const std::vector<float>& conv_weights_data = weights_buffer.data;
82  std::vector<float> depthwise_conv_weights_data(conv_weights_data.size());
83  const int depth = weights_shape.dims(0);
84  const int width = weights_shape.dims(1);
85  const int height = weights_shape.dims(2);
86  const int width_height = width * height;
87  for (int c = 0; c < depth; c++) {
88    for (int xy = 0; xy < width_height; xy++) {
89      depthwise_conv_weights_data[c + depth * xy] =
90          conv_weights_data[xy + width_height * c];
91    }
92  }
93  *weights_array.mutable_shape()->mutable_dims() = {1, width, height, depth};
94  weights_buffer.data = depthwise_conv_weights_data;
95  return true;
96}
97
98}  // namespace toco
99