You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/04/11 02:49:20 UTC
[tvm] branch main updated: [BYOC-DNNL] enable conv3d->bn folding (#10837)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 45f3d4a521 [BYOC-DNNL] enable conv3d->bn folding (#10837)
45f3d4a521 is described below
commit 45f3d4a521ec476cd9960e3d2de4f66bde61bf23
Author: Ivy Zhang <ya...@intel.com>
AuthorDate: Mon Apr 11 10:49:13 2022 +0800
[BYOC-DNNL] enable conv3d->bn folding (#10837)
* support conv3d bn folding
* add test case for fold_scale_axis
* modify lint
* remove test cases
* unify conv2d 3d impls, and add test cases.
---
src/relay/transforms/fold_scale_axis.cc | 108 ++++++++++----
src/relay/transforms/pattern_utils.h | 18 +--
src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 7 +-
tests/python/relay/test_pass_fold_scale_axis.py | 178 ++++++++++++++++++++++++
4 files changed, 272 insertions(+), 39 deletions(-)
diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc
index 4b94159fe3..f4f05badec 100644
--- a/src/relay/transforms/fold_scale_axis.cc
+++ b/src/relay/transforms/fold_scale_axis.cc
@@ -29,6 +29,7 @@
#include <tvm/relay/transform.h>
#include <tvm/tir/data_layout.h>
+#include "../backend/utils.h"
#include "../op/tensor/transform.h"
#include "pass_utils.h"
#include "pattern_utils.h"
@@ -492,11 +493,11 @@ RELAY_REGISTER_OP("multiply")
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", MultiplyForwardRewrite);
// Consumer operators
-// Conv2D send out requirement of axis folding.
-Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
+// Conv send out requirement of axis folding.
+template <typename ATTRS>
+Array<Message> ConvForwardPrep(const Call& call, const ATTRS* param, const Message& out_message) {
// TODO(tvm-team) support general data layout
// by transforming weight
- const auto* param = call->attrs.as<Conv2DAttrs>();
ICHECK(param != nullptr);
Layout data_layout(param->data_layout);
Layout kernel_layout(param->kernel_layout);
@@ -512,8 +513,8 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
//
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
- bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
- if (param->groups == 1 || is_depthwise_conv2d) {
+ bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
+ if (param->groups == 1 || is_depthwise_conv) {
auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
@@ -529,14 +530,14 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
}
// Conv2D consumes the scale axis during transformation.
-Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
- const Message& message) {
+template <typename ATTRS>
+Expr ConvForwardRewrite(const Call& ref_call, const ATTRS* param, const Array<Expr>& new_args,
+ const Message& message) {
// if data do not have scale, normal transform path.
const auto* sdata = new_args[0].as<ScaledExprNode>();
const auto* sweight = new_args[1].as<ScaledExprNode>();
if (sdata == nullptr) return Expr();
if (sweight != nullptr) return Expr();
- const auto* param = ref_call->attrs.as<Conv2DAttrs>();
ICHECK(param != nullptr);
Layout data_layout(param->data_layout);
Layout kernel_layout(param->kernel_layout);
@@ -552,13 +553,13 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
ICHECK(is_simple || is_blocking);
// Check it must be depthwise or full conv2d.
- bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
- ICHECK(param->groups == 1 || is_depthwise_conv2d);
+ bool is_depthwise_conv = IsDepthwiseConv(ref_call, param, kernel_layout);
+ ICHECK(param->groups == 1 || is_depthwise_conv);
Expr weight = new_args[1];
// match the ic_axis
- if (is_depthwise_conv2d) {
+ if (is_depthwise_conv) {
if (is_simple) {
Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ko_axis});
weight = Multiply(weight, scale);
@@ -580,14 +581,38 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
if (!weight.defined()) return Expr();
}
}
- // return transformed conv2d
+ // return transformed conv
return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
}
-RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", Conv2DForwardPrep);
+Array<Message> PreConvForwardPrep(const Call& call, const Message& out_message) {
+ if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
+ const auto* param = call->attrs.as<Conv2DAttrs>();
+ return ConvForwardPrep(call, param, out_message);
+ }
+ const auto* param = call->attrs.as<Conv3DAttrs>();
+ return ConvForwardPrep(call, param, out_message);
+}
+
+Expr PreConvForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
+ const Message& message) {
+ if (backend::IsOp(ref_call.as<CallNode>(), "nn.conv2d")) {
+ const auto* param = ref_call->attrs.as<Conv2DAttrs>();
+ return ConvForwardRewrite(ref_call, param, new_args, message);
+ }
+ const auto* param = ref_call->attrs.as<Conv3DAttrs>();
+ return ConvForwardRewrite(ref_call, param, new_args, message);
+}
+
+RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", PreConvForwardPrep);
RELAY_REGISTER_OP("nn.conv2d")
- .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", Conv2DForwardRewrite);
+ .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", PreConvForwardRewrite);
+
+RELAY_REGISTER_OP("nn.conv3d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", PreConvForwardPrep);
+
+RELAY_REGISTER_OP("nn.conv3d")
+ .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", PreConvForwardRewrite);
// Dense send out requirement of axis folding.
Array<Message> DenseForwardPrep(const Call& call, const Message& out_message) {
@@ -937,9 +962,9 @@ RELAY_REGISTER_OP("multiply")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", MultiplyBackwardTransform);
// Consumer operators
-// Conv2D send out requirement of axis folding.
-Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages) {
- const auto* param = call->attrs.as<Conv2DAttrs>();
+// Conv send out requirement of axis folding.
+template <typename ATTRS>
+Message ConvBackwardPrep(const Call& call, const ATTRS* param, const Array<Message>& in_messages) {
ICHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout);
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
@@ -952,10 +977,10 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
// By using a unified layout transformation.
// We only need to change the Prep and Mutate function.
//
- // only handle depthwise or full conv2d.
+ // only handle depthwise or full conv.
// TODO(tvm-team) handle grouped conv by reshape + bcast
- bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
- if (param->groups == 1 || is_depthwise_conv2d) {
+ bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
+ if (param->groups == 1 || is_depthwise_conv) {
auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
@@ -970,13 +995,13 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
return NullValue<Message>();
}
-// Conv2D consumes the scale axis during transformation.
-Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Expr& scale,
- const BackwardTransformer& transformer) {
+// Conv consumes the scale axis during transformation.
+template <typename ATTRS>
+Expr ConvBackwardTransform(const Call& call, const ATTRS* param, const Message& message,
+ const Expr& scale, const BackwardTransformer& transformer) {
if (!message.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
- const auto* param = call->attrs.as<Conv2DAttrs>();
ICHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout);
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
@@ -988,9 +1013,9 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
- // Check it must be depthwise or full conv2d.
- bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
- ICHECK(param->groups == 1 || is_depthwise_conv2d);
+ // Check it must be depthwise or full conv.
+ bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
+ ICHECK(param->groups == 1 || is_depthwise_conv);
bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0);
bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0);
ICHECK(is_simple || is_blocking);
@@ -1012,11 +1037,36 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
return Call(call->op, {data, weight}, call->attrs, call->type_args);
}
+Message PreConvBackwardPrep(const Call& call, const Array<Message>& in_messages) {
+ if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
+ const auto* param = call->attrs.as<Conv2DAttrs>();
+ return ConvBackwardPrep(call, param, in_messages);
+ }
+ const auto* param = call->attrs.as<Conv3DAttrs>();
+ return ConvBackwardPrep(call, param, in_messages);
+}
+
+Expr PreConvBackwardTransform(const Call& call, const Message& message, const Expr& scale,
+ const BackwardTransformer& transformer) {
+ if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
+ const auto* param = call->attrs.as<Conv2DAttrs>();
+ return ConvBackwardTransform(call, param, message, scale, transformer);
+ }
+ const auto* param = call->attrs.as<Conv3DAttrs>();
+ return ConvBackwardTransform(call, param, message, scale, transformer);
+}
+
RELAY_REGISTER_OP("nn.conv2d")
- .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", Conv2DBackwardPrep);
+ .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", PreConvBackwardPrep);
RELAY_REGISTER_OP("nn.conv2d")
- .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv2DBackwardTransform);
+ .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", PreConvBackwardTransform);
+
+RELAY_REGISTER_OP("nn.conv3d")
+ .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", PreConvBackwardPrep);
+
+RELAY_REGISTER_OP("nn.conv3d")
+ .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", PreConvBackwardTransform);
Message BiasAddBackwardPrep(const Call& call, const Array<Message>& in_messages) {
const BiasAddAttrs* attrs = call->attrs.as<BiasAddAttrs>();
diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h
index cf97d2c25d..6a773d7f3c 100644
--- a/src/relay/transforms/pattern_utils.h
+++ b/src/relay/transforms/pattern_utils.h
@@ -44,6 +44,7 @@
#include <utility>
#include <vector>
+#include "../backend/utils.h"
#include "../op/make_op.h"
namespace tvm {
@@ -183,16 +184,17 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, int target_ndim, const Array<Intege
}
/*!
- * \brief Check if the call is depthwise conv2d.
+ * \brief Check if the call is depthwise conv3d.
*
- * \param call The conv2d call.
- * \param param The conv2d attributes.
- * \return Whether it is depthwise_conv2d.
+ * \param call The conv call.
+ * \param param The conv attributes.
+ * \return Whether it is depthwise_conv3d.
*/
-inline bool IsDepthwiseConv2D(const Call& call, const Conv2DAttrs* param,
- const Layout& kernel_layout) {
- static const Layout kOIHW("OIHW");
- const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIHW);
+template <typename ATTRS>
+inline bool IsDepthwiseConv(const Call& call, ATTRS param, const Layout& kernel_layout) {
+ static const Layout kOIXX =
+ backend::IsOp(call.as<CallNode>(), "nn.conv2d") ? Layout("OIHW") : Layout("OIDHW");
+ const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIXX);
auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape);
return tir::is_const_int(wshape[0], param->groups) && tir::is_const_int(wshape[1], 1);
}
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index 7067806142..dc2afecbaf 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -157,6 +157,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
{"IODHW8i8o", tag::any},
{"ODHWI8o", tag::Odhwi8o},
{"ODHWI16o", tag::Odhwi16o},
+ {"ODHWI32o", tag::Odhwi32o},
+ {"ODHWI48o", tag::Odhwi48o},
+ {"ODHWI64o", tag::Odhwi64o},
};
bool ParsingOpName(const std::string op_name, dnnl::primitive_attr attr) {
@@ -342,7 +345,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
if (layout_dict.find(kernel_layout) == layout_dict.end()) {
layout_dict.insert({kernel_layout, tag::any});
- LOG(WARNING) << "Unregistered kernel layout for conv: " << data_layout
+ LOG(WARNING) << "Unregistered kernel layout for conv: " << kernel_layout
<< ", transfer to tag::any";
}
@@ -382,7 +385,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
auto conv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any);
auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any);
- // Covn2d description.
+ // Conv description.
auto conv_desc =
has_bias ? dnnl::convolution_forward::desc(
dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct,
diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py
index b5e7b1c816..12fc722d86 100644
--- a/tests/python/relay/test_pass_fold_scale_axis.py
+++ b/tests/python/relay/test_pass_fold_scale_axis.py
@@ -1028,6 +1028,182 @@ def test_fold_bwd_bias_add():
check((2, 4, 10, 10), 4)
+def test_fold_fwd_conv3d():
+ """Conv3d testcase."""
+
+ def before(x, conv_weight, in_bias, in_scale, channels, blocking):
+ args = [x, conv_weight, in_bias]
+ x = relay.multiply(x, in_scale)
+ x = relay.nn.relu(x)
+ x = relay.add(x, in_bias)
+ y = relay.nn.conv3d(
+ x,
+ conv_weight,
+ channels=channels,
+ kernel_size=(3, 3, 3),
+ padding=(1, 1, 1),
+ data_layout="NCDHW{}c".format(blocking[0]) if blocking else "NCDHW",
+ kernel_layout="OIDHW2i{}o".format(blocking[1]) if blocking else "OIDHW",
+ )
+
+ return relay.Function(args, y)
+
+ def expected(x, conv_weight, in_bias, in_scale, in_channels, channels, blocking):
+ # use a fixed order of args so alpha equal check can pass
+ args = [x, conv_weight, in_bias]
+ if blocking:
+ squeezed_scale = relay.squeeze(in_scale, axis=[0, 2, 3, 4])
+ x = relay.nn.relu(x)
+ in_bias = relay.divide(
+ in_bias,
+ relay.reshape(
+ squeezed_scale, (1, in_channels // blocking[0], 1, 1, 1, blocking[0])
+ ),
+ ) # NCHWc
+ x = relay.add(x, in_bias)
+ conv_weight = relay.multiply(
+ conv_weight, relay.reshape(squeezed_scale, (1, in_channels // 2, 1, 1, 1, 2, 1))
+ ) # OIHWio
+ else:
+ squeezed_scale = relay.squeeze(in_scale, axis=[1, 2, 3])
+ x = relay.nn.relu(x)
+ in_bias = relay.divide(
+ in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)
+ )
+ x = relay.add(x, in_bias)
+ conv_weight = relay.multiply(
+ conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)
+ )
+
+ y = relay.nn.conv3d(
+ x,
+ conv_weight,
+ channels=channels,
+ kernel_size=(3, 3, 3),
+ padding=(1, 1, 1),
+ data_layout="NCDHW{}c".format(blocking[0]) if blocking else "NCDHW",
+ kernel_layout="OIDHW2i{}o".format(blocking[1]) if blocking else "OIDHW",
+ )
+ return relay.Function(args, y)
+
+ def check(shape, channels, blocking):
+ x = relay.var("x", shape=shape)
+ weight = relay.var("weight")
+ if blocking:
+ in_channels = shape[1] * shape[-1]
+ in_bias = relay.var(
+ "in_bias", shape=(1, in_channels // blocking[0], 1, 1, 1, blocking[0])
+ )
+ in_scale = relay.const(
+ _get_positive_scale((1, in_channels // blocking[0], 1, 1, 1, blocking[0]))
+ )
+ else:
+ in_channels = shape[1]
+ in_bias = relay.var("in_bias", shape=(in_channels, 1, 1, 1))
+ in_scale = relay.const(_get_positive_scale((in_channels, 1, 1, 1)))
+ y1 = before(x, weight, in_bias, in_scale, channels, blocking)
+ y1 = run_opt_pass(y1, transform.InferType())
+ type_dict = {x.name_hint: x.checked_type for x in y1.params}
+ weight = relay.var("weight", type_dict["weight"])
+ y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
+ y1_expected = expected(x, weight, in_bias, in_scale, in_channels, channels, blocking)
+
+ y1_folded = run_opt_pass(y1_folded, transform.InferType())
+ y1_expected = run_opt_pass(y1_expected, transform.InferType())
+ assert tvm.ir.structural_equal(y1_folded, y1_expected)
+
+ check((2, 4, 10, 10, 10), 2, None)
+ check((2, 2, 10, 10, 10, 2), 8, (2, 4))
+
+
+def test_fold_bwd_conv3d():
+ """Conv3d testcase."""
+
+ def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
+ args = [x, conv_weight, out_bias]
+ if blocking:
+ out_bias = relay.reshape(out_bias, (1, channels // blocking[1], 1, 1, 1, blocking[1]))
+ else:
+ out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=3)
+ y = relay.nn.conv3d(
+ x,
+ conv_weight,
+ channels=channels,
+ kernel_size=(3, 3, 3),
+ padding=(1, 1, 1),
+ data_layout="NCDHW{}c".format(blocking[0]) if blocking else "NCDHW",
+ kernel_layout="OIDHW1i{}o".format(blocking[1]) if blocking else "OIDHW",
+ )
+ y = relay.add(y, out_bias)
+ y = relay.nn.relu(y)
+ if blocking:
+ out_scale = relay.reshape(out_scale, (1, channels // blocking[1], 1, 1, 1, blocking[1]))
+ y = relay.multiply(y, out_scale)
+ return relay.Function(args, y)
+
+ def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
+ # use a fixed order of args so alpha equal check can pass
+ args = [x, conv_weight, out_bias]
+ if blocking:
+ out_bias = relay.reshape(out_bias, (1, channels // blocking[1], 1, 1, 1, blocking[1]))
+ out_scale = relay.reshape(out_scale, (1, channels // blocking[1], 1, 1, 1, blocking[1]))
+ squeezed_scale = relay.squeeze(out_scale, axis=[0, 2, 3, 4])
+ conv_weight = relay.multiply(
+ conv_weight,
+ relay.reshape(
+ squeezed_scale, (channels // blocking[1], 1, 1, 1, 1, 1, blocking[1])
+ ),
+ )
+ else:
+ out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=3)
+ squeezed_scale = relay.squeeze(out_scale, axis=[1, 2, 3])
+ conv_weight = relay.multiply(
+ conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=4)
+ )
+
+ y = relay.nn.conv3d(
+ x,
+ conv_weight,
+ channels=channels,
+ kernel_size=(3, 3, 3),
+ padding=(1, 1, 1),
+ data_layout="NCDHW{}c".format(blocking[0]) if blocking else "NCDHW",
+ kernel_layout="OIDHW1i{}o".format(blocking[1]) if blocking else "OIDHW",
+ )
+ if blocking:
+ out_bias = relay.multiply(
+ out_bias,
+ relay.reshape(squeezed_scale, (1, channels // blocking[1], 1, 1, 1, blocking[1])),
+ )
+ else:
+ out_bias = relay.multiply(
+ out_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)
+ )
+ y = relay.add(y, out_bias)
+ y = relay.nn.relu(y)
+ return relay.Function(args, y)
+
+ def check(shape, in_channels, channels, blocking):
+ x = relay.var("x", shape=shape)
+ weight = relay.var("weight")
+ out_bias = relay.var("out_bias", shape=(channels,))
+ if blocking:
+ out_scale = relay.const(_get_positive_scale((channels,)))
+ else:
+ out_scale = relay.const(_get_positive_scale((channels, 1, 1, 1)))
+ y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking)
+ y1 = run_opt_pass(y1, transform.InferType())
+ type_dict = {x.name_hint: x.checked_type for x in y1.params}
+ weight = relay.var("weight", type_dict["weight"])
+ y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
+ y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking)
+ y1_expected = run_opt_pass(y1_expected, transform.InferType())
+ assert tvm.ir.structural_equal(y1_folded, y1_expected)
+
+ check((2, 4, 10, 10, 10), 4, 8, None)
+ check((2, 2, 10, 10, 10, 16), 32, 64, (16, 16))
+
+
if __name__ == "__main__":
test_fold_fwd_simple()
test_fold_fwd_dual_path()
@@ -1043,3 +1219,5 @@ if __name__ == "__main__":
test_fold_bwd_negative_scale()
test_fold_bwd_dense()
test_fold_bwd_bias_add()
+ test_fold_fwd_conv3d()
+ test_fold_bwd_conv3d()