1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <vector>
16
17#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
18#include "tensorflow/contrib/lite/toco/model.h"
19#include "tensorflow/contrib/lite/toco/tooling_util.h"
20#include "tensorflow/core/platform/logging.h"
21
22namespace toco {
23
24template <ArrayDataType Type>
25bool ComputeFillArray(Model* model, FillOperator* op) {
26  const auto& val_array = model->GetArray(op->inputs[1]);
27  auto& output_array = model->GetArray(op->outputs[0]);
28
29  CHECK(val_array.data_type == Type);
30  CHECK(output_array.data_type == Type);
31
32  // Compute the array data
33  std::vector<DataType<Type>>& data =
34      output_array.GetMutableBuffer<Type>().data;
35  data.resize(RequiredBufferSizeForShape(output_array.shape()));
36  DataType<Type> fill_val = val_array.GetBuffer<Type>().data[0];
37  for (size_t i = 0; i < data.size(); i++) {
38    data[i] = fill_val;
39  }
40
41  return true;
42}
43
44bool ResolveConstantFill::Run(Model* model, std::size_t op_index) {
45  const auto fill_it = model->operators.begin() + op_index;
46  auto* base_op = fill_it->get();
47  if (base_op->type != OperatorType::kFill) {
48    return false;
49  }
50  auto* op = static_cast<FillOperator*>(base_op);
51
52  CHECK_EQ(op->inputs.size(), 2);
53  CHECK_EQ(op->outputs.size(), 1);
54
55  auto& output_array = model->GetArray(op->outputs[0]);
56  if (output_array.data_type == ArrayDataType::kNone) {
57    // Yield until the output type has been set by PropagateArrayDataTypes
58    return false;
59  }
60
61  if (!output_array.has_shape()) {
62    // Yield until the output shape has been set by PropagateFixedShapes
63    return false;
64  }
65
66  const auto& val_array = model->GetArray(op->inputs[1]);
67  if (!val_array.has_shape()) {
68    // Yield until the value shape has been resolved.
69    return false;
70  }
71  if (!IsConstantParameterArray(*model, op->inputs[1])) {
72    // Yield until the value is constant.
73    return false;
74  }
75  CHECK_EQ(RequiredBufferSizeForShape(val_array.shape()), 1);
76
77  switch (output_array.data_type) {
78    case ArrayDataType::kFloat:
79      if (!ComputeFillArray<ArrayDataType::kFloat>(model, op)) {
80        return false;
81      }
82      break;
83    case ArrayDataType::kUint8:
84      if (!ComputeFillArray<ArrayDataType::kUint8>(model, op)) {
85        return false;
86      }
87      break;
88    case ArrayDataType::kInt32:
89      if (!ComputeFillArray<ArrayDataType::kInt32>(model, op)) {
90        return false;
91      }
92      break;
93    case ArrayDataType::kInt64:
94      if (!ComputeFillArray<ArrayDataType::kInt64>(model, op)) {
95        return false;
96      }
97      break;
98    default:
99      LOG(FATAL) << "Unsupported data type given to Fill op with output \""
100                 << op->outputs[0] << "\"";
101      break;
102  }
103
104  // Erase input arrays if no longer used
105  if (IsDiscardableArray(*model, op->inputs[0]) &&
106      CountOpsWithInput(*model, op->inputs[0]) == 1) {
107    model->EraseArray(op->inputs[0]);
108  }
109  if (IsDiscardableArray(*model, op->inputs[1]) &&
110      CountOpsWithInput(*model, op->inputs[1]) == 1) {
111    model->EraseArray(op->inputs[1]);
112  }
113
114  // Erase the operator
115  model->operators.erase(fill_it);
116
117  return true;
118}
119
120}  // namespace toco
121