connection.py revision c562ca4625992836919b5f787bd9cb288226dae7
1#
2# A higher level module for using sockets (or Windows named pipes)
3#
4# multiprocessing/connection.py
5#
6# Copyright (c) 2006-2008, R Oudkerk --- see COPYING.txt
7#
8
9__all__ = [ 'Client', 'Listener', 'Pipe' ]
10
11import os
12import sys
13import socket
14import errno
15import time
16import tempfile
17import itertools
18
19import _multiprocessing
20from multiprocessing import current_process, AuthenticationError
21from multiprocessing.util import get_temp_dir, Finalize, sub_debug, debug
22from multiprocessing.forking import duplicate, close
23
24
25#
26#
27#
28
29BUFSIZE = 8192
30# A very generous timeout when it comes to local connections...
31CONNECTION_TIMEOUT = 20.
32
33_mmap_counter = itertools.count()
34
35default_family = 'AF_INET'
36families = ['AF_INET']
37
38if hasattr(socket, 'AF_UNIX'):
39    default_family = 'AF_UNIX'
40    families += ['AF_UNIX']
41
42if sys.platform == 'win32':
43    default_family = 'AF_PIPE'
44    families += ['AF_PIPE']
45
46
47def _init_timeout(timeout=CONNECTION_TIMEOUT):
48    return time.time() + timeout
49
50def _check_timeout(t):
51    return time.time() > t
52
53#
54#
55#
56
57def arbitrary_address(family):
58    '''
59    Return an arbitrary free address for the given family
60    '''
61    if family == 'AF_INET':
62        return ('localhost', 0)
63    elif family == 'AF_UNIX':
64        return tempfile.mktemp(prefix='listener-', dir=get_temp_dir())
65    elif family == 'AF_PIPE':
66        return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' %
67                               (os.getpid(), _mmap_counter.next()))
68    else:
69        raise ValueError('unrecognized family')
70
71
72def address_type(address):
73    '''
74    Return the types of the address
75
76    This can be 'AF_INET', 'AF_UNIX', or 'AF_PIPE'
77    '''
78    if type(address) == tuple:
79        return 'AF_INET'
80    elif type(address) is str and address.startswith('\\\\'):
81        return 'AF_PIPE'
82    elif type(address) is str:
83        return 'AF_UNIX'
84    else:
85        raise ValueError('address type of %r unrecognized' % address)
86
87#
88# Public functions
89#
90
91class Listener(object):
92    '''
93    Returns a listener object.
94
95    This is a wrapper for a bound socket which is 'listening' for
96    connections, or for a Windows named pipe.
97    '''
98    def __init__(self, address=None, family=None, backlog=1, authkey=None):
99        family = family or (address and address_type(address)) \
100                 or default_family
101        address = address or arbitrary_address(family)
102
103        if family == 'AF_PIPE':
104            self._listener = PipeListener(address, backlog)
105        else:
106            self._listener = SocketListener(address, family, backlog)
107
108        if authkey is not None and not isinstance(authkey, bytes):
109            raise TypeError, 'authkey should be a byte string'
110
111        self._authkey = authkey
112
113    def accept(self):
114        '''
115        Accept a connection on the bound socket or named pipe of `self`.
116
117        Returns a `Connection` object.
118        '''
119        c = self._listener.accept()
120        if self._authkey:
121            deliver_challenge(c, self._authkey)
122            answer_challenge(c, self._authkey)
123        return c
124
125    def close(self):
126        '''
127        Close the bound socket or named pipe of `self`.
128        '''
129        return self._listener.close()
130
131    address = property(lambda self: self._listener._address)
132    last_accepted = property(lambda self: self._listener._last_accepted)
133
134
135def Client(address, family=None, authkey=None):
136    '''
137    Returns a connection to the address of a `Listener`
138    '''
139    family = family or address_type(address)
140    if family == 'AF_PIPE':
141        c = PipeClient(address)
142    else:
143        c = SocketClient(address)
144
145    if authkey is not None and not isinstance(authkey, bytes):
146        raise TypeError, 'authkey should be a byte string'
147
148    if authkey is not None:
149        answer_challenge(c, authkey)
150        deliver_challenge(c, authkey)
151
152    return c
153
154
155if sys.platform != 'win32':
156
157    def Pipe(duplex=True):
158        '''
159        Returns pair of connection objects at either end of a pipe
160        '''
161        if duplex:
162            s1, s2 = socket.socketpair()
163            c1 = _multiprocessing.Connection(os.dup(s1.fileno()))
164            c2 = _multiprocessing.Connection(os.dup(s2.fileno()))
165            s1.close()
166            s2.close()
167        else:
168            fd1, fd2 = os.pipe()
169            c1 = _multiprocessing.Connection(fd1, writable=False)
170            c2 = _multiprocessing.Connection(fd2, readable=False)
171
172        return c1, c2
173
174else:
175
176    from ._multiprocessing import win32
177
178    def Pipe(duplex=True):
179        '''
180        Returns pair of connection objects at either end of a pipe
181        '''
182        address = arbitrary_address('AF_PIPE')
183        if duplex:
184            openmode = win32.PIPE_ACCESS_DUPLEX
185            access = win32.GENERIC_READ | win32.GENERIC_WRITE
186            obsize, ibsize = BUFSIZE, BUFSIZE
187        else:
188            openmode = win32.PIPE_ACCESS_INBOUND
189            access = win32.GENERIC_WRITE
190            obsize, ibsize = 0, BUFSIZE
191
192        h1 = win32.CreateNamedPipe(
193            address, openmode,
194            win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
195            win32.PIPE_WAIT,
196            1, obsize, ibsize, win32.NMPWAIT_WAIT_FOREVER, win32.NULL
197            )
198        h2 = win32.CreateFile(
199            address, access, 0, win32.NULL, win32.OPEN_EXISTING, 0, win32.NULL
200            )
201        win32.SetNamedPipeHandleState(
202            h2, win32.PIPE_READMODE_MESSAGE, None, None
203            )
204
205        try:
206            win32.ConnectNamedPipe(h1, win32.NULL)
207        except WindowsError, e:
208            if e.args[0] != win32.ERROR_PIPE_CONNECTED:
209                raise
210
211        c1 = _multiprocessing.PipeConnection(h1, writable=duplex)
212        c2 = _multiprocessing.PipeConnection(h2, readable=duplex)
213
214        return c1, c2
215
216#
217# Definitions for connections based on sockets
218#
219
220class SocketListener(object):
221    '''
222    Representation of a socket which is bound to an address and listening
223    '''
224    def __init__(self, address, family, backlog=1):
225        self._socket = socket.socket(getattr(socket, family))
226        self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
227        self._socket.bind(address)
228        self._socket.listen(backlog)
229        self._address = self._socket.getsockname()
230        self._family = family
231        self._last_accepted = None
232
233        if family == 'AF_UNIX':
234            self._unlink = Finalize(
235                self, os.unlink, args=(address,), exitpriority=0
236                )
237        else:
238            self._unlink = None
239
240    def accept(self):
241        s, self._last_accepted = self._socket.accept()
242        fd = duplicate(s.fileno())
243        conn = _multiprocessing.Connection(fd)
244        s.close()
245        return conn
246
247    def close(self):
248        self._socket.close()
249        if self._unlink is not None:
250            self._unlink()
251
252
253def SocketClient(address):
254    '''
255    Return a connection object connected to the socket given by `address`
256    '''
257    family = address_type(address)
258    s = socket.socket( getattr(socket, family) )
259    t = _init_timeout()
260
261    while 1:
262        try:
263            s.connect(address)
264        except socket.error, e:
265            if e.args[0] != errno.ECONNREFUSED or _check_timeout(t):
266                debug('failed to connect to address %s', address)
267                raise
268            time.sleep(0.01)
269        else:
270            break
271    else:
272        raise
273
274    fd = duplicate(s.fileno())
275    conn = _multiprocessing.Connection(fd)
276    s.close()
277    return conn
278
279#
280# Definitions for connections based on named pipes
281#
282
283if sys.platform == 'win32':
284
285    class PipeListener(object):
286        '''
287        Representation of a named pipe
288        '''
289        def __init__(self, address, backlog=None):
290            self._address = address
291            handle = win32.CreateNamedPipe(
292                address, win32.PIPE_ACCESS_DUPLEX,
293                win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
294                win32.PIPE_WAIT,
295                win32.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE,
296                win32.NMPWAIT_WAIT_FOREVER, win32.NULL
297                )
298            self._handle_queue = [handle]
299            self._last_accepted = None
300
301            sub_debug('listener created with address=%r', self._address)
302
303            self.close = Finalize(
304                self, PipeListener._finalize_pipe_listener,
305                args=(self._handle_queue, self._address), exitpriority=0
306                )
307
308        def accept(self):
309            newhandle = win32.CreateNamedPipe(
310                self._address, win32.PIPE_ACCESS_DUPLEX,
311                win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
312                win32.PIPE_WAIT,
313                win32.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE,
314                win32.NMPWAIT_WAIT_FOREVER, win32.NULL
315                )
316            self._handle_queue.append(newhandle)
317            handle = self._handle_queue.pop(0)
318            try:
319                win32.ConnectNamedPipe(handle, win32.NULL)
320            except WindowsError, e:
321                if e.args[0] != win32.ERROR_PIPE_CONNECTED:
322                    raise
323            return _multiprocessing.PipeConnection(handle)
324
325        @staticmethod
326        def _finalize_pipe_listener(queue, address):
327            sub_debug('closing listener with address=%r', address)
328            for handle in queue:
329                close(handle)
330
331    def PipeClient(address):
332        '''
333        Return a connection object connected to the pipe given by `address`
334        '''
335        t = _init_timeout()
336        while 1:
337            try:
338                win32.WaitNamedPipe(address, 1000)
339                h = win32.CreateFile(
340                    address, win32.GENERIC_READ | win32.GENERIC_WRITE,
341                    0, win32.NULL, win32.OPEN_EXISTING, 0, win32.NULL
342                    )
343            except WindowsError, e:
344                if e.args[0] not in (win32.ERROR_SEM_TIMEOUT,
345                                     win32.ERROR_PIPE_BUSY) or _check_timeout(t):
346                    raise
347            else:
348                break
349        else:
350            raise
351
352        win32.SetNamedPipeHandleState(
353            h, win32.PIPE_READMODE_MESSAGE, None, None
354            )
355        return _multiprocessing.PipeConnection(h)
356
357#
358# Authentication stuff
359#
360
361MESSAGE_LENGTH = 20
362
363CHALLENGE = b'#CHALLENGE#'
364WELCOME = b'#WELCOME#'
365FAILURE = b'#FAILURE#'
366
367def deliver_challenge(connection, authkey):
368    import hmac
369    assert isinstance(authkey, bytes)
370    message = os.urandom(MESSAGE_LENGTH)
371    connection.send_bytes(CHALLENGE + message)
372    digest = hmac.new(authkey, message).digest()
373    response = connection.recv_bytes(256)        # reject large message
374    if response == digest:
375        connection.send_bytes(WELCOME)
376    else:
377        connection.send_bytes(FAILURE)
378        raise AuthenticationError('digest received was wrong')
379
380def answer_challenge(connection, authkey):
381    import hmac
382    assert isinstance(authkey, bytes)
383    message = connection.recv_bytes(256)         # reject large message
384    assert message[:len(CHALLENGE)] == CHALLENGE, 'message = %r' % message
385    message = message[len(CHALLENGE):]
386    digest = hmac.new(authkey, message).digest()
387    connection.send_bytes(digest)
388    response = connection.recv_bytes(256)        # reject large message
389    if response != WELCOME:
390        raise AuthenticationError('digest sent was rejected')
391
392#
393# Support for using xmlrpclib for serialization
394#
395
396class ConnectionWrapper(object):
397    def __init__(self, conn, dumps, loads):
398        self._conn = conn
399        self._dumps = dumps
400        self._loads = loads
401        for attr in ('fileno', 'close', 'poll', 'recv_bytes', 'send_bytes'):
402            obj = getattr(conn, attr)
403            setattr(self, attr, obj)
404    def send(self, obj):
405        s = self._dumps(obj)
406        self._conn.send_bytes(s)
407    def recv(self):
408        s = self._conn.recv_bytes()
409        return self._loads(s)
410
411def _xml_dumps(obj):
412    return xmlrpclib.dumps((obj,), None, None, None, 1).encode('utf8')
413
414def _xml_loads(s):
415    (obj,), method = xmlrpclib.loads(s.decode('utf8'))
416    return obj
417
418class XmlListener(Listener):
419    def accept(self):
420        global xmlrpclib
421        import xmlrpclib
422        obj = Listener.accept(self)
423        return ConnectionWrapper(obj, _xml_dumps, _xml_loads)
424
425def XmlClient(*args, **kwds):
426    global xmlrpclib
427    import xmlrpclib
428    return ConnectionWrapper(Client(*args, **kwds), _xml_dumps, _xml_loads)
429