1import errno
2import os
3import random
4import selectors
5import signal
6import socket
7import sys
8from test import support
9from time import sleep
10import unittest
11import unittest.mock
12import tempfile
13from time import monotonic as time
14try:
15    import resource
16except ImportError:
17    resource = None
18
19
20if hasattr(socket, 'socketpair'):
21    socketpair = socket.socketpair
22else:
23    def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
24        with socket.socket(family, type, proto) as l:
25            l.bind((support.HOST, 0))
26            l.listen()
27            c = socket.socket(family, type, proto)
28            try:
29                c.connect(l.getsockname())
30                caddr = c.getsockname()
31                while True:
32                    a, addr = l.accept()
33                    # check that we've got the correct client
34                    if addr == caddr:
35                        return c, a
36                    a.close()
37            except OSError:
38                c.close()
39                raise
40
41
42def find_ready_matching(ready, flag):
43    match = []
44    for key, events in ready:
45        if events & flag:
46            match.append(key.fileobj)
47    return match
48
49
50class BaseSelectorTestCase(unittest.TestCase):
51
52    def make_socketpair(self):
53        rd, wr = socketpair()
54        self.addCleanup(rd.close)
55        self.addCleanup(wr.close)
56        return rd, wr
57
58    def test_register(self):
59        s = self.SELECTOR()
60        self.addCleanup(s.close)
61
62        rd, wr = self.make_socketpair()
63
64        key = s.register(rd, selectors.EVENT_READ, "data")
65        self.assertIsInstance(key, selectors.SelectorKey)
66        self.assertEqual(key.fileobj, rd)
67        self.assertEqual(key.fd, rd.fileno())
68        self.assertEqual(key.events, selectors.EVENT_READ)
69        self.assertEqual(key.data, "data")
70
71        # register an unknown event
72        self.assertRaises(ValueError, s.register, 0, 999999)
73
74        # register an invalid FD
75        self.assertRaises(ValueError, s.register, -10, selectors.EVENT_READ)
76
77        # register twice
78        self.assertRaises(KeyError, s.register, rd, selectors.EVENT_READ)
79
80        # register the same FD, but with a different object
81        self.assertRaises(KeyError, s.register, rd.fileno(),
82                          selectors.EVENT_READ)
83
84    def test_unregister(self):
85        s = self.SELECTOR()
86        self.addCleanup(s.close)
87
88        rd, wr = self.make_socketpair()
89
90        s.register(rd, selectors.EVENT_READ)
91        s.unregister(rd)
92
93        # unregister an unknown file obj
94        self.assertRaises(KeyError, s.unregister, 999999)
95
96        # unregister twice
97        self.assertRaises(KeyError, s.unregister, rd)
98
99    def test_unregister_after_fd_close(self):
100        s = self.SELECTOR()
101        self.addCleanup(s.close)
102        rd, wr = self.make_socketpair()
103        r, w = rd.fileno(), wr.fileno()
104        s.register(r, selectors.EVENT_READ)
105        s.register(w, selectors.EVENT_WRITE)
106        rd.close()
107        wr.close()
108        s.unregister(r)
109        s.unregister(w)
110
111    @unittest.skipUnless(os.name == 'posix', "requires posix")
112    def test_unregister_after_fd_close_and_reuse(self):
113        s = self.SELECTOR()
114        self.addCleanup(s.close)
115        rd, wr = self.make_socketpair()
116        r, w = rd.fileno(), wr.fileno()
117        s.register(r, selectors.EVENT_READ)
118        s.register(w, selectors.EVENT_WRITE)
119        rd2, wr2 = self.make_socketpair()
120        rd.close()
121        wr.close()
122        os.dup2(rd2.fileno(), r)
123        os.dup2(wr2.fileno(), w)
124        self.addCleanup(os.close, r)
125        self.addCleanup(os.close, w)
126        s.unregister(r)
127        s.unregister(w)
128
129    def test_unregister_after_socket_close(self):
130        s = self.SELECTOR()
131        self.addCleanup(s.close)
132        rd, wr = self.make_socketpair()
133        s.register(rd, selectors.EVENT_READ)
134        s.register(wr, selectors.EVENT_WRITE)
135        rd.close()
136        wr.close()
137        s.unregister(rd)
138        s.unregister(wr)
139
140    def test_modify(self):
141        s = self.SELECTOR()
142        self.addCleanup(s.close)
143
144        rd, wr = self.make_socketpair()
145
146        key = s.register(rd, selectors.EVENT_READ)
147
148        # modify events
149        key2 = s.modify(rd, selectors.EVENT_WRITE)
150        self.assertNotEqual(key.events, key2.events)
151        self.assertEqual(key2, s.get_key(rd))
152
153        s.unregister(rd)
154
155        # modify data
156        d1 = object()
157        d2 = object()
158
159        key = s.register(rd, selectors.EVENT_READ, d1)
160        key2 = s.modify(rd, selectors.EVENT_READ, d2)
161        self.assertEqual(key.events, key2.events)
162        self.assertNotEqual(key.data, key2.data)
163        self.assertEqual(key2, s.get_key(rd))
164        self.assertEqual(key2.data, d2)
165
166        # modify unknown file obj
167        self.assertRaises(KeyError, s.modify, 999999, selectors.EVENT_READ)
168
169        # modify use a shortcut
170        d3 = object()
171        s.register = unittest.mock.Mock()
172        s.unregister = unittest.mock.Mock()
173
174        s.modify(rd, selectors.EVENT_READ, d3)
175        self.assertFalse(s.register.called)
176        self.assertFalse(s.unregister.called)
177
178    def test_close(self):
179        s = self.SELECTOR()
180        self.addCleanup(s.close)
181
182        mapping = s.get_map()
183        rd, wr = self.make_socketpair()
184
185        s.register(rd, selectors.EVENT_READ)
186        s.register(wr, selectors.EVENT_WRITE)
187
188        s.close()
189        self.assertRaises(RuntimeError, s.get_key, rd)
190        self.assertRaises(RuntimeError, s.get_key, wr)
191        self.assertRaises(KeyError, mapping.__getitem__, rd)
192        self.assertRaises(KeyError, mapping.__getitem__, wr)
193
194    def test_get_key(self):
195        s = self.SELECTOR()
196        self.addCleanup(s.close)
197
198        rd, wr = self.make_socketpair()
199
200        key = s.register(rd, selectors.EVENT_READ, "data")
201        self.assertEqual(key, s.get_key(rd))
202
203        # unknown file obj
204        self.assertRaises(KeyError, s.get_key, 999999)
205
206    def test_get_map(self):
207        s = self.SELECTOR()
208        self.addCleanup(s.close)
209
210        rd, wr = self.make_socketpair()
211
212        keys = s.get_map()
213        self.assertFalse(keys)
214        self.assertEqual(len(keys), 0)
215        self.assertEqual(list(keys), [])
216        key = s.register(rd, selectors.EVENT_READ, "data")
217        self.assertIn(rd, keys)
218        self.assertEqual(key, keys[rd])
219        self.assertEqual(len(keys), 1)
220        self.assertEqual(list(keys), [rd.fileno()])
221        self.assertEqual(list(keys.values()), [key])
222
223        # unknown file obj
224        with self.assertRaises(KeyError):
225            keys[999999]
226
227        # Read-only mapping
228        with self.assertRaises(TypeError):
229            del keys[rd]
230
231    def test_select(self):
232        s = self.SELECTOR()
233        self.addCleanup(s.close)
234
235        rd, wr = self.make_socketpair()
236
237        s.register(rd, selectors.EVENT_READ)
238        wr_key = s.register(wr, selectors.EVENT_WRITE)
239
240        result = s.select()
241        for key, events in result:
242            self.assertTrue(isinstance(key, selectors.SelectorKey))
243            self.assertTrue(events)
244            self.assertFalse(events & ~(selectors.EVENT_READ |
245                                        selectors.EVENT_WRITE))
246
247        self.assertEqual([(wr_key, selectors.EVENT_WRITE)], result)
248
249    def test_context_manager(self):
250        s = self.SELECTOR()
251        self.addCleanup(s.close)
252
253        rd, wr = self.make_socketpair()
254
255        with s as sel:
256            sel.register(rd, selectors.EVENT_READ)
257            sel.register(wr, selectors.EVENT_WRITE)
258
259        self.assertRaises(RuntimeError, s.get_key, rd)
260        self.assertRaises(RuntimeError, s.get_key, wr)
261
262    def test_fileno(self):
263        s = self.SELECTOR()
264        self.addCleanup(s.close)
265
266        if hasattr(s, 'fileno'):
267            fd = s.fileno()
268            self.assertTrue(isinstance(fd, int))
269            self.assertGreaterEqual(fd, 0)
270
271    def test_selector(self):
272        s = self.SELECTOR()
273        self.addCleanup(s.close)
274
275        NUM_SOCKETS = 12
276        MSG = b" This is a test."
277        MSG_LEN = len(MSG)
278        readers = []
279        writers = []
280        r2w = {}
281        w2r = {}
282
283        for i in range(NUM_SOCKETS):
284            rd, wr = self.make_socketpair()
285            s.register(rd, selectors.EVENT_READ)
286            s.register(wr, selectors.EVENT_WRITE)
287            readers.append(rd)
288            writers.append(wr)
289            r2w[rd] = wr
290            w2r[wr] = rd
291
292        bufs = []
293
294        while writers:
295            ready = s.select()
296            ready_writers = find_ready_matching(ready, selectors.EVENT_WRITE)
297            if not ready_writers:
298                self.fail("no sockets ready for writing")
299            wr = random.choice(ready_writers)
300            wr.send(MSG)
301
302            for i in range(10):
303                ready = s.select()
304                ready_readers = find_ready_matching(ready,
305                                                    selectors.EVENT_READ)
306                if ready_readers:
307                    break
308                # there might be a delay between the write to the write end and
309                # the read end is reported ready
310                sleep(0.1)
311            else:
312                self.fail("no sockets ready for reading")
313            self.assertEqual([w2r[wr]], ready_readers)
314            rd = ready_readers[0]
315            buf = rd.recv(MSG_LEN)
316            self.assertEqual(len(buf), MSG_LEN)
317            bufs.append(buf)
318            s.unregister(r2w[rd])
319            s.unregister(rd)
320            writers.remove(r2w[rd])
321
322        self.assertEqual(bufs, [MSG] * NUM_SOCKETS)
323
324    @unittest.skipIf(sys.platform == 'win32',
325                     'select.select() cannot be used with empty fd sets')
326    def test_empty_select(self):
327        # Issue #23009: Make sure EpollSelector.select() works when no FD is
328        # registered.
329        s = self.SELECTOR()
330        self.addCleanup(s.close)
331        self.assertEqual(s.select(timeout=0), [])
332
333    def test_timeout(self):
334        s = self.SELECTOR()
335        self.addCleanup(s.close)
336
337        rd, wr = self.make_socketpair()
338
339        s.register(wr, selectors.EVENT_WRITE)
340        t = time()
341        self.assertEqual(1, len(s.select(0)))
342        self.assertEqual(1, len(s.select(-1)))
343        self.assertLess(time() - t, 0.5)
344
345        s.unregister(wr)
346        s.register(rd, selectors.EVENT_READ)
347        t = time()
348        self.assertFalse(s.select(0))
349        self.assertFalse(s.select(-1))
350        self.assertLess(time() - t, 0.5)
351
352        t0 = time()
353        self.assertFalse(s.select(1))
354        t1 = time()
355        dt = t1 - t0
356        # Tolerate 2.0 seconds for very slow buildbots
357        self.assertTrue(0.8 <= dt <= 2.0, dt)
358
359    @unittest.skipUnless(hasattr(signal, "alarm"),
360                         "signal.alarm() required for this test")
361    def test_select_interrupt_exc(self):
362        s = self.SELECTOR()
363        self.addCleanup(s.close)
364
365        rd, wr = self.make_socketpair()
366
367        class InterruptSelect(Exception):
368            pass
369
370        def handler(*args):
371            raise InterruptSelect
372
373        orig_alrm_handler = signal.signal(signal.SIGALRM, handler)
374        self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
375        self.addCleanup(signal.alarm, 0)
376
377        signal.alarm(1)
378
379        s.register(rd, selectors.EVENT_READ)
380        t = time()
381        # select() is interrupted by a signal which raises an exception
382        with self.assertRaises(InterruptSelect):
383            s.select(30)
384        # select() was interrupted before the timeout of 30 seconds
385        self.assertLess(time() - t, 5.0)
386
387    @unittest.skipUnless(hasattr(signal, "alarm"),
388                         "signal.alarm() required for this test")
389    def test_select_interrupt_noraise(self):
390        s = self.SELECTOR()
391        self.addCleanup(s.close)
392
393        rd, wr = self.make_socketpair()
394
395        orig_alrm_handler = signal.signal(signal.SIGALRM, lambda *args: None)
396        self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
397        self.addCleanup(signal.alarm, 0)
398
399        signal.alarm(1)
400
401        s.register(rd, selectors.EVENT_READ)
402        t = time()
403        # select() is interrupted by a signal, but the signal handler doesn't
404        # raise an exception, so select() should by retries with a recomputed
405        # timeout
406        self.assertFalse(s.select(1.5))
407        self.assertGreaterEqual(time() - t, 1.0)
408
409
410class ScalableSelectorMixIn:
411
412    # see issue #18963 for why it's skipped on older OS X versions
413    @support.requires_mac_ver(10, 5)
414    @unittest.skipUnless(resource, "Test needs resource module")
415    def test_above_fd_setsize(self):
416        # A scalable implementation should have no problem with more than
417        # FD_SETSIZE file descriptors. Since we don't know the value, we just
418        # try to set the soft RLIMIT_NOFILE to the hard RLIMIT_NOFILE ceiling.
419        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
420        try:
421            resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
422            self.addCleanup(resource.setrlimit, resource.RLIMIT_NOFILE,
423                            (soft, hard))
424            NUM_FDS = min(hard, 2**16)
425        except (OSError, ValueError):
426            NUM_FDS = soft
427
428        # guard for already allocated FDs (stdin, stdout...)
429        NUM_FDS -= 32
430
431        s = self.SELECTOR()
432        self.addCleanup(s.close)
433
434        for i in range(NUM_FDS // 2):
435            try:
436                rd, wr = self.make_socketpair()
437            except OSError:
438                # too many FDs, skip - note that we should only catch EMFILE
439                # here, but apparently *BSD and Solaris can fail upon connect()
440                # or bind() with EADDRNOTAVAIL, so let's be safe
441                self.skipTest("FD limit reached")
442
443            try:
444                s.register(rd, selectors.EVENT_READ)
445                s.register(wr, selectors.EVENT_WRITE)
446            except OSError as e:
447                if e.errno == errno.ENOSPC:
448                    # this can be raised by epoll if we go over
449                    # fs.epoll.max_user_watches sysctl
450                    self.skipTest("FD limit reached")
451                raise
452
453        self.assertEqual(NUM_FDS // 2, len(s.select()))
454
455
456class DefaultSelectorTestCase(BaseSelectorTestCase):
457
458    SELECTOR = selectors.DefaultSelector
459
460
461class SelectSelectorTestCase(BaseSelectorTestCase):
462
463    SELECTOR = selectors.SelectSelector
464
465
466@unittest.skipUnless(hasattr(selectors, 'PollSelector'),
467                     "Test needs selectors.PollSelector")
468class PollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
469
470    SELECTOR = getattr(selectors, 'PollSelector', None)
471
472
473@unittest.skipUnless(hasattr(selectors, 'EpollSelector'),
474                     "Test needs selectors.EpollSelector")
475class EpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
476
477    SELECTOR = getattr(selectors, 'EpollSelector', None)
478
479    def test_register_file(self):
480        # epoll(7) returns EPERM when given a file to watch
481        s = self.SELECTOR()
482        with tempfile.NamedTemporaryFile() as f:
483            with self.assertRaises(IOError):
484                s.register(f, selectors.EVENT_READ)
485            # the SelectorKey has been removed
486            with self.assertRaises(KeyError):
487                s.get_key(f)
488
489
490@unittest.skipUnless(hasattr(selectors, 'KqueueSelector'),
491                     "Test needs selectors.KqueueSelector)")
492class KqueueSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
493
494    SELECTOR = getattr(selectors, 'KqueueSelector', None)
495
496    def test_register_bad_fd(self):
497        # a file descriptor that's been closed should raise an OSError
498        # with EBADF
499        s = self.SELECTOR()
500        bad_f = support.make_bad_fd()
501        with self.assertRaises(OSError) as cm:
502            s.register(bad_f, selectors.EVENT_READ)
503        self.assertEqual(cm.exception.errno, errno.EBADF)
504        # the SelectorKey has been removed
505        with self.assertRaises(KeyError):
506            s.get_key(bad_f)
507
508
509@unittest.skipUnless(hasattr(selectors, 'DevpollSelector'),
510                     "Test needs selectors.DevpollSelector")
511class DevpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
512
513    SELECTOR = getattr(selectors, 'DevpollSelector', None)
514
515
516
517def test_main():
518    tests = [DefaultSelectorTestCase, SelectSelectorTestCase,
519             PollSelectorTestCase, EpollSelectorTestCase,
520             KqueueSelectorTestCase, DevpollSelectorTestCase]
521    support.run_unittest(*tests)
522    support.reap_children()
523
524
525if __name__ == "__main__":
526    test_main()
527