1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License");
49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License.
59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at
69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur    http://www.apache.org/licenses/LICENSE-2.0
89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software
109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS,
119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and
139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License.
149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/
159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#ifndef TENSORFLOW_KERNELS_CAST_OP_H_
17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define TENSORFLOW_KERNELS_CAST_OP_H_
18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
1956313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/bfloat16.h"
21725206e677a9f1e343319293a347862335ff776bPeter Hawkins#include "tensorflow/core/framework/op_kernel.h"
22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/tensor_types.h"
23725206e677a9f1e343319293a347862335ff776bPeter Hawkins#include "tensorflow/core/framework/types.h"
2446231cf242c19d74af75370eefd9e9b7c504c08aVijay Vasudevan#include "tensorflow/core/platform/cpu_info.h"
253ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/platform/types.h"
26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
27f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow {
28725206e677a9f1e343319293a347862335ff776bPeter Hawkins
29725206e677a9f1e343319293a347862335ff776bPeter Hawkins// Common base class of Cast kernels
30725206e677a9f1e343319293a347862335ff776bPeter Hawkinsclass CastOpBase : public OpKernel {
31725206e677a9f1e343319293a347862335ff776bPeter Hawkins public:
32725206e677a9f1e343319293a347862335ff776bPeter Hawkins  explicit CastOpBase(OpKernelConstruction* ctx);
33725206e677a9f1e343319293a347862335ff776bPeter Hawkins
34725206e677a9f1e343319293a347862335ff776bPeter Hawkins  void Compute(OpKernelContext* ctx) override;
35725206e677a9f1e343319293a347862335ff776bPeter Hawkins
36725206e677a9f1e343319293a347862335ff776bPeter Hawkins protected:
37725206e677a9f1e343319293a347862335ff776bPeter Hawkins  DataType src_dtype_;
38725206e677a9f1e343319293a347862335ff776bPeter Hawkins  DataType dst_dtype_;
39725206e677a9f1e343319293a347862335ff776bPeter Hawkins  std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr;
40725206e677a9f1e343319293a347862335ff776bPeter Hawkins
41725206e677a9f1e343319293a347862335ff776bPeter Hawkins  Status Unimplemented();
42725206e677a9f1e343319293a347862335ff776bPeter Hawkins
43725206e677a9f1e343319293a347862335ff776bPeter Hawkins  TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase);
44725206e677a9f1e343319293a347862335ff776bPeter Hawkins};
45725206e677a9f1e343319293a347862335ff776bPeter Hawkins
46725206e677a9f1e343319293a347862335ff776bPeter Hawkins// CPU implementation of Cast
47725206e677a9f1e343319293a347862335ff776bPeter Hawkinsclass CpuCastOp : public CastOpBase {
48725206e677a9f1e343319293a347862335ff776bPeter Hawkins public:
49725206e677a9f1e343319293a347862335ff776bPeter Hawkins  explicit CpuCastOp(OpKernelConstruction* ctx);
50725206e677a9f1e343319293a347862335ff776bPeter Hawkins
51725206e677a9f1e343319293a347862335ff776bPeter Hawkins private:
52725206e677a9f1e343319293a347862335ff776bPeter Hawkins  Status Prepare();
53725206e677a9f1e343319293a347862335ff776bPeter Hawkins};
54725206e677a9f1e343319293a347862335ff776bPeter Hawkins
55f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace functor {
56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <typename Device, typename Tout, typename Tin>
58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurvoid Cast(const Device& d, typename TTypes<Tout>::Flat o,
59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur          typename TTypes<Tin>::ConstFlat i) {
60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  o.device(d) = i.template cast<Tout>();
61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}
62f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <typename Device, typename Tout, typename Tin>
64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurstruct CastFunctor {
65f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void operator()(const Device& d, typename TTypes<Tout>::Flat o,
66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                  typename TTypes<Tin>::ConstFlat i);
67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // end namespace functor
70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // end namespace tensorflow
71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace Eigen {
73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace internal {
74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
75a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner// Eigen can't convert to/from complex numbers, because it is limited to cases
76a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner// that can be static_casted. But numpy is able to cast to/from complex, which
77a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner// we want to replicate. So we add specializations for complex here.
787ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinertemplate <typename From, typename To>
79a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steinerstruct scalar_cast_op<std::complex<From>, To> {
807ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE To
817ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner  operator()(const std::complex<From>& a) const {
821b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu    // Replicate numpy behavior of returning just the real part
83a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner    return static_cast<To>(a.real());
84a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner  }
85a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner};
86a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner
877ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinertemplate <typename From, typename To>
88a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steinerstruct scalar_cast_op<From, std::complex<To>> {
897ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()(
907ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner      const From& a) const {
911b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu    // Replicate numpy behavior of setting the imaginary part to 0
92a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner    return std::complex<To>(static_cast<To>(a), To(0));
93a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner  }
94a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner};
95a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner
967ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinertemplate <typename From, typename To>
97a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steinerstruct scalar_cast_op<std::complex<From>, std::complex<To>> {
987ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()(
997ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner      const std::complex<From>& a) const {
100a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner    return std::complex<To>(static_cast<To>(a.real()),
101a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner                            static_cast<To>(a.imag()));
102a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner  }
103a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner};
104a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner
1057ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinertemplate <typename From, typename To>
106a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steinerstruct functor_traits_complex_impl {
107a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner  enum { Cost = NumTraits<To>::AddCost, PacketAccess = false };
108a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner};
109a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner
1107ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinertemplate <typename From, typename To>
111a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steinerstruct functor_traits<scalar_cast_op<std::complex<From>, To>>
1127ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner    : functor_traits_complex_impl<std::complex<From>, To> {};
1137ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinertemplate <typename From, typename To>
114a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steinerstruct functor_traits<scalar_cast_op<From, std::complex<To>>>
1157ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner    : functor_traits_complex_impl<From, std::complex<To>> {};
116a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner// Needed to avoid ambiguous partial specialization
1177ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinertemplate <typename From, typename To>
118a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steinerstruct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>>
1197ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner    : functor_traits_complex_impl<std::complex<From>, std::complex<To>> {};
120a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner
121f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Specialized cast op impls for bfloat16.
122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <>
1237ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinerstruct scalar_cast_op<::tensorflow::bfloat16, float> {
124f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  typedef float result_type;
126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(
127f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      const ::tensorflow::bfloat16& a) const {
128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    float ret;
129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    uint16_t* p = reinterpret_cast<uint16_t*>(&ret);
130769496bb0d3093a5531268635d364950c93327d6Patrick Nguyen#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
131191825e63f341a4e7777b85254f616e541000d5cA. Unique TensorFlower    p[0] = a.value;
132191825e63f341a4e7777b85254f616e541000d5cA. Unique TensorFlower    p[1] = 0;
133191825e63f341a4e7777b85254f616e541000d5cA. Unique TensorFlower#else
134982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    static_assert(::tensorflow::port::kLittleEndian,
135982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                  "Not a little endian system!");
136f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    p[0] = 0;
137f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    p[1] = a.value;
138769496bb0d3093a5531268635d364950c93327d6Patrick Nguyen#endif
139f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    return ret;
140f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
141f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
142f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
143f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <>
1447ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinerstruct functor_traits<scalar_cast_op<::tensorflow::bfloat16, float>> {
145f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  enum { Cost = NumTraits<float>::AddCost, PacketAccess = false };
146f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
147f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
148f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <>
149f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurstruct scalar_cast_op<float, ::tensorflow::bfloat16> {
150f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
151f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  typedef ::tensorflow::bfloat16 result_type;
152f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const ::tensorflow::bfloat16 operator()(
153f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      const float a) const {
154c0644791cfc064d5e4652271e51d826aeccad0c2A. Unique TensorFlower    return ::tensorflow::bfloat16(a);
155f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
156f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
157f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
158f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <>
1597ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinerstruct functor_traits<scalar_cast_op<float, ::tensorflow::bfloat16>> {
160f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  enum { Cost = NumTraits<float>::AddCost, PacketAccess = false };
161f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
162f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
163f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace internal
164f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace Eigen
165f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
166f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif  // TENSORFLOW_KERNELS_CAST_OP_H_
167