You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/12/15 05:03:11 UTC

[GitHub] [tvm] insop commented on pull request #7111: Fix a bug in batch_matmul that te.max should be used

insop commented on pull request #7111:
URL: https://github.com/apache/tvm/pull/7111#issuecomment-745054063


   1. @tsupei  Good fix, you might want to `git am` [this patch file](https://github.com/insop/incubator-tvm/commit/79e8cf5b24e5170fbfc51461249425ad6b868a7c.patch) to fix other locations
   
   
   ```
   From 79e8cf5b24e5170fbfc51461249425ad6b868a7c Mon Sep 17 00:00:00 2001
   From: Insop Song <in...@gmail.com>
   Date: Mon, 14 Dec 2020 21:01:03 -0800
   Subject: [PATCH] Additional fix to batch_matmul
   
   - add to this PR, https://github.com/apache/tvm/pull/7111
   ---
    python/tvm/topi/testing/batch_matmul.py | 2 +-
    python/tvm/topi/x86/batch_matmul.py     | 2 +-
    2 files changed, 2 insertions(+), 2 deletions(-)
   
   diff --git a/python/tvm/topi/testing/batch_matmul.py b/python/tvm/topi/testing/batch_matmul.py
   index a48c92967..d24198e96 100644
   --- a/python/tvm/topi/testing/batch_matmul.py
   +++ b/python/tvm/topi/testing/batch_matmul.py
   @@ -37,7 +37,7 @@ def batch_matmul(x, y):
        """
        XB, M, _ = x.shape
        YB, N, _ = y.shape
   -    batch = max(XB, YB)
   +    batch = te.max(XB, YB)
        out = np.zeros((batch, M, N)).astype(x.dtype)
        for i in range(batch):
            out[i] = np.dot(x[i if XB != 1 else 0], y[i if YB != 1 else 0].T)
   diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py
   index 166c79a4c..79b38de8c 100644
   --- a/python/tvm/topi/x86/batch_matmul.py
   +++ b/python/tvm/topi/x86/batch_matmul.py
   @@ -50,7 +50,7 @@ def batch_matmul(cfg, x, y, out_shape=None):
        YB, N, YK = get_const_tuple(y.shape)
        assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match"
        assert XK == YK, "shapes of x and y is inconsistant"
   -    B = max(XB, YB)
   +    B = te.max(XB, YB)
        K = XK
        if out_shape is not None:
            assert out_shape[0] == B, "got invalid output shape"
   -- 
   2.28.0
   
   
   ```
   
   
   
   2. @zhiics
   Like most of other test files, this test[ test_topi_batch_matmul.py](https://github.com/apache/tvm/blob/main/tests/python/topi/python/test_topi_batch_matmul.py) was not covering `te.var` case, hence we have this issue.
   Should we add those test cases in as well? Let me know.
    
   - https://github.com/apache/tvm/blob/main/tests/python/topi/python/test_topi_batch_matmul.py
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org