1#!/usr/bin/env python
2#
3# test_multibytecodec_support.py
4#   Common Unittest Routines for CJK codecs
5#
6
7import codecs
8import os
9import re
10import sys
11import unittest
12from httplib import HTTPException
13from test import test_support
14from StringIO import StringIO
15
16class TestBase:
17    encoding        = ''   # codec name
18    codec           = None # codec tuple (with 4 elements)
19    tstring         = ''   # string to test StreamReader
20
21    codectests      = None # must set. codec test tuple
22    roundtriptest   = 1    # set if roundtrip is possible with unicode
23    has_iso10646    = 0    # set if this encoding contains whole iso10646 map
24    xmlcharnametest = None # string to test xmlcharrefreplace
25    unmappedunicode = u'\udeee' # a unicode codepoint that is not mapped.
26
27    def setUp(self):
28        if self.codec is None:
29            self.codec = codecs.lookup(self.encoding)
30        self.encode = self.codec.encode
31        self.decode = self.codec.decode
32        self.reader = self.codec.streamreader
33        self.writer = self.codec.streamwriter
34        self.incrementalencoder = self.codec.incrementalencoder
35        self.incrementaldecoder = self.codec.incrementaldecoder
36
37    def test_chunkcoding(self):
38        for native, utf8 in zip(*[StringIO(f).readlines()
39                                  for f in self.tstring]):
40            u = self.decode(native)[0]
41            self.assertEqual(u, utf8.decode('utf-8'))
42            if self.roundtriptest:
43                self.assertEqual(native, self.encode(u)[0])
44
45    def test_errorhandle(self):
46        for source, scheme, expected in self.codectests:
47            if isinstance(source, bytes):
48                func = self.decode
49            else:
50                func = self.encode
51            if expected:
52                result = func(source, scheme)[0]
53                if func is self.decode:
54                    self.assertTrue(type(result) is unicode, type(result))
55                    self.assertEqual(result, expected,
56                                     '%r.decode(%r, %r)=%r != %r'
57                                     % (source, self.encoding, scheme, result,
58                                        expected))
59                else:
60                    self.assertTrue(type(result) is bytes, type(result))
61                    self.assertEqual(result, expected,
62                                     '%r.encode(%r, %r)=%r != %r'
63                                     % (source, self.encoding, scheme, result,
64                                        expected))
65            else:
66                self.assertRaises(UnicodeError, func, source, scheme)
67
68    def test_xmlcharrefreplace(self):
69        if self.has_iso10646:
70            return
71
72        s = u"\u0b13\u0b23\u0b60 nd eggs"
73        self.assertEqual(
74            self.encode(s, "xmlcharrefreplace")[0],
75            "ଓଣୠ nd eggs"
76        )
77
78    def test_customreplace_encode(self):
79        if self.has_iso10646:
80            return
81
82        from htmlentitydefs import codepoint2name
83
84        def xmlcharnamereplace(exc):
85            if not isinstance(exc, UnicodeEncodeError):
86                raise TypeError("don't know how to handle %r" % exc)
87            l = []
88            for c in exc.object[exc.start:exc.end]:
89                if ord(c) in codepoint2name:
90                    l.append(u"&%s;" % codepoint2name[ord(c)])
91                else:
92                    l.append(u"&#%d;" % ord(c))
93            return (u"".join(l), exc.end)
94
95        codecs.register_error("test.xmlcharnamereplace", xmlcharnamereplace)
96
97        if self.xmlcharnametest:
98            sin, sout = self.xmlcharnametest
99        else:
100            sin = u"\xab\u211c\xbb = \u2329\u1234\u232a"
101            sout = "«ℜ» = ⟨ሴ⟩"
102        self.assertEqual(self.encode(sin,
103                                    "test.xmlcharnamereplace")[0], sout)
104
105    def test_callback_wrong_objects(self):
106        def myreplace(exc):
107            return (ret, exc.end)
108        codecs.register_error("test.cjktest", myreplace)
109
110        for ret in ([1, 2, 3], [], None, object(), 'string', ''):
111            self.assertRaises(TypeError, self.encode, self.unmappedunicode,
112                              'test.cjktest')
113
114    def test_callback_long_index(self):
115        def myreplace(exc):
116            return (u'x', long(exc.end))
117        codecs.register_error("test.cjktest", myreplace)
118        self.assertEqual(self.encode(u'abcd' + self.unmappedunicode + u'efgh',
119                                     'test.cjktest'), ('abcdxefgh', 9))
120
121        def myreplace(exc):
122            return (u'x', sys.maxint + 1)
123        codecs.register_error("test.cjktest", myreplace)
124        self.assertRaises(IndexError, self.encode, self.unmappedunicode,
125                          'test.cjktest')
126
127    def test_callback_None_index(self):
128        def myreplace(exc):
129            return (u'x', None)
130        codecs.register_error("test.cjktest", myreplace)
131        self.assertRaises(TypeError, self.encode, self.unmappedunicode,
132                          'test.cjktest')
133
134    def test_callback_backward_index(self):
135        def myreplace(exc):
136            if myreplace.limit > 0:
137                myreplace.limit -= 1
138                return (u'REPLACED', 0)
139            else:
140                return (u'TERMINAL', exc.end)
141        myreplace.limit = 3
142        codecs.register_error("test.cjktest", myreplace)
143        self.assertEqual(self.encode(u'abcd' + self.unmappedunicode + u'efgh',
144                                     'test.cjktest'),
145                ('abcdREPLACEDabcdREPLACEDabcdREPLACEDabcdTERMINALefgh', 9))
146
147    def test_callback_forward_index(self):
148        def myreplace(exc):
149            return (u'REPLACED', exc.end + 2)
150        codecs.register_error("test.cjktest", myreplace)
151        self.assertEqual(self.encode(u'abcd' + self.unmappedunicode + u'efgh',
152                                     'test.cjktest'), ('abcdREPLACEDgh', 9))
153
154    def test_callback_index_outofbound(self):
155        def myreplace(exc):
156            return (u'TERM', 100)
157        codecs.register_error("test.cjktest", myreplace)
158        self.assertRaises(IndexError, self.encode, self.unmappedunicode,
159                          'test.cjktest')
160
161    def test_incrementalencoder(self):
162        UTF8Reader = codecs.getreader('utf-8')
163        for sizehint in [None] + range(1, 33) + \
164                        [64, 128, 256, 512, 1024]:
165            istream = UTF8Reader(StringIO(self.tstring[1]))
166            ostream = StringIO()
167            encoder = self.incrementalencoder()
168            while 1:
169                if sizehint is not None:
170                    data = istream.read(sizehint)
171                else:
172                    data = istream.read()
173
174                if not data:
175                    break
176                e = encoder.encode(data)
177                ostream.write(e)
178
179            self.assertEqual(ostream.getvalue(), self.tstring[0])
180
181    def test_incrementaldecoder(self):
182        UTF8Writer = codecs.getwriter('utf-8')
183        for sizehint in [None, -1] + range(1, 33) + \
184                        [64, 128, 256, 512, 1024]:
185            istream = StringIO(self.tstring[0])
186            ostream = UTF8Writer(StringIO())
187            decoder = self.incrementaldecoder()
188            while 1:
189                data = istream.read(sizehint)
190                if not data:
191                    break
192                else:
193                    u = decoder.decode(data)
194                    ostream.write(u)
195
196            self.assertEqual(ostream.getvalue(), self.tstring[1])
197
198    def test_incrementalencoder_error_callback(self):
199        inv = self.unmappedunicode
200
201        e = self.incrementalencoder()
202        self.assertRaises(UnicodeEncodeError, e.encode, inv, True)
203
204        e.errors = 'ignore'
205        self.assertEqual(e.encode(inv, True), '')
206
207        e.reset()
208        def tempreplace(exc):
209            return (u'called', exc.end)
210        codecs.register_error('test.incremental_error_callback', tempreplace)
211        e.errors = 'test.incremental_error_callback'
212        self.assertEqual(e.encode(inv, True), 'called')
213
214        # again
215        e.errors = 'ignore'
216        self.assertEqual(e.encode(inv, True), '')
217
218    def test_streamreader(self):
219        UTF8Writer = codecs.getwriter('utf-8')
220        for name in ["read", "readline", "readlines"]:
221            for sizehint in [None, -1] + range(1, 33) + \
222                            [64, 128, 256, 512, 1024]:
223                istream = self.reader(StringIO(self.tstring[0]))
224                ostream = UTF8Writer(StringIO())
225                func = getattr(istream, name)
226                while 1:
227                    data = func(sizehint)
228                    if not data:
229                        break
230                    if name == "readlines":
231                        ostream.writelines(data)
232                    else:
233                        ostream.write(data)
234
235                self.assertEqual(ostream.getvalue(), self.tstring[1])
236
237    def test_streamwriter(self):
238        readfuncs = ('read', 'readline', 'readlines')
239        UTF8Reader = codecs.getreader('utf-8')
240        for name in readfuncs:
241            for sizehint in [None] + range(1, 33) + \
242                            [64, 128, 256, 512, 1024]:
243                istream = UTF8Reader(StringIO(self.tstring[1]))
244                ostream = self.writer(StringIO())
245                func = getattr(istream, name)
246                while 1:
247                    if sizehint is not None:
248                        data = func(sizehint)
249                    else:
250                        data = func()
251
252                    if not data:
253                        break
254                    if name == "readlines":
255                        ostream.writelines(data)
256                    else:
257                        ostream.write(data)
258
259                self.assertEqual(ostream.getvalue(), self.tstring[0])
260
261class TestBase_Mapping(unittest.TestCase):
262    pass_enctest = []
263    pass_dectest = []
264    supmaps = []
265    codectests = []
266
267    def __init__(self, *args, **kw):
268        unittest.TestCase.__init__(self, *args, **kw)
269        try:
270            self.open_mapping_file().close() # test it to report the error early
271        except (IOError, HTTPException):
272            self.skipTest("Could not retrieve "+self.mapfileurl)
273
274    def open_mapping_file(self):
275        return test_support.open_urlresource(self.mapfileurl)
276
277    def test_mapping_file(self):
278        if self.mapfileurl.endswith('.xml'):
279            self._test_mapping_file_ucm()
280        else:
281            self._test_mapping_file_plain()
282
283    def _test_mapping_file_plain(self):
284        _unichr = lambda c: eval("u'\\U%08x'" % int(c, 16))
285        unichrs = lambda s: u''.join(_unichr(c) for c in s.split('+'))
286        urt_wa = {}
287
288        with self.open_mapping_file() as f:
289            for line in f:
290                if not line:
291                    break
292                data = line.split('#')[0].strip().split()
293                if len(data) != 2:
294                    continue
295
296                csetval = eval(data[0])
297                if csetval <= 0x7F:
298                    csetch = chr(csetval & 0xff)
299                elif csetval >= 0x1000000:
300                    csetch = chr(csetval >> 24) + chr((csetval >> 16) & 0xff) + \
301                             chr((csetval >> 8) & 0xff) + chr(csetval & 0xff)
302                elif csetval >= 0x10000:
303                    csetch = chr(csetval >> 16) + \
304                             chr((csetval >> 8) & 0xff) + chr(csetval & 0xff)
305                elif csetval >= 0x100:
306                    csetch = chr(csetval >> 8) + chr(csetval & 0xff)
307                else:
308                    continue
309
310                unich = unichrs(data[1])
311                if unich == u'\ufffd' or unich in urt_wa:
312                    continue
313                urt_wa[unich] = csetch
314
315                self._testpoint(csetch, unich)
316
317    def _test_mapping_file_ucm(self):
318        with self.open_mapping_file() as f:
319            ucmdata = f.read()
320        uc = re.findall('<a u="([A-F0-9]{4})" b="([0-9A-F ]+)"/>', ucmdata)
321        for uni, coded in uc:
322            unich = unichr(int(uni, 16))
323            codech = ''.join(chr(int(c, 16)) for c in coded.split())
324            self._testpoint(codech, unich)
325
326    def test_mapping_supplemental(self):
327        for mapping in self.supmaps:
328            self._testpoint(*mapping)
329
330    def _testpoint(self, csetch, unich):
331        if (csetch, unich) not in self.pass_enctest:
332            try:
333                self.assertEqual(unich.encode(self.encoding), csetch)
334            except UnicodeError, exc:
335                self.fail('Encoding failed while testing %s -> %s: %s' % (
336                            repr(unich), repr(csetch), exc.reason))
337        if (csetch, unich) not in self.pass_dectest:
338            try:
339                self.assertEqual(csetch.decode(self.encoding), unich)
340            except UnicodeError, exc:
341                self.fail('Decoding failed while testing %s -> %s: %s' % (
342                            repr(csetch), repr(unich), exc.reason))
343
344    def test_errorhandle(self):
345        for source, scheme, expected in self.codectests:
346            if isinstance(source, bytes):
347                func = source.decode
348            else:
349                func = source.encode
350            if expected:
351                if isinstance(source, bytes):
352                    result = func(self.encoding, scheme)
353                    self.assertTrue(type(result) is unicode, type(result))
354                    self.assertEqual(result, expected,
355                                     '%r.decode(%r, %r)=%r != %r'
356                                     % (source, self.encoding, scheme, result,
357                                        expected))
358                else:
359                    result = func(self.encoding, scheme)
360                    self.assertTrue(type(result) is bytes, type(result))
361                    self.assertEqual(result, expected,
362                                     '%r.encode(%r, %r)=%r != %r'
363                                     % (source, self.encoding, scheme, result,
364                                        expected))
365            else:
366                self.assertRaises(UnicodeError, func, self.encoding, scheme)
367
368def load_teststring(name):
369    dir = os.path.join(os.path.dirname(__file__), 'cjkencodings')
370    with open(os.path.join(dir, name + '.txt'), 'rb') as f:
371        encoded = f.read()
372    with open(os.path.join(dir, name + '-utf8.txt'), 'rb') as f:
373        utf8 = f.read()
374    return encoded, utf8
375