You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/05/13 13:12:03 UTC
[tvm] branch main updated: Rename gpu to cuda,
and bump dlpack to v0.5 (#8032)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 43c2ea7 Rename gpu to cuda, and bump dlpack to v0.5 (#8032)
43c2ea7 is described below
commit 43c2ea72bcd9968531441189e172372c1d8a64cb
Author: Yuchen Jin <yu...@cs.washington.edu>
AuthorDate: Thu May 13 06:11:40 2021 -0700
Rename gpu to cuda, and bump dlpack to v0.5 (#8032)
---
3rdparty/dlpack | 2 +-
apps/topi_recipe/broadcast/test_broadcast_map.py | 10 ++---
apps/topi_recipe/reduce/test_reduce_map.py | 4 +-
apps/topi_recipe/rnn/lstm.py | 2 +-
apps/topi_recipe/rnn/matexp.py | 10 ++---
docs/deploy/tensorrt.rst | 2 +-
docs/dev/index.rst | 6 +--
golang/src/device.go | 20 +++++-----
include/tvm/runtime/device_api.h | 8 ++--
python/tvm/__init__.py | 2 +-
python/tvm/_ffi/runtime_ctypes.py | 3 +-
python/tvm/contrib/nvcc.py | 8 ++--
python/tvm/runtime/__init__.py | 2 +-
python/tvm/runtime/ndarray.py | 27 +++++++++++--
python/tvm/testing.py | 10 ++---
python/tvm/topi/cuda/conv2d_alter_op.py | 2 +-
python/tvm/topi/cuda/nms.py | 2 +-
rust/tvm-sys/src/device.rs | 8 ++--
rust/tvm-sys/src/value.rs | 2 +-
src/auto_scheduler/search_policy/utils.h | 4 +-
src/auto_scheduler/search_task.cc | 6 +--
src/contrib/tf_op/tvm_dso_op_kernels.cc | 6 +--
src/relay/backend/build_module.cc | 2 +-
src/relay/backend/vm/compiler.cc | 2 +-
src/runtime/contrib/cudnn/cudnn_utils.cc | 2 +-
src/runtime/contrib/tensorrt/tensorrt_builder.cc | 4 +-
src/runtime/contrib/tensorrt/tensorrt_runtime.cc | 6 +--
src/runtime/cuda/cuda_device_api.cc | 20 +++++-----
src/runtime/module.cc | 2 +-
src/runtime/ndarray.cc | 4 +-
src/target/target_kind.cc | 6 +--
.../schedule_postproc_rewrite_for_tensor_core.cc | 2 +-
src/tir/analysis/verify_memory.cc | 2 +-
tests/cpp/build_module_test.cc | 2 +-
.../test_runtime_packed_func.py | 4 +-
tests/python/contrib/test_cublas.py | 6 +--
tests/python/contrib/test_cudnn.py | 10 ++---
tests/python/contrib/test_tensorrt.py | 20 +++++-----
tests/python/frontend/pytorch/test_forward.py | 2 +-
.../quantization/test_quantization_accuracy.py | 2 +-
tests/python/relay/test_any.py | 4 +-
tests/python/relay/test_auto_scheduler_tuning.py | 2 +-
tests/python/relay/test_cpp_build_module.py | 2 +-
tests/python/relay/test_op_level1.py | 10 ++---
tests/python/relay/test_pass_context_analysis.py | 42 ++++++++++-----------
tests/python/topi/python/test_topi_relu.py | 2 +-
tests/python/topi/python/test_topi_tensor.py | 2 +-
.../unittest/test_runtime_graph_cuda_graph.py | 2 +-
.../test_runtime_module_based_interface.py | 8 ++--
tests/python/unittest/test_target_codegen_blob.py | 4 +-
tests/python/unittest/test_target_codegen_cuda.py | 44 +++++++++++-----------
tests/python/unittest/test_target_codegen_llvm.py | 2 +-
...te_schedule_postproc_rewrite_for_tensor_core.py | 4 +-
.../unittest/test_te_schedule_tensor_core.py | 4 +-
.../test_tir_transform_lower_warp_memory.py | 12 +++---
tutorials/auto_scheduler/tune_conv2d_layer_cuda.py | 2 +-
tutorials/autotvm/tune_conv2d_cuda.py | 2 +-
tutorials/frontend/deploy_sparse.py | 2 +-
tutorials/frontend/from_caffe2.py | 2 +-
tutorials/frontend/from_keras.py | 2 +-
tutorials/frontend/from_mxnet.py | 2 +-
tutorials/frontend/from_tensorflow.py | 2 +-
tutorials/get_started/relay_quick_start.py | 2 +-
tutorials/language/reduction.py | 2 +-
tutorials/language/scan.py | 2 +-
tutorials/optimize/opt_conv_cuda.py | 2 +-
tutorials/optimize/opt_conv_tensorcore.py | 2 +-
tutorials/topi/intro_topi.py | 2 +-
68 files changed, 217 insertions(+), 197 deletions(-)
diff --git a/3rdparty/dlpack b/3rdparty/dlpack
index a07f962..ddeb264 160000
--- a/3rdparty/dlpack
+++ b/3rdparty/dlpack
@@ -1 +1 @@
-Subproject commit a07f962d446b577adf4baef2b347a0f3a2a20617
+Subproject commit ddeb264880a1fa7e7be238ab3901a810324fbe5f
diff --git a/apps/topi_recipe/broadcast/test_broadcast_map.py b/apps/topi_recipe/broadcast/test_broadcast_map.py
index e7b5c3a..43a44af 100644
--- a/apps/topi_recipe/broadcast/test_broadcast_map.py
+++ b/apps/topi_recipe/broadcast/test_broadcast_map.py
@@ -65,8 +65,8 @@ def test_broadcast_to(in_shape, out_shape):
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = np.broadcast_to(data_npy, out_shape)
- data_nd = tvm.nd.array(data_npy, tvm.gpu())
- out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), tvm.gpu())
+ data_nd = tvm.nd.array(data_npy, tvm.cuda())
+ out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), tvm.cuda())
for _ in range(2):
fcuda(data_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
@@ -116,9 +116,9 @@ def test_broadcast_binary_op(lhs_shape, rhs_shape, typ="add"):
out_npy = np.maximum(lhs_npy, rhs_npy)
elif typ == "minimum":
out_npy = np.minimum(lhs_npy, rhs_npy)
- lhs_nd = tvm.nd.array(lhs_npy, tvm.gpu())
- rhs_nd = tvm.nd.array(rhs_npy, tvm.gpu())
- out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), tvm.gpu())
+ lhs_nd = tvm.nd.array(lhs_npy, tvm.cuda())
+ rhs_nd = tvm.nd.array(rhs_npy, tvm.cuda())
+ out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), tvm.cuda())
for _ in range(2):
fcuda(lhs_nd, rhs_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
diff --git a/apps/topi_recipe/reduce/test_reduce_map.py b/apps/topi_recipe/reduce/test_reduce_map.py
index 0a78e5b..71ceb8f 100644
--- a/apps/topi_recipe/reduce/test_reduce_map.py
+++ b/apps/topi_recipe/reduce/test_reduce_map.py
@@ -78,8 +78,8 @@ def test_reduce_map(in_shape, axis, keepdims, type="sum", test_id=0):
else:
raise NotImplementedError
- data_tvm = tvm.nd.array(in_npy, device=tvm.gpu())
- out_tvm = tvm.nd.empty(shape=out_npy.shape, device=tvm.gpu())
+ data_tvm = tvm.nd.array(in_npy, device=tvm.cuda())
+ out_tvm = tvm.nd.empty(shape=out_npy.shape, device=tvm.cuda())
for _ in range(2):
fcuda(data_tvm, out_tvm)
diff --git a/apps/topi_recipe/rnn/lstm.py b/apps/topi_recipe/rnn/lstm.py
index e4b7fba..cd45bff 100644
--- a/apps/topi_recipe/rnn/lstm.py
+++ b/apps/topi_recipe/rnn/lstm.py
@@ -171,7 +171,7 @@ def lstm():
def check_device(target):
num_step = n_num_step
flstm = tvm.build(s, [Xi2h, Wh2h, scan_h, scan_c], target)
- dev = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
+ dev = tvm.cuda(0) if target == "cuda" else tvm.cl(0)
# launch the kernel.
scan_h_np = np.zeros((num_step, batch_size, num_hidden)).astype("float32")
scan_c_np = np.zeros((num_step, batch_size, num_hidden)).astype("float32")
diff --git a/apps/topi_recipe/rnn/matexp.py b/apps/topi_recipe/rnn/matexp.py
index ecf868c..85e0d61 100644
--- a/apps/topi_recipe/rnn/matexp.py
+++ b/apps/topi_recipe/rnn/matexp.py
@@ -140,7 +140,7 @@ def rnn_matexp():
}
):
f = tvm.build(s, [s_scan, Whh], target)
- dev = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
+ dev = tvm.cuda(0) if target == "cuda" else tvm.cl(0)
# launch the kernel.
res_np = np.zeros((n_num_step, n_batch_size, n_num_hidden)).astype("float32")
Whh_np = np.zeros((n_num_hidden, n_num_hidden)).astype("float32")
@@ -160,16 +160,16 @@ def rnn_matexp():
print("Time cost=%g" % tgap)
# correctness
if not SKIP_CHECK:
- res_gpu = res_a.asnumpy()
+ res_cuda = res_a.asnumpy()
res_cmp = np.ones_like(res_np).astype("float64")
Whh_np = Whh_np.astype("float64")
for t in range(1, n_num_step):
res_cmp[t][:] = np.dot(res_cmp[t - 1], Whh_np)
for i in range(n_num_step):
for j in range(n_num_hidden):
- if abs(res_cmp[i, 0, j] - res_gpu[i, 0, j]) > 1e-5:
- print("%d, %d: %g vs %g" % (i, j, res_cmp[i, 0, j], res_gpu[i, 0, j]))
- tvm.testing.assert_allclose(res_gpu, res_cmp, rtol=1e-3)
+ if abs(res_cmp[i, 0, j] - res_cuda[i, 0, j]) > 1e-5:
+ print("%d, %d: %g vs %g" % (i, j, res_cmp[i, 0, j], res_cuda[i, 0, j]))
+ tvm.testing.assert_allclose(res_cuda, res_cmp, rtol=1e-3)
check_device("cuda")
diff --git a/docs/deploy/tensorrt.rst b/docs/deploy/tensorrt.rst
index 08addeb..a39d9c8 100644
--- a/docs/deploy/tensorrt.rst
+++ b/docs/deploy/tensorrt.rst
@@ -124,7 +124,7 @@ have to be built.
.. code:: python
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
loaded_lib = tvm.runtime.load_module('compiled.so')
gen_module = tvm.contrib.graph_executor.GraphModule(loaded_lib['default'](dev))
input_data = np.random.uniform(0, 1, input_shape).astype(dtype)
diff --git a/docs/dev/index.rst b/docs/dev/index.rst
index ed0f1a1..5189ffd 100644
--- a/docs/dev/index.rst
+++ b/docs/dev/index.rst
@@ -144,7 +144,7 @@ The main goal of TVM's runtime is to provide a minimal API for loading and execu
import tvm
# Example runtime execution program in python, with type annotated
mod: tvm.runtime.Module = tvm.runtime.load_module("compiled_artifact.so")
- arr: tvm.runtime.NDArray = tvm.nd.array([1, 2, 3], device=tvm.gpu(0))
+ arr: tvm.runtime.NDArray = tvm.nd.array([1, 2, 3], device=tvm.cuda(0))
fun: tvm.runtime.PackedFunc = mod["addone"]
fun(a)
print(a.asnumpy())
@@ -164,8 +164,8 @@ The above example only deals with a simple `addone` function. The code snippet b
import tvm
# Example runtime execution program in python, with types annotated
factory: tvm.runtime.Module = tvm.runtime.load_module("resnet18.so")
- # Create a stateful graph execution module for resnet18 on gpu(0)
- gmod: tvm.runtime.Module = factory["resnet18"](tvm.gpu(0))
+ # Create a stateful graph execution module for resnet18 on cuda(0)
+ gmod: tvm.runtime.Module = factory["resnet18"](tvm.cuda(0))
data: tvm.runtime.NDArray = get_input_data()
# set input
gmod["set_input"](0, data)
diff --git a/golang/src/device.go b/golang/src/device.go
index 6569e44..b2203a3 100644
--- a/golang/src/device.go
+++ b/golang/src/device.go
@@ -29,10 +29,10 @@ import "C"
// KDLCPU is golang enum correspond to TVM device type kDLCPU.
var KDLCPU = int32(C.kDLCPU)
-// KDLGPU is golang enum correspond to TVM device type kDLGPU.
-var KDLGPU = int32(C.kDLGPU)
-// KDLCPUPinned is golang enum correspond to TVM device type kDLCPUPinned.
-var KDLCPUPinned = int32(C.kDLCPUPinned)
+// kDLCUDA is golang enum correspond to TVM device type kDLCUDA.
+var kDLCUDA = int32(C.kDLCUDA)
+// kDLCUDAHost is golang enum correspond to TVM device type kDLCUDAHost.
+var kDLCUDAHost = int32(C.kDLCUDAHost)
// KDLOpenCL is golang enum correspond to TVM device type kDLOpenCL.
var KDLOpenCL = int32(C.kDLOpenCL)
// KDLMetal is golang enum correspond to TVM device type kDLMetal.
@@ -61,14 +61,14 @@ func CPU(index int32) Device {
return Device{KDLCPU, index}
}
-// GPU returns the Device object for GPU target on given index
-func GPU(index int32) Device {
- return Device{KDLGPU, index}
+// CUDA returns the Device object for CUDA target on given index
+func CUDA(index int32) Device {
+ return Device{kDLCUDA, index}
}
-// CPUPinned returns the Device object for CPUPinned target on given index
-func CPUPinned(index int32) Device {
- return Device{KDLCPUPinned, index}
+// CUDAHost returns the Device object for CUDAHost target on given index
+func CUDAHost(index int32) Device {
+ return Device{kDLCUDAHost, index}
}
// OpenCL returns the Device object for OpenCL target on given index
diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index 57bf51d..c3527d8 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -231,10 +231,10 @@ inline const char* DeviceName(int type) {
switch (type) {
case kDLCPU:
return "cpu";
- case kDLGPU:
- return "gpu";
- case kDLCPUPinned:
- return "cpu_pinned";
+ case kDLCUDA:
+ return "cuda";
+ case kDLCUDAHost:
+ return "cuda_host";
case kDLOpenCL:
return "opencl";
case kDLSDAccel:
diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index 4643062..0adad82 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -30,7 +30,7 @@ from ._ffi import register_object, register_func, register_extension, get_global
# top-level alias
# tvm.runtime
from .runtime.object import Object
-from .runtime.ndarray import device, cpu, gpu, opencl, cl, vulkan, metal, mtl
+from .runtime.ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl
from .runtime.ndarray import vpi, rocm, ext_dev, micro_dev, hexagon
from .runtime import ndarray as nd
diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py
index 3e79801..4eda5e8 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -164,7 +164,7 @@ class Device(ctypes.Structure):
_fields_ = [("device_type", ctypes.c_int), ("device_id", ctypes.c_int)]
MASK2STR = {
1: "cpu",
- 2: "gpu",
+ 2: "cuda",
4: "opencl",
5: "aocl",
6: "sdaccel",
@@ -182,7 +182,6 @@ class Device(ctypes.Structure):
"stackvm": 1,
"cpu": 1,
"c": 1,
- "gpu": 2,
"cuda": 2,
"nvptx": 2,
"cl": 4,
diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index 6a7c098..0124d00 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -249,8 +249,8 @@ def get_target_compute_version(target=None):
return major + "." + minor
# 3. GPU
- if tvm.gpu(0).exist:
- return tvm.gpu(0).compute_version
+ if tvm.cuda(0).exist:
+ return tvm.cuda(0).compute_version
warnings.warn(
"No CUDA architecture was specified or GPU detected."
@@ -331,8 +331,8 @@ def have_tensorcore(compute_version=None, target=None):
isn't specified.
"""
if compute_version is None:
- if tvm.gpu(0).exist:
- compute_version = tvm.gpu(0).compute_version
+ if tvm.cuda(0).exist:
+ compute_version = tvm.cuda(0).compute_version
else:
if target is None or "arch" not in target.attrs:
warnings.warn(
diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py
index 54e75ba..265dedb 100644
--- a/python/tvm/runtime/__init__.py
+++ b/python/tvm/runtime/__init__.py
@@ -26,7 +26,7 @@ from .profiling import Report
# function exposures
from .object_generic import convert_to_object, convert, const
-from .ndarray import device, cpu, gpu, opencl, cl, vulkan, metal, mtl
+from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, ext_dev, micro_dev
from .module import load_module, enabled, system_lib
from .container import String
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index befe077..823b1cc 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -17,6 +17,7 @@
# pylint: disable=invalid-name, unused-import, redefined-outer-name
"""Runtime NDArray API"""
import ctypes
+import warnings
import numpy as np
import tvm._ffi
@@ -254,8 +255,7 @@ def device(dev_type, dev_id=0):
.. code-block:: python
assert tvm.device("cpu", 1) == tvm.cpu(1)
- assert tvm.device("gpu", 0) == tvm.gpu(0)
- assert tvm.device("cuda", 0) == tvm.gpu(0)
+ assert tvm.device("cuda", 0) == tvm.cuda(0)
"""
if isinstance(dev_type, string_types):
if "-device=micro_dev" in dev_type:
@@ -362,9 +362,27 @@ def cpu(dev_id=0):
return Device(1, dev_id)
+def cuda(dev_id=0):
+ """Construct a CUDA GPU device
+
+ Parameters
+ ----------
+ dev_id : int, optional
+ The integer device id
+
+ Returns
+ -------
+ dev : Device
+ The created device
+ """
+ return Device(2, dev_id)
+
+
def gpu(dev_id=0):
- """Construct a GPU device
+ """Construct a CUDA GPU device
+ deprecated:: 0.9.0
+ Use :py:func:`tvm.cuda` instead.
Parameters
----------
dev_id : int, optional
@@ -375,6 +393,9 @@ def gpu(dev_id=0):
dev : Device
The created device
"""
+ warnings.warn(
+ "Please use tvm.cuda() instead of tvm.gpu(). tvm.gpu() is going to be deprecated in 0.9.0",
+ )
return Device(2, dev_id)
diff --git a/python/tvm/testing.py b/python/tvm/testing.py
index edcf4a6..e4cf62c 100644
--- a/python/tvm/testing.py
+++ b/python/tvm/testing.py
@@ -464,9 +464,9 @@ def _compose(args, decs):
def uses_gpu(*args):
- """Mark to differentiate tests that use the GPU is some capacity.
+ """Mark to differentiate tests that use the GPU in some capacity.
- These tests will be run on CPU-only test nodes and on test nodes with GPUS.
+ These tests will be run on CPU-only test nodes and on test nodes with GPUs.
To mark a test that must have a GPU present to run, use
:py:func:`tvm.testing.requires_gpu`.
@@ -490,7 +490,7 @@ def requires_gpu(*args):
Function to mark
"""
_requires_gpu = [
- pytest.mark.skipif(not tvm.gpu().exist, reason="No GPU present"),
+ pytest.mark.skipif(not tvm.cuda().exist, reason="No GPU present"),
*uses_gpu(),
]
return _compose(args, _requires_gpu)
@@ -499,7 +499,7 @@ def requires_gpu(*args):
def requires_cuda(*args):
"""Mark a test as requiring the CUDA runtime.
- This also marks the test as requiring a gpu.
+ This also marks the test as requiring a cuda gpu.
Parameters
----------
@@ -618,7 +618,7 @@ def requires_tensorcore(*args):
_requires_tensorcore = [
pytest.mark.tensorcore,
pytest.mark.skipif(
- not tvm.gpu().exist or not nvcc.have_tensorcore(tvm.gpu(0).compute_version),
+ not tvm.cuda().exist or not nvcc.have_tensorcore(tvm.cuda(0).compute_version),
reason="No tensorcore present",
),
*requires_gpu(),
diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py
index 65bf9d1..067f272 100644
--- a/python/tvm/topi/cuda/conv2d_alter_op.py
+++ b/python/tvm/topi/cuda/conv2d_alter_op.py
@@ -225,7 +225,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
if topi_tmpl == "conv2d_HWNCnc_tensorcore.cuda":
assert data_layout == "HWNC" and kernel_layout == "HWOI"
- assert float(tvm.gpu(0).compute_version) >= 7.5
+ assert float(tvm.cuda(0).compute_version) >= 7.5
H, W, N, CI = get_const_tuple(data.shape)
KH, KW, CO, _ = get_const_tuple(kernel.shape)
diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index f064360..9a3b86d 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -878,7 +878,7 @@ def non_max_suppression(
np_valid_count = np.array([4])
s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "cuda")
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
tvm_data = tvm.nd.array(np_data, dev)
tvm_valid_count = tvm.nd.array(np_valid_count, dev)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), dev)
diff --git a/rust/tvm-sys/src/device.rs b/rust/tvm-sys/src/device.rs
index 910cc59..7b659ef 100644
--- a/rust/tvm-sys/src/device.rs
+++ b/rust/tvm-sys/src/device.rs
@@ -66,7 +66,7 @@ use thiserror::Error;
pub enum DeviceType {
CPU = 1,
GPU,
- CPUPinned,
+ CUDAHost,
OpenCL,
Vulkan,
Metal,
@@ -101,8 +101,8 @@ impl Display for DeviceType {
"{}",
match self {
DeviceType::CPU => "cpu",
- DeviceType::GPU => "gpu",
- DeviceType::CPUPinned => "cpu_pinned",
+ DeviceType::GPU => "cuda",
+ DeviceType::CUDAHost => "cuda_host",
DeviceType::OpenCL => "opencl",
DeviceType::Vulkan => "vulkan",
DeviceType::Metal => "metal",
@@ -210,7 +210,7 @@ macro_rules! impl_tvm_device {
impl_tvm_device!(
DLDeviceType_kDLCPU: [cpu, llvm, stackvm],
- DLDeviceType_kDLGPU: [gpu, cuda, nvptx],
+ DLDeviceType_kDLCUDA: [gpu, cuda, nvptx],
DLDeviceType_kDLOpenCL: [cl],
DLDeviceType_kDLMetal: [metal],
DLDeviceType_kDLVPI: [vpi],
diff --git a/rust/tvm-sys/src/value.rs b/rust/tvm-sys/src/value.rs
index f939d51..1b4f773 100644
--- a/rust/tvm-sys/src/value.rs
+++ b/rust/tvm-sys/src/value.rs
@@ -86,7 +86,7 @@ macro_rules! impl_tvm_device {
impl_tvm_device!(
DLDeviceType_kDLCPU: [cpu, llvm, stackvm],
- DLDeviceType_kDLGPU: [gpu, cuda, nvptx],
+ DLDeviceType_kDLCUDA: [gpu, cuda, nvptx],
DLDeviceType_kDLOpenCL: [cl],
DLDeviceType_kDLMetal: [metal],
DLDeviceType_kDLVPI: [vpi],
diff --git a/src/auto_scheduler/search_policy/utils.h b/src/auto_scheduler/search_policy/utils.h
index eb2cd69..ffd4bf4 100644
--- a/src/auto_scheduler/search_policy/utils.h
+++ b/src/auto_scheduler/search_policy/utils.h
@@ -53,7 +53,7 @@ inline bool IsCPUTask(const SearchTask& task) {
/*! \brief Return whether the search task is targeting a GPU. */
inline bool IsGPUTask(const SearchTask& task) {
- return (task)->target->kind->device_type == kDLGPU ||
+ return (task)->target->kind->device_type == kDLCUDA ||
(task)->target->kind->device_type == kDLOpenCL ||
(task)->target->kind->device_type == kDLVulkan ||
(task)->target->kind->device_type == kDLMetal ||
@@ -63,7 +63,7 @@ inline bool IsGPUTask(const SearchTask& task) {
/*! \brief Return whether the search task is targeting a CUDA GPU. */
inline bool IsCUDATask(const SearchTask& task) {
- return (task)->target->kind->device_type == kDLGPU;
+ return (task)->target->kind->device_type == kDLCUDA;
}
/*! \brief Return whether the search task is targeting a OpenCL GPU. */
diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc
index 80fb71d..03d880e 100755
--- a/src/auto_scheduler/search_task.cc
+++ b/src/auto_scheduler/search_task.cc
@@ -57,11 +57,11 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
const auto device_type = target->kind->device_type;
if (device_type == kDLCPU) {
return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64, 0, 0, 0, 0, 0);
- } else if (device_type == kDLGPU || device_type == kDLROCM) {
+ } else if (device_type == kDLCUDA || device_type == kDLROCM) {
auto dev = Device{static_cast<DLDeviceType>(device_type), 0};
- auto device_name = device_type == kDLGPU ? "device_api.gpu" : "device_api.rocm";
+ auto device_name = device_type == kDLCUDA ? "device_api.cuda" : "device_api.rocm";
auto func = tvm::runtime::Registry::Get(device_name);
- ICHECK(func != nullptr) << "Cannot find GPU device_api in registry";
+ ICHECK(func != nullptr) << "Cannot find CUDA device_api in registry";
auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());
tvm::runtime::TVMRetValue ret;
diff --git a/src/contrib/tf_op/tvm_dso_op_kernels.cc b/src/contrib/tf_op/tvm_dso_op_kernels.cc
index c816119..fb483ee 100644
--- a/src/contrib/tf_op/tvm_dso_op_kernels.cc
+++ b/src/contrib/tf_op/tvm_dso_op_kernels.cc
@@ -69,7 +69,7 @@ class TensorAsBuf {
if (device_type == kDLCPU) {
memcpy(origin_buf, buf + offset, size);
#ifdef TF_TVMDSOOP_ENABLE_GPU
- } else if (device_type == kDLGPU) {
+ } else if (device_type == kDLCUDA) {
cudaMemcpy(origin_buf, buf + offset, size, cudaMemcpyDeviceToDevice);
#endif
} else {
@@ -85,7 +85,7 @@ class TensorAsBuf {
if (device_type == kDLCPU) {
memcpy(buf + offset, origin_buf, size);
#ifdef TF_TVMDSOOP_ENABLE_GPU
- } else if (device_type == kDLGPU) {
+ } else if (device_type == kDLCUDA) {
cudaMemcpy(buf + offset, origin_buf, size, cudaMemcpyDeviceToDevice);
#endif
} else {
@@ -192,7 +192,7 @@ class TVMDSOOpTrait<CPUDevice> {
template <>
class TVMDSOOpTrait<GPUDevice> {
public:
- static const int device_type = kDLGPU;
+ static const int device_type = kDLCUDA;
static int device_id(OpKernelContext* context) {
auto device_base = context->device();
diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc
index 880407f..00b6fed 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -429,7 +429,7 @@ class RelayBuildModule : public runtime::ModuleNode {
Target CreateDefaultTarget(int device_type) {
std::string name = runtime::DeviceName(device_type);
if (name == "cpu") return Target("llvm");
- if (name == "gpu") return Target("cuda");
+ if (name == "cuda") return Target("cuda");
return Target(name);
}
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index afc01aa..832cc0e 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -234,7 +234,7 @@ std::vector<int64_t> ToAllocTensorShape(NDArray shape) {
Target CreateDefaultTarget(int device_type) {
std::string name = runtime::DeviceName(device_type);
if (name == "cpu") return Target("llvm");
- if (name == "gpu") return Target("cuda");
+ if (name == "cuda") return Target("cuda");
return Target(name);
}
diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc
index 006064e..da67c2e 100644
--- a/src/runtime/contrib/cudnn/cudnn_utils.cc
+++ b/src/runtime/contrib/cudnn/cudnn_utils.cc
@@ -96,7 +96,7 @@ const void* CuDNNDataType::GetConst<1>(cudnnDataType_t type) {
CuDNNThreadEntry::CuDNNThreadEntry() {
auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
- auto func = runtime::Registry::Get("device_api.gpu");
+ auto func = runtime::Registry::Get("device_api.cuda");
void* ret = (*func)();
cuda_api = static_cast<runtime::DeviceAPI*>(ret);
CUDNN_CALL(cudnnCreate(&handle));
diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc
index e98413e..b8d6f6c 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc
@@ -248,13 +248,13 @@ void TensorRTBuilder::CleanUp() {
void TensorRTBuilder::AllocateDeviceBuffer(nvinfer1::ICudaEngine* engine, const std::string& name,
std::vector<runtime::NDArray>* device_buffers) {
const uint32_t entry_id = entry_id_map_[name];
- if (data_entry_[entry_id]->device.device_type != kDLGPU) {
+ if (data_entry_[entry_id]->device.device_type != kDLCUDA) {
const int binding_index = engine->getBindingIndex(name.c_str());
ICHECK_NE(binding_index, -1);
std::vector<int64_t> shape(data_entry_[entry_id]->shape,
data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim);
device_buffers->at(binding_index) =
- runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLGPU, 0});
+ runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0});
}
}
diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
index 7efa5bf..e963594 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
@@ -135,7 +135,7 @@ class TensorRTRuntime : public JSONRuntimeBase {
const std::string name = nodes_[nid].GetOpName() + "_" + std::to_string(j);
int binding_index = engine->getBindingIndex(name.c_str());
ICHECK_NE(binding_index, -1);
- if (data_entry_[eid]->device.device_type == kDLGPU) {
+ if (data_entry_[eid]->device.device_type == kDLCUDA) {
bindings[binding_index] = data_entry_[eid]->data;
} else {
device_buffers[binding_index].CopyFrom(data_entry_[eid]);
@@ -150,7 +150,7 @@ class TensorRTRuntime : public JSONRuntimeBase {
const std::string& name = engine_and_context.outputs[i];
int binding_index = engine->getBindingIndex(name.c_str());
ICHECK_NE(binding_index, -1);
- if (data_entry_[eid]->device.device_type == kDLGPU) {
+ if (data_entry_[eid]->device.device_type == kDLCUDA) {
bindings[binding_index] = data_entry_[eid]->data;
} else {
bindings[binding_index] = device_buffers[binding_index]->data;
@@ -173,7 +173,7 @@ class TensorRTRuntime : public JSONRuntimeBase {
const std::string& name = engine_and_context.outputs[i];
int binding_index = engine->getBindingIndex(name.c_str());
ICHECK_NE(binding_index, -1);
- if (data_entry_[eid]->device.device_type != kDLGPU) {
+ if (data_entry_[eid]->device.device_type != kDLCUDA) {
device_buffers[binding_index].CopyTo(const_cast<DLTensor*>(data_entry_[eid]));
}
}
diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc
index d6c939b..47f038b 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -111,7 +111,7 @@ class CUDADeviceAPI final : public DeviceAPI {
void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final {
ICHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes";
void* ret;
- if (dev.device_type == kDLCPUPinned) {
+ if (dev.device_type == kDLCUDAHost) {
CUDA_CALL(cudaMallocHost(&ret, nbytes));
} else {
CUDA_CALL(cudaSetDevice(dev.device_id));
@@ -121,7 +121,7 @@ class CUDADeviceAPI final : public DeviceAPI {
}
void FreeDataSpace(Device dev, void* ptr) final {
- if (dev.device_type == kDLCPUPinned) {
+ if (dev.device_type == kDLCUDAHost) {
CUDA_CALL(cudaFreeHost(ptr));
} else {
CUDA_CALL(cudaSetDevice(dev.device_id));
@@ -137,11 +137,11 @@ class CUDADeviceAPI final : public DeviceAPI {
from = static_cast<const char*>(from) + from_offset;
to = static_cast<char*>(to) + to_offset;
- if (dev_from.device_type == kDLCPUPinned) {
+ if (dev_from.device_type == kDLCUDAHost) {
dev_from.device_type = kDLCPU;
}
- if (dev_to.device_type == kDLCPUPinned) {
+ if (dev_to.device_type == kDLCUDAHost) {
dev_to.device_type = kDLCPU;
}
@@ -151,17 +151,17 @@ class CUDADeviceAPI final : public DeviceAPI {
return;
}
- if (dev_from.device_type == kDLGPU && dev_to.device_type == kDLGPU) {
+ if (dev_from.device_type == kDLCUDA && dev_to.device_type == kDLCUDA) {
CUDA_CALL(cudaSetDevice(dev_from.device_id));
if (dev_from.device_id == dev_to.device_id) {
GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
} else {
cudaMemcpyPeerAsync(to, dev_to.device_id, from, dev_from.device_id, size, cu_stream);
}
- } else if (dev_from.device_type == kDLGPU && dev_to.device_type == kDLCPU) {
+ } else if (dev_from.device_type == kDLCUDA && dev_to.device_type == kDLCPU) {
CUDA_CALL(cudaSetDevice(dev_from.device_id));
GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
- } else if (dev_from.device_type == kDLCPU && dev_to.device_type == kDLGPU) {
+ } else if (dev_from.device_type == kDLCPU && dev_to.device_type == kDLCUDA) {
CUDA_CALL(cudaSetDevice(dev_to.device_id));
GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
} else {
@@ -231,16 +231,16 @@ class CUDADeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
-CUDAThreadEntry::CUDAThreadEntry() : pool(kDLGPU, CUDADeviceAPI::Global()) {}
+CUDAThreadEntry::CUDAThreadEntry() : pool(kDLCUDA, CUDADeviceAPI::Global()) {}
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); }
-TVM_REGISTER_GLOBAL("device_api.gpu").set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("device_api.cuda").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});
-TVM_REGISTER_GLOBAL("device_api.cpu_pinned").set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("device_api.cuda_host").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});
diff --git a/src/runtime/module.cc b/src/runtime/module.cc
index d84a821..15b9c0d 100644
--- a/src/runtime/module.cc
+++ b/src/runtime/module.cc
@@ -126,7 +126,7 @@ bool RuntimeEnabled(const std::string& target) {
if (target == "cpu") {
return true;
} else if (target == "cuda" || target == "gpu") {
- f_name = "device_api.gpu";
+ f_name = "device_api.cuda";
} else if (target == "cl" || target == "opencl" || target == "sdaccel") {
f_name = "device_api.opencl";
} else if (target == "mtl" || target == "metal") {
diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc
index 4b52a7d..3d3466b 100644
--- a/src/runtime/ndarray.cc
+++ b/src/runtime/ndarray.cc
@@ -231,8 +231,8 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str
ICHECK_EQ(from_size, to_size) << "TVMArrayCopyFromTo: The size must exactly match";
ICHECK(from->device.device_type == to->device.device_type || from->device.device_type == kDLCPU ||
- to->device.device_type == kDLCPU || from->device.device_type == kDLCPUPinned ||
- to->device.device_type == kDLCPUPinned)
+ to->device.device_type == kDLCPU || from->device.device_type == kDLCUDAHost ||
+ to->device.device_type == kDLCUDAHost)
<< "Can not copy across different device types directly";
// Use the device that is *not* a cpu device to get the correct device
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 474b1b0..cc493b9 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -152,7 +152,7 @@ Map<String, ObjectRef> UpdateNVPTXAttrs(Map<String, ObjectRef> attrs) {
} else {
// Use the compute version of the first CUDA GPU instead
TVMRetValue version;
- if (!DetectDeviceFlag({kDLGPU, 0}, runtime::kComputeVersion, &version)) {
+ if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) {
LOG(WARNING) << "Unable to detect CUDA version, default to \"-mcpu=sm_20\" instead";
arch = 20;
} else {
@@ -230,7 +230,7 @@ TVM_REGISTER_TARGET_KIND("c", kDLCPU)
.add_attr_option<String>("executor")
.set_default_keys({"cpu"});
-TVM_REGISTER_TARGET_KIND("cuda", kDLGPU)
+TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)
.add_attr_option<String>("mcpu")
.add_attr_option<String>("arch")
.add_attr_option<Bool>("system-lib")
@@ -241,7 +241,7 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLGPU)
.add_attr_option<Integer>("max_threads_per_block")
.set_default_keys({"cuda", "gpu"});
-TVM_REGISTER_TARGET_KIND("nvptx", kDLGPU)
+TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA)
.add_attr_option<String>("mcpu")
.add_attr_option<String>("mtriple")
.add_attr_option<Bool>("system-lib")
diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
index 377ad5c..951bd6c 100644
--- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
+++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
@@ -1089,7 +1089,7 @@ Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule,
}
// Check if current runtime support GPU CUDA
- Device dev{kDLGPU, 0};
+ Device dev{kDLCUDA, 0};
auto api = tvm::runtime::DeviceAPI::Get(dev, true);
if (api == nullptr) {
return stmt;
diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc
index 905384f..3c29e4e 100644
--- a/src/tir/analysis/verify_memory.cc
+++ b/src/tir/analysis/verify_memory.cc
@@ -149,7 +149,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
/// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device.
static bool IsGPUDevice(int dev_type) {
- return kDLGPU == dev_type || kDLOpenCL == dev_type || kDLVulkan == dev_type ||
+ return kDLCUDA == dev_type || kDLOpenCL == dev_type || kDLVulkan == dev_type ||
kDLMetal == dev_type || kDLROCM == dev_type || kOpenGL == dev_type;
}
/// Check if a given DLDeviceType/TVMDeviceExtType value denotes FPGA device.
diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc
index e937393..8cc5c4b 100644
--- a/tests/cpp/build_module_test.cc
+++ b/tests/cpp/build_module_test.cc
@@ -166,7 +166,7 @@ TEST(BuildModule, Heterogeneous) {
// Initialize graph executor.
int cpu_dev_ty = static_cast<int>(kDLCPU);
int cpu_dev_id = 0;
- int gpu_dev_ty = static_cast<int>(kDLGPU);
+ int gpu_dev_ty = static_cast<int>(kDLCUDA);
int gpu_dev_id = 0;
const runtime::PackedFunc* graph_executor =
diff --git a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py
index 9318b0c..f905ef8 100644
--- a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py
+++ b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py
@@ -101,10 +101,10 @@ def test_empty_array():
def test_device():
def test_device_func(dev):
- assert tvm.gpu(7) == dev
+ assert tvm.cuda(7) == dev
return tvm.cpu(0)
- x = test_device_func(tvm.gpu(7))
+ x = test_device_func(tvm.cuda(7))
assert x == tvm.cpu(0)
x = tvm.opencl(10)
x = tvm.testing.device_test(x, x.device_type, x.device_id)
diff --git a/tests/python/contrib/test_cublas.py b/tests/python/contrib/test_cublas.py
index c4e6f89..d871a38 100644
--- a/tests/python/contrib/test_cublas.py
+++ b/tests/python/contrib/test_cublas.py
@@ -35,7 +35,7 @@ def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5):
if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
print("skip because extern function is not available")
return
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
f = tvm.build(s, [A, B, C], target)
a = tvm.nd.array(np.random.uniform(0, 128, size=(n, l)).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(0, 128, size=(l, m)).astype(B.dtype), dev)
@@ -70,7 +70,7 @@ def verify_matmul_add_igemm(in_dtype, out_dtype, rtol=1e-5):
if not tvm.get_global_func("tvm.contrib.cublaslt.matmul", True):
print("skip because extern function is not available")
return
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
f = tvm.build(s, [A, B, C], target)
a_old = np.random.uniform(0, 128, size=(n, l))
b_old = np.random.uniform(0, 128, size=(l, m))
@@ -126,7 +126,7 @@ def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5):
if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
print("skip because extern function is not available")
return
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
f = tvm.build(s, [A, B, C], target)
a = tvm.nd.array(np.random.uniform(size=(j, n, l)).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=(j, l, m)).astype(B.dtype), dev)
diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py
index 690589c..d73f81b 100644
--- a/tests/python/contrib/test_cudnn.py
+++ b/tests/python/contrib/test_cudnn.py
@@ -41,7 +41,7 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1):
if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
print("skip because cudnn is not enabled...")
return
- if data_dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+ if data_dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
@@ -71,7 +71,7 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1):
s = te.create_schedule(Y.op)
# validation
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
f = tvm.build(s, [X, W, Y], "cuda --host=llvm", name="conv2d")
x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype)
w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype)
@@ -149,7 +149,7 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1):
s = te.create_schedule(Y.op)
# validation
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
f = tvm.build(s, [X, W, Y], target="cuda --host=llvm", name="conv3d")
x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype)
w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype)
@@ -177,7 +177,7 @@ def verify_softmax(shape, axis, dtype="float32"):
B = cudnn.softmax(A, axis)
s = te.create_schedule([B.op])
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = tvm.topi.testing.softmax_python(a_np)
a = tvm.nd.array(a_np, dev)
@@ -192,7 +192,7 @@ def verify_softmax_4d(shape, dtype="float32"):
B = cudnn.softmax(A, axis=1)
s = te.create_schedule([B.op])
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
n, c, h, w = shape
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = tvm.topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h * w, c))
diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py
index 9810759..52ee87e 100644
--- a/tests/python/contrib/test_tensorrt.py
+++ b/tests/python/contrib/test_tensorrt.py
@@ -35,7 +35,7 @@ from tvm.relay.op.contrib import tensorrt
def skip_codegen_test():
"""Skip test if TensorRT and CUDA codegen are not present"""
- if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist:
+ if not tvm.runtime.enabled("cuda") or not tvm.cuda(0).exist:
print("Skip because CUDA is not enabled.")
return True
if not tvm.get_global_func("relay.ext.tensorrt", True):
@@ -45,7 +45,7 @@ def skip_codegen_test():
def skip_runtime_test():
- if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist:
+ if not tvm.runtime.enabled("cuda") or not tvm.cuda(0).exist:
print("Skip because CUDA is not enabled.")
return True
if not tensorrt.is_tensorrt_runtime_enabled():
@@ -143,10 +143,10 @@ def run_and_verify_model(model):
with tvm.transform.PassContext(
opt_level=3, config={"relay.ext.tensorrt.options": config}
):
- exec = relay.create_executor(mode, mod=mod, device=tvm.gpu(0), target="cuda")
+ exec = relay.create_executor(mode, mod=mod, device=tvm.cuda(0), target="cuda")
else:
with tvm.transform.PassContext(opt_level=3):
- exec = relay.create_executor(mode, mod=mod, device=tvm.gpu(0), target="cuda")
+ exec = relay.create_executor(mode, mod=mod, device=tvm.cuda(0), target="cuda")
res = exec.evaluate()(i_data, **params) if not skip_runtime_test() else None
return res
@@ -199,12 +199,12 @@ def test_tensorrt_simple():
opt_level=3, config={"relay.ext.tensorrt.options": config}
):
relay_exec = relay.create_executor(
- mode, mod=mod, device=tvm.gpu(0), target="cuda"
+ mode, mod=mod, device=tvm.cuda(0), target="cuda"
)
else:
with tvm.transform.PassContext(opt_level=3):
relay_exec = relay.create_executor(
- mode, mod=mod, device=tvm.gpu(0), target="cuda"
+ mode, mod=mod, device=tvm.cuda(0), target="cuda"
)
if not skip_runtime_test():
result_dict[result_key] = relay_exec.evaluate()(x_data, y_data, z_data)
@@ -247,7 +247,7 @@ def test_tensorrt_not_compatible():
mod, config = tensorrt.partition_for_tensorrt(mod)
for mode in ["graph", "vm"]:
with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
- exec = relay.create_executor(mode, mod=mod, device=tvm.gpu(0), target="cuda")
+ exec = relay.create_executor(mode, mod=mod, device=tvm.cuda(0), target="cuda")
if not skip_runtime_test():
results = exec.evaluate()(x_data)
@@ -273,7 +273,7 @@ def test_tensorrt_serialize_graph_executor():
return graph, lib, params
def run_graph(graph, lib, params):
- mod_ = graph_executor.create(graph, lib, device=tvm.gpu(0))
+ mod_ = graph_executor.create(graph, lib, device=tvm.cuda(0))
mod_.load_params(params)
mod_.run(data=i_data)
res = mod_.get_output(0)
@@ -330,7 +330,7 @@ def test_tensorrt_serialize_vm():
def run_vm(code, lib):
vm_exec = tvm.runtime.vm.Executable.load_exec(code, lib)
- vm = VirtualMachine(vm_exec, tvm.gpu(0))
+ vm = VirtualMachine(vm_exec, tvm.cuda(0))
result = vm.invoke("main", data=i_data)
return result
@@ -1415,7 +1415,7 @@ def test_empty_subgraph():
x_data = np.random.uniform(-1, 1, x_shape).astype("float32")
for mode in ["graph", "vm"]:
with tvm.transform.PassContext(opt_level=3):
- exec = relay.create_executor(mode, mod=mod, device=tvm.gpu(0), target="cuda")
+ exec = relay.create_executor(mode, mod=mod, device=tvm.cuda(0), target="cuda")
if not skip_runtime_test():
results = exec.evaluate()(x_data)
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index f9f3bba..067af7f 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1240,7 +1240,7 @@ def test_type_as():
check_fp16 = False
try:
# Only check half precision on supported hardwares.
- if have_fp16(tvm.gpu(0).compute_version):
+ if have_fp16(tvm.cuda(0).compute_version):
check_fp16 = True
except Exception as e:
# If GPU is not enabled in TVM, skip the fp16 test.
diff --git a/tests/python/nightly/quantization/test_quantization_accuracy.py b/tests/python/nightly/quantization/test_quantization_accuracy.py
index 57fa49e..5cf4dfe 100644
--- a/tests/python/nightly/quantization/test_quantization_accuracy.py
+++ b/tests/python/nightly/quantization/test_quantization_accuracy.py
@@ -93,7 +93,7 @@ def get_model(model_name, batch_size, qconfig, target=None, original=False, simu
def eval_acc(
- model, dataset, batch_fn, target=tvm.target.cuda(), device=tvm.gpu(), log_interval=100
+ model, dataset, batch_fn, target=tvm.target.cuda(), device=tvm.cuda(), log_interval=100
):
with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(model, target)
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index fe5e048..7d1c577 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -508,7 +508,7 @@ def verify_any_conv2d(
targets = None
if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
- targets = [("cuda -libs=cudnn", tvm.gpu(0))]
+ targets = [("cuda -libs=cudnn", tvm.cuda(0))]
check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=targets)
@@ -811,7 +811,7 @@ def verify_any_dense(
targets = None
if use_cublas and tvm.get_global_func("tvm.contrib.cublas.matmul", True):
- targets = [("cuda -libs=cublas", tvm.gpu(0))]
+ targets = [("cuda -libs=cublas", tvm.cuda(0))]
check_result([data_np, weight_np], mod, ref_out_shape, assert_shape=True, targets=targets)
diff --git a/tests/python/relay/test_auto_scheduler_tuning.py b/tests/python/relay/test_auto_scheduler_tuning.py
index 13651e7..d3c54fd 100644
--- a/tests/python/relay/test_auto_scheduler_tuning.py
+++ b/tests/python/relay/test_auto_scheduler_tuning.py
@@ -69,7 +69,7 @@ def tune_network(network, target):
# Check the correctness
def get_output(data, lib):
- dev = tvm.gpu()
+ dev = tvm.cuda()
module = graph_executor.GraphModule(lib["default"](dev))
module.set_input("data", data)
module.run()
diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py
index 7d2209a..0d98cc0 100644
--- a/tests/python/relay/test_cpp_build_module.py
+++ b/tests/python/relay/test_cpp_build_module.py
@@ -65,7 +65,7 @@ def test_basic_build():
def test_fp16_build():
dtype = "float16"
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
if dtype == "float16" and not have_fp16(dev.compute_version):
print("skip because gpu does not support fp16")
return
diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py
index 91b3713..aef3c3c 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -67,7 +67,7 @@ def test_unary_op():
if (
dtype == "float16"
and target == "cuda"
- and not have_fp16(tvm.gpu(0).compute_version)
+ and not have_fp16(tvm.cuda(0).compute_version)
):
continue
intrp = relay.create_executor("graph", device=dev, target=target)
@@ -129,7 +129,7 @@ def test_binary_op():
if (
dtype == "float16"
and target == "cuda"
- and not have_fp16(tvm.gpu(0).compute_version)
+ and not have_fp16(tvm.cuda(0).compute_version)
):
continue
intrp = relay.create_executor("graph", device=dev, target=target)
@@ -158,7 +158,7 @@ def test_expand_dims():
if (
dtype == "float16"
and target == "cuda"
- and not have_fp16(tvm.gpu(0).compute_version)
+ and not have_fp16(tvm.cuda(0).compute_version)
):
continue
data = np.random.uniform(size=dshape).astype(dtype)
@@ -193,7 +193,7 @@ def test_bias_add():
if (
dtype == "float16"
and target == "cuda"
- and not have_fp16(tvm.gpu(0).compute_version)
+ and not have_fp16(tvm.cuda(0).compute_version)
):
continue
intrp = relay.create_executor("graph", device=dev, target=target)
@@ -314,7 +314,7 @@ def test_concatenate():
if (
dtype == "float16"
and target == "cuda"
- and not have_fp16(tvm.gpu(0).compute_version)
+ and not have_fp16(tvm.cuda(0).compute_version)
):
continue
intrp1 = relay.create_executor("graph", device=dev, target=target)
diff --git a/tests/python/relay/test_pass_context_analysis.py b/tests/python/relay/test_pass_context_analysis.py
index e54682b..fe19c47 100644
--- a/tests/python/relay/test_pass_context_analysis.py
+++ b/tests/python/relay/test_pass_context_analysis.py
@@ -26,19 +26,19 @@ from tvm.relay.analysis import context_analysis
def test_device_copy():
- if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
+ if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist:
return
mod = tvm.IRModule()
x = relay.var("x", shape=(2, 3))
- copy = relay.op.device_copy(x, tvm.cpu(), tvm.gpu())
+ copy = relay.op.device_copy(x, tvm.cpu(), tvm.cuda())
out = copy + relay.const(np.random.rand(2, 3))
glb_var = relay.GlobalVar("main")
mod[glb_var] = relay.Function([x], out)
ca = context_analysis(mod, tvm.cpu())
cpu_dev = tvm.cpu().device_type
- gpu_dev = tvm.gpu().device_type
+ gpu_dev = tvm.cuda().device_type
for expr, dev in ca.items():
if isinstance(expr, _expr.Call):
assert dev[0].value == gpu_dev
@@ -49,7 +49,7 @@ def test_device_copy():
def test_shape_func():
- if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
+ if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist:
return
mod = tvm.IRModule()
@@ -65,11 +65,11 @@ def test_shape_func():
is_inputs = [False]
shape_func = relay.op.vm.shape_func(fn, ins, outs, is_inputs)
mod["main"] = relay.Function([x, out], shape_func)
- ca = context_analysis(mod, tvm.gpu())
+ ca = context_analysis(mod, tvm.cuda())
main = mod["main"]
cpu_dev = tvm.cpu().device_type
- gpu_dev = tvm.gpu().device_type
+ gpu_dev = tvm.cuda().device_type
assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev
# The output of shape func should be on cpu.
assert main.params[1] in ca and ca[main.params[1]][0].value == cpu_dev
@@ -78,7 +78,7 @@ def test_shape_func():
def test_vm_shape_of():
- if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
+ if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist:
return
mod = tvm.IRModule()
@@ -86,17 +86,17 @@ def test_vm_shape_of():
x = relay.var("x", shape=data_shape)
y = relay.op.vm.shape_of(x)
mod["main"] = relay.Function([x], y)
- ca = context_analysis(mod, tvm.gpu())
+ ca = context_analysis(mod, tvm.cuda())
main = mod["main"]
cpu_dev = tvm.cpu().device_type
- gpu_dev = tvm.gpu().device_type
+ gpu_dev = tvm.cuda().device_type
assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev
assert main.body in ca and ca[main.body][0].value == cpu_dev
def test_alloc_storage():
- if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
+ if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist:
return
mod = tvm.IRModule()
@@ -104,14 +104,14 @@ def test_alloc_storage():
size = relay.Var("size", relay.scalar_type("int64"))
alignment = relay.Var("alignment", relay.scalar_type("int64"))
# allocate a chunk on of memory on gpu.
- sto = relay.op.memory.alloc_storage(size, alignment, tvm.gpu())
+ sto = relay.op.memory.alloc_storage(size, alignment, tvm.cuda())
mod["main"] = relay.Function([size, alignment], sto)
- ca = context_analysis(mod, tvm.gpu())
+ ca = context_analysis(mod, tvm.cuda())
main = mod["main"]
body = main.body
cpu_dev = tvm.cpu().device_type
- gpu_dev = tvm.gpu().device_type
+ gpu_dev = tvm.cuda().device_type
# Inputs are unified with alloc storage inputs which are on cpu
assert main.params[0] in ca and ca[main.params[0]][0].value == cpu_dev
assert main.params[1] in ca and ca[main.params[1]][0].value == cpu_dev
@@ -126,7 +126,7 @@ def test_alloc_storage():
def test_alloc_tensor():
- if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
+ if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist:
return
mod = tvm.IRModule()
@@ -136,12 +136,12 @@ def test_alloc_tensor():
sh = relay.const(np.array([3, 2]), dtype="int64")
at = relay.op.memory.alloc_tensor(sto, relay.const(0, dtype="int64"), sh)
mod["main"] = relay.Function([sto], at)
- ca = context_analysis(mod, tvm.gpu())
+ ca = context_analysis(mod, tvm.cuda())
main = mod["main"]
body = main.body
cpu_dev = tvm.cpu().device_type
- gpu_dev = tvm.gpu().device_type
+ gpu_dev = tvm.cuda().device_type
# Input of the function falls back to the default device gpu
assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev
@@ -155,7 +155,7 @@ def test_alloc_tensor():
def test_vm_reshape_tensor():
- if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
+ if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist:
return
x = relay.var("x", shape=(2, 8), dtype="float32")
@@ -163,12 +163,12 @@ def test_vm_reshape_tensor():
y = relay.op.vm.reshape_tensor(x, shape, [2, 4, 2])
mod = tvm.IRModule()
mod["main"] = relay.Function([x], y)
- ca = context_analysis(mod, tvm.gpu())
+ ca = context_analysis(mod, tvm.cuda())
main = mod["main"]
body = main.body
cpu_dev = tvm.cpu().device_type
- gpu_dev = tvm.gpu().device_type
+ gpu_dev = tvm.cuda().device_type
# Input of the function falls back to the default device gpu
assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev
@@ -181,7 +181,7 @@ def test_vm_reshape_tensor():
def test_dynamic_input():
- if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
+ if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist:
return
mod = tvm.IRModule()
@@ -195,7 +195,7 @@ def test_dynamic_input():
ca = context_analysis(mod, tvm.cpu())
main = mod["main"]
- gpu_dev = tvm.gpu().device_type
+ gpu_dev = tvm.cuda().device_type
assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev
assert main.params[1] in ca and ca[main.params[1]][0].value == gpu_dev
assert main.body in ca and ca[main.body][0].value == gpu_dev
diff --git a/tests/python/topi/python/test_topi_relu.py b/tests/python/topi/python/test_topi_relu.py
index 9acf98d..947a6ca 100644
--- a/tests/python/topi/python/test_topi_relu.py
+++ b/tests/python/topi/python/test_topi_relu.py
@@ -35,7 +35,7 @@ def verify_relu(m, n, dtype="float32"):
b_np = a_np * (a_np > 0)
def check_target(target, dev):
- if dtype == "float16" and target == "cuda" and not have_fp16(tvm.gpu(0).compute_version):
+ if dtype == "float16" and target == "cuda" and not have_fp16(tvm.cuda(0).compute_version):
print("Skip because %s does not have fp16 support" % target)
return
print("Running on target: %s" % target)
diff --git a/tests/python/topi/python/test_topi_tensor.py b/tests/python/topi/python/test_topi_tensor.py
index d395c0c..2d4eed3 100644
--- a/tests/python/topi/python/test_topi_tensor.py
+++ b/tests/python/topi/python/test_topi_tensor.py
@@ -95,7 +95,7 @@ def verify_vectorization(n, m, dtype):
if not tvm.testing.device_enabled(targeta):
print("Skip because %s is not enabled" % targeta)
return
- if dtype == "float16" and targeta == "cuda" and not have_fp16(tvm.gpu(0).compute_version):
+ if dtype == "float16" and targeta == "cuda" and not have_fp16(tvm.cuda(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
with tvm.target.Target(targeta):
diff --git a/tests/python/unittest/test_runtime_graph_cuda_graph.py b/tests/python/unittest/test_runtime_graph_cuda_graph.py
index ee7750e..fb0c736 100644
--- a/tests/python/unittest/test_runtime_graph_cuda_graph.py
+++ b/tests/python/unittest/test_runtime_graph_cuda_graph.py
@@ -73,7 +73,7 @@ def test_graph_simple():
def check_verify():
mlib = tvm.build(s, [A, B], "cuda", name="myadd")
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
try:
mod = cuda_graph_executor.create(graph, mlib, dev)
except ValueError:
diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py
index f85edfc..3100414 100644
--- a/tests/python/unittest/test_runtime_module_based_interface.py
+++ b/tests/python/unittest/test_runtime_module_based_interface.py
@@ -97,7 +97,7 @@ def test_gpu():
with relay.build_config(opt_level=3):
complied_graph_lib = relay.build_module.build(mod, "cuda", params=params)
data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32")
- dev = tvm.gpu()
+ dev = tvm.cuda()
# raw api
gmod = complied_graph_lib["default"](dev)
@@ -190,7 +190,7 @@ def test_mod_export():
# test the robustness wrt to parent module destruction
def setup_gmod():
loaded_lib = tvm.runtime.load_module(path_lib)
- dev = tvm.gpu()
+ dev = tvm.cuda()
return loaded_lib["default"](dev)
gmod = setup_gmod()
@@ -378,7 +378,7 @@ def test_remove_package_params():
fo.write(runtime.save_param_dict(complied_graph_lib.get_params()))
loaded_lib = tvm.runtime.load_module(path_lib)
data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32")
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
# raw api
gmod = loaded_lib["default"](dev)
@@ -559,7 +559,7 @@ def test_cuda_graph_executor():
complied_graph_lib = relay.build_module.build(mod, "cuda", params=params)
data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32")
- dev = tvm.gpu()
+ dev = tvm.cuda()
try:
gmod = complied_graph_lib["cuda_graph_create"](dev)
except:
diff --git a/tests/python/unittest/test_target_codegen_blob.py b/tests/python/unittest/test_target_codegen_blob.py
index c769819..2a30989 100644
--- a/tests/python/unittest/test_target_codegen_blob.py
+++ b/tests/python/unittest/test_target_codegen_blob.py
@@ -57,7 +57,7 @@ def test_synthetic():
loaded_lib = tvm.runtime.load_module(path_lib)
data = np.random.uniform(-1, 1, size=input_shape).astype("float32")
- dev = tvm.gpu()
+ dev = tvm.cuda()
module = graph_executor.GraphModule(loaded_lib["default"](dev))
module.set_input("data", data)
module.run()
@@ -68,7 +68,7 @@ def test_synthetic():
@tvm.testing.uses_gpu
def test_cuda_lib():
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
for device in ["llvm", "cuda"]:
if not tvm.testing.device_enabled(device):
print("skip because %s is not enabled..." % device)
diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py
index e639e6b..846bdcb 100644
--- a/tests/python/unittest/test_target_codegen_cuda.py
+++ b/tests/python/unittest/test_target_codegen_cuda.py
@@ -33,10 +33,10 @@ def test_cuda_vectorize_add():
num_thread = 8
def check_cuda(dtype, n, lanes):
- if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+ if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
- if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version):
+ if dtype == "int8" and not have_int8(tvm.cuda(0).compute_version):
print("skip because gpu does not support int8")
return
A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes))
@@ -46,7 +46,7 @@ def test_cuda_vectorize_add():
s[B].bind(xo, bx)
s[B].bind(xi, tx)
fun = tvm.build(s, [A, B], "cuda")
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np.random.uniform(size=(n, lanes)))
c = tvm.nd.empty((n,), B.dtype, dev)
fun(a, c)
@@ -70,7 +70,7 @@ def test_cuda_vectorize_add():
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_bf16_vectorize_add():
- if not have_bf16(tvm.gpu(0).compute_version):
+ if not have_bf16(tvm.cuda(0).compute_version):
print("skip because gpu does not support bf16")
return
num_thread = 8
@@ -99,7 +99,7 @@ def test_cuda_bf16_vectorize_add():
disabled_pass=["tir.BF16Promote", "tir.BF16CastElimination", "tir.BF16TypeLowering"]
):
fun = tvm.build(s, [A, B], "cuda")
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
np_a = np.random.uniform(size=(n, lanes)).astype("float32")
np_a = np_bf162np_float(np_float2np_bf16(np_a))
a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np_float2np_bf16(np_a))
@@ -120,7 +120,7 @@ def test_cuda_multiply_add():
num_thread = 8
def check_cuda(dtype, n, lanes):
- if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version):
+ if dtype == "int8" and not have_int8(tvm.cuda(0).compute_version):
print("skip because gpu does not support int8")
return
A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes))
@@ -138,7 +138,7 @@ def test_cuda_multiply_add():
np_b = np.random.randint(low=-128, high=127, size=(n, lanes))
np_c = np.random.randint(low=0, high=127, size=(n,))
np_d = [sum(x * y) + z for x, y, z in zip(np_a, np_b, np_c)]
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np_a)
b = tvm.nd.empty((n,), B.dtype, dev).copyfrom(np_b)
c = tvm.nd.empty((n,), C.dtype, dev).copyfrom(np_c)
@@ -155,7 +155,7 @@ def test_cuda_vectorize_load():
num_thread = 8
def check_cuda(dtype, n, lanes):
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes))
B = te.compute((n,), lambda i: A[i], name="B")
s = te.create_schedule(B.op)
@@ -181,7 +181,7 @@ def test_cuda_vectorize_load():
def test_cuda_make_int8():
def check_cuda(n, value, lanes):
dtype = "int8"
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
A = te.compute((n, lanes), lambda i, j: tvm.tir.const(value, dtype=dtype))
s = te.create_schedule(A.op)
y, x = s[A].op.axis
@@ -209,7 +209,7 @@ def test_cuda_make_int8():
def test_cuda_make_int4():
def check_cuda(n, value, lanes):
dtype = "int4"
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
A = te.compute((n, lanes), lambda i, j: tvm.tir.const(value, dtype=dtype))
s = te.create_schedule(A.op)
y, x = s[A].op.axis
@@ -300,7 +300,7 @@ def test_cuda_shuffle():
b_ = np.array((list(range(4))[::-1]) * 16, dtype="int32")
c_ = np.zeros((64,), dtype="int32")
ref = a_ + np.array((list(range(4))) * 16, dtype="int32")
- nda, ndb, ndc = [tvm.nd.array(i, tvm.gpu(0)) for i in [a_, b_, c_]]
+ nda, ndb, ndc = [tvm.nd.array(i, tvm.cuda(0)) for i in [a_, b_, c_]]
module(nda, ndb, ndc)
tvm.testing.assert_allclose(ndc.asnumpy(), ref)
@@ -440,7 +440,7 @@ def test_cuda_const_float_to_half():
s[c].bind(tx, te.thread_axis("threadIdx.x"))
func = tvm.build(s, [a, c], "cuda")
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
a_np = np.random.uniform(size=shape).astype(a.dtype)
c_np = np.zeros(shape=shape, dtype=c.dtype)
a = tvm.nd.array(a_np, dev)
@@ -528,7 +528,7 @@ def test_cuda_floordiv_with_vectorization():
s[B].bind(xio, tx)
func = tvm.build(s, [A, B], "cuda")
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
a_np = np.random.uniform(size=(n,)).astype(A.dtype)
b_np = np.array([a_np[i // k] for i in range(0, n)])
a_nd = tvm.nd.array(a_np, dev)
@@ -554,7 +554,7 @@ def test_cuda_floormod_with_vectorization():
s[B].bind(xio, tx)
func = tvm.build(s, [A, B], "cuda")
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
a_np = np.random.uniform(size=(n,)).astype(A.dtype)
b_np = np.array([a_np[i % k] for i in range(0, n)])
a_nd = tvm.nd.array(a_np, dev)
@@ -567,7 +567,7 @@ def test_cuda_floormod_with_vectorization():
@tvm.testing.requires_cuda
def test_vectorized_casts():
def check(t0, t1, factor):
- if (t0 == "float16" or t1 == "float16") and not have_fp16(tvm.gpu(0).compute_version):
+ if (t0 == "float16" or t1 == "float16") and not have_fp16(tvm.cuda(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
@@ -585,7 +585,7 @@ def test_vectorized_casts():
func = tvm.build(s, [A, B, C], "cuda")
# correctness
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
low, high = (0, 20) if t0.startswith("u") or t1.startswith("u") else (-10, 10)
a_np = np.random.randint(low, high, size=n).astype(A.dtype)
b_np = np.random.randint(low, high, size=n).astype(B.dtype)
@@ -664,7 +664,7 @@ def test_vectorized_intrin1():
]
def run_test(tvm_intrin, np_func, dtype):
- if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+ if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
# set of intrinsics does not support fp16 yet.
@@ -686,7 +686,7 @@ def test_vectorized_intrin1():
B = te.compute((n,), lambda *i: tvm_intrin(A(*i)), name="B")
s = sched(B)
f = tvm.build(s, [A, B], "cuda")
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), dev)
b = tvm.nd.array(np.zeros(shape=(n,)).astype(A.dtype), dev)
f(a, b)
@@ -712,7 +712,7 @@ def test_vectorized_intrin2(dtype="float32"):
B = te.compute((n,), lambda i: tvm_intrin(A[i], c2), name="B")
s = sched(B)
f = tvm.build(s, [A, B], "cuda")
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), dev)
b = tvm.nd.array(np.zeros(shape=(n,)).astype(A.dtype), dev)
f(a, b)
@@ -738,7 +738,7 @@ def test_vectorized_popcount():
B = te.compute((n,), lambda i: tvm.tir.popcount(A[i]), name="B")
s = sched(B)
f = tvm.build(s, [A, B], "cuda")
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
a = tvm.nd.array(np.random.randint(0, 100000, size=n).astype(A.dtype), dev)
b = tvm.nd.array(np.zeros(shape=(n,)).astype(B.dtype), dev)
f(a, b)
@@ -753,11 +753,11 @@ def test_vectorized_popcount():
@tvm.testing.requires_cuda
def test_cuda_vectorize_load_permute_pad():
def check_cuda(dtype, n, l, padding, lanes):
- if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+ if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
A = tvm.te.placeholder((n, l), name="A", dtype=dtype)
B = tvm.te.compute(
(n // lanes, l + 2 * padding, lanes),
diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py
index 56a8514..96b67ea 100644
--- a/tests/python/unittest/test_target_codegen_llvm.py
+++ b/tests/python/unittest/test_target_codegen_llvm.py
@@ -810,7 +810,7 @@ def test_llvm_gpu_lower_atomic():
s = tvm.te.create_schedule(C.op)
f = tvm.build(s, [A], target="nvptx")
- dev = tvm.gpu()
+ dev = tvm.cuda()
a = tvm.nd.array(np.zeros((size,)).astype(A.dtype), dev)
f(a)
ref = np.zeros((size,)).astype(A.dtype)
diff --git a/tests/python/unittest/test_te_schedule_postproc_rewrite_for_tensor_core.py b/tests/python/unittest/test_te_schedule_postproc_rewrite_for_tensor_core.py
index e7a8469..0f97e49 100644
--- a/tests/python/unittest/test_te_schedule_postproc_rewrite_for_tensor_core.py
+++ b/tests/python/unittest/test_te_schedule_postproc_rewrite_for_tensor_core.py
@@ -100,7 +100,7 @@ def tensor_core_matmul(warp_tile_m=16, m=64, n=32, l=96):
func = tvm.build(s, [A, B, C], "cuda")
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
a_np = np.random.uniform(size=(n, l)).astype(A.dtype)
b_np = np.random.uniform(size=(l, m)).astype(B.dtype)
c_np = np.zeros((n, m), dtype=np.float32)
@@ -195,7 +195,7 @@ def tensor_core_batch_matmul(warp_tile_m=16, m=64, n=32, l=96, batch=2):
func = tvm.build(s, [A, B, C], "cuda")
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
a_np = np.random.uniform(size=(batch, n, l)).astype(A.dtype)
b_np = np.random.uniform(size=(batch, l, m)).astype(B.dtype)
c_np = np.zeros((batch, n, m), dtype=np.float32)
diff --git a/tests/python/unittest/test_te_schedule_tensor_core.py b/tests/python/unittest/test_te_schedule_tensor_core.py
index 9491425..e0cf583 100644
--- a/tests/python/unittest/test_te_schedule_tensor_core.py
+++ b/tests/python/unittest/test_te_schedule_tensor_core.py
@@ -256,7 +256,7 @@ def test_tensor_core_batch_matmal():
func = tvm.build(s, [A, B, C], "cuda")
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
a_np = np.random.uniform(size=(batch_size, nn, ll, 32, 16)).astype(A.dtype)
b_np = np.random.uniform(size=(batch_size, ll, mm, 16, 8)).astype(B.dtype)
a = tvm.nd.array(a_np, dev)
@@ -432,7 +432,7 @@ def test_tensor_core_batch_conv():
func = tvm.build(s, [A, W, Conv], "cuda")
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
a_np = np.random.uniform(size=data_shape).astype(A.dtype)
w_np = np.random.uniform(size=kernel_shape).astype(W.dtype)
a = tvm.nd.array(a_np, dev)
diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
index ac72043..2a84078 100644
--- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py
+++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
@@ -92,7 +92,7 @@ def test_lower_warp_memory_correct_indices():
@tvm.testing.requires_cuda
def test_lower_warp_memory_cuda_end_to_end():
def check_cuda(dtype):
- if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+ if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
@@ -114,7 +114,7 @@ def test_lower_warp_memory_cuda_end_to_end():
xo, xi = s[AA].split(s[AA].op.axis[0], 32)
s[AA].bind(xi, tx)
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
func = tvm.build(s, [A, B], "cuda")
A_np = np.array(list(range(m)), dtype=dtype)
B_np = np.array(
@@ -141,7 +141,7 @@ def test_lower_warp_memory_cuda_end_to_end():
@tvm.testing.requires_cuda
def test_lower_warp_memory_cuda_half_a_warp():
def check_cuda(dtype):
- if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+ if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
@@ -181,7 +181,7 @@ def test_lower_warp_memory_cuda_half_a_warp():
_, x = AA.op.axis
s[AA].bind(x, tx)
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
func = tvm.build(s, [A, B], "cuda")
A_np = np.array([list(range(i, m + i)) for i in range(n)], dtype=dtype)
B_np = np.array([list(range(1 + i, m + i)) + [i] for i in range(n)], dtype=dtype)
@@ -198,7 +198,7 @@ def test_lower_warp_memory_cuda_half_a_warp():
@tvm.testing.requires_cuda
def test_lower_warp_memory_cuda_2_buffers():
def check_cuda(dtype):
- if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+ if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
@@ -228,7 +228,7 @@ def test_lower_warp_memory_cuda_2_buffers():
s[BB].bind(xo, bx)
s[BB].bind(xi, tx)
- dev = tvm.gpu(0)
+ dev = tvm.cuda(0)
func = tvm.build(s, [A, B, C], "cuda")
AB_np = np.array(list(range(m)), dtype=dtype)
C_np = np.array(list(range(1, m)) + [0], dtype=dtype) * 2
diff --git a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
index 41fdcbb..8664c86 100644
--- a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
+++ b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
@@ -145,7 +145,7 @@ bias_np = np.random.uniform(size=(1, CO, 1, 1)).astype(np.float32)
conv_np = conv2d_nchw_python(data_np, weight_np, strides, padding)
out_np = np.maximum(conv_np + bias_np, 0.0)
-dev = tvm.gpu()
+dev = tvm.cuda()
data_tvm = tvm.nd.array(data_np, device=dev)
weight_tvm = tvm.nd.array(weight_np, device=dev)
bias_tvm = tvm.nd.array(bias_np, device=dev)
diff --git a/tutorials/autotvm/tune_conv2d_cuda.py b/tutorials/autotvm/tune_conv2d_cuda.py
index d14f9c3..c46180d 100644
--- a/tutorials/autotvm/tune_conv2d_cuda.py
+++ b/tutorials/autotvm/tune_conv2d_cuda.py
@@ -230,7 +230,7 @@ a_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
w_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
c_np = conv2d_nchw_python(a_np, w_np, strides, padding)
-dev = tvm.gpu()
+dev = tvm.cuda()
a_tvm = tvm.nd.array(a_np, device=dev)
w_tvm = tvm.nd.array(w_np, device=dev)
c_tvm = tvm.nd.empty(c_np.shape, device=dev)
diff --git a/tutorials/frontend/deploy_sparse.py b/tutorials/frontend/deploy_sparse.py
index 92f4511..eb9b4ee 100644
--- a/tutorials/frontend/deploy_sparse.py
+++ b/tutorials/frontend/deploy_sparse.py
@@ -105,7 +105,7 @@ seq_len = 128
# TVM platform identifier. Note that best cpu performance can be achieved by setting -mcpu
# appropriately for your specific machine. CUDA and ROCm are also supported.
target = "llvm"
-# Which device to run on. Should be one of tvm.cpu() or tvm.gpu().
+# Which device to run on. Should be one of tvm.cpu() or tvm.cuda().
dev = tvm.cpu()
# If true, then a sparse variant of the network will be run and
# benchmarked.
diff --git a/tutorials/frontend/from_caffe2.py b/tutorials/frontend/from_caffe2.py
index a3378de..1c00f92 100644
--- a/tutorials/frontend/from_caffe2.py
+++ b/tutorials/frontend/from_caffe2.py
@@ -107,7 +107,7 @@ import tvm
from tvm import te
from tvm.contrib import graph_executor
-# context x86 CPU, use tvm.gpu(0) if you run on GPU
+# context x86 CPU, use tvm.cuda(0) if you run on GPU
dev = tvm.cpu(0)
# create a runtime executor module
m = graph_executor.GraphModule(lib["default"](dev))
diff --git a/tutorials/frontend/from_keras.py b/tutorials/frontend/from_keras.py
index 5f39a24..8625465 100644
--- a/tutorials/frontend/from_keras.py
+++ b/tutorials/frontend/from_keras.py
@@ -96,7 +96,7 @@ shape_dict = {"input_1": data.shape}
mod, params = relay.frontend.from_keras(keras_resnet50, shape_dict)
# compile the model
target = "cuda"
-dev = tvm.gpu(0)
+dev = tvm.cuda(0)
with tvm.transform.PassContext(opt_level=3):
executor = relay.build_module.create_executor("graph", mod, dev, target)
diff --git a/tutorials/frontend/from_mxnet.py b/tutorials/frontend/from_mxnet.py
index bfaac2c..da1bf4e 100644
--- a/tutorials/frontend/from_mxnet.py
+++ b/tutorials/frontend/from_mxnet.py
@@ -106,7 +106,7 @@ with tvm.transform.PassContext(opt_level=3):
# Now, we would like to reproduce the same forward computation using TVM.
from tvm.contrib import graph_executor
-dev = tvm.gpu(0)
+dev = tvm.cuda(0)
dtype = "float32"
m = graph_executor.GraphModule(lib["default"](dev))
# set inputs
diff --git a/tutorials/frontend/from_tensorflow.py b/tutorials/frontend/from_tensorflow.py
index 9c8d0f6..468caf5 100644
--- a/tutorials/frontend/from_tensorflow.py
+++ b/tutorials/frontend/from_tensorflow.py
@@ -72,7 +72,7 @@ label_map_url = os.path.join(repo_base, label_map)
# Use these commented settings to build for cuda.
# target = tvm.target.Target("cuda", host="llvm")
# layout = "NCHW"
-# dev = tvm.gpu(0)
+# dev = tvm.cuda(0)
target = tvm.target.Target("llvm", host="llvm")
layout = None
dev = tvm.cpu(0)
diff --git a/tutorials/get_started/relay_quick_start.py b/tutorials/get_started/relay_quick_start.py
index ffc9bbe..9bd3065 100644
--- a/tutorials/get_started/relay_quick_start.py
+++ b/tutorials/get_started/relay_quick_start.py
@@ -107,7 +107,7 @@ with tvm.transform.PassContext(opt_level=opt_level):
# Now we can create graph executor and run the module on Nvidia GPU.
# create random input
-dev = tvm.gpu()
+dev = tvm.cuda()
data = np.random.uniform(-1, 1, size=data_shape).astype("float32")
# create module
module = graph_executor.GraphModule(lib["default"](dev))
diff --git a/tutorials/language/reduction.py b/tutorials/language/reduction.py
index f782ac6..206848c 100644
--- a/tutorials/language/reduction.py
+++ b/tutorials/language/reduction.py
@@ -137,7 +137,7 @@ print(fcuda.imported_modules[0].get_source())
# Verify the correctness of result kernel by comparing it to numpy.
#
nn = 128
-dev = tvm.gpu(0)
+dev = tvm.cuda(0)
a = tvm.nd.array(np.random.uniform(size=(nn, nn)).astype(A.dtype), dev)
b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), dev)
fcuda(a, b)
diff --git a/tutorials/language/scan.py b/tutorials/language/scan.py
index 8124b56..8876921 100644
--- a/tutorials/language/scan.py
+++ b/tutorials/language/scan.py
@@ -83,7 +83,7 @@ print(tvm.lower(s, [X, s_scan], simple_mode=True))
# numpy to verify the correctness of the result.
#
fscan = tvm.build(s, [X, s_scan], "cuda", name="myscan")
-dev = tvm.gpu(0)
+dev = tvm.cuda(0)
n = 1024
m = 10
a_np = np.random.uniform(size=(m, n)).astype(s_scan.dtype)
diff --git a/tutorials/optimize/opt_conv_cuda.py b/tutorials/optimize/opt_conv_cuda.py
index 0cecc82..0ac2c62 100644
--- a/tutorials/optimize/opt_conv_cuda.py
+++ b/tutorials/optimize/opt_conv_cuda.py
@@ -238,7 +238,7 @@ s[WW].vectorize(fi) # vectorize memory load
#
func = tvm.build(s, [A, W, B], "cuda")
-dev = tvm.gpu(0)
+dev = tvm.cuda(0)
a_np = np.random.uniform(size=(in_size, in_size, in_channel, batch)).astype(A.dtype)
w_np = np.random.uniform(size=(kernel, kernel, in_channel, out_channel)).astype(W.dtype)
a = tvm.nd.array(a_np, dev)
diff --git a/tutorials/optimize/opt_conv_tensorcore.py b/tutorials/optimize/opt_conv_tensorcore.py
index 0a7798d..702e4a7 100644
--- a/tutorials/optimize/opt_conv_tensorcore.py
+++ b/tutorials/optimize/opt_conv_tensorcore.py
@@ -392,7 +392,7 @@ print(tvm.lower(s, [A, W, Conv], simple_mode=True))
# Since TensorCores are only supported in NVIDIA GPU with Compute Capability 7.0 or higher, it may not
# be able to run on our build server
-dev = tvm.gpu(0)
+dev = tvm.cuda(0)
if nvcc.have_tensorcore(dev.compute_version):
with tvm.transform.PassContext(config={"tir.UnrollLoop": {"auto_max_step": 16}}):
func = tvm.build(s, [A, W, Conv], "cuda")
diff --git a/tutorials/topi/intro_topi.py b/tutorials/topi/intro_topi.py
index 1fefae5..5ddb878 100644
--- a/tutorials/topi/intro_topi.py
+++ b/tutorials/topi/intro_topi.py
@@ -99,7 +99,7 @@ print(sg.stages)
# We can test the correctness by comparing with :code:`numpy` result as follows
#
func = tvm.build(sg, [a, b, g], "cuda")
-dev = tvm.gpu(0)
+dev = tvm.cuda(0)
a_np = np.random.uniform(size=(x, y, y)).astype(a.dtype)
b_np = np.random.uniform(size=(y, y)).astype(b.dtype)
g_np = np.sum(np.add(a_np + b_np, a_np * b_np) / 2.0)