Module: Mesa
Branch: main
Commit: 94eff7ccd86658603155261c2fd59491786e7047
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=94eff7ccd86658603155261c2fd59491786e7047

Author: Pavel Ondračka <[email protected]>
Date:   Tue Jan 31 13:20:53 2023 +0100

nir: shrink phi nodes in nir_opt_shrink_vectors

While this change helps with few shaders, the main benefit is
that it allows to unroll loops comming from nine+ttn on vec4
backends. D3D9 REP ... ENDREP type loops are unrolled now already,
LOOP ... ENDLOOP need some nine changes that will come later.

r300 RV530 shader-db:
total instructions in shared programs: 132481 -> 132344 (-0.10%)
instructions in affected programs: 3532 -> 3395 (-3.88%)
helped: 13
HURT: 0

total temps in shared programs: 16961 -> 16957 (-0.02%)
temps in affected programs: 88 -> 84 (-4.55%)
helped: 4
HURT: 0

Reviewed-by: Emma Anholt <[email protected]>
Signed-off-by: Pavel Ondračka <[email protected]>
Partial fix for: https://gitlab.freedesktop.org/mesa/mesa/-/issues/8102
Partial fix for: https://gitlab.freedesktop.org/mesa/mesa/-/issues/7222

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21038>

---

 src/compiler/nir/nir_opt_shrink_vectors.c          | 100 ++++++
 .../nir/tests/opt_shrink_vectors_tests.cpp         | 336 +++++++++++++++++++++
 2 files changed, 436 insertions(+)

diff --git a/src/compiler/nir/nir_opt_shrink_vectors.c 
b/src/compiler/nir/nir_opt_shrink_vectors.c
index 90f44e3f301..80481989184 100644
--- a/src/compiler/nir/nir_opt_shrink_vectors.c
+++ b/src/compiler/nir/nir_opt_shrink_vectors.c
@@ -349,6 +349,103 @@ opt_shrink_vectors_ssa_undef(nir_ssa_undef_instr *instr)
    return shrink_dest_to_read_mask(&instr->def);
 }
 
+static bool
+opt_shrink_vectors_phi(nir_builder *b, nir_phi_instr *instr)
+{
+   nir_ssa_def *def = &instr->dest.ssa;
+
+   /* early out if there's nothing to do. */
+   if (def->num_components == 1)
+      return false;
+
+   /* Ignore large vectors for now. */
+   if (def->num_components > 4)
+      return false;
+
+
+   /* Check the uses. */
+   nir_component_mask_t mask = 0;
+   nir_foreach_use(src, def) {
+      if (src->parent_instr->type != nir_instr_type_alu)
+         return false;
+
+      nir_alu_instr *alu = nir_instr_as_alu(src->parent_instr);
+
+      nir_alu_src *alu_src = exec_node_data(nir_alu_src, src, src);
+      int src_idx = alu_src - &alu->src[0];
+      nir_component_mask_t src_read_mask = nir_alu_instr_src_read_mask(alu, 
src_idx);
+
+      nir_ssa_def *alu_def = &alu->dest.dest.ssa;
+
+      /* We don't mark the channels used if the only reader is the original 
phi.
+       * This can happen in the case of loops.
+       */
+      nir_foreach_use(alu_use_src, alu_def) {
+         if (alu_use_src->parent_instr != &instr->instr) {
+            mask |= src_read_mask;
+         }
+      }
+
+      /* However, even if the instruction only points back at the phi, we still
+       * need to check that the swizzles are trivial.
+       */
+      if (nir_op_is_vec(alu->op)) {
+         if (src_idx != alu->src[src_idx].swizzle[0]) {
+            mask |= src_read_mask;
+         }
+      } else if (!nir_alu_src_is_trivial_ssa(alu, src_idx)) {
+         mask |= src_read_mask;
+      }
+   }
+
+   /* DCE will handle this. */
+   if (mask == 0)
+      return false;
+
+   /* Nothing to shrink? */
+   if (BITFIELD_MASK(def->num_components) == mask)
+      return false;
+
+   /* Set up the reswizzles. */
+   unsigned num_components = 0;
+   uint8_t reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
+   uint8_t src_reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
+   for (unsigned i = 0; i < def->num_components; i++) {
+      if (!((mask >> i) & 0x1))
+         continue;
+      src_reswizzle[num_components] = i;
+      reswizzle[i] = num_components++;
+   }
+
+   /* Shrink the phi, this part is simple. */
+   def->num_components = num_components;
+
+   /* We can't swizzle phi sources directly so just insert extra mov
+    * with the correct swizzle and let the other parts of nir_shrink_vectors
+    * do its job on the original source instruction. If the original source was
+    * used only in the phi, the movs will disappear later after copy propagate.
+    */
+   nir_foreach_phi_src(phi_src, instr) {
+      b->cursor = nir_after_instr_and_phis(phi_src->src.ssa->parent_instr);
+
+      nir_alu_src alu_src = {
+         .src = nir_src_for_ssa(phi_src->src.ssa)
+      };
+
+      for (unsigned i = 0; i < num_components; i++)
+         alu_src.swizzle[i] = src_reswizzle[i];
+      nir_ssa_def *mov = nir_mov_alu(b, alu_src, num_components);
+
+      nir_instr_rewrite_src_ssa(&instr->instr, &phi_src->src, mov);
+   }
+   b->cursor = nir_before_instr(&instr->instr);
+
+   /* Reswizzle readers. */
+   reswizzle_alu_uses(def, reswizzle);
+
+   return true;
+}
+
 static bool
 opt_shrink_vectors_instr(nir_builder *b, nir_instr *instr)
 {
@@ -367,6 +464,9 @@ opt_shrink_vectors_instr(nir_builder *b, nir_instr *instr)
    case nir_instr_type_ssa_undef:
       return opt_shrink_vectors_ssa_undef(nir_instr_as_ssa_undef(instr));
 
+   case nir_instr_type_phi:
+      return opt_shrink_vectors_phi(b, nir_instr_as_phi(instr));
+
    default:
       return false;
    }
diff --git a/src/compiler/nir/tests/opt_shrink_vectors_tests.cpp 
b/src/compiler/nir/tests/opt_shrink_vectors_tests.cpp
index 49329115a20..a9b86d76218 100644
--- a/src/compiler/nir/tests/opt_shrink_vectors_tests.cpp
+++ b/src/compiler/nir/tests/opt_shrink_vectors_tests.cpp
@@ -278,3 +278,339 @@ TEST_F(nir_opt_shrink_vectors_test, 
opt_shrink_vectors_vec8)
 
    nir_validate_shader(bld.shader, NULL);
 }
+
+TEST_F(nir_opt_shrink_vectors_test, opt_shrink_phis_loop_simple)
+{
+   /* Test that the phi is shrinked in the following case.
+    *
+    *    v = vec4(0.0, 0.0, 0.0, 0.0);
+    *    while (v.y < 3) {
+    *       v.y += 1.0;
+    *    }
+    *
+    * This mimics nir for loops that come out of nine+ttn.
+    */
+   nir_ssa_def *v = nir_imm_vec4(&bld, 0.0, 0.0, 0.0, 0.0);
+   nir_ssa_def *increment = nir_imm_float(&bld, 1.0);
+   nir_ssa_def *loop_max = nir_imm_float(&bld, 3.0);
+
+   nir_phi_instr *const phi = nir_phi_instr_create(bld.shader);
+   nir_ssa_def *phi_def = &phi->dest.ssa;
+
+   nir_loop *loop = nir_push_loop(&bld);
+
+   nir_ssa_dest_init(&phi->instr, &phi->dest,
+                     v->num_components, v->bit_size,
+                     NULL);
+
+   nir_phi_instr_add_src(phi, v->parent_instr->block,
+                         nir_src_for_ssa(v));
+
+   nir_ssa_def *fge = nir_fge(&bld, phi_def, loop_max);
+   nir_alu_instr *fge_alu_instr = nir_instr_as_alu(fge->parent_instr);
+   fge->num_components = 1;
+   fge_alu_instr->dest.write_mask = BITFIELD_MASK(1);
+   fge_alu_instr->src[0].swizzle[0] = 1;
+
+   nir_if *nif = nir_push_if(&bld, fge);
+   {
+      nir_jump_instr *jump = nir_jump_instr_create(bld.shader, nir_jump_break);
+      nir_builder_instr_insert(&bld, &jump->instr);
+   }
+   nir_pop_if(&bld, nif);
+
+   nir_ssa_def *fadd = nir_fadd(&bld, phi_def, increment);
+   nir_alu_instr *fadd_alu_instr = nir_instr_as_alu(fadd->parent_instr);
+   fadd->num_components = 1;
+   fadd_alu_instr->dest.write_mask = BITFIELD_MASK(1);
+   fadd_alu_instr->src[0].swizzle[0] = 1;
+
+   nir_ssa_scalar srcs[4] = {{0}};
+   for (unsigned i = 0; i < 4; i++) {
+      srcs[i] = nir_get_ssa_scalar(phi_def, i);
+   }
+   srcs[1] = nir_get_ssa_scalar(fadd, 0);
+   nir_ssa_def *vec = nir_vec_scalars(&bld, srcs, 4);
+
+   nir_phi_instr_add_src(phi, vec->parent_instr->block,
+                         nir_src_for_ssa(vec));
+
+   nir_pop_loop(&bld, loop);
+
+   bld.cursor = nir_before_block(nir_loop_first_block(loop));
+   nir_builder_instr_insert(&bld, &phi->instr);
+
+   /* Generated nir:
+    *
+    * impl main {
+    *         block block_0:
+    *         * preds: *
+    *         vec1 32 ssa_0 = deref_var &in (shader_in vec2)
+    *         vec2 32 ssa_1 = intrinsic load_deref (ssa_0) (access=0)
+    *         vec4 32 ssa_2 = load_const (0x00000000, 0x00000000, 0x00000000, 
0x00000000) = (0.000000, 0.000000, 0.000000, 0.000000)
+    *         vec1 32 ssa_3 = load_const (0x3f800000 = 1.000000)
+    *         vec1 32 ssa_4 = load_const (0x40400000 = 3.000000)
+    *         * succs: block_1 *
+    *         loop {
+    *                 block block_1:
+    *                 * preds: block_0 block_4 *
+    *                 vec4 32 ssa_8 = phi block_0: ssa_2, block_4: ssa_7
+    *                 vec1  1 ssa_5 = fge ssa_8.y, ssa_4
+    *                 * succs: block_2 block_3 *
+    *                 if ssa_5 {
+    *                         block block_2:
+    *                         * preds: block_1 *
+    *                         break
+    *                         * succs: block_5 *
+    *                 } else {
+    *                         block block_3:
+    *                         * preds: block_1 *
+    *                         * succs: block_4 *
+    *                 }
+    *                 block block_4:
+    *                 * preds: block_3 *
+    *                 vec1 32 ssa_6 = fadd ssa_8.y, ssa_3
+    *                 vec4 32 ssa_7 = vec4 ssa_8.x, ssa_6, ssa_8.z, ssa_8.w
+    *                 * succs: block_1 *
+    *         }
+    *         block block_5:
+    *         * preds: block_2 *
+    *         * succs: block_6 *
+    *         block block_6:
+    * }
+    */
+
+   nir_validate_shader(bld.shader, NULL);
+
+   ASSERT_TRUE(nir_opt_shrink_vectors(bld.shader));
+   ASSERT_TRUE(phi_def->num_components == 1);
+   check_swizzle(&fge_alu_instr->src[0], "x");
+   check_swizzle(&fadd_alu_instr->src[0], "x");
+
+   nir_validate_shader(bld.shader, NULL);
+}
+
+TEST_F(nir_opt_shrink_vectors_test, opt_shrink_phis_loop_swizzle)
+{
+   /* Test that the phi is shrinked properly in the following case where
+    * some swizzling happens in the channels.
+    *
+    *    v = vec4(0.0, 0.0, 0.0, 0.0);
+    *    while (v.z < 3) {
+    *       v = vec4(v.x, v.z + 1, v.y, v.w};
+    *    }
+    */
+   nir_ssa_def *v = nir_imm_vec4(&bld, 0.0, 0.0, 0.0, 0.0);
+   nir_ssa_def *increment = nir_imm_float(&bld, 1.0);
+   nir_ssa_def *loop_max = nir_imm_float(&bld, 3.0);
+
+   nir_phi_instr *const phi = nir_phi_instr_create(bld.shader);
+   nir_ssa_def *phi_def = &phi->dest.ssa;
+
+   nir_loop *loop = nir_push_loop(&bld);
+
+   nir_ssa_dest_init(&phi->instr, &phi->dest,
+                     v->num_components, v->bit_size,
+                     NULL);
+
+   nir_phi_instr_add_src(phi, v->parent_instr->block,
+                         nir_src_for_ssa(v));
+
+   nir_ssa_def *fge = nir_fge(&bld, phi_def, loop_max);
+   nir_alu_instr *fge_alu_instr = nir_instr_as_alu(fge->parent_instr);
+   fge->num_components = 1;
+   fge_alu_instr->dest.write_mask = BITFIELD_MASK(1);
+   fge_alu_instr->src[0].swizzle[0] = 2;
+
+   nir_if *nif = nir_push_if(&bld, fge);
+
+      nir_jump_instr *jump = nir_jump_instr_create(bld.shader, nir_jump_break);
+      nir_builder_instr_insert(&bld, &jump->instr);
+
+   nir_pop_if(&bld, nif);
+
+   nir_ssa_def *fadd = nir_fadd(&bld, phi_def, increment);
+   nir_alu_instr *fadd_alu_instr = nir_instr_as_alu(fadd->parent_instr);
+   fadd->num_components = 1;
+   fadd_alu_instr->dest.write_mask = BITFIELD_MASK(1);
+   fadd_alu_instr->src[0].swizzle[0] = 2;
+
+   nir_ssa_scalar srcs[4] = {{0}};
+   srcs[0] = nir_get_ssa_scalar(phi_def, 0);
+   srcs[1] = nir_get_ssa_scalar(fadd, 0);
+   srcs[2] = nir_get_ssa_scalar(phi_def, 1);
+   srcs[3] = nir_get_ssa_scalar(phi_def, 3);
+   nir_ssa_def *vec = nir_vec_scalars(&bld, srcs, 4);
+
+   nir_phi_instr_add_src(phi, vec->parent_instr->block,
+                         nir_src_for_ssa(vec));
+
+   nir_pop_loop(&bld, loop);
+
+   bld.cursor = nir_before_block(nir_loop_first_block(loop));
+   nir_builder_instr_insert(&bld, &phi->instr);
+
+   /* Generated nir:
+    *
+    * impl main {
+    *         block block_0:
+    *         * preds: *
+    *         vec1 32 ssa_0 = deref_var &in (shader_in vec2)
+    *         vec2 32 ssa_1 = intrinsic load_deref (ssa_0) (access=0)
+    *         vec4 32 ssa_2 = load_const (0x00000000, 0x00000000, 0x00000000, 
0x00000000) = (0.000000, 0.000000, 0.000000, 0.000000)
+    *         vec1 32 ssa_3 = load_const (0x3f800000 = 1.000000)
+    *         vec1 32 ssa_4 = load_const (0x40400000 = 3.000000)
+    *         * succs: block_1 *
+    *         loop {
+    *                 block block_1:
+    *                 * preds: block_0 block_4 *
+    *                 vec4 32 ssa_8 = phi block_0: ssa_2, block_4: ssa_7
+    *                 vec1  1 ssa_5 = fge ssa_8.z, ssa_4
+    *                 * succs: block_2 block_3 *
+    *                 if ssa_5 {
+    *                         block block_2:
+    *                         * preds: block_1 *
+    *                         break
+    *                         * succs: block_5 *
+    *                 } else {
+    *                         block block_3:
+    *                         * preds: block_1 *
+    *                         * succs: block_4 *
+    *                 }
+    *                 block block_4:
+    *                 * preds: block_3 *
+    *                 vec1 32 ssa_6 = fadd ssa_8.z, ssa_3
+    *                 vec4 32 ssa_7 = vec4 ssa_8.x, ssa_6, ssa_8.y, ssa_8.w
+    *                 * succs: block_1 *
+    *         }
+    *         block block_5:
+    *         * preds: block_2 *
+    *         * succs: block_6 *
+    *         block block_6:
+    * }
+    */
+
+   nir_validate_shader(bld.shader, NULL);
+
+   ASSERT_TRUE(nir_opt_shrink_vectors(bld.shader));
+   ASSERT_TRUE(phi_def->num_components == 2);
+
+   check_swizzle(&fge_alu_instr->src[0], "y");
+   check_swizzle(&fadd_alu_instr->src[0], "y");
+
+   nir_validate_shader(bld.shader, NULL);
+}
+
+TEST_F(nir_opt_shrink_vectors_test, opt_shrink_phis_loop_phi_out)
+{
+   /* Test that the phi is not shrinked when used by intrinsic.
+    *
+    *    v = vec4(0.0, 0.0, 0.0, 0.0);
+    *    while (v.y < 3) {
+    *       v.y += 1.0;
+    *    }
+    *    out = v;
+    */
+   nir_ssa_def *v = nir_imm_vec4(&bld, 0.0, 0.0, 0.0, 0.0);
+   nir_ssa_def *increment = nir_imm_float(&bld, 1.0);
+   nir_ssa_def *loop_max = nir_imm_float(&bld, 3.0);
+
+   nir_phi_instr *const phi = nir_phi_instr_create(bld.shader);
+   nir_ssa_def *phi_def = &phi->dest.ssa;
+
+   nir_loop *loop = nir_push_loop(&bld);
+
+   nir_ssa_dest_init(&phi->instr, &phi->dest,
+                     v->num_components, v->bit_size,
+                     NULL);
+
+   nir_phi_instr_add_src(phi, v->parent_instr->block,
+                         nir_src_for_ssa(v));
+
+   nir_ssa_def *fge = nir_fge(&bld, phi_def, loop_max);
+   nir_alu_instr *fge_alu_instr = nir_instr_as_alu(fge->parent_instr);
+   fge->num_components = 1;
+   fge_alu_instr->dest.write_mask = BITFIELD_MASK(1);
+   fge_alu_instr->src[0].swizzle[0] = 1;
+
+   nir_if *nif = nir_push_if(&bld, fge);
+   {
+      nir_jump_instr *jump = nir_jump_instr_create(bld.shader, nir_jump_break);
+      nir_builder_instr_insert(&bld, &jump->instr);
+   }
+   nir_pop_if(&bld, nif);
+
+   nir_ssa_def *fadd = nir_fadd(&bld, phi_def, increment);
+   nir_alu_instr *fadd_alu_instr = nir_instr_as_alu(fadd->parent_instr);
+   fadd->num_components = 1;
+   fadd_alu_instr->dest.write_mask = BITFIELD_MASK(1);
+   fadd_alu_instr->src[0].swizzle[0] = 1;
+
+   nir_ssa_scalar srcs[4] = {{0}};
+   for (unsigned i = 0; i < 4; i++) {
+      srcs[i] = nir_get_ssa_scalar(phi_def, i);
+   }
+   srcs[1] = nir_get_ssa_scalar(fadd, 0);
+   nir_ssa_def *vec = nir_vec_scalars(&bld, srcs, 4);
+
+   nir_phi_instr_add_src(phi, vec->parent_instr->block,
+                         nir_src_for_ssa(vec));
+
+   nir_pop_loop(&bld, loop);
+
+   out_var = nir_variable_create(bld.shader,
+                                 nir_var_shader_out,
+                                 glsl_vec_type(4), "out4");
+
+   nir_store_var(&bld, out_var, phi_def, BITFIELD_MASK(4));
+
+   bld.cursor = nir_before_block(nir_loop_first_block(loop));
+   nir_builder_instr_insert(&bld, &phi->instr);
+
+   /* Generated nir:
+    *
+    * impl main {
+    *         block block_0:
+    *         * preds: *
+    *         vec1 32 ssa_0 = deref_var &in (shader_in vec2)
+    *         vec2 32 ssa_1 = intrinsic load_deref (ssa_0) (access=0)
+    *         vec4 32 ssa_2 = load_const (0x00000000, 0x00000000, 0x00000000, 
0x00000000) = (0.000000, 0.000000, 0.000000, 0.000000)
+    *         vec1 32 ssa_3 = load_const (0x3f800000 = 1.000000)
+    *         vec1 32 ssa_4 = load_const (0x40400000 = 3.000000)
+    *         * succs: block_1 *
+    *         loop {
+    *                 block block_1:
+    *                 * preds: block_0 block_4 *
+    *                 vec4 32 ssa_9 = phi block_0: ssa_2, block_4: ssa_7
+    *                 vec1  1 ssa_5 = fge ssa_9.y, ssa_4
+    *                 * succs: block_2 block_3 *
+    *                 if ssa_5 {
+    *                         block block_2:
+    *                         * preds: block_1 *
+    *                         break
+    *                         * succs: block_5 *
+    *                 } else {
+    *                         block block_3:
+    *                         * preds: block_1 *
+    *                         * succs: block_4 *
+    *                 }
+    *                 block block_4:
+    *                 * preds: block_3 *
+    *                 vec1 32 ssa_6 = fadd ssa_9.y, ssa_3
+    *                 vec4 32 ssa_7 = vec4 ssa_9.x, ssa_6, ssa_9.z, ssa_9.w
+    *                 * succs: block_1 *
+    *         }
+    *         block block_5:
+    *         * preds: block_2 *
+    *         vec1 32 ssa_8 = deref_var &out4 (shader_out vec4)
+    *         intrinsic store_deref (ssa_8, ssa_9) (wrmask=xyzw *15*, access=0)
+    *         * succs: block_6 *
+    *         block block_6:
+    * }
+    */
+
+   nir_validate_shader(bld.shader, NULL);
+
+   ASSERT_FALSE(nir_opt_shrink_vectors(bld.shader));
+   ASSERT_TRUE(phi_def->num_components == 4);
+}

Reply via email to