11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/gpu/while_transformer.h"
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
181151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower#include <unordered_map>
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <vector>
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/literal_util.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_computation.h"
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/shape_util.h"
241151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower#include "tensorflow/compiler/xla/status_macros.h"
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/util.h"
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/lib/core/errors.h"
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace gpu {
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace {
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// TODO(b/33483676) Use an expression tree to specify computations to pattern
341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// match for while transformations.
351151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
361151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// ExprTree is a simple recursive data structure used to express computation
371151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// patterns to match.
381151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//
391151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// Each ExprTree node is comprised of an HloOpcode, and a set of operands (each
4046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower// of type ExprTree). Operands can be added by specifying the index and
4146737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower// HloOpcode of the operand.
421151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//
431151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// For example, the following computation:
441151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//
451151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//            Parameter
461151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//               |
471d7d8667d19e61bc65f35a6dae33563b2acadaacManHyuk//   Const  GetTupleElement
481151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//      \   /
491151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//       Add (root)
501151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//
511151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// Can be matched with the following expression tree:
521151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//
531151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//   ExprTree add(HloOpcode::kAdd,
541151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//                ExprTree(HloOpcode::kConstant),
551151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//                ExprTree(HloOpcode::kGetTupleElement,
561151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//                         tuple_index, ExprTree(HloOpcode::kParameter)));
571151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//
581151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// Match the ExprTree root against an Hlo graph:
591151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//
601151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//   ExprTree::TaggedInstructionMap tagged_instructions;
611151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//   TF_RETURN_IF_ERROR(add.Match(computation_->root_instruction(),
621151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//                                &tagged_instructions));
631151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//
641151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// Instructions that are "tagged" with a context-specific string will
650b6f2189cf879c89f6d72a27de7800ccea095605ManHyuk// be returned in 'tagged_instructions' for further processing (i.e. parsing
661151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// constants or recording the tuple_index).
671151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower//
681151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlowerclass ExprTree {
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
701151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  explicit ExprTree(HloOpcode opcode) : opcode_(opcode) {}
711151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  ExprTree(HloOpcode opcode, const string& tag) : opcode_(opcode), tag_(tag) {}
721151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  ExprTree(HloOpcode opcode, const ExprTree& operand0) : opcode_(opcode) {
731151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    SetOperand(0, operand0);
741151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  }
751151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0)
761151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      : opcode_(opcode) {
771151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    SetOperand(index0, operand0);
781151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  }
791151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0,
801151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower           int64 index1, const ExprTree& operand1)
811151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      : opcode_(opcode) {
821151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    SetOperand(index0, operand0);
831151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    SetOperand(index1, operand1);
841151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  }
851151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  ExprTree(HloOpcode opcode, const string& tag, const ExprTree& operand0)
861151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      : opcode_(opcode), tag_(tag) {
871151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    SetOperand(0, operand0);
881151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  }
891151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  ExprTree(HloOpcode opcode, const ExprTree& operand0, const ExprTree& operand1)
901151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      : opcode_(opcode) {
911151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    SetOperand(0, operand0);
921151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    SetOperand(1, operand1);
931151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  }
941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
951151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  ExprTree(const ExprTree& to_copy) {
961151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    opcode_ = to_copy.opcode_;
971151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    tag_ = to_copy.tag_;
981151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    if (to_copy.fused_root_tree_ != nullptr) {
991151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      fused_root_tree_.reset(new ExprTree(*to_copy.fused_root_tree_));
1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
1011151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    for (auto& pair : to_copy.operands_) {
1021151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      CHECK(operands_.find(pair.first) == operands_.end());
1031151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      operands_.insert(std::make_pair(
1041151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower          pair.first, std::unique_ptr<ExprTree>(new ExprTree(*pair.second))));
1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1081151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  void SetFusedRoot(const ExprTree& fused_root) {
1091151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    fused_root_tree_.reset(new ExprTree(fused_root));
1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1121151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  typedef std::unordered_map<string, const HloInstruction*>
1131151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      TaggedInstructionMap;
1141151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
1151151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  // Matches 'instruction' HloOpcode against 'opcode_'.
1161151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  // Recursively matches each operand in 'operands_'.
1171151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  // Recursively matches fused instructions starting at 'fused_root_tree_'
1181151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  // if 'opcode_ == kFusion'.
1191151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  // Returns OK status, and instructions in 'tagged_instructions' for each
1201151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  // matched ExprTree node with a non-empty 'tag_'.
1211151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  // Returns error message on failure.
1221151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  Status Match(const HloInstruction* instruction,
1231151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower               TaggedInstructionMap* tagged_instructions) const {
1241151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    if (opcode_ != instruction->opcode()) {
125fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower      return InvalidArgument("got opcode %s, want %s",
126fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower                             HloOpcodeString(instruction->opcode()).c_str(),
127fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower                             HloOpcodeString(opcode_).c_str());
1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
1291151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
130fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower    VLOG(2) << "Matched " << HloOpcodeString(opcode_) << ": " << tag_;
1311151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    if (!tag_.empty()) {
1321151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      tagged_instructions->insert({tag_, instruction});
1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
1341151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
1351151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    if (instruction->opcode() == HloOpcode::kFusion) {
1361151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      CHECK(fused_root_tree_ != nullptr);
1371151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      // Match fused instructions for this node starting a 'fused_root_tree'.
1381151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      TF_RETURN_IF_ERROR(fused_root_tree_->Match(
1391151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower          instruction->fused_expression_root(), tagged_instructions));
1401151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    }
1411151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
1421151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    // Match each operand in 'operands_'.
1431151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    for (auto& pair : operands_) {
1441151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first),
1451151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                                            tagged_instructions));
1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return tensorflow::Status::OK();
1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1501151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower private:
1511151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  void SetOperand(int64 index, const ExprTree& operand) {
1521151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    CHECK_EQ(0, operands_.count(index));
1531151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    operands_.insert(std::make_pair(index, MakeUnique<ExprTree>(operand)));
1541151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  }
1551151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
1561151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  HloOpcode opcode_;
1571151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  std::unordered_map<int64, std::unique_ptr<ExprTree>> operands_;
1581151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  std::unique_ptr<ExprTree> fused_root_tree_;
1591151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  string tag_;
1601151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower};
1611151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
1621151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// MatcherBase is a base class that provides common functionality for
1631151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// sub-classes which match specific target sub-computations (i.e. loop
1641151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// induction variable initialization, comparison and update).
1651151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlowerclass MatcherBase {
1661151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower public:
1671151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  MatcherBase() {}
1681151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  virtual ~MatcherBase() {}
1691151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
1701151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  // Attempts to match each ExprTree in 'expr_trees_'.
17153cb26d05a5c2080d8022124178b1cc43a30ffe5A. Unique TensorFlower  // Returns OK on the first successful match, error status otherwise.
1721151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  virtual tensorflow::Status Run() {
1731151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    Status status;
1741151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    for (const ExprTree& expr_tree : expr_trees_) {
1751151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      status = MatchExprTree(expr_tree);
1761151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      if (status.ok()) {
1771151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        return status;
1781151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      }
1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
1801151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    return status;
1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1831151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  virtual Status MatchExprTree(const ExprTree& expr_tree) = 0;
1841151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
1851151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  // Returns the constant value parsed form kConstant 'instruction'.
1861151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  // Returns error status otherwise.
1871151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  Status ParseConstInteger(const HloInstruction* instruction,
1881151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                           int64* const_value) const {
1891151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    CHECK_EQ(HloOpcode::kConstant, instruction->opcode());
1901151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    PrimitiveType element_type = instruction->shape().element_type();
1911151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    if (element_type != S32 && element_type != S64) {
1921151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      return InvalidArgument("Expected constant of integral type.");
1931151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    }
1941151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    const Literal& literal = instruction->literal();
1951151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    PrimitiveType type = literal.shape().element_type();
1961151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    if (type != S32 && type != S64) {
1971151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      return InvalidArgument("Must use S32 or S64 integral types.");
1981151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    }
1991151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    if (type == S32) {
20046737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *const_value = static_cast<int64>(literal.GetFirstElement<int32>());
2011151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    } else if (type == S64) {
20246737e4e81314f7482bfd6a710f126a27f5d7975A. Unique TensorFlower      *const_value = literal.GetFirstElement<int64>();
2031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
2041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return tensorflow::Status::OK();
2051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
2061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2071151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  StatusOr<const HloInstruction*> GetTaggedInstruction(
2081151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      const string& tag,
2091151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      const ExprTree::TaggedInstructionMap& tagged_instructions) {
2101151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    auto it = tagged_instructions.find(tag);
2111151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    if (it == tagged_instructions.end()) {
2121151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      return InvalidArgument("Cound not find instruction for tag: %s",
2131151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                             tag.c_str());
2141151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    }
2151151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    return it->second;
2161151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  }
2171151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
2181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins protected:
2191151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  std::vector<ExprTree> expr_trees_;
2201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins private:
2221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TF_DISALLOW_COPY_AND_ASSIGN(MatcherBase);
2231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
2241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
22528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower// WhileConditionComputationMatcher attempts to match a target computation
2261151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// pattern in the while condition sub-computation.
2271151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// If the target pattern is matched, two pieces of information are extracted
2281151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// from 'tagged' instructions returned by the matcher:
2291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins//
2301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// *) 'tuple_index':
2311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins//    *) The loop induction variable tuple_index from the GetTupleElement
2321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins//       instruction of the matched computation.
2331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins//    *) Used in subsequent matching passes of while init operand and body
2341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins//       computations to select loop induction variable tuple element.
2351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins//
2361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// *) 'loop_limit':
2371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins//    *) The integral value from Constant root operand in matched computation.
2381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins//    *) Used as the constant for the loop limit.
2391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins//
2401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass WhileConditionComputationMatcher : public MatcherBase {
2411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
2426882effb863dcd0da00d3287959deac46734a0b2A. Unique TensorFlower  explicit WhileConditionComputationMatcher(const HloComputation* computation)
2431151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      : computation_(computation) {
2441151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    expr_trees_.emplace_back(BuildCondExprTree());
2451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
2461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2471151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  int64 loop_limit() const { return loop_limit_; }
2481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  int64 tuple_index() const { return tuple_index_; }
2491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins private:
2511151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  // Builds expression tree for the following condition computation:
2521151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //
2531151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //     Const  Parameter
2541151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //        \     /
2551151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //         Fusion ------------> FusionParam FusionParam
2561151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //                                  \          /
2571151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //                                  GTE       /
2581151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //                                    \      /
2591151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //                                    LessThan (fused root)
2601151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //
2611151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  ExprTree BuildCondExprTree() {
2621151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    // Build ExprTree for fused instructions.
2631151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    ExprTree fused_root(
2641151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        HloOpcode::kLt,
2651151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        ExprTree(HloOpcode::kGetTupleElement, "gte",
2661151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                 ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")),
2671151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        ExprTree(HloOpcode::kParameter));
2681151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
2691151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    // Build top-level computation.
2701151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    ExprTree root(HloOpcode::kFusion,
2711151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                  ExprTree(HloOpcode::kConstant, "loop_limit"),
2721151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                  ExprTree(HloOpcode::kParameter, "param0"));
2731151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
2741151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    root.SetFusedRoot(fused_root);
2751151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    return root;
2761151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  }
2771151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
2781151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  Status MatchExprTree(const ExprTree& expr_tree) override {
279fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower    VLOG(2) << "MATCHING while condition";
2801151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    ExprTree::TaggedInstructionMap tagged_instructions;
2811151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(),
2821151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                                       &tagged_instructions));
2831151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
2841151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    // Get tagged GTE instruction and set 'tuple_index_'.
2851151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    TF_ASSIGN_OR_RETURN(const HloInstruction* gte,
2861151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                        GetTaggedInstruction("gte", tagged_instructions));
2871151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    tuple_index_ = gte->tuple_index();
2881151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
2891151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    // Get tagged Constant instruction and parse 'loop_limit_'.
2901151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    TF_ASSIGN_OR_RETURN(
2911151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        const HloInstruction* const_hlo,
2921151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        GetTaggedInstruction("loop_limit", tagged_instructions));
2931151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_limit_));
2941151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
2951151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    // Get tagged "param0" instruction, and check that it matches
2961151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    // 'computation_' parameter 0.
2971151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    TF_ASSIGN_OR_RETURN(const HloInstruction* param0,
2981151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                        GetTaggedInstruction("param0", tagged_instructions));
2991151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    if (param0 != computation_->parameter_instruction(0)) {
3001151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      return InvalidArgument("Unexpected Parameter0 instruction : %s",
3011151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                             param0->name().c_str());
3021151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    }
3031151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
3041151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    // Get tagged 'gte.fusion_param.param0', find its associated fusion operand,
3051151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    // and compare it to 'computation_' parameter0.
3061151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    TF_ASSIGN_OR_RETURN(
3071151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        const HloInstruction* gte_fusion_param0,
3081151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        GetTaggedInstruction("gte.fusion_param.param0", tagged_instructions));
3091151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    CHECK_EQ(HloOpcode::kParameter, gte_fusion_param0->opcode());
3101151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    CHECK(gte_fusion_param0->IsFused());
311e565d1f1fced69789feb10f1ea1241157ec95f93A. Unique TensorFlower    if (gte_fusion_param0->parent()->FusionInstruction()->operand(
3121151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower            gte_fusion_param0->parameter_number()) !=
3131151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        computation_->parameter_instruction(0)) {
3141151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      return InvalidArgument("Could not match fusion param: %s",
3151151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                             gte_fusion_param0->name().c_str());
3161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
3171151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
3181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return tensorflow::Status::OK();
3191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
3201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3211151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  const HloComputation* computation_;
3221151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
3231151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  int64 loop_limit_ = -1;
3241151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  int64 tuple_index_ = -1;
3251151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
3261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TF_DISALLOW_COPY_AND_ASSIGN(WhileConditionComputationMatcher);
3271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
3281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3291151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// WhileInitOperandMatcher matches a target computation pattern of the
3301151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// while instructions 'init' operand, indexing the tuple at 'tuple_index'.
3311151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// On success, parses constant 'loop_start' which represents the loop induction
3321151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// variable start values, then returns OK.
3331151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// Returns error status otherwise.
3341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass WhileInitOperandMatcher : public MatcherBase {
3351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
3361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  WhileInitOperandMatcher(const HloInstruction* while_hlo,
3371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                          const int64 tuple_index)
3381151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      : while_hlo_(while_hlo), tuple_index_(tuple_index) {
3391151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    expr_trees_.emplace_back(BuildInitExprTree());
3401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
3411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3421151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  int64 loop_start() const { return loop_start_; }
3431151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
3441151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower private:
3451151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  // Builds expression tree for the following while init operand subcomputation:
3461151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //
3471151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //             Const
3481151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //               |
3491151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //             Copy
3501151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //               |
3511151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //             Tuple0
3521151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //               |
3531151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //             While
3541151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //
3551151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  ExprTree BuildInitExprTree() {
356fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower    return ExprTree(
357fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower        HloOpcode::kWhile, "while",
358fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower        ExprTree(HloOpcode::kTuple, tuple_index_,
359fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower                 ExprTree(HloOpcode::kCopy,
360fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower                          ExprTree(HloOpcode::kConstant, "loop_start"))));
3611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
3621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3631151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  Status MatchExprTree(const ExprTree& expr_tree) override {
364fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower    VLOG(2) << "MATCHING while init";
3651151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    ExprTree::TaggedInstructionMap tagged_instructions;
3661151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    TF_RETURN_IF_ERROR(expr_tree.Match(while_hlo_, &tagged_instructions));
3671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3681151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    // Get tagged while instruction check against 'while_hlo_'.
3691151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    TF_ASSIGN_OR_RETURN(const HloInstruction* while_hlo,
3701151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                        GetTaggedInstruction("while", tagged_instructions));
3711151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    if (while_hlo != while_hlo_) {
3721151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      return InvalidArgument("Expected While for instruction : %s",
3731151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                             while_hlo->name().c_str());
3741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
3751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3761151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    // Get tagged Constant instruction and parse 'loop_start_'.
3771151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    TF_ASSIGN_OR_RETURN(
3781151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        const HloInstruction* const_hlo,
3791151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        GetTaggedInstruction("loop_start", tagged_instructions));
3801151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_));
3811151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
3821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return tensorflow::Status::OK();
3831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
3841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3851151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  const HloInstruction* while_hlo_;
3861151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  const int64 tuple_index_;
3871151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
3881151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  int64 loop_start_ = -1;
3891151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
3901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TF_DISALLOW_COPY_AND_ASSIGN(WhileInitOperandMatcher);
3911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
3921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3931151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// WhileBodyComputationMatcher matches a target computation pattern for
3941151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// the loop induction variable update. Matching proceeds from the while body
3951151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// computation root[tuple_index] to param[tuple_index], where 'tuple_index'
3961151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// If the target pattern is matched, parses a constant which represents the
3971151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// loop induction variable increment value, then returns status OK.
3981151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower// Returns error status otherwise.
3991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass WhileBodyComputationMatcher : public MatcherBase {
4001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
4011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  WhileBodyComputationMatcher(const HloComputation* computation,
4021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                              const int64 tuple_index)
4031151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      : computation_(computation), tuple_index_(tuple_index) {
4041151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    expr_trees_.emplace_back(BuildBodyExprTree(0, 1));
4051151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    expr_trees_.emplace_back(BuildBodyExprTree(1, 0));
4061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
4071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4081151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  int64 loop_increment() const { return loop_increment_; }
4091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins private:
4111151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  // Builds expression tree for the following while body computation:
4121151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //
4131151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //
4141151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //                               FusionParam FusionParam
4151151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //                                     \      /
4161151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //                  Const Param         \   GTE1
4171151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //                     \  /              \  /
4181151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //                    Fusion -----------> Add
4191151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //                      |
4201151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //                     Copy
4211151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //                      |
4221151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //                     Tuple0
4231151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  //
4241151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  ExprTree BuildBodyExprTree(const int64 const_index, const int64 gte_index) {
4251151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    // Build ExprTree for fused instructions.
4261151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    ExprTree gte1 =
4271151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        ExprTree(HloOpcode::kGetTupleElement, "gte",
4281151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                 ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0"));
4291151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    ExprTree fused_root(HloOpcode::kAdd, const_index,
4301151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                        ExprTree(HloOpcode::kParameter), gte_index, gte1);
4311151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
4321151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    // Build fusion instruction (and set fused root).
4331151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    ExprTree fusion(HloOpcode::kFusion, 0,
4341151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                    ExprTree(HloOpcode::kConstant, "loop_increment"), 1,
4351151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                    ExprTree(HloOpcode::kParameter, "param0"));
4361151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    fusion.SetFusedRoot(fused_root);
4371151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
4381151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    // Build top-level computation.
439fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower    ExprTree tuple0(HloOpcode::kTuple, tuple_index_,
440fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower                    ExprTree(HloOpcode::kCopy, fusion));
4411151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    return tuple0;
4421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
4431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4441151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  Status MatchExprTree(const ExprTree& expr_tree) override {
445fc197e6c77e336700a22e04df2b1f20e0fc72fd5A. Unique TensorFlower    VLOG(2) << "MATCHING while body";
4461151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    ExprTree::TaggedInstructionMap tagged_instructions;
4471151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(),
4481151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                                       &tagged_instructions));
4491151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
4501151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower    for (const auto& pair : tagged_instructions) {
4511151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      const auto& tag = pair.first;
4521151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      const auto& inst = pair.second;
4531151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
4541151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      if (tag == "gte" && inst->tuple_index() != tuple_index_) {
4551151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        // Check that the matched GTE instruction is at the 'tuple_index' we
4561151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        // matched in the while condition computation.
4571151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        return InvalidArgument("Unexpected tuple index instruction : %s",
4581151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                               inst->name().c_str());
4591151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      } else if (tag == "loop_increment") {
4601151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        // Parse the constant which represents the loop induction variable
4611151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        // increment value.
4621151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_));
4631151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      } else if (tag == "param0" &&
4641151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                 inst != computation_->parameter_instruction(0)) {
4651151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        // Check that the matched parameter == parameter 0 from 'computation_'.
4661151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        return InvalidArgument("Unexpected Parameter0 instruction : %s",
4671151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                               inst->name().c_str());
4681151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      } else if (tag == "gte.fusion_param.param0") {
4691151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        // Fusion parameter: lookup and compare with associated fusion operand.
4701151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        CHECK_EQ(HloOpcode::kParameter, inst->opcode());
4711151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        CHECK(inst->IsFused());
472e565d1f1fced69789feb10f1ea1241157ec95f93A. Unique TensorFlower        if (inst->parent()->FusionInstruction()->operand(
473e565d1f1fced69789feb10f1ea1241157ec95f93A. Unique TensorFlower                inst->parameter_number()) !=
4741151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower            computation_->parameter_instruction(0)) {
4751151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower          return InvalidArgument("Could not match fusion param: %s",
4761151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower                                 inst->name().c_str());
4771151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower        }
4781151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower      }
4791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
4801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return tensorflow::Status::OK();
4811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
4821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4831151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  const HloComputation* computation_;
4841151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  const int64 tuple_index_;
4851151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
4861151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower  int64 loop_increment_ = -1;
4871151644f5a918d76ef846bdd46d898c8b4f7aa03A. Unique TensorFlower
4881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TF_DISALLOW_COPY_AND_ASSIGN(WhileBodyComputationMatcher);
4891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
4901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace
4921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsStatusOr<std::tuple<int64, int64, int64>> CanTransformWhileToFor(
4941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const HloInstruction* while_hlo) {
4951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (while_hlo->opcode() != HloOpcode::kWhile) {
4961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return InvalidArgument("Expected While instruction.");
4971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
4981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
4991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  WhileConditionComputationMatcher cond_matcher(while_hlo->while_condition());
5001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TF_RETURN_IF_ERROR(cond_matcher.Run());
5011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  WhileInitOperandMatcher init_matcher(while_hlo, cond_matcher.tuple_index());
5031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TF_RETURN_IF_ERROR(init_matcher.Run());
5041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  WhileBodyComputationMatcher body_matcher(while_hlo->while_body(),
5061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                           cond_matcher.tuple_index());
5071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TF_RETURN_IF_ERROR(body_matcher.Run());
5081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Check for valid For loop parameters.
5101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (init_matcher.loop_start() >= cond_matcher.loop_limit()) {
5111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return InvalidArgument("Loop start must be less than loop limit.");
5121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (body_matcher.loop_increment() <= 0) {
5141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return InvalidArgument("Loop increment must greater than zero.");
5151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
5161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return std::make_tuple(init_matcher.loop_start(), cond_matcher.loop_limit(),
5171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                         body_matcher.loop_increment());
5181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
5191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace gpu
5211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
522