nir_search.c revision 3a9e6102b4baae3f50956e5f150c9e59138f4cc0
1/*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 *    Jason Ekstrand (jason@jlekstrand.net)
25 *
26 */
27
28#include <inttypes.h>
29#include "nir_search.h"
30
31struct match_state {
32   bool inexact_match;
33   bool has_exact_alu;
34   unsigned variables_seen;
35   nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
36};
37
38static bool
39match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
40                 unsigned num_components, const uint8_t *swizzle,
41                 struct match_state *state);
42
43static const uint8_t identity_swizzle[] = { 0, 1, 2, 3 };
44
45/**
46 * Check if a source produces a value of the given type.
47 *
48 * Used for satisfying 'a@type' constraints.
49 */
50static bool
51src_is_type(nir_src src, nir_alu_type type)
52{
53   assert(type != nir_type_invalid);
54
55   if (!src.is_ssa)
56      return false;
57
58   /* Turn nir_type_bool32 into nir_type_bool...they're the same thing. */
59   if (nir_alu_type_get_base_type(type) == nir_type_bool)
60      type = nir_type_bool;
61
62   if (src.ssa->parent_instr->type == nir_instr_type_alu) {
63      nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
64      nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
65
66      if (type == nir_type_bool) {
67         switch (src_alu->op) {
68         case nir_op_iand:
69         case nir_op_ior:
70         case nir_op_ixor:
71            return src_is_type(src_alu->src[0].src, nir_type_bool) &&
72                   src_is_type(src_alu->src[1].src, nir_type_bool);
73         case nir_op_inot:
74            return src_is_type(src_alu->src[0].src, nir_type_bool);
75         default:
76            break;
77         }
78      }
79
80      return nir_alu_type_get_base_type(output_type) == type;
81   } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) {
82      nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr);
83
84      if (type == nir_type_bool) {
85         return intr->intrinsic == nir_intrinsic_load_front_face ||
86                intr->intrinsic == nir_intrinsic_load_helper_invocation;
87      }
88   }
89
90   /* don't know */
91   return false;
92}
93
94static bool
95match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
96            unsigned num_components, const uint8_t *swizzle,
97            struct match_state *state)
98{
99   uint8_t new_swizzle[4];
100
101   /* If the source is an explicitly sized source, then we need to reset
102    * both the number of components and the swizzle.
103    */
104   if (nir_op_infos[instr->op].input_sizes[src] != 0) {
105      num_components = nir_op_infos[instr->op].input_sizes[src];
106      swizzle = identity_swizzle;
107   }
108
109   for (unsigned i = 0; i < num_components; ++i)
110      new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
111
112   /* If the value has a specific bit size and it doesn't match, bail */
113   if (value->bit_size &&
114       nir_src_bit_size(instr->src[src].src) != value->bit_size)
115      return false;
116
117   switch (value->type) {
118   case nir_search_value_expression:
119      if (!instr->src[src].src.is_ssa)
120         return false;
121
122      if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
123         return false;
124
125      return match_expression(nir_search_value_as_expression(value),
126                              nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
127                              num_components, new_swizzle, state);
128
129   case nir_search_value_variable: {
130      nir_search_variable *var = nir_search_value_as_variable(value);
131      assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
132
133      if (state->variables_seen & (1 << var->variable)) {
134         if (!nir_srcs_equal(state->variables[var->variable].src,
135                             instr->src[src].src))
136            return false;
137
138         assert(!instr->src[src].abs && !instr->src[src].negate);
139
140         for (unsigned i = 0; i < num_components; ++i) {
141            if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
142               return false;
143         }
144
145         return true;
146      } else {
147         if (var->is_constant &&
148             instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
149            return false;
150
151         if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
152            return false;
153
154         if (var->type != nir_type_invalid &&
155             !src_is_type(instr->src[src].src, var->type))
156            return false;
157
158         state->variables_seen |= (1 << var->variable);
159         state->variables[var->variable].src = instr->src[src].src;
160         state->variables[var->variable].abs = false;
161         state->variables[var->variable].negate = false;
162
163         for (unsigned i = 0; i < 4; ++i) {
164            if (i < num_components)
165               state->variables[var->variable].swizzle[i] = new_swizzle[i];
166            else
167               state->variables[var->variable].swizzle[i] = 0;
168         }
169
170         return true;
171      }
172   }
173
174   case nir_search_value_constant: {
175      nir_search_constant *const_val = nir_search_value_as_constant(value);
176
177      if (!instr->src[src].src.is_ssa)
178         return false;
179
180      if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
181         return false;
182
183      nir_load_const_instr *load =
184         nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
185
186      switch (const_val->type) {
187      case nir_type_float:
188         for (unsigned i = 0; i < num_components; ++i) {
189            double val;
190            switch (load->def.bit_size) {
191            case 32:
192               val = load->value.f32[new_swizzle[i]];
193               break;
194            case 64:
195               val = load->value.f64[new_swizzle[i]];
196               break;
197            default:
198               unreachable("unknown bit size");
199            }
200
201            if (val != const_val->data.d)
202               return false;
203         }
204         return true;
205
206      case nir_type_int:
207         for (unsigned i = 0; i < num_components; ++i) {
208            int64_t val;
209            switch (load->def.bit_size) {
210            case 32:
211               val = load->value.i32[new_swizzle[i]];
212               break;
213            case 64:
214               val = load->value.i64[new_swizzle[i]];
215               break;
216            default:
217               unreachable("unknown bit size");
218            }
219
220            if (val != const_val->data.i)
221               return false;
222         }
223         return true;
224
225      case nir_type_uint:
226      case nir_type_bool32:
227         for (unsigned i = 0; i < num_components; ++i) {
228            uint64_t val;
229            switch (load->def.bit_size) {
230            case 32:
231               val = load->value.u32[new_swizzle[i]];
232               break;
233            case 64:
234               val = load->value.u64[new_swizzle[i]];
235               break;
236            default:
237               unreachable("unknown bit size");
238            }
239
240            if (val != const_val->data.u)
241               return false;
242         }
243         return true;
244
245      default:
246         unreachable("Invalid alu source type");
247      }
248   }
249
250   default:
251      unreachable("Invalid search value type");
252   }
253}
254
255static bool
256match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
257                 unsigned num_components, const uint8_t *swizzle,
258                 struct match_state *state)
259{
260   if (instr->op != expr->opcode)
261      return false;
262
263   assert(instr->dest.dest.is_ssa);
264
265   if (expr->value.bit_size &&
266       instr->dest.dest.ssa.bit_size != expr->value.bit_size)
267      return false;
268
269   state->inexact_match = expr->inexact || state->inexact_match;
270   state->has_exact_alu = instr->exact || state->has_exact_alu;
271   if (state->inexact_match && state->has_exact_alu)
272      return false;
273
274   assert(!instr->dest.saturate);
275   assert(nir_op_infos[instr->op].num_inputs > 0);
276
277   /* If we have an explicitly sized destination, we can only handle the
278    * identity swizzle.  While dot(vec3(a, b, c).zxy) is a valid
279    * expression, we don't have the information right now to propagate that
280    * swizzle through.  We can only properly propagate swizzles if the
281    * instruction is vectorized.
282    */
283   if (nir_op_infos[instr->op].output_size != 0) {
284      for (unsigned i = 0; i < num_components; i++) {
285         if (swizzle[i] != i)
286            return false;
287      }
288   }
289
290   /* Stash off the current variables_seen bitmask.  This way we can
291    * restore it prior to matching in the commutative case below.
292    */
293   unsigned variables_seen_stash = state->variables_seen;
294
295   bool matched = true;
296   for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
297      if (!match_value(expr->srcs[i], instr, i, num_components,
298                       swizzle, state)) {
299         matched = false;
300         break;
301      }
302   }
303
304   if (matched)
305      return true;
306
307   if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
308      assert(nir_op_infos[instr->op].num_inputs == 2);
309
310      /* Restore the variables_seen bitmask.  If we don't do this, then we
311       * could end up with an erroneous failure due to variables found in the
312       * first match attempt above not matching those in the second.
313       */
314      state->variables_seen = variables_seen_stash;
315
316      if (!match_value(expr->srcs[0], instr, 1, num_components,
317                       swizzle, state))
318         return false;
319
320      return match_value(expr->srcs[1], instr, 0, num_components,
321                         swizzle, state);
322   } else {
323      return false;
324   }
325}
326
327typedef struct bitsize_tree {
328   unsigned num_srcs;
329   struct bitsize_tree *srcs[4];
330
331   unsigned common_size;
332   bool is_src_sized[4];
333   bool is_dest_sized;
334
335   unsigned dest_size;
336   unsigned src_size[4];
337} bitsize_tree;
338
339static bitsize_tree *
340build_bitsize_tree(void *mem_ctx, struct match_state *state,
341                   const nir_search_value *value)
342{
343   bitsize_tree *tree = ralloc(mem_ctx, bitsize_tree);
344
345   switch (value->type) {
346   case nir_search_value_expression: {
347      nir_search_expression *expr = nir_search_value_as_expression(value);
348      nir_op_info info = nir_op_infos[expr->opcode];
349      tree->num_srcs = info.num_inputs;
350      tree->common_size = 0;
351      for (unsigned i = 0; i < info.num_inputs; i++) {
352         tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]);
353         if (tree->is_src_sized[i])
354            tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]);
355         tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]);
356      }
357      tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type);
358      if (tree->is_dest_sized)
359         tree->dest_size = nir_alu_type_get_type_size(info.output_type);
360      break;
361   }
362
363   case nir_search_value_variable: {
364      nir_search_variable *var = nir_search_value_as_variable(value);
365      tree->num_srcs = 0;
366      tree->is_dest_sized = true;
367      tree->dest_size = nir_src_bit_size(state->variables[var->variable].src);
368      break;
369   }
370
371   case nir_search_value_constant: {
372      tree->num_srcs = 0;
373      tree->is_dest_sized = false;
374      tree->common_size = 0;
375      break;
376   }
377   }
378
379   if (value->bit_size) {
380      assert(!tree->is_dest_sized || tree->dest_size == value->bit_size);
381      tree->common_size = value->bit_size;
382   }
383
384   return tree;
385}
386
387static unsigned
388bitsize_tree_filter_up(bitsize_tree *tree)
389{
390   for (unsigned i = 0; i < tree->num_srcs; i++) {
391      unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]);
392      if (src_size == 0)
393         continue;
394
395      if (tree->is_src_sized[i]) {
396         assert(src_size == tree->src_size[i]);
397      } else if (tree->common_size != 0) {
398         assert(src_size == tree->common_size);
399         tree->src_size[i] = src_size;
400      } else {
401         tree->common_size = src_size;
402         tree->src_size[i] = src_size;
403      }
404   }
405
406   if (tree->num_srcs && tree->common_size) {
407      if (tree->dest_size == 0)
408         tree->dest_size = tree->common_size;
409      else if (!tree->is_dest_sized)
410         assert(tree->dest_size == tree->common_size);
411
412      for (unsigned i = 0; i < tree->num_srcs; i++) {
413         if (!tree->src_size[i])
414            tree->src_size[i] = tree->common_size;
415      }
416   }
417
418   return tree->dest_size;
419}
420
421static void
422bitsize_tree_filter_down(bitsize_tree *tree, unsigned size)
423{
424   if (tree->dest_size)
425      assert(tree->dest_size == size);
426   else
427      tree->dest_size = size;
428
429   if (!tree->is_dest_sized) {
430      if (tree->common_size)
431         assert(tree->common_size == size);
432      else
433         tree->common_size = size;
434   }
435
436   for (unsigned i = 0; i < tree->num_srcs; i++) {
437      if (!tree->src_size[i]) {
438         assert(tree->common_size);
439         tree->src_size[i] = tree->common_size;
440      }
441      bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]);
442   }
443}
444
445static nir_alu_src
446construct_value(const nir_search_value *value,
447                unsigned num_components, bitsize_tree *bitsize,
448                struct match_state *state,
449                nir_instr *instr, void *mem_ctx)
450{
451   switch (value->type) {
452   case nir_search_value_expression: {
453      const nir_search_expression *expr = nir_search_value_as_expression(value);
454
455      if (nir_op_infos[expr->opcode].output_size != 0)
456         num_components = nir_op_infos[expr->opcode].output_size;
457
458      nir_alu_instr *alu = nir_alu_instr_create(mem_ctx, expr->opcode);
459      nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
460                        bitsize->dest_size, NULL);
461      alu->dest.write_mask = (1 << num_components) - 1;
462      alu->dest.saturate = false;
463
464      /* We have no way of knowing what values in a given search expression
465       * map to a particular replacement value.  Therefore, if the
466       * expression we are replacing has any exact values, the entire
467       * replacement should be exact.
468       */
469      alu->exact = state->has_exact_alu;
470
471      for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) {
472         /* If the source is an explicitly sized source, then we need to reset
473          * the number of components to match.
474          */
475         if (nir_op_infos[alu->op].input_sizes[i] != 0)
476            num_components = nir_op_infos[alu->op].input_sizes[i];
477
478         alu->src[i] = construct_value(expr->srcs[i],
479                                       num_components, bitsize->srcs[i],
480                                       state, instr, mem_ctx);
481      }
482
483      nir_instr_insert_before(instr, &alu->instr);
484
485      nir_alu_src val;
486      val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
487      val.negate = false;
488      val.abs = false,
489      memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
490
491      return val;
492   }
493
494   case nir_search_value_variable: {
495      const nir_search_variable *var = nir_search_value_as_variable(value);
496      assert(state->variables_seen & (1 << var->variable));
497
498      nir_alu_src val = { NIR_SRC_INIT };
499      nir_alu_src_copy(&val, &state->variables[var->variable], mem_ctx);
500
501      assert(!var->is_constant);
502
503      return val;
504   }
505
506   case nir_search_value_constant: {
507      const nir_search_constant *c = nir_search_value_as_constant(value);
508      nir_load_const_instr *load =
509         nir_load_const_instr_create(mem_ctx, 1, bitsize->dest_size);
510
511      switch (c->type) {
512      case nir_type_float:
513         load->def.name = ralloc_asprintf(load, "%f", c->data.d);
514         switch (bitsize->dest_size) {
515         case 32:
516            load->value.f32[0] = c->data.d;
517            break;
518         case 64:
519            load->value.f64[0] = c->data.d;
520            break;
521         default:
522            unreachable("unknown bit size");
523         }
524         break;
525
526      case nir_type_int:
527         load->def.name = ralloc_asprintf(load, "%" PRIi64, c->data.i);
528         switch (bitsize->dest_size) {
529         case 32:
530            load->value.i32[0] = c->data.i;
531            break;
532         case 64:
533            load->value.i64[0] = c->data.i;
534            break;
535         default:
536            unreachable("unknown bit size");
537         }
538         break;
539
540      case nir_type_uint:
541         load->def.name = ralloc_asprintf(load, "%" PRIu64, c->data.u);
542         switch (bitsize->dest_size) {
543         case 32:
544            load->value.u32[0] = c->data.u;
545            break;
546         case 64:
547            load->value.u64[0] = c->data.u;
548            break;
549         default:
550            unreachable("unknown bit size");
551         }
552         break;
553
554      case nir_type_bool32:
555         load->value.u32[0] = c->data.u;
556         break;
557      default:
558         unreachable("Invalid alu source type");
559      }
560
561      nir_instr_insert_before(instr, &load->instr);
562
563      nir_alu_src val;
564      val.src = nir_src_for_ssa(&load->def);
565      val.negate = false;
566      val.abs = false,
567      memset(val.swizzle, 0, sizeof val.swizzle);
568
569      return val;
570   }
571
572   default:
573      unreachable("Invalid search value type");
574   }
575}
576
577nir_alu_instr *
578nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search,
579                  const nir_search_value *replace, void *mem_ctx)
580{
581   uint8_t swizzle[4] = { 0, 0, 0, 0 };
582
583   for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
584      swizzle[i] = i;
585
586   assert(instr->dest.dest.is_ssa);
587
588   struct match_state state;
589   state.inexact_match = false;
590   state.has_exact_alu = false;
591   state.variables_seen = 0;
592
593   if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
594                         swizzle, &state))
595      return NULL;
596
597   void *bitsize_ctx = ralloc_context(NULL);
598   bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace);
599   bitsize_tree_filter_up(tree);
600   bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size);
601
602   /* Inserting a mov may be unnecessary.  However, it's much easier to
603    * simply let copy propagation clean this up than to try to go through
604    * and rewrite swizzles ourselves.
605    */
606   nir_alu_instr *mov = nir_alu_instr_create(mem_ctx, nir_op_imov);
607   mov->dest.write_mask = instr->dest.write_mask;
608   nir_ssa_dest_init(&mov->instr, &mov->dest.dest,
609                     instr->dest.dest.ssa.num_components,
610                     instr->dest.dest.ssa.bit_size, NULL);
611
612   mov->src[0] = construct_value(replace,
613                                 instr->dest.dest.ssa.num_components, tree,
614                                 &state, &instr->instr, mem_ctx);
615   nir_instr_insert_before(&instr->instr, &mov->instr);
616
617   nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa,
618                            nir_src_for_ssa(&mov->dest.dest.ssa));
619
620   /* We know this one has no more uses because we just rewrote them all,
621    * so we can remove it.  The rest of the matched expression, however, we
622    * don't know so much about.  We'll just let dead code clean them up.
623    */
624   nir_instr_remove(&instr->instr);
625
626   ralloc_free(bitsize_ctx);
627
628   return mov;
629}
630