You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/04/11 02:49:20 UTC

[tvm] branch main updated: [BYOC-DNNL] enable conv3d->bn folding (#10837)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 45f3d4a521 [BYOC-DNNL] enable conv3d->bn folding (#10837)
45f3d4a521 is described below

commit 45f3d4a521ec476cd9960e3d2de4f66bde61bf23
Author: Ivy Zhang <ya...@intel.com>
AuthorDate: Mon Apr 11 10:49:13 2022 +0800

    [BYOC-DNNL] enable conv3d->bn folding (#10837)
    
    * support conv3d bn folding
    
    * add test case for fold_scale_axis
    
    * modify lint
    
    * remove test cases
    
    * unify conv2d 3d impls, and add test cases.
---
 src/relay/transforms/fold_scale_axis.cc         | 108 ++++++++++----
 src/relay/transforms/pattern_utils.h            |  18 +--
 src/runtime/contrib/dnnl/dnnl_json_runtime.cc   |   7 +-
 tests/python/relay/test_pass_fold_scale_axis.py | 178 ++++++++++++++++++++++++
 4 files changed, 272 insertions(+), 39 deletions(-)

diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc
index 4b94159fe3..f4f05badec 100644
--- a/src/relay/transforms/fold_scale_axis.cc
+++ b/src/relay/transforms/fold_scale_axis.cc
@@ -29,6 +29,7 @@
 #include <tvm/relay/transform.h>
 #include <tvm/tir/data_layout.h>
 
+#include "../backend/utils.h"
 #include "../op/tensor/transform.h"
 #include "pass_utils.h"
 #include "pattern_utils.h"
@@ -492,11 +493,11 @@ RELAY_REGISTER_OP("multiply")
     .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", MultiplyForwardRewrite);
 
 // Consumer operators
-// Conv2D send out requirement of axis folding.
-Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
+// Conv send out requirement of axis folding.
+template <typename ATTRS>
+Array<Message> ConvForwardPrep(const Call& call, const ATTRS* param, const Message& out_message) {
   // TODO(tvm-team) support general data layout
   // by transforming weight
-  const auto* param = call->attrs.as<Conv2DAttrs>();
   ICHECK(param != nullptr);
   Layout data_layout(param->data_layout);
   Layout kernel_layout(param->kernel_layout);
@@ -512,8 +513,8 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
   //
   // only handle depthwise or full conv2d.
   // TODO(tvm-team) handle grouped conv by reshape + bcast
-  bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
-  if (param->groups == 1 || is_depthwise_conv2d) {
+  bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
+  if (param->groups == 1 || is_depthwise_conv) {
     auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
     auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
     if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) ||     // simple layout
@@ -529,14 +530,14 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
 }
 
 // Conv2D consumes the scale axis during transformation.
-Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
-                          const Message& message) {
+template <typename ATTRS>
+Expr ConvForwardRewrite(const Call& ref_call, const ATTRS* param, const Array<Expr>& new_args,
+                        const Message& message) {
   // if data do not have scale, normal transform path.
   const auto* sdata = new_args[0].as<ScaledExprNode>();
   const auto* sweight = new_args[1].as<ScaledExprNode>();
   if (sdata == nullptr) return Expr();
   if (sweight != nullptr) return Expr();
-  const auto* param = ref_call->attrs.as<Conv2DAttrs>();
   ICHECK(param != nullptr);
   Layout data_layout(param->data_layout);
   Layout kernel_layout(param->kernel_layout);
@@ -552,13 +553,13 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
   ICHECK(is_simple || is_blocking);
 
   // Check it must be depthwise or full conv2d.
-  bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
-  ICHECK(param->groups == 1 || is_depthwise_conv2d);
+  bool is_depthwise_conv = IsDepthwiseConv(ref_call, param, kernel_layout);
+  ICHECK(param->groups == 1 || is_depthwise_conv);
 
   Expr weight = new_args[1];
 
   // match the ic_axis
-  if (is_depthwise_conv2d) {
+  if (is_depthwise_conv) {
     if (is_simple) {
       Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ko_axis});
       weight = Multiply(weight, scale);
@@ -580,14 +581,38 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
       if (!weight.defined()) return Expr();
     }
   }
-  // return transformed conv2d
+  // return transformed conv
   return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
 }
 
-RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", Conv2DForwardPrep);
+Array<Message> PreConvForwardPrep(const Call& call, const Message& out_message) {
+  if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
+    const auto* param = call->attrs.as<Conv2DAttrs>();
+    return ConvForwardPrep(call, param, out_message);
+  }
+  const auto* param = call->attrs.as<Conv3DAttrs>();
+  return ConvForwardPrep(call, param, out_message);
+}
+
+Expr PreConvForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
+                           const Message& message) {
+  if (backend::IsOp(ref_call.as<CallNode>(), "nn.conv2d")) {
+    const auto* param = ref_call->attrs.as<Conv2DAttrs>();
+    return ConvForwardRewrite(ref_call, param, new_args, message);
+  }
+  const auto* param = ref_call->attrs.as<Conv3DAttrs>();
+  return ConvForwardRewrite(ref_call, param, new_args, message);
+}
+
+RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", PreConvForwardPrep);
 
 RELAY_REGISTER_OP("nn.conv2d")
-    .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", Conv2DForwardRewrite);
+    .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", PreConvForwardRewrite);
+
+RELAY_REGISTER_OP("nn.conv3d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", PreConvForwardPrep);
+
+RELAY_REGISTER_OP("nn.conv3d")
+    .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", PreConvForwardRewrite);
 
 // Dense send out requirement of axis folding.
 Array<Message> DenseForwardPrep(const Call& call, const Message& out_message) {
@@ -937,9 +962,9 @@ RELAY_REGISTER_OP("multiply")
     .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", MultiplyBackwardTransform);
 
 // Consumer operators
-// Conv2D send out requirement of axis folding.
-Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages) {
-  const auto* param = call->attrs.as<Conv2DAttrs>();
+// Conv send out requirement of axis folding.
+template <typename ATTRS>
+Message ConvBackwardPrep(const Call& call, const ATTRS* param, const Array<Message>& in_messages) {
   ICHECK(param != nullptr);
   Layout kernel_layout(param->kernel_layout);
   Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
@@ -952,10 +977,10 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
   // By using a unified layout transformation.
   // We only need to change the Prep and Mutate function.
   //
-  // only handle depthwise or full conv2d.
+  // only handle depthwise or full conv.
   // TODO(tvm-team) handle grouped conv by reshape + bcast
-  bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
-  if (param->groups == 1 || is_depthwise_conv2d) {
+  bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
+  if (param->groups == 1 || is_depthwise_conv) {
     auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
     auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
     if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) ||     // simple layout
@@ -970,13 +995,13 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
   return NullValue<Message>();
 }
 
-// Conv2D consumes the scale axis during transformation.
-Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Expr& scale,
-                             const BackwardTransformer& transformer) {
+// Conv consumes the scale axis during transformation.
+template <typename ATTRS>
+Expr ConvBackwardTransform(const Call& call, const ATTRS* param, const Message& message,
+                           const Expr& scale, const BackwardTransformer& transformer) {
   if (!message.defined()) {
     return transformer->NormalCallTransform(call.operator->());
   }
-  const auto* param = call->attrs.as<Conv2DAttrs>();
   ICHECK(param != nullptr);
   Layout kernel_layout(param->kernel_layout);
   Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
@@ -988,9 +1013,9 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
   int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
   int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
   int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
-  // Check it must be depthwise or full conv2d.
-  bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
-  ICHECK(param->groups == 1 || is_depthwise_conv2d);
+  // Check it must be depthwise or full conv.
+  bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
+  ICHECK(param->groups == 1 || is_depthwise_conv);
   bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0);
   bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0);
   ICHECK(is_simple || is_blocking);
@@ -1012,11 +1037,36 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
   return Call(call->op, {data, weight}, call->attrs, call->type_args);
 }
 
+Message PreConvBackwardPrep(const Call& call, const Array<Message>& in_messages) {
+  if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
+    const auto* param = call->attrs.as<Conv2DAttrs>();
+    return ConvBackwardPrep(call, param, in_messages);
+  }
+  const auto* param = call->attrs.as<Conv3DAttrs>();
+  return ConvBackwardPrep(call, param, in_messages);
+}
+
+Expr PreConvBackwardTransform(const Call& call, const Message& message, const Expr& scale,
+                              const BackwardTransformer& transformer) {
+  if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
+    const auto* param = call->attrs.as<Conv2DAttrs>();
+    return ConvBackwardTransform(call, param, message, scale, transformer);
+  }
+  const auto* param = call->attrs.as<Conv3DAttrs>();
+  return ConvBackwardTransform(call, param, message, scale, transformer);
+}
+
 RELAY_REGISTER_OP("nn.conv2d")
-    .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", Conv2DBackwardPrep);
+    .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", PreConvBackwardPrep);
 
 RELAY_REGISTER_OP("nn.conv2d")
-    .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv2DBackwardTransform);
+    .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", PreConvBackwardTransform);
+
+RELAY_REGISTER_OP("nn.conv3d")
+    .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", PreConvBackwardPrep);
+
+RELAY_REGISTER_OP("nn.conv3d")
+    .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", PreConvBackwardTransform);
 
 Message BiasAddBackwardPrep(const Call& call, const Array<Message>& in_messages) {
   const BiasAddAttrs* attrs = call->attrs.as<BiasAddAttrs>();
diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h
index cf97d2c25d..6a773d7f3c 100644
--- a/src/relay/transforms/pattern_utils.h
+++ b/src/relay/transforms/pattern_utils.h
@@ -44,6 +44,7 @@
 #include <utility>
 #include <vector>
 
+#include "../backend/utils.h"
 #include "../op/make_op.h"
 
 namespace tvm {
@@ -183,16 +184,17 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, int target_ndim, const Array<Intege
 }
 
 /*!
- * \brief Check if the call is depthwise conv2d.
+ * \brief Check if the call is depthwise conv3d.
  *
- * \param call The conv2d call.
- * \param param The conv2d attributes.
- * \return Whether it is depthwise_conv2d.
+ * \param call The conv call.
+ * \param param The conv attributes.
+ * \return Whether it is depthwise_conv3d.
  */
-inline bool IsDepthwiseConv2D(const Call& call, const Conv2DAttrs* param,
-                              const Layout& kernel_layout) {
-  static const Layout kOIHW("OIHW");
-  const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIHW);
+template <typename ATTRS>
+inline bool IsDepthwiseConv(const Call& call, ATTRS param, const Layout& kernel_layout) {
+  static const Layout kOIXX =
+      backend::IsOp(call.as<CallNode>(), "nn.conv2d") ? Layout("OIHW") : Layout("OIDHW");
+  const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIXX);
   auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape);
   return tir::is_const_int(wshape[0], param->groups) && tir::is_const_int(wshape[1], 1);
 }
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index 7067806142..dc2afecbaf 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -157,6 +157,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
       {"IODHW8i8o", tag::any},
       {"ODHWI8o", tag::Odhwi8o},
       {"ODHWI16o", tag::Odhwi16o},
+      {"ODHWI32o", tag::Odhwi32o},
+      {"ODHWI48o", tag::Odhwi48o},
+      {"ODHWI64o", tag::Odhwi64o},
   };
 
   bool ParsingOpName(const std::string op_name, dnnl::primitive_attr attr) {
@@ -342,7 +345,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
 
     if (layout_dict.find(kernel_layout) == layout_dict.end()) {
       layout_dict.insert({kernel_layout, tag::any});
-      LOG(WARNING) << "Unregistered kernel layout for conv: " << data_layout
+      LOG(WARNING) << "Unregistered kernel layout for conv: " << kernel_layout
                    << ", transfer to tag::any";
     }
 
@@ -382,7 +385,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     auto conv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any);
     auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any);
 
-    // Covn2d description.
+    // Conv description.
     auto conv_desc =
         has_bias ? dnnl::convolution_forward::desc(
                        dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct,
diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py
index b5e7b1c816..12fc722d86 100644
--- a/tests/python/relay/test_pass_fold_scale_axis.py
+++ b/tests/python/relay/test_pass_fold_scale_axis.py
@@ -1028,6 +1028,182 @@ def test_fold_bwd_bias_add():
     check((2, 4, 10, 10), 4)
 
 
+def test_fold_fwd_conv3d():
+    """Conv3d testcase."""
+
+    def before(x, conv_weight, in_bias, in_scale, channels, blocking):
+        args = [x, conv_weight, in_bias]
+        x = relay.multiply(x, in_scale)
+        x = relay.nn.relu(x)
+        x = relay.add(x, in_bias)
+        y = relay.nn.conv3d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3, 3),
+            padding=(1, 1, 1),
+            data_layout="NCDHW{}c".format(blocking[0]) if blocking else "NCDHW",
+            kernel_layout="OIDHW2i{}o".format(blocking[1]) if blocking else "OIDHW",
+        )
+
+        return relay.Function(args, y)
+
+    def expected(x, conv_weight, in_bias, in_scale, in_channels, channels, blocking):
+        # use a fixed order of args so alpha equal check can pass
+        args = [x, conv_weight, in_bias]
+        if blocking:
+            squeezed_scale = relay.squeeze(in_scale, axis=[0, 2, 3, 4])
+            x = relay.nn.relu(x)
+            in_bias = relay.divide(
+                in_bias,
+                relay.reshape(
+                    squeezed_scale, (1, in_channels // blocking[0], 1, 1, 1, blocking[0])
+                ),
+            )  # NCHWc
+            x = relay.add(x, in_bias)
+            conv_weight = relay.multiply(
+                conv_weight, relay.reshape(squeezed_scale, (1, in_channels // 2, 1, 1, 1, 2, 1))
+            )  # OIHWio
+        else:
+            squeezed_scale = relay.squeeze(in_scale, axis=[1, 2, 3])
+            x = relay.nn.relu(x)
+            in_bias = relay.divide(
+                in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)
+            )
+            x = relay.add(x, in_bias)
+            conv_weight = relay.multiply(
+                conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)
+            )
+
+        y = relay.nn.conv3d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3, 3),
+            padding=(1, 1, 1),
+            data_layout="NCDHW{}c".format(blocking[0]) if blocking else "NCDHW",
+            kernel_layout="OIDHW2i{}o".format(blocking[1]) if blocking else "OIDHW",
+        )
+        return relay.Function(args, y)
+
+    def check(shape, channels, blocking):
+        x = relay.var("x", shape=shape)
+        weight = relay.var("weight")
+        if blocking:
+            in_channels = shape[1] * shape[-1]
+            in_bias = relay.var(
+                "in_bias", shape=(1, in_channels // blocking[0], 1, 1, 1, blocking[0])
+            )
+            in_scale = relay.const(
+                _get_positive_scale((1, in_channels // blocking[0], 1, 1, 1, blocking[0]))
+            )
+        else:
+            in_channels = shape[1]
+            in_bias = relay.var("in_bias", shape=(in_channels, 1, 1, 1))
+            in_scale = relay.const(_get_positive_scale((in_channels, 1, 1, 1)))
+        y1 = before(x, weight, in_bias, in_scale, channels, blocking)
+        y1 = run_opt_pass(y1, transform.InferType())
+        type_dict = {x.name_hint: x.checked_type for x in y1.params}
+        weight = relay.var("weight", type_dict["weight"])
+        y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
+        y1_expected = expected(x, weight, in_bias, in_scale, in_channels, channels, blocking)
+
+        y1_folded = run_opt_pass(y1_folded, transform.InferType())
+        y1_expected = run_opt_pass(y1_expected, transform.InferType())
+        assert tvm.ir.structural_equal(y1_folded, y1_expected)
+
+    check((2, 4, 10, 10, 10), 2, None)
+    check((2, 2, 10, 10, 10, 2), 8, (2, 4))
+
+
+def test_fold_bwd_conv3d():
+    """Conv3d testcase."""
+
+    def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
+        args = [x, conv_weight, out_bias]
+        if blocking:
+            out_bias = relay.reshape(out_bias, (1, channels // blocking[1], 1, 1, 1, blocking[1]))
+        else:
+            out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=3)
+        y = relay.nn.conv3d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3, 3),
+            padding=(1, 1, 1),
+            data_layout="NCDHW{}c".format(blocking[0]) if blocking else "NCDHW",
+            kernel_layout="OIDHW1i{}o".format(blocking[1]) if blocking else "OIDHW",
+        )
+        y = relay.add(y, out_bias)
+        y = relay.nn.relu(y)
+        if blocking:
+            out_scale = relay.reshape(out_scale, (1, channels // blocking[1], 1, 1, 1, blocking[1]))
+        y = relay.multiply(y, out_scale)
+        return relay.Function(args, y)
+
+    def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
+        # use a fixed order of args so alpha equal check can pass
+        args = [x, conv_weight, out_bias]
+        if blocking:
+            out_bias = relay.reshape(out_bias, (1, channels // blocking[1], 1, 1, 1, blocking[1]))
+            out_scale = relay.reshape(out_scale, (1, channels // blocking[1], 1, 1, 1, blocking[1]))
+            squeezed_scale = relay.squeeze(out_scale, axis=[0, 2, 3, 4])
+            conv_weight = relay.multiply(
+                conv_weight,
+                relay.reshape(
+                    squeezed_scale, (channels // blocking[1], 1, 1, 1, 1, 1, blocking[1])
+                ),
+            )
+        else:
+            out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=3)
+            squeezed_scale = relay.squeeze(out_scale, axis=[1, 2, 3])
+            conv_weight = relay.multiply(
+                conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=4)
+            )
+
+        y = relay.nn.conv3d(
+            x,
+            conv_weight,
+            channels=channels,
+            kernel_size=(3, 3, 3),
+            padding=(1, 1, 1),
+            data_layout="NCDHW{}c".format(blocking[0]) if blocking else "NCDHW",
+            kernel_layout="OIDHW1i{}o".format(blocking[1]) if blocking else "OIDHW",
+        )
+        if blocking:
+            out_bias = relay.multiply(
+                out_bias,
+                relay.reshape(squeezed_scale, (1, channels // blocking[1], 1, 1, 1, blocking[1])),
+            )
+        else:
+            out_bias = relay.multiply(
+                out_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)
+            )
+        y = relay.add(y, out_bias)
+        y = relay.nn.relu(y)
+        return relay.Function(args, y)
+
+    def check(shape, in_channels, channels, blocking):
+        x = relay.var("x", shape=shape)
+        weight = relay.var("weight")
+        out_bias = relay.var("out_bias", shape=(channels,))
+        if blocking:
+            out_scale = relay.const(_get_positive_scale((channels,)))
+        else:
+            out_scale = relay.const(_get_positive_scale((channels, 1, 1, 1)))
+        y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking)
+        y1 = run_opt_pass(y1, transform.InferType())
+        type_dict = {x.name_hint: x.checked_type for x in y1.params}
+        weight = relay.var("weight", type_dict["weight"])
+        y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
+        y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking)
+        y1_expected = run_opt_pass(y1_expected, transform.InferType())
+        assert tvm.ir.structural_equal(y1_folded, y1_expected)
+
+    check((2, 4, 10, 10, 10), 4, 8, None)
+    check((2, 2, 10, 10, 10, 16), 32, 64, (16, 16))
+
+
 if __name__ == "__main__":
     test_fold_fwd_simple()
     test_fold_fwd_dual_path()
@@ -1043,3 +1219,5 @@ if __name__ == "__main__":
     test_fold_bwd_negative_scale()
     test_fold_bwd_dense()
     test_fold_bwd_bias_add()
+    test_fold_fwd_conv3d()
+    test_fold_bwd_conv3d()