math_grad.cc revision bd548874a53c5f934e480991e9ca730b8d73657a
1f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarika 3cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut HazarikaLicensed under the Apache License, Version 2.0 (the "License"); 4cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarikayou may not use this file except in compliance with the License. 5cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut HazarikaYou may obtain a copy of the License at 6cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarika 7cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarika http://www.apache.org/licenses/LICENSE-2.0 8cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarika 9cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut HazarikaUnless required by applicable law or agreed to in writing, software 10cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarikadistributed under the License is distributed on an "AS IS" BASIS, 11cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut HazarikaWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut HazarikaSee the License for the specific language governing permissions and 13cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarikalimitations under the License. 14cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarika==============================================================================*/ 15cd727f537d5085eec7f1b8f9c1d33922d4de75d4Prodyut Hazarika 16f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li#include "tensorflow/cc/ops/standard_ops.h" 17f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li 18f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li#include "tensorflow/cc/framework/grad_op_registry.h" 19f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li 20f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Linamespace tensorflow { 21f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Linamespace ops { 22f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Linamespace { 230594c42af26255fd8d3d7d39c0cb0e2da5b8841bThierry Strudel 240594c42af26255fd8d3d7d39c0cb0e2da5b8841bThierry Strudel// MatMulGrad helper function used to compute two MatMul operations 250594c42af26255fd8d3d7d39c0cb0e2da5b8841bThierry Strudel// based on input matrix transposition combinations. 26f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie LiStatus MatMulGradHelper(const Scope& scope, const bool is_batch, 27f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li const Output& x0, const bool adj_x0, const Output& x1, 28f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li const bool adj_x1, const Output& y0, const bool adj_y0, 29f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li const Output& y1, const bool adj_y1, 30c0529447ae16f023dfab2978ea2b245f368e893bAndy Qiu std::vector<Output>* grad_outputs) { 31f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li if (is_batch == false) { 32f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li auto dx = 33f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1)); 34f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li grad_outputs->push_back(dx); 35af3bf2227c951a59e2dcc44ab90790d247225375Andy Qiu auto dy = 36af3bf2227c951a59e2dcc44ab90790d247225375Andy Qiu MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1)); 37f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li grad_outputs->push_back(dy); 38f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li } else { 39f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li auto dx = 40f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1)); 41c0529447ae16f023dfab2978ea2b245f368e893bAndy Qiu grad_outputs->push_back(dx); 42f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li auto dy = 43f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1)); 44f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li grad_outputs->push_back(dy); 459d1d3833469f52dbd2a017702bf0116fddc703bcAndy Qiu } 46f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li return scope.status(); 47f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li} 48f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li 49c0529447ae16f023dfab2978ea2b245f368e893bAndy Qiu// MatMulGrad common used to read and check node attr state, and determine 50c0529447ae16f023dfab2978ea2b245f368e893bAndy Qiu// proper MatMul products for gradients based on input matrix transposition 51f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li// combinations. 5250fdcdd34912ff25d41e4d339298043d97e56ea3Sun, Jian// TODO(andydavis) Re-use this function for BatchMatMulGrad. 5350fdcdd34912ff25d41e4d339298043d97e56ea3Sun, JianStatus MatMulGradCommon(const Scope& scope, const Operation& op, 54979bcaa58b8db871baf3fd8cc03071d35190f194Lei Zhang const bool is_batch, 55f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li const std::vector<Output>& grad_inputs, 56f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li const string& attr_adj_x, const string& attr_adj_y, 57c0529447ae16f023dfab2978ea2b245f368e893bAndy Qiu std::vector<Output>* grad_outputs) { 58f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li DataType dtype; 5903ec9fafe981e98a32150dfb1ded2da6a84c212dAndy Qiu TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype)); 60f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { 61f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li return errors::Unimplemented( 62f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li "MatMul gradient for complex data type is not supported yet."); 63f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li } 64c0529447ae16f023dfab2978ea2b245f368e893bAndy Qiu 65f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li bool ta; 66f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li bool tb; 67f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta)); 68f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb)); 69f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li 70c0529447ae16f023dfab2978ea2b245f368e893bAndy Qiu if (!ta && !tb) { 71f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), 72f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li true, op.input(0), true, grad_inputs[0], false, 73af3bf2227c951a59e2dcc44ab90790d247225375Andy Qiu grad_outputs); 74f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li } else if (!ta && tb) { 75f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), 76f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li false, grad_inputs[0], true, op.input(0), false, 77f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li grad_outputs); 78f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li } else if (ta && !tb) { 79f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li return MatMulGradHelper(scope, is_batch, op.input(1), false, grad_inputs[0], 80f6d5b36e320f093f08855d64fa3d565eacae3c4bJackie Li true, op.input(0), false, grad_inputs[0], false, 81 grad_outputs); 82 } 83 return MatMulGradHelper(scope, is_batch, op.input(1), true, grad_inputs[0], 84 true, grad_inputs[0], true, op.input(0), true, 85 grad_outputs); 86} 87 88Status MatMulGrad(const Scope& scope, const Operation& op, 89 const std::vector<Output>& grad_inputs, 90 std::vector<Output>* grad_outputs) { 91 return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a", 92 "transpose_b", grad_outputs); 93} 94 95REGISTER_GRADIENT_OP("MatMul", MatMulGrad); 96 97Status BatchMatMulGrad(const Scope& scope, const Operation& op, 98 const std::vector<Output>& grad_inputs, 99 std::vector<Output>* grad_outputs) { 100 return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y", 101 grad_outputs); 102} 103REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad); 104 105} // anonymous namespace 106} // namespace ops 107} // namespace tensorflow 108