Mesa (master): amd/llvm: Add Subgroup Scan functions for SI

GitLab Mirror gitlab-mirror at kemper.freedesktop.org
Wed Nov 20 20:47:13 UTC 2019


Module: Mesa
Branch: master
Commit: 0cbcfc071e32fd5fc9950a5660adb7dafb7aaef0
URL:    http://cgit.freedesktop.org/mesa/mesa/commit/?id=0cbcfc071e32fd5fc9950a5660adb7dafb7aaef0

Author: Daniel Schürmann <daniel at schuermann.dev>
Date:   Wed Nov 20 12:40:07 2019 +0100

amd/llvm: Add Subgroup Scan functions for SI

The idea of this implementation is taken from the ROCm Device Libs:
https://github.com/RadeonOpenCompute/ROCm-Device-Libs/blob/master/ockl/src/wfredscan.cl

Reviewed-by: Samuel Pitoiset <samuel.pitoiset at gmail.com>

---

 src/amd/llvm/ac_llvm_build.c | 81 ++++++++++++++++++++++++++++++++++++++++----
 1 file changed, 75 insertions(+), 6 deletions(-)

diff --git a/src/amd/llvm/ac_llvm_build.c b/src/amd/llvm/ac_llvm_build.c
index d418ee5ab71..12ee06c2678 100644
--- a/src/amd/llvm/ac_llvm_build.c
+++ b/src/amd/llvm/ac_llvm_build.c
@@ -4042,8 +4042,6 @@ ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs,
 /**
  * \param maxprefix specifies that the result only needs to be correct for a
  *     prefix of this many threads
- *
- * TODO: add inclusive and excluse scan functions for GFX6.
  */
 static LLVMValueRef
 ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity,
@@ -4051,13 +4049,84 @@ ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValu
 {
 	LLVMValueRef result, tmp;
 
-	if (ctx->chip_class >= GFX10) {
-		result = inclusive ? src : identity;
+	if (inclusive) {
+		result = src;
+	} else if (ctx->chip_class >= GFX10) {
+		result = identity;
+	} else if (ctx->chip_class >= GFX8) {
+		src = ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false);
+		result = src;
 	} else {
-		if (!inclusive)
-			src = ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false);
+		/* wavefront shift_right by 1 on SI/CI */
+		LLVMValueRef active, tmp1, tmp2;
+		LLVMValueRef tid = ac_get_thread_id(ctx);
+		tmp1 = ac_build_ds_swizzle(ctx, src, (1 << 15) | dpp_quad_perm(0, 0, 1, 2));
+		tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x18, 0x03, 0x00));
+		active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+				       LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x7, 0), ""),
+				       LLVMConstInt(ctx->i32, 0x4, 0), "");
+		tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+		tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x10, 0x07, 0x00));
+		active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+				       LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0xf, 0), ""),
+				       LLVMConstInt(ctx->i32, 0x8, 0), "");
+		tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+		tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x00, 0x0f, 0x00));
+		active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+				       LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x1f, 0), ""),
+				       LLVMConstInt(ctx->i32, 0x10, 0), "");
+		tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+		tmp2 = ac_build_readlane(ctx, src, LLVMConstInt(ctx->i32, 31, 0));
+		active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 32, 0), "");
+		tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+		active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 0, 0), "");
+		src = LLVMBuildSelect(ctx->builder, active, identity, tmp1, "");
 		result = src;
+        }
+
+	if (ctx->chip_class <= GFX7) {
+		assert(maxprefix == 64);
+		LLVMValueRef tid = ac_get_thread_id(ctx);
+		LLVMValueRef active;
+		tmp = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x1e, 0x00, 0x00));
+		active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
+				       LLVMBuildAnd(ctx->builder, tid, ctx->i32_1, ""),
+				       ctx->i32_0, "");
+		tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
+		result = ac_build_alu_op(ctx, result, tmp, op);
+		tmp = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x1c, 0x01, 0x00));
+		active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
+				       LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 2, 0), ""),
+				       ctx->i32_0, "");
+		tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
+		result = ac_build_alu_op(ctx, result, tmp, op);
+		tmp = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x18, 0x03, 0x00));
+		active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
+				       LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 4, 0), ""),
+				       ctx->i32_0, "");
+		tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
+		result = ac_build_alu_op(ctx, result, tmp, op);
+		tmp = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x10, 0x07, 0x00));
+		active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
+				       LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 8, 0), ""),
+				       ctx->i32_0, "");
+		tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
+		result = ac_build_alu_op(ctx, result, tmp, op);
+		tmp = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x00, 0x0f, 0x00));
+		active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
+				       LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 16, 0), ""),
+				       ctx->i32_0, "");
+		tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
+		result = ac_build_alu_op(ctx, result, tmp, op);
+		tmp = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 31, 0));
+		active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
+				       LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 32, 0), ""),
+				       ctx->i32_0, "");
+		tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
+		result = ac_build_alu_op(ctx, result, tmp, op);
+		return result;
 	}
+
 	if (maxprefix <= 1)
 		return result;
 	tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(1), 0xf, 0xf, false);




More information about the mesa-commit mailing list