Module: Mesa
Branch: main
Commit: d4b97bf3fa587d9636f2d78b54e998203dc1d680
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=d4b97bf3fa587d9636f2d78b54e998203dc1d680

Author: Daniel Schürmann <[email protected]>
Date:   Wed Dec  1 17:34:48 2021 +0100

nir: add Continue Construct to nir_loop

The added continue_list corresponds to the SPIR-V
Continue Construct and serves as a converged control-flow
construct and is executed after each continue statement
and before the next iteration of the loop body.

Also adds validation rules for loops with Continue Construct

Reviewed-by: Faith Ekstrand <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13962>

---

 src/compiler/nir/nir.c              | 28 ++++++++++++++++++++++------
 src/compiler/nir/nir.h              | 35 +++++++++++++++++++++++++++++++++++
 src/compiler/nir/nir_clone.c        |  4 ++++
 src/compiler/nir/nir_control_flow.c | 20 +++++++++++++++-----
 src/compiler/nir/nir_print.c        |  9 +++++++++
 src/compiler/nir/nir_serialize.c    | 12 ++++++++++++
 src/compiler/nir/nir_validate.c     | 31 ++++++++++++++++++++++++++-----
 7 files changed, 123 insertions(+), 16 deletions(-)

diff --git a/src/compiler/nir/nir.c b/src/compiler/nir/nir.c
index 6fa1dfb5133..3199af67a4f 100644
--- a/src/compiler/nir/nir.c
+++ b/src/compiler/nir/nir.c
@@ -647,6 +647,8 @@ nir_loop_create(nir_shader *shader)
    body->successors[0] = body;
    _mesa_set_add(body->predecessors, body);
 
+   exec_list_make_empty(&loop->continue_list);
+
    return loop;
 }
 
@@ -1924,11 +1926,17 @@ nir_block_cf_tree_next(nir_block *block)
          return nir_if_first_else_block(if_stmt);
 
       assert(block == nir_if_last_else_block(if_stmt));
+      return nir_cf_node_as_block(nir_cf_node_next(parent));
    }
-   FALLTHROUGH;
 
-   case nir_cf_node_loop:
+   case nir_cf_node_loop: {
+      nir_loop *loop = nir_cf_node_as_loop(parent);
+      if (block == nir_loop_last_block(loop) &&
+          nir_loop_has_continue_construct(loop))
+         return nir_loop_first_continue_block(loop);
+
       return nir_cf_node_as_block(nir_cf_node_next(parent));
+   }
 
    case nir_cf_node_function:
       return NULL;
@@ -1962,12 +1970,17 @@ nir_block_cf_tree_prev(nir_block *block)
          return nir_if_last_then_block(if_stmt);
 
       assert(block == nir_if_first_then_block(if_stmt));
+      return nir_cf_node_as_block(nir_cf_node_prev(parent));
    }
-   FALLTHROUGH;
+   case nir_cf_node_loop: {
+      nir_loop *loop = nir_cf_node_as_loop(parent);
+      if (nir_loop_has_continue_construct(loop) &&
+          block == nir_loop_first_continue_block(loop))
+         return nir_loop_last_block(loop);
 
-   case nir_cf_node_loop:
+      assert(block == nir_loop_first_block(loop));
       return nir_cf_node_as_block(nir_cf_node_prev(parent));
-
+   }
    case nir_cf_node_function:
       return NULL;
 
@@ -2018,7 +2031,10 @@ nir_block *nir_cf_node_cf_tree_last(nir_cf_node *node)
 
    case nir_cf_node_loop: {
       nir_loop *loop = nir_cf_node_as_loop(node);
-      return nir_loop_last_block(loop);
+      if (nir_loop_has_continue_construct(loop))
+         return nir_loop_last_continue_block(loop);
+      else
+         return nir_loop_last_block(loop);
    }
 
    case nir_cf_node_block: {
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index 344d604352c..18ede24f8c2 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -2974,6 +2974,7 @@ typedef struct {
    nir_cf_node cf_node;
 
    struct exec_list body; /** < list of nir_cf_node */
+   struct exec_list continue_list; /** < (optional) list of nir_cf_node */
 
    nir_loop_info *info;
    nir_loop_control control;
@@ -3209,6 +3210,40 @@ nir_loop_last_block(nir_loop *loop)
    return nir_cf_node_as_block(exec_node_data(nir_cf_node, tail, node));
 }
 
+static inline bool
+nir_loop_has_continue_construct(const nir_loop *loop)
+{
+   return !exec_list_is_empty(&loop->continue_list);
+}
+
+static inline nir_block *
+nir_loop_first_continue_block(nir_loop *loop)
+{
+   assert(nir_loop_has_continue_construct(loop));
+   struct exec_node *head = exec_list_get_head(&loop->continue_list);
+   return nir_cf_node_as_block(exec_node_data(nir_cf_node, head, node));
+}
+
+static inline nir_block *
+nir_loop_last_continue_block(nir_loop *loop)
+{
+   assert(nir_loop_has_continue_construct(loop));
+   struct exec_node *tail = exec_list_get_tail(&loop->continue_list);
+   return nir_cf_node_as_block(exec_node_data(nir_cf_node, tail, node));
+}
+
+/**
+ * Return the target block of a nir_jump_continue statement
+ */
+static inline nir_block *
+nir_loop_continue_target(nir_loop *loop)
+{
+   if (nir_loop_has_continue_construct(loop))
+      return nir_loop_first_continue_block(loop);
+   else
+      return nir_loop_first_block(loop);
+}
+
 /**
  * Return true if this list of cf_nodes contains a single empty block.
  */
diff --git a/src/compiler/nir/nir_clone.c b/src/compiler/nir/nir_clone.c
index 700f1b4ddb7..52682a1a185 100644
--- a/src/compiler/nir/nir_clone.c
+++ b/src/compiler/nir/nir_clone.c
@@ -598,6 +598,10 @@ clone_loop(clone_state *state, struct exec_list *cf_list, 
const nir_loop *loop)
    nir_cf_node_insert_end(cf_list, &nloop->cf_node);
 
    clone_cf_list(state, &nloop->body, &loop->body);
+   if (nir_loop_has_continue_construct(loop)) {
+      nir_loop_add_continue_construct(nloop);
+      clone_cf_list(state, &nloop->continue_list, &loop->continue_list);
+   }
 
    return nloop;
 }
diff --git a/src/compiler/nir/nir_control_flow.c 
b/src/compiler/nir/nir_control_flow.c
index 4973c60de5c..43d02f2e176 100644
--- a/src/compiler/nir/nir_control_flow.c
+++ b/src/compiler/nir/nir_control_flow.c
@@ -288,10 +288,16 @@ block_add_normal_succs(nir_block *block)
       } else if (parent->type == nir_cf_node_loop) {
          nir_loop *loop = nir_cf_node_as_loop(parent);
 
-         nir_block *head_block = nir_loop_first_block(loop);
+         nir_block *cont_block;
+         if (block == nir_loop_last_block(loop)) {
+            cont_block = nir_loop_continue_target(loop);
+         } else {
+            assert(block == nir_loop_last_continue_block(loop));
+            cont_block = nir_loop_first_block(loop);
+         }
 
-         link_blocks(block, head_block, NULL);
-         nir_insert_phi_undef(head_block, block);
+         link_blocks(block, cont_block, NULL);
+         nir_insert_phi_undef(cont_block, block);
       } else {
          nir_function_impl *impl = nir_cf_node_as_function(parent);
          link_blocks(block, impl->end_block, NULL);
@@ -482,8 +488,8 @@ nir_handle_add_jump(nir_block *block)
 
    case nir_jump_continue: {
       nir_loop *loop = nearest_loop(&block->cf_node);
-      nir_block *first_block = nir_loop_first_block(loop);
-      link_blocks(block, first_block, NULL);
+      nir_block *cont_block = nir_loop_continue_target(loop);
+      link_blocks(block, cont_block, NULL);
       break;
    }
 
@@ -665,6 +671,8 @@ cleanup_cf_node(nir_cf_node *node, nir_function_impl *impl)
       nir_loop *loop = nir_cf_node_as_loop(node);
       foreach_list_typed(nir_cf_node, child, node, &loop->body)
          cleanup_cf_node(child, impl);
+      foreach_list_typed(nir_cf_node, child, node, &loop->continue_list)
+         cleanup_cf_node(child, impl);
       break;
    }
    case nir_cf_node_function: {
@@ -780,6 +788,8 @@ relink_jump_halt_cf_node(nir_cf_node *node, nir_block 
*end_block)
       nir_loop *loop = nir_cf_node_as_loop(node);
       foreach_list_typed(nir_cf_node, child, node, &loop->body)
          relink_jump_halt_cf_node(child, end_block);
+      foreach_list_typed(nir_cf_node, child, node, &loop->continue_list)
+         relink_jump_halt_cf_node(child, end_block);
       break;
    }
 
diff --git a/src/compiler/nir/nir_print.c b/src/compiler/nir/nir_print.c
index 25868c8f779..f63401be27d 100644
--- a/src/compiler/nir/nir_print.c
+++ b/src/compiler/nir/nir_print.c
@@ -1658,6 +1658,15 @@ print_loop(nir_loop *loop, print_state *state, unsigned 
tabs)
       print_cf_node(node, state, tabs + 1);
    }
    print_tabs(tabs, fp);
+
+   if (nir_loop_has_continue_construct(loop)) {
+      fprintf(fp, "} continue {\n");
+      foreach_list_typed(nir_cf_node, node, node, &loop->continue_list) {
+         print_cf_node(node, state, tabs + 1);
+      }
+      print_tabs(tabs, fp);
+   }
+
    fprintf(fp, "}\n");
 }
 
diff --git a/src/compiler/nir/nir_serialize.c b/src/compiler/nir/nir_serialize.c
index 5ee571a6898..ae490e07348 100644
--- a/src/compiler/nir/nir_serialize.c
+++ b/src/compiler/nir/nir_serialize.c
@@ -1892,7 +1892,13 @@ write_loop(write_ctx *ctx, nir_loop *loop)
 {
    blob_write_uint8(ctx->blob, loop->control);
    blob_write_uint8(ctx->blob, loop->divergent);
+   bool has_continue_construct = nir_loop_has_continue_construct(loop);
+   blob_write_uint8(ctx->blob, has_continue_construct);
+
    write_cf_list(ctx, &loop->body);
+   if (has_continue_construct) {
+      write_cf_list(ctx, &loop->continue_list);
+   }
 }
 
 static void
@@ -1904,7 +1910,13 @@ read_loop(read_ctx *ctx, struct exec_list *cf_list)
 
    loop->control = blob_read_uint8(ctx->blob);
    loop->divergent = blob_read_uint8(ctx->blob);
+   bool has_continue_construct = blob_read_uint8(ctx->blob);
+
    read_cf_list(ctx, &loop->body);
+   if (has_continue_construct) {
+      nir_loop_add_continue_construct(loop);
+      read_cf_list(ctx, &loop->continue_list);
+   }
 }
 
 static void
diff --git a/src/compiler/nir/nir_validate.c b/src/compiler/nir/nir_validate.c
index 5e778109a73..82241c9f313 100644
--- a/src/compiler/nir/nir_validate.c
+++ b/src/compiler/nir/nir_validate.c
@@ -77,6 +77,9 @@ typedef struct {
    /* the current loop being visited */
    nir_loop *loop;
 
+   /* weather the loop continue construct is being visited */
+   bool in_loop_continue_construct;
+
    /* the parent of the current cf node being visited */
    nir_cf_node *parent_node;
 
@@ -1073,6 +1076,7 @@ validate_jump_instr(nir_jump_instr *instr, validate_state 
*state)
       validate_assert(state, block->successors[1] == NULL);
       validate_assert(state, instr->target == NULL);
       validate_assert(state, instr->else_target == NULL);
+      validate_assert(state, !state->in_loop_continue_construct);
       break;
 
    case nir_jump_break:
@@ -1092,12 +1096,13 @@ validate_jump_instr(nir_jump_instr *instr, 
validate_state *state)
       validate_assert(state, state->impl->structured);
       validate_assert(state, state->loop != NULL);
       if (state->loop) {
-         nir_block *first = nir_loop_first_block(state->loop);
-         validate_assert(state, block->successors[0] == first);
+         nir_block *cont_block = nir_loop_continue_target(state->loop);
+         validate_assert(state, block->successors[0] == cont_block);
       }
       validate_assert(state, block->successors[1] == NULL);
       validate_assert(state, instr->target == NULL);
       validate_assert(state, instr->else_target == NULL);
+      validate_assert(state, !state->in_loop_continue_construct);
       break;
 
    case nir_jump_goto:
@@ -1242,6 +1247,7 @@ collect_blocks(struct exec_list *cf_list, validate_state 
*state)
 
       case nir_cf_node_loop:
          collect_blocks(&nir_cf_node_as_loop(node)->body, state);
+         collect_blocks(&nir_cf_node_as_loop(node)->continue_list, state);
          break;
 
       default:
@@ -1310,8 +1316,15 @@ validate_block(nir_block *block, validate_state *state)
       if (next == NULL) {
          switch (state->parent_node->type) {
          case nir_cf_node_loop: {
-            nir_block *first = nir_loop_first_block(state->loop);
-            validate_assert(state, block->successors[0] == first);
+            if (block == nir_loop_last_block(state->loop)) {
+               nir_block *cont = nir_loop_continue_target(state->loop);
+               validate_assert(state, block->successors[0] == cont);
+            } else {
+               validate_assert(state, 
nir_loop_has_continue_construct(state->loop) &&
+                                      block == 
nir_loop_last_continue_block(state->loop));
+               nir_block *head = nir_loop_first_block(state->loop);
+               validate_assert(state, block->successors[0] == head);
+            }
             /* due to the hack for infinite loops, block->successors[1] may
              * point to the block after the loop.
              */
@@ -1421,14 +1434,21 @@ validate_loop(nir_loop *loop, validate_state *state)
    nir_cf_node *old_parent = state->parent_node;
    state->parent_node = &loop->cf_node;
    nir_loop *old_loop = state->loop;
+   bool old_continue_construct = state->in_loop_continue_construct;
    state->loop = loop;
+   state->in_loop_continue_construct = false;
 
    foreach_list_typed(nir_cf_node, cf_node, node, &loop->body) {
       validate_cf_node(cf_node, state);
    }
-
+   state->in_loop_continue_construct = true;
+   foreach_list_typed(nir_cf_node, cf_node, node, &loop->continue_list) {
+      validate_cf_node(cf_node, state);
+   }
+   state->in_loop_continue_construct = false;
    state->parent_node = old_parent;
    state->loop = old_loop;
+   state->in_loop_continue_construct = old_continue_construct;
 }
 
 static void
@@ -1742,6 +1762,7 @@ init_validate_state(validate_state *state)
    state->errors = _mesa_pointer_hash_table_create(state->mem_ctx);
 
    state->loop = NULL;
+   state->in_loop_continue_construct = false;
    state->instr = NULL;
    state->var = NULL;
 }

Reply via email to