1# Copyright (c) 2001-2006 Twisted Matrix Laboratories.
2#
3# Permission is hereby granted, free of charge, to any person obtaining
4# a copy of this software and associated documentation files (the
5# "Software"), to deal in the Software without restriction, including
6# without limitation the rights to use, copy, modify, merge, publish,
7# distribute, sublicense, and/or sell copies of the Software, and to
8# permit persons to whom the Software is furnished to do so, subject to
9# the following conditions:
10#
11# The above copyright notice and this permission notice shall be
12# included in all copies or substantial portions of the Software.
13#
14# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21"""
22Tests for epoll wrapper.
23"""
24import errno
25import os
26import select
27import socket
28import time
29import unittest
30
31if not hasattr(select, "epoll"):
32    raise unittest.SkipTest("test works only on Linux 2.6")
33
34try:
35    select.epoll()
36except OSError as e:
37    if e.errno == errno.ENOSYS:
38        raise unittest.SkipTest("kernel doesn't support epoll()")
39    raise
40
41class TestEPoll(unittest.TestCase):
42
43    def setUp(self):
44        self.serverSocket = socket.socket()
45        self.serverSocket.bind(('127.0.0.1', 0))
46        self.serverSocket.listen()
47        self.connections = [self.serverSocket]
48
49    def tearDown(self):
50        for skt in self.connections:
51            skt.close()
52
53    def _connected_pair(self):
54        client = socket.socket()
55        client.setblocking(False)
56        try:
57            client.connect(('127.0.0.1', self.serverSocket.getsockname()[1]))
58        except OSError as e:
59            self.assertEqual(e.args[0], errno.EINPROGRESS)
60        else:
61            raise AssertionError("Connect should have raised EINPROGRESS")
62        server, addr = self.serverSocket.accept()
63
64        self.connections.extend((client, server))
65        return client, server
66
67    def test_create(self):
68        try:
69            ep = select.epoll(16)
70        except OSError as e:
71            raise AssertionError(str(e))
72        self.assertTrue(ep.fileno() > 0, ep.fileno())
73        self.assertTrue(not ep.closed)
74        ep.close()
75        self.assertTrue(ep.closed)
76        self.assertRaises(ValueError, ep.fileno)
77        if hasattr(select, "EPOLL_CLOEXEC"):
78            select.epoll(select.EPOLL_CLOEXEC).close()
79            self.assertRaises(OSError, select.epoll, flags=12356)
80
81    def test_badcreate(self):
82        self.assertRaises(TypeError, select.epoll, 1, 2, 3)
83        self.assertRaises(TypeError, select.epoll, 'foo')
84        self.assertRaises(TypeError, select.epoll, None)
85        self.assertRaises(TypeError, select.epoll, ())
86        self.assertRaises(TypeError, select.epoll, ['foo'])
87        self.assertRaises(TypeError, select.epoll, {})
88
89    def test_context_manager(self):
90        with select.epoll(16) as ep:
91            self.assertGreater(ep.fileno(), 0)
92            self.assertFalse(ep.closed)
93        self.assertTrue(ep.closed)
94        self.assertRaises(ValueError, ep.fileno)
95
96    def test_add(self):
97        server, client = self._connected_pair()
98
99        ep = select.epoll(2)
100        try:
101            ep.register(server.fileno(), select.EPOLLIN | select.EPOLLOUT)
102            ep.register(client.fileno(), select.EPOLLIN | select.EPOLLOUT)
103        finally:
104            ep.close()
105
106        # adding by object w/ fileno works, too.
107        ep = select.epoll(2)
108        try:
109            ep.register(server, select.EPOLLIN | select.EPOLLOUT)
110            ep.register(client, select.EPOLLIN | select.EPOLLOUT)
111        finally:
112            ep.close()
113
114        ep = select.epoll(2)
115        try:
116            # TypeError: argument must be an int, or have a fileno() method.
117            self.assertRaises(TypeError, ep.register, object(),
118                select.EPOLLIN | select.EPOLLOUT)
119            self.assertRaises(TypeError, ep.register, None,
120                select.EPOLLIN | select.EPOLLOUT)
121            # ValueError: file descriptor cannot be a negative integer (-1)
122            self.assertRaises(ValueError, ep.register, -1,
123                select.EPOLLIN | select.EPOLLOUT)
124            # OSError: [Errno 9] Bad file descriptor
125            self.assertRaises(OSError, ep.register, 10000,
126                select.EPOLLIN | select.EPOLLOUT)
127            # registering twice also raises an exception
128            ep.register(server, select.EPOLLIN | select.EPOLLOUT)
129            self.assertRaises(OSError, ep.register, server,
130                select.EPOLLIN | select.EPOLLOUT)
131        finally:
132            ep.close()
133
134    def test_fromfd(self):
135        server, client = self._connected_pair()
136
137        ep = select.epoll(2)
138        ep2 = select.epoll.fromfd(ep.fileno())
139
140        ep2.register(server.fileno(), select.EPOLLIN | select.EPOLLOUT)
141        ep2.register(client.fileno(), select.EPOLLIN | select.EPOLLOUT)
142
143        events = ep.poll(1, 4)
144        events2 = ep2.poll(0.9, 4)
145        self.assertEqual(len(events), 2)
146        self.assertEqual(len(events2), 2)
147
148        ep.close()
149        try:
150            ep2.poll(1, 4)
151        except OSError as e:
152            self.assertEqual(e.args[0], errno.EBADF, e)
153        else:
154            self.fail("epoll on closed fd didn't raise EBADF")
155
156    def test_control_and_wait(self):
157        client, server = self._connected_pair()
158
159        ep = select.epoll(16)
160        ep.register(server.fileno(),
161                   select.EPOLLIN | select.EPOLLOUT | select.EPOLLET)
162        ep.register(client.fileno(),
163                   select.EPOLLIN | select.EPOLLOUT | select.EPOLLET)
164
165        now = time.monotonic()
166        events = ep.poll(1, 4)
167        then = time.monotonic()
168        self.assertFalse(then - now > 0.1, then - now)
169
170        events.sort()
171        expected = [(client.fileno(), select.EPOLLOUT),
172                    (server.fileno(), select.EPOLLOUT)]
173        expected.sort()
174
175        self.assertEqual(events, expected)
176
177        events = ep.poll(timeout=2.1, maxevents=4)
178        self.assertFalse(events)
179
180        client.send(b"Hello!")
181        server.send(b"world!!!")
182
183        now = time.monotonic()
184        events = ep.poll(1, 4)
185        then = time.monotonic()
186        self.assertFalse(then - now > 0.01)
187
188        events.sort()
189        expected = [(client.fileno(), select.EPOLLIN | select.EPOLLOUT),
190                    (server.fileno(), select.EPOLLIN | select.EPOLLOUT)]
191        expected.sort()
192
193        self.assertEqual(events, expected)
194
195        ep.unregister(client.fileno())
196        ep.modify(server.fileno(), select.EPOLLOUT)
197        now = time.monotonic()
198        events = ep.poll(1, 4)
199        then = time.monotonic()
200        self.assertFalse(then - now > 0.01)
201
202        expected = [(server.fileno(), select.EPOLLOUT)]
203        self.assertEqual(events, expected)
204
205    def test_errors(self):
206        self.assertRaises(ValueError, select.epoll, -2)
207        self.assertRaises(ValueError, select.epoll().register, -1,
208                          select.EPOLLIN)
209
210    def test_unregister_closed(self):
211        server, client = self._connected_pair()
212        fd = server.fileno()
213        ep = select.epoll(16)
214        ep.register(server)
215
216        now = time.monotonic()
217        events = ep.poll(1, 4)
218        then = time.monotonic()
219        self.assertFalse(then - now > 0.01)
220
221        server.close()
222        ep.unregister(fd)
223
224    def test_close(self):
225        open_file = open(__file__, "rb")
226        self.addCleanup(open_file.close)
227        fd = open_file.fileno()
228        epoll = select.epoll()
229
230        # test fileno() method and closed attribute
231        self.assertIsInstance(epoll.fileno(), int)
232        self.assertFalse(epoll.closed)
233
234        # test close()
235        epoll.close()
236        self.assertTrue(epoll.closed)
237        self.assertRaises(ValueError, epoll.fileno)
238
239        # close() can be called more than once
240        epoll.close()
241
242        # operations must fail with ValueError("I/O operation on closed ...")
243        self.assertRaises(ValueError, epoll.modify, fd, select.EPOLLIN)
244        self.assertRaises(ValueError, epoll.poll, 1.0)
245        self.assertRaises(ValueError, epoll.register, fd, select.EPOLLIN)
246        self.assertRaises(ValueError, epoll.unregister, fd)
247
248    def test_fd_non_inheritable(self):
249        epoll = select.epoll()
250        self.addCleanup(epoll.close)
251        self.assertEqual(os.get_inheritable(epoll.fileno()), False)
252
253
254if __name__ == "__main__":
255    unittest.main()
256