11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License. 51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at 61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins http://www.apache.org/licenses/LICENSE-2.0 81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software 101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and 131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License. 141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/ 151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" 171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 181151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower#include <unordered_map> 191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <vector> 201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/literal_util.h" 221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_computation.h" 231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/shape_util.h" 241151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower#include "tensorflow/compiler/xla/status_macros.h" 251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/util.h" 261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/lib/core/errors.h" 271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla { 291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace gpu { 301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace { 321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// TODO(b/33483676) Use an expression tree to specify computations to pattern 341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// match for while transformations. 351151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 361151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// ExprTree is a simple recursive data structure used to express computation 371151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// patterns to match. 381151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// 391151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// Each ExprTree node is comprised of an HloOpcode, and a set of operands (each 4046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower// of type ExprTree). Operands can be added by specifying the index and 4146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower// HloOpcode of the operand. 421151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// 431151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// For example, the following computation: 441151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// 451151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// Parameter 461151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// | 471d7d8667d19e61bc65f35a6dae33563b2acadaacManHyuk// Const GetTupleElement 481151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// \ / 491151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// Add (root) 501151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// 511151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// Can be matched with the following expression tree: 521151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// 531151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// ExprTree add(HloOpcode::kAdd, 541151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// ExprTree(HloOpcode::kConstant), 551151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// ExprTree(HloOpcode::kGetTupleElement, 561151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// tuple_index, ExprTree(HloOpcode::kParameter))); 571151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// 581151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// Match the ExprTree root against an Hlo graph: 591151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// 601151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// ExprTree::TaggedInstructionMap tagged_instructions; 611151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// TF_RETURN_IF_ERROR(add.Match(computation_->root_instruction(), 621151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// &tagged_instructions)); 631151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// 641151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// Instructions that are "tagged" with a context-specific string will 650b6f2189cf879c89f6d72a27de7800ccea095605ManHyuk// be returned in 'tagged_instructions' for further processing (i.e. parsing 661151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// constants or recording the tuple_index). 671151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// 681151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlowerclass ExprTree { 691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public: 701151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower explicit ExprTree(HloOpcode opcode) : opcode_(opcode) {} 711151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree(HloOpcode opcode, const string& tag) : opcode_(opcode), tag_(tag) {} 721151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree(HloOpcode opcode, const ExprTree& operand0) : opcode_(opcode) { 731151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower SetOperand(0, operand0); 741151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } 751151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0) 761151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower : opcode_(opcode) { 771151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower SetOperand(index0, operand0); 781151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } 791151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0, 801151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower int64 index1, const ExprTree& operand1) 811151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower : opcode_(opcode) { 821151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower SetOperand(index0, operand0); 831151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower SetOperand(index1, operand1); 841151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } 851151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree(HloOpcode opcode, const string& tag, const ExprTree& operand0) 861151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower : opcode_(opcode), tag_(tag) { 871151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower SetOperand(0, operand0); 881151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } 891151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree(HloOpcode opcode, const ExprTree& operand0, const ExprTree& operand1) 901151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower : opcode_(opcode) { 911151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower SetOperand(0, operand0); 921151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower SetOperand(1, operand1); 931151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } 941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 951151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree(const ExprTree& to_copy) { 961151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower opcode_ = to_copy.opcode_; 971151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower tag_ = to_copy.tag_; 981151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower if (to_copy.fused_root_tree_ != nullptr) { 991151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower fused_root_tree_.reset(new ExprTree(*to_copy.fused_root_tree_)); 1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1011151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower for (auto& pair : to_copy.operands_) { 1021151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower CHECK(operands_.find(pair.first) == operands_.end()); 1031151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower operands_.insert(std::make_pair( 1041151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower pair.first, std::unique_ptr<ExprTree>(new ExprTree(*pair.second)))); 1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1081151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower void SetFusedRoot(const ExprTree& fused_root) { 1091151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower fused_root_tree_.reset(new ExprTree(fused_root)); 1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1121151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower typedef std::unordered_map<string, const HloInstruction*> 1131151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower TaggedInstructionMap; 1141151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 1151151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Matches 'instruction' HloOpcode against 'opcode_'. 1161151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Recursively matches each operand in 'operands_'. 1171151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Recursively matches fused instructions starting at 'fused_root_tree_' 1181151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // if 'opcode_ == kFusion'. 1191151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Returns OK status, and instructions in 'tagged_instructions' for each 1201151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // matched ExprTree node with a non-empty 'tag_'. 1211151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Returns error message on failure. 1221151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower Status Match(const HloInstruction* instruction, 1231151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower TaggedInstructionMap* tagged_instructions) const { 1241151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower if (opcode_ != instruction->opcode()) { 125fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower return InvalidArgument("got opcode %s, want %s", 126fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower HloOpcodeString(instruction->opcode()).c_str(), 127fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower HloOpcodeString(opcode_).c_str()); 1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1291151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 130fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower VLOG(2) << "Matched " << HloOpcodeString(opcode_) << ": " << tag_; 1311151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower if (!tag_.empty()) { 1321151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower tagged_instructions->insert({tag_, instruction}); 1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1341151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 1351151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower if (instruction->opcode() == HloOpcode::kFusion) { 1361151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower CHECK(fused_root_tree_ != nullptr); 1371151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Match fused instructions for this node starting a 'fused_root_tree'. 1381151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower TF_RETURN_IF_ERROR(fused_root_tree_->Match( 1391151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower instruction->fused_expression_root(), tagged_instructions)); 1401151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } 1411151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 1421151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Match each operand in 'operands_'. 1431151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower for (auto& pair : operands_) { 1441151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first), 1451151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower tagged_instructions)); 1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return tensorflow::Status::OK(); 1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1501151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower private: 1511151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower void SetOperand(int64 index, const ExprTree& operand) { 1521151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower CHECK_EQ(0, operands_.count(index)); 1531151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower operands_.insert(std::make_pair(index, MakeUnique<ExprTree>(operand))); 1541151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } 1551151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 1561151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower HloOpcode opcode_; 1571151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower std::unordered_map<int64, std::unique_ptr<ExprTree>> operands_; 1581151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower std::unique_ptr<ExprTree> fused_root_tree_; 1591151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower string tag_; 1601151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower}; 1611151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 1621151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// MatcherBase is a base class that provides common functionality for 1631151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// sub-classes which match specific target sub-computations (i.e. loop 1641151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// induction variable initialization, comparison and update). 1651151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlowerclass MatcherBase { 1661151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower public: 1671151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower MatcherBase() {} 1681151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower virtual ~MatcherBase() {} 1691151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 1701151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Attempts to match each ExprTree in 'expr_trees_'. 17153cb26d05a5c2080d8022124178b1cc43a30ffe5A. Unique TensorFlower // Returns OK on the first successful match, error status otherwise. 1721151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower virtual tensorflow::Status Run() { 1731151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower Status status; 1741151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower for (const ExprTree& expr_tree : expr_trees_) { 1751151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower status = MatchExprTree(expr_tree); 1761151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower if (status.ok()) { 1771151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower return status; 1781151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } 1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1801151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower return status; 1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1831151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower virtual Status MatchExprTree(const ExprTree& expr_tree) = 0; 1841151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 1851151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Returns the constant value parsed form kConstant 'instruction'. 1861151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Returns error status otherwise. 1871151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower Status ParseConstInteger(const HloInstruction* instruction, 1881151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower int64* const_value) const { 1891151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower CHECK_EQ(HloOpcode::kConstant, instruction->opcode()); 1901151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower PrimitiveType element_type = instruction->shape().element_type(); 1911151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower if (element_type != S32 && element_type != S64) { 1921151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower return InvalidArgument("Expected constant of integral type."); 1931151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } 1941151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower const Literal& literal = instruction->literal(); 1951151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower PrimitiveType type = literal.shape().element_type(); 1961151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower if (type != S32 && type != S64) { 1971151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower return InvalidArgument("Must use S32 or S64 integral types."); 1981151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } 1991151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower if (type == S32) { 20046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *const_value = static_cast<int64>(literal.GetFirstElement<int32>()); 2011151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } else if (type == S64) { 20246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower *const_value = literal.GetFirstElement<int64>(); 2031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return tensorflow::Status::OK(); 2051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2071151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower StatusOr<const HloInstruction*> GetTaggedInstruction( 2081151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower const string& tag, 2091151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower const ExprTree::TaggedInstructionMap& tagged_instructions) { 2101151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower auto it = tagged_instructions.find(tag); 2111151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower if (it == tagged_instructions.end()) { 2121151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower return InvalidArgument("Cound not find instruction for tag: %s", 2131151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower tag.c_str()); 2141151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } 2151151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower return it->second; 2161151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } 2171151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 2181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins protected: 2191151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower std::vector<ExprTree> expr_trees_; 2201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins private: 2221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TF_DISALLOW_COPY_AND_ASSIGN(MatcherBase); 2231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 2241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 22528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower// WhileConditionComputationMatcher attempts to match a target computation 2261151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// pattern in the while condition sub-computation. 2271151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// If the target pattern is matched, two pieces of information are extracted 2281151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// from 'tagged' instructions returned by the matcher: 2291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// 2301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// *) 'tuple_index': 2311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// *) The loop induction variable tuple_index from the GetTupleElement 2321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// instruction of the matched computation. 2331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// *) Used in subsequent matching passes of while init operand and body 2341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// computations to select loop induction variable tuple element. 2351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// 2361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// *) 'loop_limit': 2371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// *) The integral value from Constant root operand in matched computation. 2381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// *) Used as the constant for the loop limit. 2391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// 2401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass WhileConditionComputationMatcher : public MatcherBase { 2411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public: 2426882effb863dcd0da00d3287959deac46734a0b2A. Unique TensorFlower explicit WhileConditionComputationMatcher(const HloComputation* computation) 2431151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower : computation_(computation) { 2441151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower expr_trees_.emplace_back(BuildCondExprTree()); 2451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 2461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2471151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower int64 loop_limit() const { return loop_limit_; } 2481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 tuple_index() const { return tuple_index_; } 2491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins private: 2511151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Builds expression tree for the following condition computation: 2521151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // 2531151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Const Parameter 2541151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // \ / 2551151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Fusion ------------> FusionParam FusionParam 2561151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // \ / 2571151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // GTE / 2581151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // \ / 2591151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // LessThan (fused root) 2601151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // 2611151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree BuildCondExprTree() { 2621151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Build ExprTree for fused instructions. 2631151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree fused_root( 2641151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower HloOpcode::kLt, 2651151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree(HloOpcode::kGetTupleElement, "gte", 2661151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")), 2671151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree(HloOpcode::kParameter)); 2681151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 2691151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Build top-level computation. 2701151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree root(HloOpcode::kFusion, 2711151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree(HloOpcode::kConstant, "loop_limit"), 2721151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree(HloOpcode::kParameter, "param0")); 2731151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 2741151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower root.SetFusedRoot(fused_root); 2751151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower return root; 2761151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } 2771151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 2781151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower Status MatchExprTree(const ExprTree& expr_tree) override { 279fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower VLOG(2) << "MATCHING while condition"; 2801151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree::TaggedInstructionMap tagged_instructions; 2811151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), 2821151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower &tagged_instructions)); 2831151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 2841151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Get tagged GTE instruction and set 'tuple_index_'. 2851151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower TF_ASSIGN_OR_RETURN(const HloInstruction* gte, 2861151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower GetTaggedInstruction("gte", tagged_instructions)); 2871151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower tuple_index_ = gte->tuple_index(); 2881151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 2891151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Get tagged Constant instruction and parse 'loop_limit_'. 2901151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower TF_ASSIGN_OR_RETURN( 2911151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower const HloInstruction* const_hlo, 2921151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower GetTaggedInstruction("loop_limit", tagged_instructions)); 2931151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_limit_)); 2941151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 2951151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Get tagged "param0" instruction, and check that it matches 2961151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // 'computation_' parameter 0. 2971151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower TF_ASSIGN_OR_RETURN(const HloInstruction* param0, 2981151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower GetTaggedInstruction("param0", tagged_instructions)); 2991151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower if (param0 != computation_->parameter_instruction(0)) { 3001151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower return InvalidArgument("Unexpected Parameter0 instruction : %s", 3011151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower param0->name().c_str()); 3021151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } 3031151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 3041151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Get tagged 'gte.fusion_param.param0', find its associated fusion operand, 3051151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // and compare it to 'computation_' parameter0. 3061151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower TF_ASSIGN_OR_RETURN( 3071151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower const HloInstruction* gte_fusion_param0, 3081151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower GetTaggedInstruction("gte.fusion_param.param0", tagged_instructions)); 3091151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower CHECK_EQ(HloOpcode::kParameter, gte_fusion_param0->opcode()); 3101151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower CHECK(gte_fusion_param0->IsFused()); 311e565d1f1fced69789feb10f1ea1241157ec95f93A. Unique TensorFlower if (gte_fusion_param0->parent()->FusionInstruction()->operand( 3121151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower gte_fusion_param0->parameter_number()) != 3131151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower computation_->parameter_instruction(0)) { 3141151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower return InvalidArgument("Could not match fusion param: %s", 3151151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower gte_fusion_param0->name().c_str()); 3161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3171151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 3181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return tensorflow::Status::OK(); 3191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3211151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower const HloComputation* computation_; 3221151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 3231151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower int64 loop_limit_ = -1; 3241151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower int64 tuple_index_ = -1; 3251151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 3261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TF_DISALLOW_COPY_AND_ASSIGN(WhileConditionComputationMatcher); 3271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 3281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3291151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// WhileInitOperandMatcher matches a target computation pattern of the 3301151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// while instructions 'init' operand, indexing the tuple at 'tuple_index'. 3311151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// On success, parses constant 'loop_start' which represents the loop induction 3321151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// variable start values, then returns OK. 3331151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// Returns error status otherwise. 3341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass WhileInitOperandMatcher : public MatcherBase { 3351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public: 3361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins WhileInitOperandMatcher(const HloInstruction* while_hlo, 3371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 tuple_index) 3381151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower : while_hlo_(while_hlo), tuple_index_(tuple_index) { 3391151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower expr_trees_.emplace_back(BuildInitExprTree()); 3401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3421151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower int64 loop_start() const { return loop_start_; } 3431151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 3441151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower private: 3451151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Builds expression tree for the following while init operand subcomputation: 3461151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // 3471151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Const 3481151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // | 3491151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Copy 3501151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // | 3511151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Tuple0 3521151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // | 3531151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // While 3541151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // 3551151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree BuildInitExprTree() { 356fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower return ExprTree( 357fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower HloOpcode::kWhile, "while", 358fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower ExprTree(HloOpcode::kTuple, tuple_index_, 359fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower ExprTree(HloOpcode::kCopy, 360fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower ExprTree(HloOpcode::kConstant, "loop_start")))); 3611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3631151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower Status MatchExprTree(const ExprTree& expr_tree) override { 364fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower VLOG(2) << "MATCHING while init"; 3651151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree::TaggedInstructionMap tagged_instructions; 3661151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower TF_RETURN_IF_ERROR(expr_tree.Match(while_hlo_, &tagged_instructions)); 3671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3681151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Get tagged while instruction check against 'while_hlo_'. 3691151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower TF_ASSIGN_OR_RETURN(const HloInstruction* while_hlo, 3701151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower GetTaggedInstruction("while", tagged_instructions)); 3711151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower if (while_hlo != while_hlo_) { 3721151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower return InvalidArgument("Expected While for instruction : %s", 3731151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower while_hlo->name().c_str()); 3741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3761151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Get tagged Constant instruction and parse 'loop_start_'. 3771151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower TF_ASSIGN_OR_RETURN( 3781151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower const HloInstruction* const_hlo, 3791151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower GetTaggedInstruction("loop_start", tagged_instructions)); 3801151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_)); 3811151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 3821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return tensorflow::Status::OK(); 3831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 3841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3851151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower const HloInstruction* while_hlo_; 3861151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower const int64 tuple_index_; 3871151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 3881151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower int64 loop_start_ = -1; 3891151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 3901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TF_DISALLOW_COPY_AND_ASSIGN(WhileInitOperandMatcher); 3911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 3921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 3931151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// WhileBodyComputationMatcher matches a target computation pattern for 3941151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// the loop induction variable update. Matching proceeds from the while body 3951151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// computation root[tuple_index] to param[tuple_index], where 'tuple_index' 3961151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// If the target pattern is matched, parses a constant which represents the 3971151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// loop induction variable increment value, then returns status OK. 3981151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// Returns error status otherwise. 3991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass WhileBodyComputationMatcher : public MatcherBase { 4001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public: 4011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins WhileBodyComputationMatcher(const HloComputation* computation, 4021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const int64 tuple_index) 4031151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower : computation_(computation), tuple_index_(tuple_index) { 4041151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower expr_trees_.emplace_back(BuildBodyExprTree(0, 1)); 4051151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower expr_trees_.emplace_back(BuildBodyExprTree(1, 0)); 4061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4081151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower int64 loop_increment() const { return loop_increment_; } 4091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins private: 4111151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Builds expression tree for the following while body computation: 4121151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // 4131151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // 4141151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // FusionParam FusionParam 4151151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // \ / 4161151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Const Param \ GTE1 4171151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // \ / \ / 4181151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Fusion -----------> Add 4191151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // | 4201151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Copy 4211151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // | 4221151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Tuple0 4231151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // 4241151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree BuildBodyExprTree(const int64 const_index, const int64 gte_index) { 4251151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Build ExprTree for fused instructions. 4261151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree gte1 = 4271151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree(HloOpcode::kGetTupleElement, "gte", 4281151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")); 4291151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree fused_root(HloOpcode::kAdd, const_index, 4301151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree(HloOpcode::kParameter), gte_index, gte1); 4311151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 4321151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Build fusion instruction (and set fused root). 4331151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree fusion(HloOpcode::kFusion, 0, 4341151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree(HloOpcode::kConstant, "loop_increment"), 1, 4351151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree(HloOpcode::kParameter, "param0")); 4361151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower fusion.SetFusedRoot(fused_root); 4371151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 4381151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Build top-level computation. 439fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower ExprTree tuple0(HloOpcode::kTuple, tuple_index_, 440fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower ExprTree(HloOpcode::kCopy, fusion)); 4411151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower return tuple0; 4421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4441151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower Status MatchExprTree(const ExprTree& expr_tree) override { 445fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower VLOG(2) << "MATCHING while body"; 4461151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower ExprTree::TaggedInstructionMap tagged_instructions; 4471151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), 4481151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower &tagged_instructions)); 4491151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 4501151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower for (const auto& pair : tagged_instructions) { 4511151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower const auto& tag = pair.first; 4521151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower const auto& inst = pair.second; 4531151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 4541151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower if (tag == "gte" && inst->tuple_index() != tuple_index_) { 4551151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Check that the matched GTE instruction is at the 'tuple_index' we 4561151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // matched in the while condition computation. 4571151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower return InvalidArgument("Unexpected tuple index instruction : %s", 4581151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower inst->name().c_str()); 4591151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } else if (tag == "loop_increment") { 4601151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Parse the constant which represents the loop induction variable 4611151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // increment value. 4621151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_)); 4631151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } else if (tag == "param0" && 4641151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower inst != computation_->parameter_instruction(0)) { 4651151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Check that the matched parameter == parameter 0 from 'computation_'. 4661151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower return InvalidArgument("Unexpected Parameter0 instruction : %s", 4671151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower inst->name().c_str()); 4681151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } else if (tag == "gte.fusion_param.param0") { 4691151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower // Fusion parameter: lookup and compare with associated fusion operand. 4701151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower CHECK_EQ(HloOpcode::kParameter, inst->opcode()); 4711151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower CHECK(inst->IsFused()); 472e565d1f1fced69789feb10f1ea1241157ec95f93A. Unique TensorFlower if (inst->parent()->FusionInstruction()->operand( 473e565d1f1fced69789feb10f1ea1241157ec95f93A. Unique TensorFlower inst->parameter_number()) != 4741151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower computation_->parameter_instruction(0)) { 4751151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower return InvalidArgument("Could not match fusion param: %s", 4761151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower inst->name().c_str()); 4771151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } 4781151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower } 4791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return tensorflow::Status::OK(); 4811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4831151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower const HloComputation* computation_; 4841151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower const int64 tuple_index_; 4851151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 4861151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower int64 loop_increment_ = -1; 4871151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower 4881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TF_DISALLOW_COPY_AND_ASSIGN(WhileBodyComputationMatcher); 4891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 4901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace 4921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsStatusOr<std::tuple<int64, int64, int64>> CanTransformWhileToFor( 4941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const HloInstruction* while_hlo) { 4951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (while_hlo->opcode() != HloOpcode::kWhile) { 4961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return InvalidArgument("Expected While instruction."); 4971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 4981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 4991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins WhileConditionComputationMatcher cond_matcher(while_hlo->while_condition()); 5001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TF_RETURN_IF_ERROR(cond_matcher.Run()); 5011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins WhileInitOperandMatcher init_matcher(while_hlo, cond_matcher.tuple_index()); 5031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TF_RETURN_IF_ERROR(init_matcher.Run()); 5041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins WhileBodyComputationMatcher body_matcher(while_hlo->while_body(), 5061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins cond_matcher.tuple_index()); 5071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TF_RETURN_IF_ERROR(body_matcher.Run()); 5081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Check for valid For loop parameters. 5101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (init_matcher.loop_start() >= cond_matcher.loop_limit()) { 5111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return InvalidArgument("Loop start must be less than loop limit."); 5121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (body_matcher.loop_increment() <= 0) { 5141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return InvalidArgument("Loop increment must greater than zero."); 5151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 5161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return std::make_tuple(init_matcher.loop_start(), cond_matcher.loop_limit(), 5171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins body_matcher.loop_increment()); 5181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 5191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace gpu 5211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace xla 522