1/**
2 * markupsafe._speedups
3 * ~~~~~~~~~~~~~~~~~~~~
4 *
5 * This module implements functions for automatic escaping in C for better
6 * performance.
7 *
8 * :copyright: (c) 2010 by Armin Ronacher.
9 * :license: BSD.
10 */
11
12#include <Python.h>
13
14#define ESCAPED_CHARS_TABLE_SIZE 63
15#define UNICHR(x) (PyUnicode_AS_UNICODE((PyUnicodeObject*)PyUnicode_DecodeASCII(x, strlen(x), NULL)));
16
17#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
18typedef int Py_ssize_t;
19#define PY_SSIZE_T_MAX INT_MAX
20#define PY_SSIZE_T_MIN INT_MIN
21#endif
22
23
24static PyObject* markup;
25static Py_ssize_t escaped_chars_delta_len[ESCAPED_CHARS_TABLE_SIZE];
26static Py_UNICODE *escaped_chars_repl[ESCAPED_CHARS_TABLE_SIZE];
27
28static int
29init_constants(void)
30{
31	PyObject *module;
32	/* happing of characters to replace */
33	escaped_chars_repl['"'] = UNICHR("&#34;");
34	escaped_chars_repl['\''] = UNICHR("&#39;");
35	escaped_chars_repl['&'] = UNICHR("&amp;");
36	escaped_chars_repl['<'] = UNICHR("&lt;");
37	escaped_chars_repl['>'] = UNICHR("&gt;");
38
39	/* lengths of those characters when replaced - 1 */
40	memset(escaped_chars_delta_len, 0, sizeof (escaped_chars_delta_len));
41	escaped_chars_delta_len['"'] = escaped_chars_delta_len['\''] = \
42		escaped_chars_delta_len['&'] = 4;
43	escaped_chars_delta_len['<'] = escaped_chars_delta_len['>'] = 3;
44
45	/* import markup type so that we can mark the return value */
46	module = PyImport_ImportModule("markupsafe");
47	if (!module)
48		return 0;
49	markup = PyObject_GetAttrString(module, "Markup");
50	Py_DECREF(module);
51
52	return 1;
53}
54
55static PyObject*
56escape_unicode(PyUnicodeObject *in)
57{
58	PyUnicodeObject *out;
59	Py_UNICODE *inp = PyUnicode_AS_UNICODE(in);
60	const Py_UNICODE *inp_end = PyUnicode_AS_UNICODE(in) + PyUnicode_GET_SIZE(in);
61	Py_UNICODE *next_escp;
62	Py_UNICODE *outp;
63	Py_ssize_t delta=0, erepl=0, delta_len=0;
64
65	/* First we need to figure out how long the escaped string will be */
66	while (*(inp) || inp < inp_end) {
67		if (*inp < ESCAPED_CHARS_TABLE_SIZE) {
68			delta += escaped_chars_delta_len[*inp];
69			erepl += !!escaped_chars_delta_len[*inp];
70		}
71		++inp;
72	}
73
74	/* Do we need to escape anything at all? */
75	if (!erepl) {
76		Py_INCREF(in);
77		return (PyObject*)in;
78	}
79
80	out = (PyUnicodeObject*)PyUnicode_FromUnicode(NULL, PyUnicode_GET_SIZE(in) + delta);
81	if (!out)
82		return NULL;
83
84	outp = PyUnicode_AS_UNICODE(out);
85	inp = PyUnicode_AS_UNICODE(in);
86	while (erepl-- > 0) {
87		/* look for the next substitution */
88		next_escp = inp;
89		while (next_escp < inp_end) {
90			if (*next_escp < ESCAPED_CHARS_TABLE_SIZE &&
91			    (delta_len = escaped_chars_delta_len[*next_escp])) {
92				++delta_len;
93				break;
94			}
95			++next_escp;
96		}
97
98		if (next_escp > inp) {
99			/* copy unescaped chars between inp and next_escp */
100			Py_UNICODE_COPY(outp, inp, next_escp-inp);
101			outp += next_escp - inp;
102		}
103
104		/* escape 'next_escp' */
105		Py_UNICODE_COPY(outp, escaped_chars_repl[*next_escp], delta_len);
106		outp += delta_len;
107
108		inp = next_escp + 1;
109	}
110	if (inp < inp_end)
111		Py_UNICODE_COPY(outp, inp, PyUnicode_GET_SIZE(in) - (inp - PyUnicode_AS_UNICODE(in)));
112
113	return (PyObject*)out;
114}
115
116
117static PyObject*
118escape(PyObject *self, PyObject *text)
119{
120	PyObject *s = NULL, *rv = NULL, *html;
121
122	/* we don't have to escape integers, bools or floats */
123	if (PyLong_CheckExact(text) ||
124#if PY_MAJOR_VERSION < 3
125	    PyInt_CheckExact(text) ||
126#endif
127	    PyFloat_CheckExact(text) || PyBool_Check(text) ||
128	    text == Py_None)
129		return PyObject_CallFunctionObjArgs(markup, text, NULL);
130
131	/* if the object has an __html__ method that performs the escaping */
132	html = PyObject_GetAttrString(text, "__html__");
133	if (html) {
134		rv = PyObject_CallObject(html, NULL);
135		Py_DECREF(html);
136		return rv;
137	}
138
139	/* otherwise make the object unicode if it isn't, then escape */
140	PyErr_Clear();
141	if (!PyUnicode_Check(text)) {
142#if PY_MAJOR_VERSION < 3
143		PyObject *unicode = PyObject_Unicode(text);
144#else
145		PyObject *unicode = PyObject_Str(text);
146#endif
147		if (!unicode)
148			return NULL;
149		s = escape_unicode((PyUnicodeObject*)unicode);
150		Py_DECREF(unicode);
151	}
152	else
153		s = escape_unicode((PyUnicodeObject*)text);
154
155	/* convert the unicode string into a markup object. */
156	rv = PyObject_CallFunctionObjArgs(markup, (PyObject*)s, NULL);
157	Py_DECREF(s);
158	return rv;
159}
160
161
162static PyObject*
163escape_silent(PyObject *self, PyObject *text)
164{
165	if (text != Py_None)
166		return escape(self, text);
167	return PyObject_CallFunctionObjArgs(markup, NULL);
168}
169
170
171static PyObject*
172soft_unicode(PyObject *self, PyObject *s)
173{
174	if (!PyUnicode_Check(s))
175#if PY_MAJOR_VERSION < 3
176		return PyObject_Unicode(s);
177#else
178		return PyObject_Str(s);
179#endif
180	Py_INCREF(s);
181	return s;
182}
183
184
185static PyMethodDef module_methods[] = {
186	{"escape", (PyCFunction)escape, METH_O,
187	 "escape(s) -> markup\n\n"
188	 "Convert the characters &, <, >, ', and \" in string s to HTML-safe\n"
189	 "sequences.  Use this if you need to display text that might contain\n"
190	 "such characters in HTML.  Marks return value as markup string."},
191	{"escape_silent", (PyCFunction)escape_silent, METH_O,
192	 "escape_silent(s) -> markup\n\n"
193	 "Like escape but converts None to an empty string."},
194	{"soft_unicode", (PyCFunction)soft_unicode, METH_O,
195	 "soft_unicode(object) -> string\n\n"
196         "Make a string unicode if it isn't already.  That way a markup\n"
197         "string is not converted back to unicode."},
198	{NULL, NULL, 0, NULL}		/* Sentinel */
199};
200
201
202#if PY_MAJOR_VERSION < 3
203
204#ifndef PyMODINIT_FUNC	/* declarations for DLL import/export */
205#define PyMODINIT_FUNC void
206#endif
207PyMODINIT_FUNC
208init_speedups(void)
209{
210	if (!init_constants())
211		return;
212
213	Py_InitModule3("markupsafe._speedups", module_methods, "");
214}
215
216#else /* Python 3.x module initialization */
217
218static struct PyModuleDef module_definition = {
219        PyModuleDef_HEAD_INIT,
220	"markupsafe._speedups",
221	NULL,
222	-1,
223	module_methods,
224	NULL,
225	NULL,
226	NULL,
227	NULL
228};
229
230PyMODINIT_FUNC
231PyInit__speedups(void)
232{
233	if (!init_constants())
234		return NULL;
235
236	return PyModule_Create(&module_definition);
237}
238
239#endif
240