math_grad_test.cc revision 23441bf71dbd0a10bcace160165ce01bf6deac9a
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/framework/grad_op_registry.h"
17#include "tensorflow/cc/ops/standard_ops.h"
18#include "tensorflow/core/framework/node_def_util.h"
19#include "tensorflow/core/framework/tensor_testutil.h"
20#include "tensorflow/core/graph/default_device.h"
21#include "tensorflow/core/lib/core/status_test_util.h"
22#include "tensorflow/core/lib/random/random.h"
23#include "tensorflow/core/platform/test.h"
24#include "tensorflow/core/public/session.h"
25
26namespace tensorflow {
27using namespace ops;  // NOLINT(build/namespaces)
28
29namespace {
30
31// TODO(andydavis) Test gradient function against numeric gradients output.
32// TODO(andydavis) As more gradients are added move common test functions
33// to a testutil library.
34class MathGradTest : public ::testing::Test {
35 protected:
36  MathGradTest() : root_(Scope::NewRootScope()) {}
37
38  void ComputeMatMulGrad(const Output& x, const bool t_x, const Output& y,
39                         const bool t_y, const Output& dz,
40                         std::vector<Tensor>* out) {
41    // Compute forward MatMul: z = MatMul(x, y).
42    auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
43    TF_EXPECT_OK(root_.status());
44    CHECK_NOTNULL(z.node());
45    std::vector<Output> grad_outputs;
46    // Call MatMulGrad which populates 'grad_outputs'.
47    CallGradFunction(Operation(z.node()), {dz}, &grad_outputs);
48    EXPECT_EQ(2, grad_outputs.size());
49    // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'.
50    GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
51  }
52
53  void CallGradFunction(const Operation& op,
54                        const std::vector<Output>& grad_inputs,
55                        std::vector<Output>* grad_outputs) {
56    GradFunc grad_fn;
57    TF_EXPECT_OK(
58        GradOpRegistry::Global()->Lookup(op.node()->type_string(), &grad_fn));
59    TF_EXPECT_OK(grad_fn(root_, op, grad_inputs, grad_outputs));
60    TF_EXPECT_OK(root_.status());
61  }
62
63  Tensor ComputeMatMul(const Output& x, const bool t_x, const Output& y,
64                       const bool t_y) {
65    auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
66    TF_EXPECT_OK(root_.status());
67    Tensor out;
68    GetTensor(root_, z, &out);
69    return out;
70  }
71
72  void RandMatMulGradData(const bool tx, const bool ty,
73                          std::vector<Tensor>* data) {
74    // z = MatMul(x, y)
75    const int m = Rand();
76    const int k = Rand();
77    const int n = Rand();
78    // x.shape = [m, k]
79    const TensorShape x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
80    data->emplace_back(DT_FLOAT, x_shape);
81    RandTensor(&data->back());
82    // y.shape = [k, n]
83    const TensorShape y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
84    data->emplace_back(DT_FLOAT, y_shape);
85    RandTensor(&data->back());
86    // z.shape = [m, n]
87    data->emplace_back(DT_FLOAT, TensorShape({m, n}));
88    RandTensor(&data->back());
89  }
90
91  void RandTensor(Tensor* t) {
92    test::FillFn<float>(
93        t, [this](const int i) { return static_cast<float>(Rand()); });
94  }
95
96  int Rand() { return 1 + (random::New64() % 10); }
97
98  // TODO(andydavis) Move 'GetTensors/GetTensor' to some testutil class.
99  // Note: they should be moved to a general/non-grad specific testutil class.
100  void GetTensors(const Scope& scope, OutputList tensors,
101                  std::vector<Tensor>* out) {
102    SessionOptions options;
103    std::unique_ptr<Session> session(NewSession(options));
104    GraphDef def;
105    scope.graph()->ToGraphDef(&def);
106
107    graph::SetDefaultDevice("/cpu:0", &def);
108
109    TF_CHECK_OK(session->Create(def));
110    std::vector<string> names;
111    for (const auto& t : tensors) {
112      names.push_back(strings::StrCat(t.node()->name(), ":", t.index()));
113    }
114    TF_CHECK_OK(session->Run({}, names, {}, out));
115    TF_CHECK_OK(session->Close());
116  }
117
118  void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
119    std::vector<Tensor> outputs;
120    GetTensors(scope, {tensor}, &outputs);
121    *out = outputs[0];
122  }
123
124  Scope root_;
125};
126
127TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
128  std::vector<Tensor> data;
129  RandMatMulGradData(false, false, &data);
130  auto x = Const(root_, data[0]);
131  auto y = Const(root_, data[1]);
132  auto dz = Const(root_, data[2]);
133
134  std::vector<Tensor> grad_outputs;
135  ComputeMatMulGrad(x, false, y, false, dz, &grad_outputs);
136
137  test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, true));
138  test::ExpectClose(grad_outputs[1], ComputeMatMul(x, true, dz, false));
139}
140
141TEST_F(MathGradTest, MatMulGrad_TransposeX) {
142  std::vector<Tensor> data;
143  RandMatMulGradData(true, false, &data);
144  auto x = Const(root_, data[0]);
145  auto y = Const(root_, data[1]);
146  auto dz = Const(root_, data[2]);
147
148  std::vector<Tensor> grad_outputs;
149  ComputeMatMulGrad(x, true, y, false, dz, &grad_outputs);
150
151  test::ExpectClose(grad_outputs[0], ComputeMatMul(y, false, dz, true));
152  test::ExpectClose(grad_outputs[1], ComputeMatMul(x, false, dz, false));
153}
154
155TEST_F(MathGradTest, MatMulGrad_TransposeY) {
156  std::vector<Tensor> data;
157  RandMatMulGradData(false, true, &data);
158  auto x = Const(root_, data[0]);
159  auto y = Const(root_, data[1]);
160  auto dz = Const(root_, data[2]);
161
162  std::vector<Tensor> grad_outputs;
163  ComputeMatMulGrad(x, false, y, true, dz, &grad_outputs);
164
165  test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, false));
166  test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, false));
167}
168
169TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
170  std::vector<Tensor> data;
171  RandMatMulGradData(true, true, &data);
172  auto x = Const(root_, data[0]);
173  auto y = Const(root_, data[1]);
174  auto dz = Const(root_, data[2]);
175
176  std::vector<Tensor> grad_outputs;
177  ComputeMatMulGrad(x, true, y, true, dz, &grad_outputs);
178
179  test::ExpectClose(grad_outputs[0], ComputeMatMul(y, true, dz, true));
180  test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, true));
181}
182
183}  // namespace
184}  // namespace tensorflow
185