Module: Mesa
Branch: main
Commit: c511b8968a28c0c98a36500f68881cde0d2104bd
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=c511b8968a28c0c98a36500f68881cde0d2104bd

Author: Konstantin Seurer <konstantin.seu...@gmail.com>
Date:   Thu Jan  4 17:38:34 2024 +0100

radv: Implement VK_KHR_ray_tracing_position_fetch

Reviewed-by: Friedrich Vock <friedrich.v...@gmx.de>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26895>

---

 docs/features.txt                               |  1 +
 docs/relnotes/new_features.txt                  |  1 +
 src/amd/vulkan/nir/radv_nir_lower_ray_queries.c | 10 ++++++++--
 src/amd/vulkan/nir/radv_nir_rt_common.c         | 14 ++++++++++++++
 src/amd/vulkan/nir/radv_nir_rt_common.h         |  3 +++
 src/amd/vulkan/nir/radv_nir_rt_shader.c         | 16 ++++++++++++----
 src/amd/vulkan/radv_physical_device.c           |  4 ++++
 src/amd/vulkan/radv_shader.c                    |  1 +
 8 files changed, 44 insertions(+), 6 deletions(-)

diff --git a/docs/features.txt b/docs/features.txt
index 367ebd4a689..356f22cfa2a 100644
--- a/docs/features.txt
+++ b/docs/features.txt
@@ -529,6 +529,7 @@ Khronos extensions that are not part of any Vulkan version:
   VK_KHR_ray_query                                      DONE (anv/gfx12.5+, 
radv/gfx10.3+)
   VK_KHR_ray_tracing_maintenance1                       DONE (anv/gfx12.5+, 
radv/gfx10.3+)
   VK_KHR_ray_tracing_pipeline                           DONE (anv/gfx12.5+, 
radv/gfx10.3+)
+  VK_KHR_ray_tracing_position_fetch                     DONE (radv/gfx10.3+)
   VK_KHR_shader_clock                                   DONE (anv, hasvk, lvp, 
nvk, radv, vn)
   VK_KHR_shader_subgroup_uniform_control_flow           DONE (anv, hasvk, radv)
   VK_KHR_shared_presentable_image                       not started
diff --git a/docs/relnotes/new_features.txt b/docs/relnotes/new_features.txt
index 3f9dbe04660..44d7822f0f6 100644
--- a/docs/relnotes/new_features.txt
+++ b/docs/relnotes/new_features.txt
@@ -19,3 +19,4 @@ GL_ARB_cull_distance on Asahi
 VK_KHR_calibrated_timestamps on RADV
 VK_KHR_vertex_attribute_divisor on RADV
 VK_KHR_maintenance6 on RADV
+VK_KHR_ray_tracing_position_fetch on RADV
diff --git a/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c 
b/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c
index f6c97fdce3d..af871fa234a 100644
--- a/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c
+++ b/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c
@@ -394,7 +394,8 @@ lower_rq_initialize(nir_builder *b, nir_def *index, 
nir_intrinsic_instr *instr,
 }
 
 static nir_def *
-lower_rq_load(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr, 
struct ray_query_vars *vars)
+lower_rq_load(struct radv_device *device, nir_builder *b, nir_def *index, 
nir_intrinsic_instr *instr,
+              struct ray_query_vars *vars)
 {
    bool committed = nir_intrinsic_committed(instr);
    struct ray_query_intersection_vars *intersection = committed ? 
&vars->closest : &vars->candidate;
@@ -482,6 +483,11 @@ lower_rq_load(nir_builder *b, nir_def *index, 
nir_intrinsic_instr *instr, struct
       return rq_load_var(b, index, vars->direction);
    case nir_ray_query_value_world_ray_origin:
       return rq_load_var(b, index, vars->origin);
+   case nir_ray_query_value_intersection_triangle_vertex_positions: {
+      nir_def *instance_node_addr = rq_load_var(b, index, 
intersection->instance_addr);
+      nir_def *primitive_id = rq_load_var(b, index, 
intersection->primitive_id);
+      return radv_load_vertex_position(device, b, instance_node_addr, 
primitive_id, nir_intrinsic_column(instr));
+   }
    default:
       unreachable("Invalid nir_ray_query_value!");
    }
@@ -707,7 +713,7 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, 
struct radv_device *device
                lower_rq_initialize(&builder, index, intrinsic, vars, 
device->instance);
                break;
             case nir_intrinsic_rq_load:
-               new_dest = lower_rq_load(&builder, index, intrinsic, vars);
+               new_dest = lower_rq_load(device, &builder, index, intrinsic, 
vars);
                break;
             case nir_intrinsic_rq_proceed:
                new_dest = lower_rq_proceed(&builder, index, intrinsic, vars, 
device);
diff --git a/src/amd/vulkan/nir/radv_nir_rt_common.c 
b/src/amd/vulkan/nir/radv_nir_rt_common.c
index 51f4d209448..9ae88c44ca2 100644
--- a/src/amd/vulkan/nir/radv_nir_rt_common.c
+++ b/src/amd/vulkan/nir/radv_nir_rt_common.c
@@ -312,6 +312,20 @@ nir_build_wto_matrix_load(nir_builder *b, nir_def 
*instance_addr, nir_def **out)
    }
 }
 
+nir_def *
+radv_load_vertex_position(struct radv_device *device, nir_builder *b, nir_def 
*instance_addr, nir_def *primitive_id,
+                          uint32_t index)
+{
+   nir_def *bvh_addr_id =
+      nir_build_load_global(b, 1, 64, nir_iadd_imm(b, instance_addr, 
offsetof(struct radv_bvh_instance_node, bvh_ptr)));
+   nir_def *bvh_addr = build_node_to_addr(device, b, bvh_addr_id, true);
+
+   nir_def *offset = nir_imul_imm(b, primitive_id, sizeof(struct 
radv_bvh_triangle_node));
+   offset = nir_iadd_imm(b, offset, sizeof(struct radv_bvh_box32_node) + index 
* 3 * sizeof(float));
+
+   return nir_build_load_global(b, 3, 32, nir_iadd(b, bvh_addr, nir_u2u64(b, 
offset)));
+}
+
 /* When a hit is opaque the any_hit shader is skipped for this hit and the hit
  * is assumed to be an actual hit. */
 static nir_def *
diff --git a/src/amd/vulkan/nir/radv_nir_rt_common.h 
b/src/amd/vulkan/nir/radv_nir_rt_common.h
index 6973bb246c4..5ccf38ea1f0 100644
--- a/src/amd/vulkan/nir/radv_nir_rt_common.h
+++ b/src/amd/vulkan/nir/radv_nir_rt_common.h
@@ -38,6 +38,9 @@ nir_def *nir_build_vec3_mat_mult(nir_builder *b, nir_def 
*vec, nir_def *matrix[]
 
 void nir_build_wto_matrix_load(nir_builder *b, nir_def *instance_addr, nir_def 
**out);
 
+nir_def *radv_load_vertex_position(struct radv_device *device, nir_builder *b, 
nir_def *instance_addr,
+                                   nir_def *primitive_id, uint32_t index);
+
 struct radv_ray_traversal_args;
 
 struct radv_ray_flags {
diff --git a/src/amd/vulkan/nir/radv_nir_rt_shader.c 
b/src/amd/vulkan/nir/radv_nir_rt_shader.c
index 4094cd9a6ff..d78dfecbae3 100644
--- a/src/amd/vulkan/nir/radv_nir_rt_shader.c
+++ b/src/amd/vulkan/nir/radv_nir_rt_shader.c
@@ -185,6 +185,7 @@ lower_rt_derefs(nir_shader *shader)
  * Global variables for an RT pipeline
  */
 struct rt_variables {
+   struct radv_device *device;
    const VkPipelineCreateFlags2KHR flags;
 
    /* idx of the next shader to run in the next iteration of the main loop.
@@ -229,9 +230,10 @@ struct rt_variables {
 };
 
 static struct rt_variables
-create_rt_variables(nir_shader *shader, const VkPipelineCreateFlags2KHR flags)
+create_rt_variables(nir_shader *shader, struct radv_device *device, const 
VkPipelineCreateFlags2KHR flags)
 {
    struct rt_variables vars = {
+      .device = device,
       .flags = flags,
    };
    vars.idx = nir_variable_create(shader, nir_var_shader_temp, 
glsl_uint_type(), "idx");
@@ -660,6 +662,12 @@ radv_lower_rt_instruction(nir_builder *b, nir_instr 
*instr, void *_data)
 
       break;
    }
+   case nir_intrinsic_load_ray_triangle_vertex_positions: {
+      nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
+      nir_def *primitive_id = nir_load_var(b, vars->primitive_id);
+      ret = radv_load_vertex_position(vars->device, b, instance_node_addr, 
primitive_id, nir_intrinsic_column(intr));
+      break;
+   }
    default:
       return false;
    }
@@ -782,7 +790,7 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct 
rt_variables *vars, ni
 
    nir_opt_dead_cf(shader);
 
-   struct rt_variables src_vars = create_rt_variables(shader, vars->flags);
+   struct rt_variables src_vars = create_rt_variables(shader, vars->device, 
vars->flags);
    map_rt_variables(var_remap, &src_vars, vars);
 
    NIR_PASS_V(shader, lower_rt_instructions, &src_vars, false);
@@ -1506,7 +1514,7 @@ radv_build_traversal_shader(struct radv_device *device, 
struct radv_ray_tracing_
    b.shader->info.workgroup_size[0] = 8;
    b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 
64 ? 8 : 4;
    b.shader->info.shared_size = device->physical_device->rt_wave_size * 
MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
-   struct rt_variables vars = create_rt_variables(b.shader, create_flags);
+   struct rt_variables vars = create_rt_variables(b.shader, device, 
create_flags);
 
    /* initialize trace_ray arguments */
    nir_store_var(&b, vars.accel_struct, nir_load_accel_struct_amd(&b), 1);
@@ -1674,7 +1682,7 @@ radv_nir_lower_rt_abi(nir_shader *shader, const 
VkRayTracingPipelineCreateInfoKH
 
    const VkPipelineCreateFlagBits2KHR create_flags = 
vk_rt_pipeline_create_flags(pCreateInfo);
 
-   struct rt_variables vars = create_rt_variables(shader, create_flags);
+   struct rt_variables vars = create_rt_variables(shader, device, 
create_flags);
 
    if (monolithic)
       lower_rt_instructions_monolithic(shader, device, pipeline, pCreateInfo, 
&vars);
diff --git a/src/amd/vulkan/radv_physical_device.c 
b/src/amd/vulkan/radv_physical_device.c
index 2e3efeab11e..fe3ec275ef5 100644
--- a/src/amd/vulkan/radv_physical_device.c
+++ b/src/amd/vulkan/radv_physical_device.c
@@ -473,6 +473,7 @@ radv_physical_device_get_supported_extensions(const struct 
radv_physical_device
       .KHR_ray_query = radv_enable_rt(device, false),
       .KHR_ray_tracing_maintenance1 = radv_enable_rt(device, false),
       .KHR_ray_tracing_pipeline = radv_enable_rt(device, true),
+      .KHR_ray_tracing_position_fetch = radv_enable_rt(device, false),
       .KHR_relaxed_block_layout = true,
       .KHR_sampler_mirror_clamp_to_edge = true,
       .KHR_sampler_ycbcr_conversion = true,
@@ -946,6 +947,9 @@ radv_physical_device_get_features(const struct 
radv_physical_device *pdevice, st
       .rayTracingMaintenance1 = true,
       .rayTracingPipelineTraceRaysIndirect2 = radv_enable_rt(pdevice, true),
 
+      /* VK_KHR_ray_tracing_position_fetch */
+      .rayTracingPositionFetch = true,
+
       /* VK_EXT_vertex_input_dynamic_state */
       .vertexInputDynamicState = true,
 
diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c
index a24933f0db0..b2753470291 100644
--- a/src/amd/vulkan/radv_shader.c
+++ b/src/amd/vulkan/radv_shader.c
@@ -434,6 +434,7 @@ radv_shader_spirv_to_nir(struct radv_device *device, const 
struct radv_shader_st
                .ray_cull_mask = true,
                .ray_query = true,
                .ray_tracing = true,
+               .ray_tracing_position_fetch = true,
                .ray_traversal_primitive_culling = true,
                .runtime_descriptor_array = true,
                .shader_clock = true,

Reply via email to