1# Copyright 2014 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""Utility classes to handle sending and receiving messages."""
6
7
8import struct
9import weakref
10
11# pylint: disable=F0401
12import mojo.bindings.serialization as serialization
13import mojo.system as system
14
15
16# The flag values for a message header.
17NO_FLAG = 0
18MESSAGE_EXPECTS_RESPONSE_FLAG = 1 << 0
19MESSAGE_IS_RESPONSE_FLAG = 1 << 1
20
21
22class MessageHeader(object):
23  """The header of a mojo message."""
24
25  _SIMPLE_MESSAGE_NUM_FIELDS = 2
26  _SIMPLE_MESSAGE_STRUCT = struct.Struct("=IIII")
27
28  _REQUEST_ID_STRUCT = struct.Struct("=Q")
29  _REQUEST_ID_OFFSET = _SIMPLE_MESSAGE_STRUCT.size
30
31  _MESSAGE_WITH_REQUEST_ID_NUM_FIELDS = 3
32  _MESSAGE_WITH_REQUEST_ID_SIZE = (
33      _SIMPLE_MESSAGE_STRUCT.size + _REQUEST_ID_STRUCT.size)
34
35  def __init__(self, message_type, flags, request_id=0, data=None):
36    self._message_type = message_type
37    self._flags = flags
38    self._request_id = request_id
39    self._data = data
40
41  @classmethod
42  def Deserialize(cls, data):
43    buf = buffer(data)
44    if len(data) < cls._SIMPLE_MESSAGE_STRUCT.size:
45      raise serialization.DeserializationException('Header is too short.')
46    (size, version, message_type, flags) = (
47        cls._SIMPLE_MESSAGE_STRUCT.unpack_from(buf))
48    if (version < cls._SIMPLE_MESSAGE_NUM_FIELDS):
49      raise serialization.DeserializationException('Incorrect version.')
50    request_id = 0
51    if _HasRequestId(flags):
52      if version < cls._MESSAGE_WITH_REQUEST_ID_NUM_FIELDS:
53        raise serialization.DeserializationException('Incorrect version.')
54      if (size < cls._MESSAGE_WITH_REQUEST_ID_SIZE or
55          len(data) < cls._MESSAGE_WITH_REQUEST_ID_SIZE):
56        raise serialization.DeserializationException('Header is too short.')
57      (request_id, ) = cls._REQUEST_ID_STRUCT.unpack_from(
58          buf, cls._REQUEST_ID_OFFSET)
59    return MessageHeader(message_type, flags, request_id, data)
60
61  @property
62  def message_type(self):
63    return self._message_type
64
65  # pylint: disable=E0202
66  @property
67  def request_id(self):
68    assert self.has_request_id
69    return self._request_id
70
71  # pylint: disable=E0202
72  @request_id.setter
73  def request_id(self, request_id):
74    assert self.has_request_id
75    self._request_id = request_id
76    self._REQUEST_ID_STRUCT.pack_into(self._data, self._REQUEST_ID_OFFSET,
77                                      request_id)
78
79  @property
80  def has_request_id(self):
81    return _HasRequestId(self._flags)
82
83  @property
84  def expects_response(self):
85    return self._HasFlag(MESSAGE_EXPECTS_RESPONSE_FLAG)
86
87  @property
88  def is_response(self):
89    return self._HasFlag(MESSAGE_IS_RESPONSE_FLAG)
90
91  @property
92  def size(self):
93    if self.has_request_id:
94      return self._MESSAGE_WITH_REQUEST_ID_SIZE
95    return self._SIMPLE_MESSAGE_STRUCT.size
96
97  def Serialize(self):
98    if not self._data:
99      self._data = bytearray(self.size)
100      version = self._SIMPLE_MESSAGE_NUM_FIELDS
101      size = self._SIMPLE_MESSAGE_STRUCT.size
102      if self.has_request_id:
103        version = self._MESSAGE_WITH_REQUEST_ID_NUM_FIELDS
104        size = self._MESSAGE_WITH_REQUEST_ID_SIZE
105      self._SIMPLE_MESSAGE_STRUCT.pack_into(self._data, 0, size, version,
106                                            self._message_type, self._flags)
107      if self.has_request_id:
108        self._REQUEST_ID_STRUCT.pack_into(self._data, self._REQUEST_ID_OFFSET,
109                                          self._request_id)
110    return self._data
111
112  def _HasFlag(self, flag):
113    return self._flags & flag != 0
114
115
116class Message(object):
117  """A message for a message pipe. This contains data and handles."""
118
119  def __init__(self, data=None, handles=None):
120    self.data = data
121    self.handles = handles
122    self._header = None
123    self._payload = None
124
125  @property
126  def header(self):
127    if self._header is None:
128      self._header = MessageHeader.Deserialize(self.data)
129    return self._header
130
131  @property
132  def payload(self):
133    if self._payload is None:
134      self._payload = Message(self.data[self.header.size:], self.handles)
135    return self._payload
136
137  def SetRequestId(self, request_id):
138    header = self.header
139    header.request_id = request_id
140    (data, _) = header.Serialize()
141    self.data[:header.Size] = data[:header.Size]
142
143
144class MessageReceiver(object):
145  """A class which implements this interface can receive Message objects."""
146
147  def Accept(self, message):
148    """
149    Receive a Message. The MessageReceiver is allowed to mutate the message.
150
151    Args:
152      message: the received message.
153
154    Returns:
155      True if the message has been handled, False otherwise.
156    """
157    raise NotImplementedError()
158
159
160class MessageReceiverWithResponder(MessageReceiver):
161  """
162  A MessageReceiver that can also handle the response message generated from the
163  given message.
164  """
165
166  def AcceptWithResponder(self, message, responder):
167    """
168    A variant on Accept that registers a MessageReceiver (known as the
169    responder) to handle the response message generated from the given message.
170    The responder's Accept method may be called as part of the call to
171    AcceptWithResponder, or some time after its return.
172
173    Args:
174      message: the received message.
175      responder: the responder that will receive the response.
176
177    Returns:
178      True if the message has been handled, False otherwise.
179    """
180    raise NotImplementedError()
181
182
183class ConnectionErrorHandler(object):
184  """
185  A ConnectionErrorHandler is notified of an error happening while using the
186  bindings over message pipes.
187  """
188
189  def OnError(self, result):
190    raise NotImplementedError()
191
192
193class Connector(MessageReceiver):
194  """
195  A Connector owns a message pipe and will send any received messages to the
196  registered MessageReceiver. It also acts as a MessageReceiver and will send
197  any message through the handle.
198
199  The method Start must be called before the Connector will start listening to
200  incoming messages.
201  """
202
203  def __init__(self, handle):
204    MessageReceiver.__init__(self)
205    self._handle = handle
206    self._cancellable = None
207    self._incoming_message_receiver = None
208    self._error_handler = None
209
210  def __del__(self):
211    if self._cancellable:
212      self._cancellable()
213
214  def SetIncomingMessageReceiver(self, message_receiver):
215    """
216    Set the MessageReceiver that will receive message from the owned message
217    pipe.
218    """
219    self._incoming_message_receiver = message_receiver
220
221  def SetErrorHandler(self, error_handler):
222    """
223    Set the ConnectionErrorHandler that will be notified of errors on the owned
224    message pipe.
225    """
226    self._error_handler = error_handler
227
228  def Start(self):
229    assert not self._cancellable
230    self._RegisterAsyncWaiterForRead()
231
232  def Accept(self, message):
233    result = self._handle.WriteMessage(message.data, message.handles)
234    return result == system.RESULT_OK
235
236  def Close(self):
237    if self._cancellable:
238      self._cancellable()
239      self._cancellable = None
240    self._handle.Close()
241
242  def _OnAsyncWaiterResult(self, result):
243    self._cancellable = None
244    if result == system.RESULT_OK:
245      self._ReadOutstandingMessages()
246    else:
247      self._OnError(result)
248
249  def _OnError(self, result):
250    assert not self._cancellable
251    if self._error_handler:
252      self._error_handler.OnError(result)
253
254  def _RegisterAsyncWaiterForRead(self) :
255    assert not self._cancellable
256    self._cancellable = self._handle.AsyncWait(
257        system.HANDLE_SIGNAL_READABLE,
258        system.DEADLINE_INDEFINITE,
259        _WeakCallback(self._OnAsyncWaiterResult))
260
261  def _ReadOutstandingMessages(self):
262    result = system.RESULT_OK
263    while result == system.RESULT_OK:
264      result = _ReadAndDispatchMessage(self._handle,
265                                       self._incoming_message_receiver)
266    if result == system.RESULT_SHOULD_WAIT:
267      self._RegisterAsyncWaiterForRead()
268      return
269    self._OnError(result)
270
271
272class Router(MessageReceiverWithResponder):
273  """
274  A Router will handle mojo message and forward those to a Connector. It deals
275  with parsing of headers and adding of request ids in order to be able to match
276  a response to a request.
277  """
278
279  def __init__(self, handle):
280    MessageReceiverWithResponder.__init__(self)
281    self._incoming_message_receiver = None
282    self._next_request_id = 1
283    self._responders = {}
284    self._connector = Connector(handle)
285    self._connector.SetIncomingMessageReceiver(
286        ForwardingMessageReceiver(self._HandleIncomingMessage))
287
288  def Start(self):
289    self._connector.Start()
290
291  def SetIncomingMessageReceiver(self, message_receiver):
292    """
293    Set the MessageReceiver that will receive message from the owned message
294    pipe.
295    """
296    self._incoming_message_receiver = message_receiver
297
298  def SetErrorHandler(self, error_handler):
299    """
300    Set the ConnectionErrorHandler that will be notified of errors on the owned
301    message pipe.
302    """
303    self._connector.SetErrorHandler(error_handler)
304
305  def Accept(self, message):
306    # A message without responder is directly forwarded to the connector.
307    return self._connector.Accept(message)
308
309  def AcceptWithResponder(self, message, responder):
310    # The message must have a header.
311    header = message.header
312    assert header.expects_response
313    request_id = self.NextRequestId()
314    header.request_id = request_id
315    if not self._connector.Accept(message):
316      return False
317    self._responders[request_id] = responder
318    return True
319
320  def Close(self):
321    self._connector.Close()
322
323  def _HandleIncomingMessage(self, message):
324    header = message.header
325    if header.expects_response:
326      if self._incoming_message_receiver:
327        return self._incoming_message_receiver.AcceptWithResponder(
328            message, self)
329      # If we receive a request expecting a response when the client is not
330      # listening, then we have no choice but to tear down the pipe.
331      self.Close()
332      return False
333    if header.is_response:
334      request_id = header.request_id
335      responder = self._responders.pop(request_id, None)
336      if responder is None:
337        return False
338      return responder.Accept(message)
339    if self._incoming_message_receiver:
340      return self._incoming_message_receiver.Accept(message)
341    # Ok to drop the message
342    return False
343
344  def NextRequestId(self):
345    request_id = self._next_request_id
346    while request_id == 0 or request_id in self._responders:
347      request_id = (request_id + 1) % (1 << 64)
348    self._next_request_id = (request_id + 1) % (1 << 64)
349    return request_id
350
351class ForwardingMessageReceiver(MessageReceiver):
352  """A MessageReceiver that forward calls to |Accept| to a callable."""
353
354  def __init__(self, callback):
355    MessageReceiver.__init__(self)
356    self._callback = callback
357
358  def Accept(self, message):
359    return self._callback(message)
360
361
362def _WeakCallback(callback):
363  func = callback.im_func
364  self = callback.im_self
365  if not self:
366    return callback
367  weak_self = weakref.ref(self)
368  def Callback(*args, **kwargs):
369    self = weak_self()
370    if self:
371      return func(self, *args, **kwargs)
372  return Callback
373
374
375def _ReadAndDispatchMessage(handle, message_receiver):
376  (result, _, sizes) = handle.ReadMessage()
377  if result == system.RESULT_OK and message_receiver:
378    message_receiver.Accept(Message(bytearray(), []))
379  if result != system.RESULT_RESOURCE_EXHAUSTED:
380    return result
381  (result, data, _) = handle.ReadMessage(bytearray(sizes[0]))
382  if result == system.RESULT_OK and message_receiver:
383    message_receiver.Accept(Message(data[0], data[1]))
384  return result
385
386def _HasRequestId(flags):
387  return flags & (MESSAGE_EXPECTS_RESPONSE_FLAG|MESSAGE_IS_RESPONSE_FLAG) != 0
388