---
 src/amd/common/ac_nir_to_llvm.c | 133 +++++++++++++++++++++++++---------------
 1 file changed, 83 insertions(+), 50 deletions(-)

diff --git a/src/amd/common/ac_nir_to_llvm.c b/src/amd/common/ac_nir_to_llvm.c
index 67945a353e8..cb011bd88bb 100644
--- a/src/amd/common/ac_nir_to_llvm.c
+++ b/src/amd/common/ac_nir_to_llvm.c
@@ -6428,7 +6428,8 @@ void ac_nir_translate(struct ac_llvm_context *ac, struct 
ac_shader_abi *abi,
 
 static
 LLVMModuleRef ac_translate_nir_to_llvm(LLVMTargetMachineRef tm,
-                                       struct nir_shader *nir,
+                                       struct nir_shader *const *shaders,
+                                       int shader_count,
                                        struct ac_shader_variant_info 
*shader_info,
                                        const struct ac_nir_compiler_options 
*options)
 {
@@ -6441,11 +6442,6 @@ LLVMModuleRef 
ac_translate_nir_to_llvm(LLVMTargetMachineRef tm,
 
        ac_llvm_context_init(&ctx.ac, ctx.context, options->chip_class);
        ctx.ac.module = ctx.module;
-
-       memset(shader_info, 0, sizeof(*shader_info));
-
-       ac_nir_shader_info_pass(nir, options, &shader_info->info);
-
        LLVMSetTarget(ctx.module, options->supports_spill ? 
"amdgcn-mesa-mesa3d" : "amdgcn--");
 
        LLVMTargetDataRef data_layout = LLVMCreateTargetDataLayout(tm);
@@ -6455,72 +6451,109 @@ LLVMModuleRef 
ac_translate_nir_to_llvm(LLVMTargetMachineRef tm,
        LLVMDisposeMessage(data_layout_str);
 
        setup_types(&ctx);
-
        ctx.builder = LLVMCreateBuilderInContext(ctx.context);
        ctx.ac.builder = ctx.builder;
-       ctx.stage = nir->stage;
-       ctx.max_workgroup_size = 
ac_nir_get_max_workgroup_size(ctx.options->chip_class, nir);
+
+       memset(shader_info, 0, sizeof(*shader_info));
+
+       for(int i = 0; i < shader_count; ++i)
+               ac_nir_shader_info_pass(shaders[i], options, 
&shader_info->info);
 
        for (i = 0; i < AC_UD_MAX_SETS; i++)
                shader_info->user_sgprs_locs.descriptor_sets[i].sgpr_idx = -1;
        for (i = 0; i < AC_UD_MAX_UD; i++)
                shader_info->user_sgprs_locs.shader_data[i].sgpr_idx = -1;
 
-       create_function(&ctx, nir->stage, false, MESA_SHADER_VERTEX);
+       ctx.max_workgroup_size = 
ac_nir_get_max_workgroup_size(ctx.options->chip_class, shaders[0]);
+
+       create_function(&ctx, shaders[shader_count - 1]->stage, shader_count >= 
2,
+                       shader_count >= 2 ? shaders[shader_count - 2]->stage  : 
MESA_SHADER_VERTEX);
 
-       if (nir->stage == MESA_SHADER_GEOMETRY) {
-               ctx.gs_next_vertex = ac_build_alloca(&ctx.ac, ctx.i32, 
"gs_next_vertex");
+       ctx.abi.inputs = &ctx.inputs[0];
+       ctx.abi.emit_outputs = handle_shader_outputs_post;
+       ctx.abi.load_ssbo = radv_load_ssbo;
+       ctx.abi.load_sampler_desc = radv_get_sampler_desc;
 
-               ctx.gs_max_out_vertices = nir->info.gs.vertices_out;
-       } else if (nir->stage == MESA_SHADER_TESS_EVAL) {
-               ctx.tes_primitive_mode = nir->info.tess.primitive_mode;
-       } else if (nir->stage == MESA_SHADER_VERTEX) {
-               if (shader_info->info.vs.needs_instance_id) {
-                       ctx.shader_info->vs.vgpr_comp_cnt =
-                               MAX2(3, ctx.shader_info->vs.vgpr_comp_cnt);
+       for(int i = 0; i < shader_count; ++i) {
+               ctx.stage = shaders[i]->stage;
+               ctx.output_mask = 0;
+               ctx.tess_outputs_written = 0;
+               ctx.num_output_clips = 
shaders[i]->info.clip_distance_array_size;
+               ctx.num_output_culls = 
shaders[i]->info.cull_distance_array_size;
+
+               if (shaders[i]->stage == MESA_SHADER_GEOMETRY) {
+                       ctx.gs_next_vertex = ac_build_alloca(&ctx.ac, ctx.i32, 
"gs_next_vertex");
+
+                       ctx.gs_max_out_vertices = 
shaders[i]->info.gs.vertices_out;
+               } else if (shaders[i]->stage == MESA_SHADER_TESS_EVAL) {
+                       ctx.tes_primitive_mode = 
shaders[i]->info.tess.primitive_mode;
+               } else if (shaders[i]->stage == MESA_SHADER_VERTEX) {
+                       if (shader_info->info.vs.needs_instance_id) {
+                               ctx.shader_info->vs.vgpr_comp_cnt =
+                                       MAX2(3, 
ctx.shader_info->vs.vgpr_comp_cnt);
+                       }
+               } else if (shaders[i]->stage == MESA_SHADER_FRAGMENT) {
+                       shader_info->fs.can_discard = 
shaders[i]->info.fs.uses_discard;
                }
-       } else if (nir->stage == MESA_SHADER_FRAGMENT) {
-               shader_info->fs.can_discard = nir->info.fs.uses_discard;
-       }
 
-       ac_setup_rings(&ctx);
+               if (i)
+                       emit_barrier(&ctx);
 
-       ctx.num_output_clips = nir->info.clip_distance_array_size;
-       ctx.num_output_culls = nir->info.cull_distance_array_size;
+               ac_setup_rings(&ctx);
 
-       if (nir->stage == MESA_SHADER_FRAGMENT)
-               handle_fs_inputs(&ctx, nir);
-       else if(nir->stage == MESA_SHADER_VERTEX)
-               handle_vs_inputs(&ctx, nir);
+               LLVMBasicBlockRef merge_block;
+               if (shader_count >= 2) {
+                       LLVMValueRef fn = 
LLVMGetBasicBlockParent(LLVMGetInsertBlock(ctx.ac.builder));
+                       LLVMBasicBlockRef then_block = 
LLVMAppendBasicBlockInContext(ctx.ac.context, fn, "");
+                       merge_block = 
LLVMAppendBasicBlockInContext(ctx.ac.context, fn, "");
 
-       ctx.abi.inputs = &ctx.inputs[0];
-       ctx.abi.emit_outputs = handle_shader_outputs_post;
-       ctx.abi.load_ssbo = radv_load_ssbo;
-       ctx.abi.load_sampler_desc = radv_get_sampler_desc;
+                       LLVMValueRef count = ac_build_bfe(&ctx.ac, 
ctx.merged_wave_info,
+                                                         
LLVMConstInt(ctx.ac.i32, 8 * i, false),
+                                                         
LLVMConstInt(ctx.ac.i32, 8, false), false);
+                       LLVMValueRef thread_id = ac_get_thread_id(&ctx.ac);
+                       LLVMValueRef cond = LLVMBuildICmp(ctx.ac.builder, 
LLVMIntULT,
+                                                         thread_id, count, "");
+                       LLVMBuildCondBr(ctx.ac.builder, cond, then_block, 
merge_block);
 
-       nir_foreach_variable(variable, &nir->outputs)
-               scan_shader_output_decl(&ctx, variable, nir, nir->stage);
+                       LLVMPositionBuilderAtEnd(ctx.ac.builder, then_block);
+               }
 
-       ac_nir_translate(&ctx.ac, &ctx.abi, nir, &ctx);
+               if (shaders[i]->stage == MESA_SHADER_FRAGMENT)
+                       handle_fs_inputs(&ctx, shaders[i]);
+               else if(shaders[i]->stage == MESA_SHADER_VERTEX)
+                       handle_vs_inputs(&ctx, shaders[i]);
 
-       LLVMBuildRetVoid(ctx.builder);
+               nir_foreach_variable(variable, &shaders[i]->outputs)
+                       scan_shader_output_decl(&ctx, variable, shaders[i], 
shaders[i]->stage);
 
-       ac_llvm_finalize_module(&ctx);
+               ac_nir_translate(&ctx.ac, &ctx.abi, shaders[i], &ctx);
 
-       ac_nir_eliminate_const_vs_outputs(&ctx);
+               if (shader_count >= 2) {
+                       LLVMBuildBr(ctx.ac.builder, merge_block);
+                       LLVMPositionBuilderAtEnd(ctx.ac.builder, merge_block);
+               }
 
-       if (nir->stage == MESA_SHADER_GEOMETRY) {
-               unsigned addclip = ctx.num_output_clips + ctx.num_output_culls 
> 4;
-               shader_info->gs.gsvs_vertex_size = 
(util_bitcount64(ctx.output_mask) + addclip) * 16;
-               shader_info->gs.max_gsvs_emit_size = 
shader_info->gs.gsvs_vertex_size *
-                       nir->info.gs.vertices_out;
-       } else if (nir->stage == MESA_SHADER_TESS_CTRL) {
-               shader_info->tcs.outputs_written = ctx.tess_outputs_written;
-               shader_info->tcs.patch_outputs_written = 
ctx.tess_patch_outputs_written;
-       } else if (nir->stage == MESA_SHADER_VERTEX && 
ctx.options->key.vs.as_ls) {
-               shader_info->vs.outputs_written = ctx.tess_outputs_written;
+               if (shaders[i]->stage == MESA_SHADER_GEOMETRY) {
+                       unsigned addclip = 
shaders[i]->info.clip_distance_array_size +
+                                       
shaders[i]->info.cull_distance_array_size > 4;
+                       shader_info->gs.gsvs_vertex_size = 
(util_bitcount64(ctx.output_mask) + addclip) * 16;
+                       shader_info->gs.max_gsvs_emit_size = 
shader_info->gs.gsvs_vertex_size *
+                               shaders[i]->info.gs.vertices_out;
+               } else if (shaders[i]->stage == MESA_SHADER_TESS_CTRL) {
+                       shader_info->tcs.outputs_written = 
ctx.tess_outputs_written;
+                       shader_info->tcs.patch_outputs_written = 
ctx.tess_patch_outputs_written;
+               } else if (shaders[i]->stage == MESA_SHADER_VERTEX && 
ctx.options->key.vs.as_ls) {
+                       shader_info->vs.outputs_written = 
ctx.tess_outputs_written;
+               }
        }
 
+       LLVMBuildRetVoid(ctx.builder);
+
+       ac_llvm_finalize_module(&ctx);
+
+       if (shader_count == 1)
+               ac_nir_eliminate_const_vs_outputs(&ctx);
+
        return ctx.module;
 }
 
@@ -6699,7 +6732,7 @@ void ac_compile_nir_shader(LLVMTargetMachineRef tm,
                           bool dump_shader)
 {
 
-       LLVMModuleRef llvm_module = ac_translate_nir_to_llvm(tm, nir[0], 
shader_info,
+       LLVMModuleRef llvm_module = ac_translate_nir_to_llvm(tm, nir, 
nir_count, shader_info,
                                                             options);
 
        ac_compile_llvm_module(tm, llvm_module, binary, config, shader_info, 
nir[0]->stage, dump_shader, options->supports_spill);
-- 
2.14.2

_______________________________________________
mesa-dev mailing list
mesa-dev@lists.freedesktop.org
https://lists.freedesktop.org/mailman/listinfo/mesa-dev

Reply via email to