1
2import os, string
3import unittest
4
5from test_all import db, dbobj, test_support, get_new_environment_path, \
6        get_new_database_path
7
8#----------------------------------------------------------------------
9
10class dbobjTestCase(unittest.TestCase):
11    """Verify that dbobj.DB and dbobj.DBEnv work properly"""
12    db_name = 'test-dbobj.db'
13
14    def setUp(self):
15        self.homeDir = get_new_environment_path()
16
17    def tearDown(self):
18        if hasattr(self, 'db'):
19            del self.db
20        if hasattr(self, 'env'):
21            del self.env
22        test_support.rmtree(self.homeDir)
23
24    def test01_both(self):
25        class TestDBEnv(dbobj.DBEnv): pass
26        class TestDB(dbobj.DB):
27            def put(self, key, *args, **kwargs):
28                key = key.upper()
29                # call our parent classes put method with an upper case key
30                return dbobj.DB.put(self, key, *args, **kwargs)
31        self.env = TestDBEnv()
32        self.env.open(self.homeDir, db.DB_CREATE | db.DB_INIT_MPOOL)
33        self.db = TestDB(self.env)
34        self.db.open(self.db_name, db.DB_HASH, db.DB_CREATE)
35        self.db.put('spam', 'eggs')
36        self.assertEqual(self.db.get('spam'), None,
37               "overridden dbobj.DB.put() method failed [1]")
38        self.assertEqual(self.db.get('SPAM'), 'eggs',
39               "overridden dbobj.DB.put() method failed [2]")
40        self.db.close()
41        self.env.close()
42
43    def test02_dbobj_dict_interface(self):
44        self.env = dbobj.DBEnv()
45        self.env.open(self.homeDir, db.DB_CREATE | db.DB_INIT_MPOOL)
46        self.db = dbobj.DB(self.env)
47        self.db.open(self.db_name+'02', db.DB_HASH, db.DB_CREATE)
48        # __setitem__
49        self.db['spam'] = 'eggs'
50        # __len__
51        self.assertEqual(len(self.db), 1)
52        # __getitem__
53        self.assertEqual(self.db['spam'], 'eggs')
54        # __del__
55        del self.db['spam']
56        self.assertEqual(self.db.get('spam'), None, "dbobj __del__ failed")
57        self.db.close()
58        self.env.close()
59
60    def test03_dbobj_type_before_open(self):
61        # Ensure this doesn't cause a segfault.
62        self.assertRaises(db.DBInvalidArgError, db.DB().type)
63
64#----------------------------------------------------------------------
65
66def test_suite():
67    return unittest.makeSuite(dbobjTestCase)
68
69if __name__ == '__main__':
70    unittest.main(defaultTest='test_suite')
71