19e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
29e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower
39e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License");
49e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFloweryou may not use this file except in compliance with the License.
59e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerYou may obtain a copy of the License at
69e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower
79e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower    http://www.apache.org/licenses/LICENSE-2.0
89e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower
99e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software
109e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS,
119e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerSee the License for the specific language governing permissions and
139e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerlimitations under the License.
149e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower==============================================================================*/
159e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower
169e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
179e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower
189e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include <memory>
199e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include <string>
209e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower
219e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_computation.h"
229e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_instruction.h"
239e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_opcode.h"
249e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_ordering.h"
259e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/shape_util.h"
269e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
279e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/types.h"
289e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/xla_data.pb.h"
299e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower
309e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowernamespace xla {
319e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowernamespace {
329e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower
339e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerclass MinimumMemoryForSequenceTest : public HloTestBase {};
349e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower
359e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerTEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
369e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  auto module = CreateNewModule();
379e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
389e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  const Shape tuple_shape =
399e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower      ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
409e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower
419e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  auto cond_builder = HloComputation::Builder("WhileCond");
429e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
439e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  HloInstruction* cond_param = cond_builder.AddInstruction(
449e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower      HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
459e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  HloInstruction* cond_iter = cond_builder.AddInstruction(
469e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower      HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
479e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  HloInstruction* cond_data = cond_builder.AddInstruction(
489e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower      HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
499e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
509e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  HloInstruction* cond_lt = cond_builder.AddInstruction(
519e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower      HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
529e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower                                   HloOpcode::kLt, cond_iter, cond_data));
539e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  HloComputation* cond_computation =
549e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower      module->AddEmbeddedComputation(cond_builder.Build());
559e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower
569e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  auto body_builder = HloComputation::Builder("WhileBody");
579e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
589e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  HloInstruction* body_param = body_builder.AddInstruction(
599e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower      HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
609e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  HloComputation* body_computation =
619e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower      module->AddEmbeddedComputation(body_builder.Build());
629e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower
639e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  auto builder = HloComputation::Builder(TestName());
649e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  // Entry params: 8 bytes (4 bytes per param), TOTAL=8
659e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  HloInstruction* iter = builder.AddInstruction(
669e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower      HloInstruction::CreateParameter(0, scalar_shape, "param_iter"));
679e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  HloInstruction* data = builder.AddInstruction(
689e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower      HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
699e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24
709e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  HloInstruction* tuple =
719e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower      builder.AddInstruction(HloInstruction::CreateTuple({iter, data}));
729e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  // While: 8 bytes (4 bytes per element), TOTAL=32
739e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  // Both cond and body use a max of 24 bytes, TOTAL=56
749e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
759e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower      tuple_shape, cond_computation, body_computation, tuple));
769e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  HloComputation* entry_computation =
779e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower      module->AddEntryComputation(builder.Build());
789e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower
799e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  auto size_fn = [](const LogicalBuffer& buffer) {
809e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower    return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
819e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  };
829e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower
839e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  SequentialHloOrdering::HloModuleSequence module_sequence;
849e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  module_sequence[cond_computation] = {cond_param, cond_iter, cond_data,
859e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower                                       cond_lt};
869e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  module_sequence[body_computation] = {body_param};
879e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  module_sequence[entry_computation] = {iter, data, tuple, while_op};
889e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower  EXPECT_EQ(56,
899e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower            MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie());
909e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower}
919e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower
929e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower}  // namespace
939e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower}  // namespace xla
94