10fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
20fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne
30fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-MilneLicensed under the Apache License, Version 2.0 (the "License");
40fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milneyou may not use this file except in compliance with the License.
50fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-MilneYou may obtain a copy of the License at
60fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne
70fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne    http://www.apache.org/licenses/LICENSE-2.0
80fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne
90fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-MilneUnless required by applicable law or agreed to in writing, software
100fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milnedistributed under the License is distributed on an "AS IS" BASIS,
110fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-MilneWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-MilneSee the License for the specific language governing permissions and
130fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milnelimitations under the License.
140fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne==============================================================================*/
150fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne
16f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#ifndef TENSORFLOW_CC_OPS_WHILE_LOOP_H_
17f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#define TENSORFLOW_CC_OPS_WHILE_LOOP_H_
180fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne
190fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne#include "tensorflow/cc/framework/ops.h"
200fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne#include "tensorflow/cc/framework/scope.h"
210fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne
220fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milnenamespace tensorflow {
230fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milnenamespace ops {
240fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne
250fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// Function that takes cond graph inputs and returns cond graph boolean output.
260fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// 'output' need not be set if an error is returned.
270fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milnetypedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
280fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne                             Output* output)>
290fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne    CondGraphBuilderFn;
300fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne
310fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// Function that takes body graph inputs and returns body graph outputs.
320fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// 'outputs' need not be populated if an error is returned.
330fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milnetypedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
340fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne                             std::vector<Output>* outputs)>
350fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne    BodyGraphBuilderFn;
360fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne
370fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// Constructs a while loop.
380fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne//
390fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// Arguments:
400fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// * scope: used to construct the while loop.
410fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// * inputs: the initial values of the loop variables. Must be non-empty.
420fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// * cond: a function that builds the condition graph of the loop. Takes the
430fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne//     current loop variables as inputs and returns a scalar boolean Output
440fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne//     indicating whether the loop should continue.
450fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// * body: a function that builds the body graph of the loop. Takes the current
460fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne//     loop variables as inputs and returns the updated loop variables.
470fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// * frame_name: the frame name to use for this while loop. This should be a
480fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne//     unique name. This will be used as a prefix for created operations.
490fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// * outputs: output param that returns final loop variable outputs in non-error
500fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne//     case. Must be non-null and empty.
5192362d0f0510d5bb1afa3c9cfd007cbf9cdf6d2aSkye Wanderman-Milne// * create_while_ctx: if true, a WhileContext is created and populated for this
52301b14c240fe99249dc2225132a7ebe5cbecbdc4Skye Wanderman-Milne//     loop. See core/graph/while_context.h for more details on
53301b14c240fe99249dc2225132a7ebe5cbecbdc4Skye Wanderman-Milne//     WhileContexts. This is set to false for loops used as part of gradient
54301b14c240fe99249dc2225132a7ebe5cbecbdc4Skye Wanderman-Milne//     computations, since they're part of the gradient for a loop in the
55301b14c240fe99249dc2225132a7ebe5cbecbdc4Skye Wanderman-Milne//     forward-pass.
56301b14c240fe99249dc2225132a7ebe5cbecbdc4Skye Wanderman-Milne//     TODO(skyewm): revisit this. Should we create WhileContexts for all loops,
57301b14c240fe99249dc2225132a7ebe5cbecbdc4Skye Wanderman-Milne//     even if we don't need them?
5892362d0f0510d5bb1afa3c9cfd007cbf9cdf6d2aSkye Wanderman-Milne// * cond_output: if non-null, the output of the predicate is returned. This
5992362d0f0510d5bb1afa3c9cfd007cbf9cdf6d2aSkye Wanderman-Milne//     will always be a LoopCond node.
600fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne//
610fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// Returns an error if the while loop could not be fully constructed.
620fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne//
630fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// TODO(skyewm): clean up partially-constructed loop in error case
640fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// TODO(skyewm): create public interface to this method
650fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-MilneStatus BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
660fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne                      const CondGraphBuilderFn& cond,
670fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne                      const BodyGraphBuilderFn& body, const string& frame_name,
6892362d0f0510d5bb1afa3c9cfd007cbf9cdf6d2aSkye Wanderman-Milne                      OutputList* outputs, bool create_while_ctx = true,
6992362d0f0510d5bb1afa3c9cfd007cbf9cdf6d2aSkye Wanderman-Milne                      Output* cond_output = nullptr);
700fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne
710fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne}  // namespace ops
720fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne}  // namespace tensorflow
730fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne
74f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#endif  // TENSORFLOW_CC_OPS_WHILE_LOOP_H_
75