10fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 20fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 30fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-MilneLicensed under the Apache License, Version 2.0 (the "License"); 40fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milneyou may not use this file except in compliance with the License. 50fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-MilneYou may obtain a copy of the License at 60fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 70fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne http://www.apache.org/licenses/LICENSE-2.0 80fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 90fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-MilneUnless required by applicable law or agreed to in writing, software 100fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milnedistributed under the License is distributed on an "AS IS" BASIS, 110fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-MilneWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 120fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-MilneSee the License for the specific language governing permissions and 130fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milnelimitations under the License. 140fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne==============================================================================*/ 150fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 16f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#ifndef TENSORFLOW_CC_OPS_WHILE_LOOP_H_ 17f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#define TENSORFLOW_CC_OPS_WHILE_LOOP_H_ 180fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 190fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne#include "tensorflow/cc/framework/ops.h" 200fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne#include "tensorflow/cc/framework/scope.h" 210fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 220fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milnenamespace tensorflow { 230fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milnenamespace ops { 240fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 250fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// Function that takes cond graph inputs and returns cond graph boolean output. 260fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// 'output' need not be set if an error is returned. 270fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milnetypedef std::function<Status(const Scope&, const std::vector<Output>& inputs, 280fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne Output* output)> 290fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne CondGraphBuilderFn; 300fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 310fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// Function that takes body graph inputs and returns body graph outputs. 320fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// 'outputs' need not be populated if an error is returned. 330fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milnetypedef std::function<Status(const Scope&, const std::vector<Output>& inputs, 340fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne std::vector<Output>* outputs)> 350fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne BodyGraphBuilderFn; 360fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 370fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// Constructs a while loop. 380fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// 390fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// Arguments: 400fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// * scope: used to construct the while loop. 410fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// * inputs: the initial values of the loop variables. Must be non-empty. 420fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// * cond: a function that builds the condition graph of the loop. Takes the 430fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// current loop variables as inputs and returns a scalar boolean Output 440fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// indicating whether the loop should continue. 450fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// * body: a function that builds the body graph of the loop. Takes the current 460fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// loop variables as inputs and returns the updated loop variables. 470fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// * frame_name: the frame name to use for this while loop. This should be a 480fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// unique name. This will be used as a prefix for created operations. 490fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// * outputs: output param that returns final loop variable outputs in non-error 500fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// case. Must be non-null and empty. 5192362d0f0510d5bb1afa3c9cfd007cbf9cdf6d2aSkye Wanderman-Milne// * create_while_ctx: if true, a WhileContext is created and populated for this 52301b14c240fe99249dc2225132a7ebe5cbecbdc4Skye Wanderman-Milne// loop. See core/graph/while_context.h for more details on 53301b14c240fe99249dc2225132a7ebe5cbecbdc4Skye Wanderman-Milne// WhileContexts. This is set to false for loops used as part of gradient 54301b14c240fe99249dc2225132a7ebe5cbecbdc4Skye Wanderman-Milne// computations, since they're part of the gradient for a loop in the 55301b14c240fe99249dc2225132a7ebe5cbecbdc4Skye Wanderman-Milne// forward-pass. 56301b14c240fe99249dc2225132a7ebe5cbecbdc4Skye Wanderman-Milne// TODO(skyewm): revisit this. Should we create WhileContexts for all loops, 57301b14c240fe99249dc2225132a7ebe5cbecbdc4Skye Wanderman-Milne// even if we don't need them? 5892362d0f0510d5bb1afa3c9cfd007cbf9cdf6d2aSkye Wanderman-Milne// * cond_output: if non-null, the output of the predicate is returned. This 5992362d0f0510d5bb1afa3c9cfd007cbf9cdf6d2aSkye Wanderman-Milne// will always be a LoopCond node. 600fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// 610fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// Returns an error if the while loop could not be fully constructed. 620fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// 630fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// TODO(skyewm): clean up partially-constructed loop in error case 640fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne// TODO(skyewm): create public interface to this method 650fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-MilneStatus BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs, 660fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne const CondGraphBuilderFn& cond, 670fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne const BodyGraphBuilderFn& body, const string& frame_name, 6892362d0f0510d5bb1afa3c9cfd007cbf9cdf6d2aSkye Wanderman-Milne OutputList* outputs, bool create_while_ctx = true, 6992362d0f0510d5bb1afa3c9cfd007cbf9cdf6d2aSkye Wanderman-Milne Output* cond_output = nullptr); 700fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 710fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne} // namespace ops 720fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne} // namespace tensorflow 730fd2a74120b86972441378f79fb5d03e86fed856Skye Wanderman-Milne 74f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#endif // TENSORFLOW_CC_OPS_WHILE_LOOP_H_ 75