10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License");
40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License.
50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at
60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    http://www.apache.org/licenses/LICENSE-2.0
80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software
100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS,
110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and
130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License.
140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/
150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <memory>
160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <string>
170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <unordered_map>
180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <vector>
190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/model.h"
220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/tooling_util.h"
230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/core/platform/logging.h"
240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace toco {
260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto conv_it = model->operators.begin() + op_index;
290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (conv_it->get()->type != OperatorType::kConv) {
300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (conv_op->stride_width != conv_op->stride_height) {
340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto& weights_array = model->GetArray(conv_op->inputs[1]);
370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!weights_array.buffer) {
380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    // Yield until the weights are resolved as a constant array.
390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (weights_array.data_type != ArrayDataType::kFloat) {
420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (weights_array.shape().dims(3) != 1) {
450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    // Not a pure convolution: Conv does accumulation across the depth
460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    // dimension.
470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // At this point we know we have a pure conv. Rewrite it as DepthwiseConv.
500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  AddMessageF(
510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      "%s is purely convolutional (input/weights depth is 1), replacing it by "
520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      "a DepthwiseConv.",
530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      LogName(*conv_op));
540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto* depthwiseconv_op = new DepthwiseConvOperator;
550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Conv and DepthwiseConv take the same inputs
560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  depthwiseconv_op->inputs = conv_op->inputs;
570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Conv may have a 2nd output for im2col
580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  depthwiseconv_op->outputs = {conv_op->outputs[0]};
590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (conv_op->outputs.size() > 1) {
600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    // delete the im2col array.
61ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower    model->EraseArray(conv_op->outputs[1]);
620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  depthwiseconv_op->fused_activation_function =
640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      conv_op->fused_activation_function;
650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Let PropagateFixedSizes recompute fixed padding, just in case some day it
660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // may be different for Conv vs DepthwiseConv.
670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  depthwiseconv_op->padding.type = conv_op->padding.type;
680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  depthwiseconv_op->stride_height = conv_op->stride_height;
690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  depthwiseconv_op->stride_width = conv_op->stride_width;
700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  depthwiseconv_op->depth_multiplier = weights_array.shape().dims(0);
710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Replace the operator in the graph.
720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto depthwiseconv_it =
730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      model->operators.emplace(conv_it, depthwiseconv_op);
740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  conv_it = depthwiseconv_it + 1;
750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK_EQ(conv_it->get(), conv_op);
760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(conv_it);
770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Shuffle the weights.
780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& weights_shape = weights_array.shape();
790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto& weights_buffer =
800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      weights_array.GetMutableBuffer<ArrayDataType::kFloat>();
810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const std::vector<float>& conv_weights_data = weights_buffer.data;
820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  std::vector<float> depthwise_conv_weights_data(conv_weights_data.size());
830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int depth = weights_shape.dims(0);
840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int width = weights_shape.dims(1);
850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int height = weights_shape.dims(2);
860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int width_height = width * height;
870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (int c = 0; c < depth; c++) {
880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    for (int xy = 0; xy < width_height; xy++) {
890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      depthwise_conv_weights_data[c + depth * xy] =
900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle          conv_weights_data[xy + width_height * c];
910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  *weights_array.mutable_shape()->mutable_dims() = {1, width, height, depth};
940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  weights_buffer.data = depthwise_conv_weights_data;
950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return true;
960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace toco
99