You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2019/11/05 18:48:27 UTC

[GitHub] [incubator-tvm] soiferj commented on a change in pull request #4242: [AutoTVM] Add batch_matmul to tunable operations

soiferj commented on a change in pull request #4242: [AutoTVM] Add batch_matmul to tunable operations
URL: https://github.com/apache/incubator-tvm/pull/4242#discussion_r342734128
 
 

 ##########
 File path: topi/python/topi/x86/batch_matmul.py
 ##########
 @@ -18,43 +18,70 @@
 """x86 batch_matmul operators"""
 from __future__ import absolute_import as _abs
 import tvm
+from tvm import autotvm
+from tvm.autotvm.task.space import SplitEntity
 from tvm.contrib import cblas
-from topi.nn import batch_matmul, batch_matmul_default
-from .. import generic
+from .. import generic, nn
 from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
 
-@batch_matmul.register(["cpu"])
-def batch_matmul_x86(x, y):
+
+@autotvm.register_topi_compute(nn.batch_matmul, "cpu", "direct")
+def _declaration_batch_matmul_nopack(cfg, x, y):
     """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
     data in batch.
 
     Parameters
     ----------
+    cfg : ConfigSpace
+        Autotvm tuning space config file
     x : tvm.Tensor
         3-D with shape [batch, M, K]
-
     y : tvm.Tensor
         3-D with shape [batch, N, K]
-
     Returns
     -------
     output : tvm.Tensor
         3-D with shape [batch, M, N]
     """
+    print("CFG TYPE: ", type(cfg))
     target = tvm.target.current_target()
     if "cblas" in target.libs:
         return cblas.batch_matmul(x, y, False, True)
-    return batch_matmul_default(x, y)
 
-@generic.schedule_batch_matmul.register(["cpu"])
-def schedule_batch_matmul(outs):
+    assert len(x.shape) == 3 and len(
+        y.shape) == 3, "only support 3-dim batch_matmul"
+    XB, M, XK = get_const_tuple(x.shape)
+    YB, N, YK = get_const_tuple(y.shape)
+    assert XB == YB, "batch dimension doesn't match"
+    assert XK == YK, "shapes of x and y is inconsistant"
+    B = XB
+    K = XK
+    # create tuning space
+    cfg.define_split("tile_y", M, num_outputs=2)
+    cfg.define_split("tile_x", N, num_outputs=2)
+    cfg.define_split("tile_k", K, num_outputs=2)
+    if cfg.is_fallback:
+        _default_batch_matmul_nopack_config(cfg, M, N, K)
 
 Review comment:
   +1 for moving this into the schedule.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services