1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
16#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
17
18#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
19
20namespace tflite {
21
22enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
23
24template <int N>
25struct Dims {
26  int sizes[N];
27  int strides[N];
28};
29
30// Gets next index to iterate through a multidimensional array.
31inline bool NextIndex(const int num_dims, const int* dims, int* current) {
32  TFLITE_DCHECK_GT(num_dims, 0);
33  TFLITE_DCHECK(dims != nullptr);
34  TFLITE_DCHECK(current != nullptr);
35  int carry = 1;
36  for (int idx = num_dims - 1; idx >= 0; --idx) {
37    int current_val = current[idx] + carry;
38    TFLITE_DCHECK_GE(dims[idx], current_val);
39    if (dims[idx] == current_val) {
40      current[idx] = 0;
41    } else {
42      current[idx] = current_val;
43      carry = 0;
44      break;
45    }
46  }
47  return (carry == 0);
48}
49
50// Gets offset of index if reducing on axis. When reducing, the flattened offset
51// will not change, if the input index changes on the given axis. For example,
52// if you have a 3D tensor and you are reducing to 2D by eliminating axis 0,
53// then index (0, 1, 2) and index (1, 1, 2) will map to the same flattened
54// offset.
55// TODO(kanlig): uses Dims to represent dimensions.
56inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
57                                  const int* index, const int num_axis,
58                                  const int* axis) {
59  TFLITE_DCHECK_GT(num_dims, 0);
60  TFLITE_DCHECK(dims != nullptr);
61  TFLITE_DCHECK(index != nullptr);
62  size_t offset = 0;
63  for (int idx = 0; idx < num_dims; ++idx) {
64    // if we need to skip this axis
65    bool is_axis = false;
66    if (axis != nullptr) {
67      for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
68        if (idx == axis[axis_idx]) {
69          is_axis = true;
70          break;
71        }
72      }
73    }
74    if (!is_axis) {
75      offset = offset * static_cast<size_t>(dims[idx]) +
76               static_cast<size_t>(index[idx]);
77    }
78  }
79  return offset;
80}
81
82inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
83  TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
84  TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
85  TFLITE_DCHECK(i2 >= 0 && i2 < dims.sizes[2]);
86  TFLITE_DCHECK(i3 >= 0 && i3 < dims.sizes[3]);
87  return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
88         i3 * dims.strides[3];
89}
90
91inline int Offset(const Dims<4>& dims, int* index) {
92  return Offset(dims, index[0], index[1], index[2], index[3]);
93}
94
95// Get array size, DCHECKing that the dim index is in range.
96template <int N>
97int ArraySize(const Dims<N>& array, int index) {
98  TFLITE_DCHECK(index >= 0 && index < N);
99  return array.sizes[index];
100}
101
102// Get common array size, DCHECKing that they all agree.
103template <typename ArrayType1, typename ArrayType2>
104int MatchingArraySize(const ArrayType1& array1, int index1,
105                      const ArrayType2& array2, int index2) {
106  TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
107  return ArraySize(array1, index1);
108}
109
110template <typename ArrayType1, typename ArrayType2, typename... Args>
111int MatchingArraySize(const ArrayType1& array1, int index1,
112                      const ArrayType2& array2, int index2, Args... args) {
113  TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
114  return MatchingArraySize(array1, index1, args...);
115}
116
117inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
118  int max_offset = 0;
119  for (int i = 0; i < 4; i++) {
120    max_offset += (dims.sizes[i] - 1) * dims.strides[i];
121  }
122  return max_offset + 1;
123}
124
125template <int N>
126bool IsPackedWithoutStrides(const Dims<N>& dims) {
127  int expected_stride = 1;
128  for (int d = 0; d < N; d++) {
129    if (dims.strides[d] != expected_stride) return false;
130    expected_stride *= dims.sizes[d];
131  }
132  return true;
133}
134
135}  // namespace tflite
136
137#endif  // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
138