[Beignet] [PATCH] Add the support for vector type in printf.

junyan.he at inbox.com junyan.he at inbox.com
Tue Jun 24 01:35:58 PDT 2014


From: Junyan He <junyan.he at linux.intel.com>

Signed-off-by: Junyan He <junyan.he at linux.intel.com>
---
 backend/src/ir/printf.cpp               | 144 +++++++++++++++++---------------
 backend/src/ir/printf.hpp               |  11 ++-
 backend/src/llvm/llvm_printf_parser.cpp |  83 ++++++++++++++++--
 kernels/test_printf.cl                  |  10 ++-
 4 files changed, 167 insertions(+), 81 deletions(-)

diff --git a/backend/src/ir/printf.cpp b/backend/src/ir/printf.cpp
index 58711e2..68b2ce4 100644
--- a/backend/src/ir/printf.cpp
+++ b/backend/src/ir/printf.cpp
@@ -84,8 +84,6 @@ namespace gbe
         str += num_str;
       }
 
-      // TODO:  Handle the vector here.
-
       switch (state.length_modifier) {
         case PRINTF_LM_HH:
           str += "hh";
@@ -97,7 +95,7 @@ namespace gbe
           str += "l";
           break;
         case PRINTF_LM_HL:
-          str += "hl";
+          str += "";
           break;
         default:
           assert(state.length_modifier == PRINTF_LM_NONE);
@@ -105,12 +103,12 @@ namespace gbe
     }
 
 #define PRINT_SOMETHING(target_ty, conv)  do {                          \
-      pf_str = pf_str + std::string(#conv);                             \
+      if (!vec_i)                                                       \
+        pf_str = pf_str + std::string(#conv);                           \
       printf(pf_str.c_str(),                                            \
              ((target_ty *)((char *)buf_addr + slot.state->out_buf_sizeof_offset * \
                             global_wk_sz0 * global_wk_sz1 * global_wk_sz2)) \
-             [k*global_wk_sz0*global_wk_sz1 + j*global_wk_sz0 + i]);    \
-      pf_str = "";                                                      \
+             [(k*global_wk_sz0*global_wk_sz1 + j*global_wk_sz0 + i) * vec_num + vec_i]);\
     } while (0)
 
 
@@ -126,80 +124,88 @@ namespace gbe
         for (i = 0; i < global_wk_sz0; i++) {
           for (j = 0; j < global_wk_sz1; j++) {
             for (k = 0; k < global_wk_sz2; k++) {
-              int flag = ((int *)index_addr)[stmt*global_wk_sz0*global_wk_sz1*global_wk_sz2 + k*global_wk_sz0*global_wk_sz1 + j*global_wk_sz0 + i];
+
+              int flag = ((int *)index_addr)[stmt*global_wk_sz0*global_wk_sz1*global_wk_sz2
+                                             + k*global_wk_sz0*global_wk_sz1 + j*global_wk_sz0 + i];
               if (flag) {
-                pf_str = "";
                 for (auto &slot : pf) {
+                  pf_str = "";
+                  int vec_num;
+
                   if (slot.type == PRINTF_SLOT_TYPE_STRING) {
-                    pf_str = pf_str + std::string(slot.str);
+                    printf("%s", slot.str);
                     continue;
                   }
                   assert(slot.type == PRINTF_SLOT_TYPE_STATE);
 
                   generatePrintfFmtString(*slot.state, pf_str);
 
-                  switch (slot.state->conversion_specifier) {
-                    case PRINTF_CONVERSION_D:
-                    case PRINTF_CONVERSION_I:
-                      PRINT_SOMETHING(int, d);
-                      break;
-
-                    case PRINTF_CONVERSION_O:
-                      PRINT_SOMETHING(int, o);
-                      break;
-                    case PRINTF_CONVERSION_U:
-                      PRINT_SOMETHING(int, u);
-                      break;
-                    case PRINTF_CONVERSION_X:
-                      PRINT_SOMETHING(int, X);
-                      break;
-                    case PRINTF_CONVERSION_x:
-                      PRINT_SOMETHING(int, x);
-                      break;
-
-                    case PRINTF_CONVERSION_C:
-                      PRINT_SOMETHING(char, c);
-                      break;
-
-                    case PRINTF_CONVERSION_F:
-                      PRINT_SOMETHING(float, F);
-                      break;
-                    case PRINTF_CONVERSION_f:
-                      PRINT_SOMETHING(float, f);
-                      break;
-                    case PRINTF_CONVERSION_E:
-                      PRINT_SOMETHING(float, E);
-                      break;
-                    case PRINTF_CONVERSION_e:
-                      PRINT_SOMETHING(float, e);
-                      break;
-                    case PRINTF_CONVERSION_G:
-                      PRINT_SOMETHING(float, G);
-                      break;
-                    case PRINTF_CONVERSION_g:
-                      PRINT_SOMETHING(float, g);
-                      break;
-                    case PRINTF_CONVERSION_A:
-                      PRINT_SOMETHING(float, A);
-                      break;
-                    case PRINTF_CONVERSION_a:
-                      PRINT_SOMETHING(float, a);
-                      break;
-
-                    case PRINTF_CONVERSION_S:
-                      pf_str = pf_str + "s";
-                      printf(pf_str.c_str(), slot.state->str.c_str());
-                      pf_str = "";
-                      break;
-
-                    default:
-                      assert(0);
-                      return;
+                  vec_num = slot.state->vector_n > 0 ? slot.state->vector_n : 1;
+
+                  for (int vec_i = 0; vec_i < vec_num; vec_i++) {
+                    if (vec_i)
+                      printf(",");
+
+                    switch (slot.state->conversion_specifier) {
+                      case PRINTF_CONVERSION_D:
+                      case PRINTF_CONVERSION_I:
+                        PRINT_SOMETHING(int, d);
+                        break;
+
+                      case PRINTF_CONVERSION_O:
+                        PRINT_SOMETHING(int, o);
+                        break;
+                      case PRINTF_CONVERSION_U:
+                        PRINT_SOMETHING(int, u);
+                        break;
+                      case PRINTF_CONVERSION_X:
+                        PRINT_SOMETHING(int, X);
+                        break;
+                      case PRINTF_CONVERSION_x:
+                        PRINT_SOMETHING(int, x);
+                        break;
+
+                      case PRINTF_CONVERSION_C:
+                        PRINT_SOMETHING(char, c);
+                        break;
+
+                      case PRINTF_CONVERSION_F:
+                        PRINT_SOMETHING(float, F);
+                        break;
+                      case PRINTF_CONVERSION_f:
+                        PRINT_SOMETHING(float, f);
+                        break;
+                      case PRINTF_CONVERSION_E:
+                        PRINT_SOMETHING(float, E);
+                        break;
+                      case PRINTF_CONVERSION_e:
+                        PRINT_SOMETHING(float, e);
+                        break;
+                      case PRINTF_CONVERSION_G:
+                        PRINT_SOMETHING(float, G);
+                        break;
+                      case PRINTF_CONVERSION_g:
+                        PRINT_SOMETHING(float, g);
+                        break;
+                      case PRINTF_CONVERSION_A:
+                        PRINT_SOMETHING(float, A);
+                        break;
+                      case PRINTF_CONVERSION_a:
+                        PRINT_SOMETHING(float, a);
+                        break;
+
+                      case PRINTF_CONVERSION_S:
+                        pf_str = pf_str + "s";
+                        printf(pf_str.c_str(), slot.state->str.c_str());
+                        break;
+
+                      default:
+                        assert(0);
+                        return;
+                    }
                   }
-                }
 
-                if (pf_str != "") {
-                  printf("%s", pf_str.c_str());
+                  pf_str = "";
                 }
               }
             }
diff --git a/backend/src/ir/printf.hpp b/backend/src/ir/printf.hpp
index 8b759d4..680b8e6 100644
--- a/backend/src/ir/printf.hpp
+++ b/backend/src/ir/printf.hpp
@@ -182,6 +182,13 @@ namespace gbe
 
       uint32_t getPrintfBufferElementSize(uint32_t i) {
         PrintfSlot* slot = slots[i];
+        int vec_num = 1;
+        if (slot->state->vector_n > 0) {
+          vec_num = slot->state->vector_n;
+        }
+
+        assert(vec_num > 0 && vec_num <= 16);
+
         switch (slot->state->conversion_specifier) {
           case PRINTF_CONVERSION_I:
           case PRINTF_CONVERSION_D:
@@ -191,7 +198,7 @@ namespace gbe
           case PRINTF_CONVERSION_x:
             /* Char will be aligned to sizeof(int) here. */
           case PRINTF_CONVERSION_C:
-            return (uint32_t)sizeof(int);
+            return (uint32_t)(sizeof(int) * vec_num);
           case PRINTF_CONVERSION_E:
           case PRINTF_CONVERSION_e:
           case PRINTF_CONVERSION_F:
@@ -200,7 +207,7 @@ namespace gbe
           case PRINTF_CONVERSION_g:
           case PRINTF_CONVERSION_A:
           case PRINTF_CONVERSION_a:
-            return (uint32_t)sizeof(float);
+            return (uint32_t)(sizeof(float) * vec_num);
           case PRINTF_CONVERSION_S:
             return (uint32_t)0;
           default:
diff --git a/backend/src/llvm/llvm_printf_parser.cpp b/backend/src/llvm/llvm_printf_parser.cpp
index dcad036..ff8e259 100644
--- a/backend/src/llvm/llvm_printf_parser.cpp
+++ b/backend/src/llvm/llvm_printf_parser.cpp
@@ -98,7 +98,7 @@ namespace gbe
       return -1;
 
 #define FMT_PLUS_PLUS do {                                  \
-      if (fmt + 1 < end) fmt++;                             \
+      if (fmt + 1 <= end) fmt++;                             \
       else {                                                \
         printf("Error, line: %d, fmt > end\n", __LINE__);   \
         return -1;                                          \
@@ -627,20 +627,21 @@ error:
        conversion need to be applied. */
     switch (arg->getType()->getTypeID()) {
       case Type::IntegerTyID: {
+        bool sign = false;
         switch (slot.state->conversion_specifier) {
           case PRINTF_CONVERSION_I:
           case PRINTF_CONVERSION_D:
-            /* Int to Int, just store. */
-            dst_type = Type::getInt32PtrTy(module->getContext(), 1);
-            sizeof_size = sizeof(int);
-            return true;
-
+            sign = true;
           case PRINTF_CONVERSION_O:
           case PRINTF_CONVERSION_U:
           case PRINTF_CONVERSION_x:
           case PRINTF_CONVERSION_X:
-            /* To uint, add a conversion. */
-            arg = builder->CreateIntCast(arg, Type::getInt32Ty(module->getContext()), true);
+            /* If the bits change, we need to consider the signed. */
+            if (arg->getType() != Type::getInt32Ty(module->getContext())) {
+              arg = builder->CreateIntCast(arg, Type::getInt32Ty(module->getContext()), sign);
+            }
+
+            /* Int to Int, just store. */
             dst_type = Type::getInt32PtrTy(module->getContext(), 1);
             sizeof_size = sizeof(int);
             return true;
@@ -745,6 +746,72 @@ error:
           }
         }
 
+      case Type::VectorTyID: {
+        Type* vect_type = arg->getType();
+        Type* elt_type = vect_type->getVectorElementType();
+        int vec_num = vect_type->getVectorNumElements();
+        bool sign = false;
+
+        if (vec_num != slot.state->vector_n) {
+          return false;
+        }
+
+        switch (slot.state->conversion_specifier) {
+          case PRINTF_CONVERSION_I:
+          case PRINTF_CONVERSION_D:
+            sign = true;
+          case PRINTF_CONVERSION_O:
+          case PRINTF_CONVERSION_U:
+          case PRINTF_CONVERSION_x:
+          case PRINTF_CONVERSION_X:
+            if (elt_type->getTypeID() != Type::IntegerTyID)
+              return false;
+
+            /* If the bits change, we need to consider the signed. */
+            if (elt_type != Type::getInt32Ty(elt_type->getContext())) {
+              Value *II = NULL;
+              for (int i = 0; i < vec_num; i++) {
+                Value *vec = II ? II : UndefValue::get(VectorType::get(Type::getInt32Ty(elt_type->getContext()), vec_num));
+                Value *cv = ConstantInt::get(Type::getInt32Ty(elt_type->getContext()), i);
+                Value *org = builder->CreateExtractElement(arg, cv);
+                Value *cvt = builder->CreateIntCast(org, Type::getInt32Ty(module->getContext()), sign);
+                II = builder->CreateInsertElement(vec, cvt, cv);
+              }
+              arg = II;
+            }
+
+            dst_type = arg->getType()->getPointerTo(1);
+            sizeof_size = sizeof(int) * vec_num;
+            return true;
+
+          case PRINTF_CONVERSION_F:
+          case PRINTF_CONVERSION_f:
+          case PRINTF_CONVERSION_E:
+          case PRINTF_CONVERSION_e:
+          case PRINTF_CONVERSION_G:
+          case PRINTF_CONVERSION_g:
+          case PRINTF_CONVERSION_A:
+          case PRINTF_CONVERSION_a:
+            if (elt_type->getTypeID() != Type::DoubleTyID && elt_type->getTypeID() != Type::FloatTyID)
+              return false;
+
+            if (elt_type->getTypeID() != Type::FloatTyID) {
+              Value *II = NULL;
+              for (int i = 0; i < vec_num; i++) {
+                Value *vec = II ? II : UndefValue::get(VectorType::get(Type::getFloatTy(elt_type->getContext()), vec_num));
+                Value *cv = ConstantInt::get(Type::getInt32Ty(elt_type->getContext()), i);
+                Value *org = builder->CreateExtractElement(arg, cv);
+                Value* cvt  = builder->CreateFPCast(org, Type::getFloatTy(module->getContext()));
+                II = builder->CreateInsertElement(vec, cvt, cv);
+              }
+              arg = II;
+            }
+        }
+        dst_type = arg->getType()->getPointerTo(1);
+        sizeof_size = sizeof(int) * vec_num;
+        return true;
+      }
+
       default:
         return false;
     }
diff --git a/kernels/test_printf.cl b/kernels/test_printf.cl
index c21ee98..84bb478 100644
--- a/kernels/test_printf.cl
+++ b/kernels/test_printf.cl
@@ -6,6 +6,10 @@ test_printf(void)
   int z = (int)get_global_id(2);
   uint a = 'x';
   float f = 5.0f;
+  int3 vec;
+  vec.x = x;
+  vec.y = y;
+  vec.z = z;
 
   if (x == 0 && y == 0 && z == 0) {
     printf("--- Welcome to the printf test of %s ---\n", "Intel Beignet");
@@ -16,8 +20,8 @@ test_printf(void)
   if (x % 15 == 0)
     if (y % 3 == 0)
       if (z % 7 == 0)
-        printf("######## global_id(x, y, z) = (%d, %d, %d), global_size(d0, d1, d3) = (%d, %d, %d)\n",
-                x, y, z, get_global_size(0), get_global_size(1), get_global_size(2));
+        printf("######## global_id(x, y, z) = %v3d, global_size(d0, d1, d3) = (%d, %d, %d)\n",
+                vec, get_global_size(0), get_global_size(1), get_global_size(2));
 
   if (x == 1)
     if (y == 0) {
@@ -26,7 +30,9 @@ test_printf(void)
       else
           printf("#### output a float to int is %d\n", f);
     }
+
   if (x == 0 && y == 0 && z == 0) {
     printf("--- End to the printf test ---\n");
   }
+
 }
-- 
1.8.3.2



More information about the Beignet mailing list