[Beignet] [PATCH] Add the support for vector type in printf.

Zhigang Gong zhigang.gong at linux.intel.com
Tue Jun 24 17:12:58 PDT 2014


LGTM, pushed, thanks.

On Tue, Jun 24, 2014 at 04:35:58PM +0800, junyan.he at inbox.com wrote:
> From: Junyan He <junyan.he at linux.intel.com>
> 
> Signed-off-by: Junyan He <junyan.he at linux.intel.com>
> ---
>  backend/src/ir/printf.cpp               | 144 +++++++++++++++++---------------
>  backend/src/ir/printf.hpp               |  11 ++-
>  backend/src/llvm/llvm_printf_parser.cpp |  83 ++++++++++++++++--
>  kernels/test_printf.cl                  |  10 ++-
>  4 files changed, 167 insertions(+), 81 deletions(-)
> 
> diff --git a/backend/src/ir/printf.cpp b/backend/src/ir/printf.cpp
> index 58711e2..68b2ce4 100644
> --- a/backend/src/ir/printf.cpp
> +++ b/backend/src/ir/printf.cpp
> @@ -84,8 +84,6 @@ namespace gbe
>          str += num_str;
>        }
>  
> -      // TODO:  Handle the vector here.
> -
>        switch (state.length_modifier) {
>          case PRINTF_LM_HH:
>            str += "hh";
> @@ -97,7 +95,7 @@ namespace gbe
>            str += "l";
>            break;
>          case PRINTF_LM_HL:
> -          str += "hl";
> +          str += "";
>            break;
>          default:
>            assert(state.length_modifier == PRINTF_LM_NONE);
> @@ -105,12 +103,12 @@ namespace gbe
>      }
>  
>  #define PRINT_SOMETHING(target_ty, conv)  do {                          \
> -      pf_str = pf_str + std::string(#conv);                             \
> +      if (!vec_i)                                                       \
> +        pf_str = pf_str + std::string(#conv);                           \
>        printf(pf_str.c_str(),                                            \
>               ((target_ty *)((char *)buf_addr + slot.state->out_buf_sizeof_offset * \
>                              global_wk_sz0 * global_wk_sz1 * global_wk_sz2)) \
> -             [k*global_wk_sz0*global_wk_sz1 + j*global_wk_sz0 + i]);    \
> -      pf_str = "";                                                      \
> +             [(k*global_wk_sz0*global_wk_sz1 + j*global_wk_sz0 + i) * vec_num + vec_i]);\
>      } while (0)
>  
>  
> @@ -126,80 +124,88 @@ namespace gbe
>          for (i = 0; i < global_wk_sz0; i++) {
>            for (j = 0; j < global_wk_sz1; j++) {
>              for (k = 0; k < global_wk_sz2; k++) {
> -              int flag = ((int *)index_addr)[stmt*global_wk_sz0*global_wk_sz1*global_wk_sz2 + k*global_wk_sz0*global_wk_sz1 + j*global_wk_sz0 + i];
> +
> +              int flag = ((int *)index_addr)[stmt*global_wk_sz0*global_wk_sz1*global_wk_sz2
> +                                             + k*global_wk_sz0*global_wk_sz1 + j*global_wk_sz0 + i];
>                if (flag) {
> -                pf_str = "";
>                  for (auto &slot : pf) {
> +                  pf_str = "";
> +                  int vec_num;
> +
>                    if (slot.type == PRINTF_SLOT_TYPE_STRING) {
> -                    pf_str = pf_str + std::string(slot.str);
> +                    printf("%s", slot.str);
>                      continue;
>                    }
>                    assert(slot.type == PRINTF_SLOT_TYPE_STATE);
>  
>                    generatePrintfFmtString(*slot.state, pf_str);
>  
> -                  switch (slot.state->conversion_specifier) {
> -                    case PRINTF_CONVERSION_D:
> -                    case PRINTF_CONVERSION_I:
> -                      PRINT_SOMETHING(int, d);
> -                      break;
> -
> -                    case PRINTF_CONVERSION_O:
> -                      PRINT_SOMETHING(int, o);
> -                      break;
> -                    case PRINTF_CONVERSION_U:
> -                      PRINT_SOMETHING(int, u);
> -                      break;
> -                    case PRINTF_CONVERSION_X:
> -                      PRINT_SOMETHING(int, X);
> -                      break;
> -                    case PRINTF_CONVERSION_x:
> -                      PRINT_SOMETHING(int, x);
> -                      break;
> -
> -                    case PRINTF_CONVERSION_C:
> -                      PRINT_SOMETHING(char, c);
> -                      break;
> -
> -                    case PRINTF_CONVERSION_F:
> -                      PRINT_SOMETHING(float, F);
> -                      break;
> -                    case PRINTF_CONVERSION_f:
> -                      PRINT_SOMETHING(float, f);
> -                      break;
> -                    case PRINTF_CONVERSION_E:
> -                      PRINT_SOMETHING(float, E);
> -                      break;
> -                    case PRINTF_CONVERSION_e:
> -                      PRINT_SOMETHING(float, e);
> -                      break;
> -                    case PRINTF_CONVERSION_G:
> -                      PRINT_SOMETHING(float, G);
> -                      break;
> -                    case PRINTF_CONVERSION_g:
> -                      PRINT_SOMETHING(float, g);
> -                      break;
> -                    case PRINTF_CONVERSION_A:
> -                      PRINT_SOMETHING(float, A);
> -                      break;
> -                    case PRINTF_CONVERSION_a:
> -                      PRINT_SOMETHING(float, a);
> -                      break;
> -
> -                    case PRINTF_CONVERSION_S:
> -                      pf_str = pf_str + "s";
> -                      printf(pf_str.c_str(), slot.state->str.c_str());
> -                      pf_str = "";
> -                      break;
> -
> -                    default:
> -                      assert(0);
> -                      return;
> +                  vec_num = slot.state->vector_n > 0 ? slot.state->vector_n : 1;
> +
> +                  for (int vec_i = 0; vec_i < vec_num; vec_i++) {
> +                    if (vec_i)
> +                      printf(",");
> +
> +                    switch (slot.state->conversion_specifier) {
> +                      case PRINTF_CONVERSION_D:
> +                      case PRINTF_CONVERSION_I:
> +                        PRINT_SOMETHING(int, d);
> +                        break;
> +
> +                      case PRINTF_CONVERSION_O:
> +                        PRINT_SOMETHING(int, o);
> +                        break;
> +                      case PRINTF_CONVERSION_U:
> +                        PRINT_SOMETHING(int, u);
> +                        break;
> +                      case PRINTF_CONVERSION_X:
> +                        PRINT_SOMETHING(int, X);
> +                        break;
> +                      case PRINTF_CONVERSION_x:
> +                        PRINT_SOMETHING(int, x);
> +                        break;
> +
> +                      case PRINTF_CONVERSION_C:
> +                        PRINT_SOMETHING(char, c);
> +                        break;
> +
> +                      case PRINTF_CONVERSION_F:
> +                        PRINT_SOMETHING(float, F);
> +                        break;
> +                      case PRINTF_CONVERSION_f:
> +                        PRINT_SOMETHING(float, f);
> +                        break;
> +                      case PRINTF_CONVERSION_E:
> +                        PRINT_SOMETHING(float, E);
> +                        break;
> +                      case PRINTF_CONVERSION_e:
> +                        PRINT_SOMETHING(float, e);
> +                        break;
> +                      case PRINTF_CONVERSION_G:
> +                        PRINT_SOMETHING(float, G);
> +                        break;
> +                      case PRINTF_CONVERSION_g:
> +                        PRINT_SOMETHING(float, g);
> +                        break;
> +                      case PRINTF_CONVERSION_A:
> +                        PRINT_SOMETHING(float, A);
> +                        break;
> +                      case PRINTF_CONVERSION_a:
> +                        PRINT_SOMETHING(float, a);
> +                        break;
> +
> +                      case PRINTF_CONVERSION_S:
> +                        pf_str = pf_str + "s";
> +                        printf(pf_str.c_str(), slot.state->str.c_str());
> +                        break;
> +
> +                      default:
> +                        assert(0);
> +                        return;
> +                    }
>                    }
> -                }
>  
> -                if (pf_str != "") {
> -                  printf("%s", pf_str.c_str());
> +                  pf_str = "";
>                  }
>                }
>              }
> diff --git a/backend/src/ir/printf.hpp b/backend/src/ir/printf.hpp
> index 8b759d4..680b8e6 100644
> --- a/backend/src/ir/printf.hpp
> +++ b/backend/src/ir/printf.hpp
> @@ -182,6 +182,13 @@ namespace gbe
>  
>        uint32_t getPrintfBufferElementSize(uint32_t i) {
>          PrintfSlot* slot = slots[i];
> +        int vec_num = 1;
> +        if (slot->state->vector_n > 0) {
> +          vec_num = slot->state->vector_n;
> +        }
> +
> +        assert(vec_num > 0 && vec_num <= 16);
> +
>          switch (slot->state->conversion_specifier) {
>            case PRINTF_CONVERSION_I:
>            case PRINTF_CONVERSION_D:
> @@ -191,7 +198,7 @@ namespace gbe
>            case PRINTF_CONVERSION_x:
>              /* Char will be aligned to sizeof(int) here. */
>            case PRINTF_CONVERSION_C:
> -            return (uint32_t)sizeof(int);
> +            return (uint32_t)(sizeof(int) * vec_num);
>            case PRINTF_CONVERSION_E:
>            case PRINTF_CONVERSION_e:
>            case PRINTF_CONVERSION_F:
> @@ -200,7 +207,7 @@ namespace gbe
>            case PRINTF_CONVERSION_g:
>            case PRINTF_CONVERSION_A:
>            case PRINTF_CONVERSION_a:
> -            return (uint32_t)sizeof(float);
> +            return (uint32_t)(sizeof(float) * vec_num);
>            case PRINTF_CONVERSION_S:
>              return (uint32_t)0;
>            default:
> diff --git a/backend/src/llvm/llvm_printf_parser.cpp b/backend/src/llvm/llvm_printf_parser.cpp
> index dcad036..ff8e259 100644
> --- a/backend/src/llvm/llvm_printf_parser.cpp
> +++ b/backend/src/llvm/llvm_printf_parser.cpp
> @@ -98,7 +98,7 @@ namespace gbe
>        return -1;
>  
>  #define FMT_PLUS_PLUS do {                                  \
> -      if (fmt + 1 < end) fmt++;                             \
> +      if (fmt + 1 <= end) fmt++;                             \
>        else {                                                \
>          printf("Error, line: %d, fmt > end\n", __LINE__);   \
>          return -1;                                          \
> @@ -627,20 +627,21 @@ error:
>         conversion need to be applied. */
>      switch (arg->getType()->getTypeID()) {
>        case Type::IntegerTyID: {
> +        bool sign = false;
>          switch (slot.state->conversion_specifier) {
>            case PRINTF_CONVERSION_I:
>            case PRINTF_CONVERSION_D:
> -            /* Int to Int, just store. */
> -            dst_type = Type::getInt32PtrTy(module->getContext(), 1);
> -            sizeof_size = sizeof(int);
> -            return true;
> -
> +            sign = true;
>            case PRINTF_CONVERSION_O:
>            case PRINTF_CONVERSION_U:
>            case PRINTF_CONVERSION_x:
>            case PRINTF_CONVERSION_X:
> -            /* To uint, add a conversion. */
> -            arg = builder->CreateIntCast(arg, Type::getInt32Ty(module->getContext()), true);
> +            /* If the bits change, we need to consider the signed. */
> +            if (arg->getType() != Type::getInt32Ty(module->getContext())) {
> +              arg = builder->CreateIntCast(arg, Type::getInt32Ty(module->getContext()), sign);
> +            }
> +
> +            /* Int to Int, just store. */
>              dst_type = Type::getInt32PtrTy(module->getContext(), 1);
>              sizeof_size = sizeof(int);
>              return true;
> @@ -745,6 +746,72 @@ error:
>            }
>          }
>  
> +      case Type::VectorTyID: {
> +        Type* vect_type = arg->getType();
> +        Type* elt_type = vect_type->getVectorElementType();
> +        int vec_num = vect_type->getVectorNumElements();
> +        bool sign = false;
> +
> +        if (vec_num != slot.state->vector_n) {
> +          return false;
> +        }
> +
> +        switch (slot.state->conversion_specifier) {
> +          case PRINTF_CONVERSION_I:
> +          case PRINTF_CONVERSION_D:
> +            sign = true;
> +          case PRINTF_CONVERSION_O:
> +          case PRINTF_CONVERSION_U:
> +          case PRINTF_CONVERSION_x:
> +          case PRINTF_CONVERSION_X:
> +            if (elt_type->getTypeID() != Type::IntegerTyID)
> +              return false;
> +
> +            /* If the bits change, we need to consider the signed. */
> +            if (elt_type != Type::getInt32Ty(elt_type->getContext())) {
> +              Value *II = NULL;
> +              for (int i = 0; i < vec_num; i++) {
> +                Value *vec = II ? II : UndefValue::get(VectorType::get(Type::getInt32Ty(elt_type->getContext()), vec_num));
> +                Value *cv = ConstantInt::get(Type::getInt32Ty(elt_type->getContext()), i);
> +                Value *org = builder->CreateExtractElement(arg, cv);
> +                Value *cvt = builder->CreateIntCast(org, Type::getInt32Ty(module->getContext()), sign);
> +                II = builder->CreateInsertElement(vec, cvt, cv);
> +              }
> +              arg = II;
> +            }
> +
> +            dst_type = arg->getType()->getPointerTo(1);
> +            sizeof_size = sizeof(int) * vec_num;
> +            return true;
> +
> +          case PRINTF_CONVERSION_F:
> +          case PRINTF_CONVERSION_f:
> +          case PRINTF_CONVERSION_E:
> +          case PRINTF_CONVERSION_e:
> +          case PRINTF_CONVERSION_G:
> +          case PRINTF_CONVERSION_g:
> +          case PRINTF_CONVERSION_A:
> +          case PRINTF_CONVERSION_a:
> +            if (elt_type->getTypeID() != Type::DoubleTyID && elt_type->getTypeID() != Type::FloatTyID)
> +              return false;
> +
> +            if (elt_type->getTypeID() != Type::FloatTyID) {
> +              Value *II = NULL;
> +              for (int i = 0; i < vec_num; i++) {
> +                Value *vec = II ? II : UndefValue::get(VectorType::get(Type::getFloatTy(elt_type->getContext()), vec_num));
> +                Value *cv = ConstantInt::get(Type::getInt32Ty(elt_type->getContext()), i);
> +                Value *org = builder->CreateExtractElement(arg, cv);
> +                Value* cvt  = builder->CreateFPCast(org, Type::getFloatTy(module->getContext()));
> +                II = builder->CreateInsertElement(vec, cvt, cv);
> +              }
> +              arg = II;
> +            }
> +        }
> +        dst_type = arg->getType()->getPointerTo(1);
> +        sizeof_size = sizeof(int) * vec_num;
> +        return true;
> +      }
> +
>        default:
>          return false;
>      }
> diff --git a/kernels/test_printf.cl b/kernels/test_printf.cl
> index c21ee98..84bb478 100644
> --- a/kernels/test_printf.cl
> +++ b/kernels/test_printf.cl
> @@ -6,6 +6,10 @@ test_printf(void)
>    int z = (int)get_global_id(2);
>    uint a = 'x';
>    float f = 5.0f;
> +  int3 vec;
> +  vec.x = x;
> +  vec.y = y;
> +  vec.z = z;
>  
>    if (x == 0 && y == 0 && z == 0) {
>      printf("--- Welcome to the printf test of %s ---\n", "Intel Beignet");
> @@ -16,8 +20,8 @@ test_printf(void)
>    if (x % 15 == 0)
>      if (y % 3 == 0)
>        if (z % 7 == 0)
> -        printf("######## global_id(x, y, z) = (%d, %d, %d), global_size(d0, d1, d3) = (%d, %d, %d)\n",
> -                x, y, z, get_global_size(0), get_global_size(1), get_global_size(2));
> +        printf("######## global_id(x, y, z) = %v3d, global_size(d0, d1, d3) = (%d, %d, %d)\n",
> +                vec, get_global_size(0), get_global_size(1), get_global_size(2));
>  
>    if (x == 1)
>      if (y == 0) {
> @@ -26,7 +30,9 @@ test_printf(void)
>        else
>            printf("#### output a float to int is %d\n", f);
>      }
> +
>    if (x == 0 && y == 0 && z == 0) {
>      printf("--- End to the printf test ---\n");
>    }
> +
>  }
> -- 
> 1.8.3.2
> 
> _______________________________________________
> Beignet mailing list
> Beignet at lists.freedesktop.org
> http://lists.freedesktop.org/mailman/listinfo/beignet


More information about the Beignet mailing list