1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
17
18#include "tensorflow/compiler/xla/layout_util.h"
19#include "tensorflow/compiler/xla/service/computation_layout.h"
20#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
21#include "tensorflow/compiler/xla/service/hlo_computation.h"
22#include "tensorflow/compiler/xla/service/hlo_instruction.h"
23#include "tensorflow/compiler/xla/service/hlo_module.h"
24#include "tensorflow/compiler/xla/service/hlo_opcode.h"
25#include "tensorflow/compiler/xla/shape_layout.h"
26#include "tensorflow/compiler/xla/shape_util.h"
27#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
28#include "tensorflow/compiler/xla/xla_data.pb.h"
29
30namespace xla {
31namespace gpu {
32namespace {
33
34using LayoutAssignmentTest = HloTestBase;
35
36TEST_F(LayoutAssignmentTest, Elementwise) {
37  Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
38  Shape ashape_in_row_major(ashape);
39  Shape ashape_in_col_major(ashape);
40  *ashape_in_row_major.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
41  *ashape_in_col_major.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
42
43  // Enumerate all possible combinations of layouts.
44  for (const Shape& lhs_shape_with_layout :
45       {ashape_in_row_major, ashape_in_col_major}) {
46    for (const Shape& rhs_shape_with_layout :
47         {ashape_in_row_major, ashape_in_col_major}) {
48      for (const Shape& result_shape_with_layout :
49           {ashape_in_row_major, ashape_in_col_major}) {
50        // GpuLayoutAssignment should assign the same layout to "add" and its
51        // two operands.
52        auto builder = HloComputation::Builder(TestName());
53        auto x = builder.AddInstruction(
54            HloInstruction::CreateParameter(0, ashape, "x"));
55        auto y = builder.AddInstruction(
56            HloInstruction::CreateParameter(1, ashape, "y"));
57        auto add = builder.AddInstruction(
58            HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y));
59        auto module = CreateNewModule();
60        HloComputation* computation =
61            module->AddEntryComputation(builder.Build(add));
62
63        ComputationLayout computation_layout(
64            computation->ComputeProgramShape());
65        *computation_layout.mutable_parameter_layout(0) =
66            ShapeLayout(lhs_shape_with_layout);
67        *computation_layout.mutable_parameter_layout(1) =
68            ShapeLayout(rhs_shape_with_layout);
69        *computation_layout.mutable_result_layout() =
70            ShapeLayout(result_shape_with_layout);
71
72        GpuLayoutAssignment layout_assignment(&computation_layout);
73        EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
74
75        for (const HloInstruction* operand : add->operands()) {
76          EXPECT_TRUE(LayoutUtil::Equal(add->shape().layout(),
77                                        operand->shape().layout()));
78        }
79      }
80    }
81  }
82}
83
84// Returns a list shapes with all the possible layouts of this shape, including
85// a shape with no layout.
86std::vector<Shape> AllLayoutsOf(const Shape& s) {
87  std::vector<int64> layout_vec(s.dimensions_size());
88  std::iota(layout_vec.begin(), layout_vec.end(), 0);
89
90  std::vector<Shape> shapes;
91  shapes.push_back(s);
92  shapes.back().clear_layout();
93
94  do {
95    shapes.push_back(s);
96    *shapes.back().mutable_layout() = LayoutUtil::MakeLayout(layout_vec);
97  } while (std::next_permutation(layout_vec.begin(), layout_vec.end()));
98
99  return shapes;
100}
101
102TEST_F(LayoutAssignmentTest, BatchNormInference) {
103  const int64 kFeatureIndex = 1;
104
105  // The shape of the data operand to BatchNormInference and of the output of
106  // the BatchNormInference call.
107  Shape shape = ShapeUtil::MakeShape(F32, {42, 12, 1, 100});
108
109  // The shape of the scale, offset, mean, and variance inputs to
110  // BatchNormTraining.  These are rank 1, with as many elements are in the
111  // kFeatureIndex dim of shape.
112  Shape aux_shape =
113      ShapeUtil::MakeShape(F32, {shape.dimensions(kFeatureIndex)});
114
115  for (const Shape& input_shape : AllLayoutsOf(shape)) {
116    for (const Shape& result_shape : AllLayoutsOf(shape)) {
117      SCOPED_TRACE(tensorflow::strings::StrCat(
118          "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape),
119          ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape)));
120
121      auto builder = HloComputation::Builder(TestName());
122      auto* operand = builder.AddInstruction(
123          HloInstruction::CreateParameter(0, shape, "operand"));
124      auto* scale = builder.AddInstruction(
125          HloInstruction::CreateParameter(1, aux_shape, "scale"));
126      auto* offset = builder.AddInstruction(
127          HloInstruction::CreateParameter(2, aux_shape, "offset"));
128      auto* mean = builder.AddInstruction(
129          HloInstruction::CreateParameter(3, aux_shape, "mean"));
130      auto* variance = builder.AddInstruction(
131          HloInstruction::CreateParameter(4, aux_shape, "variance"));
132
133      auto* epsilon = builder.AddInstruction(
134          HloInstruction::CreateConstant(Literal::CreateR0<float>(1)));
135      auto* feature_index =
136          builder.AddInstruction(HloInstruction::CreateConstant(
137              Literal::CreateR0<int64>(kFeatureIndex)));
138
139      auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall(
140          shape,
141          {operand, scale, offset, mean, variance, epsilon, feature_index},
142          kCudnnBatchNormForwardInferenceCallTarget));
143
144      auto module = CreateNewModule();
145      HloComputation* computation =
146          module->AddEntryComputation(builder.Build(batchnorm));
147
148      ComputationLayout computation_layout(computation->ComputeProgramShape());
149
150      if (input_shape.has_layout()) {
151        *computation_layout.mutable_parameter_layout(0) =
152            ShapeLayout(input_shape);
153      }
154
155      if (result_shape.has_layout()) {
156        *computation_layout.mutable_result_layout() = ShapeLayout(result_shape);
157      }
158
159      GpuLayoutAssignment layout_assignment(&computation_layout);
160      EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
161
162      // The first operand to batchnorm should have the same layout as the
163      // result.
164      EXPECT_TRUE(LayoutUtil::Equal(batchnorm->operand(0)->shape().layout(),
165                                    batchnorm->shape().layout()))
166          << batchnorm->ToString();
167    }
168  }
169}
170
171TEST_F(LayoutAssignmentTest, BatchNormTraining) {
172  const int64 kFeatureIndex = 1;
173
174  // The shape of the data operand to BatchNormTraining.
175  Shape shape = ShapeUtil::MakeShape(F32, {42, 12, 1, 100});
176
177  // The shape of the offset and scale inputs to BatchNormTraining.  These are
178  // rank 1, with as many elements are in the kFeatureIndex dim of shape.
179  Shape offset_scale_shape =
180      ShapeUtil::MakeShape(F32, {shape.dimensions(kFeatureIndex)});
181
182  // Shape of the output of our BatchNormTraining op.
183  Shape batchnorm_shape = ShapeUtil::MakeTupleShape(
184      {shape, offset_scale_shape, offset_scale_shape});
185
186  // Enumerate all combinations of shapes.
187  for (const Shape& input_shape : AllLayoutsOf(shape)) {
188    for (const Shape& result_shape : AllLayoutsOf(shape)) {
189      SCOPED_TRACE(tensorflow::strings::StrCat(
190          "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape),
191          ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape)));
192
193      auto builder = HloComputation::Builder(TestName());
194      auto* operand = builder.AddInstruction(
195          HloInstruction::CreateParameter(0, shape, "operand"));
196      auto* scale = builder.AddInstruction(
197          HloInstruction::CreateParameter(1, offset_scale_shape, "scale"));
198      auto* offset = builder.AddInstruction(
199          HloInstruction::CreateParameter(2, offset_scale_shape, "offset"));
200
201      auto* epsilon = builder.AddInstruction(
202          HloInstruction::CreateConstant(Literal::CreateR0<float>(1)));
203      auto* feature_index =
204          builder.AddInstruction(HloInstruction::CreateConstant(
205              Literal::CreateR0<int64>(kFeatureIndex)));
206
207      auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall(
208          batchnorm_shape, {operand, scale, offset, epsilon, feature_index},
209          kCudnnBatchNormForwardTrainingCallTarget));
210
211      auto module = CreateNewModule();
212      HloComputation* computation =
213          module->AddEntryComputation(builder.Build(batchnorm));
214
215      ComputationLayout computation_layout(computation->ComputeProgramShape());
216
217      if (input_shape.has_layout()) {
218        *computation_layout.mutable_parameter_layout(0) =
219            ShapeLayout(input_shape);
220      }
221
222      if (result_shape.has_layout()) {
223        *computation_layout.mutable_result_layout() =
224            ShapeLayout(ShapeUtil::MakeTupleShape(
225                {result_shape, offset_scale_shape, offset_scale_shape}));
226      }
227
228      GpuLayoutAssignment layout_assignment(&computation_layout);
229      EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
230
231      // The first operand to batchnorm should have the same layout as the
232      // first element of the result tuple.
233      EXPECT_TRUE(
234          LayoutUtil::Equal(batchnorm->operand(0)->shape().layout(),
235                            batchnorm->shape().tuple_shapes(0).layout()))
236          << batchnorm->ToString();
237    }
238  }
239}
240
241TEST_F(LayoutAssignmentTest, BatchNormGrad) {
242  const int64 kFeatureIndex = 1;
243
244  // The shape of the data operand to BatchNormTraining.
245  Shape shape = ShapeUtil::MakeShape(F32, {42, 12, 1, 100});
246
247  // The shape of the scale, mean, and variance inputs to BatchNormGrad.  These
248  // are rank 1, with as many elements are in the kFeatureIndex dim of shape.
249  Shape scale_shape =
250      ShapeUtil::MakeShape(F32, {shape.dimensions(kFeatureIndex)});
251
252  // Shape of the output of our BatchNormGrad op.
253  Shape batchnorm_shape =
254      ShapeUtil::MakeTupleShape({shape, scale_shape, scale_shape});
255
256  // Enumerate all combinations of shapes plus whether we're constraining param
257  // 0 or param 4.
258  for (const Shape& input_shape : AllLayoutsOf(shape)) {
259    for (const Shape& result_shape : AllLayoutsOf(shape)) {
260      for (int constrained_param_no : {0, 4}) {
261        SCOPED_TRACE(tensorflow::strings::StrCat(
262            "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape),
263            ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape)));
264
265        auto builder = HloComputation::Builder(TestName());
266        auto* operand = builder.AddInstruction(
267            HloInstruction::CreateParameter(0, shape, "operand"));
268        auto* scale = builder.AddInstruction(
269            HloInstruction::CreateParameter(1, scale_shape, "scale"));
270        auto* mean = builder.AddInstruction(
271            HloInstruction::CreateParameter(2, scale_shape, "mean"));
272        auto* var = builder.AddInstruction(
273            HloInstruction::CreateParameter(3, scale_shape, "var"));
274        auto* grad_offset = builder.AddInstruction(
275            HloInstruction::CreateParameter(4, shape, "var"));
276
277        auto* epsilon = builder.AddInstruction(
278            HloInstruction::CreateConstant(Literal::CreateR0<float>(1)));
279        auto* feature_index =
280            builder.AddInstruction(HloInstruction::CreateConstant(
281                Literal::CreateR0<int64>(kFeatureIndex)));
282
283        auto* batchnorm =
284            builder.AddInstruction(HloInstruction::CreateCustomCall(
285                batchnorm_shape,
286                {operand, scale, mean, var, grad_offset, epsilon,
287                 feature_index},
288                kCudnnBatchNormBackwardCallTarget));
289
290        auto module = CreateNewModule();
291        HloComputation* computation =
292            module->AddEntryComputation(builder.Build(batchnorm));
293
294        ComputationLayout computation_layout(
295            computation->ComputeProgramShape());
296
297        if (input_shape.has_layout()) {
298          *computation_layout.mutable_parameter_layout(constrained_param_no) =
299              ShapeLayout(input_shape);
300        }
301
302        if (result_shape.has_layout()) {
303          *computation_layout.mutable_result_layout() =
304              ShapeLayout(ShapeUtil::MakeTupleShape(
305                  {result_shape, scale_shape, scale_shape}));
306        }
307
308        GpuLayoutAssignment layout_assignment(&computation_layout);
309        EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
310
311        // The first and fourth operands to the batchnorm call should have the
312        // same layout as the first element of the result tuple.
313        EXPECT_TRUE(
314            LayoutUtil::Equal(batchnorm->operand(0)->shape().layout(),
315                              batchnorm->shape().tuple_shapes(0).layout()))
316            << batchnorm->ToString();
317        EXPECT_TRUE(
318            LayoutUtil::Equal(batchnorm->operand(4)->shape().layout(),
319                              batchnorm->shape().tuple_shapes(0).layout()))
320            << batchnorm->ToString();
321      }
322    }
323  }
324}
325
326}  // namespace
327}  // namespace gpu
328}  // namespace xla
329