1from __future__ import absolute_import
2import itertools
3
4import lit.util
5from lit.ShCommands import Command, Pipeline, Seq
6
7class ShLexer:
8    def __init__(self, data, win32Escapes = False):
9        self.data = data
10        self.pos = 0
11        self.end = len(data)
12        self.win32Escapes = win32Escapes
13
14    def eat(self):
15        c = self.data[self.pos]
16        self.pos += 1
17        return c
18
19    def look(self):
20        return self.data[self.pos]
21
22    def maybe_eat(self, c):
23        """
24        maybe_eat(c) - Consume the character c if it is the next character,
25        returning True if a character was consumed. """
26        if self.data[self.pos] == c:
27            self.pos += 1
28            return True
29        return False
30
31    def lex_arg_fast(self, c):
32        # Get the leading whitespace free section.
33        chunk = self.data[self.pos - 1:].split(None, 1)[0]
34
35        # If it has special characters, the fast path failed.
36        if ('|' in chunk or '&' in chunk or
37            '<' in chunk or '>' in chunk or
38            "'" in chunk or '"' in chunk or
39            ';' in chunk or '\\' in chunk):
40            return None
41
42        self.pos = self.pos - 1 + len(chunk)
43        return chunk
44
45    def lex_arg_slow(self, c):
46        if c in "'\"":
47            str = self.lex_arg_quoted(c)
48        else:
49            str = c
50        while self.pos != self.end:
51            c = self.look()
52            if c.isspace() or c in "|&;":
53                break
54            elif c in '><':
55                # This is an annoying case; we treat '2>' as a single token so
56                # we don't have to track whitespace tokens.
57
58                # If the parse string isn't an integer, do the usual thing.
59                if not str.isdigit():
60                    break
61
62                # Otherwise, lex the operator and convert to a redirection
63                # token.
64                num = int(str)
65                tok = self.lex_one_token()
66                assert isinstance(tok, tuple) and len(tok) == 1
67                return (tok[0], num)
68            elif c == '"':
69                self.eat()
70                str += self.lex_arg_quoted('"')
71            elif c == "'":
72                self.eat()
73                str += self.lex_arg_quoted("'")
74            elif not self.win32Escapes and c == '\\':
75                # Outside of a string, '\\' escapes everything.
76                self.eat()
77                if self.pos == self.end:
78                    lit.util.warning(
79                        "escape at end of quoted argument in: %r" % self.data)
80                    return str
81                str += self.eat()
82            else:
83                str += self.eat()
84        return str
85
86    def lex_arg_quoted(self, delim):
87        str = ''
88        while self.pos != self.end:
89            c = self.eat()
90            if c == delim:
91                return str
92            elif c == '\\' and delim == '"':
93                # Inside a '"' quoted string, '\\' only escapes the quote
94                # character and backslash, otherwise it is preserved.
95                if self.pos == self.end:
96                    lit.util.warning(
97                        "escape at end of quoted argument in: %r" % self.data)
98                    return str
99                c = self.eat()
100                if c == '"': #
101                    str += '"'
102                elif c == '\\':
103                    str += '\\'
104                else:
105                    str += '\\' + c
106            else:
107                str += c
108        lit.util.warning("missing quote character in %r" % self.data)
109        return str
110
111    def lex_arg_checked(self, c):
112        pos = self.pos
113        res = self.lex_arg_fast(c)
114        end = self.pos
115
116        self.pos = pos
117        reference = self.lex_arg_slow(c)
118        if res is not None:
119            if res != reference:
120                raise ValueError("Fast path failure: %r != %r" % (
121                        res, reference))
122            if self.pos != end:
123                raise ValueError("Fast path failure: %r != %r" % (
124                        self.pos, end))
125        return reference
126
127    def lex_arg(self, c):
128        return self.lex_arg_fast(c) or self.lex_arg_slow(c)
129
130    def lex_one_token(self):
131        """
132        lex_one_token - Lex a single 'sh' token. """
133
134        c = self.eat()
135        if c == ';':
136            return (c,)
137        if c == '|':
138            if self.maybe_eat('|'):
139                return ('||',)
140            return (c,)
141        if c == '&':
142            if self.maybe_eat('&'):
143                return ('&&',)
144            if self.maybe_eat('>'):
145                return ('&>',)
146            return (c,)
147        if c == '>':
148            if self.maybe_eat('&'):
149                return ('>&',)
150            if self.maybe_eat('>'):
151                return ('>>',)
152            return (c,)
153        if c == '<':
154            if self.maybe_eat('&'):
155                return ('<&',)
156            if self.maybe_eat('>'):
157                return ('<<',)
158            return (c,)
159
160        return self.lex_arg(c)
161
162    def lex(self):
163        while self.pos != self.end:
164            if self.look().isspace():
165                self.eat()
166            else:
167                yield self.lex_one_token()
168
169###
170
171class ShParser:
172    def __init__(self, data, win32Escapes = False, pipefail = False):
173        self.data = data
174        self.pipefail = pipefail
175        self.tokens = ShLexer(data, win32Escapes = win32Escapes).lex()
176
177    def lex(self):
178        for item in self.tokens:
179            return item
180        return None
181
182    def look(self):
183        token = self.lex()
184        if token is not None:
185            self.tokens = itertools.chain([token], self.tokens)
186        return token
187
188    def parse_command(self):
189        tok = self.lex()
190        if not tok:
191            raise ValueError("empty command!")
192        if isinstance(tok, tuple):
193            raise ValueError("syntax error near unexpected token %r" % tok[0])
194
195        args = [tok]
196        redirects = []
197        while 1:
198            tok = self.look()
199
200            # EOF?
201            if tok is None:
202                break
203
204            # If this is an argument, just add it to the current command.
205            if isinstance(tok, str):
206                args.append(self.lex())
207                continue
208
209            # Otherwise see if it is a terminator.
210            assert isinstance(tok, tuple)
211            if tok[0] in ('|',';','&','||','&&'):
212                break
213
214            # Otherwise it must be a redirection.
215            op = self.lex()
216            arg = self.lex()
217            if not arg:
218                raise ValueError("syntax error near token %r" % op[0])
219            redirects.append((op, arg))
220
221        return Command(args, redirects)
222
223    def parse_pipeline(self):
224        negate = False
225
226        commands = [self.parse_command()]
227        while self.look() == ('|',):
228            self.lex()
229            commands.append(self.parse_command())
230        return Pipeline(commands, negate, self.pipefail)
231
232    def parse(self):
233        lhs = self.parse_pipeline()
234
235        while self.look():
236            operator = self.lex()
237            assert isinstance(operator, tuple) and len(operator) == 1
238
239            if not self.look():
240                raise ValueError(
241                    "missing argument to operator %r" % operator[0])
242
243            # FIXME: Operator precedence!!
244            lhs = Seq(lhs, operator[0], self.parse_pipeline())
245
246        return lhs
247
248###
249
250import unittest
251
252class TestShLexer(unittest.TestCase):
253    def lex(self, str, *args, **kwargs):
254        return list(ShLexer(str, *args, **kwargs).lex())
255
256    def test_basic(self):
257        self.assertEqual(self.lex('a|b>c&d<e;f'),
258                         ['a', ('|',), 'b', ('>',), 'c', ('&',), 'd',
259                          ('<',), 'e', (';',), 'f'])
260
261    def test_redirection_tokens(self):
262        self.assertEqual(self.lex('a2>c'),
263                         ['a2', ('>',), 'c'])
264        self.assertEqual(self.lex('a 2>c'),
265                         ['a', ('>',2), 'c'])
266
267    def test_quoting(self):
268        self.assertEqual(self.lex(""" 'a' """),
269                         ['a'])
270        self.assertEqual(self.lex(""" "hello\\"world" """),
271                         ['hello"world'])
272        self.assertEqual(self.lex(""" "hello\\'world" """),
273                         ["hello\\'world"])
274        self.assertEqual(self.lex(""" "hello\\\\world" """),
275                         ["hello\\world"])
276        self.assertEqual(self.lex(""" he"llo wo"rld """),
277                         ["hello world"])
278        self.assertEqual(self.lex(""" a\\ b a\\\\b """),
279                         ["a b", "a\\b"])
280        self.assertEqual(self.lex(""" "" "" """),
281                         ["", ""])
282        self.assertEqual(self.lex(""" a\\ b """, win32Escapes = True),
283                         ['a\\', 'b'])
284
285class TestShParse(unittest.TestCase):
286    def parse(self, str):
287        return ShParser(str).parse()
288
289    def test_basic(self):
290        self.assertEqual(self.parse('echo hello'),
291                         Pipeline([Command(['echo', 'hello'], [])], False))
292        self.assertEqual(self.parse('echo ""'),
293                         Pipeline([Command(['echo', ''], [])], False))
294        self.assertEqual(self.parse("""echo -DFOO='a'"""),
295                         Pipeline([Command(['echo', '-DFOO=a'], [])], False))
296        self.assertEqual(self.parse('echo -DFOO="a"'),
297                         Pipeline([Command(['echo', '-DFOO=a'], [])], False))
298
299    def test_redirection(self):
300        self.assertEqual(self.parse('echo hello > c'),
301                         Pipeline([Command(['echo', 'hello'],
302                                           [((('>'),), 'c')])], False))
303        self.assertEqual(self.parse('echo hello > c >> d'),
304                         Pipeline([Command(['echo', 'hello'], [(('>',), 'c'),
305                                                     (('>>',), 'd')])], False))
306        self.assertEqual(self.parse('a 2>&1'),
307                         Pipeline([Command(['a'], [(('>&',2), '1')])], False))
308
309    def test_pipeline(self):
310        self.assertEqual(self.parse('a | b'),
311                         Pipeline([Command(['a'], []),
312                                   Command(['b'], [])],
313                                  False))
314
315        self.assertEqual(self.parse('a | b | c'),
316                         Pipeline([Command(['a'], []),
317                                   Command(['b'], []),
318                                   Command(['c'], [])],
319                                  False))
320
321    def test_list(self):
322        self.assertEqual(self.parse('a ; b'),
323                         Seq(Pipeline([Command(['a'], [])], False),
324                             ';',
325                             Pipeline([Command(['b'], [])], False)))
326
327        self.assertEqual(self.parse('a & b'),
328                         Seq(Pipeline([Command(['a'], [])], False),
329                             '&',
330                             Pipeline([Command(['b'], [])], False)))
331
332        self.assertEqual(self.parse('a && b'),
333                         Seq(Pipeline([Command(['a'], [])], False),
334                             '&&',
335                             Pipeline([Command(['b'], [])], False)))
336
337        self.assertEqual(self.parse('a || b'),
338                         Seq(Pipeline([Command(['a'], [])], False),
339                             '||',
340                             Pipeline([Command(['b'], [])], False)))
341
342        self.assertEqual(self.parse('a && b || c'),
343                         Seq(Seq(Pipeline([Command(['a'], [])], False),
344                                 '&&',
345                                 Pipeline([Command(['b'], [])], False)),
346                             '||',
347                             Pipeline([Command(['c'], [])], False)))
348
349        self.assertEqual(self.parse('a; b'),
350                         Seq(Pipeline([Command(['a'], [])], False),
351                             ';',
352                             Pipeline([Command(['b'], [])], False)))
353
354if __name__ == '__main__':
355    unittest.main()
356