Module: Mesa
Branch: main
Commit: bf6c214b2584f074652ff54d68ee018888855acd
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=bf6c214b2584f074652ff54d68ee018888855acd

Author: Giancarlo Devich <[email protected]>
Date:   Fri Feb 24 13:31:59 2023 -0800

d3d12: Create varying structures as necessary, reference them

This changes instances of d3d12_varying_info to d3d12_varying_info*,
significantly reducing the size of the d3d12_shader_key,
d3d12_gs_variant_key, and d3d12_tcs_variant_key.

Associated changes to key fill, compare, hashing, and gs and tcs variant
maps significantly reduce the amount of time spent clearing and
comparing memory.

The biggest win here is not having to re-zero _or_ re-fill varyings in
d3d12_fill_shader_key, validate_geometry_shader_variant, and
validate_tess_ctrl_shader_variant.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21527>

---

 src/gallium/drivers/d3d12/d3d12_compiler.cpp    | 222 +++++++++++++++++-------
 src/gallium/drivers/d3d12/d3d12_compiler.h      |  28 ++-
 src/gallium/drivers/d3d12/d3d12_gs_variant.cpp  |  57 +++---
 src/gallium/drivers/d3d12/d3d12_tcs_variant.cpp |  19 +-
 4 files changed, 229 insertions(+), 97 deletions(-)

diff --git a/src/gallium/drivers/d3d12/d3d12_compiler.cpp 
b/src/gallium/drivers/d3d12/d3d12_compiler.cpp
index edaac5131df..eff3ef567a8 100644
--- a/src/gallium/drivers/d3d12/d3d12_compiler.cpp
+++ b/src/gallium/drivers/d3d12/d3d12_compiler.cpp
@@ -96,6 +96,37 @@ compile_nir(struct d3d12_context *ctx, struct 
d3d12_shader_selector *sel,
    struct d3d12_screen *screen = d3d12_screen(ctx->base.screen);
    struct d3d12_shader *shader = rzalloc(sel, d3d12_shader);
    shader->key = *key;
+
+   if (key->required_varying_inputs != nullptr) {
+      shader->key.required_varying_inputs = ralloc(sel, struct 
d3d12_varying_info);
+      *shader->key.required_varying_inputs = *key->required_varying_inputs;
+   }
+   if (key->required_varying_outputs != nullptr) {
+      shader->key.required_varying_outputs = ralloc(sel, struct 
d3d12_varying_info);
+      *shader->key.required_varying_outputs = *key->required_varying_outputs;
+   }
+   
+
+   if (key->stage == PIPE_SHADER_TESS_CTRL &&
+         key->hs.required_patch_outputs != nullptr) {
+      shader->key.hs.required_patch_outputs = ralloc(sel, struct 
d3d12_varying_info);
+      *shader->key.hs.required_patch_outputs = *key->hs.required_patch_outputs;
+   }
+   if (shader->key.stage == PIPE_SHADER_TESS_EVAL &&
+         shader->key.ds.required_patch_inputs != nullptr) {
+      shader->key.ds.required_patch_inputs = ralloc(sel, struct 
d3d12_varying_info);
+      *shader->key.ds.required_patch_inputs = *key->ds.required_patch_inputs;
+   }
+
+   shader->output_vars_fs = nullptr;
+   shader->output_vars_gs = nullptr;
+   shader->output_vars_default = nullptr;
+
+   shader->input_vars_vs = nullptr;
+   shader->input_vars_default = nullptr;
+
+   shader->tess_eval_output_vars = nullptr;
+   shader->tess_ctrl_input_vars = nullptr;
    shader->nir = nir;
    sel->current = shader;
 
@@ -511,7 +542,7 @@ needs_vertex_reordering(struct d3d12_selection_context 
*sel_ctx, const struct pi
 }
 
 static nir_variable *
-create_varying_from_info(nir_shader *nir, struct d3d12_varying_info *info,
+create_varying_from_info(nir_shader *nir, const struct d3d12_varying_info 
*info,
                          unsigned slot, unsigned slot_frac, nir_variable_mode 
mode, bool patch)
 {
    nir_variable *var;
@@ -537,7 +568,7 @@ create_varying_from_info(nir_shader *nir, struct 
d3d12_varying_info *info,
 }
 
 void
-create_varyings_from_info(nir_shader *nir, struct d3d12_varying_info *info,
+create_varyings_from_info(nir_shader *nir, const struct d3d12_varying_info 
*info,
                           unsigned slot, nir_variable_mode mode, bool patch)
 {
    unsigned mask = info->slots[slot].location_frac_mask;
@@ -546,7 +577,7 @@ create_varyings_from_info(nir_shader *nir, struct 
d3d12_varying_info *info,
 }
 
 static void
-fill_varyings(struct d3d12_varying_info *info, nir_shader *s,
+fill_varyings(struct d3d12_varying_info *info, const nir_shader *s,
               nir_variable_mode modes, uint64_t mask, bool patch)
 {
    info->max = 0;
@@ -610,6 +641,12 @@ fill_flat_varyings(struct d3d12_gs_variant_key *key, 
d3d12_shader_selector *fs)
 bool
 d3d12_compare_varying_info(const d3d12_varying_info *expect, const 
d3d12_varying_info *have)
 {
+   if (expect == have)
+      return true;
+
+   if (expect == nullptr || have == nullptr)
+      return false;
+
    if (expect->mask != have->mask
       || expect->max != have->max)
       return false;
@@ -673,8 +710,12 @@ validate_geometry_shader_variant(struct 
d3d12_selection_context *sel_ctx)
    }
 
    if (variant_needed) {
-      fill_varyings(&key.varyings, vs->initial, nir_var_shader_out,
-                    vs->initial->info.outputs_written, false);
+      if (vs->initial_output_vars == nullptr) {
+         vs->initial_output_vars = ralloc(vs, struct d3d12_varying_info);
+         fill_varyings(vs->initial_output_vars, vs->initial, 
nir_var_shader_out,
+            vs->initial->info.outputs_written, false);
+      }
+      key.varyings = vs->initial_output_vars;
    }
 
    /* Find/create the proper variant and bind it */
@@ -700,8 +741,12 @@ validate_tess_ctrl_shader_variant(struct 
d3d12_selection_context *sel_ctx)
 
    /* Fill the variant key */
    if (variant_needed) {
-      fill_varyings(&key.varyings, vs->initial, nir_var_shader_out,
-                    vs->initial->info.outputs_written, false);
+      if (vs->initial_output_vars == nullptr) {
+         vs->initial_output_vars = ralloc(vs, struct d3d12_varying_info);
+         fill_varyings(vs->initial_output_vars, vs->initial, 
nir_var_shader_out,
+            vs->initial->info.outputs_written, false);
+      }
+      key.varyings = vs->initial_output_vars;
       key.vertices_out = ctx->patch_vertices;
    }
 
@@ -720,19 +765,6 @@ d3d12_compare_shader_keys(struct d3d12_selection_context* 
sel_ctx, const d3d12_s
    if (expect->hash != have->hash)
       return false;
 
-   /* Because we only add varyings we check that a shader has at least the 
expected in-
-    * and outputs. */
-
-   if (!d3d12_compare_varying_info(&expect->required_varying_inputs,
-                                   &have->required_varying_inputs) ||
-       expect->next_varying_inputs != have->next_varying_inputs)
-      return false;
-
-   if (!d3d12_compare_varying_info(&expect->required_varying_outputs,
-                                   &have->required_varying_outputs) ||
-       expect->prev_varying_outputs != have->prev_varying_outputs)
-      return false;
-
    if (expect->stage == PIPE_SHADER_GEOMETRY) {
       if (expect->gs.writes_psize) {
          if (!have->gs.writes_psize ||
@@ -767,12 +799,12 @@ d3d12_compare_shader_keys(struct d3d12_selection_context* 
sel_ctx, const d3d12_s
           expect->hs.point_mode != have->hs.point_mode ||
           expect->hs.spacing != have->hs.spacing ||
           expect->hs.patch_vertices_in != have->hs.patch_vertices_in ||
-          !d3d12_compare_varying_info(&expect->hs.required_patch_outputs, 
&have->hs.required_patch_outputs) ||
+          !d3d12_compare_varying_info(expect->hs.required_patch_outputs, 
have->hs.required_patch_outputs) ||
           expect->hs.next_patch_inputs != have->hs.next_patch_inputs)
          return false;
    } else if (expect->stage == PIPE_SHADER_TESS_EVAL) {
       if (expect->ds.tcs_vertices_out != have->ds.tcs_vertices_out ||
-          !d3d12_compare_varying_info(&expect->ds.required_patch_inputs, 
&have->ds.required_patch_inputs) ||
+          !d3d12_compare_varying_info(expect->ds.required_patch_inputs, 
have->ds.required_patch_inputs) ||
           expect->ds.prev_patch_outputs != have ->ds.prev_patch_outputs)
          return false;
    }
@@ -828,6 +860,19 @@ d3d12_compare_shader_keys(struct d3d12_selection_context* 
sel_ctx, const d3d12_s
    if (expect->stage == PIPE_SHADER_FRAGMENT && expect->fs.provoking_vertex != 
have->fs.provoking_vertex)
       return false;
 
+   /* Because we only add varyings we check that a shader has at least the 
expected in-
+    * and outputs. */
+
+   if (!d3d12_compare_varying_info(expect->required_varying_inputs,
+                                   have->required_varying_inputs) ||
+       expect->next_varying_inputs != have->next_varying_inputs)
+      return false;
+
+   if (!d3d12_compare_varying_info(expect->required_varying_outputs,
+                                   have->required_varying_outputs) ||
+       expect->prev_varying_outputs != have->prev_varying_outputs)
+      return false;
+
    return true;
 }
 
@@ -837,8 +882,10 @@ d3d12_shader_key_hash(const d3d12_shader_key *key)
    uint32_t hash;
 
    hash = (uint32_t)key->stage;
-   hash += key->required_varying_inputs.mask;
-   hash += key->required_varying_outputs.mask;
+   if (key->required_varying_inputs != nullptr)
+      hash += key->required_varying_inputs->mask + 
key->required_varying_inputs->max;
+   if (key->required_varying_outputs != nullptr)
+      hash += key->required_varying_outputs->mask + 
key->required_varying_outputs->max;
    hash += key->next_varying_inputs;
    hash += key->prev_varying_outputs;
    switch (key->stage) {
@@ -858,12 +905,14 @@ d3d12_shader_key_hash(const d3d12_shader_key *key)
       break;
    case PIPE_SHADER_TESS_CTRL:
       hash += key->hs.all;
-      hash += key->hs.required_patch_outputs.mask;
+      if (key->hs.required_patch_outputs)
+         hash += key->hs.required_patch_outputs->mask + 
key->hs.required_patch_outputs->max;
       break;
    case PIPE_SHADER_TESS_EVAL:
       hash += key->ds.tcs_vertices_out;
       hash += key->ds.prev_patch_outputs;
-      hash += key->ds.required_patch_inputs.mask;
+      if (key->ds.required_patch_inputs)
+         hash += key->ds.required_patch_inputs->mask + 
key->ds.required_patch_inputs->max;
       break;
    default:
       /* No type specific information to hash for other stages. */
@@ -890,11 +939,8 @@ d3d12_fill_shader_key(struct d3d12_selection_context 
*sel_ctx,
          VARYING_BIT_CLIP_DIST0 |
          VARYING_BIT_CLIP_DIST1;
 
-   key->hash = 0;
+   memset(key, 0, offsetof(d3d12_shader_key, vs));
    key->stage = stage;
-   key->required_varying_inputs.mask = 0;
-   key->required_varying_outputs.mask = 0;
-   memset(&key->next_varying_inputs, 0, offsetof(d3d12_shader_key, vs) - 
offsetof(d3d12_shader_key, next_varying_inputs));
 
    switch (stage)
    {
@@ -909,12 +955,12 @@ d3d12_fill_shader_key(struct d3d12_selection_context 
*sel_ctx,
       break;
    case PIPE_SHADER_TESS_CTRL:
       key->hs.all = 0;
-      key->hs.required_patch_outputs.mask = 0;
+      key->hs.required_patch_outputs = nullptr;
       break;
    case PIPE_SHADER_TESS_EVAL:
       key->ds.tcs_vertices_out = 0;
       key->ds.prev_patch_outputs = 0;
-      key->ds.required_patch_inputs.mask = 0;
+      key->ds.required_patch_inputs = nullptr;
       break;
    case PIPE_SHADER_COMPUTE:
       memset(key->cs.workgroup_size, 0, sizeof(key->cs.workgroup_size));
@@ -928,19 +974,45 @@ d3d12_fill_shader_key(struct d3d12_selection_context 
*sel_ctx,
    if (prev) {
       /* We require as inputs what the previous stage has written,
        * except certain system values */
-      if (stage == PIPE_SHADER_FRAGMENT || stage == PIPE_SHADER_GEOMETRY)
+
+      struct d3d12_varying_info **output_vars = nullptr;
+
+      switch (stage) {
+      case PIPE_SHADER_FRAGMENT:
+         system_out_values |= VARYING_BIT_POS | VARYING_BIT_PSIZ | 
VARYING_BIT_VIEWPORT | VARYING_BIT_LAYER;
+         output_vars = &prev->current->output_vars_fs;
+         break;
+      case PIPE_SHADER_GEOMETRY:
          system_out_values |= VARYING_BIT_POS;
-      if (stage == PIPE_SHADER_FRAGMENT)
-         system_out_values |= VARYING_BIT_PSIZ | VARYING_BIT_VIEWPORT | 
VARYING_BIT_LAYER;
+         output_vars = &prev->current->output_vars_gs;
+         break;
+      default:
+         output_vars = &prev->current->output_vars_default;
+         break;
+      }
+
       uint64_t mask = prev->current->nir->info.outputs_written & 
~system_out_values;
-      fill_varyings(&key->required_varying_inputs, prev->current->nir,
-                    nir_var_shader_out, mask, false);
+
+      if (*output_vars == nullptr) {
+         *output_vars = ralloc(prev, struct d3d12_varying_info);
+         fill_varyings(*output_vars, prev->current->nir,
+                       nir_var_shader_out, mask, false);
+      }
+
+      key->required_varying_inputs = *output_vars;
+
       key->prev_varying_outputs = prev->current->nir->info.outputs_written;
 
       if (stage == PIPE_SHADER_TESS_EVAL) {
          uint32_t patch_mask = prev->current->nir->info.patch_outputs_written;
-         fill_varyings(&key->ds.required_patch_inputs, prev->current->nir,
-                       nir_var_shader_out, patch_mask, true);
+
+         if (prev->current->tess_eval_output_vars == nullptr) {
+            prev->current->tess_eval_output_vars = ralloc(prev, struct 
d3d12_varying_info);
+            fill_varyings(prev->current->tess_eval_output_vars, 
prev->current->nir,
+                          nir_var_shader_out, patch_mask, true);
+         }
+
+         key->ds.required_patch_inputs = prev->current->tess_eval_output_vars;
          key->ds.prev_patch_outputs = patch_mask;
       }
 
@@ -963,16 +1035,35 @@ d3d12_fill_shader_key(struct d3d12_selection_context 
*sel_ctx,
     * except certain system values */
    if (next) {
       if (!next->is_variant) {
-         if (stage == PIPE_SHADER_VERTEX)
+
+         struct d3d12_varying_info **output_vars = 
&next->current->input_vars_default;
+
+         if (stage == PIPE_SHADER_VERTEX) {
             system_generated_in_values |= VARYING_BIT_POS;
+            output_vars = &next->current->input_vars_vs;
+         }
          uint64_t mask = next->current->nir->info.inputs_read & 
~system_generated_in_values;
-         fill_varyings(&key->required_varying_outputs, next->current->nir,
-                       nir_var_shader_in, mask, false);
+
+         
+         if (*output_vars == nullptr) {
+            *output_vars = ralloc(next, struct d3d12_varying_info);
+            fill_varyings(*output_vars, next->current->nir,
+                          nir_var_shader_in, mask, false);
+         }
+
+         key->required_varying_outputs = *output_vars;
+
 
          if (stage == PIPE_SHADER_TESS_CTRL) {
             uint32_t patch_mask = next->current->nir->info.patch_outputs_read;
-            fill_varyings(&key->hs.required_patch_outputs, prev->current->nir,
-                          nir_var_shader_in, patch_mask, true);
+
+            if (prev->current->tess_ctrl_input_vars == nullptr){
+               prev->current->tess_ctrl_input_vars = ralloc(next, struct 
d3d12_varying_info);
+               fill_varyings(prev->current->tess_ctrl_input_vars, 
prev->current->nir,
+                              nir_var_shader_in, patch_mask, true);
+            }
+
+            key->hs.required_patch_outputs = 
prev->current->tess_ctrl_input_vars;
             key->hs.next_patch_inputs = patch_mask;
          }
       }
@@ -1058,7 +1149,7 @@ d3d12_fill_shader_key(struct d3d12_selection_context 
*sel_ctx,
       }
    }
 
-   for (unsigned i = 0; i < sel_ctx->ctx->num_samplers[stage]; ++i) {
+   for (unsigned i = 0, e = sel_ctx->ctx->num_samplers[stage]; i < e; ++i) {
       if (!sel_ctx->ctx->samplers[stage][i] ||
           sel_ctx->ctx->samplers[stage][i]->filter == PIPE_TEX_FILTER_NEAREST)
          continue;
@@ -1244,19 +1335,21 @@ select_shader_variant(struct d3d12_selection_context 
*sel_ctx, d3d12_shader_sele
 
    /* Add the needed in and outputs, and re-sort */
    if (prev) {
-      uint64_t mask = key.required_varying_inputs.mask & 
~new_nir_variant->info.inputs_read;
-      new_nir_variant->info.inputs_read |= mask;
-      while (mask) {
-         int slot = u_bit_scan64(&mask);
-         create_varyings_from_info(new_nir_variant, 
&key.required_varying_inputs, slot, nir_var_shader_in, false);
+      if (key.required_varying_inputs != nullptr) {
+         uint64_t mask = key.required_varying_inputs->mask & 
~new_nir_variant->info.inputs_read;
+         new_nir_variant->info.inputs_read |= mask;
+         while (mask) {
+            int slot = u_bit_scan64(&mask);
+            create_varyings_from_info(new_nir_variant, 
key.required_varying_inputs, slot, nir_var_shader_in, false);
+         }
       }
 
       if (sel->stage == PIPE_SHADER_TESS_EVAL) {
-         uint32_t patch_mask = (uint32_t)key.ds.required_patch_inputs.mask & 
~new_nir_variant->info.patch_inputs_read;
+         uint32_t patch_mask = (uint32_t)key.ds.required_patch_inputs->mask & 
~new_nir_variant->info.patch_inputs_read;
          new_nir_variant->info.patch_inputs_read |= patch_mask;
          while (patch_mask) {
             int slot = u_bit_scan(&patch_mask);
-            create_varyings_from_info(new_nir_variant, 
&key.ds.required_patch_inputs, slot, nir_var_shader_in, true);
+            create_varyings_from_info(new_nir_variant, 
key.ds.required_patch_inputs, slot, nir_var_shader_in, true);
          }
       }
       dxil_reassign_driver_locations(new_nir_variant, nir_var_shader_in,
@@ -1265,19 +1358,22 @@ select_shader_variant(struct d3d12_selection_context 
*sel_ctx, d3d12_shader_sele
 
 
    if (next) {
-      uint64_t mask = key.required_varying_outputs.mask & 
~new_nir_variant->info.outputs_written;
-      new_nir_variant->info.outputs_written |= mask;
-      while (mask) {
-         int slot = u_bit_scan64(&mask);
-         create_varyings_from_info(new_nir_variant, 
&key.required_varying_outputs, slot, nir_var_shader_out, false);
+      if (key.required_varying_outputs != nullptr) {
+         uint64_t mask = key.required_varying_outputs->mask & 
~new_nir_variant->info.outputs_written;
+         new_nir_variant->info.outputs_written |= mask;
+         while (mask) {
+            int slot = u_bit_scan64(&mask);
+            create_varyings_from_info(new_nir_variant, 
key.required_varying_outputs, slot, nir_var_shader_out, false);
+         }
       }
 
-      if (sel->stage == PIPE_SHADER_TESS_CTRL) {
-         uint32_t patch_mask = (uint32_t)key.hs.required_patch_outputs.mask & 
~new_nir_variant->info.patch_outputs_written;
+      if (sel->stage == PIPE_SHADER_TESS_CTRL &&
+            key.hs.required_patch_outputs != nullptr) {
+         uint32_t patch_mask = (uint32_t)key.hs.required_patch_outputs->mask & 
~new_nir_variant->info.patch_outputs_written;
          new_nir_variant->info.patch_outputs_written |= patch_mask;
          while (patch_mask) {
             int slot = u_bit_scan(&patch_mask);
-            create_varyings_from_info(new_nir_variant, 
&key.hs.required_patch_outputs, slot, nir_var_shader_out, true);
+            create_varyings_from_info(new_nir_variant, 
key.hs.required_patch_outputs, slot, nir_var_shader_out, true);
          }
       }
       dxil_reassign_driver_locations(new_nir_variant, nir_var_shader_out,
@@ -1424,6 +1520,9 @@ d3d12_create_shader_impl(struct d3d12_context *ctx,
 
    /* Keep this initial shader as the blue print for possible variants */
    sel->initial = nir;
+   sel->initial_output_vars = nullptr;
+   sel->gs_key.varyings = nullptr;
+   sel->tcs_key.varyings = nullptr;
 
    /*
     * We must compile some shader here, because if the previous or a next 
shaders exists later
@@ -1590,6 +1689,7 @@ d3d12_shader_free(struct d3d12_shader_selector *sel)
       free(shader->bytecode);
       shader = shader->next_variant;
    }
-   ralloc_free(sel->initial);
+
+   ralloc_free((void*)sel->initial);
    ralloc_free(sel);
 }
diff --git a/src/gallium/drivers/d3d12/d3d12_compiler.h 
b/src/gallium/drivers/d3d12/d3d12_compiler.h
index 957170928dc..f74ba2db0e2 100644
--- a/src/gallium/drivers/d3d12/d3d12_compiler.h
+++ b/src/gallium/drivers/d3d12/d3d12_compiler.h
@@ -91,8 +91,8 @@ struct d3d12_shader_key {
    uint32_t hash;
    enum pipe_shader_type stage;
 
-   struct d3d12_varying_info required_varying_inputs;
-   struct d3d12_varying_info required_varying_outputs;
+   struct d3d12_varying_info *required_varying_inputs;
+   struct d3d12_varying_info *required_varying_outputs;
    uint64_t next_varying_inputs;
    uint64_t prev_varying_outputs;
    unsigned last_vertex_processing_stage : 1;
@@ -136,13 +136,13 @@ struct d3d12_shader_key {
             };
             uint64_t all;
          };
-         struct d3d12_varying_info required_patch_outputs;
+         struct d3d12_varying_info *required_patch_outputs;
       } hs;
 
       struct {
          unsigned tcs_vertices_out;
          uint32_t prev_patch_outputs;
-         struct d3d12_varying_info required_patch_inputs;
+         struct d3d12_varying_info *required_patch_inputs;
       } ds;
 
       union {
@@ -179,6 +179,15 @@ struct d3d12_shader {
    size_t bytecode_length;
 
    nir_shader *nir;
+   struct d3d12_varying_info *output_vars_gs;
+   struct d3d12_varying_info *output_vars_fs;
+   struct d3d12_varying_info *output_vars_default;
+
+   struct d3d12_varying_info *input_vars_vs;
+   struct d3d12_varying_info *input_vars_default;
+
+   struct d3d12_varying_info *tess_eval_output_vars;
+   struct d3d12_varying_info *tess_ctrl_input_vars;
 
    struct {
       unsigned binding;
@@ -223,18 +232,20 @@ struct d3d12_gs_variant_key
    unsigned edge_flag_fix:1;
    unsigned flatshade_first:1;
    uint64_t flat_varyings;
-   struct d3d12_varying_info varyings;
+   struct d3d12_varying_info *varyings;
 };
 
 struct d3d12_tcs_variant_key
 {
    unsigned vertices_out;
-   struct d3d12_varying_info varyings;
+   struct d3d12_varying_info *varyings;
 };
 
 struct d3d12_shader_selector {
    enum pipe_shader_type stage;
-   nir_shader *initial;
+   const nir_shader *initial;
+   struct d3d12_varying_info *initial_output_vars;
+
    struct d3d12_shader *first;
    struct d3d12_shader *current;
 
@@ -297,6 +308,9 @@ missing_dual_src_outputs(struct d3d12_context* ctx);
 bool
 has_flat_varyings(struct d3d12_context* ctx);
 
+bool
+d3d12_compare_varying_info(const struct d3d12_varying_info *expect, const 
struct d3d12_varying_info *have);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/src/gallium/drivers/d3d12/d3d12_gs_variant.cpp 
b/src/gallium/drivers/d3d12/d3d12_gs_variant.cpp
index b1605623cf2..4ddc1d433f1 100644
--- a/src/gallium/drivers/d3d12/d3d12_gs_variant.cpp
+++ b/src/gallium/drivers/d3d12/d3d12_gs_variant.cpp
@@ -71,7 +71,7 @@ static d3d12_shader_selector*
 d3d12_make_passthrough_gs(struct d3d12_context *ctx, struct 
d3d12_gs_variant_key *key)
 {
    struct d3d12_shader_selector *gs;
-   uint64_t varyings = key->varyings.mask;
+   uint64_t varyings = key->varyings->mask;
    nir_shader *nir;
    struct pipe_shader_state templ;
 
@@ -94,32 +94,32 @@ d3d12_make_passthrough_gs(struct d3d12_context *ctx, struct 
d3d12_gs_variant_key
       char tmp[100];
       const int i = u_bit_scan64(&varyings);
 
-      unsigned frac_slots = key->varyings.slots[i].location_frac_mask;
+      unsigned frac_slots = key->varyings->slots[i].location_frac_mask;
       while (frac_slots) {
          nir_variable *in, *out;
          int j = u_bit_scan(&frac_slots);
 
-         snprintf(tmp, ARRAY_SIZE(tmp), "in_%d", 
key->varyings.slots[i].vars[j].driver_location);
+         snprintf(tmp, ARRAY_SIZE(tmp), "in_%d", 
key->varyings->slots[i].vars[j].driver_location);
          in = nir_variable_create(nir,
                                   nir_var_shader_in,
-                                  
glsl_array_type(key->varyings.slots[i].types[j], 1, false),
+                                  
glsl_array_type(key->varyings->slots[i].types[j], 1, false),
                                   tmp);
          in->data.location = i;
          in->data.location_frac = j;
-         in->data.driver_location = 
key->varyings.slots[i].vars[j].driver_location;
-         in->data.interpolation = key->varyings.slots[i].vars[j].interpolation;
-         in->data.compact = key->varyings.slots[i].vars[j].compact;
+         in->data.driver_location = 
key->varyings->slots[i].vars[j].driver_location;
+         in->data.interpolation = 
key->varyings->slots[i].vars[j].interpolation;
+         in->data.compact = key->varyings->slots[i].vars[j].compact;
 
-         snprintf(tmp, ARRAY_SIZE(tmp), "out_%d", 
key->varyings.slots[i].vars[j].driver_location);
+         snprintf(tmp, ARRAY_SIZE(tmp), "out_%d", 
key->varyings->slots[i].vars[j].driver_location);
          out = nir_variable_create(nir,
                                    nir_var_shader_out,
-                                   key->varyings.slots[i].types[j],
+                                   key->varyings->slots[i].types[j],
                                    tmp);
          out->data.location = i;
          out->data.location_frac = j;
-         out->data.driver_location = 
key->varyings.slots[i].vars[j].driver_location;
-         out->data.interpolation = 
key->varyings.slots[i].vars[j].interpolation;
-         out->data.compact = key->varyings.slots[i].vars[j].compact;
+         out->data.driver_location = 
key->varyings->slots[i].vars[j].driver_location;
+         out->data.interpolation = 
key->varyings->slots[i].vars[j].interpolation;
+         out->data.compact = key->varyings->slots[i].vars[j].compact;
 
          nir_deref_instr *in_value = nir_build_deref_array(&b, 
nir_build_deref_var(&b, in),
                                                                nir_imm_int(&b, 
0));
@@ -169,7 +169,7 @@ d3d12_begin_emit_primitives_gs(struct 
emit_primitives_context *emit_ctx,
    nir_builder *b = &emit_ctx->b;
    nir_variable *edgeflag_var = NULL;
    nir_variable *pos_var = NULL;
-   uint64_t varyings = key->varyings.mask;
+   uint64_t varyings = key->varyings->mask;
 
    emit_ctx->ctx = ctx;
 
@@ -191,19 +191,19 @@ d3d12_begin_emit_primitives_gs(struct 
emit_primitives_context *emit_ctx,
       char tmp[100];
       const int i = u_bit_scan64(&varyings);
 
-      unsigned frac_slots = key->varyings.slots[i].location_frac_mask;
+      unsigned frac_slots = key->varyings->slots[i].location_frac_mask;
       while (frac_slots) {
          int j = u_bit_scan(&frac_slots);
          snprintf(tmp, ARRAY_SIZE(tmp), "in_%d", emit_ctx->num_vars);
          emit_ctx->in[emit_ctx->num_vars] = nir_variable_create(nir,
                                                                 
nir_var_shader_in,
-                                                                
glsl_array_type(key->varyings.slots[i].types[j], 3, 0),
+                                                                
glsl_array_type(key->varyings->slots[i].types[j], 3, 0),
                                                                 tmp);
          emit_ctx->in[emit_ctx->num_vars]->data.location = i;
          emit_ctx->in[emit_ctx->num_vars]->data.location_frac = j;
-         emit_ctx->in[emit_ctx->num_vars]->data.driver_location = 
key->varyings.slots[i].vars[j].driver_location;
-         emit_ctx->in[emit_ctx->num_vars]->data.interpolation = 
key->varyings.slots[i].vars[j].interpolation;
-         emit_ctx->in[emit_ctx->num_vars]->data.compact = 
key->varyings.slots[i].vars[j].compact;
+         emit_ctx->in[emit_ctx->num_vars]->data.driver_location = 
key->varyings->slots[i].vars[j].driver_location;
+         emit_ctx->in[emit_ctx->num_vars]->data.interpolation = 
key->varyings->slots[i].vars[j].interpolation;
+         emit_ctx->in[emit_ctx->num_vars]->data.compact = 
key->varyings->slots[i].vars[j].compact;
 
          /* Don't create an output for the edge flag variable */
          if (i == VARYING_SLOT_EDGE) {
@@ -216,13 +216,13 @@ d3d12_begin_emit_primitives_gs(struct 
emit_primitives_context *emit_ctx,
          snprintf(tmp, ARRAY_SIZE(tmp), "out_%d", emit_ctx->num_vars);
          emit_ctx->out[emit_ctx->num_vars] = nir_variable_create(nir,
                                                                  
nir_var_shader_out,
-                                                                 
key->varyings.slots[i].types[j],
+                                                                 
key->varyings->slots[i].types[j],
                                                                  tmp);
          emit_ctx->out[emit_ctx->num_vars]->data.location = i;
          emit_ctx->out[emit_ctx->num_vars]->data.location_frac = j;
-         emit_ctx->out[emit_ctx->num_vars]->data.driver_location = 
key->varyings.slots[i].vars[j].driver_location;
-         emit_ctx->out[emit_ctx->num_vars]->data.interpolation = 
key->varyings.slots[i].vars[j].interpolation;
-         emit_ctx->out[emit_ctx->num_vars]->data.compact = 
key->varyings.slots[i].vars[j].compact;
+         emit_ctx->out[emit_ctx->num_vars]->data.driver_location = 
key->varyings->slots[i].vars[j].driver_location;
+         emit_ctx->out[emit_ctx->num_vars]->data.interpolation = 
key->varyings->slots[i].vars[j].interpolation;
+         emit_ctx->out[emit_ctx->num_vars]->data.compact = 
key->varyings->slots[i].vars[j].compact;
 
          emit_ctx->num_vars++;
       }
@@ -455,13 +455,18 @@ d3d12_emit_triangles(struct d3d12_context *ctx, struct 
d3d12_gs_variant_key *key
 static uint32_t
 hash_gs_variant_key(const void *key)
 {
-   return _mesa_hash_data(key, sizeof(struct d3d12_gs_variant_key));
+   d3d12_gs_variant_key *v = (d3d12_gs_variant_key*)key;
+   uint32_t hash = _mesa_hash_data(v, offsetof(d3d12_gs_variant_key, 
varyings));
+   if (v->varyings)
+      hash = _mesa_hash_data_with_seed(v->varyings->slots, 
sizeof(v->varyings->slots[0]) * v->varyings->max, hash);
+   return hash;
 }
 
 static bool
 equals_gs_variant_key(const void *a, const void *b)
 {
-   return memcmp(a, b, sizeof(struct d3d12_gs_variant_key)) == 0;
+   return memcmp(a, b, offsetof(d3d12_gs_variant_key, varyings)) == 0
+      &&  d3d12_compare_varying_info(((d3d12_gs_variant_key*)a)->varyings, 
((d3d12_gs_variant_key*)b)->varyings);
 }
 
 void
@@ -499,6 +504,10 @@ create_geometry_shader_variant(struct d3d12_context *ctx, 
struct d3d12_gs_varian
    if (gs) {
       gs->is_variant = true;
       gs->gs_key = *key;
+      if (key->varyings) {
+         gs->gs_key.varyings = ralloc(gs, struct d3d12_varying_info);
+         *gs->gs_key.varyings = *key->varyings;
+      }
    }
 
    return gs;
diff --git a/src/gallium/drivers/d3d12/d3d12_tcs_variant.cpp 
b/src/gallium/drivers/d3d12/d3d12_tcs_variant.cpp
index 00897f6438d..dbaa9a0346a 100644
--- a/src/gallium/drivers/d3d12/d3d12_tcs_variant.cpp
+++ b/src/gallium/drivers/d3d12/d3d12_tcs_variant.cpp
@@ -31,13 +31,18 @@
 static uint32_t
 hash_tcs_variant_key(const void *key)
 {
-   return _mesa_hash_data(key, sizeof(struct d3d12_tcs_variant_key));
+   d3d12_tcs_variant_key *v = (d3d12_tcs_variant_key*)key;
+   uint32_t hash = _mesa_hash_data(v, offsetof(d3d12_tcs_variant_key, 
varyings));
+   if (v->varyings)
+      hash = _mesa_hash_data_with_seed(v->varyings->slots, 
sizeof(v->varyings->slots[0]) * v->varyings->max, hash);
+   return hash;
 }
 
 static bool
 equals_tcs_variant_key(const void *a, const void *b)
 {
-   return memcmp(a, b, sizeof(struct d3d12_tcs_variant_key)) == 0;
+   return memcmp(a, b, offsetof(d3d12_tcs_variant_key, varyings)) == 0
+      &&  d3d12_compare_varying_info(((d3d12_tcs_variant_key*)a)->varyings, 
((d3d12_tcs_variant_key*)b)->varyings);
 }
 
 void
@@ -80,11 +85,11 @@ create_tess_ctrl_shader_variant(struct d3d12_context *ctx, 
struct d3d12_tcs_vari
    nir_shader *nir = b.shader;
 
    nir_ssa_def *invocation_id = nir_load_invocation_id(&b);
-   uint64_t varying_mask = key->varyings.mask;
+   uint64_t varying_mask = key->varyings->mask;
 
    while(varying_mask) {
       int var_idx = u_bit_scan64(&varying_mask);
-      auto slot = &key->varyings.slots[var_idx];
+      auto slot = &key->varyings->slots[var_idx];
       unsigned frac_mask = slot->location_frac_mask;
       while (frac_mask) {
          int frac = u_bit_scan(&frac_mask);
@@ -144,7 +149,11 @@ create_tess_ctrl_shader_variant(struct d3d12_context *ctx, 
struct d3d12_tcs_vari
    d3d12_shader_selector *tcs = d3d12_create_shader(ctx, 
PIPE_SHADER_TESS_CTRL, &templ);
    if (tcs) {
       tcs->is_variant = true;
-      memcpy(&tcs->tcs_key, key, sizeof(*key));
+      tcs->tcs_key = *key;
+      if (key->varyings) {
+         tcs->tcs_key.varyings = ralloc(tcs, struct d3d12_varying_info);
+         *tcs->tcs_key.varyings = *key->varyings;
+      }
    }
    return tcs;
 }

Reply via email to