[Mesa-dev] [PATCH 06/25] amd/common: scan/reduce across waves of a workgroup

Haehnle, Nicolai Nicolai.Haehnle at amd.com
Fri Dec 7 14:32:24 UTC 2018


On 06.12.18 15:20, Connor Abbott wrote:
> Is this going to be used by an extension? If you don't have a use for
> it yet, it would probably be better to wait.

Well, I have been using it quite extensively in a branch I've been 
working on, but that's not quite ready yet.

Cheers,
Nicolai


> On Thu, Dec 6, 2018 at 3:01 PM Nicolai Hähnle <nhaehnle at gmail.com> wrote:
>>
>> From: Nicolai Hähnle <nicolai.haehnle at amd.com>
>>
>> Order-aware scan/reduce can trade-off LDS traffic for external atomics
>> memory traffic in producer/consumer compute shaders.
>> ---
>>   src/amd/common/ac_llvm_build.c | 195 ++++++++++++++++++++++++++++++++-
>>   src/amd/common/ac_llvm_build.h |  36 ++++++
>>   2 files changed, 227 insertions(+), 4 deletions(-)
>>
>> diff --git a/src/amd/common/ac_llvm_build.c b/src/amd/common/ac_llvm_build.c
>> index 68c8bad9e83..932f4bbdeef 100644
>> --- a/src/amd/common/ac_llvm_build.c
>> +++ b/src/amd/common/ac_llvm_build.c
>> @@ -3345,68 +3345,88 @@ ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs,
>>                                          _64bit ? ctx->f64 : ctx->f32,
>>                                          (LLVMValueRef[]){lhs, rhs}, 2, AC_FUNC_ATTR_READNONE);
>>          case nir_op_iand: return LLVMBuildAnd(ctx->builder, lhs, rhs, "");
>>          case nir_op_ior: return LLVMBuildOr(ctx->builder, lhs, rhs, "");
>>          case nir_op_ixor: return LLVMBuildXor(ctx->builder, lhs, rhs, "");
>>          default:
>>                  unreachable("bad reduction intrinsic");
>>          }
>>   }
>>
>> -/* TODO: add inclusive and excluse scan functions for SI chip class.  */
>> +/**
>> + * \param maxprefix specifies that the result only needs to be correct for a
>> + *     prefix of this many threads
>> + *
>> + * TODO: add inclusive and excluse scan functions for SI chip class.
>> + */
>>   static LLVMValueRef
>> -ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity)
>> +ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity,
>> +             unsigned maxprefix)
>>   {
>>          LLVMValueRef result, tmp;
>>          result = src;
>> +       if (maxprefix <= 1)
>> +               return result;
>>          tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(1), 0xf, 0xf, false);
>>          result = ac_build_alu_op(ctx, result, tmp, op);
>> +       if (maxprefix <= 2)
>> +               return result;
>>          tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(2), 0xf, 0xf, false);
>>          result = ac_build_alu_op(ctx, result, tmp, op);
>> +       if (maxprefix <= 3)
>> +               return result;
>>          tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(3), 0xf, 0xf, false);
>>          result = ac_build_alu_op(ctx, result, tmp, op);
>> +       if (maxprefix <= 4)
>> +               return result;
>>          tmp = ac_build_dpp(ctx, identity, result, dpp_row_sr(4), 0xf, 0xe, false);
>>          result = ac_build_alu_op(ctx, result, tmp, op);
>> +       if (maxprefix <= 8)
>> +               return result;
>>          tmp = ac_build_dpp(ctx, identity, result, dpp_row_sr(8), 0xf, 0xc, false);
>>          result = ac_build_alu_op(ctx, result, tmp, op);
>> +       if (maxprefix <= 16)
>> +               return result;
>>          tmp = ac_build_dpp(ctx, identity, result, dpp_row_bcast15, 0xa, 0xf, false);
>>          result = ac_build_alu_op(ctx, result, tmp, op);
>> +       if (maxprefix <= 32)
>> +               return result;
>>          tmp = ac_build_dpp(ctx, identity, result, dpp_row_bcast31, 0xc, 0xf, false);
>>          result = ac_build_alu_op(ctx, result, tmp, op);
>>          return result;
>>   }
>>
>>   LLVMValueRef
>>   ac_build_inclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op)
>>   {
>>          ac_build_optimization_barrier(ctx, &src);
>>          LLVMValueRef result;
>>          LLVMValueRef identity =
>>                  get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src)));
>>          result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity),
>>                                    LLVMTypeOf(identity), "");
>> -       result = ac_build_scan(ctx, op, result, identity);
>> +       result = ac_build_scan(ctx, op, result, identity, 64);
>>
>>          return ac_build_wwm(ctx, result);
>>   }
>>
>>   LLVMValueRef
>>   ac_build_exclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op)
>>   {
>>          ac_build_optimization_barrier(ctx, &src);
>>          LLVMValueRef result;
>>          LLVMValueRef identity =
>>                  get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src)));
>>          result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity),
>>                                    LLVMTypeOf(identity), "");
>>          result = ac_build_dpp(ctx, identity, result, dpp_wf_sr1, 0xf, 0xf, false);
>> -       result = ac_build_scan(ctx, op, result, identity);
>> +       result = ac_build_scan(ctx, op, result, identity, 64);
>>
>>          return ac_build_wwm(ctx, result);
>>   }
>>
>>   LLVMValueRef
>>   ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, unsigned cluster_size)
>>   {
>>          if (cluster_size == 1) return src;
>>          ac_build_optimization_barrier(ctx, &src);
>>          LLVMValueRef result, swap;
>> @@ -3450,20 +3470,187 @@ ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, unsign
>>                  result = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 63, 0));
>>                  return ac_build_wwm(ctx, result);
>>          } else {
>>                  swap = ac_build_readlane(ctx, result, ctx->i32_0);
>>                  result = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 32, 0));
>>                  result = ac_build_alu_op(ctx, result, swap, op);
>>                  return ac_build_wwm(ctx, result);
>>          }
>>   }
>>
>> +/**
>> + * "Top half" of a scan that reduces per-wave values across an entire
>> + * workgroup.
>> + *
>> + * The source value must be present in the highest lane of the wave, and the
>> + * highest lane must be live.
>> + */
>> +void
>> +ac_build_wg_wavescan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
>> +{
>> +       if (ws->maxwaves <= 1)
>> +               return;
>> +
>> +       const LLVMValueRef i32_63 = LLVMConstInt(ctx->i32, 63, false);
>> +       LLVMBuilderRef builder = ctx->builder;
>> +       LLVMValueRef tid = ac_get_thread_id(ctx);
>> +       LLVMValueRef tmp;
>> +
>> +       tmp = LLVMBuildICmp(builder, LLVMIntEQ, tid, i32_63, "");
>> +       ac_build_ifcc(ctx, tmp, 1000);
>> +       LLVMBuildStore(builder, ws->src, LLVMBuildGEP(builder, ws->scratch, &ws->waveidx, 1, ""));
>> +       ac_build_endif(ctx, 1000);
>> +}
>> +
>> +/**
>> + * "Bottom half" of a scan that reduces per-wave values across an entire
>> + * workgroup.
>> + *
>> + * The caller must place a barrier between the top and bottom halves.
>> + */
>> +void
>> +ac_build_wg_wavescan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
>> +{
>> +       const LLVMTypeRef type = LLVMTypeOf(ws->src);
>> +       const LLVMValueRef identity =
>> +               get_reduction_identity(ctx, ws->op, ac_get_type_size(type));
>> +
>> +       if (ws->maxwaves <= 1) {
>> +               ws->result_reduce = ws->src;
>> +               ws->result_inclusive = ws->src;
>> +               ws->result_exclusive = identity;
>> +               return;
>> +       }
>> +       assert(ws->maxwaves <= 32);
>> +
>> +       LLVMBuilderRef builder = ctx->builder;
>> +       LLVMValueRef tid = ac_get_thread_id(ctx);
>> +       LLVMBasicBlockRef bbs[2];
>> +       LLVMValueRef phivalues_scan[2];
>> +       LLVMValueRef tmp, tmp2;
>> +
>> +       bbs[0] = LLVMGetInsertBlock(builder);
>> +       phivalues_scan[0] = LLVMGetUndef(type);
>> +
>> +       if (ws->enable_reduce)
>> +               tmp = LLVMBuildICmp(builder, LLVMIntULT, tid, ws->numwaves, "");
>> +       else if (ws->enable_inclusive)
>> +               tmp = LLVMBuildICmp(builder, LLVMIntULE, tid, ws->waveidx, "");
>> +       else
>> +               tmp = LLVMBuildICmp(builder, LLVMIntULT, tid, ws->waveidx, "");
>> +       ac_build_ifcc(ctx, tmp, 1001);
>> +       {
>> +               tmp = LLVMBuildLoad(builder, LLVMBuildGEP(builder, ws->scratch, &tid, 1, ""), "");
>> +
>> +               ac_build_optimization_barrier(ctx, &tmp);
>> +
>> +               bbs[1] = LLVMGetInsertBlock(builder);
>> +               phivalues_scan[1] = ac_build_scan(ctx, ws->op, tmp, identity, ws->maxwaves);
>> +       }
>> +       ac_build_endif(ctx, 1001);
>> +
>> +       const LLVMValueRef scan = ac_build_phi(ctx, type, 2, phivalues_scan, bbs);
>> +
>> +       if (ws->enable_reduce) {
>> +               tmp = LLVMBuildSub(builder, ws->numwaves, ctx->i32_1, "");
>> +               ws->result_reduce = ac_build_readlane(ctx, scan, tmp);
>> +       }
>> +       if (ws->enable_inclusive)
>> +               ws->result_inclusive = ac_build_readlane(ctx, scan, ws->waveidx);
>> +       if (ws->enable_exclusive) {
>> +               tmp = LLVMBuildSub(builder, ws->waveidx, ctx->i32_1, "");
>> +               tmp = ac_build_readlane(ctx, scan, tmp);
>> +               tmp2 = LLVMBuildICmp(builder, LLVMIntEQ, ws->waveidx, ctx->i32_0, "");
>> +               ws->result_exclusive = LLVMBuildSelect(builder, tmp2, identity, tmp, "");
>> +       }
>> +}
>> +
>> +/**
>> + * Inclusive scan of a per-wave value across an entire workgroup.
>> + *
>> + * This implies an s_barrier instruction.
>> + *
>> + * Unlike ac_build_inclusive_scan, the caller \em must ensure that all threads
>> + * of the workgroup are live. (This requirement cannot easily be relaxed in a
>> + * useful manner because of the barrier in the algorithm.)
>> + */
>> +void
>> +ac_build_wg_wavescan(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
>> +{
>> +       ac_build_wg_wavescan_top(ctx, ws);
>> +       ac_build_s_barrier(ctx);
>> +       ac_build_wg_wavescan_bottom(ctx, ws);
>> +}
>> +
>> +/**
>> + * "Top half" of a scan that reduces per-thread values across an entire
>> + * workgroup.
>> + *
>> + * All lanes must be active when this code runs.
>> + */
>> +void
>> +ac_build_wg_scan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
>> +{
>> +       if (ws->enable_exclusive) {
>> +               ws->extra = ac_build_exclusive_scan(ctx, ws->src, ws->op);
>> +               ws->src = ac_build_alu_op(ctx, ws->extra, ws->src, ws->op);
>> +       } else {
>> +               ws->src = ac_build_inclusive_scan(ctx, ws->src, ws->op);
>> +       }
>> +
>> +       bool enable_inclusive = ws->enable_inclusive;
>> +       bool enable_exclusive = ws->enable_exclusive;
>> +       ws->enable_inclusive = false;
>> +       ws->enable_exclusive = ws->enable_exclusive || enable_inclusive;
>> +       ac_build_wg_wavescan_top(ctx, ws);
>> +       ws->enable_inclusive = enable_inclusive;
>> +       ws->enable_exclusive = enable_exclusive;
>> +}
>> +
>> +/**
>> + * "Bottom half" of a scan that reduces per-thread values across an entire
>> + * workgroup.
>> + *
>> + * The caller must place a barrier between the top and bottom halves.
>> + */
>> +void
>> +ac_build_wg_scan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
>> +{
>> +       bool enable_inclusive = ws->enable_inclusive;
>> +       bool enable_exclusive = ws->enable_exclusive;
>> +       ws->enable_inclusive = false;
>> +       ws->enable_exclusive = ws->enable_exclusive || enable_inclusive;
>> +       ac_build_wg_wavescan_bottom(ctx, ws);
>> +       ws->enable_inclusive = enable_inclusive;
>> +       ws->enable_exclusive = enable_exclusive;
>> +
>> +       /* ws->result_reduce is already the correct value */
>> +       if (ws->enable_inclusive)
>> +               ws->result_inclusive = ac_build_alu_op(ctx, ws->result_exclusive, ws->src, ws->op);
>> +       if (ws->enable_exclusive)
>> +               ws->result_exclusive = ac_build_alu_op(ctx, ws->result_exclusive, ws->extra, ws->op);
>> +}
>> +
>> +/**
>> + * A scan that reduces per-thread values across an entire workgroup.
>> + *
>> + * The caller must ensure that all lanes are active when this code runs
>> + * (WWM is insufficient!), because there is an implied barrier.
>> + */
>> +void
>> +ac_build_wg_scan(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
>> +{
>> +       ac_build_wg_scan_top(ctx, ws);
>> +       ac_build_s_barrier(ctx);
>> +       ac_build_wg_scan_bottom(ctx, ws);
>> +}
>> +
>>   LLVMValueRef
>>   ac_build_quad_swizzle(struct ac_llvm_context *ctx, LLVMValueRef src,
>>                  unsigned lane0, unsigned lane1, unsigned lane2, unsigned lane3)
>>   {
>>          unsigned mask = dpp_quad_perm(lane0, lane1, lane2, lane3);
>>          if (ctx->chip_class >= VI) {
>>                  return ac_build_dpp(ctx, src, src, mask, 0xf, 0xf, false);
>>          } else {
>>                  return ac_build_ds_swizzle(ctx, src, (1 << 15) | mask);
>>          }
>> diff --git a/src/amd/common/ac_llvm_build.h b/src/amd/common/ac_llvm_build.h
>> index cf3e3cedf65..cad131768d2 100644
>> --- a/src/amd/common/ac_llvm_build.h
>> +++ b/src/amd/common/ac_llvm_build.h
>> @@ -519,20 +519,56 @@ ac_build_mbcnt(struct ac_llvm_context *ctx, LLVMValueRef mask);
>>
>>   LLVMValueRef
>>   ac_build_inclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op);
>>
>>   LLVMValueRef
>>   ac_build_exclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op);
>>
>>   LLVMValueRef
>>   ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, unsigned cluster_size);
>>
>> +/**
>> + * Common arguments for a scan/reduce operation that accumulates per-wave
>> + * values across an entire workgroup, while respecting the order of waves.
>> + */
>> +struct ac_wg_scan {
>> +       bool enable_reduce;
>> +       bool enable_exclusive;
>> +       bool enable_inclusive;
>> +       nir_op op;
>> +       LLVMValueRef src; /* clobbered! */
>> +       LLVMValueRef result_reduce;
>> +       LLVMValueRef result_exclusive;
>> +       LLVMValueRef result_inclusive;
>> +       LLVMValueRef extra;
>> +       LLVMValueRef waveidx;
>> +       LLVMValueRef numwaves; /* only needed for "reduce" operations */
>> +
>> +       /* T addrspace(LDS) pointer to the same type as value, at least maxwaves entries */
>> +       LLVMValueRef scratch;
>> +       unsigned maxwaves;
>> +};
>> +
>> +void
>> +ac_build_wg_wavescan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws);
>> +void
>> +ac_build_wg_wavescan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws);
>> +void
>> +ac_build_wg_wavescan(struct ac_llvm_context *ctx, struct ac_wg_scan *ws);
>> +
>> +void
>> +ac_build_wg_scan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws);
>> +void
>> +ac_build_wg_scan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws);
>> +void
>> +ac_build_wg_scan(struct ac_llvm_context *ctx, struct ac_wg_scan *ws);
>> +
>>   LLVMValueRef
>>   ac_build_quad_swizzle(struct ac_llvm_context *ctx, LLVMValueRef src,
>>                  unsigned lane0, unsigned lane1, unsigned lane2, unsigned lane3);
>>
>>   LLVMValueRef
>>   ac_build_shuffle(struct ac_llvm_context *ctx, LLVMValueRef src, LLVMValueRef index);
>>
>>   #ifdef __cplusplus
>>   }
>>   #endif
>> --
>> 2.19.1
>>
>> _______________________________________________
>> mesa-dev mailing list
>> mesa-dev at lists.freedesktop.org
>> https://lists.freedesktop.org/mailman/listinfo/mesa-dev


More information about the mesa-dev mailing list