nir_search.c revision 3a9e6102b4baae3f50956e5f150c9e59138f4cc0
1/* 2 * Copyright © 2014 Intel Corporation 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 * 23 * Authors: 24 * Jason Ekstrand (jason@jlekstrand.net) 25 * 26 */ 27 28#include <inttypes.h> 29#include "nir_search.h" 30 31struct match_state { 32 bool inexact_match; 33 bool has_exact_alu; 34 unsigned variables_seen; 35 nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES]; 36}; 37 38static bool 39match_expression(const nir_search_expression *expr, nir_alu_instr *instr, 40 unsigned num_components, const uint8_t *swizzle, 41 struct match_state *state); 42 43static const uint8_t identity_swizzle[] = { 0, 1, 2, 3 }; 44 45/** 46 * Check if a source produces a value of the given type. 47 * 48 * Used for satisfying 'a@type' constraints. 49 */ 50static bool 51src_is_type(nir_src src, nir_alu_type type) 52{ 53 assert(type != nir_type_invalid); 54 55 if (!src.is_ssa) 56 return false; 57 58 /* Turn nir_type_bool32 into nir_type_bool...they're the same thing. */ 59 if (nir_alu_type_get_base_type(type) == nir_type_bool) 60 type = nir_type_bool; 61 62 if (src.ssa->parent_instr->type == nir_instr_type_alu) { 63 nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr); 64 nir_alu_type output_type = nir_op_infos[src_alu->op].output_type; 65 66 if (type == nir_type_bool) { 67 switch (src_alu->op) { 68 case nir_op_iand: 69 case nir_op_ior: 70 case nir_op_ixor: 71 return src_is_type(src_alu->src[0].src, nir_type_bool) && 72 src_is_type(src_alu->src[1].src, nir_type_bool); 73 case nir_op_inot: 74 return src_is_type(src_alu->src[0].src, nir_type_bool); 75 default: 76 break; 77 } 78 } 79 80 return nir_alu_type_get_base_type(output_type) == type; 81 } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) { 82 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr); 83 84 if (type == nir_type_bool) { 85 return intr->intrinsic == nir_intrinsic_load_front_face || 86 intr->intrinsic == nir_intrinsic_load_helper_invocation; 87 } 88 } 89 90 /* don't know */ 91 return false; 92} 93 94static bool 95match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, 96 unsigned num_components, const uint8_t *swizzle, 97 struct match_state *state) 98{ 99 uint8_t new_swizzle[4]; 100 101 /* If the source is an explicitly sized source, then we need to reset 102 * both the number of components and the swizzle. 103 */ 104 if (nir_op_infos[instr->op].input_sizes[src] != 0) { 105 num_components = nir_op_infos[instr->op].input_sizes[src]; 106 swizzle = identity_swizzle; 107 } 108 109 for (unsigned i = 0; i < num_components; ++i) 110 new_swizzle[i] = instr->src[src].swizzle[swizzle[i]]; 111 112 /* If the value has a specific bit size and it doesn't match, bail */ 113 if (value->bit_size && 114 nir_src_bit_size(instr->src[src].src) != value->bit_size) 115 return false; 116 117 switch (value->type) { 118 case nir_search_value_expression: 119 if (!instr->src[src].src.is_ssa) 120 return false; 121 122 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu) 123 return false; 124 125 return match_expression(nir_search_value_as_expression(value), 126 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr), 127 num_components, new_swizzle, state); 128 129 case nir_search_value_variable: { 130 nir_search_variable *var = nir_search_value_as_variable(value); 131 assert(var->variable < NIR_SEARCH_MAX_VARIABLES); 132 133 if (state->variables_seen & (1 << var->variable)) { 134 if (!nir_srcs_equal(state->variables[var->variable].src, 135 instr->src[src].src)) 136 return false; 137 138 assert(!instr->src[src].abs && !instr->src[src].negate); 139 140 for (unsigned i = 0; i < num_components; ++i) { 141 if (state->variables[var->variable].swizzle[i] != new_swizzle[i]) 142 return false; 143 } 144 145 return true; 146 } else { 147 if (var->is_constant && 148 instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const) 149 return false; 150 151 if (var->cond && !var->cond(instr, src, num_components, new_swizzle)) 152 return false; 153 154 if (var->type != nir_type_invalid && 155 !src_is_type(instr->src[src].src, var->type)) 156 return false; 157 158 state->variables_seen |= (1 << var->variable); 159 state->variables[var->variable].src = instr->src[src].src; 160 state->variables[var->variable].abs = false; 161 state->variables[var->variable].negate = false; 162 163 for (unsigned i = 0; i < 4; ++i) { 164 if (i < num_components) 165 state->variables[var->variable].swizzle[i] = new_swizzle[i]; 166 else 167 state->variables[var->variable].swizzle[i] = 0; 168 } 169 170 return true; 171 } 172 } 173 174 case nir_search_value_constant: { 175 nir_search_constant *const_val = nir_search_value_as_constant(value); 176 177 if (!instr->src[src].src.is_ssa) 178 return false; 179 180 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const) 181 return false; 182 183 nir_load_const_instr *load = 184 nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr); 185 186 switch (const_val->type) { 187 case nir_type_float: 188 for (unsigned i = 0; i < num_components; ++i) { 189 double val; 190 switch (load->def.bit_size) { 191 case 32: 192 val = load->value.f32[new_swizzle[i]]; 193 break; 194 case 64: 195 val = load->value.f64[new_swizzle[i]]; 196 break; 197 default: 198 unreachable("unknown bit size"); 199 } 200 201 if (val != const_val->data.d) 202 return false; 203 } 204 return true; 205 206 case nir_type_int: 207 for (unsigned i = 0; i < num_components; ++i) { 208 int64_t val; 209 switch (load->def.bit_size) { 210 case 32: 211 val = load->value.i32[new_swizzle[i]]; 212 break; 213 case 64: 214 val = load->value.i64[new_swizzle[i]]; 215 break; 216 default: 217 unreachable("unknown bit size"); 218 } 219 220 if (val != const_val->data.i) 221 return false; 222 } 223 return true; 224 225 case nir_type_uint: 226 case nir_type_bool32: 227 for (unsigned i = 0; i < num_components; ++i) { 228 uint64_t val; 229 switch (load->def.bit_size) { 230 case 32: 231 val = load->value.u32[new_swizzle[i]]; 232 break; 233 case 64: 234 val = load->value.u64[new_swizzle[i]]; 235 break; 236 default: 237 unreachable("unknown bit size"); 238 } 239 240 if (val != const_val->data.u) 241 return false; 242 } 243 return true; 244 245 default: 246 unreachable("Invalid alu source type"); 247 } 248 } 249 250 default: 251 unreachable("Invalid search value type"); 252 } 253} 254 255static bool 256match_expression(const nir_search_expression *expr, nir_alu_instr *instr, 257 unsigned num_components, const uint8_t *swizzle, 258 struct match_state *state) 259{ 260 if (instr->op != expr->opcode) 261 return false; 262 263 assert(instr->dest.dest.is_ssa); 264 265 if (expr->value.bit_size && 266 instr->dest.dest.ssa.bit_size != expr->value.bit_size) 267 return false; 268 269 state->inexact_match = expr->inexact || state->inexact_match; 270 state->has_exact_alu = instr->exact || state->has_exact_alu; 271 if (state->inexact_match && state->has_exact_alu) 272 return false; 273 274 assert(!instr->dest.saturate); 275 assert(nir_op_infos[instr->op].num_inputs > 0); 276 277 /* If we have an explicitly sized destination, we can only handle the 278 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid 279 * expression, we don't have the information right now to propagate that 280 * swizzle through. We can only properly propagate swizzles if the 281 * instruction is vectorized. 282 */ 283 if (nir_op_infos[instr->op].output_size != 0) { 284 for (unsigned i = 0; i < num_components; i++) { 285 if (swizzle[i] != i) 286 return false; 287 } 288 } 289 290 /* Stash off the current variables_seen bitmask. This way we can 291 * restore it prior to matching in the commutative case below. 292 */ 293 unsigned variables_seen_stash = state->variables_seen; 294 295 bool matched = true; 296 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) { 297 if (!match_value(expr->srcs[i], instr, i, num_components, 298 swizzle, state)) { 299 matched = false; 300 break; 301 } 302 } 303 304 if (matched) 305 return true; 306 307 if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) { 308 assert(nir_op_infos[instr->op].num_inputs == 2); 309 310 /* Restore the variables_seen bitmask. If we don't do this, then we 311 * could end up with an erroneous failure due to variables found in the 312 * first match attempt above not matching those in the second. 313 */ 314 state->variables_seen = variables_seen_stash; 315 316 if (!match_value(expr->srcs[0], instr, 1, num_components, 317 swizzle, state)) 318 return false; 319 320 return match_value(expr->srcs[1], instr, 0, num_components, 321 swizzle, state); 322 } else { 323 return false; 324 } 325} 326 327typedef struct bitsize_tree { 328 unsigned num_srcs; 329 struct bitsize_tree *srcs[4]; 330 331 unsigned common_size; 332 bool is_src_sized[4]; 333 bool is_dest_sized; 334 335 unsigned dest_size; 336 unsigned src_size[4]; 337} bitsize_tree; 338 339static bitsize_tree * 340build_bitsize_tree(void *mem_ctx, struct match_state *state, 341 const nir_search_value *value) 342{ 343 bitsize_tree *tree = ralloc(mem_ctx, bitsize_tree); 344 345 switch (value->type) { 346 case nir_search_value_expression: { 347 nir_search_expression *expr = nir_search_value_as_expression(value); 348 nir_op_info info = nir_op_infos[expr->opcode]; 349 tree->num_srcs = info.num_inputs; 350 tree->common_size = 0; 351 for (unsigned i = 0; i < info.num_inputs; i++) { 352 tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]); 353 if (tree->is_src_sized[i]) 354 tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]); 355 tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]); 356 } 357 tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type); 358 if (tree->is_dest_sized) 359 tree->dest_size = nir_alu_type_get_type_size(info.output_type); 360 break; 361 } 362 363 case nir_search_value_variable: { 364 nir_search_variable *var = nir_search_value_as_variable(value); 365 tree->num_srcs = 0; 366 tree->is_dest_sized = true; 367 tree->dest_size = nir_src_bit_size(state->variables[var->variable].src); 368 break; 369 } 370 371 case nir_search_value_constant: { 372 tree->num_srcs = 0; 373 tree->is_dest_sized = false; 374 tree->common_size = 0; 375 break; 376 } 377 } 378 379 if (value->bit_size) { 380 assert(!tree->is_dest_sized || tree->dest_size == value->bit_size); 381 tree->common_size = value->bit_size; 382 } 383 384 return tree; 385} 386 387static unsigned 388bitsize_tree_filter_up(bitsize_tree *tree) 389{ 390 for (unsigned i = 0; i < tree->num_srcs; i++) { 391 unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]); 392 if (src_size == 0) 393 continue; 394 395 if (tree->is_src_sized[i]) { 396 assert(src_size == tree->src_size[i]); 397 } else if (tree->common_size != 0) { 398 assert(src_size == tree->common_size); 399 tree->src_size[i] = src_size; 400 } else { 401 tree->common_size = src_size; 402 tree->src_size[i] = src_size; 403 } 404 } 405 406 if (tree->num_srcs && tree->common_size) { 407 if (tree->dest_size == 0) 408 tree->dest_size = tree->common_size; 409 else if (!tree->is_dest_sized) 410 assert(tree->dest_size == tree->common_size); 411 412 for (unsigned i = 0; i < tree->num_srcs; i++) { 413 if (!tree->src_size[i]) 414 tree->src_size[i] = tree->common_size; 415 } 416 } 417 418 return tree->dest_size; 419} 420 421static void 422bitsize_tree_filter_down(bitsize_tree *tree, unsigned size) 423{ 424 if (tree->dest_size) 425 assert(tree->dest_size == size); 426 else 427 tree->dest_size = size; 428 429 if (!tree->is_dest_sized) { 430 if (tree->common_size) 431 assert(tree->common_size == size); 432 else 433 tree->common_size = size; 434 } 435 436 for (unsigned i = 0; i < tree->num_srcs; i++) { 437 if (!tree->src_size[i]) { 438 assert(tree->common_size); 439 tree->src_size[i] = tree->common_size; 440 } 441 bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]); 442 } 443} 444 445static nir_alu_src 446construct_value(const nir_search_value *value, 447 unsigned num_components, bitsize_tree *bitsize, 448 struct match_state *state, 449 nir_instr *instr, void *mem_ctx) 450{ 451 switch (value->type) { 452 case nir_search_value_expression: { 453 const nir_search_expression *expr = nir_search_value_as_expression(value); 454 455 if (nir_op_infos[expr->opcode].output_size != 0) 456 num_components = nir_op_infos[expr->opcode].output_size; 457 458 nir_alu_instr *alu = nir_alu_instr_create(mem_ctx, expr->opcode); 459 nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components, 460 bitsize->dest_size, NULL); 461 alu->dest.write_mask = (1 << num_components) - 1; 462 alu->dest.saturate = false; 463 464 /* We have no way of knowing what values in a given search expression 465 * map to a particular replacement value. Therefore, if the 466 * expression we are replacing has any exact values, the entire 467 * replacement should be exact. 468 */ 469 alu->exact = state->has_exact_alu; 470 471 for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) { 472 /* If the source is an explicitly sized source, then we need to reset 473 * the number of components to match. 474 */ 475 if (nir_op_infos[alu->op].input_sizes[i] != 0) 476 num_components = nir_op_infos[alu->op].input_sizes[i]; 477 478 alu->src[i] = construct_value(expr->srcs[i], 479 num_components, bitsize->srcs[i], 480 state, instr, mem_ctx); 481 } 482 483 nir_instr_insert_before(instr, &alu->instr); 484 485 nir_alu_src val; 486 val.src = nir_src_for_ssa(&alu->dest.dest.ssa); 487 val.negate = false; 488 val.abs = false, 489 memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle); 490 491 return val; 492 } 493 494 case nir_search_value_variable: { 495 const nir_search_variable *var = nir_search_value_as_variable(value); 496 assert(state->variables_seen & (1 << var->variable)); 497 498 nir_alu_src val = { NIR_SRC_INIT }; 499 nir_alu_src_copy(&val, &state->variables[var->variable], mem_ctx); 500 501 assert(!var->is_constant); 502 503 return val; 504 } 505 506 case nir_search_value_constant: { 507 const nir_search_constant *c = nir_search_value_as_constant(value); 508 nir_load_const_instr *load = 509 nir_load_const_instr_create(mem_ctx, 1, bitsize->dest_size); 510 511 switch (c->type) { 512 case nir_type_float: 513 load->def.name = ralloc_asprintf(load, "%f", c->data.d); 514 switch (bitsize->dest_size) { 515 case 32: 516 load->value.f32[0] = c->data.d; 517 break; 518 case 64: 519 load->value.f64[0] = c->data.d; 520 break; 521 default: 522 unreachable("unknown bit size"); 523 } 524 break; 525 526 case nir_type_int: 527 load->def.name = ralloc_asprintf(load, "%" PRIi64, c->data.i); 528 switch (bitsize->dest_size) { 529 case 32: 530 load->value.i32[0] = c->data.i; 531 break; 532 case 64: 533 load->value.i64[0] = c->data.i; 534 break; 535 default: 536 unreachable("unknown bit size"); 537 } 538 break; 539 540 case nir_type_uint: 541 load->def.name = ralloc_asprintf(load, "%" PRIu64, c->data.u); 542 switch (bitsize->dest_size) { 543 case 32: 544 load->value.u32[0] = c->data.u; 545 break; 546 case 64: 547 load->value.u64[0] = c->data.u; 548 break; 549 default: 550 unreachable("unknown bit size"); 551 } 552 break; 553 554 case nir_type_bool32: 555 load->value.u32[0] = c->data.u; 556 break; 557 default: 558 unreachable("Invalid alu source type"); 559 } 560 561 nir_instr_insert_before(instr, &load->instr); 562 563 nir_alu_src val; 564 val.src = nir_src_for_ssa(&load->def); 565 val.negate = false; 566 val.abs = false, 567 memset(val.swizzle, 0, sizeof val.swizzle); 568 569 return val; 570 } 571 572 default: 573 unreachable("Invalid search value type"); 574 } 575} 576 577nir_alu_instr * 578nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search, 579 const nir_search_value *replace, void *mem_ctx) 580{ 581 uint8_t swizzle[4] = { 0, 0, 0, 0 }; 582 583 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i) 584 swizzle[i] = i; 585 586 assert(instr->dest.dest.is_ssa); 587 588 struct match_state state; 589 state.inexact_match = false; 590 state.has_exact_alu = false; 591 state.variables_seen = 0; 592 593 if (!match_expression(search, instr, instr->dest.dest.ssa.num_components, 594 swizzle, &state)) 595 return NULL; 596 597 void *bitsize_ctx = ralloc_context(NULL); 598 bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace); 599 bitsize_tree_filter_up(tree); 600 bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size); 601 602 /* Inserting a mov may be unnecessary. However, it's much easier to 603 * simply let copy propagation clean this up than to try to go through 604 * and rewrite swizzles ourselves. 605 */ 606 nir_alu_instr *mov = nir_alu_instr_create(mem_ctx, nir_op_imov); 607 mov->dest.write_mask = instr->dest.write_mask; 608 nir_ssa_dest_init(&mov->instr, &mov->dest.dest, 609 instr->dest.dest.ssa.num_components, 610 instr->dest.dest.ssa.bit_size, NULL); 611 612 mov->src[0] = construct_value(replace, 613 instr->dest.dest.ssa.num_components, tree, 614 &state, &instr->instr, mem_ctx); 615 nir_instr_insert_before(&instr->instr, &mov->instr); 616 617 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, 618 nir_src_for_ssa(&mov->dest.dest.ssa)); 619 620 /* We know this one has no more uses because we just rewrote them all, 621 * so we can remove it. The rest of the matched expression, however, we 622 * don't know so much about. We'll just let dead code clean them up. 623 */ 624 nir_instr_remove(&instr->instr); 625 626 ralloc_free(bitsize_ctx); 627 628 return mov; 629} 630