1/*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
21 * DEALINGS IN THE SOFTWARE.
22 */
23
24#include "nir.h"
25#include "nir_builder.h"
26#include "nir_control_flow.h"
27#include "nir_loop_analyze.h"
28
29/* Prepare this loop for unrolling by first converting to lcssa and then
30 * converting the phis from the loops first block and the block that follows
31 * the loop into regs.  Partially converting out of SSA allows us to unroll
32 * the loop without having to keep track of and update phis along the way
33 * which gets tricky and doesn't add much value over conveting to regs.
34 *
35 * The loop may have a continue instruction at the end of the loop which does
36 * nothing.  Once we're out of SSA, we can safely delete it so we don't have
37 * to deal with it later.
38 */
39static void
40loop_prepare_for_unroll(nir_loop *loop)
41{
42   nir_convert_loop_to_lcssa(loop);
43
44   nir_lower_phis_to_regs_block(nir_loop_first_block(loop));
45
46   nir_block *block_after_loop =
47      nir_cf_node_as_block(nir_cf_node_next(&loop->cf_node));
48
49   nir_lower_phis_to_regs_block(block_after_loop);
50
51   nir_instr *last_instr = nir_block_last_instr(nir_loop_last_block(loop));
52   if (last_instr && last_instr->type == nir_instr_type_jump) {
53      assert(nir_instr_as_jump(last_instr)->type == nir_jump_continue);
54      nir_instr_remove(last_instr);
55   }
56}
57
58static void
59get_first_blocks_in_terminator(nir_loop_terminator *term,
60                               nir_block **first_break_block,
61                               nir_block **first_continue_block)
62{
63   if (term->continue_from_then) {
64      *first_continue_block = nir_if_first_then_block(term->nif);
65      *first_break_block = nir_if_first_else_block(term->nif);
66   } else {
67      *first_continue_block = nir_if_first_else_block(term->nif);
68      *first_break_block = nir_if_first_then_block(term->nif);
69   }
70}
71
72/**
73 * Unroll a loop where we know exactly how many iterations there are and there
74 * is only a single exit point.  Note here we can unroll loops with multiple
75 * theoretical exits that only have a single terminating exit that we always
76 * know is the "real" exit.
77 *
78 *     loop {
79 *         ...instrs...
80 *     }
81 *
82 * And the iteration count is 3, the output will be:
83 *
84 *     ...instrs... ...instrs... ...instrs...
85 */
86static void
87simple_unroll(nir_loop *loop)
88{
89   nir_loop_terminator *limiting_term = loop->info->limiting_terminator;
90   assert(nir_is_trivial_loop_if(limiting_term->nif,
91                                 limiting_term->break_block));
92
93   loop_prepare_for_unroll(loop);
94
95   /* Skip over loop terminator and get the loop body. */
96   list_for_each_entry(nir_loop_terminator, terminator,
97                       &loop->info->loop_terminator_list,
98                       loop_terminator_link) {
99
100      /* Remove all but the limiting terminator as we know the other exit
101       * conditions can never be met. Note we need to extract any instructions
102       * in the continue from branch and insert then into the loop body before
103       * removing it.
104       */
105      if (terminator->nif != limiting_term->nif) {
106         nir_block *first_break_block;
107         nir_block *first_continue_block;
108         get_first_blocks_in_terminator(terminator, &first_break_block,
109                                        &first_continue_block);
110
111         assert(nir_is_trivial_loop_if(terminator->nif,
112                                       terminator->break_block));
113
114         nir_cf_list continue_from_lst;
115         nir_cf_extract(&continue_from_lst,
116                        nir_before_block(first_continue_block),
117                        nir_after_block(terminator->continue_from_block));
118         nir_cf_reinsert(&continue_from_lst,
119                         nir_after_cf_node(&terminator->nif->cf_node));
120
121         nir_cf_node_remove(&terminator->nif->cf_node);
122      }
123   }
124
125   nir_block *first_break_block;
126   nir_block *first_continue_block;
127   get_first_blocks_in_terminator(limiting_term, &first_break_block,
128                                  &first_continue_block);
129
130   /* Pluck out the loop header */
131   nir_block *header_blk = nir_loop_first_block(loop);
132   nir_cf_list lp_header;
133   nir_cf_extract(&lp_header, nir_before_block(header_blk),
134                  nir_before_cf_node(&limiting_term->nif->cf_node));
135
136   /* Add the continue from block of the limiting terminator to the loop body
137    */
138   nir_cf_list continue_from_lst;
139   nir_cf_extract(&continue_from_lst, nir_before_block(first_continue_block),
140                  nir_after_block(limiting_term->continue_from_block));
141   nir_cf_reinsert(&continue_from_lst,
142                   nir_after_cf_node(&limiting_term->nif->cf_node));
143
144   /* Pluck out the loop body */
145   nir_cf_list loop_body;
146   nir_cf_extract(&loop_body, nir_after_cf_node(&limiting_term->nif->cf_node),
147                  nir_after_block(nir_loop_last_block(loop)));
148
149   struct hash_table *remap_table =
150      _mesa_hash_table_create(NULL, _mesa_hash_pointer,
151                              _mesa_key_pointer_equal);
152
153   /* Clone the loop header */
154   nir_cf_list cloned_header;
155   nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
156                     remap_table);
157
158   /* Insert cloned loop header before the loop */
159   nir_cf_reinsert(&cloned_header, nir_before_cf_node(&loop->cf_node));
160
161   /* Temp list to store the cloned loop body as we unroll */
162   nir_cf_list unrolled_lp_body;
163
164   /* Clone loop header and append to the loop body */
165   for (unsigned i = 0; i < loop->info->trip_count; i++) {
166      /* Clone loop body */
167      nir_cf_list_clone(&unrolled_lp_body, &loop_body, loop->cf_node.parent,
168                        remap_table);
169
170      /* Insert unrolled loop body before the loop */
171      nir_cf_reinsert(&unrolled_lp_body, nir_before_cf_node(&loop->cf_node));
172
173      /* Clone loop header */
174      nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
175                        remap_table);
176
177      /* Insert loop header after loop body */
178      nir_cf_reinsert(&cloned_header, nir_before_cf_node(&loop->cf_node));
179   }
180
181   /* Remove the break from the loop terminator and add instructions from
182    * the break block after the unrolled loop.
183    */
184   nir_instr *break_instr = nir_block_last_instr(limiting_term->break_block);
185   nir_instr_remove(break_instr);
186   nir_cf_list break_list;
187   nir_cf_extract(&break_list, nir_before_block(first_break_block),
188                  nir_after_block(limiting_term->break_block));
189
190   /* Clone so things get properly remapped */
191   nir_cf_list cloned_break_list;
192   nir_cf_list_clone(&cloned_break_list, &break_list, loop->cf_node.parent,
193                     remap_table);
194
195   nir_cf_reinsert(&cloned_break_list, nir_before_cf_node(&loop->cf_node));
196
197   /* Remove the loop */
198   nir_cf_node_remove(&loop->cf_node);
199
200   /* Delete the original loop body, break block & header */
201   nir_cf_delete(&lp_header);
202   nir_cf_delete(&loop_body);
203   nir_cf_delete(&break_list);
204
205   _mesa_hash_table_destroy(remap_table, NULL);
206}
207
208static void
209move_cf_list_into_loop_term(nir_cf_list *lst, nir_loop_terminator *term)
210{
211   /* Move the rest of the loop inside the continue-from-block */
212   nir_cf_reinsert(lst, nir_after_block(term->continue_from_block));
213
214   /* Remove the break */
215   nir_instr_remove(nir_block_last_instr(term->break_block));
216}
217
218static nir_cursor
219get_complex_unroll_insert_location(nir_cf_node *node, bool continue_from_then)
220{
221   if (node->type == nir_cf_node_loop) {
222      return nir_before_cf_node(node);
223   } else {
224      nir_if *if_stmt = nir_cf_node_as_if(node);
225      if (continue_from_then) {
226         return nir_after_block(nir_if_last_then_block(if_stmt));
227      } else {
228         return nir_after_block(nir_if_last_else_block(if_stmt));
229      }
230   }
231}
232
233/**
234 * Unroll a loop with two exists when the trip count of one of the exits is
235 * unknown.  If continue_from_then is true, the loop is repeated only when the
236 * "then" branch of the if is taken; otherwise it is repeated only
237 * when the "else" branch of the if is taken.
238 *
239 * For example, if the input is:
240 *
241 *      loop {
242 *         ...phis/condition...
243 *         if condition {
244 *            ...then instructions...
245 *         } else {
246 *            ...continue instructions...
247 *            break
248 *         }
249 *         ...body...
250 *      }
251 *
252 * And the iteration count is 3, and unlimit_term->continue_from_then is true,
253 * then the output will be:
254 *
255 *      ...condition...
256 *      if condition {
257 *         ...then instructions...
258 *         ...body...
259 *         if condition {
260 *            ...then instructions...
261 *            ...body...
262 *            if condition {
263 *               ...then instructions...
264 *               ...body...
265 *            } else {
266 *               ...continue instructions...
267 *            }
268 *         } else {
269 *            ...continue instructions...
270 *         }
271 *      } else {
272 *         ...continue instructions...
273 *      }
274 */
275static void
276complex_unroll(nir_loop *loop, nir_loop_terminator *unlimit_term,
277               bool limiting_term_second)
278{
279   assert(nir_is_trivial_loop_if(unlimit_term->nif,
280                                 unlimit_term->break_block));
281
282   nir_loop_terminator *limiting_term = loop->info->limiting_terminator;
283   assert(nir_is_trivial_loop_if(limiting_term->nif,
284                                 limiting_term->break_block));
285
286   loop_prepare_for_unroll(loop);
287
288   nir_block *header_blk = nir_loop_first_block(loop);
289
290   nir_cf_list lp_header;
291   nir_cf_list limit_break_list;
292   unsigned num_times_to_clone;
293   if (limiting_term_second) {
294      /* Pluck out the loop header */
295      nir_cf_extract(&lp_header, nir_before_block(header_blk),
296                     nir_before_cf_node(&unlimit_term->nif->cf_node));
297
298      /* We need some special handling when its the second terminator causing
299       * us to exit the loop for example:
300       *
301       *   for (int i = 0; i < uniform_lp_count; i++) {
302       *      colour = vec4(0.0, 1.0, 0.0, 1.0);
303       *
304       *      if (i == 1) {
305       *         break;
306       *      }
307       *      ... any further code is unreachable after i == 1 ...
308       *   }
309       */
310      nir_cf_list after_lt;
311      nir_if *limit_if = limiting_term->nif;
312      nir_cf_extract(&after_lt, nir_after_cf_node(&limit_if->cf_node),
313                     nir_after_block(nir_loop_last_block(loop)));
314      move_cf_list_into_loop_term(&after_lt, limiting_term);
315
316      /* Because the trip count is the number of times we pass over the entire
317       * loop before hitting a break when the second terminator is the
318       * limiting terminator we can actually execute code inside the loop when
319       * trip count == 0 e.g. the code above the break.  So we need to bump
320       * the trip_count in order for the code below to clone anything.  When
321       * trip count == 1 we execute the code above the break twice and the
322       * code below it once so we need clone things twice and so on.
323       */
324      num_times_to_clone = loop->info->trip_count + 1;
325   } else {
326      /* Pluck out the loop header */
327      nir_cf_extract(&lp_header, nir_before_block(header_blk),
328                     nir_before_cf_node(&limiting_term->nif->cf_node));
329
330      nir_block *first_break_block;
331      nir_block *first_continue_block;
332      get_first_blocks_in_terminator(limiting_term, &first_break_block,
333                                     &first_continue_block);
334
335      /* Remove the break then extract instructions from the break block so we
336       * can insert them in the innermost else of the unrolled loop.
337       */
338      nir_instr *break_instr = nir_block_last_instr(limiting_term->break_block);
339      nir_instr_remove(break_instr);
340      nir_cf_extract(&limit_break_list, nir_before_block(first_break_block),
341                     nir_after_block(limiting_term->break_block));
342
343      nir_cf_list continue_list;
344      nir_cf_extract(&continue_list, nir_before_block(first_continue_block),
345                     nir_after_block(limiting_term->continue_from_block));
346
347      nir_cf_reinsert(&continue_list,
348                      nir_after_cf_node(&limiting_term->nif->cf_node));
349
350      nir_cf_node_remove(&limiting_term->nif->cf_node);
351
352      num_times_to_clone = loop->info->trip_count;
353   }
354
355   /* In the terminator that we have no trip count for move everything after
356    * the terminator into the continue from branch.
357    */
358   nir_cf_list loop_end;
359   nir_cf_extract(&loop_end, nir_after_cf_node(&unlimit_term->nif->cf_node),
360                  nir_after_block(nir_loop_last_block(loop)));
361   move_cf_list_into_loop_term(&loop_end, unlimit_term);
362
363   /* Pluck out the loop body. */
364   nir_cf_list loop_body;
365   nir_cf_extract(&loop_body, nir_before_block(nir_loop_first_block(loop)),
366                  nir_after_block(nir_loop_last_block(loop)));
367
368   struct hash_table *remap_table =
369      _mesa_hash_table_create(NULL, _mesa_hash_pointer,
370                              _mesa_key_pointer_equal);
371
372   /* Set unroll_loc to the loop as we will insert the unrolled loop before it
373    */
374   nir_cf_node *unroll_loc = &loop->cf_node;
375
376   /* Temp lists to store the cloned loop as we unroll */
377   nir_cf_list unrolled_lp_body;
378   nir_cf_list cloned_header;
379
380   for (unsigned i = 0; i < num_times_to_clone; i++) {
381      /* Clone loop header */
382      nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
383                        remap_table);
384
385      nir_cursor cursor =
386         get_complex_unroll_insert_location(unroll_loc,
387                                            unlimit_term->continue_from_then);
388
389      /* Insert cloned loop header */
390      nir_cf_reinsert(&cloned_header, cursor);
391
392      cursor =
393         get_complex_unroll_insert_location(unroll_loc,
394                                            unlimit_term->continue_from_then);
395
396      /* Clone loop body */
397      nir_cf_list_clone(&unrolled_lp_body, &loop_body, loop->cf_node.parent,
398                        remap_table);
399
400      unroll_loc = exec_node_data(nir_cf_node,
401                                  exec_list_get_tail(&unrolled_lp_body.list),
402                                  node);
403      assert(unroll_loc->type == nir_cf_node_block &&
404             exec_list_is_empty(&nir_cf_node_as_block(unroll_loc)->instr_list));
405
406      /* Get the unrolled if node */
407      unroll_loc = nir_cf_node_prev(unroll_loc);
408
409      /* Insert unrolled loop body */
410      nir_cf_reinsert(&unrolled_lp_body, cursor);
411   }
412
413   if (!limiting_term_second) {
414      assert(unroll_loc->type == nir_cf_node_if);
415
416      nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
417                        remap_table);
418
419      nir_cursor cursor =
420         get_complex_unroll_insert_location(unroll_loc,
421                                            unlimit_term->continue_from_then);
422
423      /* Insert cloned loop header */
424      nir_cf_reinsert(&cloned_header, cursor);
425
426      /* Clone so things get properly remapped, and insert break block from
427       * the limiting terminator.
428       */
429      nir_cf_list cloned_break_blk;
430      nir_cf_list_clone(&cloned_break_blk, &limit_break_list,
431                        loop->cf_node.parent, remap_table);
432
433      cursor =
434         get_complex_unroll_insert_location(unroll_loc,
435                                            unlimit_term->continue_from_then);
436
437      nir_cf_reinsert(&cloned_break_blk, cursor);
438      nir_cf_delete(&limit_break_list);
439   }
440
441   /* The loop has been unrolled so remove it. */
442   nir_cf_node_remove(&loop->cf_node);
443
444   /* Delete the original loop header and body */
445   nir_cf_delete(&lp_header);
446   nir_cf_delete(&loop_body);
447
448   _mesa_hash_table_destroy(remap_table, NULL);
449}
450
451static bool
452is_loop_small_enough_to_unroll(nir_shader *shader, nir_loop_info *li)
453{
454   unsigned max_iter = shader->options->max_unroll_iterations;
455
456   if (li->trip_count > max_iter)
457      return false;
458
459   if (li->force_unroll)
460      return true;
461
462   bool loop_not_too_large =
463      li->num_instructions * li->trip_count <= max_iter * 25;
464
465   return loop_not_too_large;
466}
467
468static bool
469process_loops(nir_shader *sh, nir_cf_node *cf_node, bool *innermost_loop)
470{
471   bool progress = false;
472   nir_loop *loop;
473
474   switch (cf_node->type) {
475   case nir_cf_node_block:
476      return progress;
477   case nir_cf_node_if: {
478      nir_if *if_stmt = nir_cf_node_as_if(cf_node);
479      foreach_list_typed_safe(nir_cf_node, nested_node, node, &if_stmt->then_list)
480         progress |= process_loops(sh, nested_node, innermost_loop);
481      foreach_list_typed_safe(nir_cf_node, nested_node, node, &if_stmt->else_list)
482         progress |= process_loops(sh, nested_node, innermost_loop);
483      return progress;
484   }
485   case nir_cf_node_loop: {
486      loop = nir_cf_node_as_loop(cf_node);
487      foreach_list_typed_safe(nir_cf_node, nested_node, node, &loop->body)
488         progress |= process_loops(sh, nested_node, innermost_loop);
489      break;
490   }
491   default:
492      unreachable("unknown cf node type");
493   }
494
495   if (*innermost_loop) {
496      /* Don't attempt to unroll outer loops or a second inner loop in
497       * this pass wait until the next pass as we have altered the cf.
498       */
499      *innermost_loop = false;
500
501      if (loop->info->limiting_terminator == NULL)
502         return progress;
503
504      if (!is_loop_small_enough_to_unroll(sh, loop->info))
505         return progress;
506
507      if (loop->info->is_trip_count_known) {
508         simple_unroll(loop);
509         progress = true;
510      } else {
511         /* Attempt to unroll loops with two terminators. */
512         unsigned num_lt = list_length(&loop->info->loop_terminator_list);
513         if (num_lt == 2) {
514            bool limiting_term_second = true;
515            nir_loop_terminator *terminator =
516               list_last_entry(&loop->info->loop_terminator_list,
517                                nir_loop_terminator, loop_terminator_link);
518
519
520            if (terminator->nif == loop->info->limiting_terminator->nif) {
521               limiting_term_second = false;
522               terminator =
523                  list_first_entry(&loop->info->loop_terminator_list,
524                                  nir_loop_terminator, loop_terminator_link);
525            }
526
527            /* If the first terminator has a trip count of zero and is the
528             * limiting terminator just do a simple unroll as the second
529             * terminator can never be reached.
530             */
531            if (loop->info->trip_count == 0 && !limiting_term_second) {
532               simple_unroll(loop);
533            } else {
534               complex_unroll(loop, terminator, limiting_term_second);
535            }
536            progress = true;
537         }
538      }
539   }
540
541   return progress;
542}
543
544static bool
545nir_opt_loop_unroll_impl(nir_function_impl *impl,
546                         nir_variable_mode indirect_mask)
547{
548   bool progress = false;
549   nir_metadata_require(impl, nir_metadata_loop_analysis, indirect_mask);
550   nir_metadata_require(impl, nir_metadata_block_index);
551
552   foreach_list_typed_safe(nir_cf_node, node, node, &impl->body) {
553      bool innermost_loop = true;
554      progress |= process_loops(impl->function->shader, node,
555                                &innermost_loop);
556   }
557
558   if (progress)
559      nir_lower_regs_to_ssa_impl(impl);
560
561   return progress;
562}
563
564bool
565nir_opt_loop_unroll(nir_shader *shader, nir_variable_mode indirect_mask)
566{
567   bool progress = false;
568
569   nir_foreach_function(function, shader) {
570      if (function->impl) {
571         progress |= nir_opt_loop_unroll_impl(function->impl, indirect_mask);
572      }
573   }
574   return progress;
575}
576