You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ec...@apache.org on 2023/05/26 14:22:53 UTC
[tvm] branch main updated: [METAL] Fix int8 vectorized cast (#14962)
This is an automated email from the ASF dual-hosted git repository.
echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6198c7fd8a [METAL] Fix int8 vectorized cast (#14962)
6198c7fd8a is described below
commit 6198c7fd8a75534d98efd0ef800b36fc4e3dc021
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Fri May 26 22:22:46 2023 +0800
[METAL] Fix int8 vectorized cast (#14962)
Current codegen output `(half4)*(device uint*)A` tries to create a `int32`
number and then cast it to `half4`, which is not the expected behavior.
As Metal supports `uchar4` and `char4` types, we can direct use them to
solve that problem.
---
src/target/source/codegen_metal.cc | 5 ----
tests/python/unittest/test_target_codegen_metal.py | 30 +++++++++++++++++-----
2 files changed, 24 insertions(+), 11 deletions(-)
diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc
index bd2b930166..b7105e4bcd 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -220,11 +220,6 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
if (t.is_uint()) {
os << 'u';
}
- if (t.bits() == 8 && t.lanes() == 4) {
- // directly 4 8 bit int in integer.
- os << "int";
- return;
- }
switch (t.bits()) {
case 8:
os << "char";
diff --git a/tests/python/unittest/test_target_codegen_metal.py b/tests/python/unittest/test_target_codegen_metal.py
index 3b1cdb4422..dcbbba8c9c 100644
--- a/tests/python/unittest/test_target_codegen_metal.py
+++ b/tests/python/unittest/test_target_codegen_metal.py
@@ -14,12 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import tvm
-from tvm import te
import numpy as np
-import tvm.testing
+import tvm
import tvm.script
+import tvm.testing
+from tvm import te
from tvm.script import tir as T
@@ -149,7 +149,25 @@ def test_select_vectorize():
np.testing.assert_allclose(b_nd.numpy(), a, atol=1e-5, rtol=1e-5)
+@tvm.testing.requires_gpu
+@tvm.testing.requires_metal
+def test_vectorized_uint8():
+ @T.prim_func
+ def func(A: T.Buffer((16), "uint8"), B: T.Buffer((16), "float32")):
+ for i in T.thread_binding(4, thread="threadIdx.x"):
+ for j in T.vectorized(4):
+ with T.block("block"):
+ vi = T.axis.spatial(16, i * 4 + j)
+ B[vi] = T.Cast("float32", A[vi])
+
+ dev = tvm.metal()
+ a = np.arange(16).astype("uint8")
+ a_nd = tvm.nd.array(a, dev)
+ b_nd = tvm.nd.empty((16,), "float32", dev)
+ f = tvm.build(func, target="metal")
+ f(a_nd, b_nd)
+ np.testing.assert_allclose(b_nd.numpy(), a.astype("float32"), atol=1e-5, rtol=1e-5)
+
+
if __name__ == "__main__":
- test_ramp()
- test_metal_inf_nan()
- test_metal_erf()
+ tvm.testing.main()