1"""Test the secrets module.
2
3As most of the functions in secrets are thin wrappers around functions
4defined elsewhere, we don't need to test them exhaustively.
5"""
6
7
8import secrets
9import unittest
10import string
11
12
13# === Unit tests ===
14
15class Compare_Digest_Tests(unittest.TestCase):
16    """Test secrets.compare_digest function."""
17
18    def test_equal(self):
19        # Test compare_digest functionality with equal (byte/text) strings.
20        for s in ("a", "bcd", "xyz123"):
21            a = s*100
22            b = s*100
23            self.assertTrue(secrets.compare_digest(a, b))
24            self.assertTrue(secrets.compare_digest(a.encode('utf-8'), b.encode('utf-8')))
25
26    def test_unequal(self):
27        # Test compare_digest functionality with unequal (byte/text) strings.
28        self.assertFalse(secrets.compare_digest("abc", "abcd"))
29        self.assertFalse(secrets.compare_digest(b"abc", b"abcd"))
30        for s in ("x", "mn", "a1b2c3"):
31            a = s*100 + "q"
32            b = s*100 + "k"
33            self.assertFalse(secrets.compare_digest(a, b))
34            self.assertFalse(secrets.compare_digest(a.encode('utf-8'), b.encode('utf-8')))
35
36    def test_bad_types(self):
37        # Test that compare_digest raises with mixed types.
38        a = 'abcde'
39        b = a.encode('utf-8')
40        assert isinstance(a, str)
41        assert isinstance(b, bytes)
42        self.assertRaises(TypeError, secrets.compare_digest, a, b)
43        self.assertRaises(TypeError, secrets.compare_digest, b, a)
44
45    def test_bool(self):
46        # Test that compare_digest returns a bool.
47        self.assertIsInstance(secrets.compare_digest("abc", "abc"), bool)
48        self.assertIsInstance(secrets.compare_digest("abc", "xyz"), bool)
49
50
51class Random_Tests(unittest.TestCase):
52    """Test wrappers around SystemRandom methods."""
53
54    def test_randbits(self):
55        # Test randbits.
56        errmsg = "randbits(%d) returned %d"
57        for numbits in (3, 12, 30):
58            for i in range(6):
59                n = secrets.randbits(numbits)
60                self.assertTrue(0 <= n < 2**numbits, errmsg % (numbits, n))
61
62    def test_choice(self):
63        # Test choice.
64        items = [1, 2, 4, 8, 16, 32, 64]
65        for i in range(10):
66            self.assertTrue(secrets.choice(items) in items)
67
68    def test_randbelow(self):
69        # Test randbelow.
70        for i in range(2, 10):
71            self.assertIn(secrets.randbelow(i), range(i))
72        self.assertRaises(ValueError, secrets.randbelow, 0)
73        self.assertRaises(ValueError, secrets.randbelow, -1)
74
75
76class Token_Tests(unittest.TestCase):
77    """Test token functions."""
78
79    def test_token_defaults(self):
80        # Test that token_* functions handle default size correctly.
81        for func in (secrets.token_bytes, secrets.token_hex,
82                     secrets.token_urlsafe):
83            with self.subTest(func=func):
84                name = func.__name__
85                try:
86                    func()
87                except TypeError:
88                    self.fail("%s cannot be called with no argument" % name)
89                try:
90                    func(None)
91                except TypeError:
92                    self.fail("%s cannot be called with None" % name)
93        size = secrets.DEFAULT_ENTROPY
94        self.assertEqual(len(secrets.token_bytes(None)), size)
95        self.assertEqual(len(secrets.token_hex(None)), 2*size)
96
97    def test_token_bytes(self):
98        # Test token_bytes.
99        for n in (1, 8, 17, 100):
100            with self.subTest(n=n):
101                self.assertIsInstance(secrets.token_bytes(n), bytes)
102                self.assertEqual(len(secrets.token_bytes(n)), n)
103
104    def test_token_hex(self):
105        # Test token_hex.
106        for n in (1, 12, 25, 90):
107            with self.subTest(n=n):
108                s = secrets.token_hex(n)
109                self.assertIsInstance(s, str)
110                self.assertEqual(len(s), 2*n)
111                self.assertTrue(all(c in string.hexdigits for c in s))
112
113    def test_token_urlsafe(self):
114        # Test token_urlsafe.
115        legal = string.ascii_letters + string.digits + '-_'
116        for n in (1, 11, 28, 76):
117            with self.subTest(n=n):
118                s = secrets.token_urlsafe(n)
119                self.assertIsInstance(s, str)
120                self.assertTrue(all(c in legal for c in s))
121
122
123if __name__ == '__main__':
124    unittest.main()
125