You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/12/14 07:09:33 UTC

[GitHub] [tvm] tsupei opened a new issue #7102: [TOPI] batch_matmul cannot use and / or / not operator to Expr

tsupei opened a new issue #7102:
URL: https://github.com/apache/tvm/issues/7102


   I try to use topi.nn.batch_matmul, following the instruction [here](https://tvm.apache.org/docs/api/python/topi.html). It requires input type `tvm.te.Tensor`. However, the following code incurs error.
   
   ```python3=
   import tvm
   import torch
   from tvm import topi
   from tvm import te
   
   def main():
   
       b = te.var("b")
       n = te.var("n")
       d = te.var("d")
   
       a = te.placeholder((b, n, d), dtype="float32", name="a")
       b = te.placeholder((b, n, d), dtype="float32", name="b")
   
       print(type(a))
   
       c = topi.nn.batch_matmul(a, b)
       print(c)
   
   if __name__ == "__main__":
       main()
   ```
   ```bash
   Traceback (most recent call last):
     File "bmm.py", line 19, in <module>
       main()
     File "bmm.py", line 15, in main
       c = topi.nn.batch_matmul(a, b)
     File "/home/jojo6174/tvm-installation/tvm/python/tvm/topi/nn/batch_matmul.py", line 54, in batch_matmul
       batch = max(XB, YB)
     File "/home/jojo6174/tvm-installation/tvm/python/tvm/tir/expr.py", line 176, in __bool__
       return self.__nonzero__()
     File "/home/jojo6174/tvm-installation/tvm/python/tvm/tir/expr.py", line 172, in __nonzero__
       + "use tvm.tir.all / tvm.tir.any instead"
   ValueError: Cannot use and / or / not operator to Expr, hint: use tvm.tir.all / tvm.tir.any instead
   ```
   
   I found that in `tvm/python/tvm/topi/nn/batch_matmul.py`. A python built-in max is called, however, the argumenets type seems incompatible. I modified it to `te.max` solved the problem. Is there any misunderstanding here. Thanks!
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhiics closed issue #7102: [TOPI] batch_matmul cannot use and / or / not operator to Expr

Posted by GitBox <gi...@apache.org>.
zhiics closed issue #7102:
URL: https://github.com/apache/tvm/issues/7102


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] insop commented on issue #7102: [TOPI] batch_matmul cannot use and / or / not operator to Expr

Posted by GitBox <gi...@apache.org>.
insop commented on issue #7102:
URL: https://github.com/apache/tvm/issues/7102#issuecomment-745977528


   @tsupei , @zhiics 
   
   How about this? by this, we can test.
   Patch is [here](https://github.com/insop/incubator-tvm/commit/9b7b1589f4a05a739e19652b7a2705e4bfc7bfcb.patch).
   
   ```
   From 9b7b1589f4a05a739e19652b7a2705e4bfc7bfcb Mon Sep 17 00:00:00 2001
   From: Insop Song <in...@gmail.com>
   Date: Wed, 16 Dec 2020 01:19:04 -0800
   Subject: [PATCH] Add test to dynamic batch matmul
   
   ---
    .../topi/python/test_topi_batch_matmul.py     | 38 ++++++++++++++++---
    1 file changed, 32 insertions(+), 6 deletions(-)
   
   diff --git a/tests/python/topi/python/test_topi_batch_matmul.py b/tests/python/topi/python/test_topi_batch_matmul.py
   index e939f6c21..be6552f03 100644
   --- a/tests/python/topi/python/test_topi_batch_matmul.py
   +++ b/tests/python/topi/python/test_topi_batch_matmul.py
   @@ -32,10 +32,24 @@ _batch_matmul_implement = {
    }
    
    
   -def verify_batch_matmul(x_batch, y_batch, M, N, K):
   -    x = te.placeholder((x_batch, M, K), name="x")
   -    y = te.placeholder((y_batch, N, K), name="y")
   -    dtype = x.dtype
   +def verify_batch_matmul(x_batch, y_batch, M, N, K, dynamic=False, debug=False):
   +
   +    if not dynamic:
   +        x = te.placeholder((x_batch, M, K), name="x")
   +        y = te.placeholder((y_batch, N, K), name="y")
   +        dtype = x.dtype
   +    else:
   +        assert x_batch == y_batch or x_batch == 1 or y_batch == 1
   +        batch_size = max(x_batch, y_batch)
   +        dynamic_batch_size = te.var("dynamic_batch_size")
   +        dynamic_M = te.var("dynamic_M")
   +        dynamic_N = te.var("dynamic_N")
   +        dynamic_K = te.var("dynamic_K")
   +
   +        x = te.placeholder((dynamic_batch_size, dynamic_M, dynamic_K), name="x")
   +        y = te.placeholder((dynamic_batch_size, dynamic_N, dynamic_K), name="y")
   +        dtype = x.dtype
   +
    
        # use memoize to pickle the test data for next time use
        @memoize("topi.tests.test_topi_batch_matmul")
   @@ -53,10 +67,19 @@ def verify_batch_matmul(x_batch, y_batch, M, N, K):
            with tvm.target.Target(device):
                fcompute, fschedule = tvm.topi.testing.dispatch(device, _batch_matmul_implement)
                out = fcompute(x, y)
   -            s = fschedule([out])
   +            if not dynamic:
   +                s = fschedule([out])
   +                out_shape = out.shape
   +            else:
   +                s = te.create_schedule(out.op)
   +                out_shape = (batch_size, M, N)
   +
   +            if debug:
   +                print(tvm.lower(s, [x, y, out], simple_mode=True))
   +
            a = tvm.nd.array(a_np, ctx)
            b = tvm.nd.array(b_np, ctx)
   -        c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=dtype), ctx)
   +        c = tvm.nd.array(np.zeros(get_const_tuple(out_shape), dtype=dtype), ctx)
            f = tvm.build(s, [x, y, out], device, name="dense")
            f(a, b, c)
            tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
   @@ -75,6 +98,9 @@ def test_batch_matmul():
        verify_batch_matmul(1, 5, 16, 16, 32)
        verify_batch_matmul(5, 1, 16, 16, 32)
    
   +    # Test dynamic batch
   +    verify_batch_matmul(1, 1, 16, 16, 32, dynamic=True, debug=True)
   +    verify_batch_matmul(5, 5, 16, 16, 32, dynamic=True)
    
    if __name__ == "__main__":
        test_batch_matmul()
   -- 
   2.28.0
   ```
   
   Test restults
   
   ```
   $ python ./test_topi_batch_matmul.py 
   Running on target: llvm -device=arm_cpu
   Cannot find config for target=llvm -keys=arm_cpu,cpu -device=arm_cpu -link-params=0, workload=('batch_matmul.x86', ('TENSOR', (1, 16, 32), 'float32'), ('TENSOR', (1, 16, 32), 'float32')). A fallback configuration is used, which may bring great performance regression.
   Running on target: llvm
   Cannot find config for target=llvm -keys=cpu -link-params=0, workload=('batch_matmul.x86', ('TENSOR', (1, 16, 32), 'float32'), ('TENSOR', (1, 16, 32), 'float32')). A fallback configuration is used, which may bring great performance regression.
   Running on target: llvm -device=arm_cpu
   Cannot find config for target=llvm -keys=arm_cpu,cpu -device=arm_cpu -link-params=0, workload=('batch_matmul.x86', ('TENSOR', (5, 16, 32), 'float32'), ('TENSOR', (5, 16, 32), 'float32')). A fallback configuration is used, which may bring great performance regression.
   Running on target: llvm
   Cannot find config for target=llvm -keys=cpu -link-params=0, workload=('batch_matmul.x86', ('TENSOR', (5, 16, 32), 'float32'), ('TENSOR', (5, 16, 32), 'float32')). A fallback configuration is used, which may bring great performance regression.
   Running on target: llvm -device=arm_cpu
   Cannot find config for target=llvm -keys=arm_cpu,cpu -device=arm_cpu -link-params=0, workload=('batch_matmul.x86', ('TENSOR', (5, 16, 32), 'float32'), ('TENSOR', (5, 20, 32), 'float32')). A fallback configuration is used, which may bring great performance regression.
   Running on target: llvm
   Cannot find config for target=llvm -keys=cpu -link-params=0, workload=('batch_matmul.x86', ('TENSOR', (5, 16, 32), 'float32'), ('TENSOR', (5, 20, 32), 'float32')). A fallback configuration is used, which may bring great performance regression.
   Running on target: llvm -device=arm_cpu
   Cannot find config for target=llvm -keys=arm_cpu,cpu -device=arm_cpu -link-params=0, workload=('batch_matmul.x86', ('TENSOR', (30, 16, 32), 'float32'), ('TENSOR', (30, 20, 32), 'float32')). A fallback configuration is used, which may bring great performance regression.
   Running on target: llvm
   Cannot find config for target=llvm -keys=cpu -link-params=0, workload=('batch_matmul.x86', ('TENSOR', (30, 16, 32), 'float32'), ('TENSOR', (30, 20, 32), 'float32')). A fallback configuration is used, which may bring great performance regression.
   Running on target: llvm -device=arm_cpu
   Cannot find config for target=llvm -keys=arm_cpu,cpu -device=arm_cpu -link-params=0, workload=('batch_matmul.x86', ('TENSOR', (1, 16, 32), 'float32'), ('TENSOR', (5, 16, 32), 'float32')). A fallback configuration is used, which may bring great performance regression.
   Running on target: llvm
   Cannot find config for target=llvm -keys=cpu -link-params=0, workload=('batch_matmul.x86', ('TENSOR', (1, 16, 32), 'float32'), ('TENSOR', (5, 16, 32), 'float32')). A fallback configuration is used, which may bring great performance regression.
   Running on target: llvm -device=arm_cpu
   Cannot find config for target=llvm -keys=arm_cpu,cpu -device=arm_cpu -link-params=0, workload=('batch_matmul.x86', ('TENSOR', (5, 16, 32), 'float32'), ('TENSOR', (1, 16, 32), 'float32')). A fallback configuration is used, which may bring great performance regression.
   Running on target: llvm
   Cannot find config for target=llvm -keys=cpu -link-params=0, workload=('batch_matmul.x86', ('TENSOR', (5, 16, 32), 'float32'), ('TENSOR', (1, 16, 32), 'float32')). A fallback configuration is used, which may bring great performance regression.
   Running on target: llvm -device=arm_cpu
   Cannot find config for target=llvm -keys=arm_cpu,cpu -device=arm_cpu -link-params=0, workload=('batch_matmul.x86', ('TENSOR', (dynamic_batch_size, dynamic_M, dynamic_K), 'float32'), ('TENSOR', (dynamic_batch_size, dynamic_N, dynamic_K), 'float32')). A fallback configuration is used, which may bring great performance regression.
   primfn(x_1: handle, y_1: handle, compute_1: handle) -> ()
     attr = {"global_symbol": "main", "tir.noalias": True}
     buffers = {compute: Buffer(compute_2: Pointer(float32), float32, [dynamic_batch_size: int32, dynamic_M: int32, dynamic_N: int32], [stride: int32, stride_1: int32, stride_2: int32], type="auto"),
                y: Buffer(y_2: Pointer(float32), float32, [dynamic_batch_size, dynamic_N, dynamic_K: int32], [stride_3: int32, stride_4: int32, stride_5: int32], type="auto"),
                x: Buffer(x_2: Pointer(float32), float32, [dynamic_batch_size, dynamic_M, dynamic_K], [stride_6: int32, stride_7: int32, stride_8: int32], type="auto")}
     buffer_map = {x_1: x, y_1: y, compute_1: compute} {
     for (b: int32, 0, dynamic_batch_size) {
       for (i: int32, 0, dynamic_M) {
         for (j: int32, 0, dynamic_N) {
           compute_2[(((b*stride) + (i*stride_1)) + (j*stride_2))] = 0f32
           for (k: int32, 0, dynamic_K) {
             compute_2[(((b*stride) + (i*stride_1)) + (j*stride_2))] = ((float32*)compute_2[(((b*stride) + (i*stride_1)) + (j*stride_2))] + ((float32*)x_2[(((b*stride_6) + (i*stride_7)) + (k*stride_8))]*(float32*)y_2[(((b*stride_3) + (j*stride_4)) + (k*stride_5))]))
           }
         }
       }
     }
   }
   
   
   Running on target: llvm
   Cannot find config for target=llvm -keys=cpu -link-params=0, workload=('batch_matmul.x86', ('TENSOR', (dynamic_batch_size, dynamic_M, dynamic_K), 'float32'), ('TENSOR', (dynamic_batch_size, dynamic_N, dynamic_K), 'float32')). A fallback configuration is used, which may bring great performance regression.
   primfn(x_1: handle, y_1: handle, compute_1: handle) -> ()
     attr = {"global_symbol": "main", "tir.noalias": True}
     buffers = {y: Buffer(y_2: Pointer(float32), float32, [dynamic_batch_size: int32, dynamic_N: int32, dynamic_K: int32], [stride: int32, stride_1: int32, stride_2: int32], type="auto"),
                compute: Buffer(compute_2: Pointer(float32), float32, [dynamic_batch_size, dynamic_M: int32, dynamic_N], [stride_3: int32, stride_4: int32, stride_5: int32], type="auto"),
                x: Buffer(x_2: Pointer(float32), float32, [dynamic_batch_size, dynamic_M, dynamic_K], [stride_6: int32, stride_7: int32, stride_8: int32], type="auto")}
     buffer_map = {x_1: x, y_1: y, compute_1: compute} {
     for (b: int32, 0, dynamic_batch_size) {
       for (i: int32, 0, dynamic_M) {
         for (j: int32, 0, dynamic_N) {
           compute_2[(((b*stride_3) + (i*stride_4)) + (j*stride_5))] = 0f32
           for (k: int32, 0, dynamic_K) {
             compute_2[(((b*stride_3) + (i*stride_4)) + (j*stride_5))] = ((float32*)compute_2[(((b*stride_3) + (i*stride_4)) + (j*stride_5))] + ((float32*)x_2[(((b*stride_6) + (i*stride_7)) + (k*stride_8))]*(float32*)y_2[(((b*stride) + (j*stride_1)) + (k*stride_2))]))
           }
         }
       }
     }
   }
   
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhiics commented on issue #7102: [TOPI] batch_matmul cannot use and / or / not operator to Expr

Posted by GitBox <gi...@apache.org>.
zhiics commented on issue #7102:
URL: https://github.com/apache/tvm/issues/7102#issuecomment-748382078


   closed by #7111 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tsupei commented on issue #7102: [TOPI] batch_matmul cannot use and / or / not operator to Expr

Posted by GitBox <gi...@apache.org>.
tsupei commented on issue #7102:
URL: https://github.com/apache/tvm/issues/7102#issuecomment-745008162


   @zhiics Sure! I will send it later


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhiics commented on issue #7102: [TOPI] batch_matmul cannot use and / or / not operator to Expr

Posted by GitBox <gi...@apache.org>.
zhiics commented on issue #7102:
URL: https://github.com/apache/tvm/issues/7102#issuecomment-744637958


   Thanks for reporting the issue.  I think it is `te.max`. Could you send a PR?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] insop edited a comment on issue #7102: [TOPI] batch_matmul cannot use and / or / not operator to Expr

Posted by GitBox <gi...@apache.org>.
insop edited a comment on issue #7102:
URL: https://github.com/apache/tvm/issues/7102#issuecomment-745977528


   Moved the post to https://github.com/apache/tvm/pull/7111 instead


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org