You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/01/17 23:56:04 UTC

[GitHub] [tvm] yzh119 opened a new issue #9953: [Bug] New TIR syntax printer failed to handle dynamic shape.

yzh119 opened a new issue #9953:
URL: https://github.com/apache/tvm/issues/9953


   The current TIR syntax printer (introduced in #9680 ) fails when there are dynamic shapes in the script:
   
   python
   ```
   @T.prim_func
   def f(a: T.handle, b: T.handle, c: T.handle):
       N = T.var("int32")
       M = T.var("int32")
       K = T.var("int32")
       A = T.match_buffer(a, (N, K), "float32")
       B = T.match_buffer(b, (K, M), "float32")
       C = T.match_buffer(c, (N, M), "float32")
       for i, j, k in T.grid(N, M, K):
           with T.block("gemm"):
               vi, vj, vk = T.axis.remap("SSR", [i, j, k])
               with T.init():
                   C[vi, vj] = 0.
               C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
   
   print(f.script())
   ```
   
   ### Expected behavior
   The output script should be the same as input.
   
   ### Actual behavior
   
   The `M, N, K` are used before declaration.
   ```python
   # from tvm.script import tir as T
   @T.prim_func
   def func(A: T.Buffer[(N, K), "float32"], B: T.Buffer[(K, M), "float32"], C: T.Buffer[(N, M), "float32"]) -> None:
       K = T.var("int32")
       M = T.var("int32")
       N = T.var("int32")
       # body
       # with T.block("root")
       for i, j, k in T.grid(N, M, K):
           with T.block("gemm"):
               vi, vj, vk = T.axis.remap("SSR", [i, j, k])
               T.reads(C[vi, vj], A[vi, vk], B[vk, vj])
               T.writes(C[vi, vj])
               with T.init():
                   C[vi, vj] = T.float32(0)
               C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
   ```
   
   The same case if I pass tensor shape as parameters:
   ```python
   @T.prim_func
   def f(a: T.handle, b: T.handle, c: T.handle, N: T.int32, M: T.int32, K: T.int32):
       A = T.match_buffer(a, (N, K), "float32")
       B = T.match_buffer(b, (K, M), "float32")
       C = T.match_buffer(c, (N, M), "float32")
       for i, j, k in T.grid(N, M, K):
           with T.block("gemm"):
               vi, vj, vk = T.axis.remap("SSR", [i, j, k])
               with T.init():
                   C[vi, vj] = 0.
               C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
   ```
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on issue #9953: [Bug] New TIR syntax printer failed to handle dynamic shape.

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on issue #9953:
URL: https://github.com/apache/tvm/issues/9953#issuecomment-1015015613


   @yzh119 The bug is introduced in the printer because it didn't check whether the shape consists only of constant values.
   
   CC @shingjan It's a printer bug we should fix.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi closed issue #9953: [Bug] New TIR syntax printer failed to handle dynamic shape.

Posted by GitBox <gi...@apache.org>.
masahi closed issue #9953:
URL: https://github.com/apache/tvm/issues/9953


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org