You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/11/09 01:45:46 UTC

[GitHub] [incubator-tvm] Light-of-Hers edited a comment on issue #6884: BUG in CodeGenC::PrintSSAAssign

Light-of-Hers edited a comment on issue #6884:
URL: https://github.com/apache/incubator-tvm/issues/6884#issuecomment-723707565


   @tqchen Here is an example. The `int8` vectorized scatter operation may cause such problem.
   ```python
   import tvm
   from tvm import te
   
   N = 128
   A = te.placeholder((N, N * 8), name="A", dtype="int8")
   B = te.compute((N, N), lambda i, j: A[i, j * 8], name="B")
   s = te.create_schedule(B.op)
   
   i, j = s[B].op.axis
   s[B].bind(i, te.thread_axis("blockIdx.x"))
   jo, ji = s[B].split(j, nparts=2)
   s[B].reorder(i, ji, jo)
   s[B].vectorize(jo)
   
   mod = tvm.lower(s, [A, B])
   func = tvm.build(mod, target="opencl")
   print(func.imported_modules[0].get_source())
   ```
   The generated opencl code is:
   ```opencl
   __kernel void main_kernel0(__global char* restrict B, __global char* restrict A) {
     for (int j_inner = 0; j_inner < 64; ++j_inner) {
         int2 _1 = (int2)((((((int)get_group_id(0)) * 128) + j_inner))+(64*0), (((((int)get_group_id(0)) * 128) + j_inner))+(64*1));
         int2 _2 = (int2)((((((int)get_group_id(0)) * 1024) + (j_inner * 8)))+(512*0), (((((int)get_group_id(0)) * 1024) + (j_inner * 8)))+(512*1));
         char2 _3 = (0x000000ff << 0) & (A[_2.s0] << 0))|((0x000000ff << 8) & (A[_2.s1] << 8);
         B[_1.s0] = _3.s0;
         B[_1.s1] = _3.s1;
     }
   }
   ```
   Obviously `char2 _3 = (0x000000ff << 0) & (A[_2.s0] << 0))|((0x000000ff << 8) & (A[_2.s1] << 8)` has unmatched brackets.
   The solution is simple. Just removing the if-branches in `CodeGenC::PrintSSAAssign` is okay, like: 
   ```C++
   void CodeGenC::PrintSSAAssign(const std::string& target, const std::string& src, DataType t) {
     PrintType(t, stream);
     stream << ' ' << target << " = ";
     // if (src.length() > 3 && src[0] == '(' && src[src.length() - 1] == ')') {
     //   stream << src.substr(1, src.length() - 2);
     // } else {
     //   stream << src;
     // }
     stream << src;
     stream << ";\n";
   }
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org