You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/09/24 18:51:37 UTC

[GitHub] [tvm] anwang2009 commented on a change in pull request #8952: [TVM] Add importer for ONNX QLinearMatMul op

anwang2009 commented on a change in pull request #8952:
URL: https://github.com/apache/tvm/pull/8952#discussion_r715835129



##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -3462,6 +3462,66 @@ def _impl_v10(cls, inputs, attr, params):
         return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype)
 
 
+class QLinearMatMul(OnnxOpConverter):
+    """Operator converter for QLinearMatMul from Microsoft onnxruntime contrib opset."""
+
+    @classmethod
+    def _impl_v10(cls, inputs, attr, params):
+        def get_scalar(x, dtype="float32"):
+            if isinstance(x, _expr.Var) and x.name_hint in params:
+                return _op.const(params[x.name_hint].numpy(), dtype)
+            rank = len(infer_shape(x))
+            assert rank <= 1, "QLinearMul scale and zero_point input must be scalars"
+            if rank == 1:
+                x = _op.squeeze(x, [0])
+            return _op.cast(x, dtype)
+
+        import pdb
+        pdb.set_trace()

Review comment:
       remove

##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -3462,6 +3462,66 @@ def _impl_v10(cls, inputs, attr, params):
         return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype)
 
 
+class QLinearMatMul(OnnxOpConverter):
+    """Operator converter for QLinearMatMul from Microsoft onnxruntime contrib opset."""
+
+    @classmethod
+    def _impl_v10(cls, inputs, attr, params):
+        def get_scalar(x, dtype="float32"):
+            if isinstance(x, _expr.Var) and x.name_hint in params:
+                return _op.const(params[x.name_hint].numpy(), dtype)
+            rank = len(infer_shape(x))
+            assert rank <= 1, "QLinearMul scale and zero_point input must be scalars"
+            if rank == 1:
+                x = _op.squeeze(x, [0])
+            return _op.cast(x, dtype)
+
+        import pdb
+        pdb.set_trace()
+
+        a = inputs[0]
+        a_scale = get_scalar(inputs[1])
+        a_zero_point = get_scalar(inputs[2], "int32")
+
+        b = inputs[3]
+        b_scale = get_scalar(inputs[4])
+        b_zero_point = get_scalar(inputs[5], "int32")
+
+        y_scale = fold_constant(get_scalar(inputs[6]))
+        y_zero_point = get_scalar(inputs[7], "int32")
+
+        dtype = infer_type(a).checked_type.dtype
+
+        a_rank = len(infer_shape(a))
+        b_rank = len(infer_shape(b))
+
+        assert (a_rank == 2) and (
+            b_rank == 2
+        ), "QLinearMatMul importer currently requires both 'a' and 'b' tensors to be 2D, but rank(a)={}, rank(b)={}".format(
+            a_rank, b_rank
+        )
+
+        ## Note: The ONNX documentation for this op is fairly clear about acceptable overflow
+        ## behavior during the matmul operation:
+        ##   - The scalar multiplication ops MAY NOT overflow.
+        ##   - The scalar addition ops, which sum the results of the scalar multiplication,
+        ##     MAY overflow, but if they do so, it must behave as one would expect during
+        ##     32-bit integer-addition overflow.
+
+        ## As of this writing, Relay's nn.matmul operator doesn't expose a way for us to
+        ## express these constraints. So to ensure correct behavior, we'll play it safe by
+        ## converting the input tensors to int32 prior before performing matmul.
+
+        a_int32 = _op.cast(a, "int32")
+        b_int32 = _op.cast(b, "int32")
+        matmul_int32 = _op.nn.matmul(a_int32, b_int32)
+
+        a = _qnn.op.dequantize(inputs[0], a_scale, a_zero_point)
+        b = _qnn.op.dequantize(inputs[3], b_scale, b_zero_point)
+        out =

Review comment:
       finish thought here. Also, a and b assignments above are never read




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org