1//===------- RPCUTils.h - Utilities for building RPC APIs -------*- C++ -*-===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9// 10// Utilities to support construction of simple RPC APIs. 11// 12// The RPC utilities aim for ease of use (minimal conceptual overhead) for C++ 13// programmers, high performance, low memory overhead, and efficient use of the 14// communications channel. 15// 16//===----------------------------------------------------------------------===// 17 18#ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H 19#define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H 20 21#include <map> 22#include <thread> 23#include <vector> 24 25#include "llvm/ADT/STLExtras.h" 26#include "llvm/ExecutionEngine/Orc/OrcError.h" 27#include "llvm/ExecutionEngine/Orc/RPCSerialization.h" 28 29#include <future> 30 31namespace llvm { 32namespace orc { 33namespace rpc { 34 35/// Base class of all fatal RPC errors (those that necessarily result in the 36/// termination of the RPC session). 37class RPCFatalError : public ErrorInfo<RPCFatalError> { 38public: 39 static char ID; 40}; 41 42/// RPCConnectionClosed is returned from RPC operations if the RPC connection 43/// has already been closed due to either an error or graceful disconnection. 44class ConnectionClosed : public ErrorInfo<ConnectionClosed> { 45public: 46 static char ID; 47 std::error_code convertToErrorCode() const override; 48 void log(raw_ostream &OS) const override; 49}; 50 51/// BadFunctionCall is returned from handleOne when the remote makes a call with 52/// an unrecognized function id. 53/// 54/// This error is fatal because Orc RPC needs to know how to parse a function 55/// call to know where the next call starts, and if it doesn't recognize the 56/// function id it cannot parse the call. 57template <typename FnIdT, typename SeqNoT> 58class BadFunctionCall 59 : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> { 60public: 61 static char ID; 62 63 BadFunctionCall(FnIdT FnId, SeqNoT SeqNo) 64 : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {} 65 66 std::error_code convertToErrorCode() const override { 67 return orcError(OrcErrorCode::UnexpectedRPCCall); 68 } 69 70 void log(raw_ostream &OS) const override { 71 OS << "Call to invalid RPC function id '" << FnId << "' with " 72 "sequence number " << SeqNo; 73 } 74 75private: 76 FnIdT FnId; 77 SeqNoT SeqNo; 78}; 79 80template <typename FnIdT, typename SeqNoT> 81char BadFunctionCall<FnIdT, SeqNoT>::ID = 0; 82 83/// InvalidSequenceNumberForResponse is returned from handleOne when a response 84/// call arrives with a sequence number that doesn't correspond to any in-flight 85/// function call. 86/// 87/// This error is fatal because Orc RPC needs to know how to parse the rest of 88/// the response call to know where the next call starts, and if it doesn't have 89/// a result parser for this sequence number it can't do that. 90template <typename SeqNoT> 91class InvalidSequenceNumberForResponse 92 : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> { 93public: 94 static char ID; 95 96 InvalidSequenceNumberForResponse(SeqNoT SeqNo) 97 : SeqNo(std::move(SeqNo)) {} 98 99 std::error_code convertToErrorCode() const override { 100 return orcError(OrcErrorCode::UnexpectedRPCCall); 101 }; 102 103 void log(raw_ostream &OS) const override { 104 OS << "Response has unknown sequence number " << SeqNo; 105 } 106private: 107 SeqNoT SeqNo; 108}; 109 110template <typename SeqNoT> 111char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0; 112 113/// This non-fatal error will be passed to asynchronous result handlers in place 114/// of a result if the connection goes down before a result returns, or if the 115/// function to be called cannot be negotiated with the remote. 116class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> { 117public: 118 static char ID; 119 120 std::error_code convertToErrorCode() const override; 121 void log(raw_ostream &OS) const override; 122}; 123 124/// This error is returned if the remote does not have a handler installed for 125/// the given RPC function. 126class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> { 127public: 128 static char ID; 129 130 CouldNotNegotiate(std::string Signature); 131 std::error_code convertToErrorCode() const override; 132 void log(raw_ostream &OS) const override; 133 const std::string &getSignature() const { return Signature; } 134private: 135 std::string Signature; 136}; 137 138template <typename DerivedFunc, typename FnT> class Function; 139 140// RPC Function class. 141// DerivedFunc should be a user defined class with a static 'getName()' method 142// returning a const char* representing the function's name. 143template <typename DerivedFunc, typename RetT, typename... ArgTs> 144class Function<DerivedFunc, RetT(ArgTs...)> { 145public: 146 /// User defined function type. 147 using Type = RetT(ArgTs...); 148 149 /// Return type. 150 using ReturnType = RetT; 151 152 /// Returns the full function prototype as a string. 153 static const char *getPrototype() { 154 std::lock_guard<std::mutex> Lock(NameMutex); 155 if (Name.empty()) 156 raw_string_ostream(Name) 157 << RPCTypeName<RetT>::getName() << " " << DerivedFunc::getName() 158 << "(" << llvm::orc::rpc::RPCTypeNameSequence<ArgTs...>() << ")"; 159 return Name.data(); 160 } 161 162private: 163 static std::mutex NameMutex; 164 static std::string Name; 165}; 166 167template <typename DerivedFunc, typename RetT, typename... ArgTs> 168std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex; 169 170template <typename DerivedFunc, typename RetT, typename... ArgTs> 171std::string Function<DerivedFunc, RetT(ArgTs...)>::Name; 172 173/// Allocates RPC function ids during autonegotiation. 174/// Specializations of this class must provide four members: 175/// 176/// static T getInvalidId(): 177/// Should return a reserved id that will be used to represent missing 178/// functions during autonegotiation. 179/// 180/// static T getResponseId(): 181/// Should return a reserved id that will be used to send function responses 182/// (return values). 183/// 184/// static T getNegotiateId(): 185/// Should return a reserved id for the negotiate function, which will be used 186/// to negotiate ids for user defined functions. 187/// 188/// template <typename Func> T allocate(): 189/// Allocate a unique id for function Func. 190template <typename T, typename = void> class RPCFunctionIdAllocator; 191 192/// This specialization of RPCFunctionIdAllocator provides a default 193/// implementation for integral types. 194template <typename T> 195class RPCFunctionIdAllocator< 196 T, typename std::enable_if<std::is_integral<T>::value>::type> { 197public: 198 static T getInvalidId() { return T(0); } 199 static T getResponseId() { return T(1); } 200 static T getNegotiateId() { return T(2); } 201 202 template <typename Func> T allocate() { return NextId++; } 203 204private: 205 T NextId = 3; 206}; 207 208namespace detail { 209 210// FIXME: Remove MSVCPError/MSVCPExpected once MSVC's future implementation 211// supports classes without default constructors. 212#ifdef _MSC_VER 213 214namespace msvc_hacks { 215 216// Work around MSVC's future implementation's use of default constructors: 217// A default constructed value in the promise will be overwritten when the 218// real error is set - so the default constructed Error has to be checked 219// already. 220class MSVCPError : public Error { 221public: 222 MSVCPError() { (void)!!*this; } 223 224 MSVCPError(MSVCPError &&Other) : Error(std::move(Other)) {} 225 226 MSVCPError &operator=(MSVCPError Other) { 227 Error::operator=(std::move(Other)); 228 return *this; 229 } 230 231 MSVCPError(Error Err) : Error(std::move(Err)) {} 232}; 233 234// Work around MSVC's future implementation, similar to MSVCPError. 235template <typename T> class MSVCPExpected : public Expected<T> { 236public: 237 MSVCPExpected() 238 : Expected<T>(make_error<StringError>("", inconvertibleErrorCode())) { 239 consumeError(this->takeError()); 240 } 241 242 MSVCPExpected(MSVCPExpected &&Other) : Expected<T>(std::move(Other)) {} 243 244 MSVCPExpected &operator=(MSVCPExpected &&Other) { 245 Expected<T>::operator=(std::move(Other)); 246 return *this; 247 } 248 249 MSVCPExpected(Error Err) : Expected<T>(std::move(Err)) {} 250 251 template <typename OtherT> 252 MSVCPExpected( 253 OtherT &&Val, 254 typename std::enable_if<std::is_convertible<OtherT, T>::value>::type * = 255 nullptr) 256 : Expected<T>(std::move(Val)) {} 257 258 template <class OtherT> 259 MSVCPExpected( 260 Expected<OtherT> &&Other, 261 typename std::enable_if<std::is_convertible<OtherT, T>::value>::type * = 262 nullptr) 263 : Expected<T>(std::move(Other)) {} 264 265 template <class OtherT> 266 explicit MSVCPExpected( 267 Expected<OtherT> &&Other, 268 typename std::enable_if<!std::is_convertible<OtherT, T>::value>::type * = 269 nullptr) 270 : Expected<T>(std::move(Other)) {} 271}; 272 273} // end namespace msvc_hacks 274 275#endif // _MSC_VER 276 277/// Provides a typedef for a tuple containing the decayed argument types. 278template <typename T> class FunctionArgsTuple; 279 280template <typename RetT, typename... ArgTs> 281class FunctionArgsTuple<RetT(ArgTs...)> { 282public: 283 using Type = std::tuple<typename std::decay< 284 typename std::remove_reference<ArgTs>::type>::type...>; 285}; 286 287// ResultTraits provides typedefs and utilities specific to the return type 288// of functions. 289template <typename RetT> class ResultTraits { 290public: 291 // The return type wrapped in llvm::Expected. 292 using ErrorReturnType = Expected<RetT>; 293 294#ifdef _MSC_VER 295 // The ErrorReturnType wrapped in a std::promise. 296 using ReturnPromiseType = std::promise<msvc_hacks::MSVCPExpected<RetT>>; 297 298 // The ErrorReturnType wrapped in a std::future. 299 using ReturnFutureType = std::future<msvc_hacks::MSVCPExpected<RetT>>; 300#else 301 // The ErrorReturnType wrapped in a std::promise. 302 using ReturnPromiseType = std::promise<ErrorReturnType>; 303 304 // The ErrorReturnType wrapped in a std::future. 305 using ReturnFutureType = std::future<ErrorReturnType>; 306#endif 307 308 // Create a 'blank' value of the ErrorReturnType, ready and safe to 309 // overwrite. 310 static ErrorReturnType createBlankErrorReturnValue() { 311 return ErrorReturnType(RetT()); 312 } 313 314 // Consume an abandoned ErrorReturnType. 315 static void consumeAbandoned(ErrorReturnType RetOrErr) { 316 consumeError(RetOrErr.takeError()); 317 } 318}; 319 320// ResultTraits specialization for void functions. 321template <> class ResultTraits<void> { 322public: 323 // For void functions, ErrorReturnType is llvm::Error. 324 using ErrorReturnType = Error; 325 326#ifdef _MSC_VER 327 // The ErrorReturnType wrapped in a std::promise. 328 using ReturnPromiseType = std::promise<msvc_hacks::MSVCPError>; 329 330 // The ErrorReturnType wrapped in a std::future. 331 using ReturnFutureType = std::future<msvc_hacks::MSVCPError>; 332#else 333 // The ErrorReturnType wrapped in a std::promise. 334 using ReturnPromiseType = std::promise<ErrorReturnType>; 335 336 // The ErrorReturnType wrapped in a std::future. 337 using ReturnFutureType = std::future<ErrorReturnType>; 338#endif 339 340 // Create a 'blank' value of the ErrorReturnType, ready and safe to 341 // overwrite. 342 static ErrorReturnType createBlankErrorReturnValue() { 343 return ErrorReturnType::success(); 344 } 345 346 // Consume an abandoned ErrorReturnType. 347 static void consumeAbandoned(ErrorReturnType Err) { 348 consumeError(std::move(Err)); 349 } 350}; 351 352// ResultTraits<Error> is equivalent to ResultTraits<void>. This allows 353// handlers for void RPC functions to return either void (in which case they 354// implicitly succeed) or Error (in which case their error return is 355// propagated). See usage in HandlerTraits::runHandlerHelper. 356template <> class ResultTraits<Error> : public ResultTraits<void> {}; 357 358// ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows 359// handlers for RPC functions returning a T to return either a T (in which 360// case they implicitly succeed) or Expected<T> (in which case their error 361// return is propagated). See usage in HandlerTraits::runHandlerHelper. 362template <typename RetT> 363class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {}; 364 365// Determines whether an RPC function's defined error return type supports 366// error return value. 367template <typename T> 368class SupportsErrorReturn { 369public: 370 static const bool value = false; 371}; 372 373template <> 374class SupportsErrorReturn<Error> { 375public: 376 static const bool value = true; 377}; 378 379template <typename T> 380class SupportsErrorReturn<Expected<T>> { 381public: 382 static const bool value = true; 383}; 384 385// RespondHelper packages return values based on whether or not the declared 386// RPC function return type supports error returns. 387template <bool FuncSupportsErrorReturn> 388class RespondHelper; 389 390// RespondHelper specialization for functions that support error returns. 391template <> 392class RespondHelper<true> { 393public: 394 395 // Send Expected<T>. 396 template <typename WireRetT, typename HandlerRetT, typename ChannelT, 397 typename FunctionIdT, typename SequenceNumberT> 398 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, 399 SequenceNumberT SeqNo, 400 Expected<HandlerRetT> ResultOrErr) { 401 if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>()) 402 return ResultOrErr.takeError(); 403 404 // Open the response message. 405 if (auto Err = C.startSendMessage(ResponseId, SeqNo)) 406 return Err; 407 408 // Serialize the result. 409 if (auto Err = 410 SerializationTraits<ChannelT, WireRetT, 411 Expected<HandlerRetT>>::serialize( 412 C, std::move(ResultOrErr))) 413 return Err; 414 415 // Close the response message. 416 return C.endSendMessage(); 417 } 418 419 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> 420 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, 421 SequenceNumberT SeqNo, Error Err) { 422 if (Err && Err.isA<RPCFatalError>()) 423 return Err; 424 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) 425 return Err2; 426 if (auto Err2 = serializeSeq(C, std::move(Err))) 427 return Err2; 428 return C.endSendMessage(); 429 } 430 431}; 432 433// RespondHelper specialization for functions that do not support error returns. 434template <> 435class RespondHelper<false> { 436public: 437 438 template <typename WireRetT, typename HandlerRetT, typename ChannelT, 439 typename FunctionIdT, typename SequenceNumberT> 440 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, 441 SequenceNumberT SeqNo, 442 Expected<HandlerRetT> ResultOrErr) { 443 if (auto Err = ResultOrErr.takeError()) 444 return Err; 445 446 // Open the response message. 447 if (auto Err = C.startSendMessage(ResponseId, SeqNo)) 448 return Err; 449 450 // Serialize the result. 451 if (auto Err = 452 SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize( 453 C, *ResultOrErr)) 454 return Err; 455 456 // Close the response message. 457 return C.endSendMessage(); 458 } 459 460 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> 461 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, 462 SequenceNumberT SeqNo, Error Err) { 463 if (Err) 464 return Err; 465 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) 466 return Err2; 467 return C.endSendMessage(); 468 } 469 470}; 471 472 473// Send a response of the given wire return type (WireRetT) over the 474// channel, with the given sequence number. 475template <typename WireRetT, typename HandlerRetT, typename ChannelT, 476 typename FunctionIdT, typename SequenceNumberT> 477Error respond(ChannelT &C, const FunctionIdT &ResponseId, 478 SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) { 479 return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: 480 template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr)); 481} 482 483// Send an empty response message on the given channel to indicate that 484// the handler ran. 485template <typename WireRetT, typename ChannelT, typename FunctionIdT, 486 typename SequenceNumberT> 487Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, 488 Error Err) { 489 return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: 490 sendResult(C, ResponseId, SeqNo, std::move(Err)); 491} 492 493// Converts a given type to the equivalent error return type. 494template <typename T> class WrappedHandlerReturn { 495public: 496 using Type = Expected<T>; 497}; 498 499template <typename T> class WrappedHandlerReturn<Expected<T>> { 500public: 501 using Type = Expected<T>; 502}; 503 504template <> class WrappedHandlerReturn<void> { 505public: 506 using Type = Error; 507}; 508 509template <> class WrappedHandlerReturn<Error> { 510public: 511 using Type = Error; 512}; 513 514template <> class WrappedHandlerReturn<ErrorSuccess> { 515public: 516 using Type = Error; 517}; 518 519// Traits class that strips the response function from the list of handler 520// arguments. 521template <typename FnT> class AsyncHandlerTraits; 522 523template <typename ResultT, typename... ArgTs> 524class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> { 525public: 526 using Type = Error(ArgTs...); 527 using ResultType = Expected<ResultT>; 528}; 529 530template <typename... ArgTs> 531class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> { 532public: 533 using Type = Error(ArgTs...); 534 using ResultType = Error; 535}; 536 537template <typename ResponseHandlerT, typename... ArgTs> 538class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> : 539 public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type, 540 ArgTs...)> {}; 541 542// This template class provides utilities related to RPC function handlers. 543// The base case applies to non-function types (the template class is 544// specialized for function types) and inherits from the appropriate 545// speciilization for the given non-function type's call operator. 546template <typename HandlerT> 547class HandlerTraits : public HandlerTraits<decltype( 548 &std::remove_reference<HandlerT>::type::operator())> { 549}; 550 551// Traits for handlers with a given function type. 552template <typename RetT, typename... ArgTs> 553class HandlerTraits<RetT(ArgTs...)> { 554public: 555 // Function type of the handler. 556 using Type = RetT(ArgTs...); 557 558 // Return type of the handler. 559 using ReturnType = RetT; 560 561 // Call the given handler with the given arguments. 562 template <typename HandlerT, typename... TArgTs> 563 static typename WrappedHandlerReturn<RetT>::Type 564 unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) { 565 return unpackAndRunHelper(Handler, Args, 566 llvm::index_sequence_for<TArgTs...>()); 567 } 568 569 // Call the given handler with the given arguments. 570 template <typename HandlerT, typename ResponderT, typename... TArgTs> 571 static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder, 572 std::tuple<TArgTs...> &Args) { 573 return unpackAndRunAsyncHelper(Handler, Responder, Args, 574 llvm::index_sequence_for<TArgTs...>()); 575 } 576 577 // Call the given handler with the given arguments. 578 template <typename HandlerT> 579 static typename std::enable_if< 580 std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, 581 Error>::type 582 run(HandlerT &Handler, ArgTs &&... Args) { 583 Handler(std::move(Args)...); 584 return Error::success(); 585 } 586 587 template <typename HandlerT, typename... TArgTs> 588 static typename std::enable_if< 589 !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, 590 typename HandlerTraits<HandlerT>::ReturnType>::type 591 run(HandlerT &Handler, TArgTs... Args) { 592 return Handler(std::move(Args)...); 593 } 594 595 // Serialize arguments to the channel. 596 template <typename ChannelT, typename... CArgTs> 597 static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) { 598 return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...); 599 } 600 601 // Deserialize arguments from the channel. 602 template <typename ChannelT, typename... CArgTs> 603 static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) { 604 return deserializeArgsHelper(C, Args, 605 llvm::index_sequence_for<CArgTs...>()); 606 } 607 608private: 609 template <typename ChannelT, typename... CArgTs, size_t... Indexes> 610 static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args, 611 llvm::index_sequence<Indexes...> _) { 612 return SequenceSerialization<ChannelT, ArgTs...>::deserialize( 613 C, std::get<Indexes>(Args)...); 614 } 615 616 template <typename HandlerT, typename ArgTuple, size_t... Indexes> 617 static typename WrappedHandlerReturn< 618 typename HandlerTraits<HandlerT>::ReturnType>::Type 619 unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args, 620 llvm::index_sequence<Indexes...>) { 621 return run(Handler, std::move(std::get<Indexes>(Args))...); 622 } 623 624 625 template <typename HandlerT, typename ResponderT, typename ArgTuple, 626 size_t... Indexes> 627 static typename WrappedHandlerReturn< 628 typename HandlerTraits<HandlerT>::ReturnType>::Type 629 unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder, 630 ArgTuple &Args, 631 llvm::index_sequence<Indexes...>) { 632 return run(Handler, Responder, std::move(std::get<Indexes>(Args))...); 633 } 634}; 635 636// Handler traits for free functions. 637template <typename RetT, typename... ArgTs> 638class HandlerTraits<RetT(*)(ArgTs...)> 639 : public HandlerTraits<RetT(ArgTs...)> {}; 640 641// Handler traits for class methods (especially call operators for lambdas). 642template <typename Class, typename RetT, typename... ArgTs> 643class HandlerTraits<RetT (Class::*)(ArgTs...)> 644 : public HandlerTraits<RetT(ArgTs...)> {}; 645 646// Handler traits for const class methods (especially call operators for 647// lambdas). 648template <typename Class, typename RetT, typename... ArgTs> 649class HandlerTraits<RetT (Class::*)(ArgTs...) const> 650 : public HandlerTraits<RetT(ArgTs...)> {}; 651 652// Utility to peel the Expected wrapper off a response handler error type. 653template <typename HandlerT> class ResponseHandlerArg; 654 655template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> { 656public: 657 using ArgType = Expected<ArgT>; 658 using UnwrappedArgType = ArgT; 659}; 660 661template <typename ArgT> 662class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> { 663public: 664 using ArgType = Expected<ArgT>; 665 using UnwrappedArgType = ArgT; 666}; 667 668template <> class ResponseHandlerArg<Error(Error)> { 669public: 670 using ArgType = Error; 671}; 672 673template <> class ResponseHandlerArg<ErrorSuccess(Error)> { 674public: 675 using ArgType = Error; 676}; 677 678// ResponseHandler represents a handler for a not-yet-received function call 679// result. 680template <typename ChannelT> class ResponseHandler { 681public: 682 virtual ~ResponseHandler() {} 683 684 // Reads the function result off the wire and acts on it. The meaning of 685 // "act" will depend on how this method is implemented in any given 686 // ResponseHandler subclass but could, for example, mean running a 687 // user-specified handler or setting a promise value. 688 virtual Error handleResponse(ChannelT &C) = 0; 689 690 // Abandons this outstanding result. 691 virtual void abandon() = 0; 692 693 // Create an error instance representing an abandoned response. 694 static Error createAbandonedResponseError() { 695 return make_error<ResponseAbandoned>(); 696 } 697}; 698 699// ResponseHandler subclass for RPC functions with non-void returns. 700template <typename ChannelT, typename FuncRetT, typename HandlerT> 701class ResponseHandlerImpl : public ResponseHandler<ChannelT> { 702public: 703 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} 704 705 // Handle the result by deserializing it from the channel then passing it 706 // to the user defined handler. 707 Error handleResponse(ChannelT &C) override { 708 using UnwrappedArgType = typename ResponseHandlerArg< 709 typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType; 710 UnwrappedArgType Result; 711 if (auto Err = 712 SerializationTraits<ChannelT, FuncRetT, 713 UnwrappedArgType>::deserialize(C, Result)) 714 return Err; 715 if (auto Err = C.endReceiveMessage()) 716 return Err; 717 return Handler(std::move(Result)); 718 } 719 720 // Abandon this response by calling the handler with an 'abandoned response' 721 // error. 722 void abandon() override { 723 if (auto Err = Handler(this->createAbandonedResponseError())) { 724 // Handlers should not fail when passed an abandoned response error. 725 report_fatal_error(std::move(Err)); 726 } 727 } 728 729private: 730 HandlerT Handler; 731}; 732 733// ResponseHandler subclass for RPC functions with void returns. 734template <typename ChannelT, typename HandlerT> 735class ResponseHandlerImpl<ChannelT, void, HandlerT> 736 : public ResponseHandler<ChannelT> { 737public: 738 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} 739 740 // Handle the result (no actual value, just a notification that the function 741 // has completed on the remote end) by calling the user-defined handler with 742 // Error::success(). 743 Error handleResponse(ChannelT &C) override { 744 if (auto Err = C.endReceiveMessage()) 745 return Err; 746 return Handler(Error::success()); 747 } 748 749 // Abandon this response by calling the handler with an 'abandoned response' 750 // error. 751 void abandon() override { 752 if (auto Err = Handler(this->createAbandonedResponseError())) { 753 // Handlers should not fail when passed an abandoned response error. 754 report_fatal_error(std::move(Err)); 755 } 756 } 757 758private: 759 HandlerT Handler; 760}; 761 762template <typename ChannelT, typename FuncRetT, typename HandlerT> 763class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT> 764 : public ResponseHandler<ChannelT> { 765public: 766 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} 767 768 // Handle the result by deserializing it from the channel then passing it 769 // to the user defined handler. 770 Error handleResponse(ChannelT &C) override { 771 using HandlerArgType = typename ResponseHandlerArg< 772 typename HandlerTraits<HandlerT>::Type>::ArgType; 773 HandlerArgType Result((typename HandlerArgType::value_type())); 774 775 if (auto Err = 776 SerializationTraits<ChannelT, Expected<FuncRetT>, 777 HandlerArgType>::deserialize(C, Result)) 778 return Err; 779 if (auto Err = C.endReceiveMessage()) 780 return Err; 781 return Handler(std::move(Result)); 782 } 783 784 // Abandon this response by calling the handler with an 'abandoned response' 785 // error. 786 void abandon() override { 787 if (auto Err = Handler(this->createAbandonedResponseError())) { 788 // Handlers should not fail when passed an abandoned response error. 789 report_fatal_error(std::move(Err)); 790 } 791 } 792 793private: 794 HandlerT Handler; 795}; 796 797template <typename ChannelT, typename HandlerT> 798class ResponseHandlerImpl<ChannelT, Error, HandlerT> 799 : public ResponseHandler<ChannelT> { 800public: 801 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} 802 803 // Handle the result by deserializing it from the channel then passing it 804 // to the user defined handler. 805 Error handleResponse(ChannelT &C) override { 806 Error Result = Error::success(); 807 if (auto Err = 808 SerializationTraits<ChannelT, Error, Error>::deserialize(C, Result)) 809 return Err; 810 if (auto Err = C.endReceiveMessage()) 811 return Err; 812 return Handler(std::move(Result)); 813 } 814 815 // Abandon this response by calling the handler with an 'abandoned response' 816 // error. 817 void abandon() override { 818 if (auto Err = Handler(this->createAbandonedResponseError())) { 819 // Handlers should not fail when passed an abandoned response error. 820 report_fatal_error(std::move(Err)); 821 } 822 } 823 824private: 825 HandlerT Handler; 826}; 827 828// Create a ResponseHandler from a given user handler. 829template <typename ChannelT, typename FuncRetT, typename HandlerT> 830std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) { 831 return llvm::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>( 832 std::move(H)); 833} 834 835// Helper for wrapping member functions up as functors. This is useful for 836// installing methods as result handlers. 837template <typename ClassT, typename RetT, typename... ArgTs> 838class MemberFnWrapper { 839public: 840 using MethodT = RetT (ClassT::*)(ArgTs...); 841 MemberFnWrapper(ClassT &Instance, MethodT Method) 842 : Instance(Instance), Method(Method) {} 843 RetT operator()(ArgTs &&... Args) { 844 return (Instance.*Method)(std::move(Args)...); 845 } 846 847private: 848 ClassT &Instance; 849 MethodT Method; 850}; 851 852// Helper that provides a Functor for deserializing arguments. 853template <typename... ArgTs> class ReadArgs { 854public: 855 Error operator()() { return Error::success(); } 856}; 857 858template <typename ArgT, typename... ArgTs> 859class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> { 860public: 861 ReadArgs(ArgT &Arg, ArgTs &... Args) 862 : ReadArgs<ArgTs...>(Args...), Arg(Arg) {} 863 864 Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) { 865 this->Arg = std::move(ArgVal); 866 return ReadArgs<ArgTs...>::operator()(ArgVals...); 867 } 868 869private: 870 ArgT &Arg; 871}; 872 873// Manage sequence numbers. 874template <typename SequenceNumberT> class SequenceNumberManager { 875public: 876 // Reset, making all sequence numbers available. 877 void reset() { 878 std::lock_guard<std::mutex> Lock(SeqNoLock); 879 NextSequenceNumber = 0; 880 FreeSequenceNumbers.clear(); 881 } 882 883 // Get the next available sequence number. Will re-use numbers that have 884 // been released. 885 SequenceNumberT getSequenceNumber() { 886 std::lock_guard<std::mutex> Lock(SeqNoLock); 887 if (FreeSequenceNumbers.empty()) 888 return NextSequenceNumber++; 889 auto SequenceNumber = FreeSequenceNumbers.back(); 890 FreeSequenceNumbers.pop_back(); 891 return SequenceNumber; 892 } 893 894 // Release a sequence number, making it available for re-use. 895 void releaseSequenceNumber(SequenceNumberT SequenceNumber) { 896 std::lock_guard<std::mutex> Lock(SeqNoLock); 897 FreeSequenceNumbers.push_back(SequenceNumber); 898 } 899 900private: 901 std::mutex SeqNoLock; 902 SequenceNumberT NextSequenceNumber = 0; 903 std::vector<SequenceNumberT> FreeSequenceNumbers; 904}; 905 906// Checks that predicate P holds for each corresponding pair of type arguments 907// from T1 and T2 tuple. 908template <template <class, class> class P, typename T1Tuple, typename T2Tuple> 909class RPCArgTypeCheckHelper; 910 911template <template <class, class> class P> 912class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> { 913public: 914 static const bool value = true; 915}; 916 917template <template <class, class> class P, typename T, typename... Ts, 918 typename U, typename... Us> 919class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> { 920public: 921 static const bool value = 922 P<T, U>::value && 923 RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value; 924}; 925 926template <template <class, class> class P, typename T1Sig, typename T2Sig> 927class RPCArgTypeCheck { 928public: 929 using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type; 930 using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type; 931 932 static_assert(std::tuple_size<T1Tuple>::value >= 933 std::tuple_size<T2Tuple>::value, 934 "Too many arguments to RPC call"); 935 static_assert(std::tuple_size<T1Tuple>::value <= 936 std::tuple_size<T2Tuple>::value, 937 "Too few arguments to RPC call"); 938 939 static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value; 940}; 941 942template <typename ChannelT, typename WireT, typename ConcreteT> 943class CanSerialize { 944private: 945 using S = SerializationTraits<ChannelT, WireT, ConcreteT>; 946 947 template <typename T> 948 static std::true_type 949 check(typename std::enable_if< 950 std::is_same<decltype(T::serialize(std::declval<ChannelT &>(), 951 std::declval<const ConcreteT &>())), 952 Error>::value, 953 void *>::type); 954 955 template <typename> static std::false_type check(...); 956 957public: 958 static const bool value = decltype(check<S>(0))::value; 959}; 960 961template <typename ChannelT, typename WireT, typename ConcreteT> 962class CanDeserialize { 963private: 964 using S = SerializationTraits<ChannelT, WireT, ConcreteT>; 965 966 template <typename T> 967 static std::true_type 968 check(typename std::enable_if< 969 std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(), 970 std::declval<ConcreteT &>())), 971 Error>::value, 972 void *>::type); 973 974 template <typename> static std::false_type check(...); 975 976public: 977 static const bool value = decltype(check<S>(0))::value; 978}; 979 980/// Contains primitive utilities for defining, calling and handling calls to 981/// remote procedures. ChannelT is a bidirectional stream conforming to the 982/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure 983/// identifier type that must be serializable on ChannelT, and SequenceNumberT 984/// is an integral type that will be used to number in-flight function calls. 985/// 986/// These utilities support the construction of very primitive RPC utilities. 987/// Their intent is to ensure correct serialization and deserialization of 988/// procedure arguments, and to keep the client and server's view of the API in 989/// sync. 990template <typename ImplT, typename ChannelT, typename FunctionIdT, 991 typename SequenceNumberT> 992class RPCEndpointBase { 993protected: 994 class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> { 995 public: 996 static const char *getName() { return "__orc_rpc$invalid"; } 997 }; 998 999 class OrcRPCResponse : public Function<OrcRPCResponse, void()> { 1000 public: 1001 static const char *getName() { return "__orc_rpc$response"; } 1002 }; 1003 1004 class OrcRPCNegotiate 1005 : public Function<OrcRPCNegotiate, FunctionIdT(std::string)> { 1006 public: 1007 static const char *getName() { return "__orc_rpc$negotiate"; } 1008 }; 1009 1010 // Helper predicate for testing for the presence of SerializeTraits 1011 // serializers. 1012 template <typename WireT, typename ConcreteT> 1013 class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> { 1014 public: 1015 using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value; 1016 1017 static_assert(value, "Missing serializer for argument (Can't serialize the " 1018 "first template type argument of CanSerializeCheck " 1019 "from the second)"); 1020 }; 1021 1022 // Helper predicate for testing for the presence of SerializeTraits 1023 // deserializers. 1024 template <typename WireT, typename ConcreteT> 1025 class CanDeserializeCheck 1026 : detail::CanDeserialize<ChannelT, WireT, ConcreteT> { 1027 public: 1028 using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value; 1029 1030 static_assert(value, "Missing deserializer for argument (Can't deserialize " 1031 "the second template type argument of " 1032 "CanDeserializeCheck from the first)"); 1033 }; 1034 1035public: 1036 /// Construct an RPC instance on a channel. 1037 RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation) 1038 : C(C), LazyAutoNegotiation(LazyAutoNegotiation) { 1039 // Hold ResponseId in a special variable, since we expect Response to be 1040 // called relatively frequently, and want to avoid the map lookup. 1041 ResponseId = FnIdAllocator.getResponseId(); 1042 RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId; 1043 1044 // Register the negotiate function id and handler. 1045 auto NegotiateId = FnIdAllocator.getNegotiateId(); 1046 RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId; 1047 Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>( 1048 [this](const std::string &Name) { return handleNegotiate(Name); }); 1049 } 1050 1051 1052 /// Negotiate a function id for Func with the other end of the channel. 1053 template <typename Func> Error negotiateFunction(bool Retry = false) { 1054 return getRemoteFunctionId<Func>(true, Retry).takeError(); 1055 } 1056 1057 /// Append a call Func, does not call send on the channel. 1058 /// The first argument specifies a user-defined handler to be run when the 1059 /// function returns. The handler should take an Expected<Func::ReturnType>, 1060 /// or an Error (if Func::ReturnType is void). The handler will be called 1061 /// with an error if the return value is abandoned due to a channel error. 1062 template <typename Func, typename HandlerT, typename... ArgTs> 1063 Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) { 1064 1065 static_assert( 1066 detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type, 1067 void(ArgTs...)>::value, 1068 ""); 1069 1070 // Look up the function ID. 1071 FunctionIdT FnId; 1072 if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false)) 1073 FnId = *FnIdOrErr; 1074 else { 1075 // Negotiation failed. Notify the handler then return the negotiate-failed 1076 // error. 1077 cantFail(Handler(make_error<ResponseAbandoned>())); 1078 return FnIdOrErr.takeError(); 1079 } 1080 1081 SequenceNumberT SeqNo; // initialized in locked scope below. 1082 { 1083 // Lock the pending responses map and sequence number manager. 1084 std::lock_guard<std::mutex> Lock(ResponsesMutex); 1085 1086 // Allocate a sequence number. 1087 SeqNo = SequenceNumberMgr.getSequenceNumber(); 1088 assert(!PendingResponses.count(SeqNo) && 1089 "Sequence number already allocated"); 1090 1091 // Install the user handler. 1092 PendingResponses[SeqNo] = 1093 detail::createResponseHandler<ChannelT, typename Func::ReturnType>( 1094 std::move(Handler)); 1095 } 1096 1097 // Open the function call message. 1098 if (auto Err = C.startSendMessage(FnId, SeqNo)) { 1099 abandonPendingResponses(); 1100 return Err; 1101 } 1102 1103 // Serialize the call arguments. 1104 if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs( 1105 C, Args...)) { 1106 abandonPendingResponses(); 1107 return Err; 1108 } 1109 1110 // Close the function call messagee. 1111 if (auto Err = C.endSendMessage()) { 1112 abandonPendingResponses(); 1113 return Err; 1114 } 1115 1116 return Error::success(); 1117 } 1118 1119 Error sendAppendedCalls() { return C.send(); }; 1120 1121 template <typename Func, typename HandlerT, typename... ArgTs> 1122 Error callAsync(HandlerT Handler, const ArgTs &... Args) { 1123 if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...)) 1124 return Err; 1125 return C.send(); 1126 } 1127 1128 /// Handle one incoming call. 1129 Error handleOne() { 1130 FunctionIdT FnId; 1131 SequenceNumberT SeqNo; 1132 if (auto Err = C.startReceiveMessage(FnId, SeqNo)) { 1133 abandonPendingResponses(); 1134 return Err; 1135 } 1136 if (FnId == ResponseId) 1137 return handleResponse(SeqNo); 1138 auto I = Handlers.find(FnId); 1139 if (I != Handlers.end()) 1140 return I->second(C, SeqNo); 1141 1142 // else: No handler found. Report error to client? 1143 return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId, 1144 SeqNo); 1145 } 1146 1147 /// Helper for handling setter procedures - this method returns a functor that 1148 /// sets the variables referred to by Args... to values deserialized from the 1149 /// channel. 1150 /// E.g. 1151 /// 1152 /// typedef Function<0, bool, int> Func1; 1153 /// 1154 /// ... 1155 /// bool B; 1156 /// int I; 1157 /// if (auto Err = expect<Func1>(Channel, readArgs(B, I))) 1158 /// /* Handle Args */ ; 1159 /// 1160 template <typename... ArgTs> 1161 static detail::ReadArgs<ArgTs...> readArgs(ArgTs &... Args) { 1162 return detail::ReadArgs<ArgTs...>(Args...); 1163 } 1164 1165 /// Abandon all outstanding result handlers. 1166 /// 1167 /// This will call all currently registered result handlers to receive an 1168 /// "abandoned" error as their argument. This is used internally by the RPC 1169 /// in error situations, but can also be called directly by clients who are 1170 /// disconnecting from the remote and don't or can't expect responses to their 1171 /// outstanding calls. (Especially for outstanding blocking calls, calling 1172 /// this function may be necessary to avoid dead threads). 1173 void abandonPendingResponses() { 1174 // Lock the pending responses map and sequence number manager. 1175 std::lock_guard<std::mutex> Lock(ResponsesMutex); 1176 1177 for (auto &KV : PendingResponses) 1178 KV.second->abandon(); 1179 PendingResponses.clear(); 1180 SequenceNumberMgr.reset(); 1181 } 1182 1183 /// Remove the handler for the given function. 1184 /// A handler must currently be registered for this function. 1185 template <typename Func> 1186 void removeHandler() { 1187 auto IdItr = LocalFunctionIds.find(Func::getPrototype()); 1188 assert(IdItr != LocalFunctionIds.end() && 1189 "Function does not have a registered handler"); 1190 auto HandlerItr = Handlers.find(IdItr->second); 1191 assert(HandlerItr != Handlers.end() && 1192 "Function does not have a registered handler"); 1193 Handlers.erase(HandlerItr); 1194 } 1195 1196 /// Clear all handlers. 1197 void clearHandlers() { 1198 Handlers.clear(); 1199 } 1200 1201protected: 1202 1203 FunctionIdT getInvalidFunctionId() const { 1204 return FnIdAllocator.getInvalidId(); 1205 } 1206 1207 /// Add the given handler to the handler map and make it available for 1208 /// autonegotiation and execution. 1209 template <typename Func, typename HandlerT> 1210 void addHandlerImpl(HandlerT Handler) { 1211 1212 static_assert(detail::RPCArgTypeCheck< 1213 CanDeserializeCheck, typename Func::Type, 1214 typename detail::HandlerTraits<HandlerT>::Type>::value, 1215 ""); 1216 1217 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); 1218 LocalFunctionIds[Func::getPrototype()] = NewFnId; 1219 Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler)); 1220 } 1221 1222 template <typename Func, typename HandlerT> 1223 void addAsyncHandlerImpl(HandlerT Handler) { 1224 1225 static_assert(detail::RPCArgTypeCheck< 1226 CanDeserializeCheck, typename Func::Type, 1227 typename detail::AsyncHandlerTraits< 1228 typename detail::HandlerTraits<HandlerT>::Type 1229 >::Type>::value, 1230 ""); 1231 1232 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); 1233 LocalFunctionIds[Func::getPrototype()] = NewFnId; 1234 Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler)); 1235 } 1236 1237 Error handleResponse(SequenceNumberT SeqNo) { 1238 using Handler = typename decltype(PendingResponses)::mapped_type; 1239 Handler PRHandler; 1240 1241 { 1242 // Lock the pending responses map and sequence number manager. 1243 std::unique_lock<std::mutex> Lock(ResponsesMutex); 1244 auto I = PendingResponses.find(SeqNo); 1245 1246 if (I != PendingResponses.end()) { 1247 PRHandler = std::move(I->second); 1248 PendingResponses.erase(I); 1249 SequenceNumberMgr.releaseSequenceNumber(SeqNo); 1250 } else { 1251 // Unlock the pending results map to prevent recursive lock. 1252 Lock.unlock(); 1253 abandonPendingResponses(); 1254 return make_error< 1255 InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo); 1256 } 1257 } 1258 1259 assert(PRHandler && 1260 "If we didn't find a response handler we should have bailed out"); 1261 1262 if (auto Err = PRHandler->handleResponse(C)) { 1263 abandonPendingResponses(); 1264 return Err; 1265 } 1266 1267 return Error::success(); 1268 } 1269 1270 FunctionIdT handleNegotiate(const std::string &Name) { 1271 auto I = LocalFunctionIds.find(Name); 1272 if (I == LocalFunctionIds.end()) 1273 return getInvalidFunctionId(); 1274 return I->second; 1275 } 1276 1277 // Find the remote FunctionId for the given function. 1278 template <typename Func> 1279 Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap, 1280 bool NegotiateIfInvalid) { 1281 bool DoNegotiate; 1282 1283 // Check if we already have a function id... 1284 auto I = RemoteFunctionIds.find(Func::getPrototype()); 1285 if (I != RemoteFunctionIds.end()) { 1286 // If it's valid there's nothing left to do. 1287 if (I->second != getInvalidFunctionId()) 1288 return I->second; 1289 DoNegotiate = NegotiateIfInvalid; 1290 } else 1291 DoNegotiate = NegotiateIfNotInMap; 1292 1293 // We don't have a function id for Func yet, but we're allowed to try to 1294 // negotiate one. 1295 if (DoNegotiate) { 1296 auto &Impl = static_cast<ImplT &>(*this); 1297 if (auto RemoteIdOrErr = 1298 Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) { 1299 RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; 1300 if (*RemoteIdOrErr == getInvalidFunctionId()) 1301 return make_error<CouldNotNegotiate>(Func::getPrototype()); 1302 return *RemoteIdOrErr; 1303 } else 1304 return RemoteIdOrErr.takeError(); 1305 } 1306 1307 // No key was available in the map and we weren't allowed to try to 1308 // negotiate one, so return an unknown function error. 1309 return make_error<CouldNotNegotiate>(Func::getPrototype()); 1310 } 1311 1312 using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>; 1313 1314 // Wrap the given user handler in the necessary argument-deserialization code, 1315 // result-serialization code, and call to the launch policy (if present). 1316 template <typename Func, typename HandlerT> 1317 WrappedHandlerFn wrapHandler(HandlerT Handler) { 1318 return [this, Handler](ChannelT &Channel, 1319 SequenceNumberT SeqNo) mutable -> Error { 1320 // Start by deserializing the arguments. 1321 using ArgsTuple = 1322 typename detail::FunctionArgsTuple< 1323 typename detail::HandlerTraits<HandlerT>::Type>::Type; 1324 auto Args = std::make_shared<ArgsTuple>(); 1325 1326 if (auto Err = 1327 detail::HandlerTraits<typename Func::Type>::deserializeArgs( 1328 Channel, *Args)) 1329 return Err; 1330 1331 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning 1332 // for RPCArgs. Void cast RPCArgs to work around this for now. 1333 // FIXME: Remove this workaround once we can assume a working GCC version. 1334 (void)Args; 1335 1336 // End receieve message, unlocking the channel for reading. 1337 if (auto Err = Channel.endReceiveMessage()) 1338 return Err; 1339 1340 using HTraits = detail::HandlerTraits<HandlerT>; 1341 using FuncReturn = typename Func::ReturnType; 1342 return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo, 1343 HTraits::unpackAndRun(Handler, *Args)); 1344 }; 1345 } 1346 1347 // Wrap the given user handler in the necessary argument-deserialization code, 1348 // result-serialization code, and call to the launch policy (if present). 1349 template <typename Func, typename HandlerT> 1350 WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) { 1351 return [this, Handler](ChannelT &Channel, 1352 SequenceNumberT SeqNo) mutable -> Error { 1353 // Start by deserializing the arguments. 1354 using AHTraits = detail::AsyncHandlerTraits< 1355 typename detail::HandlerTraits<HandlerT>::Type>; 1356 using ArgsTuple = 1357 typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type; 1358 auto Args = std::make_shared<ArgsTuple>(); 1359 1360 if (auto Err = 1361 detail::HandlerTraits<typename Func::Type>::deserializeArgs( 1362 Channel, *Args)) 1363 return Err; 1364 1365 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning 1366 // for RPCArgs. Void cast RPCArgs to work around this for now. 1367 // FIXME: Remove this workaround once we can assume a working GCC version. 1368 (void)Args; 1369 1370 // End receieve message, unlocking the channel for reading. 1371 if (auto Err = Channel.endReceiveMessage()) 1372 return Err; 1373 1374 using HTraits = detail::HandlerTraits<HandlerT>; 1375 using FuncReturn = typename Func::ReturnType; 1376 auto Responder = 1377 [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error { 1378 return detail::respond<FuncReturn>(C, ResponseId, SeqNo, 1379 std::move(RetVal)); 1380 }; 1381 1382 return HTraits::unpackAndRunAsync(Handler, Responder, *Args); 1383 }; 1384 } 1385 1386 ChannelT &C; 1387 1388 bool LazyAutoNegotiation; 1389 1390 RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator; 1391 1392 FunctionIdT ResponseId; 1393 std::map<std::string, FunctionIdT> LocalFunctionIds; 1394 std::map<const char *, FunctionIdT> RemoteFunctionIds; 1395 1396 std::map<FunctionIdT, WrappedHandlerFn> Handlers; 1397 1398 std::mutex ResponsesMutex; 1399 detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr; 1400 std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>> 1401 PendingResponses; 1402}; 1403 1404} // end namespace detail 1405 1406template <typename ChannelT, typename FunctionIdT = uint32_t, 1407 typename SequenceNumberT = uint32_t> 1408class MultiThreadedRPCEndpoint 1409 : public detail::RPCEndpointBase< 1410 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, 1411 ChannelT, FunctionIdT, SequenceNumberT> { 1412private: 1413 using BaseClass = 1414 detail::RPCEndpointBase< 1415 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, 1416 ChannelT, FunctionIdT, SequenceNumberT>; 1417 1418public: 1419 MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) 1420 : BaseClass(C, LazyAutoNegotiation) {} 1421 1422 /// Add a handler for the given RPC function. 1423 /// This installs the given handler functor for the given RPC Function, and 1424 /// makes the RPC function available for negotiation/calling from the remote. 1425 template <typename Func, typename HandlerT> 1426 void addHandler(HandlerT Handler) { 1427 return this->template addHandlerImpl<Func>(std::move(Handler)); 1428 } 1429 1430 /// Add a class-method as a handler. 1431 template <typename Func, typename ClassT, typename RetT, typename... ArgTs> 1432 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { 1433 addHandler<Func>( 1434 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); 1435 } 1436 1437 template <typename Func, typename HandlerT> 1438 void addAsyncHandler(HandlerT Handler) { 1439 return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); 1440 } 1441 1442 /// Add a class-method as a handler. 1443 template <typename Func, typename ClassT, typename RetT, typename... ArgTs> 1444 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { 1445 addAsyncHandler<Func>( 1446 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); 1447 } 1448 1449 /// Return type for non-blocking call primitives. 1450 template <typename Func> 1451 using NonBlockingCallResult = typename detail::ResultTraits< 1452 typename Func::ReturnType>::ReturnFutureType; 1453 1454 /// Call Func on Channel C. Does not block, does not call send. Returns a pair 1455 /// of a future result and the sequence number assigned to the result. 1456 /// 1457 /// This utility function is primarily used for single-threaded mode support, 1458 /// where the sequence number can be used to wait for the corresponding 1459 /// result. In multi-threaded mode the appendCallNB method, which does not 1460 /// return the sequence numeber, should be preferred. 1461 template <typename Func, typename... ArgTs> 1462 Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &... Args) { 1463 using RTraits = detail::ResultTraits<typename Func::ReturnType>; 1464 using ErrorReturn = typename RTraits::ErrorReturnType; 1465 using ErrorReturnPromise = typename RTraits::ReturnPromiseType; 1466 1467 // FIXME: Stack allocate and move this into the handler once LLVM builds 1468 // with C++14. 1469 auto Promise = std::make_shared<ErrorReturnPromise>(); 1470 auto FutureResult = Promise->get_future(); 1471 1472 if (auto Err = this->template appendCallAsync<Func>( 1473 [Promise](ErrorReturn RetOrErr) { 1474 Promise->set_value(std::move(RetOrErr)); 1475 return Error::success(); 1476 }, 1477 Args...)) { 1478 RTraits::consumeAbandoned(FutureResult.get()); 1479 return std::move(Err); 1480 } 1481 return std::move(FutureResult); 1482 } 1483 1484 /// The same as appendCallNBWithSeq, except that it calls C.send() to 1485 /// flush the channel after serializing the call. 1486 template <typename Func, typename... ArgTs> 1487 Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &... Args) { 1488 auto Result = appendCallNB<Func>(Args...); 1489 if (!Result) 1490 return Result; 1491 if (auto Err = this->C.send()) { 1492 this->abandonPendingResponses(); 1493 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( 1494 std::move(Result->get())); 1495 return std::move(Err); 1496 } 1497 return Result; 1498 } 1499 1500 /// Call Func on Channel C. Blocks waiting for a result. Returns an Error 1501 /// for void functions or an Expected<T> for functions returning a T. 1502 /// 1503 /// This function is for use in threaded code where another thread is 1504 /// handling responses and incoming calls. 1505 template <typename Func, typename... ArgTs, 1506 typename AltRetT = typename Func::ReturnType> 1507 typename detail::ResultTraits<AltRetT>::ErrorReturnType 1508 callB(const ArgTs &... Args) { 1509 if (auto FutureResOrErr = callNB<Func>(Args...)) 1510 return FutureResOrErr->get(); 1511 else 1512 return FutureResOrErr.takeError(); 1513 } 1514 1515 /// Handle incoming RPC calls. 1516 Error handlerLoop() { 1517 while (true) 1518 if (auto Err = this->handleOne()) 1519 return Err; 1520 return Error::success(); 1521 } 1522}; 1523 1524template <typename ChannelT, typename FunctionIdT = uint32_t, 1525 typename SequenceNumberT = uint32_t> 1526class SingleThreadedRPCEndpoint 1527 : public detail::RPCEndpointBase< 1528 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, 1529 ChannelT, FunctionIdT, SequenceNumberT> { 1530private: 1531 using BaseClass = 1532 detail::RPCEndpointBase< 1533 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, 1534 ChannelT, FunctionIdT, SequenceNumberT>; 1535 1536public: 1537 SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) 1538 : BaseClass(C, LazyAutoNegotiation) {} 1539 1540 template <typename Func, typename HandlerT> 1541 void addHandler(HandlerT Handler) { 1542 return this->template addHandlerImpl<Func>(std::move(Handler)); 1543 } 1544 1545 template <typename Func, typename ClassT, typename RetT, typename... ArgTs> 1546 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { 1547 addHandler<Func>( 1548 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); 1549 } 1550 1551 template <typename Func, typename HandlerT> 1552 void addAsyncHandler(HandlerT Handler) { 1553 return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); 1554 } 1555 1556 /// Add a class-method as a handler. 1557 template <typename Func, typename ClassT, typename RetT, typename... ArgTs> 1558 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { 1559 addAsyncHandler<Func>( 1560 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); 1561 } 1562 1563 template <typename Func, typename... ArgTs, 1564 typename AltRetT = typename Func::ReturnType> 1565 typename detail::ResultTraits<AltRetT>::ErrorReturnType 1566 callB(const ArgTs &... Args) { 1567 bool ReceivedResponse = false; 1568 using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType; 1569 auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue(); 1570 1571 // We have to 'Check' result (which we know is in a success state at this 1572 // point) so that it can be overwritten in the async handler. 1573 (void)!!Result; 1574 1575 if (auto Err = this->template appendCallAsync<Func>( 1576 [&](ResultType R) { 1577 Result = std::move(R); 1578 ReceivedResponse = true; 1579 return Error::success(); 1580 }, 1581 Args...)) { 1582 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( 1583 std::move(Result)); 1584 return std::move(Err); 1585 } 1586 1587 while (!ReceivedResponse) { 1588 if (auto Err = this->handleOne()) { 1589 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( 1590 std::move(Result)); 1591 return std::move(Err); 1592 } 1593 } 1594 1595 return Result; 1596 } 1597}; 1598 1599/// Asynchronous dispatch for a function on an RPC endpoint. 1600template <typename RPCClass, typename Func> 1601class RPCAsyncDispatch { 1602public: 1603 RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {} 1604 1605 template <typename HandlerT, typename... ArgTs> 1606 Error operator()(HandlerT Handler, const ArgTs &... Args) const { 1607 return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...); 1608 } 1609 1610private: 1611 RPCClass &Endpoint; 1612}; 1613 1614/// Construct an asynchronous dispatcher from an RPC endpoint and a Func. 1615template <typename Func, typename RPCEndpointT> 1616RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) { 1617 return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint); 1618} 1619 1620/// \brief Allows a set of asynchrounous calls to be dispatched, and then 1621/// waited on as a group. 1622class ParallelCallGroup { 1623public: 1624 1625 ParallelCallGroup() = default; 1626 ParallelCallGroup(const ParallelCallGroup &) = delete; 1627 ParallelCallGroup &operator=(const ParallelCallGroup &) = delete; 1628 1629 /// \brief Make as asynchronous call. 1630 template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs> 1631 Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler, 1632 const ArgTs &... Args) { 1633 // Increment the count of outstanding calls. This has to happen before 1634 // we invoke the call, as the handler may (depending on scheduling) 1635 // be run immediately on another thread, and we don't want the decrement 1636 // in the wrapped handler below to run before the increment. 1637 { 1638 std::unique_lock<std::mutex> Lock(M); 1639 ++NumOutstandingCalls; 1640 } 1641 1642 // Wrap the user handler in a lambda that will decrement the 1643 // outstanding calls count, then poke the condition variable. 1644 using ArgType = typename detail::ResponseHandlerArg< 1645 typename detail::HandlerTraits<HandlerT>::Type>::ArgType; 1646 // FIXME: Move handler into wrapped handler once we have C++14. 1647 auto WrappedHandler = [this, Handler](ArgType Arg) { 1648 auto Err = Handler(std::move(Arg)); 1649 std::unique_lock<std::mutex> Lock(M); 1650 --NumOutstandingCalls; 1651 CV.notify_all(); 1652 return Err; 1653 }; 1654 1655 return AsyncDispatch(std::move(WrappedHandler), Args...); 1656 } 1657 1658 /// \brief Blocks until all calls have been completed and their return value 1659 /// handlers run. 1660 void wait() { 1661 std::unique_lock<std::mutex> Lock(M); 1662 while (NumOutstandingCalls > 0) 1663 CV.wait(Lock); 1664 } 1665 1666private: 1667 std::mutex M; 1668 std::condition_variable CV; 1669 uint32_t NumOutstandingCalls = 0; 1670}; 1671 1672/// @brief Convenience class for grouping RPC Functions into APIs that can be 1673/// negotiated as a block. 1674/// 1675template <typename... Funcs> 1676class APICalls { 1677public: 1678 1679 /// @brief Test whether this API contains Function F. 1680 template <typename F> 1681 class Contains { 1682 public: 1683 static const bool value = false; 1684 }; 1685 1686 /// @brief Negotiate all functions in this API. 1687 template <typename RPCEndpoint> 1688 static Error negotiate(RPCEndpoint &R) { 1689 return Error::success(); 1690 } 1691}; 1692 1693template <typename Func, typename... Funcs> 1694class APICalls<Func, Funcs...> { 1695public: 1696 1697 template <typename F> 1698 class Contains { 1699 public: 1700 static const bool value = std::is_same<F, Func>::value | 1701 APICalls<Funcs...>::template Contains<F>::value; 1702 }; 1703 1704 template <typename RPCEndpoint> 1705 static Error negotiate(RPCEndpoint &R) { 1706 if (auto Err = R.template negotiateFunction<Func>()) 1707 return Err; 1708 return APICalls<Funcs...>::negotiate(R); 1709 } 1710 1711}; 1712 1713template <typename... InnerFuncs, typename... Funcs> 1714class APICalls<APICalls<InnerFuncs...>, Funcs...> { 1715public: 1716 1717 template <typename F> 1718 class Contains { 1719 public: 1720 static const bool value = 1721 APICalls<InnerFuncs...>::template Contains<F>::value | 1722 APICalls<Funcs...>::template Contains<F>::value; 1723 }; 1724 1725 template <typename RPCEndpoint> 1726 static Error negotiate(RPCEndpoint &R) { 1727 if (auto Err = APICalls<InnerFuncs...>::negotiate(R)) 1728 return Err; 1729 return APICalls<Funcs...>::negotiate(R); 1730 } 1731 1732}; 1733 1734} // end namespace rpc 1735} // end namespace orc 1736} // end namespace llvm 1737 1738#endif 1739