1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifndef TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_ 17#define TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_ 18 19#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20#include "tensorflow/core/framework/tensor_types.h" 21 22#include "tensorflow/core/kernels/aggregate_ops.h" 23 24typedef Eigen::ThreadPoolDevice CPUDevice; 25 26#ifdef TENSORFLOW_USE_SYCL 27typedef Eigen::SyclDevice SYCLDevice; 28#endif // TENSORFLOW_USE_SYCL 29 30namespace tensorflow { 31 32// Partial specializations for a CPUDevice, that uses the Eigen implementation 33// from AddNEigenImpl. 34namespace functor { 35template <typename T> 36struct Add2Functor<CPUDevice, T> { 37 void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, 38 typename TTypes<T>::ConstFlat in1, 39 typename TTypes<T>::ConstFlat in2) { 40 Add2EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2); 41 } 42}; 43template <typename T> 44struct Add3Functor<CPUDevice, T> { 45 void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, 46 typename TTypes<T>::ConstFlat in1, 47 typename TTypes<T>::ConstFlat in2, 48 typename TTypes<T>::ConstFlat in3) { 49 Add3EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3); 50 } 51}; 52template <typename T> 53struct Add4Functor<CPUDevice, T> { 54 void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, 55 typename TTypes<T>::ConstFlat in1, 56 typename TTypes<T>::ConstFlat in2, 57 typename TTypes<T>::ConstFlat in3, 58 typename TTypes<T>::ConstFlat in4) { 59 Add4EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4); 60 } 61}; 62template <typename T> 63struct Add5Functor<CPUDevice, T> { 64 void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, 65 typename TTypes<T>::ConstFlat in1, 66 typename TTypes<T>::ConstFlat in2, 67 typename TTypes<T>::ConstFlat in3, 68 typename TTypes<T>::ConstFlat in4, 69 typename TTypes<T>::ConstFlat in5) { 70 Add5EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5); 71 } 72}; 73template <typename T> 74struct Add6Functor<CPUDevice, T> { 75 void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, 76 typename TTypes<T>::ConstFlat in1, 77 typename TTypes<T>::ConstFlat in2, 78 typename TTypes<T>::ConstFlat in3, 79 typename TTypes<T>::ConstFlat in4, 80 typename TTypes<T>::ConstFlat in5, 81 typename TTypes<T>::ConstFlat in6) { 82 Add6EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6); 83 } 84}; 85template <typename T> 86struct Add7Functor<CPUDevice, T> { 87 void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, 88 typename TTypes<T>::ConstFlat in1, 89 typename TTypes<T>::ConstFlat in2, 90 typename TTypes<T>::ConstFlat in3, 91 typename TTypes<T>::ConstFlat in4, 92 typename TTypes<T>::ConstFlat in5, 93 typename TTypes<T>::ConstFlat in6, 94 typename TTypes<T>::ConstFlat in7) { 95 Add7EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 96 in7); 97 } 98}; 99 100template <typename T> 101struct Add8Functor<CPUDevice, T> { 102 void operator()( 103 const CPUDevice& d, typename TTypes<T>::Flat out, 104 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 105 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 106 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 107 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { 108 Add8EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 109 in7, in8); 110 } 111}; 112 113template <typename T> 114struct Add8pFunctor<CPUDevice, T> { 115 void operator()( 116 const CPUDevice& d, typename TTypes<T>::Flat out, 117 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 118 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 119 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 120 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { 121 Add8pEigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 122 in7, in8); 123 } 124}; 125 126template <typename T> 127struct Add9Functor<CPUDevice, T> { 128 void operator()( 129 const CPUDevice& d, typename TTypes<T>::Flat out, 130 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 131 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 132 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 133 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8, 134 typename TTypes<T>::ConstFlat in9) { 135 Add9EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 136 in7, in8, in9); 137 } 138}; 139 140#ifdef TENSORFLOW_USE_SYCL 141// Partial specializations for a SYCLDevice, that uses the Eigen implementation 142// from AddNEigenImpl. 143template <typename T> 144struct Add2Functor<SYCLDevice, T> { 145 void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, 146 typename TTypes<T>::ConstFlat in1, 147 typename TTypes<T>::ConstFlat in2) { 148 Add2EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2); 149 } 150}; 151template <typename T> 152struct Add3Functor<SYCLDevice, T> { 153 void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, 154 typename TTypes<T>::ConstFlat in1, 155 typename TTypes<T>::ConstFlat in2, 156 typename TTypes<T>::ConstFlat in3) { 157 Add3EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3); 158 } 159}; 160template <typename T> 161struct Add4Functor<SYCLDevice, T> { 162 void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, 163 typename TTypes<T>::ConstFlat in1, 164 typename TTypes<T>::ConstFlat in2, 165 typename TTypes<T>::ConstFlat in3, 166 typename TTypes<T>::ConstFlat in4) { 167 Add4EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4); 168 } 169}; 170template <typename T> 171struct Add5Functor<SYCLDevice, T> { 172 void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, 173 typename TTypes<T>::ConstFlat in1, 174 typename TTypes<T>::ConstFlat in2, 175 typename TTypes<T>::ConstFlat in3, 176 typename TTypes<T>::ConstFlat in4, 177 typename TTypes<T>::ConstFlat in5) { 178 Add5EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5); 179 } 180}; 181template <typename T> 182struct Add6Functor<SYCLDevice, T> { 183 void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, 184 typename TTypes<T>::ConstFlat in1, 185 typename TTypes<T>::ConstFlat in2, 186 typename TTypes<T>::ConstFlat in3, 187 typename TTypes<T>::ConstFlat in4, 188 typename TTypes<T>::ConstFlat in5, 189 typename TTypes<T>::ConstFlat in6) { 190 Add6EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6); 191 } 192}; 193template <typename T> 194struct Add7Functor<SYCLDevice, T> { 195 void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, 196 typename TTypes<T>::ConstFlat in1, 197 typename TTypes<T>::ConstFlat in2, 198 typename TTypes<T>::ConstFlat in3, 199 typename TTypes<T>::ConstFlat in4, 200 typename TTypes<T>::ConstFlat in5, 201 typename TTypes<T>::ConstFlat in6, 202 typename TTypes<T>::ConstFlat in7) { 203 Add7EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 204 in7); 205 } 206}; 207 208template <typename T> 209struct Add8Functor<SYCLDevice, T> { 210 void operator()( 211 const SYCLDevice& d, typename TTypes<T>::Flat out, 212 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 213 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 214 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 215 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { 216 Add8EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 217 in7, in8); 218 } 219}; 220 221template <typename T> 222struct Add8pFunctor<SYCLDevice, T> { 223 void operator()( 224 const SYCLDevice& d, typename TTypes<T>::Flat out, 225 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 226 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 227 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 228 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { 229 Add8pEigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 230 in7, in8); 231 } 232}; 233 234template <typename T> 235struct Add9Functor<SYCLDevice, T> { 236 void operator()( 237 const SYCLDevice& d, typename TTypes<T>::Flat out, 238 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 239 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 240 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 241 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8, 242 typename TTypes<T>::ConstFlat in9) { 243 Add9EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 244 in7, in8, in9); 245 } 246}; 247#endif // TENSORFLOW_USE_SYCL 248 249} // namespace functor 250 251} // namespace tensorflow 252 253#endif // TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_ 254