You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/11/21 07:41:01 UTC

[GitHub] [tvm] abhikran-quic commented on a change in pull request #9186: [ONNX] Add MatMulInteger16 contrib op

abhikran-quic commented on a change in pull request #9186:
URL: https://github.com/apache/tvm/pull/9186#discussion_r753760519



##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -212,6 +212,77 @@ def get_scalar(x, params, dtype="float32"):
     return _op.cast(x, dtype)
 
 
+def matmul_out_dtype(inputs, out_dtype):
+    """Common function to handle MatMul and MatMulInteger16"""
+    a_shape = shape_of(inputs[0])
+    a_rank = infer_shape(a_shape)[0]
+    b_shape = shape_of(inputs[1])
+    b_rank = infer_shape(b_shape)[0]
+    if a_rank > 2 or b_rank > 2:
+
+        def flatten_to_nd(x, x_shape, nd=3):
+            ndims = infer_shape(x_shape)[0]
+            if ndims == nd:
+                return x
+            newshape = _op.concatenate(
+                [
+                    _expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype),
+                    _op.strided_slice(x_shape, [ndims - nd + 1], [ndims]),
+                ],
+                0,
+            )
+            out = _op.reshape(x, fold_constant(newshape))
+            return out
+
+        b_type = infer_type(inputs[1])
+        # Convert to dense if the second matrix is 2d and non-dynamic
+        if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type):
+            a = flatten_to_nd(inputs[0], a_shape, 2)
+            b = _op.transpose(inputs[1])
+            output = _op.nn.dense(a, b, out_dtype=out_dtype)
+        else:
+            # Convert a and b into 3 dimensional tensors.
+            a = flatten_to_nd(inputs[0], a_shape, 3)
+            b = flatten_to_nd(inputs[1], b_shape, 3)
+            # Perform a NN batch matmul.
+            output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False)
+        # Determine the output batch dimension.
+        if a_rank > b_rank:
+            out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
+        elif a_rank < b_rank:
+            out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
+        # If its unclear how broadcasting should be applied, the output
+        # shape is determined by choosing the maximum value from each input.
+        else:
+            out_batch = _op.concatenate(
+                [
+                    _op.maximum(
+                        _op.strided_slice(a_shape, [i], [i + 1]),
+                        _op.strided_slice(b_shape, [i], [i + 1]),
+                    )
+                    for i in range(a_rank - 2)
+                ],
+                0,
+            )
+        # Reshape output to original dimensions.
+        final_shape = _op.concatenate(
+            [
+                out_batch,
+                _op.strided_slice(
+                    a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1]
+                ),
+                _op.strided_slice(
+                    b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]]
+                ),

Review comment:
       Sure. This is resolved.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org