You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/03/11 14:42:05 UTC

[tvm] branch main updated: [Test] Add Test Case to Cover Bug Fix by PR#7432 (#7601)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 56feab9  [Test] Add Test Case to Cover Bug Fix by PR#7432 (#7601)
56feab9 is described below

commit 56feab9f4d97f310018d6a1df6ed4d5dd75e9178
Author: Qiang Zhang <jo...@163.com>
AuthorDate: Thu Mar 11 22:41:50 2021 +0800

    [Test] Add Test Case to Cover Bug Fix by PR#7432 (#7601)
---
 tests/python/relay/test_pass_auto_quantize.py | 34 +++++++++++++++++++++++++++
 1 file changed, 34 insertions(+)

diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py
index 8a7c4cb..31f5ac6 100644
--- a/tests/python/relay/test_pass_auto_quantize.py
+++ b/tests/python/relay/test_pass_auto_quantize.py
@@ -307,6 +307,39 @@ def test_unquantizable_suffix_partition():
     verify_partition_fails(mod, params)
 
 
+def test_left_shift_negative():
+    data = relay.var("data", shape=(1, 16, 64, 64))
+    weight = relay.const(np.full((16, 16, 3, 3), 256.0))
+    conv2d = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=16)
+    relu = relay.nn.relu(conv2d)
+
+    mod = tvm.IRModule.from_expr(relu)
+
+    with tvm.transform.PassContext(opt_level=3):
+        with relay.quantize.qconfig(
+            calibrate_mode="global_scale", global_scale=8.0, skip_conv_layers=None
+        ):
+            qnn_mod = relay.quantize.quantize(mod)
+
+    class OpFinder(relay.ExprVisitor):
+        def __init__(self, op_name):
+            super(OpFinder, self).__init__()
+            self._op_name = op_name
+            self.ops = list()
+
+        def visit_call(self, call):
+            super().visit_call(call)
+            if call.op.name == self._op_name:
+                self.ops.append(call)
+
+    opf = OpFinder("left_shift")
+    opf.visit(qnn_mod["main"])
+    assert len(opf.ops) > 0, 'Broken case, can\'t find any "left_shift" operators.'
+    for left_shift_op in opf.ops:
+        shift_amount = left_shift_op.args[1].data.asnumpy()
+        assert shift_amount >= 0, "Shift amount must be non-negative."
+
+
 if __name__ == "__main__":
     test_mul_rewrite()
     test_batch_flatten_rewrite()
@@ -320,3 +353,4 @@ if __name__ == "__main__":
     test_unquantizable_prefix_partition()
     test_unquantizable_core_partition()
     test_unquantizable_suffix_partition()
+    test_left_shift_negative()