1#!/usr/bin/python
2"""
3Client for file transfer services offered by RSS (Remote Shell Server).
4
5@author: Michael Goldish (mgoldish@redhat.com)
6@copyright: 2008-2010 Red Hat Inc.
7"""
8
9import socket, struct, time, sys, os, glob
10
11# Globals
12CHUNKSIZE = 65536
13
14# Protocol message constants
15RSS_MAGIC           = 0x525353
16RSS_OK              = 1
17RSS_ERROR           = 2
18RSS_UPLOAD          = 3
19RSS_DOWNLOAD        = 4
20RSS_SET_PATH        = 5
21RSS_CREATE_FILE     = 6
22RSS_CREATE_DIR      = 7
23RSS_LEAVE_DIR       = 8
24RSS_DONE            = 9
25
26# See rss.cpp for protocol details.
27
28
29class FileTransferError(Exception):
30    def __init__(self, msg, e=None, filename=None):
31        Exception.__init__(self, msg, e, filename)
32        self.msg = msg
33        self.e = e
34        self.filename = filename
35
36    def __str__(self):
37        s = self.msg
38        if self.e and self.filename:
39            s += "    (error: %s,    filename: %s)" % (self.e, self.filename)
40        elif self.e:
41            s += "    (%s)" % self.e
42        elif self.filename:
43            s += "    (filename: %s)" % self.filename
44        return s
45
46
47class FileTransferConnectError(FileTransferError):
48    pass
49
50
51class FileTransferTimeoutError(FileTransferError):
52    pass
53
54
55class FileTransferProtocolError(FileTransferError):
56    pass
57
58
59class FileTransferSocketError(FileTransferError):
60    pass
61
62
63class FileTransferServerError(FileTransferError):
64    def __init__(self, errmsg):
65        FileTransferError.__init__(self, None, errmsg)
66
67    def __str__(self):
68        s = "Server said: %r" % self.e
69        if self.filename:
70            s += "    (filename: %s)" % self.filename
71        return s
72
73
74class FileTransferNotFoundError(FileTransferError):
75    pass
76
77
78class FileTransferClient(object):
79    """
80    Connect to a RSS (remote shell server) and transfer files.
81    """
82
83    def __init__(self, address, port, log_func=None, timeout=20):
84        """
85        Connect to a server.
86
87        @param address: The server's address
88        @param port: The server's port
89        @param log_func: If provided, transfer stats will be passed to this
90                function during the transfer
91        @param timeout: Time duration to wait for connection to succeed
92        @raise FileTransferConnectError: Raised if the connection fails
93        """
94        self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
95        self._socket.settimeout(timeout)
96        try:
97            self._socket.connect((address, port))
98        except socket.error, e:
99            raise FileTransferConnectError("Cannot connect to server at "
100                                           "%s:%s" % (address, port), e)
101        try:
102            if self._receive_msg(timeout) != RSS_MAGIC:
103                raise FileTransferConnectError("Received wrong magic number")
104        except FileTransferTimeoutError:
105            raise FileTransferConnectError("Timeout expired while waiting to "
106                                           "receive magic number")
107        self._send(struct.pack("=i", CHUNKSIZE))
108        self._log_func = log_func
109        self._last_time = time.time()
110        self._last_transferred = 0
111        self.transferred = 0
112
113
114    def __del__(self):
115        self.close()
116
117
118    def close(self):
119        """
120        Close the connection.
121        """
122        self._socket.close()
123
124
125    def _send(self, str, timeout=60):
126        try:
127            if timeout <= 0:
128                raise socket.timeout
129            self._socket.settimeout(timeout)
130            self._socket.sendall(str)
131        except socket.timeout:
132            raise FileTransferTimeoutError("Timeout expired while sending "
133                                           "data to server")
134        except socket.error, e:
135            raise FileTransferSocketError("Could not send data to server", e)
136
137
138    def _receive(self, size, timeout=60):
139        strs = []
140        end_time = time.time() + timeout
141        try:
142            while size > 0:
143                timeout = end_time - time.time()
144                if timeout <= 0:
145                    raise socket.timeout
146                self._socket.settimeout(timeout)
147                data = self._socket.recv(size)
148                if not data:
149                    raise FileTransferProtocolError("Connection closed "
150                                                    "unexpectedly while "
151                                                    "receiving data from "
152                                                    "server")
153                strs.append(data)
154                size -= len(data)
155        except socket.timeout:
156            raise FileTransferTimeoutError("Timeout expired while receiving "
157                                           "data from server")
158        except socket.error, e:
159            raise FileTransferSocketError("Error receiving data from server",
160                                          e)
161        return "".join(strs)
162
163
164    def _report_stats(self, str):
165        if self._log_func:
166            dt = time.time() - self._last_time
167            if dt >= 1:
168                transferred = self.transferred / 1048576.
169                speed = (self.transferred - self._last_transferred) / dt
170                speed /= 1048576.
171                self._log_func("%s %.3f MB (%.3f MB/sec)" %
172                               (str, transferred, speed))
173                self._last_time = time.time()
174                self._last_transferred = self.transferred
175
176
177    def _send_packet(self, str, timeout=60):
178        self._send(struct.pack("=I", len(str)))
179        self._send(str, timeout)
180        self.transferred += len(str) + 4
181        self._report_stats("Sent")
182
183
184    def _receive_packet(self, timeout=60):
185        size = struct.unpack("=I", self._receive(4))[0]
186        str = self._receive(size, timeout)
187        self.transferred += len(str) + 4
188        self._report_stats("Received")
189        return str
190
191
192    def _send_file_chunks(self, filename, timeout=60):
193        if self._log_func:
194            self._log_func("Sending file %s" % filename)
195        f = open(filename, "rb")
196        try:
197            try:
198                end_time = time.time() + timeout
199                while True:
200                    data = f.read(CHUNKSIZE)
201                    self._send_packet(data, end_time - time.time())
202                    if len(data) < CHUNKSIZE:
203                        break
204            except FileTransferError, e:
205                e.filename = filename
206                raise
207        finally:
208            f.close()
209
210
211    def _receive_file_chunks(self, filename, timeout=60):
212        if self._log_func:
213            self._log_func("Receiving file %s" % filename)
214        f = open(filename, "wb")
215        try:
216            try:
217                end_time = time.time() + timeout
218                while True:
219                    data = self._receive_packet(end_time - time.time())
220                    f.write(data)
221                    if len(data) < CHUNKSIZE:
222                        break
223            except FileTransferError, e:
224                e.filename = filename
225                raise
226        finally:
227            f.close()
228
229
230    def _send_msg(self, msg, timeout=60):
231        self._send(struct.pack("=I", msg))
232
233
234    def _receive_msg(self, timeout=60):
235        s = self._receive(4, timeout)
236        return struct.unpack("=I", s)[0]
237
238
239    def _handle_transfer_error(self):
240        # Save original exception
241        e = sys.exc_info()
242        try:
243            # See if we can get an error message
244            msg = self._receive_msg()
245        except FileTransferError:
246            # No error message -- re-raise original exception
247            raise e[0], e[1], e[2]
248        if msg == RSS_ERROR:
249            errmsg = self._receive_packet()
250            raise FileTransferServerError(errmsg)
251        raise e[0], e[1], e[2]
252
253
254class FileUploadClient(FileTransferClient):
255    """
256    Connect to a RSS (remote shell server) and upload files or directory trees.
257    """
258
259    def __init__(self, address, port, log_func=None, timeout=20):
260        """
261        Connect to a server.
262
263        @param address: The server's address
264        @param port: The server's port
265        @param log_func: If provided, transfer stats will be passed to this
266                function during the transfer
267        @param timeout: Time duration to wait for connection to succeed
268        @raise FileTransferConnectError: Raised if the connection fails
269        @raise FileTransferProtocolError: Raised if an incorrect magic number
270                is received
271        @raise FileTransferSocketError: Raised if the RSS_UPLOAD message cannot
272                be sent to the server
273        """
274        super(FileUploadClient, self).__init__(address, port, log_func, timeout)
275        self._send_msg(RSS_UPLOAD)
276
277
278    def _upload_file(self, path, end_time):
279        if os.path.isfile(path):
280            self._send_msg(RSS_CREATE_FILE)
281            self._send_packet(os.path.basename(path))
282            self._send_file_chunks(path, end_time - time.time())
283        elif os.path.isdir(path):
284            self._send_msg(RSS_CREATE_DIR)
285            self._send_packet(os.path.basename(path))
286            for filename in os.listdir(path):
287                self._upload_file(os.path.join(path, filename), end_time)
288            self._send_msg(RSS_LEAVE_DIR)
289
290
291    def upload(self, src_pattern, dst_path, timeout=600):
292        """
293        Send files or directory trees to the server.
294        The semantics of src_pattern and dst_path are similar to those of scp.
295        For example, the following are OK:
296            src_pattern='/tmp/foo.txt', dst_path='C:\\'
297                (uploads a single file)
298            src_pattern='/usr/', dst_path='C:\\Windows\\'
299                (uploads a directory tree recursively)
300            src_pattern='/usr/*', dst_path='C:\\Windows\\'
301                (uploads all files and directory trees under /usr/)
302        The following is not OK:
303            src_pattern='/tmp/foo.txt', dst_path='C:\\Windows\\*'
304                (wildcards are only allowed in src_pattern)
305
306        @param src_pattern: A path or wildcard pattern specifying the files or
307                directories to send to the server
308        @param dst_path: A path in the server's filesystem where the files will
309                be saved
310        @param timeout: Time duration in seconds to wait for the transfer to
311                complete
312        @raise FileTransferTimeoutError: Raised if timeout expires
313        @raise FileTransferServerError: Raised if something goes wrong and the
314                server sends an informative error message to the client
315        @note: Other exceptions can be raised.
316        """
317        end_time = time.time() + timeout
318        try:
319            try:
320                self._send_msg(RSS_SET_PATH)
321                self._send_packet(dst_path)
322                matches = glob.glob(src_pattern)
323                for filename in matches:
324                    self._upload_file(os.path.abspath(filename), end_time)
325                self._send_msg(RSS_DONE)
326            except FileTransferTimeoutError:
327                raise
328            except FileTransferError:
329                self._handle_transfer_error()
330            else:
331                # If nothing was transferred, raise an exception
332                if not matches:
333                    raise FileTransferNotFoundError("Pattern %s does not "
334                                                    "match any files or "
335                                                    "directories" %
336                                                    src_pattern)
337                # Look for RSS_OK or RSS_ERROR
338                msg = self._receive_msg(end_time - time.time())
339                if msg == RSS_OK:
340                    return
341                elif msg == RSS_ERROR:
342                    errmsg = self._receive_packet()
343                    raise FileTransferServerError(errmsg)
344                else:
345                    # Neither RSS_OK nor RSS_ERROR found
346                    raise FileTransferProtocolError("Received unexpected msg")
347        except:
348            # In any case, if the transfer failed, close the connection
349            self.close()
350            raise
351
352
353class FileDownloadClient(FileTransferClient):
354    """
355    Connect to a RSS (remote shell server) and download files or directory trees.
356    """
357
358    def __init__(self, address, port, log_func=None, timeout=20):
359        """
360        Connect to a server.
361
362        @param address: The server's address
363        @param port: The server's port
364        @param log_func: If provided, transfer stats will be passed to this
365                function during the transfer
366        @param timeout: Time duration to wait for connection to succeed
367        @raise FileTransferConnectError: Raised if the connection fails
368        @raise FileTransferProtocolError: Raised if an incorrect magic number
369                is received
370        @raise FileTransferSendError: Raised if the RSS_UPLOAD message cannot
371                be sent to the server
372        """
373        super(FileDownloadClient, self).__init__(address, port, log_func, timeout)
374        self._send_msg(RSS_DOWNLOAD)
375
376
377    def download(self, src_pattern, dst_path, timeout=600):
378        """
379        Receive files or directory trees from the server.
380        The semantics of src_pattern and dst_path are similar to those of scp.
381        For example, the following are OK:
382            src_pattern='C:\\foo.txt', dst_path='/tmp'
383                (downloads a single file)
384            src_pattern='C:\\Windows', dst_path='/tmp'
385                (downloads a directory tree recursively)
386            src_pattern='C:\\Windows\\*', dst_path='/tmp'
387                (downloads all files and directory trees under C:\\Windows)
388        The following is not OK:
389            src_pattern='C:\\Windows', dst_path='/tmp/*'
390                (wildcards are only allowed in src_pattern)
391
392        @param src_pattern: A path or wildcard pattern specifying the files or
393                directories, in the server's filesystem, that will be sent to
394                the client
395        @param dst_path: A path in the local filesystem where the files will
396                be saved
397        @param timeout: Time duration in seconds to wait for the transfer to
398                complete
399        @raise FileTransferTimeoutError: Raised if timeout expires
400        @raise FileTransferServerError: Raised if something goes wrong and the
401                server sends an informative error message to the client
402        @note: Other exceptions can be raised.
403        """
404        dst_path = os.path.abspath(dst_path)
405        end_time = time.time() + timeout
406        file_count = 0
407        dir_count = 0
408        try:
409            try:
410                self._send_msg(RSS_SET_PATH)
411                self._send_packet(src_pattern)
412            except FileTransferError:
413                self._handle_transfer_error()
414            while True:
415                msg = self._receive_msg()
416                if msg == RSS_CREATE_FILE:
417                    # Receive filename and file contents
418                    filename = self._receive_packet()
419                    if os.path.isdir(dst_path):
420                        dst_path = os.path.join(dst_path, filename)
421                    self._receive_file_chunks(dst_path, end_time - time.time())
422                    dst_path = os.path.dirname(dst_path)
423                    file_count += 1
424                elif msg == RSS_CREATE_DIR:
425                    # Receive dirname and create the directory
426                    dirname = self._receive_packet()
427                    if os.path.isdir(dst_path):
428                        dst_path = os.path.join(dst_path, dirname)
429                    if not os.path.isdir(dst_path):
430                        os.mkdir(dst_path)
431                    dir_count += 1
432                elif msg == RSS_LEAVE_DIR:
433                    # Return to parent dir
434                    dst_path = os.path.dirname(dst_path)
435                elif msg == RSS_DONE:
436                    # Transfer complete
437                    if not file_count and not dir_count:
438                        raise FileTransferNotFoundError("Pattern %s does not "
439                                                        "match any files or "
440                                                        "directories that "
441                                                        "could be downloaded" %
442                                                        src_pattern)
443                    break
444                elif msg == RSS_ERROR:
445                    # Receive error message and abort
446                    errmsg = self._receive_packet()
447                    raise FileTransferServerError(errmsg)
448                else:
449                    # Unexpected msg
450                    raise FileTransferProtocolError("Received unexpected msg")
451        except:
452            # In any case, if the transfer failed, close the connection
453            self.close()
454            raise
455
456
457def upload(address, port, src_pattern, dst_path, log_func=None, timeout=60,
458           connect_timeout=20):
459    """
460    Connect to server and upload files.
461
462    @see: FileUploadClient
463    """
464    client = FileUploadClient(address, port, log_func, connect_timeout)
465    client.upload(src_pattern, dst_path, timeout)
466    client.close()
467
468
469def download(address, port, src_pattern, dst_path, log_func=None, timeout=60,
470             connect_timeout=20):
471    """
472    Connect to server and upload files.
473
474    @see: FileDownloadClient
475    """
476    client = FileDownloadClient(address, port, log_func, connect_timeout)
477    client.download(src_pattern, dst_path, timeout)
478    client.close()
479
480
481def main():
482    import optparse
483
484    usage = "usage: %prog [options] address port src_pattern dst_path"
485    parser = optparse.OptionParser(usage=usage)
486    parser.add_option("-d", "--download",
487                      action="store_true", dest="download",
488                      help="download files from server")
489    parser.add_option("-u", "--upload",
490                      action="store_true", dest="upload",
491                      help="upload files to server")
492    parser.add_option("-v", "--verbose",
493                      action="store_true", dest="verbose",
494                      help="be verbose")
495    parser.add_option("-t", "--timeout",
496                      type="int", dest="timeout", default=3600,
497                      help="transfer timeout")
498    options, args = parser.parse_args()
499    if options.download == options.upload:
500        parser.error("you must specify either -d or -u")
501    if len(args) != 4:
502        parser.error("incorrect number of arguments")
503    address, port, src_pattern, dst_path = args
504    port = int(port)
505
506    logger = None
507    if options.verbose:
508        def p(s):
509            print s
510        logger = p
511
512    if options.download:
513        download(address, port, src_pattern, dst_path, logger, options.timeout)
514    elif options.upload:
515        upload(address, port, src_pattern, dst_path, logger, options.timeout)
516
517
518if __name__ == "__main__":
519    main()
520