1# Copyright (c) 2001-2006 Twisted Matrix Laboratories.
2#
3# Permission is hereby granted, free of charge, to any person obtaining
4# a copy of this software and associated documentation files (the
5# "Software"), to deal in the Software without restriction, including
6# without limitation the rights to use, copy, modify, merge, publish,
7# distribute, sublicense, and/or sell copies of the Software, and to
8# permit persons to whom the Software is furnished to do so, subject to
9# the following conditions:
10#
11# The above copyright notice and this permission notice shall be
12# included in all copies or substantial portions of the Software.
13#
14# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21"""
22Tests for epoll wrapper.
23"""
24import socket
25import errno
26import time
27import select
28import unittest
29
30from test import test_support
31if not hasattr(select, "epoll"):
32    raise unittest.SkipTest("test works only on Linux 2.6")
33
34try:
35    select.epoll()
36except IOError, e:
37    if e.errno == errno.ENOSYS:
38        raise unittest.SkipTest("kernel doesn't support epoll()")
39    raise
40
41class TestEPoll(unittest.TestCase):
42
43    def setUp(self):
44        self.serverSocket = socket.socket()
45        self.serverSocket.bind(('127.0.0.1', 0))
46        self.serverSocket.listen(1)
47        self.connections = [self.serverSocket]
48
49
50    def tearDown(self):
51        for skt in self.connections:
52            skt.close()
53
54    def _connected_pair(self):
55        client = socket.socket()
56        client.setblocking(False)
57        try:
58            client.connect(('127.0.0.1', self.serverSocket.getsockname()[1]))
59        except socket.error, e:
60            self.assertEqual(e.args[0], errno.EINPROGRESS)
61        else:
62            raise AssertionError("Connect should have raised EINPROGRESS")
63        server, addr = self.serverSocket.accept()
64
65        self.connections.extend((client, server))
66        return client, server
67
68    def test_create(self):
69        try:
70            ep = select.epoll(16)
71        except OSError, e:
72            raise AssertionError(str(e))
73        self.assertTrue(ep.fileno() > 0, ep.fileno())
74        self.assertTrue(not ep.closed)
75        ep.close()
76        self.assertTrue(ep.closed)
77        self.assertRaises(ValueError, ep.fileno)
78
79    def test_badcreate(self):
80        self.assertRaises(TypeError, select.epoll, 1, 2, 3)
81        self.assertRaises(TypeError, select.epoll, 'foo')
82        self.assertRaises(TypeError, select.epoll, None)
83        self.assertRaises(TypeError, select.epoll, ())
84        self.assertRaises(TypeError, select.epoll, ['foo'])
85        self.assertRaises(TypeError, select.epoll, {})
86
87    def test_add(self):
88        server, client = self._connected_pair()
89
90        ep = select.epoll(2)
91        try:
92            ep.register(server.fileno(), select.EPOLLIN | select.EPOLLOUT)
93            ep.register(client.fileno(), select.EPOLLIN | select.EPOLLOUT)
94        finally:
95            ep.close()
96
97        # adding by object w/ fileno works, too.
98        ep = select.epoll(2)
99        try:
100            ep.register(server, select.EPOLLIN | select.EPOLLOUT)
101            ep.register(client, select.EPOLLIN | select.EPOLLOUT)
102        finally:
103            ep.close()
104
105        ep = select.epoll(2)
106        try:
107            # TypeError: argument must be an int, or have a fileno() method.
108            self.assertRaises(TypeError, ep.register, object(),
109                select.EPOLLIN | select.EPOLLOUT)
110            self.assertRaises(TypeError, ep.register, None,
111                select.EPOLLIN | select.EPOLLOUT)
112            # ValueError: file descriptor cannot be a negative integer (-1)
113            self.assertRaises(ValueError, ep.register, -1,
114                select.EPOLLIN | select.EPOLLOUT)
115            # IOError: [Errno 9] Bad file descriptor
116            self.assertRaises(IOError, ep.register, 10000,
117                select.EPOLLIN | select.EPOLLOUT)
118            # registering twice also raises an exception
119            ep.register(server, select.EPOLLIN | select.EPOLLOUT)
120            self.assertRaises(IOError, ep.register, server,
121                select.EPOLLIN | select.EPOLLOUT)
122        finally:
123            ep.close()
124
125    def test_fromfd(self):
126        server, client = self._connected_pair()
127
128        ep = select.epoll(2)
129        ep2 = select.epoll.fromfd(ep.fileno())
130
131        ep2.register(server.fileno(), select.EPOLLIN | select.EPOLLOUT)
132        ep2.register(client.fileno(), select.EPOLLIN | select.EPOLLOUT)
133
134        events = ep.poll(1, 4)
135        events2 = ep2.poll(0.9, 4)
136        self.assertEqual(len(events), 2)
137        self.assertEqual(len(events2), 2)
138
139        ep.close()
140        try:
141            ep2.poll(1, 4)
142        except IOError, e:
143            self.assertEqual(e.args[0], errno.EBADF, e)
144        else:
145            self.fail("epoll on closed fd didn't raise EBADF")
146
147    def test_control_and_wait(self):
148        client, server = self._connected_pair()
149
150        ep = select.epoll(16)
151        ep.register(server.fileno(),
152                   select.EPOLLIN | select.EPOLLOUT | select.EPOLLET)
153        ep.register(client.fileno(),
154                   select.EPOLLIN | select.EPOLLOUT | select.EPOLLET)
155
156        now = time.time()
157        events = ep.poll(1, 4)
158        then = time.time()
159        self.assertFalse(then - now > 0.1, then - now)
160
161        events.sort()
162        expected = [(client.fileno(), select.EPOLLOUT),
163                    (server.fileno(), select.EPOLLOUT)]
164        expected.sort()
165
166        self.assertEqual(events, expected)
167        self.assertFalse(then - now > 0.01, then - now)
168
169        now = time.time()
170        events = ep.poll(timeout=2.1, maxevents=4)
171        then = time.time()
172        self.assertFalse(events)
173
174        client.send("Hello!")
175        server.send("world!!!")
176
177        now = time.time()
178        events = ep.poll(1, 4)
179        then = time.time()
180        self.assertFalse(then - now > 0.01)
181
182        events.sort()
183        expected = [(client.fileno(), select.EPOLLIN | select.EPOLLOUT),
184                    (server.fileno(), select.EPOLLIN | select.EPOLLOUT)]
185        expected.sort()
186
187        self.assertEqual(events, expected)
188
189        ep.unregister(client.fileno())
190        ep.modify(server.fileno(), select.EPOLLOUT)
191        now = time.time()
192        events = ep.poll(1, 4)
193        then = time.time()
194        self.assertFalse(then - now > 0.01)
195
196        expected = [(server.fileno(), select.EPOLLOUT)]
197        self.assertEqual(events, expected)
198
199    def test_errors(self):
200        self.assertRaises(ValueError, select.epoll, -2)
201        self.assertRaises(ValueError, select.epoll().register, -1,
202                          select.EPOLLIN)
203
204    def test_unregister_closed(self):
205        server, client = self._connected_pair()
206        fd = server.fileno()
207        ep = select.epoll(16)
208        ep.register(server)
209
210        now = time.time()
211        events = ep.poll(1, 4)
212        then = time.time()
213        self.assertFalse(then - now > 0.01)
214
215        server.close()
216        ep.unregister(fd)
217
218def test_main():
219    test_support.run_unittest(TestEPoll)
220
221if __name__ == "__main__":
222    test_main()
223