Module: Mesa Branch: main Commit: 4fad4c1d790f855bf93d4026371f175fff1d2c12 URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=4fad4c1d790f855bf93d4026371f175fff1d2c12
Author: Emma Anholt <[email protected]> Date: Thu Feb 10 14:15:01 2022 -0800 gallivm/nir: Refactor out some repeated logic for SSBO/shared access. I needed to be able to get these pointers/limits from another location, and missing some of the repeated steps was giving me bugs. Reviewed-by: Dave Airlie <[email protected]> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14999> --- src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c | 105 ++++++++++++++----------- 1 file changed, 57 insertions(+), 48 deletions(-) diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c index 824b3299e20..0213aac69ec 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c +++ b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c @@ -1065,6 +1065,51 @@ emit_load_const(struct lp_build_nir_context *bld_base, memset(&outval[instr->def.num_components], 0, NIR_MAX_VEC_COMPONENTS - instr->def.num_components); } +/** + * Get the base address of SSBO[@index] for the @invocation channel, returning + * the address and also the bounds (in units of the bit_size). + */ +static LLVMValueRef +ssbo_base_pointer(struct lp_build_nir_context *bld_base, + unsigned bit_size, + LLVMValueRef index, LLVMValueRef invocation, LLVMValueRef *bounds) +{ + struct gallivm_state *gallivm = bld_base->base.gallivm; + struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base; + uint32_t shift_val = bit_size_to_shift_size(bit_size); + + LLVMValueRef ssbo_idx = LLVMBuildExtractElement(gallivm->builder, index, invocation, ""); + LLVMValueRef ssbo_size_ptr = lp_build_array_get(gallivm, bld->ssbo_sizes_ptr, ssbo_idx); + LLVMValueRef ssbo_ptr = lp_build_array_get(gallivm, bld->ssbo_ptr, ssbo_idx); + if (bounds) + *bounds = LLVMBuildAShr(gallivm->builder, ssbo_size_ptr, lp_build_const_int32(gallivm, shift_val), ""); + + return ssbo_ptr; +} + +static LLVMValueRef +mem_access_base_pointer(struct lp_build_nir_context *bld_base, + struct lp_build_context *mem_bld, + unsigned bit_size, + LLVMValueRef index, LLVMValueRef invocation, LLVMValueRef *bounds) +{ + struct gallivm_state *gallivm = bld_base->base.gallivm; + struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base; + LLVMValueRef ptr; + + if (index) { + ptr = ssbo_base_pointer(bld_base, bit_size, index, invocation, bounds); + } else { + ptr = bld->shared_ptr; + *bounds = NULL; + } + + /* Cast it to the pointer type of the access this instruciton is doing. */ + if (bit_size == 32) + return ptr; + else + return LLVMBuildBitCast(gallivm->builder, ptr, LLVMPointerType(mem_bld->elem_type, 0), ""); +} static void emit_load_mem(struct lp_build_nir_context *bld_base, unsigned nc, @@ -1077,7 +1122,6 @@ static void emit_load_mem(struct lp_build_nir_context *bld_base, struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base; LLVMBuilderRef builder = bld->bld_base.base.gallivm->builder; struct lp_build_context *uint_bld = &bld_base->uint_bld; - LLVMValueRef ssbo_limit = NULL; struct lp_build_context *load_bld; uint32_t shift_val = bit_size_to_shift_size(bit_size); @@ -1101,16 +1145,9 @@ static void emit_load_mem(struct lp_build_nir_context *bld_base, struct lp_build_if_state exec_ifthen; lp_build_if(&exec_ifthen, gallivm, loop_cond); - LLVMValueRef mem_ptr; - - if (index) { - LLVMValueRef ssbo_idx = LLVMBuildExtractElement(gallivm->builder, index, loop_state.counter, ""); - LLVMValueRef ssbo_size_ptr = lp_build_array_get(gallivm, bld->ssbo_sizes_ptr, ssbo_idx); - LLVMValueRef ssbo_ptr = lp_build_array_get(gallivm, bld->ssbo_ptr, ssbo_idx); - ssbo_limit = LLVMBuildAShr(gallivm->builder, ssbo_size_ptr, lp_build_const_int32(gallivm, shift_val), ""); - mem_ptr = ssbo_ptr; - } else - mem_ptr = bld->shared_ptr; + LLVMValueRef ssbo_limit; + LLVMValueRef mem_ptr = mem_access_base_pointer(bld_base, load_bld, bit_size, index, + loop_state.counter, &ssbo_limit); for (unsigned c = 0; c < nc; c++) { LLVMValueRef loop_index = LLVMBuildAdd(builder, loop_offset, lp_build_const_int32(gallivm, c), ""); @@ -1126,12 +1163,7 @@ static void emit_load_mem(struct lp_build_nir_context *bld_base, fetch_cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, do_fetch, lp_build_const_int32(gallivm, 0), ""); lp_build_if(&ifthen, gallivm, fetch_cond); - LLVMValueRef scalar; - if (bit_size != 32) { - LLVMValueRef mem_ptr2 = LLVMBuildBitCast(builder, mem_ptr, LLVMPointerType(load_bld->elem_type, 0), ""); - scalar = lp_build_pointer_get(builder, mem_ptr2, loop_index); - } else - scalar = lp_build_pointer_get(builder, mem_ptr, loop_index); + LLVMValueRef scalar = lp_build_pointer_get(builder, mem_ptr, loop_index); temp_res = LLVMBuildLoad(builder, result[c], ""); temp_res = LLVMBuildInsertElement(builder, temp_res, scalar, loop_state.counter, ""); @@ -1171,9 +1203,7 @@ static void emit_store_mem(struct lp_build_nir_context *bld_base, struct gallivm_state *gallivm = bld_base->base.gallivm; struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base; LLVMBuilderRef builder = bld->bld_base.base.gallivm->builder; - LLVMValueRef mem_ptr; struct lp_build_context *uint_bld = &bld_base->uint_bld; - LLVMValueRef ssbo_limit = NULL; struct lp_build_context *store_bld; uint32_t shift_val = bit_size_to_shift_size(bit_size); store_bld = get_int_bld(bld_base, true, bit_size); @@ -1190,14 +1220,9 @@ static void emit_store_mem(struct lp_build_nir_context *bld_base, struct lp_build_if_state exec_ifthen; lp_build_if(&exec_ifthen, gallivm, loop_cond); - if (index) { - LLVMValueRef ssbo_idx = LLVMBuildExtractElement(gallivm->builder, index, loop_state.counter, ""); - LLVMValueRef ssbo_size_ptr = lp_build_array_get(gallivm, bld->ssbo_sizes_ptr, ssbo_idx); - LLVMValueRef ssbo_ptr = lp_build_array_get(gallivm, bld->ssbo_ptr, ssbo_idx); - ssbo_limit = LLVMBuildAShr(gallivm->builder, ssbo_size_ptr, lp_build_const_int32(gallivm, shift_val), ""); - mem_ptr = ssbo_ptr; - } else - mem_ptr = bld->shared_ptr; + LLVMValueRef ssbo_limit; + LLVMValueRef mem_ptr = mem_access_base_pointer(bld_base, store_bld, bit_size, index, + loop_state.counter, &ssbo_limit); for (unsigned c = 0; c < nc; c++) { if (!(writemask & (1u << c))) @@ -1219,11 +1244,7 @@ static void emit_store_mem(struct lp_build_nir_context *bld_base, store_cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, do_store, lp_build_const_int32(gallivm, 0), ""); lp_build_if(&ifthen, gallivm, store_cond); - if (bit_size != 32) { - LLVMValueRef mem_ptr2 = LLVMBuildBitCast(builder, mem_ptr, LLVMPointerType(store_bld->elem_type, 0), ""); - lp_build_pointer_set(builder, mem_ptr2, loop_index, value_ptr); - } else - lp_build_pointer_set(builder, mem_ptr, loop_index, value_ptr); + lp_build_pointer_set(builder, mem_ptr, loop_index, value_ptr); lp_build_endif(&ifthen); } @@ -1244,7 +1265,6 @@ static void emit_atomic_mem(struct lp_build_nir_context *bld_base, struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base; LLVMBuilderRef builder = bld->bld_base.base.gallivm->builder; struct lp_build_context *uint_bld = &bld_base->uint_bld; - LLVMValueRef ssbo_limit = NULL; uint32_t shift_val = bit_size_to_shift_size(bit_size); struct lp_build_context *atomic_bld = get_int_bld(bld_base, true, bit_size); @@ -1262,15 +1282,9 @@ static void emit_atomic_mem(struct lp_build_nir_context *bld_base, struct lp_build_if_state exec_ifthen; lp_build_if(&exec_ifthen, gallivm, loop_cond); - LLVMValueRef mem_ptr; - if (index) { - LLVMValueRef ssbo_idx = LLVMBuildExtractElement(gallivm->builder, index, loop_state.counter, ""); - LLVMValueRef ssbo_size_ptr = lp_build_array_get(gallivm, bld->ssbo_sizes_ptr, ssbo_idx); - LLVMValueRef ssbo_ptr = lp_build_array_get(gallivm, bld->ssbo_ptr, ssbo_idx); - ssbo_limit = LLVMBuildAShr(gallivm->builder, ssbo_size_ptr, lp_build_const_int32(gallivm, shift_val), ""); - mem_ptr = ssbo_ptr; - } else - mem_ptr = bld->shared_ptr; + LLVMValueRef ssbo_limit; + LLVMValueRef mem_ptr = mem_access_base_pointer(bld_base, atomic_bld, bit_size, index, + loop_state.counter, &ssbo_limit); LLVMValueRef do_fetch = lp_build_const_int32(gallivm, -1); if (ssbo_limit) { @@ -1282,12 +1296,7 @@ static void emit_atomic_mem(struct lp_build_nir_context *bld_base, loop_state.counter, ""); value_ptr = LLVMBuildBitCast(gallivm->builder, value_ptr, atomic_bld->elem_type, ""); - LLVMValueRef scalar_ptr; - if (bit_size != 32) { - LLVMValueRef mem_ptr2 = LLVMBuildBitCast(builder, mem_ptr, LLVMPointerType(atomic_bld->elem_type, 0), ""); - scalar_ptr = LLVMBuildGEP(builder, mem_ptr2, &loop_offset, 1, ""); - } else - scalar_ptr = LLVMBuildGEP(builder, mem_ptr, &loop_offset, 1, ""); + LLVMValueRef scalar_ptr = LLVMBuildGEP(builder, mem_ptr, &loop_offset, 1, ""); struct lp_build_if_state ifthen; LLVMValueRef inner_cond, temp_res;
