1# Copyright (c) 2001-2006 Twisted Matrix Laboratories.
2#
3# Permission is hereby granted, free of charge, to any person obtaining
4# a copy of this software and associated documentation files (the
5# "Software"), to deal in the Software without restriction, including
6# without limitation the rights to use, copy, modify, merge, publish,
7# distribute, sublicense, and/or sell copies of the Software, and to
8# permit persons to whom the Software is furnished to do so, subject to
9# the following conditions:
10#
11# The above copyright notice and this permission notice shall be
12# included in all copies or substantial portions of the Software.
13#
14# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21"""
22Tests for epoll wrapper.
23"""
24import socket
25import errno
26import time
27import select
28import unittest
29
30from test import test_support
31if not hasattr(select, "epoll"):
32    raise unittest.SkipTest("test works only on Linux 2.6")
33
34try:
35    select.epoll()
36except IOError, e:
37    if e.errno == errno.ENOSYS:
38        raise unittest.SkipTest("kernel doesn't support epoll()")
39    raise
40
41class TestEPoll(unittest.TestCase):
42
43    def setUp(self):
44        self.serverSocket = socket.socket()
45        self.serverSocket.bind(('127.0.0.1', 0))
46        self.serverSocket.listen(1)
47        self.connections = [self.serverSocket]
48
49
50    def tearDown(self):
51        for skt in self.connections:
52            skt.close()
53
54    def _connected_pair(self):
55        client = socket.socket()
56        client.setblocking(False)
57        try:
58            client.connect(('127.0.0.1', self.serverSocket.getsockname()[1]))
59        except socket.error, e:
60            self.assertEqual(e.args[0], errno.EINPROGRESS)
61        else:
62            raise AssertionError("Connect should have raised EINPROGRESS")
63        server, addr = self.serverSocket.accept()
64
65        self.connections.extend((client, server))
66        return client, server
67
68    def test_create(self):
69        try:
70            ep = select.epoll(16)
71        except OSError, e:
72            raise AssertionError(str(e))
73        self.assertTrue(ep.fileno() > 0, ep.fileno())
74        self.assertTrue(not ep.closed)
75        ep.close()
76        self.assertTrue(ep.closed)
77        self.assertRaises(ValueError, ep.fileno)
78
79    def test_badcreate(self):
80        self.assertRaises(TypeError, select.epoll, 1, 2, 3)
81        self.assertRaises(TypeError, select.epoll, 'foo')
82        self.assertRaises(TypeError, select.epoll, None)
83        self.assertRaises(TypeError, select.epoll, ())
84        self.assertRaises(TypeError, select.epoll, ['foo'])
85        self.assertRaises(TypeError, select.epoll, {})
86
87    def test_add(self):
88        server, client = self._connected_pair()
89
90        ep = select.epoll(2)
91        try:
92            ep.register(server.fileno(), select.EPOLLIN | select.EPOLLOUT)
93            ep.register(client.fileno(), select.EPOLLIN | select.EPOLLOUT)
94        finally:
95            ep.close()
96
97        # adding by object w/ fileno works, too.
98        ep = select.epoll(2)
99        try:
100            ep.register(server, select.EPOLLIN | select.EPOLLOUT)
101            ep.register(client, select.EPOLLIN | select.EPOLLOUT)
102        finally:
103            ep.close()
104
105        ep = select.epoll(2)
106        try:
107            # TypeError: argument must be an int, or have a fileno() method.
108            self.assertRaises(TypeError, ep.register, object(),
109                select.EPOLLIN | select.EPOLLOUT)
110            self.assertRaises(TypeError, ep.register, None,
111                select.EPOLLIN | select.EPOLLOUT)
112            # ValueError: file descriptor cannot be a negative integer (-1)
113            self.assertRaises(ValueError, ep.register, -1,
114                select.EPOLLIN | select.EPOLLOUT)
115            # IOError: [Errno 9] Bad file descriptor
116            self.assertRaises(IOError, ep.register, 10000,
117                select.EPOLLIN | select.EPOLLOUT)
118            # registering twice also raises an exception
119            ep.register(server, select.EPOLLIN | select.EPOLLOUT)
120            self.assertRaises(IOError, ep.register, server,
121                select.EPOLLIN | select.EPOLLOUT)
122        finally:
123            ep.close()
124
125    def test_fromfd(self):
126        server, client = self._connected_pair()
127
128        ep = select.epoll(2)
129        ep2 = select.epoll.fromfd(ep.fileno())
130
131        ep2.register(server.fileno(), select.EPOLLIN | select.EPOLLOUT)
132        ep2.register(client.fileno(), select.EPOLLIN | select.EPOLLOUT)
133
134        events = ep.poll(1, 4)
135        events2 = ep2.poll(0.9, 4)
136        self.assertEqual(len(events), 2)
137        self.assertEqual(len(events2), 2)
138
139        ep.close()
140        try:
141            ep2.poll(1, 4)
142        except IOError, e:
143            self.assertEqual(e.args[0], errno.EBADF, e)
144        else:
145            self.fail("epoll on closed fd didn't raise EBADF")
146
147    def test_control_and_wait(self):
148        client, server = self._connected_pair()
149
150        ep = select.epoll(16)
151        ep.register(server.fileno(),
152                   select.EPOLLIN | select.EPOLLOUT | select.EPOLLET)
153        ep.register(client.fileno(),
154                   select.EPOLLIN | select.EPOLLOUT | select.EPOLLET)
155
156        now = time.time()
157        events = ep.poll(1, 4)
158        then = time.time()
159        self.assertFalse(then - now > 0.1, then - now)
160
161        events.sort()
162        expected = [(client.fileno(), select.EPOLLOUT),
163                    (server.fileno(), select.EPOLLOUT)]
164        expected.sort()
165
166        self.assertEqual(events, expected)
167
168        events = ep.poll(timeout=2.1, maxevents=4)
169        self.assertFalse(events)
170
171        client.send("Hello!")
172        server.send("world!!!")
173
174        now = time.time()
175        events = ep.poll(1, 4)
176        then = time.time()
177        self.assertFalse(then - now > 0.01)
178
179        events.sort()
180        expected = [(client.fileno(), select.EPOLLIN | select.EPOLLOUT),
181                    (server.fileno(), select.EPOLLIN | select.EPOLLOUT)]
182        expected.sort()
183
184        self.assertEqual(events, expected)
185
186        ep.unregister(client.fileno())
187        ep.modify(server.fileno(), select.EPOLLOUT)
188        now = time.time()
189        events = ep.poll(1, 4)
190        then = time.time()
191        self.assertFalse(then - now > 0.01)
192
193        expected = [(server.fileno(), select.EPOLLOUT)]
194        self.assertEqual(events, expected)
195
196    def test_errors(self):
197        self.assertRaises(ValueError, select.epoll, -2)
198        self.assertRaises(ValueError, select.epoll().register, -1,
199                          select.EPOLLIN)
200
201    def test_unregister_closed(self):
202        server, client = self._connected_pair()
203        fd = server.fileno()
204        ep = select.epoll(16)
205        ep.register(server)
206
207        now = time.time()
208        events = ep.poll(1, 4)
209        then = time.time()
210        self.assertFalse(then - now > 0.01)
211
212        server.close()
213        ep.unregister(fd)
214
215def test_main():
216    test_support.run_unittest(TestEPoll)
217
218if __name__ == "__main__":
219    test_main()
220