You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/04/24 09:27:55 UTC

[GitHub] [incubator-tvm] libaihong opened a new pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

libaihong opened a new pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428


   When I test with vectorized load&store for char2, there is a cuda compilation error:
   
   _RuntimeError: Compilation error:
   /tmp/tmpa87wlle2/my_kernel.cu(3093): error: no operator "=" matches these operands
               operand types are: char2 = int
   /tmp/tmpa87wlle2/my_kernel.cu(3094): error: no operator "&" matches these operands
               operand types are: char2 & int
   /tmp/tmpa87wlle2/my_kernel.cu(3096): error: no operator ">>" matches these operands
               operand types are: char2 >> int
   /tmp/tmpa87wlle2/my_kernel.cu(3097): error: no operator ">>" matches these operands
               operand types are: char2 >> int_
   
   
   The error occurs at the pieces of cuda codes generated by TVM as below:
   `char2 _66;`
   `_66=((signed char)(_67.x) << 0); `
   `_66=_66 & ~(0x000000ff << 8) |((signed char)(_67.y) << 8);`
   `((signed char*)T_cast)[_1.x] = ((char)(_66 >> 0));`
   `((signed char*)T_cast)[_1.y] = ((char)(_66 >> 8));`
   
   The “_66” is a char2 vector, but it is used as a scalar, which is the cause of this issue. However, the root cause is that the vectorized load/store for char2 is not supported. After fix this problem, the generated cuda codes are listed as following:
   `char2 _66;`
   `_66.x=((signed char)(_67.x));`
   `_66.y=((signed char)(_67.y)); `
   `((signed char*)T_cast)[_1.x] = _66.x;`
   `((signed char*)T_cast)[_1.y] = _66.y;`
   
   
   @vinx13 , could you please help review? Thanks!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] libaihong commented on pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
libaihong commented on pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#issuecomment-619465094


   > 
   > 
   > Could you please add a few tests.Do you mean loading uchar2x4? We could load/store them as uint16_t x 4. I do not see similar code in PrintType below:
   > 
   > https://github.com/apache/incubator-tvm/blob/master/src/target/source/codegen_cuda.cc#L186
   
   I've added some unit test. 
   
   
   > 
   > 
   > Could you please add a few tests.Do you mean loading uchar2x4? We could load/store them as uint16_t x 4. I do not see similar code in PrintType below:
   > 
   > https://github.com/apache/incubator-tvm/blob/master/src/target/source/codegen_cuda.cc#L186
   
   I've added some unit test.
   This modification only add code for char2 when the lanes is 2.  And when lanes is 4, it will use the old logic and will not loading uchar4.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] boh-inspur commented on a change in pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
boh-inspur commented on a change in pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#discussion_r416455958



##########
File path: src/target/source/codegen_cuda.cc
##########
@@ -274,9 +274,21 @@ void CodeGenCUDA::PrintVecElemLoad(
   static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
   if ((t.is_int()) && t.bits() == 8) {
-    os << "((char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {
+      os << vec;
+    } else if (t.lanes() == 2) {
+      os << vec << "." << access[i % 2];
+    } else {
+      os << "((char)(" << vec << " >> " << i * 8 << "))";
+    }
   } else if ((t.is_uint()) && t.bits() == 8) {
-    os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {

Review comment:
       @wpan11nv , I've realized the type with 'int',  and there is something wrong.
   When the type is int8x2,the cuda code seems correct and also build correctly.
   `
   #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)
   
   #include <sm_61_intrinsics.h>
   #endif
   extern "C" __global__ void default_function_kernel0(void* __restrict__ B, void* __restrict__ A) {
     int _1;
     {
       int _2 = (( int*)(( signed char*)A + (((((int)blockIdx.x) * 16) + (((int)threadIdx.x) * 2)))))[0];
       int _3 = (int)16843009;
       _1=((((char)(_2 >> 0))+((char)(_3 >> 0))) << 0);
       _1=_1 & ~(0x000000ff << 8) |((((char)(_2 >> 8))+((char)(_3 >> 8))) << 8);
     }
     (( int*)(( signed char*)B + (((((int)blockIdx.x) * 16) + (((int)threadIdx.x) * 2)))))[0] = _1;
   }
   `
   But there is a runtime error when copy from gpu to cpu memory : Check failed: e == cudaSuccess || e == cudaErrorCudartUnloading: misaligned address.  
   Do you have any advice?
   
   And if we use int32_t, when the type is int8x2, and if the size of the tensor is very huge, it will wastes lots of memory which is not nessary, and the runtime resource is more important. So I think based on the current code, just need a small modification that can support int8x3 by using char3 if needed support char3. What about your opinion?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] boh-inspur commented on pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
boh-inspur commented on pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#issuecomment-620311830


   > 
   > 
   > Do you see any issue to implement in a consistent way?
   > 
   > uint_8x2 could be also stored as uint16_t, the same as uint8_t. If I am not mistaken, this allows us to support {u}int8_tx3 naturally. Right?
   
   
   
   
   > 
   > 
   > Do you see any issue to implement in a consistent way?
   > 
   > uint_8x2 could be also stored as uint16_t, the same as uint8_t. If I am not mistaken, this allows us to support {u}int8_tx3 naturally. Right?
   
   Yes, it's other way to stored as uint16_t. And I think the author of the type logic may concern about {u}int8_tx3. In my code I've concerned about this condition before, but I think maybe it is impossible to use {u}int8_tx3, so the logic is not include this type  int8_t*3, but if needed, little modification of this code can support {u}int8_tx3.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] wpan11nv commented on a change in pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
wpan11nv commented on a change in pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#discussion_r416251080



##########
File path: src/target/source/codegen_cuda.cc
##########
@@ -274,9 +274,21 @@ void CodeGenCUDA::PrintVecElemLoad(
   static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
   if ((t.is_int()) && t.bits() == 8) {
-    os << "((char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {
+      os << vec;
+    } else if (t.lanes() == 2) {
+      os << vec << "." << access[i % 2];
+    } else {
+      os << "((char)(" << vec << " >> " << i * 8 << "))";
+    }
   } else if ((t.is_uint()) && t.bits() == 8) {
-    os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {

Review comment:
       I meant you have two "if (t.lanes() == 1)  return vec". You could just move this logic out, as it is true for all inputs. 
   
   The other part is nice-to-have if you can naturally support char2/3/4./8. that will  help simplify the codegen logic. 
   
   e.g. char2/3/4 can be stored as uint32_t. The only difference is only lower k bytes are useful. 
   
    




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] wpan11nv commented on a change in pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
wpan11nv commented on a change in pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#discussion_r416007695



##########
File path: src/target/source/codegen_cuda.cc
##########
@@ -274,9 +274,21 @@ void CodeGenCUDA::PrintVecElemLoad(
   static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
   if ((t.is_int()) && t.bits() == 8) {
-    os << "((char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {
+      os << vec;
+    } else if (t.lanes() == 2) {
+      os << vec << "." << access[i % 2];
+    } else {
+      os << "((char)(" << vec << " >> " << i * 8 << "))";
+    }
   } else if ((t.is_uint()) && t.bits() == 8) {
-    os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {

Review comment:
       This "lane = 1" logic could be promoted for all cases. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] boh-inspur commented on a change in pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
boh-inspur commented on a change in pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#discussion_r417015283



##########
File path: src/target/source/codegen_cuda.cc
##########
@@ -274,9 +274,21 @@ void CodeGenCUDA::PrintVecElemLoad(
   static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
   if ((t.is_int()) && t.bits() == 8) {
-    os << "((char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {
+      os << vec;
+    } else if (t.lanes() == 2) {
+      os << vec << "." << access[i % 2];
+    } else {
+      os << "((char)(" << vec << " >> " << i * 8 << "))";
+    }
   } else if ((t.is_uint()) && t.bits() == 8) {
-    os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {

Review comment:
       Thanks, I've moved the logic out, and modify code to support char3, add some unittest for char3.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] boh-inspur commented on a change in pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
boh-inspur commented on a change in pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#discussion_r416238967



##########
File path: src/target/source/codegen_cuda.cc
##########
@@ -274,9 +274,21 @@ void CodeGenCUDA::PrintVecElemLoad(
   static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
   if ((t.is_int()) && t.bits() == 8) {
-    os << "((char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {
+      os << vec;
+    } else if (t.lanes() == 2) {
+      os << vec << "." << access[i % 2];
+    } else {
+      os << "((char)(" << vec << " >> " << i * 8 << "))";
+    }
   } else if ((t.is_uint()) && t.bits() == 8) {
-    os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {

Review comment:
       When the lane is 1, the type is "char",  it doesn't need to do shift operation or type cast, so deal it in a separate logic.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] boh-inspur commented on a change in pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
boh-inspur commented on a change in pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#discussion_r416253058



##########
File path: src/target/source/codegen_cuda.cc
##########
@@ -274,9 +274,21 @@ void CodeGenCUDA::PrintVecElemLoad(
   static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
   if ((t.is_int()) && t.bits() == 8) {
-    os << "((char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {
+      os << vec;
+    } else if (t.lanes() == 2) {
+      os << vec << "." << access[i % 2];
+    } else {
+      os << "((char)(" << vec << " >> " << i * 8 << "))";
+    }
   } else if ((t.is_uint()) && t.bits() == 8) {
-    os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {

Review comment:
       That's a good suggestion, thanks a lot, I will modify the code, and commit it later.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] wpan11nv commented on a change in pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
wpan11nv commented on a change in pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#discussion_r416240568



##########
File path: src/target/source/codegen_cuda.cc
##########
@@ -274,9 +274,21 @@ void CodeGenCUDA::PrintVecElemLoad(
   static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
   if ((t.is_int()) && t.bits() == 8) {
-    os << "((char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {
+      os << vec;
+    } else if (t.lanes() == 2) {
+      os << vec << "." << access[i % 2];
+    } else {
+      os << "((char)(" << vec << " >> " << i * 8 << "))";
+    }
   } else if ((t.is_uint()) && t.bits() == 8) {
-    os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {

Review comment:
       I know that, I meant this is true for all types. not necessarily for char. You could common up logic for both branches, 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] boh-inspur commented on a change in pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
boh-inspur commented on a change in pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#discussion_r416249062



##########
File path: src/target/source/codegen_cuda.cc
##########
@@ -274,9 +274,21 @@ void CodeGenCUDA::PrintVecElemLoad(
   static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
   if ((t.is_int()) && t.bits() == 8) {
-    os << "((char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {
+      os << vec;
+    } else if (t.lanes() == 2) {
+      os << vec << "." << access[i % 2];
+    } else {
+      os << "((char)(" << vec << " >> " << i * 8 << "))";
+    }
   } else if ((t.is_uint()) && t.bits() == 8) {
-    os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {

Review comment:
       Do you mean store int8*2 as int16_t?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] wpan11nv commented on pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
wpan11nv commented on pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#issuecomment-619214302


   Could you please add a few tests.Do you mean supporting loading uchar2x4? We could load/store them as uint16_t x  4.  I do not see similar code  in PrintType below:
   
   https://github.com/apache/incubator-tvm/blob/master/src/target/source/codegen_cuda.cc#L186


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] wpan11nv commented on a change in pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
wpan11nv commented on a change in pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#discussion_r416785651



##########
File path: src/target/source/codegen_cuda.cc
##########
@@ -274,9 +274,21 @@ void CodeGenCUDA::PrintVecElemLoad(
   static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
   if ((t.is_int()) && t.bits() == 8) {
-    os << "((char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {
+      os << vec;
+    } else if (t.lanes() == 2) {
+      os << vec << "." << access[i % 2];
+    } else {
+      os << "((char)(" << vec << " >> " << i * 8 << "))";
+    }
   } else if ((t.is_uint()) && t.bits() == 8) {
-    os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {

Review comment:
       That is a good point I missed. We could not  use uint32_t as the stored type. for char2 or char3.  It is fine to keep it as it is. Thanks for looking into this!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#issuecomment-621350989


   Thanks  @wpan11nv @boh-inspur !


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] boh-inspur commented on a change in pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
boh-inspur commented on a change in pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#discussion_r416245487



##########
File path: src/target/source/codegen_cuda.cc
##########
@@ -274,9 +274,21 @@ void CodeGenCUDA::PrintVecElemLoad(
   static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
   if ((t.is_int()) && t.bits() == 8) {
-    os << "((char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {
+      os << vec;
+    } else if (t.lanes() == 2) {
+      os << vec << "." << access[i % 2];
+    } else {
+      os << "((char)(" << vec << " >> " << i * 8 << "))";
+    }
   } else if ((t.is_uint()) && t.bits() == 8) {
-    os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {

Review comment:
       Yes, we can do that.  The generated code may like the following:
   `char _1;`
   `char  _2 = ((char)((_1)<<0))`
   That's correct, but the code is a little strange, and may takes more time in runtime? If that's OK, I think we can common up logic for both branches, what is your opinion?
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] wpan11nv commented on a change in pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
wpan11nv commented on a change in pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#discussion_r416246643



##########
File path: src/target/source/codegen_cuda.cc
##########
@@ -274,9 +274,21 @@ void CodeGenCUDA::PrintVecElemLoad(
   static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
   if ((t.is_int()) && t.bits() == 8) {
-    os << "((char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {
+      os << vec;
+    } else if (t.lanes() == 2) {
+      os << vec << "." << access[i % 2];
+    } else {
+      os << "((char)(" << vec << " >> " << i * 8 << "))";
+    }
   } else if ((t.is_uint()) && t.bits() == 8) {
-    os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {

Review comment:
       It should look like:
   
   if lane == 1:
     return vec
   else {
     stored as int and extracted with bit ops. 
   } 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] wpan11nv commented on a change in pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
wpan11nv commented on a change in pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#discussion_r416785651



##########
File path: src/target/source/codegen_cuda.cc
##########
@@ -274,9 +274,21 @@ void CodeGenCUDA::PrintVecElemLoad(
   static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
   if ((t.is_int()) && t.bits() == 8) {
-    os << "((char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {
+      os << vec;
+    } else if (t.lanes() == 2) {
+      os << vec << "." << access[i % 2];
+    } else {
+      os << "((char)(" << vec << " >> " << i * 8 << "))";
+    }
   } else if ((t.is_uint()) && t.bits() == 8) {
-    os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {

Review comment:
       That is a good point I missed. We could not  use uint32_t as the stored type. for char2 or char3.  It is fine to keep it as it is. Thanks for looking into this!
   
   Also please move "if lane()=1" logic out. otherwise LGTM. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] boh-inspur commented on a change in pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
boh-inspur commented on a change in pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#discussion_r416455958



##########
File path: src/target/source/codegen_cuda.cc
##########
@@ -274,9 +274,21 @@ void CodeGenCUDA::PrintVecElemLoad(
   static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
   if ((t.is_int()) && t.bits() == 8) {
-    os << "((char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {
+      os << vec;
+    } else if (t.lanes() == 2) {
+      os << vec << "." << access[i % 2];
+    } else {
+      os << "((char)(" << vec << " >> " << i * 8 << "))";
+    }
   } else if ((t.is_uint()) && t.bits() == 8) {
-    os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {

Review comment:
       @wpan11nv , I've realized the type with 'int',  and there is something wrong.
   When the type is int8*2,the cuda code seems correct and also build correctly.
   `
   #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)
   
   #include <sm_61_intrinsics.h>
   #endif
   extern "C" __global__ void default_function_kernel0(void* __restrict__ B, void* __restrict__ A) {
     int _1;
     {
       int _2 = (( int*)(( signed char*)A + (((((int)blockIdx.x) * 16) + (((int)threadIdx.x) * 2)))))[0];
       int _3 = (int)16843009;
       _1=((((char)(_2 >> 0))+((char)(_3 >> 0))) << 0);
       _1=_1 & ~(0x000000ff << 8) |((((char)(_2 >> 8))+((char)(_3 >> 8))) << 8);
     }
     (( int*)(( signed char*)B + (((((int)blockIdx.x) * 16) + (((int)threadIdx.x) * 2)))))[0] = _1;
   }
   `
   But there is a runtime error when copy from gpu to cpu memory : Check failed: e == cudaSuccess || e == cudaErrorCudartUnloading: misaligned address.  
   Do you have any advice?
   
   And if we use int32_t, when the type is int8*2, and if the size of the tensor is very huge, it will wastes lots of memory which is not nessary, and the runtime resource is more important. So I think based on the current code, just need a small modification that can support int8*3 by using char3 if needed support char3. What about your opinion?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] wpan11nv edited a comment on pull request #5428: [CODEGEN][CUDA] Fix a bug when vectorized load&store was involved for…

Posted by GitBox <gi...@apache.org>.
wpan11nv edited a comment on pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#issuecomment-619214302


   Could you please add a few tests.Do you mean loading uchar2x4? We could load/store them as uint16_t x  4.  I do not see similar code  in PrintType below:
   
   https://github.com/apache/incubator-tvm/blob/master/src/target/source/codegen_cuda.cc#L186


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org