19e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 29e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower 39e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License"); 49e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFloweryou may not use this file except in compliance with the License. 59e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerYou may obtain a copy of the License at 69e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower 79e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower http://www.apache.org/licenses/LICENSE-2.0 89e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower 99e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software 109e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS, 119e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerSee the License for the specific language governing permissions and 139e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerlimitations under the License. 149e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower==============================================================================*/ 159e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower 169e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_scheduling.h" 179e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower 189e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include <memory> 199e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include <string> 209e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower 219e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_computation.h" 229e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_instruction.h" 239e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_opcode.h" 249e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/service/hlo_ordering.h" 259e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/shape_util.h" 269e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/tests/hlo_test_base.h" 279e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/types.h" 289e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower#include "tensorflow/compiler/xla/xla_data.pb.h" 299e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower 309e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowernamespace xla { 319e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowernamespace { 329e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower 339e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerclass MinimumMemoryForSequenceTest : public HloTestBase {}; 349e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower 359e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlowerTEST_F(MinimumMemoryForSequenceTest, MultiComputation) { 369e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower auto module = CreateNewModule(); 379e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); 389e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower const Shape tuple_shape = 399e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); 409e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower 419e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower auto cond_builder = HloComputation::Builder("WhileCond"); 429e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) 439e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloInstruction* cond_param = cond_builder.AddInstruction( 449e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); 459e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloInstruction* cond_iter = cond_builder.AddInstruction( 469e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); 479e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloInstruction* cond_data = cond_builder.AddInstruction( 489e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); 499e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) 509e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloInstruction* cond_lt = cond_builder.AddInstruction( 519e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), 529e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloOpcode::kLt, cond_iter, cond_data)); 539e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloComputation* cond_computation = 549e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower module->AddEmbeddedComputation(cond_builder.Build()); 559e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower 569e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower auto body_builder = HloComputation::Builder("WhileBody"); 579e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) 589e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloInstruction* body_param = body_builder.AddInstruction( 599e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloInstruction::CreateParameter(0, tuple_shape, "body_param")); 609e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloComputation* body_computation = 619e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower module->AddEmbeddedComputation(body_builder.Build()); 629e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower 639e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower auto builder = HloComputation::Builder(TestName()); 649e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower // Entry params: 8 bytes (4 bytes per param), TOTAL=8 659e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloInstruction* iter = builder.AddInstruction( 669e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); 679e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloInstruction* data = builder.AddInstruction( 689e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloInstruction::CreateParameter(1, scalar_shape, "param_data")); 699e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 709e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloInstruction* tuple = 719e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); 729e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower // While: 8 bytes (4 bytes per element), TOTAL=32 739e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower // Both cond and body use a max of 24 bytes, TOTAL=56 749e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( 759e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower tuple_shape, cond_computation, body_computation, tuple)); 769e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower HloComputation* entry_computation = 779e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower module->AddEntryComputation(builder.Build()); 789e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower 799e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower auto size_fn = [](const LogicalBuffer& buffer) { 809e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); 819e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower }; 829e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower 839e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower SequentialHloOrdering::HloModuleSequence module_sequence; 849e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, 859e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower cond_lt}; 869e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower module_sequence[body_computation] = {body_param}; 879e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower module_sequence[entry_computation] = {iter, data, tuple, while_op}; 889e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower EXPECT_EQ(56, 899e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); 909e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower} 919e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower 929e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower} // namespace 939e8005d7771e3f98b0a2ce74e4b0bc3765410a27A. Unique TensorFlower} // namespace xla 94