Mesa (main): gallivm/nir: Refactor out some repeated logic for SSBO/shared access.

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Wed Apr 6 00:37:10 UTC 2022


Module: Mesa
Branch: main
Commit: 4fad4c1d790f855bf93d4026371f175fff1d2c12
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=4fad4c1d790f855bf93d4026371f175fff1d2c12

Author: Emma Anholt <emma at anholt.net>
Date:   Thu Feb 10 14:15:01 2022 -0800

gallivm/nir: Refactor out some repeated logic for SSBO/shared access.

I needed to be able to get these pointers/limits from another location,
and missing some of the repeated steps was giving me bugs.

Reviewed-by: Dave Airlie <airlied at redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14999>

---

 src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c | 105 ++++++++++++++-----------
 1 file changed, 57 insertions(+), 48 deletions(-)

diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c
index 824b3299e20..0213aac69ec 100644
--- a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c
+++ b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c
@@ -1065,6 +1065,51 @@ emit_load_const(struct lp_build_nir_context *bld_base,
    memset(&outval[instr->def.num_components], 0, NIR_MAX_VEC_COMPONENTS - instr->def.num_components);
 }
 
+/**
+ * Get the base address of SSBO[@index] for the @invocation channel, returning
+ * the address and also the bounds (in units of the bit_size).
+ */
+static LLVMValueRef
+ssbo_base_pointer(struct lp_build_nir_context *bld_base,
+                  unsigned bit_size,
+                  LLVMValueRef index, LLVMValueRef invocation, LLVMValueRef *bounds)
+{
+   struct gallivm_state *gallivm = bld_base->base.gallivm;
+   struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base;
+   uint32_t shift_val = bit_size_to_shift_size(bit_size);
+
+   LLVMValueRef ssbo_idx = LLVMBuildExtractElement(gallivm->builder, index, invocation, "");
+   LLVMValueRef ssbo_size_ptr = lp_build_array_get(gallivm, bld->ssbo_sizes_ptr, ssbo_idx);
+   LLVMValueRef ssbo_ptr = lp_build_array_get(gallivm, bld->ssbo_ptr, ssbo_idx);
+   if (bounds)
+      *bounds = LLVMBuildAShr(gallivm->builder, ssbo_size_ptr, lp_build_const_int32(gallivm, shift_val), "");
+
+   return ssbo_ptr;
+}
+
+static LLVMValueRef
+mem_access_base_pointer(struct lp_build_nir_context *bld_base,
+                        struct lp_build_context *mem_bld,
+                        unsigned bit_size,
+                        LLVMValueRef index, LLVMValueRef invocation, LLVMValueRef *bounds)
+{
+   struct gallivm_state *gallivm = bld_base->base.gallivm;
+   struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base;
+   LLVMValueRef ptr;
+
+   if (index) {
+      ptr = ssbo_base_pointer(bld_base, bit_size, index, invocation, bounds);
+   } else {
+      ptr = bld->shared_ptr;
+      *bounds = NULL;
+   }
+
+   /* Cast it to the pointer type of the access this instruciton is doing. */
+   if (bit_size == 32)
+      return ptr;
+   else
+      return LLVMBuildBitCast(gallivm->builder, ptr, LLVMPointerType(mem_bld->elem_type, 0), "");
+}
 
 static void emit_load_mem(struct lp_build_nir_context *bld_base,
                           unsigned nc,
@@ -1077,7 +1122,6 @@ static void emit_load_mem(struct lp_build_nir_context *bld_base,
    struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base;
    LLVMBuilderRef builder = bld->bld_base.base.gallivm->builder;
    struct lp_build_context *uint_bld = &bld_base->uint_bld;
-   LLVMValueRef ssbo_limit = NULL;
    struct lp_build_context *load_bld;
    uint32_t shift_val = bit_size_to_shift_size(bit_size);
 
@@ -1101,16 +1145,9 @@ static void emit_load_mem(struct lp_build_nir_context *bld_base,
    struct lp_build_if_state exec_ifthen;
    lp_build_if(&exec_ifthen, gallivm, loop_cond);
 
-   LLVMValueRef mem_ptr;
-
-   if (index) {
-      LLVMValueRef ssbo_idx = LLVMBuildExtractElement(gallivm->builder, index, loop_state.counter, "");
-      LLVMValueRef ssbo_size_ptr = lp_build_array_get(gallivm, bld->ssbo_sizes_ptr, ssbo_idx);
-      LLVMValueRef ssbo_ptr = lp_build_array_get(gallivm, bld->ssbo_ptr, ssbo_idx);
-      ssbo_limit = LLVMBuildAShr(gallivm->builder, ssbo_size_ptr, lp_build_const_int32(gallivm, shift_val), "");
-      mem_ptr = ssbo_ptr;
-   } else
-      mem_ptr = bld->shared_ptr;
+   LLVMValueRef ssbo_limit;
+   LLVMValueRef mem_ptr = mem_access_base_pointer(bld_base, load_bld, bit_size, index,
+                                                  loop_state.counter, &ssbo_limit);
 
    for (unsigned c = 0; c < nc; c++) {
       LLVMValueRef loop_index = LLVMBuildAdd(builder, loop_offset, lp_build_const_int32(gallivm, c), "");
@@ -1126,12 +1163,7 @@ static void emit_load_mem(struct lp_build_nir_context *bld_base,
       fetch_cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, do_fetch, lp_build_const_int32(gallivm, 0), "");
 
       lp_build_if(&ifthen, gallivm, fetch_cond);
-      LLVMValueRef scalar;
-      if (bit_size != 32) {
-         LLVMValueRef mem_ptr2 = LLVMBuildBitCast(builder, mem_ptr, LLVMPointerType(load_bld->elem_type, 0), "");
-         scalar = lp_build_pointer_get(builder, mem_ptr2, loop_index);
-      } else
-         scalar = lp_build_pointer_get(builder, mem_ptr, loop_index);
+      LLVMValueRef scalar = lp_build_pointer_get(builder, mem_ptr, loop_index);
 
       temp_res = LLVMBuildLoad(builder, result[c], "");
       temp_res = LLVMBuildInsertElement(builder, temp_res, scalar, loop_state.counter, "");
@@ -1171,9 +1203,7 @@ static void emit_store_mem(struct lp_build_nir_context *bld_base,
    struct gallivm_state *gallivm = bld_base->base.gallivm;
    struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base;
    LLVMBuilderRef builder = bld->bld_base.base.gallivm->builder;
-   LLVMValueRef mem_ptr;
    struct lp_build_context *uint_bld = &bld_base->uint_bld;
-   LLVMValueRef ssbo_limit = NULL;
    struct lp_build_context *store_bld;
    uint32_t shift_val = bit_size_to_shift_size(bit_size);
    store_bld = get_int_bld(bld_base, true, bit_size);
@@ -1190,14 +1220,9 @@ static void emit_store_mem(struct lp_build_nir_context *bld_base,
    struct lp_build_if_state exec_ifthen;
    lp_build_if(&exec_ifthen, gallivm, loop_cond);
 
-   if (index) {
-      LLVMValueRef ssbo_idx = LLVMBuildExtractElement(gallivm->builder, index, loop_state.counter, "");
-      LLVMValueRef ssbo_size_ptr = lp_build_array_get(gallivm, bld->ssbo_sizes_ptr, ssbo_idx);
-      LLVMValueRef ssbo_ptr = lp_build_array_get(gallivm, bld->ssbo_ptr, ssbo_idx);
-      ssbo_limit = LLVMBuildAShr(gallivm->builder, ssbo_size_ptr, lp_build_const_int32(gallivm, shift_val), "");
-      mem_ptr = ssbo_ptr;
-   } else
-      mem_ptr = bld->shared_ptr;
+   LLVMValueRef ssbo_limit;
+   LLVMValueRef mem_ptr = mem_access_base_pointer(bld_base, store_bld, bit_size, index,
+                                                  loop_state.counter, &ssbo_limit);
 
    for (unsigned c = 0; c < nc; c++) {
       if (!(writemask & (1u << c)))
@@ -1219,11 +1244,7 @@ static void emit_store_mem(struct lp_build_nir_context *bld_base,
 
       store_cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, do_store, lp_build_const_int32(gallivm, 0), "");
       lp_build_if(&ifthen, gallivm, store_cond);
-      if (bit_size != 32) {
-         LLVMValueRef mem_ptr2 = LLVMBuildBitCast(builder, mem_ptr, LLVMPointerType(store_bld->elem_type, 0), "");
-         lp_build_pointer_set(builder, mem_ptr2, loop_index, value_ptr);
-      } else
-         lp_build_pointer_set(builder, mem_ptr, loop_index, value_ptr);
+      lp_build_pointer_set(builder, mem_ptr, loop_index, value_ptr);
       lp_build_endif(&ifthen);
    }
 
@@ -1244,7 +1265,6 @@ static void emit_atomic_mem(struct lp_build_nir_context *bld_base,
    struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base;
    LLVMBuilderRef builder = bld->bld_base.base.gallivm->builder;
    struct lp_build_context *uint_bld = &bld_base->uint_bld;
-   LLVMValueRef ssbo_limit = NULL;
    uint32_t shift_val = bit_size_to_shift_size(bit_size);
    struct lp_build_context *atomic_bld = get_int_bld(bld_base, true, bit_size);
 
@@ -1262,15 +1282,9 @@ static void emit_atomic_mem(struct lp_build_nir_context *bld_base,
    struct lp_build_if_state exec_ifthen;
    lp_build_if(&exec_ifthen, gallivm, loop_cond);
 
-   LLVMValueRef mem_ptr;
-   if (index) {
-      LLVMValueRef ssbo_idx = LLVMBuildExtractElement(gallivm->builder, index, loop_state.counter, "");
-      LLVMValueRef ssbo_size_ptr = lp_build_array_get(gallivm, bld->ssbo_sizes_ptr, ssbo_idx);
-      LLVMValueRef ssbo_ptr = lp_build_array_get(gallivm, bld->ssbo_ptr, ssbo_idx);
-      ssbo_limit = LLVMBuildAShr(gallivm->builder, ssbo_size_ptr, lp_build_const_int32(gallivm, shift_val), "");
-      mem_ptr = ssbo_ptr;
-   } else
-      mem_ptr = bld->shared_ptr;
+   LLVMValueRef ssbo_limit;
+   LLVMValueRef mem_ptr = mem_access_base_pointer(bld_base, atomic_bld, bit_size, index,
+                                                  loop_state.counter, &ssbo_limit);
 
    LLVMValueRef do_fetch = lp_build_const_int32(gallivm, -1);
    if (ssbo_limit) {
@@ -1282,12 +1296,7 @@ static void emit_atomic_mem(struct lp_build_nir_context *bld_base,
                                                     loop_state.counter, "");
    value_ptr = LLVMBuildBitCast(gallivm->builder, value_ptr, atomic_bld->elem_type, "");
 
-   LLVMValueRef scalar_ptr;
-   if (bit_size != 32) {
-      LLVMValueRef mem_ptr2 = LLVMBuildBitCast(builder, mem_ptr, LLVMPointerType(atomic_bld->elem_type, 0), "");
-      scalar_ptr = LLVMBuildGEP(builder, mem_ptr2, &loop_offset, 1, "");
-   } else
-      scalar_ptr = LLVMBuildGEP(builder, mem_ptr, &loop_offset, 1, "");
+   LLVMValueRef scalar_ptr = LLVMBuildGEP(builder, mem_ptr, &loop_offset, 1, "");
 
    struct lp_build_if_state ifthen;
    LLVMValueRef inner_cond, temp_res;



More information about the mesa-commit mailing list