You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ec...@apache.org on 2023/05/26 14:22:53 UTC

[tvm] branch main updated: [METAL] Fix int8 vectorized cast (#14962)

This is an automated email from the ASF dual-hosted git repository.

echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6198c7fd8a [METAL] Fix int8 vectorized cast (#14962)
6198c7fd8a is described below

commit 6198c7fd8a75534d98efd0ef800b36fc4e3dc021
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Fri May 26 22:22:46 2023 +0800

    [METAL] Fix int8 vectorized cast (#14962)
    
    Current codegen output `(half4)*(device uint*)A` tries to create a `int32`
    number and then cast it to `half4`, which is not the expected behavior.
    
    As Metal supports `uchar4` and `char4` types, we can direct use them to
    solve that problem.
---
 src/target/source/codegen_metal.cc                 |  5 ----
 tests/python/unittest/test_target_codegen_metal.py | 30 +++++++++++++++++-----
 2 files changed, 24 insertions(+), 11 deletions(-)

diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc
index bd2b930166..b7105e4bcd 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -220,11 +220,6 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
     if (t.is_uint()) {
       os << 'u';
     }
-    if (t.bits() == 8 && t.lanes() == 4) {
-      // directly 4 8 bit int in integer.
-      os << "int";
-      return;
-    }
     switch (t.bits()) {
       case 8:
         os << "char";
diff --git a/tests/python/unittest/test_target_codegen_metal.py b/tests/python/unittest/test_target_codegen_metal.py
index 3b1cdb4422..dcbbba8c9c 100644
--- a/tests/python/unittest/test_target_codegen_metal.py
+++ b/tests/python/unittest/test_target_codegen_metal.py
@@ -14,12 +14,12 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import tvm
-from tvm import te
 import numpy as np
 
-import tvm.testing
+import tvm
 import tvm.script
+import tvm.testing
+from tvm import te
 from tvm.script import tir as T
 
 
@@ -149,7 +149,25 @@ def test_select_vectorize():
     np.testing.assert_allclose(b_nd.numpy(), a, atol=1e-5, rtol=1e-5)
 
 
+@tvm.testing.requires_gpu
+@tvm.testing.requires_metal
+def test_vectorized_uint8():
+    @T.prim_func
+    def func(A: T.Buffer((16), "uint8"), B: T.Buffer((16), "float32")):
+        for i in T.thread_binding(4, thread="threadIdx.x"):
+            for j in T.vectorized(4):
+                with T.block("block"):
+                    vi = T.axis.spatial(16, i * 4 + j)
+                    B[vi] = T.Cast("float32", A[vi])
+
+    dev = tvm.metal()
+    a = np.arange(16).astype("uint8")
+    a_nd = tvm.nd.array(a, dev)
+    b_nd = tvm.nd.empty((16,), "float32", dev)
+    f = tvm.build(func, target="metal")
+    f(a_nd, b_nd)
+    np.testing.assert_allclose(b_nd.numpy(), a.astype("float32"), atol=1e-5, rtol=1e-5)
+
+
 if __name__ == "__main__":
-    test_ramp()
-    test_metal_inf_nan()
-    test_metal_erf()
+    tvm.testing.main()