array_grad_test.cc revision 3b13cfdb8970fab86d5923d47791004eca92d4ff
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/cc/framework/grad_op_registry.h" 17#include "tensorflow/cc/framework/gradient_checker.h" 18#include "tensorflow/cc/framework/testutil.h" 19#include "tensorflow/cc/gradients/grad_testutil.h" 20#include "tensorflow/cc/ops/standard_ops.h" 21#include "tensorflow/core/framework/tensor_testutil.h" 22#include "tensorflow/core/lib/core/status_test_util.h" 23 24namespace tensorflow { 25using namespace ops; // NOLINT(build/namespaces) 26 27namespace { 28 29class PackGradTest : public ::testing::Test { 30 protected: 31 PackGradTest() : scope_(Scope::NewRootScope()) {} 32 33 void CheckGrad(const Output& grad_input, const int axis) { 34 auto a = ops::Const(scope_, 1, {2, 3}); 35 auto b = ops::Const(scope_, 2, {2, 3}); 36 37 auto pack = Pack(scope_, {a, b}, Pack::Axis(axis)); 38 TF_ASSERT_OK(scope_.status()); 39 40 std::vector<Output> grad_outputs; 41 TF_ASSERT_OK(test::CallGradFunction(scope_, Operation(pack.node()), 42 {grad_input}, &grad_outputs)); 43 44 std::vector<Tensor> outputs; 45 test::GetTensors(scope_, {grad_outputs[0], grad_outputs[1]}, &outputs); 46 47 test::ExpectTensorEqual<int>( 48 outputs[0], test::AsTensor<int>({1, 2, 3, 4, 5, 6}, {2, 3})); 49 test::ExpectTensorEqual<int>( 50 outputs[1], test::AsTensor<int>({7, 8, 9, 10, 11, 12}, {2, 3})); 51 } 52 53 Scope scope_; 54}; 55 56TEST_F(PackGradTest, Axis0) { 57 CheckGrad( 58 ops::Const(scope_, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 2, 3}), 59 0); 60} 61 62TEST_F(PackGradTest, Axis1) { 63 CheckGrad( 64 ops::Const(scope_, {1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}, {2, 2, 3}), 65 1); 66} 67 68class UnpackGradTest : public ::testing::Test { 69 protected: 70 UnpackGradTest() : scope_(Scope::NewRootScope()) {} 71 72 void CheckGrad(const std::vector<Output>& grad_inputs, const int num, 73 const int axis) { 74 auto a = ops::Const(scope_, 1, {4, 2, 3}); 75 76 auto unpack = Unpack(scope_, a, num, Unpack::Axis(axis)); 77 TF_ASSERT_OK(scope_.status()); 78 79 std::vector<Output> grad_outputs; 80 TF_ASSERT_OK(test::CallGradFunction(scope_, Operation(unpack[0].node()), 81 grad_inputs, &grad_outputs)); 82 83 Tensor expected_output(DT_INT32, {4, 2, 3}); 84 test::FillIota<int32>(&expected_output, 1); 85 86 Tensor output; 87 test::GetTensor(scope_, grad_outputs[0], &output); 88 89 test::ExpectTensorEqual<int>(output, expected_output); 90 } 91 92 Scope scope_; 93}; 94 95TEST_F(UnpackGradTest, Axis0) { 96 auto g0 = ops::Const(scope_, {1, 2, 3, 4, 5, 6}, {2, 3}); 97 auto g1 = ops::Const(scope_, {7, 8, 9, 10, 11, 12}, {2, 3}); 98 auto g2 = ops::Const(scope_, {13, 14, 15, 16, 17, 18}, {2, 3}); 99 auto g3 = ops::Const(scope_, {19, 20, 21, 22, 23, 24}, {2, 3}); 100 CheckGrad({g0, g1, g2, g3}, 4, 0); 101} 102 103TEST_F(UnpackGradTest, Axis1) { 104 auto g0 = 105 ops::Const(scope_, {{1, 2, 3}, {7, 8, 9}, {13, 14, 15}, {19, 20, 21}}); 106 auto g1 = 107 ops::Const(scope_, {{4, 5, 6}, {10, 11, 12}, {16, 17, 18}, {22, 23, 24}}); 108 CheckGrad({g0, g1}, 2, 1); 109} 110 111class ArrayGradTest : public ::testing::Test { 112 protected: 113 ArrayGradTest() : scope_(Scope::NewRootScope()) {} 114 115 void RunTest(const Output& x, const TensorShape& x_shape, const Output& y, 116 const TensorShape& y_shape) { 117 float max_error; 118 TF_ASSERT_OK( 119 ComputeGradientError(scope_, x, x_shape, y, y_shape, &max_error)); 120 EXPECT_LT(max_error, 1e-4); 121 } 122 123 Scope scope_; 124}; 125 126TEST_F(ArrayGradTest, IdentityGrad) { 127 TensorShape shape({5, 2}); 128 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); 129 auto y = Identity(scope_, x); 130 RunTest(x, shape, y, shape); 131} 132 133} // namespace 134} // namespace tensorflow 135