You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "masahi (via GitHub)" <gi...@apache.org> on 2023/07/14 08:46:58 UTC

[GitHub] [tvm] masahi opened a new pull request, #15318: [Unity] fp16 A x int B GEMM update - support int8, more bias shape

masahi opened a new pull request, #15318:
URL: https://github.com/apache/tvm/pull/15318

   (no comment)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] sunggg commented on a diff in pull request #15318: [Unity] fp16 A x int B GEMM update - support int8, more bias shape

Posted by "sunggg (via GitHub)" <gi...@apache.org>.
sunggg commented on code in PR #15318:
URL: https://github.com/apache/tvm/pull/15318#discussion_r1264878919


##########
python/tvm/contrib/cutlass/gen_tensor_op.py:
##########
@@ -792,6 +801,12 @@ def get_batch_on_arg(arg_name, arg_shape):
         headers.append("cutlass/layout/matrix.h")
         attrs = {"input": func_args[0], "gamma": func_args[1], "beta": func_args[2]}
         attrs.update(dict(annotations))
+
+        if isinstance(attrs["M"], tvm.tir.Var):

Review Comment:
   Does this mean we support batch?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #15318: [Unity] fp16 A x int B GEMM update - support int8, more bias shape

Posted by "tvm-bot (via GitHub)" <gi...@apache.org>.
tvm-bot commented on PR #15318:
URL: https://github.com/apache/tvm/pull/15318#issuecomment-1635521215

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   <!--bot-comment-ccs-start-->
    * cc @quic-sanirudh <sub>See [#10317](https://github.com/apache/tvm/issues/10317) for details</sub><!--bot-comment-ccs-end-->
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on a diff in pull request #15318: [Unity] fp16 A x int B GEMM update - support int8, more bias shape

Posted by "masahi (via GitHub)" <gi...@apache.org>.
masahi commented on code in PR #15318:
URL: https://github.com/apache/tvm/pull/15318#discussion_r1265818442


##########
python/tvm/contrib/cutlass/gen_tensor_op.py:
##########
@@ -792,6 +801,12 @@ def get_batch_on_arg(arg_name, arg_shape):
         headers.append("cutlass/layout/matrix.h")
         attrs = {"input": func_args[0], "gamma": func_args[1], "beta": func_args[2]}
         attrs.update(dict(annotations))
+
+        if isinstance(attrs["M"], tvm.tir.Var):

Review Comment:
   This is unrelated to this PR but needed for dolly. It enables layer norm offload with dynamic first dimension. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 merged pull request #15318: [Unity] fp16 A x int B GEMM update - support int8, more bias shape

Posted by "vinx13 (via GitHub)" <gi...@apache.org>.
vinx13 merged PR #15318:
URL: https://github.com/apache/tvm/pull/15318


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on a diff in pull request #15318: [Unity] fp16 A x int B GEMM update - support int8, more bias shape

Posted by "masahi (via GitHub)" <gi...@apache.org>.
masahi commented on code in PR #15318:
URL: https://github.com/apache/tvm/pull/15318#discussion_r1263481377


##########
tests/python/relax/test_codegen_cutlass.py:
##########
@@ -1488,6 +1491,153 @@ def split_transform_deploy_mod(mod):
         tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
+def test_fp16A_int8B_gemm():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def decode(
+            A: T.Buffer((T.int64(64), T.int64(64)), "int8"),
+            B: T.Buffer((T.int64(64),), "float16"),
+            decode_1: T.Buffer((T.int64(64), T.int64(64)), "float16"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i, j in T.grid(T.int64(64), T.int64(64)):
+                with T.block("decode"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(A[v_i, v_j], B[v_j])
+                    T.writes(decode_1[v_i, v_j])
+                    decode_1[v_i, v_j] = T.Cast("float16", A[v_i, v_j]) * B[v_j]
+
+        @T.prim_func
+        def encode(
+            A: T.Buffer((T.int64(64), T.int64(64)), "float16"),
+            w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"),
+            compute: T.Buffer((T.int64(64),), "float16"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            max_abs_value = T.alloc_buffer((T.int64(64),), "float16")
+            scale = T.alloc_buffer((T.int64(64),))
+            for i, k in T.grid(T.int64(64), T.int64(64)):
+                with T.block("max_abs_value"):
+                    v_i, v_k = T.axis.remap("SR", [i, k])
+                    T.reads(A[v_i, v_k])
+                    T.writes(max_abs_value[v_i])
+                    with T.init():
+                        max_abs_value[v_i] = T.float16(-65504)
+                    max_abs_value[v_i] = T.max(max_abs_value[v_i], T.fabs(A[v_i, v_k]))
+            for i in range(T.int64(64)):
+                with T.block("scale"):
+                    v_i = T.axis.spatial(T.int64(64), i)
+                    T.reads(max_abs_value[v_i])
+                    T.writes(scale[v_i])
+                    scale[v_i] = T.max(
+                        T.Cast("float32", max_abs_value[v_i]), T.float32(0.0001)
+                    ) * T.float32(0.0078125)
+            for j, i in T.grid(T.int64(64), T.int64(64)):
+                with T.block("w_gathered"):
+                    v_j, v_i = T.axis.remap("SS", [j, i])
+                    T.reads(A[v_i, v_j], scale[v_i])
+                    T.writes(w_gathered[v_j, v_i])
+                    w_gathered[v_j, v_i] = T.Cast(
+                        "int8",
+                        T.min(
+                            T.max(
+                                T.round(T.Cast("float32", A[v_i, v_j]) / scale[v_i]),
+                                T.float32(-128),
+                            ),
+                            T.float32(127),
+                        ),
+                    )
+            for i0 in range(T.int64(64)):
+                with T.block("compute"):
+                    v_i0 = T.axis.spatial(T.int64(64), i0)
+                    T.reads(scale[v_i0])
+                    T.writes(compute[v_i0])
+                    compute[v_i0] = T.Cast("float16", scale[v_i0])
+
+        @R.function
+        def main(
+            x: R.Tensor((64, 64), dtype="float16"),
+            y: R.Tensor((64, 64), dtype="float16"),
+            bias: R.Tensor((64, 64), dtype="float16"),
+        ) -> R.Tensor((64, 64), dtype="float16"):
+            R.func_attr({"num_input": 1})
+            cls = Module
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.encode,
+                    (y,),
+                    out_sinfo=[R.Tensor((64, 64), dtype="int8"), R.Tensor((64,), dtype="float16")],
+                )
+                lv1: R.Tensor((64, 64), dtype="int8") = lv[0]
+                lv2: R.Tensor((64, 64), dtype="int8") = R.call_pure_packed(
+                    "cutlass.ft_preprocess_weight",
+                    lv1,
+                    R.prim_value(80),
+                    R.prim_value(0),
+                    sinfo_args=(R.Tensor((64, 64), dtype="int8"),),
+                )
+                lv3: R.Tensor((64,), dtype="float16") = lv[1]
+                lv4: R.Tensor((64, 64), dtype="int8") = R.builtin.stop_lift_params(lv2)
+                lv5: R.Tensor((64,), dtype="float16") = R.builtin.stop_lift_params(lv3)
+                lv6 = R.call_tir(
+                    cls.decode, (lv4, lv5), out_sinfo=R.Tensor((64, 64), dtype="float16")
+                )
+                lv1_1: R.Tensor((64, 64), dtype="float16") = R.matmul(x, lv6, out_dtype="float16")
+                lv2_1: R.Tensor((64, 128), dtype="float16") = R.add(lv1_1, bias)
+                lv2_2: R.Tensor((64, 128), dtype="float16") = R.nn.gelu(lv2_1)

Review Comment:
   This test demonstrates that we are no supporting a bias shape like this and also gelu activation offloaded to FT.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on a diff in pull request #15318: [Unity] fp16 A x int B GEMM update - support int8, more bias shape

Posted by "masahi (via GitHub)" <gi...@apache.org>.
masahi commented on code in PR #15318:
URL: https://github.com/apache/tvm/pull/15318#discussion_r1263481377


##########
tests/python/relax/test_codegen_cutlass.py:
##########
@@ -1488,6 +1491,153 @@ def split_transform_deploy_mod(mod):
         tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
+def test_fp16A_int8B_gemm():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def decode(
+            A: T.Buffer((T.int64(64), T.int64(64)), "int8"),
+            B: T.Buffer((T.int64(64),), "float16"),
+            decode_1: T.Buffer((T.int64(64), T.int64(64)), "float16"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i, j in T.grid(T.int64(64), T.int64(64)):
+                with T.block("decode"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(A[v_i, v_j], B[v_j])
+                    T.writes(decode_1[v_i, v_j])
+                    decode_1[v_i, v_j] = T.Cast("float16", A[v_i, v_j]) * B[v_j]
+
+        @T.prim_func
+        def encode(
+            A: T.Buffer((T.int64(64), T.int64(64)), "float16"),
+            w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"),
+            compute: T.Buffer((T.int64(64),), "float16"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            max_abs_value = T.alloc_buffer((T.int64(64),), "float16")
+            scale = T.alloc_buffer((T.int64(64),))
+            for i, k in T.grid(T.int64(64), T.int64(64)):
+                with T.block("max_abs_value"):
+                    v_i, v_k = T.axis.remap("SR", [i, k])
+                    T.reads(A[v_i, v_k])
+                    T.writes(max_abs_value[v_i])
+                    with T.init():
+                        max_abs_value[v_i] = T.float16(-65504)
+                    max_abs_value[v_i] = T.max(max_abs_value[v_i], T.fabs(A[v_i, v_k]))
+            for i in range(T.int64(64)):
+                with T.block("scale"):
+                    v_i = T.axis.spatial(T.int64(64), i)
+                    T.reads(max_abs_value[v_i])
+                    T.writes(scale[v_i])
+                    scale[v_i] = T.max(
+                        T.Cast("float32", max_abs_value[v_i]), T.float32(0.0001)
+                    ) * T.float32(0.0078125)
+            for j, i in T.grid(T.int64(64), T.int64(64)):
+                with T.block("w_gathered"):
+                    v_j, v_i = T.axis.remap("SS", [j, i])
+                    T.reads(A[v_i, v_j], scale[v_i])
+                    T.writes(w_gathered[v_j, v_i])
+                    w_gathered[v_j, v_i] = T.Cast(
+                        "int8",
+                        T.min(
+                            T.max(
+                                T.round(T.Cast("float32", A[v_i, v_j]) / scale[v_i]),
+                                T.float32(-128),
+                            ),
+                            T.float32(127),
+                        ),
+                    )
+            for i0 in range(T.int64(64)):
+                with T.block("compute"):
+                    v_i0 = T.axis.spatial(T.int64(64), i0)
+                    T.reads(scale[v_i0])
+                    T.writes(compute[v_i0])
+                    compute[v_i0] = T.Cast("float16", scale[v_i0])
+
+        @R.function
+        def main(
+            x: R.Tensor((64, 64), dtype="float16"),
+            y: R.Tensor((64, 64), dtype="float16"),
+            bias: R.Tensor((64, 64), dtype="float16"),
+        ) -> R.Tensor((64, 64), dtype="float16"):
+            R.func_attr({"num_input": 1})
+            cls = Module
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.encode,
+                    (y,),
+                    out_sinfo=[R.Tensor((64, 64), dtype="int8"), R.Tensor((64,), dtype="float16")],
+                )
+                lv1: R.Tensor((64, 64), dtype="int8") = lv[0]
+                lv2: R.Tensor((64, 64), dtype="int8") = R.call_pure_packed(
+                    "cutlass.ft_preprocess_weight",
+                    lv1,
+                    R.prim_value(80),
+                    R.prim_value(0),
+                    sinfo_args=(R.Tensor((64, 64), dtype="int8"),),
+                )
+                lv3: R.Tensor((64,), dtype="float16") = lv[1]
+                lv4: R.Tensor((64, 64), dtype="int8") = R.builtin.stop_lift_params(lv2)
+                lv5: R.Tensor((64,), dtype="float16") = R.builtin.stop_lift_params(lv3)
+                lv6 = R.call_tir(
+                    cls.decode, (lv4, lv5), out_sinfo=R.Tensor((64, 64), dtype="float16")
+                )
+                lv1_1: R.Tensor((64, 64), dtype="float16") = R.matmul(x, lv6, out_dtype="float16")
+                lv2_1: R.Tensor((64, 128), dtype="float16") = R.add(lv1_1, bias)
+                lv2_2: R.Tensor((64, 128), dtype="float16") = R.nn.gelu(lv2_1)

Review Comment:
   This test demonstrates that we are now supporting a bias shape like this and also gelu activation offloaded to FT.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org