1# coding=utf-8
2#
3# Copyright © 2011 Intel Corporation
4#
5# Permission is hereby granted, free of charge, to any person obtaining a
6# copy of this software and associated documentation files (the "Software"),
7# to deal in the Software without restriction, including without limitation
8# the rights to use, copy, modify, merge, publish, distribute, sublicense,
9# and/or sell copies of the Software, and to permit persons to whom the
10# Software is furnished to do so, subject to the following conditions:
11#
12# The above copyright notice and this permission notice (including the next
13# paragraph) shall be included in all copies or substantial portions of the
14# Software.
15#
16# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
22# DEALINGS IN THE SOFTWARE.
23
24import os
25import os.path
26import re
27import subprocess
28import sys
29
30sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) # For access to sexps.py, which is in parent dir
31from sexps import *
32
33def make_test_case(f_name, ret_type, body):
34    """Create a simple optimization test case consisting of a single
35    function with the given name, return type, and body.
36
37    Global declarations are automatically created for any undeclared
38    variables that are referenced by the function.  All undeclared
39    variables are assumed to be floats.
40    """
41    check_sexp(body)
42    declarations = {}
43    def make_declarations(sexp, already_declared = ()):
44        if isinstance(sexp, list):
45            if len(sexp) == 2 and sexp[0] == 'var_ref':
46                if sexp[1] not in already_declared:
47                    declarations[sexp[1]] = [
48                        'declare', ['in'], 'float', sexp[1]]
49            elif len(sexp) == 4 and sexp[0] == 'assign':
50                assert sexp[2][0] == 'var_ref'
51                if sexp[2][1] not in already_declared:
52                    declarations[sexp[2][1]] = [
53                        'declare', ['out'], 'float', sexp[2][1]]
54                make_declarations(sexp[3], already_declared)
55            else:
56                already_declared = set(already_declared)
57                for s in sexp:
58                    if isinstance(s, list) and len(s) >= 4 and \
59                            s[0] == 'declare':
60                        already_declared.add(s[3])
61                    else:
62                        make_declarations(s, already_declared)
63    make_declarations(body)
64    return declarations.values() + \
65        [['function', f_name, ['signature', ret_type, ['parameters'], body]]]
66
67
68# The following functions can be used to build expressions.
69
70def const_float(value):
71    """Create an expression representing the given floating point value."""
72    return ['constant', 'float', ['{0:.6f}'.format(value)]]
73
74def const_bool(value):
75    """Create an expression representing the given boolean value.
76
77    If value is not a boolean, it is converted to a boolean.  So, for
78    instance, const_bool(1) is equivalent to const_bool(True).
79    """
80    return ['constant', 'bool', ['{0}'.format(1 if value else 0)]]
81
82def gt_zero(var_name):
83    """Create Construct the expression var_name > 0"""
84    return ['expression', 'bool', '>', ['var_ref', var_name], const_float(0)]
85
86
87# The following functions can be used to build complex control flow
88# statements.  All of these functions return statement lists (even
89# those which only create a single statement), so that statements can
90# be sequenced together using the '+' operator.
91
92def return_(value = None):
93    """Create a return statement."""
94    if value is not None:
95        return [['return', value]]
96    else:
97        return [['return']]
98
99def break_():
100    """Create a break statement."""
101    return ['break']
102
103def continue_():
104    """Create a continue statement."""
105    return ['continue']
106
107def simple_if(var_name, then_statements, else_statements = None):
108    """Create a statement of the form
109
110    if (var_name > 0.0) {
111       <then_statements>
112    } else {
113       <else_statements>
114    }
115
116    else_statements may be omitted.
117    """
118    if else_statements is None:
119        else_statements = []
120    check_sexp(then_statements)
121    check_sexp(else_statements)
122    return [['if', gt_zero(var_name), then_statements, else_statements]]
123
124def loop(statements):
125    """Create a loop containing the given statements as its loop
126    body.
127    """
128    check_sexp(statements)
129    return [['loop', [], [], [], [], statements]]
130
131def declare_temp(var_type, var_name):
132    """Create a declaration of the form
133
134    (declare (temporary) <var_type> <var_name)
135    """
136    return [['declare', ['temporary'], var_type, var_name]]
137
138def assign_x(var_name, value):
139    """Create a statement that assigns <value> to the variable
140    <var_name>.  The assignment uses the mask (x).
141    """
142    check_sexp(value)
143    return [['assign', ['x'], ['var_ref', var_name], value]]
144
145def complex_if(var_prefix, statements):
146    """Create a statement of the form
147
148    if (<var_prefix>a > 0.0) {
149       if (<var_prefix>b > 0.0) {
150          <statements>
151       }
152    }
153
154    This is useful in testing jump lowering, because if <statements>
155    ends in a jump, lower_jumps.cpp won't try to combine this
156    construct with the code that follows it, as it might do for a
157    simple if.
158
159    All variables used in the if statement are prefixed with
160    var_prefix.  This can be used to ensure uniqueness.
161    """
162    check_sexp(statements)
163    return simple_if(var_prefix + 'a', simple_if(var_prefix + 'b', statements))
164
165def declare_execute_flag():
166    """Create the statements that lower_jumps.cpp uses to declare and
167    initialize the temporary boolean execute_flag.
168    """
169    return declare_temp('bool', 'execute_flag') + \
170        assign_x('execute_flag', const_bool(True))
171
172def declare_return_flag():
173    """Create the statements that lower_jumps.cpp uses to declare and
174    initialize the temporary boolean return_flag.
175    """
176    return declare_temp('bool', 'return_flag') + \
177        assign_x('return_flag', const_bool(False))
178
179def declare_return_value():
180    """Create the statements that lower_jumps.cpp uses to declare and
181    initialize the temporary variable return_value.  Assume that
182    return_value is a float.
183    """
184    return declare_temp('float', 'return_value')
185
186def declare_break_flag():
187    """Create the statements that lower_jumps.cpp uses to declare and
188    initialize the temporary boolean break_flag.
189    """
190    return declare_temp('bool', 'break_flag') + \
191        assign_x('break_flag', const_bool(False))
192
193def lowered_return_simple(value = None):
194    """Create the statements that lower_jumps.cpp lowers a return
195    statement to, in situations where it does not need to clear the
196    execute flag.
197    """
198    if value:
199        result = assign_x('return_value', value)
200    else:
201        result = []
202    return result + assign_x('return_flag', const_bool(True))
203
204def lowered_return(value = None):
205    """Create the statements that lower_jumps.cpp lowers a return
206    statement to, in situations where it needs to clear the execute
207    flag.
208    """
209    return lowered_return_simple(value) + \
210        assign_x('execute_flag', const_bool(False))
211
212def lowered_continue():
213    """Create the statement that lower_jumps.cpp lowers a continue
214    statement to.
215    """
216    return assign_x('execute_flag', const_bool(False))
217
218def lowered_break_simple():
219    """Create the statement that lower_jumps.cpp lowers a break
220    statement to, in situations where it does not need to clear the
221    execute flag.
222    """
223    return assign_x('break_flag', const_bool(True))
224
225def lowered_break():
226    """Create the statement that lower_jumps.cpp lowers a break
227    statement to, in situations where it needs to clear the execute
228    flag.
229    """
230    return lowered_break_simple() + assign_x('execute_flag', const_bool(False))
231
232def if_execute_flag(statements):
233    """Wrap statements in an if test so that they will only execute if
234    execute_flag is True.
235    """
236    check_sexp(statements)
237    return [['if', ['var_ref', 'execute_flag'], statements, []]]
238
239def if_not_return_flag(statements):
240    """Wrap statements in an if test so that they will only execute if
241    return_flag is False.
242    """
243    check_sexp(statements)
244    return [['if', ['var_ref', 'return_flag'], [], statements]]
245
246def final_return():
247    """Create the return statement that lower_jumps.cpp places at the
248    end of a function when lowering returns.
249    """
250    return [['return', ['var_ref', 'return_value']]]
251
252def final_break():
253    """Create the conditional break statement that lower_jumps.cpp
254    places at the end of a function when lowering breaks.
255    """
256    return [['if', ['var_ref', 'break_flag'], break_(), []]]
257
258def bash_quote(*args):
259    """Quote the arguments appropriately so that bash will understand
260    each argument as a single word.
261    """
262    def quote_word(word):
263        for c in word:
264            if not (c.isalpha() or c.isdigit() or c in '@%_-+=:,./'):
265                break
266        else:
267            if not word:
268                return "''"
269            return word
270        return "'{0}'".format(word.replace("'", "'\"'\"'"))
271    return ' '.join(quote_word(word) for word in args)
272
273def create_test_case(doc_string, input_sexp, expected_sexp, test_name,
274                     pull_out_jumps=False, lower_sub_return=False,
275                     lower_main_return=False, lower_continue=False,
276                     lower_break=False):
277    """Create a test case that verifies that do_lower_jumps transforms
278    the given code in the expected way.
279    """
280    doc_lines = [line.strip() for line in doc_string.splitlines()]
281    doc_string = ''.join('# {0}\n'.format(line) for line in doc_lines if line != '')
282    check_sexp(input_sexp)
283    check_sexp(expected_sexp)
284    input_str = sexp_to_string(sort_decls(input_sexp))
285    expected_output = sexp_to_string(sort_decls(expected_sexp))
286
287    optimization = (
288        'do_lower_jumps({0:d}, {1:d}, {2:d}, {3:d}, {4:d})'.format(
289            pull_out_jumps, lower_sub_return, lower_main_return,
290            lower_continue, lower_break))
291    args = ['../../glsl_test', 'optpass', '--quiet', '--input-ir', optimization]
292    test_file = '{0}.opt_test'.format(test_name)
293    with open(test_file, 'w') as f:
294        f.write('#!/bin/bash\n#\n# This file was generated by create_test_cases.py.\n#\n')
295        f.write(doc_string)
296        f.write('{0} <<EOF\n'.format(bash_quote(*args)))
297        f.write('{0}\nEOF\n'.format(input_str))
298    os.chmod(test_file, 0774)
299    expected_file = '{0}.opt_test.expected'.format(test_name)
300    with open(expected_file, 'w') as f:
301        f.write('{0}\n'.format(expected_output))
302
303def test_lower_returns_main():
304    doc_string = """Test that do_lower_jumps respects the lower_main_return
305    flag in deciding whether to lower returns in the main
306    function.
307    """
308    input_sexp = make_test_case('main', 'void', (
309            complex_if('', return_())
310            ))
311    expected_sexp = make_test_case('main', 'void', (
312            declare_execute_flag() +
313            declare_return_flag() +
314            complex_if('', lowered_return())
315            ))
316    create_test_case(doc_string, input_sexp, expected_sexp, 'lower_returns_main_true',
317                     lower_main_return=True)
318    create_test_case(doc_string, input_sexp, input_sexp, 'lower_returns_main_false',
319                     lower_main_return=False)
320
321def test_lower_returns_sub():
322    doc_string = """Test that do_lower_jumps respects the lower_sub_return flag
323    in deciding whether to lower returns in subroutines.
324    """
325    input_sexp = make_test_case('sub', 'void', (
326            complex_if('', return_())
327            ))
328    expected_sexp = make_test_case('sub', 'void', (
329            declare_execute_flag() +
330            declare_return_flag() +
331            complex_if('', lowered_return())
332            ))
333    create_test_case(doc_string, input_sexp, expected_sexp, 'lower_returns_sub_true',
334                     lower_sub_return=True)
335    create_test_case(doc_string, input_sexp, input_sexp, 'lower_returns_sub_false',
336                     lower_sub_return=False)
337
338def test_lower_returns_1():
339    doc_string = """Test that a void return at the end of a function is
340    eliminated.
341    """
342    input_sexp = make_test_case('main', 'void', (
343            assign_x('a', const_float(1)) +
344            return_()
345            ))
346    expected_sexp = make_test_case('main', 'void', (
347            assign_x('a', const_float(1))
348            ))
349    create_test_case(doc_string, input_sexp, expected_sexp, 'lower_returns_1',
350                     lower_main_return=True)
351
352def test_lower_returns_2():
353    doc_string = """Test that lowering is not performed on a non-void return at
354    the end of subroutine.
355    """
356    input_sexp = make_test_case('sub', 'float', (
357            assign_x('a', const_float(1)) +
358            return_(const_float(1))
359            ))
360    create_test_case(doc_string, input_sexp, input_sexp, 'lower_returns_2',
361                     lower_sub_return=True)
362
363def test_lower_returns_3():
364    doc_string = """Test lowering of returns when there is one nested inside a
365    complex structure of ifs, and one at the end of a function.
366
367    In this case, the latter return needs to be lowered because it
368    will not be at the end of the function once the final return
369    is inserted.
370    """
371    input_sexp = make_test_case('sub', 'float', (
372            complex_if('', return_(const_float(1))) +
373            return_(const_float(2))
374            ))
375    expected_sexp = make_test_case('sub', 'float', (
376            declare_execute_flag() +
377            declare_return_value() +
378            declare_return_flag() +
379            complex_if('', lowered_return(const_float(1))) +
380            if_execute_flag(lowered_return(const_float(2))) +
381            final_return()
382            ))
383    create_test_case(doc_string, input_sexp, expected_sexp, 'lower_returns_3',
384                     lower_sub_return=True)
385
386def test_lower_returns_4():
387    doc_string = """Test that returns are properly lowered when they occur in
388    both branches of an if-statement.
389    """
390    input_sexp = make_test_case('sub', 'float', (
391            simple_if('a', return_(const_float(1)),
392                      return_(const_float(2)))
393            ))
394    expected_sexp = make_test_case('sub', 'float', (
395            declare_execute_flag() +
396            declare_return_value() +
397            declare_return_flag() +
398            simple_if('a', lowered_return(const_float(1)),
399                      lowered_return(const_float(2))) +
400            final_return()
401            ))
402    create_test_case(doc_string, input_sexp, expected_sexp, 'lower_returns_4',
403                     lower_sub_return=True)
404
405def test_lower_unified_returns():
406    doc_string = """If both branches of an if statement end in a return, and
407    pull_out_jumps is True, then those returns should be lifted
408    outside the if and then properly lowered.
409
410    Verify that this lowering occurs during the same pass as the
411    lowering of other returns by checking that extra temporary
412    variables aren't generated.
413    """
414    input_sexp = make_test_case('main', 'void', (
415            complex_if('a', return_()) +
416            simple_if('b', simple_if('c', return_(), return_()))
417            ))
418    expected_sexp = make_test_case('main', 'void', (
419            declare_execute_flag() +
420            declare_return_flag() +
421            complex_if('a', lowered_return()) +
422            if_execute_flag(simple_if('b', (simple_if('c', [], []) +
423                                            lowered_return())))
424            ))
425    create_test_case(doc_string, input_sexp, expected_sexp, 'lower_unified_returns',
426                     lower_main_return=True, pull_out_jumps=True)
427
428def test_lower_pulled_out_jump():
429    doc_string = """If one branch of an if ends in a jump, and control cannot
430    fall out the bottom of the other branch, and pull_out_jumps is
431    True, then the jump is lifted outside the if.
432
433    Verify that this lowering occurs during the same pass as the
434    lowering of other jumps by checking that extra temporary
435    variables aren't generated.
436    """
437    input_sexp = make_test_case('main', 'void', (
438            complex_if('a', return_()) +
439            loop(simple_if('b', simple_if('c', break_(), continue_()),
440                           return_())) +
441            assign_x('d', const_float(1))
442            ))
443    # Note: optimization produces two other effects: the break
444    # gets lifted out of the if statements, and the code after the
445    # loop gets guarded so that it only executes if the return
446    # flag is clear.
447    expected_sexp = make_test_case('main', 'void', (
448            declare_execute_flag() +
449            declare_return_flag() +
450            complex_if('a', lowered_return()) +
451            if_execute_flag(
452                loop(simple_if('b', simple_if('c', [], continue_()),
453                               lowered_return_simple()) +
454                     break_()) +
455                if_not_return_flag(assign_x('d', const_float(1))))
456            ))
457    create_test_case(doc_string, input_sexp, expected_sexp, 'lower_pulled_out_jump',
458                     lower_main_return=True, pull_out_jumps=True)
459
460def test_lower_breaks_1():
461    doc_string = """If a loop contains an unconditional break at the bottom of
462    it, it should not be lowered."""
463    input_sexp = make_test_case('main', 'void', (
464            loop(assign_x('a', const_float(1)) +
465                 break_())
466            ))
467    expected_sexp = input_sexp
468    create_test_case(doc_string, input_sexp, expected_sexp, 'lower_breaks_1', lower_break=True)
469
470def test_lower_breaks_2():
471    doc_string = """If a loop contains a conditional break at the bottom of it,
472    it should not be lowered if it is in the then-clause.
473    """
474    input_sexp = make_test_case('main', 'void', (
475            loop(assign_x('a', const_float(1)) +
476                 simple_if('b', break_()))
477            ))
478    expected_sexp = input_sexp
479    create_test_case(doc_string, input_sexp, expected_sexp, 'lower_breaks_2', lower_break=True)
480
481def test_lower_breaks_3():
482    doc_string = """If a loop contains a conditional break at the bottom of it,
483    it should not be lowered if it is in the then-clause, even if
484    there are statements preceding the break.
485    """
486    input_sexp = make_test_case('main', 'void', (
487            loop(assign_x('a', const_float(1)) +
488                 simple_if('b', (assign_x('c', const_float(1)) +
489                                 break_())))
490            ))
491    expected_sexp = input_sexp
492    create_test_case(doc_string, input_sexp, expected_sexp, 'lower_breaks_3', lower_break=True)
493
494def test_lower_breaks_4():
495    doc_string = """If a loop contains a conditional break at the bottom of it,
496    it should not be lowered if it is in the else-clause.
497    """
498    input_sexp = make_test_case('main', 'void', (
499            loop(assign_x('a', const_float(1)) +
500                 simple_if('b', [], break_()))
501            ))
502    expected_sexp = input_sexp
503    create_test_case(doc_string, input_sexp, expected_sexp, 'lower_breaks_4', lower_break=True)
504
505def test_lower_breaks_5():
506    doc_string = """If a loop contains a conditional break at the bottom of it,
507    it should not be lowered if it is in the else-clause, even if
508    there are statements preceding the break.
509    """
510    input_sexp = make_test_case('main', 'void', (
511            loop(assign_x('a', const_float(1)) +
512                 simple_if('b', [], (assign_x('c', const_float(1)) +
513                                     break_())))
514            ))
515    expected_sexp = input_sexp
516    create_test_case(doc_string, input_sexp, expected_sexp, 'lower_breaks_5', lower_break=True)
517
518def test_lower_breaks_6():
519    doc_string = """If a loop contains conditional breaks and continues, and
520    ends in an unconditional break, then the unconditional break
521    needs to be lowered, because it will no longer be at the end
522    of the loop after the final break is added.
523    """
524    input_sexp = make_test_case('main', 'void', (
525            loop(simple_if('a', (complex_if('b', continue_()) +
526                                 complex_if('c', break_()))) +
527                 break_())
528            ))
529    expected_sexp = make_test_case('main', 'void', (
530            declare_break_flag() +
531            loop(declare_execute_flag() +
532                 simple_if(
533                    'a',
534                    (complex_if('b', lowered_continue()) +
535                     if_execute_flag(
536                            complex_if('c', lowered_break())))) +
537                 if_execute_flag(lowered_break_simple()) +
538                 final_break())
539            ))
540    create_test_case(doc_string, input_sexp, expected_sexp, 'lower_breaks_6',
541                     lower_break=True, lower_continue=True)
542
543def test_lower_guarded_conditional_break():
544    doc_string = """Normally a conditional break at the end of a loop isn't
545    lowered, however if the conditional break gets placed inside
546    an if(execute_flag) because of earlier lowering of continues,
547    then the break needs to be lowered.
548    """
549    input_sexp = make_test_case('main', 'void', (
550            loop(complex_if('a', continue_()) +
551                 simple_if('b', break_()))
552            ))
553    expected_sexp = make_test_case('main', 'void', (
554            declare_break_flag() +
555            loop(declare_execute_flag() +
556                 complex_if('a', lowered_continue()) +
557                 if_execute_flag(simple_if('b', lowered_break())) +
558                 final_break())
559            ))
560    create_test_case(doc_string, input_sexp, expected_sexp, 'lower_guarded_conditional_break',
561                     lower_break=True, lower_continue=True)
562
563def test_remove_continue_at_end_of_loop():
564    doc_string = """Test that a redundant continue-statement at the end of a
565    loop is removed.
566    """
567    input_sexp = make_test_case('main', 'void', (
568            loop(assign_x('a', const_float(1)) +
569                 continue_())
570            ))
571    expected_sexp = make_test_case('main', 'void', (
572            loop(assign_x('a', const_float(1)))
573            ))
574    create_test_case(doc_string, input_sexp, expected_sexp, 'remove_continue_at_end_of_loop')
575
576def test_lower_return_void_at_end_of_loop():
577    doc_string = """Test that a return of void at the end of a loop is properly
578    lowered.
579    """
580    input_sexp = make_test_case('main', 'void', (
581            loop(assign_x('a', const_float(1)) +
582                 return_()) +
583            assign_x('b', const_float(2))
584            ))
585    expected_sexp = make_test_case('main', 'void', (
586            declare_return_flag() +
587            loop(assign_x('a', const_float(1)) +
588                 lowered_return_simple() +
589                 break_()) +
590            if_not_return_flag(assign_x('b', const_float(2)))
591            ))
592    create_test_case(doc_string, input_sexp, input_sexp, 'return_void_at_end_of_loop_lower_nothing')
593    create_test_case(doc_string, input_sexp, expected_sexp, 'return_void_at_end_of_loop_lower_return',
594                     lower_main_return=True)
595    create_test_case(doc_string, input_sexp, expected_sexp, 'return_void_at_end_of_loop_lower_return_and_break',
596                     lower_main_return=True, lower_break=True)
597
598def test_lower_return_non_void_at_end_of_loop():
599    doc_string = """Test that a non-void return at the end of a loop is
600    properly lowered.
601    """
602    input_sexp = make_test_case('sub', 'float', (
603            loop(assign_x('a', const_float(1)) +
604                 return_(const_float(2))) +
605            assign_x('b', const_float(3)) +
606            return_(const_float(4))
607            ))
608    expected_sexp = make_test_case('sub', 'float', (
609            declare_execute_flag() +
610            declare_return_value() +
611            declare_return_flag() +
612            loop(assign_x('a', const_float(1)) +
613                 lowered_return_simple(const_float(2)) +
614                 break_()) +
615            if_not_return_flag(assign_x('b', const_float(3)) +
616                               lowered_return(const_float(4))) +
617            final_return()
618            ))
619    create_test_case(doc_string, input_sexp, input_sexp, 'return_non_void_at_end_of_loop_lower_nothing')
620    create_test_case(doc_string, input_sexp, expected_sexp, 'return_non_void_at_end_of_loop_lower_return',
621                     lower_sub_return=True)
622    create_test_case(doc_string, input_sexp, expected_sexp, 'return_non_void_at_end_of_loop_lower_return_and_break',
623                     lower_sub_return=True, lower_break=True)
624
625if __name__ == '__main__':
626    test_lower_returns_main()
627    test_lower_returns_sub()
628    test_lower_returns_1()
629    test_lower_returns_2()
630    test_lower_returns_3()
631    test_lower_returns_4()
632    test_lower_unified_returns()
633    test_lower_pulled_out_jump()
634    test_lower_breaks_1()
635    test_lower_breaks_2()
636    test_lower_breaks_3()
637    test_lower_breaks_4()
638    test_lower_breaks_5()
639    test_lower_breaks_6()
640    test_lower_guarded_conditional_break()
641    test_remove_continue_at_end_of_loop()
642    test_lower_return_void_at_end_of_loop()
643    test_lower_return_non_void_at_end_of_loop()
644