1#-*- coding: ISO-8859-1 -*-
2# pysqlite2/test/factory.py: tests for the various factories in pysqlite
3#
4# Copyright (C) 2005-2007 Gerhard H�ring <gh@ghaering.de>
5#
6# This file is part of pysqlite.
7#
8# This software is provided 'as-is', without any express or implied
9# warranty.  In no event will the authors be held liable for any damages
10# arising from the use of this software.
11#
12# Permission is granted to anyone to use this software for any purpose,
13# including commercial applications, and to alter it and redistribute it
14# freely, subject to the following restrictions:
15#
16# 1. The origin of this software must not be misrepresented; you must not
17#    claim that you wrote the original software. If you use this software
18#    in a product, an acknowledgment in the product documentation would be
19#    appreciated but is not required.
20# 2. Altered source versions must be plainly marked as such, and must not be
21#    misrepresented as being the original software.
22# 3. This notice may not be removed or altered from any source distribution.
23
24import unittest
25import sqlite3 as sqlite
26from collections import Sequence
27
28class MyConnection(sqlite.Connection):
29    def __init__(self, *args, **kwargs):
30        sqlite.Connection.__init__(self, *args, **kwargs)
31
32def dict_factory(cursor, row):
33    d = {}
34    for idx, col in enumerate(cursor.description):
35        d[col[0]] = row[idx]
36    return d
37
38class MyCursor(sqlite.Cursor):
39    def __init__(self, *args, **kwargs):
40        sqlite.Cursor.__init__(self, *args, **kwargs)
41        self.row_factory = dict_factory
42
43class ConnectionFactoryTests(unittest.TestCase):
44    def setUp(self):
45        self.con = sqlite.connect(":memory:", factory=MyConnection)
46
47    def tearDown(self):
48        self.con.close()
49
50    def CheckIsInstance(self):
51        self.assertIsInstance(self.con, MyConnection)
52
53class CursorFactoryTests(unittest.TestCase):
54    def setUp(self):
55        self.con = sqlite.connect(":memory:")
56
57    def tearDown(self):
58        self.con.close()
59
60    def CheckIsInstance(self):
61        cur = self.con.cursor()
62        self.assertIsInstance(cur, sqlite.Cursor)
63        cur = self.con.cursor(MyCursor)
64        self.assertIsInstance(cur, MyCursor)
65        cur = self.con.cursor(factory=lambda con: MyCursor(con))
66        self.assertIsInstance(cur, MyCursor)
67
68    def CheckInvalidFactory(self):
69        # not a callable at all
70        self.assertRaises(TypeError, self.con.cursor, None)
71        # invalid callable with not exact one argument
72        self.assertRaises(TypeError, self.con.cursor, lambda: None)
73        # invalid callable returning non-cursor
74        self.assertRaises(TypeError, self.con.cursor, lambda con: None)
75
76class RowFactoryTestsBackwardsCompat(unittest.TestCase):
77    def setUp(self):
78        self.con = sqlite.connect(":memory:")
79
80    def CheckIsProducedByFactory(self):
81        cur = self.con.cursor(factory=MyCursor)
82        cur.execute("select 4+5 as foo")
83        row = cur.fetchone()
84        self.assertIsInstance(row, dict)
85        cur.close()
86
87    def tearDown(self):
88        self.con.close()
89
90class RowFactoryTests(unittest.TestCase):
91    def setUp(self):
92        self.con = sqlite.connect(":memory:")
93
94    def CheckCustomFactory(self):
95        self.con.row_factory = lambda cur, row: list(row)
96        row = self.con.execute("select 1, 2").fetchone()
97        self.assertIsInstance(row, list)
98
99    def CheckSqliteRowIndex(self):
100        self.con.row_factory = sqlite.Row
101        row = self.con.execute("select 1 as a, 2 as b").fetchone()
102        self.assertIsInstance(row, sqlite.Row)
103
104        col1, col2 = row["a"], row["b"]
105        self.assertEqual(col1, 1, "by name: wrong result for column 'a'")
106        self.assertEqual(col2, 2, "by name: wrong result for column 'a'")
107
108        col1, col2 = row["A"], row["B"]
109        self.assertEqual(col1, 1, "by name: wrong result for column 'A'")
110        self.assertEqual(col2, 2, "by name: wrong result for column 'B'")
111
112        self.assertEqual(row[0], 1, "by index: wrong result for column 0")
113        self.assertEqual(row[0L], 1, "by index: wrong result for column 0")
114        self.assertEqual(row[1], 2, "by index: wrong result for column 1")
115        self.assertEqual(row[1L], 2, "by index: wrong result for column 1")
116        self.assertEqual(row[-1], 2, "by index: wrong result for column -1")
117        self.assertEqual(row[-1L], 2, "by index: wrong result for column -1")
118        self.assertEqual(row[-2], 1, "by index: wrong result for column -2")
119        self.assertEqual(row[-2L], 1, "by index: wrong result for column -2")
120
121        with self.assertRaises(IndexError):
122            row['c']
123        with self.assertRaises(IndexError):
124            row[2]
125        with self.assertRaises(IndexError):
126            row[2L]
127        with self.assertRaises(IndexError):
128            row[-3]
129        with self.assertRaises(IndexError):
130            row[-3L]
131        with self.assertRaises(IndexError):
132            row[2**1000]
133
134    def CheckSqliteRowIter(self):
135        """Checks if the row object is iterable"""
136        self.con.row_factory = sqlite.Row
137        row = self.con.execute("select 1 as a, 2 as b").fetchone()
138        for col in row:
139            pass
140
141    def CheckSqliteRowAsTuple(self):
142        """Checks if the row object can be converted to a tuple"""
143        self.con.row_factory = sqlite.Row
144        row = self.con.execute("select 1 as a, 2 as b").fetchone()
145        t = tuple(row)
146        self.assertEqual(t, (row['a'], row['b']))
147
148    def CheckSqliteRowAsDict(self):
149        """Checks if the row object can be correctly converted to a dictionary"""
150        self.con.row_factory = sqlite.Row
151        row = self.con.execute("select 1 as a, 2 as b").fetchone()
152        d = dict(row)
153        self.assertEqual(d["a"], row["a"])
154        self.assertEqual(d["b"], row["b"])
155
156    def CheckSqliteRowHashCmp(self):
157        """Checks if the row object compares and hashes correctly"""
158        self.con.row_factory = sqlite.Row
159        row_1 = self.con.execute("select 1 as a, 2 as b").fetchone()
160        row_2 = self.con.execute("select 1 as a, 2 as b").fetchone()
161        row_3 = self.con.execute("select 1 as a, 3 as b").fetchone()
162
163        self.assertEqual(row_1, row_1)
164        self.assertEqual(row_1, row_2)
165        self.assertTrue(row_2 != row_3)
166
167        self.assertFalse(row_1 != row_1)
168        self.assertFalse(row_1 != row_2)
169        self.assertFalse(row_2 == row_3)
170
171        self.assertEqual(row_1, row_2)
172        self.assertEqual(hash(row_1), hash(row_2))
173        self.assertNotEqual(row_1, row_3)
174        self.assertNotEqual(hash(row_1), hash(row_3))
175
176    def CheckSqliteRowAsSequence(self):
177        """ Checks if the row object can act like a sequence """
178        self.con.row_factory = sqlite.Row
179        row = self.con.execute("select 1 as a, 2 as b").fetchone()
180
181        as_tuple = tuple(row)
182        self.assertEqual(list(reversed(row)), list(reversed(as_tuple)))
183        self.assertIsInstance(row, Sequence)
184
185    def CheckFakeCursorClass(self):
186        # Issue #24257: Incorrect use of PyObject_IsInstance() caused
187        # segmentation fault.
188        # Issue #27861: Also applies for cursor factory.
189        class FakeCursor(str):
190            __class__ = sqlite.Cursor
191        self.con.row_factory = sqlite.Row
192        self.assertRaises(TypeError, self.con.cursor, FakeCursor)
193        self.assertRaises(TypeError, sqlite.Row, FakeCursor(), ())
194
195    def tearDown(self):
196        self.con.close()
197
198class TextFactoryTests(unittest.TestCase):
199    def setUp(self):
200        self.con = sqlite.connect(":memory:")
201
202    def CheckUnicode(self):
203        austria = unicode("�sterreich", "latin1")
204        row = self.con.execute("select ?", (austria,)).fetchone()
205        self.assertEqual(type(row[0]), unicode, "type of row[0] must be unicode")
206
207    def CheckString(self):
208        self.con.text_factory = str
209        austria = unicode("�sterreich", "latin1")
210        row = self.con.execute("select ?", (austria,)).fetchone()
211        self.assertEqual(type(row[0]), str, "type of row[0] must be str")
212        self.assertEqual(row[0], austria.encode("utf-8"), "column must equal original data in UTF-8")
213
214    def CheckCustom(self):
215        self.con.text_factory = lambda x: unicode(x, "utf-8", "ignore")
216        austria = unicode("�sterreich", "latin1")
217        row = self.con.execute("select ?", (austria.encode("latin1"),)).fetchone()
218        self.assertEqual(type(row[0]), unicode, "type of row[0] must be unicode")
219        self.assertTrue(row[0].endswith(u"reich"), "column must contain original data")
220
221    def CheckOptimizedUnicode(self):
222        self.con.text_factory = sqlite.OptimizedUnicode
223        austria = unicode("�sterreich", "latin1")
224        germany = unicode("Deutchland")
225        a_row = self.con.execute("select ?", (austria,)).fetchone()
226        d_row = self.con.execute("select ?", (germany,)).fetchone()
227        self.assertEqual(type(a_row[0]), unicode, "type of non-ASCII row must be unicode")
228        self.assertEqual(type(d_row[0]), str, "type of ASCII-only row must be str")
229
230    def tearDown(self):
231        self.con.close()
232
233class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase):
234    def setUp(self):
235        self.con = sqlite.connect(":memory:")
236        self.con.execute("create table test (value text)")
237        self.con.execute("insert into test (value) values (?)", ("a\x00b",))
238
239    def CheckString(self):
240        # text_factory defaults to unicode
241        row = self.con.execute("select value from test").fetchone()
242        self.assertIs(type(row[0]), unicode)
243        self.assertEqual(row[0], "a\x00b")
244
245    def CheckCustom(self):
246        # A custom factory should receive a str argument
247        self.con.text_factory = lambda x: x
248        row = self.con.execute("select value from test").fetchone()
249        self.assertIs(type(row[0]), str)
250        self.assertEqual(row[0], "a\x00b")
251
252    def CheckOptimizedUnicodeAsString(self):
253        # ASCII -> str argument
254        self.con.text_factory = sqlite.OptimizedUnicode
255        row = self.con.execute("select value from test").fetchone()
256        self.assertIs(type(row[0]), str)
257        self.assertEqual(row[0], "a\x00b")
258
259    def CheckOptimizedUnicodeAsUnicode(self):
260        # Non-ASCII -> unicode argument
261        self.con.text_factory = sqlite.OptimizedUnicode
262        self.con.execute("delete from test")
263        self.con.execute("insert into test (value) values (?)", (u'�\0�',))
264        row = self.con.execute("select value from test").fetchone()
265        self.assertIs(type(row[0]), unicode)
266        self.assertEqual(row[0], u"�\x00�")
267
268    def tearDown(self):
269        self.con.close()
270
271def suite():
272    connection_suite = unittest.makeSuite(ConnectionFactoryTests, "Check")
273    cursor_suite = unittest.makeSuite(CursorFactoryTests, "Check")
274    row_suite_compat = unittest.makeSuite(RowFactoryTestsBackwardsCompat, "Check")
275    row_suite = unittest.makeSuite(RowFactoryTests, "Check")
276    text_suite = unittest.makeSuite(TextFactoryTests, "Check")
277    text_zero_bytes_suite = unittest.makeSuite(TextFactoryTestsWithEmbeddedZeroBytes, "Check")
278    return unittest.TestSuite((connection_suite, cursor_suite, row_suite_compat, row_suite, text_suite, text_zero_bytes_suite))
279
280def test():
281    runner = unittest.TextTestRunner()
282    runner.run(suite())
283
284if __name__ == "__main__":
285    test()
286