You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/09/16 22:01:17 UTC
[tvm] branch main updated: [Testing] Add decorator tvm.testing.requires_cuda_compute_version (#12778)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new aded9d43ba [Testing] Add decorator tvm.testing.requires_cuda_compute_version (#12778)
aded9d43ba is described below
commit aded9d43ba1e798031900911cca4613487db84fe
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Fri Sep 16 17:01:11 2022 -0500
[Testing] Add decorator tvm.testing.requires_cuda_compute_version (#12778)
* [Testing] Add decorator tvm.testing.requires_cuda_compute_version
Previously, individual unit tests would call
`tvm.contrib.nvcc.get_target_compute_version` and return early. This
was repeated boilerplate in many tests, and incorrectly reported a
test as `PASSED` if the required infrastructure wasn't present.
This commit introduces `tvm.testing.requires_cuda_compute_version`, a
decorator that checks the CUDA compute version and applies
`pytest.mark.skipif`. If required infrastructure isn't present, a
test will be reported as `SKIPPED`.
* requires_cuda_compute_version skips test when no GPU is present
---
python/tvm/testing/utils.py | 44 +++++++
tests/python/unittest/test_tir_ptx_cp_async.py | 7 +-
tests/python/unittest/test_tir_ptx_ldmatrix.py | 8 +-
tests/python/unittest/test_tir_ptx_mma.py | 146 +++------------------
tests/python/unittest/test_tir_ptx_mma_sp.py | 14 +-
.../test_tir_schedule_tensorize_ldmatrix_mma.py | 13 +-
6 files changed, 71 insertions(+), 161 deletions(-)
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 37a27a4213..ad1e003d6e 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -1058,6 +1058,50 @@ def requires_nvcc_version(major_version, minor_version=0, release_version=0):
return inner
+def requires_cuda_compute_version(major_version, minor_version=0):
+ """Mark a test as requiring at least a compute architecture
+
+ Unit test marked with this decorator will run only if the CUDA
+ compute architecture of the GPU is at least `(major_version,
+ minor_version)`.
+
+ This also marks the test as requiring a cuda support.
+
+ Parameters
+ ----------
+ major_version: int
+
+ The major version of the (major,minor) version tuple.
+
+ minor_version: int
+
+ The minor version of the (major,minor) version tuple.
+ """
+ min_version = (major_version, minor_version)
+ try:
+ arch = tvm.contrib.nvcc.get_target_compute_version()
+ compute_version = tvm.contrib.nvcc.parse_compute_version(arch)
+ except ValueError:
+ # No GPU present. This test will be skipped from the
+ # requires_cuda() marks as well.
+ compute_version = (0, 0)
+
+ min_version_str = ".".join(str(v) for v in min_version)
+ compute_version_str = ".".join(str(v) for v in compute_version)
+ requires = [
+ pytest.mark.skipif(
+ compute_version < min_version,
+ reason=f"Requires CUDA compute >= {min_version_str}, but have {compute_version_str}",
+ ),
+ *requires_cuda.marks(),
+ ]
+
+ def inner(func):
+ return _compose([func], requires)
+
+ return inner
+
+
def skip_if_32bit(reason):
def decorator(*args):
if "32bit" in platform.architecture()[0]:
diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py b/tests/python/unittest/test_tir_ptx_cp_async.py
index 5e6535f295..dc521f3c47 100644
--- a/tests/python/unittest/test_tir_ptx_cp_async.py
+++ b/tests/python/unittest/test_tir_ptx_cp_async.py
@@ -47,14 +47,9 @@ def ptx_cp_async(A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "floa
B[tx, i] = A_shared[tx, i]
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(8)
def test_ptx_cp_async():
f = ptx_cp_async
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
- if major < 8:
- # Require at least SM80
- return
mod = tvm.build(f, target="cuda")
A_np = np.random.rand(32, 128).astype("float16")
diff --git a/tests/python/unittest/test_tir_ptx_ldmatrix.py b/tests/python/unittest/test_tir_ptx_ldmatrix.py
index f718082ff8..f652be4421 100644
--- a/tests/python/unittest/test_tir_ptx_ldmatrix.py
+++ b/tests/python/unittest/test_tir_ptx_ldmatrix.py
@@ -56,15 +56,11 @@ def ptx_ldmatrix(
B[8 * j + tx // 4, 8 * k + (tx % 4) * 2 + i] = A_local[4 * k + 2 * j + i]
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(7, 5)
def test_ptx_ldmatrix():
f = ptx_ldmatrix
_, _, param_num, param_trans = f.params
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major * 10 + minor < 75:
- # Require at least SM75
- return
+
for num in [1, 2, 4]:
for trans in [False, True]:
mod = tvm.build(f.specialize({param_num: num, param_trans: trans}), target="cuda")
diff --git a/tests/python/unittest/test_tir_ptx_mma.py b/tests/python/unittest/test_tir_ptx_mma.py
index bee9b7b480..cc9eec3a69 100644
--- a/tests/python/unittest/test_tir_ptx_mma.py
+++ b/tests/python/unittest/test_tir_ptx_mma.py
@@ -66,14 +66,9 @@ def gemm_mma_m8n8k4_row_col_fp64pf64fp64(a: T.handle, b: T.handle, c: T.handle):
C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = Accum[mma_accum_c_id]
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m8n8k4_row_col_fp64pf64fp64():
sch = tvm.tir.Schedule(gemm_mma_m8n8k4_row_col_fp64pf64fp64)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major < 8:
- # Require at least SM80
- return
cuda_mod = tvm.build(sch.mod, target="cuda")
A_np = np.random.uniform(-1, 1, [8, 4]).astype("float64")
@@ -147,14 +142,9 @@ def gemm_mma_m8n8k4_row_row_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(7)
def test_gemm_mma_m8n8k4_row_row_fp16fp16fp16():
sch = tvm.tir.Schedule(gemm_mma_m8n8k4_row_row_fp16fp16fp16)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major < 7:
- # Require at least SM70
- return
cuda_mod = tvm.build(sch.mod, target="cuda")
A_np = np.random.uniform(-1, 1, [16, 4]).astype("float16")
@@ -235,14 +225,9 @@ def gemm_mma_m8n8k4_row_row_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(7)
def test_gemm_mma_m8n8k4_row_row_fp16fp16fp32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k4_row_row_fp16fp16fp32)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major < 7:
- # Require at least SM70
- return
cuda_mod = tvm.build(sch.mod, target="cuda")
A_np = np.random.uniform(-1, 1, [16, 4]).astype("float16")
@@ -311,14 +296,9 @@ def gemm_mma_m8n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle):
# Failure occurs during the external call to nvcc, when attempting to
# generate the .fatbin file.
@tvm.testing.requires_nvcc_version(11)
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(7, 5)
def test_gemm_mma_m8n8k16_row_col_s8s8s32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k16_row_col_s8s8s32)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major * 10 + minor < 75:
- # Require at least SM75
- return
cuda_mod = tvm.build(sch.mod, target="cuda")
A_np = np.random.uniform(-10, 10, [8, 16]).astype("int8")
@@ -387,14 +367,9 @@ def gemm_mma_m8n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle):
# Failure occurs during the external call to nvcc, when attempting to
# generate the .fatbin file.
@tvm.testing.requires_nvcc_version(11)
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(7, 5)
def test_gemm_mma_m8n8k16_row_col_s8u8s32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k16_row_col_s8u8s32)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major * 10 + minor < 75:
- # Require at least SM75
- return
cuda_mod = tvm.build(sch.mod, target="cuda")
A_np = np.random.uniform(-10, 10, [8, 16]).astype("int8")
@@ -463,14 +438,9 @@ def gemm_mma_m8n8k32_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle):
# Failure occurs during the external call to nvcc, when attempting to
# generate the .fatbin file.
@tvm.testing.requires_nvcc_version(11)
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(7, 5)
def test_gemm_mma_m8n8k32_row_col_s4s4s32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k32_row_col_s4s4s32)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major * 10 + minor < 75:
- # Require at least SM75
- return
cuda_mod = tvm.build(sch.mod, target="cuda")
ctx = tvm.cuda()
@@ -531,14 +501,9 @@ def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle):
# Failure occurs during the external call to nvcc, when attempting to
# generate the .fatbin file.
@tvm.testing.requires_nvcc_version(11)
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(7, 5)
def test_gemm_mma_m8n8k32_row_col_s4u4s32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k32_row_col_s4u4s32)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major * 10 + minor < 75:
- # Require at least SM75
- return
cuda_mod = tvm.build(sch.mod, target="cuda")
ctx = tvm.cuda()
@@ -601,14 +566,9 @@ def gemm_mma_m16n8k8_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle)
]
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k8_row_col_fp16fp16fp32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k8_row_col_fp16fp16fp32)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major < 8:
- # Require at least SM80
- return
cuda_mod = tvm.build(sch.mod, target="cuda")
A_np = np.random.uniform(-1, 1, [16, 8]).astype("float16")
@@ -682,15 +642,9 @@ def gemm_mma_m16n8k16_row_col_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle
] = Accum[mma_accum_c_id]
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k16_row_col_fp16fp16fp16():
sch = tvm.tir.Schedule(gemm_mma_m16n8k16_row_col_fp16fp16fp16)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major < 8:
- # Require at least SM80
- return
- cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")
A_np = np.random.uniform(-1, 1, [16, 16]).astype("float16")
@@ -764,15 +718,9 @@ def gemm_mma_m16n8k16_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle
] = Accum[mma_accum_c_id]
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k16_row_col_fp16fp16fp32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k16_row_col_fp16fp16fp32)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major < 8:
- # Require at least SM80
- return
- cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")
A_np = np.random.uniform(-1, 1, [16, 16]).astype("float16")
@@ -846,15 +794,9 @@ def gemm_mma_m16n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k16_row_col_s8s8s32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k16_row_col_s8s8s32)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major < 8:
- # Require at least SM80
- return
- cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")
A_np = np.random.uniform(-10, 10, [16, 16]).astype("int8")
@@ -928,15 +870,9 @@ def gemm_mma_m16n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k16_row_col_s8u8s32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k16_row_col_s8u8s32)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major < 8:
- # Require at least SM80
- return
- cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")
A_np = np.random.uniform(-10, 10, [16, 16]).astype("int8")
@@ -1010,15 +946,9 @@ def gemm_mma_m16n8k32_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k32_row_col_s8s8s32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k32_row_col_s8s8s32)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major < 8:
- # Require at least SM80
- return
- cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")
A_np = np.random.uniform(-10, 10, [16, 32]).astype("int8")
@@ -1092,15 +1022,9 @@ def gemm_mma_m16n8k32_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k32_row_col_s8u8s32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k32_row_col_s8u8s32)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major < 8:
- # Require at least SM80
- return
- cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")
A_np = np.random.uniform(-10, 10, [16, 32]).astype("int8")
@@ -1174,15 +1098,9 @@ def gemm_mma_m16n8k64_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k64_row_col_s4s4s32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k64_row_col_s4s4s32)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major < 8:
- # Require at least SM80
- return
- cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")
ctx = tvm.cuda()
@@ -1248,15 +1166,9 @@ def gemm_mma_m16n8k64_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k64_row_col_s4u4s32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k64_row_col_s4u4s32)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major < 8:
- # Require at least SM80
- return
- cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")
ctx = tvm.cuda()
@@ -1323,15 +1235,9 @@ def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k256_row_col_b1b1s32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k256_row_col_b1b1s32)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
- if major < 8:
- # Require at least SM80
- return
- cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")
ctx = tvm.cuda()
@@ -1345,20 +1251,4 @@ def test_gemm_mma_m16n8k256_row_col_b1b1s32():
if __name__ == "__main__":
- test_gemm_mma_m8n8k4_row_col_fp64pf64fp64()
- test_gemm_mma_m8n8k4_row_row_fp16fp16fp16()
- test_gemm_mma_m8n8k4_row_row_fp16fp16fp32()
- test_gemm_mma_m8n8k16_row_col_s8s8s32()
- test_gemm_mma_m8n8k16_row_col_s8u8s32()
- test_gemm_mma_m8n8k32_row_col_s4s4s32()
- test_gemm_mma_m8n8k32_row_col_s4u4s32()
- test_gemm_mma_m16n8k8_row_col_fp16fp16fp32()
- test_gemm_mma_m16n8k16_row_col_fp16fp16fp16()
- test_gemm_mma_m16n8k16_row_col_fp16fp16fp32()
- test_gemm_mma_m16n8k16_row_col_s8s8s32()
- test_gemm_mma_m16n8k16_row_col_s8u8s32()
- test_gemm_mma_m16n8k32_row_col_s8s8s32()
- test_gemm_mma_m16n8k32_row_col_s8u8s32()
- test_gemm_mma_m16n8k64_row_col_s4s4s32()
- test_gemm_mma_m16n8k64_row_col_s4u4s32()
- test_gemm_mma_m16n8k256_row_col_b1b1s32()
+ tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_ptx_mma_sp.py b/tests/python/unittest/test_tir_ptx_mma_sp.py
index 24170b4898..0b5073864a 100644
--- a/tests/python/unittest/test_tir_ptx_mma_sp.py
+++ b/tests/python/unittest/test_tir_ptx_mma_sp.py
@@ -255,7 +255,7 @@ def mma_sp_m16n8k32_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata:
C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i]
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(8)
def test_mma_sp_m16n8k16_f16():
def get_meta_m16n8k16_half(mask):
assert mask.shape == (16, 4, 2)
@@ -273,11 +273,6 @@ def test_mma_sp_m16n8k16_f16():
for out_dtype in ["float16", "float32"]:
func = mma_sp_m16n8k16_f16f16f16 if out_dtype == "float16" else mma_sp_m16n8k16_f16f16f32
sch = tvm.tir.Schedule(func)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
- if major < 8:
- # Requires SM80+
- return
cuda_mod = tvm.build(sch.mod, target="cuda")
A_np = np.random.uniform(-1, 1, [16, 8]).astype("float16")
@@ -297,7 +292,7 @@ def test_mma_sp_m16n8k16_f16():
tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3)
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(8)
def test_mma_sp_m16n8k32_f16():
def get_meta_m16n8k32_half(mask):
assert mask.shape == (16, 8, 2)
@@ -317,11 +312,6 @@ def test_mma_sp_m16n8k32_f16():
for out_dtype in ["float16", "float32"]:
func = mma_sp_m16n8k32_f16f16f16 if out_dtype == "float16" else mma_sp_m16n8k32_f16f16f32
sch = tvm.tir.Schedule(func)
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
- if major < 8:
- # Requires SM80+
- return
cuda_mod = tvm.build(sch.mod, target="cuda")
A_np = np.random.uniform(-1, 1, [16, 16]).astype("float16")
diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
index 32c1625653..2eda2b9ec4 100644
--- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
+++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
@@ -111,9 +111,6 @@ def run_test(
mma_store_intrin,
)
- if not tvm.testing.is_ampere_or_newer():
- return None
-
f = tvm.build(sch.mod["main"], target="cuda", name="dense")
dev = tvm.device("cuda", 0)
@@ -155,7 +152,7 @@ def run_test(
return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c)
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(8)
def test_f16f16f32_m16n16k16():
def index_map(i, j):
return (
@@ -212,7 +209,7 @@ def test_f16f16f32_m16n16k16():
print("f16f16f32_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean)))
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(8)
def test_f16f16f16_m16n16k16():
def index_map(i, j):
return (
@@ -269,7 +266,7 @@ def test_f16f16f16_m16n16k16():
print("f16f16f16_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean)))
-@tvm.testing.requires_cuda
+@tvm.testing.requires_cuda_compute_version(8)
def test_i8i8i32_m16n16k32():
def index_map_A(i, j):
return (
@@ -341,6 +338,4 @@ def test_i8i8i32_m16n16k32():
if __name__ == "__main__":
- test_f16f16f32_m16n16k16()
- test_f16f16f16_m16n16k16()
- test_i8i8i32_m16n16k32()
+ tvm.testing.main()