connection.py revision f8d62d23e9b02c557b2bbe69f693fc14c2574281
1#
2# A higher level module for using sockets (or Windows named pipes)
3#
4# multiprocessing/connection.py
5#
6# Copyright (c) 2006-2008, R Oudkerk --- see COPYING.txt
7#
8
9__all__ = [ 'Client', 'Listener', 'Pipe' ]
10
11import os
12import sys
13import socket
14import time
15import tempfile
16import itertools
17
18import _multiprocessing
19from multiprocessing import current_process
20from multiprocessing.util import get_temp_dir, Finalize, sub_debug, debug
21from multiprocessing.forking import duplicate, close
22
23
24#
25#
26#
27
28BUFSIZE = 8192
29
30_mmap_counter = itertools.count()
31
32default_family = 'AF_INET'
33families = ['AF_INET']
34
35if hasattr(socket, 'AF_UNIX'):
36    default_family = 'AF_UNIX'
37    families += ['AF_UNIX']
38
39if sys.platform == 'win32':
40    default_family = 'AF_PIPE'
41    families += ['AF_PIPE']
42
43#
44#
45#
46
47def arbitrary_address(family):
48    '''
49    Return an arbitrary free address for the given family
50    '''
51    if family == 'AF_INET':
52        return ('localhost', 0)
53    elif family == 'AF_UNIX':
54        return tempfile.mktemp(prefix='listener-', dir=get_temp_dir())
55    elif family == 'AF_PIPE':
56        return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' %
57                               (os.getpid(), _mmap_counter.next()))
58    else:
59        raise ValueError('unrecognized family')
60
61
62def address_type(address):
63    '''
64    Return the types of the address
65
66    This can be 'AF_INET', 'AF_UNIX', or 'AF_PIPE'
67    '''
68    if type(address) == tuple:
69        return 'AF_INET'
70    elif type(address) is str and address.startswith('\\\\'):
71        return 'AF_PIPE'
72    elif type(address) is str:
73        return 'AF_UNIX'
74    else:
75        raise ValueError('address type of %r unrecognized' % address)
76
77#
78# Public functions
79#
80
81class Listener(object):
82    '''
83    Returns a listener object.
84
85    This is a wrapper for a bound socket which is 'listening' for
86    connections, or for a Windows named pipe.
87    '''
88    def __init__(self, address=None, family=None, backlog=1, authkey=None):
89        family = family or (address and address_type(address)) \
90                 or default_family
91        address = address or arbitrary_address(family)
92
93        if family == 'AF_PIPE':
94            self._listener = PipeListener(address, backlog)
95        else:
96            self._listener = SocketListener(address, family, backlog)
97
98        if authkey is not None and not isinstance(authkey, bytes):
99            raise TypeError, 'authkey should be a byte string'
100
101        self._authkey = authkey
102
103    def accept(self):
104        '''
105        Accept a connection on the bound socket or named pipe of `self`.
106
107        Returns a `Connection` object.
108        '''
109        c = self._listener.accept()
110        if self._authkey:
111            deliver_challenge(c, self._authkey)
112            answer_challenge(c, self._authkey)
113        return c
114
115    def close(self):
116        '''
117        Close the bound socket or named pipe of `self`.
118        '''
119        return self._listener.close()
120
121    address = property(lambda self: self._listener._address)
122    last_accepted = property(lambda self: self._listener._last_accepted)
123
124
125def Client(address, family=None, authkey=None):
126    '''
127    Returns a connection to the address of a `Listener`
128    '''
129    family = family or address_type(address)
130    if family == 'AF_PIPE':
131        c = PipeClient(address)
132    else:
133        c = SocketClient(address)
134
135    if authkey is not None and not isinstance(authkey, bytes):
136        raise TypeError, 'authkey should be a byte string'
137
138    if authkey is not None:
139        answer_challenge(c, authkey)
140        deliver_challenge(c, authkey)
141
142    return c
143
144
145if sys.platform != 'win32':
146
147    def Pipe(duplex=True):
148        '''
149        Returns pair of connection objects at either end of a pipe
150        '''
151        if duplex:
152            s1, s2 = socket.socketpair()
153            c1 = _multiprocessing.Connection(os.dup(s1.fileno()))
154            c2 = _multiprocessing.Connection(os.dup(s2.fileno()))
155            s1.close()
156            s2.close()
157        else:
158            fd1, fd2 = os.pipe()
159            c1 = _multiprocessing.Connection(fd1, writable=False)
160            c2 = _multiprocessing.Connection(fd2, readable=False)
161
162        return c1, c2
163
164else:
165
166    from ._multiprocessing import win32
167
168    def Pipe(duplex=True):
169        '''
170        Returns pair of connection objects at either end of a pipe
171        '''
172        address = arbitrary_address('AF_PIPE')
173        if duplex:
174            openmode = win32.PIPE_ACCESS_DUPLEX
175            access = win32.GENERIC_READ | win32.GENERIC_WRITE
176            obsize, ibsize = BUFSIZE, BUFSIZE
177        else:
178            openmode = win32.PIPE_ACCESS_INBOUND
179            access = win32.GENERIC_WRITE
180            obsize, ibsize = 0, BUFSIZE
181
182        h1 = win32.CreateNamedPipe(
183            address, openmode,
184            win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
185            win32.PIPE_WAIT,
186            1, obsize, ibsize, win32.NMPWAIT_WAIT_FOREVER, win32.NULL
187            )
188        h2 = win32.CreateFile(
189            address, access, 0, win32.NULL, win32.OPEN_EXISTING, 0, win32.NULL
190            )
191        win32.SetNamedPipeHandleState(
192            h2, win32.PIPE_READMODE_MESSAGE, None, None
193            )
194
195        try:
196            win32.ConnectNamedPipe(h1, win32.NULL)
197        except WindowsError, e:
198            if e.args[0] != win32.ERROR_PIPE_CONNECTED:
199                raise
200
201        c1 = _multiprocessing.PipeConnection(h1, writable=duplex)
202        c2 = _multiprocessing.PipeConnection(h2, readable=duplex)
203
204        return c1, c2
205
206#
207# Definitions for connections based on sockets
208#
209
210class SocketListener(object):
211    '''
212    Representation of a socket which is bound to an address and listening
213    '''
214    def __init__(self, address, family, backlog=1):
215        self._socket = socket.socket(getattr(socket, family))
216        self._socket.bind(address)
217        self._socket.listen(backlog)
218        self._address = self._socket.getsockname()
219        self._family = family
220        self._last_accepted = None
221
222        if family == 'AF_UNIX':
223            self._unlink = Finalize(
224                self, os.unlink, args=(address,), exitpriority=0
225                )
226        else:
227            self._unlink = None
228
229    def accept(self):
230        s, self._last_accepted = self._socket.accept()
231        fd = duplicate(s.fileno())
232        conn = _multiprocessing.Connection(fd)
233        s.close()
234        return conn
235
236    def close(self):
237        self._socket.close()
238        if self._unlink is not None:
239            self._unlink()
240
241
242def SocketClient(address):
243    '''
244    Return a connection object connected to the socket given by `address`
245    '''
246    family = address_type(address)
247    s = socket.socket( getattr(socket, family) )
248
249    while 1:
250        try:
251            s.connect(address)
252        except socket.error, e:
253            if e.args[0] != 10061:    # 10061 => connection refused
254                debug('failed to connect to address %s', address)
255                raise
256            time.sleep(0.01)
257        else:
258            break
259    else:
260        raise
261
262    fd = duplicate(s.fileno())
263    conn = _multiprocessing.Connection(fd)
264    s.close()
265    return conn
266
267#
268# Definitions for connections based on named pipes
269#
270
271if sys.platform == 'win32':
272
273    class PipeListener(object):
274        '''
275        Representation of a named pipe
276        '''
277        def __init__(self, address, backlog=None):
278            self._address = address
279            handle = win32.CreateNamedPipe(
280                address, win32.PIPE_ACCESS_DUPLEX,
281                win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
282                win32.PIPE_WAIT,
283                win32.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE,
284                win32.NMPWAIT_WAIT_FOREVER, win32.NULL
285                )
286            self._handle_queue = [handle]
287            self._last_accepted = None
288
289            sub_debug('listener created with address=%r', self._address)
290
291            self.close = Finalize(
292                self, PipeListener._finalize_pipe_listener,
293                args=(self._handle_queue, self._address), exitpriority=0
294                )
295
296        def accept(self):
297            newhandle = win32.CreateNamedPipe(
298                self._address, win32.PIPE_ACCESS_DUPLEX,
299                win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
300                win32.PIPE_WAIT,
301                win32.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE,
302                win32.NMPWAIT_WAIT_FOREVER, win32.NULL
303                )
304            self._handle_queue.append(newhandle)
305            handle = self._handle_queue.pop(0)
306            try:
307                win32.ConnectNamedPipe(handle, win32.NULL)
308            except WindowsError, e:
309                if e.args[0] != win32.ERROR_PIPE_CONNECTED:
310                    raise
311            return _multiprocessing.PipeConnection(handle)
312
313        @staticmethod
314        def _finalize_pipe_listener(queue, address):
315            sub_debug('closing listener with address=%r', address)
316            for handle in queue:
317                close(handle)
318
319    def PipeClient(address):
320        '''
321        Return a connection object connected to the pipe given by `address`
322        '''
323        while 1:
324            try:
325                win32.WaitNamedPipe(address, 1000)
326                h = win32.CreateFile(
327                    address, win32.GENERIC_READ | win32.GENERIC_WRITE,
328                    0, win32.NULL, win32.OPEN_EXISTING, 0, win32.NULL
329                    )
330            except WindowsError, e:
331                if e.args[0] not in (win32.ERROR_SEM_TIMEOUT,
332                                     win32.ERROR_PIPE_BUSY):
333                    raise
334            else:
335                break
336        else:
337            raise
338
339        win32.SetNamedPipeHandleState(
340            h, win32.PIPE_READMODE_MESSAGE, None, None
341            )
342        return _multiprocessing.PipeConnection(h)
343
344#
345# Authentication stuff
346#
347
348MESSAGE_LENGTH = 20
349
350CHALLENGE = b'#CHALLENGE#'
351WELCOME = b'#WELCOME#'
352FAILURE = b'#FAILURE#'
353
354def deliver_challenge(connection, authkey):
355    import hmac
356    assert isinstance(authkey, bytes)
357    message = os.urandom(MESSAGE_LENGTH)
358    connection.send_bytes(CHALLENGE + message)
359    digest = hmac.new(authkey, message).digest()
360    response = connection.recv_bytes(256)        # reject large message
361    if response == digest:
362        connection.send_bytes(WELCOME)
363    else:
364        connection.send_bytes(FAILURE)
365        raise AuthenticationError('digest received was wrong')
366
367def answer_challenge(connection, authkey):
368    import hmac
369    assert isinstance(authkey, bytes)
370    message = connection.recv_bytes(256)         # reject large message
371    assert message[:len(CHALLENGE)] == CHALLENGE, 'message = %r' % message
372    message = message[len(CHALLENGE):]
373    digest = hmac.new(authkey, message).digest()
374    connection.send_bytes(digest)
375    response = connection.recv_bytes(256)        # reject large message
376    if response != WELCOME:
377        raise AuthenticationError('digest sent was rejected')
378
379#
380# Support for using xmlrpclib for serialization
381#
382
383class ConnectionWrapper(object):
384    def __init__(self, conn, dumps, loads):
385        self._conn = conn
386        self._dumps = dumps
387        self._loads = loads
388        for attr in ('fileno', 'close', 'poll', 'recv_bytes', 'send_bytes'):
389            obj = getattr(conn, attr)
390            setattr(self, attr, obj)
391    def send(self, obj):
392        s = self._dumps(obj)
393        self._conn.send_bytes(s)
394    def recv(self):
395        s = self._conn.recv_bytes()
396        return self._loads(s)
397
398def _xml_dumps(obj):
399    return xmlrpclib.dumps((obj,), None, None, None, 1).encode('utf8')
400
401def _xml_loads(s):
402    (obj,), method = xmlrpclib.loads(s.decode('utf8'))
403    return obj
404
405class XmlListener(Listener):
406    def accept(self):
407        global xmlrpclib
408        import xmlrpclib
409        obj = Listener.accept(self)
410        return ConnectionWrapper(obj, _xml_dumps, _xml_loads)
411
412def XmlClient(*args, **kwds):
413    global xmlrpclib
414    import xmlrpclib
415    return ConnectionWrapper(Client(*args, **kwds), _xml_dumps, _xml_loads)
416