1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License"); 49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License. 59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at 69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur http://www.apache.org/licenses/LICENSE-2.0 89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software 109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS, 119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and 139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License. 149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/ 159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#ifndef TENSORFLOW_KERNELS_CAST_OP_H_ 17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define TENSORFLOW_KERNELS_CAST_OP_H_ 18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 1956313def004795f75ef8281a0294c958d28f1e06Vijay Vasudevan#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/bfloat16.h" 21725206e677a9f1e343319293a347862335ff776bPeter Hawkins#include "tensorflow/core/framework/op_kernel.h" 22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/framework/tensor_types.h" 23725206e677a9f1e343319293a347862335ff776bPeter Hawkins#include "tensorflow/core/framework/types.h" 2446231cf242c19d74af75370eefd9e9b7c504c08aVijay Vasudevan#include "tensorflow/core/platform/cpu_info.h" 253ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/platform/types.h" 26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 27f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow { 28725206e677a9f1e343319293a347862335ff776bPeter Hawkins 29725206e677a9f1e343319293a347862335ff776bPeter Hawkins// Common base class of Cast kernels 30725206e677a9f1e343319293a347862335ff776bPeter Hawkinsclass CastOpBase : public OpKernel { 31725206e677a9f1e343319293a347862335ff776bPeter Hawkins public: 32725206e677a9f1e343319293a347862335ff776bPeter Hawkins explicit CastOpBase(OpKernelConstruction* ctx); 33725206e677a9f1e343319293a347862335ff776bPeter Hawkins 34725206e677a9f1e343319293a347862335ff776bPeter Hawkins void Compute(OpKernelContext* ctx) override; 35725206e677a9f1e343319293a347862335ff776bPeter Hawkins 36725206e677a9f1e343319293a347862335ff776bPeter Hawkins protected: 37725206e677a9f1e343319293a347862335ff776bPeter Hawkins DataType src_dtype_; 38725206e677a9f1e343319293a347862335ff776bPeter Hawkins DataType dst_dtype_; 39725206e677a9f1e343319293a347862335ff776bPeter Hawkins std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr; 40725206e677a9f1e343319293a347862335ff776bPeter Hawkins 41725206e677a9f1e343319293a347862335ff776bPeter Hawkins Status Unimplemented(); 42725206e677a9f1e343319293a347862335ff776bPeter Hawkins 43725206e677a9f1e343319293a347862335ff776bPeter Hawkins TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase); 44725206e677a9f1e343319293a347862335ff776bPeter Hawkins}; 45725206e677a9f1e343319293a347862335ff776bPeter Hawkins 46725206e677a9f1e343319293a347862335ff776bPeter Hawkins// CPU implementation of Cast 47725206e677a9f1e343319293a347862335ff776bPeter Hawkinsclass CpuCastOp : public CastOpBase { 48725206e677a9f1e343319293a347862335ff776bPeter Hawkins public: 49725206e677a9f1e343319293a347862335ff776bPeter Hawkins explicit CpuCastOp(OpKernelConstruction* ctx); 50725206e677a9f1e343319293a347862335ff776bPeter Hawkins 51725206e677a9f1e343319293a347862335ff776bPeter Hawkins private: 52725206e677a9f1e343319293a347862335ff776bPeter Hawkins Status Prepare(); 53725206e677a9f1e343319293a347862335ff776bPeter Hawkins}; 54725206e677a9f1e343319293a347862335ff776bPeter Hawkins 55f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace functor { 56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <typename Device, typename Tout, typename Tin> 58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurvoid Cast(const Device& d, typename TTypes<Tout>::Flat o, 59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur typename TTypes<Tin>::ConstFlat i) { 60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur o.device(d) = i.template cast<Tout>(); 61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 62f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <typename Device, typename Tout, typename Tin> 64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurstruct CastFunctor { 65f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void operator()(const Device& d, typename TTypes<Tout>::Flat o, 66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur typename TTypes<Tin>::ConstFlat i); 67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // end namespace functor 70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // end namespace tensorflow 71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace Eigen { 73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace internal { 74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 75a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner// Eigen can't convert to/from complex numbers, because it is limited to cases 76a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner// that can be static_casted. But numpy is able to cast to/from complex, which 77a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner// we want to replicate. So we add specializations for complex here. 787ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinertemplate <typename From, typename To> 79a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steinerstruct scalar_cast_op<std::complex<From>, To> { 807ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE To 817ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner operator()(const std::complex<From>& a) const { 821b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu // Replicate numpy behavior of returning just the real part 83a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner return static_cast<To>(a.real()); 84a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner } 85a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner}; 86a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner 877ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinertemplate <typename From, typename To> 88a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steinerstruct scalar_cast_op<From, std::complex<To>> { 897ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()( 907ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner const From& a) const { 911b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu // Replicate numpy behavior of setting the imaginary part to 0 92a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner return std::complex<To>(static_cast<To>(a), To(0)); 93a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner } 94a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner}; 95a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner 967ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinertemplate <typename From, typename To> 97a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steinerstruct scalar_cast_op<std::complex<From>, std::complex<To>> { 987ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()( 997ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner const std::complex<From>& a) const { 100a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner return std::complex<To>(static_cast<To>(a.real()), 101a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner static_cast<To>(a.imag())); 102a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner } 103a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner}; 104a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner 1057ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinertemplate <typename From, typename To> 106a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steinerstruct functor_traits_complex_impl { 107a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner enum { Cost = NumTraits<To>::AddCost, PacketAccess = false }; 108a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner}; 109a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner 1107ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinertemplate <typename From, typename To> 111a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steinerstruct functor_traits<scalar_cast_op<std::complex<From>, To>> 1127ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner : functor_traits_complex_impl<std::complex<From>, To> {}; 1137ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinertemplate <typename From, typename To> 114a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steinerstruct functor_traits<scalar_cast_op<From, std::complex<To>>> 1157ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner : functor_traits_complex_impl<From, std::complex<To>> {}; 116a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner// Needed to avoid ambiguous partial specialization 1177ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinertemplate <typename From, typename To> 118a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steinerstruct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>> 1197ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steiner : functor_traits_complex_impl<std::complex<From>, std::complex<To>> {}; 120a92da9e09fb476a9b267499e326919a89b826fb7Benoit Steiner 121f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Specialized cast op impls for bfloat16. 122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <> 1237ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinerstruct scalar_cast_op<::tensorflow::bfloat16, float> { 124f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) 125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur typedef float result_type; 126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()( 127f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const ::tensorflow::bfloat16& a) const { 128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur float ret; 129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur uint16_t* p = reinterpret_cast<uint16_t*>(&ret); 130769496bb0d3093a5531268635d364950c93327d6Patrick Nguyen#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ 131191825e63f341a4e7777b85254f616e541000d5cA. Unique TensorFlower p[0] = a.value; 132191825e63f341a4e7777b85254f616e541000d5cA. Unique TensorFlower p[1] = 0; 133191825e63f341a4e7777b85254f616e541000d5cA. Unique TensorFlower#else 134982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen static_assert(::tensorflow::port::kLittleEndian, 135982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen "Not a little endian system!"); 136f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur p[0] = 0; 137f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur p[1] = a.value; 138769496bb0d3093a5531268635d364950c93327d6Patrick Nguyen#endif 139f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return ret; 140f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 141f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 142f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 143f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <> 1447ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinerstruct functor_traits<scalar_cast_op<::tensorflow::bfloat16, float>> { 145f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; 146f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 147f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 148f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <> 149f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurstruct scalar_cast_op<float, ::tensorflow::bfloat16> { 150f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) 151f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur typedef ::tensorflow::bfloat16 result_type; 152f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const ::tensorflow::bfloat16 operator()( 153f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const float a) const { 154c0644791cfc064d5e4652271e51d826aeccad0c2A. Unique TensorFlower return ::tensorflow::bfloat16(a); 155f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 156f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 157f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 158f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <> 1597ef9ebd03a4ab90eb0cdf82a1f569257436f57d7Benoit Steinerstruct functor_traits<scalar_cast_op<float, ::tensorflow::bfloat16>> { 160f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; 161f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 162f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 163f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace internal 164f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace Eigen 165f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 166f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif // TENSORFLOW_KERNELS_CAST_OP_H_ 167