You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/12/10 00:39:16 UTC

[GitHub] [tvm] jwfromm commented on a change in pull request #7074: Fix QNN type inference

jwfromm commented on a change in pull request #7074:
URL: https://github.com/apache/tvm/pull/7074#discussion_r539753995



##########
File path: src/relay/qnn/op/op_common.h
##########
@@ -171,6 +171,11 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con
   ICHECK_EQ(types.size(), kNumQnnBinaryOpArgTypes);
 
   // Check the scale and zero point types
+  for (size_t i = 0; i < 8; ++i) {

Review comment:
       Why does this one have so many scales and zero points? The others with 4 type checks make sense, but checking 9 types here is a clear outlier.

##########
File path: tests/python/frontend/pytorch/qnn_test.py
##########
@@ -32,17 +32,58 @@
 from tvm.relay.frontend.pytorch_utils import is_version_greater_than
 from tvm.contrib.download import download_testdata
 
+from tvm.relay.dataflow_pattern import wildcard, is_op
+from tvm.relay.op.contrib.register import register_pattern_table
+from tvm.relay.op.contrib.register import get_pattern_table
+
 
 def torch_version_check():
     from packaging import version
 
     return version.parse(torch.__version__) > version.parse("1.4.0")
 
 
+def make_qnn_add_pattern():
+    lhs = wildcard()
+    rhs = wildcard()
+    lhs_scale = wildcard()
+    lhs_zero_point = wildcard()
+    rhs_scale = wildcard()
+    rhs_zero_point = wildcard()
+    output_scale = wildcard()
+    output_zero_point = wildcard()
+    qadd = is_op("qnn.add")(
+        lhs,
+        rhs,
+        lhs_scale,
+        lhs_zero_point,
+        rhs_scale,
+        rhs_zero_point,
+        output_scale,
+        output_zero_point,
+    )
+    return qadd.optional(is_op("clip"))
+
+
+@register_pattern_table("test_table")
+def pattern_table():
+    return [
+        ("qnn_add", make_qnn_add_pattern()),
+    ]
+
+
 def get_tvm_runtime(script_module, input_name, ishape):
 
     input_shapes = [(input_name, ishape)]
     mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
+    pattern_table = get_pattern_table("test_table")

Review comment:
       Can you add a comment for what this block is doing?

##########
File path: src/relay/qnn/op/requantize.cc
##########
@@ -263,6 +263,14 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
     return false;
   }
 
+  if (types[0].as<IncompleteTypeNode>()) {
+    return false;
+  }
+  for (size_t i = 3; i < 5; ++i) {

Review comment:
       While we're adding a bunch of type checks, can you add a comment indicating what each input represents, something like `// Expected types: data, scale, zero_point, ...`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org