1"""RPC Implemention, originally written for the Python Idle IDE
2
3For security reasons, GvR requested that Idle's Python execution server process
4connect to the Idle process, which listens for the connection.  Since Idle has
5only one client per server, this was not a limitation.
6
7   +---------------------------------+ +-------------+
8   | SocketServer.BaseRequestHandler | | SocketIO    |
9   +---------------------------------+ +-------------+
10                   ^                   | register()  |
11                   |                   | unregister()|
12                   |                   +-------------+
13                   |                      ^  ^
14                   |                      |  |
15                   | + -------------------+  |
16                   | |                       |
17   +-------------------------+        +-----------------+
18   | RPCHandler              |        | RPCClient       |
19   | [attribute of RPCServer]|        |                 |
20   +-------------------------+        +-----------------+
21
22The RPCServer handler class is expected to provide register/unregister methods.
23RPCHandler inherits the mix-in class SocketIO, which provides these methods.
24
25See the Idle run.main() docstring for further information on how this was
26accomplished in Idle.
27
28"""
29
30import sys
31import os
32import socket
33import select
34import SocketServer
35import struct
36import cPickle as pickle
37import threading
38import Queue
39import traceback
40import copy_reg
41import types
42import marshal
43
44
45def unpickle_code(ms):
46    co = marshal.loads(ms)
47    assert isinstance(co, types.CodeType)
48    return co
49
50def pickle_code(co):
51    assert isinstance(co, types.CodeType)
52    ms = marshal.dumps(co)
53    return unpickle_code, (ms,)
54
55# XXX KBK 24Aug02 function pickling capability not used in Idle
56#  def unpickle_function(ms):
57#      return ms
58
59#  def pickle_function(fn):
60#      assert isinstance(fn, type.FunctionType)
61#      return repr(fn)
62
63copy_reg.pickle(types.CodeType, pickle_code, unpickle_code)
64# copy_reg.pickle(types.FunctionType, pickle_function, unpickle_function)
65
66BUFSIZE = 8*1024
67LOCALHOST = '127.0.0.1'
68
69class RPCServer(SocketServer.TCPServer):
70
71    def __init__(self, addr, handlerclass=None):
72        if handlerclass is None:
73            handlerclass = RPCHandler
74        SocketServer.TCPServer.__init__(self, addr, handlerclass)
75
76    def server_bind(self):
77        "Override TCPServer method, no bind() phase for connecting entity"
78        pass
79
80    def server_activate(self):
81        """Override TCPServer method, connect() instead of listen()
82
83        Due to the reversed connection, self.server_address is actually the
84        address of the Idle Client to which we are connecting.
85
86        """
87        self.socket.connect(self.server_address)
88
89    def get_request(self):
90        "Override TCPServer method, return already connected socket"
91        return self.socket, self.server_address
92
93    def handle_error(self, request, client_address):
94        """Override TCPServer method
95
96        Error message goes to __stderr__.  No error message if exiting
97        normally or socket raised EOF.  Other exceptions not handled in
98        server code will cause os._exit.
99
100        """
101        try:
102            raise
103        except SystemExit:
104            raise
105        except:
106            erf = sys.__stderr__
107            print>>erf, '\n' + '-'*40
108            print>>erf, 'Unhandled server exception!'
109            print>>erf, 'Thread: %s' % threading.currentThread().getName()
110            print>>erf, 'Client Address: ', client_address
111            print>>erf, 'Request: ', repr(request)
112            traceback.print_exc(file=erf)
113            print>>erf, '\n*** Unrecoverable, server exiting!'
114            print>>erf, '-'*40
115            os._exit(0)
116
117#----------------- end class RPCServer --------------------
118
119objecttable = {}
120request_queue = Queue.Queue(0)
121response_queue = Queue.Queue(0)
122
123
124class SocketIO(object):
125
126    nextseq = 0
127
128    def __init__(self, sock, objtable=None, debugging=None):
129        self.sockthread = threading.currentThread()
130        if debugging is not None:
131            self.debugging = debugging
132        self.sock = sock
133        if objtable is None:
134            objtable = objecttable
135        self.objtable = objtable
136        self.responses = {}
137        self.cvars = {}
138
139    def close(self):
140        sock = self.sock
141        self.sock = None
142        if sock is not None:
143            sock.close()
144
145    def exithook(self):
146        "override for specific exit action"
147        os._exit(0)
148
149    def debug(self, *args):
150        if not self.debugging:
151            return
152        s = self.location + " " + str(threading.currentThread().getName())
153        for a in args:
154            s = s + " " + str(a)
155        print>>sys.__stderr__, s
156
157    def register(self, oid, object):
158        self.objtable[oid] = object
159
160    def unregister(self, oid):
161        try:
162            del self.objtable[oid]
163        except KeyError:
164            pass
165
166    def localcall(self, seq, request):
167        self.debug("localcall:", request)
168        try:
169            how, (oid, methodname, args, kwargs) = request
170        except TypeError:
171            return ("ERROR", "Bad request format")
172        if oid not in self.objtable:
173            return ("ERROR", "Unknown object id: %r" % (oid,))
174        obj = self.objtable[oid]
175        if methodname == "__methods__":
176            methods = {}
177            _getmethods(obj, methods)
178            return ("OK", methods)
179        if methodname == "__attributes__":
180            attributes = {}
181            _getattributes(obj, attributes)
182            return ("OK", attributes)
183        if not hasattr(obj, methodname):
184            return ("ERROR", "Unsupported method name: %r" % (methodname,))
185        method = getattr(obj, methodname)
186        try:
187            if how == 'CALL':
188                ret = method(*args, **kwargs)
189                if isinstance(ret, RemoteObject):
190                    ret = remoteref(ret)
191                return ("OK", ret)
192            elif how == 'QUEUE':
193                request_queue.put((seq, (method, args, kwargs)))
194                return("QUEUED", None)
195            else:
196                return ("ERROR", "Unsupported message type: %s" % how)
197        except SystemExit:
198            raise
199        except socket.error:
200            raise
201        except:
202            msg = "*** Internal Error: rpc.py:SocketIO.localcall()\n\n"\
203                  " Object: %s \n Method: %s \n Args: %s\n"
204            print>>sys.__stderr__, msg % (oid, method, args)
205            traceback.print_exc(file=sys.__stderr__)
206            return ("EXCEPTION", None)
207
208    def remotecall(self, oid, methodname, args, kwargs):
209        self.debug("remotecall:asynccall: ", oid, methodname)
210        seq = self.asynccall(oid, methodname, args, kwargs)
211        return self.asyncreturn(seq)
212
213    def remotequeue(self, oid, methodname, args, kwargs):
214        self.debug("remotequeue:asyncqueue: ", oid, methodname)
215        seq = self.asyncqueue(oid, methodname, args, kwargs)
216        return self.asyncreturn(seq)
217
218    def asynccall(self, oid, methodname, args, kwargs):
219        request = ("CALL", (oid, methodname, args, kwargs))
220        seq = self.newseq()
221        if threading.currentThread() != self.sockthread:
222            cvar = threading.Condition()
223            self.cvars[seq] = cvar
224        self.debug(("asynccall:%d:" % seq), oid, methodname, args, kwargs)
225        self.putmessage((seq, request))
226        return seq
227
228    def asyncqueue(self, oid, methodname, args, kwargs):
229        request = ("QUEUE", (oid, methodname, args, kwargs))
230        seq = self.newseq()
231        if threading.currentThread() != self.sockthread:
232            cvar = threading.Condition()
233            self.cvars[seq] = cvar
234        self.debug(("asyncqueue:%d:" % seq), oid, methodname, args, kwargs)
235        self.putmessage((seq, request))
236        return seq
237
238    def asyncreturn(self, seq):
239        self.debug("asyncreturn:%d:call getresponse(): " % seq)
240        response = self.getresponse(seq, wait=0.05)
241        self.debug(("asyncreturn:%d:response: " % seq), response)
242        return self.decoderesponse(response)
243
244    def decoderesponse(self, response):
245        how, what = response
246        if how == "OK":
247            return what
248        if how == "QUEUED":
249            return None
250        if how == "EXCEPTION":
251            self.debug("decoderesponse: EXCEPTION")
252            return None
253        if how == "EOF":
254            self.debug("decoderesponse: EOF")
255            self.decode_interrupthook()
256            return None
257        if how == "ERROR":
258            self.debug("decoderesponse: Internal ERROR:", what)
259            raise RuntimeError, what
260        raise SystemError, (how, what)
261
262    def decode_interrupthook(self):
263        ""
264        raise EOFError
265
266    def mainloop(self):
267        """Listen on socket until I/O not ready or EOF
268
269        pollresponse() will loop looking for seq number None, which
270        never comes, and exit on EOFError.
271
272        """
273        try:
274            self.getresponse(myseq=None, wait=0.05)
275        except EOFError:
276            self.debug("mainloop:return")
277            return
278
279    def getresponse(self, myseq, wait):
280        response = self._getresponse(myseq, wait)
281        if response is not None:
282            how, what = response
283            if how == "OK":
284                response = how, self._proxify(what)
285        return response
286
287    def _proxify(self, obj):
288        if isinstance(obj, RemoteProxy):
289            return RPCProxy(self, obj.oid)
290        if isinstance(obj, types.ListType):
291            return map(self._proxify, obj)
292        # XXX Check for other types -- not currently needed
293        return obj
294
295    def _getresponse(self, myseq, wait):
296        self.debug("_getresponse:myseq:", myseq)
297        if threading.currentThread() is self.sockthread:
298            # this thread does all reading of requests or responses
299            while 1:
300                response = self.pollresponse(myseq, wait)
301                if response is not None:
302                    return response
303        else:
304            # wait for notification from socket handling thread
305            cvar = self.cvars[myseq]
306            cvar.acquire()
307            while myseq not in self.responses:
308                cvar.wait()
309            response = self.responses[myseq]
310            self.debug("_getresponse:%s: thread woke up: response: %s" %
311                       (myseq, response))
312            del self.responses[myseq]
313            del self.cvars[myseq]
314            cvar.release()
315            return response
316
317    def newseq(self):
318        self.nextseq = seq = self.nextseq + 2
319        return seq
320
321    def putmessage(self, message):
322        self.debug("putmessage:%d:" % message[0])
323        try:
324            s = pickle.dumps(message)
325        except pickle.PicklingError:
326            print >>sys.__stderr__, "Cannot pickle:", repr(message)
327            raise
328        s = struct.pack("<i", len(s)) + s
329        while len(s) > 0:
330            try:
331                r, w, x = select.select([], [self.sock], [])
332                n = self.sock.send(s[:BUFSIZE])
333            except (AttributeError, TypeError):
334                raise IOError, "socket no longer exists"
335            except socket.error:
336                raise
337            else:
338                s = s[n:]
339
340    buffer = ""
341    bufneed = 4
342    bufstate = 0 # meaning: 0 => reading count; 1 => reading data
343
344    def pollpacket(self, wait):
345        self._stage0()
346        if len(self.buffer) < self.bufneed:
347            r, w, x = select.select([self.sock.fileno()], [], [], wait)
348            if len(r) == 0:
349                return None
350            try:
351                s = self.sock.recv(BUFSIZE)
352            except socket.error:
353                raise EOFError
354            if len(s) == 0:
355                raise EOFError
356            self.buffer += s
357            self._stage0()
358        return self._stage1()
359
360    def _stage0(self):
361        if self.bufstate == 0 and len(self.buffer) >= 4:
362            s = self.buffer[:4]
363            self.buffer = self.buffer[4:]
364            self.bufneed = struct.unpack("<i", s)[0]
365            self.bufstate = 1
366
367    def _stage1(self):
368        if self.bufstate == 1 and len(self.buffer) >= self.bufneed:
369            packet = self.buffer[:self.bufneed]
370            self.buffer = self.buffer[self.bufneed:]
371            self.bufneed = 4
372            self.bufstate = 0
373            return packet
374
375    def pollmessage(self, wait):
376        packet = self.pollpacket(wait)
377        if packet is None:
378            return None
379        try:
380            message = pickle.loads(packet)
381        except pickle.UnpicklingError:
382            print >>sys.__stderr__, "-----------------------"
383            print >>sys.__stderr__, "cannot unpickle packet:", repr(packet)
384            traceback.print_stack(file=sys.__stderr__)
385            print >>sys.__stderr__, "-----------------------"
386            raise
387        return message
388
389    def pollresponse(self, myseq, wait):
390        """Handle messages received on the socket.
391
392        Some messages received may be asynchronous 'call' or 'queue' requests,
393        and some may be responses for other threads.
394
395        'call' requests are passed to self.localcall() with the expectation of
396        immediate execution, during which time the socket is not serviced.
397
398        'queue' requests are used for tasks (which may block or hang) to be
399        processed in a different thread.  These requests are fed into
400        request_queue by self.localcall().  Responses to queued requests are
401        taken from response_queue and sent across the link with the associated
402        sequence numbers.  Messages in the queues are (sequence_number,
403        request/response) tuples and code using this module removing messages
404        from the request_queue is responsible for returning the correct
405        sequence number in the response_queue.
406
407        pollresponse() will loop until a response message with the myseq
408        sequence number is received, and will save other responses in
409        self.responses and notify the owning thread.
410
411        """
412        while 1:
413            # send queued response if there is one available
414            try:
415                qmsg = response_queue.get(0)
416            except Queue.Empty:
417                pass
418            else:
419                seq, response = qmsg
420                message = (seq, ('OK', response))
421                self.putmessage(message)
422            # poll for message on link
423            try:
424                message = self.pollmessage(wait)
425                if message is None:  # socket not ready
426                    return None
427            except EOFError:
428                self.handle_EOF()
429                return None
430            except AttributeError:
431                return None
432            seq, resq = message
433            how = resq[0]
434            self.debug("pollresponse:%d:myseq:%s" % (seq, myseq))
435            # process or queue a request
436            if how in ("CALL", "QUEUE"):
437                self.debug("pollresponse:%d:localcall:call:" % seq)
438                response = self.localcall(seq, resq)
439                self.debug("pollresponse:%d:localcall:response:%s"
440                           % (seq, response))
441                if how == "CALL":
442                    self.putmessage((seq, response))
443                elif how == "QUEUE":
444                    # don't acknowledge the 'queue' request!
445                    pass
446                continue
447            # return if completed message transaction
448            elif seq == myseq:
449                return resq
450            # must be a response for a different thread:
451            else:
452                cv = self.cvars.get(seq, None)
453                # response involving unknown sequence number is discarded,
454                # probably intended for prior incarnation of server
455                if cv is not None:
456                    cv.acquire()
457                    self.responses[seq] = resq
458                    cv.notify()
459                    cv.release()
460                continue
461
462    def handle_EOF(self):
463        "action taken upon link being closed by peer"
464        self.EOFhook()
465        self.debug("handle_EOF")
466        for key in self.cvars:
467            cv = self.cvars[key]
468            cv.acquire()
469            self.responses[key] = ('EOF', None)
470            cv.notify()
471            cv.release()
472        # call our (possibly overridden) exit function
473        self.exithook()
474
475    def EOFhook(self):
476        "Classes using rpc client/server can override to augment EOF action"
477        pass
478
479#----------------- end class SocketIO --------------------
480
481class RemoteObject(object):
482    # Token mix-in class
483    pass
484
485def remoteref(obj):
486    oid = id(obj)
487    objecttable[oid] = obj
488    return RemoteProxy(oid)
489
490class RemoteProxy(object):
491
492    def __init__(self, oid):
493        self.oid = oid
494
495class RPCHandler(SocketServer.BaseRequestHandler, SocketIO):
496
497    debugging = False
498    location = "#S"  # Server
499
500    def __init__(self, sock, addr, svr):
501        svr.current_handler = self ## cgt xxx
502        SocketIO.__init__(self, sock)
503        SocketServer.BaseRequestHandler.__init__(self, sock, addr, svr)
504
505    def handle(self):
506        "handle() method required by SocketServer"
507        self.mainloop()
508
509    def get_remote_proxy(self, oid):
510        return RPCProxy(self, oid)
511
512class RPCClient(SocketIO):
513
514    debugging = False
515    location = "#C"  # Client
516
517    nextseq = 1 # Requests coming from the client are odd numbered
518
519    def __init__(self, address, family=socket.AF_INET, type=socket.SOCK_STREAM):
520        self.listening_sock = socket.socket(family, type)
521        self.listening_sock.bind(address)
522        self.listening_sock.listen(1)
523
524    def accept(self):
525        working_sock, address = self.listening_sock.accept()
526        if self.debugging:
527            print>>sys.__stderr__, "****** Connection request from ", address
528        if address[0] == LOCALHOST:
529            SocketIO.__init__(self, working_sock)
530        else:
531            print>>sys.__stderr__, "** Invalid host: ", address
532            raise socket.error
533
534    def get_remote_proxy(self, oid):
535        return RPCProxy(self, oid)
536
537class RPCProxy(object):
538
539    __methods = None
540    __attributes = None
541
542    def __init__(self, sockio, oid):
543        self.sockio = sockio
544        self.oid = oid
545
546    def __getattr__(self, name):
547        if self.__methods is None:
548            self.__getmethods()
549        if self.__methods.get(name):
550            return MethodProxy(self.sockio, self.oid, name)
551        if self.__attributes is None:
552            self.__getattributes()
553        if name in self.__attributes:
554            value = self.sockio.remotecall(self.oid, '__getattribute__',
555                                           (name,), {})
556            return value
557        else:
558            raise AttributeError, name
559
560    def __getattributes(self):
561        self.__attributes = self.sockio.remotecall(self.oid,
562                                                "__attributes__", (), {})
563
564    def __getmethods(self):
565        self.__methods = self.sockio.remotecall(self.oid,
566                                                "__methods__", (), {})
567
568def _getmethods(obj, methods):
569    # Helper to get a list of methods from an object
570    # Adds names to dictionary argument 'methods'
571    for name in dir(obj):
572        attr = getattr(obj, name)
573        if hasattr(attr, '__call__'):
574            methods[name] = 1
575    if type(obj) == types.InstanceType:
576        _getmethods(obj.__class__, methods)
577    if type(obj) == types.ClassType:
578        for super in obj.__bases__:
579            _getmethods(super, methods)
580
581def _getattributes(obj, attributes):
582    for name in dir(obj):
583        attr = getattr(obj, name)
584        if not hasattr(attr, '__call__'):
585            attributes[name] = 1
586
587class MethodProxy(object):
588
589    def __init__(self, sockio, oid, name):
590        self.sockio = sockio
591        self.oid = oid
592        self.name = name
593
594    def __call__(self, *args, **kwargs):
595        value = self.sockio.remotecall(self.oid, self.name, args, kwargs)
596        return value
597
598
599# XXX KBK 09Sep03  We need a proper unit test for this module.  Previously
600#                  existing test code was removed at Rev 1.27 (r34098).
601