1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
3084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License");
4084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFloweryou may not use this file except in compliance with the License.
5084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlowerYou may obtain a copy of the License at
6084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
7084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower    http://www.apache.org/licenses/LICENSE-2.0
8084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
9084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software
10084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS,
11084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlowerSee the License for the specific language governing permissions and
13084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlowerlimitations under the License.
14084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower==============================================================================*/
15084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
169f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlower#include <memory>
17b481783fe0e00a86f6feb20a8dcad5fc4fc936a4Josh Levenberg#include <vector>
189f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlower
19084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower#include "tensorflow/core/framework/function_testlib.h"
20084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower#include "tensorflow/core/framework/tensor_testutil.h"
21084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower#include "tensorflow/core/platform/test.h"
22084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower#include "tensorflow/core/public/session.h"
23084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
24084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlowernamespace tensorflow {
259f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlowernamespace {
26084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
27084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlowernamespace f = test::function;
289f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlowerusing FDH = FunctionDefHelper;
29084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
309f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlowerstd::unique_ptr<Session> NewSession() {
31084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  SessionOptions opts;
32084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  (*opts.config.mutable_device_count())["CPU"] = 1;
339f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlower  return std::unique_ptr<Session>(NewSession(opts));
34084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower}
35084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
36084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlowerstd::vector<Tensor> PackGrad(const Tensor& x0, const Tensor& x1,
37eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower                             const Tensor& dy, int axis) {
38084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  auto T = DT_FLOAT;
39084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  auto gdef = test::function::GDef(
40084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower      {f::NDef("x0", "Placeholder", {}, {{"dtype", T}}),
41084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower       f::NDef("x1", "Placeholder", {}, {{"dtype", T}}),
42eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower       f::NDef("axis", "Placeholder", {}, {{"dtype", DT_INT32}}),
43084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower       f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
44084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower       f::NDef("dx", "SymbolicGradient", {"x0", "x1", "dy"},
45eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower               {{"f", FDH::FunctionRef("Pack",
46eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower                                       {{"N", 2}, {"T", T}, {"axis", axis}})},
47084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                {"Tin", DataTypeSlice{T, T, T}},
48084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                {"Tout", DataTypeSlice{T, T}}})});
49084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  VLOG(1) << DebugStringWhole(gdef);
50084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  auto sess = NewSession();
51084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  TF_CHECK_OK(sess->Create(gdef));
52084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  std::vector<Tensor> out;
53eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower  TF_CHECK_OK(sess->Run({{"x0:0", x0},
54eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower                         {"x1:0", x1},
55eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower                         {"axis:0", test::AsScalar(axis)},
56eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower                         {"dy:0", dy}},
57084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                        {"dx:0", "dx:1"}, {}, &out));
58084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  CHECK_EQ(out.size(), 2);
59084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  TF_CHECK_OK(sess->Close());
60084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  return out;
61084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower}
62084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
639f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlowerTEST(ArrayGradTest, PackGrad) {
64084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  Tensor x0(DT_FLOAT, {2, 3});
65084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  x0.flat<float>().setZero();
66084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  Tensor x1(DT_FLOAT, {2, 3});
67084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  x1.flat<float>().setZero();
68084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  Tensor dy(DT_FLOAT, {2, 2, 3});
69084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  test::FillIota<float>(&dy, 0);
70eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower  auto dx = PackGrad(x0, x1, dy, 0);
71084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  test::ExpectClose(dx[0],
72084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                    test::AsTensor<float>({0., 1., 2., 3., 4., 5.}, {2, 3}));
73084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  test::ExpectClose(dx[1],
74084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                    test::AsTensor<float>({6., 7., 8., 9., 10., 11.}, {2, 3}));
75084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower}
76084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
77084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlowerstd::vector<Tensor> UnpackGrad(const Tensor& x, const Tensor& dy0,
78eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower                               const Tensor& dy1, int axis) {
79084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  auto T = DT_FLOAT;
80084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  auto gdef = test::function::GDef(
81084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower      {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
82eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower       f::NDef("axis", "Placeholder", {}, {{"dtype", DT_INT32}}),
83084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower       f::NDef("dy0", "Placeholder", {}, {{"dtype", T}}),
84084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower       f::NDef("dy1", "Placeholder", {}, {{"dtype", T}}),
85084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower       f::NDef("dx", "SymbolicGradient", {"x", "dy0", "dy1"},
86eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower               {{"f", FDH::FunctionRef("Unpack",
87eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower                                       {{"num", 2}, {"T", T}, {"axis", axis}})},
88084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                {"Tin", DataTypeSlice{T, T, T}},
89084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                {"Tout", DataTypeSlice{T}}})});
90084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  VLOG(1) << DebugStringWhole(gdef);
91084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  auto sess = NewSession();
92084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  TF_CHECK_OK(sess->Create(gdef));
93084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  std::vector<Tensor> out;
94eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower  TF_CHECK_OK(sess->Run({{"x:0", x},
95eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower                         {"axis:0", test::AsScalar(axis)},
96eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower                         {"dy0:0", dy0},
97eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower                         {"dy1:0", dy1}},
98eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower                        {"dx:0"}, {}, &out));
99084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  CHECK_EQ(out.size(), 1);
100084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  TF_CHECK_OK(sess->Close());
101084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  return out;
102084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower}
103084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
1049f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlowerTEST(ArrayGradTest, UnpackGrad) {
105084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  Tensor x(DT_FLOAT, {2, 2, 3});
106084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  x.flat<float>().setZero();
107084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  Tensor dy0(DT_FLOAT, {2, 3});
108084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  Tensor dy1(DT_FLOAT, {2, 3});
109084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  test::FillIota<float>(&dy0, 0);
110084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  test::FillIota<float>(&dy1, 100);
111eff93149a6dc8e6826898fd9f9c28c81e21c9836A. Unique TensorFlower  auto dx = UnpackGrad(x, dy0, dy1, 0);
112084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  test::ExpectClose(dx[0], test::AsTensor<float>({0., 1., 2., 3., 4., 5., 100.,
113084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                                                  101., 102., 103., 104., 105.},
114084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                                                 {2, 2, 3}));
115084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower}
116084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
117084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlowerstd::vector<Tensor> ConcatGrad(int dim, const Tensor& x0, const Tensor& x1,
118084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                               const Tensor& dy) {
119084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  auto T = DT_FLOAT;
120084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  auto gdef = test::function::GDef(
121084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower      {f::NDef("dim", "Placeholder", {}, {{"dtype", DT_INT32}}),
122084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower       f::NDef("x0", "Placeholder", {}, {{"dtype", T}}),
123084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower       f::NDef("x1", "Placeholder", {}, {{"dtype", T}}),
124084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower       f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
125084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower       f::NDef("dx", "SymbolicGradient", {"dim", "x0", "x1", "dy"},
126084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower               {{"f", FDH::FunctionRef("Concat", {{"N", 2}, {"T", T}})},
127084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                {"Tin", DataTypeSlice{DT_INT32, T, T, T}},
128084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                {"Tout", DataTypeSlice{DT_INT32, T, T}}})});
129084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  VLOG(1) << DebugStringWhole(gdef);
130084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  auto sess = NewSession();
131084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  TF_CHECK_OK(sess->Create(gdef));
132084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  std::vector<Tensor> out;
133cf6cdcf6dc7e5dba2b0f729b86bf1f14ca73666eManjunath Kudlur  TF_CHECK_OK(sess->Run(
134cf6cdcf6dc7e5dba2b0f729b86bf1f14ca73666eManjunath Kudlur      {{"dim", test::AsScalar(dim)}, {"x0:0", x0}, {"x1:0", x1}, {"dy:0", dy}},
135cf6cdcf6dc7e5dba2b0f729b86bf1f14ca73666eManjunath Kudlur      {"dx:0", "dx:1", "dx:2"}, {}, &out));
136084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  CHECK_EQ(out.size(), 3);
137084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  TF_CHECK_OK(sess->Close());
138084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  return out;
139084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower}
140084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
14195b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlowerstd::vector<Tensor> ConcatGradV2(int dim, const Tensor& x0, const Tensor& x1,
14295b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower                                 const Tensor& dy) {
14395b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower  auto T = DT_FLOAT;
14495b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower  auto gdef = test::function::GDef(
14595b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower      {f::NDef("x0", "Placeholder", {}, {{"dtype", T}}),
14695b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower       f::NDef("x1", "Placeholder", {}, {{"dtype", T}}),
14795b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower       f::NDef("dim", "Placeholder", {}, {{"dtype", DT_INT32}}),
14895b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower       f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
14995b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower       f::NDef("dx", "SymbolicGradient", {"x0", "x1", "dim", "dy"},
15095b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower               {{"f", FDH::FunctionRef("ConcatV2", {{"N", 2}, {"T", T}})},
15195b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower                {"Tin", DataTypeSlice{T, T, DT_INT32, T}},
15295b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower                {"Tout", DataTypeSlice{T, T, DT_INT32}}})});
15395b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower  VLOG(1) << DebugStringWhole(gdef);
15495b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower  auto sess = NewSession();
15595b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower  TF_CHECK_OK(sess->Create(gdef));
15695b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower  std::vector<Tensor> out;
15795b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower  TF_CHECK_OK(sess->Run(
15895b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower      {{"x0:0", x0}, {"x1:0", x1}, {"dim", test::AsScalar(dim)}, {"dy:0", dy}},
15995b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower      {"dx:0", "dx:1", "dx:2"}, {}, &out));
16095b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower  CHECK_EQ(out.size(), 3);
16195b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower  TF_CHECK_OK(sess->Close());
16295b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower  return out;
16395b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower}
16495b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower
1659f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlowerTEST(ArrayGradTest, ConcatGrad) {
166084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  Tensor x0(DT_FLOAT, {2, 3, 5});
167084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  x0.flat<float>().setZero();
168084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  Tensor x1(DT_FLOAT, {2, 1, 5});
169084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  x1.flat<float>().setZero();
170084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  Tensor dy(DT_FLOAT, {2, 4, 5});
171084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  test::FillIota<float>(&dy, 0);
172f3d5990d6971faf9b15f5e6fa53e2e640742b8e0A. Unique TensorFlower
173f3d5990d6971faf9b15f5e6fa53e2e640742b8e0A. Unique TensorFlower  // Test Concat.
174084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  auto dx = ConcatGrad(1, x0, x1, dy);
175cf6cdcf6dc7e5dba2b0f729b86bf1f14ca73666eManjunath Kudlur  test::ExpectTensorEqual<int32>(dx[0], test::AsScalar(0));
176084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  test::ExpectClose(
177084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower      dx[1],
178084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower      test::AsTensor<float>({0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.,
179084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                             10., 11., 12., 13., 14., 20., 21., 22., 23., 24.,
180084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                             25., 26., 27., 28., 29., 30., 31., 32., 33., 34.},
181084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                            {2, 3, 5}));
182084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  test::ExpectClose(dx[2], test::AsTensor<float>({15., 16., 17., 18., 19., 35.,
183084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                                                  36., 37., 38., 39.},
184084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                                                 {2, 1, 5}));
18595b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower
186f3d5990d6971faf9b15f5e6fa53e2e640742b8e0A. Unique TensorFlower  // Test ConcatV2 with positive concat axis.
18795b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower  dx = ConcatGradV2(1, x0, x1, dy);
18895b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower  test::ExpectTensorEqual<int32>(dx[dx.size() - 1], test::AsScalar(0));
18995b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower  test::ExpectClose(
19095b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower      dx[0],
19195b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower      test::AsTensor<float>({0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.,
19295b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower                             10., 11., 12., 13., 14., 20., 21., 22., 23., 24.,
19395b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower                             25., 26., 27., 28., 29., 30., 31., 32., 33., 34.},
19495b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower                            {2, 3, 5}));
195f3d5990d6971faf9b15f5e6fa53e2e640742b8e0A. Unique TensorFlower  test::ExpectClose(dx[1], test::AsTensor<float>({15., 16., 17., 18., 19., 35.,
196f3d5990d6971faf9b15f5e6fa53e2e640742b8e0A. Unique TensorFlower                                                  36., 37., 38., 39.},
197f3d5990d6971faf9b15f5e6fa53e2e640742b8e0A. Unique TensorFlower                                                 {2, 1, 5}));
198f3d5990d6971faf9b15f5e6fa53e2e640742b8e0A. Unique TensorFlower
199f3d5990d6971faf9b15f5e6fa53e2e640742b8e0A. Unique TensorFlower  // Test ConcatV2 with negative concat axis.
200f3d5990d6971faf9b15f5e6fa53e2e640742b8e0A. Unique TensorFlower  dx = ConcatGradV2(-2, x0, x1, dy);
201f3d5990d6971faf9b15f5e6fa53e2e640742b8e0A. Unique TensorFlower  test::ExpectTensorEqual<int32>(dx[dx.size() - 1], test::AsScalar(0));
202f3d5990d6971faf9b15f5e6fa53e2e640742b8e0A. Unique TensorFlower  test::ExpectClose(
203f3d5990d6971faf9b15f5e6fa53e2e640742b8e0A. Unique TensorFlower      dx[0],
204f3d5990d6971faf9b15f5e6fa53e2e640742b8e0A. Unique TensorFlower      test::AsTensor<float>({0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.,
205f3d5990d6971faf9b15f5e6fa53e2e640742b8e0A. Unique TensorFlower                             10., 11., 12., 13., 14., 20., 21., 22., 23., 24.,
206f3d5990d6971faf9b15f5e6fa53e2e640742b8e0A. Unique TensorFlower                             25., 26., 27., 28., 29., 30., 31., 32., 33., 34.},
207f3d5990d6971faf9b15f5e6fa53e2e640742b8e0A. Unique TensorFlower                            {2, 3, 5}));
20895b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower  test::ExpectClose(dx[1], test::AsTensor<float>({15., 16., 17., 18., 19., 35.,
20995b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower                                                  36., 37., 38., 39.},
21095b55a97df1817c5de817e84d667d3b7a17f50f0A. Unique TensorFlower                                                 {2, 1, 5}));
211084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower}
212084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
213084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlowerstd::vector<Tensor> SplitGrad(int dim, const Tensor& x, const Tensor& dy0,
214084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                              const Tensor& dy1) {
215084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  auto T = DT_FLOAT;
216084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  auto gdef = test::function::GDef(
217084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower      {f::NDef("dim", "Placeholder", {}, {{"dtype", DT_INT32}}),
218084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower       f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
219084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower       f::NDef("dy0", "Placeholder", {}, {{"dtype", T}}),
220084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower       f::NDef("dy1", "Placeholder", {}, {{"dtype", T}}),
221084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower       f::NDef("dx", "SymbolicGradient", {"dim", "x", "dy0", "dy1"},
222084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower               {{"f", FDH::FunctionRef(
223084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                          "Split",
224084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                          {{"split_dim", dim}, {"num_split", 2}, {"T", T}})},
225084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                {"Tin", DataTypeSlice{DT_INT32, T, T, T}},
226084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                {"Tout", DataTypeSlice{DT_INT32, T}}})});
227084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  VLOG(1) << DebugStringWhole(gdef);
228084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  auto sess = NewSession();
229084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  TF_CHECK_OK(sess->Create(gdef));
230084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  std::vector<Tensor> out;
231cf6cdcf6dc7e5dba2b0f729b86bf1f14ca73666eManjunath Kudlur  TF_CHECK_OK(sess->Run({{"dim", test::AsScalar(dim)},
232084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                         {"x:0", x},
233084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                         {"dy0:0", dy0},
234084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                         {"dy1:0", dy1}},
235084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                        {"dx:0", "dx:1"}, {}, &out));
236084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  CHECK_EQ(out.size(), 2);
237084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  TF_CHECK_OK(sess->Close());
238084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  return out;
239084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower}
240084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
2419f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlowerTEST(ArrayGradTest, SplitGrad) {
242084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  Tensor x(DT_FLOAT, {2, 4, 5});
243084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  x.flat<float>().setZero();
244084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  Tensor dy0(DT_FLOAT, {2, 2, 5});
245084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  Tensor dy1(DT_FLOAT, {2, 2, 5});
246084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  test::FillIota<float>(&dy0, 0);
247084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  test::FillIota<float>(&dy1, 100);
248084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  auto dx = SplitGrad(1, x, dy0, dy1);
249cf6cdcf6dc7e5dba2b0f729b86bf1f14ca73666eManjunath Kudlur  test::ExpectTensorEqual<int32>(dx[0], test::AsScalar(0));
250084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  test::ExpectClose(
251084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower      dx[1], test::AsTensor<float>(
252084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                 {0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,
253084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                  100., 101., 102., 103., 104., 105., 106., 107., 108., 109.,
254084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                  10.,  11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,
255084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                  110., 111., 112., 113., 114., 115., 116., 117., 118., 119.},
256084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower                 {2, 4, 5}));
257084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower}
258084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
25923eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlowerstd::vector<Tensor> ReshapeGrad(const Tensor& x, const Tensor& s,
26023eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                                const Tensor& dy) {
26123eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  auto T = DT_FLOAT;
26223eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  auto gdef = test::function::GDef(
26323eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower      {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
26423eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower       f::NDef("s", "Placeholder", {}, {{"dtype", DT_INT32}}),
26523eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower       f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
26623eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower       f::NDef("dx", "SymbolicGradient", {"x", "s", "dy"},
26723eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower               {{"f", FDH::FunctionRef("Reshape", {{"T", T}})},
26823eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                {"Tin", DataTypeSlice{T, DT_INT32, T}},
26923eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                {"Tout", DataTypeSlice{T, DT_INT32}}})});
27023eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  VLOG(1) << DebugStringWhole(gdef);
27123eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  auto sess = NewSession();
27223eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  TF_CHECK_OK(sess->Create(gdef));
27323eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  std::vector<Tensor> out;
27423eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  TF_CHECK_OK(sess->Run({{"x:0", x}, {"s:0", s}, {"dy:0", dy}},
27523eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                        {"dx:0", "dx:1"}, {}, &out));
27623eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  CHECK_EQ(out.size(), 2);
27723eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  TF_CHECK_OK(sess->Close());
27823eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  return out;
27923eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower}
28023eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower
2819f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlowerTEST(ArrayGradTest, ReshapeGrad) {
28223eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  Tensor x(DT_FLOAT, {2, 4, 5});
28323eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  x.flat<float>().setZero();
28423eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  auto s = test::AsTensor<int32>({8, 5});
28523eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  Tensor dy(DT_FLOAT, {8, 5});
28623eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  test::FillIota<float>(&dy, 73);
28723eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  auto dx = ReshapeGrad(x, s, dy);
28823eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  test::ExpectClose(
28923eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower      dx[0], test::AsTensor<float>(
29023eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                 {73.,  74.,  75.,  76.,  77.,  78.,  79.,  80.,  81.,  82.,
29123eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                  83.,  84.,  85.,  86.,  87.,  88.,  89.,  90.,  91.,  92.,
29223eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                  93.,  94.,  95.,  96.,  97.,  98.,  99.,  100., 101., 102.,
29323eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                  103., 104., 105., 106., 107., 108., 109., 110., 111., 112.},
29423eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                 {2, 4, 5}));
29523eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0}));
29623eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower}
29723eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower
29823eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlowerstd::vector<Tensor> ExpandDimsGrad(const Tensor& x, const Tensor& s,
29923eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                                   const Tensor& dy) {
30023eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  auto T = DT_FLOAT;
30123eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  auto gdef = test::function::GDef(
30223eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower      {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
30323eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower       f::NDef("s", "Placeholder", {}, {{"dtype", DT_INT32}}),
30423eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower       f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
30523eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower       f::NDef("dx", "SymbolicGradient", {"x", "s", "dy"},
30623eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower               {{"f", FDH::FunctionRef("ExpandDims", {{"T", T}})},
30723eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                {"Tin", DataTypeSlice{T, DT_INT32, T}},
30823eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                {"Tout", DataTypeSlice{T, DT_INT32}}})});
30923eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  VLOG(1) << DebugStringWhole(gdef);
31023eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  auto sess = NewSession();
31123eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  TF_CHECK_OK(sess->Create(gdef));
31223eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  std::vector<Tensor> out;
31323eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  TF_CHECK_OK(sess->Run({{"x:0", x}, {"s:0", s}, {"dy:0", dy}},
31423eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                        {"dx:0", "dx:1"}, {}, &out));
31523eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  CHECK_EQ(out.size(), 2);
31623eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  TF_CHECK_OK(sess->Close());
31723eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  return out;
31823eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower}
31923eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower
3209f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlowerTEST(ArrayGradTest, ExpandDimsGrad) {
32123eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  Tensor x(DT_FLOAT, {2, 4, 5});
32223eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  x.flat<float>().setZero();
32323eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  auto s = test::AsTensor<int32>({1});
32423eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  Tensor dy(DT_FLOAT, {2, 1, 4, 5});
32523eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  test::FillIota<float>(&dy, 73);
32623eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  auto dx = ExpandDimsGrad(x, s, dy);
32723eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  test::ExpectClose(
32823eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower      dx[0], test::AsTensor<float>(
32923eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                 {73.,  74.,  75.,  76.,  77.,  78.,  79.,  80.,  81.,  82.,
33023eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                  83.,  84.,  85.,  86.,  87.,  88.,  89.,  90.,  91.,  92.,
33123eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                  93.,  94.,  95.,  96.,  97.,  98.,  99.,  100., 101., 102.,
33223eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                  103., 104., 105., 106., 107., 108., 109., 110., 111., 112.},
33323eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                 {2, 4, 5}));
33423eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0}));
33523eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower}
33623eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower
33777ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlowerstd::vector<Tensor> SqueezeGrad(const Tensor& x, const Tensor& dy) {
33877ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  auto T = DT_FLOAT;
33977ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  auto gdef = test::function::GDef(
34077ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower      {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
34177ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower       f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
34277ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower       f::NDef("dx", "SymbolicGradient", {"x", "dy"},
34377ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower               {{"f", FDH::FunctionRef("Squeeze", {{"T", T}})},
34477ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower                {"Tin", DataTypeSlice{T, T}},
34577ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower                {"Tout", DataTypeSlice{T}}})});
34677ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  VLOG(1) << DebugStringWhole(gdef);
34777ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  auto sess = NewSession();
34877ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  TF_CHECK_OK(sess->Create(gdef));
34977ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  std::vector<Tensor> out;
35077ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  TF_CHECK_OK(sess->Run({{"x:0", x}, {"dy:0", dy}}, {"dx:0"}, {}, &out));
35177ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  CHECK_EQ(out.size(), 1);
35277ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  TF_CHECK_OK(sess->Close());
35377ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  return out;
35477ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower}
35577ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower
3569f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlowerTEST(ArrayGradTest, SqueezeGrad) {
35777ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  Tensor x(DT_FLOAT, {2, 1, 3});
35877ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  x.flat<float>().setZero();
35977ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  Tensor dy(DT_FLOAT, {2, 3});
36077ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  test::FillIota<float>(&dy, 1);
36177ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  auto dx = SqueezeGrad(x, dy);
36277ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  test::ExpectClose(dx[0],
36377ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower                    test::AsTensor<float>({1., 2., 3., 4., 5., 6.}, {2, 1, 3}));
36477ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower}
36577ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower
36623eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlowerstd::vector<Tensor> TransposeGrad(const Tensor& x, const Tensor& p,
36723eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                                  const Tensor& dy) {
36823eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  auto T = DT_FLOAT;
36923eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  auto gdef = test::function::GDef(
37023eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower      {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
37123eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower       f::NDef("p", "Placeholder", {}, {{"dtype", DT_INT32}}),
37223eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower       f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
37323eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower       f::NDef("dx", "SymbolicGradient", {"x", "p", "dy"},
37423eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower               {{"f", FDH::FunctionRef("Transpose", {{"T", T}})},
37523eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                {"Tin", DataTypeSlice{T, DT_INT32, T}},
37623eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                {"Tout", DataTypeSlice{T, DT_INT32}}})});
37723eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  VLOG(1) << DebugStringWhole(gdef);
37823eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  auto sess = NewSession();
37923eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  TF_CHECK_OK(sess->Create(gdef));
38023eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  std::vector<Tensor> out;
38123eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  TF_CHECK_OK(sess->Run({{"x:0", x}, {"p:0", p}, {"dy:0", dy}},
38223eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                        {"dx:0", "dx:1"}, {}, &out));
38323eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  CHECK_EQ(out.size(), 2);
38423eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  TF_CHECK_OK(sess->Close());
38523eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  return out;
38623eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower}
38723eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower
3889f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlowerTEST(ArrayGradTest, TransposeGrad) {
38923eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  Tensor x(DT_FLOAT, {2, 4, 5});
39023eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  x.flat<float>().setZero();
39123eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  auto p = test::AsTensor<int32>({2, 0, 1});
39223eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  Tensor dy(DT_FLOAT, {5, 2, 4});
39323eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  test::FillIota<float>(&dy, 0);
39423eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  auto dx = TransposeGrad(x, p, dy);
39523eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  test::ExpectClose(dx[0], test::AsTensor<float>(
39623eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                               {0., 8.,  16., 24., 32., 1., 9.,  17., 25., 33.,
39723eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                                2., 10., 18., 26., 34., 3., 11., 19., 27., 35.,
39823eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                                4., 12., 20., 28., 36., 5., 13., 21., 29., 37.,
39923eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                                6., 14., 22., 30., 38., 7., 15., 23., 31., 39.},
40023eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower                               {2, 4, 5}));
40123eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower  test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0}));
40223eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower}
40323eccfb17635bce1c19b668986dceae1281ccee8A. Unique TensorFlower
40477ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlowerstd::vector<Tensor> ReverseGrad(const Tensor& x, const Tensor& dims,
40577ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower                                const Tensor& dy) {
40677ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  auto T = DT_FLOAT;
40777ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  auto gdef = test::function::GDef(
40877ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower      {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
40977ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower       f::NDef("dims", "Placeholder", {}, {{"dtype", DT_BOOL}}),
41077ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower       f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
41177ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower       f::NDef("dx", "SymbolicGradient", {"x", "dims", "dy"},
41277ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower               {{"f", FDH::FunctionRef("Reverse", {{"T", T}})},
41377ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower                {"Tin", DataTypeSlice{T, DT_BOOL, T}},
41477ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower                {"Tout", DataTypeSlice{T, DT_BOOL}}})});
41577ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  VLOG(1) << DebugStringWhole(gdef);
41677ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  auto sess = NewSession();
41777ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  TF_CHECK_OK(sess->Create(gdef));
41877ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  std::vector<Tensor> out;
41977ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  TF_CHECK_OK(sess->Run({{"x:0", x}, {"dims:0", dims}, {"dy:0", dy}},
42077ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower                        {"dx:0", "dx:1"}, {}, &out));
42177ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  CHECK_EQ(out.size(), 2);
42277ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  TF_CHECK_OK(sess->Close());
42377ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  return out;
42477ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower}
42577ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower
4269f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlowerTEST(ArrayGradTest, ReverseGrad) {
42777ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  Tensor x(DT_FLOAT, {2, 3});
42877ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  x.flat<float>().setZero();
42977ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  auto dims = test::AsTensor<bool>({false, true});
43077ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  Tensor dy(DT_FLOAT, {2, 3});
43177ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  test::FillIota<float>(&dy, 1);
43277ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  auto dx = ReverseGrad(x, dims, dy);
43377ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  test::ExpectClose(dx[0],
43477ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower                    test::AsTensor<float>({3., 2., 1., 6., 5., 4.}, {2, 3}));
43577ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  test::ExpectTensorEqual<bool>(dx[1], test::AsTensor<bool>({false, false}));
43677ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower}
43777ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower
438f11a06002dbb219f446768fb640f09569e61e9aeAndrew Sellestd::vector<Tensor> ReverseV2Grad(const Tensor& x, const Tensor& axis,
439f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle                                  const Tensor& dy) {
440f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  auto T = DT_FLOAT;
441f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  auto Tidx = DT_INT32;
442f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  auto gdef = test::function::GDef(
443f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle      {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
444f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle       f::NDef("axis", "Placeholder", {}, {{"dtype", DT_INT32}}),
445f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle       f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
446f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle       f::NDef(
447f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle           "dx", "SymbolicGradient", {"x", "axis", "dy"},
448f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle           {{"f", FDH::FunctionRef("ReverseV2", {{"T", T}, {"Tidx", Tidx}})},
449f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle            {"Tin", DataTypeSlice{T, DT_INT32, T}},
450f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle            {"Tout", DataTypeSlice{T, DT_INT32}}})});
451f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  VLOG(1) << DebugStringWhole(gdef);
452f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  auto sess = NewSession();
453f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  TF_CHECK_OK(sess->Create(gdef));
454f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  std::vector<Tensor> out;
455f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  TF_CHECK_OK(sess->Run({{"x:0", x}, {"axis:0", axis}, {"dy:0", dy}},
456f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle                        {"dx:0", "dx:1"}, {}, &out));
457f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  CHECK_EQ(out.size(), 2);
458f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  TF_CHECK_OK(sess->Close());
459f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  return out;
460f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle}
461f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle
4629f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlowerTEST(ArrayGradTest, ReverseV2Grad) {
463f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  Tensor x(DT_FLOAT, {2, 3});
464f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  x.flat<float>().setZero();
465f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  auto axis = test::AsTensor<int32>({1});
466f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  Tensor dy(DT_FLOAT, {2, 3});
467f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  test::FillIota<float>(&dy, 1);
468f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  auto dx = ReverseV2Grad(x, axis, dy);
469f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  test::ExpectTensorEqual<float>(
470f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle      dx[0], test::AsTensor<float>({3., 2., 1., 6., 5., 4.}, {2, 3}));
471f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle  test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0}));
472f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle}
473f11a06002dbb219f446768fb640f09569e61e9aeAndrew Selle
47477ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlowerstd::vector<Tensor> SliceGrad(const Tensor& x, const Tensor& b, const Tensor& s,
47577ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower                              const Tensor& dy) {
47677ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  auto T = DT_FLOAT;
47777ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  auto gdef = test::function::GDef(
47877ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower      {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
47977ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower       f::NDef("b", "Placeholder", {}, {{"dtype", DT_INT32}}),
48077ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower       f::NDef("s", "Placeholder", {}, {{"dtype", DT_INT32}}),
48177ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower       f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
48277ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower       f::NDef(
48377ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower           "dx", "SymbolicGradient", {"x", "b", "s", "dy"},
48477ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower           {{"f", FDH::FunctionRef("Slice", {{"T", T}, {"Index", DT_INT32}})},
48577ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower            {"Tin", DataTypeSlice{T, DT_INT32, DT_INT32, T}},
48677ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower            {"Tout", DataTypeSlice{T, DT_INT32, DT_INT32}}})});
48777ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  VLOG(1) << DebugStringWhole(gdef);
48877ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  auto sess = NewSession();
48977ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  TF_CHECK_OK(sess->Create(gdef));
49077ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  std::vector<Tensor> out;
49177ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  TF_CHECK_OK(sess->Run({{"x:0", x}, {"b:0", b}, {"s:0", s}, {"dy:0", dy}},
49277ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower                        {"dx:0", "dx:1", "dx:2"}, {}, &out));
49377ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  CHECK_EQ(out.size(), 3);
49477ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  TF_CHECK_OK(sess->Close());
49577ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  return out;
49677ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower}
49777ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower
4989f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlowerTEST(ArrayGradTest, SliceGrad) {
49977ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  Tensor x(DT_FLOAT, {2, 3, 4});
50077ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  x.flat<float>().setZero();
50177ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  auto begin = test::AsTensor<int32>({1, 1, 1});
50277ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  auto size = test::AsTensor<int32>({1, 2, 2});
50377ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  Tensor dy(DT_FLOAT, {1, 2, 2});
50477ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  test::FillIota<float>(&dy, 1);
50577ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  auto dx = SliceGrad(x, begin, size, dy);
50677ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  test::ExpectClose(dx[0],
50777ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower                    test::AsTensor<float>(
50877ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower                        {
50977ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower                            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
51077ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower                            0., 0., 0., 0., 0., 1., 2., 0., 0., 3., 4., 0.,
51177ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower                        },
51277ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower                        {2, 3, 4}));
51377ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0}));
51477ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower  test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0, 0}));
51577ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower}
51677ac22f5226044f647d6185b77521b55023145ebA. Unique TensorFlower
517f2f7dba849baeecd473423eca560ef0f28300327Andrew Sellestd::vector<Tensor> StridedSliceGrad(const Tensor& x, const Tensor& begin,
518f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                                     const Tensor& end, const Tensor& strides,
519f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                                     const Tensor& dy, int32 begin_mask,
520f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                                     int32 end_mask, int32 ellipsis_mask,
521f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                                     int32 new_axis_mask,
522f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                                     int32 shrink_axis_mask) {
523f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  auto T = DT_FLOAT;
524f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  auto gdef = test::function::GDef(
525f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle      {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
526f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle       f::NDef("begin", "Placeholder", {}, {{"dtype", DT_INT32}}),
527f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle       f::NDef("end", "Placeholder", {}, {{"dtype", DT_INT32}}),
528f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle       f::NDef("strides", "Placeholder", {}, {{"dtype", DT_INT32}}),
529f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle       f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
530f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle       f::NDef(
531f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle           "dx", "SymbolicGradient", {"x", "begin", "end", "strides", "dy"},
532f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle           {{"f", FDH::FunctionRef("StridedSlice",
533f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                                   {
534f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                                       {"T", T},
535f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                                       {"Index", DT_INT32},
536f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                                       {"begin_mask", begin_mask},
537f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                                       {"end_mask", end_mask},
538f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                                       {"new_axis_mask", new_axis_mask},
539f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                                       {"shrink_axis_mask", shrink_axis_mask},
540f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                                       {"ellipsis_mask", ellipsis_mask},
541f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                                   })},
542f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle            {"Tin", DataTypeSlice{T, DT_INT32, DT_INT32, DT_INT32, T}},
543f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle            {"Tout", DataTypeSlice{T, DT_INT32, DT_INT32, DT_INT32}}})});
544f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  VLOG(1) << DebugStringWhole(gdef);
545f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  auto sess = NewSession();
546f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  TF_CHECK_OK(sess->Create(gdef));
547f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  std::vector<Tensor> out;
548f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  TF_CHECK_OK(sess->Run({{"x:0", x},
549f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                         {"begin:0", begin},
550f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                         {"end:0", end},
551f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                         {"strides:0", strides},
552f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                         {"dy:0", dy}},
553f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                        {"dx:0", "dx:1", "dx:2", "dx:3"}, {}, &out));
554f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  CHECK_EQ(out.size(), 4);
555f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  TF_CHECK_OK(sess->Close());
556f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  return out;
557f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle}
558f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle
559e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Sellestd::vector<Tensor> StridedSliceGradGrad(
560e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle    const Tensor& shape, const Tensor& begin, const Tensor& end,
561e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle    const Tensor& strides, const Tensor& dy, const Tensor& grad,
562e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle    int32 begin_mask, int32 end_mask, int32 ellipsis_mask, int32 new_axis_mask,
563e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle    int32 shrink_axis_mask) {
564e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle  auto T = DT_FLOAT;
565e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle  auto gdef = test::function::GDef(
566e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle      {f::NDef("shape", "Placeholder", {}, {{"dtype", DT_INT32}}),
567e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle       f::NDef("begin", "Placeholder", {}, {{"dtype", DT_INT32}}),
568e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle       f::NDef("end", "Placeholder", {}, {{"dtype", DT_INT32}}),
569e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle       f::NDef("strides", "Placeholder", {}, {{"dtype", DT_INT32}}),
570e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle       f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
571e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle       f::NDef("grad", "Placeholder", {}, {{"dtype", T}}),
572e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle       f::NDef(
573e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle           "dx", "SymbolicGradient",
574e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle           {"shape", "begin", "end", "strides", "dy", "grad"},
575e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle           {{"f", FDH::FunctionRef("StridedSliceGrad",
576e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                                   {
577e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                                       {"T", T},
578e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                                       {"Index", DT_INT32},
579e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                                       {"begin_mask", begin_mask},
580e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                                       {"end_mask", end_mask},
581e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                                       {"new_axis_mask", new_axis_mask},
582e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                                       {"shrink_axis_mask", shrink_axis_mask},
583e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                                       {"ellipsis_mask", ellipsis_mask},
584e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                                   })},
585e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle            {"Tin",
586e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle             DataTypeSlice{DT_INT32, DT_INT32, DT_INT32, DT_INT32, T, T}},
587e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle            {"Tout",
588e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle             DataTypeSlice{DT_INT32, DT_INT32, DT_INT32, DT_INT32, T}}})});
589e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle  VLOG(1) << DebugStringWhole(gdef);
590e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle  auto sess = NewSession();
591e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle  TF_CHECK_OK(sess->Create(gdef));
592e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle  std::vector<Tensor> out;
593e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle  TF_CHECK_OK(sess->Run({{"shape:0", shape},
594e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                         {"begin:0", begin},
595e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                         {"end:0", end},
596e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                         {"strides:0", strides},
597e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                         {"dy:0", dy},
598e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                         {"grad:0", grad}},
599e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                        {"dx:0", "dx:1", "dx:2", "dx:3", "dx:4"}, {}, &out));
600e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle  CHECK_EQ(out.size(), 5);
601e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle  TF_CHECK_OK(sess->Close());
602e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle  return out;
603e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle}
604e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle
6059f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlowerTEST(ArrayGradTest, StridedSliceGrad) {
606f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  Tensor x(DT_FLOAT, {2, 3, 4});
607f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  x.flat<float>().setZero();
608e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle  Tensor x_shape = test::AsTensor<int32>({2, 3, 4}, {3});
609f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle
610f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  {
611f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    auto start = test::AsTensor<int32>({1, 1, 1});
612f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    auto stop = test::AsTensor<int32>({2, 3, 3});
613f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    auto strides = test::AsTensor<int32>({1, 1, 1});
614f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    Tensor dy(DT_FLOAT, {1, 2, 2});
615f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    test::FillIota<float>(&dy, 1);
616f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    int begin_mask = 0, end_mask = 0, new_axis_mask = 0, shrink_axis_mask = 0,
617f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle        ellipsis_mask = 0;
618f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    auto dx =
619f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle        StridedSliceGrad(x, start, stop, strides, dy, begin_mask, end_mask,
620f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                         ellipsis_mask, new_axis_mask, shrink_axis_mask);
621f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    test::ExpectClose(dx[0],
622f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                      test::AsTensor<float>(
623f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                          {
624f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                              0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
625f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                              0., 0., 0., 0., 0., 1., 2., 0., 0., 3., 4., 0.,
626f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                          },
627f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                          {2, 3, 4}));
628f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0}));
629f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0, 0}));
630e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle    auto ddx = StridedSliceGradGrad(x_shape, start, stop, strides, dy, dx[0],
631e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                                    begin_mask, end_mask, ellipsis_mask,
632e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                                    new_axis_mask, shrink_axis_mask);
633e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle    test::ExpectClose(ddx[4], dy);
634f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  }
635f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle
636f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  // test equivalent of python tf.gradients(foo[1:2, 1:3, 1:3])
637f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  {
638f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    auto start = test::AsTensor<int32>({1, 1, 1});
639f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    auto stop = test::AsTensor<int32>({2, 3, 3});
640f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    auto strides = test::AsTensor<int32>({1, 1, 1});
641f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    Tensor dy(DT_FLOAT, {1, 2, 2});
642f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    test::FillIota<float>(&dy, 1);
643f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    int begin_mask = 0, end_mask = 0, new_axis_mask = 0, shrink_axis_mask = 0,
644f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle        ellipsis_mask = 0;
645f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    auto dx =
646f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle        StridedSliceGrad(x, start, stop, strides, dy, begin_mask, end_mask,
647f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                         ellipsis_mask, new_axis_mask, shrink_axis_mask);
648f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    test::ExpectClose(dx[0],
649f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                      test::AsTensor<float>(
650f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                          {
651f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                              0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
652f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                              0., 0., 0., 0., 0., 1., 2., 0., 0., 3., 4., 0.,
653f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                          },
654f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                          {2, 3, 4}));
655f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0}));
656f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0, 0}));
657e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle    auto ddx = StridedSliceGradGrad(x_shape, start, stop, strides, dy, dx[0],
658e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                                    begin_mask, end_mask, ellipsis_mask,
659e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                                    new_axis_mask, shrink_axis_mask);
660e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle    test::ExpectClose(ddx[4], dy);
661f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  }
662f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle
663f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  // test equivalent of python tf.gradients(foo[1, 1:, :-2, None])
664f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  {
665f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    int dontcare = 66;
666f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    auto start = test::AsTensor<int32>({1, 1, dontcare, dontcare});
667f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    auto stop = test::AsTensor<int32>({2, dontcare, -2, dontcare});
668f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    auto strides = test::AsTensor<int32>({1, 1, 1, dontcare});
669f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    Tensor dy(DT_FLOAT, {2, 2, 1});
670f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    test::FillIota<float>(&dy, 1);
671f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    int begin_mask = 4, end_mask = 2, new_axis_mask = 8, shrink_axis_mask = 1,
672f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle        ellipsis_mask = 0;
673f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    auto dx =
674f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle        StridedSliceGrad(x, start, stop, strides, dy, begin_mask, end_mask,
675f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                         ellipsis_mask, new_axis_mask, shrink_axis_mask);
676f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    test::ExpectClose(dx[0],
677f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                      test::AsTensor<float>(
678f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                          {
679f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                              0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
680f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                              0., 0., 0., 0., 1., 2., 0., 0., 3., 4., 0., 0.,
681f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                          },
682f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                          {2, 3, 4}));
683f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0, 0}));
684f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0, 0, 0}));
685e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle    auto ddx = StridedSliceGradGrad(x_shape, start, stop, strides, dy, dx[0],
686e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                                    begin_mask, end_mask, ellipsis_mask,
687e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                                    new_axis_mask, shrink_axis_mask);
688e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle    test::ExpectClose(ddx[4], dy);
689f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  }
690f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle
691f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  // test equivalent of tf.gradients(foo[1, ...]) i.e. foo[1, 0:3, 0:4]
692f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  {
693f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    int dontcare = 66;
694f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    auto start = test::AsTensor<int32>({1, dontcare});
695f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    auto stop = test::AsTensor<int32>({2, dontcare});
696f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    auto strides = test::AsTensor<int32>({1, 1});
697f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    Tensor dy(DT_FLOAT, {3, 4});
698f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    test::FillIota<float>(&dy, 1);
699f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    int begin_mask = 0, end_mask = 0, new_axis_mask = 0, shrink_axis_mask = 1,
700f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle        ellipsis_mask = 2;
701f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    auto dx =
702f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle        StridedSliceGrad(x, start, stop, strides, dy, begin_mask, end_mask,
703f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                         ellipsis_mask, new_axis_mask, shrink_axis_mask);
704f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    test::ExpectClose(dx[0],
705f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                      test::AsTensor<float>(
706f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                          {
707f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                              0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,  0.,  0.,
708f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                              1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.,
709f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                          },
710f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle                          {2, 3, 4}));
711f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0}));
712f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle    test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0}));
713e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle    auto ddx = StridedSliceGradGrad(x_shape, start, stop, strides, dy, dx[0],
714e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                                    begin_mask, end_mask, ellipsis_mask,
715e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle                                    new_axis_mask, shrink_axis_mask);
716e3fdccaa4c4d834367e4e17d24518bc0a8e770f5Andrew Selle    test::ExpectClose(ddx[4], dy);
717f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle  }
718f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle}
719f2f7dba849baeecd473423eca560ef0f28300327Andrew Selle
7209f10f60fbd9fefaf225c1985014010b6b2f738c1A. Unique TensorFlower}  // namespace
721084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower}  // namespace tensorflow
722