You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/12/25 00:59:05 UTC
[tvm] branch main updated: [AutoScheduler] Add layout rewrite
support for dense and batch matmul on CPU (#7161)
This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7dcafb0 [AutoScheduler] Add layout rewrite support for dense and batch matmul on CPU (#7161)
7dcafb0 is described below
commit 7dcafb017a05ac0d5ecd7cfe8d8741d33a24bbad
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Thu Dec 24 16:57:24 2020 -0800
[AutoScheduler] Add layout rewrite support for dense and batch matmul on CPU (#7161)
* [AutoScheduler] Add layout rewrite for dense and batch_matmul
* Fix test & Address comments
* Fix shape inference
* fix test
---
include/tvm/auto_scheduler/compute_dag.h | 8 ++++
include/tvm/relay/attrs/nn.h | 10 ++++-
python/tvm/auto_scheduler/__init__.py | 2 +-
python/tvm/auto_scheduler/compute_dag.py | 18 +++++++++
python/tvm/relay/op/strategy/generic.py | 17 +++++---
python/tvm/relay/op/strategy/x86.py | 15 +++++--
python/tvm/testing.py | 18 +++++++++
python/tvm/topi/nn/batch_matmul.py | 30 +++++++++++---
python/tvm/topi/nn/conv2d.py | 37 +++++------------
python/tvm/topi/nn/dense.py | 30 +++++++++++---
src/auto_scheduler/compute_dag.cc | 26 ++++++++++++
src/relay/op/make_op.h | 2 +
src/relay/op/nn/nn.cc | 34 +++++++++++-----
src/relay/op/nn/nn.h | 10 ++++-
.../transforms/auto_scheduler_layout_rewrite.cc | 16 +++++++-
.../transforms/combine_parallel_batch_matmul.cc | 7 ++--
src/relay/transforms/combine_parallel_dense.cc | 16 ++++++++
src/relay/transforms/combine_parallel_op_batch.h | 2 +-
.../relay/test_auto_scheduler_layout_rewrite.py | 47 +++++++++++++++++++++-
.../relay/test_pass_combine_parallel_dense.py | 2 -
.../python/unittest/test_auto_scheduler_common.py | 18 ---------
.../unittest/test_auto_scheduler_search_policy.py | 3 +-
22 files changed, 276 insertions(+), 92 deletions(-)
diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h
index bdb6489..1e3f097 100755
--- a/include/tvm/auto_scheduler/compute_dag.h
+++ b/include/tvm/auto_scheduler/compute_dag.h
@@ -303,6 +303,14 @@ class ComputeDAG : public ObjectRef {
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode);
};
+/*!
+ * \brief Get the orginal shape from a rewritten layout string.
+ * \param rewritten_layout The layout after auto-scheduler's layout rewrite.
+ * \param axis_names Specifiy the names of axes.
+ * \return shape The original shape.
+ */
+Array<PrimExpr> GetShapeFromRewrittenLayout(String rewritten_layout, Array<String> axis_names);
+
} // namespace auto_scheduler
} // namespace tvm
diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index 5ffca99..7bfd580 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -120,7 +120,7 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
tvm::String data_layout;
tvm::String kernel_layout;
tvm::String out_layout;
- std::string auto_scheduler_rewritten_layout;
+ tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
DataType out_dtype;
TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
@@ -924,6 +924,7 @@ struct AvgPool3DAttrs : public tvm::AttrsNode<AvgPool3DAttrs> {
/*! \brief Attributes for dense operator */
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
IndexExpr units;
+ tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
DataType out_dtype;
TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") {
@@ -936,6 +937,13 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
}
};
+/*! \brief Attributes for batch matmul operator */
+struct BatchMatmulAttrs : public tvm::AttrsNode<BatchMatmulAttrs> {
+ tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
+
+ TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") {}
+};
+
/*! \brief Attributes for sparse_dense operator */
struct SparseDenseAttrs : public tvm::AttrsNode<SparseDenseAttrs> {
bool sparse_lhs;
diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py
index 4926b88..a03e156 100644
--- a/python/tvm/auto_scheduler/__init__.py
+++ b/python/tvm/auto_scheduler/__init__.py
@@ -31,7 +31,7 @@ from . import utils
from . import workload_registry
# Shortcut
-from .compute_dag import ComputeDAG, LayoutRewriteOption
+from .compute_dag import ComputeDAG, LayoutRewriteOption, get_shape_from_rewritten_layout
from .cost_model import RandomModel, XGBModel
from .dispatcher import DispatchContext, ApplyHistoryBest
from .measure import (
diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py
index 94cb640..d8a2422 100755
--- a/python/tvm/auto_scheduler/compute_dag.py
+++ b/python/tvm/auto_scheduler/compute_dag.py
@@ -234,3 +234,21 @@ class ComputeDAG(Object):
# Since we always use tensors to recover the ComputeDAG, we do not support
# (de)serialization of the ComputeDAG constructed by a schedule.
self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, LoadJSON(state["tensors"]), None)
+
+
+def get_shape_from_rewritten_layout(rewritten_layout, axis_names):
+ """Get the orginal shape from a rewritten layout string.
+
+ Parameters
+ ----------
+ rewritten_layout: str
+ The layout after rewrite
+ axis_names: List[str]
+ Specify the order of axes by names
+
+ Returns
+ -------
+ shape: List[PrimExpr]
+ The original shape
+ """
+ return _ffi_api.GetShapeFromRewrittenLayout(rewritten_layout, axis_names)
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index 9fc6089..95b5d6a 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -199,7 +199,6 @@ def wrap_compute_conv2d(
data_layout = attrs.get_str("data_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.out_dtype
- auto_scheduler_rewritten_layout = get_auto_scheduler_rewritten_layout(attrs)
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
args = [inputs[0], inputs[1], strides, padding, dilation]
if has_groups:
@@ -210,7 +209,7 @@ def wrap_compute_conv2d(
args.append(out_layout)
args.append(out_dtype)
if need_auto_scheduler_layout:
- args.append(auto_scheduler_rewritten_layout)
+ args.append(get_auto_scheduler_rewritten_layout(attrs))
return [topi_compute(*args)]
return _compute_conv2d
@@ -684,14 +683,17 @@ def dilation2d_strategy(attrs, inputs, out_type, target):
# dense
-def wrap_compute_dense(topi_compute):
+def wrap_compute_dense(topi_compute, need_auto_scheduler_layout=False):
"""wrap dense topi compute"""
def _compute_dense(attrs, inputs, out_type):
"""Compute definition of dense"""
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
- return [topi_compute(inputs[0], inputs[1], None, out_dtype)]
+ args = [inputs[0], inputs[1], None, out_dtype]
+ if need_auto_scheduler_layout:
+ args.append(get_auto_scheduler_rewritten_layout(attrs))
+ return [topi_compute(*args)]
return _compute_dense
@@ -710,11 +712,14 @@ def dense_strategy(attrs, inputs, out_type, target):
# batch_matmul
-def wrap_compute_batch_matmul(topi_compute):
+def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False):
"""wrap batch_matmul topi compute"""
def _compute_batch_matmul(attrs, inputs, out_type):
- return [topi_compute(inputs[0], inputs[1], out_type.shape)]
+ args = [inputs[0], inputs[1], out_type.shape]
+ if need_auto_scheduler_layout:
+ args.append(get_auto_scheduler_rewritten_layout(attrs))
+ return [topi_compute(*args)]
return _compute_batch_matmul
diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py
index 5dfeca6..841213a 100644
--- a/python/tvm/relay/op/strategy/x86.py
+++ b/python/tvm/relay/op/strategy/x86.py
@@ -325,6 +325,15 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
name="dense_nopack.x86",
plevel=10,
)
+
+ if is_auto_scheduler_enabled():
+ strategy.add_implementation(
+ wrap_compute_dense(topi.nn.dense, need_auto_scheduler_layout=True),
+ naive_schedule,
+ name="dense.generic",
+ plevel=11,
+ )
+
if "cblas" in target.libs:
with SpecializedCondition(same_type and dtype in ["float32", "float64"]):
strategy.add_implementation(
@@ -350,7 +359,7 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
plevel=15,
)
with SpecializedCondition(m >= 16):
- # this implementation may not be well-optimized, so use plevel=8 for now.
+ # this implementation may not be well-optimized, so use plevel=5 for now.
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_pack),
wrap_topi_schedule(topi.x86.schedule_dense_pack),
@@ -364,9 +373,9 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
"""batch_matmul x86 strategy"""
strategy = _op.OpStrategy()
- if is_dynamic(out_type):
+ if is_dynamic(out_type) or is_auto_scheduler_enabled():
strategy.add_implementation(
- wrap_compute_batch_matmul(topi.nn.batch_matmul),
+ wrap_compute_batch_matmul(topi.nn.batch_matmul, need_auto_scheduler_layout=True),
wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul),
name="batch_matmul.generic",
plevel=10,
diff --git a/python/tvm/testing.py b/python/tvm/testing.py
index 8311a63..32307a9 100644
--- a/python/tvm/testing.py
+++ b/python/tvm/testing.py
@@ -58,6 +58,7 @@ import logging
import os
import sys
import time
+import threading
import pytest
import numpy as np
import tvm
@@ -742,4 +743,21 @@ def terminate_self():
sys.exit(-1)
+class PropagatingThread(threading.Thread):
+ """A thread that propagates the exection to the main thread"""
+
+ def run(self):
+ self.exc = None
+ try:
+ self.ret = self._target(*self._args, **self._kwargs)
+ except BaseException as e:
+ self.exc = e
+
+ def join(self, timeout=None):
+ super(PropagatingThread, self).join(timeout)
+ if self.exc:
+ raise self.exc
+ return self.ret
+
+
tvm._ffi._init_api("testing", __name__)
diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py
index 98acc2d..9ca2df7 100644
--- a/python/tvm/topi/nn/batch_matmul.py
+++ b/python/tvm/topi/nn/batch_matmul.py
@@ -14,13 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Binary Neural Network (BNN) Operators"""
+"""Batch matrix multiplication"""
# pylint: disable=invalid-name
-from tvm import te
+from tvm import te, auto_scheduler
from ..utils import get_const_tuple
-def batch_matmul(x, y, oshape=None):
+def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch. Supports broadcasting for batch dimension.
@@ -36,14 +36,25 @@ def batch_matmul(x, y, oshape=None):
Explicit intended output shape of the computation. Can be useful in cases
with dynamic input shapes.
+ auto_scheduler_rewritten_layout: str = ""
+ The layout after auto-scheduler's layout rewrite pass.
+
Returns
-------
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
- assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
x_shape = get_const_tuple(x.shape)
- y_shape = get_const_tuple(y.shape)
+ if auto_scheduler_rewritten_layout:
+ # Infer shape for the rewritten layout
+ y_shape = auto_scheduler.get_shape_from_rewritten_layout(
+ auto_scheduler_rewritten_layout, ["b", "j", "k"]
+ )
+ auto_scheduler.remove_index_check(y)
+ else:
+ y_shape = get_const_tuple(y.shape)
+ assert len(x_shape) == 3 and len(y_shape) == 3, "only support 3-dim batch_matmul"
+
XB = x_shape[0]
YB = y_shape[0]
_, M, K = x.shape
@@ -54,8 +65,15 @@ def batch_matmul(x, y, oshape=None):
batch = te.max(XB, YB)
N = y.shape[1]
oshape = (batch, M, N)
- return te.compute(
+
+ output = te.compute(
oshape,
lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k),
tag="batch_matmul",
+ attrs={"layout_free_placeholders": [y]},
)
+
+ if auto_scheduler_rewritten_layout:
+ output = auto_scheduler.rewrite_compute_body(output, auto_scheduler_rewritten_layout)
+
+ return output
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index ead9f16..e2384c4 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -361,6 +361,12 @@ def conv2d_nhwc(
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
+ out_dtype: str = "float32",
+ The type of output tensor
+
+ auto_scheduler_rewritten_layout: str = ""
+ The layout after auto-scheduler's layout rewrite pass.
+
Returns
-------
output : tvm.te.Tensor
@@ -381,34 +387,9 @@ def conv2d_nhwc(
if auto_scheduler_rewritten_layout:
# Infer shape for the rewritten layout
- # todo(merrymercy): wrap this with a more general interface.
- if len(Filter.shape) == 17:
- # For mali.
- # GPU tile structure is SSSRRSRS
- # You could refer function comment of DoMultiLevelTiling
- # in the utils.h to see more detail explanation.
- kernel_h = Filter.shape[6] * Filter.shape[9] * Filter.shape[13]
- kernel_w = Filter.shape[7] * Filter.shape[10] * Filter.shape[14]
- channel = Filter.shape[8] * Filter.shape[11] * Filter.shape[15]
- num_filter = Filter.shape[12] * Filter.shape[16]
- for i in range(6):
- num_filter *= Filter.shape[i]
- elif len(Filter.shape) >= 10:
- # For cpu tile structure SSRSRS
- base = len(Filter.shape) - 10
- kernel_h = Filter.shape[2 + base] * Filter.shape[6 + base]
- kernel_w = Filter.shape[3 + base] * Filter.shape[7 + base]
- channel = Filter.shape[4 + base] * Filter.shape[8 + base]
- num_filter = Filter.shape[5 + base] * Filter.shape[9 + base]
- for i in range(base + 2):
- num_filter *= Filter.shape[i]
- elif len(Filter.shape) == 4:
- num_filter, kernel_h, kernel_w, channel = Filter.shape
- else:
- raise ValueError(
- "Don't know how to infer the layout for filter shape: %s. "
- "Please add a new branch to handle this case." % str(Filter)
- )
+ kernel_h, kernel_w, channel, num_filter = auto_scheduler.get_shape_from_rewritten_layout(
+ auto_scheduler_rewritten_layout, ["ry", "rx", "rc", "ff"]
+ )
auto_scheduler.remove_index_check(Filter)
else:
kernel_h, kernel_w, channel, num_filter = Filter.shape
diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py
index 0ce0f9e..474fea4 100644
--- a/python/tvm/topi/nn/dense.py
+++ b/python/tvm/topi/nn/dense.py
@@ -15,11 +15,11 @@
# specific language governing permissions and limitations
# under the License.
"""TVM operator fully connected compute."""
-from tvm import te
+from tvm import te, auto_scheduler
from .. import tag
-def dense(data, weight, bias=None, out_dtype=None):
+def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layout=""):
"""The default implementation of dense in topi.
Parameters
@@ -30,30 +30,44 @@ def dense(data, weight, bias=None, out_dtype=None):
weight : tvm.te.Tensor
2-D with shape [out_dim, in_dim]
- bias : tvm.te.Tensor, optional
+ bias : Optional[tvm.te.Tensor]
1-D with shape [out_dim]
- out_dtype : str
+ out_dtype : Optional[str]
The output type. This is used for mixed precision.
+ auto_scheduler_rewritten_layout: str = ""
+ The layout after auto-scheduler's layout rewrite pass.
+
Returns
-------
output : tvm.te.Tensor
2-D with shape [batch, out_dim]
"""
- assert len(data.shape) == 2 and len(weight.shape) == 2, "only support 2-dim dense"
+ assert len(data.shape) == 2, "only support 2-dim dense"
if bias is not None:
assert len(bias.shape) == 1
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = data.shape
- out_dim, _ = weight.shape
+
+ if auto_scheduler_rewritten_layout:
+ # Infer shape for the rewritten layout
+ out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout(
+ auto_scheduler_rewritten_layout, ["j", "k"]
+ )
+ auto_scheduler.remove_index_check(weight)
+ else:
+ out_dim, red_dim = weight.shape
+ assert in_dim == red_dim
+
k = te.reduce_axis((0, in_dim), name="k")
matmul = te.compute(
(batch, out_dim),
lambda i, j: te.sum(data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=k),
name="T_dense",
tag="dense",
+ attrs={"layout_free_placeholders": [weight]},
)
if bias is not None:
matmul = te.compute(
@@ -61,4 +75,8 @@ def dense(data, weight, bias=None, out_dtype=None):
lambda i, j: matmul[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST,
)
+
+ if auto_scheduler_rewritten_layout:
+ matmul = auto_scheduler.rewrite_compute_body(matmul, auto_scheduler_rewritten_layout)
+
return matmul
diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc
index af45f2d..64114c8 100755
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -33,6 +33,7 @@
#include <tvm/te/schedule_pass.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
+#include <tvm/topi/transform.h>
#include <algorithm>
#include <cstdint>
@@ -1410,6 +1411,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ss.str();
});
+Array<PrimExpr> GetShapeFromRewrittenLayout(String rewritten_layout, Array<String> axis_names) {
+ Array<PrimExpr> shape;
+ std::vector<std::string> extracted_names;
+ topi::parse_auto_scheduler_layout(rewritten_layout, &shape, &extracted_names);
+
+ Array<PrimExpr> ret(axis_names.size(), 1);
+
+ size_t ct = 0;
+ for (size_t i = 0; i < axis_names.size(); ++i) {
+ for (size_t j = 0; j < extracted_names.size(); ++j) {
+ if (axis_names[i] == extracted_names[j]) {
+ ret.Set(i, ret[i] * shape[j]);
+ ct++;
+ }
+ }
+ }
+
+ CHECK_EQ(ct, extracted_names.size()) << "The number or names of axes do not match";
+
+ return ret;
+}
+
TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAG")
.set_body_typed([](Optional<Array<te::Tensor>> tensors, Optional<te::Schedule> sch) {
if (sch) {
@@ -1452,5 +1475,8 @@ TVM_REGISTER_GLOBAL("auto_scheduler.RewriteIndexForNewLayout")
return index_rewriter.Rewrite(body);
});
+TVM_REGISTER_GLOBAL("auto_scheduler.GetShapeFromRewrittenLayout")
+ .set_body_typed(GetShapeFromRewrittenLayout);
+
} // namespace auto_scheduler
} // namespace tvm
diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h
index d2fb6aa..2b05290 100644
--- a/src/relay/op/make_op.h
+++ b/src/relay/op/make_op.h
@@ -46,6 +46,8 @@ Expr MakeConcatenate(Expr data, int axis);
Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype);
+Expr MakeBatchMatmul(Expr lhs, Expr rhs);
+
Expr MakeExpandDims(Expr data, int axis, int num_newaxis);
Expr MakeFull(Expr fill_value, Array<Integer> shape, DataType dtype);
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index 816b980..fbb6204 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -24,6 +24,7 @@
#include "nn.h"
+#include <tvm/auto_scheduler/compute_dag.h>
#include <tvm/relay/attrs/image.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/op.h>
@@ -845,37 +846,49 @@ If the input has size k on axis 1, then both gamma and beta have shape (k,).
.add_type_rel("GroupNorm", GroupNormRel);
// relay.nn.batch_matmul
+TVM_REGISTER_NODE_TYPE(BatchMatmulAttrs);
+
bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 3);
const auto* x = types[0].as<TensorTypeNode>();
const auto* y = types[1].as<TensorTypeNode>();
if (x == nullptr || y == nullptr) return false;
- ICHECK(x->shape.size() == 3 && y->shape.size() == 3);
+
+ const auto* param = attrs.as<BatchMatmulAttrs>();
+ Array<PrimExpr> y_shape;
+ if (param->auto_scheduler_rewritten_layout.size() == 0) {
+ y_shape = y->shape;
+ } else {
+ y_shape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout,
+ {"b", "j", "k"});
+ }
+
+ ICHECK(x->shape.size() == 3 && y_shape.size() == 3);
bool is_dyn = false;
Array<tvm::PrimExpr> oshape;
for (size_t i = 0; i < 3; ++i) {
- if (x->shape[i].as<tir::AnyNode>() != nullptr || y->shape[i].as<tir::AnyNode>() != nullptr) {
+ if (x->shape[i].as<tir::AnyNode>() != nullptr || y_shape[i].as<tir::AnyNode>() != nullptr) {
is_dyn = true;
oshape.push_back(Any());
} else {
if (i == 0) {
- oshape.push_back(max(x->shape[i], y->shape[i]));
+ oshape.push_back(max(x->shape[i], y_shape[i]));
} else {
oshape.push_back(x->shape[i]);
}
}
}
if (!is_dyn) {
- ICHECK(reporter->AssertEQ(x->shape[0], y->shape[0]) || reporter->AssertEQ(x->shape[0], 1) ||
- reporter->AssertEQ(y->shape[0], 1))
+ ICHECK(reporter->AssertEQ(x->shape[0], y_shape[0]) || reporter->AssertEQ(x->shape[0], 1) ||
+ reporter->AssertEQ(y_shape[0], 1))
<< "BatchDot: batch dimensions don't match, "
- << " x shape=" << x->shape << ", y shape=" << y->shape;
- ICHECK(reporter->AssertEQ(x->shape[2], y->shape[2]))
+ << " x shape=" << x->shape << ", y shape=" << y_shape;
+ ICHECK(reporter->AssertEQ(x->shape[2], y_shape[2]))
<< "BatchDot: shapes of x and y is inconsistent, "
- << " x shape=" << x->shape << ", y shape=" << y->shape;
+ << " x shape=" << x->shape << ", y shape=" << y_shape;
- oshape.Set(2, y->shape[1]);
+ oshape.Set(2, y_shape[1]);
}
// assign output type
@@ -885,8 +898,9 @@ bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
// Positional relay function to create batch_matmul operator used by frontend FFI.
Expr MakeBatchMatmul(Expr x, Expr y) {
+ auto attrs = make_object<BatchMatmulAttrs>();
static const Op& op = Op::Get("nn.batch_matmul");
- return Call(op, {x, y}, Attrs(), {});
+ return Call(op, {x, y}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_matmul").set_body_typed(MakeBatchMatmul);
diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h
index 30ef307..9b9cff2 100644
--- a/src/relay/op/nn/nn.h
+++ b/src/relay/op/nn/nn.h
@@ -57,7 +57,15 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// data dtype as the weight dtype. However if weight dtype is explicitly
// present we will use that.
auto weight_dtype = (weight == nullptr ? data->dtype : weight->dtype);
- reporter->Assign(types[1], TensorType(wshape, weight_dtype));
+ if (param->auto_scheduler_rewritten_layout.size() == 0) {
+ // Normal case: assign result to reporter
+ reporter->Assign(types[1], TensorType(wshape, weight_dtype));
+ } else {
+ // If the layout is rewritten by auto-scheduler,
+ // we just forcly apply the layout provided by auto-scheduler and
+ // skip the normal inference logic.
+ {} // do nothing
+ }
oshape.Set((oshape.size() - 1), param->units);
} else {
if (weight == nullptr) return false;
diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc
index c9875ef..53e7a02 100644
--- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc
+++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc
@@ -83,6 +83,12 @@ class FuncMutator : public ExprMutator {
Attrs updated_attrs;
if (auto pattr = call->attrs.as<Conv2DAttrs>()) {
updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
+ } else if (auto pattr = call->attrs.as<DenseAttrs>()) {
+ updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
+ } else if (auto pattr = call->attrs.as<BatchMatmulAttrs>()) {
+ updated_attrs = CopyAttrsWithNewLayout(pattr, new_layout);
+ } else {
+ LOG(FATAL) << "Unhandled attribute: " << call->attrs;
}
new_n = Call(call->op, updated_args, updated_attrs);
}
@@ -93,7 +99,7 @@ class FuncMutator : public ExprMutator {
std::deque<std::string> ori_layouts_queue_;
std::deque<std::string> new_layouts_queue_;
- std::vector<std::string> target_ops_{"nn.conv2d"};
+ std::vector<std::string> target_ops_{"nn.conv2d", "nn.dense", "nn.batch_matmul"};
};
Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) {
@@ -150,8 +156,14 @@ TVM_REGISTER_GLOBAL("relay.attrs.get_auto_scheduler_rewritten_layout")
.set_body_typed([](const Attrs& attrs) {
if (attrs->IsInstance<Conv2DAttrs>()) {
return attrs.as<Conv2DAttrs>()->auto_scheduler_rewritten_layout;
+ } else if (attrs->IsInstance<DenseAttrs>()) {
+ return attrs.as<DenseAttrs>()->auto_scheduler_rewritten_layout;
+ } else if (attrs->IsInstance<BatchMatmulAttrs>()) {
+ return attrs.as<BatchMatmulAttrs>()->auto_scheduler_rewritten_layout;
+ } else {
+ LOG(FATAL) << "Unhandled attribute: " << attrs;
}
- return std::string();
+ return tvm::String();
});
} // namespace transform
diff --git a/src/relay/transforms/combine_parallel_batch_matmul.cc b/src/relay/transforms/combine_parallel_batch_matmul.cc
index 5b56504..20a7c7f 100644
--- a/src/relay/transforms/combine_parallel_batch_matmul.cc
+++ b/src/relay/transforms/combine_parallel_batch_matmul.cc
@@ -70,16 +70,15 @@ class ParallelBatchMatmulCombiner : public ParallelOpCombiner {
}
Call MakeCombinedOp(const Group& branches) {
- const Op& batch_matmul = Op::Get("nn.batch_matmul");
Expr data = branches[0][0]->args[0];
Array<Expr> weights;
for (const auto& branch : branches) {
- auto batch_matmul = branch[0];
- weights.push_back(batch_matmul->args[1]);
+ auto call = branch[0];
+ weights.push_back(call->args[1]);
}
Expr new_weight = MakeConcatenate(Tuple(weights), 1);
- return Call(batch_matmul, {data, new_weight}, {}, {});
+ return Downcast<Call>(MakeBatchMatmul(data, new_weight));
}
bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { return true; }
diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc
index 6d4c8c0..d9ca4bf 100644
--- a/src/relay/transforms/combine_parallel_dense.cc
+++ b/src/relay/transforms/combine_parallel_dense.cc
@@ -57,6 +57,22 @@ class ParallelDenseToBatchCombiner : public ParallelOpBatchCombiner {
: ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {}
protected:
+ Call MakeCombinedOp(const Group& branches) {
+ Array<Expr> new_args;
+ size_t num_args = branches[0][0]->args.size();
+ for (size_t i = 0; i < num_args; i++) {
+ Array<Expr> arg_from_all_branches;
+ for (const auto& branch : branches) {
+ arg_from_all_branches.push_back(branch[0]->args[i]);
+ }
+
+ new_args.push_back(MakeStack(Tuple(arg_from_all_branches), 0));
+ }
+
+ CHECK_EQ(num_args, 2);
+ return Downcast<Call>(MakeBatchMatmul(new_args[0], new_args[1]));
+ }
+
virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
StructuralEqual eq;
const auto* attrs_a = a->attrs.as<DenseAttrs>();
diff --git a/src/relay/transforms/combine_parallel_op_batch.h b/src/relay/transforms/combine_parallel_op_batch.h
index 7a518e9..db4734b 100644
--- a/src/relay/transforms/combine_parallel_op_batch.h
+++ b/src/relay/transforms/combine_parallel_op_batch.h
@@ -95,7 +95,7 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner {
* \param branches branches that are to be combined
* \return new call with branches combined as batch op by stacking args
*/
- Call MakeCombinedOp(const Group& branches) final;
+ virtual Call MakeCombinedOp(const Group& branches);
/*
* \brief Checks if argument of op following combined ops are able to be combined
diff --git a/tests/python/relay/test_auto_scheduler_layout_rewrite.py b/tests/python/relay/test_auto_scheduler_layout_rewrite.py
index 299fcb8..66d40ba 100644
--- a/tests/python/relay/test_auto_scheduler_layout_rewrite.py
+++ b/tests/python/relay/test_auto_scheduler_layout_rewrite.py
@@ -23,6 +23,7 @@ import tvm
from tvm import relay, auto_scheduler
from tvm.contrib import graph_runtime
import tvm.testing
+from tvm.testing import PropagatingThread
def get_np_array(var, dtype):
@@ -70,6 +71,28 @@ def get_relay_conv2d(
return mod, data, weight
+def get_relay_dense(m=128, n=128, k=128):
+ dtype = "float32"
+ d = relay.var("data", shape=(m, k), dtype=dtype)
+ w = relay.var("weight", shape=(n, k), dtype=dtype)
+ y = relay.nn.dense(d, w, units=n)
+ mod = tvm.IRModule()
+ mod["main"] = relay.Function([d, w], y)
+ data, weight = get_np_array(d, dtype), get_np_array(w, dtype)
+ return mod, data, weight
+
+
+def get_relay_batchmm(batch=4, m=128, n=128, k=128):
+ dtype = "float32"
+ d = relay.var("data", shape=(batch, m, k), dtype=dtype)
+ w = relay.var("weight", shape=(batch, n, k), dtype=dtype)
+ y = relay.nn.batch_matmul(d, w)
+ mod = tvm.IRModule()
+ mod["main"] = relay.Function([d, w], y)
+ data, weight = get_np_array(d, dtype), get_np_array(w, dtype)
+ return mod, data, weight
+
+
def tune_and_check(mod, data, weight):
# Extract tasks from a relay program
target = tvm.target.Target("llvm")
@@ -109,13 +132,33 @@ def tune_and_check(mod, data, weight):
actual_output = compile_and_run()
expected_output = compile_and_run(disabled_pass={"AutoSchedulerLayoutRewrite"})
- tvm.testing.assert_allclose(actual_output, expected_output, rtol=1e-4)
+ tvm.testing.assert_allclose(actual_output, expected_output, rtol=1e-4, atol=1e-4)
def test_conv2d():
+ # wrap the search in a new thread to avoid the conflict
+ # between python's multiprocessing and tvm's thread pool
mod, data, weight = get_relay_conv2d(kh=1, kw=1)
- tune_and_check(mod, data, weight)
+ t = PropagatingThread(target=tune_and_check, args=(mod, data, weight))
+ t.start()
+ t.join()
+
+
+def test_dense():
+ mod, data, weight = get_relay_dense()
+ t = PropagatingThread(target=tune_and_check, args=(mod, data, weight))
+ t.start()
+ t.join()
+
+
+def test_batch_matmul():
+ mod, data, weight = get_relay_batchmm()
+ t = PropagatingThread(target=tune_and_check, args=(mod, data, weight))
+ t.start()
+ t.join()
if __name__ == "__main__":
test_conv2d()
+ test_dense()
+ test_batch_matmul()
diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py
index a8c9782..cd946ab 100644
--- a/tests/python/relay/test_pass_combine_parallel_dense.py
+++ b/tests/python/relay/test_pass_combine_parallel_dense.py
@@ -286,8 +286,6 @@ def test_combine_parallel_dense_flat_biasadd():
y = run_opt_pass(y_before, combine_pass)
y_expected = expected(x, w1, w2, b1, b2, j, bias_shape1, bias_shape2)
y_expected = run_opt_pass(y_expected, transform.InferType())
- print(y.astext(False))
- print(y_expected.astext(False))
tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
check(3, 5, 4, (), ())
diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py
index 87814f2..a037b68 100644
--- a/tests/python/unittest/test_auto_scheduler_common.py
+++ b/tests/python/unittest/test_auto_scheduler_common.py
@@ -16,9 +16,6 @@
# under the License.
"""Common functions for auto_scheduler test cases"""
-
-import threading
-
import tvm
from tvm import te, auto_scheduler
from tvm import topi
@@ -251,18 +248,3 @@ def get_tiled_matmul():
)
return dag, s0
-
-
-class PropagatingThread(threading.Thread):
- def run(self):
- self.exc = None
- try:
- self.ret = self._target(*self._args, **self._kwargs)
- except BaseException as e:
- self.exc = e
-
- def join(self):
- super(PropagatingThread, self).join()
- if self.exc:
- raise self.exc
- return self.ret
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py
index 6d4fb68..5bc7c2a 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -24,9 +24,10 @@ import tempfile
import tvm
import tvm.testing
+from tvm.testing import PropagatingThread
from tvm import auto_scheduler
-from test_auto_scheduler_common import matmul_auto_scheduler_test, PropagatingThread
+from test_auto_scheduler_common import matmul_auto_scheduler_test
import multiprocessing