1//===- llvm/ExecutionEngine/Orc/RPCSerialization.h --------------*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9
10#ifndef LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H
11#define LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H
12
13#include "OrcError.h"
14#include "llvm/Support/thread.h"
15#include <map>
16#include <mutex>
17#include <sstream>
18
19namespace llvm {
20namespace orc {
21namespace rpc {
22
23template <typename T>
24class RPCTypeName;
25
26/// TypeNameSequence is a utility for rendering sequences of types to a string
27/// by rendering each type, separated by ", ".
28template <typename... ArgTs> class RPCTypeNameSequence {};
29
30/// Render an empty TypeNameSequence to an ostream.
31template <typename OStream>
32OStream &operator<<(OStream &OS, const RPCTypeNameSequence<> &V) {
33  return OS;
34}
35
36/// Render a TypeNameSequence of a single type to an ostream.
37template <typename OStream, typename ArgT>
38OStream &operator<<(OStream &OS, const RPCTypeNameSequence<ArgT> &V) {
39  OS << RPCTypeName<ArgT>::getName();
40  return OS;
41}
42
43/// Render a TypeNameSequence of more than one type to an ostream.
44template <typename OStream, typename ArgT1, typename ArgT2, typename... ArgTs>
45OStream&
46operator<<(OStream &OS, const RPCTypeNameSequence<ArgT1, ArgT2, ArgTs...> &V) {
47  OS << RPCTypeName<ArgT1>::getName() << ", "
48     << RPCTypeNameSequence<ArgT2, ArgTs...>();
49  return OS;
50}
51
52template <>
53class RPCTypeName<void> {
54public:
55  static const char* getName() { return "void"; }
56};
57
58template <>
59class RPCTypeName<int8_t> {
60public:
61  static const char* getName() { return "int8_t"; }
62};
63
64template <>
65class RPCTypeName<uint8_t> {
66public:
67  static const char* getName() { return "uint8_t"; }
68};
69
70template <>
71class RPCTypeName<int16_t> {
72public:
73  static const char* getName() { return "int16_t"; }
74};
75
76template <>
77class RPCTypeName<uint16_t> {
78public:
79  static const char* getName() { return "uint16_t"; }
80};
81
82template <>
83class RPCTypeName<int32_t> {
84public:
85  static const char* getName() { return "int32_t"; }
86};
87
88template <>
89class RPCTypeName<uint32_t> {
90public:
91  static const char* getName() { return "uint32_t"; }
92};
93
94template <>
95class RPCTypeName<int64_t> {
96public:
97  static const char* getName() { return "int64_t"; }
98};
99
100template <>
101class RPCTypeName<uint64_t> {
102public:
103  static const char* getName() { return "uint64_t"; }
104};
105
106template <>
107class RPCTypeName<bool> {
108public:
109  static const char* getName() { return "bool"; }
110};
111
112template <>
113class RPCTypeName<std::string> {
114public:
115  static const char* getName() { return "std::string"; }
116};
117
118template <>
119class RPCTypeName<Error> {
120public:
121  static const char* getName() { return "Error"; }
122};
123
124template <typename T>
125class RPCTypeName<Expected<T>> {
126public:
127  static const char* getName() {
128    std::lock_guard<std::mutex> Lock(NameMutex);
129    if (Name.empty())
130      raw_string_ostream(Name) << "Expected<"
131                               << RPCTypeNameSequence<T>()
132                               << ">";
133    return Name.data();
134  }
135
136private:
137  static std::mutex NameMutex;
138  static std::string Name;
139};
140
141template <typename T>
142std::mutex RPCTypeName<Expected<T>>::NameMutex;
143
144template <typename T>
145std::string RPCTypeName<Expected<T>>::Name;
146
147template <typename T1, typename T2>
148class RPCTypeName<std::pair<T1, T2>> {
149public:
150  static const char* getName() {
151    std::lock_guard<std::mutex> Lock(NameMutex);
152    if (Name.empty())
153      raw_string_ostream(Name) << "std::pair<" << RPCTypeNameSequence<T1, T2>()
154                               << ">";
155    return Name.data();
156  }
157private:
158  static std::mutex NameMutex;
159  static std::string Name;
160};
161
162template <typename T1, typename T2>
163std::mutex RPCTypeName<std::pair<T1, T2>>::NameMutex;
164template <typename T1, typename T2>
165std::string RPCTypeName<std::pair<T1, T2>>::Name;
166
167template <typename... ArgTs>
168class RPCTypeName<std::tuple<ArgTs...>> {
169public:
170  static const char* getName() {
171    std::lock_guard<std::mutex> Lock(NameMutex);
172    if (Name.empty())
173      raw_string_ostream(Name) << "std::tuple<"
174                               << RPCTypeNameSequence<ArgTs...>() << ">";
175    return Name.data();
176  }
177private:
178  static std::mutex NameMutex;
179  static std::string Name;
180};
181
182template <typename... ArgTs>
183std::mutex RPCTypeName<std::tuple<ArgTs...>>::NameMutex;
184template <typename... ArgTs>
185std::string RPCTypeName<std::tuple<ArgTs...>>::Name;
186
187template <typename T>
188class RPCTypeName<std::vector<T>> {
189public:
190  static const char*getName() {
191    std::lock_guard<std::mutex> Lock(NameMutex);
192    if (Name.empty())
193      raw_string_ostream(Name) << "std::vector<" << RPCTypeName<T>::getName()
194                               << ">";
195    return Name.data();
196  }
197
198private:
199  static std::mutex NameMutex;
200  static std::string Name;
201};
202
203template <typename T>
204std::mutex RPCTypeName<std::vector<T>>::NameMutex;
205template <typename T>
206std::string RPCTypeName<std::vector<T>>::Name;
207
208
209/// The SerializationTraits<ChannelT, T> class describes how to serialize and
210/// deserialize an instance of type T to/from an abstract channel of type
211/// ChannelT. It also provides a representation of the type's name via the
212/// getName method.
213///
214/// Specializations of this class should provide the following functions:
215///
216///   @code{.cpp}
217///
218///   static const char* getName();
219///   static Error serialize(ChannelT&, const T&);
220///   static Error deserialize(ChannelT&, T&);
221///
222///   @endcode
223///
224/// The third argument of SerializationTraits is intended to support SFINAE.
225/// E.g.:
226///
227///   @code{.cpp}
228///
229///   class MyVirtualChannel { ... };
230///
231///   template <DerivedChannelT>
232///   class SerializationTraits<DerivedChannelT, bool,
233///         typename std::enable_if<
234///           std::is_base_of<VirtChannel, DerivedChannel>::value
235///         >::type> {
236///   public:
237///     static const char* getName() { ... };
238///   }
239///
240///   @endcode
241template <typename ChannelT, typename WireType,
242          typename ConcreteType = WireType, typename = void>
243class SerializationTraits;
244
245template <typename ChannelT>
246class SequenceTraits {
247public:
248  static Error emitSeparator(ChannelT &C) { return Error::success(); }
249  static Error consumeSeparator(ChannelT &C) { return Error::success(); }
250};
251
252/// Utility class for serializing sequences of values of varying types.
253/// Specializations of this class contain 'serialize' and 'deserialize' methods
254/// for the given channel. The ArgTs... list will determine the "over-the-wire"
255/// types to be serialized. The serialize and deserialize methods take a list
256/// CArgTs... ("caller arg types") which must be the same length as ArgTs...,
257/// but may be different types from ArgTs, provided that for each CArgT there
258/// is a SerializationTraits specialization
259/// SerializeTraits<ChannelT, ArgT, CArgT> with methods that can serialize the
260/// caller argument to over-the-wire value.
261template <typename ChannelT, typename... ArgTs>
262class SequenceSerialization;
263
264template <typename ChannelT>
265class SequenceSerialization<ChannelT> {
266public:
267  static Error serialize(ChannelT &C) { return Error::success(); }
268  static Error deserialize(ChannelT &C) { return Error::success(); }
269};
270
271template <typename ChannelT, typename ArgT>
272class SequenceSerialization<ChannelT, ArgT> {
273public:
274
275  template <typename CArgT>
276  static Error serialize(ChannelT &C, CArgT &&CArg) {
277    return SerializationTraits<ChannelT, ArgT,
278                               typename std::decay<CArgT>::type>::
279             serialize(C, std::forward<CArgT>(CArg));
280  }
281
282  template <typename CArgT>
283  static Error deserialize(ChannelT &C, CArgT &CArg) {
284    return SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg);
285  }
286};
287
288template <typename ChannelT, typename ArgT, typename... ArgTs>
289class SequenceSerialization<ChannelT, ArgT, ArgTs...> {
290public:
291
292  template <typename CArgT, typename... CArgTs>
293  static Error serialize(ChannelT &C, CArgT &&CArg,
294                         CArgTs &&... CArgs) {
295    if (auto Err =
296        SerializationTraits<ChannelT, ArgT, typename std::decay<CArgT>::type>::
297          serialize(C, std::forward<CArgT>(CArg)))
298      return Err;
299    if (auto Err = SequenceTraits<ChannelT>::emitSeparator(C))
300      return Err;
301    return SequenceSerialization<ChannelT, ArgTs...>::
302             serialize(C, std::forward<CArgTs>(CArgs)...);
303  }
304
305  template <typename CArgT, typename... CArgTs>
306  static Error deserialize(ChannelT &C, CArgT &CArg,
307                           CArgTs &... CArgs) {
308    if (auto Err =
309        SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg))
310      return Err;
311    if (auto Err = SequenceTraits<ChannelT>::consumeSeparator(C))
312      return Err;
313    return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, CArgs...);
314  }
315};
316
317template <typename ChannelT, typename... ArgTs>
318Error serializeSeq(ChannelT &C, ArgTs &&... Args) {
319  return SequenceSerialization<ChannelT, typename std::decay<ArgTs>::type...>::
320           serialize(C, std::forward<ArgTs>(Args)...);
321}
322
323template <typename ChannelT, typename... ArgTs>
324Error deserializeSeq(ChannelT &C, ArgTs &... Args) {
325  return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, Args...);
326}
327
328template <typename ChannelT>
329class SerializationTraits<ChannelT, Error> {
330public:
331
332  using WrappedErrorSerializer =
333    std::function<Error(ChannelT &C, const ErrorInfoBase&)>;
334
335  using WrappedErrorDeserializer =
336    std::function<Error(ChannelT &C, Error &Err)>;
337
338  template <typename ErrorInfoT, typename SerializeFtor,
339            typename DeserializeFtor>
340  static void registerErrorType(std::string Name, SerializeFtor Serialize,
341                                DeserializeFtor Deserialize) {
342    assert(!Name.empty() &&
343           "The empty string is reserved for the Success value");
344
345    const std::string *KeyName = nullptr;
346    {
347      // We're abusing the stability of std::map here: We take a reference to the
348      // key of the deserializers map to save us from duplicating the string in
349      // the serializer. This should be changed to use a stringpool if we switch
350      // to a map type that may move keys in memory.
351      std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex);
352      auto I =
353        Deserializers.insert(Deserializers.begin(),
354                             std::make_pair(std::move(Name),
355                                            std::move(Deserialize)));
356      KeyName = &I->first;
357    }
358
359    {
360      assert(KeyName != nullptr && "No keyname pointer");
361      std::lock_guard<std::recursive_mutex> Lock(SerializersMutex);
362      // FIXME: Move capture Serialize once we have C++14.
363      Serializers[ErrorInfoT::classID()] =
364          [KeyName, Serialize](ChannelT &C, const ErrorInfoBase &EIB) -> Error {
365        assert(EIB.dynamicClassID() == ErrorInfoT::classID() &&
366               "Serializer called for wrong error type");
367        if (auto Err = serializeSeq(C, *KeyName))
368          return Err;
369        return Serialize(C, static_cast<const ErrorInfoT &>(EIB));
370      };
371    }
372  }
373
374  static Error serialize(ChannelT &C, Error &&Err) {
375    std::lock_guard<std::recursive_mutex> Lock(SerializersMutex);
376
377    if (!Err)
378      return serializeSeq(C, std::string());
379
380    return handleErrors(std::move(Err),
381                        [&C](const ErrorInfoBase &EIB) {
382                          auto SI = Serializers.find(EIB.dynamicClassID());
383                          if (SI == Serializers.end())
384                            return serializeAsStringError(C, EIB);
385                          return (SI->second)(C, EIB);
386                        });
387  }
388
389  static Error deserialize(ChannelT &C, Error &Err) {
390    std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex);
391
392    std::string Key;
393    if (auto Err = deserializeSeq(C, Key))
394      return Err;
395
396    if (Key.empty()) {
397      ErrorAsOutParameter EAO(&Err);
398      Err = Error::success();
399      return Error::success();
400    }
401
402    auto DI = Deserializers.find(Key);
403    assert(DI != Deserializers.end() && "No deserializer for error type");
404    return (DI->second)(C, Err);
405  }
406
407private:
408
409  static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) {
410    std::string ErrMsg;
411    {
412      raw_string_ostream ErrMsgStream(ErrMsg);
413      EIB.log(ErrMsgStream);
414    }
415    return serialize(C, make_error<StringError>(std::move(ErrMsg),
416                                                inconvertibleErrorCode()));
417  }
418
419  static std::recursive_mutex SerializersMutex;
420  static std::recursive_mutex DeserializersMutex;
421  static std::map<const void*, WrappedErrorSerializer> Serializers;
422  static std::map<std::string, WrappedErrorDeserializer> Deserializers;
423};
424
425template <typename ChannelT>
426std::recursive_mutex SerializationTraits<ChannelT, Error>::SerializersMutex;
427
428template <typename ChannelT>
429std::recursive_mutex SerializationTraits<ChannelT, Error>::DeserializersMutex;
430
431template <typename ChannelT>
432std::map<const void*,
433         typename SerializationTraits<ChannelT, Error>::WrappedErrorSerializer>
434SerializationTraits<ChannelT, Error>::Serializers;
435
436template <typename ChannelT>
437std::map<std::string,
438         typename SerializationTraits<ChannelT, Error>::WrappedErrorDeserializer>
439SerializationTraits<ChannelT, Error>::Deserializers;
440
441/// Registers a serializer and deserializer for the given error type on the
442/// given channel type.
443template <typename ChannelT, typename ErrorInfoT, typename SerializeFtor,
444          typename DeserializeFtor>
445void registerErrorSerialization(std::string Name, SerializeFtor &&Serialize,
446                                DeserializeFtor &&Deserialize) {
447  SerializationTraits<ChannelT, Error>::template registerErrorType<ErrorInfoT>(
448    std::move(Name),
449    std::forward<SerializeFtor>(Serialize),
450    std::forward<DeserializeFtor>(Deserialize));
451}
452
453/// Registers serialization/deserialization for StringError.
454template <typename ChannelT>
455void registerStringError() {
456  static bool AlreadyRegistered = false;
457  if (!AlreadyRegistered) {
458    registerErrorSerialization<ChannelT, StringError>(
459      "StringError",
460      [](ChannelT &C, const StringError &SE) {
461        return serializeSeq(C, SE.getMessage());
462      },
463      [](ChannelT &C, Error &Err) -> Error {
464        ErrorAsOutParameter EAO(&Err);
465        std::string Msg;
466        if (auto E2 = deserializeSeq(C, Msg))
467          return E2;
468        Err =
469          make_error<StringError>(std::move(Msg),
470                                  orcError(
471                                    OrcErrorCode::UnknownErrorCodeFromRemote));
472        return Error::success();
473      });
474    AlreadyRegistered = true;
475  }
476}
477
478/// SerializationTraits for Expected<T1> from an Expected<T2>.
479template <typename ChannelT, typename T1, typename T2>
480class SerializationTraits<ChannelT, Expected<T1>, Expected<T2>> {
481public:
482
483  static Error serialize(ChannelT &C, Expected<T2> &&ValOrErr) {
484    if (ValOrErr) {
485      if (auto Err = serializeSeq(C, true))
486        return Err;
487      return SerializationTraits<ChannelT, T1, T2>::serialize(C, *ValOrErr);
488    }
489    if (auto Err = serializeSeq(C, false))
490      return Err;
491    return serializeSeq(C, ValOrErr.takeError());
492  }
493
494  static Error deserialize(ChannelT &C, Expected<T2> &ValOrErr) {
495    ExpectedAsOutParameter<T2> EAO(&ValOrErr);
496    bool HasValue;
497    if (auto Err = deserializeSeq(C, HasValue))
498      return Err;
499    if (HasValue)
500      return SerializationTraits<ChannelT, T1, T2>::deserialize(C, *ValOrErr);
501    Error Err = Error::success();
502    if (auto E2 = deserializeSeq(C, Err))
503      return E2;
504    ValOrErr = std::move(Err);
505    return Error::success();
506  }
507};
508
509/// SerializationTraits for Expected<T1> from a T2.
510template <typename ChannelT, typename T1, typename T2>
511class SerializationTraits<ChannelT, Expected<T1>, T2> {
512public:
513
514  static Error serialize(ChannelT &C, T2 &&Val) {
515    return serializeSeq(C, Expected<T2>(std::forward<T2>(Val)));
516  }
517};
518
519/// SerializationTraits for Expected<T1> from an Error.
520template <typename ChannelT, typename T>
521class SerializationTraits<ChannelT, Expected<T>, Error> {
522public:
523
524  static Error serialize(ChannelT &C, Error &&Err) {
525    return serializeSeq(C, Expected<T>(std::move(Err)));
526  }
527};
528
529/// SerializationTraits default specialization for std::pair.
530template <typename ChannelT, typename T1, typename T2>
531class SerializationTraits<ChannelT, std::pair<T1, T2>> {
532public:
533  static Error serialize(ChannelT &C, const std::pair<T1, T2> &V) {
534    return serializeSeq(C, V.first, V.second);
535  }
536
537  static Error deserialize(ChannelT &C, std::pair<T1, T2> &V) {
538    return deserializeSeq(C, V.first, V.second);
539  }
540};
541
542/// SerializationTraits default specialization for std::tuple.
543template <typename ChannelT, typename... ArgTs>
544class SerializationTraits<ChannelT, std::tuple<ArgTs...>> {
545public:
546
547  /// RPC channel serialization for std::tuple.
548  static Error serialize(ChannelT &C, const std::tuple<ArgTs...> &V) {
549    return serializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>());
550  }
551
552  /// RPC channel deserialization for std::tuple.
553  static Error deserialize(ChannelT &C, std::tuple<ArgTs...> &V) {
554    return deserializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>());
555  }
556
557private:
558  // Serialization helper for std::tuple.
559  template <size_t... Is>
560  static Error serializeTupleHelper(ChannelT &C, const std::tuple<ArgTs...> &V,
561                                    llvm::index_sequence<Is...> _) {
562    return serializeSeq(C, std::get<Is>(V)...);
563  }
564
565  // Serialization helper for std::tuple.
566  template <size_t... Is>
567  static Error deserializeTupleHelper(ChannelT &C, std::tuple<ArgTs...> &V,
568                                      llvm::index_sequence<Is...> _) {
569    return deserializeSeq(C, std::get<Is>(V)...);
570  }
571};
572
573/// SerializationTraits default specialization for std::vector.
574template <typename ChannelT, typename T>
575class SerializationTraits<ChannelT, std::vector<T>> {
576public:
577
578  /// Serialize a std::vector<T> from std::vector<T>.
579  static Error serialize(ChannelT &C, const std::vector<T> &V) {
580    if (auto Err = serializeSeq(C, static_cast<uint64_t>(V.size())))
581      return Err;
582
583    for (const auto &E : V)
584      if (auto Err = serializeSeq(C, E))
585        return Err;
586
587    return Error::success();
588  }
589
590  /// Deserialize a std::vector<T> to a std::vector<T>.
591  static Error deserialize(ChannelT &C, std::vector<T> &V) {
592    uint64_t Count = 0;
593    if (auto Err = deserializeSeq(C, Count))
594      return Err;
595
596    V.resize(Count);
597    for (auto &E : V)
598      if (auto Err = deserializeSeq(C, E))
599        return Err;
600
601    return Error::success();
602  }
603};
604
605} // end namespace rpc
606} // end namespace orc
607} // end namespace llvm
608
609#endif // LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H
610