19bfa43625061ec62bd9623ab014db4851307e92dRohan Jain/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
29bfa43625061ec62bd9623ab014db4851307e92dRohan Jain
39bfa43625061ec62bd9623ab014db4851307e92dRohan JainLicensed under the Apache License, Version 2.0 (the "License");
49bfa43625061ec62bd9623ab014db4851307e92dRohan Jainyou may not use this file except in compliance with the License.
59bfa43625061ec62bd9623ab014db4851307e92dRohan JainYou may obtain a copy of the License at
69bfa43625061ec62bd9623ab014db4851307e92dRohan Jain
79bfa43625061ec62bd9623ab014db4851307e92dRohan Jain    http://www.apache.org/licenses/LICENSE-2.0
89bfa43625061ec62bd9623ab014db4851307e92dRohan Jain
99bfa43625061ec62bd9623ab014db4851307e92dRohan JainUnless required by applicable law or agreed to in writing, software
109bfa43625061ec62bd9623ab014db4851307e92dRohan Jaindistributed under the License is distributed on an "AS IS" BASIS,
119bfa43625061ec62bd9623ab014db4851307e92dRohan JainWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129bfa43625061ec62bd9623ab014db4851307e92dRohan JainSee the License for the specific language governing permissions and
139bfa43625061ec62bd9623ab014db4851307e92dRohan Jainlimitations under the License.
149bfa43625061ec62bd9623ab014db4851307e92dRohan Jain==============================================================================*/
15f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
16f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
179bfa43625061ec62bd9623ab014db4851307e92dRohan Jain
189bfa43625061ec62bd9623ab014db4851307e92dRohan Jain#include "tensorflow/core/distributed_runtime/worker_interface.h"
199bfa43625061ec62bd9623ab014db4851307e92dRohan Jain#include "tensorflow/core/distributed_runtime/worker_session.h"
209bfa43625061ec62bd9623ab014db4851307e92dRohan Jain#include "tensorflow/core/framework/function.h"
219bfa43625061ec62bd9623ab014db4851307e92dRohan Jain
229bfa43625061ec62bd9623ab014db4851307e92dRohan Jainnamespace tensorflow {
239bfa43625061ec62bd9623ab014db4851307e92dRohan Jain
249bfa43625061ec62bd9623ab014db4851307e92dRohan Jainstruct WorkerSession;
259bfa43625061ec62bd9623ab014db4851307e92dRohan Jain
269bfa43625061ec62bd9623ab014db4851307e92dRohan Jain// ClusterFunctionLibraryRuntime contains methods to Instantiate and Run
279bfa43625061ec62bd9623ab014db4851307e92dRohan Jain// functions across processes by making RPCs.
289bfa43625061ec62bd9623ab014db4851307e92dRohan Jainclass ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime {
299bfa43625061ec62bd9623ab014db4851307e92dRohan Jain public:
309bfa43625061ec62bd9623ab014db4851307e92dRohan Jain  ClusterFunctionLibraryRuntime(WorkerSession* worker_session)
319bfa43625061ec62bd9623ab014db4851307e92dRohan Jain      : worker_session_(worker_session) {}
329bfa43625061ec62bd9623ab014db4851307e92dRohan Jain
339bfa43625061ec62bd9623ab014db4851307e92dRohan Jain  ~ClusterFunctionLibraryRuntime() override;
349bfa43625061ec62bd9623ab014db4851307e92dRohan Jain
359bfa43625061ec62bd9623ab014db4851307e92dRohan Jain  Status Instantiate(const string& function_name,
369bfa43625061ec62bd9623ab014db4851307e92dRohan Jain                     const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
3732d138db751c541e951d1958cac4918214e9644eDerek Murray                     const FunctionLibraryRuntime::InstantiateOptions& options,
389bfa43625061ec62bd9623ab014db4851307e92dRohan Jain                     FunctionLibraryRuntime::LocalHandle* handle) override;
399bfa43625061ec62bd9623ab014db4851307e92dRohan Jain
409bfa43625061ec62bd9623ab014db4851307e92dRohan Jain  void Run(const FunctionLibraryRuntime::Options& opts,
419bfa43625061ec62bd9623ab014db4851307e92dRohan Jain           FunctionLibraryRuntime::LocalHandle handle,
429bfa43625061ec62bd9623ab014db4851307e92dRohan Jain           gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
439bfa43625061ec62bd9623ab014db4851307e92dRohan Jain           FunctionLibraryRuntime::DoneCallback done) override;
449bfa43625061ec62bd9623ab014db4851307e92dRohan Jain
459bfa43625061ec62bd9623ab014db4851307e92dRohan Jain private:
4632d138db751c541e951d1958cac4918214e9644eDerek Murray  static Status ConstructFunctionGraph(
4732d138db751c541e951d1958cac4918214e9644eDerek Murray      const OpDef& sig, AttrSlice attrs,
4832d138db751c541e951d1958cac4918214e9644eDerek Murray      const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g,
4932d138db751c541e951d1958cac4918214e9644eDerek Murray      std::vector<string>* send_keys, std::vector<string>* recv_keys);
509bfa43625061ec62bd9623ab014db4851307e92dRohan Jain  friend class ClusterFunctionLibraryRuntimeTest;
519bfa43625061ec62bd9623ab014db4851307e92dRohan Jain
529bfa43625061ec62bd9623ab014db4851307e92dRohan Jain  mutable mutex mu_;
539bfa43625061ec62bd9623ab014db4851307e92dRohan Jain  WorkerSession* const worker_session_ = nullptr;  // not owned.
549bfa43625061ec62bd9623ab014db4851307e92dRohan Jain
559bfa43625061ec62bd9623ab014db4851307e92dRohan Jain  struct FunctionData {
569bfa43625061ec62bd9623ab014db4851307e92dRohan Jain    const string graph_handle;
579bfa43625061ec62bd9623ab014db4851307e92dRohan Jain    const string target;
589bfa43625061ec62bd9623ab014db4851307e92dRohan Jain    WorkerInterface* wi = nullptr;
599bfa43625061ec62bd9623ab014db4851307e92dRohan Jain    const std::vector<string> send_keys;
609bfa43625061ec62bd9623ab014db4851307e92dRohan Jain    const std::vector<string> recv_keys;
619bfa43625061ec62bd9623ab014db4851307e92dRohan Jain
629bfa43625061ec62bd9623ab014db4851307e92dRohan Jain    FunctionData(const string& graph_handle, const string& target,
639bfa43625061ec62bd9623ab014db4851307e92dRohan Jain                 WorkerInterface* wi, const std::vector<string>& send_keys,
649bfa43625061ec62bd9623ab014db4851307e92dRohan Jain                 const std::vector<string>& recv_keys)
659bfa43625061ec62bd9623ab014db4851307e92dRohan Jain        : graph_handle(graph_handle),
669bfa43625061ec62bd9623ab014db4851307e92dRohan Jain          target(target),
679bfa43625061ec62bd9623ab014db4851307e92dRohan Jain          wi(wi),
689bfa43625061ec62bd9623ab014db4851307e92dRohan Jain          send_keys(send_keys),
699bfa43625061ec62bd9623ab014db4851307e92dRohan Jain          recv_keys(recv_keys) {}
709bfa43625061ec62bd9623ab014db4851307e92dRohan Jain  };
719bfa43625061ec62bd9623ab014db4851307e92dRohan Jain
729bfa43625061ec62bd9623ab014db4851307e92dRohan Jain  std::vector<FunctionData> function_data_ GUARDED_BY(mu_);
739bfa43625061ec62bd9623ab014db4851307e92dRohan Jain};
749bfa43625061ec62bd9623ab014db4851307e92dRohan Jain
759bfa43625061ec62bd9623ab014db4851307e92dRohan Jain}  // namespace tensorflow
769bfa43625061ec62bd9623ab014db4851307e92dRohan Jain
77f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
78