You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ru...@apache.org on 2023/08/02 17:27:23 UTC
[tvm] branch unity updated: [Unity][Op] Conv1dTranspose (#15456)
This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new c3f26858ba [Unity][Op] Conv1dTranspose (#15456)
c3f26858ba is described below
commit c3f26858ba325a055d399c69168474fc606fa71e
Author: Lesheng Jin <34...@users.noreply.github.com>
AuthorDate: Wed Aug 2 10:27:16 2023 -0700
[Unity][Op] Conv1dTranspose (#15456)
This PR introduces Conv1dTranspose to relax.
---
include/tvm/relax/attrs/nn.h | 45 +++
python/tvm/relax/op/nn/nn.py | 91 ++++++
python/tvm/relax/transform/legalize_ops/nn.py | 40 +++
src/relax/op/nn/convolution.cc | 126 ++++++++
src/relax/op/nn/convolution.h | 11 +
tests/python/relax/test_op_nn_convolution.py | 351 ++++++++++++++++++++++
tests/python/relax/test_tvmscript_parser_op_nn.py | 18 ++
7 files changed, 682 insertions(+)
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index a59cf5e71f..2dc610f654 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -117,6 +117,51 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
}
}; // struct Conv2dAttrs
+/*! \brief Attributes used in Conv1DTranspose operator */
+struct Conv1DTransposeAttrs : public tvm::AttrsNode<Conv1DTransposeAttrs> {
+ Array<IntImm> strides;
+ Array<IntImm> padding;
+ Array<IntImm> output_padding;
+ Array<IntImm> dilation;
+ int groups;
+ String data_layout;
+ String kernel_layout;
+ String out_layout;
+ DataType out_dtype;
+
+ TVM_DECLARE_ATTRS(Conv1DTransposeAttrs, "relax.attrs.Conv1DTransposeAttrs") {
+ TVM_ATTR_FIELD(strides).describe("Specifies the strides of the convolution.");
+ TVM_ATTR_FIELD(padding).describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on both sides"
+ "two int : padding width in the order of (left, right)");
+ TVM_ATTR_FIELD(output_padding).describe("Used to disambiguate the output shape.");
+ TVM_ATTR_FIELD(dilation).describe(
+ "Specifies the dilation rate to use for dilated convolution.");
+ TVM_ATTR_FIELD(groups).describe(
+ "Number of groups to split the input into for grouped convolution. The number of input and "
+ "output channels should be divisible by the number of groups.");
+ TVM_ATTR_FIELD(data_layout)
+ .describe(
+ "Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
+ "'N', 'C', 'W' stands for batch, channel, width"
+ "dimensions respectively. Convolution is applied on the 'W' dimensions.");
+ TVM_ATTR_FIELD(kernel_layout)
+ .describe(
+ "Dimension ordering of weight. Can be 'OIW', 'IOW', etc."
+ "'O', 'I', 'W' stands for num_filter, input_channel, and width"
+ "dimensions respectively.");
+ TVM_ATTR_FIELD(out_layout)
+ .describe(
+ "Dimension ordering of output. Can be 'NCW', 'NWC', etc."
+ "'N', 'C', 'W' stands for batch, channel, and width"
+ "dimensions respectively. Default to be same as input layout.");
+ TVM_ATTR_FIELD(out_dtype).describe(
+ "Output data type, set to explicit type under mixed precision setting");
+ }
+}; // struct Conv1DTransposeAttrs
+
/*! \brief Attributes used in Conv2d operator */
struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
Array<IntImm> strides;
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index aa5c7b18f6..1a4c3cceae 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -220,6 +220,97 @@ def conv2d(
)
+def conv1d_transpose(
+ data: Expr,
+ weight: Expr,
+ strides: Union[int, Tuple[int]] = 1,
+ padding: Union[int, Tuple[int, ...]] = 0,
+ output_padding: Union[int, Tuple[int]] = 0,
+ dilation: Union[int, Tuple[int]] = 1,
+ groups: int = 1,
+ data_layout: str = "NCW",
+ kernel_layout: str = "IOW",
+ out_layout: Optional[str] = None,
+ out_dtype: Optional[Union[str, DataType]] = None,
+) -> Expr:
+ r"""1D transposed convolution operator.
+
+ This operator can be seen as the gradient operator of conv1d.
+
+ The output shape can be explained in the simple case when `data_layout == "NCW"` and
+ `kernel_layout == "IOW"`. Suppose `data` has shape `(N, in_channel, in_w)`, `weight` has
+ shape `(in_channel, out_channel, weight_w)`, we need to assure that `in_channel % groups == 0`.
+ The shape of the output will be `(N, out_channel * groups, out_w)`, where
+
+ - `out_w = ((in_w - 1) * strides[0] + weight_w - 2 * padding[0] + output_padding[0])`
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data to the operator.
+
+ weight : relax.Expr
+ The weight expressions.
+
+ strides : Union[int, Tuple[int]]
+ The strides of convolution. It is required to have length 1.
+
+ padding : Union[int, Tuple[int, ...]]
+ The padding of convolution on both sides of inputs before convolution.
+ It is required to have length either 1 or 2.
+
+ output_padding : Union[int, Tuple[int, ...]], optional
+ Used to disambiguate the output shape.
+
+ dilation : Union[int, Tuple[int]]
+ Specifies the dilation rate to be used for dilated convolution.
+ It is required to have length either 1.
+
+ groups : int
+ Number of groups to split the input into for grouped convolution.
+ The number of input and output channels should be divisible by the number of groups.
+
+ data_layout : str
+ Layout of the input.
+
+ kernel_layout : str
+ Layout of the weight.
+
+ out_layout : Optional[str]
+ Layout of the output. If not specified, it is the same as data_layout
+
+ out_dtype : Optional[Union[str, DataType]]
+ Specifies the output data type for mixed precision conv2d.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ if isinstance(strides, int):
+ strides = (strides,)
+ if isinstance(dilation, int):
+ dilation = (dilation,)
+ if isinstance(padding, int):
+ padding = (padding, padding)
+ if isinstance(output_padding, int):
+ output_padding = (output_padding,)
+
+ return _ffi_api.conv1d_transpose( # type: ignore
+ data,
+ weight,
+ strides,
+ padding,
+ output_padding,
+ dilation,
+ groups,
+ data_layout,
+ kernel_layout,
+ out_layout,
+ out_dtype,
+ )
+
+
def conv2d_transpose(
data: Expr,
weight: Expr,
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py
index 5386fbf7cb..e4e608e769 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -108,6 +108,46 @@ def _nn_conv2d(bb: BlockBuilder, call: Call) -> Expr:
)
+@register_legalize("relax.nn.conv1d_transpose")
+def _nn_conv1d_transpose(bb: BlockBuilder, call: Call) -> Expr:
+ if call.attrs.out_layout != call.attrs.data_layout:
+ logging.info(
+ "TOPI conv1d_transpose does not support different input-output "
+ "layouts, and thus cannot be legalized by TOPI"
+ )
+ return call
+ if call.attrs.data_layout != "NCW" or call.attrs.kernel_layout != "IOW":
+ logging.info(
+ "TOPI conv1d_transpose does not support input layout other than NCW, "
+ "and kernel layout other than IOW, so cannot be legalized by TOPI"
+ )
+ return call
+ dilation = call.attrs.dilation
+ if len(dilation) != 1 or dilation[0] != 1:
+ logging.info(
+ "TOPI conv1d_transpose does not support dilations other than 1, "
+ "and thus cannot be legalized by TOPI"
+ )
+ return call
+ if call.attrs.groups != 1:
+ logging.info(
+ "TOPI conv1d_transpose does not support groups other than 1, "
+ "and thus cannot be legalized by TOPI"
+ )
+ return call
+
+ return bb.call_te(
+ topi.nn.conv1d_transpose_ncw,
+ call.args[0],
+ call.args[1],
+ stride=call.attrs.strides,
+ padding=call.attrs.padding,
+ out_dtype=call.struct_info.dtype,
+ output_padding=call.attrs.output_padding,
+ primfunc_name_hint="conv1d_transpose",
+ )
+
+
@register_legalize("relax.nn.conv2d_transpose")
def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> Expr:
if call.attrs.out_layout != call.attrs.data_layout:
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index d698cf9757..96fc7e8464 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -352,6 +352,132 @@ TVM_REGISTER_OP("relax.nn.conv2d")
.set_attr<FInferMixedPrecision>("FInferMixedPrecision", InferMixedPrecisionConv2d)
.set_attr<Bool>("FPurity", Bool(true));
+TVM_REGISTER_NODE_TYPE(Conv1DTransposeAttrs);
+
+Expr conv1d_transpose(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm> padding,
+ Array<IntImm> output_padding, Array<IntImm> dilation, int groups,
+ String data_layout, String kernel_layout, Optional<String> out_layout,
+ DataType out_dtype) {
+ padding = GetCompletePadding1D(std::move(padding));
+
+ CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, "
+ "the given number of groups is "
+ << groups;
+ CHECK_EQ(output_padding.size(), 1) << "The input output_padding length is expected to be 1. "
+ "However, the given output_padding is "
+ << output_padding;
+ CHECK_EQ(strides.size(), 1)
+ << "The input strides length is expected to be 1. However, the given strides is " << strides;
+ CHECK_EQ(dilation.size(), 1)
+ << "The input dilation length is expected to be 1. However, the given dilation is "
+ << dilation;
+
+ auto attrs = make_object<Conv1DTransposeAttrs>();
+ attrs->strides = ConvertIntImmToInt64(strides);
+ attrs->padding = ConvertIntImmToInt64(padding);
+ attrs->output_padding = ConvertIntImmToInt64(output_padding);
+ attrs->dilation = ConvertIntImmToInt64(dilation);
+ attrs->groups = groups;
+ attrs->data_layout = data_layout;
+ attrs->kernel_layout = std::move(kernel_layout);
+ attrs->out_layout = std::move(out_layout.value_or(data_layout));
+ attrs->out_dtype = std::move(out_dtype);
+ const Op& op = Op::Get("relax.nn.conv1d_transpose");
+ return Call(op, {data, weight}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.conv1d_transpose").set_body_typed(conv1d_transpose);
+
+StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder& ctx) {
+ Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+ TensorStructInfo data_sinfo = input_sinfo[0];
+ TensorStructInfo weight_sinfo = input_sinfo[1];
+
+ const auto* attrs = call->attrs.as<Conv1DTransposeAttrs>();
+ auto [data_layout, data2NCW] = CheckTensorLayout(call, ctx, attrs->data_layout, //
+ /*tgt_layout=*/"NCW", //
+ /*tensor_name=*/"data");
+ auto [weight_layout, weight2IOW] = CheckTensorLayout(call, ctx, attrs->kernel_layout, //
+ /*tgt_layout=*/"IOW", //
+ /*tensor_name=*/"kernel");
+ auto [out_layout, out2NCW] = CheckTensorLayout(call, ctx, attrs->out_layout, //
+ /*tgt_layout=*/"NCW", //
+ /*tensor_name=*/"output");
+ Optional<ShapeExpr> data_shape =
+ CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
+ Optional<ShapeExpr> weight_shape =
+ CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout);
+
+ DataType out_dtype = attrs->out_dtype.is_void()
+ ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo)
+ : attrs->out_dtype;
+ if (!data_shape.defined() || !weight_shape.defined()) {
+ return TensorStructInfo(out_dtype, out_layout.ndim());
+ }
+
+ Array<PrimExpr> data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values);
+ Array<PrimExpr> weight_IOW_shape = weight2IOW.ForwardShape(weight_shape.value()->values);
+
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+ PrimExpr input_channel_data = data_NCW_shape[1];
+ PrimExpr input_channel_kernel = weight_IOW_shape[0];
+ if (analyzer->CanProve(input_channel_data != input_channel_kernel)) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "Conv1dTranspose expects the channel size of the data should equal to the input channel "
+ "size of the weight. However, the data channel size is "
+ << input_channel_data << " while the weight input channel size is "
+ << input_channel_kernel);
+ } else if (!analyzer->CanProveEqual(input_channel_data, input_channel_kernel)) {
+ // Todo(relax-team): Trust the input shape at this moment, and revisit
+ // this condition with runtime shape check
+ }
+ if (analyzer->CanProve(floormod(input_channel_kernel, attrs->groups) != 0)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Conv1dTranspose expects the number of input channels to be divisible by "
+ "the number of groups. However, the number of input channels is "
+ << input_channel_kernel << " while the number of groups is " << attrs->groups);
+ } else if (!analyzer->CanProveEqual(floormod(input_channel_kernel, attrs->groups), 0)) {
+ // Todo(relax-team): Trust the input shape at this moment, and revisit
+ // this condition with runtime shape check
+ }
+ if (analyzer->CanProve(attrs->output_padding[0]->value >= attrs->strides[0]->value)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Conv1dTranspose expects the output padding less than the strides, but the "
+ "output padding is"
+ << attrs->output_padding << " while the strides are" << attrs->strides);
+ } else if (!analyzer->CanProve(attrs->output_padding[0]->value < attrs->strides[0]->value)) {
+ // Todo(relax-team): Trust the input padding at this moment, and revisit
+ // this condition with runtime shape check
+ }
+
+ PrimExpr input_w = data_NCW_shape[2];
+ PrimExpr kernel_w = weight_IOW_shape[2];
+ PrimExpr padding_w = attrs->padding[0] + attrs->padding[1];
+
+ std::vector<PrimExpr> out_NCW_shape;
+ out_NCW_shape.resize(3);
+ out_NCW_shape[0] = data_NCW_shape[0];
+ out_NCW_shape[1] = weight_IOW_shape[1] * attrs->groups;
+
+ PrimExpr out_w = (input_w - 1) * attrs->strides[0] - padding_w +
+ attrs->dilation[0] * (kernel_w - 1) + attrs->output_padding[0] + 1;
+ out_NCW_shape[2] = analyzer->Simplify(out_w);
+
+ Array<PrimExpr> out_shape = out2NCW.BackwardShape(out_NCW_shape);
+ return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
+}
+
+// TODO(relax-team): implement FInferMixedPrecision and FRelaxInferLayout for conv1d_transpose
+// and unit test for mixed_precision
+TVM_REGISTER_OP("relax.nn.conv1d_transpose")
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_attrs_type<Conv1DTransposeAttrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv1dTranspose)
+ .set_attr<Bool>("FPurity", Bool(true));
+
/* relax.nn.conv2d_transpose */
TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h
index 833e730ee9..536d2af371 100644
--- a/src/relax/op/nn/convolution.h
+++ b/src/relax/op/nn/convolution.h
@@ -62,6 +62,17 @@ Expr conv2d(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm> padding
Array<IntImm> dilation, int groups, String data_layout, String kernel_layout,
Optional<String> out_layout, DataType out_dtype);
+/*!
+ * \brief One dimensional transposed convolution operator.
+ *
+ * This operator is intended to be the backward operator of conv1d. It can be used to calculate the
+ * gradient of the result of conv1d w.r.t. the input of conv1d.
+ */
+Expr conv1d_transpose(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm> padding,
+ Array<IntImm> output_padding, Array<IntImm> dilation, int groups,
+ String data_layout, String kernel_layout, Optional<String> out_layout,
+ DataType out_dtype);
+
/*!
* \brief Two dimensional transposed convolution operator.
*
diff --git a/tests/python/relax/test_op_nn_convolution.py b/tests/python/relax/test_op_nn_convolution.py
index d1d604429e..2ec451c132 100644
--- a/tests/python/relax/test_op_nn_convolution.py
+++ b/tests/python/relax/test_op_nn_convolution.py
@@ -27,6 +27,7 @@ def test_conv1d_op_correctness():
x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
w = relax.Var("w", R.Tensor((4, 3, 3), "float32"))
assert relax.op.nn.conv1d(x, w).op == Op.get("relax.nn.conv1d")
+ assert relax.op.nn.conv1d_transpose(x, w).op == Op.get("relax.nn.conv1d_transpose")
def test_conv2d_op_correctness():
@@ -411,6 +412,356 @@ def test_conv1d_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.nn.conv1d(x1, w0))
+def test_conv1d_transpose_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 28, 3), "float32"))
+ x2 = relax.Var("x", R.Tensor("float32", ndim=3))
+ x3 = relax.Var("x", R.Tensor("float32"))
+ x4 = relax.Var("x", R.Tensor())
+ x5 = relax.Var("x", R.Tensor((2, 4, 28, 16), "float32"))
+ w0 = relax.Var("w", R.Tensor((3, 4, 3), "float32"))
+ w1 = relax.Var("w", R.Tensor((4, 3, 3), "float32"))
+ w2 = relax.Var("w", R.Tensor("float32", ndim=3))
+ w3 = relax.Var("w", R.Tensor("float32"))
+ w4 = relax.Var("w", R.Tensor((4, 48, 3, 16), "float32"))
+
+ _check_inference(
+ bb, relax.op.nn.conv1d_transpose(x0, w0), relax.TensorStructInfo((2, 4, 30), "float32")
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x0, w0, out_dtype="float16"),
+ relax.TensorStructInfo((2, 4, 30), "float16"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x0, w0, padding=1),
+ relax.TensorStructInfo((2, 4, 28), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x0, w0, padding=[1, 3]),
+ relax.TensorStructInfo((2, 4, 26), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x0, w0, strides=3, output_padding=1),
+ relax.TensorStructInfo((2, 4, 85), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x0, w0, strides=2),
+ relax.TensorStructInfo((2, 4, 57), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x0, w0, dilation=2),
+ relax.TensorStructInfo((2, 4, 32), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x0, w0, dilation=(2,)),
+ relax.TensorStructInfo((2, 4, 32), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x1, w0, data_layout="NWC"),
+ relax.TensorStructInfo((2, 30, 4), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x0, w0, out_layout="NWC"),
+ relax.TensorStructInfo((2, 30, 4), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x0, w1, kernel_layout="OIW"),
+ relax.TensorStructInfo((2, 4, 30), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(
+ x5, w4, data_layout="NCW16c", kernel_layout="IOW16i", out_layout="NWC16c"
+ ),
+ relax.TensorStructInfo((2, 30, 3, 16), "float32"),
+ )
+ _check_inference(
+ bb, relax.op.nn.conv1d_transpose(x2, w0), relax.TensorStructInfo(dtype="float32", ndim=3)
+ )
+ _check_inference(
+ bb, relax.op.nn.conv1d_transpose(x3, w0), relax.TensorStructInfo(dtype="float32", ndim=3)
+ )
+ _check_inference(
+ bb, relax.op.nn.conv1d_transpose(x0, w2), relax.TensorStructInfo(dtype="float32", ndim=3)
+ )
+ _check_inference(
+ bb, relax.op.nn.conv1d_transpose(x0, w3), relax.TensorStructInfo(dtype="float32", ndim=3)
+ )
+ _check_inference(
+ bb, relax.op.nn.conv1d_transpose(x4, w0), relax.TensorStructInfo(dtype="", ndim=3)
+ )
+
+
+def test_conv1d_transpose_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ c = tir.Var("c", "int64")
+ c16 = tir.Var("c16", "int64")
+ iw = tir.Var("iw", "int64")
+ ki = tir.Var("ki", "int64")
+ ko = tir.Var("ko", "int64")
+ kw = tir.Var("kw", "int64")
+ x0 = relax.Var("x", R.Tensor((n, c, iw), "float32"))
+ x1 = relax.Var("x", R.Tensor((n, c, iw, c16), "float32"))
+ w0 = relax.Var("w", R.Tensor((ki, ko, kw), "float32"))
+ w1 = relax.Var("w", R.Tensor((c, ko, kw), "float32"))
+ w2 = relax.Var("w", R.Tensor((c, ko, kw, c16), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x0, w0),
+ relax.TensorStructInfo((n, ko, iw + kw - 1), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x0, w1),
+ relax.TensorStructInfo((n, ko, iw + kw - 1), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(
+ x1, w2, data_layout="NCW16c", kernel_layout="IOW16i", out_layout="NCW"
+ ),
+ relax.TensorStructInfo((n, ko, iw + kw - 1), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x0, w0, strides=2, padding=1, dilation=2, output_padding=1),
+ relax.TensorStructInfo(
+ (n, ko, iw * 2 + kw * 2 - 4),
+ "float32",
+ ),
+ )
+
+
+def test_conv1d_transpose_infer_struct_info_shape_var():
+ bb = relax.BlockBuilder()
+ s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3))
+ s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+ s2 = relax.Var("s", relax.ShapeStructInfo(ndim=3))
+ s3 = relax.Var("s", relax.ShapeStructInfo())
+ x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+ x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+ x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32"))
+ w = relax.Var("w", relax.TensorStructInfo(s2, "float32"))
+
+ _check_inference(bb, relax.op.nn.conv1d(x0, w), relax.TensorStructInfo(dtype="float32", ndim=3))
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x1, w, data_layout="NCW16c"),
+ relax.TensorStructInfo(dtype="float32", ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x0, w, out_layout="NCW16c"),
+ relax.TensorStructInfo(dtype="float32", ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d(x2, w),
+ relax.TensorStructInfo(dtype="float32", ndim=3),
+ )
+
+
+def test_conv1d_transpose_infer_struct_info_groups():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 128, 28), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 8, 28, 16), "float32"))
+ w0 = relax.Var("w", R.Tensor((128, 6, 3), "float32"))
+ w1 = relax.Var("w", R.Tensor((16, 6, 3, 8), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x0, w0, groups=8),
+ relax.TensorStructInfo((2, 48, 30), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x0, w1, kernel_layout="IOW8i", groups=8),
+ relax.TensorStructInfo((2, 48, 30), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x1, w0, data_layout="NCW16c", groups=8),
+ relax.TensorStructInfo((2, 3, 30, 16), "float32"),
+ )
+
+
+def test_conv1d_transpose_infer_struct_info_symbolic_groups():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ ic = tir.Var("c", "int64")
+ oc = tir.Var("oc", "int64")
+ x = relax.Var("x", R.Tensor((n, ic * 4, 28), "float32"))
+ w0 = relax.Var("w", R.Tensor((ic, oc, 3), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x, w0, groups=4),
+ relax.TensorStructInfo((n, oc * 4, 30), "float32"),
+ )
+
+
+def test_conv1d_transpose_infer_struct_info_input_channel_group_incompatible():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ ic = tir.Var("c", "int64")
+ oc = tir.Var("oc", "int64")
+ x0 = relax.Var("x", R.Tensor((2, 128, 28), "float32"))
+ w0 = relax.Var("w", R.Tensor((128, 20, 3), "float32"))
+ x1 = relax.Var("x", R.Tensor((n, ic, 28), "float32"))
+ w1 = relax.Var("w", R.Tensor((ic - 1, oc, 3), "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d_transpose(x0, w0, groups=6))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d_transpose(x1, w1, groups=6))
+
+
+def test_conv1d_transpose_non_positive_group():
+ x = relax.Var("x", R.Tensor((2, 128, 28), "float32"))
+ w = relax.Var("w", R.Tensor((128, 16, 3), "float32"))
+
+ with pytest.raises(TVMError):
+ relax.op.nn.conv1d_transpose(x, w, groups=0)
+ with pytest.raises(TVMError):
+ relax.op.nn.conv1d_transpose(x, w, groups=-2)
+
+
+def test_conv1d_transpose_infer_struct_info_more_input_dtype():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28), "float16"))
+ w0 = relax.Var("w", R.Tensor((3, 4, 3), "float16"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 28), "float64"))
+ w1 = relax.Var("w", R.Tensor((3, 4, 3), "float64"))
+ x2 = relax.Var("x", R.Tensor((2, 3, 28), "int8"))
+ w2 = relax.Var("w", R.Tensor((3, 4, 3), "int8"))
+ x3 = relax.Var("x", R.Tensor((2, 3, 28), "int32"))
+ w3 = relax.Var("w", R.Tensor((3, 4, 3), "int32"))
+
+ _check_inference(
+ bb, relax.op.nn.conv1d_transpose(x0, w0), relax.TensorStructInfo((2, 4, 30), "float16")
+ )
+ _check_inference(
+ bb, relax.op.nn.conv1d_transpose(x1, w1), relax.TensorStructInfo((2, 4, 30), "float64")
+ )
+ _check_inference(
+ bb, relax.op.nn.conv1d_transpose(x2, w2), relax.TensorStructInfo((2, 4, 30), "int8")
+ )
+ _check_inference(
+ bb, relax.op.nn.conv1d_transpose(x3, w3), relax.TensorStructInfo((2, 4, 30), "int32")
+ )
+
+
+def test_conv1d_transpose_unequal_input_channel():
+ bb = relax.BlockBuilder()
+ ic = tir.Var("ic", "int64")
+ x0 = relax.Var("x", R.Tensor([2, 3, 28], "float32"))
+ w0 = relax.Var("w", R.Tensor([4, 3, 3], "float32"))
+ x1 = relax.Var("x", R.Tensor([2, ic, 28], "float32"))
+ w1 = relax.Var("w", R.Tensor([ic + 2, 4, 3], "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d_transpose(x0, w0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d_transpose(x1, w1))
+
+
+def test_conv1d_transpose_wrong_output_padding():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor([2, 3, 28], "float32"))
+ w0 = relax.Var("w", R.Tensor([3, 4, 3], "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d_transpose(x0, w0, strides=2, output_padding=2))
+
+
+def test_conv1d_transpose_stride_padding_dilation_int64():
+ x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ w = relax.Var("w", R.Tensor((3, 4, 3), "float32"))
+ conv1d = relax.op.nn.conv1d_transpose(x, w, strides=1, padding=1, dilation=1)
+
+ assert conv1d.attrs.strides[0].dtype == "int64"
+ assert conv1d.attrs.padding[0].dtype == "int64"
+ assert conv1d.attrs.dilation[0].dtype == "int64"
+
+
+def test_conv1d_transpose_wrong_strides_padding_dilation_length():
+ x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ w = relax.Var("w", R.Tensor((3, 4, 3), "float32"))
+ with pytest.raises(TVMError):
+ relax.op.nn.conv1d_transpose(x, w, strides=(1, 2))
+ with pytest.raises(TVMError):
+ relax.op.nn.conv1d_transpose(x, w, padding=(1, 2, 3))
+ with pytest.raises(TVMError):
+ relax.op.nn.conv1d_transpose(x, w, dilation=(1, 2))
+
+
+def test_conv1d_transpose_infer_struct_info_wrong_layout_string():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ w = relax.Var("w", R.Tensor((3, 4, 3), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d_transpose(x, w, data_layout="IOW"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d_transpose(x, w, kernel_layout="NWC"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d_transpose(x, w, out_layout="OWI"))
+
+
+def test_conv1d_transpose_dtype_mismatch():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ w = relax.Var("w", R.Tensor((3, 4, 3), "int8"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d_transpose(x, w))
+
+
+def test_conv1d_transpose_wrong_input_ndim():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 28, 3), "float32"))
+ x2 = relax.Var("x", R.Tensor("float32", ndim=2))
+ w0 = relax.Var("w", R.Tensor((3, 4, 3), "float32"))
+ w1 = relax.Var("w", R.Tensor((3, 4, 6, 3), "float32"))
+ w2 = relax.Var("w", R.Tensor("float32", ndim=5))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d_transpose(x0, w1))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d_transpose(x0, w1, data_layout="NCW16c"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d_transpose(x0, w2))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d_transpose(x1, w0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d_transpose(x2, w0))
+
+
+def test_conv1d_transpose_infer_struct_info_wrong_input_type():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
+ x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28)))
+ w0 = relax.Var("w", R.Tensor((3, 4, 3), "float32"))
+ w1 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((3, 4, 3), "float32")))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d_transpose(x0, w1))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv1d_transpose(x1, w0))
+
+
def test_conv2d_infer_struct_info():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py b/tests/python/relax/test_tvmscript_parser_op_nn.py
index 014d524751..bba08d4d84 100644
--- a/tests/python/relax/test_tvmscript_parser_op_nn.py
+++ b/tests/python/relax/test_tvmscript_parser_op_nn.py
@@ -53,6 +53,24 @@ def test_conv1d():
_check(foo, bb.get()["foo"])
+def test_conv1d_transpose():
+ @R.function
+ def foo(
+ x: R.Tensor((2, 3, 228), "float16"), w: R.Tensor((3, 16, 5), "float16")
+ ) -> R.Tensor((2, 16, 232), "float16"):
+ gv: R.Tensor((2, 16, 232), "float16") = R.nn.conv1d_transpose(x, w, out_dtype="float16")
+ return gv
+
+ x = relax.Var("x", R.Tensor([2, 3, 228], "float16"))
+ w = relax.Var("w", R.Tensor([3, 16, 5], "float16"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [x, w]):
+ gv = bb.emit(relax.op.nn.conv1d_transpose(x, w, out_dtype="float16"))
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
def test_conv2d():
@R.function
def foo(