129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License"); 429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFloweryou may not use this file except in compliance with the License. 529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowerYou may obtain a copy of the License at 629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower http://www.apache.org/licenses/LICENSE-2.0 829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software 1029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS, 1129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowerSee the License for the specific language governing permissions and 1329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowerlimitations under the License. 1429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower==============================================================================*/ 1529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower#include <vector> 1629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 1729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 1829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower#include "tensorflow/contrib/lite/toco/model.h" 1929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower#include "tensorflow/contrib/lite/toco/tooling_util.h" 2029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower#include "tensorflow/core/platform/logging.h" 2129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 2229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowernamespace toco { 2329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 2429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowernamespace { 2529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 2629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowertemplate <ArrayDataType Type> 2729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowervoid Stack(Model* model, StackOperator const& op) { 2829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower auto& output_array = model->GetArray(op.outputs[0]); 2929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower CHECK(output_array.data_type == Type); 3029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 3129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower // Create a buffer for the output array 3229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower std::vector<DataType<Type>>& output_data = 3329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower output_array.GetMutableBuffer<Type>().data; 3429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower output_data.resize(RequiredBufferSizeForShape(output_array.shape())); 3529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 3629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower // Stack inputs into buffer 3729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower CHECK_EQ(op.axis, 0) << "Stacking only supported along first axis"; 3829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower int dst_offset = 0; 3929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower for (int i = 0; i < op.inputs.size(); i++) { 4029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower // Append array data to output for each input array 4129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower const auto& input_array = model->GetArray(op.inputs[i]); 4229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower int input_size = RequiredBufferSizeForShape(input_array.shape()); 4329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower memcpy(&output_data[dst_offset], &input_array.GetBuffer<Type>().data[0], 4429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower input_size * sizeof(Type)); 4529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower dst_offset += input_size; 4629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower } 4729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower CHECK_EQ(dst_offset, output_data.size()); 4829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower} 4929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 5029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower} // namespace 5129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 5229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlowerbool ResolveConstantStack::Run(Model* model, std::size_t op_index) { 5329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower auto it = model->operators.begin() + op_index; 5429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower const auto* base_op = it->get(); 5529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower if (base_op->type != OperatorType::kStack) { 5629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower return false; 5729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower } 5829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower const auto* op = static_cast<const StackOperator*>(base_op); 5929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 6029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower CHECK_GE(op->inputs.size(), 1); 6129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower CHECK_EQ(op->outputs.size(), 1); 6229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower auto& output_array = model->GetArray(op->outputs[0]); 6329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower if (output_array.data_type == ArrayDataType::kNone) { 6429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower // Yield until the output type has been set by PropagateArrayDataTypes 6529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower return false; 6629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower } 6729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 6829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower if (!output_array.has_shape()) { 6929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower // Yield until the output shape has been set by PropagateFixedShapes 7029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower return false; 7129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower } 7229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 7329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower for (const auto& input : op->inputs) { 7429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower if (!IsConstantParameterArray(*model, input)) { 7529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower // Yield if any input is mutable 7629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower return false; 7729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower } 7829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower } 7929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 8029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower CHECK(!output_array.buffer); 8129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower switch (output_array.data_type) { 8229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower case ArrayDataType::kFloat: 8329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower Stack<ArrayDataType::kFloat>(model, *op); 8429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower break; 8529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower case ArrayDataType::kUint8: 8629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower Stack<ArrayDataType::kUint8>(model, *op); 8729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower break; 8829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower case ArrayDataType::kInt32: 8929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower Stack<ArrayDataType::kInt32>(model, *op); 9029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower break; 9129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower case ArrayDataType::kInt64: 9229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower Stack<ArrayDataType::kInt64>(model, *op); 9329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower break; 9429baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower default: 9529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower LOG(FATAL) << "Unsupported data type given to Stack op with output \"" 9629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower << op->outputs[0] << "\""; 9729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower break; 9829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower } 9929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 10029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower // Erase input arrays if no longer used 10129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower for (const auto& input : op->inputs) { 10229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower if (IsDiscardableArray(*model, input) && 10329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower CountOpsWithInput(*model, input) == 1) { 104ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower model->EraseArray(input); 10529baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower } 10629baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower } 10729baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 10829baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower // Erase the operator 10929baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower model->operators.erase(it); 11029baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower return true; 11129baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower} 11229baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower 11329baea36e3b374a852ad3dedc1c3719016febdc4A. Unique TensorFlower} // namespace toco 114