129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License");
429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFloweryou may not use this file except in compliance with the License.
529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowerYou may obtain a copy of the License at
629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    http://www.apache.org/licenses/LICENSE-2.0
829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software
1029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS,
1129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowerSee the License for the specific language governing permissions and
1329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowerlimitations under the License.
1429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower==============================================================================*/
1529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower#include <vector>
1629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
1729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
1829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower#include "tensorflow/contrib/lite/toco/model.h"
1929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower#include "tensorflow/contrib/lite/toco/tooling_util.h"
2029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower#include "tensorflow/core/platform/logging.h"
2129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
2229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowernamespace toco {
2329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
2429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowernamespace {
2529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
2629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowertemplate <ArrayDataType Type>
2729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowervoid Stack(Model* model, StackOperator const& op) {
2829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  auto& output_array = model->GetArray(op.outputs[0]);
2929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  CHECK(output_array.data_type == Type);
3029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
3129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  // Create a buffer for the output array
3229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  std::vector<DataType<Type>>& output_data =
3329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower      output_array.GetMutableBuffer<Type>().data;
3429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  output_data.resize(RequiredBufferSizeForShape(output_array.shape()));
3529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
3629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  // Stack inputs into buffer
3729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  CHECK_EQ(op.axis, 0) << "Stacking only supported along first axis";
3829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  int dst_offset = 0;
3929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  for (int i = 0; i < op.inputs.size(); i++) {
4029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    // Append array data to output for each input array
4129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    const auto& input_array = model->GetArray(op.inputs[i]);
4229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    int input_size = RequiredBufferSizeForShape(input_array.shape());
4329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    memcpy(&output_data[dst_offset], &input_array.GetBuffer<Type>().data[0],
4429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower           input_size * sizeof(Type));
4529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    dst_offset += input_size;
4629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  }
4729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  CHECK_EQ(dst_offset, output_data.size());
4829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower}
4929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
5029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower}  // namespace
5129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
5229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowerbool ResolveConstantStack::Run(Model* model, std::size_t op_index) {
5329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  auto it = model->operators.begin() + op_index;
5429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  const auto* base_op = it->get();
5529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  if (base_op->type != OperatorType::kStack) {
5629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    return false;
5729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  }
5829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  const auto* op = static_cast<const StackOperator*>(base_op);
5929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
6029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  CHECK_GE(op->inputs.size(), 1);
6129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  CHECK_EQ(op->outputs.size(), 1);
6229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  auto& output_array = model->GetArray(op->outputs[0]);
6329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  if (output_array.data_type == ArrayDataType::kNone) {
6429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    // Yield until the output type has been set by PropagateArrayDataTypes
6529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    return false;
6629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  }
6729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
6829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  if (!output_array.has_shape()) {
6929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    // Yield until the output shape has been set by PropagateFixedShapes
7029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    return false;
7129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  }
7229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
7329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  for (const auto& input : op->inputs) {
7429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    if (!IsConstantParameterArray(*model, input)) {
7529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower      // Yield if any input is mutable
7629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower      return false;
7729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    }
7829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  }
7929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
8029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  CHECK(!output_array.buffer);
8129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  switch (output_array.data_type) {
8229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    case ArrayDataType::kFloat:
8329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower      Stack<ArrayDataType::kFloat>(model, *op);
8429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower      break;
8529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    case ArrayDataType::kUint8:
8629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower      Stack<ArrayDataType::kUint8>(model, *op);
8729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower      break;
8829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    case ArrayDataType::kInt32:
8929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower      Stack<ArrayDataType::kInt32>(model, *op);
9029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower      break;
9129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    case ArrayDataType::kInt64:
9229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower      Stack<ArrayDataType::kInt64>(model, *op);
9329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower      break;
9429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    default:
9529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower      LOG(FATAL) << "Unsupported data type given to Stack op with output \""
9629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower                 << op->outputs[0] << "\"";
9729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower      break;
9829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  }
9929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
10029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  // Erase input arrays if no longer used
10129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  for (const auto& input : op->inputs) {
10229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    if (IsDiscardableArray(*model, input) &&
10329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower        CountOpsWithInput(*model, input) == 1) {
104ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower      model->EraseArray(input);
10529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower    }
10629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  }
10729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
10829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  // Erase the operator
10929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  model->operators.erase(it);
11029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower  return true;
11129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower}
11229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower
11329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower}  // namespace toco
114