1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_
17#define TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_
18
19#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20#include "tensorflow/core/framework/tensor_types.h"
21
22#include "tensorflow/core/kernels/aggregate_ops.h"
23
24typedef Eigen::ThreadPoolDevice CPUDevice;
25
26#ifdef TENSORFLOW_USE_SYCL
27typedef Eigen::SyclDevice SYCLDevice;
28#endif  // TENSORFLOW_USE_SYCL
29
30namespace tensorflow {
31
32// Partial specializations for a CPUDevice, that uses the Eigen implementation
33// from AddNEigenImpl.
34namespace functor {
35template <typename T>
36struct Add2Functor<CPUDevice, T> {
37  void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
38                  typename TTypes<T>::ConstFlat in1,
39                  typename TTypes<T>::ConstFlat in2) {
40    Add2EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2);
41  }
42};
43template <typename T>
44struct Add3Functor<CPUDevice, T> {
45  void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
46                  typename TTypes<T>::ConstFlat in1,
47                  typename TTypes<T>::ConstFlat in2,
48                  typename TTypes<T>::ConstFlat in3) {
49    Add3EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3);
50  }
51};
52template <typename T>
53struct Add4Functor<CPUDevice, T> {
54  void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
55                  typename TTypes<T>::ConstFlat in1,
56                  typename TTypes<T>::ConstFlat in2,
57                  typename TTypes<T>::ConstFlat in3,
58                  typename TTypes<T>::ConstFlat in4) {
59    Add4EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4);
60  }
61};
62template <typename T>
63struct Add5Functor<CPUDevice, T> {
64  void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
65                  typename TTypes<T>::ConstFlat in1,
66                  typename TTypes<T>::ConstFlat in2,
67                  typename TTypes<T>::ConstFlat in3,
68                  typename TTypes<T>::ConstFlat in4,
69                  typename TTypes<T>::ConstFlat in5) {
70    Add5EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5);
71  }
72};
73template <typename T>
74struct Add6Functor<CPUDevice, T> {
75  void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
76                  typename TTypes<T>::ConstFlat in1,
77                  typename TTypes<T>::ConstFlat in2,
78                  typename TTypes<T>::ConstFlat in3,
79                  typename TTypes<T>::ConstFlat in4,
80                  typename TTypes<T>::ConstFlat in5,
81                  typename TTypes<T>::ConstFlat in6) {
82    Add6EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6);
83  }
84};
85template <typename T>
86struct Add7Functor<CPUDevice, T> {
87  void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
88                  typename TTypes<T>::ConstFlat in1,
89                  typename TTypes<T>::ConstFlat in2,
90                  typename TTypes<T>::ConstFlat in3,
91                  typename TTypes<T>::ConstFlat in4,
92                  typename TTypes<T>::ConstFlat in5,
93                  typename TTypes<T>::ConstFlat in6,
94                  typename TTypes<T>::ConstFlat in7) {
95    Add7EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
96                                         in7);
97  }
98};
99
100template <typename T>
101struct Add8Functor<CPUDevice, T> {
102  void operator()(
103      const CPUDevice& d, typename TTypes<T>::Flat out,
104      typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
105      typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
106      typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
107      typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
108    Add8EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
109                                         in7, in8);
110  }
111};
112
113template <typename T>
114struct Add8pFunctor<CPUDevice, T> {
115  void operator()(
116      const CPUDevice& d, typename TTypes<T>::Flat out,
117      typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
118      typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
119      typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
120      typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
121    Add8pEigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
122                                          in7, in8);
123  }
124};
125
126template <typename T>
127struct Add9Functor<CPUDevice, T> {
128  void operator()(
129      const CPUDevice& d, typename TTypes<T>::Flat out,
130      typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
131      typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
132      typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
133      typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
134      typename TTypes<T>::ConstFlat in9) {
135    Add9EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
136                                         in7, in8, in9);
137  }
138};
139
140#ifdef TENSORFLOW_USE_SYCL
141// Partial specializations for a SYCLDevice, that uses the Eigen implementation
142// from AddNEigenImpl.
143template <typename T>
144struct Add2Functor<SYCLDevice, T> {
145  void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
146                  typename TTypes<T>::ConstFlat in1,
147                  typename TTypes<T>::ConstFlat in2) {
148    Add2EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2);
149  }
150};
151template <typename T>
152struct Add3Functor<SYCLDevice, T> {
153  void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
154                  typename TTypes<T>::ConstFlat in1,
155                  typename TTypes<T>::ConstFlat in2,
156                  typename TTypes<T>::ConstFlat in3) {
157    Add3EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3);
158  }
159};
160template <typename T>
161struct Add4Functor<SYCLDevice, T> {
162  void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
163                  typename TTypes<T>::ConstFlat in1,
164                  typename TTypes<T>::ConstFlat in2,
165                  typename TTypes<T>::ConstFlat in3,
166                  typename TTypes<T>::ConstFlat in4) {
167    Add4EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4);
168  }
169};
170template <typename T>
171struct Add5Functor<SYCLDevice, T> {
172  void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
173                  typename TTypes<T>::ConstFlat in1,
174                  typename TTypes<T>::ConstFlat in2,
175                  typename TTypes<T>::ConstFlat in3,
176                  typename TTypes<T>::ConstFlat in4,
177                  typename TTypes<T>::ConstFlat in5) {
178    Add5EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5);
179  }
180};
181template <typename T>
182struct Add6Functor<SYCLDevice, T> {
183  void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
184                  typename TTypes<T>::ConstFlat in1,
185                  typename TTypes<T>::ConstFlat in2,
186                  typename TTypes<T>::ConstFlat in3,
187                  typename TTypes<T>::ConstFlat in4,
188                  typename TTypes<T>::ConstFlat in5,
189                  typename TTypes<T>::ConstFlat in6) {
190    Add6EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6);
191  }
192};
193template <typename T>
194struct Add7Functor<SYCLDevice, T> {
195  void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
196                  typename TTypes<T>::ConstFlat in1,
197                  typename TTypes<T>::ConstFlat in2,
198                  typename TTypes<T>::ConstFlat in3,
199                  typename TTypes<T>::ConstFlat in4,
200                  typename TTypes<T>::ConstFlat in5,
201                  typename TTypes<T>::ConstFlat in6,
202                  typename TTypes<T>::ConstFlat in7) {
203    Add7EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
204                                          in7);
205  }
206};
207
208template <typename T>
209struct Add8Functor<SYCLDevice, T> {
210  void operator()(
211      const SYCLDevice& d, typename TTypes<T>::Flat out,
212      typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
213      typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
214      typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
215      typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
216    Add8EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
217                                          in7, in8);
218  }
219};
220
221template <typename T>
222struct Add8pFunctor<SYCLDevice, T> {
223  void operator()(
224      const SYCLDevice& d, typename TTypes<T>::Flat out,
225      typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
226      typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
227      typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
228      typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
229    Add8pEigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
230                                           in7, in8);
231  }
232};
233
234template <typename T>
235struct Add9Functor<SYCLDevice, T> {
236  void operator()(
237      const SYCLDevice& d, typename TTypes<T>::Flat out,
238      typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
239      typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
240      typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
241      typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
242      typename TTypes<T>::ConstFlat in9) {
243    Add9EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
244                                          in7, in8, in9);
245  }
246};
247#endif  // TENSORFLOW_USE_SYCL
248
249}  // namespace functor
250
251}  // namespace tensorflow
252
253#endif  // TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_
254