1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
17#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
18
19#include <memory>
20
21#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
22#include "tensorflow/compiler/xla/service/hlo_computation.h"
23#include "tensorflow/compiler/xla/service/hlo_instruction.h"
24#include "tensorflow/compiler/xla/service/hlo_module.h"
25#include "tensorflow/compiler/xla/statusor.h"
26#include "tensorflow/compiler/xla/util.h"
27#include "tensorflow/compiler/xla/xla_data.pb.h"
28#include "tensorflow/core/lib/gtl/array_slice.h"
29#include "tensorflow/core/lib/gtl/flatmap.h"
30#include "tensorflow/core/platform/macros.h"
31
32namespace xla {
33
34// Responsible for evaluating HLO and obtain literal as the evaluation results.
35//
36// This class is not thread-safe.
37class HloEvaluator : public DfsHloVisitorWithDefault {
38 public:
39  HloEvaluator();
40  // Evaluates an HLO module and an array of pointers to literals.
41  // Returns the evaluated result as a literal if successful.
42  // Precondition: The indices of arg_literals correspond to the parameter
43  // numbers of the HLO parameters in the computation. See comment below for an
44  // example.
45  // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
46  // type.
47  template <typename LiteralPtr>
48  StatusOr<std::unique_ptr<Literal>> Evaluate(
49      const HloModule& module,
50      tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals);
51
52  // Evaluates an HLO computation and an array of pointers to literals.
53  // Returns the evaluated result as a literal if successful.
54  // Precondition: The indices of arg_literals correspond to the parameter
55  // numbers of the HLO parameters in the computation. For e.g., consider the
56  // following graph:
57  //
58  //                *
59  //            /       \
60  //            +     Parameter1
61  //        /      \
62  //       /        \
63  //    Parameter0  Constant
64  //
65  // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
66  // 1 in this computation. The input literals array will then have its first
67  // literal map to Parameter0 and the second map to Parameter1.
68  // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
69  // type.
70  template <typename LiteralPtr>
71  StatusOr<std::unique_ptr<Literal>> Evaluate(
72      const HloComputation& computation,
73      tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals);
74
75  // Evaluates a single HLO instruction and an array of pointers to literals.
76  // Return the evaluated result as literal if successful.
77  // Precondition:
78  // 1. argument literals correspond to the input instruction's parameters in
79  // their post-ordering.
80  // 2. the instruction's operands must be of either Parameter or Constant type.
81  // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
82  // type.
83  template <typename LiteralPtr>
84  StatusOr<std::unique_ptr<Literal>> Evaluate(
85      HloInstruction* instruction,
86      tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals);
87
88  // Evaluates a single HLO instruction with constant operands.
89  // Returns the evaluated result as literal if successful.
90  // Precondition:
91  // 1. all operands of the input instruction are constants.
92  // 2. the instruction is not a Parameter operation.
93  StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction);
94
95  // Same as Evaluate, except returning nullptr on error.
96  std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction);
97
98  // Evaluates a single HLO instruction, substituting the given literals for
99  // some of the instruction's operands.
100  //
101  // For example, given instruction = op(A, B, C) and the map
102  // {A = x, C = y}, this evaluates op(x, B, y).
103  StatusOr<std::unique_ptr<Literal>> EvaluateWithSubstitutions(
104      const HloInstruction* instruction,
105      const std::unordered_map<const HloInstruction*, const Literal*>&
106          substitutions);
107
108 protected:
109  // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting
110  // literal type of each evaluated Handle* method of a TypedVisitor.
111  // There are however a few notable exceptions to this rule, notably:
112  // - HandleCompare and HandleIsFinite: where the resulting literal type is
113  // always boolean.
114  // These operations are handled outside of the parent HloEvaluator handlers
115  // instead of from within TypedVisitor.
116  //
117  // Type params:
118  //   - ReturnT: The type of input and output of each operation.
119  //   - ElementwiseT: The type in which internal computation are done.
120  template <typename ReturnT, typename ElementwiseT = ReturnT>
121  class TypedVisitor;
122
123  // Wraps around instruction handling to infer types before dispatching to
124  // the corresponding typed Visitor.
125  Status DefaultAction(HloInstruction* hlo) override {
126    return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get());
127  }
128
129  Status Preprocess(HloInstruction* hlo) override;
130
131  Status Postprocess(HloInstruction* hlo) override;
132
133  // Operations that are type-agnostic or always return a specific type, such as
134  // HandleIsFinite where boolean is always returned.
135  //
136  Status HandleParameter(HloInstruction* parameter) override;
137
138  Status HandleConstant(HloInstruction* constant) override;
139
140  Status HandleConcatenate(HloInstruction* concatenate) override;
141
142  Status HandleReshape(HloInstruction* reshape) override;
143
144  Status HandleTranspose(HloInstruction* transpose) override;
145
146  Status HandleIsFinite(HloInstruction* is_finite) override;
147
148  Status HandleCompare(HloInstruction* compare) override;
149
150  Status HandleTuple(HloInstruction* tuple) override;
151
152  Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
153
154  Status HandleCopy(HloInstruction* copy) override;
155
156 private:
157  // Returns the already-evaluated literal result for the instruction.
158  // A Constant instruction is considered evaluated and its literal will be
159  // returned directly without looking up the cache.
160  // Crash with log if the given instruction has not been evaluated previously.
161  const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) {
162    if (hlo->IsConstant()) {
163      return hlo->literal();
164    }
165    auto it = evaluated_.find(hlo);
166    CHECK(it != evaluated_.end())
167        << "could not find evaluated value for: " << hlo->ToString();
168    return *(it->second);
169  }
170
171  // Map from a primitive type to its associated (templated) DfsHloVisitor.
172  // Note: the hash function here is only needed because current gcc std::hash
173  // does not specialize for enum types. This should however be fixed in the
174  // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5
175  tensorflow::gtl::FlatMap<PrimitiveType, std::unique_ptr<DfsHloVisitor>,
176                           std::hash<int>>
177      typed_visitors_;
178
179  // Tracks the HLO instruction and its evaluated literal result.
180  // TODO(b/35950897): have better memory management here to free instructions
181  // that are no longer a parent for any other subsequent instruction in
182  // post-orderring.
183  // Must be cleared for each evaluation.
184  tensorflow::gtl::FlatMap<const HloInstruction*, std::unique_ptr<Literal>>
185      evaluated_;
186
187  // Caches pointers to input literals, assuming they are in post-order.
188  // Literals are not owned by this class, and they must outlive the lifetime of
189  // each invocation to the Evaluate* method.
190  // Must be cleared for each evaluation.
191  std::vector<const Literal*> arg_literals_;
192
193  TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator);
194};
195
196}  // namespace xla
197
198#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
199