math_grad_test.cc revision 23441bf71dbd0a10bcace160165ce01bf6deac9a
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/cc/framework/grad_op_registry.h" 17#include "tensorflow/cc/ops/standard_ops.h" 18#include "tensorflow/core/framework/node_def_util.h" 19#include "tensorflow/core/framework/tensor_testutil.h" 20#include "tensorflow/core/graph/default_device.h" 21#include "tensorflow/core/lib/core/status_test_util.h" 22#include "tensorflow/core/lib/random/random.h" 23#include "tensorflow/core/platform/test.h" 24#include "tensorflow/core/public/session.h" 25 26namespace tensorflow { 27using namespace ops; // NOLINT(build/namespaces) 28 29namespace { 30 31// TODO(andydavis) Test gradient function against numeric gradients output. 32// TODO(andydavis) As more gradients are added move common test functions 33// to a testutil library. 34class MathGradTest : public ::testing::Test { 35 protected: 36 MathGradTest() : root_(Scope::NewRootScope()) {} 37 38 void ComputeMatMulGrad(const Output& x, const bool t_x, const Output& y, 39 const bool t_y, const Output& dz, 40 std::vector<Tensor>* out) { 41 // Compute forward MatMul: z = MatMul(x, y). 42 auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); 43 TF_EXPECT_OK(root_.status()); 44 CHECK_NOTNULL(z.node()); 45 std::vector<Output> grad_outputs; 46 // Call MatMulGrad which populates 'grad_outputs'. 47 CallGradFunction(Operation(z.node()), {dz}, &grad_outputs); 48 EXPECT_EQ(2, grad_outputs.size()); 49 // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'. 50 GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out); 51 } 52 53 void CallGradFunction(const Operation& op, 54 const std::vector<Output>& grad_inputs, 55 std::vector<Output>* grad_outputs) { 56 GradFunc grad_fn; 57 TF_EXPECT_OK( 58 GradOpRegistry::Global()->Lookup(op.node()->type_string(), &grad_fn)); 59 TF_EXPECT_OK(grad_fn(root_, op, grad_inputs, grad_outputs)); 60 TF_EXPECT_OK(root_.status()); 61 } 62 63 Tensor ComputeMatMul(const Output& x, const bool t_x, const Output& y, 64 const bool t_y) { 65 auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); 66 TF_EXPECT_OK(root_.status()); 67 Tensor out; 68 GetTensor(root_, z, &out); 69 return out; 70 } 71 72 void RandMatMulGradData(const bool tx, const bool ty, 73 std::vector<Tensor>* data) { 74 // z = MatMul(x, y) 75 const int m = Rand(); 76 const int k = Rand(); 77 const int n = Rand(); 78 // x.shape = [m, k] 79 const TensorShape x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k}); 80 data->emplace_back(DT_FLOAT, x_shape); 81 RandTensor(&data->back()); 82 // y.shape = [k, n] 83 const TensorShape y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n}); 84 data->emplace_back(DT_FLOAT, y_shape); 85 RandTensor(&data->back()); 86 // z.shape = [m, n] 87 data->emplace_back(DT_FLOAT, TensorShape({m, n})); 88 RandTensor(&data->back()); 89 } 90 91 void RandTensor(Tensor* t) { 92 test::FillFn<float>( 93 t, [this](const int i) { return static_cast<float>(Rand()); }); 94 } 95 96 int Rand() { return 1 + (random::New64() % 10); } 97 98 // TODO(andydavis) Move 'GetTensors/GetTensor' to some testutil class. 99 // Note: they should be moved to a general/non-grad specific testutil class. 100 void GetTensors(const Scope& scope, OutputList tensors, 101 std::vector<Tensor>* out) { 102 SessionOptions options; 103 std::unique_ptr<Session> session(NewSession(options)); 104 GraphDef def; 105 scope.graph()->ToGraphDef(&def); 106 107 graph::SetDefaultDevice("/cpu:0", &def); 108 109 TF_CHECK_OK(session->Create(def)); 110 std::vector<string> names; 111 for (const auto& t : tensors) { 112 names.push_back(strings::StrCat(t.node()->name(), ":", t.index())); 113 } 114 TF_CHECK_OK(session->Run({}, names, {}, out)); 115 TF_CHECK_OK(session->Close()); 116 } 117 118 void GetTensor(const Scope& scope, Output tensor, Tensor* out) { 119 std::vector<Tensor> outputs; 120 GetTensors(scope, {tensor}, &outputs); 121 *out = outputs[0]; 122 } 123 124 Scope root_; 125}; 126 127TEST_F(MathGradTest, MatMulGrad_NoTranspose) { 128 std::vector<Tensor> data; 129 RandMatMulGradData(false, false, &data); 130 auto x = Const(root_, data[0]); 131 auto y = Const(root_, data[1]); 132 auto dz = Const(root_, data[2]); 133 134 std::vector<Tensor> grad_outputs; 135 ComputeMatMulGrad(x, false, y, false, dz, &grad_outputs); 136 137 test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, true)); 138 test::ExpectClose(grad_outputs[1], ComputeMatMul(x, true, dz, false)); 139} 140 141TEST_F(MathGradTest, MatMulGrad_TransposeX) { 142 std::vector<Tensor> data; 143 RandMatMulGradData(true, false, &data); 144 auto x = Const(root_, data[0]); 145 auto y = Const(root_, data[1]); 146 auto dz = Const(root_, data[2]); 147 148 std::vector<Tensor> grad_outputs; 149 ComputeMatMulGrad(x, true, y, false, dz, &grad_outputs); 150 151 test::ExpectClose(grad_outputs[0], ComputeMatMul(y, false, dz, true)); 152 test::ExpectClose(grad_outputs[1], ComputeMatMul(x, false, dz, false)); 153} 154 155TEST_F(MathGradTest, MatMulGrad_TransposeY) { 156 std::vector<Tensor> data; 157 RandMatMulGradData(false, true, &data); 158 auto x = Const(root_, data[0]); 159 auto y = Const(root_, data[1]); 160 auto dz = Const(root_, data[2]); 161 162 std::vector<Tensor> grad_outputs; 163 ComputeMatMulGrad(x, false, y, true, dz, &grad_outputs); 164 165 test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, false)); 166 test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, false)); 167} 168 169TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) { 170 std::vector<Tensor> data; 171 RandMatMulGradData(true, true, &data); 172 auto x = Const(root_, data[0]); 173 auto y = Const(root_, data[1]); 174 auto dz = Const(root_, data[2]); 175 176 std::vector<Tensor> grad_outputs; 177 ComputeMatMulGrad(x, true, y, true, dz, &grad_outputs); 178 179 test::ExpectClose(grad_outputs[0], ComputeMatMul(y, true, dz, true)); 180 test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, true)); 181} 182 183} // namespace 184} // namespace tensorflow 185