1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License");
49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License.
59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at
69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur    http://www.apache.org/licenses/LICENSE-2.0
89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software
109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS,
119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and
139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License.
149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/
159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define EIGEN_USE_THREADS
17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
18e28cc25e80978a6f3b06f5b13f12a1a9c0b2748dA. Unique TensorFlower#include "tensorflow/core/kernels/concat_lib_cpu.h"
19b481783fe0e00a86f6feb20a8dcad5fc4fc936a4Josh Levenberg#include <vector>
20b481783fe0e00a86f6feb20a8dcad5fc4fc936a4Josh Levenberg#include "tensorflow/core/framework/register_types.h"
21e28cc25e80978a6f3b06f5b13f12a1a9c0b2748dA. Unique TensorFlower#include "tensorflow/core/kernels/concat_lib.h"
22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow {
24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
25e28cc25e80978a6f3b06f5b13f12a1a9c0b2748dA. Unique TensorFlowernamespace {
26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <typename T>
27e28cc25e80978a6f3b06f5b13f12a1a9c0b2748dA. Unique TensorFlowerstruct MemCpyCopier {
28e28cc25e80978a6f3b06f5b13f12a1a9c0b2748dA. Unique TensorFlower  inline void Copy(T* dst, const T* src, int input_index, size_t n) {
29e28cc25e80978a6f3b06f5b13f12a1a9c0b2748dA. Unique TensorFlower    if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
30e28cc25e80978a6f3b06f5b13f12a1a9c0b2748dA. Unique TensorFlower      memcpy(dst, src, n * sizeof(T));
31e28cc25e80978a6f3b06f5b13f12a1a9c0b2748dA. Unique TensorFlower    } else {
32e28cc25e80978a6f3b06f5b13f12a1a9c0b2748dA. Unique TensorFlower      for (size_t k = 0; k < n; ++k) {
33e28cc25e80978a6f3b06f5b13f12a1a9c0b2748dA. Unique TensorFlower        *dst++ = *src++;
34e28cc25e80978a6f3b06f5b13f12a1a9c0b2748dA. Unique TensorFlower      }
35f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    }
36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
37e28cc25e80978a6f3b06f5b13f12a1a9c0b2748dA. Unique TensorFlower};
38464a7e2315155fc71f98ea5e9ae8ef1bccb835e1A. Unique TensorFlowertemplate <>
39464a7e2315155fc71f98ea5e9ae8ef1bccb835e1A. Unique TensorFlowerstruct MemCpyCopier<ResourceHandle> {
40464a7e2315155fc71f98ea5e9ae8ef1bccb835e1A. Unique TensorFlower  inline void Copy(ResourceHandle* dst, const ResourceHandle* src,
41464a7e2315155fc71f98ea5e9ae8ef1bccb835e1A. Unique TensorFlower                   int input_index, size_t n) {
42464a7e2315155fc71f98ea5e9ae8ef1bccb835e1A. Unique TensorFlower    for (size_t k = 0; k < n; ++k) {
43464a7e2315155fc71f98ea5e9ae8ef1bccb835e1A. Unique TensorFlower      *dst++ = *src++;
44464a7e2315155fc71f98ea5e9ae8ef1bccb835e1A. Unique TensorFlower    }
45464a7e2315155fc71f98ea5e9ae8ef1bccb835e1A. Unique TensorFlower  }
46464a7e2315155fc71f98ea5e9ae8ef1bccb835e1A. Unique TensorFlower};
47464a7e2315155fc71f98ea5e9ae8ef1bccb835e1A. Unique TensorFlower
48e28cc25e80978a6f3b06f5b13f12a1a9c0b2748dA. Unique TensorFlower}  // namespace
49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <typename T>
510f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlowervoid ConcatCPU(
520f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower    DeviceBase* d,
530f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower    const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
540f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower        inputs,
550f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower    typename TTypes<T, 2>::Matrix* output) {
569b7c47c1d48dfbe69e2ab62aae6146823ba7e664A. Unique TensorFlower  if (std::is_same<T, string>::value) {
579b7c47c1d48dfbe69e2ab62aae6146823ba7e664A. Unique TensorFlower    // use a large cost here to force strings to be handled by separate threads
589b7c47c1d48dfbe69e2ab62aae6146823ba7e664A. Unique TensorFlower    ConcatCPUImpl<T>(d, inputs, 100000, MemCpyCopier<T>(), output);
599b7c47c1d48dfbe69e2ab62aae6146823ba7e664A. Unique TensorFlower  } else {
609b7c47c1d48dfbe69e2ab62aae6146823ba7e664A. Unique TensorFlower    ConcatCPUImpl<T>(d, inputs, sizeof(T) /* cost_per_unit */,
619b7c47c1d48dfbe69e2ab62aae6146823ba7e664A. Unique TensorFlower                     MemCpyCopier<T>(), output);
629b7c47c1d48dfbe69e2ab62aae6146823ba7e664A. Unique TensorFlower  }
63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}
64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
65f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define REGISTER(T)                                                            \
66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  template void ConcatCPU<T>(                                                  \
67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      DeviceBase*,                                                             \
68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&, \
69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      typename TTypes<T, 2>::Matrix* output);
70464a7e2315155fc71f98ea5e9ae8ef1bccb835e1A. Unique TensorFlowerTF_CALL_ALL_TYPES(REGISTER)
71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurREGISTER(quint8)
72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurREGISTER(qint8)
7355819a3fa1cba3de850922b972b29d643fc345eeA. Unique TensorFlowerREGISTER(quint16)
7455819a3fa1cba3de850922b972b29d643fc345eeA. Unique TensorFlowerREGISTER(qint16)
75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath KudlurREGISTER(qint32)
76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
776b753f33c99ae6010bfc7ec6b751a8e0a1bcdfabA. Unique TensorFlower#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \
786b753f33c99ae6010bfc7ec6b751a8e0a1bcdfabA. Unique TensorFlower    !defined(__ANDROID_TYPES_FULL__)
79f88cd9195589e41f011c68995d61d760dc2e1a83Jiri Simsa    // Primarily used for SavedModel support on mobile. Registering it here only
80f88cd9195589e41f011c68995d61d760dc2e1a83Jiri Simsa    // if __ANDROID_TYPES_FULL__ is not defined (which already registers string)
81f88cd9195589e41f011c68995d61d760dc2e1a83Jiri Simsa    // to avoid duplicate registration.
82f88cd9195589e41f011c68995d61d760dc2e1a83Jiri Simsa    REGISTER(string);
83ca6b88eb95b089010a1b970e0de7398195b5bccaA. Unique TensorFlower#endif  // defined(IS_MOBILE_PLATFORM) &&
846b753f33c99ae6010bfc7ec6b751a8e0a1bcdfabA. Unique TensorFlower        // !defined(SUPPORT_SELECTIVE_REGISTRATION) &&
856b753f33c99ae6010bfc7ec6b751a8e0a1bcdfabA. Unique TensorFlower        // !defined(__ANDROID_TYPES_FULL__)
86ca6b88eb95b089010a1b970e0de7398195b5bccaA. Unique TensorFlower
873e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#ifdef TENSORFLOW_USE_SYCL
883e975ea978bac4d861bb09328b06f3c316212611Andrew Harptemplate <typename T>
890f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlowervoid ConcatSYCL(
900f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower    const Eigen::SyclDevice& d,
910f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower    const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
920f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower        inputs,
930f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower    typename TTypes<T, 2>::Matrix* output) {
943e975ea978bac4d861bb09328b06f3c316212611Andrew Harp  ConcatSYCLImpl<T>(d, inputs, sizeof(T) /* cost_per_unit */, MemCpyCopier<T>(),
950f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower                    output);
963e975ea978bac4d861bb09328b06f3c316212611Andrew Harp}
970f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower#define REGISTER_SYCL(T)                                                       \
980f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower  template void ConcatSYCL<T>(                                                 \
990f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      const Eigen::SyclDevice&,                                                \
1000f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&, \
1010f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower      typename TTypes<T, 2>::Matrix* output);
1023e975ea978bac4d861bb09328b06f3c316212611Andrew Harp
1031b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan HseuTF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL)
1043e975ea978bac4d861bb09328b06f3c316212611Andrew Harp
1053e975ea978bac4d861bb09328b06f3c316212611Andrew Harp#undef REGISTER_SYCL
1060f65c8f572201f8838189f3e3c3e455759112c14A. Unique TensorFlower#endif  // TENSORFLOW_USE_SYCL
107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace tensorflow
108