1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_KERNELS_CWISE_OPS_H_
17#define TENSORFLOW_KERNELS_CWISE_OPS_H_
18
19#include <cmath>
20#include <functional>
21#include <type_traits>
22
23#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24
25#include "tensorflow/core/framework/numeric_types.h"
26#include "tensorflow/core/framework/tensor_types.h"
27#include "tensorflow/core/kernels/bounds_check.h"
28
29namespace Eigen {
30namespace numext {
31#if GOOGLE_CUDA
32template <>
33EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::complex<float> exp(
34    const std::complex<float>& x) {
35  auto com = ::expf(x.real());
36  auto res_real = com * ::cosf(x.imag());
37  auto res_imag = com * ::sinf(x.imag());
38  return std::complex<float>(res_real, res_imag);
39}
40template <>
41EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::complex<double> exp(
42    const std::complex<double>& x) {
43  auto com = ::exp(x.real());
44  auto res_real = com * ::cos(x.imag());
45  auto res_imag = com * ::sin(x.imag());
46  return std::complex<double>(res_real, res_imag);
47}
48#endif
49}  // namespace numext
50
51namespace internal {
52
53template <typename T>
54struct scalar_asinh_op {
55  EIGEN_EMPTY_STRUCT_CTOR(scalar_asinh_op)
56  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const {
57#if EIGEN_HAS_CXX11_MATH
58    return numext::asinh(a);
59#else
60    return std::asinh(a);
61#endif  // EIGEN_HAS_CXX11_MATH
62  }
63};
64template <typename T>
65struct functor_traits<scalar_asinh_op<T>> {
66  enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
67};
68
69template <typename T>
70struct scalar_acosh_op {
71  EIGEN_EMPTY_STRUCT_CTOR(scalar_acosh_op)
72  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const {
73#if EIGEN_HAS_CXX11_MATH
74    return numext::acosh(a);
75#else
76    return std::acosh(a);
77#endif  // EIGEN_HAS_CXX11_MATH
78  }
79};
80template <typename T>
81struct functor_traits<scalar_acosh_op<T>> {
82  enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
83};
84
85template <typename T>
86struct scalar_atanh_op {
87  EIGEN_EMPTY_STRUCT_CTOR(scalar_atanh_op)
88  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const {
89#if EIGEN_HAS_CXX11_MATH
90    return numext::atanh(a);
91#else
92    return std::atanh(a);
93#endif  // EIGEN_HAS_CXX11_MATH
94  }
95};
96template <typename T>
97struct functor_traits<scalar_atanh_op<T>> {
98  enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
99};
100
101// TODO(rmlarsen): This is a workaround for upstream change
102// https://bitbucket.org/eigen/eigen/commits/f339468d04d0f87caeb6cab9aef568627e9f6ea9
103// that renamed scalar_binary_pow_op to scalar_pow_op and deleted the unary
104// version of the latter. Remove once we upgrade to Eigen 3.3.
105template <typename Scalar, typename Exponent>
106struct scalar_binary_pow_op_google {
107  EIGEN_EMPTY_STRUCT_CTOR(scalar_binary_pow_op_google)
108  EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a,
109                                             const Exponent& b) const {
110    return numext::pow(a, b);
111  }
112};
113
114template <typename Scalar, typename Exponent>
115struct functor_traits<scalar_binary_pow_op_google<Scalar, Exponent>> {
116  enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
117};
118
119template <typename Scalar, typename Exponent>
120struct safe_scalar_binary_pow_op {
121  static_assert(std::is_integral<Scalar>::value, "Integer type expected");
122  static_assert(std::is_integral<Exponent>::value &&
123                    std::is_signed<Exponent>::value,
124                "Signed integer type expected");
125
126  bool* const error;
127
128  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error)
129      : error(error) {}
130
131  EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a,
132                                             const Exponent& b) const {
133    const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b);
134    if (TF_PREDICT_TRUE(safe_b >= 0)) {
135      return numext::pow(a, safe_b);
136    } else {
137      *error = true;
138      return 0;
139    }
140  }
141};
142
143template <typename Scalar, typename Exponent>
144struct functor_traits<safe_scalar_binary_pow_op<Scalar, Exponent>> {
145  enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
146};
147
148template <typename T, typename DivOrMod>
149struct safe_div_or_mod_op {
150  static_assert(std::is_integral<T>::value, "Integer type expected");
151
152  bool* const error;
153
154  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_div_or_mod_op(bool* error)
155      : error(error) {}
156
157  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a,
158                                                           const T& b) const {
159    const T safe_b = tensorflow::internal::SubtleMustCopy(b);
160    if (TF_PREDICT_TRUE(safe_b != 0)) {
161      return DivOrMod()(a, safe_b);
162    } else {
163      *error = true;
164      return 0;
165    }
166  }
167};
168
169template <typename T, typename DivOrMod>
170struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> {
171  enum {
172    Cost = functor_traits<DivOrMod>::Cost + NumTraits<T>::AddCost,
173    PacketAccess = false,
174  };
175};
176
177// scalar_left and scalar_right are template helpers to partially
178// apply a binary function.
179//
180// Suppose Binary is a binary functor f(x, y), scalar_left<> is a
181// unary functor g_x(y) = f(x, y), where x is provided via the
182// constructor. Similarly, scalar_right<> is a unary functor g_y(x) =
183// f(x, y).
184
185template <typename Tout, typename Tin, typename Binary>
186struct scalar_left : private Binary {
187  typedef Tout result_type;
188  const Tin* left;
189
190  EIGEN_DEVICE_FUNC inline scalar_left(const scalar_left& other) = default;
191
192  template <typename... Args>
193  EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c, Args... args)
194      : Binary(args...), left(c) {}
195
196  EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const {
197    return Binary::operator()(*left, right);
198  }
199
200  template <typename Packet>
201  EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const {
202    const Packet left_packet = Eigen::internal::pset1<Packet>(*left);
203    return Binary::packetOp(left_packet, right_packet);
204  }
205};
206
207template <typename Tout, typename Tin, typename Binary>
208struct functor_traits<scalar_left<Tout, Tin, Binary>> {
209  enum {
210    Cost = functor_traits<Binary>::Cost,
211    PacketAccess = functor_traits<Binary>::PacketAccess,
212  };
213};
214
215template <typename Tout, typename Tin, typename Binary>
216struct scalar_right : private Binary {
217  typedef Tout result_type;
218  const Tin* right;
219
220  EIGEN_DEVICE_FUNC inline scalar_right(const scalar_right& other) = default;
221
222  template <typename... Args>
223  EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c, Args... args)
224      : Binary(args...), right(c) {}
225
226  EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const {
227    return Binary::operator()(left, *right);
228  }
229
230  template <typename Packet>
231  EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const {
232    const Packet right_packet = Eigen::internal::pset1<Packet>(*right);
233    return Binary::packetOp(left_packet, right_packet);
234  }
235};
236
237template <typename Tout, typename Tin, typename Binary>
238struct functor_traits<scalar_right<Tout, Tin, Binary>> {
239  enum {
240    Cost = functor_traits<Binary>::Cost,
241    PacketAccess = functor_traits<Binary>::PacketAccess,
242  };
243};
244
245// similar to std::equal_to, but with the DEVICE_FUNC qualifier
246template <class T>
247struct equal_to : std::binary_function<T, T, bool> {
248  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
249                                                        const T& y) const {
250    return x == y;
251  }
252};
253
254// similar to std::not_equal_to, but with the DEVICE_FUNC qualifier
255template <class T>
256struct not_equal_to : std::binary_function<T, T, bool> {
257  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
258                                                        const T& y) const {
259    return x != y;
260  }
261};
262
263// similar to std::greater, but with the DEVICE_FUNC qualifier
264template <class T>
265struct greater : std::binary_function<T, T, bool> {
266  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
267                                                        const T& y) const {
268    return x > y;
269  }
270};
271
272// similar to std::less, but with the DEVICE_FUNC qualifier
273template <class T>
274struct less : std::binary_function<T, T, bool> {
275  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
276                                                        const T& y) const {
277    return x < y;
278  }
279};
280
281// similar to std::greater_equal, but with the DEVICE_FUNC qualifier
282template <class T>
283struct greater_equal : std::binary_function<T, T, bool> {
284  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
285                                                        const T& y) const {
286    return x >= y;
287  }
288};
289
290// similar to std::less_equal, but with the DEVICE_FUNC qualifier
291template <class T>
292struct less_equal : std::binary_function<T, T, bool> {
293  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
294                                                        const T& y) const {
295    return x <= y;
296  }
297};
298
299// Functor that enables composition of multiple Eigen functors.
300template <typename Scalar, typename UnaryFunctor, typename BinaryFunctor>
301struct scalar_compose_op {
302  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
303  operator()(const Scalar& a, const Scalar& b) const {
304    return UnaryFunctor()(BinaryFunctor()(a, b));
305  }
306  template <typename Packet>
307  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
308  packetOp(const Packet& a, const Packet& b) const {
309    return UnaryFunctor().packetOp(BinaryFunctor().packetOp(a, b));
310  }
311};
312
313template <typename Scalar, typename UnaryFunctor, typename BinaryFunctor>
314struct functor_traits<scalar_compose_op<Scalar, UnaryFunctor, BinaryFunctor>> {
315  enum {
316    Cost = functor_traits<UnaryFunctor>::Cost +
317           functor_traits<BinaryFunctor>::Cost,
318    PacketAccess = functor_traits<UnaryFunctor>::PacketAccess &&
319                   functor_traits<BinaryFunctor>::PacketAccess
320  };
321};
322
323// TODO(b/32239616): This kernel should be moved into Eigen and vectorized.
324template <typename T, typename Enable = void>
325struct google_floor_div {
326  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
327                                                           const T& y) const {
328    if ((x < T(0)) != (y < T(0))) {
329      T abs_x = std::abs(x);
330      T abs_y = std::abs(y);
331      return -(abs_x + abs_y - 1) / abs_y;
332    } else {
333      return x / y;
334    }
335  }
336};
337
338template <typename T>
339struct google_floor_div<
340    T, typename std::enable_if<std::is_unsigned<T>::value>::type> {
341  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
342                                                           const T& y) const {
343    return x / y;
344  }
345};
346
347template <typename Scalar>
348struct functor_traits<google_floor_div<Scalar>> {
349  enum {
350    Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value +
351           2 * NumTraits<Scalar>::AddCost,
352    PacketAccess = false
353  };
354};
355
356// TODO(b/32239616): This kernel should be moved into Eigen and vectorized.
357template <typename T, typename Enable = void>
358struct google_floor_div_real {
359  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
360                                                           const T& y) const {
361    return Eigen::numext::floor(x / y);
362  }
363};
364
365template <typename Scalar>
366struct functor_traits<google_floor_div_real<Scalar>> {
367  enum {
368    Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value +
369           2 * NumTraits<Scalar>::AddCost,
370    PacketAccess = false
371  };
372};
373
374// TODO(b//32239616): This kernel should be moved into Eigen and vectorized.
375template <typename T>
376struct google_floor_fmod {
377  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
378                                                           const T& y) const {
379    // EIGEN_STATIC_ASSERT(NUMERIC_TYPE_MUST_BE_REAL);
380    T trunc_mod = std::fmod(x, y);
381    return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y);
382  }
383};
384
385template <typename Scalar>
386struct functor_traits<google_floor_fmod<Scalar>> {
387  enum {
388    Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value +
389           2 * NumTraits<Scalar>::AddCost,
390    PacketAccess = false
391  };
392};
393
394// TODO(b/32239616): This kernel should be moved into Eigen and vectorized.
395template <typename T>
396struct google_floor_mod {
397  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
398                                                           const T& y) const {
399    // EIGEN_STATIC_ASSERT(!NUMERIC_TYPE_MUST_BE_REAL);
400    T trunc_mod = x % y;
401    return (x < T(0)) == (y < T(0)) ? trunc_mod : (trunc_mod + y) % y;
402  }
403};
404
405template <typename Scalar>
406struct functor_traits<google_floor_mod<Scalar>> {
407  enum {
408    Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value +
409           2 * NumTraits<Scalar>::AddCost,
410    PacketAccess = false
411  };
412};
413
414#if EIGEN_COMP_GNUC && __cplusplus > 199711L
415#define DISABLE_FLOAT_EQUALITY_WARNING \
416  _Pragma("GCC diagnostic push")       \
417      _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
418#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
419#else
420#define DISABLE_FLOAT_EQUALITY_WARNING
421#define ENABLE_FLOAT_EQUALITY_WARNING
422#endif
423
424template <typename Scalar>
425struct scalar_round_op_google {
426  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
427  operator()(const Scalar& x) const {
428    EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex),
429                        NUMERIC_TYPE_MUST_BE_REAL)
430
431    Scalar round_val = Eigen::numext::floor(x);
432    const Scalar fraction = x - round_val;
433    if (fraction > Scalar(.5)) {
434      round_val += Scalar(1.0);
435    } else if (fraction == Scalar(.5)) {
436      const Scalar nearest_even_int =
437          round_val - Scalar(2) * Eigen::numext::floor(Scalar(.5) * x);
438      bool is_odd = (nearest_even_int == Scalar(1));
439      if (is_odd) {
440        round_val += Scalar(1);
441      }
442    }
443    return round_val;
444  }
445};
446
447template <typename Scalar>
448struct functor_traits<scalar_round_op_google<Scalar>> {
449  enum { Cost = 4 * NumTraits<Scalar>::AddCost, PacketAccess = false };
450};
451
452#undef ENABLE_FLOAT_EQUALITY_WARNING
453#undef DISABLE_FLOAT_EQUALITY_WARNING
454
455template <typename Scalar>
456struct bitwise_xor_op {
457  EIGEN_EMPTY_STRUCT_CTOR(bitwise_xor_op)
458  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
459  operator()(const Scalar& x, const Scalar& y) const {
460    return x ^ y;
461  }
462  typedef typename Eigen::internal::packet_traits<Scalar>::type Packet;
463  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a,
464                                                        const Packet& b) const {
465    return Eigen::internal::pxor(a, b);
466  }
467};
468
469template <typename Scalar>
470struct functor_traits<bitwise_xor_op<Scalar>> {
471  enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true };
472};
473
474}  // end namespace internal
475}  // end namespace Eigen
476
477namespace tensorflow {
478namespace functor {
479
480////////////////////////////////////////////////////////////////////////////////
481// Helpers
482////////////////////////////////////////////////////////////////////////////////
483
484// Base template for functors whose input scalar type is T and
485// output scalar type is R.
486template <typename T, typename F, typename R = T>
487struct base {
488  // func defines operator() and its vectorized version packetOp().
489  typedef F func;
490
491  // If true, the functor's corresponding binary op will instantiate
492  // specialized kernels to perform an optimized broadcast
493  // operation. Each functor for which this is enabled increases the
494  // code size, so by default this is disabled for binary functors and
495  // is enabled on a per-op basis as needed.
496  static const bool use_bcast_optimization = false;
497
498  // operator() has the signature:
499  //  out_type operator()(in_type in0, in_type in1 ...)
500  typedef R out_type;
501  typedef T in_type;
502
503  // TensorFlow provides tensor-ized version of "func". Roughly
504  // speaking, the tensorflow operation has the signature:
505  //   tout_type op(tin_type in0)
506  //   tout_type op(tin_type in0, tin_type in1)
507  //   tout_type op(tin_type in0, in_type scalar)
508  typedef typename TTypes<out_type>::Flat tout_type;
509  typedef typename TTypes<in_type>::ConstFlat tin_type;
510  typedef typename TTypes<in_type>::ConstScalar tscalar_type;
511
512  // Whether the functor can error out.  Currently applies only to integer
513  // div and mod.
514  static const bool has_errors = false;
515};
516
517// For now, we only apply certain speed optimization for
518// float/double's broadcast binary op.
519template <typename T>
520struct use_bcast_optimization {
521  static const bool value = false;
522};
523
524template <>
525struct use_bcast_optimization<float> {
526  static const bool value = true;
527};
528
529template <>
530struct use_bcast_optimization<double> {
531  static const bool value = true;
532};
533
534////////////////////////////////////////////////////////////////////////////////
535// Unary functors
536////////////////////////////////////////////////////////////////////////////////
537
538// abs(x) = |x|
539// neg(x) = - x
540// inverse(x) = 1 / x
541// square(x) = x^2
542// sqrt(x) = x^(1/2)
543// rsqrt(x) = x^(-1/2)
544// exp(x) = e^x
545// expm1(x) = e^x - 1
546// log(x) = natural logarithm of x
547// log1p(x) = natural logarithm of 1 + x
548// tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
549// sigmoid = 1 / (1 + exp(-x))  // a.k.a, logistic
550//
551// NOTE: We may eventually implement common functions used in NN
552// here. E.g., rectifier, softplus, derivatives of tanh, sigmod, etc.
553// For reference, see speech/lstm/eigen_functors.h.
554
555template <typename T>
556struct abs : base<T, Eigen::internal::scalar_abs_op<T>,
557                  typename Eigen::internal::scalar_abs_op<T>::result_type> {};
558
559template <typename T>
560struct neg : base<T, Eigen::internal::scalar_opposite_op<T>> {};
561
562template <typename T>
563struct inverse : base<T, Eigen::internal::scalar_inverse_op<T>> {};
564
565template <typename T>
566struct square : base<T, Eigen::internal::scalar_square_op<T>> {};
567
568template <typename T>
569struct sqrt : base<T, Eigen::internal::scalar_sqrt_op<T>> {};
570
571template <typename T>
572struct rsqrt : base<T, Eigen::internal::scalar_rsqrt_op<T>> {};
573
574template <typename T>
575struct exp : base<T, Eigen::internal::scalar_exp_op<T>> {};
576
577template <typename T>
578struct expm1 : base<T, Eigen::internal::scalar_expm1_op<T>> {};
579
580template <typename T>
581struct log : base<T, Eigen::internal::scalar_log_op<T>> {};
582
583template <typename T>
584struct log1p : base<T, Eigen::internal::scalar_log1p_op<T>> {};
585
586template <typename T>
587struct sign : base<T, Eigen::internal::scalar_sign_op<T>> {};
588
589template <typename T>
590struct sinh : base<T, Eigen::internal::scalar_sinh_op<T>> {};
591
592template <typename T>
593struct cosh : base<T, Eigen::internal::scalar_cosh_op<T>> {};
594
595template <typename T>
596struct tanh : base<T, Eigen::internal::scalar_tanh_op<T>> {};
597
598template <typename T>
599struct asinh : base<T, Eigen::internal::scalar_asinh_op<T>> {};
600
601template <typename T>
602struct acosh : base<T, Eigen::internal::scalar_acosh_op<T>> {};
603
604template <typename T>
605struct atanh : base<T, Eigen::internal::scalar_atanh_op<T>> {};
606
607template <typename T>
608struct lgamma : base<T, Eigen::internal::scalar_lgamma_op<T>> {};
609
610template <typename T>
611struct digamma : base<T, Eigen::internal::scalar_digamma_op<T>> {};
612
613template <typename T>
614struct erf : base<T, Eigen::internal::scalar_erf_op<T>> {};
615
616template <typename T>
617struct erfc : base<T, Eigen::internal::scalar_erfc_op<T>> {};
618
619template <typename T>
620struct sigmoid : base<T, Eigen::internal::scalar_sigmoid_op<T>> {};
621
622template <typename T>
623struct sin : base<T, Eigen::internal::scalar_sin_op<T>> {};
624
625template <typename T>
626struct cos : base<T, Eigen::internal::scalar_cos_op<T>> {};
627
628template <typename T>
629struct tan : base<T, Eigen::internal::scalar_tan_op<T>> {};
630
631template <typename T>
632struct asin : base<T, Eigen::internal::scalar_asin_op<T>> {};
633
634template <typename T>
635struct acos : base<T, Eigen::internal::scalar_acos_op<T>> {};
636
637template <typename T>
638struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {};
639
640struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> {
641};
642
643// Flip all bits. Named invert to be consistent with numpy.
644template <typename T>
645struct invert_op {
646  EIGEN_EMPTY_STRUCT_CTOR(invert_op)
647  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const {
648    return ~a;
649  }
650};
651
652template <typename T>
653struct invert : base<T, invert_op<T>> {};
654
655// NOTE: std::isinf, std::isnan, std::isfinite are plain function.
656// Therefore we need to wrap them in functors to be used with Eigen's
657// type system.
658template <typename T>
659struct isinf : base<T, Eigen::internal::scalar_isinf_op<T>, bool> {};
660
661template <typename T>
662struct isnan : base<T, Eigen::internal::scalar_isnan_op<T>, bool> {};
663
664template <typename T>
665struct isfinite : base<T, Eigen::internal::scalar_isfinite_op<T>, bool> {};
666
667template <typename T>
668struct floor : base<T, Eigen::internal::scalar_floor_op<T>> {};
669
670template <typename T>
671struct round : base<T, Eigen::internal::scalar_round_op_google<T>> {};
672
673template <typename T>
674struct ceil : base<T, Eigen::internal::scalar_ceil_op<T>> {};
675
676/** this should go in Eigen
677 * \brief Template functor to compute the round to int value of a scalar
678 */
679template <typename Scalar>
680struct scalar_rint_op {
681  EIGEN_EMPTY_STRUCT_CTOR(scalar_rint_op)
682  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
683  operator()(const Scalar& a) const {
684#if defined(__CUDACC__)
685    return ::rint(a);
686#elif defined(__ANDROID__)
687    return rint(a);
688#else
689    return std::rint(a);
690#endif
691  }
692};
693
694template <typename T>
695struct rint : base<T, scalar_rint_op<T>> {};
696
697////////////////////////////////////////////////////////////////////////////////
698// Binary functors
699////////////////////////////////////////////////////////////////////////////////
700
701// Binary functors:
702//
703// add(x, y) = x + y
704// sub(x, y) = x - y
705// mul(x, y) = x * y
706// div(x, y) = x / y
707// mod(x, y) = x % y         (int32 and int64 only)
708// fmod(x, y) = fmod(x, y)   (float and double only)
709// pow(x, y) = x ^ y
710// maximum(x, y) = x > y ? x : y
711// minimum(x, y) = x < y ? x : y
712// squared_difference(x, y) = (x - y) * (x - y)
713
714template <typename T>
715struct add : base<T, Eigen::internal::scalar_sum_op<T>> {
716  static const bool use_bcast_optimization = true;
717};
718
719template <typename T>
720struct sub : base<T, Eigen::internal::scalar_difference_op<T>> {
721  static const bool use_bcast_optimization = true;
722};
723
724template <typename T>
725struct mul : base<T, Eigen::internal::scalar_product_op<T>> {
726  static const bool use_bcast_optimization = true;
727};
728
729template <typename T>
730struct div : base<T, Eigen::internal::scalar_quotient_op<T>> {};
731
732template <typename T>
733struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op<
734                              T, Eigen::internal::scalar_quotient_op<T>>> {
735  static const bool has_errors = true;
736};
737
738template <typename T>
739struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {};
740
741template <typename T>
742struct mod : base<T, Eigen::internal::scalar_mod2_op<T>> {};
743
744template <typename T>
745struct safe_mod : base<T, Eigen::internal::safe_div_or_mod_op<
746                              T, Eigen::internal::scalar_mod2_op<T>>> {
747  static const bool has_errors = true;
748};
749
750template <typename T>
751struct floor_fmod : base<T, Eigen::internal::google_floor_fmod<T>> {};
752
753template <typename T>
754struct safe_floor_mod : base<T, Eigen::internal::safe_div_or_mod_op<
755                                    T, Eigen::internal::google_floor_mod<T>>> {
756  static const bool has_errors = true;
757};
758
759template <typename T>
760struct floor_div : base<T, Eigen::internal::google_floor_div<T>> {};
761
762template <typename T>
763struct safe_floor_div : base<T, Eigen::internal::safe_div_or_mod_op<
764                                    T, Eigen::internal::google_floor_div<T>>> {
765  static const bool has_errors = true;
766};
767
768template <typename T>
769struct floor_div_real : base<T, Eigen::internal::google_floor_div_real<T>> {};
770
771template <typename T>
772struct pow : base<T, Eigen::internal::scalar_binary_pow_op_google<T, T>> {};
773
774template <typename T>
775struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> {
776  static const bool has_errors = true;
777};
778
779template <typename T>
780struct maximum : base<T, Eigen::internal::scalar_max_op<T>> {};
781
782template <typename T>
783struct minimum : base<T, Eigen::internal::scalar_min_op<T>> {};
784
785template <typename T>
786struct igamma : base<T, Eigen::internal::scalar_igamma_op<T>> {};
787
788template <typename T>
789struct igammac : base<T, Eigen::internal::scalar_igammac_op<T>> {};
790
791template <typename T>
792struct zeta : base<T, Eigen::internal::scalar_zeta_op<T>> {};
793
794template <typename T>
795struct polygamma : base<T, Eigen::internal::scalar_polygamma_op<T>> {};
796
797template <typename Scalar>
798struct scalar_atan2_op {
799  EIGEN_EMPTY_STRUCT_CTOR(scalar_atan2_op)
800  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
801  operator()(const Scalar& y, const Scalar& x) const {
802#if GOOGLE_CUDA
803    return ::atan2(y, x);
804#else
805    return std::atan2(y, x);
806#endif
807  }
808};
809
810template <typename T>
811struct atan2 : base<T, scalar_atan2_op<T>> {};
812
813template <typename T>
814struct squared_difference
815    : base<T, Eigen::internal::scalar_compose_op<
816                  T, Eigen::internal::scalar_square_op<T>,
817                  Eigen::internal::scalar_difference_op<T>>> {};
818
819template <typename T>
820struct less : base<T, Eigen::internal::less<T>, bool> {};
821
822template <typename T>
823struct less_equal : base<T, Eigen::internal::less_equal<T>, bool> {};
824
825template <typename T>
826struct greater : base<T, Eigen::internal::greater<T>, bool> {};
827
828template <typename T>
829struct greater_equal : base<T, Eigen::internal::greater_equal<T>, bool> {};
830
831template <typename T>
832struct equal_to : base<T, Eigen::internal::equal_to<T>, bool> {};
833
834template <typename T>
835struct not_equal_to : base<T, Eigen::internal::not_equal_to<T>, bool> {};
836
837struct logical_and : base<bool, Eigen::internal::scalar_boolean_and_op> {};
838
839struct logical_or : base<bool, Eigen::internal::scalar_boolean_or_op> {};
840
841template <typename T>
842struct bitwise_and_op {
843  EIGEN_EMPTY_STRUCT_CTOR(bitwise_and_op)
844  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
845                                                           const T& y) const {
846    return x & y;
847  }
848};
849
850template <typename T>
851struct bitwise_or_op {
852  EIGEN_EMPTY_STRUCT_CTOR(bitwise_or_op)
853  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
854                                                           const T& y) const {
855    return x | y;
856  }
857};
858
859template <typename T>
860struct bitwise_and : base<T, bitwise_and_op<T>> {};
861
862template <typename T>
863struct bitwise_or : base<T, bitwise_or_op<T>> {};
864
865template <typename T>
866struct bitwise_xor : base<T, Eigen::internal::bitwise_xor_op<T>> {};
867
868template <typename T>
869struct left_shift_op {
870  EIGEN_EMPTY_STRUCT_CTOR(left_shift_op)
871  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
872                                                           const T& y) const {
873    // Avoids UB: don't shift by larger than the bitwidth of T, and
874    // performs left shifts as unsigned shifts.
875    T y_clamped = y;
876    if (y_clamped < 0) {
877      y_clamped = 0;
878    } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) {
879      y_clamped = sizeof(T) * CHAR_BIT - 1;
880    }
881    using U = typename std::make_unsigned<T>::type;
882    return static_cast<T>(static_cast<U>(x) << static_cast<U>(y_clamped));
883  }
884};
885
886template <typename T>
887struct right_shift_op {
888  EIGEN_EMPTY_STRUCT_CTOR(right_shift_op)
889  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
890                                                           const T& y) const {
891    // Avoids UB: don't shift by larger than the bitwidth of T.
892    T y_clamped = y;
893    if (y_clamped < 0) {
894      y_clamped = 0;
895    } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) {
896      y_clamped = sizeof(T) * CHAR_BIT - 1;
897    }
898    // Technically right shifts of signed integers are not necessarily
899    // arithmetic shifts according to the C++ standard. However in practice most
900    // implementations are arithmetic shifts. If this proves to be a problem in
901    // practice, we may need to use an alternative implementation.
902    return x >> y_clamped;
903  }
904};
905
906template <typename T>
907struct left_shift : base<T, left_shift_op<T>> {};
908
909template <typename T>
910struct right_shift : base<T, right_shift_op<T>> {};
911
912template <typename T>
913struct make_complex_func {
914  typedef std::complex<T> result_type;
915  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(T real,
916                                                               T imag) const {
917    return std::complex<T>(real, imag);
918  }
919};
920
921template <typename T>
922struct make_complex : base<T, make_complex_func<T>, std::complex<T>> {};
923
924template <typename T>
925struct get_real
926    : base<T, Eigen::internal::scalar_real_op<T>, typename T::value_type> {};
927
928template <typename T>
929struct get_imag
930    : base<T, Eigen::internal::scalar_imag_op<T>, typename T::value_type> {};
931
932template <typename T>
933struct get_angle
934    : base<T, Eigen::internal::scalar_arg_op<T>, typename T::value_type> {};
935
936template <typename T>
937struct conj : base<T, Eigen::internal::scalar_conjugate_op<T>> {};
938
939////////////////////////////////////////////////////////////////////////////////
940// Functors takes 1 or 2 tensors, computes the base functor on
941// coefficient of the input tensors and puts the results in the output
942// tensor.
943////////////////////////////////////////////////////////////////////////////////
944template <typename Device, typename Functor>
945struct UnaryFunctor {
946  // Computes on device "d": out[i] = Functor(in[i])
947  void operator()(const Device& d, typename Functor::tout_type out,
948                  typename Functor::tin_type in);
949};
950
951template <typename Device, typename Functor, int NDIMS,
952          bool has_errors = Functor::has_errors>
953struct BinaryFunctor {
954  // Computes on device "d": out[i] = Functor(in0[i], in1[i])
955  void operator()(const Device& d, typename Functor::tout_type out,
956                  typename Functor::tin_type in0,
957                  typename Functor::tin_type in1, bool* error);
958
959  // Computes on device "d": out[i] = Functor(scalar[0], in[i])
960  void Left(const Device& d, typename Functor::tout_type out,
961            typename Functor::tscalar_type scalar,
962            typename Functor::tin_type in, bool* error);
963
964  // Computes on device "d": out[i] = Functor(in[i], scalar[0])
965  void Right(const Device& d, typename Functor::tout_type out,
966             typename Functor::tin_type in,
967             typename Functor::tscalar_type scalar, bool* error);
968
969  // Computes on device "d":
970  //   out = Functor(in0.broadcast(bcast0), in1.broadcast(bcast1))
971  //
972  // TODO(zhifengc): makes BCast a template member function on NDIMS
973  // instead making BinaryFunctor templates on NDIMS.
974  void BCast(const Device& d,
975             typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
976             typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
977             typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
978             typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
979             typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
980             bool* error);
981};
982
983template <typename Device, typename T>
984struct ApproximateEqual {
985  void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
986                  typename TTypes<T>::ConstFlat y, T tolerance,
987                  typename TTypes<bool>::Flat z);
988};
989
990template <int NDIMS>
991bool AllOne(const typename Eigen::array<Eigen::DenseIndex, NDIMS>& a) {
992  for (size_t i = 0; i < a.size(); ++i) {
993    if (a[i] != 1) return false;
994  }
995  return true;
996}
997
998template <typename Device, typename T>
999struct SelectFunctor {
1000  void operator()(const Device& d, typename TTypes<T>::Flat out,
1001                  typename TTypes<bool>::ConstFlat cond_flat,
1002                  typename TTypes<T>::ConstFlat then_flat,
1003                  typename TTypes<T>::ConstFlat else_flat);
1004};
1005
1006template <typename Device, typename T>
1007struct SelectScalarFunctor {
1008  void operator()(const Device& d, typename TTypes<T>::Flat out,
1009                  typename TTypes<bool>::ConstScalar cond,
1010                  typename TTypes<T>::ConstFlat then_flat,
1011                  typename TTypes<T>::ConstFlat else_flat);
1012};
1013
1014template <typename Device, typename T>
1015struct BatchSelectFunctor {
1016  void operator()(const Device& d,
1017                  typename TTypes<T>::Matrix output_flat_outer_dims,
1018                  TTypes<bool>::ConstVec cond_vec,
1019                  typename TTypes<T>::ConstMatrix then_flat_outer_dims,
1020                  typename TTypes<T>::ConstMatrix else_flat_outer_dims);
1021};
1022
1023}  // end namespace functor
1024}  // end namespace tensorflow
1025
1026#endif  // TENSORFLOW_KERNELS_CWISE_OPS_H_
1027