[Beignet] [PATCH] Add the support for vector type in printf.
Zhigang Gong
zhigang.gong at linux.intel.com
Tue Jun 24 17:12:58 PDT 2014
LGTM, pushed, thanks.
On Tue, Jun 24, 2014 at 04:35:58PM +0800, junyan.he at inbox.com wrote:
> From: Junyan He <junyan.he at linux.intel.com>
>
> Signed-off-by: Junyan He <junyan.he at linux.intel.com>
> ---
> backend/src/ir/printf.cpp | 144 +++++++++++++++++---------------
> backend/src/ir/printf.hpp | 11 ++-
> backend/src/llvm/llvm_printf_parser.cpp | 83 ++++++++++++++++--
> kernels/test_printf.cl | 10 ++-
> 4 files changed, 167 insertions(+), 81 deletions(-)
>
> diff --git a/backend/src/ir/printf.cpp b/backend/src/ir/printf.cpp
> index 58711e2..68b2ce4 100644
> --- a/backend/src/ir/printf.cpp
> +++ b/backend/src/ir/printf.cpp
> @@ -84,8 +84,6 @@ namespace gbe
> str += num_str;
> }
>
> - // TODO: Handle the vector here.
> -
> switch (state.length_modifier) {
> case PRINTF_LM_HH:
> str += "hh";
> @@ -97,7 +95,7 @@ namespace gbe
> str += "l";
> break;
> case PRINTF_LM_HL:
> - str += "hl";
> + str += "";
> break;
> default:
> assert(state.length_modifier == PRINTF_LM_NONE);
> @@ -105,12 +103,12 @@ namespace gbe
> }
>
> #define PRINT_SOMETHING(target_ty, conv) do { \
> - pf_str = pf_str + std::string(#conv); \
> + if (!vec_i) \
> + pf_str = pf_str + std::string(#conv); \
> printf(pf_str.c_str(), \
> ((target_ty *)((char *)buf_addr + slot.state->out_buf_sizeof_offset * \
> global_wk_sz0 * global_wk_sz1 * global_wk_sz2)) \
> - [k*global_wk_sz0*global_wk_sz1 + j*global_wk_sz0 + i]); \
> - pf_str = ""; \
> + [(k*global_wk_sz0*global_wk_sz1 + j*global_wk_sz0 + i) * vec_num + vec_i]);\
> } while (0)
>
>
> @@ -126,80 +124,88 @@ namespace gbe
> for (i = 0; i < global_wk_sz0; i++) {
> for (j = 0; j < global_wk_sz1; j++) {
> for (k = 0; k < global_wk_sz2; k++) {
> - int flag = ((int *)index_addr)[stmt*global_wk_sz0*global_wk_sz1*global_wk_sz2 + k*global_wk_sz0*global_wk_sz1 + j*global_wk_sz0 + i];
> +
> + int flag = ((int *)index_addr)[stmt*global_wk_sz0*global_wk_sz1*global_wk_sz2
> + + k*global_wk_sz0*global_wk_sz1 + j*global_wk_sz0 + i];
> if (flag) {
> - pf_str = "";
> for (auto &slot : pf) {
> + pf_str = "";
> + int vec_num;
> +
> if (slot.type == PRINTF_SLOT_TYPE_STRING) {
> - pf_str = pf_str + std::string(slot.str);
> + printf("%s", slot.str);
> continue;
> }
> assert(slot.type == PRINTF_SLOT_TYPE_STATE);
>
> generatePrintfFmtString(*slot.state, pf_str);
>
> - switch (slot.state->conversion_specifier) {
> - case PRINTF_CONVERSION_D:
> - case PRINTF_CONVERSION_I:
> - PRINT_SOMETHING(int, d);
> - break;
> -
> - case PRINTF_CONVERSION_O:
> - PRINT_SOMETHING(int, o);
> - break;
> - case PRINTF_CONVERSION_U:
> - PRINT_SOMETHING(int, u);
> - break;
> - case PRINTF_CONVERSION_X:
> - PRINT_SOMETHING(int, X);
> - break;
> - case PRINTF_CONVERSION_x:
> - PRINT_SOMETHING(int, x);
> - break;
> -
> - case PRINTF_CONVERSION_C:
> - PRINT_SOMETHING(char, c);
> - break;
> -
> - case PRINTF_CONVERSION_F:
> - PRINT_SOMETHING(float, F);
> - break;
> - case PRINTF_CONVERSION_f:
> - PRINT_SOMETHING(float, f);
> - break;
> - case PRINTF_CONVERSION_E:
> - PRINT_SOMETHING(float, E);
> - break;
> - case PRINTF_CONVERSION_e:
> - PRINT_SOMETHING(float, e);
> - break;
> - case PRINTF_CONVERSION_G:
> - PRINT_SOMETHING(float, G);
> - break;
> - case PRINTF_CONVERSION_g:
> - PRINT_SOMETHING(float, g);
> - break;
> - case PRINTF_CONVERSION_A:
> - PRINT_SOMETHING(float, A);
> - break;
> - case PRINTF_CONVERSION_a:
> - PRINT_SOMETHING(float, a);
> - break;
> -
> - case PRINTF_CONVERSION_S:
> - pf_str = pf_str + "s";
> - printf(pf_str.c_str(), slot.state->str.c_str());
> - pf_str = "";
> - break;
> -
> - default:
> - assert(0);
> - return;
> + vec_num = slot.state->vector_n > 0 ? slot.state->vector_n : 1;
> +
> + for (int vec_i = 0; vec_i < vec_num; vec_i++) {
> + if (vec_i)
> + printf(",");
> +
> + switch (slot.state->conversion_specifier) {
> + case PRINTF_CONVERSION_D:
> + case PRINTF_CONVERSION_I:
> + PRINT_SOMETHING(int, d);
> + break;
> +
> + case PRINTF_CONVERSION_O:
> + PRINT_SOMETHING(int, o);
> + break;
> + case PRINTF_CONVERSION_U:
> + PRINT_SOMETHING(int, u);
> + break;
> + case PRINTF_CONVERSION_X:
> + PRINT_SOMETHING(int, X);
> + break;
> + case PRINTF_CONVERSION_x:
> + PRINT_SOMETHING(int, x);
> + break;
> +
> + case PRINTF_CONVERSION_C:
> + PRINT_SOMETHING(char, c);
> + break;
> +
> + case PRINTF_CONVERSION_F:
> + PRINT_SOMETHING(float, F);
> + break;
> + case PRINTF_CONVERSION_f:
> + PRINT_SOMETHING(float, f);
> + break;
> + case PRINTF_CONVERSION_E:
> + PRINT_SOMETHING(float, E);
> + break;
> + case PRINTF_CONVERSION_e:
> + PRINT_SOMETHING(float, e);
> + break;
> + case PRINTF_CONVERSION_G:
> + PRINT_SOMETHING(float, G);
> + break;
> + case PRINTF_CONVERSION_g:
> + PRINT_SOMETHING(float, g);
> + break;
> + case PRINTF_CONVERSION_A:
> + PRINT_SOMETHING(float, A);
> + break;
> + case PRINTF_CONVERSION_a:
> + PRINT_SOMETHING(float, a);
> + break;
> +
> + case PRINTF_CONVERSION_S:
> + pf_str = pf_str + "s";
> + printf(pf_str.c_str(), slot.state->str.c_str());
> + break;
> +
> + default:
> + assert(0);
> + return;
> + }
> }
> - }
>
> - if (pf_str != "") {
> - printf("%s", pf_str.c_str());
> + pf_str = "";
> }
> }
> }
> diff --git a/backend/src/ir/printf.hpp b/backend/src/ir/printf.hpp
> index 8b759d4..680b8e6 100644
> --- a/backend/src/ir/printf.hpp
> +++ b/backend/src/ir/printf.hpp
> @@ -182,6 +182,13 @@ namespace gbe
>
> uint32_t getPrintfBufferElementSize(uint32_t i) {
> PrintfSlot* slot = slots[i];
> + int vec_num = 1;
> + if (slot->state->vector_n > 0) {
> + vec_num = slot->state->vector_n;
> + }
> +
> + assert(vec_num > 0 && vec_num <= 16);
> +
> switch (slot->state->conversion_specifier) {
> case PRINTF_CONVERSION_I:
> case PRINTF_CONVERSION_D:
> @@ -191,7 +198,7 @@ namespace gbe
> case PRINTF_CONVERSION_x:
> /* Char will be aligned to sizeof(int) here. */
> case PRINTF_CONVERSION_C:
> - return (uint32_t)sizeof(int);
> + return (uint32_t)(sizeof(int) * vec_num);
> case PRINTF_CONVERSION_E:
> case PRINTF_CONVERSION_e:
> case PRINTF_CONVERSION_F:
> @@ -200,7 +207,7 @@ namespace gbe
> case PRINTF_CONVERSION_g:
> case PRINTF_CONVERSION_A:
> case PRINTF_CONVERSION_a:
> - return (uint32_t)sizeof(float);
> + return (uint32_t)(sizeof(float) * vec_num);
> case PRINTF_CONVERSION_S:
> return (uint32_t)0;
> default:
> diff --git a/backend/src/llvm/llvm_printf_parser.cpp b/backend/src/llvm/llvm_printf_parser.cpp
> index dcad036..ff8e259 100644
> --- a/backend/src/llvm/llvm_printf_parser.cpp
> +++ b/backend/src/llvm/llvm_printf_parser.cpp
> @@ -98,7 +98,7 @@ namespace gbe
> return -1;
>
> #define FMT_PLUS_PLUS do { \
> - if (fmt + 1 < end) fmt++; \
> + if (fmt + 1 <= end) fmt++; \
> else { \
> printf("Error, line: %d, fmt > end\n", __LINE__); \
> return -1; \
> @@ -627,20 +627,21 @@ error:
> conversion need to be applied. */
> switch (arg->getType()->getTypeID()) {
> case Type::IntegerTyID: {
> + bool sign = false;
> switch (slot.state->conversion_specifier) {
> case PRINTF_CONVERSION_I:
> case PRINTF_CONVERSION_D:
> - /* Int to Int, just store. */
> - dst_type = Type::getInt32PtrTy(module->getContext(), 1);
> - sizeof_size = sizeof(int);
> - return true;
> -
> + sign = true;
> case PRINTF_CONVERSION_O:
> case PRINTF_CONVERSION_U:
> case PRINTF_CONVERSION_x:
> case PRINTF_CONVERSION_X:
> - /* To uint, add a conversion. */
> - arg = builder->CreateIntCast(arg, Type::getInt32Ty(module->getContext()), true);
> + /* If the bits change, we need to consider the signed. */
> + if (arg->getType() != Type::getInt32Ty(module->getContext())) {
> + arg = builder->CreateIntCast(arg, Type::getInt32Ty(module->getContext()), sign);
> + }
> +
> + /* Int to Int, just store. */
> dst_type = Type::getInt32PtrTy(module->getContext(), 1);
> sizeof_size = sizeof(int);
> return true;
> @@ -745,6 +746,72 @@ error:
> }
> }
>
> + case Type::VectorTyID: {
> + Type* vect_type = arg->getType();
> + Type* elt_type = vect_type->getVectorElementType();
> + int vec_num = vect_type->getVectorNumElements();
> + bool sign = false;
> +
> + if (vec_num != slot.state->vector_n) {
> + return false;
> + }
> +
> + switch (slot.state->conversion_specifier) {
> + case PRINTF_CONVERSION_I:
> + case PRINTF_CONVERSION_D:
> + sign = true;
> + case PRINTF_CONVERSION_O:
> + case PRINTF_CONVERSION_U:
> + case PRINTF_CONVERSION_x:
> + case PRINTF_CONVERSION_X:
> + if (elt_type->getTypeID() != Type::IntegerTyID)
> + return false;
> +
> + /* If the bits change, we need to consider the signed. */
> + if (elt_type != Type::getInt32Ty(elt_type->getContext())) {
> + Value *II = NULL;
> + for (int i = 0; i < vec_num; i++) {
> + Value *vec = II ? II : UndefValue::get(VectorType::get(Type::getInt32Ty(elt_type->getContext()), vec_num));
> + Value *cv = ConstantInt::get(Type::getInt32Ty(elt_type->getContext()), i);
> + Value *org = builder->CreateExtractElement(arg, cv);
> + Value *cvt = builder->CreateIntCast(org, Type::getInt32Ty(module->getContext()), sign);
> + II = builder->CreateInsertElement(vec, cvt, cv);
> + }
> + arg = II;
> + }
> +
> + dst_type = arg->getType()->getPointerTo(1);
> + sizeof_size = sizeof(int) * vec_num;
> + return true;
> +
> + case PRINTF_CONVERSION_F:
> + case PRINTF_CONVERSION_f:
> + case PRINTF_CONVERSION_E:
> + case PRINTF_CONVERSION_e:
> + case PRINTF_CONVERSION_G:
> + case PRINTF_CONVERSION_g:
> + case PRINTF_CONVERSION_A:
> + case PRINTF_CONVERSION_a:
> + if (elt_type->getTypeID() != Type::DoubleTyID && elt_type->getTypeID() != Type::FloatTyID)
> + return false;
> +
> + if (elt_type->getTypeID() != Type::FloatTyID) {
> + Value *II = NULL;
> + for (int i = 0; i < vec_num; i++) {
> + Value *vec = II ? II : UndefValue::get(VectorType::get(Type::getFloatTy(elt_type->getContext()), vec_num));
> + Value *cv = ConstantInt::get(Type::getInt32Ty(elt_type->getContext()), i);
> + Value *org = builder->CreateExtractElement(arg, cv);
> + Value* cvt = builder->CreateFPCast(org, Type::getFloatTy(module->getContext()));
> + II = builder->CreateInsertElement(vec, cvt, cv);
> + }
> + arg = II;
> + }
> + }
> + dst_type = arg->getType()->getPointerTo(1);
> + sizeof_size = sizeof(int) * vec_num;
> + return true;
> + }
> +
> default:
> return false;
> }
> diff --git a/kernels/test_printf.cl b/kernels/test_printf.cl
> index c21ee98..84bb478 100644
> --- a/kernels/test_printf.cl
> +++ b/kernels/test_printf.cl
> @@ -6,6 +6,10 @@ test_printf(void)
> int z = (int)get_global_id(2);
> uint a = 'x';
> float f = 5.0f;
> + int3 vec;
> + vec.x = x;
> + vec.y = y;
> + vec.z = z;
>
> if (x == 0 && y == 0 && z == 0) {
> printf("--- Welcome to the printf test of %s ---\n", "Intel Beignet");
> @@ -16,8 +20,8 @@ test_printf(void)
> if (x % 15 == 0)
> if (y % 3 == 0)
> if (z % 7 == 0)
> - printf("######## global_id(x, y, z) = (%d, %d, %d), global_size(d0, d1, d3) = (%d, %d, %d)\n",
> - x, y, z, get_global_size(0), get_global_size(1), get_global_size(2));
> + printf("######## global_id(x, y, z) = %v3d, global_size(d0, d1, d3) = (%d, %d, %d)\n",
> + vec, get_global_size(0), get_global_size(1), get_global_size(2));
>
> if (x == 1)
> if (y == 0) {
> @@ -26,7 +30,9 @@ test_printf(void)
> else
> printf("#### output a float to int is %d\n", f);
> }
> +
> if (x == 0 && y == 0 && z == 0) {
> printf("--- End to the printf test ---\n");
> }
> +
> }
> --
> 1.8.3.2
>
> _______________________________________________
> Beignet mailing list
> Beignet at lists.freedesktop.org
> http://lists.freedesktop.org/mailman/listinfo/beignet
More information about the Beignet
mailing list