1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifndef TENSORFLOW_KERNELS_CWISE_OPS_H_ 17#define TENSORFLOW_KERNELS_CWISE_OPS_H_ 18 19#include <cmath> 20#include <functional> 21#include <type_traits> 22 23#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 24 25#include "tensorflow/core/framework/numeric_types.h" 26#include "tensorflow/core/framework/tensor_types.h" 27#include "tensorflow/core/kernels/bounds_check.h" 28 29namespace Eigen { 30namespace numext { 31#if GOOGLE_CUDA 32template <> 33EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::complex<float> exp( 34 const std::complex<float>& x) { 35 auto com = ::expf(x.real()); 36 auto res_real = com * ::cosf(x.imag()); 37 auto res_imag = com * ::sinf(x.imag()); 38 return std::complex<float>(res_real, res_imag); 39} 40template <> 41EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::complex<double> exp( 42 const std::complex<double>& x) { 43 auto com = ::exp(x.real()); 44 auto res_real = com * ::cos(x.imag()); 45 auto res_imag = com * ::sin(x.imag()); 46 return std::complex<double>(res_real, res_imag); 47} 48#endif 49} // namespace numext 50 51namespace internal { 52 53template <typename T> 54struct scalar_asinh_op { 55 EIGEN_EMPTY_STRUCT_CTOR(scalar_asinh_op) 56 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const { 57#if EIGEN_HAS_CXX11_MATH 58 return numext::asinh(a); 59#else 60 return std::asinh(a); 61#endif // EIGEN_HAS_CXX11_MATH 62 } 63}; 64template <typename T> 65struct functor_traits<scalar_asinh_op<T>> { 66 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; 67}; 68 69template <typename T> 70struct scalar_acosh_op { 71 EIGEN_EMPTY_STRUCT_CTOR(scalar_acosh_op) 72 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const { 73#if EIGEN_HAS_CXX11_MATH 74 return numext::acosh(a); 75#else 76 return std::acosh(a); 77#endif // EIGEN_HAS_CXX11_MATH 78 } 79}; 80template <typename T> 81struct functor_traits<scalar_acosh_op<T>> { 82 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; 83}; 84 85template <typename T> 86struct scalar_atanh_op { 87 EIGEN_EMPTY_STRUCT_CTOR(scalar_atanh_op) 88 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const { 89#if EIGEN_HAS_CXX11_MATH 90 return numext::atanh(a); 91#else 92 return std::atanh(a); 93#endif // EIGEN_HAS_CXX11_MATH 94 } 95}; 96template <typename T> 97struct functor_traits<scalar_atanh_op<T>> { 98 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; 99}; 100 101// TODO(rmlarsen): This is a workaround for upstream change 102// https://bitbucket.org/eigen/eigen/commits/f339468d04d0f87caeb6cab9aef568627e9f6ea9 103// that renamed scalar_binary_pow_op to scalar_pow_op and deleted the unary 104// version of the latter. Remove once we upgrade to Eigen 3.3. 105template <typename Scalar, typename Exponent> 106struct scalar_binary_pow_op_google { 107 EIGEN_EMPTY_STRUCT_CTOR(scalar_binary_pow_op_google) 108 EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a, 109 const Exponent& b) const { 110 return numext::pow(a, b); 111 } 112}; 113 114template <typename Scalar, typename Exponent> 115struct functor_traits<scalar_binary_pow_op_google<Scalar, Exponent>> { 116 enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false }; 117}; 118 119template <typename Scalar, typename Exponent> 120struct safe_scalar_binary_pow_op { 121 static_assert(std::is_integral<Scalar>::value, "Integer type expected"); 122 static_assert(std::is_integral<Exponent>::value && 123 std::is_signed<Exponent>::value, 124 "Signed integer type expected"); 125 126 bool* const error; 127 128 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error) 129 : error(error) {} 130 131 EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a, 132 const Exponent& b) const { 133 const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b); 134 if (TF_PREDICT_TRUE(safe_b >= 0)) { 135 return numext::pow(a, safe_b); 136 } else { 137 *error = true; 138 return 0; 139 } 140 } 141}; 142 143template <typename Scalar, typename Exponent> 144struct functor_traits<safe_scalar_binary_pow_op<Scalar, Exponent>> { 145 enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false }; 146}; 147 148template <typename T, typename DivOrMod> 149struct safe_div_or_mod_op { 150 static_assert(std::is_integral<T>::value, "Integer type expected"); 151 152 bool* const error; 153 154 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_div_or_mod_op(bool* error) 155 : error(error) {} 156 157 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a, 158 const T& b) const { 159 const T safe_b = tensorflow::internal::SubtleMustCopy(b); 160 if (TF_PREDICT_TRUE(safe_b != 0)) { 161 return DivOrMod()(a, safe_b); 162 } else { 163 *error = true; 164 return 0; 165 } 166 } 167}; 168 169template <typename T, typename DivOrMod> 170struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> { 171 enum { 172 Cost = functor_traits<DivOrMod>::Cost + NumTraits<T>::AddCost, 173 PacketAccess = false, 174 }; 175}; 176 177// scalar_left and scalar_right are template helpers to partially 178// apply a binary function. 179// 180// Suppose Binary is a binary functor f(x, y), scalar_left<> is a 181// unary functor g_x(y) = f(x, y), where x is provided via the 182// constructor. Similarly, scalar_right<> is a unary functor g_y(x) = 183// f(x, y). 184 185template <typename Tout, typename Tin, typename Binary> 186struct scalar_left : private Binary { 187 typedef Tout result_type; 188 const Tin* left; 189 190 EIGEN_DEVICE_FUNC inline scalar_left(const scalar_left& other) = default; 191 192 template <typename... Args> 193 EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c, Args... args) 194 : Binary(args...), left(c) {} 195 196 EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const { 197 return Binary::operator()(*left, right); 198 } 199 200 template <typename Packet> 201 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const { 202 const Packet left_packet = Eigen::internal::pset1<Packet>(*left); 203 return Binary::packetOp(left_packet, right_packet); 204 } 205}; 206 207template <typename Tout, typename Tin, typename Binary> 208struct functor_traits<scalar_left<Tout, Tin, Binary>> { 209 enum { 210 Cost = functor_traits<Binary>::Cost, 211 PacketAccess = functor_traits<Binary>::PacketAccess, 212 }; 213}; 214 215template <typename Tout, typename Tin, typename Binary> 216struct scalar_right : private Binary { 217 typedef Tout result_type; 218 const Tin* right; 219 220 EIGEN_DEVICE_FUNC inline scalar_right(const scalar_right& other) = default; 221 222 template <typename... Args> 223 EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c, Args... args) 224 : Binary(args...), right(c) {} 225 226 EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const { 227 return Binary::operator()(left, *right); 228 } 229 230 template <typename Packet> 231 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const { 232 const Packet right_packet = Eigen::internal::pset1<Packet>(*right); 233 return Binary::packetOp(left_packet, right_packet); 234 } 235}; 236 237template <typename Tout, typename Tin, typename Binary> 238struct functor_traits<scalar_right<Tout, Tin, Binary>> { 239 enum { 240 Cost = functor_traits<Binary>::Cost, 241 PacketAccess = functor_traits<Binary>::PacketAccess, 242 }; 243}; 244 245// similar to std::equal_to, but with the DEVICE_FUNC qualifier 246template <class T> 247struct equal_to : std::binary_function<T, T, bool> { 248 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 249 const T& y) const { 250 return x == y; 251 } 252}; 253 254// similar to std::not_equal_to, but with the DEVICE_FUNC qualifier 255template <class T> 256struct not_equal_to : std::binary_function<T, T, bool> { 257 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 258 const T& y) const { 259 return x != y; 260 } 261}; 262 263// similar to std::greater, but with the DEVICE_FUNC qualifier 264template <class T> 265struct greater : std::binary_function<T, T, bool> { 266 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 267 const T& y) const { 268 return x > y; 269 } 270}; 271 272// similar to std::less, but with the DEVICE_FUNC qualifier 273template <class T> 274struct less : std::binary_function<T, T, bool> { 275 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 276 const T& y) const { 277 return x < y; 278 } 279}; 280 281// similar to std::greater_equal, but with the DEVICE_FUNC qualifier 282template <class T> 283struct greater_equal : std::binary_function<T, T, bool> { 284 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 285 const T& y) const { 286 return x >= y; 287 } 288}; 289 290// similar to std::less_equal, but with the DEVICE_FUNC qualifier 291template <class T> 292struct less_equal : std::binary_function<T, T, bool> { 293 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 294 const T& y) const { 295 return x <= y; 296 } 297}; 298 299// Functor that enables composition of multiple Eigen functors. 300template <typename Scalar, typename UnaryFunctor, typename BinaryFunctor> 301struct scalar_compose_op { 302 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar 303 operator()(const Scalar& a, const Scalar& b) const { 304 return UnaryFunctor()(BinaryFunctor()(a, b)); 305 } 306 template <typename Packet> 307 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 308 packetOp(const Packet& a, const Packet& b) const { 309 return UnaryFunctor().packetOp(BinaryFunctor().packetOp(a, b)); 310 } 311}; 312 313template <typename Scalar, typename UnaryFunctor, typename BinaryFunctor> 314struct functor_traits<scalar_compose_op<Scalar, UnaryFunctor, BinaryFunctor>> { 315 enum { 316 Cost = functor_traits<UnaryFunctor>::Cost + 317 functor_traits<BinaryFunctor>::Cost, 318 PacketAccess = functor_traits<UnaryFunctor>::PacketAccess && 319 functor_traits<BinaryFunctor>::PacketAccess 320 }; 321}; 322 323// TODO(b/32239616): This kernel should be moved into Eigen and vectorized. 324template <typename T, typename Enable = void> 325struct google_floor_div { 326 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 327 const T& y) const { 328 if ((x < T(0)) != (y < T(0))) { 329 T abs_x = std::abs(x); 330 T abs_y = std::abs(y); 331 return -(abs_x + abs_y - 1) / abs_y; 332 } else { 333 return x / y; 334 } 335 } 336}; 337 338template <typename T> 339struct google_floor_div< 340 T, typename std::enable_if<std::is_unsigned<T>::value>::type> { 341 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 342 const T& y) const { 343 return x / y; 344 } 345}; 346 347template <typename Scalar> 348struct functor_traits<google_floor_div<Scalar>> { 349 enum { 350 Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value + 351 2 * NumTraits<Scalar>::AddCost, 352 PacketAccess = false 353 }; 354}; 355 356// TODO(b/32239616): This kernel should be moved into Eigen and vectorized. 357template <typename T, typename Enable = void> 358struct google_floor_div_real { 359 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 360 const T& y) const { 361 return Eigen::numext::floor(x / y); 362 } 363}; 364 365template <typename Scalar> 366struct functor_traits<google_floor_div_real<Scalar>> { 367 enum { 368 Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value + 369 2 * NumTraits<Scalar>::AddCost, 370 PacketAccess = false 371 }; 372}; 373 374// TODO(b//32239616): This kernel should be moved into Eigen and vectorized. 375template <typename T> 376struct google_floor_fmod { 377 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 378 const T& y) const { 379 // EIGEN_STATIC_ASSERT(NUMERIC_TYPE_MUST_BE_REAL); 380 T trunc_mod = std::fmod(x, y); 381 return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y); 382 } 383}; 384 385template <typename Scalar> 386struct functor_traits<google_floor_fmod<Scalar>> { 387 enum { 388 Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value + 389 2 * NumTraits<Scalar>::AddCost, 390 PacketAccess = false 391 }; 392}; 393 394// TODO(b/32239616): This kernel should be moved into Eigen and vectorized. 395template <typename T> 396struct google_floor_mod { 397 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 398 const T& y) const { 399 // EIGEN_STATIC_ASSERT(!NUMERIC_TYPE_MUST_BE_REAL); 400 T trunc_mod = x % y; 401 return (x < T(0)) == (y < T(0)) ? trunc_mod : (trunc_mod + y) % y; 402 } 403}; 404 405template <typename Scalar> 406struct functor_traits<google_floor_mod<Scalar>> { 407 enum { 408 Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value + 409 2 * NumTraits<Scalar>::AddCost, 410 PacketAccess = false 411 }; 412}; 413 414#if EIGEN_COMP_GNUC && __cplusplus > 199711L 415#define DISABLE_FLOAT_EQUALITY_WARNING \ 416 _Pragma("GCC diagnostic push") \ 417 _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"") 418#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop") 419#else 420#define DISABLE_FLOAT_EQUALITY_WARNING 421#define ENABLE_FLOAT_EQUALITY_WARNING 422#endif 423 424template <typename Scalar> 425struct scalar_round_op_google { 426 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar 427 operator()(const Scalar& x) const { 428 EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), 429 NUMERIC_TYPE_MUST_BE_REAL) 430 431 Scalar round_val = Eigen::numext::floor(x); 432 const Scalar fraction = x - round_val; 433 if (fraction > Scalar(.5)) { 434 round_val += Scalar(1.0); 435 } else if (fraction == Scalar(.5)) { 436 const Scalar nearest_even_int = 437 round_val - Scalar(2) * Eigen::numext::floor(Scalar(.5) * x); 438 bool is_odd = (nearest_even_int == Scalar(1)); 439 if (is_odd) { 440 round_val += Scalar(1); 441 } 442 } 443 return round_val; 444 } 445}; 446 447template <typename Scalar> 448struct functor_traits<scalar_round_op_google<Scalar>> { 449 enum { Cost = 4 * NumTraits<Scalar>::AddCost, PacketAccess = false }; 450}; 451 452#undef ENABLE_FLOAT_EQUALITY_WARNING 453#undef DISABLE_FLOAT_EQUALITY_WARNING 454 455template <typename Scalar> 456struct bitwise_xor_op { 457 EIGEN_EMPTY_STRUCT_CTOR(bitwise_xor_op) 458 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar 459 operator()(const Scalar& x, const Scalar& y) const { 460 return x ^ y; 461 } 462 typedef typename Eigen::internal::packet_traits<Scalar>::type Packet; 463 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, 464 const Packet& b) const { 465 return Eigen::internal::pxor(a, b); 466 } 467}; 468 469template <typename Scalar> 470struct functor_traits<bitwise_xor_op<Scalar>> { 471 enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true }; 472}; 473 474} // end namespace internal 475} // end namespace Eigen 476 477namespace tensorflow { 478namespace functor { 479 480//////////////////////////////////////////////////////////////////////////////// 481// Helpers 482//////////////////////////////////////////////////////////////////////////////// 483 484// Base template for functors whose input scalar type is T and 485// output scalar type is R. 486template <typename T, typename F, typename R = T> 487struct base { 488 // func defines operator() and its vectorized version packetOp(). 489 typedef F func; 490 491 // If true, the functor's corresponding binary op will instantiate 492 // specialized kernels to perform an optimized broadcast 493 // operation. Each functor for which this is enabled increases the 494 // code size, so by default this is disabled for binary functors and 495 // is enabled on a per-op basis as needed. 496 static const bool use_bcast_optimization = false; 497 498 // operator() has the signature: 499 // out_type operator()(in_type in0, in_type in1 ...) 500 typedef R out_type; 501 typedef T in_type; 502 503 // TensorFlow provides tensor-ized version of "func". Roughly 504 // speaking, the tensorflow operation has the signature: 505 // tout_type op(tin_type in0) 506 // tout_type op(tin_type in0, tin_type in1) 507 // tout_type op(tin_type in0, in_type scalar) 508 typedef typename TTypes<out_type>::Flat tout_type; 509 typedef typename TTypes<in_type>::ConstFlat tin_type; 510 typedef typename TTypes<in_type>::ConstScalar tscalar_type; 511 512 // Whether the functor can error out. Currently applies only to integer 513 // div and mod. 514 static const bool has_errors = false; 515}; 516 517// For now, we only apply certain speed optimization for 518// float/double's broadcast binary op. 519template <typename T> 520struct use_bcast_optimization { 521 static const bool value = false; 522}; 523 524template <> 525struct use_bcast_optimization<float> { 526 static const bool value = true; 527}; 528 529template <> 530struct use_bcast_optimization<double> { 531 static const bool value = true; 532}; 533 534//////////////////////////////////////////////////////////////////////////////// 535// Unary functors 536//////////////////////////////////////////////////////////////////////////////// 537 538// abs(x) = |x| 539// neg(x) = - x 540// inverse(x) = 1 / x 541// square(x) = x^2 542// sqrt(x) = x^(1/2) 543// rsqrt(x) = x^(-1/2) 544// exp(x) = e^x 545// expm1(x) = e^x - 1 546// log(x) = natural logarithm of x 547// log1p(x) = natural logarithm of 1 + x 548// tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) 549// sigmoid = 1 / (1 + exp(-x)) // a.k.a, logistic 550// 551// NOTE: We may eventually implement common functions used in NN 552// here. E.g., rectifier, softplus, derivatives of tanh, sigmod, etc. 553// For reference, see speech/lstm/eigen_functors.h. 554 555template <typename T> 556struct abs : base<T, Eigen::internal::scalar_abs_op<T>, 557 typename Eigen::internal::scalar_abs_op<T>::result_type> {}; 558 559template <typename T> 560struct neg : base<T, Eigen::internal::scalar_opposite_op<T>> {}; 561 562template <typename T> 563struct inverse : base<T, Eigen::internal::scalar_inverse_op<T>> {}; 564 565template <typename T> 566struct square : base<T, Eigen::internal::scalar_square_op<T>> {}; 567 568template <typename T> 569struct sqrt : base<T, Eigen::internal::scalar_sqrt_op<T>> {}; 570 571template <typename T> 572struct rsqrt : base<T, Eigen::internal::scalar_rsqrt_op<T>> {}; 573 574template <typename T> 575struct exp : base<T, Eigen::internal::scalar_exp_op<T>> {}; 576 577template <typename T> 578struct expm1 : base<T, Eigen::internal::scalar_expm1_op<T>> {}; 579 580template <typename T> 581struct log : base<T, Eigen::internal::scalar_log_op<T>> {}; 582 583template <typename T> 584struct log1p : base<T, Eigen::internal::scalar_log1p_op<T>> {}; 585 586template <typename T> 587struct sign : base<T, Eigen::internal::scalar_sign_op<T>> {}; 588 589template <typename T> 590struct sinh : base<T, Eigen::internal::scalar_sinh_op<T>> {}; 591 592template <typename T> 593struct cosh : base<T, Eigen::internal::scalar_cosh_op<T>> {}; 594 595template <typename T> 596struct tanh : base<T, Eigen::internal::scalar_tanh_op<T>> {}; 597 598template <typename T> 599struct asinh : base<T, Eigen::internal::scalar_asinh_op<T>> {}; 600 601template <typename T> 602struct acosh : base<T, Eigen::internal::scalar_acosh_op<T>> {}; 603 604template <typename T> 605struct atanh : base<T, Eigen::internal::scalar_atanh_op<T>> {}; 606 607template <typename T> 608struct lgamma : base<T, Eigen::internal::scalar_lgamma_op<T>> {}; 609 610template <typename T> 611struct digamma : base<T, Eigen::internal::scalar_digamma_op<T>> {}; 612 613template <typename T> 614struct erf : base<T, Eigen::internal::scalar_erf_op<T>> {}; 615 616template <typename T> 617struct erfc : base<T, Eigen::internal::scalar_erfc_op<T>> {}; 618 619template <typename T> 620struct sigmoid : base<T, Eigen::internal::scalar_sigmoid_op<T>> {}; 621 622template <typename T> 623struct sin : base<T, Eigen::internal::scalar_sin_op<T>> {}; 624 625template <typename T> 626struct cos : base<T, Eigen::internal::scalar_cos_op<T>> {}; 627 628template <typename T> 629struct tan : base<T, Eigen::internal::scalar_tan_op<T>> {}; 630 631template <typename T> 632struct asin : base<T, Eigen::internal::scalar_asin_op<T>> {}; 633 634template <typename T> 635struct acos : base<T, Eigen::internal::scalar_acos_op<T>> {}; 636 637template <typename T> 638struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {}; 639 640struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> { 641}; 642 643// Flip all bits. Named invert to be consistent with numpy. 644template <typename T> 645struct invert_op { 646 EIGEN_EMPTY_STRUCT_CTOR(invert_op) 647 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const { 648 return ~a; 649 } 650}; 651 652template <typename T> 653struct invert : base<T, invert_op<T>> {}; 654 655// NOTE: std::isinf, std::isnan, std::isfinite are plain function. 656// Therefore we need to wrap them in functors to be used with Eigen's 657// type system. 658template <typename T> 659struct isinf : base<T, Eigen::internal::scalar_isinf_op<T>, bool> {}; 660 661template <typename T> 662struct isnan : base<T, Eigen::internal::scalar_isnan_op<T>, bool> {}; 663 664template <typename T> 665struct isfinite : base<T, Eigen::internal::scalar_isfinite_op<T>, bool> {}; 666 667template <typename T> 668struct floor : base<T, Eigen::internal::scalar_floor_op<T>> {}; 669 670template <typename T> 671struct round : base<T, Eigen::internal::scalar_round_op_google<T>> {}; 672 673template <typename T> 674struct ceil : base<T, Eigen::internal::scalar_ceil_op<T>> {}; 675 676/** this should go in Eigen 677 * \brief Template functor to compute the round to int value of a scalar 678 */ 679template <typename Scalar> 680struct scalar_rint_op { 681 EIGEN_EMPTY_STRUCT_CTOR(scalar_rint_op) 682 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar 683 operator()(const Scalar& a) const { 684#if defined(__CUDACC__) 685 return ::rint(a); 686#elif defined(__ANDROID__) 687 return rint(a); 688#else 689 return std::rint(a); 690#endif 691 } 692}; 693 694template <typename T> 695struct rint : base<T, scalar_rint_op<T>> {}; 696 697//////////////////////////////////////////////////////////////////////////////// 698// Binary functors 699//////////////////////////////////////////////////////////////////////////////// 700 701// Binary functors: 702// 703// add(x, y) = x + y 704// sub(x, y) = x - y 705// mul(x, y) = x * y 706// div(x, y) = x / y 707// mod(x, y) = x % y (int32 and int64 only) 708// fmod(x, y) = fmod(x, y) (float and double only) 709// pow(x, y) = x ^ y 710// maximum(x, y) = x > y ? x : y 711// minimum(x, y) = x < y ? x : y 712// squared_difference(x, y) = (x - y) * (x - y) 713 714template <typename T> 715struct add : base<T, Eigen::internal::scalar_sum_op<T>> { 716 static const bool use_bcast_optimization = true; 717}; 718 719template <typename T> 720struct sub : base<T, Eigen::internal::scalar_difference_op<T>> { 721 static const bool use_bcast_optimization = true; 722}; 723 724template <typename T> 725struct mul : base<T, Eigen::internal::scalar_product_op<T>> { 726 static const bool use_bcast_optimization = true; 727}; 728 729template <typename T> 730struct div : base<T, Eigen::internal::scalar_quotient_op<T>> {}; 731 732template <typename T> 733struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op< 734 T, Eigen::internal::scalar_quotient_op<T>>> { 735 static const bool has_errors = true; 736}; 737 738template <typename T> 739struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {}; 740 741template <typename T> 742struct mod : base<T, Eigen::internal::scalar_mod2_op<T>> {}; 743 744template <typename T> 745struct safe_mod : base<T, Eigen::internal::safe_div_or_mod_op< 746 T, Eigen::internal::scalar_mod2_op<T>>> { 747 static const bool has_errors = true; 748}; 749 750template <typename T> 751struct floor_fmod : base<T, Eigen::internal::google_floor_fmod<T>> {}; 752 753template <typename T> 754struct safe_floor_mod : base<T, Eigen::internal::safe_div_or_mod_op< 755 T, Eigen::internal::google_floor_mod<T>>> { 756 static const bool has_errors = true; 757}; 758 759template <typename T> 760struct floor_div : base<T, Eigen::internal::google_floor_div<T>> {}; 761 762template <typename T> 763struct safe_floor_div : base<T, Eigen::internal::safe_div_or_mod_op< 764 T, Eigen::internal::google_floor_div<T>>> { 765 static const bool has_errors = true; 766}; 767 768template <typename T> 769struct floor_div_real : base<T, Eigen::internal::google_floor_div_real<T>> {}; 770 771template <typename T> 772struct pow : base<T, Eigen::internal::scalar_binary_pow_op_google<T, T>> {}; 773 774template <typename T> 775struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> { 776 static const bool has_errors = true; 777}; 778 779template <typename T> 780struct maximum : base<T, Eigen::internal::scalar_max_op<T>> {}; 781 782template <typename T> 783struct minimum : base<T, Eigen::internal::scalar_min_op<T>> {}; 784 785template <typename T> 786struct igamma : base<T, Eigen::internal::scalar_igamma_op<T>> {}; 787 788template <typename T> 789struct igammac : base<T, Eigen::internal::scalar_igammac_op<T>> {}; 790 791template <typename T> 792struct zeta : base<T, Eigen::internal::scalar_zeta_op<T>> {}; 793 794template <typename T> 795struct polygamma : base<T, Eigen::internal::scalar_polygamma_op<T>> {}; 796 797template <typename Scalar> 798struct scalar_atan2_op { 799 EIGEN_EMPTY_STRUCT_CTOR(scalar_atan2_op) 800 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 801 operator()(const Scalar& y, const Scalar& x) const { 802#if GOOGLE_CUDA 803 return ::atan2(y, x); 804#else 805 return std::atan2(y, x); 806#endif 807 } 808}; 809 810template <typename T> 811struct atan2 : base<T, scalar_atan2_op<T>> {}; 812 813template <typename T> 814struct squared_difference 815 : base<T, Eigen::internal::scalar_compose_op< 816 T, Eigen::internal::scalar_square_op<T>, 817 Eigen::internal::scalar_difference_op<T>>> {}; 818 819template <typename T> 820struct less : base<T, Eigen::internal::less<T>, bool> {}; 821 822template <typename T> 823struct less_equal : base<T, Eigen::internal::less_equal<T>, bool> {}; 824 825template <typename T> 826struct greater : base<T, Eigen::internal::greater<T>, bool> {}; 827 828template <typename T> 829struct greater_equal : base<T, Eigen::internal::greater_equal<T>, bool> {}; 830 831template <typename T> 832struct equal_to : base<T, Eigen::internal::equal_to<T>, bool> {}; 833 834template <typename T> 835struct not_equal_to : base<T, Eigen::internal::not_equal_to<T>, bool> {}; 836 837struct logical_and : base<bool, Eigen::internal::scalar_boolean_and_op> {}; 838 839struct logical_or : base<bool, Eigen::internal::scalar_boolean_or_op> {}; 840 841template <typename T> 842struct bitwise_and_op { 843 EIGEN_EMPTY_STRUCT_CTOR(bitwise_and_op) 844 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 845 const T& y) const { 846 return x & y; 847 } 848}; 849 850template <typename T> 851struct bitwise_or_op { 852 EIGEN_EMPTY_STRUCT_CTOR(bitwise_or_op) 853 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 854 const T& y) const { 855 return x | y; 856 } 857}; 858 859template <typename T> 860struct bitwise_and : base<T, bitwise_and_op<T>> {}; 861 862template <typename T> 863struct bitwise_or : base<T, bitwise_or_op<T>> {}; 864 865template <typename T> 866struct bitwise_xor : base<T, Eigen::internal::bitwise_xor_op<T>> {}; 867 868template <typename T> 869struct left_shift_op { 870 EIGEN_EMPTY_STRUCT_CTOR(left_shift_op) 871 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 872 const T& y) const { 873 // Avoids UB: don't shift by larger than the bitwidth of T, and 874 // performs left shifts as unsigned shifts. 875 T y_clamped = y; 876 if (y_clamped < 0) { 877 y_clamped = 0; 878 } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) { 879 y_clamped = sizeof(T) * CHAR_BIT - 1; 880 } 881 using U = typename std::make_unsigned<T>::type; 882 return static_cast<T>(static_cast<U>(x) << static_cast<U>(y_clamped)); 883 } 884}; 885 886template <typename T> 887struct right_shift_op { 888 EIGEN_EMPTY_STRUCT_CTOR(right_shift_op) 889 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 890 const T& y) const { 891 // Avoids UB: don't shift by larger than the bitwidth of T. 892 T y_clamped = y; 893 if (y_clamped < 0) { 894 y_clamped = 0; 895 } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) { 896 y_clamped = sizeof(T) * CHAR_BIT - 1; 897 } 898 // Technically right shifts of signed integers are not necessarily 899 // arithmetic shifts according to the C++ standard. However in practice most 900 // implementations are arithmetic shifts. If this proves to be a problem in 901 // practice, we may need to use an alternative implementation. 902 return x >> y_clamped; 903 } 904}; 905 906template <typename T> 907struct left_shift : base<T, left_shift_op<T>> {}; 908 909template <typename T> 910struct right_shift : base<T, right_shift_op<T>> {}; 911 912template <typename T> 913struct make_complex_func { 914 typedef std::complex<T> result_type; 915 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(T real, 916 T imag) const { 917 return std::complex<T>(real, imag); 918 } 919}; 920 921template <typename T> 922struct make_complex : base<T, make_complex_func<T>, std::complex<T>> {}; 923 924template <typename T> 925struct get_real 926 : base<T, Eigen::internal::scalar_real_op<T>, typename T::value_type> {}; 927 928template <typename T> 929struct get_imag 930 : base<T, Eigen::internal::scalar_imag_op<T>, typename T::value_type> {}; 931 932template <typename T> 933struct get_angle 934 : base<T, Eigen::internal::scalar_arg_op<T>, typename T::value_type> {}; 935 936template <typename T> 937struct conj : base<T, Eigen::internal::scalar_conjugate_op<T>> {}; 938 939//////////////////////////////////////////////////////////////////////////////// 940// Functors takes 1 or 2 tensors, computes the base functor on 941// coefficient of the input tensors and puts the results in the output 942// tensor. 943//////////////////////////////////////////////////////////////////////////////// 944template <typename Device, typename Functor> 945struct UnaryFunctor { 946 // Computes on device "d": out[i] = Functor(in[i]) 947 void operator()(const Device& d, typename Functor::tout_type out, 948 typename Functor::tin_type in); 949}; 950 951template <typename Device, typename Functor, int NDIMS, 952 bool has_errors = Functor::has_errors> 953struct BinaryFunctor { 954 // Computes on device "d": out[i] = Functor(in0[i], in1[i]) 955 void operator()(const Device& d, typename Functor::tout_type out, 956 typename Functor::tin_type in0, 957 typename Functor::tin_type in1, bool* error); 958 959 // Computes on device "d": out[i] = Functor(scalar[0], in[i]) 960 void Left(const Device& d, typename Functor::tout_type out, 961 typename Functor::tscalar_type scalar, 962 typename Functor::tin_type in, bool* error); 963 964 // Computes on device "d": out[i] = Functor(in[i], scalar[0]) 965 void Right(const Device& d, typename Functor::tout_type out, 966 typename Functor::tin_type in, 967 typename Functor::tscalar_type scalar, bool* error); 968 969 // Computes on device "d": 970 // out = Functor(in0.broadcast(bcast0), in1.broadcast(bcast1)) 971 // 972 // TODO(zhifengc): makes BCast a template member function on NDIMS 973 // instead making BinaryFunctor templates on NDIMS. 974 void BCast(const Device& d, 975 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out, 976 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0, 977 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0, 978 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1, 979 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1, 980 bool* error); 981}; 982 983template <typename Device, typename T> 984struct ApproximateEqual { 985 void operator()(const Device& d, typename TTypes<T>::ConstFlat x, 986 typename TTypes<T>::ConstFlat y, T tolerance, 987 typename TTypes<bool>::Flat z); 988}; 989 990template <int NDIMS> 991bool AllOne(const typename Eigen::array<Eigen::DenseIndex, NDIMS>& a) { 992 for (size_t i = 0; i < a.size(); ++i) { 993 if (a[i] != 1) return false; 994 } 995 return true; 996} 997 998template <typename Device, typename T> 999struct SelectFunctor { 1000 void operator()(const Device& d, typename TTypes<T>::Flat out, 1001 typename TTypes<bool>::ConstFlat cond_flat, 1002 typename TTypes<T>::ConstFlat then_flat, 1003 typename TTypes<T>::ConstFlat else_flat); 1004}; 1005 1006template <typename Device, typename T> 1007struct SelectScalarFunctor { 1008 void operator()(const Device& d, typename TTypes<T>::Flat out, 1009 typename TTypes<bool>::ConstScalar cond, 1010 typename TTypes<T>::ConstFlat then_flat, 1011 typename TTypes<T>::ConstFlat else_flat); 1012}; 1013 1014template <typename Device, typename T> 1015struct BatchSelectFunctor { 1016 void operator()(const Device& d, 1017 typename TTypes<T>::Matrix output_flat_outer_dims, 1018 TTypes<bool>::ConstVec cond_vec, 1019 typename TTypes<T>::ConstMatrix then_flat_outer_dims, 1020 typename TTypes<T>::ConstMatrix else_flat_outer_dims); 1021}; 1022 1023} // end namespace functor 1024} // end namespace tensorflow 1025 1026#endif // TENSORFLOW_KERNELS_CWISE_OPS_H_ 1027