nir_search.c revision 7e0ee3a38b033aad12a0c19dc5437bc9c011437a
1/* 2 * Copyright © 2014 Intel Corporation 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 * 23 * Authors: 24 * Jason Ekstrand (jason@jlekstrand.net) 25 * 26 */ 27 28#include <inttypes.h> 29#include "nir_search.h" 30 31struct match_state { 32 bool inexact_match; 33 bool has_exact_alu; 34 unsigned variables_seen; 35 nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES]; 36}; 37 38static bool 39match_expression(const nir_search_expression *expr, nir_alu_instr *instr, 40 unsigned num_components, const uint8_t *swizzle, 41 struct match_state *state); 42 43static const uint8_t identity_swizzle[] = { 0, 1, 2, 3 }; 44 45static bool alu_instr_is_bool(nir_alu_instr *instr); 46 47static bool 48src_is_bool(nir_src src) 49{ 50 if (!src.is_ssa) 51 return false; 52 if (src.ssa->parent_instr->type != nir_instr_type_alu) 53 return false; 54 return alu_instr_is_bool(nir_instr_as_alu(src.ssa->parent_instr)); 55} 56 57static bool 58alu_instr_is_bool(nir_alu_instr *instr) 59{ 60 switch (instr->op) { 61 case nir_op_iand: 62 case nir_op_ior: 63 case nir_op_ixor: 64 return src_is_bool(instr->src[0].src) && src_is_bool(instr->src[1].src); 65 case nir_op_inot: 66 return src_is_bool(instr->src[0].src); 67 default: 68 return (nir_alu_type_get_base_type(nir_op_infos[instr->op].output_type) 69 == nir_type_bool); 70 } 71} 72 73static bool 74match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, 75 unsigned num_components, const uint8_t *swizzle, 76 struct match_state *state) 77{ 78 uint8_t new_swizzle[4]; 79 80 /* If the source is an explicitly sized source, then we need to reset 81 * both the number of components and the swizzle. 82 */ 83 if (nir_op_infos[instr->op].input_sizes[src] != 0) { 84 num_components = nir_op_infos[instr->op].input_sizes[src]; 85 swizzle = identity_swizzle; 86 } 87 88 for (unsigned i = 0; i < num_components; ++i) 89 new_swizzle[i] = instr->src[src].swizzle[swizzle[i]]; 90 91 /* If the value has a specific bit size and it doesn't match, bail */ 92 if (value->bit_size && 93 nir_src_bit_size(instr->src[src].src) != value->bit_size) 94 return false; 95 96 switch (value->type) { 97 case nir_search_value_expression: 98 if (!instr->src[src].src.is_ssa) 99 return false; 100 101 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu) 102 return false; 103 104 return match_expression(nir_search_value_as_expression(value), 105 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr), 106 num_components, new_swizzle, state); 107 108 case nir_search_value_variable: { 109 nir_search_variable *var = nir_search_value_as_variable(value); 110 assert(var->variable < NIR_SEARCH_MAX_VARIABLES); 111 112 if (state->variables_seen & (1 << var->variable)) { 113 if (!nir_srcs_equal(state->variables[var->variable].src, 114 instr->src[src].src)) 115 return false; 116 117 assert(!instr->src[src].abs && !instr->src[src].negate); 118 119 for (unsigned i = 0; i < num_components; ++i) { 120 if (state->variables[var->variable].swizzle[i] != new_swizzle[i]) 121 return false; 122 } 123 124 return true; 125 } else { 126 if (var->is_constant && 127 instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const) 128 return false; 129 130 if (var->type != nir_type_invalid) { 131 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu) 132 return false; 133 134 nir_alu_instr *src_alu = 135 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr); 136 137 if (nir_alu_type_get_base_type(nir_op_infos[src_alu->op].output_type) != 138 var->type && 139 !(nir_alu_type_get_base_type(var->type) == nir_type_bool && 140 alu_instr_is_bool(src_alu))) 141 return false; 142 } 143 144 state->variables_seen |= (1 << var->variable); 145 state->variables[var->variable].src = instr->src[src].src; 146 state->variables[var->variable].abs = false; 147 state->variables[var->variable].negate = false; 148 149 for (unsigned i = 0; i < 4; ++i) { 150 if (i < num_components) 151 state->variables[var->variable].swizzle[i] = new_swizzle[i]; 152 else 153 state->variables[var->variable].swizzle[i] = 0; 154 } 155 156 return true; 157 } 158 } 159 160 case nir_search_value_constant: { 161 nir_search_constant *const_val = nir_search_value_as_constant(value); 162 163 if (!instr->src[src].src.is_ssa) 164 return false; 165 166 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const) 167 return false; 168 169 nir_load_const_instr *load = 170 nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr); 171 172 switch (const_val->type) { 173 case nir_type_float: 174 for (unsigned i = 0; i < num_components; ++i) { 175 double val; 176 switch (load->def.bit_size) { 177 case 32: 178 val = load->value.f32[new_swizzle[i]]; 179 break; 180 case 64: 181 val = load->value.f64[new_swizzle[i]]; 182 break; 183 default: 184 unreachable("unknown bit size"); 185 } 186 187 if (val != const_val->data.d) 188 return false; 189 } 190 return true; 191 192 case nir_type_int: 193 for (unsigned i = 0; i < num_components; ++i) { 194 int64_t val; 195 switch (load->def.bit_size) { 196 case 32: 197 val = load->value.i32[new_swizzle[i]]; 198 break; 199 case 64: 200 val = load->value.i64[new_swizzle[i]]; 201 break; 202 default: 203 unreachable("unknown bit size"); 204 } 205 206 if (val != const_val->data.i) 207 return false; 208 } 209 return true; 210 211 case nir_type_uint: 212 case nir_type_bool32: 213 for (unsigned i = 0; i < num_components; ++i) { 214 uint64_t val; 215 switch (load->def.bit_size) { 216 case 32: 217 val = load->value.u32[new_swizzle[i]]; 218 break; 219 case 64: 220 val = load->value.u64[new_swizzle[i]]; 221 break; 222 default: 223 unreachable("unknown bit size"); 224 } 225 226 if (val != const_val->data.u) 227 return false; 228 } 229 return true; 230 231 default: 232 unreachable("Invalid alu source type"); 233 } 234 } 235 236 default: 237 unreachable("Invalid search value type"); 238 } 239} 240 241static bool 242match_expression(const nir_search_expression *expr, nir_alu_instr *instr, 243 unsigned num_components, const uint8_t *swizzle, 244 struct match_state *state) 245{ 246 if (instr->op != expr->opcode) 247 return false; 248 249 assert(instr->dest.dest.is_ssa); 250 251 if (expr->value.bit_size && 252 instr->dest.dest.ssa.bit_size != expr->value.bit_size) 253 return false; 254 255 state->inexact_match = expr->inexact || state->inexact_match; 256 state->has_exact_alu = instr->exact || state->has_exact_alu; 257 if (state->inexact_match && state->has_exact_alu) 258 return false; 259 260 assert(!instr->dest.saturate); 261 assert(nir_op_infos[instr->op].num_inputs > 0); 262 263 /* If we have an explicitly sized destination, we can only handle the 264 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid 265 * expression, we don't have the information right now to propagate that 266 * swizzle through. We can only properly propagate swizzles if the 267 * instruction is vectorized. 268 */ 269 if (nir_op_infos[instr->op].output_size != 0) { 270 for (unsigned i = 0; i < num_components; i++) { 271 if (swizzle[i] != i) 272 return false; 273 } 274 } 275 276 /* Stash off the current variables_seen bitmask. This way we can 277 * restore it prior to matching in the commutative case below. 278 */ 279 unsigned variables_seen_stash = state->variables_seen; 280 281 bool matched = true; 282 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) { 283 if (!match_value(expr->srcs[i], instr, i, num_components, 284 swizzle, state)) { 285 matched = false; 286 break; 287 } 288 } 289 290 if (matched) 291 return true; 292 293 if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) { 294 assert(nir_op_infos[instr->op].num_inputs == 2); 295 296 /* Restore the variables_seen bitmask. If we don't do this, then we 297 * could end up with an erroneous failure due to variables found in the 298 * first match attempt above not matching those in the second. 299 */ 300 state->variables_seen = variables_seen_stash; 301 302 if (!match_value(expr->srcs[0], instr, 1, num_components, 303 swizzle, state)) 304 return false; 305 306 return match_value(expr->srcs[1], instr, 0, num_components, 307 swizzle, state); 308 } else { 309 return false; 310 } 311} 312 313typedef struct bitsize_tree { 314 unsigned num_srcs; 315 struct bitsize_tree *srcs[4]; 316 317 unsigned common_size; 318 bool is_src_sized[4]; 319 bool is_dest_sized; 320 321 unsigned dest_size; 322 unsigned src_size[4]; 323} bitsize_tree; 324 325static bitsize_tree * 326build_bitsize_tree(void *mem_ctx, struct match_state *state, 327 const nir_search_value *value) 328{ 329 bitsize_tree *tree = ralloc(mem_ctx, bitsize_tree); 330 331 switch (value->type) { 332 case nir_search_value_expression: { 333 nir_search_expression *expr = nir_search_value_as_expression(value); 334 nir_op_info info = nir_op_infos[expr->opcode]; 335 tree->num_srcs = info.num_inputs; 336 tree->common_size = 0; 337 for (unsigned i = 0; i < info.num_inputs; i++) { 338 tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]); 339 if (tree->is_src_sized[i]) 340 tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]); 341 tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]); 342 } 343 tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type); 344 if (tree->is_dest_sized) 345 tree->dest_size = nir_alu_type_get_type_size(info.output_type); 346 break; 347 } 348 349 case nir_search_value_variable: { 350 nir_search_variable *var = nir_search_value_as_variable(value); 351 tree->num_srcs = 0; 352 tree->is_dest_sized = true; 353 tree->dest_size = nir_src_bit_size(state->variables[var->variable].src); 354 break; 355 } 356 357 case nir_search_value_constant: { 358 tree->num_srcs = 0; 359 tree->is_dest_sized = false; 360 tree->common_size = 0; 361 break; 362 } 363 } 364 365 if (value->bit_size) { 366 assert(!tree->is_dest_sized || tree->dest_size == value->bit_size); 367 tree->common_size = value->bit_size; 368 } 369 370 return tree; 371} 372 373static unsigned 374bitsize_tree_filter_up(bitsize_tree *tree) 375{ 376 for (unsigned i = 0; i < tree->num_srcs; i++) { 377 unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]); 378 if (src_size == 0) 379 continue; 380 381 if (tree->is_src_sized[i]) { 382 assert(src_size == tree->src_size[i]); 383 } else if (tree->common_size != 0) { 384 assert(src_size == tree->common_size); 385 tree->src_size[i] = src_size; 386 } else { 387 tree->common_size = src_size; 388 tree->src_size[i] = src_size; 389 } 390 } 391 392 if (tree->num_srcs && tree->common_size) { 393 if (tree->dest_size == 0) 394 tree->dest_size = tree->common_size; 395 else if (!tree->is_dest_sized) 396 assert(tree->dest_size == tree->common_size); 397 398 for (unsigned i = 0; i < tree->num_srcs; i++) { 399 if (!tree->src_size[i]) 400 tree->src_size[i] = tree->common_size; 401 } 402 } 403 404 return tree->dest_size; 405} 406 407static void 408bitsize_tree_filter_down(bitsize_tree *tree, unsigned size) 409{ 410 if (tree->dest_size) 411 assert(tree->dest_size == size); 412 else 413 tree->dest_size = size; 414 415 if (!tree->is_dest_sized) { 416 if (tree->common_size) 417 assert(tree->common_size == size); 418 else 419 tree->common_size = size; 420 } 421 422 for (unsigned i = 0; i < tree->num_srcs; i++) { 423 if (!tree->src_size[i]) { 424 assert(tree->common_size); 425 tree->src_size[i] = tree->common_size; 426 } 427 bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]); 428 } 429} 430 431static nir_alu_src 432construct_value(const nir_search_value *value, 433 unsigned num_components, bitsize_tree *bitsize, 434 struct match_state *state, 435 nir_instr *instr, void *mem_ctx) 436{ 437 switch (value->type) { 438 case nir_search_value_expression: { 439 const nir_search_expression *expr = nir_search_value_as_expression(value); 440 441 if (nir_op_infos[expr->opcode].output_size != 0) 442 num_components = nir_op_infos[expr->opcode].output_size; 443 444 nir_alu_instr *alu = nir_alu_instr_create(mem_ctx, expr->opcode); 445 nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components, 446 bitsize->dest_size, NULL); 447 alu->dest.write_mask = (1 << num_components) - 1; 448 alu->dest.saturate = false; 449 450 /* We have no way of knowing what values in a given search expression 451 * map to a particular replacement value. Therefore, if the 452 * expression we are replacing has any exact values, the entire 453 * replacement should be exact. 454 */ 455 alu->exact = state->has_exact_alu; 456 457 for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) { 458 /* If the source is an explicitly sized source, then we need to reset 459 * the number of components to match. 460 */ 461 if (nir_op_infos[alu->op].input_sizes[i] != 0) 462 num_components = nir_op_infos[alu->op].input_sizes[i]; 463 464 alu->src[i] = construct_value(expr->srcs[i], 465 num_components, bitsize->srcs[i], 466 state, instr, mem_ctx); 467 } 468 469 nir_instr_insert_before(instr, &alu->instr); 470 471 nir_alu_src val; 472 val.src = nir_src_for_ssa(&alu->dest.dest.ssa); 473 val.negate = false; 474 val.abs = false, 475 memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle); 476 477 return val; 478 } 479 480 case nir_search_value_variable: { 481 const nir_search_variable *var = nir_search_value_as_variable(value); 482 assert(state->variables_seen & (1 << var->variable)); 483 484 nir_alu_src val = { NIR_SRC_INIT }; 485 nir_alu_src_copy(&val, &state->variables[var->variable], mem_ctx); 486 487 assert(!var->is_constant); 488 489 return val; 490 } 491 492 case nir_search_value_constant: { 493 const nir_search_constant *c = nir_search_value_as_constant(value); 494 nir_load_const_instr *load = 495 nir_load_const_instr_create(mem_ctx, 1, bitsize->dest_size); 496 497 switch (c->type) { 498 case nir_type_float: 499 load->def.name = ralloc_asprintf(load, "%f", c->data.d); 500 switch (bitsize->dest_size) { 501 case 32: 502 load->value.f32[0] = c->data.d; 503 break; 504 case 64: 505 load->value.f64[0] = c->data.d; 506 break; 507 default: 508 unreachable("unknown bit size"); 509 } 510 break; 511 512 case nir_type_int: 513 load->def.name = ralloc_asprintf(load, "%" PRIi64, c->data.i); 514 switch (bitsize->dest_size) { 515 case 32: 516 load->value.i32[0] = c->data.i; 517 break; 518 case 64: 519 load->value.i64[0] = c->data.i; 520 break; 521 default: 522 unreachable("unknown bit size"); 523 } 524 break; 525 526 case nir_type_uint: 527 load->def.name = ralloc_asprintf(load, "%" PRIu64, c->data.u); 528 switch (bitsize->dest_size) { 529 case 32: 530 load->value.u32[0] = c->data.u; 531 break; 532 case 64: 533 load->value.u64[0] = c->data.u; 534 break; 535 default: 536 unreachable("unknown bit size"); 537 } 538 break; 539 540 case nir_type_bool32: 541 load->value.u32[0] = c->data.u; 542 break; 543 default: 544 unreachable("Invalid alu source type"); 545 } 546 547 nir_instr_insert_before(instr, &load->instr); 548 549 nir_alu_src val; 550 val.src = nir_src_for_ssa(&load->def); 551 val.negate = false; 552 val.abs = false, 553 memset(val.swizzle, 0, sizeof val.swizzle); 554 555 return val; 556 } 557 558 default: 559 unreachable("Invalid search value type"); 560 } 561} 562 563nir_alu_instr * 564nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search, 565 const nir_search_value *replace, void *mem_ctx) 566{ 567 uint8_t swizzle[4] = { 0, 0, 0, 0 }; 568 569 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i) 570 swizzle[i] = i; 571 572 assert(instr->dest.dest.is_ssa); 573 574 struct match_state state; 575 state.inexact_match = false; 576 state.has_exact_alu = false; 577 state.variables_seen = 0; 578 579 if (!match_expression(search, instr, instr->dest.dest.ssa.num_components, 580 swizzle, &state)) 581 return NULL; 582 583 void *bitsize_ctx = ralloc_context(NULL); 584 bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace); 585 bitsize_tree_filter_up(tree); 586 bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size); 587 588 /* Inserting a mov may be unnecessary. However, it's much easier to 589 * simply let copy propagation clean this up than to try to go through 590 * and rewrite swizzles ourselves. 591 */ 592 nir_alu_instr *mov = nir_alu_instr_create(mem_ctx, nir_op_imov); 593 mov->dest.write_mask = instr->dest.write_mask; 594 nir_ssa_dest_init(&mov->instr, &mov->dest.dest, 595 instr->dest.dest.ssa.num_components, 596 instr->dest.dest.ssa.bit_size, NULL); 597 598 mov->src[0] = construct_value(replace, 599 instr->dest.dest.ssa.num_components, tree, 600 &state, &instr->instr, mem_ctx); 601 nir_instr_insert_before(&instr->instr, &mov->instr); 602 603 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, 604 nir_src_for_ssa(&mov->dest.dest.ssa)); 605 606 /* We know this one has no more uses because we just rewrote them all, 607 * so we can remove it. The rest of the matched expression, however, we 608 * don't know so much about. We'll just let dead code clean them up. 609 */ 610 nir_instr_remove(&instr->instr); 611 612 ralloc_free(bitsize_ctx); 613 614 return mov; 615} 616