You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/02/26 01:39:09 UTC

[GitHub] [tvm] comaniac commented on a change in pull request #7496: [Frontend,TOPI] Improve dynamism for BatchMatmul and Dense

comaniac commented on a change in pull request #7496:
URL: https://github.com/apache/tvm/pull/7496#discussion_r583324189



##########
File path: python/tvm/topi/cuda/batch_matmul.py
##########
@@ -161,7 +161,8 @@ def batch_matmul_cublas(cfg, x, y, out_shape=None):
     """
     b, m, k = x.shape
     b, n, k = y.shape
-    cfg.add_flop(b * m * k * n * 2)
+    if isinstance(b, int) and isinstance(m, int) and isinstance(n, int) and isinstance(k, int):

Review comment:
       ```suggestion
       if all([isinstance(s, int) for s in [b, m, n, k]]):
   ```
   And it seems possible for them to be in `tir.IntImm` as well?

##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -44,6 +44,28 @@
 __all__ = ["from_tensorflow"]
 
 
+def check_symbolic_shape(shape):
+    return not all([isinstance(dim, (int, tvm.tir.IntImm)) for dim in shape])
+
+
+def list_shape_of(tensor, ndim):
+    shape_tensor = _op.shape_of(tensor)
+    return [
+        _op.strided_slice(shape_tensor, begin=[i], end=[i + 1], strides=[1]) for i in range(ndim)
+    ]
+
+
+def concat_dynamic_shape(shape_list):

Review comment:
       Didn't see this function being used anywhere?

##########
File path: python/tvm/topi/nn/batch_matmul.py
##########
@@ -61,9 +62,17 @@ def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""):
     _, M, K = x.shape
     k = te.reduce_axis((0, K), name="k")
     if oshape is None:
-        assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
-        assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistent"
-        batch = te.max(XB, YB)
+        if isinstance(XB, int) and isinstance(YB, int):
+            assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
+            batch = max(XB, YB)
+        elif isinstance(XB, tir.expr.Var):
+            batch = XB
+        else:
+            batch = YB
+
+        if isinstance(x_shape[2], int) and isinstance(y_shape[2], int):
+            assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant"

Review comment:
       ```suggestion
               assert x_shape[2] == y_shape[2], "shapes of x and y are inconsistant"
   ```

##########
File path: python/tvm/topi/cuda/dense.py
##########
@@ -42,11 +42,8 @@ def dense_cublas(cfg, data, weight, bias=None, out_dtype=None):
     batch, in_dim = data.shape
     out_dim, _ = weight.shape
     matmul = cublas.matmul(data, weight, False, True)
-    if isinstance(batch, int):
+    if isinstance(batch, int) and isinstance(in_dim, int) and isinstance(out_dim, int):

Review comment:
       What if batch has `tir.IntImm` type?

##########
File path: python/tvm/topi/nn/batch_matmul.py
##########
@@ -61,9 +62,17 @@ def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""):
     _, M, K = x.shape
     k = te.reduce_axis((0, K), name="k")
     if oshape is None:
-        assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
-        assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistent"
-        batch = te.max(XB, YB)
+        if isinstance(XB, int) and isinstance(YB, int):
+            assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
+            batch = max(XB, YB)
+        elif isinstance(XB, tir.expr.Var):
+            batch = XB
+        else:
+            batch = YB

Review comment:
       Could you clarify why this logic?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org