You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/03/11 14:42:05 UTC
[tvm] branch main updated: [Test] Add Test Case to Cover Bug Fix by
PR#7432 (#7601)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 56feab9 [Test] Add Test Case to Cover Bug Fix by PR#7432 (#7601)
56feab9 is described below
commit 56feab9f4d97f310018d6a1df6ed4d5dd75e9178
Author: Qiang Zhang <jo...@163.com>
AuthorDate: Thu Mar 11 22:41:50 2021 +0800
[Test] Add Test Case to Cover Bug Fix by PR#7432 (#7601)
---
tests/python/relay/test_pass_auto_quantize.py | 34 +++++++++++++++++++++++++++
1 file changed, 34 insertions(+)
diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py
index 8a7c4cb..31f5ac6 100644
--- a/tests/python/relay/test_pass_auto_quantize.py
+++ b/tests/python/relay/test_pass_auto_quantize.py
@@ -307,6 +307,39 @@ def test_unquantizable_suffix_partition():
verify_partition_fails(mod, params)
+def test_left_shift_negative():
+ data = relay.var("data", shape=(1, 16, 64, 64))
+ weight = relay.const(np.full((16, 16, 3, 3), 256.0))
+ conv2d = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=16)
+ relu = relay.nn.relu(conv2d)
+
+ mod = tvm.IRModule.from_expr(relu)
+
+ with tvm.transform.PassContext(opt_level=3):
+ with relay.quantize.qconfig(
+ calibrate_mode="global_scale", global_scale=8.0, skip_conv_layers=None
+ ):
+ qnn_mod = relay.quantize.quantize(mod)
+
+ class OpFinder(relay.ExprVisitor):
+ def __init__(self, op_name):
+ super(OpFinder, self).__init__()
+ self._op_name = op_name
+ self.ops = list()
+
+ def visit_call(self, call):
+ super().visit_call(call)
+ if call.op.name == self._op_name:
+ self.ops.append(call)
+
+ opf = OpFinder("left_shift")
+ opf.visit(qnn_mod["main"])
+ assert len(opf.ops) > 0, 'Broken case, can\'t find any "left_shift" operators.'
+ for left_shift_op in opf.ops:
+ shift_amount = left_shift_op.args[1].data.asnumpy()
+ assert shift_amount >= 0, "Shift amount must be non-negative."
+
+
if __name__ == "__main__":
test_mul_rewrite()
test_batch_flatten_rewrite()
@@ -320,3 +353,4 @@ if __name__ == "__main__":
test_unquantizable_prefix_partition()
test_unquantizable_core_partition()
test_unquantizable_suffix_partition()
+ test_left_shift_negative()