1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// HLO instructions are in DAG form and represent the computations that the user
17// has built up via the XLA service interface. They are ultimately lowered
18// in a platform-aware way by traversing the HLO DAG and emitting a lowered
19// form; e.g. see DfsHloVisitor.
20
21#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
22#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
23
24#include <functional>
25#include <iosfwd>
26#include <list>
27#include <memory>
28#include <set>
29#include <string>
30#include <tuple>
31#include <unordered_map>
32#include <unordered_set>
33#include <vector>
34
35#include "tensorflow/compiler/xla/iterator_util.h"
36#include "tensorflow/compiler/xla/literal_util.h"
37#include "tensorflow/compiler/xla/map_util.h"
38#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
39#include "tensorflow/compiler/xla/service/hlo.pb.h"
40#include "tensorflow/compiler/xla/service/hlo_opcode.h"
41#include "tensorflow/compiler/xla/service/hlo_sharding.h"
42#include "tensorflow/compiler/xla/service/name_uniquer.h"
43#include "tensorflow/compiler/xla/types.h"
44#include "tensorflow/compiler/xla/xla_data.pb.h"
45#include "tensorflow/core/lib/core/status.h"
46#include "tensorflow/core/lib/core/stringpiece.h"
47#include "tensorflow/core/lib/gtl/array_slice.h"
48#include "tensorflow/core/lib/gtl/flatmap.h"
49#include "tensorflow/core/lib/gtl/inlined_vector.h"
50#include "tensorflow/core/lib/gtl/iterator_range.h"
51#include "tensorflow/core/platform/logging.h"
52#include "tensorflow/core/platform/macros.h"
53#include "tensorflow/core/platform/types.h"
54
55namespace xla {
56
57class HloComputation;
58class HloModule;
59
60// A bunch of switches that control how the hlo text should be printed.
61class HloPrintOptions {
62 public:
63  // Constructs the default print options: don't print large constants, don't
64  // compact operands, no indentation.
65  HloPrintOptions()
66      : print_large_constants_(false),
67        print_subcomputation_references_(true),
68        print_metadata_(true),
69        compact_operands_(false),
70        print_operand_shape_(true),
71        print_program_shape_(true),
72        print_percent_(true),
73        indent_amount_(0) {}
74
75  static HloPrintOptions ShortParsable() {
76    return HloPrintOptions()
77        .set_print_large_constants(true)
78        .set_print_subcomputation_references(true)
79        .set_print_metadata(false)
80        .set_print_operand_shape(false)
81        .set_print_program_shape(false)
82        .set_print_percent(false);
83  }
84
85  // If true, large constants will be printed out.
86  HloPrintOptions& set_print_large_constants(bool value) {
87    print_large_constants_ = value;
88    return *this;
89  }
90
91  // If true, the names of subcomputations (e.g. a fusion node's fused
92  // computation) won't be printed.  This makes the resulting text not parsable.
93  //
94  // A CustomCall's call target is printed even if
95  // print_subcomputation_references is false, because the call target isn't an
96  // HloComputation.
97  HloPrintOptions& set_print_subcomputation_references(bool value) {
98    print_subcomputation_references_ = value;
99    return *this;
100  }
101
102  // If true, metatdata will be printed.
103  HloPrintOptions& set_print_metadata(bool value) {
104    print_metadata_ = value;
105    return *this;
106  }
107
108  // If true, operands' shapes will be printed.
109  HloPrintOptions& set_print_operand_shape(bool value) {
110    print_operand_shape_ = value;
111    return *this;
112  }
113
114  // If true, program shape of hlo computations will be printed.
115  HloPrintOptions& set_print_program_shape(bool value) {
116    print_program_shape_ = value;
117    return *this;
118  }
119
120  // If true, names will be printed with prefix '%'.
121  HloPrintOptions& set_print_percent(bool value) {
122    print_percent_ = value;
123    return *this;
124  }
125
126  // If true, only a part of operands will be printed out, and their names will
127  // be omitted (note that in this case the text will not be parsable).
128  HloPrintOptions& set_compact_operands(bool value) {
129    compact_operands_ = value;
130    return *this;
131  }
132
133  // The indent of the hlo text block.
134  HloPrintOptions& set_indent_amount(int value) {
135    indent_amount_ = value;
136    return *this;
137  }
138
139  bool print_large_constants() const { return print_large_constants_; }
140  bool print_subcomputation_references() const {
141    return print_subcomputation_references_;
142  }
143  bool print_metadata() const { return print_metadata_; }
144  bool compact_operands() const { return compact_operands_; }
145  bool print_operand_shape() const { return print_operand_shape_; }
146  bool print_program_shape() const { return print_program_shape_; }
147  bool print_percent() const { return print_percent_; }
148  int indent_amount() const { return indent_amount_; }
149
150 private:
151  bool print_large_constants_;
152  bool print_subcomputation_references_;
153  bool print_metadata_;
154  bool compact_operands_;
155  bool print_operand_shape_;
156  bool print_program_shape_;
157  bool print_percent_;
158  int indent_amount_;
159};
160
161// HLO instructions are the IR used by the high-level compiler.
162class HloInstruction {
163 public:
164  enum class FusionKind {
165    kLoop,          // Fused into a loop.
166    kInput,         // Op's input is fused into the op itself.
167    kOutput,        // Op's output is fused into the op itself.
168                    // REQUIRES: At least one operand buffer must be able
169                    // to alias the output buffer.
170    kTransposeDot,  // Fused into a dot with transposed operands.
171    kCustom,        // Custom category for backend-specific fusions that
172                    // do not match any of the more specific ones.
173  };
174
175  ~HloInstruction();
176
177  // Creates an instruction from the given proto. Arguments:
178  //
179  //   module: the module which will contain the instruction. The newly created
180  //     instruction is *not* added to the module or any computation, however.
181  //   proto: the proto to convert from.
182  //   instruction_map: a map from instruction name to HloInstruction*. This map
183  //     must contain all operands of the newly constructed instruction.
184  //   computation_map: a map from computation name to HloComputation*. This map
185  //     must contain all computations which the newly constructed instruction
186  //     calls.
187  //   add_fused_computation: A function to call to add a fused
188  //     computation. Used (clearly) when the instruction is a fusion
189  //     instruction.
190  static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
191      HloModule* module, const HloInstructionProto& proto,
192      const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
193      const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
194      const std::function<void(std::unique_ptr<HloComputation>)>&
195          add_fused_computation);
196
197  // Creates a parameter-retrieving instruction.
198  static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
199                                                         const Shape& shape,
200                                                         const string& name);
201
202  // Creates a literal constant instruction.
203  static std::unique_ptr<HloInstruction> CreateConstant(
204      std::unique_ptr<Literal> literal);
205
206  // Creates a get tuple element instruction.
207  static std::unique_ptr<HloInstruction> CreateGetTupleElement(
208      const Shape& shape, HloInstruction* operand, int64 index);
209
210  // Creates a trace instruction that logs the input operand in the computation.
211  static std::unique_ptr<HloInstruction> CreateTrace(const string& tag,
212                                                     HloInstruction* operand);
213
214  // Creates a random number generation instruction that fills a shape with
215  // random numbers from a given distribution.
216  static std::unique_ptr<HloInstruction> CreateRng(
217      const Shape& shape, RandomDistribution distribution,
218      tensorflow::gtl::ArraySlice<HloInstruction*> parameters);
219
220  // Creates a unary instruction (one operand).
221  // Precondition: opcode must be a legitimate unary operation.
222  static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape,
223                                                     HloOpcode opcode,
224                                                     HloInstruction* operand);
225
226  // Creates a binary instruction (two operands).
227  // Precondition: opcode must be a legitimate binary operation.
228  static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape,
229                                                      HloOpcode opcode,
230                                                      HloInstruction* lhs,
231                                                      HloInstruction* rhs);
232
233  // Creates a ternary instruction (three operands).
234  // Precondition: opcode must be a legitimate ternary operation.
235  static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape,
236                                                       HloOpcode opcode,
237                                                       HloInstruction* lhs,
238                                                       HloInstruction* rhs,
239                                                       HloInstruction* ehs);
240
241  // Creates a variadic instruction (variable number of operands).
242  // Precondition: opcode must be a legitimate variadic operation.
243  static std::unique_ptr<HloInstruction> CreateVariadic(
244      const Shape& shape, HloOpcode opcode,
245      tensorflow::gtl::ArraySlice<HloInstruction*> operands);
246
247  // Creates a map instruction, where the computation (given by the handle) is
248  // applied element-wise to every element in operands (across the operands,
249  // at a given index) with the same `static_operands`.
250  static std::unique_ptr<HloInstruction> CreateMap(
251      const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
252      HloComputation* map_computation,
253      tensorflow::gtl::ArraySlice<HloInstruction*> static_operands = {});
254
255  // Creates a convolution op, where rhs is the convolutional filter
256  // and window describes how the filter is applied to lhs.
257  static std::unique_ptr<HloInstruction> CreateConvolve(
258      const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
259      const Window& window,
260      const ConvolutionDimensionNumbers& dimension_numbers);
261
262  // Creates an FFT op, of the type indicated by fft_type.
263  static std::unique_ptr<HloInstruction> CreateFft(
264      const Shape& shape, HloInstruction* operand, FftType fft_type,
265      tensorflow::gtl::ArraySlice<int64> fft_length);
266
267  // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
268  // dimensions specified in 'dimension_numbers'.
269  static std::unique_ptr<HloInstruction> CreateDot(
270      const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
271      const DotDimensionNumbers& dimension_numbers);
272
273  // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1
274  // of the LHS with dimension 0 of the RHS with no batch dimensions.  Both LHS
275  // and the RHS must be of rank 2.
276  static std::unique_ptr<HloInstruction> CreateCanonicalDot(
277      const Shape& shape, HloInstruction* lhs, HloInstruction* rhs);
278
279  // Creates a reduce-precision op, where operand is the data to reduce in
280  // precision, and exponent_bits and mantissa_bits describe the precision to
281  // reduce it to.
282  static std::unique_ptr<HloInstruction> CreateReducePrecision(
283      const Shape& shape, HloInstruction* operand, const int exponent_bits,
284      const int mantissa_bits);
285
286  // Creates a cross replica sum op.
287  static std::unique_ptr<HloInstruction> CreateCrossReplicaSum(
288      const Shape& shape,
289      tensorflow::gtl::ArraySlice<HloInstruction*> operands);
290
291  // Creates a conversion instruction, where operand is the data to convert and
292  // shape is the target shape for the conversion.
293  static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
294                                                       HloInstruction* operand);
295
296  // Creates a bitcast conversion instruction, where operand is the data to
297  // convert and shape is the target shape for the conversion.
298  static std::unique_ptr<HloInstruction> CreateBitcastConvert(
299      const Shape& shape, HloInstruction* operand);
300
301  // Creates an infeed instruction, which reads data of the given shape from the
302  // Infeed interface of the device.
303  static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape,
304                                                      const string& config);
305
306  // Creates an outfeed instruction, which outputs data.
307  static std::unique_ptr<HloInstruction> CreateOutfeed(
308      const Shape& shape, HloInstruction* operand,
309      tensorflow::StringPiece outfeed_config);
310
311  // Creates an asynchronous send instruction with the given channel id, which
312  // initiates sending the operand data to a unique receive instruction in
313  // another computation that has the same channel id.
314  static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand,
315                                                    int64 channel_id);
316
317  // Blocks until data transfer for the Send instruction (operand) is complete.
318  // The operand must be kSend.
319  static std::unique_ptr<HloInstruction> CreateSendDone(
320      HloInstruction* operand);
321
322  // Creates an asynchronous receive instruction with the given channel id,
323  // which allocates resources to receive data of the given shape from a unique
324  // send instruction in another computation that has the same channel id.
325  static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape,
326                                                    int64 channel_id);
327
328  // Blocks until data transfer for the Recv instruction (operand) is complete
329  // and returns the receive buffer. The operand must be kRecv.
330  static std::unique_ptr<HloInstruction> CreateRecvDone(
331      HloInstruction* operand);
332
333  // Creates a slice instruction, where the operand is sliced by the given
334  // start/limit indices.
335  static std::unique_ptr<HloInstruction> CreateSlice(
336      const Shape& shape, HloInstruction* operand,
337      tensorflow::gtl::ArraySlice<int64> start_indices,
338      tensorflow::gtl::ArraySlice<int64> limit_indices,
339      tensorflow::gtl::ArraySlice<int64> strides);
340
341  // Creates a slice instruction, where the first operand is sliced by
342  // start indices specified in the second operand, and by size specified in
343  // 'slice_sizes'.
344  static std::unique_ptr<HloInstruction> CreateDynamicSlice(
345      const Shape& shape, HloInstruction* operand,
346      HloInstruction* start_indices,
347      tensorflow::gtl::ArraySlice<int64> slice_sizes);
348
349  // Creates a dynamic update slice instruction, which updates a slice
350  // of 'operand' with 'update' and 'start_indices'.
351  static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice(
352      const Shape& shape, HloInstruction* operand, HloInstruction* update,
353      HloInstruction* start_indices);
354
355  // Creates a concatenate instruction, where the operands are concatenated on
356  // the provided dimension.
357  static std::unique_ptr<HloInstruction> CreateConcatenate(
358      const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
359      int64 dimension);
360
361  // Creates a reduce instruction, where the computation (given by the handle)
362  // is applied successively to every element in operand. That is, if f is the
363  // function to apply (which either takes 2 [accumulator, value] or 3
364  // [accumulator, index, value] arguments) and init is a reduction operator
365  // specified initial value (for example, 0 for addition), then this operation
366  // will compute:
367  //   f(f(init, [index0], value0), [index1], value1), ...)
368  static std::unique_ptr<HloInstruction> CreateReduce(
369      const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
370      tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
371      HloComputation* reduce_computation);
372
373  // Creates a reduce-window instruction, where the computation (given
374  // by the handle) is applied window-wise at each valid window
375  // position in the operand.
376  static std::unique_ptr<HloInstruction> CreateReduceWindow(
377      const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
378      const Window& window, HloComputation* reduce_computation);
379
380  // Creates a batch-norm-training instruction.
381  static std::unique_ptr<HloInstruction> CreateBatchNormTraining(
382      const Shape& shape, HloInstruction* operand, HloInstruction* scale,
383      HloInstruction* offset, float epsilon, int64 feature_index);
384
385  // Creates a batch-norm-inference instruction.
386  static std::unique_ptr<HloInstruction> CreateBatchNormInference(
387      const Shape& shape, HloInstruction* operand, HloInstruction* scale,
388      HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
389      float epsilon, int64 feature_index);
390
391  // Creates a batch-norm-grad instruction.
392  static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
393      const Shape& shape, HloInstruction* operand, HloInstruction* scale,
394      HloInstruction* mean, HloInstruction* variance,
395      HloInstruction* grad_output, float epsilon, int64 feature_index);
396
397  // Creates a scatter computation that scatters the `source` array to the
398  // selected indices of each window.
399  static std::unique_ptr<HloInstruction> CreateSelectAndScatter(
400      const Shape& shape, HloInstruction* operand, HloComputation* select,
401      const Window& window, HloInstruction* source, HloInstruction* init_value,
402      HloComputation* scatter);
403
404  // Creates a broadcast instruction.
405  static std::unique_ptr<HloInstruction> CreateBroadcast(
406      const Shape& shape, HloInstruction* operand,
407      tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
408
409  // Creates a sequence of instructions that performs an explicit broadcast of
410  // the operand to the target shape.
411  //
412  // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is
413  // returned as a unique_ptr for API consistency with other factory methods in
414  // this interface.
415  //
416  // TODO(b/72173833) Ideally HloComputations would always be present, and so
417  // the adder being passed by the caller would not be necessary.
418  static std::unique_ptr<HloInstruction> CreateBroadcastSequence(
419      const Shape& output_shape, HloInstruction* operand,
420      const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
421          adder);
422
423  // Creates a pad instruction, where the operand is padded on the edges and
424  // between the elements with the given padding value.
425  static std::unique_ptr<HloInstruction> CreatePad(
426      const Shape& shape, HloInstruction* operand,
427      HloInstruction* padding_value, const PaddingConfig& padding_config);
428
429  // Creates a reshape instruction, where the operand is flattened row-major
430  // order and then reshaped to the given result shape.
431  static std::unique_ptr<HloInstruction> CreateReshape(const Shape& shape,
432                                                       HloInstruction* operand);
433
434  // Creates a transpose instruction which permutes the operand dimensions.
435  static std::unique_ptr<HloInstruction> CreateTranspose(
436      const Shape& shape, HloInstruction* operand,
437      tensorflow::gtl::ArraySlice<int64> dimensions);
438
439  // Creates a while instruction, given a condition computation, a body
440  // computation, and the initial value for the input of the computations. For
441  // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1
442  // corresponds to the C code below.
443  // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 }
444  static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape,
445                                                     HloComputation* condition,
446                                                     HloComputation* body,
447                                                     HloInstruction* init);
448
449  static std::unique_ptr<HloInstruction> CreateConditional(
450      const Shape& shape, HloInstruction* pred,
451      HloInstruction* true_computation_arg, HloComputation* true_computation,
452      HloInstruction* false_computation_arg, HloComputation* false_computation);
453
454  static std::unique_ptr<HloInstruction> CreateGather(
455      const Shape& shape, HloInstruction* operand,
456      HloInstruction* gather_indices,
457      const GatherDimensionNumbers& gather_dim_numbers,
458      tensorflow::gtl::ArraySlice<int64> window_bounds);
459
460  // Creates a fusion instruction. A fusion instruction contains one or more
461  // fused instructions forming an expression with a single root
462  // "fused_root". Additional instructions can be added to the fusion
463  // instruction with the method FuseInstruction.
464  static std::unique_ptr<HloInstruction> CreateFusion(
465      const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root);
466
467  static std::unique_ptr<HloInstruction> CreateFusion(
468      const Shape& shape, FusionKind fusion_kind,
469      tensorflow::gtl::ArraySlice<HloInstruction*> operands,
470      HloComputation* fusion_computation);
471
472  // Creates a call instruction that applies the given computation on the given
473  // operands. "shape" is the resultant shape.
474  static std::unique_ptr<HloInstruction> CreateCall(
475      const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
476      HloComputation* computation);
477
478  // Creates a custom call instruction that applies the given custom call target
479  // to the given operands. "shape" is the resultant shape.
480  static std::unique_ptr<HloInstruction> CreateCustomCall(
481      const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
482      tensorflow::StringPiece custom_call_target);
483
484  // Creates a HostCompute instruction, which records host-side control and
485  // data dependencies for use in instruction scheduling.
486  static std::unique_ptr<HloInstruction> CreateHostCompute(
487      const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
488      tensorflow::StringPiece channel_name, const int64 cost_estimate_ns);
489
490  // Creates a tuple instruction with the given elements. This is a convenience
491  // wrapper around CreateVariadic.
492  static std::unique_ptr<HloInstruction> CreateTuple(
493      tensorflow::gtl::ArraySlice<HloInstruction*> elements);
494
495  // Creates a reverse instruction, which reverses the order of the elements
496  // in the specified dimensions.
497  static std::unique_ptr<HloInstruction> CreateReverse(
498      const Shape& shape, HloInstruction* operand,
499      tensorflow::gtl::ArraySlice<int64> dimensions);
500
501  // Creates an instance of GatherDimensionNumbers.
502  static GatherDimensionNumbers MakeGatherDimNumbers(
503      tensorflow::gtl::ArraySlice<int64> output_window_dims,
504      tensorflow::gtl::ArraySlice<int64> elided_window_dims,
505      tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims);
506
507  // Returns the opcode for this instruction.
508  HloOpcode opcode() const { return opcode_; }
509
510  // Returns true if this instruction has a side effect. An instruction has a
511  // side effect if it uses certain opcodes or calls a computation with a side
512  // effect.
513  bool HasSideEffect() const;
514
515  // Returns the result shape of this instruction.
516  const Shape& shape() const;
517
518  // Returns the (mutable) result shape of this instruction.
519  Shape* mutable_shape() { return &shape_; }
520
521  // Returns the ith operand to this instruction.
522  const HloInstruction* operand(int64 i) const;
523
524  // Returns the ith operand to this instruction.
525  HloInstruction* mutable_operand(int64 i);
526
527  // Returns the number of operands to this instruction.
528  int64 operand_count() const { return operands_.size(); }
529
530  // Returns the vector of operands of this instruction.
531  using InstructionVector = tensorflow::gtl::InlinedVector<HloInstruction*, 2>;
532  const InstructionVector& operands() const { return operands_; }
533
534  // Returns the index of 'target' in the operands sequence.
535  // Precondition: target must be an operand (or a fatal error will occur).
536  int64 operand_index(const HloInstruction* target) const;
537
538  // Returns the number of users of this instruction.
539  int64 user_count() const { return users_.size(); }
540
541  // Returns the users of this instruction.
542  const std::vector<HloInstruction*>& users() const { return users_; }
543
544  // Returns true if this instruction is a user of 'instruction'.
545  bool IsUserOf(const HloInstruction* instruction) const {
546    return ContainsKey(instruction->user_set_, this);
547  }
548
549  // Adds a control dependency from this instruction to the given
550  // instruction. This instruction becomes a control predecessor of
551  // 'instruction', and 'instruction' becomes a control successor of this
552  // instruction. Returns an error status if either of the given instructions
553  // does not belong to the same computation.
554  //
555  // This is used to enforce an additional ordering requirement that is not
556  // captured by normal data dependencies, such as ordering among Send or Recv
557  // operations to avoid deadlock.
558  Status AddControlDependencyTo(HloInstruction* instruction);
559
560  // Removes a previously added control dependency from this instruction to
561  // 'instruction'.
562  Status RemoveControlDependencyTo(HloInstruction* instruction);
563
564  // Returns the set of control predecessors (successors) of this
565  // instruction. Control predecessors (successors) must execute before (after)
566  // the current instruction.
567  const std::vector<HloInstruction*>& control_predecessors() const {
568    return control_predecessors_;
569  }
570  const std::vector<HloInstruction*>& control_successors() const {
571    return control_successors_;
572  }
573
574  // Returns true if "other" performs the same computation as this instruction.
575  bool Identical(
576      const HloInstruction& other,
577      const std::function<bool(const HloInstruction*, const HloInstruction*)>&
578          eq_operands = std::equal_to<const HloInstruction*>(),
579      const std::function<bool(const HloComputation*, const HloComputation*)>&
580          eq_computations = std::equal_to<const HloComputation*>(),
581      bool layout_sensitive = true) const {
582    // An instruction is always identical to itself.
583    if (this == &other) {
584      return true;
585    }
586
587    // Identical instruction must have the same opcode, shape, and identical
588    // operands.
589    if (opcode() != other.opcode()) {
590      return false;
591    }
592    using EqShapeFuncType = bool (*)(const Shape&, const Shape&);
593    EqShapeFuncType eq_shapes =
594        layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible;
595    if (!eq_shapes(shape(), other.shape())) {
596      return false;
597    }
598    if (operands().size() != other.operands().size()) {
599      return false;
600    }
601
602    // Use an explicit loop rather than ContainerEquals, because copying around
603    // std::functions may be too expensive in some cases.
604    for (size_t i = 0; i < operands().size(); ++i) {
605      if (!eq_operands(operand(i), other.operand(i))) {
606        return false;
607      }
608    }
609
610    return IdenticalSlowPath(other, eq_computations, eq_shapes);
611  }
612
613  // Returns whether the instruction has a constant operand.
614  bool HasConstantOperand() const;
615
616  // Returns whether this instruction does a rank-2 transposition.
617  bool IsRank2Transpose() const;
618
619  // Replaces the use of this instruction in "user" with "new_producer". Note
620  // that there might be multiple uses of this instruction in "user"; all will
621  // be replaced.
622  Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer);
623
624  // Replaces the specified operand with new_operand.
625  Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand);
626
627  // Replaces all uses of this instruction with the new producer. If
628  // new_producer is a user of this instruction then new_producer remains a use
629  // of this instruction to avoid introducing cycles into the graph.
630  //
631  // If this instruction is the root of its computation, sets the computation's
632  // root to new_producer.
633  Status ReplaceAllUsesWith(HloInstruction* new_producer);
634
635  // Detaches an instruction from its operands. That is, remove the instruction
636  // from each operand's user set. This should only be called prior to
637  // deallocating the instruction.
638  void DetachFromOperands();
639
640  // Performs a postorder DFS visit using this node as the root. If
641  // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when
642  // complete. If ignore_control_predecessors is true, instructions only
643  // reachable via control dependencies will not be visited, and the postorder
644  // will not take control dependencies into account. It is as if the control
645  // dependencies didn't exist in the graph at all.
646  template <typename HloInstructionPtr>
647  Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
648                bool call_finish_visit = true,
649                bool ignore_control_predecessors = false);
650  Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true,
651                bool ignore_control_predecessors = false) const {
652    return const_cast<HloInstruction*>(this)->Accept(
653        visitor, call_finish_visit, ignore_control_predecessors);
654  }
655
656  // Same as Accept() above, but the order of operand and control predecessor
657  // visitation is determined by the given operand order; if compare(A, B) ==
658  // true, A is visited before B.
659  using CompareFunction =
660      std::function<bool(const HloInstruction*, const HloInstruction*)>;
661  Status AcceptWithOperandOrder(DfsHloVisitor* visitor,
662                                const CompareFunction& operand_order,
663                                bool call_finish_visit = true);
664
665  // Performs a postorder DFS visit using this node as the root. Calls the given
666  // visitor function at each instruction.
667  Status Accept(const std::function<Status(HloInstruction*)>& visitor_func);
668  Status Accept(
669      const std::function<Status(const HloInstruction*)>& visitor_func) const;
670
671  // Visits all instructions rooted at this instruction using the given visitor
672  // in the given order. 'order' must contain at least the set of instructions
673  // rooted at this node (ie, those accessible from a DFS traversal from this
674  // instruction). Instructions contained in 'order' which are not in the set of
675  // instructions rooted at this node are ignored. 'order' must also be a valid
676  // topological sort of these instructions (defs appear before uses) though
677  // need not be a DFS post-order.
678  Status AcceptOrdered(DfsHloVisitor* visitor,
679                       const std::vector<const HloInstruction*>& order);
680
681  // Visit this instruction and only this instruction with the given visitor.
682  template <typename HloInstructionPtr>
683  Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
684
685  // Returns the literal associated with this instruction.
686  //
687  // Note: only constant and parameter opcodes have an associated literal.
688  const Literal& literal() const;
689
690  // Returns the parameter number associated with this instruction.
691  //
692  // Note: only parameter opcodes have an associated parameter number.
693  int64 parameter_number() const {
694    CHECK_EQ(HloOpcode::kParameter, opcode_);
695    return parameter_number_;
696  }
697
698  // Returns the dimension sizes or numbers associated with this instruction.
699  //
700  // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape,
701  // and reverse.
702  const std::vector<int64>& dimensions() const;
703  int64 dimensions(int64 index) const;
704
705  // Accessor for the dimension in which a concatenate HLO should occur.
706  // Precondition: opcode() == HloOpcode::kConcatenate
707  int64 concatenate_dimension() const;
708
709  // Returns the tuple index associated with this instruction.
710  //
711  // Precondition: opcode() == HloOpcode::kGetTupleElement
712  int64 tuple_index() const;
713
714  // Returns the first non-GetTupleElement ancestor instruction of 'hlo'.
715  // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the
716  // (possibly nested) tuple indices used on the path from ancestor to 'hlo'.
717  std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex()
718      const;
719
720  std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() {
721    auto rv =
722        const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex();
723    return {const_cast<HloInstruction*>(rv.first), rv.second};
724  }
725
726  // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction.
727  const HloInstruction* LatestNonGteAncestor() const;
728
729  HloInstruction* LatestNonGteAncestor() {
730    return const_cast<HloInstruction*>(
731        const_cast<const HloInstruction*>(this)->LatestNonGteAncestor());
732  }
733
734  // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc.
735  // The setter should only be called by HloModule or HloComputation methods.
736  //
737  // Precondition: The instruction has a valid to_apply_ field.
738  HloComputation* to_apply() const;
739  void set_to_apply(HloComputation* to_apply);
740
741  // Returns the custom_call_target for CustomCall.
742  // Precondition: opcode() == HloOpcode::kCustomCall
743  const string& custom_call_target() const;
744
745  // Returns the config for the Outfeed instruction.
746  // Precondition: opcode() == HloOpcode::kOutfeed
747  const string& outfeed_config() const;
748
749  // Returns the shape for the Outfeed instruction.
750  // Precondition: opcode() == HloOpcode::kOutfeed
751  const Shape& outfeed_shape() const;
752
753  // Gets/sets the while_condition or while_body HloComputation for While. The
754  // setters should only be called by HloModule or HloComputation methods.
755  //
756  // Precondition: The instruction is a While instruction.
757  HloComputation* while_condition() const;
758  HloComputation* while_body() const;
759  void set_while_condition(HloComputation* while_condition);
760  void set_while_body(HloComputation* while_body);
761
762  // Gets/sets the select or scatter HloComputation for SelectAndScatter. The
763  // setters should only be called by HloModule or HloComputation methods.
764  //
765  // Precondition: opcode() == HloOpcode::kSelectAndScatter.
766  HloComputation* select() const;
767  HloComputation* scatter() const;
768  void set_select(HloComputation* select);
769  void set_scatter(HloComputation* scatter);
770
771  // Gets/sets the true and false HloComputation for Conditional. The setters
772  // should only be called by HloModule or HloComputation methods.
773  //
774  // Precondition: The instruction is a Conditional instruction.
775  HloComputation* true_computation() const;
776  HloComputation* false_computation() const;
777  void set_true_computation(HloComputation* true_computation);
778  void set_false_computation(HloComputation* false_computation);
779
780  // Returns a string for the signature of this instruction if considered as a
781  // function, e.g. the signature of an F32 add is (F32, F32) -> F32.
782  string SignatureString() const;
783
784  // Returns a debugging string that represents this instruction.
785  //
786  // (We express the default options using an overload rather than a default
787  // param because gdb ignores default params, but does resolve overloads.)
788  //
789  // TODO(b/73348663): Make ToString() adaptive to the size of the string by
790  // default, backing off on providing full information for very large strings,
791  // or provide a different name for a ToString-like function that does that.
792  string ToString() const { return ToString(HloPrintOptions()); }
793  string ToString(const HloPrintOptions& options) const;
794
795  // Components of the ToString() representation:
796
797  // Returns a string representation of the operand list.
798  string OperandsToString(const HloPrintOptions& options) const;
799
800  // Returns string representation of op-specific attributes.
801  std::vector<string> ExtraAttributesToString(
802      const HloPrintOptions& options) const;
803
804  // As ToString, but returns a shorter string.
805  string ToShortString() const;
806
807  // Returns a serialized representation of this instruction.
808  HloInstructionProto ToProto() const;
809
810  // Returns a category for the HLO. This could be something like "convolution"
811  // or "elementwise".
812  string ToCategory() const;
813
814  // Returns a logging instruction, if the output of this instruction is logged.
815  //
816  // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace
817  HloInstruction* tracing() const;
818  void set_tracing(HloInstruction* trace_instruction);
819
820  // Returns the channel id associated with the instruction. The id is
821  // shared between each Send/Recv pair and is globally unique to identify each
822  // channel.
823  //
824  // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv
825  int64 channel_id() const { return channel_id_; }
826
827  // Returns feature_index field associated with the instruction. The index
828  // represents the index of the feature dimension.
829  //
830  // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference,
831  // or kBatchNormGrad.
832  int64 feature_index() const { return feature_index_; }
833
834  // Returns a epsilon value associated with the instruction. The is a small
835  // number added to the variance to avoid divide-by-zero error.
836  //
837  // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference,
838  // or kBatchNormGrad.
839  float epsilon() const { return epsilon_; }
840
841  // Returns the infeed configuration string. The infeed configuration includes
842  // any metadata needed for the backend compiler (e.g., infeed buffer address)
843  // and is target-dependent.
844  string infeed_config() const { return infeed_config_; }
845  void set_infeed_config(const string& config) { infeed_config_ = config; }
846
847  // Returns a tag to be used in tracing.
848  //
849  // Precondition: opcode() == HloOpcode::kTrace
850  string TracingTag() const;
851
852  // Returns whether the instruction is a constant.
853  bool IsConstant() const;
854
855  // Returns true if this instruction is fused, ie contained within a fusion
856  // instruction.
857  bool IsFused() const;
858
859  // Returns the computation for this fused instruction.
860  //
861  // Precondition: opcode() == HloOpcode::kFusion
862  HloComputation* fused_instructions_computation() const;
863
864  // Returns true if this instruction can be legally fused into a fusion
865  // instruction.
866  bool IsFusable() const;
867
868  // Returns the root instruction of the fused expression contained within this
869  // fusion instruction.
870  //
871  // Precondition: opcode() == HloOpcode::kFusion
872  HloInstruction* fused_expression_root() const;
873
874  // Returns the list of fused instructions inside this fusion instruction.  The
875  // returned type is a range of HloInstruction*s.
876  //
877  // Precondition: opcode() == HloOpcode::kFusion
878  const tensorflow::gtl::iterator_range<UnwrappingIterator<
879      std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
880  fused_instructions() const;
881
882  const tensorflow::gtl::iterator_range<
883      UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
884  fused_instructions();
885
886  // Gets the number of instructions inside this fusion instruction.
887  //
888  // Precondition: opcode() == HloOpcode::kFusion
889  int64 fused_instruction_count() const;
890
891  // Returns the fused parameter instruction in this fusion instruction
892  // corresponding to the given parameter number.
893  //
894  // Precondition: opcode() == HloOpcode::kFusion
895  HloInstruction* fused_parameter(int64 parameter_number) const;
896
897  // Returns the vector of fused parameters inside this fusion instruction.
898  //
899  // Precondition: opcode() == HloOpcode::kFusion
900  const std::vector<HloInstruction*>& fused_parameters() const;
901
902  // Returns true if this instruction is a fusion instruction that generates
903  // multiple outputs.
904  const bool IsMultiOutputFusion() const {
905    return opcode() == HloOpcode::kFusion &&
906           fused_expression_root()->opcode() == HloOpcode::kTuple;
907  }
908
909  FusionKind fusion_kind() const {
910    CHECK_EQ(HloOpcode::kFusion, opcode_);
911    return fusion_kind_;
912  }
913
914  void set_fusion_kind(FusionKind kind) {
915    CHECK_EQ(HloOpcode::kFusion, opcode_);
916    fusion_kind_ = kind;
917  }
918
919  // Returns the sharding applied to this operator.
920  // REQUIRES: has_sharding() is true.
921  const HloSharding& sharding() const {
922    CHECK(has_sharding());
923    return *sharding_;
924  }
925  // Returns the sharding applied to this operator, or default_ if none exists.
926  const HloSharding& sharding_or_default(const HloSharding& default_) const {
927    return sharding_ ? *sharding_ : default_;
928  }
929  // Sets the sharding of this operator. Should only be called by HloModule or
930  // HloComputation methods.
931  void set_sharding(const HloSharding& sharding) {
932    sharding_ = MakeUnique<HloSharding>(sharding);
933  }
934  // Remove any sharding from this operator.
935  void clear_sharding() { sharding_ = nullptr; }
936  // Return true if this operator has a sharding assigned.
937  bool has_sharding() const { return sharding_ != nullptr; }
938
939  // Adds a new operand the fusion instruction.
940  HloInstruction* AddFusionOperand(HloInstruction* new_operand);
941
942  // Merges the fused instructions from 'instruction_to_merge' into the
943  // fused instruction set of 'this', updating operands as necessary.
944  //
945  // Precondition: opcode() == HloOpcode::kFusion
946  // Predondition: 'instruction_to_merge' must be an operand of 'this'.
947  void MergeFusionInstruction(HloInstruction* instruction_to_merge);
948
949  // Merges the fused instructions from instruction_to_merge into the fused
950  // instruction set of 'this' and generates multioutput fusion instructions.
951  // All the users of instruction_to_merge will be redirected to 'this'
952  // instruction. instruction_to_merge will be removed from its parent
953  // computation.
954  //
955  // Precondition: opcode() == HloOpcode::kFusion
956  void MergeFusionInstructionIntoMultiOutput(
957      HloInstruction* instruction_to_merge);
958
959  // Fuses the given instruction in this fusion instruction. instruction_to_fuse
960  // is cloned and the clone is placed in the fusion
961  // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather
962  // than moved to cleanly handle the case where the instruction has a use
963  // outside the fusion instruction. Moving such an instruction into a fusion
964  // instruction would violate the single-result invariant of HLO instructions
965  // and significantly complicate code generation.
966  //
967  // Precondition: this->opcode() == HloOpcode::kFusion
968  HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) {
969    return FuseInstructionInternal(instruction_to_fuse);
970  }
971
972  // Fuses the given instruction in this fusion instruction and generate
973  // multioutput fusion instruction. A clone of the instruction_to_fuse will
974  // be part of the output of fusion instructions. The users of
975  // instruction_to_fuse will be redirected to this fusion instructions.
976  // instruction_to_fuse will be removed from its parent computation.
977  //
978  // Precondition: this->opcode() == HloOpcode::kFusion
979  HloInstruction* FuseInstructionIntoMultiOutput(
980      HloInstruction* instruction_to_fuse) {
981    return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true);
982  }
983
984  // Returns the start index in the given dimension for a slice node.
985  //
986  // Precondition: opcode() == HloOpcode::kSlice
987  int64 slice_starts(int64 dimension) const {
988    CHECK_EQ(HloOpcode::kSlice, opcode_);
989    return slice_starts_[dimension];
990  }
991  const std::vector<int64>& slice_starts() const { return slice_starts_; }
992
993  // Returns the (exclusive) limit index in the given dimension for a slice
994  // node.
995  //
996  // Precondition: opcode() == HloOpcode::kSlice
997  int64 slice_limits(int64 dimension) const {
998    CHECK_EQ(HloOpcode::kSlice, opcode_);
999    return slice_limits_[dimension];
1000  }
1001  const std::vector<int64>& slice_limits() const {
1002    CHECK_EQ(HloOpcode::kSlice, opcode_);
1003    return slice_limits_;
1004  }
1005
1006  // Returns the stride in the given dimension for a slice node.
1007  //
1008  // Precondition: opcode() == HloOpcode::kSlice
1009  int64 slice_strides(int64 dimension) const {
1010    CHECK_EQ(HloOpcode::kSlice, opcode_);
1011    return slice_strides_[dimension];
1012  }
1013  const std::vector<int64>& slice_strides() const { return slice_strides_; }
1014
1015  // Returns the flag that describes whether a slice must be lowered into an
1016  // offset into the original operand.
1017  bool IsInPlaceSlice() const { return is_in_place_slice_; }
1018
1019  // Sets and returns the flag that describes whether a slice must be lowered
1020  // into an offset into the original operand.
1021  bool SetIsInPlaceSlice(bool value) {
1022    is_in_place_slice_ = value;
1023    return value;
1024  }
1025
1026  // Returns the size of the slice in the given dimension for a dynamic
1027  // slice node.
1028  //
1029  // Precondition: opcode() == HloOpcode::kDynamicSlice
1030  int64 slice_sizes(int64 dimension) const {
1031    CHECK_EQ(HloOpcode::kDynamicSlice, opcode_);
1032    return dynamic_slice_sizes_[dimension];
1033  }
1034  const std::vector<int64>& dynamic_slice_sizes() const {
1035    CHECK_EQ(HloOpcode::kDynamicSlice, opcode_);
1036    return dynamic_slice_sizes_;
1037  }
1038
1039  // Returns the number of exponent bits for a reduce-precision node.
1040  //
1041  // Precondition: opcode() == HloOpcode::kReducePrecision
1042  int32 exponent_bits() const {
1043    CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
1044    return exponent_bits_;
1045  }
1046
1047  // Returns the number of mantissa bits for a reduce-precision node.
1048  //
1049  // Precondition: opcode() == HloOpcode::kReducePrecision
1050  int32 mantissa_bits() const {
1051    CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
1052    return mantissa_bits_;
1053  }
1054
1055  // Returns data on the window in a windowed operation such as
1056  // convolution.
1057  const Window& window() const {
1058    CHECK(window_ != nullptr);
1059    return *window_;
1060  }
1061
1062  // Sets the window data in a windowed operation such as convolution.
1063  void set_window(const Window& window) {
1064    window_ = MakeUnique<Window>(window);
1065  }
1066
1067  // Returns the padding configuration for a pad node.
1068  //
1069  // Precondition: opcode() == HloOpcode::kPad
1070  const PaddingConfig& padding_config() const {
1071    CHECK(padding_config_ != nullptr);
1072    return *padding_config_;
1073  }
1074
1075  // Returns data on the dimension numbers used for a convolution operation,
1076  // which may be a kConvolution instruction or a kCustomCall that implements a
1077  // convolution.
1078  const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
1079    CHECK(convolution_dimension_numbers_ != nullptr);
1080    return *convolution_dimension_numbers_;
1081  }
1082
1083  // Sets the convolution dimension numbers on this instruction.  In general you
1084  // shouldn't need to call this; instead, specify the convolution dimension
1085  // numbers when you create the instruction.
1086  void set_convolution_dimension_numbers(
1087      const ConvolutionDimensionNumbers& dnums) {
1088    convolution_dimension_numbers_ =
1089        MakeUnique<ConvolutionDimensionNumbers>(dnums);
1090  }
1091
1092  FftType fft_type() const {
1093    CHECK_EQ(HloOpcode::kFft, opcode_);
1094    return fft_type_;
1095  }
1096
1097  const std::vector<int64>& fft_length() const {
1098    CHECK_EQ(HloOpcode::kFft, opcode_);
1099    return fft_length_;
1100  }
1101
1102  // Returns the dump string of the convolution dimension numbers.
1103  string ConvolutionDimensionNumbersToString() const;
1104
1105  // Returns data on the dimension numbers used for a dot operation.
1106  const DotDimensionNumbers& dot_dimension_numbers() const {
1107    CHECK(dot_dimension_numbers_ != nullptr);
1108    return *dot_dimension_numbers_;
1109  }
1110
1111  // Returns the dump string of the dot dimension numbers.
1112  string DotDimensionNumbersToString() const;
1113
1114  const GatherDimensionNumbers& gather_dimension_numbers() const {
1115    CHECK(gather_dimension_numbers_ != nullptr);
1116    return *gather_dimension_numbers_;
1117  }
1118
1119  tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const {
1120    CHECK_EQ(opcode(), HloOpcode::kGather);
1121    return gather_window_bounds_;
1122  }
1123
1124  // Returns the dump string of the gather dimension numbers.
1125  string GatherDimensionNumbersToString() const;
1126
1127  // Returns the random distribution for this rng node.
1128  //
1129  // Precondition: opcode() == HloOpcode::kRng
1130  RandomDistribution random_distribution() const;
1131
1132  // Clones the HLO instruction. The clone will have the same opcode, shape, and
1133  // operands. After creation the clone has no uses. "this" (the instruction
1134  // cloned from) is not changed. Suffix is the string to append to the name of
1135  // the instruction to form the name of the cloned instruction.
1136  // If the module pointer is not nullptr, it will be the module where
1137  // the cloned computations will be added to (in order to support deep
1138  // cloning).
1139  std::unique_ptr<HloInstruction> Clone(const string& suffix = "clone",
1140                                        HloModule* module = nullptr) const;
1141
1142  // Clones the HLO instruction as above but with new shape and operands.
1143  // If the module pointer is not nullptr, it will be the module where
1144  // the cloned computations will be added to (in order to support deep
1145  // cloning).
1146  std::unique_ptr<HloInstruction> CloneWithNewOperands(
1147      const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
1148      HloModule* module = nullptr) const;
1149
1150  // Returns the computations this instruction directly calls (if any).
1151  const std::vector<HloComputation*>& called_computations() const {
1152    return called_computations_;
1153  }
1154
1155  // Replaces all called computations based on a map function. This is needed
1156  // when we clone hlo_computations and want to let the instructions to point
1157  // to the newly cloned nodes.
1158  void ReplaceCalledComputations(
1159      std::function<HloComputation*(HloComputation*)> map_function) {
1160    for (int64 i = 0; i < called_computations_.size(); ++i) {
1161      called_computations_[i] = map_function(called_computations_[i]);
1162    }
1163  }
1164
1165  // Clears out the called computations.
1166  //
1167  // This is, in particular, necessary when inlining function bodies into their
1168  // caller. If there were side-effecting operations in the called computations,
1169  // the call itself is considered side-effecting and thus cannot be removed. By
1170  // clearing out the computations, we reflect the fact that all side-effecting
1171  // properties have been reflected in the caller, and make the call HLO
1172  // removable.
1173  void ClearCalledComputations() { called_computations_.clear(); }
1174
1175  // Returns true if this instruction performs an elementwise operation on
1176  // `operand_idx`-th operand. An instruction is elementwise on an operand iff,
1177  // after performing necessary implicit broadcast
1178  // (cs/IrArray::EmitArrayElementAddress), to compute the output at index
1179  // {i_0,i_1,...,i_n}, the only element required from the operand (if any) is
1180  // the element at {i_0,i_1,...,i_n}.
1181  //
1182  // Note on performance: when this instruction is kFusion, this method, in the
1183  // worst case, scans all fused instructions. We could speed this up by
1184  // caching.
1185  bool IsElementwiseOnOperand(int64 operand_idx) const;
1186
1187  // Returns true if this instruction is elementwise on all its operands.
1188  bool IsElementwise() const;
1189
1190  // Returns true if this elementwise instruction implicitly broadcasts operand
1191  // `operand_idx`.
1192  //
1193  // Precondition: this instruction should be an elementwise operation.
1194  bool ImplicitlyBroadcastsOperand(int64 operand_idx) const;
1195
1196  // Returns true if this instruction is binary and elementwise.
1197  bool IsElementwiseBinary() const;
1198
1199  // Returns whether this instruction may reuse elements of its `i`th operand.
1200  bool ReusesOperandElements(int64 i) const {
1201    return OperandElementUse(i) == UseKind::kReuse;
1202  }
1203
1204  // Returns the indices that the given operand appear in the operand list of
1205  // this instruction. Note that an instruction can use the same operand
1206  // multiple times.
1207  std::vector<int64> OperandIndices(const HloInstruction* operand) const;
1208
1209  // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If
1210  // this reshape merely inserts or deletes 1-sized dimensions, return the input
1211  // indices of the deleted dimensions and the output indices of the inserted
1212  // dimensions.
1213  //
1214  // Precondition: this op must be a reshape.
1215  std::tuple<bool, std::vector<int64>, std::vector<int64>>
1216  ReshapeMerelyInsertsOrDeletes1SizedDimensions() const;
1217
1218  // Gets/sets the string identifier for this instruction.
1219  const string& name() const { return name_; }
1220  void set_name(tensorflow::StringPiece name) { name_ = name.ToString(); }
1221
1222  // Use the given NameUniquer to select a unique name for the instruction based
1223  // on the instruction's existing name.
1224  void UniquifyName(NameUniquer* name_uniquer);
1225
1226  // Set the unique id for this instruction to "id"
1227  void SetUniqueId(int id) {
1228    CHECK_EQ(unique_id_, -1);  // Should not be assigned already
1229    CHECK_GE(id, 0);
1230    unique_id_ = id;
1231  }
1232
1233  // Return the unique ID assigned to this node via SetUniqueId (or -1
1234  // if no id has been assigned yet).
1235  int unique_id() const { return unique_id_; }
1236
1237  // Sets the debug metadata for this instruction.
1238  void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
1239  const OpMetadata& metadata() const { return metadata_; }
1240
1241  // Set/get the computation containing this instruction. set_parent should only
1242  // be called by HloComputation methods which add/remove instructions to
1243  // computations.
1244  void set_parent(HloComputation* computation) { parent_ = computation; }
1245  const HloComputation* parent() const { return parent_; }
1246  HloComputation* parent() { return parent_; }
1247
1248  // Returns the module for this instruction.
1249  HloModule* GetModule() const;
1250
1251  // Returns whether we could assign input and output layouts to this
1252  // instruction to make it a bitcast.
1253  bool CouldBeBitcast() const;
1254
1255  // Get/Set the number of partitions per outer dimension (in order, starting
1256  // with outer-most dimension first). Currently used by the parallel cpu
1257  // backend to partition HLOs into parallel tasks.
1258  // TODO(b/62783254) Replace these methods with a more general way to
1259  // annotate HLOs with backend-specific information.
1260  const std::vector<int64>& outer_dimension_partitions() const {
1261    return outer_dimension_partitions_;
1262  }
1263  void set_outer_dimension_partitions(
1264      const std::vector<int64>& outer_dimension_partitions);
1265
1266  // Change the layout for an Constant Hlo instruction to match new_layout.  For
1267  // tuple shaped constants shape_index is the path to the internal array
1268  // subshape whose layout needs to be changed.
1269  void RelayoutConstant(const Layout& new_layout,
1270                        const ShapeIndex& shape_index = {});
1271
1272 private:
1273  enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
1274
1275  // Helper class for computing OperandElementUse for kFusion.
1276  class FusionReusesParamElements;
1277
1278  // See comments on Identical().
1279  // eq_shapes() is used to check shapes for equality, and would normally be
1280  // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on
1281  // whether we want a layout-sensitive check or not.
1282  bool IdenticalSlowPath(
1283      const HloInstruction& other,
1284      const std::function<bool(const HloComputation*, const HloComputation*)>&
1285          eq_computations,
1286      const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const;
1287
1288  // Creates an n-ary elementwise operation.
1289  static std::unique_ptr<HloInstruction> CreateNary(
1290      const Shape& shape, HloOpcode opcode,
1291      tensorflow::gtl::ArraySlice<HloInstruction*> operands);
1292
1293  // Appends operand to the list of operands and adds this instruction as a user
1294  // of the operand.
1295  void AppendOperand(HloInstruction* operand);
1296
1297  // Adds a user for this instruction.
1298  void AddUser(HloInstruction* user);
1299
1300  // Removes a user for this instruction.
1301  void RemoveUser(HloInstruction* user);
1302
1303  // Internal constructor for a given opcode/shape, other fields must be filled
1304  // by factory methods.
1305  HloInstruction(HloOpcode opcode, const Shape& shape);
1306
1307  // Fuses the given instruction into this fusion instruction. When add_output
1308  // is false (which is the default), instruction_to_fuse is cloned and the
1309  // clone is placed in the fusion instruction. instruction_to_fuse is
1310  // unchanged.
1311  //
1312  // When add_output is true, a clone of the instruction_to_fuse will be part
1313  // of the output of fusion instructions. The users of instruction_to_fuse
1314  // will be redirected to this fusion instructions. instruction_to_fuse will
1315  // be removed from its parent computation.
1316  //
1317  // Precondition: this->opcode() == HloOpcode::kFusion
1318  HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse,
1319                                          bool add_output = false);
1320
1321  // Clones the given instruction_to_fuse and insert the clone into this fusion
1322  // instruction. If add_output is true, a clone of instruction_to_fuse will
1323  // be in the output of the this fusion instruction (part of the tuple of the
1324  // fusion root).
1325  //
1326  // Precondition: opcode() == HloOpcode::kFusion
1327  HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse,
1328                                       bool add_output = false);
1329
1330  // Clones a fusion instruction with a new shape and operands.
1331  std::unique_ptr<HloInstruction> CloneFusionWithNewOperands(
1332      const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
1333      HloModule* module = nullptr) const;
1334
1335  // Returns true if this instruction can legally have the dimensions field
1336  // set. Used for checking precondition of dimensions field accessors.
1337  bool CanHaveDimensionsField() const;
1338
1339  // Returns how this instruction uses elements of its `i`th operand.
1340  UseKind OperandElementUse(int64 i) const;
1341
1342  int unique_id_;  // Unique to this HloInstruction within a HloModule
1343
1344  // Opcode for this instruction.
1345  HloOpcode opcode_;
1346
1347  // Instruction operands.
1348  InstructionVector operands_;
1349
1350  // The set of control predecessors of this instruction.
1351  std::vector<HloInstruction*> control_predecessors_;
1352
1353  // The users of this instruction. Users are HLOs where this instruction is an
1354  // operand. The vector users_ and the set user_set_ contain identical
1355  // members. The set enables fast membership testing and the vector enables
1356  // fast, stable iteration.
1357  std::vector<HloInstruction*> users_;
1358  std::unordered_set<const HloInstruction*> user_set_;
1359
1360  // The set of control successors of this instruction.
1361  std::vector<HloInstruction*> control_successors_;
1362
1363  // The computation in which this instruction is contained.
1364  HloComputation* parent_ = nullptr;
1365
1366  // Shape of outfeed request.
1367  Shape outfeed_shape_;
1368
1369  // Result shape of this instruction.
1370  Shape shape_;
1371
1372  // Literal, only present for kConstant.
1373  std::unique_ptr<Literal> literal_;
1374
1375  // Constant index, only present for kGetTupleElement.
1376  int64 tuple_index_ = -1;
1377
1378  // Dimensions present for some operations that require reshaping or
1379  // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse.
1380  std::vector<int64> dimensions_;
1381
1382  // Describes the window in a windowed operation such as convolution.
1383  std::unique_ptr<Window> window_;
1384
1385  // Describes the dimension numbers used for a convolution.
1386  std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
1387
1388  // Describes the dimension numbers used for a dot.
1389  std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
1390
1391  std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
1392  std::vector<int64> gather_window_bounds_;
1393
1394  // Describes FFT type for an FFT instruction.
1395  FftType fft_type_ = FftType::FFT;
1396
1397  // Indicates the FFT length for an FFT instruction.
1398  std::vector<int64> fft_length_;
1399
1400  // Describes the [begin, end) index range for a slice.
1401  std::vector<int64> slice_starts_;
1402  std::vector<int64> slice_limits_;
1403  std::vector<int64> slice_strides_;
1404
1405  // Describes whether the slice can be lowered to an offset into the operand.
1406  bool is_in_place_slice_ = false;
1407
1408  // The bit sizes for a reduce-precision operation.
1409  int32 exponent_bits_ = 0;
1410  int32 mantissa_bits_ = 0;
1411
1412  // Describes the [start, start + size) range size for a dynamic slice
1413  // ('start' is specified dynamically in the second operand of the operation).
1414  std::vector<int64> dynamic_slice_sizes_;
1415
1416  // The padding configuration that describes the edge padding and interior
1417  // padding of this pad instruction. Only set for pad instructions.
1418  std::unique_ptr<PaddingConfig> padding_config_;
1419
1420  // The type of the fusion. Used by kFusion only.
1421  FusionKind fusion_kind_;
1422
1423  // The sharding, if one exists.
1424  std::unique_ptr<HloSharding> sharding_;
1425
1426  // For parameter instructions this field holds the parameter number.
1427  int64 parameter_number_ = 0;
1428
1429  // Name of a global symbol to call, only present for kCustomCall.
1430  string custom_call_target_;
1431
1432  // Name to use for host send/recv channels, only present for kHostCompute.
1433  string channel_name_;
1434
1435  // Estimate of the duration of a host computation in nanoseconds.
1436  int64 cost_estimate_ns_;
1437
1438  // Computations called by this instruction.
1439  std::vector<HloComputation*> called_computations_;
1440
1441  // Indices of computations in called_computations_ for instructions which call
1442  // multiple computations.
1443  enum {
1444    // kWhile computations.
1445    kBodyComputationIndex = 0,
1446    kConditionComputationIndex = 1,
1447
1448    // kSelectAndScatter computations.
1449    kSelectComputationIndex = 0,
1450    kScatterComputationIndex = 1,
1451
1452    // kConditional computations.
1453    kTrueComputationIndex = 0,
1454    kFalseComputationIndex = 1,
1455  };
1456
1457  // Outfeed configuration information, only present for kOutfeed.
1458  string outfeed_config_;
1459
1460  // A trace instruction that consumes this instruction.
1461  //
1462  // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as
1463  // an operand.
1464  HloInstruction* trace_instruction_ = nullptr;
1465
1466  // The distribution requested for random number generation.
1467  // Only present for kRng.
1468  RandomDistribution distribution_;
1469
1470  // A small float number added to the variance to avoid divide-by-zero error.
1471  // Only present for kBatchNormTraining.
1472  float epsilon_ = 0.0f;
1473
1474  // An integer value representing the index of the feature dimension.
1475  // Only present for kBatchNormTraining.
1476  int64 feature_index_ = -1;
1477
1478  // Represents a unique identifier for each Send/Recv instruction pair.
1479  // Only present for kSend or kRecv.
1480  int64 channel_id_ = -1;
1481
1482  // The string representation of the infeed configuration.
1483  string infeed_config_;
1484
1485  // String identifier for instruction.
1486  string name_;
1487
1488  // Metadata for debugging.
1489  OpMetadata metadata_;
1490
1491  // The number of partitions per outer dimension (listed in order from
1492  // outer-most dimension first).
1493  std::vector<int64> outer_dimension_partitions_;
1494
1495  TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction);
1496};
1497
1498string ToString(HloInstruction::FusionKind kind);
1499StatusOr<HloInstruction::FusionKind> StringToFusionKind(
1500    const string& kind_name);
1501
1502// Custom (de)stringification functions for protos that live inside
1503// HloInstruction.
1504string PaddingConfigToString(const PaddingConfig& padding);
1505string OpMetadataToString(const OpMetadata& metadata);
1506string RandomDistributionToString(const RandomDistribution& distribution);
1507StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
1508
1509std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
1510
1511// Map classes that guarantee a deterministic iteration order when the key is
1512// an HloInstruction* or a const HloInstruction*.
1513// To make the iteration order over the map deterministic, the comparator
1514// should not be using the pointer values, but rather an intrinsic property of
1515// the hlo.
1516//
1517// Note that this cannot be used for HLO instructions across multiple modules
1518// since the id of HLO instructions are only unique within each HLO module.
1519struct HloPtrComparator {
1520  bool operator()(const HloInstruction* const& lhs,
1521                  const HloInstruction* const& rhs) const {
1522    return lhs->unique_id() < rhs->unique_id();
1523  }
1524};
1525
1526template <typename ValueT>
1527using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>;
1528
1529template <typename ValueT>
1530using ConstHloInstructionMap =
1531    std::map<const HloInstruction*, ValueT, HloPtrComparator>;
1532
1533using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>;
1534using ConstHloInstructionSet =
1535    std::set<const HloInstruction*, HloPtrComparator>;
1536
1537}  // namespace xla
1538
1539#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
1540