1//===- llvm/ExecutionEngine/Orc/RawByteChannel.h ----------------*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9
10#ifndef LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H
11#define LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H
12
13#include "llvm/ADT/StringRef.h"
14#include "llvm/ExecutionEngine/Orc/RPCSerialization.h"
15#include "llvm/Support/Endian.h"
16#include "llvm/Support/Error.h"
17#include <cstdint>
18#include <mutex>
19#include <string>
20#include <type_traits>
21
22namespace llvm {
23namespace orc {
24namespace rpc {
25
26/// Interface for byte-streams to be used with RPC.
27class RawByteChannel {
28public:
29  virtual ~RawByteChannel() = default;
30
31  /// Read Size bytes from the stream into *Dst.
32  virtual Error readBytes(char *Dst, unsigned Size) = 0;
33
34  /// Read size bytes from *Src and append them to the stream.
35  virtual Error appendBytes(const char *Src, unsigned Size) = 0;
36
37  /// Flush the stream if possible.
38  virtual Error send() = 0;
39
40  /// Notify the channel that we're starting a message send.
41  /// Locks the channel for writing.
42  template <typename FunctionIdT, typename SequenceIdT>
43  Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) {
44    writeLock.lock();
45    if (auto Err = serializeSeq(*this, FnId, SeqNo)) {
46      writeLock.unlock();
47      return Err;
48    }
49    return Error::success();
50  }
51
52  /// Notify the channel that we're ending a message send.
53  /// Unlocks the channel for writing.
54  Error endSendMessage() {
55    writeLock.unlock();
56    return Error::success();
57  }
58
59  /// Notify the channel that we're starting a message receive.
60  /// Locks the channel for reading.
61  template <typename FunctionIdT, typename SequenceNumberT>
62  Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) {
63    readLock.lock();
64    if (auto Err = deserializeSeq(*this, FnId, SeqNo)) {
65      readLock.unlock();
66      return Err;
67    }
68    return Error::success();
69  }
70
71  /// Notify the channel that we're ending a message receive.
72  /// Unlocks the channel for reading.
73  Error endReceiveMessage() {
74    readLock.unlock();
75    return Error::success();
76  }
77
78  /// Get the lock for stream reading.
79  std::mutex &getReadLock() { return readLock; }
80
81  /// Get the lock for stream writing.
82  std::mutex &getWriteLock() { return writeLock; }
83
84private:
85  std::mutex readLock, writeLock;
86};
87
88template <typename ChannelT, typename T>
89class SerializationTraits<
90    ChannelT, T, T,
91    typename std::enable_if<
92        std::is_base_of<RawByteChannel, ChannelT>::value &&
93        (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
94         std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value ||
95         std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value ||
96         std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value ||
97         std::is_same<T, char>::value)>::type> {
98public:
99  static Error serialize(ChannelT &C, T V) {
100    support::endian::byte_swap<T, support::big>(V);
101    return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T));
102  };
103
104  static Error deserialize(ChannelT &C, T &V) {
105    if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T)))
106      return Err;
107    support::endian::byte_swap<T, support::big>(V);
108    return Error::success();
109  };
110};
111
112template <typename ChannelT>
113class SerializationTraits<ChannelT, bool, bool,
114                          typename std::enable_if<std::is_base_of<
115                              RawByteChannel, ChannelT>::value>::type> {
116public:
117  static Error serialize(ChannelT &C, bool V) {
118    uint8_t Tmp = V ? 1 : 0;
119    if (auto Err =
120          C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1))
121      return Err;
122    return Error::success();
123  }
124
125  static Error deserialize(ChannelT &C, bool &V) {
126    uint8_t Tmp = 0;
127    if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1))
128      return Err;
129    V = Tmp != 0;
130    return Error::success();
131  }
132};
133
134template <typename ChannelT>
135class SerializationTraits<ChannelT, std::string, StringRef,
136                          typename std::enable_if<std::is_base_of<
137                              RawByteChannel, ChannelT>::value>::type> {
138public:
139  /// RPC channel serialization for std::strings.
140  static Error serialize(RawByteChannel &C, StringRef S) {
141    if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size())))
142      return Err;
143    return C.appendBytes((const char *)S.data(), S.size());
144  }
145};
146
147template <typename ChannelT, typename T>
148class SerializationTraits<ChannelT, std::string, T,
149                          typename std::enable_if<
150                            std::is_base_of<RawByteChannel, ChannelT>::value &&
151                            (std::is_same<T, const char*>::value ||
152                             std::is_same<T, char*>::value)>::type> {
153public:
154  static Error serialize(RawByteChannel &C, const char *S) {
155    return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C,
156                                                                            S);
157  }
158};
159
160template <typename ChannelT>
161class SerializationTraits<ChannelT, std::string, std::string,
162                          typename std::enable_if<std::is_base_of<
163                              RawByteChannel, ChannelT>::value>::type> {
164public:
165  /// RPC channel serialization for std::strings.
166  static Error serialize(RawByteChannel &C, const std::string &S) {
167    return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C,
168                                                                            S);
169  }
170
171  /// RPC channel deserialization for std::strings.
172  static Error deserialize(RawByteChannel &C, std::string &S) {
173    uint64_t Count = 0;
174    if (auto Err = deserializeSeq(C, Count))
175      return Err;
176    S.resize(Count);
177    return C.readBytes(&S[0], Count);
178  }
179};
180
181} // end namespace rpc
182} // end namespace orc
183} // end namespace llvm
184
185#endif // LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H
186