You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/04/19 12:21:52 UTC

[tvm] branch main updated: Add FlattenAtrousConv transformation (#10996)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 312b552b11 Add FlattenAtrousConv transformation (#10996)
312b552b11 is described below

commit 312b552b11d4830a9c12193068471b3fe6ab325a
Author: Alexey Voronov <av...@gmail.com>
AuthorDate: Tue Apr 19 15:21:45 2022 +0300

    Add FlattenAtrousConv transformation (#10996)
---
 python/tvm/relay/transform/transform.py            |  27 ++
 src/relay/qnn/utils.h                              |   6 +
 src/relay/transforms/flatten_atrous_conv.cc        | 195 ++++++++++
 .../python/relay/test_pass_flatten_atrous_conv.py  | 427 +++++++++++++++++++++
 4 files changed, 655 insertions(+)

diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py
index e4ee14b629..566d0ffa2b 100644
--- a/python/tvm/relay/transform/transform.py
+++ b/python/tvm/relay/transform/transform.py
@@ -1311,6 +1311,33 @@ def FakeQuantizationToInteger(hard_fail=False, use_qat=False):
     return _ffi_api.FakeQuantizationToInteger(hard_fail, use_qat)
 
 
+def FlattenAtrousConv():
+    # pylint: disable=anomalous-backslash-in-string
+    """
+    The purpose of this pass is to find a sequence of space_to_batch_nd-conv2d-batch_to_space_nd
+    operations:
+
+    .. code-block:: text
+
+      x     w
+      |     |
+      s2b   |
+       \\   /
+        conv2d
+         |
+         b2s
+
+    and convert them into subgraphs with a convolution with the modified "dilation" and
+    recalculated "padding" parameters.
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered FlattenAtrousConv pass.
+    """
+    return _ffi_api.FlattenAtrousConv()
+
+
 def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1):
     """
     Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version
diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h
index b4841c8ddd..18c592f2ed 100644
--- a/src/relay/qnn/utils.h
+++ b/src/relay/qnn/utils.h
@@ -270,6 +270,12 @@ static inline std::vector<float> GetFloatVectorFromConstant(const Expr& expr) {
   return vals;
 }
 
+Expr MakeQnnConv2D(Expr data, Expr weight, Expr input_zero_point, Expr kernel_zero_point,
+                   Expr input_scale, Expr kernel_scale, Array<IndexExpr> strides,
+                   Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups,
+                   IndexExpr channels, Array<IndexExpr> kernel_size, String data_layout,
+                   String kernel_layout, String out_layout, DataType out_dtype);
+
 }  // namespace qnn
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/transforms/flatten_atrous_conv.cc b/src/relay/transforms/flatten_atrous_conv.cc
new file mode 100644
index 0000000000..54e0f193cf
--- /dev/null
+++ b/src/relay/transforms/flatten_atrous_conv.cc
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/flatten_atrous_conv.cc
+ * \brief This transform flattens atrous convolution, which corresponds to the sequence of
+ * operations: "space_to_batch_nd"->"conv2d"->"batch_to_space_nd".
+ */
+
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/qnn/attrs.h>
+#include <tvm/relay/transform.h>
+#include <tvm/topi/broadcast.h>
+
+#include <array>
+#include <set>
+#include <unordered_map>
+
+#include "../qnn/utils.h"
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+/* Description of FlattenAtrousConv
+ *
+ * The purpose of this pass is to find a sequence of space_to_batch_nd-conv2d-batch_to_space_nd
+ * operations:
+ *
+ *   x     w
+ *   |     |
+ *   s2b   |
+ *    \   /
+ *     conv2d
+ *      |
+ *      b2s
+ *
+ * and convert them into subgraphs with a convolution with the modified "dilation" and
+ * recalculated "padding" parameters.
+ */
+
+using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;
+
+class FlattenAtrousConvSubgraphMutator {
+ public:
+  Expr MutateSubgraph(const Expr& expr) {
+    try {
+      const CallNode* b2s_node_ = expr.as<CallNode>();
+      const CallNode* conv2d_node_ = b2s_node_->args[0].as<CallNode>();
+      const CallNode* s2b_node_ = conv2d_node_->args[0].as<CallNode>();
+
+      ICHECK(b2s_node_ != nullptr);
+      const auto* b2s_attrs = b2s_node_->attrs.as<BatchToSpaceNDAttrs>();
+      ICHECK(b2s_attrs != nullptr);
+
+      Array<PrimExpr> dilation = {b2s_attrs->block_shape[0], b2s_attrs->block_shape[1]};
+
+      ICHECK(conv2d_node_ != nullptr);
+      const auto* conv2d_attrs = conv2d_node_->attrs.as<Conv2DAttrs>();
+      ICHECK(conv2d_attrs != nullptr);
+
+      Array<PrimExpr> kernel_shape = conv2d_attrs->kernel_size;
+      PrimExpr kernel_h = kernel_shape[0];
+      PrimExpr kernel_w = kernel_shape[1];
+
+      ICHECK(s2b_node_ != nullptr);
+      const auto* s2b_attrs = s2b_node_->attrs.as<SpaceToBatchNDAttrs>();
+      ICHECK(s2b_attrs != nullptr);
+
+      Expr data = s2b_node_->args[0];
+      ICHECK(conv2d_attrs->data_layout == "NHWC");
+      Array<PrimExpr> data_shape = transform::InferTypeLocal(data).as<TensorTypeNode>()->shape;
+      PrimExpr in_h = data_shape[1];
+      PrimExpr in_w = data_shape[2];
+
+      PrimExpr dilation_h = dilation[0];
+      PrimExpr dilation_w = dilation[1];
+
+      PrimExpr dilated_kernel_h = (kernel_h - 1) * dilation_h + 1;
+      PrimExpr dilated_kernel_w = (kernel_w - 1) * dilation_w + 1;
+
+      Array<PrimExpr> strides = {1, 1};
+      PrimExpr stride_h = strides[0];
+      PrimExpr stride_w = strides[1];
+
+      auto _get_pad_pair = [](PrimExpr input1d, PrimExpr kernel1d,
+                              PrimExpr stride1d) -> Array<PrimExpr> {
+        PrimExpr out1d = truncdiv((input1d + stride1d - 1), stride1d);
+        PrimExpr pad = topi::maximum(((out1d - 1) * stride1d + kernel1d - input1d), 0);
+        PrimExpr pad_before = truncdiv(pad, 2);
+        PrimExpr pad_after = pad - pad_before;
+        return {pad_before, pad_after};
+      };
+
+      Array<PrimExpr> pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h);
+      Array<PrimExpr> pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w);
+
+      Array<IndexExpr> padding = {pad_v[0], pad_h[0], pad_v[1], pad_h[1]};
+
+      Expr weight = conv2d_node_->args[1];
+
+      if (conv2d_node_->op == Op::Get("nn.conv2d")) {
+        return Conv2D(data, weight, strides, padding, dilation, conv2d_attrs->groups,
+                      conv2d_attrs->channels, conv2d_attrs->kernel_size, conv2d_attrs->data_layout,
+                      conv2d_attrs->kernel_layout, conv2d_attrs->out_layout,
+                      conv2d_attrs->out_dtype);
+      }
+
+      if (conv2d_node_->op == Op::Get("qnn.conv2d")) {
+        Expr input_zero_point = conv2d_node_->args[2];
+        Expr kernel_zero_point = conv2d_node_->args[3];
+        Expr input_scale = conv2d_node_->args[4];
+        Expr kernel_scale = conv2d_node_->args[5];
+        return qnn::MakeQnnConv2D(data, weight, input_zero_point, kernel_zero_point, input_scale,
+                                  kernel_scale, strides, padding, dilation, conv2d_attrs->groups,
+                                  conv2d_attrs->channels, conv2d_attrs->kernel_size,
+                                  conv2d_attrs->data_layout, conv2d_attrs->kernel_layout,
+                                  conv2d_attrs->out_layout, conv2d_attrs->out_dtype);
+      }
+
+      DLOG(INFO) << "Ran into an unhandled convolution, skipping " << expr << std::endl;
+      return expr;
+    } catch (std::exception& e) {
+      DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping " << expr << " with "
+                 << e.what() << std::endl;
+      return expr;
+    }
+  }
+};
+
+class FlattenAtrousConvRewriter : public MixedModeMutator {
+ protected:
+  Expr Rewrite_(const CallNode* pre, const Expr& post) override {
+    if (const CallNode* call_node = post.as<CallNode>()) {
+      if (ops_[op_iter_].count(call_node->op)) {
+        ++op_iter_;
+        if (op_iter_ == ops_.size()) {
+          op_iter_ = 0;
+          return FlattenAtrousConvSubgraphMutator().MutateSubgraph(post);
+        }
+      } else {
+        op_iter_ = 0;
+      }
+    }
+    return post;
+  }
+
+ private:
+  size_t op_iter_ = 0;
+  const std::array<ExprSet, 3> ops_ = {
+      ExprSet{Op::Get("nn.space_to_batch_nd")},
+      ExprSet{Op::Get("nn.conv2d"), Op::Get("qnn.conv2d")},
+      ExprSet{Op::Get("nn.batch_to_space_nd")},
+  };
+};
+
+Expr FlattenAtrousConv(const Expr& expr, const IRModule& mod) {
+  return FlattenAtrousConvRewriter().Mutate(expr);
+}
+
+namespace transform {
+
+Pass FlattenAtrousConv() {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
+      [=](Function f, IRModule m, PassContext pc) {
+        return Downcast<Function>(FlattenAtrousConv(f, m));
+      };
+  return CreateFunctionPass(pass_func, 0, "FlattenAtrousConv", {"InferType"});
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.FlattenAtrousConv").set_body_typed(FlattenAtrousConv);
+
+}  // namespace transform
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/relay/test_pass_flatten_atrous_conv.py b/tests/python/relay/test_pass_flatten_atrous_conv.py
new file mode 100644
index 0000000000..f6b3718e40
--- /dev/null
+++ b/tests/python/relay/test_pass_flatten_atrous_conv.py
@@ -0,0 +1,427 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-wildcard-import
+import numpy as np
+import pytest
+import tvm
+from tvm import relay
+
+
+def compare_expected_fac(expr, expected_expr, args):
+    mod_def = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(expr))
+    mod_flat = tvm.relay.transform.FlattenAtrousConv()(mod_def)
+    mod_exp = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(expected_expr))
+
+    assert expr is expected_expr or not tvm.ir.structural_equal(mod_def, mod_flat)
+    assert tvm.ir.structural_equal(mod_flat, mod_exp)
+
+    result_def = (
+        relay.create_executor("vm", mod=mod_def, device=tvm.cpu(), target="llvm")
+        .evaluate()(*args)
+        .numpy()
+    )
+    result_flat = (
+        relay.create_executor("vm", mod=mod_flat, device=tvm.cpu(), target="llvm")
+        .evaluate()(*args)
+        .numpy()
+    )
+    result_exp = (
+        relay.create_executor("vm", mod=mod_exp, device=tvm.cpu(), target="llvm")
+        .evaluate()(*args)
+        .numpy()
+    )
+
+    assert np.array_equal(result_def, result_flat)
+    assert np.array_equal(result_flat, result_exp)
+
+
+def test_fac_block_shape_2():
+    # pattern entry with block_shape=[2, 2]
+    shape_x = [1, 5, 5, 4]
+    shape_w = [3, 3, 4, 1]
+
+    x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32")
+    w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32")
+
+    weight = relay.const(w_np)
+    data = relay.var("data", shape=shape_x, dtype="float32")
+    op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]])
+    op2 = relay.nn.conv2d(
+        op1,
+        weight,
+        padding=[0, 0, 0, 0],
+        groups=4,
+        channels=4,
+        kernel_size=[3, 3],
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+    )
+    expr = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]])
+
+    expected_expr = relay.nn.conv2d(
+        data,
+        weight,
+        padding=[2, 2, 2, 2],
+        dilation=[2, 2],
+        groups=4,
+        channels=4,
+        kernel_size=[3, 3],
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+    )
+
+    compare_expected_fac(expr, expected_expr, [x_np])
+
+
+def test_fac_block_shape_4():
+    # pattern entry with block_shape=[4, 4]
+    shape_x = [1, 5, 5, 4]
+    shape_w = [3, 3, 4, 1]
+
+    x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32")
+    w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32")
+
+    weight = relay.const(w_np)
+    data = relay.var("data", shape=shape_x, dtype="float32")
+    op1 = relay.nn.space_to_batch_nd(data, block_shape=[4, 4], paddings=[[4, 7], [4, 7]])
+    op2 = relay.nn.conv2d(
+        op1,
+        weight,
+        padding=[0, 0, 0, 0],
+        groups=4,
+        channels=4,
+        kernel_size=[3, 3],
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+    )
+    expr = relay.nn.batch_to_space_nd(op2, block_shape=[4, 4], crops=[[0, 3], [0, 3]])
+
+    expected_expr = relay.nn.conv2d(
+        data,
+        weight,
+        padding=[4, 4, 4, 4],
+        dilation=[4, 4],
+        groups=4,
+        channels=4,
+        kernel_size=[3, 3],
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+    )
+
+    compare_expected_fac(expr, expected_expr, [x_np])
+
+
+def test_fac_quantize():
+    # quantize pattern entry
+    shape_x = [1, 5, 5, 4]
+    shape_w = [3, 3, 4, 1]
+
+    x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8")
+    w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8")
+
+    weight = relay.const(w_np)
+    data = relay.var("data", shape=shape_x, dtype="int8")
+    op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]])
+    op2 = relay.qnn.op.conv2d(
+        op1,
+        weight,
+        input_zero_point=relay.const(0),
+        kernel_zero_point=relay.const(0),
+        input_scale=relay.const(2.0),
+        kernel_scale=relay.const(1.0),
+        padding=[0, 0, 0, 0],
+        groups=4,
+        channels=4,
+        kernel_size=[3, 3],
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+    )
+    expr = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]])
+
+    expected_expr = relay.qnn.op.conv2d(
+        data,
+        weight,
+        input_zero_point=relay.const(0),
+        kernel_zero_point=relay.const(0),
+        input_scale=relay.const(2.0),
+        kernel_scale=relay.const(1.0),
+        padding=[2, 2, 2, 2],
+        dilation=[2, 2],
+        groups=4,
+        channels=4,
+        kernel_size=[3, 3],
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+    )
+
+    compare_expected_fac(expr, expected_expr, [x_np])
+
+
+def test_fac_surrounding():
+    # pattern entry with surrounding operations add
+    shape_x = [1, 5, 5, 4]
+    shape_w = [3, 3, 4, 1]
+
+    x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32")
+    w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32")
+
+    weight = relay.const(w_np)
+    data = relay.var("data", shape=shape_x, dtype="float32")
+    op0 = relay.op.add(data, relay.const(1.0))
+    op1 = relay.nn.space_to_batch_nd(op0, block_shape=[2, 2], paddings=[[2, 3], [2, 3]])
+    op2 = relay.nn.conv2d(
+        op1,
+        weight,
+        padding=[0, 0, 0, 0],
+        groups=4,
+        channels=4,
+        kernel_size=[3, 3],
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+    )
+    op3 = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]])
+    expr = relay.op.add(op3, relay.const(-1.0))
+
+    op0 = relay.op.add(data, relay.const(1.0))
+    op1 = relay.nn.conv2d(
+        op0,
+        weight,
+        padding=[2, 2, 2, 2],
+        dilation=[2, 2],
+        groups=4,
+        channels=4,
+        kernel_size=[3, 3],
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+    )
+    expected_expr = relay.op.add(op1, relay.const(-1.0))
+
+    compare_expected_fac(expr, expected_expr, [x_np])
+
+
+def test_fac_several():
+    # several pattern entries
+    shape_x = [1, 5, 5, 4]
+    shape_w = [3, 3, 4, 1]
+
+    x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32")
+    w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32")
+
+    weight = relay.const(w_np)
+    data = relay.var("data", shape=shape_x, dtype="float32")
+    op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]])
+    op2 = relay.nn.conv2d(
+        op1,
+        weight,
+        padding=[0, 0, 0, 0],
+        groups=4,
+        channels=4,
+        kernel_size=[3, 3],
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+    )
+    op3 = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]])
+    op4 = relay.nn.space_to_batch_nd(op3, block_shape=[4, 4], paddings=[[4, 7], [4, 7]])
+    op5 = relay.nn.conv2d(
+        op4,
+        weight,
+        padding=[0, 0, 0, 0],
+        groups=4,
+        channels=4,
+        kernel_size=[3, 3],
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+    )
+    expr = relay.nn.batch_to_space_nd(op5, block_shape=[4, 4], crops=[[0, 3], [0, 3]])
+
+    op1 = relay.nn.conv2d(
+        data,
+        weight,
+        padding=[2, 2, 2, 2],
+        dilation=[2, 2],
+        groups=4,
+        channels=4,
+        kernel_size=[3, 3],
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+    )
+
+    expected_expr = relay.nn.conv2d(
+        op1,
+        weight,
+        padding=[4, 4, 4, 4],
+        dilation=[4, 4],
+        groups=4,
+        channels=4,
+        kernel_size=[3, 3],
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+    )
+
+    compare_expected_fac(expr, expected_expr, [x_np])
+
+
+def test__fac_only_s2b_conv():
+    # negative case, only operations space_to_batch_nd-conv2d
+    shape_x = [1, 5, 5, 4]
+    shape_w = [3, 3, 4, 1]
+
+    x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32")
+    w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32")
+
+    weight = relay.const(w_np)
+    data = relay.var("data", shape=shape_x, dtype="float32")
+    op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]])
+    expr = relay.nn.conv2d(
+        op1,
+        weight,
+        padding=[0, 0, 0, 0],
+        groups=4,
+        channels=4,
+        kernel_size=[3, 3],
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+    )
+
+    expected_expr = expr
+
+    compare_expected_fac(expr, expected_expr, [x_np])
+
+
+def test_fac_only_s2b():
+    # negative case, only operation space_to_batch_nd
+    shape_x = [1, 5, 5, 4]
+    shape_w = [3, 3, 4, 1]
+
+    x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32")
+    w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32")
+
+    weight = relay.const(w_np)
+    data = relay.var("data", shape=shape_x, dtype="float32")
+    expr = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]])
+
+    expected_expr = expr
+
+    compare_expected_fac(expr, expected_expr, [x_np])
+
+
+def test_fac_only_conv_b2s():
+    # negative case, only operations conv2d-batch_to_space_nd
+    shape_x = [1, 5, 5, 4]
+    shape_w = [3, 3, 4, 1]
+
+    x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32")
+    w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32")
+
+    weight = relay.const(w_np)
+    data = relay.var("data", shape=shape_x, dtype="float32")
+    op1 = relay.nn.conv2d(
+        data,
+        weight,
+        padding=[0, 0, 0, 0],
+        groups=4,
+        channels=4,
+        kernel_size=[3, 3],
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+    )
+    expr = relay.nn.batch_to_space_nd(op1, block_shape=[2, 2], crops=[[0, 1], [0, 1]])
+
+    expected_expr = expr
+
+    compare_expected_fac(expr, expected_expr, [x_np])
+
+
+def test_fac_only_b2s():
+    # negative case, only operation batch_to_space_nd
+    shape_x = [1, 5, 5, 4]
+    shape_w = [3, 3, 4, 1]
+
+    x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32")
+    w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32")
+
+    weight = relay.const(w_np)
+    data = relay.var("data", shape=shape_x, dtype="float32")
+    expr = relay.nn.batch_to_space_nd(data, block_shape=[2, 2], crops=[[0, 1], [0, 1]])
+
+    expected_expr = expr
+
+    compare_expected_fac(expr, expected_expr, [x_np])
+
+
+def test_fac_op_btwn_s2b_conv():
+    # negative case, add operation between space_to_batch_nd-conv2d
+    shape_x = [1, 5, 5, 4]
+    shape_w = [3, 3, 4, 1]
+
+    x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32")
+    w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32")
+
+    weight = relay.const(w_np)
+    data = relay.var("data", shape=shape_x, dtype="float32")
+    op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]])
+    op_1_5 = relay.op.add(op1, relay.const(1.0))
+    op2 = relay.nn.conv2d(
+        op_1_5,
+        weight,
+        padding=[0, 0, 0, 0],
+        groups=4,
+        channels=4,
+        kernel_size=[3, 3],
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+    )
+    expr = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]])
+
+    expected_expr = expr
+
+    compare_expected_fac(expr, expected_expr, [x_np])
+
+
+def test_fac_op_btwn_conv_b2s():
+    # negative case, add operation between conv2d-batch_to_space_nd
+    shape_x = [1, 5, 5, 4]
+    shape_w = [3, 3, 4, 1]
+
+    x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32")
+    w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32")
+
+    weight = relay.const(w_np)
+    data = relay.var("data", shape=shape_x, dtype="float32")
+    op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]])
+    op2 = relay.nn.conv2d(
+        op1,
+        weight,
+        padding=[0, 0, 0, 0],
+        groups=4,
+        channels=4,
+        kernel_size=[3, 3],
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+    )
+    op_2_5 = relay.op.add(op2, relay.const(1.0))
+    expr = relay.nn.batch_to_space_nd(op_2_5, block_shape=[2, 2], crops=[[0, 1], [0, 1]])
+
+    expected_expr = expr
+
+    compare_expected_fac(expr, expected_expr, [x_np])
+
+
+if __name__ == "__main__":
+    import sys
+
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))