1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h"
16
17namespace toco {
18
19void CreateOptionalArray(Model* model, string* input_array_buffer,
20                         const string& array_name) {
21  *input_array_buffer = array_name;
22  model->CreateOptionalArray(array_name);
23}
24
25void CopyArrayData(const Buffer<ArrayDataType::kFloat>& src_buffer,
26                   int src_stride, int src_start_idx1, int src_start_idx2,
27                   Buffer<ArrayDataType::kFloat>* dst_buffer, int dst_stride,
28                   int dst_start_idx1, int dst_start_idx2, int dim1_copy_size,
29                   int dim2_copy_size) {
30  int src_offset = src_start_idx1 * src_stride + src_start_idx2;
31  int dst_offset = dst_start_idx1 * dst_stride + dst_start_idx2;
32  for (int i = 0; i < dim1_copy_size; i++) {
33    for (int j = 0; j < dim2_copy_size; j++) {
34      int idx_src = src_offset + i * src_stride + j;
35      int idx_dst = dst_offset + i * dst_stride + j;
36      dst_buffer->data[idx_dst] = src_buffer.data[idx_src];
37    }
38  }
39}
40
41Buffer<ArrayDataType::kFloat>* CreateFloatArrayBuffer(Model* model,
42                                                      string* array_name,
43                                                      const Shape& shape) {
44  *array_name = AvailableArrayName(*model, *array_name);
45  auto& array = model->GetOrCreateArray(*array_name);
46  array.data_type = ArrayDataType::kFloat;
47  array.copy_shape(shape);
48  Buffer<ArrayDataType::kFloat>* buffer =
49      &(array.GetMutableBuffer<ArrayDataType::kFloat>());
50  buffer->data.resize(RequiredBufferSizeForShape(shape));
51  return buffer;
52}
53
54void CopySubArrayToArray(Model* model, string* array_name,
55                         const string& tensor_name, int dim1_size,
56                         int dim2_size, const Array& original_array,
57                         int start_idx1, int start_idx2) {
58  // Determine whether it's bias or not, create shape, buffer.
59  bool is_bias = dim2_size == 1;
60  Shape shape = is_bias ? Shape({dim1_size}) : Shape({dim1_size, dim2_size});
61  Buffer<ArrayDataType::kFloat>* buffer =
62      CreateFloatArrayBuffer(model, array_name, shape);
63  auto& orig_buffer = original_array.GetBuffer<ArrayDataType::kFloat>();
64
65  // Copy data from big tensor.
66  CopyArrayData(orig_buffer, is_bias ? 1 : original_array.shape().dims(1),
67                start_idx1, start_idx2, buffer, dim2_size, 0, 0, dim1_size,
68                dim2_size);
69}
70
71void CopyArrayToSubArray(Buffer<ArrayDataType::kFloat>& tensor_buffer,
72                         int tensor_stride, const Array& sub_array,
73                         int start_idx1, int start_idx2) {
74  // Get tensor data.
75  bool is_bias = sub_array.shape().dims().size() == 1;
76  int dim1_copy_size = sub_array.shape().dims()[0];
77  int dim2_copy_size = is_bias ? 1 : sub_array.shape().dims(1);
78  auto& sub_buffer = sub_array.GetBuffer<ArrayDataType::kFloat>();
79
80  // Copy data from sub tensor.
81  CopyArrayData(sub_buffer, dim2_copy_size, 0, 0, &tensor_buffer,
82                is_bias ? 1 : tensor_stride, start_idx1, start_idx2,
83                dim1_copy_size, dim2_copy_size);
84}
85
86bool GetMatchingRnnArray(Model* model, const string& back_edge_source_array,
87                         string* rnn_array) {
88  for (const auto& rnn_state : model->flags.rnn_states()) {
89    if (rnn_state.back_edge_source_array() == back_edge_source_array) {
90      *rnn_array = rnn_state.state_array();
91      return true;
92    }
93  }
94  return false;
95}
96
97}  // namespace toco
98