1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <iostream>
16#include <string>
17#include <vector>
18
19#include "tensorflow/contrib/lite/toco/model.h"
20#include "tensorflow/contrib/lite/toco/tooling_util.h"
21
22namespace toco {
23
24// For consistency with the parameters defined in extended LstmCell's kernel
25// (tensorflow/contrib/lite/kernels/lstm.cc),
26// use lowercase for these constants.
27
28enum ExtendedLstmCellInputs {
29  kInputTensor = 0,
30  kInputToInputWeightsTensor = 1,  // Optional
31  kInputToForgetWeightsTensor = 2,
32  kInputToCellWeightsTensor = 3,
33  kInputToOutputWeightsTensor = 4,
34  kRecurrentToInputWeightsTensor = 5,  // Optional
35  kRecurrentToForgetWeightsTensor = 6,
36  kRecurrentToCellWeightsTensor = 7,
37  kRecurrentToOutputWeightsTensor = 8,
38  kCellToInputWeightsTensor = 9,    // Optional
39  kCellToForgetWeightsTensor = 10,  // Optional
40  kCellToOutputWeightsTensor = 11,  // Optional
41  kInputGateBiasTensor = 12,        // Optional
42  kForgetGateBiasTensor = 13,
43  kCellGateBiasTensor = 14,
44  kOutputGateBiasTensor = 15,
45  kProjectionWeightsTensor = 16,  // Optional
46  kProjectionBiasTensor = 17,     // Optional
47  kExtendedLstmInputCount = 18
48};
49
50enum ExtendedLstmCellOutputs {
51  kScratchBufferTensor = 0,
52  kOutputStateTensor = 1,
53  kCellStateTensor = 2,
54  kOutputTensor = 3
55};
56
57// Create optional array used for optional tensor in ExtendedLstmCell inputs.
58void CreateOptionalArray(Model* model, string* input_array_buffer,
59                         const string& array_name);
60
61// Create float array and get its buffer.
62Buffer<ArrayDataType::kFloat>* CreateFloatArrayBuffer(Model* model,
63                                                      string* array_name,
64                                                      const Shape& shape);
65
66// Copy data from one array to the other one (supports 1D and 2D array),
67// for 1D array, the 2nd dim's size is 1.
68// Arguments:
69//   src_buffer: the source buffer
70//   src_stride: the stride of source buffer, i.e., 2nd dim's size
71//   src_start_idx1: the 1st dim index of start point in src matrix
72//   src_start_idx2: the 2nd dim index of start point in src matrix
73//   dst_buffer: the destination buffer
74//   dst_stride: the stride of destination buffer, i.e., 2nd dim's size
75//   dst_start_idx1: the 1st dim index of start point in dst matrix
76//   dst_start_idx2: the 2nd dim index of start point in dst matrix
77//   dim1_copy_size: 1st dim size of copy data
78//   dim2_copy_size: 2nd dim size of copy data
79void CopyArrayData(const Buffer<ArrayDataType::kFloat>& src_buffer,
80                   int src_stride, int src_start_idx1, int src_start_idx2,
81                   Buffer<ArrayDataType::kFloat>* dst_buffer, int dst_stride,
82                   int dst_start_idx1, int dst_start_idx2, int dim1_copy_size,
83                   int dim2_copy_size);
84
85// Copy a subset of array data and create a smaller array,
86// mostly used for spliting weights and bias for Lstm cell.
87void CopySubArrayToArray(Model* model, string* array_name,
88                         const string& tensor_name, int dim1_size,
89                         int dim2_size, const Array& original_array,
90                         int start_idx1, int start_idx2);
91
92// Copy array data to a large array's submatrix,
93// mostly used for merging weights and bias for Lstm cell.
94void CopyArrayToSubArray(Buffer<ArrayDataType::kFloat>& tensor_buffer,
95                         int tensor_stride, const Array& sub_array,
96                         int start_idx1, int start_idx2);
97
98// Get mating rnn array inputs using rnn_states flag.
99bool GetMatchingRnnArray(Model* model, const string& back_edge_source_array,
100                         string* rnn_array);
101
102}  // namespace toco
103