You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/03/27 21:43:46 UTC
[tvm] branch unity updated: [Unity][Op] Conv1d (#14388)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 9a3ec23b66 [Unity][Op] Conv1d (#14388)
9a3ec23b66 is described below
commit 9a3ec23b66fa86d37a130e60c83bff8e983c46ba
Author: Lesheng Jin <34...@users.noreply.github.com>
AuthorDate: Mon Mar 27 14:43:35 2023 -0700
[Unity][Op] Conv1d (#14388)
This PR implements Conv1d.
Unit tests are provided accordingly.
---
include/tvm/relax/attrs/nn.h | 43 +++
python/tvm/relax/frontend/torch/fx_translator.py | 29 ++
python/tvm/relax/op/nn/nn.py | 98 ++++++
python/tvm/relax/transform/legalize_ops/nn.py | 40 +++
src/relax/op/nn/convolution.cc | 155 +++++++++
src/relax/op/nn/convolution.h | 5 +
src/relax/op/op_common.h | 20 ++
tests/python/relax/test_frontend_from_fx.py | 88 ++++-
tests/python/relax/test_op_nn_convolution.py | 378 ++++++++++++++++++++-
.../python/relax/test_transform_legalize_ops_nn.py | 177 ++++++++++
tests/python/relax/test_tvmscript_parser_op_nn.py | 30 +-
11 files changed, 1055 insertions(+), 8 deletions(-)
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index f49cb6b121..3daa32fd76 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -29,6 +29,49 @@
namespace tvm {
namespace relax {
+/*! \brief Attributes used in Conv1d operator */
+struct Conv1DAttrs : public tvm::AttrsNode<Conv1DAttrs> {
+ Array<IntImm> strides;
+ Array<IntImm> padding;
+ Array<IntImm> dilation;
+ int groups;
+ String data_layout;
+ String kernel_layout;
+ String out_layout;
+ DataType out_dtype;
+
+ TVM_DECLARE_ATTRS(Conv1DAttrs, "relax.attrs.Conv1DAttrs") {
+ TVM_ATTR_FIELD(strides).describe("Specifies the strides of the convolution.");
+ TVM_ATTR_FIELD(padding).describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on both sides"
+ "two int : padding width in the order of (left, right)");
+ TVM_ATTR_FIELD(dilation).describe(
+ "Specifies the dilation rate to use for dilated convolution.");
+ TVM_ATTR_FIELD(groups).describe(
+ "Number of groups to split the input into for grouped convolution. The number of input and "
+ "output channels should be divisible by the number of groups.");
+ TVM_ATTR_FIELD(data_layout)
+ .describe(
+ "Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
+ "'N', 'C', 'W' stands for batch, channel, width"
+ "dimensions respectively. Convolution is applied on the 'W' dimensions.");
+ TVM_ATTR_FIELD(kernel_layout)
+ .describe(
+ "Dimension ordering of weight. Can be 'OIW', 'IOW', etc."
+ "'O', 'I', 'W' stands for num_filter, input_channel, and width"
+ "dimensions respectively.");
+ TVM_ATTR_FIELD(out_layout)
+ .describe(
+ "Dimension ordering of output. Can be 'NCW', 'NWC', etc."
+ "'N', 'C', 'W' stands for batch, channel, and width"
+ "dimensions respectively. Default to be same as input layout.");
+ TVM_ATTR_FIELD(out_dtype).describe(
+ "Output data type, set to explicit type under mixed precision setting");
+ }
+}; // struct Conv1dAttrs
+
/*! \brief Attributes used in Conv2d operator */
struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
Array<IntImm> strides;
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py
index ef6793cc67..c65e94d691 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -637,6 +637,34 @@ class TorchFXImporter:
bias = None if module.bias is None else self.params[module.bias]
return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32"))
+ def _conv1d(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+ weight = self.params[module.weight]
+
+ conv1d = self.block_builder.emit(
+ relax.op.nn.conv1d(
+ x,
+ weight,
+ strides=module.stride,
+ padding=module.padding,
+ dilation=module.dilation,
+ groups=module.groups,
+ data_layout="NCW",
+ kernel_layout="OIW",
+ out_dtype="float32",
+ )
+ )
+
+ if module.bias is None:
+ return conv1d
+
+ bias = self.params[module.bias]
+ assert len(self.shape_of(bias)) == 1
+ bias = relax.op.reshape(bias, (1, -1, 1))
+
+ return self.block_builder.emit(relax.op.add(conv1d, bias))
+
def _conv2d(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
@@ -1001,6 +1029,7 @@ class TorchFXImporter:
self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.node.Node], relax.Var]] = {
# call_module
nn.Linear: self._linear,
+ nn.Conv1d: self._conv1d,
nn.Conv2d: self._conv2d,
nn.MaxPool2d: self._max_pool2d,
nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True),
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index c774bbc926..e1d41c6cdf 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -23,6 +23,104 @@ from . import _ffi_api
from ...expr import Expr
+def conv1d(
+ data: Expr,
+ weight: Expr,
+ strides: Union[int, Tuple[int]] = 1,
+ padding: Union[int, Tuple[int, ...]] = 0,
+ dilation: Union[int, Tuple[int]] = 1,
+ groups: int = 1,
+ data_layout: str = "NCW",
+ kernel_layout: str = "OIW",
+ out_layout: Optional[str] = None,
+ out_dtype: Optional[Union[str, DataType]] = None,
+) -> Expr:
+ r"""1D convolution.
+
+ This operator takes the weight as the 1D convolution kernel
+ and convolves it with data to produce an output.
+
+
+ In the default case, where the data_layout is `NCW`
+ and kernel_layout is `OIW`, conv1d takes in
+ a data Tensor with shape `(batch_size, in_channels, width)`,
+ and a weight Tensor with shape `(channels, in_channels, kernel_w)`,
+ where `kernel_w` is the length of the `W` kernel dimension,
+ to produce an output Tensor with the following rule:
+
+ .. math::
+
+ \mbox{out}[b, c, x] = \sum_{dx, k}
+ \mbox{data}[b, k, \mbox{strides} * x + dx] *
+ \mbox{weight}[c, k, dx]
+
+ Padding and dilation are applied to data and weight respectively before the computation.
+ This operator accepts data layout specification.
+ Semantically, the operator will convert the layout to the canonical layout
+ (`NCW` for data and `OIW` for weight), perform the computation,
+ then convert to the out_layout.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data to the operator.
+
+ weight : relax.Expr
+ The weight expressions.
+
+ strides : Union[int, Tuple[int]]
+ The strides of convolution. It is required to have length 1.
+
+ padding : Union[int, Tuple[int, ...]]
+ The padding of convolution on both sides of inputs before convolution.
+ It is required to have length either 1 or 2.
+
+ dilation : Union[int, Tuple[int, int]]
+ Specifies the dilation rate to be used for dilated convolution.
+ It is required to have length 1.
+
+ groups : int
+ Number of groups to split the input into for grouped convolution.
+ The number of input and output channels should be divisible by the number of groups.
+
+ data_layout : str
+ Layout of the input.
+
+ kernel_layout : str
+ Layout of the weight.
+
+ out_layout : Optional[str]
+ Layout of the output. If not specified, it is the same as data_layout
+
+ out_dtype : Optional[Union[str, DataType]]
+ Specifies the output data type for mixed precision conv1d.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ if isinstance(strides, int):
+ strides = (strides,)
+ if isinstance(dilation, int):
+ dilation = (dilation,)
+ if isinstance(padding, int):
+ padding = (padding, padding)
+
+ return _ffi_api.conv1d( # type: ignore
+ data,
+ weight,
+ strides,
+ padding,
+ dilation,
+ groups,
+ data_layout,
+ kernel_layout,
+ out_layout,
+ out_dtype,
+ )
+
+
def conv2d(
data: Expr,
weight: Expr,
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py
index bfc0544536..889e6e0941 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -24,6 +24,46 @@ from ...expr import Call, Expr
from .common import register_legalize, _call_topi_without_attr
+@register_legalize("relax.nn.conv1d")
+def _nn_conv1d(bb: BlockBuilder, call: Call) -> Expr:
+ if call.attrs.out_layout != call.attrs.data_layout:
+ logging.info(
+ "TOPI conv1d does not support different input-output "
+ "layouts, and thus cannot be legalized by TOPI"
+ )
+ return call
+ if len(call.attrs.data_layout) != 3 or len(call.attrs.kernel_layout) != 3:
+ logging.info(
+ "Conv1D where data layout or kernel layout have channel chunk "
+ "cannot be legalized by TOPI at this moment."
+ )
+ return call
+ if call.attrs.groups != 1:
+ data_layout = tir.layout(call.attrs.data_layout)
+ kernel_layout = tir.layout(call.attrs.kernel_layout)
+ ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")]
+ oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")]
+ if not isinstance(ic, tir.IntImm) or not isinstance(oc, tir.IntImm):
+ logging.info(
+ "Conv1D where number of groups is more than one and input or output "
+ "channel size is symbolic cannot be legalized by TOPI at this moment."
+ )
+ return call
+
+ return bb.call_te(
+ topi.nn.conv1d,
+ data=call.args[0],
+ kernel=call.args[1],
+ strides=call.attrs.strides,
+ padding=call.attrs.padding,
+ dilation=call.attrs.dilation,
+ data_layout=call.attrs.data_layout,
+ kernel_layout=call.attrs.kernel_layout,
+ out_dtype=call.attrs.out_dtype if call.attrs.out_dtype != "" else None,
+ primfunc_name_hint="conv1d",
+ )
+
+
@register_legalize("relax.nn.conv2d")
def _nn_conv2d(bb: BlockBuilder, call: Call) -> Expr:
if call.attrs.out_layout != call.attrs.data_layout:
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index e10d205b23..ae84409c2a 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -29,6 +29,161 @@
namespace tvm {
namespace relax {
+/* relax.nn.conv1d */
+TVM_REGISTER_NODE_TYPE(Conv1DAttrs);
+
+Expr conv1d(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm> padding,
+ Array<IntImm> dilation, int groups, String data_layout, String kernel_layout,
+ Optional<String> out_layout, DataType out_dtype) {
+ padding = GetCompletePadding1D(std::move(padding));
+
+ CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, "
+ "the given number of groups is "
+ << groups;
+ CHECK_EQ(strides.size(), 1)
+ << "The input strides length is expected to be 1. However, the given strides is " << strides;
+ CHECK_EQ(dilation.size(), 1)
+ << "The input dilation length is expected to be 1. However, the given dilation is "
+ << dilation;
+ return MakeConv<Conv1DAttrs>(std::move(data), std::move(weight), std::move(strides),
+ std::move(padding), std::move(dilation), groups, data_layout,
+ std::move(kernel_layout), out_layout.value_or(data_layout),
+ out_dtype, /*op_name=*/"relax.nn.conv1d");
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.conv1d").set_body_typed(conv1d);
+
+StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) {
+ Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+ TensorStructInfo data_sinfo = input_sinfo[0];
+ TensorStructInfo weight_sinfo = input_sinfo[1];
+
+ const auto* attrs = call->attrs.as<Conv1DAttrs>();
+ auto [data_layout, data2NCW] = CheckTensorLayout(call, ctx, attrs->data_layout, //
+ /*tgt_layout=*/"NCW", //
+ /*tensor_name=*/"data");
+ auto [weight_layout, weight2OIW] = CheckTensorLayout(call, ctx, attrs->kernel_layout, //
+ /*tgt_layout=*/"OIW", //
+ /*tensor_name=*/"kernel");
+ auto [out_layout, out2NCW] = CheckTensorLayout(call, ctx, attrs->out_layout, //
+ /*tgt_layout=*/"NCW", //
+ /*tensor_name=*/"output");
+
+ Optional<ShapeExpr> data_shape =
+ CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
+ Optional<ShapeExpr> weight_shape =
+ CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout);
+
+ DataType out_dtype = attrs->out_dtype.is_void()
+ ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo)
+ : attrs->out_dtype;
+ if (!data_shape.defined() || !weight_shape.defined()) {
+ return TensorStructInfo(out_dtype, out_layout.ndim());
+ }
+
+ Array<PrimExpr> data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values);
+ Array<PrimExpr> weight_OIW_shape = weight2OIW.ForwardShape(weight_shape.value()->values);
+
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+ PrimExpr input_channel_data = data_NCW_shape[1];
+ PrimExpr input_channel_kernel = weight_OIW_shape[1];
+ if (analyzer->CanProve(input_channel_data != input_channel_kernel * attrs->groups)) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "The channel size of the data should equal to the product of input channel size of the "
+ "weight and the number of groups. However, the data channel size is "
+ << input_channel_data << " while the weight input channel size and number of groups are "
+ << input_channel_kernel << " and " << attrs->groups);
+ } else if (!analyzer->CanProveEqual(input_channel_data, input_channel_kernel * attrs->groups)) {
+ // Todo(relax-team): Trust the input shape at this moment, and revisit
+ // this condition with runtime shape check
+ }
+ if (analyzer->CanProve(floormod(weight_OIW_shape[0], attrs->groups) != 0)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Conv1d expects the number of output channels to be divisible by the "
+ "number of groups. However, the number of output channels is "
+ << weight_OIW_shape[0] << " while the number of groups is " << attrs->groups);
+ } else if (!analyzer->CanProveEqual(floormod(weight_OIW_shape[0], attrs->groups), 0)) {
+ // Todo(relax-team): Trust the input shape at this moment, and revisit
+ // this condition with runtime shape check
+ }
+
+ PrimExpr input_w = data_NCW_shape[2];
+ PrimExpr kernel_w = weight_OIW_shape[2];
+ PrimExpr padding_w = attrs->padding[0] + attrs->padding[1];
+
+ std::vector<PrimExpr> out_NCW_shape;
+ out_NCW_shape.resize(3);
+ out_NCW_shape[0] = data_NCW_shape[0];
+ out_NCW_shape[1] = weight_OIW_shape[0];
+
+ PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w - 1) - 1;
+ out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[0]) + 1);
+
+ Array<PrimExpr> out_shape = out2NCW.BackwardShape(out_NCW_shape);
+ return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
+}
+
+InferLayoutOutput InferLayoutConv1d(const Call& call,
+ const Map<String, Array<String>>& desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ const auto& it = desired_layouts.find("relax.nn.conv1d");
+ const auto* attrs = call->attrs.as<Conv1DAttrs>();
+ ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision data_layout, weight_layout, output_layout;
+ ObjectPtr<Conv1DAttrs> new_attrs = make_object<Conv1DAttrs>(*attrs);
+
+ if (it != desired_layouts.end()) {
+ // We have a desired layout for conv1d.
+ Layout desired_data_layout = (*it).second[0];
+ Layout desired_weight_layout = (*it).second[1];
+ Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0];
+ ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only";
+ ICHECK_EQ(desired_weight_layout.ndim(), desired_weight_layout.ndim_primal())
+ << "Axis swap only";
+ ICHECK_EQ(desired_output_layout.ndim(), desired_output_layout.ndim_primal())
+ << "Axis swap only";
+ data_layout = TransposeLike(InitialLayout(3), attrs->data_layout, desired_data_layout);
+ weight_layout = TransposeLike(InitialLayout(3), attrs->kernel_layout, desired_weight_layout);
+ output_layout = TransposeLike(InitialLayout(3), attrs->out_layout, desired_output_layout);
+ new_attrs->data_layout = (*it).second[0];
+ new_attrs->kernel_layout = (*it).second[1];
+ new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0];
+ } else {
+ // We don't have a desired layout for conv1d.
+ // We can just propagate the layout from the input.
+ data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
+ weight_layout = GetLayoutDecision(var_layout_map, call->args[1]);
+ output_layout = data_layout;
+ new_attrs->data_layout =
+ TransposeLike(attrs->data_layout, InitialLayout(3), data_layout->layout).name();
+ new_attrs->kernel_layout =
+ TransposeLike(attrs->kernel_layout, InitialLayout(3), weight_layout->layout).name();
+ new_attrs->out_layout =
+ TransposeLike(attrs->out_layout, InitialLayout(3), output_layout->layout).name();
+ }
+ return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs));
+}
+
+Call InferMixedPrecisionConv1d(const Call& call, const DataType& out_dtype) {
+ const auto* conv1d_attrs = call->attrs.as<Conv1DAttrs>();
+ return Downcast<Call>(conv1d(call->args[0], call->args[1], conv1d_attrs->strides,
+ conv1d_attrs->padding, conv1d_attrs->dilation, conv1d_attrs->groups,
+ conv1d_attrs->data_layout, conv1d_attrs->kernel_layout,
+ conv1d_attrs->out_layout, out_dtype));
+}
+
+TVM_REGISTER_OP("relax.nn.conv1d")
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_attrs_type<Conv1DAttrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv1d)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConv1d)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways)
+ .set_attr<FInferMixedPrecision>("FInferMixedPrecision", InferMixedPrecisionConv1d);
+
/* relax.nn.conv2d */
TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h
index 7093c6a4d9..833e730ee9 100644
--- a/src/relax/op/nn/convolution.h
+++ b/src/relax/op/nn/convolution.h
@@ -52,6 +52,11 @@ inline Expr MakeConv(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm
return Call(op, {data, weight}, Attrs(attrs), {});
}
+/*! \brief 1D convolution */
+Expr conv1d(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm> padding,
+ Array<IntImm> dilation, int groups, String data_layout, String kernel_layout,
+ Optional<String> out_layout, DataType out_dtype);
+
/*! \brief 2D convolution */
Expr conv2d(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm> padding,
Array<IntImm> dilation, int groups, String data_layout, String kernel_layout,
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index bd5f2cd4d5..616dded39e 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -240,6 +240,26 @@ inline Array<IntImm> ConvertIntImmToInt64(const Array<IntImm>& int_imms) {
/************ Utilities for NN operators ************/
+/*!
+ * \brief Complete the padding to a 2-length array.
+ * - If the padding length is 1, the same padding is used on all left/right sides
+ * - If the padding length is 2, padding is in the order of (left, right)
+ * \param padding The given padding to be completed
+ * \return The completed padding.
+ * \throws Throws error if the input padding length is neither 1 or 2.
+ */
+inline Array<IntImm> GetCompletePadding1D(Array<IntImm> padding) {
+ if (padding.size() == 1) {
+ return {padding[0], padding[0]};
+ } else if (padding.size() == 2) {
+ return padding;
+ }
+ LOG(FATAL) << "The input padding length is expected to be either 1 or 2. However, the given "
+ "padding is "
+ << padding;
+ throw;
+}
+
/*!
* \brief Complete the padding to a 4-length array.
* - If the padding length is 1, the same padding is used on all top/left/bottom/right sides
diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py
index d201cb111c..9e07ff7b59 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -37,7 +37,93 @@ def verify_model(torch_model, input_info, binding, expected):
@tvm.testing.requires_gpu
-def test_conv():
+def test_conv1d():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ class Conv1D1(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv1d(3, 6, 7, bias=True)
+
+ def forward(self, input):
+ return self.conv(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10), dtype="float32"),
+ w1: R.Tensor((6, 3, 7), dtype="float32"),
+ w2: R.Tensor((6,), dtype="float32"),
+ ) -> R.Tensor((1, 6, 4), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d(
+ input_1,
+ w1,
+ strides=[1],
+ padding=[0, 0],
+ dilation=[1],
+ data_layout="NCW",
+ kernel_layout="OIW",
+ out_layout="NCW",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1])
+ lv3: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2)
+ gv: R.Tensor((1, 6, 4), dtype="float32") = lv3
+ R.output(gv)
+ return gv
+
+ class Conv1D2(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv1d(3, 6, 7, bias=False)
+
+ def forward(self, input):
+ return self.conv(input)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10), dtype="float32"),
+ w1: R.Tensor((6, 3, 7), dtype="float32"),
+ ) -> R.Tensor((1, 6, 4), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d(
+ input_1,
+ w1,
+ strides=[1],
+ padding=[0, 0],
+ dilation=[1],
+ data_layout="NCW",
+ kernel_layout="OIW",
+ out_layout="NCW",
+ out_dtype="float32",
+ )
+ gv: R.Tensor((1, 6, 4), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ input_info = [([1, 3, 10], "float32")]
+
+ model = Conv1D1()
+ binding = {"w1": model.conv.weight.numpy(), "w2": model.conv.bias.numpy()}
+ verify_model(model, input_info, binding, expected1)
+
+ model = Conv1D2()
+ binding = {"w1": model.conv.weight.numpy()}
+ verify_model(model, input_info, binding, expected2)
+
+
+@tvm.testing.requires_gpu
+def test_conv2d():
import torch
from torch.nn import Module
diff --git a/tests/python/relax/test_op_nn_convolution.py b/tests/python/relax/test_op_nn_convolution.py
index 334f6977f7..d1d604429e 100644
--- a/tests/python/relax/test_op_nn_convolution.py
+++ b/tests/python/relax/test_op_nn_convolution.py
@@ -23,7 +23,13 @@ from tvm.ir import Op
from tvm.script import relax as R
-def test_op_correctness():
+def test_conv1d_op_correctness():
+ x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ w = relax.Var("w", R.Tensor((4, 3, 3), "float32"))
+ assert relax.op.nn.conv1d(x, w).op == Op.get("relax.nn.conv1d")
+
+
+def test_conv2d_op_correctness():
x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
assert relax.op.nn.conv2d(x, w).op == Op.get("relax.nn.conv2d")
@@ -35,6 +41,376 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+def test_conv1d_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 28, 3), "float32"))
+ x2 = relax.Var("x", R.Tensor("float32", ndim=3))
+ x3 = relax.Var("x", R.Tensor("float32"))
+ x4 = relax.Var("x", R.Tensor())
+ x5 = relax.Var("x", R.Tensor((2, 4, 28, 16), "float32"))
+ w0 = relax.Var("w", R.Tensor((4, 3, 3), "float32"))
+ w1 = relax.Var("w", R.Tensor((3, 4, 3), "float32"))
+ w2 = relax.Var("w", R.Tensor("float32", ndim=3))
+ w3 = relax.Var("w", R.Tensor("float32"))
+ w4 = relax.Var("w", R.Tensor((48, 4, 3, 16), "float32"))
+
+ _check_inference(bb, relax.op.nn.conv1d(x0, w0), relax.TensorStructInfo((2, 4, 26), "float32"))
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x0, w0, out_dtype="float16"),
+ relax.TensorStructInfo((2, 4, 26), "float16"),
+ )
+ _check_inference(
+ bb, relax.op.nn.conv1d(x0, w0, padding=1), relax.TensorStructInfo((2, 4, 28), "float32")
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x0, w0, padding=[1, 3]),
+ relax.TensorStructInfo((2, 4, 30), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x0, w0, strides=2),
+ relax.TensorStructInfo((2, 4, 13), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x0, w0, strides=(2,)),
+ relax.TensorStructInfo((2, 4, 13), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x0, w0, dilation=2),
+ relax.TensorStructInfo((2, 4, 24), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x0, w0, dilation=(2,)),
+ relax.TensorStructInfo((2, 4, 24), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x1, w0, data_layout="NWC"),
+ relax.TensorStructInfo((2, 26, 4), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x0, w0, out_layout="NWC"),
+ relax.TensorStructInfo((2, 26, 4), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x0, w1, kernel_layout="IOW"),
+ relax.TensorStructInfo((2, 4, 26), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(
+ x5, w4, data_layout="NCW16c", kernel_layout="OIW16i", out_layout="NWC16c"
+ ),
+ relax.TensorStructInfo((2, 26, 3, 16), "float32"),
+ )
+ _check_inference(
+ bb, relax.op.nn.conv1d(x2, w0), relax.TensorStructInfo(dtype="float32", ndim=3)
+ )
+ _check_inference(
+ bb, relax.op.nn.conv1d(x3, w0), relax.TensorStructInfo(dtype="float32", ndim=3)
+ )
+ _check_inference(
+ bb, relax.op.nn.conv1d(x0, w2), relax.TensorStructInfo(dtype="float32", ndim=3)
+ )
+ _check_inference(
+ bb, relax.op.nn.conv1d(x0, w3), relax.TensorStructInfo(dtype="float32", ndim=3)
+ )
+ _check_inference(bb, relax.op.nn.conv1d(x4, w0), relax.TensorStructInfo(dtype="", ndim=3))
+
+
+def test_conv1d_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ c = tir.Var("c", "int64")
+ c16 = tir.Var("c16", "int64")
+ iw = tir.Var("iw", "int64")
+ ki = tir.Var("ki", "int64")
+ ko = tir.Var("ko", "int64")
+ kw = tir.Var("kw", "int64")
+ x0 = relax.Var("x", R.Tensor((n, c, iw), "float32"))
+ x1 = relax.Var("x", R.Tensor((n, c, iw, c16), "float32"))
+ w0 = relax.Var("w", R.Tensor((ko, ki, kw), "float32"))
+ w1 = relax.Var("w", R.Tensor((ko, c, kw), "float32"))
+ w2 = relax.Var("w", R.Tensor((ko, c, kw, c16), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x0, w0),
+ relax.TensorStructInfo((n, ko, iw + 1 - kw), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x0, w1),
+ relax.TensorStructInfo((n, ko, iw + 1 - kw), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x1, w2, data_layout="NCW16c", kernel_layout="OIW16i", out_layout="NCW"),
+ relax.TensorStructInfo((n, ko, iw + 1 - kw), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x0, w0, strides=2, padding=1, dilation=2),
+ relax.TensorStructInfo(
+ (n, ko, tvm.tir.floordiv(iw + 3, 2) + 1 - kw),
+ "float32",
+ ),
+ )
+
+
+def test_conv1d_infer_struct_info_shape_var():
+ bb = relax.BlockBuilder()
+ s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3))
+ s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+ s2 = relax.Var("s", relax.ShapeStructInfo(ndim=3))
+ s3 = relax.Var("s", relax.ShapeStructInfo())
+ x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+ x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+ x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32"))
+ w = relax.Var("w", relax.TensorStructInfo(s2, "float32"))
+
+ _check_inference(bb, relax.op.nn.conv1d(x0, w), relax.TensorStructInfo(dtype="float32", ndim=3))
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x1, w, data_layout="NCW16c"),
+ relax.TensorStructInfo(dtype="float32", ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x0, w, out_layout="NCW16c"),
+ relax.TensorStructInfo(dtype="float32", ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x2, w),
+ relax.TensorStructInfo(dtype="float32", ndim=3),
+ )
+
+
+def test_conv1d_infer_struct_info_groups():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 128, 28), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 8, 28, 16), "float32"))
+ w0 = relax.Var("w", R.Tensor((48, 16, 3), "float32"))
+ w1 = relax.Var("w", R.Tensor((48, 2, 3, 8), "float32"))
+
+ _check_inference(
+ bb, relax.op.nn.conv1d(x0, w0, groups=8), relax.TensorStructInfo((2, 48, 26), "float32")
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x0, w1, kernel_layout="OIW8i", groups=8),
+ relax.TensorStructInfo((2, 48, 26), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x1, w0, data_layout="NCW16c", groups=8),
+ relax.TensorStructInfo((2, 3, 26, 16), "float32"),
+ )
+
+
+def test_conv1d_infer_struct_info_symbolic_groups():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ ic = tir.Var("c", "int64")
+ oc = tir.Var("oc", "int64")
+ x = relax.Var("x", R.Tensor((n, ic * 4, 28), "float32"))
+ w0 = relax.Var("w", R.Tensor((oc * 4, ic, 3), "float32"))
+ w1 = relax.Var("w", R.Tensor((oc, ic, 3), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x, w0, groups=4),
+ relax.TensorStructInfo((n, oc * 4, 26), "float32"),
+ )
+ _check_inference(
+ bb, relax.op.nn.conv1d(x, w1, groups=4), relax.TensorStructInfo((n, oc, 26), "float32")
+ )
+
+
+def test_conv1d_infer_struct_info_input_channel_group_incompatible():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ ic = tir.Var("c", "int64")
+ oc = tir.Var("oc", "int64")
+ x0 = relax.Var("x", R.Tensor((2, 128, 28), "float32"))
+ w0 = relax.Var("w", R.Tensor((48, 20, 3), "float32"))
+ x1 = relax.Var("x", R.Tensor((n, ic * 6, 28), "float32"))
+ w1 = relax.Var("w", R.Tensor((oc, ic - 1, 3), "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d(x0, w0, groups=6))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d(x1, w1, groups=6))
+
+
+def test_conv1d_infer_struct_info_output_channel_group_incompatible():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ ic = tir.Var("c", "int64")
+ oc = tir.Var("oc", "int64")
+ x0 = relax.Var("x", R.Tensor((2, 120, 28), "float32"))
+ w0 = relax.Var("w", R.Tensor((128, 20, 3), "float32"))
+ x1 = relax.Var("x", R.Tensor((n, ic * 6, 28), "float32"))
+ w1 = relax.Var("w", R.Tensor((oc * 6 + 4, ic * 6, 3), "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d(x0, w0, groups=6))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d(x1, w1, groups=6))
+
+
+def test_conv1d_non_positive_group():
+ x = relax.Var("x", R.Tensor((2, 128, 28), "float32"))
+ w = relax.Var("w", R.Tensor((48, 16, 3), "float32"))
+
+ with pytest.raises(TVMError):
+ relax.op.nn.conv1d(x, w, groups=0)
+ with pytest.raises(TVMError):
+ relax.op.nn.conv1d(x, w, groups=-2)
+
+
+def test_conv1d_infer_struct_info_more_input_dtype():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28), "float16"))
+ w0 = relax.Var("w", R.Tensor((4, 3, 3), "float16"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 28), "float64"))
+ w1 = relax.Var("w", R.Tensor((4, 3, 3), "float64"))
+ x2 = relax.Var("x", R.Tensor((2, 3, 28), "int8"))
+ w2 = relax.Var("w", R.Tensor((4, 3, 3), "int8"))
+ x3 = relax.Var("x", R.Tensor((2, 3, 28), "int32"))
+ w3 = relax.Var("w", R.Tensor((4, 3, 3), "int32"))
+
+ _check_inference(bb, relax.op.nn.conv1d(x0, w0), relax.TensorStructInfo((2, 4, 26), "float16"))
+ _check_inference(bb, relax.op.nn.conv1d(x1, w1), relax.TensorStructInfo((2, 4, 26), "float64"))
+ _check_inference(bb, relax.op.nn.conv1d(x2, w2), relax.TensorStructInfo((2, 4, 26), "int8"))
+ _check_inference(bb, relax.op.nn.conv1d(x3, w3), relax.TensorStructInfo((2, 4, 26), "int32"))
+
+
+def test_conv1d_infer_struct_info_mixed_precision():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28), "float16"))
+ w0 = relax.Var("w", R.Tensor((4, 3, 3), "float16"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 28), "int8"))
+ w1 = relax.Var("w", R.Tensor((4, 3, 3), "int8"))
+ x2 = relax.Var("x", R.Tensor((2, 3, 28)))
+ w2 = relax.Var("w", R.Tensor((4, 3, 3)))
+
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x0, w0, out_dtype="float32"),
+ relax.TensorStructInfo((2, 4, 26), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x1, w1, out_dtype="int32"),
+ relax.TensorStructInfo((2, 4, 26), "int32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x2, w2, out_dtype="float32"),
+ relax.TensorStructInfo((2, 4, 26), "float32"),
+ )
+
+
+def test_conv1d_unequal_input_channel():
+ bb = relax.BlockBuilder()
+ ic = tir.Var("ic", "int64")
+ x0 = relax.Var("x", R.Tensor([2, 3, 28], "float32"))
+ w0 = relax.Var("w", R.Tensor([3, 4, 3], "float32"))
+ x1 = relax.Var("x", R.Tensor([2, ic, 28], "float32"))
+ w1 = relax.Var("w", R.Tensor([4, ic + 2, 3], "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d(x0, w0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d(x1, w1))
+
+
+def test_conv1d_stride_padding_dilation_int64():
+ x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ w = relax.Var("w", R.Tensor((4, 3, 3), "float32"))
+ conv1d = relax.op.nn.conv1d(x, w, strides=(1,), padding=(1, 1), dilation=(1,))
+
+ assert conv1d.attrs.strides[0].dtype == "int64"
+ assert conv1d.attrs.padding[0].dtype == "int64"
+ assert conv1d.attrs.padding[1].dtype == "int64"
+ assert conv1d.attrs.dilation[0].dtype == "int64"
+
+
+def test_conv1d_wrong_strides_padding_dilation_length():
+ x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ w = relax.Var("w", R.Tensor((4, 3, 3), "float32"))
+ with pytest.raises(TVMError):
+ relax.op.nn.conv1d(x, w, strides=(1, 2))
+ with pytest.raises(TVMError):
+ relax.op.nn.conv1d(x, w, padding=(1, 2, 3))
+ with pytest.raises(TVMError):
+ relax.op.nn.conv1d(x, w, dilation=(1, 2))
+
+
+def test_conv1d_infer_struct_info_wrong_layout_string():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ w = relax.Var("w", R.Tensor((4, 3, 3), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d(x, w, data_layout="OIW"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d(x, w, kernel_layout="NWC"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d(x, w, out_layout="OWI"))
+
+
+def test_conv1d_dtype_mismatch():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ w = relax.Var("w", R.Tensor((4, 3, 3), "int8"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d(x, w))
+
+
+def test_conv1d_wrong_input_ndim():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 28, 3), "float32"))
+ x2 = relax.Var("x", R.Tensor("float32", ndim=2))
+ w0 = relax.Var("w", R.Tensor((4, 3, 3), "float32"))
+ w1 = relax.Var("w", R.Tensor((4, 3, 6, 3), "float32"))
+ w2 = relax.Var("w", R.Tensor("float32", ndim=5))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d(x0, w1))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d(x0, w1, data_layout="NCW16c"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d(x0, w2))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d(x1, w0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d(x2, w0))
+
+
+def test_conv1d_infer_struct_info_wrong_input_type():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28)))
+ w0 = relax.Var("w", R.Tensor((4, 3, 3), "float32"))
+ w1 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((4, 3, 3), "float32")))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d(x0, w1))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d(x1, w0))
+
+
def test_conv2d_infer_struct_info():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py
index a1fe266d68..e944b8d76e 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -25,6 +25,183 @@ import tvm.testing
##################### Neural network #####################
+def test_conv1d():
+ # fmt: off
+ @tvm.script.ir_module
+ class Conv1d:
+ @R.function
+ def main(x: R.Tensor((2, 128, 28), "float32"), w: R.Tensor((64, 16, 3), "float32")) -> R.Tensor((2, 64, 13), "float32"):
+ gv: R.Tensor((2, 4, 13), "float32") = R.nn.conv1d(x, w, strides=(2,), padding=(1,), dilation=(2,), groups=8)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((64, 16, 3), dtype="float32")) -> R.Tensor((2, 64, 13), dtype="float32"):
+ gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 64, 13), dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(16), T.int64(3)), "float32"), conv1d_ncw: T.Buffer((T.int64(2), T.int64(64), T.int64(13)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ pad_temp = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(30)))
+ for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(30)):
+ with T.block("pad_temp"):
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1)])
+ T.writes(pad_temp[v_i0, v_i1, v_i2])
+ pad_temp[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(29), rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1)], T.float32(0))
+ for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(64), T.int64(13), T.int64(128), T.int64(3)):
+ with T.block("conv1d_ncw"):
+ v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, ff, yy, rc, ry])
+ T.reads(pad_temp[v_nn, v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)], rxplaceholder_1[v_ff, v_rc, v_ry])
+ T.writes(conv1d_ncw[v_nn, v_ff, v_yy])
+ with T.init():
+ conv1d_ncw[v_nn, v_ff, v_yy] = T.float32(0)
+ conv1d_ncw[v_nn, v_ff, v_yy] = conv1d_ncw[v_nn, v_ff, v_yy] + pad_temp[v_nn, v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)] * rxplaceholder_1[v_ff, v_rc, v_ry]
+ # fmt: on
+
+ mod = LegalizeOps()(Conv1d)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_conv1d_with_out_dtype():
+ # fmt: off
+ @tvm.script.ir_module
+ class Conv1d:
+ @R.function
+ def main(x: R.Tensor((2, 3, 28), "float32"), w: R.Tensor((4, 3, 3), "float32")) -> R.Tensor((2, 4, 26), "float16"):
+ gv: R.Tensor((2, 4, 26), "float16") = R.nn.conv1d(x, w, out_dtype="float16")
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3, 28), dtype="float32"), w: R.Tensor((4, 3, 3), dtype="float32")) -> R.Tensor((2, 4, 26), dtype="float16"):
+ gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 4, 26), dtype="float16"))
+ return gv
+
+ @T.prim_func
+ def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(3)), "float32"), conv1d_ncw: T.Buffer((T.int64(2), T.int64(4), T.int64(26)), "float16")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ pad_temp = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28)))
+ for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(28)):
+ with T.block("pad_temp"):
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(rxplaceholder[v_i0, v_i1, v_i2])
+ T.writes(pad_temp[v_i0, v_i1, v_i2])
+ pad_temp[v_i0, v_i1, v_i2] = rxplaceholder[v_i0, v_i1, v_i2]
+ for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(3), T.int64(3)):
+ with T.block("conv1d_ncw"):
+ v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, ff, yy, rc, ry])
+ T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry], rxplaceholder_1[v_ff, v_rc, v_ry])
+ T.writes(conv1d_ncw[v_nn, v_ff, v_yy])
+ with T.init():
+ conv1d_ncw[v_nn, v_ff, v_yy] = T.float16(0)
+ conv1d_ncw[v_nn, v_ff, v_yy] = conv1d_ncw[v_nn, v_ff, v_yy] + T.Cast("float16", pad_temp[v_nn, v_rc, v_yy + v_ry]) * T.Cast("float16", rxplaceholder_1[v_ff, v_rc, v_ry])
+ # fmt: on
+
+ mod = LegalizeOps()(Conv1d)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_conv1d_nwc():
+ # fmt: off
+ @tvm.script.ir_module
+ class Conv1d:
+ @R.function
+ def main(x: R.Tensor((2, 28, 128), "float32"), w: R.Tensor((64, 128, 3), "float32")) -> R.Tensor((2, 26, 64), "float32"):
+ gv: R.Tensor((2, 26, 64), "float32") = R.nn.conv1d(x, w, data_layout="NWC")
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 28, 128), dtype="float32"), w: R.Tensor((64, 128, 3), dtype="float32")) -> R.Tensor((2, 26, 64), dtype="float32"):
+ gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 26, 64), dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(28), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(128), T.int64(3)), "float32"), conv1d_nwc: T.Buffer((T.int64(2), T.int64(26), T.int64(64)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ pad_temp = T.alloc_buffer((T.int64(2), T.int64(28), T.int64(128)))
+ for i0, i1, i2 in T.grid(T.int64(2), T.int64(28), T.int64(128)):
+ with T.block("pad_temp"):
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(rxplaceholder[v_i0, v_i1, v_i2])
+ T.writes(pad_temp[v_i0, v_i1, v_i2])
+ pad_temp[v_i0, v_i1, v_i2] = rxplaceholder[v_i0, v_i1, v_i2]
+ for nn, yy, ff, ry, rc in T.grid(T.int64(2), T.int64(26), T.int64(64), T.int64(3), T.int64(128)):
+ with T.block("conv1d_nwc"):
+ v_nn, v_yy, v_ff, v_ry, v_rc = T.axis.remap("SSSRR", [nn, yy, ff, ry, rc])
+ T.reads(pad_temp[v_nn, v_yy + v_ry, v_rc], rxplaceholder_1[v_ff, v_rc, v_ry])
+ T.writes(conv1d_nwc[v_nn, v_yy, v_ff])
+ with T.init():
+ conv1d_nwc[v_nn, v_yy, v_ff] = T.float32(0)
+ conv1d_nwc[v_nn, v_yy, v_ff] = conv1d_nwc[v_nn, v_yy, v_ff] + pad_temp[v_nn, v_yy + v_ry, v_rc] * rxplaceholder_1[v_ff, v_rc, v_ry]
+ # fmt: on
+
+ mod = LegalizeOps()(Conv1d)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_conv1d_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class Conv1d:
+ @R.function
+ def main(x: R.Tensor(("n", "c", "w"), "float32"), kernel: R.Tensor(("f", "c", "kw"), "float32")) -> R.Tensor(("n", "f", "w - kw + 1"), "float32"):
+ n = T.int64()
+ w = T.int64()
+ f = T.int64()
+ kw = T.int64()
+ gv: R.Tensor((n, f, w - kw + 1), "float32") = R.nn.conv1d(x, kernel)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor(("n", "c", "w"), dtype="float32"), kernel: R.Tensor(("f", "c", "kw"), dtype="float32")) -> R.Tensor(("n", "f", "w - kw + 1"), dtype="float32"):
+ n = T.int64()
+ f = T.int64()
+ w = T.int64()
+ kw = T.int64()
+ c = T.int64()
+ gv = R.call_tir(Expected.conv1d, (x, kernel), out_sinfo=R.Tensor((n, f, w - kw + 1), dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def conv1d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv1d_ncw: T.handle):
+ T.func_attr({"tir.noalias": True})
+ n, c, w = T.int64(), T.int64(), T.int64()
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (n, c, w))
+ f, kw = T.int64(), T.int64()
+ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (f, c, kw))
+ conv1d_ncw = T.match_buffer(var_conv1d_ncw, (n, f, w - kw + T.int64(1)))
+ # with T.block("root"):
+ pad_temp = T.alloc_buffer((n, c, w))
+ for i0, i1, i2 in T.grid(n, c, w):
+ with T.block("pad_temp"):
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(rxplaceholder[v_i0, v_i1, v_i2])
+ T.writes(pad_temp[v_i0, v_i1, v_i2])
+ pad_temp[v_i0, v_i1, v_i2] = rxplaceholder[v_i0, v_i1, v_i2]
+ for nn, ff, yy, rc, ry in T.grid(n, f, w + T.int64(1) - kw, c, kw):
+ with T.block("conv1d_ncw"):
+ v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, ff, yy, rc, ry])
+ T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry], rxplaceholder_1[v_ff, v_rc, v_ry])
+ T.writes(conv1d_ncw[v_nn, v_ff, v_yy])
+ with T.init():
+ conv1d_ncw[v_nn, v_ff, v_yy] = T.float32(0)
+ conv1d_ncw[v_nn, v_ff, v_yy] = conv1d_ncw[v_nn, v_ff, v_yy] + pad_temp[v_nn, v_rc, v_yy + v_ry] * rxplaceholder_1[v_ff, v_rc, v_ry]
+ # fmt: on
+
+ mod = LegalizeOps()(Conv1d)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_conv2d():
# fmt: off
@tvm.script.ir_module
diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py b/tests/python/relax/test_tvmscript_parser_op_nn.py
index cfb454a578..a822fae719 100644
--- a/tests/python/relax/test_tvmscript_parser_op_nn.py
+++ b/tests/python/relax/test_tvmscript_parser_op_nn.py
@@ -35,16 +35,34 @@ def _check(
tvm.ir.assert_structural_equal(parsed, expect)
+def test_conv1d():
+ @R.function
+ def foo(
+ x: R.Tensor((2, 3, 228), "float16"), w: R.Tensor((16, 3, 5), "float16")
+ ) -> R.Tensor((2, 16, 224), "float16"):
+ gv: R.Tensor((2, 16, 224), "float16") = R.nn.conv1d(x, w, out_dtype="float16")
+ return gv
+
+ x = relax.Var("x", R.Tensor([2, 3, 228], "float16"))
+ w = relax.Var("w", R.Tensor([16, 3, 5], "float16"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [x, w]):
+ gv = bb.emit(relax.op.nn.conv1d(x, w, out_dtype="float16"))
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
def test_conv2d():
@R.function
def foo(
- x: R.Tensor((2, 3, 228, 228), "float32"), w: R.Tensor((16, 3, 5, 5), "float32")
+ x: R.Tensor((2, 3, 228, 228), "float16"), w: R.Tensor((16, 3, 5, 5), "float16")
) -> R.Tensor((2, 16, 224, 224), "float16"):
gv: R.Tensor((2, 16, 224, 224), "float16") = R.nn.conv2d(x, w, out_dtype="float16")
return gv
- x = relax.Var("x", R.Tensor([2, 3, 228, 228], "float32"))
- w = relax.Var("w", R.Tensor([16, 3, 5, 5], "float32"))
+ x = relax.Var("x", R.Tensor([2, 3, 228, 228], "float16"))
+ w = relax.Var("w", R.Tensor([16, 3, 5, 5], "float16"))
bb = relax.BlockBuilder()
with bb.function("foo", [x, w]):
gv = bb.emit(relax.op.nn.conv2d(x, w, out_dtype="float16"))
@@ -56,15 +74,15 @@ def test_conv2d():
def test_conv2d_transpose():
@R.function
def foo(
- x: R.Tensor((2, 3, 228, 228), "float32"), w: R.Tensor((3, 16, 5, 5), "float32")
+ x: R.Tensor((2, 3, 228, 228), "float16"), w: R.Tensor((3, 16, 5, 5), "float16")
) -> R.Tensor((2, 16, 232, 232), "float16"):
gv: R.Tensor((2, 16, 232, 232), "float16") = R.nn.conv2d_transpose(
x, w, out_dtype="float16"
)
return gv
- x = relax.Var("x", R.Tensor([2, 3, 228, 228], "float32"))
- w = relax.Var("w", R.Tensor([3, 16, 5, 5], "float32"))
+ x = relax.Var("x", R.Tensor([2, 3, 228, 228], "float16"))
+ w = relax.Var("w", R.Tensor([3, 16, 5, 5], "float16"))
bb = relax.BlockBuilder()
with bb.function("foo", [x, w]):
gv = bb.emit(relax.op.nn.conv2d_transpose(x, w, out_dtype="float16"))