Module: Mesa
Branch: main
Commit: a74e98547c070568191f1800c821c6fd5257116a
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=a74e98547c070568191f1800c821c6fd5257116a

Author: Dave Airlie <[email protected]>
Date:   Mon Sep  4 11:26:47 2023 +1000

nir: don't inline linked functions

Don't inline linked functions here, let nir_inline_functions do the job
when we get to it.

Instead just copy over the implementation and any other pieces need.

Reviewed-by: Alyssa Rosenzweig <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24687>

---

 src/compiler/nir/nir_functions.c | 86 +++++++++++++++++++++++++++++++++-------
 1 file changed, 72 insertions(+), 14 deletions(-)

diff --git a/src/compiler/nir/nir_functions.c b/src/compiler/nir/nir_functions.c
index 5e7b5eadee4..2cdfcf0c1ce 100644
--- a/src/compiler/nir/nir_functions.c
+++ b/src/compiler/nir/nir_functions.c
@@ -264,6 +264,72 @@ struct lower_link_state {
    const nir_shader *link_shader;
 };
 
+static bool
+lower_calls_vars_instr(struct nir_builder *b,
+                       nir_instr *instr,
+                       void *cb_data)
+{
+   struct lower_link_state *state = cb_data;
+
+   switch (instr->type) {
+   case nir_instr_type_deref: {
+      nir_deref_instr *deref = nir_instr_as_deref(instr);
+      if (deref->deref_type != nir_deref_type_var)
+         return false;
+      if (deref->var->data.mode == nir_var_function_temp)
+         return false;
+
+      assert(state->shader_var_remap);
+      struct hash_entry *entry =
+         _mesa_hash_table_search(state->shader_var_remap, deref->var);
+      if (entry == NULL) {
+         nir_variable *nvar = nir_variable_clone(deref->var, b->shader);
+         nir_shader_add_variable(b->shader, nvar);
+         entry = _mesa_hash_table_insert(state->shader_var_remap,
+                                         deref->var, nvar);
+      }
+      deref->var = entry->data;
+      break;
+   }
+   case nir_instr_type_call: {
+      nir_call_instr *ncall = nir_instr_as_call(instr);
+      if (!ncall->callee->name)
+         return false;
+
+      nir_function *func = nir_shader_get_function_for_name(b->shader, 
ncall->callee->name);
+      if (func) {
+         ncall->callee = func;
+         break;
+      }
+
+      nir_function *new_func;
+      new_func = nir_shader_get_function_for_name(state->link_shader, 
ncall->callee->name);
+      if (new_func)
+         ncall->callee = nir_function_clone(b->shader, new_func);
+      break;
+   }
+   default:
+      break;
+   }
+   return true;
+}
+
+static bool
+lower_call_function_impl(struct nir_builder *b,
+                         nir_function *callee,
+                         const nir_function_impl *impl,
+                         struct lower_link_state *state)
+{
+   nir_function_impl *copy = nir_function_impl_clone(b->shader, impl);
+   copy->function = callee;
+   callee->impl = copy;
+
+   return nir_function_instructions_pass(copy,
+                                         lower_calls_vars_instr,
+                                         nir_metadata_none,
+                                         state);
+}
+
 static bool
 function_link_pass(struct nir_builder *b,
                    nir_instr *instr,
@@ -280,24 +346,16 @@ function_link_pass(struct nir_builder *b,
    if (!call->callee->name)
       return false;
 
+   if (call->callee->impl)
+      return false;
+
    func = nir_shader_get_function_for_name(state->link_shader, 
call->callee->name);
    if (!func || !func->impl) {
       return false;
    }
-
-   nir_def **params = rzalloc_array(b->shader, nir_def*, call->num_params);
-
-   for (unsigned i = 0; i < call->num_params; i++) {
-      params[i] = nir_ssa_for_src(b, call->params[i],
-                                  call->callee->params[i].num_components);
-   }
-
-   b->cursor = nir_instr_remove(&call->instr);
-   nir_inline_function_impl(b, func->impl, params, state->shader_var_remap);
-
-   ralloc_free(params);
-
-   return true;
+   return lower_call_function_impl(b, call->callee,
+                                   func->impl,
+                                   state);
 }
 
 bool

Reply via email to