1//===------- RPCUTils.h - Utilities for building RPC APIs -------*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// Utilities to support construction of simple RPC APIs.
11//
12// The RPC utilities aim for ease of use (minimal conceptual overhead) for C++
13// programmers, high performance, low memory overhead, and efficient use of the
14// communications channel.
15//
16//===----------------------------------------------------------------------===//
17
18#ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
19#define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
20
21#include <map>
22#include <thread>
23#include <vector>
24
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ExecutionEngine/Orc/OrcError.h"
27#include "llvm/ExecutionEngine/Orc/RPCSerialization.h"
28
29#include <future>
30
31namespace llvm {
32namespace orc {
33namespace rpc {
34
35/// Base class of all fatal RPC errors (those that necessarily result in the
36/// termination of the RPC session).
37class RPCFatalError : public ErrorInfo<RPCFatalError> {
38public:
39  static char ID;
40};
41
42/// RPCConnectionClosed is returned from RPC operations if the RPC connection
43/// has already been closed due to either an error or graceful disconnection.
44class ConnectionClosed : public ErrorInfo<ConnectionClosed> {
45public:
46  static char ID;
47  std::error_code convertToErrorCode() const override;
48  void log(raw_ostream &OS) const override;
49};
50
51/// BadFunctionCall is returned from handleOne when the remote makes a call with
52/// an unrecognized function id.
53///
54/// This error is fatal because Orc RPC needs to know how to parse a function
55/// call to know where the next call starts, and if it doesn't recognize the
56/// function id it cannot parse the call.
57template <typename FnIdT, typename SeqNoT>
58class BadFunctionCall
59  : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> {
60public:
61  static char ID;
62
63  BadFunctionCall(FnIdT FnId, SeqNoT SeqNo)
64      : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {}
65
66  std::error_code convertToErrorCode() const override {
67    return orcError(OrcErrorCode::UnexpectedRPCCall);
68  }
69
70  void log(raw_ostream &OS) const override {
71    OS << "Call to invalid RPC function id '" << FnId << "' with "
72          "sequence number " << SeqNo;
73  }
74
75private:
76  FnIdT FnId;
77  SeqNoT SeqNo;
78};
79
80template <typename FnIdT, typename SeqNoT>
81char BadFunctionCall<FnIdT, SeqNoT>::ID = 0;
82
83/// InvalidSequenceNumberForResponse is returned from handleOne when a response
84/// call arrives with a sequence number that doesn't correspond to any in-flight
85/// function call.
86///
87/// This error is fatal because Orc RPC needs to know how to parse the rest of
88/// the response call to know where the next call starts, and if it doesn't have
89/// a result parser for this sequence number it can't do that.
90template <typename SeqNoT>
91class InvalidSequenceNumberForResponse
92    : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> {
93public:
94  static char ID;
95
96  InvalidSequenceNumberForResponse(SeqNoT SeqNo)
97      : SeqNo(std::move(SeqNo)) {}
98
99  std::error_code convertToErrorCode() const override {
100    return orcError(OrcErrorCode::UnexpectedRPCCall);
101  };
102
103  void log(raw_ostream &OS) const override {
104    OS << "Response has unknown sequence number " << SeqNo;
105  }
106private:
107  SeqNoT SeqNo;
108};
109
110template <typename SeqNoT>
111char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0;
112
113/// This non-fatal error will be passed to asynchronous result handlers in place
114/// of a result if the connection goes down before a result returns, or if the
115/// function to be called cannot be negotiated with the remote.
116class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> {
117public:
118  static char ID;
119
120  std::error_code convertToErrorCode() const override;
121  void log(raw_ostream &OS) const override;
122};
123
124/// This error is returned if the remote does not have a handler installed for
125/// the given RPC function.
126class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> {
127public:
128  static char ID;
129
130  CouldNotNegotiate(std::string Signature);
131  std::error_code convertToErrorCode() const override;
132  void log(raw_ostream &OS) const override;
133  const std::string &getSignature() const { return Signature; }
134private:
135  std::string Signature;
136};
137
138template <typename DerivedFunc, typename FnT> class Function;
139
140// RPC Function class.
141// DerivedFunc should be a user defined class with a static 'getName()' method
142// returning a const char* representing the function's name.
143template <typename DerivedFunc, typename RetT, typename... ArgTs>
144class Function<DerivedFunc, RetT(ArgTs...)> {
145public:
146  /// User defined function type.
147  using Type = RetT(ArgTs...);
148
149  /// Return type.
150  using ReturnType = RetT;
151
152  /// Returns the full function prototype as a string.
153  static const char *getPrototype() {
154    std::lock_guard<std::mutex> Lock(NameMutex);
155    if (Name.empty())
156      raw_string_ostream(Name)
157          << RPCTypeName<RetT>::getName() << " " << DerivedFunc::getName()
158          << "(" << llvm::orc::rpc::RPCTypeNameSequence<ArgTs...>() << ")";
159    return Name.data();
160  }
161
162private:
163  static std::mutex NameMutex;
164  static std::string Name;
165};
166
167template <typename DerivedFunc, typename RetT, typename... ArgTs>
168std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex;
169
170template <typename DerivedFunc, typename RetT, typename... ArgTs>
171std::string Function<DerivedFunc, RetT(ArgTs...)>::Name;
172
173/// Allocates RPC function ids during autonegotiation.
174/// Specializations of this class must provide four members:
175///
176/// static T getInvalidId():
177///   Should return a reserved id that will be used to represent missing
178/// functions during autonegotiation.
179///
180/// static T getResponseId():
181///   Should return a reserved id that will be used to send function responses
182/// (return values).
183///
184/// static T getNegotiateId():
185///   Should return a reserved id for the negotiate function, which will be used
186/// to negotiate ids for user defined functions.
187///
188/// template <typename Func> T allocate():
189///   Allocate a unique id for function Func.
190template <typename T, typename = void> class RPCFunctionIdAllocator;
191
192/// This specialization of RPCFunctionIdAllocator provides a default
193/// implementation for integral types.
194template <typename T>
195class RPCFunctionIdAllocator<
196    T, typename std::enable_if<std::is_integral<T>::value>::type> {
197public:
198  static T getInvalidId() { return T(0); }
199  static T getResponseId() { return T(1); }
200  static T getNegotiateId() { return T(2); }
201
202  template <typename Func> T allocate() { return NextId++; }
203
204private:
205  T NextId = 3;
206};
207
208namespace detail {
209
210// FIXME: Remove MSVCPError/MSVCPExpected once MSVC's future implementation
211//        supports classes without default constructors.
212#ifdef _MSC_VER
213
214namespace msvc_hacks {
215
216// Work around MSVC's future implementation's use of default constructors:
217// A default constructed value in the promise will be overwritten when the
218// real error is set - so the default constructed Error has to be checked
219// already.
220class MSVCPError : public Error {
221public:
222  MSVCPError() { (void)!!*this; }
223
224  MSVCPError(MSVCPError &&Other) : Error(std::move(Other)) {}
225
226  MSVCPError &operator=(MSVCPError Other) {
227    Error::operator=(std::move(Other));
228    return *this;
229  }
230
231  MSVCPError(Error Err) : Error(std::move(Err)) {}
232};
233
234// Work around MSVC's future implementation, similar to MSVCPError.
235template <typename T> class MSVCPExpected : public Expected<T> {
236public:
237  MSVCPExpected()
238      : Expected<T>(make_error<StringError>("", inconvertibleErrorCode())) {
239    consumeError(this->takeError());
240  }
241
242  MSVCPExpected(MSVCPExpected &&Other) : Expected<T>(std::move(Other)) {}
243
244  MSVCPExpected &operator=(MSVCPExpected &&Other) {
245    Expected<T>::operator=(std::move(Other));
246    return *this;
247  }
248
249  MSVCPExpected(Error Err) : Expected<T>(std::move(Err)) {}
250
251  template <typename OtherT>
252  MSVCPExpected(
253      OtherT &&Val,
254      typename std::enable_if<std::is_convertible<OtherT, T>::value>::type * =
255          nullptr)
256      : Expected<T>(std::move(Val)) {}
257
258  template <class OtherT>
259  MSVCPExpected(
260      Expected<OtherT> &&Other,
261      typename std::enable_if<std::is_convertible<OtherT, T>::value>::type * =
262          nullptr)
263      : Expected<T>(std::move(Other)) {}
264
265  template <class OtherT>
266  explicit MSVCPExpected(
267      Expected<OtherT> &&Other,
268      typename std::enable_if<!std::is_convertible<OtherT, T>::value>::type * =
269          nullptr)
270      : Expected<T>(std::move(Other)) {}
271};
272
273} // end namespace msvc_hacks
274
275#endif // _MSC_VER
276
277/// Provides a typedef for a tuple containing the decayed argument types.
278template <typename T> class FunctionArgsTuple;
279
280template <typename RetT, typename... ArgTs>
281class FunctionArgsTuple<RetT(ArgTs...)> {
282public:
283  using Type = std::tuple<typename std::decay<
284      typename std::remove_reference<ArgTs>::type>::type...>;
285};
286
287// ResultTraits provides typedefs and utilities specific to the return type
288// of functions.
289template <typename RetT> class ResultTraits {
290public:
291  // The return type wrapped in llvm::Expected.
292  using ErrorReturnType = Expected<RetT>;
293
294#ifdef _MSC_VER
295  // The ErrorReturnType wrapped in a std::promise.
296  using ReturnPromiseType = std::promise<msvc_hacks::MSVCPExpected<RetT>>;
297
298  // The ErrorReturnType wrapped in a std::future.
299  using ReturnFutureType = std::future<msvc_hacks::MSVCPExpected<RetT>>;
300#else
301  // The ErrorReturnType wrapped in a std::promise.
302  using ReturnPromiseType = std::promise<ErrorReturnType>;
303
304  // The ErrorReturnType wrapped in a std::future.
305  using ReturnFutureType = std::future<ErrorReturnType>;
306#endif
307
308  // Create a 'blank' value of the ErrorReturnType, ready and safe to
309  // overwrite.
310  static ErrorReturnType createBlankErrorReturnValue() {
311    return ErrorReturnType(RetT());
312  }
313
314  // Consume an abandoned ErrorReturnType.
315  static void consumeAbandoned(ErrorReturnType RetOrErr) {
316    consumeError(RetOrErr.takeError());
317  }
318};
319
320// ResultTraits specialization for void functions.
321template <> class ResultTraits<void> {
322public:
323  // For void functions, ErrorReturnType is llvm::Error.
324  using ErrorReturnType = Error;
325
326#ifdef _MSC_VER
327  // The ErrorReturnType wrapped in a std::promise.
328  using ReturnPromiseType = std::promise<msvc_hacks::MSVCPError>;
329
330  // The ErrorReturnType wrapped in a std::future.
331  using ReturnFutureType = std::future<msvc_hacks::MSVCPError>;
332#else
333  // The ErrorReturnType wrapped in a std::promise.
334  using ReturnPromiseType = std::promise<ErrorReturnType>;
335
336  // The ErrorReturnType wrapped in a std::future.
337  using ReturnFutureType = std::future<ErrorReturnType>;
338#endif
339
340  // Create a 'blank' value of the ErrorReturnType, ready and safe to
341  // overwrite.
342  static ErrorReturnType createBlankErrorReturnValue() {
343    return ErrorReturnType::success();
344  }
345
346  // Consume an abandoned ErrorReturnType.
347  static void consumeAbandoned(ErrorReturnType Err) {
348    consumeError(std::move(Err));
349  }
350};
351
352// ResultTraits<Error> is equivalent to ResultTraits<void>. This allows
353// handlers for void RPC functions to return either void (in which case they
354// implicitly succeed) or Error (in which case their error return is
355// propagated). See usage in HandlerTraits::runHandlerHelper.
356template <> class ResultTraits<Error> : public ResultTraits<void> {};
357
358// ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows
359// handlers for RPC functions returning a T to return either a T (in which
360// case they implicitly succeed) or Expected<T> (in which case their error
361// return is propagated). See usage in HandlerTraits::runHandlerHelper.
362template <typename RetT>
363class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {};
364
365// Determines whether an RPC function's defined error return type supports
366// error return value.
367template <typename T>
368class SupportsErrorReturn {
369public:
370  static const bool value = false;
371};
372
373template <>
374class SupportsErrorReturn<Error> {
375public:
376  static const bool value = true;
377};
378
379template <typename T>
380class SupportsErrorReturn<Expected<T>> {
381public:
382  static const bool value = true;
383};
384
385// RespondHelper packages return values based on whether or not the declared
386// RPC function return type supports error returns.
387template <bool FuncSupportsErrorReturn>
388class RespondHelper;
389
390// RespondHelper specialization for functions that support error returns.
391template <>
392class RespondHelper<true> {
393public:
394
395  // Send Expected<T>.
396  template <typename WireRetT, typename HandlerRetT, typename ChannelT,
397            typename FunctionIdT, typename SequenceNumberT>
398  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
399                          SequenceNumberT SeqNo,
400                          Expected<HandlerRetT> ResultOrErr) {
401    if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>())
402      return ResultOrErr.takeError();
403
404    // Open the response message.
405    if (auto Err = C.startSendMessage(ResponseId, SeqNo))
406      return Err;
407
408    // Serialize the result.
409    if (auto Err =
410        SerializationTraits<ChannelT, WireRetT,
411                            Expected<HandlerRetT>>::serialize(
412                                                     C, std::move(ResultOrErr)))
413      return Err;
414
415    // Close the response message.
416    return C.endSendMessage();
417  }
418
419  template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
420  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
421                          SequenceNumberT SeqNo, Error Err) {
422    if (Err && Err.isA<RPCFatalError>())
423      return Err;
424    if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
425      return Err2;
426    if (auto Err2 = serializeSeq(C, std::move(Err)))
427      return Err2;
428    return C.endSendMessage();
429  }
430
431};
432
433// RespondHelper specialization for functions that do not support error returns.
434template <>
435class RespondHelper<false> {
436public:
437
438  template <typename WireRetT, typename HandlerRetT, typename ChannelT,
439            typename FunctionIdT, typename SequenceNumberT>
440  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
441                          SequenceNumberT SeqNo,
442                          Expected<HandlerRetT> ResultOrErr) {
443    if (auto Err = ResultOrErr.takeError())
444      return Err;
445
446    // Open the response message.
447    if (auto Err = C.startSendMessage(ResponseId, SeqNo))
448      return Err;
449
450    // Serialize the result.
451    if (auto Err =
452        SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
453                                                               C, *ResultOrErr))
454      return Err;
455
456    // Close the response message.
457    return C.endSendMessage();
458  }
459
460  template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
461  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
462                          SequenceNumberT SeqNo, Error Err) {
463    if (Err)
464      return Err;
465    if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
466      return Err2;
467    return C.endSendMessage();
468  }
469
470};
471
472
473// Send a response of the given wire return type (WireRetT) over the
474// channel, with the given sequence number.
475template <typename WireRetT, typename HandlerRetT, typename ChannelT,
476          typename FunctionIdT, typename SequenceNumberT>
477Error respond(ChannelT &C, const FunctionIdT &ResponseId,
478              SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) {
479  return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
480    template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr));
481}
482
483// Send an empty response message on the given channel to indicate that
484// the handler ran.
485template <typename WireRetT, typename ChannelT, typename FunctionIdT,
486          typename SequenceNumberT>
487Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
488              Error Err) {
489  return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
490    sendResult(C, ResponseId, SeqNo, std::move(Err));
491}
492
493// Converts a given type to the equivalent error return type.
494template <typename T> class WrappedHandlerReturn {
495public:
496  using Type = Expected<T>;
497};
498
499template <typename T> class WrappedHandlerReturn<Expected<T>> {
500public:
501  using Type = Expected<T>;
502};
503
504template <> class WrappedHandlerReturn<void> {
505public:
506  using Type = Error;
507};
508
509template <> class WrappedHandlerReturn<Error> {
510public:
511  using Type = Error;
512};
513
514template <> class WrappedHandlerReturn<ErrorSuccess> {
515public:
516  using Type = Error;
517};
518
519// Traits class that strips the response function from the list of handler
520// arguments.
521template <typename FnT> class AsyncHandlerTraits;
522
523template <typename ResultT, typename... ArgTs>
524class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> {
525public:
526  using Type = Error(ArgTs...);
527  using ResultType = Expected<ResultT>;
528};
529
530template <typename... ArgTs>
531class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> {
532public:
533  using Type = Error(ArgTs...);
534  using ResultType = Error;
535};
536
537template <typename ResponseHandlerT, typename... ArgTs>
538class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> :
539    public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type,
540                                    ArgTs...)> {};
541
542// This template class provides utilities related to RPC function handlers.
543// The base case applies to non-function types (the template class is
544// specialized for function types) and inherits from the appropriate
545// speciilization for the given non-function type's call operator.
546template <typename HandlerT>
547class HandlerTraits : public HandlerTraits<decltype(
548                          &std::remove_reference<HandlerT>::type::operator())> {
549};
550
551// Traits for handlers with a given function type.
552template <typename RetT, typename... ArgTs>
553class HandlerTraits<RetT(ArgTs...)> {
554public:
555  // Function type of the handler.
556  using Type = RetT(ArgTs...);
557
558  // Return type of the handler.
559  using ReturnType = RetT;
560
561  // Call the given handler with the given arguments.
562  template <typename HandlerT, typename... TArgTs>
563  static typename WrappedHandlerReturn<RetT>::Type
564  unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) {
565    return unpackAndRunHelper(Handler, Args,
566                              llvm::index_sequence_for<TArgTs...>());
567  }
568
569  // Call the given handler with the given arguments.
570  template <typename HandlerT, typename ResponderT, typename... TArgTs>
571  static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder,
572                                 std::tuple<TArgTs...> &Args) {
573    return unpackAndRunAsyncHelper(Handler, Responder, Args,
574                                   llvm::index_sequence_for<TArgTs...>());
575  }
576
577  // Call the given handler with the given arguments.
578  template <typename HandlerT>
579  static typename std::enable_if<
580      std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
581      Error>::type
582  run(HandlerT &Handler, ArgTs &&... Args) {
583    Handler(std::move(Args)...);
584    return Error::success();
585  }
586
587  template <typename HandlerT, typename... TArgTs>
588  static typename std::enable_if<
589      !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
590      typename HandlerTraits<HandlerT>::ReturnType>::type
591  run(HandlerT &Handler, TArgTs... Args) {
592    return Handler(std::move(Args)...);
593  }
594
595  // Serialize arguments to the channel.
596  template <typename ChannelT, typename... CArgTs>
597  static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) {
598    return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...);
599  }
600
601  // Deserialize arguments from the channel.
602  template <typename ChannelT, typename... CArgTs>
603  static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) {
604    return deserializeArgsHelper(C, Args,
605                                 llvm::index_sequence_for<CArgTs...>());
606  }
607
608private:
609  template <typename ChannelT, typename... CArgTs, size_t... Indexes>
610  static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args,
611                                     llvm::index_sequence<Indexes...> _) {
612    return SequenceSerialization<ChannelT, ArgTs...>::deserialize(
613        C, std::get<Indexes>(Args)...);
614  }
615
616  template <typename HandlerT, typename ArgTuple, size_t... Indexes>
617  static typename WrappedHandlerReturn<
618      typename HandlerTraits<HandlerT>::ReturnType>::Type
619  unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args,
620                     llvm::index_sequence<Indexes...>) {
621    return run(Handler, std::move(std::get<Indexes>(Args))...);
622  }
623
624
625  template <typename HandlerT, typename ResponderT, typename ArgTuple,
626            size_t... Indexes>
627  static typename WrappedHandlerReturn<
628      typename HandlerTraits<HandlerT>::ReturnType>::Type
629  unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder,
630                          ArgTuple &Args,
631                          llvm::index_sequence<Indexes...>) {
632    return run(Handler, Responder, std::move(std::get<Indexes>(Args))...);
633  }
634};
635
636// Handler traits for free functions.
637template <typename RetT, typename... ArgTs>
638class HandlerTraits<RetT(*)(ArgTs...)>
639  : public HandlerTraits<RetT(ArgTs...)> {};
640
641// Handler traits for class methods (especially call operators for lambdas).
642template <typename Class, typename RetT, typename... ArgTs>
643class HandlerTraits<RetT (Class::*)(ArgTs...)>
644    : public HandlerTraits<RetT(ArgTs...)> {};
645
646// Handler traits for const class methods (especially call operators for
647// lambdas).
648template <typename Class, typename RetT, typename... ArgTs>
649class HandlerTraits<RetT (Class::*)(ArgTs...) const>
650    : public HandlerTraits<RetT(ArgTs...)> {};
651
652// Utility to peel the Expected wrapper off a response handler error type.
653template <typename HandlerT> class ResponseHandlerArg;
654
655template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> {
656public:
657  using ArgType = Expected<ArgT>;
658  using UnwrappedArgType = ArgT;
659};
660
661template <typename ArgT>
662class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> {
663public:
664  using ArgType = Expected<ArgT>;
665  using UnwrappedArgType = ArgT;
666};
667
668template <> class ResponseHandlerArg<Error(Error)> {
669public:
670  using ArgType = Error;
671};
672
673template <> class ResponseHandlerArg<ErrorSuccess(Error)> {
674public:
675  using ArgType = Error;
676};
677
678// ResponseHandler represents a handler for a not-yet-received function call
679// result.
680template <typename ChannelT> class ResponseHandler {
681public:
682  virtual ~ResponseHandler() {}
683
684  // Reads the function result off the wire and acts on it. The meaning of
685  // "act" will depend on how this method is implemented in any given
686  // ResponseHandler subclass but could, for example, mean running a
687  // user-specified handler or setting a promise value.
688  virtual Error handleResponse(ChannelT &C) = 0;
689
690  // Abandons this outstanding result.
691  virtual void abandon() = 0;
692
693  // Create an error instance representing an abandoned response.
694  static Error createAbandonedResponseError() {
695    return make_error<ResponseAbandoned>();
696  }
697};
698
699// ResponseHandler subclass for RPC functions with non-void returns.
700template <typename ChannelT, typename FuncRetT, typename HandlerT>
701class ResponseHandlerImpl : public ResponseHandler<ChannelT> {
702public:
703  ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
704
705  // Handle the result by deserializing it from the channel then passing it
706  // to the user defined handler.
707  Error handleResponse(ChannelT &C) override {
708    using UnwrappedArgType = typename ResponseHandlerArg<
709        typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType;
710    UnwrappedArgType Result;
711    if (auto Err =
712            SerializationTraits<ChannelT, FuncRetT,
713                                UnwrappedArgType>::deserialize(C, Result))
714      return Err;
715    if (auto Err = C.endReceiveMessage())
716      return Err;
717    return Handler(std::move(Result));
718  }
719
720  // Abandon this response by calling the handler with an 'abandoned response'
721  // error.
722  void abandon() override {
723    if (auto Err = Handler(this->createAbandonedResponseError())) {
724      // Handlers should not fail when passed an abandoned response error.
725      report_fatal_error(std::move(Err));
726    }
727  }
728
729private:
730  HandlerT Handler;
731};
732
733// ResponseHandler subclass for RPC functions with void returns.
734template <typename ChannelT, typename HandlerT>
735class ResponseHandlerImpl<ChannelT, void, HandlerT>
736    : public ResponseHandler<ChannelT> {
737public:
738  ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
739
740  // Handle the result (no actual value, just a notification that the function
741  // has completed on the remote end) by calling the user-defined handler with
742  // Error::success().
743  Error handleResponse(ChannelT &C) override {
744    if (auto Err = C.endReceiveMessage())
745      return Err;
746    return Handler(Error::success());
747  }
748
749  // Abandon this response by calling the handler with an 'abandoned response'
750  // error.
751  void abandon() override {
752    if (auto Err = Handler(this->createAbandonedResponseError())) {
753      // Handlers should not fail when passed an abandoned response error.
754      report_fatal_error(std::move(Err));
755    }
756  }
757
758private:
759  HandlerT Handler;
760};
761
762template <typename ChannelT, typename FuncRetT, typename HandlerT>
763class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT>
764    : public ResponseHandler<ChannelT> {
765public:
766  ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
767
768  // Handle the result by deserializing it from the channel then passing it
769  // to the user defined handler.
770  Error handleResponse(ChannelT &C) override {
771    using HandlerArgType = typename ResponseHandlerArg<
772        typename HandlerTraits<HandlerT>::Type>::ArgType;
773    HandlerArgType Result((typename HandlerArgType::value_type()));
774
775    if (auto Err =
776            SerializationTraits<ChannelT, Expected<FuncRetT>,
777                                HandlerArgType>::deserialize(C, Result))
778      return Err;
779    if (auto Err = C.endReceiveMessage())
780      return Err;
781    return Handler(std::move(Result));
782  }
783
784  // Abandon this response by calling the handler with an 'abandoned response'
785  // error.
786  void abandon() override {
787    if (auto Err = Handler(this->createAbandonedResponseError())) {
788      // Handlers should not fail when passed an abandoned response error.
789      report_fatal_error(std::move(Err));
790    }
791  }
792
793private:
794  HandlerT Handler;
795};
796
797template <typename ChannelT, typename HandlerT>
798class ResponseHandlerImpl<ChannelT, Error, HandlerT>
799    : public ResponseHandler<ChannelT> {
800public:
801  ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
802
803  // Handle the result by deserializing it from the channel then passing it
804  // to the user defined handler.
805  Error handleResponse(ChannelT &C) override {
806    Error Result = Error::success();
807    if (auto Err =
808            SerializationTraits<ChannelT, Error, Error>::deserialize(C, Result))
809      return Err;
810    if (auto Err = C.endReceiveMessage())
811      return Err;
812    return Handler(std::move(Result));
813  }
814
815  // Abandon this response by calling the handler with an 'abandoned response'
816  // error.
817  void abandon() override {
818    if (auto Err = Handler(this->createAbandonedResponseError())) {
819      // Handlers should not fail when passed an abandoned response error.
820      report_fatal_error(std::move(Err));
821    }
822  }
823
824private:
825  HandlerT Handler;
826};
827
828// Create a ResponseHandler from a given user handler.
829template <typename ChannelT, typename FuncRetT, typename HandlerT>
830std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) {
831  return llvm::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>(
832      std::move(H));
833}
834
835// Helper for wrapping member functions up as functors. This is useful for
836// installing methods as result handlers.
837template <typename ClassT, typename RetT, typename... ArgTs>
838class MemberFnWrapper {
839public:
840  using MethodT = RetT (ClassT::*)(ArgTs...);
841  MemberFnWrapper(ClassT &Instance, MethodT Method)
842      : Instance(Instance), Method(Method) {}
843  RetT operator()(ArgTs &&... Args) {
844    return (Instance.*Method)(std::move(Args)...);
845  }
846
847private:
848  ClassT &Instance;
849  MethodT Method;
850};
851
852// Helper that provides a Functor for deserializing arguments.
853template <typename... ArgTs> class ReadArgs {
854public:
855  Error operator()() { return Error::success(); }
856};
857
858template <typename ArgT, typename... ArgTs>
859class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> {
860public:
861  ReadArgs(ArgT &Arg, ArgTs &... Args)
862      : ReadArgs<ArgTs...>(Args...), Arg(Arg) {}
863
864  Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) {
865    this->Arg = std::move(ArgVal);
866    return ReadArgs<ArgTs...>::operator()(ArgVals...);
867  }
868
869private:
870  ArgT &Arg;
871};
872
873// Manage sequence numbers.
874template <typename SequenceNumberT> class SequenceNumberManager {
875public:
876  // Reset, making all sequence numbers available.
877  void reset() {
878    std::lock_guard<std::mutex> Lock(SeqNoLock);
879    NextSequenceNumber = 0;
880    FreeSequenceNumbers.clear();
881  }
882
883  // Get the next available sequence number. Will re-use numbers that have
884  // been released.
885  SequenceNumberT getSequenceNumber() {
886    std::lock_guard<std::mutex> Lock(SeqNoLock);
887    if (FreeSequenceNumbers.empty())
888      return NextSequenceNumber++;
889    auto SequenceNumber = FreeSequenceNumbers.back();
890    FreeSequenceNumbers.pop_back();
891    return SequenceNumber;
892  }
893
894  // Release a sequence number, making it available for re-use.
895  void releaseSequenceNumber(SequenceNumberT SequenceNumber) {
896    std::lock_guard<std::mutex> Lock(SeqNoLock);
897    FreeSequenceNumbers.push_back(SequenceNumber);
898  }
899
900private:
901  std::mutex SeqNoLock;
902  SequenceNumberT NextSequenceNumber = 0;
903  std::vector<SequenceNumberT> FreeSequenceNumbers;
904};
905
906// Checks that predicate P holds for each corresponding pair of type arguments
907// from T1 and T2 tuple.
908template <template <class, class> class P, typename T1Tuple, typename T2Tuple>
909class RPCArgTypeCheckHelper;
910
911template <template <class, class> class P>
912class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> {
913public:
914  static const bool value = true;
915};
916
917template <template <class, class> class P, typename T, typename... Ts,
918          typename U, typename... Us>
919class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> {
920public:
921  static const bool value =
922      P<T, U>::value &&
923      RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value;
924};
925
926template <template <class, class> class P, typename T1Sig, typename T2Sig>
927class RPCArgTypeCheck {
928public:
929  using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type;
930  using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type;
931
932  static_assert(std::tuple_size<T1Tuple>::value >=
933                    std::tuple_size<T2Tuple>::value,
934                "Too many arguments to RPC call");
935  static_assert(std::tuple_size<T1Tuple>::value <=
936                    std::tuple_size<T2Tuple>::value,
937                "Too few arguments to RPC call");
938
939  static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value;
940};
941
942template <typename ChannelT, typename WireT, typename ConcreteT>
943class CanSerialize {
944private:
945  using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
946
947  template <typename T>
948  static std::true_type
949  check(typename std::enable_if<
950        std::is_same<decltype(T::serialize(std::declval<ChannelT &>(),
951                                           std::declval<const ConcreteT &>())),
952                     Error>::value,
953        void *>::type);
954
955  template <typename> static std::false_type check(...);
956
957public:
958  static const bool value = decltype(check<S>(0))::value;
959};
960
961template <typename ChannelT, typename WireT, typename ConcreteT>
962class CanDeserialize {
963private:
964  using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
965
966  template <typename T>
967  static std::true_type
968  check(typename std::enable_if<
969        std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(),
970                                             std::declval<ConcreteT &>())),
971                     Error>::value,
972        void *>::type);
973
974  template <typename> static std::false_type check(...);
975
976public:
977  static const bool value = decltype(check<S>(0))::value;
978};
979
980/// Contains primitive utilities for defining, calling and handling calls to
981/// remote procedures. ChannelT is a bidirectional stream conforming to the
982/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure
983/// identifier type that must be serializable on ChannelT, and SequenceNumberT
984/// is an integral type that will be used to number in-flight function calls.
985///
986/// These utilities support the construction of very primitive RPC utilities.
987/// Their intent is to ensure correct serialization and deserialization of
988/// procedure arguments, and to keep the client and server's view of the API in
989/// sync.
990template <typename ImplT, typename ChannelT, typename FunctionIdT,
991          typename SequenceNumberT>
992class RPCEndpointBase {
993protected:
994  class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> {
995  public:
996    static const char *getName() { return "__orc_rpc$invalid"; }
997  };
998
999  class OrcRPCResponse : public Function<OrcRPCResponse, void()> {
1000  public:
1001    static const char *getName() { return "__orc_rpc$response"; }
1002  };
1003
1004  class OrcRPCNegotiate
1005      : public Function<OrcRPCNegotiate, FunctionIdT(std::string)> {
1006  public:
1007    static const char *getName() { return "__orc_rpc$negotiate"; }
1008  };
1009
1010  // Helper predicate for testing for the presence of SerializeTraits
1011  // serializers.
1012  template <typename WireT, typename ConcreteT>
1013  class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> {
1014  public:
1015    using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value;
1016
1017    static_assert(value, "Missing serializer for argument (Can't serialize the "
1018                         "first template type argument of CanSerializeCheck "
1019                         "from the second)");
1020  };
1021
1022  // Helper predicate for testing for the presence of SerializeTraits
1023  // deserializers.
1024  template <typename WireT, typename ConcreteT>
1025  class CanDeserializeCheck
1026      : detail::CanDeserialize<ChannelT, WireT, ConcreteT> {
1027  public:
1028    using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value;
1029
1030    static_assert(value, "Missing deserializer for argument (Can't deserialize "
1031                         "the second template type argument of "
1032                         "CanDeserializeCheck from the first)");
1033  };
1034
1035public:
1036  /// Construct an RPC instance on a channel.
1037  RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation)
1038      : C(C), LazyAutoNegotiation(LazyAutoNegotiation) {
1039    // Hold ResponseId in a special variable, since we expect Response to be
1040    // called relatively frequently, and want to avoid the map lookup.
1041    ResponseId = FnIdAllocator.getResponseId();
1042    RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId;
1043
1044    // Register the negotiate function id and handler.
1045    auto NegotiateId = FnIdAllocator.getNegotiateId();
1046    RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId;
1047    Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>(
1048        [this](const std::string &Name) { return handleNegotiate(Name); });
1049  }
1050
1051
1052  /// Negotiate a function id for Func with the other end of the channel.
1053  template <typename Func> Error negotiateFunction(bool Retry = false) {
1054    return getRemoteFunctionId<Func>(true, Retry).takeError();
1055  }
1056
1057  /// Append a call Func, does not call send on the channel.
1058  /// The first argument specifies a user-defined handler to be run when the
1059  /// function returns. The handler should take an Expected<Func::ReturnType>,
1060  /// or an Error (if Func::ReturnType is void). The handler will be called
1061  /// with an error if the return value is abandoned due to a channel error.
1062  template <typename Func, typename HandlerT, typename... ArgTs>
1063  Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) {
1064
1065    static_assert(
1066        detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type,
1067                                void(ArgTs...)>::value,
1068        "");
1069
1070    // Look up the function ID.
1071    FunctionIdT FnId;
1072    if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false))
1073      FnId = *FnIdOrErr;
1074    else {
1075      // Negotiation failed. Notify the handler then return the negotiate-failed
1076      // error.
1077      cantFail(Handler(make_error<ResponseAbandoned>()));
1078      return FnIdOrErr.takeError();
1079    }
1080
1081    SequenceNumberT SeqNo; // initialized in locked scope below.
1082    {
1083      // Lock the pending responses map and sequence number manager.
1084      std::lock_guard<std::mutex> Lock(ResponsesMutex);
1085
1086      // Allocate a sequence number.
1087      SeqNo = SequenceNumberMgr.getSequenceNumber();
1088      assert(!PendingResponses.count(SeqNo) &&
1089             "Sequence number already allocated");
1090
1091      // Install the user handler.
1092      PendingResponses[SeqNo] =
1093        detail::createResponseHandler<ChannelT, typename Func::ReturnType>(
1094            std::move(Handler));
1095    }
1096
1097    // Open the function call message.
1098    if (auto Err = C.startSendMessage(FnId, SeqNo)) {
1099      abandonPendingResponses();
1100      return Err;
1101    }
1102
1103    // Serialize the call arguments.
1104    if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs(
1105            C, Args...)) {
1106      abandonPendingResponses();
1107      return Err;
1108    }
1109
1110    // Close the function call messagee.
1111    if (auto Err = C.endSendMessage()) {
1112      abandonPendingResponses();
1113      return Err;
1114    }
1115
1116    return Error::success();
1117  }
1118
1119  Error sendAppendedCalls() { return C.send(); };
1120
1121  template <typename Func, typename HandlerT, typename... ArgTs>
1122  Error callAsync(HandlerT Handler, const ArgTs &... Args) {
1123    if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...))
1124      return Err;
1125    return C.send();
1126  }
1127
1128  /// Handle one incoming call.
1129  Error handleOne() {
1130    FunctionIdT FnId;
1131    SequenceNumberT SeqNo;
1132    if (auto Err = C.startReceiveMessage(FnId, SeqNo)) {
1133      abandonPendingResponses();
1134      return Err;
1135    }
1136    if (FnId == ResponseId)
1137      return handleResponse(SeqNo);
1138    auto I = Handlers.find(FnId);
1139    if (I != Handlers.end())
1140      return I->second(C, SeqNo);
1141
1142    // else: No handler found. Report error to client?
1143    return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId,
1144                                                                     SeqNo);
1145  }
1146
1147  /// Helper for handling setter procedures - this method returns a functor that
1148  /// sets the variables referred to by Args... to values deserialized from the
1149  /// channel.
1150  /// E.g.
1151  ///
1152  ///   typedef Function<0, bool, int> Func1;
1153  ///
1154  ///   ...
1155  ///   bool B;
1156  ///   int I;
1157  ///   if (auto Err = expect<Func1>(Channel, readArgs(B, I)))
1158  ///     /* Handle Args */ ;
1159  ///
1160  template <typename... ArgTs>
1161  static detail::ReadArgs<ArgTs...> readArgs(ArgTs &... Args) {
1162    return detail::ReadArgs<ArgTs...>(Args...);
1163  }
1164
1165  /// Abandon all outstanding result handlers.
1166  ///
1167  /// This will call all currently registered result handlers to receive an
1168  /// "abandoned" error as their argument. This is used internally by the RPC
1169  /// in error situations, but can also be called directly by clients who are
1170  /// disconnecting from the remote and don't or can't expect responses to their
1171  /// outstanding calls. (Especially for outstanding blocking calls, calling
1172  /// this function may be necessary to avoid dead threads).
1173  void abandonPendingResponses() {
1174    // Lock the pending responses map and sequence number manager.
1175    std::lock_guard<std::mutex> Lock(ResponsesMutex);
1176
1177    for (auto &KV : PendingResponses)
1178      KV.second->abandon();
1179    PendingResponses.clear();
1180    SequenceNumberMgr.reset();
1181  }
1182
1183  /// Remove the handler for the given function.
1184  /// A handler must currently be registered for this function.
1185  template <typename Func>
1186  void removeHandler() {
1187    auto IdItr = LocalFunctionIds.find(Func::getPrototype());
1188    assert(IdItr != LocalFunctionIds.end() &&
1189           "Function does not have a registered handler");
1190    auto HandlerItr = Handlers.find(IdItr->second);
1191    assert(HandlerItr != Handlers.end() &&
1192           "Function does not have a registered handler");
1193    Handlers.erase(HandlerItr);
1194  }
1195
1196  /// Clear all handlers.
1197  void clearHandlers() {
1198    Handlers.clear();
1199  }
1200
1201protected:
1202
1203  FunctionIdT getInvalidFunctionId() const {
1204    return FnIdAllocator.getInvalidId();
1205  }
1206
1207  /// Add the given handler to the handler map and make it available for
1208  /// autonegotiation and execution.
1209  template <typename Func, typename HandlerT>
1210  void addHandlerImpl(HandlerT Handler) {
1211
1212    static_assert(detail::RPCArgTypeCheck<
1213                      CanDeserializeCheck, typename Func::Type,
1214                      typename detail::HandlerTraits<HandlerT>::Type>::value,
1215                  "");
1216
1217    FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1218    LocalFunctionIds[Func::getPrototype()] = NewFnId;
1219    Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler));
1220  }
1221
1222  template <typename Func, typename HandlerT>
1223  void addAsyncHandlerImpl(HandlerT Handler) {
1224
1225    static_assert(detail::RPCArgTypeCheck<
1226                      CanDeserializeCheck, typename Func::Type,
1227                      typename detail::AsyncHandlerTraits<
1228                        typename detail::HandlerTraits<HandlerT>::Type
1229                      >::Type>::value,
1230                  "");
1231
1232    FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1233    LocalFunctionIds[Func::getPrototype()] = NewFnId;
1234    Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler));
1235  }
1236
1237  Error handleResponse(SequenceNumberT SeqNo) {
1238    using Handler = typename decltype(PendingResponses)::mapped_type;
1239    Handler PRHandler;
1240
1241    {
1242      // Lock the pending responses map and sequence number manager.
1243      std::unique_lock<std::mutex> Lock(ResponsesMutex);
1244      auto I = PendingResponses.find(SeqNo);
1245
1246      if (I != PendingResponses.end()) {
1247        PRHandler = std::move(I->second);
1248        PendingResponses.erase(I);
1249        SequenceNumberMgr.releaseSequenceNumber(SeqNo);
1250      } else {
1251        // Unlock the pending results map to prevent recursive lock.
1252        Lock.unlock();
1253        abandonPendingResponses();
1254        return make_error<
1255                 InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo);
1256      }
1257    }
1258
1259    assert(PRHandler &&
1260           "If we didn't find a response handler we should have bailed out");
1261
1262    if (auto Err = PRHandler->handleResponse(C)) {
1263      abandonPendingResponses();
1264      return Err;
1265    }
1266
1267    return Error::success();
1268  }
1269
1270  FunctionIdT handleNegotiate(const std::string &Name) {
1271    auto I = LocalFunctionIds.find(Name);
1272    if (I == LocalFunctionIds.end())
1273      return getInvalidFunctionId();
1274    return I->second;
1275  }
1276
1277  // Find the remote FunctionId for the given function.
1278  template <typename Func>
1279  Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap,
1280                                            bool NegotiateIfInvalid) {
1281    bool DoNegotiate;
1282
1283    // Check if we already have a function id...
1284    auto I = RemoteFunctionIds.find(Func::getPrototype());
1285    if (I != RemoteFunctionIds.end()) {
1286      // If it's valid there's nothing left to do.
1287      if (I->second != getInvalidFunctionId())
1288        return I->second;
1289      DoNegotiate = NegotiateIfInvalid;
1290    } else
1291      DoNegotiate = NegotiateIfNotInMap;
1292
1293    // We don't have a function id for Func yet, but we're allowed to try to
1294    // negotiate one.
1295    if (DoNegotiate) {
1296      auto &Impl = static_cast<ImplT &>(*this);
1297      if (auto RemoteIdOrErr =
1298          Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) {
1299        RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
1300        if (*RemoteIdOrErr == getInvalidFunctionId())
1301          return make_error<CouldNotNegotiate>(Func::getPrototype());
1302        return *RemoteIdOrErr;
1303      } else
1304        return RemoteIdOrErr.takeError();
1305    }
1306
1307    // No key was available in the map and we weren't allowed to try to
1308    // negotiate one, so return an unknown function error.
1309    return make_error<CouldNotNegotiate>(Func::getPrototype());
1310  }
1311
1312  using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>;
1313
1314  // Wrap the given user handler in the necessary argument-deserialization code,
1315  // result-serialization code, and call to the launch policy (if present).
1316  template <typename Func, typename HandlerT>
1317  WrappedHandlerFn wrapHandler(HandlerT Handler) {
1318    return [this, Handler](ChannelT &Channel,
1319                           SequenceNumberT SeqNo) mutable -> Error {
1320      // Start by deserializing the arguments.
1321      using ArgsTuple =
1322          typename detail::FunctionArgsTuple<
1323            typename detail::HandlerTraits<HandlerT>::Type>::Type;
1324      auto Args = std::make_shared<ArgsTuple>();
1325
1326      if (auto Err =
1327              detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1328                  Channel, *Args))
1329        return Err;
1330
1331      // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1332      // for RPCArgs. Void cast RPCArgs to work around this for now.
1333      // FIXME: Remove this workaround once we can assume a working GCC version.
1334      (void)Args;
1335
1336      // End receieve message, unlocking the channel for reading.
1337      if (auto Err = Channel.endReceiveMessage())
1338        return Err;
1339
1340      using HTraits = detail::HandlerTraits<HandlerT>;
1341      using FuncReturn = typename Func::ReturnType;
1342      return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo,
1343                                         HTraits::unpackAndRun(Handler, *Args));
1344    };
1345  }
1346
1347  // Wrap the given user handler in the necessary argument-deserialization code,
1348  // result-serialization code, and call to the launch policy (if present).
1349  template <typename Func, typename HandlerT>
1350  WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) {
1351    return [this, Handler](ChannelT &Channel,
1352                           SequenceNumberT SeqNo) mutable -> Error {
1353      // Start by deserializing the arguments.
1354      using AHTraits = detail::AsyncHandlerTraits<
1355                         typename detail::HandlerTraits<HandlerT>::Type>;
1356      using ArgsTuple =
1357          typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type;
1358      auto Args = std::make_shared<ArgsTuple>();
1359
1360      if (auto Err =
1361              detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1362                  Channel, *Args))
1363        return Err;
1364
1365      // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1366      // for RPCArgs. Void cast RPCArgs to work around this for now.
1367      // FIXME: Remove this workaround once we can assume a working GCC version.
1368      (void)Args;
1369
1370      // End receieve message, unlocking the channel for reading.
1371      if (auto Err = Channel.endReceiveMessage())
1372        return Err;
1373
1374      using HTraits = detail::HandlerTraits<HandlerT>;
1375      using FuncReturn = typename Func::ReturnType;
1376      auto Responder =
1377        [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error {
1378          return detail::respond<FuncReturn>(C, ResponseId, SeqNo,
1379                                             std::move(RetVal));
1380        };
1381
1382      return HTraits::unpackAndRunAsync(Handler, Responder, *Args);
1383    };
1384  }
1385
1386  ChannelT &C;
1387
1388  bool LazyAutoNegotiation;
1389
1390  RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator;
1391
1392  FunctionIdT ResponseId;
1393  std::map<std::string, FunctionIdT> LocalFunctionIds;
1394  std::map<const char *, FunctionIdT> RemoteFunctionIds;
1395
1396  std::map<FunctionIdT, WrappedHandlerFn> Handlers;
1397
1398  std::mutex ResponsesMutex;
1399  detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr;
1400  std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>>
1401      PendingResponses;
1402};
1403
1404} // end namespace detail
1405
1406template <typename ChannelT, typename FunctionIdT = uint32_t,
1407          typename SequenceNumberT = uint32_t>
1408class MultiThreadedRPCEndpoint
1409    : public detail::RPCEndpointBase<
1410          MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1411          ChannelT, FunctionIdT, SequenceNumberT> {
1412private:
1413  using BaseClass =
1414      detail::RPCEndpointBase<
1415        MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1416        ChannelT, FunctionIdT, SequenceNumberT>;
1417
1418public:
1419  MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1420      : BaseClass(C, LazyAutoNegotiation) {}
1421
1422  /// Add a handler for the given RPC function.
1423  /// This installs the given handler functor for the given RPC Function, and
1424  /// makes the RPC function available for negotiation/calling from the remote.
1425  template <typename Func, typename HandlerT>
1426  void addHandler(HandlerT Handler) {
1427    return this->template addHandlerImpl<Func>(std::move(Handler));
1428  }
1429
1430  /// Add a class-method as a handler.
1431  template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1432  void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1433    addHandler<Func>(
1434      detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1435  }
1436
1437  template <typename Func, typename HandlerT>
1438  void addAsyncHandler(HandlerT Handler) {
1439    return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1440  }
1441
1442  /// Add a class-method as a handler.
1443  template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1444  void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1445    addAsyncHandler<Func>(
1446      detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1447  }
1448
1449  /// Return type for non-blocking call primitives.
1450  template <typename Func>
1451  using NonBlockingCallResult = typename detail::ResultTraits<
1452      typename Func::ReturnType>::ReturnFutureType;
1453
1454  /// Call Func on Channel C. Does not block, does not call send. Returns a pair
1455  /// of a future result and the sequence number assigned to the result.
1456  ///
1457  /// This utility function is primarily used for single-threaded mode support,
1458  /// where the sequence number can be used to wait for the corresponding
1459  /// result. In multi-threaded mode the appendCallNB method, which does not
1460  /// return the sequence numeber, should be preferred.
1461  template <typename Func, typename... ArgTs>
1462  Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &... Args) {
1463    using RTraits = detail::ResultTraits<typename Func::ReturnType>;
1464    using ErrorReturn = typename RTraits::ErrorReturnType;
1465    using ErrorReturnPromise = typename RTraits::ReturnPromiseType;
1466
1467    // FIXME: Stack allocate and move this into the handler once LLVM builds
1468    //        with C++14.
1469    auto Promise = std::make_shared<ErrorReturnPromise>();
1470    auto FutureResult = Promise->get_future();
1471
1472    if (auto Err = this->template appendCallAsync<Func>(
1473            [Promise](ErrorReturn RetOrErr) {
1474              Promise->set_value(std::move(RetOrErr));
1475              return Error::success();
1476            },
1477            Args...)) {
1478      RTraits::consumeAbandoned(FutureResult.get());
1479      return std::move(Err);
1480    }
1481    return std::move(FutureResult);
1482  }
1483
1484  /// The same as appendCallNBWithSeq, except that it calls C.send() to
1485  /// flush the channel after serializing the call.
1486  template <typename Func, typename... ArgTs>
1487  Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &... Args) {
1488    auto Result = appendCallNB<Func>(Args...);
1489    if (!Result)
1490      return Result;
1491    if (auto Err = this->C.send()) {
1492      this->abandonPendingResponses();
1493      detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1494          std::move(Result->get()));
1495      return std::move(Err);
1496    }
1497    return Result;
1498  }
1499
1500  /// Call Func on Channel C. Blocks waiting for a result. Returns an Error
1501  /// for void functions or an Expected<T> for functions returning a T.
1502  ///
1503  /// This function is for use in threaded code where another thread is
1504  /// handling responses and incoming calls.
1505  template <typename Func, typename... ArgTs,
1506            typename AltRetT = typename Func::ReturnType>
1507  typename detail::ResultTraits<AltRetT>::ErrorReturnType
1508  callB(const ArgTs &... Args) {
1509    if (auto FutureResOrErr = callNB<Func>(Args...))
1510      return FutureResOrErr->get();
1511    else
1512      return FutureResOrErr.takeError();
1513  }
1514
1515  /// Handle incoming RPC calls.
1516  Error handlerLoop() {
1517    while (true)
1518      if (auto Err = this->handleOne())
1519        return Err;
1520    return Error::success();
1521  }
1522};
1523
1524template <typename ChannelT, typename FunctionIdT = uint32_t,
1525          typename SequenceNumberT = uint32_t>
1526class SingleThreadedRPCEndpoint
1527    : public detail::RPCEndpointBase<
1528          SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1529          ChannelT, FunctionIdT, SequenceNumberT> {
1530private:
1531  using BaseClass =
1532      detail::RPCEndpointBase<
1533        SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1534        ChannelT, FunctionIdT, SequenceNumberT>;
1535
1536public:
1537  SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1538      : BaseClass(C, LazyAutoNegotiation) {}
1539
1540  template <typename Func, typename HandlerT>
1541  void addHandler(HandlerT Handler) {
1542    return this->template addHandlerImpl<Func>(std::move(Handler));
1543  }
1544
1545  template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1546  void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1547    addHandler<Func>(
1548        detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1549  }
1550
1551  template <typename Func, typename HandlerT>
1552  void addAsyncHandler(HandlerT Handler) {
1553    return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1554  }
1555
1556  /// Add a class-method as a handler.
1557  template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1558  void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1559    addAsyncHandler<Func>(
1560      detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1561  }
1562
1563  template <typename Func, typename... ArgTs,
1564            typename AltRetT = typename Func::ReturnType>
1565  typename detail::ResultTraits<AltRetT>::ErrorReturnType
1566  callB(const ArgTs &... Args) {
1567    bool ReceivedResponse = false;
1568    using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType;
1569    auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue();
1570
1571    // We have to 'Check' result (which we know is in a success state at this
1572    // point) so that it can be overwritten in the async handler.
1573    (void)!!Result;
1574
1575    if (auto Err = this->template appendCallAsync<Func>(
1576            [&](ResultType R) {
1577              Result = std::move(R);
1578              ReceivedResponse = true;
1579              return Error::success();
1580            },
1581            Args...)) {
1582      detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1583          std::move(Result));
1584      return std::move(Err);
1585    }
1586
1587    while (!ReceivedResponse) {
1588      if (auto Err = this->handleOne()) {
1589        detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1590            std::move(Result));
1591        return std::move(Err);
1592      }
1593    }
1594
1595    return Result;
1596  }
1597};
1598
1599/// Asynchronous dispatch for a function on an RPC endpoint.
1600template <typename RPCClass, typename Func>
1601class RPCAsyncDispatch {
1602public:
1603  RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {}
1604
1605  template <typename HandlerT, typename... ArgTs>
1606  Error operator()(HandlerT Handler, const ArgTs &... Args) const {
1607    return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...);
1608  }
1609
1610private:
1611  RPCClass &Endpoint;
1612};
1613
1614/// Construct an asynchronous dispatcher from an RPC endpoint and a Func.
1615template <typename Func, typename RPCEndpointT>
1616RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) {
1617  return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint);
1618}
1619
1620/// \brief Allows a set of asynchrounous calls to be dispatched, and then
1621///        waited on as a group.
1622class ParallelCallGroup {
1623public:
1624
1625  ParallelCallGroup() = default;
1626  ParallelCallGroup(const ParallelCallGroup &) = delete;
1627  ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;
1628
1629  /// \brief Make as asynchronous call.
1630  template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs>
1631  Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler,
1632             const ArgTs &... Args) {
1633    // Increment the count of outstanding calls. This has to happen before
1634    // we invoke the call, as the handler may (depending on scheduling)
1635    // be run immediately on another thread, and we don't want the decrement
1636    // in the wrapped handler below to run before the increment.
1637    {
1638      std::unique_lock<std::mutex> Lock(M);
1639      ++NumOutstandingCalls;
1640    }
1641
1642    // Wrap the user handler in a lambda that will decrement the
1643    // outstanding calls count, then poke the condition variable.
1644    using ArgType = typename detail::ResponseHandlerArg<
1645        typename detail::HandlerTraits<HandlerT>::Type>::ArgType;
1646    // FIXME: Move handler into wrapped handler once we have C++14.
1647    auto WrappedHandler = [this, Handler](ArgType Arg) {
1648      auto Err = Handler(std::move(Arg));
1649      std::unique_lock<std::mutex> Lock(M);
1650      --NumOutstandingCalls;
1651      CV.notify_all();
1652      return Err;
1653    };
1654
1655    return AsyncDispatch(std::move(WrappedHandler), Args...);
1656  }
1657
1658  /// \brief Blocks until all calls have been completed and their return value
1659  ///        handlers run.
1660  void wait() {
1661    std::unique_lock<std::mutex> Lock(M);
1662    while (NumOutstandingCalls > 0)
1663      CV.wait(Lock);
1664  }
1665
1666private:
1667  std::mutex M;
1668  std::condition_variable CV;
1669  uint32_t NumOutstandingCalls = 0;
1670};
1671
1672/// @brief Convenience class for grouping RPC Functions into APIs that can be
1673///        negotiated as a block.
1674///
1675template <typename... Funcs>
1676class APICalls {
1677public:
1678
1679  /// @brief Test whether this API contains Function F.
1680  template <typename F>
1681  class Contains {
1682  public:
1683    static const bool value = false;
1684  };
1685
1686  /// @brief Negotiate all functions in this API.
1687  template <typename RPCEndpoint>
1688  static Error negotiate(RPCEndpoint &R) {
1689    return Error::success();
1690  }
1691};
1692
1693template <typename Func, typename... Funcs>
1694class APICalls<Func, Funcs...> {
1695public:
1696
1697  template <typename F>
1698  class Contains {
1699  public:
1700    static const bool value = std::is_same<F, Func>::value |
1701                              APICalls<Funcs...>::template Contains<F>::value;
1702  };
1703
1704  template <typename RPCEndpoint>
1705  static Error negotiate(RPCEndpoint &R) {
1706    if (auto Err = R.template negotiateFunction<Func>())
1707      return Err;
1708    return APICalls<Funcs...>::negotiate(R);
1709  }
1710
1711};
1712
1713template <typename... InnerFuncs, typename... Funcs>
1714class APICalls<APICalls<InnerFuncs...>, Funcs...> {
1715public:
1716
1717  template <typename F>
1718  class Contains {
1719  public:
1720    static const bool value =
1721      APICalls<InnerFuncs...>::template Contains<F>::value |
1722      APICalls<Funcs...>::template Contains<F>::value;
1723  };
1724
1725  template <typename RPCEndpoint>
1726  static Error negotiate(RPCEndpoint &R) {
1727    if (auto Err = APICalls<InnerFuncs...>::negotiate(R))
1728      return Err;
1729    return APICalls<Funcs...>::negotiate(R);
1730  }
1731
1732};
1733
1734} // end namespace rpc
1735} // end namespace orc
1736} // end namespace llvm
1737
1738#endif
1739