math_grad.cc revision bd548874a53c5f934e480991e9ca730b8d73657a
1f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarika
3cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut HazarikaLicensed under the Apache License, Version 2.0 (the "License");
4cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarikayou may not use this file except in compliance with the License.
5cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut HazarikaYou may obtain a copy of the License at
6cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarika
7cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarika    http://www.apache.org/licenses/LICENSE-2.0
8cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarika
9cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut HazarikaUnless required by applicable law or agreed to in writing, software
10cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarikadistributed under the License is distributed on an "AS IS" BASIS,
11cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut HazarikaWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut HazarikaSee the License for the specific language governing permissions and
13cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarikalimitations under the License.
14cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarika==============================================================================*/
15cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarika
16f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li#include "tensorflow/cc/ops/standard_ops.h"
17f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li
18f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li#include "tensorflow/cc/framework/grad_op_registry.h"
19f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li
20f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Linamespace tensorflow {
21f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Linamespace ops {
22f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Linamespace {
230594c42af26255fd8d3d7d39c0cb0e2da5b8841bThierry Strudel
240594c42af26255fd8d3d7d39c0cb0e2da5b8841bThierry Strudel// MatMulGrad helper function used to compute two MatMul operations
250594c42af26255fd8d3d7d39c0cb0e2da5b8841bThierry Strudel// based on input matrix transposition combinations.
26f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie LiStatus MatMulGradHelper(const Scope& scope, const bool is_batch,
27f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li                        const Output& x0, const bool adj_x0, const Output& x1,
28f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li                        const bool adj_x1, const Output& y0, const bool adj_y0,
29f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li                        const Output& y1, const bool adj_y1,
30c0529447ae16f023dfab2978ea2b245f368e893bAndy Qiu                        std::vector<Output>* grad_outputs) {
31f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li  if (is_batch == false) {
32f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li    auto dx =
33f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li        MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
34f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li    grad_outputs->push_back(dx);
35af3bf2227c951a59e2dcc44ab90790d247225375Andy Qiu    auto dy =
36af3bf2227c951a59e2dcc44ab90790d247225375Andy Qiu        MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
37f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li    grad_outputs->push_back(dy);
38f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li  } else {
39f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li    auto dx =
40f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li        BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1));
41c0529447ae16f023dfab2978ea2b245f368e893bAndy Qiu    grad_outputs->push_back(dx);
42f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li    auto dy =
43f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li        BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1));
44f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li    grad_outputs->push_back(dy);
459d1d3833469f52dbd2a017702bf0116fddc703bcAndy Qiu  }
46f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li  return scope.status();
47f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li}
48f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li
49c0529447ae16f023dfab2978ea2b245f368e893bAndy Qiu// MatMulGrad common used to read and check node attr state, and determine
50c0529447ae16f023dfab2978ea2b245f368e893bAndy Qiu// proper MatMul products for gradients based on input matrix transposition
51f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li// combinations.
5250fdcdd34912ff25d41e4d339298043d97e56ea3Sun, Jian// TODO(andydavis) Re-use this function for BatchMatMulGrad.
5350fdcdd34912ff25d41e4d339298043d97e56ea3Sun, JianStatus MatMulGradCommon(const Scope& scope, const Operation& op,
54979bcaa58b8db871baf3fd8cc03071d35190f194Lei Zhang                        const bool is_batch,
55f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li                        const std::vector<Output>& grad_inputs,
56f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li                        const string& attr_adj_x, const string& attr_adj_y,
57c0529447ae16f023dfab2978ea2b245f368e893bAndy Qiu                        std::vector<Output>* grad_outputs) {
58f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li  DataType dtype;
5903ec9fafe981e98a32150dfb1ded2da6a84c212dAndy Qiu  TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype));
60f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li  if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
61f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li    return errors::Unimplemented(
62f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li        "MatMul gradient for complex data type is not supported yet.");
63f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li  }
64c0529447ae16f023dfab2978ea2b245f368e893bAndy Qiu
65f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li  bool ta;
66f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li  bool tb;
67f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li  TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta));
68f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li  TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb));
69f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li
70c0529447ae16f023dfab2978ea2b245f368e893bAndy Qiu  if (!ta && !tb) {
71f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
72f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li                            true, op.input(0), true, grad_inputs[0], false,
73af3bf2227c951a59e2dcc44ab90790d247225375Andy Qiu                            grad_outputs);
74f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li  } else if (!ta && tb) {
75f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li    return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
76f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li                            false, grad_inputs[0], true, op.input(0), false,
77f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li                            grad_outputs);
78f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li  } else if (ta && !tb) {
79f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li    return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0],
80f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li                            true, op.input(0), false, grad_inputs[0], false,
81                            grad_outputs);
82  }
83  return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0],
84                          true, grad_inputs[0], true, op.input(0), true,
85                          grad_outputs);
86}
87
88Status MatMulGrad(const Scope& scope, const Operation& op,
89                  const std::vector<Output>& grad_inputs,
90                  std::vector<Output>* grad_outputs) {
91  return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a",
92                          "transpose_b", grad_outputs);
93}
94
95REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
96
97Status BatchMatMulGrad(const Scope& scope, const Operation& op,
98                       const std::vector<Output>& grad_inputs,
99                       std::vector<Output>* grad_outputs) {
100  return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y",
101                          grad_outputs);
102}
103REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad);
104
105}  // anonymous namespace
106}  // namespace ops
107}  // namespace tensorflow
108