1"""TestCases for multi-threaded access to a DB.
2"""
3
4import os
5import sys
6import time
7import errno
8from random import random
9
10DASH = '-'
11
12try:
13    WindowsError
14except NameError:
15    class WindowsError(Exception):
16        pass
17
18import unittest
19from test_all import db, dbutils, test_support, verbose, have_threads, \
20        get_new_environment_path, get_new_database_path
21
22if have_threads :
23    from threading import Thread
24    if sys.version_info[0] < 3 :
25        from threading import currentThread
26    else :
27        from threading import current_thread as currentThread
28
29
30#----------------------------------------------------------------------
31
32class BaseThreadedTestCase(unittest.TestCase):
33    dbtype       = db.DB_UNKNOWN  # must be set in derived class
34    dbopenflags  = 0
35    dbsetflags   = 0
36    envflags     = 0
37
38    def setUp(self):
39        if verbose:
40            dbutils._deadlock_VerboseFile = sys.stdout
41
42        self.homeDir = get_new_environment_path()
43        self.env = db.DBEnv()
44        self.setEnvOpts()
45        self.env.open(self.homeDir, self.envflags | db.DB_CREATE)
46
47        self.filename = self.__class__.__name__ + '.db'
48        self.d = db.DB(self.env)
49        if self.dbsetflags:
50            self.d.set_flags(self.dbsetflags)
51        self.d.open(self.filename, self.dbtype, self.dbopenflags|db.DB_CREATE)
52
53    def tearDown(self):
54        self.d.close()
55        self.env.close()
56        test_support.rmtree(self.homeDir)
57
58    def setEnvOpts(self):
59        pass
60
61    def makeData(self, key):
62        return DASH.join([key] * 5)
63
64
65#----------------------------------------------------------------------
66
67
68class ConcurrentDataStoreBase(BaseThreadedTestCase):
69    dbopenflags = db.DB_THREAD
70    envflags    = db.DB_THREAD | db.DB_INIT_CDB | db.DB_INIT_MPOOL
71    readers     = 0 # derived class should set
72    writers     = 0
73    records     = 1000
74
75    def test01_1WriterMultiReaders(self):
76        if verbose:
77            print '\n', '-=' * 30
78            print "Running %s.test01_1WriterMultiReaders..." % \
79                  self.__class__.__name__
80
81        keys=range(self.records)
82        import random
83        random.shuffle(keys)
84        records_per_writer=self.records//self.writers
85        readers_per_writer=self.readers//self.writers
86        self.assertEqual(self.records,self.writers*records_per_writer)
87        self.assertEqual(self.readers,self.writers*readers_per_writer)
88        self.assertTrue((records_per_writer%readers_per_writer)==0)
89        readers = []
90
91        for x in xrange(self.readers):
92            rt = Thread(target = self.readerThread,
93                        args = (self.d, x),
94                        name = 'reader %d' % x,
95                        )#verbose = verbose)
96            if sys.version_info[0] < 3 :
97                rt.setDaemon(True)
98            else :
99                rt.daemon = True
100            readers.append(rt)
101
102        writers=[]
103        for x in xrange(self.writers):
104            a=keys[records_per_writer*x:records_per_writer*(x+1)]
105            a.sort()  # Generate conflicts
106            b=readers[readers_per_writer*x:readers_per_writer*(x+1)]
107            wt = Thread(target = self.writerThread,
108                        args = (self.d, a, b),
109                        name = 'writer %d' % x,
110                        )#verbose = verbose)
111            writers.append(wt)
112
113        for t in writers:
114            if sys.version_info[0] < 3 :
115                t.setDaemon(True)
116            else :
117                t.daemon = True
118            t.start()
119
120        for t in writers:
121            t.join()
122        for t in readers:
123            t.join()
124
125    def writerThread(self, d, keys, readers):
126        if sys.version_info[0] < 3 :
127            name = currentThread().getName()
128        else :
129            name = currentThread().name
130
131        if verbose:
132            print "%s: creating records %d - %d" % (name, start, stop)
133
134        count=len(keys)//len(readers)
135        count2=count
136        for x in keys :
137            key = '%04d' % x
138            dbutils.DeadlockWrap(d.put, key, self.makeData(key),
139                                 max_retries=12)
140            if verbose and x % 100 == 0:
141                print "%s: records %d - %d finished" % (name, start, x)
142
143            count2-=1
144            if not count2 :
145                readers.pop().start()
146                count2=count
147
148        if verbose:
149            print "%s: finished creating records" % name
150
151        if verbose:
152            print "%s: thread finished" % name
153
154    def readerThread(self, d, readerNum):
155        if sys.version_info[0] < 3 :
156            name = currentThread().getName()
157        else :
158            name = currentThread().name
159
160        for i in xrange(5) :
161            c = d.cursor()
162            count = 0
163            rec = c.first()
164            while rec:
165                count += 1
166                key, data = rec
167                self.assertEqual(self.makeData(key), data)
168                rec = c.next()
169            if verbose:
170                print "%s: found %d records" % (name, count)
171            c.close()
172
173        if verbose:
174            print "%s: thread finished" % name
175
176
177class BTreeConcurrentDataStore(ConcurrentDataStoreBase):
178    dbtype  = db.DB_BTREE
179    writers = 2
180    readers = 10
181    records = 1000
182
183
184class HashConcurrentDataStore(ConcurrentDataStoreBase):
185    dbtype  = db.DB_HASH
186    writers = 2
187    readers = 10
188    records = 1000
189
190
191#----------------------------------------------------------------------
192
193class SimpleThreadedBase(BaseThreadedTestCase):
194    dbopenflags = db.DB_THREAD
195    envflags    = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK
196    readers = 10
197    writers = 2
198    records = 1000
199
200    def setEnvOpts(self):
201        self.env.set_lk_detect(db.DB_LOCK_DEFAULT)
202
203    def test02_SimpleLocks(self):
204        if verbose:
205            print '\n', '-=' * 30
206            print "Running %s.test02_SimpleLocks..." % self.__class__.__name__
207
208
209        keys=range(self.records)
210        import random
211        random.shuffle(keys)
212        records_per_writer=self.records//self.writers
213        readers_per_writer=self.readers//self.writers
214        self.assertEqual(self.records,self.writers*records_per_writer)
215        self.assertEqual(self.readers,self.writers*readers_per_writer)
216        self.assertTrue((records_per_writer%readers_per_writer)==0)
217
218        readers = []
219        for x in xrange(self.readers):
220            rt = Thread(target = self.readerThread,
221                        args = (self.d, x),
222                        name = 'reader %d' % x,
223                        )#verbose = verbose)
224            if sys.version_info[0] < 3 :
225                rt.setDaemon(True)
226            else :
227                rt.daemon = True
228            readers.append(rt)
229
230        writers = []
231        for x in xrange(self.writers):
232            a=keys[records_per_writer*x:records_per_writer*(x+1)]
233            a.sort()  # Generate conflicts
234            b=readers[readers_per_writer*x:readers_per_writer*(x+1)]
235            wt = Thread(target = self.writerThread,
236                        args = (self.d, a, b),
237                        name = 'writer %d' % x,
238                        )#verbose = verbose)
239            writers.append(wt)
240
241        for t in writers:
242            if sys.version_info[0] < 3 :
243                t.setDaemon(True)
244            else :
245                t.daemon = True
246            t.start()
247
248        for t in writers:
249            t.join()
250        for t in readers:
251            t.join()
252
253    def writerThread(self, d, keys, readers):
254        if sys.version_info[0] < 3 :
255            name = currentThread().getName()
256        else :
257            name = currentThread().name
258        if verbose:
259            print "%s: creating records %d - %d" % (name, start, stop)
260
261        count=len(keys)//len(readers)
262        count2=count
263        for x in keys :
264            key = '%04d' % x
265            dbutils.DeadlockWrap(d.put, key, self.makeData(key),
266                                 max_retries=12)
267
268            if verbose and x % 100 == 0:
269                print "%s: records %d - %d finished" % (name, start, x)
270
271            count2-=1
272            if not count2 :
273                readers.pop().start()
274                count2=count
275
276        if verbose:
277            print "%s: thread finished" % name
278
279    def readerThread(self, d, readerNum):
280        if sys.version_info[0] < 3 :
281            name = currentThread().getName()
282        else :
283            name = currentThread().name
284
285        c = d.cursor()
286        count = 0
287        rec = dbutils.DeadlockWrap(c.first, max_retries=10)
288        while rec:
289            count += 1
290            key, data = rec
291            self.assertEqual(self.makeData(key), data)
292            rec = dbutils.DeadlockWrap(c.next, max_retries=10)
293        if verbose:
294            print "%s: found %d records" % (name, count)
295        c.close()
296
297        if verbose:
298            print "%s: thread finished" % name
299
300
301class BTreeSimpleThreaded(SimpleThreadedBase):
302    dbtype = db.DB_BTREE
303
304
305class HashSimpleThreaded(SimpleThreadedBase):
306    dbtype = db.DB_HASH
307
308
309#----------------------------------------------------------------------
310
311
312class ThreadedTransactionsBase(BaseThreadedTestCase):
313    dbopenflags = db.DB_THREAD | db.DB_AUTO_COMMIT
314    envflags    = (db.DB_THREAD |
315                   db.DB_INIT_MPOOL |
316                   db.DB_INIT_LOCK |
317                   db.DB_INIT_LOG |
318                   db.DB_INIT_TXN
319                   )
320    readers = 0
321    writers = 0
322    records = 2000
323    txnFlag = 0
324
325    def setEnvOpts(self):
326        #self.env.set_lk_detect(db.DB_LOCK_DEFAULT)
327        pass
328
329    def test03_ThreadedTransactions(self):
330        if verbose:
331            print '\n', '-=' * 30
332            print "Running %s.test03_ThreadedTransactions..." % \
333                  self.__class__.__name__
334
335        keys=range(self.records)
336        import random
337        random.shuffle(keys)
338        records_per_writer=self.records//self.writers
339        readers_per_writer=self.readers//self.writers
340        self.assertEqual(self.records,self.writers*records_per_writer)
341        self.assertEqual(self.readers,self.writers*readers_per_writer)
342        self.assertTrue((records_per_writer%readers_per_writer)==0)
343
344        readers=[]
345        for x in xrange(self.readers):
346            rt = Thread(target = self.readerThread,
347                        args = (self.d, x),
348                        name = 'reader %d' % x,
349                        )#verbose = verbose)
350            if sys.version_info[0] < 3 :
351                rt.setDaemon(True)
352            else :
353                rt.daemon = True
354            readers.append(rt)
355
356        writers = []
357        for x in xrange(self.writers):
358            a=keys[records_per_writer*x:records_per_writer*(x+1)]
359            b=readers[readers_per_writer*x:readers_per_writer*(x+1)]
360            wt = Thread(target = self.writerThread,
361                        args = (self.d, a, b),
362                        name = 'writer %d' % x,
363                        )#verbose = verbose)
364            writers.append(wt)
365
366        dt = Thread(target = self.deadlockThread)
367        if sys.version_info[0] < 3 :
368            dt.setDaemon(True)
369        else :
370            dt.daemon = True
371        dt.start()
372
373        for t in writers:
374            if sys.version_info[0] < 3 :
375                t.setDaemon(True)
376            else :
377                t.daemon = True
378            t.start()
379
380        for t in writers:
381            t.join()
382        for t in readers:
383            t.join()
384
385        self.doLockDetect = False
386        dt.join()
387
388    def writerThread(self, d, keys, readers):
389        if sys.version_info[0] < 3 :
390            name = currentThread().getName()
391        else :
392            name = currentThread().name
393
394        count=len(keys)//len(readers)
395        while len(keys):
396            try:
397                txn = self.env.txn_begin(None, self.txnFlag)
398                keys2=keys[:count]
399                for x in keys2 :
400                    key = '%04d' % x
401                    d.put(key, self.makeData(key), txn)
402                    if verbose and x % 100 == 0:
403                        print "%s: records %d - %d finished" % (name, start, x)
404                txn.commit()
405                keys=keys[count:]
406                readers.pop().start()
407            except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val:
408                if verbose:
409                    if sys.version_info < (2, 6) :
410                        print "%s: Aborting transaction (%s)" % (name, val[1])
411                    else :
412                        print "%s: Aborting transaction (%s)" % (name,
413                                val.args[1])
414                txn.abort()
415
416        if verbose:
417            print "%s: thread finished" % name
418
419    def readerThread(self, d, readerNum):
420        if sys.version_info[0] < 3 :
421            name = currentThread().getName()
422        else :
423            name = currentThread().name
424
425        finished = False
426        while not finished:
427            try:
428                txn = self.env.txn_begin(None, self.txnFlag)
429                c = d.cursor(txn)
430                count = 0
431                rec = c.first()
432                while rec:
433                    count += 1
434                    key, data = rec
435                    self.assertEqual(self.makeData(key), data)
436                    rec = c.next()
437                if verbose: print "%s: found %d records" % (name, count)
438                c.close()
439                txn.commit()
440                finished = True
441            except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val:
442                if verbose:
443                    if sys.version_info < (2, 6) :
444                        print "%s: Aborting transaction (%s)" % (name, val[1])
445                    else :
446                        print "%s: Aborting transaction (%s)" % (name,
447                                val.args[1])
448                c.close()
449                txn.abort()
450
451        if verbose:
452            print "%s: thread finished" % name
453
454    def deadlockThread(self):
455        self.doLockDetect = True
456        while self.doLockDetect:
457            time.sleep(0.05)
458            try:
459                aborted = self.env.lock_detect(
460                    db.DB_LOCK_RANDOM, db.DB_LOCK_CONFLICT)
461                if verbose and aborted:
462                    print "deadlock: Aborted %d deadlocked transaction(s)" \
463                          % aborted
464            except db.DBError:
465                pass
466
467
468class BTreeThreadedTransactions(ThreadedTransactionsBase):
469    dbtype = db.DB_BTREE
470    writers = 2
471    readers = 10
472    records = 1000
473
474class HashThreadedTransactions(ThreadedTransactionsBase):
475    dbtype = db.DB_HASH
476    writers = 2
477    readers = 10
478    records = 1000
479
480class BTreeThreadedNoWaitTransactions(ThreadedTransactionsBase):
481    dbtype = db.DB_BTREE
482    writers = 2
483    readers = 10
484    records = 1000
485    txnFlag = db.DB_TXN_NOWAIT
486
487class HashThreadedNoWaitTransactions(ThreadedTransactionsBase):
488    dbtype = db.DB_HASH
489    writers = 2
490    readers = 10
491    records = 1000
492    txnFlag = db.DB_TXN_NOWAIT
493
494
495#----------------------------------------------------------------------
496
497def test_suite():
498    suite = unittest.TestSuite()
499
500    if have_threads:
501        suite.addTest(unittest.makeSuite(BTreeConcurrentDataStore))
502        suite.addTest(unittest.makeSuite(HashConcurrentDataStore))
503        suite.addTest(unittest.makeSuite(BTreeSimpleThreaded))
504        suite.addTest(unittest.makeSuite(HashSimpleThreaded))
505        suite.addTest(unittest.makeSuite(BTreeThreadedTransactions))
506        suite.addTest(unittest.makeSuite(HashThreadedTransactions))
507        suite.addTest(unittest.makeSuite(BTreeThreadedNoWaitTransactions))
508        suite.addTest(unittest.makeSuite(HashThreadedNoWaitTransactions))
509
510    else:
511        print "Threads not available, skipping thread tests."
512
513    return suite
514
515
516if __name__ == '__main__':
517    unittest.main(defaultTest='test_suite')
518