nir_search.c revision e8328e55e7ac26bbf3b3a47a1bb1cae4ab9130a2
1/*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 *    Jason Ekstrand (jason@jlekstrand.net)
25 *
26 */
27
28#include <inttypes.h>
29#include "nir_search.h"
30
31struct match_state {
32   bool inexact_match;
33   bool has_exact_alu;
34   unsigned variables_seen;
35   nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
36};
37
38static bool
39match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
40                 unsigned num_components, const uint8_t *swizzle,
41                 struct match_state *state);
42
43static const uint8_t identity_swizzle[] = { 0, 1, 2, 3 };
44
45/**
46 * Check if a source produces a value of the given type.
47 *
48 * Used for satisfying 'a@type' constraints.
49 */
50static bool
51src_is_type(nir_src src, nir_alu_type type)
52{
53   assert(type != nir_type_invalid);
54
55   if (!src.is_ssa)
56      return false;
57
58   /* Turn nir_type_bool32 into nir_type_bool...they're the same thing. */
59   if (nir_alu_type_get_base_type(type) == nir_type_bool)
60      type = nir_type_bool;
61
62   if (src.ssa->parent_instr->type == nir_instr_type_alu) {
63      nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
64      nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
65
66      if (type == nir_type_bool) {
67         switch (src_alu->op) {
68         case nir_op_iand:
69         case nir_op_ior:
70         case nir_op_ixor:
71            return src_is_type(src_alu->src[0].src, nir_type_bool) &&
72                   src_is_type(src_alu->src[1].src, nir_type_bool);
73         case nir_op_inot:
74            return src_is_type(src_alu->src[0].src, nir_type_bool);
75         default:
76            break;
77         }
78      }
79
80      return nir_alu_type_get_base_type(output_type) == type;
81   } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) {
82      nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr);
83
84      if (type == nir_type_bool) {
85         return intr->intrinsic == nir_intrinsic_load_front_face ||
86                intr->intrinsic == nir_intrinsic_load_helper_invocation;
87      }
88   }
89
90   /* don't know */
91   return false;
92}
93
94static bool
95match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
96            unsigned num_components, const uint8_t *swizzle,
97            struct match_state *state)
98{
99   uint8_t new_swizzle[4];
100
101   /* Searching only works on SSA values because, if it's not SSA, we can't
102    * know if the value changed between one instance of that value in the
103    * expression and another.  Also, the replace operation will place reads of
104    * that value right before the last instruction in the expression we're
105    * replacing so those reads will happen after the original reads and may
106    * not be valid if they're register reads.
107    */
108   if (!instr->src[src].src.is_ssa)
109      return false;
110
111   /* If the source is an explicitly sized source, then we need to reset
112    * both the number of components and the swizzle.
113    */
114   if (nir_op_infos[instr->op].input_sizes[src] != 0) {
115      num_components = nir_op_infos[instr->op].input_sizes[src];
116      swizzle = identity_swizzle;
117   }
118
119   for (unsigned i = 0; i < num_components; ++i)
120      new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
121
122   /* If the value has a specific bit size and it doesn't match, bail */
123   if (value->bit_size &&
124       nir_src_bit_size(instr->src[src].src) != value->bit_size)
125      return false;
126
127   switch (value->type) {
128   case nir_search_value_expression:
129      if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
130         return false;
131
132      return match_expression(nir_search_value_as_expression(value),
133                              nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
134                              num_components, new_swizzle, state);
135
136   case nir_search_value_variable: {
137      nir_search_variable *var = nir_search_value_as_variable(value);
138      assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
139
140      if (state->variables_seen & (1 << var->variable)) {
141         if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa)
142            return false;
143
144         assert(!instr->src[src].abs && !instr->src[src].negate);
145
146         for (unsigned i = 0; i < num_components; ++i) {
147            if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
148               return false;
149         }
150
151         return true;
152      } else {
153         if (var->is_constant &&
154             instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
155            return false;
156
157         if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
158            return false;
159
160         if (var->type != nir_type_invalid &&
161             !src_is_type(instr->src[src].src, var->type))
162            return false;
163
164         state->variables_seen |= (1 << var->variable);
165         state->variables[var->variable].src = instr->src[src].src;
166         state->variables[var->variable].abs = false;
167         state->variables[var->variable].negate = false;
168
169         for (unsigned i = 0; i < 4; ++i) {
170            if (i < num_components)
171               state->variables[var->variable].swizzle[i] = new_swizzle[i];
172            else
173               state->variables[var->variable].swizzle[i] = 0;
174         }
175
176         return true;
177      }
178   }
179
180   case nir_search_value_constant: {
181      nir_search_constant *const_val = nir_search_value_as_constant(value);
182
183      if (!instr->src[src].src.is_ssa)
184         return false;
185
186      if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
187         return false;
188
189      nir_load_const_instr *load =
190         nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
191
192      switch (const_val->type) {
193      case nir_type_float:
194         for (unsigned i = 0; i < num_components; ++i) {
195            double val;
196            switch (load->def.bit_size) {
197            case 32:
198               val = load->value.f32[new_swizzle[i]];
199               break;
200            case 64:
201               val = load->value.f64[new_swizzle[i]];
202               break;
203            default:
204               unreachable("unknown bit size");
205            }
206
207            if (val != const_val->data.d)
208               return false;
209         }
210         return true;
211
212      case nir_type_int:
213         for (unsigned i = 0; i < num_components; ++i) {
214            int64_t val;
215            switch (load->def.bit_size) {
216            case 32:
217               val = load->value.i32[new_swizzle[i]];
218               break;
219            case 64:
220               val = load->value.i64[new_swizzle[i]];
221               break;
222            default:
223               unreachable("unknown bit size");
224            }
225
226            if (val != const_val->data.i)
227               return false;
228         }
229         return true;
230
231      case nir_type_uint:
232      case nir_type_bool32:
233         for (unsigned i = 0; i < num_components; ++i) {
234            uint64_t val;
235            switch (load->def.bit_size) {
236            case 32:
237               val = load->value.u32[new_swizzle[i]];
238               break;
239            case 64:
240               val = load->value.u64[new_swizzle[i]];
241               break;
242            default:
243               unreachable("unknown bit size");
244            }
245
246            if (val != const_val->data.u)
247               return false;
248         }
249         return true;
250
251      default:
252         unreachable("Invalid alu source type");
253      }
254   }
255
256   default:
257      unreachable("Invalid search value type");
258   }
259}
260
261static bool
262match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
263                 unsigned num_components, const uint8_t *swizzle,
264                 struct match_state *state)
265{
266   if (expr->cond && !expr->cond(instr))
267      return false;
268
269   if (instr->op != expr->opcode)
270      return false;
271
272   assert(instr->dest.dest.is_ssa);
273
274   if (expr->value.bit_size &&
275       instr->dest.dest.ssa.bit_size != expr->value.bit_size)
276      return false;
277
278   state->inexact_match = expr->inexact || state->inexact_match;
279   state->has_exact_alu = instr->exact || state->has_exact_alu;
280   if (state->inexact_match && state->has_exact_alu)
281      return false;
282
283   assert(!instr->dest.saturate);
284   assert(nir_op_infos[instr->op].num_inputs > 0);
285
286   /* If we have an explicitly sized destination, we can only handle the
287    * identity swizzle.  While dot(vec3(a, b, c).zxy) is a valid
288    * expression, we don't have the information right now to propagate that
289    * swizzle through.  We can only properly propagate swizzles if the
290    * instruction is vectorized.
291    */
292   if (nir_op_infos[instr->op].output_size != 0) {
293      for (unsigned i = 0; i < num_components; i++) {
294         if (swizzle[i] != i)
295            return false;
296      }
297   }
298
299   /* Stash off the current variables_seen bitmask.  This way we can
300    * restore it prior to matching in the commutative case below.
301    */
302   unsigned variables_seen_stash = state->variables_seen;
303
304   bool matched = true;
305   for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
306      if (!match_value(expr->srcs[i], instr, i, num_components,
307                       swizzle, state)) {
308         matched = false;
309         break;
310      }
311   }
312
313   if (matched)
314      return true;
315
316   if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
317      assert(nir_op_infos[instr->op].num_inputs == 2);
318
319      /* Restore the variables_seen bitmask.  If we don't do this, then we
320       * could end up with an erroneous failure due to variables found in the
321       * first match attempt above not matching those in the second.
322       */
323      state->variables_seen = variables_seen_stash;
324
325      if (!match_value(expr->srcs[0], instr, 1, num_components,
326                       swizzle, state))
327         return false;
328
329      return match_value(expr->srcs[1], instr, 0, num_components,
330                         swizzle, state);
331   } else {
332      return false;
333   }
334}
335
336typedef struct bitsize_tree {
337   unsigned num_srcs;
338   struct bitsize_tree *srcs[4];
339
340   unsigned common_size;
341   bool is_src_sized[4];
342   bool is_dest_sized;
343
344   unsigned dest_size;
345   unsigned src_size[4];
346} bitsize_tree;
347
348static bitsize_tree *
349build_bitsize_tree(void *mem_ctx, struct match_state *state,
350                   const nir_search_value *value)
351{
352   bitsize_tree *tree = rzalloc(mem_ctx, bitsize_tree);
353
354   switch (value->type) {
355   case nir_search_value_expression: {
356      nir_search_expression *expr = nir_search_value_as_expression(value);
357      nir_op_info info = nir_op_infos[expr->opcode];
358      tree->num_srcs = info.num_inputs;
359      tree->common_size = 0;
360      for (unsigned i = 0; i < info.num_inputs; i++) {
361         tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]);
362         if (tree->is_src_sized[i])
363            tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]);
364         tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]);
365      }
366      tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type);
367      if (tree->is_dest_sized)
368         tree->dest_size = nir_alu_type_get_type_size(info.output_type);
369      break;
370   }
371
372   case nir_search_value_variable: {
373      nir_search_variable *var = nir_search_value_as_variable(value);
374      tree->num_srcs = 0;
375      tree->is_dest_sized = true;
376      tree->dest_size = nir_src_bit_size(state->variables[var->variable].src);
377      break;
378   }
379
380   case nir_search_value_constant: {
381      tree->num_srcs = 0;
382      tree->is_dest_sized = false;
383      tree->common_size = 0;
384      break;
385   }
386   }
387
388   if (value->bit_size) {
389      assert(!tree->is_dest_sized || tree->dest_size == value->bit_size);
390      tree->common_size = value->bit_size;
391   }
392
393   return tree;
394}
395
396static unsigned
397bitsize_tree_filter_up(bitsize_tree *tree)
398{
399   for (unsigned i = 0; i < tree->num_srcs; i++) {
400      unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]);
401      if (src_size == 0)
402         continue;
403
404      if (tree->is_src_sized[i]) {
405         assert(src_size == tree->src_size[i]);
406      } else if (tree->common_size != 0) {
407         assert(src_size == tree->common_size);
408         tree->src_size[i] = src_size;
409      } else {
410         tree->common_size = src_size;
411         tree->src_size[i] = src_size;
412      }
413   }
414
415   if (tree->num_srcs && tree->common_size) {
416      if (tree->dest_size == 0)
417         tree->dest_size = tree->common_size;
418      else if (!tree->is_dest_sized)
419         assert(tree->dest_size == tree->common_size);
420
421      for (unsigned i = 0; i < tree->num_srcs; i++) {
422         if (!tree->src_size[i])
423            tree->src_size[i] = tree->common_size;
424      }
425   }
426
427   return tree->dest_size;
428}
429
430static void
431bitsize_tree_filter_down(bitsize_tree *tree, unsigned size)
432{
433   if (tree->dest_size)
434      assert(tree->dest_size == size);
435   else
436      tree->dest_size = size;
437
438   if (!tree->is_dest_sized) {
439      if (tree->common_size)
440         assert(tree->common_size == size);
441      else
442         tree->common_size = size;
443   }
444
445   for (unsigned i = 0; i < tree->num_srcs; i++) {
446      if (!tree->src_size[i]) {
447         assert(tree->common_size);
448         tree->src_size[i] = tree->common_size;
449      }
450      bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]);
451   }
452}
453
454static nir_alu_src
455construct_value(const nir_search_value *value,
456                unsigned num_components, bitsize_tree *bitsize,
457                struct match_state *state,
458                nir_instr *instr, void *mem_ctx)
459{
460   switch (value->type) {
461   case nir_search_value_expression: {
462      const nir_search_expression *expr = nir_search_value_as_expression(value);
463
464      if (nir_op_infos[expr->opcode].output_size != 0)
465         num_components = nir_op_infos[expr->opcode].output_size;
466
467      nir_alu_instr *alu = nir_alu_instr_create(mem_ctx, expr->opcode);
468      nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
469                        bitsize->dest_size, NULL);
470      alu->dest.write_mask = (1 << num_components) - 1;
471      alu->dest.saturate = false;
472
473      /* We have no way of knowing what values in a given search expression
474       * map to a particular replacement value.  Therefore, if the
475       * expression we are replacing has any exact values, the entire
476       * replacement should be exact.
477       */
478      alu->exact = state->has_exact_alu;
479
480      for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) {
481         /* If the source is an explicitly sized source, then we need to reset
482          * the number of components to match.
483          */
484         if (nir_op_infos[alu->op].input_sizes[i] != 0)
485            num_components = nir_op_infos[alu->op].input_sizes[i];
486
487         alu->src[i] = construct_value(expr->srcs[i],
488                                       num_components, bitsize->srcs[i],
489                                       state, instr, mem_ctx);
490      }
491
492      nir_instr_insert_before(instr, &alu->instr);
493
494      nir_alu_src val;
495      val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
496      val.negate = false;
497      val.abs = false,
498      memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
499
500      return val;
501   }
502
503   case nir_search_value_variable: {
504      const nir_search_variable *var = nir_search_value_as_variable(value);
505      assert(state->variables_seen & (1 << var->variable));
506
507      nir_alu_src val = { NIR_SRC_INIT };
508      nir_alu_src_copy(&val, &state->variables[var->variable], mem_ctx);
509
510      assert(!var->is_constant);
511
512      return val;
513   }
514
515   case nir_search_value_constant: {
516      const nir_search_constant *c = nir_search_value_as_constant(value);
517      nir_load_const_instr *load =
518         nir_load_const_instr_create(mem_ctx, 1, bitsize->dest_size);
519
520      switch (c->type) {
521      case nir_type_float:
522         load->def.name = ralloc_asprintf(load, "%f", c->data.d);
523         switch (bitsize->dest_size) {
524         case 32:
525            load->value.f32[0] = c->data.d;
526            break;
527         case 64:
528            load->value.f64[0] = c->data.d;
529            break;
530         default:
531            unreachable("unknown bit size");
532         }
533         break;
534
535      case nir_type_int:
536         load->def.name = ralloc_asprintf(load, "%" PRIi64, c->data.i);
537         switch (bitsize->dest_size) {
538         case 32:
539            load->value.i32[0] = c->data.i;
540            break;
541         case 64:
542            load->value.i64[0] = c->data.i;
543            break;
544         default:
545            unreachable("unknown bit size");
546         }
547         break;
548
549      case nir_type_uint:
550         load->def.name = ralloc_asprintf(load, "%" PRIu64, c->data.u);
551         switch (bitsize->dest_size) {
552         case 32:
553            load->value.u32[0] = c->data.u;
554            break;
555         case 64:
556            load->value.u64[0] = c->data.u;
557            break;
558         default:
559            unreachable("unknown bit size");
560         }
561         break;
562
563      case nir_type_bool32:
564         load->value.u32[0] = c->data.u;
565         break;
566      default:
567         unreachable("Invalid alu source type");
568      }
569
570      nir_instr_insert_before(instr, &load->instr);
571
572      nir_alu_src val;
573      val.src = nir_src_for_ssa(&load->def);
574      val.negate = false;
575      val.abs = false,
576      memset(val.swizzle, 0, sizeof val.swizzle);
577
578      return val;
579   }
580
581   default:
582      unreachable("Invalid search value type");
583   }
584}
585
586nir_alu_instr *
587nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search,
588                  const nir_search_value *replace, void *mem_ctx)
589{
590   uint8_t swizzle[4] = { 0, 0, 0, 0 };
591
592   for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
593      swizzle[i] = i;
594
595   assert(instr->dest.dest.is_ssa);
596
597   struct match_state state;
598   state.inexact_match = false;
599   state.has_exact_alu = false;
600   state.variables_seen = 0;
601
602   if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
603                         swizzle, &state))
604      return NULL;
605
606   void *bitsize_ctx = ralloc_context(NULL);
607   bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace);
608   bitsize_tree_filter_up(tree);
609   bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size);
610
611   /* Inserting a mov may be unnecessary.  However, it's much easier to
612    * simply let copy propagation clean this up than to try to go through
613    * and rewrite swizzles ourselves.
614    */
615   nir_alu_instr *mov = nir_alu_instr_create(mem_ctx, nir_op_imov);
616   mov->dest.write_mask = instr->dest.write_mask;
617   nir_ssa_dest_init(&mov->instr, &mov->dest.dest,
618                     instr->dest.dest.ssa.num_components,
619                     instr->dest.dest.ssa.bit_size, NULL);
620
621   mov->src[0] = construct_value(replace,
622                                 instr->dest.dest.ssa.num_components, tree,
623                                 &state, &instr->instr, mem_ctx);
624   nir_instr_insert_before(&instr->instr, &mov->instr);
625
626   nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa,
627                            nir_src_for_ssa(&mov->dest.dest.ssa));
628
629   /* We know this one has no more uses because we just rewrote them all,
630    * so we can remove it.  The rest of the matched expression, however, we
631    * don't know so much about.  We'll just let dead code clean them up.
632    */
633   nir_instr_remove(&instr->instr);
634
635   ralloc_free(bitsize_ctx);
636
637   return mov;
638}
639