1"""
2TestCases for checking dbShelve objects.
3"""
4
5import os, string, sys
6import random
7import unittest
8
9
10from test_all import db, dbshelve, test_support, verbose, \
11        get_new_environment_path, get_new_database_path
12
13
14
15
16
17#----------------------------------------------------------------------
18
19# We want the objects to be comparable so we can test dbshelve.values
20# later on.
21class DataClass:
22    def __init__(self):
23        self.value = random.random()
24
25    def __repr__(self) :  # For Python 3.0 comparison
26        return "DataClass %f" %self.value
27
28    def __cmp__(self, other):  # For Python 2.x comparison
29        return cmp(self.value, other)
30
31
32class DBShelveTestCase(unittest.TestCase):
33    if (sys.version_info < (2, 7)) or ((sys.version_info >= (3, 0)) and
34            (sys.version_info < (3, 2))) :
35        def assertIn(self, a, b, msg=None) :
36            return self.assertTrue(a in b, msg=msg)
37
38
39    def setUp(self):
40        if sys.version_info[0] >= 3 :
41            from test_all import do_proxy_db_py3k
42            self._flag_proxy_db_py3k = do_proxy_db_py3k(False)
43        self.filename = get_new_database_path()
44        self.do_open()
45
46    def tearDown(self):
47        if sys.version_info[0] >= 3 :
48            from test_all import do_proxy_db_py3k
49            do_proxy_db_py3k(self._flag_proxy_db_py3k)
50        self.do_close()
51        test_support.unlink(self.filename)
52
53    def mk(self, key):
54        """Turn key into an appropriate key type for this db"""
55        # override in child class for RECNO
56        if sys.version_info[0] < 3 :
57            return key
58        else :
59            return bytes(key, "iso8859-1")  # 8 bits
60
61    def populateDB(self, d):
62        for x in string.letters:
63            d[self.mk('S' + x)] = 10 * x           # add a string
64            d[self.mk('I' + x)] = ord(x)           # add an integer
65            d[self.mk('L' + x)] = [x] * 10         # add a list
66
67            inst = DataClass()            # add an instance
68            inst.S = 10 * x
69            inst.I = ord(x)
70            inst.L = [x] * 10
71            d[self.mk('O' + x)] = inst
72
73
74    # overridable in derived classes to affect how the shelf is created/opened
75    def do_open(self):
76        self.d = dbshelve.open(self.filename)
77
78    # and closed...
79    def do_close(self):
80        self.d.close()
81
82
83
84    def test01_basics(self):
85        if verbose:
86            print '\n', '-=' * 30
87            print "Running %s.test01_basics..." % self.__class__.__name__
88
89        self.populateDB(self.d)
90        self.d.sync()
91        self.do_close()
92        self.do_open()
93        d = self.d
94
95        l = len(d)
96        k = d.keys()
97        s = d.stat()
98        f = d.fd()
99
100        if verbose:
101            print "length:", l
102            print "keys:", k
103            print "stats:", s
104
105        self.assertEqual(0, d.has_key(self.mk('bad key')))
106        self.assertEqual(1, d.has_key(self.mk('IA')))
107        self.assertEqual(1, d.has_key(self.mk('OA')))
108
109        d.delete(self.mk('IA'))
110        del d[self.mk('OA')]
111        self.assertEqual(0, d.has_key(self.mk('IA')))
112        self.assertEqual(0, d.has_key(self.mk('OA')))
113        self.assertEqual(len(d), l-2)
114
115        values = []
116        for key in d.keys():
117            value = d[key]
118            values.append(value)
119            if verbose:
120                print "%s: %s" % (key, value)
121            self.checkrec(key, value)
122
123        dbvalues = d.values()
124        self.assertEqual(len(dbvalues), len(d.keys()))
125        if sys.version_info < (2, 6) :
126            values.sort()
127            dbvalues.sort()
128            self.assertEqual(values, dbvalues)
129        else :  # XXX: Convert all to strings. Please, improve
130            values.sort(key=lambda x : str(x))
131            dbvalues.sort(key=lambda x : str(x))
132            self.assertEqual(repr(values), repr(dbvalues))
133
134        items = d.items()
135        self.assertEqual(len(items), len(values))
136
137        for key, value in items:
138            self.checkrec(key, value)
139
140        self.assertEqual(d.get(self.mk('bad key')), None)
141        self.assertEqual(d.get(self.mk('bad key'), None), None)
142        self.assertEqual(d.get(self.mk('bad key'), 'a string'), 'a string')
143        self.assertEqual(d.get(self.mk('bad key'), [1, 2, 3]), [1, 2, 3])
144
145        d.set_get_returns_none(0)
146        self.assertRaises(db.DBNotFoundError, d.get, self.mk('bad key'))
147        d.set_get_returns_none(1)
148
149        d.put(self.mk('new key'), 'new data')
150        self.assertEqual(d.get(self.mk('new key')), 'new data')
151        self.assertEqual(d[self.mk('new key')], 'new data')
152
153
154
155    def test02_cursors(self):
156        if verbose:
157            print '\n', '-=' * 30
158            print "Running %s.test02_cursors..." % self.__class__.__name__
159
160        self.populateDB(self.d)
161        d = self.d
162
163        count = 0
164        c = d.cursor()
165        rec = c.first()
166        while rec is not None:
167            count = count + 1
168            if verbose:
169                print rec
170            key, value = rec
171            self.checkrec(key, value)
172            # Hack to avoid conversion by 2to3 tool
173            rec = getattr(c, "next")()
174        del c
175
176        self.assertEqual(count, len(d))
177
178        count = 0
179        c = d.cursor()
180        rec = c.last()
181        while rec is not None:
182            count = count + 1
183            if verbose:
184                print rec
185            key, value = rec
186            self.checkrec(key, value)
187            rec = c.prev()
188
189        self.assertEqual(count, len(d))
190
191        c.set(self.mk('SS'))
192        key, value = c.current()
193        self.checkrec(key, value)
194        del c
195
196
197    def test03_append(self):
198        # NOTE: this is overridden in RECNO subclass, don't change its name.
199        if verbose:
200            print '\n', '-=' * 30
201            print "Running %s.test03_append..." % self.__class__.__name__
202
203        self.assertRaises(dbshelve.DBShelveError,
204                          self.d.append, 'unit test was here')
205
206
207    def test04_iterable(self) :
208        self.populateDB(self.d)
209        d = self.d
210        keys = d.keys()
211        keyset = set(keys)
212        self.assertEqual(len(keyset), len(keys))
213
214        for key in d :
215            self.assertIn(key, keyset)
216            keyset.remove(key)
217        self.assertEqual(len(keyset), 0)
218
219    def checkrec(self, key, value):
220        # override this in a subclass if the key type is different
221
222        if sys.version_info[0] >= 3 :
223            if isinstance(key, bytes) :
224                key = key.decode("iso8859-1")  # 8 bits
225
226        x = key[1]
227        if key[0] == 'S':
228            self.assertEqual(type(value), str)
229            self.assertEqual(value, 10 * x)
230
231        elif key[0] == 'I':
232            self.assertEqual(type(value), int)
233            self.assertEqual(value, ord(x))
234
235        elif key[0] == 'L':
236            self.assertEqual(type(value), list)
237            self.assertEqual(value, [x] * 10)
238
239        elif key[0] == 'O':
240            if sys.version_info[0] < 3 :
241                from types import InstanceType
242                self.assertEqual(type(value), InstanceType)
243            else :
244                self.assertEqual(type(value), DataClass)
245
246            self.assertEqual(value.S, 10 * x)
247            self.assertEqual(value.I, ord(x))
248            self.assertEqual(value.L, [x] * 10)
249
250        else:
251            self.assertTrue(0, 'Unknown key type, fix the test')
252
253#----------------------------------------------------------------------
254
255class BasicShelveTestCase(DBShelveTestCase):
256    def do_open(self):
257        self.d = dbshelve.DBShelf()
258        self.d.open(self.filename, self.dbtype, self.dbflags)
259
260    def do_close(self):
261        self.d.close()
262
263
264class BTreeShelveTestCase(BasicShelveTestCase):
265    dbtype = db.DB_BTREE
266    dbflags = db.DB_CREATE
267
268
269class HashShelveTestCase(BasicShelveTestCase):
270    dbtype = db.DB_HASH
271    dbflags = db.DB_CREATE
272
273
274class ThreadBTreeShelveTestCase(BasicShelveTestCase):
275    dbtype = db.DB_BTREE
276    dbflags = db.DB_CREATE | db.DB_THREAD
277
278
279class ThreadHashShelveTestCase(BasicShelveTestCase):
280    dbtype = db.DB_HASH
281    dbflags = db.DB_CREATE | db.DB_THREAD
282
283
284#----------------------------------------------------------------------
285
286class BasicEnvShelveTestCase(DBShelveTestCase):
287    def do_open(self):
288        self.env = db.DBEnv()
289        self.env.open(self.homeDir,
290                self.envflags | db.DB_INIT_MPOOL | db.DB_CREATE)
291
292        self.filename = os.path.split(self.filename)[1]
293        self.d = dbshelve.DBShelf(self.env)
294        self.d.open(self.filename, self.dbtype, self.dbflags)
295
296
297    def do_close(self):
298        self.d.close()
299        self.env.close()
300
301
302    def setUp(self) :
303        self.homeDir = get_new_environment_path()
304        DBShelveTestCase.setUp(self)
305
306    def tearDown(self):
307        if sys.version_info[0] >= 3 :
308            from test_all import do_proxy_db_py3k
309            do_proxy_db_py3k(self._flag_proxy_db_py3k)
310        self.do_close()
311        test_support.rmtree(self.homeDir)
312
313
314class EnvBTreeShelveTestCase(BasicEnvShelveTestCase):
315    envflags = 0
316    dbtype = db.DB_BTREE
317    dbflags = db.DB_CREATE
318
319
320class EnvHashShelveTestCase(BasicEnvShelveTestCase):
321    envflags = 0
322    dbtype = db.DB_HASH
323    dbflags = db.DB_CREATE
324
325
326class EnvThreadBTreeShelveTestCase(BasicEnvShelveTestCase):
327    envflags = db.DB_THREAD
328    dbtype = db.DB_BTREE
329    dbflags = db.DB_CREATE | db.DB_THREAD
330
331
332class EnvThreadHashShelveTestCase(BasicEnvShelveTestCase):
333    envflags = db.DB_THREAD
334    dbtype = db.DB_HASH
335    dbflags = db.DB_CREATE | db.DB_THREAD
336
337
338#----------------------------------------------------------------------
339# test cases for a DBShelf in a RECNO DB.
340
341class RecNoShelveTestCase(BasicShelveTestCase):
342    dbtype = db.DB_RECNO
343    dbflags = db.DB_CREATE
344
345    def setUp(self):
346        BasicShelveTestCase.setUp(self)
347
348        # pool to assign integer key values out of
349        self.key_pool = list(range(1, 5000))
350        self.key_map = {}     # map string keys to the number we gave them
351        self.intkey_map = {}  # reverse map of above
352
353    def mk(self, key):
354        if key not in self.key_map:
355            self.key_map[key] = self.key_pool.pop(0)
356            self.intkey_map[self.key_map[key]] = key
357        return self.key_map[key]
358
359    def checkrec(self, intkey, value):
360        key = self.intkey_map[intkey]
361        BasicShelveTestCase.checkrec(self, key, value)
362
363    def test03_append(self):
364        if verbose:
365            print '\n', '-=' * 30
366            print "Running %s.test03_append..." % self.__class__.__name__
367
368        self.d[1] = 'spam'
369        self.d[5] = 'eggs'
370        self.assertEqual(6, self.d.append('spam'))
371        self.assertEqual(7, self.d.append('baked beans'))
372        self.assertEqual('spam', self.d.get(6))
373        self.assertEqual('spam', self.d.get(1))
374        self.assertEqual('baked beans', self.d.get(7))
375        self.assertEqual('eggs', self.d.get(5))
376
377
378#----------------------------------------------------------------------
379
380def test_suite():
381    suite = unittest.TestSuite()
382
383    suite.addTest(unittest.makeSuite(DBShelveTestCase))
384    suite.addTest(unittest.makeSuite(BTreeShelveTestCase))
385    suite.addTest(unittest.makeSuite(HashShelveTestCase))
386    suite.addTest(unittest.makeSuite(ThreadBTreeShelveTestCase))
387    suite.addTest(unittest.makeSuite(ThreadHashShelveTestCase))
388    suite.addTest(unittest.makeSuite(EnvBTreeShelveTestCase))
389    suite.addTest(unittest.makeSuite(EnvHashShelveTestCase))
390    suite.addTest(unittest.makeSuite(EnvThreadBTreeShelveTestCase))
391    suite.addTest(unittest.makeSuite(EnvThreadHashShelveTestCase))
392    suite.addTest(unittest.makeSuite(RecNoShelveTestCase))
393
394    return suite
395
396
397if __name__ == '__main__':
398    unittest.main(defaultTest='test_suite')
399