1#! /usr/bin/env python
2#
3# Copyright (C) 2014 Intel Corporation
4#
5# Permission is hereby granted, free of charge, to any person obtaining a
6# copy of this software and associated documentation files (the "Software"),
7# to deal in the Software without restriction, including without limitation
8# the rights to use, copy, modify, merge, publish, distribute, sublicense,
9# and/or sell copies of the Software, and to permit persons to whom the
10# Software is furnished to do so, subject to the following conditions:
11#
12# The above copyright notice and this permission notice (including the next
13# paragraph) shall be included in all copies or substantial portions of the
14# Software.
15#
16# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22# IN THE SOFTWARE.
23#
24# Authors:
25#    Jason Ekstrand (jason@jlekstrand.net)
26
27from __future__ import print_function
28import ast
29import itertools
30import struct
31import sys
32import mako.template
33import re
34import traceback
35
36from nir_opcodes import opcodes
37
38_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
39
40def type_bits(type_str):
41   m = _type_re.match(type_str)
42   assert m.group('type')
43
44   if m.group('bits') is None:
45      return 0
46   else:
47      return int(m.group('bits'))
48
49# Represents a set of variables, each with a unique id
50class VarSet(object):
51   def __init__(self):
52      self.names = {}
53      self.ids = itertools.count()
54      self.immutable = False;
55
56   def __getitem__(self, name):
57      if name not in self.names:
58         assert not self.immutable, "Unknown replacement variable: " + name
59         self.names[name] = self.ids.next()
60
61      return self.names[name]
62
63   def lock(self):
64      self.immutable = True
65
66class Value(object):
67   @staticmethod
68   def create(val, name_base, varset):
69      if isinstance(val, tuple):
70         return Expression(val, name_base, varset)
71      elif isinstance(val, Expression):
72         return val
73      elif isinstance(val, (str, unicode)):
74         return Variable(val, name_base, varset)
75      elif isinstance(val, (bool, int, long, float)):
76         return Constant(val, name_base)
77
78   __template = mako.template.Template("""
79#include "compiler/nir/nir_search_helpers.h"
80static const ${val.c_type} ${val.name} = {
81   { ${val.type_enum}, ${val.bit_size} },
82% if isinstance(val, Constant):
83   ${val.type()}, { ${hex(val)} /* ${val.value} */ },
84% elif isinstance(val, Variable):
85   ${val.index}, /* ${val.var_name} */
86   ${'true' if val.is_constant else 'false'},
87   ${val.type() or 'nir_type_invalid' },
88   ${val.cond if val.cond else 'NULL'},
89% elif isinstance(val, Expression):
90   ${'true' if val.inexact else 'false'},
91   nir_op_${val.opcode},
92   { ${', '.join(src.c_ptr for src in val.sources)} },
93   ${val.cond if val.cond else 'NULL'},
94% endif
95};""")
96
97   def __init__(self, name, type_str):
98      self.name = name
99      self.type_str = type_str
100
101   @property
102   def type_enum(self):
103      return "nir_search_value_" + self.type_str
104
105   @property
106   def c_type(self):
107      return "nir_search_" + self.type_str
108
109   @property
110   def c_ptr(self):
111      return "&{0}.value".format(self.name)
112
113   def render(self):
114      return self.__template.render(val=self,
115                                    Constant=Constant,
116                                    Variable=Variable,
117                                    Expression=Expression)
118
119_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
120
121class Constant(Value):
122   def __init__(self, val, name):
123      Value.__init__(self, name, "constant")
124
125      if isinstance(val, (str)):
126         m = _constant_re.match(val)
127         self.value = ast.literal_eval(m.group('value'))
128         self.bit_size = int(m.group('bits')) if m.group('bits') else 0
129      else:
130         self.value = val
131         self.bit_size = 0
132
133      if isinstance(self.value, bool):
134         assert self.bit_size == 0 or self.bit_size == 32
135         self.bit_size = 32
136
137   def __hex__(self):
138      if isinstance(self.value, (bool)):
139         return 'NIR_TRUE' if self.value else 'NIR_FALSE'
140      if isinstance(self.value, (int, long)):
141         return hex(self.value)
142      elif isinstance(self.value, float):
143         return hex(struct.unpack('Q', struct.pack('d', self.value))[0])
144      else:
145         assert False
146
147   def type(self):
148      if isinstance(self.value, (bool)):
149         return "nir_type_bool32"
150      elif isinstance(self.value, (int, long)):
151         return "nir_type_int"
152      elif isinstance(self.value, float):
153         return "nir_type_float"
154
155_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
156                          r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
157                          r"(?P<cond>\([^\)]+\))?")
158
159class Variable(Value):
160   def __init__(self, val, name, varset):
161      Value.__init__(self, name, "variable")
162
163      m = _var_name_re.match(val)
164      assert m and m.group('name') is not None
165
166      self.var_name = m.group('name')
167      self.is_constant = m.group('const') is not None
168      self.cond = m.group('cond')
169      self.required_type = m.group('type')
170      self.bit_size = int(m.group('bits')) if m.group('bits') else 0
171
172      if self.required_type == 'bool':
173         assert self.bit_size == 0 or self.bit_size == 32
174         self.bit_size = 32
175
176      if self.required_type is not None:
177         assert self.required_type in ('float', 'bool', 'int', 'uint')
178
179      self.index = varset[self.var_name]
180
181   def type(self):
182      if self.required_type == 'bool':
183         return "nir_type_bool32"
184      elif self.required_type in ('int', 'uint'):
185         return "nir_type_int"
186      elif self.required_type == 'float':
187         return "nir_type_float"
188
189_opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
190                        r"(?P<cond>\([^\)]+\))?")
191
192class Expression(Value):
193   def __init__(self, expr, name_base, varset):
194      Value.__init__(self, name_base, "expression")
195      assert isinstance(expr, tuple)
196
197      m = _opcode_re.match(expr[0])
198      assert m and m.group('opcode') is not None
199
200      self.opcode = m.group('opcode')
201      self.bit_size = int(m.group('bits')) if m.group('bits') else 0
202      self.inexact = m.group('inexact') is not None
203      self.cond = m.group('cond')
204      self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
205                       for (i, src) in enumerate(expr[1:]) ]
206
207   def render(self):
208      srcs = "\n".join(src.render() for src in self.sources)
209      return srcs + super(Expression, self).render()
210
211class IntEquivalenceRelation(object):
212   """A class representing an equivalence relation on integers.
213
214   Each integer has a canonical form which is the maximum integer to which it
215   is equivalent.  Two integers are equivalent precisely when they have the
216   same canonical form.
217
218   The convention of maximum is explicitly chosen to make using it in
219   BitSizeValidator easier because it means that an actual bit_size (if any)
220   will always be the canonical form.
221   """
222   def __init__(self):
223      self._remap = {}
224
225   def get_canonical(self, x):
226      """Get the canonical integer corresponding to x."""
227      if x in self._remap:
228         return self.get_canonical(self._remap[x])
229      else:
230         return x
231
232   def add_equiv(self, a, b):
233      """Add an equivalence and return the canonical form."""
234      c = max(self.get_canonical(a), self.get_canonical(b))
235      if a != c:
236         assert a < c
237         self._remap[a] = c
238
239      if b != c:
240         assert b < c
241         self._remap[b] = c
242
243      return c
244
245class BitSizeValidator(object):
246   """A class for validating bit sizes of expressions.
247
248   NIR supports multiple bit-sizes on expressions in order to handle things
249   such as fp64.  The source and destination of every ALU operation is
250   assigned a type and that type may or may not specify a bit size.  Sources
251   and destinations whose type does not specify a bit size are considered
252   "unsized" and automatically take on the bit size of the corresponding
253   register or SSA value.  NIR has two simple rules for bit sizes that are
254   validated by nir_validator:
255
256    1) A given SSA def or register has a single bit size that is respected by
257       everything that reads from it or writes to it.
258
259    2) The bit sizes of all unsized inputs/outputs on any given ALU
260       instruction must match.  They need not match the sized inputs or
261       outputs but they must match each other.
262
263   In order to keep nir_algebraic relatively simple and easy-to-use,
264   nir_search supports a type of bit-size inference based on the two rules
265   above.  This is similar to type inference in many common programming
266   languages.  If, for instance, you are constructing an add operation and you
267   know the second source is 16-bit, then you know that the other source and
268   the destination must also be 16-bit.  There are, however, cases where this
269   inference can be ambiguous or contradictory.  Consider, for instance, the
270   following transformation:
271
272   (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
273
274   This transformation can potentially cause a problem because usub_borrow is
275   well-defined for any bit-size of integer.  However, b2i always generates a
276   32-bit result so it could end up replacing a 64-bit expression with one
277   that takes two 64-bit values and produces a 32-bit value.  As another
278   example, consider this expression:
279
280   (('bcsel', a, b, 0), ('iand', a, b))
281
282   In this case, in the search expression a must be 32-bit but b can
283   potentially have any bit size.  If we had a 64-bit b value, we would end up
284   trying to and a 32-bit value with a 64-bit value which would be invalid
285
286   This class solves that problem by providing a validation layer that proves
287   that a given search-and-replace operation is 100% well-defined before we
288   generate any code.  This ensures that bugs are caught at compile time
289   rather than at run time.
290
291   The basic operation of the validator is very similar to the bitsize_tree in
292   nir_search only a little more subtle.  Instead of simply tracking bit
293   sizes, it tracks "bit classes" where each class is represented by an
294   integer.  A value of 0 means we don't know anything yet, positive values
295   are actual bit-sizes, and negative values are used to track equivalence
296   classes of sizes that must be the same but have yet to receive an actual
297   size.  The first stage uses the bitsize_tree algorithm to assign bit
298   classes to each variable.  If it ever comes across an inconsistency, it
299   assert-fails.  Then the second stage uses that information to prove that
300   the resulting expression can always validly be constructed.
301   """
302
303   def __init__(self, varset):
304      self._num_classes = 0
305      self._var_classes = [0] * len(varset.names)
306      self._class_relation = IntEquivalenceRelation()
307
308   def validate(self, search, replace):
309      dst_class = self._propagate_bit_size_up(search)
310      if dst_class == 0:
311         dst_class = self._new_class()
312      self._propagate_bit_class_down(search, dst_class)
313
314      validate_dst_class = self._validate_bit_class_up(replace)
315      assert validate_dst_class == 0 or validate_dst_class == dst_class
316      self._validate_bit_class_down(replace, dst_class)
317
318   def _new_class(self):
319      self._num_classes += 1
320      return -self._num_classes
321
322   def _set_var_bit_class(self, var_id, bit_class):
323      assert bit_class != 0
324      var_class = self._var_classes[var_id]
325      if var_class == 0:
326         self._var_classes[var_id] = bit_class
327      else:
328         canon_class = self._class_relation.get_canonical(var_class)
329         assert canon_class < 0 or canon_class == bit_class
330         var_class = self._class_relation.add_equiv(var_class, bit_class)
331         self._var_classes[var_id] = var_class
332
333   def _get_var_bit_class(self, var_id):
334      return self._class_relation.get_canonical(self._var_classes[var_id])
335
336   def _propagate_bit_size_up(self, val):
337      if isinstance(val, (Constant, Variable)):
338         return val.bit_size
339
340      elif isinstance(val, Expression):
341         nir_op = opcodes[val.opcode]
342         val.common_size = 0
343         for i in range(nir_op.num_inputs):
344            src_bits = self._propagate_bit_size_up(val.sources[i])
345            if src_bits == 0:
346               continue
347
348            src_type_bits = type_bits(nir_op.input_types[i])
349            if src_type_bits != 0:
350               assert src_bits == src_type_bits
351            else:
352               assert val.common_size == 0 or src_bits == val.common_size
353               val.common_size = src_bits
354
355         dst_type_bits = type_bits(nir_op.output_type)
356         if dst_type_bits != 0:
357            assert val.bit_size == 0 or val.bit_size == dst_type_bits
358            return dst_type_bits
359         else:
360            if val.common_size != 0:
361               assert val.bit_size == 0 or val.bit_size == val.common_size
362            else:
363               val.common_size = val.bit_size
364            return val.common_size
365
366   def _propagate_bit_class_down(self, val, bit_class):
367      if isinstance(val, Constant):
368         assert val.bit_size == 0 or val.bit_size == bit_class
369
370      elif isinstance(val, Variable):
371         assert val.bit_size == 0 or val.bit_size == bit_class
372         self._set_var_bit_class(val.index, bit_class)
373
374      elif isinstance(val, Expression):
375         nir_op = opcodes[val.opcode]
376         dst_type_bits = type_bits(nir_op.output_type)
377         if dst_type_bits != 0:
378            assert bit_class == 0 or bit_class == dst_type_bits
379         else:
380            assert val.common_size == 0 or val.common_size == bit_class
381            val.common_size = bit_class
382
383         if val.common_size:
384            common_class = val.common_size
385         elif nir_op.num_inputs:
386            # If we got here then we have no idea what the actual size is.
387            # Instead, we use a generic class
388            common_class = self._new_class()
389
390         for i in range(nir_op.num_inputs):
391            src_type_bits = type_bits(nir_op.input_types[i])
392            if src_type_bits != 0:
393               self._propagate_bit_class_down(val.sources[i], src_type_bits)
394            else:
395               self._propagate_bit_class_down(val.sources[i], common_class)
396
397   def _validate_bit_class_up(self, val):
398      if isinstance(val, Constant):
399         return val.bit_size
400
401      elif isinstance(val, Variable):
402         var_class = self._get_var_bit_class(val.index)
403         # By the time we get to validation, every variable should have a class
404         assert var_class != 0
405
406         # If we have an explicit size provided by the user, the variable
407         # *must* exactly match the search.  It cannot be implicitly sized
408         # because otherwise we could end up with a conflict at runtime.
409         assert val.bit_size == 0 or val.bit_size == var_class
410
411         return var_class
412
413      elif isinstance(val, Expression):
414         nir_op = opcodes[val.opcode]
415         val.common_class = 0
416         for i in range(nir_op.num_inputs):
417            src_class = self._validate_bit_class_up(val.sources[i])
418            if src_class == 0:
419               continue
420
421            src_type_bits = type_bits(nir_op.input_types[i])
422            if src_type_bits != 0:
423               assert src_class == src_type_bits
424            else:
425               assert val.common_class == 0 or src_class == val.common_class
426               val.common_class = src_class
427
428         dst_type_bits = type_bits(nir_op.output_type)
429         if dst_type_bits != 0:
430            assert val.bit_size == 0 or val.bit_size == dst_type_bits
431            return dst_type_bits
432         else:
433            if val.common_class != 0:
434               assert val.bit_size == 0 or val.bit_size == val.common_class
435            else:
436               val.common_class = val.bit_size
437            return val.common_class
438
439   def _validate_bit_class_down(self, val, bit_class):
440      # At this point, everything *must* have a bit class.  Otherwise, we have
441      # a value we don't know how to define.
442      assert bit_class != 0
443
444      if isinstance(val, Constant):
445         assert val.bit_size == 0 or val.bit_size == bit_class
446
447      elif isinstance(val, Variable):
448         assert val.bit_size == 0 or val.bit_size == bit_class
449
450      elif isinstance(val, Expression):
451         nir_op = opcodes[val.opcode]
452         dst_type_bits = type_bits(nir_op.output_type)
453         if dst_type_bits != 0:
454            assert bit_class == dst_type_bits
455         else:
456            assert val.common_class == 0 or val.common_class == bit_class
457            val.common_class = bit_class
458
459         for i in range(nir_op.num_inputs):
460            src_type_bits = type_bits(nir_op.input_types[i])
461            if src_type_bits != 0:
462               self._validate_bit_class_down(val.sources[i], src_type_bits)
463            else:
464               self._validate_bit_class_down(val.sources[i], val.common_class)
465
466_optimization_ids = itertools.count()
467
468condition_list = ['true']
469
470class SearchAndReplace(object):
471   def __init__(self, transform):
472      self.id = _optimization_ids.next()
473
474      search = transform[0]
475      replace = transform[1]
476      if len(transform) > 2:
477         self.condition = transform[2]
478      else:
479         self.condition = 'true'
480
481      if self.condition not in condition_list:
482         condition_list.append(self.condition)
483      self.condition_index = condition_list.index(self.condition)
484
485      varset = VarSet()
486      if isinstance(search, Expression):
487         self.search = search
488      else:
489         self.search = Expression(search, "search{0}".format(self.id), varset)
490
491      varset.lock()
492
493      if isinstance(replace, Value):
494         self.replace = replace
495      else:
496         self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
497
498      BitSizeValidator(varset).validate(self.search, self.replace)
499
500_algebraic_pass_template = mako.template.Template("""
501#include "nir.h"
502#include "nir_search.h"
503
504#ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
505#define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
506
507struct transform {
508   const nir_search_expression *search;
509   const nir_search_value *replace;
510   unsigned condition_offset;
511};
512
513#endif
514
515% for (opcode, xform_list) in xform_dict.iteritems():
516% for xform in xform_list:
517   ${xform.search.render()}
518   ${xform.replace.render()}
519% endfor
520
521static const struct transform ${pass_name}_${opcode}_xforms[] = {
522% for xform in xform_list:
523   { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
524% endfor
525};
526% endfor
527
528static bool
529${pass_name}_block(nir_block *block, const bool *condition_flags,
530                   void *mem_ctx)
531{
532   bool progress = false;
533
534   nir_foreach_instr_reverse_safe(instr, block) {
535      if (instr->type != nir_instr_type_alu)
536         continue;
537
538      nir_alu_instr *alu = nir_instr_as_alu(instr);
539      if (!alu->dest.dest.is_ssa)
540         continue;
541
542      switch (alu->op) {
543      % for opcode in xform_dict.keys():
544      case nir_op_${opcode}:
545         for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
546            const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
547            if (condition_flags[xform->condition_offset] &&
548                nir_replace_instr(alu, xform->search, xform->replace,
549                                  mem_ctx)) {
550               progress = true;
551               break;
552            }
553         }
554         break;
555      % endfor
556      default:
557         break;
558      }
559   }
560
561   return progress;
562}
563
564static bool
565${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
566{
567   void *mem_ctx = ralloc_parent(impl);
568   bool progress = false;
569
570   nir_foreach_block_reverse(block, impl) {
571      progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
572   }
573
574   if (progress)
575      nir_metadata_preserve(impl, nir_metadata_block_index |
576                                  nir_metadata_dominance);
577
578   return progress;
579}
580
581
582bool
583${pass_name}(nir_shader *shader)
584{
585   bool progress = false;
586   bool condition_flags[${len(condition_list)}];
587   const nir_shader_compiler_options *options = shader->options;
588   (void) options;
589
590   % for index, condition in enumerate(condition_list):
591   condition_flags[${index}] = ${condition};
592   % endfor
593
594   nir_foreach_function(function, shader) {
595      if (function->impl)
596         progress |= ${pass_name}_impl(function->impl, condition_flags);
597   }
598
599   return progress;
600}
601""")
602
603class AlgebraicPass(object):
604   def __init__(self, pass_name, transforms):
605      self.xform_dict = {}
606      self.pass_name = pass_name
607
608      error = False
609
610      for xform in transforms:
611         if not isinstance(xform, SearchAndReplace):
612            try:
613               xform = SearchAndReplace(xform)
614            except:
615               print("Failed to parse transformation:", file=sys.stderr)
616               print("  " + str(xform), file=sys.stderr)
617               traceback.print_exc(file=sys.stderr)
618               print('', file=sys.stderr)
619               error = True
620               continue
621
622         if xform.search.opcode not in self.xform_dict:
623            self.xform_dict[xform.search.opcode] = []
624
625         self.xform_dict[xform.search.opcode].append(xform)
626
627      if error:
628         sys.exit(1)
629
630   def render(self):
631      return _algebraic_pass_template.render(pass_name=self.pass_name,
632                                             xform_dict=self.xform_dict,
633                                             condition_list=condition_list)
634