You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/12/25 00:59:05 UTC

[tvm] branch main updated: [AutoScheduler] Add layout rewrite support for dense and batch matmul on CPU (#7161)

This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7dcafb0  [AutoScheduler] Add layout rewrite support for dense and batch matmul on CPU (#7161)
7dcafb0 is described below

commit 7dcafb017a05ac0d5ecd7cfe8d8741d33a24bbad
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Thu Dec 24 16:57:24 2020 -0800

    [AutoScheduler] Add layout rewrite support for dense and batch matmul on CPU (#7161)
    
    * [AutoScheduler] Add layout rewrite for dense and batch_matmul
    
    * Fix test & Address comments
    
    * Fix shape inference
    
    * fix test
---
 include/tvm/auto_scheduler/compute_dag.h           |  8 ++++
 include/tvm/relay/attrs/nn.h                       | 10 ++++-
 python/tvm/auto_scheduler/__init__.py              |  2 +-
 python/tvm/auto_scheduler/compute_dag.py           | 18 +++++++++
 python/tvm/relay/op/strategy/generic.py            | 17 +++++---
 python/tvm/relay/op/strategy/x86.py                | 15 +++++--
 python/tvm/testing.py                              | 18 +++++++++
 python/tvm/topi/nn/batch_matmul.py                 | 30 +++++++++++---
 python/tvm/topi/nn/conv2d.py                       | 37 +++++------------
 python/tvm/topi/nn/dense.py                        | 30 +++++++++++---
 src/auto_scheduler/compute_dag.cc                  | 26 ++++++++++++
 src/relay/op/make_op.h                             |  2 +
 src/relay/op/nn/nn.cc                              | 34 +++++++++++-----
 src/relay/op/nn/nn.h                               | 10 ++++-
 .../transforms/auto_scheduler_layout_rewrite.cc    | 16 +++++++-
 .../transforms/combine_parallel_batch_matmul.cc    |  7 ++--
 src/relay/transforms/combine_parallel_dense.cc     | 16 ++++++++
 src/relay/transforms/combine_parallel_op_batch.h   |  2 +-
 .../relay/test_auto_scheduler_layout_rewrite.py    | 47 +++++++++++++++++++++-
 .../relay/test_pass_combine_parallel_dense.py      |  2 -
 .../python/unittest/test_auto_scheduler_common.py  | 18 ---------
 .../unittest/test_auto_scheduler_search_policy.py  |  3 +-
 22 files changed, 276 insertions(+), 92 deletions(-)

diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h
index bdb6489..1e3f097 100755
--- a/include/tvm/auto_scheduler/compute_dag.h
+++ b/include/tvm/auto_scheduler/compute_dag.h
@@ -303,6 +303,14 @@ class ComputeDAG : public ObjectRef {
   TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode);
 };
 
+/*!
+ *  \brief Get the orginal shape from a rewritten layout string.
+ *  \param rewritten_layout The layout after auto-scheduler's layout rewrite.
+ *  \param axis_names Specifiy the names of axes.
+ *  \return shape The original shape.
+ */
+Array<PrimExpr> GetShapeFromRewrittenLayout(String rewritten_layout, Array<String> axis_names);
+
 }  // namespace auto_scheduler
 }  // namespace tvm
 
diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index 5ffca99..7bfd580 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -120,7 +120,7 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
   tvm::String data_layout;
   tvm::String kernel_layout;
   tvm::String out_layout;
-  std::string auto_scheduler_rewritten_layout;
+  tvm::String auto_scheduler_rewritten_layout;  // The layout after auto-scheduler's layout rewrite
   DataType out_dtype;
 
   TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
@@ -924,6 +924,7 @@ struct AvgPool3DAttrs : public tvm::AttrsNode<AvgPool3DAttrs> {
 /*! \brief Attributes for dense operator */
 struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
   IndexExpr units;
+  tvm::String auto_scheduler_rewritten_layout;  // The layout after auto-scheduler's layout rewrite
   DataType out_dtype;
 
   TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") {
@@ -936,6 +937,13 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
   }
 };
 
+/*! \brief Attributes for batch matmul operator */
+struct BatchMatmulAttrs : public tvm::AttrsNode<BatchMatmulAttrs> {
+  tvm::String auto_scheduler_rewritten_layout;  // The layout after auto-scheduler's layout rewrite
+
+  TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") {}
+};
+
 /*! \brief Attributes for sparse_dense operator */
 struct SparseDenseAttrs : public tvm::AttrsNode<SparseDenseAttrs> {
   bool sparse_lhs;
diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py
index 4926b88..a03e156 100644
--- a/python/tvm/auto_scheduler/__init__.py
+++ b/python/tvm/auto_scheduler/__init__.py
@@ -31,7 +31,7 @@ from . import utils
 from . import workload_registry
 
 # Shortcut
-from .compute_dag import ComputeDAG, LayoutRewriteOption
+from .compute_dag import ComputeDAG, LayoutRewriteOption, get_shape_from_rewritten_layout
 from .cost_model import RandomModel, XGBModel
 from .dispatcher import DispatchContext, ApplyHistoryBest
 from .measure import (
diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py
index 94cb640..d8a2422 100755
--- a/python/tvm/auto_scheduler/compute_dag.py
+++ b/python/tvm/auto_scheduler/compute_dag.py
@@ -234,3 +234,21 @@ class ComputeDAG(Object):
         # Since we always use tensors to recover the ComputeDAG, we do not support
         # (de)serialization of the ComputeDAG constructed by a schedule.
         self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, LoadJSON(state["tensors"]), None)
+
+
+def get_shape_from_rewritten_layout(rewritten_layout, axis_names):
+    """Get the orginal shape from a rewritten layout string.
+
+    Parameters
+    ----------
+    rewritten_layout: str
+        The layout after rewrite
+    axis_names: List[str]
+        Specify the order of axes by names
+
+    Returns
+    -------
+    shape: List[PrimExpr]
+        The original shape
+    """
+    return _ffi_api.GetShapeFromRewrittenLayout(rewritten_layout, axis_names)
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index 9fc6089..95b5d6a 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -199,7 +199,6 @@ def wrap_compute_conv2d(
         data_layout = attrs.get_str("data_layout")
         out_layout = attrs.get_str("out_layout")
         out_dtype = attrs.out_dtype
-        auto_scheduler_rewritten_layout = get_auto_scheduler_rewritten_layout(attrs)
         out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
         args = [inputs[0], inputs[1], strides, padding, dilation]
         if has_groups:
@@ -210,7 +209,7 @@ def wrap_compute_conv2d(
             args.append(out_layout)
         args.append(out_dtype)
         if need_auto_scheduler_layout:
-            args.append(auto_scheduler_rewritten_layout)
+            args.append(get_auto_scheduler_rewritten_layout(attrs))
         return [topi_compute(*args)]
 
     return _compute_conv2d
@@ -684,14 +683,17 @@ def dilation2d_strategy(attrs, inputs, out_type, target):
 
 
 # dense
-def wrap_compute_dense(topi_compute):
+def wrap_compute_dense(topi_compute, need_auto_scheduler_layout=False):
     """wrap dense topi compute"""
 
     def _compute_dense(attrs, inputs, out_type):
         """Compute definition of dense"""
         out_dtype = attrs.out_dtype
         out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
-        return [topi_compute(inputs[0], inputs[1], None, out_dtype)]
+        args = [inputs[0], inputs[1], None, out_dtype]
+        if need_auto_scheduler_layout:
+            args.append(get_auto_scheduler_rewritten_layout(attrs))
+        return [topi_compute(*args)]
 
     return _compute_dense
 
@@ -710,11 +712,14 @@ def dense_strategy(attrs, inputs, out_type, target):
 
 
 # batch_matmul
-def wrap_compute_batch_matmul(topi_compute):
+def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False):
     """wrap batch_matmul topi compute"""
 
     def _compute_batch_matmul(attrs, inputs, out_type):
-        return [topi_compute(inputs[0], inputs[1], out_type.shape)]
+        args = [inputs[0], inputs[1], out_type.shape]
+        if need_auto_scheduler_layout:
+            args.append(get_auto_scheduler_rewritten_layout(attrs))
+        return [topi_compute(*args)]
 
     return _compute_batch_matmul
 
diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py
index 5dfeca6..841213a 100644
--- a/python/tvm/relay/op/strategy/x86.py
+++ b/python/tvm/relay/op/strategy/x86.py
@@ -325,6 +325,15 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
         name="dense_nopack.x86",
         plevel=10,
     )
+
+    if is_auto_scheduler_enabled():
+        strategy.add_implementation(
+            wrap_compute_dense(topi.nn.dense, need_auto_scheduler_layout=True),
+            naive_schedule,
+            name="dense.generic",
+            plevel=11,
+        )
+
     if "cblas" in target.libs:
         with SpecializedCondition(same_type and dtype in ["float32", "float64"]):
             strategy.add_implementation(
@@ -350,7 +359,7 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
                 plevel=15,
             )
     with SpecializedCondition(m >= 16):
-        # this implementation may not be well-optimized, so use plevel=8 for now.
+        # this implementation may not be well-optimized, so use plevel=5 for now.
         strategy.add_implementation(
             wrap_compute_dense(topi.x86.dense_pack),
             wrap_topi_schedule(topi.x86.schedule_dense_pack),
@@ -364,9 +373,9 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
 def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
     """batch_matmul x86 strategy"""
     strategy = _op.OpStrategy()
-    if is_dynamic(out_type):
+    if is_dynamic(out_type) or is_auto_scheduler_enabled():
         strategy.add_implementation(
-            wrap_compute_batch_matmul(topi.nn.batch_matmul),
+            wrap_compute_batch_matmul(topi.nn.batch_matmul, need_auto_scheduler_layout=True),
             wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul),
             name="batch_matmul.generic",
             plevel=10,
diff --git a/python/tvm/testing.py b/python/tvm/testing.py
index 8311a63..32307a9 100644
--- a/python/tvm/testing.py
+++ b/python/tvm/testing.py
@@ -58,6 +58,7 @@ import logging
 import os
 import sys
 import time
+import threading
 import pytest
 import numpy as np
 import tvm
@@ -742,4 +743,21 @@ def terminate_self():
     sys.exit(-1)
 
 
+class PropagatingThread(threading.Thread):
+    """A thread that propagates the exection to the main thread"""
+
+    def run(self):
+        self.exc = None
+        try:
+            self.ret = self._target(*self._args, **self._kwargs)
+        except BaseException as e:
+            self.exc = e
+
+    def join(self, timeout=None):
+        super(PropagatingThread, self).join(timeout)
+        if self.exc:
+            raise self.exc
+        return self.ret
+
+
 tvm._ffi._init_api("testing", __name__)
diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py
index 98acc2d..9ca2df7 100644
--- a/python/tvm/topi/nn/batch_matmul.py
+++ b/python/tvm/topi/nn/batch_matmul.py
@@ -14,13 +14,13 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Binary Neural Network (BNN) Operators"""
+"""Batch matrix multiplication"""
 # pylint: disable=invalid-name
-from tvm import te
+from tvm import te, auto_scheduler
 from ..utils import get_const_tuple
 
 
-def batch_matmul(x, y, oshape=None):
+def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""):
     """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
     data in batch. Supports broadcasting for batch dimension.
 
@@ -36,14 +36,25 @@ def batch_matmul(x, y, oshape=None):
         Explicit intended output shape of the computation. Can be useful in cases
         with dynamic input shapes.
 
+    auto_scheduler_rewritten_layout: str = ""
+        The layout after auto-scheduler's layout rewrite pass.
+
     Returns
     -------
     output : tvm.te.Tensor
         3-D with shape [batch, M, N]
     """
-    assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
     x_shape = get_const_tuple(x.shape)
-    y_shape = get_const_tuple(y.shape)
+    if auto_scheduler_rewritten_layout:
+        # Infer shape for the rewritten layout
+        y_shape = auto_scheduler.get_shape_from_rewritten_layout(
+            auto_scheduler_rewritten_layout, ["b", "j", "k"]
+        )
+        auto_scheduler.remove_index_check(y)
+    else:
+        y_shape = get_const_tuple(y.shape)
+    assert len(x_shape) == 3 and len(y_shape) == 3, "only support 3-dim batch_matmul"
+
     XB = x_shape[0]
     YB = y_shape[0]
     _, M, K = x.shape
@@ -54,8 +65,15 @@ def batch_matmul(x, y, oshape=None):
         batch = te.max(XB, YB)
         N = y.shape[1]
         oshape = (batch, M, N)
-    return te.compute(
+
+    output = te.compute(
         oshape,
         lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k),
         tag="batch_matmul",
+        attrs={"layout_free_placeholders": [y]},
     )
+
+    if auto_scheduler_rewritten_layout:
+        output = auto_scheduler.rewrite_compute_body(output, auto_scheduler_rewritten_layout)
+
+    return output
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index ead9f16..e2384c4 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -361,6 +361,12 @@ def conv2d_nhwc(
     dilation: int or a list/tuple of two ints
         dilation size, or [dilation_height, dilation_width]
 
+    out_dtype: str = "float32",
+        The type of output tensor
+
+    auto_scheduler_rewritten_layout: str = ""
+        The layout after auto-scheduler's layout rewrite pass.
+
     Returns
     -------
     output : tvm.te.Tensor
@@ -381,34 +387,9 @@ def conv2d_nhwc(
 
     if auto_scheduler_rewritten_layout:
         # Infer shape for the rewritten layout
-        # todo(merrymercy): wrap this with a more general interface.
-        if len(Filter.shape) == 17:
-            # For mali.
-            # GPU tile structure is SSSRRSRS
-            # You could refer function comment of DoMultiLevelTiling
-            # in the utils.h to see more detail explanation.
-            kernel_h = Filter.shape[6] * Filter.shape[9] * Filter.shape[13]
-            kernel_w = Filter.shape[7] * Filter.shape[10] * Filter.shape[14]
-            channel = Filter.shape[8] * Filter.shape[11] * Filter.shape[15]
-            num_filter = Filter.shape[12] * Filter.shape[16]
-            for i in range(6):
-                num_filter *= Filter.shape[i]
-        elif len(Filter.shape) >= 10:
-            # For cpu tile structure SSRSRS
-            base = len(Filter.shape) - 10
-            kernel_h = Filter.shape[2 + base] * Filter.shape[6 + base]
-            kernel_w = Filter.shape[3 + base] * Filter.shape[7 + base]
-            channel = Filter.shape[4 + base] * Filter.shape[8 + base]
-            num_filter = Filter.shape[5 + base] * Filter.shape[9 + base]
-            for i in range(base + 2):
-                num_filter *= Filter.shape[i]
-        elif len(Filter.shape) == 4:
-            num_filter, kernel_h, kernel_w, channel = Filter.shape
-        else:
-            raise ValueError(
-                "Don't know how to infer the layout for filter shape: %s. "
-                "Please add a new branch to handle this case." % str(Filter)
-            )
+        kernel_h, kernel_w, channel, num_filter = auto_scheduler.get_shape_from_rewritten_layout(
+            auto_scheduler_rewritten_layout, ["ry", "rx", "rc", "ff"]
+        )
         auto_scheduler.remove_index_check(Filter)
     else:
         kernel_h, kernel_w, channel, num_filter = Filter.shape
diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py
index 0ce0f9e..474fea4 100644
--- a/python/tvm/topi/nn/dense.py
+++ b/python/tvm/topi/nn/dense.py
@@ -15,11 +15,11 @@
 # specific language governing permissions and limitations
 # under the License.
 """TVM operator fully connected compute."""
-from tvm import te
+from tvm import te, auto_scheduler
 from .. import tag
 
 
-def dense(data, weight, bias=None, out_dtype=None):
+def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layout=""):
     """The default implementation of dense in topi.
 
     Parameters
@@ -30,30 +30,44 @@ def dense(data, weight, bias=None, out_dtype=None):
     weight : tvm.te.Tensor
         2-D with shape [out_dim, in_dim]
 
-    bias : tvm.te.Tensor, optional
+    bias : Optional[tvm.te.Tensor]
         1-D with shape [out_dim]
 
-    out_dtype : str
+    out_dtype : Optional[str]
         The output type. This is used for mixed precision.
 
+    auto_scheduler_rewritten_layout: str = ""
+        The layout after auto-scheduler's layout rewrite pass.
+
     Returns
     -------
     output : tvm.te.Tensor
         2-D with shape [batch, out_dim]
     """
-    assert len(data.shape) == 2 and len(weight.shape) == 2, "only support 2-dim dense"
+    assert len(data.shape) == 2, "only support 2-dim dense"
     if bias is not None:
         assert len(bias.shape) == 1
     if out_dtype is None:
         out_dtype = data.dtype
     batch, in_dim = data.shape
-    out_dim, _ = weight.shape
+
+    if auto_scheduler_rewritten_layout:
+        # Infer shape for the rewritten layout
+        out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout(
+            auto_scheduler_rewritten_layout, ["j", "k"]
+        )
+        auto_scheduler.remove_index_check(weight)
+    else:
+        out_dim, red_dim = weight.shape
+    assert in_dim == red_dim
+
     k = te.reduce_axis((0, in_dim), name="k")
     matmul = te.compute(
         (batch, out_dim),
         lambda i, j: te.sum(data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=k),
         name="T_dense",
         tag="dense",
+        attrs={"layout_free_placeholders": [weight]},
     )
     if bias is not None:
         matmul = te.compute(
@@ -61,4 +75,8 @@ def dense(data, weight, bias=None, out_dtype=None):
             lambda i, j: matmul[i, j] + bias[j].astype(out_dtype),
             tag=tag.BROADCAST,
         )
+
+    if auto_scheduler_rewritten_layout:
+        matmul = auto_scheduler.rewrite_compute_body(matmul, auto_scheduler_rewritten_layout)
+
     return matmul
diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc
index af45f2d..64114c8 100755
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -33,6 +33,7 @@
 #include <tvm/te/schedule_pass.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/stmt_functor.h>
+#include <tvm/topi/transform.h>
 
 #include <algorithm>
 #include <cstdint>
@@ -1410,6 +1411,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << ss.str();
     });
 
+Array<PrimExpr> GetShapeFromRewrittenLayout(String rewritten_layout, Array<String> axis_names) {
+  Array<PrimExpr> shape;
+  std::vector<std::string> extracted_names;
+  topi::parse_auto_scheduler_layout(rewritten_layout, &shape, &extracted_names);
+
+  Array<PrimExpr> ret(axis_names.size(), 1);
+
+  size_t ct = 0;
+  for (size_t i = 0; i < axis_names.size(); ++i) {
+    for (size_t j = 0; j < extracted_names.size(); ++j) {
+      if (axis_names[i] == extracted_names[j]) {
+        ret.Set(i, ret[i] * shape[j]);
+        ct++;
+      }
+    }
+  }
+
+  CHECK_EQ(ct, extracted_names.size()) << "The number or names of axes do not match";
+
+  return ret;
+}
+
 TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAG")
     .set_body_typed([](Optional<Array<te::Tensor>> tensors, Optional<te::Schedule> sch) {
       if (sch) {
@@ -1452,5 +1475,8 @@ TVM_REGISTER_GLOBAL("auto_scheduler.RewriteIndexForNewLayout")
       return index_rewriter.Rewrite(body);
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.GetShapeFromRewrittenLayout")
+    .set_body_typed(GetShapeFromRewrittenLayout);
+
 }  // namespace auto_scheduler
 }  // namespace tvm
diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h
index d2fb6aa..2b05290 100644
--- a/src/relay/op/make_op.h
+++ b/src/relay/op/make_op.h
@@ -46,6 +46,8 @@ Expr MakeConcatenate(Expr data, int axis);
 
 Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype);
 
+Expr MakeBatchMatmul(Expr lhs, Expr rhs);
+
 Expr MakeExpandDims(Expr data, int axis, int num_newaxis);
 
 Expr MakeFull(Expr fill_value, Array<Integer> shape, DataType dtype);
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index 816b980..fbb6204 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -24,6 +24,7 @@
 
 #include "nn.h"
 
+#include <tvm/auto_scheduler/compute_dag.h>
 #include <tvm/relay/attrs/image.h>
 #include <tvm/relay/attrs/nn.h>
 #include <tvm/relay/op.h>
@@ -845,37 +846,49 @@ If the input has size k on axis 1, then both gamma and beta have shape (k,).
     .add_type_rel("GroupNorm", GroupNormRel);
 
 // relay.nn.batch_matmul
+TVM_REGISTER_NODE_TYPE(BatchMatmulAttrs);
+
 bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                     const TypeReporter& reporter) {
   ICHECK_EQ(types.size(), 3);
   const auto* x = types[0].as<TensorTypeNode>();
   const auto* y = types[1].as<TensorTypeNode>();
   if (x == nullptr || y == nullptr) return false;
-  ICHECK(x->shape.size() == 3 && y->shape.size() == 3);
+
+  const auto* param = attrs.as<BatchMatmulAttrs>();
+  Array<PrimExpr> y_shape;
+  if (param->auto_scheduler_rewritten_layout.size() == 0) {
+    y_shape = y->shape;
+  } else {
+    y_shape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout,
+                                                          {"b", "j", "k"});
+  }
+
+  ICHECK(x->shape.size() == 3 && y_shape.size() == 3);
   bool is_dyn = false;
   Array<tvm::PrimExpr> oshape;
   for (size_t i = 0; i < 3; ++i) {
-    if (x->shape[i].as<tir::AnyNode>() != nullptr || y->shape[i].as<tir::AnyNode>() != nullptr) {
+    if (x->shape[i].as<tir::AnyNode>() != nullptr || y_shape[i].as<tir::AnyNode>() != nullptr) {
       is_dyn = true;
       oshape.push_back(Any());
     } else {
       if (i == 0) {
-        oshape.push_back(max(x->shape[i], y->shape[i]));
+        oshape.push_back(max(x->shape[i], y_shape[i]));
       } else {
         oshape.push_back(x->shape[i]);
       }
     }
   }
   if (!is_dyn) {
-    ICHECK(reporter->AssertEQ(x->shape[0], y->shape[0]) || reporter->AssertEQ(x->shape[0], 1) ||
-           reporter->AssertEQ(y->shape[0], 1))
+    ICHECK(reporter->AssertEQ(x->shape[0], y_shape[0]) || reporter->AssertEQ(x->shape[0], 1) ||
+           reporter->AssertEQ(y_shape[0], 1))
         << "BatchDot: batch dimensions don't match, "
-        << " x shape=" << x->shape << ", y shape=" << y->shape;
-    ICHECK(reporter->AssertEQ(x->shape[2], y->shape[2]))
+        << " x shape=" << x->shape << ", y shape=" << y_shape;
+    ICHECK(reporter->AssertEQ(x->shape[2], y_shape[2]))
         << "BatchDot: shapes of x and y is inconsistent, "
-        << " x shape=" << x->shape << ", y shape=" << y->shape;
+        << " x shape=" << x->shape << ", y shape=" << y_shape;
 
-    oshape.Set(2, y->shape[1]);
+    oshape.Set(2, y_shape[1]);
   }
 
   // assign output type
@@ -885,8 +898,9 @@ bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
 
 // Positional relay function to create batch_matmul operator used by frontend FFI.
 Expr MakeBatchMatmul(Expr x, Expr y) {
+  auto attrs = make_object<BatchMatmulAttrs>();
   static const Op& op = Op::Get("nn.batch_matmul");
-  return Call(op, {x, y}, Attrs(), {});
+  return Call(op, {x, y}, Attrs(attrs), {});
 }
 
 TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_matmul").set_body_typed(MakeBatchMatmul);
diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h
index 30ef307..9b9cff2 100644
--- a/src/relay/op/nn/nn.h
+++ b/src/relay/op/nn/nn.h
@@ -57,7 +57,15 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
     // data dtype as the weight dtype. However if weight dtype is explicitly
     // present we will use that.
     auto weight_dtype = (weight == nullptr ? data->dtype : weight->dtype);
-    reporter->Assign(types[1], TensorType(wshape, weight_dtype));
+    if (param->auto_scheduler_rewritten_layout.size() == 0) {
+      // Normal case: assign result to reporter
+      reporter->Assign(types[1], TensorType(wshape, weight_dtype));
+    } else {
+      // If the layout is rewritten by auto-scheduler,
+      // we just forcly apply the layout provided by auto-scheduler and
+      // skip the normal inference logic.
+      {}  // do nothing
+    }
     oshape.Set((oshape.size() - 1), param->units);
   } else {
     if (weight == nullptr) return false;
diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc
index c9875ef..53e7a02 100644
--- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc
+++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc
@@ -83,6 +83,12 @@ class FuncMutator : public ExprMutator {
       Attrs updated_attrs;
       if (auto pattr = call->attrs.as<Conv2DAttrs>()) {
         updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
+      } else if (auto pattr = call->attrs.as<DenseAttrs>()) {
+        updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
+      } else if (auto pattr = call->attrs.as<BatchMatmulAttrs>()) {
+        updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
+      } else {
+        LOG(FATAL) << "Unhandled attribute: " << call->attrs;
       }
       new_n = Call(call->op, updated_args, updated_attrs);
     }
@@ -93,7 +99,7 @@ class FuncMutator : public ExprMutator {
   std::deque<std::string> ori_layouts_queue_;
   std::deque<std::string> new_layouts_queue_;
 
-  std::vector<std::string> target_ops_{"nn.conv2d"};
+  std::vector<std::string> target_ops_{"nn.conv2d", "nn.dense", "nn.batch_matmul"};
 };
 
 Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) {
@@ -150,8 +156,14 @@ TVM_REGISTER_GLOBAL("relay.attrs.get_auto_scheduler_rewritten_layout")
     .set_body_typed([](const Attrs& attrs) {
       if (attrs->IsInstance<Conv2DAttrs>()) {
         return attrs.as<Conv2DAttrs>()->auto_scheduler_rewritten_layout;
+      } else if (attrs->IsInstance<DenseAttrs>()) {
+        return attrs.as<DenseAttrs>()->auto_scheduler_rewritten_layout;
+      } else if (attrs->IsInstance<BatchMatmulAttrs>()) {
+        return attrs.as<BatchMatmulAttrs>()->auto_scheduler_rewritten_layout;
+      } else {
+        LOG(FATAL) << "Unhandled attribute: " << attrs;
       }
-      return std::string();
+      return tvm::String();
     });
 
 }  // namespace transform
diff --git a/src/relay/transforms/combine_parallel_batch_matmul.cc b/src/relay/transforms/combine_parallel_batch_matmul.cc
index 5b56504..20a7c7f 100644
--- a/src/relay/transforms/combine_parallel_batch_matmul.cc
+++ b/src/relay/transforms/combine_parallel_batch_matmul.cc
@@ -70,16 +70,15 @@ class ParallelBatchMatmulCombiner : public ParallelOpCombiner {
   }
 
   Call MakeCombinedOp(const Group& branches) {
-    const Op& batch_matmul = Op::Get("nn.batch_matmul");
     Expr data = branches[0][0]->args[0];
 
     Array<Expr> weights;
     for (const auto& branch : branches) {
-      auto batch_matmul = branch[0];
-      weights.push_back(batch_matmul->args[1]);
+      auto call = branch[0];
+      weights.push_back(call->args[1]);
     }
     Expr new_weight = MakeConcatenate(Tuple(weights), 1);
-    return Call(batch_matmul, {data, new_weight}, {}, {});
+    return Downcast<Call>(MakeBatchMatmul(data, new_weight));
   }
 
   bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { return true; }
diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc
index 6d4c8c0..d9ca4bf 100644
--- a/src/relay/transforms/combine_parallel_dense.cc
+++ b/src/relay/transforms/combine_parallel_dense.cc
@@ -57,6 +57,22 @@ class ParallelDenseToBatchCombiner : public ParallelOpBatchCombiner {
       : ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {}
 
  protected:
+  Call MakeCombinedOp(const Group& branches) {
+    Array<Expr> new_args;
+    size_t num_args = branches[0][0]->args.size();
+    for (size_t i = 0; i < num_args; i++) {
+      Array<Expr> arg_from_all_branches;
+      for (const auto& branch : branches) {
+        arg_from_all_branches.push_back(branch[0]->args[i]);
+      }
+
+      new_args.push_back(MakeStack(Tuple(arg_from_all_branches), 0));
+    }
+
+    CHECK_EQ(num_args, 2);
+    return Downcast<Call>(MakeBatchMatmul(new_args[0], new_args[1]));
+  }
+
   virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
     StructuralEqual eq;
     const auto* attrs_a = a->attrs.as<DenseAttrs>();
diff --git a/src/relay/transforms/combine_parallel_op_batch.h b/src/relay/transforms/combine_parallel_op_batch.h
index 7a518e9..db4734b 100644
--- a/src/relay/transforms/combine_parallel_op_batch.h
+++ b/src/relay/transforms/combine_parallel_op_batch.h
@@ -95,7 +95,7 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner {
    * \param branches branches that are to be combined
    * \return new call with branches combined as batch op by stacking args
    */
-  Call MakeCombinedOp(const Group& branches) final;
+  virtual Call MakeCombinedOp(const Group& branches);
 
   /*
    * \brief Checks if argument of op following combined ops are able to be combined
diff --git a/tests/python/relay/test_auto_scheduler_layout_rewrite.py b/tests/python/relay/test_auto_scheduler_layout_rewrite.py
index 299fcb8..66d40ba 100644
--- a/tests/python/relay/test_auto_scheduler_layout_rewrite.py
+++ b/tests/python/relay/test_auto_scheduler_layout_rewrite.py
@@ -23,6 +23,7 @@ import tvm
 from tvm import relay, auto_scheduler
 from tvm.contrib import graph_runtime
 import tvm.testing
+from tvm.testing import PropagatingThread
 
 
 def get_np_array(var, dtype):
@@ -70,6 +71,28 @@ def get_relay_conv2d(
     return mod, data, weight
 
 
+def get_relay_dense(m=128, n=128, k=128):
+    dtype = "float32"
+    d = relay.var("data", shape=(m, k), dtype=dtype)
+    w = relay.var("weight", shape=(n, k), dtype=dtype)
+    y = relay.nn.dense(d, w, units=n)
+    mod = tvm.IRModule()
+    mod["main"] = relay.Function([d, w], y)
+    data, weight = get_np_array(d, dtype), get_np_array(w, dtype)
+    return mod, data, weight
+
+
+def get_relay_batchmm(batch=4, m=128, n=128, k=128):
+    dtype = "float32"
+    d = relay.var("data", shape=(batch, m, k), dtype=dtype)
+    w = relay.var("weight", shape=(batch, n, k), dtype=dtype)
+    y = relay.nn.batch_matmul(d, w)
+    mod = tvm.IRModule()
+    mod["main"] = relay.Function([d, w], y)
+    data, weight = get_np_array(d, dtype), get_np_array(w, dtype)
+    return mod, data, weight
+
+
 def tune_and_check(mod, data, weight):
     # Extract tasks from a relay program
     target = tvm.target.Target("llvm")
@@ -109,13 +132,33 @@ def tune_and_check(mod, data, weight):
         actual_output = compile_and_run()
         expected_output = compile_and_run(disabled_pass={"AutoSchedulerLayoutRewrite"})
 
-        tvm.testing.assert_allclose(actual_output, expected_output, rtol=1e-4)
+        tvm.testing.assert_allclose(actual_output, expected_output, rtol=1e-4, atol=1e-4)
 
 
 def test_conv2d():
+    # wrap the search in a new thread to avoid the conflict
+    # between python's multiprocessing and tvm's thread pool
     mod, data, weight = get_relay_conv2d(kh=1, kw=1)
-    tune_and_check(mod, data, weight)
+    t = PropagatingThread(target=tune_and_check, args=(mod, data, weight))
+    t.start()
+    t.join()
+
+
+def test_dense():
+    mod, data, weight = get_relay_dense()
+    t = PropagatingThread(target=tune_and_check, args=(mod, data, weight))
+    t.start()
+    t.join()
+
+
+def test_batch_matmul():
+    mod, data, weight = get_relay_batchmm()
+    t = PropagatingThread(target=tune_and_check, args=(mod, data, weight))
+    t.start()
+    t.join()
 
 
 if __name__ == "__main__":
     test_conv2d()
+    test_dense()
+    test_batch_matmul()
diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py
index a8c9782..cd946ab 100644
--- a/tests/python/relay/test_pass_combine_parallel_dense.py
+++ b/tests/python/relay/test_pass_combine_parallel_dense.py
@@ -286,8 +286,6 @@ def test_combine_parallel_dense_flat_biasadd():
         y = run_opt_pass(y_before, combine_pass)
         y_expected = expected(x, w1, w2, b1, b2, j, bias_shape1, bias_shape2)
         y_expected = run_opt_pass(y_expected, transform.InferType())
-        print(y.astext(False))
-        print(y_expected.astext(False))
         tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
 
     check(3, 5, 4, (), ())
diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py
index 87814f2..a037b68 100644
--- a/tests/python/unittest/test_auto_scheduler_common.py
+++ b/tests/python/unittest/test_auto_scheduler_common.py
@@ -16,9 +16,6 @@
 # under the License.
 
 """Common functions for auto_scheduler test cases"""
-
-import threading
-
 import tvm
 from tvm import te, auto_scheduler
 from tvm import topi
@@ -251,18 +248,3 @@ def get_tiled_matmul():
     )
 
     return dag, s0
-
-
-class PropagatingThread(threading.Thread):
-    def run(self):
-        self.exc = None
-        try:
-            self.ret = self._target(*self._args, **self._kwargs)
-        except BaseException as e:
-            self.exc = e
-
-    def join(self):
-        super(PropagatingThread, self).join()
-        if self.exc:
-            raise self.exc
-        return self.ret
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py
index 6d4fb68..5bc7c2a 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -24,9 +24,10 @@ import tempfile
 
 import tvm
 import tvm.testing
+from tvm.testing import PropagatingThread
 from tvm import auto_scheduler
 
-from test_auto_scheduler_common import matmul_auto_scheduler_test, PropagatingThread
+from test_auto_scheduler_common import matmul_auto_scheduler_test
 import multiprocessing