You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/05/13 13:12:03 UTC

[tvm] branch main updated: Rename gpu to cuda, and bump dlpack to v0.5 (#8032)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 43c2ea7  Rename gpu to cuda, and bump dlpack to v0.5 (#8032)
43c2ea7 is described below

commit 43c2ea72bcd9968531441189e172372c1d8a64cb
Author: Yuchen Jin <yu...@cs.washington.edu>
AuthorDate: Thu May 13 06:11:40 2021 -0700

    Rename gpu to cuda, and bump dlpack to v0.5 (#8032)
---
 3rdparty/dlpack                                    |  2 +-
 apps/topi_recipe/broadcast/test_broadcast_map.py   | 10 ++---
 apps/topi_recipe/reduce/test_reduce_map.py         |  4 +-
 apps/topi_recipe/rnn/lstm.py                       |  2 +-
 apps/topi_recipe/rnn/matexp.py                     | 10 ++---
 docs/deploy/tensorrt.rst                           |  2 +-
 docs/dev/index.rst                                 |  6 +--
 golang/src/device.go                               | 20 +++++-----
 include/tvm/runtime/device_api.h                   |  8 ++--
 python/tvm/__init__.py                             |  2 +-
 python/tvm/_ffi/runtime_ctypes.py                  |  3 +-
 python/tvm/contrib/nvcc.py                         |  8 ++--
 python/tvm/runtime/__init__.py                     |  2 +-
 python/tvm/runtime/ndarray.py                      | 27 +++++++++++--
 python/tvm/testing.py                              | 10 ++---
 python/tvm/topi/cuda/conv2d_alter_op.py            |  2 +-
 python/tvm/topi/cuda/nms.py                        |  2 +-
 rust/tvm-sys/src/device.rs                         |  8 ++--
 rust/tvm-sys/src/value.rs                          |  2 +-
 src/auto_scheduler/search_policy/utils.h           |  4 +-
 src/auto_scheduler/search_task.cc                  |  6 +--
 src/contrib/tf_op/tvm_dso_op_kernels.cc            |  6 +--
 src/relay/backend/build_module.cc                  |  2 +-
 src/relay/backend/vm/compiler.cc                   |  2 +-
 src/runtime/contrib/cudnn/cudnn_utils.cc           |  2 +-
 src/runtime/contrib/tensorrt/tensorrt_builder.cc   |  4 +-
 src/runtime/contrib/tensorrt/tensorrt_runtime.cc   |  6 +--
 src/runtime/cuda/cuda_device_api.cc                | 20 +++++-----
 src/runtime/module.cc                              |  2 +-
 src/runtime/ndarray.cc                             |  4 +-
 src/target/target_kind.cc                          |  6 +--
 .../schedule_postproc_rewrite_for_tensor_core.cc   |  2 +-
 src/tir/analysis/verify_memory.cc                  |  2 +-
 tests/cpp/build_module_test.cc                     |  2 +-
 .../test_runtime_packed_func.py                    |  4 +-
 tests/python/contrib/test_cublas.py                |  6 +--
 tests/python/contrib/test_cudnn.py                 | 10 ++---
 tests/python/contrib/test_tensorrt.py              | 20 +++++-----
 tests/python/frontend/pytorch/test_forward.py      |  2 +-
 .../quantization/test_quantization_accuracy.py     |  2 +-
 tests/python/relay/test_any.py                     |  4 +-
 tests/python/relay/test_auto_scheduler_tuning.py   |  2 +-
 tests/python/relay/test_cpp_build_module.py        |  2 +-
 tests/python/relay/test_op_level1.py               | 10 ++---
 tests/python/relay/test_pass_context_analysis.py   | 42 ++++++++++-----------
 tests/python/topi/python/test_topi_relu.py         |  2 +-
 tests/python/topi/python/test_topi_tensor.py       |  2 +-
 .../unittest/test_runtime_graph_cuda_graph.py      |  2 +-
 .../test_runtime_module_based_interface.py         |  8 ++--
 tests/python/unittest/test_target_codegen_blob.py  |  4 +-
 tests/python/unittest/test_target_codegen_cuda.py  | 44 +++++++++++-----------
 tests/python/unittest/test_target_codegen_llvm.py  |  2 +-
 ...te_schedule_postproc_rewrite_for_tensor_core.py |  4 +-
 .../unittest/test_te_schedule_tensor_core.py       |  4 +-
 .../test_tir_transform_lower_warp_memory.py        | 12 +++---
 tutorials/auto_scheduler/tune_conv2d_layer_cuda.py |  2 +-
 tutorials/autotvm/tune_conv2d_cuda.py              |  2 +-
 tutorials/frontend/deploy_sparse.py                |  2 +-
 tutorials/frontend/from_caffe2.py                  |  2 +-
 tutorials/frontend/from_keras.py                   |  2 +-
 tutorials/frontend/from_mxnet.py                   |  2 +-
 tutorials/frontend/from_tensorflow.py              |  2 +-
 tutorials/get_started/relay_quick_start.py         |  2 +-
 tutorials/language/reduction.py                    |  2 +-
 tutorials/language/scan.py                         |  2 +-
 tutorials/optimize/opt_conv_cuda.py                |  2 +-
 tutorials/optimize/opt_conv_tensorcore.py          |  2 +-
 tutorials/topi/intro_topi.py                       |  2 +-
 68 files changed, 217 insertions(+), 197 deletions(-)

diff --git a/3rdparty/dlpack b/3rdparty/dlpack
index a07f962..ddeb264 160000
--- a/3rdparty/dlpack
+++ b/3rdparty/dlpack
@@ -1 +1 @@
-Subproject commit a07f962d446b577adf4baef2b347a0f3a2a20617
+Subproject commit ddeb264880a1fa7e7be238ab3901a810324fbe5f
diff --git a/apps/topi_recipe/broadcast/test_broadcast_map.py b/apps/topi_recipe/broadcast/test_broadcast_map.py
index e7b5c3a..43a44af 100644
--- a/apps/topi_recipe/broadcast/test_broadcast_map.py
+++ b/apps/topi_recipe/broadcast/test_broadcast_map.py
@@ -65,8 +65,8 @@ def test_broadcast_to(in_shape, out_shape):
     data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
     out_npy = np.broadcast_to(data_npy, out_shape)
 
-    data_nd = tvm.nd.array(data_npy, tvm.gpu())
-    out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), tvm.gpu())
+    data_nd = tvm.nd.array(data_npy, tvm.cuda())
+    out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), tvm.cuda())
     for _ in range(2):
         fcuda(data_nd, out_nd)
     tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
@@ -116,9 +116,9 @@ def test_broadcast_binary_op(lhs_shape, rhs_shape, typ="add"):
         out_npy = np.maximum(lhs_npy, rhs_npy)
     elif typ == "minimum":
         out_npy = np.minimum(lhs_npy, rhs_npy)
-    lhs_nd = tvm.nd.array(lhs_npy, tvm.gpu())
-    rhs_nd = tvm.nd.array(rhs_npy, tvm.gpu())
-    out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), tvm.gpu())
+    lhs_nd = tvm.nd.array(lhs_npy, tvm.cuda())
+    rhs_nd = tvm.nd.array(rhs_npy, tvm.cuda())
+    out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), tvm.cuda())
     for _ in range(2):
         fcuda(lhs_nd, rhs_nd, out_nd)
     tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
diff --git a/apps/topi_recipe/reduce/test_reduce_map.py b/apps/topi_recipe/reduce/test_reduce_map.py
index 0a78e5b..71ceb8f 100644
--- a/apps/topi_recipe/reduce/test_reduce_map.py
+++ b/apps/topi_recipe/reduce/test_reduce_map.py
@@ -78,8 +78,8 @@ def test_reduce_map(in_shape, axis, keepdims, type="sum", test_id=0):
     else:
         raise NotImplementedError
 
-    data_tvm = tvm.nd.array(in_npy, device=tvm.gpu())
-    out_tvm = tvm.nd.empty(shape=out_npy.shape, device=tvm.gpu())
+    data_tvm = tvm.nd.array(in_npy, device=tvm.cuda())
+    out_tvm = tvm.nd.empty(shape=out_npy.shape, device=tvm.cuda())
 
     for _ in range(2):
         fcuda(data_tvm, out_tvm)
diff --git a/apps/topi_recipe/rnn/lstm.py b/apps/topi_recipe/rnn/lstm.py
index e4b7fba..cd45bff 100644
--- a/apps/topi_recipe/rnn/lstm.py
+++ b/apps/topi_recipe/rnn/lstm.py
@@ -171,7 +171,7 @@ def lstm():
     def check_device(target):
         num_step = n_num_step
         flstm = tvm.build(s, [Xi2h, Wh2h, scan_h, scan_c], target)
-        dev = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
+        dev = tvm.cuda(0) if target == "cuda" else tvm.cl(0)
         # launch the kernel.
         scan_h_np = np.zeros((num_step, batch_size, num_hidden)).astype("float32")
         scan_c_np = np.zeros((num_step, batch_size, num_hidden)).astype("float32")
diff --git a/apps/topi_recipe/rnn/matexp.py b/apps/topi_recipe/rnn/matexp.py
index ecf868c..85e0d61 100644
--- a/apps/topi_recipe/rnn/matexp.py
+++ b/apps/topi_recipe/rnn/matexp.py
@@ -140,7 +140,7 @@ def rnn_matexp():
             }
         ):
             f = tvm.build(s, [s_scan, Whh], target)
-        dev = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
+        dev = tvm.cuda(0) if target == "cuda" else tvm.cl(0)
         # launch the kernel.
         res_np = np.zeros((n_num_step, n_batch_size, n_num_hidden)).astype("float32")
         Whh_np = np.zeros((n_num_hidden, n_num_hidden)).astype("float32")
@@ -160,16 +160,16 @@ def rnn_matexp():
         print("Time cost=%g" % tgap)
         # correctness
         if not SKIP_CHECK:
-            res_gpu = res_a.asnumpy()
+            res_cuda = res_a.asnumpy()
             res_cmp = np.ones_like(res_np).astype("float64")
             Whh_np = Whh_np.astype("float64")
             for t in range(1, n_num_step):
                 res_cmp[t][:] = np.dot(res_cmp[t - 1], Whh_np)
             for i in range(n_num_step):
                 for j in range(n_num_hidden):
-                    if abs(res_cmp[i, 0, j] - res_gpu[i, 0, j]) > 1e-5:
-                        print("%d, %d: %g vs %g" % (i, j, res_cmp[i, 0, j], res_gpu[i, 0, j]))
-            tvm.testing.assert_allclose(res_gpu, res_cmp, rtol=1e-3)
+                    if abs(res_cmp[i, 0, j] - res_cuda[i, 0, j]) > 1e-5:
+                        print("%d, %d: %g vs %g" % (i, j, res_cmp[i, 0, j], res_cuda[i, 0, j]))
+            tvm.testing.assert_allclose(res_cuda, res_cmp, rtol=1e-3)
 
     check_device("cuda")
 
diff --git a/docs/deploy/tensorrt.rst b/docs/deploy/tensorrt.rst
index 08addeb..a39d9c8 100644
--- a/docs/deploy/tensorrt.rst
+++ b/docs/deploy/tensorrt.rst
@@ -124,7 +124,7 @@ have to be built.
 
 .. code:: python
 
-    dev = tvm.gpu(0)
+    dev = tvm.cuda(0)
     loaded_lib = tvm.runtime.load_module('compiled.so')
     gen_module = tvm.contrib.graph_executor.GraphModule(loaded_lib['default'](dev))
     input_data = np.random.uniform(0, 1, input_shape).astype(dtype)
diff --git a/docs/dev/index.rst b/docs/dev/index.rst
index ed0f1a1..5189ffd 100644
--- a/docs/dev/index.rst
+++ b/docs/dev/index.rst
@@ -144,7 +144,7 @@ The main goal of TVM's runtime is to provide a minimal API for loading and execu
     import tvm
     # Example runtime execution program in python, with type annotated
     mod: tvm.runtime.Module = tvm.runtime.load_module("compiled_artifact.so")
-    arr: tvm.runtime.NDArray = tvm.nd.array([1, 2, 3], device=tvm.gpu(0))
+    arr: tvm.runtime.NDArray = tvm.nd.array([1, 2, 3], device=tvm.cuda(0))
     fun: tvm.runtime.PackedFunc = mod["addone"]
     fun(a)
     print(a.asnumpy())
@@ -164,8 +164,8 @@ The above example only deals with a simple `addone` function. The code snippet b
    import tvm
    # Example runtime execution program in python, with types annotated
    factory: tvm.runtime.Module = tvm.runtime.load_module("resnet18.so")
-   # Create a stateful graph execution module for resnet18 on gpu(0)
-   gmod: tvm.runtime.Module = factory["resnet18"](tvm.gpu(0))
+   # Create a stateful graph execution module for resnet18 on cuda(0)
+   gmod: tvm.runtime.Module = factory["resnet18"](tvm.cuda(0))
    data: tvm.runtime.NDArray = get_input_data()
    # set input
    gmod["set_input"](0, data)
diff --git a/golang/src/device.go b/golang/src/device.go
index 6569e44..b2203a3 100644
--- a/golang/src/device.go
+++ b/golang/src/device.go
@@ -29,10 +29,10 @@ import "C"
 
 // KDLCPU is golang enum correspond to TVM device type kDLCPU.
 var KDLCPU                  = int32(C.kDLCPU)
-// KDLGPU is golang enum correspond to TVM device type kDLGPU.
-var KDLGPU                  = int32(C.kDLGPU)
-// KDLCPUPinned is golang enum correspond to TVM device type kDLCPUPinned.
-var KDLCPUPinned            = int32(C.kDLCPUPinned)
+// kDLCUDA is golang enum correspond to TVM device type kDLCUDA.
+var kDLCUDA                  = int32(C.kDLCUDA)
+// kDLCUDAHost is golang enum correspond to TVM device type kDLCUDAHost.
+var kDLCUDAHost            = int32(C.kDLCUDAHost)
 // KDLOpenCL is golang enum correspond to TVM device type kDLOpenCL.
 var KDLOpenCL               = int32(C.kDLOpenCL)
 // KDLMetal is golang enum correspond to TVM device type kDLMetal.
@@ -61,14 +61,14 @@ func CPU(index int32) Device {
     return Device{KDLCPU, index}
 }
 
-// GPU returns the Device object for GPU target on given index
-func GPU(index int32) Device {
-    return Device{KDLGPU, index}
+// CUDA returns the Device object for CUDA target on given index
+func CUDA(index int32) Device {
+    return Device{kDLCUDA, index}
 }
 
-// CPUPinned returns the Device object for CPUPinned target on given index
-func CPUPinned(index int32) Device {
-    return Device{KDLCPUPinned, index}
+// CUDAHost returns the Device object for CUDAHost target on given index
+func CUDAHost(index int32) Device {
+    return Device{kDLCUDAHost, index}
 }
 
 // OpenCL returns the Device object for OpenCL target on given index
diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index 57bf51d..c3527d8 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -231,10 +231,10 @@ inline const char* DeviceName(int type) {
   switch (type) {
     case kDLCPU:
       return "cpu";
-    case kDLGPU:
-      return "gpu";
-    case kDLCPUPinned:
-      return "cpu_pinned";
+    case kDLCUDA:
+      return "cuda";
+    case kDLCUDAHost:
+      return "cuda_host";
     case kDLOpenCL:
       return "opencl";
     case kDLSDAccel:
diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index 4643062..0adad82 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -30,7 +30,7 @@ from ._ffi import register_object, register_func, register_extension, get_global
 # top-level alias
 # tvm.runtime
 from .runtime.object import Object
-from .runtime.ndarray import device, cpu, gpu, opencl, cl, vulkan, metal, mtl
+from .runtime.ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl
 from .runtime.ndarray import vpi, rocm, ext_dev, micro_dev, hexagon
 from .runtime import ndarray as nd
 
diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py
index 3e79801..4eda5e8 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -164,7 +164,7 @@ class Device(ctypes.Structure):
     _fields_ = [("device_type", ctypes.c_int), ("device_id", ctypes.c_int)]
     MASK2STR = {
         1: "cpu",
-        2: "gpu",
+        2: "cuda",
         4: "opencl",
         5: "aocl",
         6: "sdaccel",
@@ -182,7 +182,6 @@ class Device(ctypes.Structure):
         "stackvm": 1,
         "cpu": 1,
         "c": 1,
-        "gpu": 2,
         "cuda": 2,
         "nvptx": 2,
         "cl": 4,
diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index 6a7c098..0124d00 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -249,8 +249,8 @@ def get_target_compute_version(target=None):
         return major + "." + minor
 
     # 3. GPU
-    if tvm.gpu(0).exist:
-        return tvm.gpu(0).compute_version
+    if tvm.cuda(0).exist:
+        return tvm.cuda(0).compute_version
 
     warnings.warn(
         "No CUDA architecture was specified or GPU detected."
@@ -331,8 +331,8 @@ def have_tensorcore(compute_version=None, target=None):
         isn't specified.
     """
     if compute_version is None:
-        if tvm.gpu(0).exist:
-            compute_version = tvm.gpu(0).compute_version
+        if tvm.cuda(0).exist:
+            compute_version = tvm.cuda(0).compute_version
         else:
             if target is None or "arch" not in target.attrs:
                 warnings.warn(
diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py
index 54e75ba..265dedb 100644
--- a/python/tvm/runtime/__init__.py
+++ b/python/tvm/runtime/__init__.py
@@ -26,7 +26,7 @@ from .profiling import Report
 
 # function exposures
 from .object_generic import convert_to_object, convert, const
-from .ndarray import device, cpu, gpu, opencl, cl, vulkan, metal, mtl
+from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl
 from .ndarray import vpi, rocm, ext_dev, micro_dev
 from .module import load_module, enabled, system_lib
 from .container import String
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index befe077..823b1cc 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -17,6 +17,7 @@
 # pylint: disable=invalid-name, unused-import, redefined-outer-name
 """Runtime NDArray API"""
 import ctypes
+import warnings
 import numpy as np
 import tvm._ffi
 
@@ -254,8 +255,7 @@ def device(dev_type, dev_id=0):
     .. code-block:: python
 
       assert tvm.device("cpu", 1) == tvm.cpu(1)
-      assert tvm.device("gpu", 0) == tvm.gpu(0)
-      assert tvm.device("cuda", 0) == tvm.gpu(0)
+      assert tvm.device("cuda", 0) == tvm.cuda(0)
     """
     if isinstance(dev_type, string_types):
         if "-device=micro_dev" in dev_type:
@@ -362,9 +362,27 @@ def cpu(dev_id=0):
     return Device(1, dev_id)
 
 
+def cuda(dev_id=0):
+    """Construct a CUDA GPU device
+
+    Parameters
+    ----------
+    dev_id : int, optional
+        The integer device id
+
+    Returns
+    -------
+    dev : Device
+        The created device
+    """
+    return Device(2, dev_id)
+
+
 def gpu(dev_id=0):
-    """Construct a GPU device
+    """Construct a CUDA GPU device
 
+        deprecated:: 0.9.0
+        Use :py:func:`tvm.cuda` instead.
     Parameters
     ----------
     dev_id : int, optional
@@ -375,6 +393,9 @@ def gpu(dev_id=0):
     dev : Device
         The created device
     """
+    warnings.warn(
+        "Please use tvm.cuda() instead of tvm.gpu(). tvm.gpu() is going to be deprecated in 0.9.0",
+    )
     return Device(2, dev_id)
 
 
diff --git a/python/tvm/testing.py b/python/tvm/testing.py
index edcf4a6..e4cf62c 100644
--- a/python/tvm/testing.py
+++ b/python/tvm/testing.py
@@ -464,9 +464,9 @@ def _compose(args, decs):
 
 
 def uses_gpu(*args):
-    """Mark to differentiate tests that use the GPU is some capacity.
+    """Mark to differentiate tests that use the GPU in some capacity.
 
-    These tests will be run on CPU-only test nodes and on test nodes with GPUS.
+    These tests will be run on CPU-only test nodes and on test nodes with GPUs.
     To mark a test that must have a GPU present to run, use
     :py:func:`tvm.testing.requires_gpu`.
 
@@ -490,7 +490,7 @@ def requires_gpu(*args):
         Function to mark
     """
     _requires_gpu = [
-        pytest.mark.skipif(not tvm.gpu().exist, reason="No GPU present"),
+        pytest.mark.skipif(not tvm.cuda().exist, reason="No GPU present"),
         *uses_gpu(),
     ]
     return _compose(args, _requires_gpu)
@@ -499,7 +499,7 @@ def requires_gpu(*args):
 def requires_cuda(*args):
     """Mark a test as requiring the CUDA runtime.
 
-    This also marks the test as requiring a gpu.
+    This also marks the test as requiring a cuda gpu.
 
     Parameters
     ----------
@@ -618,7 +618,7 @@ def requires_tensorcore(*args):
     _requires_tensorcore = [
         pytest.mark.tensorcore,
         pytest.mark.skipif(
-            not tvm.gpu().exist or not nvcc.have_tensorcore(tvm.gpu(0).compute_version),
+            not tvm.cuda().exist or not nvcc.have_tensorcore(tvm.cuda(0).compute_version),
             reason="No tensorcore present",
         ),
         *requires_gpu(),
diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py
index 65bf9d1..067f272 100644
--- a/python/tvm/topi/cuda/conv2d_alter_op.py
+++ b/python/tvm/topi/cuda/conv2d_alter_op.py
@@ -225,7 +225,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
 
     if topi_tmpl == "conv2d_HWNCnc_tensorcore.cuda":
         assert data_layout == "HWNC" and kernel_layout == "HWOI"
-        assert float(tvm.gpu(0).compute_version) >= 7.5
+        assert float(tvm.cuda(0).compute_version) >= 7.5
         H, W, N, CI = get_const_tuple(data.shape)
         KH, KW, CO, _ = get_const_tuple(kernel.shape)
 
diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index f064360..9a3b86d 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -878,7 +878,7 @@ def non_max_suppression(
         np_valid_count = np.array([4])
         s = topi.generic.schedule_nms(out)
         f = tvm.build(s, [data, valid_count, out], "cuda")
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         tvm_data = tvm.nd.array(np_data, dev)
         tvm_valid_count = tvm.nd.array(np_valid_count, dev)
         tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), dev)
diff --git a/rust/tvm-sys/src/device.rs b/rust/tvm-sys/src/device.rs
index 910cc59..7b659ef 100644
--- a/rust/tvm-sys/src/device.rs
+++ b/rust/tvm-sys/src/device.rs
@@ -66,7 +66,7 @@ use thiserror::Error;
 pub enum DeviceType {
     CPU = 1,
     GPU,
-    CPUPinned,
+    CUDAHost,
     OpenCL,
     Vulkan,
     Metal,
@@ -101,8 +101,8 @@ impl Display for DeviceType {
             "{}",
             match self {
                 DeviceType::CPU => "cpu",
-                DeviceType::GPU => "gpu",
-                DeviceType::CPUPinned => "cpu_pinned",
+                DeviceType::GPU => "cuda",
+                DeviceType::CUDAHost => "cuda_host",
                 DeviceType::OpenCL => "opencl",
                 DeviceType::Vulkan => "vulkan",
                 DeviceType::Metal => "metal",
@@ -210,7 +210,7 @@ macro_rules! impl_tvm_device {
 
 impl_tvm_device!(
     DLDeviceType_kDLCPU: [cpu, llvm, stackvm],
-    DLDeviceType_kDLGPU: [gpu, cuda, nvptx],
+    DLDeviceType_kDLCUDA: [gpu, cuda, nvptx],
     DLDeviceType_kDLOpenCL: [cl],
     DLDeviceType_kDLMetal: [metal],
     DLDeviceType_kDLVPI: [vpi],
diff --git a/rust/tvm-sys/src/value.rs b/rust/tvm-sys/src/value.rs
index f939d51..1b4f773 100644
--- a/rust/tvm-sys/src/value.rs
+++ b/rust/tvm-sys/src/value.rs
@@ -86,7 +86,7 @@ macro_rules! impl_tvm_device {
 
 impl_tvm_device!(
     DLDeviceType_kDLCPU: [cpu, llvm, stackvm],
-    DLDeviceType_kDLGPU: [gpu, cuda, nvptx],
+    DLDeviceType_kDLCUDA: [gpu, cuda, nvptx],
     DLDeviceType_kDLOpenCL: [cl],
     DLDeviceType_kDLMetal: [metal],
     DLDeviceType_kDLVPI: [vpi],
diff --git a/src/auto_scheduler/search_policy/utils.h b/src/auto_scheduler/search_policy/utils.h
index eb2cd69..ffd4bf4 100644
--- a/src/auto_scheduler/search_policy/utils.h
+++ b/src/auto_scheduler/search_policy/utils.h
@@ -53,7 +53,7 @@ inline bool IsCPUTask(const SearchTask& task) {
 
 /*! \brief Return whether the search task is targeting a GPU. */
 inline bool IsGPUTask(const SearchTask& task) {
-  return (task)->target->kind->device_type == kDLGPU ||
+  return (task)->target->kind->device_type == kDLCUDA ||
          (task)->target->kind->device_type == kDLOpenCL ||
          (task)->target->kind->device_type == kDLVulkan ||
          (task)->target->kind->device_type == kDLMetal ||
@@ -63,7 +63,7 @@ inline bool IsGPUTask(const SearchTask& task) {
 
 /*! \brief Return whether the search task is targeting a CUDA GPU. */
 inline bool IsCUDATask(const SearchTask& task) {
-  return (task)->target->kind->device_type == kDLGPU;
+  return (task)->target->kind->device_type == kDLCUDA;
 }
 
 /*! \brief Return whether the search task is targeting a OpenCL GPU. */
diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc
index 80fb71d..03d880e 100755
--- a/src/auto_scheduler/search_task.cc
+++ b/src/auto_scheduler/search_task.cc
@@ -57,11 +57,11 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
   const auto device_type = target->kind->device_type;
   if (device_type == kDLCPU) {
     return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64, 0, 0, 0, 0, 0);
-  } else if (device_type == kDLGPU || device_type == kDLROCM) {
+  } else if (device_type == kDLCUDA || device_type == kDLROCM) {
     auto dev = Device{static_cast<DLDeviceType>(device_type), 0};
-    auto device_name = device_type == kDLGPU ? "device_api.gpu" : "device_api.rocm";
+    auto device_name = device_type == kDLCUDA ? "device_api.cuda" : "device_api.rocm";
     auto func = tvm::runtime::Registry::Get(device_name);
-    ICHECK(func != nullptr) << "Cannot find GPU device_api in registry";
+    ICHECK(func != nullptr) << "Cannot find CUDA device_api in registry";
     auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());
 
     tvm::runtime::TVMRetValue ret;
diff --git a/src/contrib/tf_op/tvm_dso_op_kernels.cc b/src/contrib/tf_op/tvm_dso_op_kernels.cc
index c816119..fb483ee 100644
--- a/src/contrib/tf_op/tvm_dso_op_kernels.cc
+++ b/src/contrib/tf_op/tvm_dso_op_kernels.cc
@@ -69,7 +69,7 @@ class TensorAsBuf {
     if (device_type == kDLCPU) {
       memcpy(origin_buf, buf + offset, size);
 #ifdef TF_TVMDSOOP_ENABLE_GPU
-    } else if (device_type == kDLGPU) {
+    } else if (device_type == kDLCUDA) {
       cudaMemcpy(origin_buf, buf + offset, size, cudaMemcpyDeviceToDevice);
 #endif
     } else {
@@ -85,7 +85,7 @@ class TensorAsBuf {
     if (device_type == kDLCPU) {
       memcpy(buf + offset, origin_buf, size);
 #ifdef TF_TVMDSOOP_ENABLE_GPU
-    } else if (device_type == kDLGPU) {
+    } else if (device_type == kDLCUDA) {
       cudaMemcpy(buf + offset, origin_buf, size, cudaMemcpyDeviceToDevice);
 #endif
     } else {
@@ -192,7 +192,7 @@ class TVMDSOOpTrait<CPUDevice> {
 template <>
 class TVMDSOOpTrait<GPUDevice> {
  public:
-  static const int device_type = kDLGPU;
+  static const int device_type = kDLCUDA;
 
   static int device_id(OpKernelContext* context) {
     auto device_base = context->device();
diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc
index 880407f..00b6fed 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -429,7 +429,7 @@ class RelayBuildModule : public runtime::ModuleNode {
   Target CreateDefaultTarget(int device_type) {
     std::string name = runtime::DeviceName(device_type);
     if (name == "cpu") return Target("llvm");
-    if (name == "gpu") return Target("cuda");
+    if (name == "cuda") return Target("cuda");
     return Target(name);
   }
 
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index afc01aa..832cc0e 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -234,7 +234,7 @@ std::vector<int64_t> ToAllocTensorShape(NDArray shape) {
 Target CreateDefaultTarget(int device_type) {
   std::string name = runtime::DeviceName(device_type);
   if (name == "cpu") return Target("llvm");
-  if (name == "gpu") return Target("cuda");
+  if (name == "cuda") return Target("cuda");
   return Target(name);
 }
 
diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc
index 006064e..da67c2e 100644
--- a/src/runtime/contrib/cudnn/cudnn_utils.cc
+++ b/src/runtime/contrib/cudnn/cudnn_utils.cc
@@ -96,7 +96,7 @@ const void* CuDNNDataType::GetConst<1>(cudnnDataType_t type) {
 
 CuDNNThreadEntry::CuDNNThreadEntry() {
   auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
-  auto func = runtime::Registry::Get("device_api.gpu");
+  auto func = runtime::Registry::Get("device_api.cuda");
   void* ret = (*func)();
   cuda_api = static_cast<runtime::DeviceAPI*>(ret);
   CUDNN_CALL(cudnnCreate(&handle));
diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc
index e98413e..b8d6f6c 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc
@@ -248,13 +248,13 @@ void TensorRTBuilder::CleanUp() {
 void TensorRTBuilder::AllocateDeviceBuffer(nvinfer1::ICudaEngine* engine, const std::string& name,
                                            std::vector<runtime::NDArray>* device_buffers) {
   const uint32_t entry_id = entry_id_map_[name];
-  if (data_entry_[entry_id]->device.device_type != kDLGPU) {
+  if (data_entry_[entry_id]->device.device_type != kDLCUDA) {
     const int binding_index = engine->getBindingIndex(name.c_str());
     ICHECK_NE(binding_index, -1);
     std::vector<int64_t> shape(data_entry_[entry_id]->shape,
                                data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim);
     device_buffers->at(binding_index) =
-        runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLGPU, 0});
+        runtime::NDArray::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0});
   }
 }
 
diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
index 7efa5bf..e963594 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
@@ -135,7 +135,7 @@ class TensorRTRuntime : public JSONRuntimeBase {
           const std::string name = nodes_[nid].GetOpName() + "_" + std::to_string(j);
           int binding_index = engine->getBindingIndex(name.c_str());
           ICHECK_NE(binding_index, -1);
-          if (data_entry_[eid]->device.device_type == kDLGPU) {
+          if (data_entry_[eid]->device.device_type == kDLCUDA) {
             bindings[binding_index] = data_entry_[eid]->data;
           } else {
             device_buffers[binding_index].CopyFrom(data_entry_[eid]);
@@ -150,7 +150,7 @@ class TensorRTRuntime : public JSONRuntimeBase {
       const std::string& name = engine_and_context.outputs[i];
       int binding_index = engine->getBindingIndex(name.c_str());
       ICHECK_NE(binding_index, -1);
-      if (data_entry_[eid]->device.device_type == kDLGPU) {
+      if (data_entry_[eid]->device.device_type == kDLCUDA) {
         bindings[binding_index] = data_entry_[eid]->data;
       } else {
         bindings[binding_index] = device_buffers[binding_index]->data;
@@ -173,7 +173,7 @@ class TensorRTRuntime : public JSONRuntimeBase {
       const std::string& name = engine_and_context.outputs[i];
       int binding_index = engine->getBindingIndex(name.c_str());
       ICHECK_NE(binding_index, -1);
-      if (data_entry_[eid]->device.device_type != kDLGPU) {
+      if (data_entry_[eid]->device.device_type != kDLCUDA) {
         device_buffers[binding_index].CopyTo(const_cast<DLTensor*>(data_entry_[eid]));
       }
     }
diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc
index d6c939b..47f038b 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -111,7 +111,7 @@ class CUDADeviceAPI final : public DeviceAPI {
   void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final {
     ICHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes";
     void* ret;
-    if (dev.device_type == kDLCPUPinned) {
+    if (dev.device_type == kDLCUDAHost) {
       CUDA_CALL(cudaMallocHost(&ret, nbytes));
     } else {
       CUDA_CALL(cudaSetDevice(dev.device_id));
@@ -121,7 +121,7 @@ class CUDADeviceAPI final : public DeviceAPI {
   }
 
   void FreeDataSpace(Device dev, void* ptr) final {
-    if (dev.device_type == kDLCPUPinned) {
+    if (dev.device_type == kDLCUDAHost) {
       CUDA_CALL(cudaFreeHost(ptr));
     } else {
       CUDA_CALL(cudaSetDevice(dev.device_id));
@@ -137,11 +137,11 @@ class CUDADeviceAPI final : public DeviceAPI {
     from = static_cast<const char*>(from) + from_offset;
     to = static_cast<char*>(to) + to_offset;
 
-    if (dev_from.device_type == kDLCPUPinned) {
+    if (dev_from.device_type == kDLCUDAHost) {
       dev_from.device_type = kDLCPU;
     }
 
-    if (dev_to.device_type == kDLCPUPinned) {
+    if (dev_to.device_type == kDLCUDAHost) {
       dev_to.device_type = kDLCPU;
     }
 
@@ -151,17 +151,17 @@ class CUDADeviceAPI final : public DeviceAPI {
       return;
     }
 
-    if (dev_from.device_type == kDLGPU && dev_to.device_type == kDLGPU) {
+    if (dev_from.device_type == kDLCUDA && dev_to.device_type == kDLCUDA) {
       CUDA_CALL(cudaSetDevice(dev_from.device_id));
       if (dev_from.device_id == dev_to.device_id) {
         GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
       } else {
         cudaMemcpyPeerAsync(to, dev_to.device_id, from, dev_from.device_id, size, cu_stream);
       }
-    } else if (dev_from.device_type == kDLGPU && dev_to.device_type == kDLCPU) {
+    } else if (dev_from.device_type == kDLCUDA && dev_to.device_type == kDLCPU) {
       CUDA_CALL(cudaSetDevice(dev_from.device_id));
       GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
-    } else if (dev_from.device_type == kDLCPU && dev_to.device_type == kDLGPU) {
+    } else if (dev_from.device_type == kDLCPU && dev_to.device_type == kDLCUDA) {
       CUDA_CALL(cudaSetDevice(dev_to.device_id));
       GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
     } else {
@@ -231,16 +231,16 @@ class CUDADeviceAPI final : public DeviceAPI {
 
 typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
 
-CUDAThreadEntry::CUDAThreadEntry() : pool(kDLGPU, CUDADeviceAPI::Global()) {}
+CUDAThreadEntry::CUDAThreadEntry() : pool(kDLCUDA, CUDADeviceAPI::Global()) {}
 
 CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); }
 
-TVM_REGISTER_GLOBAL("device_api.gpu").set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("device_api.cuda").set_body([](TVMArgs args, TVMRetValue* rv) {
   DeviceAPI* ptr = CUDADeviceAPI::Global();
   *rv = static_cast<void*>(ptr);
 });
 
-TVM_REGISTER_GLOBAL("device_api.cpu_pinned").set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("device_api.cuda_host").set_body([](TVMArgs args, TVMRetValue* rv) {
   DeviceAPI* ptr = CUDADeviceAPI::Global();
   *rv = static_cast<void*>(ptr);
 });
diff --git a/src/runtime/module.cc b/src/runtime/module.cc
index d84a821..15b9c0d 100644
--- a/src/runtime/module.cc
+++ b/src/runtime/module.cc
@@ -126,7 +126,7 @@ bool RuntimeEnabled(const std::string& target) {
   if (target == "cpu") {
     return true;
   } else if (target == "cuda" || target == "gpu") {
-    f_name = "device_api.gpu";
+    f_name = "device_api.cuda";
   } else if (target == "cl" || target == "opencl" || target == "sdaccel") {
     f_name = "device_api.opencl";
   } else if (target == "mtl" || target == "metal") {
diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc
index 4b52a7d..3d3466b 100644
--- a/src/runtime/ndarray.cc
+++ b/src/runtime/ndarray.cc
@@ -231,8 +231,8 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str
   ICHECK_EQ(from_size, to_size) << "TVMArrayCopyFromTo: The size must exactly match";
 
   ICHECK(from->device.device_type == to->device.device_type || from->device.device_type == kDLCPU ||
-         to->device.device_type == kDLCPU || from->device.device_type == kDLCPUPinned ||
-         to->device.device_type == kDLCPUPinned)
+         to->device.device_type == kDLCPU || from->device.device_type == kDLCUDAHost ||
+         to->device.device_type == kDLCUDAHost)
       << "Can not copy across different device types directly";
 
   // Use the device that is *not* a cpu device to get the correct device
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 474b1b0..cc493b9 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -152,7 +152,7 @@ Map<String, ObjectRef> UpdateNVPTXAttrs(Map<String, ObjectRef> attrs) {
   } else {
     // Use the compute version of the first CUDA GPU instead
     TVMRetValue version;
-    if (!DetectDeviceFlag({kDLGPU, 0}, runtime::kComputeVersion, &version)) {
+    if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) {
       LOG(WARNING) << "Unable to detect CUDA version, default to \"-mcpu=sm_20\" instead";
       arch = 20;
     } else {
@@ -230,7 +230,7 @@ TVM_REGISTER_TARGET_KIND("c", kDLCPU)
     .add_attr_option<String>("executor")
     .set_default_keys({"cpu"});
 
-TVM_REGISTER_TARGET_KIND("cuda", kDLGPU)
+TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)
     .add_attr_option<String>("mcpu")
     .add_attr_option<String>("arch")
     .add_attr_option<Bool>("system-lib")
@@ -241,7 +241,7 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLGPU)
     .add_attr_option<Integer>("max_threads_per_block")
     .set_default_keys({"cuda", "gpu"});
 
-TVM_REGISTER_TARGET_KIND("nvptx", kDLGPU)
+TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA)
     .add_attr_option<String>("mcpu")
     .add_attr_option<String>("mtriple")
     .add_attr_option<Bool>("system-lib")
diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
index 377ad5c..951bd6c 100644
--- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
+++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
@@ -1089,7 +1089,7 @@ Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule,
   }
 
   // Check if current runtime support GPU CUDA
-  Device dev{kDLGPU, 0};
+  Device dev{kDLCUDA, 0};
   auto api = tvm::runtime::DeviceAPI::Get(dev, true);
   if (api == nullptr) {
     return stmt;
diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc
index 905384f..3c29e4e 100644
--- a/src/tir/analysis/verify_memory.cc
+++ b/src/tir/analysis/verify_memory.cc
@@ -149,7 +149,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
 
   /// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device.
   static bool IsGPUDevice(int dev_type) {
-    return kDLGPU == dev_type || kDLOpenCL == dev_type || kDLVulkan == dev_type ||
+    return kDLCUDA == dev_type || kDLOpenCL == dev_type || kDLVulkan == dev_type ||
            kDLMetal == dev_type || kDLROCM == dev_type || kOpenGL == dev_type;
   }
   /// Check if a given DLDeviceType/TVMDeviceExtType value denotes FPGA device.
diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc
index e937393..8cc5c4b 100644
--- a/tests/cpp/build_module_test.cc
+++ b/tests/cpp/build_module_test.cc
@@ -166,7 +166,7 @@ TEST(BuildModule, Heterogeneous) {
   // Initialize graph executor.
   int cpu_dev_ty = static_cast<int>(kDLCPU);
   int cpu_dev_id = 0;
-  int gpu_dev_ty = static_cast<int>(kDLGPU);
+  int gpu_dev_ty = static_cast<int>(kDLCUDA);
   int gpu_dev_id = 0;
 
   const runtime::PackedFunc* graph_executor =
diff --git a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py
index 9318b0c..f905ef8 100644
--- a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py
+++ b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py
@@ -101,10 +101,10 @@ def test_empty_array():
 
 def test_device():
     def test_device_func(dev):
-        assert tvm.gpu(7) == dev
+        assert tvm.cuda(7) == dev
         return tvm.cpu(0)
 
-    x = test_device_func(tvm.gpu(7))
+    x = test_device_func(tvm.cuda(7))
     assert x == tvm.cpu(0)
     x = tvm.opencl(10)
     x = tvm.testing.device_test(x, x.device_type, x.device_id)
diff --git a/tests/python/contrib/test_cublas.py b/tests/python/contrib/test_cublas.py
index c4e6f89..d871a38 100644
--- a/tests/python/contrib/test_cublas.py
+++ b/tests/python/contrib/test_cublas.py
@@ -35,7 +35,7 @@ def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5):
         if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
             print("skip because extern function is not available")
             return
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         f = tvm.build(s, [A, B, C], target)
         a = tvm.nd.array(np.random.uniform(0, 128, size=(n, l)).astype(A.dtype), dev)
         b = tvm.nd.array(np.random.uniform(0, 128, size=(l, m)).astype(B.dtype), dev)
@@ -70,7 +70,7 @@ def verify_matmul_add_igemm(in_dtype, out_dtype, rtol=1e-5):
         if not tvm.get_global_func("tvm.contrib.cublaslt.matmul", True):
             print("skip because extern function is not available")
             return
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         f = tvm.build(s, [A, B, C], target)
         a_old = np.random.uniform(0, 128, size=(n, l))
         b_old = np.random.uniform(0, 128, size=(l, m))
@@ -126,7 +126,7 @@ def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5):
         if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
             print("skip because extern function is not available")
             return
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         f = tvm.build(s, [A, B, C], target)
         a = tvm.nd.array(np.random.uniform(size=(j, n, l)).astype(A.dtype), dev)
         b = tvm.nd.array(np.random.uniform(size=(j, l, m)).astype(B.dtype), dev)
diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py
index 690589c..d73f81b 100644
--- a/tests/python/contrib/test_cudnn.py
+++ b/tests/python/contrib/test_cudnn.py
@@ -41,7 +41,7 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1):
     if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
         print("skip because cudnn is not enabled...")
         return
-    if data_dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+    if data_dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
         print("Skip because gpu does not have fp16 support")
         return
 
@@ -71,7 +71,7 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1):
     s = te.create_schedule(Y.op)
 
     # validation
-    dev = tvm.gpu(0)
+    dev = tvm.cuda(0)
     f = tvm.build(s, [X, W, Y], "cuda --host=llvm", name="conv2d")
     x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype)
     w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype)
@@ -149,7 +149,7 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1):
     s = te.create_schedule(Y.op)
 
     # validation
-    dev = tvm.gpu(0)
+    dev = tvm.cuda(0)
     f = tvm.build(s, [X, W, Y], target="cuda --host=llvm", name="conv3d")
     x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype)
     w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype)
@@ -177,7 +177,7 @@ def verify_softmax(shape, axis, dtype="float32"):
     B = cudnn.softmax(A, axis)
     s = te.create_schedule([B.op])
 
-    dev = tvm.gpu(0)
+    dev = tvm.cuda(0)
     a_np = np.random.uniform(size=shape).astype(dtype)
     b_np = tvm.topi.testing.softmax_python(a_np)
     a = tvm.nd.array(a_np, dev)
@@ -192,7 +192,7 @@ def verify_softmax_4d(shape, dtype="float32"):
     B = cudnn.softmax(A, axis=1)
     s = te.create_schedule([B.op])
 
-    dev = tvm.gpu(0)
+    dev = tvm.cuda(0)
     n, c, h, w = shape
     a_np = np.random.uniform(size=shape).astype(dtype)
     b_np = tvm.topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h * w, c))
diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py
index 9810759..52ee87e 100644
--- a/tests/python/contrib/test_tensorrt.py
+++ b/tests/python/contrib/test_tensorrt.py
@@ -35,7 +35,7 @@ from tvm.relay.op.contrib import tensorrt
 
 def skip_codegen_test():
     """Skip test if TensorRT and CUDA codegen are not present"""
-    if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist:
+    if not tvm.runtime.enabled("cuda") or not tvm.cuda(0).exist:
         print("Skip because CUDA is not enabled.")
         return True
     if not tvm.get_global_func("relay.ext.tensorrt", True):
@@ -45,7 +45,7 @@ def skip_codegen_test():
 
 
 def skip_runtime_test():
-    if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist:
+    if not tvm.runtime.enabled("cuda") or not tvm.cuda(0).exist:
         print("Skip because CUDA is not enabled.")
         return True
     if not tensorrt.is_tensorrt_runtime_enabled():
@@ -143,10 +143,10 @@ def run_and_verify_model(model):
             with tvm.transform.PassContext(
                 opt_level=3, config={"relay.ext.tensorrt.options": config}
             ):
-                exec = relay.create_executor(mode, mod=mod, device=tvm.gpu(0), target="cuda")
+                exec = relay.create_executor(mode, mod=mod, device=tvm.cuda(0), target="cuda")
         else:
             with tvm.transform.PassContext(opt_level=3):
-                exec = relay.create_executor(mode, mod=mod, device=tvm.gpu(0), target="cuda")
+                exec = relay.create_executor(mode, mod=mod, device=tvm.cuda(0), target="cuda")
 
         res = exec.evaluate()(i_data, **params) if not skip_runtime_test() else None
         return res
@@ -199,12 +199,12 @@ def test_tensorrt_simple():
                     opt_level=3, config={"relay.ext.tensorrt.options": config}
                 ):
                     relay_exec = relay.create_executor(
-                        mode, mod=mod, device=tvm.gpu(0), target="cuda"
+                        mode, mod=mod, device=tvm.cuda(0), target="cuda"
                     )
             else:
                 with tvm.transform.PassContext(opt_level=3):
                     relay_exec = relay.create_executor(
-                        mode, mod=mod, device=tvm.gpu(0), target="cuda"
+                        mode, mod=mod, device=tvm.cuda(0), target="cuda"
                     )
             if not skip_runtime_test():
                 result_dict[result_key] = relay_exec.evaluate()(x_data, y_data, z_data)
@@ -247,7 +247,7 @@ def test_tensorrt_not_compatible():
     mod, config = tensorrt.partition_for_tensorrt(mod)
     for mode in ["graph", "vm"]:
         with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
-            exec = relay.create_executor(mode, mod=mod, device=tvm.gpu(0), target="cuda")
+            exec = relay.create_executor(mode, mod=mod, device=tvm.cuda(0), target="cuda")
             if not skip_runtime_test():
                 results = exec.evaluate()(x_data)
 
@@ -273,7 +273,7 @@ def test_tensorrt_serialize_graph_executor():
         return graph, lib, params
 
     def run_graph(graph, lib, params):
-        mod_ = graph_executor.create(graph, lib, device=tvm.gpu(0))
+        mod_ = graph_executor.create(graph, lib, device=tvm.cuda(0))
         mod_.load_params(params)
         mod_.run(data=i_data)
         res = mod_.get_output(0)
@@ -330,7 +330,7 @@ def test_tensorrt_serialize_vm():
 
     def run_vm(code, lib):
         vm_exec = tvm.runtime.vm.Executable.load_exec(code, lib)
-        vm = VirtualMachine(vm_exec, tvm.gpu(0))
+        vm = VirtualMachine(vm_exec, tvm.cuda(0))
         result = vm.invoke("main", data=i_data)
         return result
 
@@ -1415,7 +1415,7 @@ def test_empty_subgraph():
     x_data = np.random.uniform(-1, 1, x_shape).astype("float32")
     for mode in ["graph", "vm"]:
         with tvm.transform.PassContext(opt_level=3):
-            exec = relay.create_executor(mode, mod=mod, device=tvm.gpu(0), target="cuda")
+            exec = relay.create_executor(mode, mod=mod, device=tvm.cuda(0), target="cuda")
             if not skip_runtime_test():
                 results = exec.evaluate()(x_data)
 
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index f9f3bba..067af7f 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1240,7 +1240,7 @@ def test_type_as():
         check_fp16 = False
         try:
             # Only check half precision on supported hardwares.
-            if have_fp16(tvm.gpu(0).compute_version):
+            if have_fp16(tvm.cuda(0).compute_version):
                 check_fp16 = True
         except Exception as e:
             # If GPU is not enabled in TVM, skip the fp16 test.
diff --git a/tests/python/nightly/quantization/test_quantization_accuracy.py b/tests/python/nightly/quantization/test_quantization_accuracy.py
index 57fa49e..5cf4dfe 100644
--- a/tests/python/nightly/quantization/test_quantization_accuracy.py
+++ b/tests/python/nightly/quantization/test_quantization_accuracy.py
@@ -93,7 +93,7 @@ def get_model(model_name, batch_size, qconfig, target=None, original=False, simu
 
 
 def eval_acc(
-    model, dataset, batch_fn, target=tvm.target.cuda(), device=tvm.gpu(), log_interval=100
+    model, dataset, batch_fn, target=tvm.target.cuda(), device=tvm.cuda(), log_interval=100
 ):
     with tvm.transform.PassContext(opt_level=3):
         graph, lib, params = relay.build(model, target)
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index fe5e048..7d1c577 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -508,7 +508,7 @@ def verify_any_conv2d(
 
     targets = None
     if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
-        targets = [("cuda -libs=cudnn", tvm.gpu(0))]
+        targets = [("cuda -libs=cudnn", tvm.cuda(0))]
 
     check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=targets)
 
@@ -811,7 +811,7 @@ def verify_any_dense(
 
     targets = None
     if use_cublas and tvm.get_global_func("tvm.contrib.cublas.matmul", True):
-        targets = [("cuda -libs=cublas", tvm.gpu(0))]
+        targets = [("cuda -libs=cublas", tvm.cuda(0))]
 
     check_result([data_np, weight_np], mod, ref_out_shape, assert_shape=True, targets=targets)
 
diff --git a/tests/python/relay/test_auto_scheduler_tuning.py b/tests/python/relay/test_auto_scheduler_tuning.py
index 13651e7..d3c54fd 100644
--- a/tests/python/relay/test_auto_scheduler_tuning.py
+++ b/tests/python/relay/test_auto_scheduler_tuning.py
@@ -69,7 +69,7 @@ def tune_network(network, target):
 
         # Check the correctness
         def get_output(data, lib):
-            dev = tvm.gpu()
+            dev = tvm.cuda()
             module = graph_executor.GraphModule(lib["default"](dev))
             module.set_input("data", data)
             module.run()
diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py
index 7d2209a..0d98cc0 100644
--- a/tests/python/relay/test_cpp_build_module.py
+++ b/tests/python/relay/test_cpp_build_module.py
@@ -65,7 +65,7 @@ def test_basic_build():
 def test_fp16_build():
     dtype = "float16"
 
-    dev = tvm.gpu(0)
+    dev = tvm.cuda(0)
     if dtype == "float16" and not have_fp16(dev.compute_version):
         print("skip because gpu does not support fp16")
         return
diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py
index 91b3713..aef3c3c 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -67,7 +67,7 @@ def test_unary_op():
                 if (
                     dtype == "float16"
                     and target == "cuda"
-                    and not have_fp16(tvm.gpu(0).compute_version)
+                    and not have_fp16(tvm.cuda(0).compute_version)
                 ):
                     continue
                 intrp = relay.create_executor("graph", device=dev, target=target)
@@ -129,7 +129,7 @@ def test_binary_op():
                 if (
                     dtype == "float16"
                     and target == "cuda"
-                    and not have_fp16(tvm.gpu(0).compute_version)
+                    and not have_fp16(tvm.cuda(0).compute_version)
                 ):
                     continue
                 intrp = relay.create_executor("graph", device=dev, target=target)
@@ -158,7 +158,7 @@ def test_expand_dims():
             if (
                 dtype == "float16"
                 and target == "cuda"
-                and not have_fp16(tvm.gpu(0).compute_version)
+                and not have_fp16(tvm.cuda(0).compute_version)
             ):
                 continue
             data = np.random.uniform(size=dshape).astype(dtype)
@@ -193,7 +193,7 @@ def test_bias_add():
             if (
                 dtype == "float16"
                 and target == "cuda"
-                and not have_fp16(tvm.gpu(0).compute_version)
+                and not have_fp16(tvm.cuda(0).compute_version)
             ):
                 continue
             intrp = relay.create_executor("graph", device=dev, target=target)
@@ -314,7 +314,7 @@ def test_concatenate():
             if (
                 dtype == "float16"
                 and target == "cuda"
-                and not have_fp16(tvm.gpu(0).compute_version)
+                and not have_fp16(tvm.cuda(0).compute_version)
             ):
                 continue
             intrp1 = relay.create_executor("graph", device=dev, target=target)
diff --git a/tests/python/relay/test_pass_context_analysis.py b/tests/python/relay/test_pass_context_analysis.py
index e54682b..fe19c47 100644
--- a/tests/python/relay/test_pass_context_analysis.py
+++ b/tests/python/relay/test_pass_context_analysis.py
@@ -26,19 +26,19 @@ from tvm.relay.analysis import context_analysis
 
 
 def test_device_copy():
-    if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
+    if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist:
         return
 
     mod = tvm.IRModule()
     x = relay.var("x", shape=(2, 3))
-    copy = relay.op.device_copy(x, tvm.cpu(), tvm.gpu())
+    copy = relay.op.device_copy(x, tvm.cpu(), tvm.cuda())
     out = copy + relay.const(np.random.rand(2, 3))
     glb_var = relay.GlobalVar("main")
     mod[glb_var] = relay.Function([x], out)
     ca = context_analysis(mod, tvm.cpu())
 
     cpu_dev = tvm.cpu().device_type
-    gpu_dev = tvm.gpu().device_type
+    gpu_dev = tvm.cuda().device_type
     for expr, dev in ca.items():
         if isinstance(expr, _expr.Call):
             assert dev[0].value == gpu_dev
@@ -49,7 +49,7 @@ def test_device_copy():
 
 
 def test_shape_func():
-    if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
+    if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist:
         return
 
     mod = tvm.IRModule()
@@ -65,11 +65,11 @@ def test_shape_func():
     is_inputs = [False]
     shape_func = relay.op.vm.shape_func(fn, ins, outs, is_inputs)
     mod["main"] = relay.Function([x, out], shape_func)
-    ca = context_analysis(mod, tvm.gpu())
+    ca = context_analysis(mod, tvm.cuda())
     main = mod["main"]
 
     cpu_dev = tvm.cpu().device_type
-    gpu_dev = tvm.gpu().device_type
+    gpu_dev = tvm.cuda().device_type
     assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev
     # The output of shape func should be on cpu.
     assert main.params[1] in ca and ca[main.params[1]][0].value == cpu_dev
@@ -78,7 +78,7 @@ def test_shape_func():
 
 
 def test_vm_shape_of():
-    if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
+    if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist:
         return
 
     mod = tvm.IRModule()
@@ -86,17 +86,17 @@ def test_vm_shape_of():
     x = relay.var("x", shape=data_shape)
     y = relay.op.vm.shape_of(x)
     mod["main"] = relay.Function([x], y)
-    ca = context_analysis(mod, tvm.gpu())
+    ca = context_analysis(mod, tvm.cuda())
     main = mod["main"]
 
     cpu_dev = tvm.cpu().device_type
-    gpu_dev = tvm.gpu().device_type
+    gpu_dev = tvm.cuda().device_type
     assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev
     assert main.body in ca and ca[main.body][0].value == cpu_dev
 
 
 def test_alloc_storage():
-    if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
+    if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist:
         return
 
     mod = tvm.IRModule()
@@ -104,14 +104,14 @@ def test_alloc_storage():
     size = relay.Var("size", relay.scalar_type("int64"))
     alignment = relay.Var("alignment", relay.scalar_type("int64"))
     # allocate a chunk on of memory on gpu.
-    sto = relay.op.memory.alloc_storage(size, alignment, tvm.gpu())
+    sto = relay.op.memory.alloc_storage(size, alignment, tvm.cuda())
     mod["main"] = relay.Function([size, alignment], sto)
-    ca = context_analysis(mod, tvm.gpu())
+    ca = context_analysis(mod, tvm.cuda())
     main = mod["main"]
     body = main.body
 
     cpu_dev = tvm.cpu().device_type
-    gpu_dev = tvm.gpu().device_type
+    gpu_dev = tvm.cuda().device_type
     # Inputs are unified with alloc storage inputs which are on cpu
     assert main.params[0] in ca and ca[main.params[0]][0].value == cpu_dev
     assert main.params[1] in ca and ca[main.params[1]][0].value == cpu_dev
@@ -126,7 +126,7 @@ def test_alloc_storage():
 
 
 def test_alloc_tensor():
-    if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
+    if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist:
         return
 
     mod = tvm.IRModule()
@@ -136,12 +136,12 @@ def test_alloc_tensor():
     sh = relay.const(np.array([3, 2]), dtype="int64")
     at = relay.op.memory.alloc_tensor(sto, relay.const(0, dtype="int64"), sh)
     mod["main"] = relay.Function([sto], at)
-    ca = context_analysis(mod, tvm.gpu())
+    ca = context_analysis(mod, tvm.cuda())
     main = mod["main"]
     body = main.body
 
     cpu_dev = tvm.cpu().device_type
-    gpu_dev = tvm.gpu().device_type
+    gpu_dev = tvm.cuda().device_type
     # Input of the function falls back to the default device gpu
     assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev
 
@@ -155,7 +155,7 @@ def test_alloc_tensor():
 
 
 def test_vm_reshape_tensor():
-    if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
+    if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist:
         return
 
     x = relay.var("x", shape=(2, 8), dtype="float32")
@@ -163,12 +163,12 @@ def test_vm_reshape_tensor():
     y = relay.op.vm.reshape_tensor(x, shape, [2, 4, 2])
     mod = tvm.IRModule()
     mod["main"] = relay.Function([x], y)
-    ca = context_analysis(mod, tvm.gpu())
+    ca = context_analysis(mod, tvm.cuda())
     main = mod["main"]
     body = main.body
 
     cpu_dev = tvm.cpu().device_type
-    gpu_dev = tvm.gpu().device_type
+    gpu_dev = tvm.cuda().device_type
     # Input of the function falls back to the default device gpu
     assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev
 
@@ -181,7 +181,7 @@ def test_vm_reshape_tensor():
 
 
 def test_dynamic_input():
-    if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
+    if not tvm.testing.device_enabled("cuda") or not tvm.cuda(0).exist:
         return
 
     mod = tvm.IRModule()
@@ -195,7 +195,7 @@ def test_dynamic_input():
     ca = context_analysis(mod, tvm.cpu())
     main = mod["main"]
 
-    gpu_dev = tvm.gpu().device_type
+    gpu_dev = tvm.cuda().device_type
     assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev
     assert main.params[1] in ca and ca[main.params[1]][0].value == gpu_dev
     assert main.body in ca and ca[main.body][0].value == gpu_dev
diff --git a/tests/python/topi/python/test_topi_relu.py b/tests/python/topi/python/test_topi_relu.py
index 9acf98d..947a6ca 100644
--- a/tests/python/topi/python/test_topi_relu.py
+++ b/tests/python/topi/python/test_topi_relu.py
@@ -35,7 +35,7 @@ def verify_relu(m, n, dtype="float32"):
     b_np = a_np * (a_np > 0)
 
     def check_target(target, dev):
-        if dtype == "float16" and target == "cuda" and not have_fp16(tvm.gpu(0).compute_version):
+        if dtype == "float16" and target == "cuda" and not have_fp16(tvm.cuda(0).compute_version):
             print("Skip because %s does not have fp16 support" % target)
             return
         print("Running on target: %s" % target)
diff --git a/tests/python/topi/python/test_topi_tensor.py b/tests/python/topi/python/test_topi_tensor.py
index d395c0c..2d4eed3 100644
--- a/tests/python/topi/python/test_topi_tensor.py
+++ b/tests/python/topi/python/test_topi_tensor.py
@@ -95,7 +95,7 @@ def verify_vectorization(n, m, dtype):
         if not tvm.testing.device_enabled(targeta):
             print("Skip because %s is not enabled" % targeta)
             return
-        if dtype == "float16" and targeta == "cuda" and not have_fp16(tvm.gpu(0).compute_version):
+        if dtype == "float16" and targeta == "cuda" and not have_fp16(tvm.cuda(0).compute_version):
             print("Skip because gpu does not have fp16 support")
             return
         with tvm.target.Target(targeta):
diff --git a/tests/python/unittest/test_runtime_graph_cuda_graph.py b/tests/python/unittest/test_runtime_graph_cuda_graph.py
index ee7750e..fb0c736 100644
--- a/tests/python/unittest/test_runtime_graph_cuda_graph.py
+++ b/tests/python/unittest/test_runtime_graph_cuda_graph.py
@@ -73,7 +73,7 @@ def test_graph_simple():
 
     def check_verify():
         mlib = tvm.build(s, [A, B], "cuda", name="myadd")
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         try:
             mod = cuda_graph_executor.create(graph, mlib, dev)
         except ValueError:
diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py
index f85edfc..3100414 100644
--- a/tests/python/unittest/test_runtime_module_based_interface.py
+++ b/tests/python/unittest/test_runtime_module_based_interface.py
@@ -97,7 +97,7 @@ def test_gpu():
     with relay.build_config(opt_level=3):
         complied_graph_lib = relay.build_module.build(mod, "cuda", params=params)
     data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32")
-    dev = tvm.gpu()
+    dev = tvm.cuda()
 
     # raw api
     gmod = complied_graph_lib["default"](dev)
@@ -190,7 +190,7 @@ def test_mod_export():
         # test the robustness wrt to parent module destruction
         def setup_gmod():
             loaded_lib = tvm.runtime.load_module(path_lib)
-            dev = tvm.gpu()
+            dev = tvm.cuda()
             return loaded_lib["default"](dev)
 
         gmod = setup_gmod()
@@ -378,7 +378,7 @@ def test_remove_package_params():
             fo.write(runtime.save_param_dict(complied_graph_lib.get_params()))
         loaded_lib = tvm.runtime.load_module(path_lib)
         data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32")
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
 
         # raw api
         gmod = loaded_lib["default"](dev)
@@ -559,7 +559,7 @@ def test_cuda_graph_executor():
         complied_graph_lib = relay.build_module.build(mod, "cuda", params=params)
     data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32")
 
-    dev = tvm.gpu()
+    dev = tvm.cuda()
     try:
         gmod = complied_graph_lib["cuda_graph_create"](dev)
     except:
diff --git a/tests/python/unittest/test_target_codegen_blob.py b/tests/python/unittest/test_target_codegen_blob.py
index c769819..2a30989 100644
--- a/tests/python/unittest/test_target_codegen_blob.py
+++ b/tests/python/unittest/test_target_codegen_blob.py
@@ -57,7 +57,7 @@ def test_synthetic():
 
     loaded_lib = tvm.runtime.load_module(path_lib)
     data = np.random.uniform(-1, 1, size=input_shape).astype("float32")
-    dev = tvm.gpu()
+    dev = tvm.cuda()
     module = graph_executor.GraphModule(loaded_lib["default"](dev))
     module.set_input("data", data)
     module.run()
@@ -68,7 +68,7 @@ def test_synthetic():
 
 @tvm.testing.uses_gpu
 def test_cuda_lib():
-    dev = tvm.gpu(0)
+    dev = tvm.cuda(0)
     for device in ["llvm", "cuda"]:
         if not tvm.testing.device_enabled(device):
             print("skip because %s is not enabled..." % device)
diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py
index e639e6b..846bdcb 100644
--- a/tests/python/unittest/test_target_codegen_cuda.py
+++ b/tests/python/unittest/test_target_codegen_cuda.py
@@ -33,10 +33,10 @@ def test_cuda_vectorize_add():
     num_thread = 8
 
     def check_cuda(dtype, n, lanes):
-        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+        if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
             print("Skip because gpu does not have fp16 support")
             return
-        if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version):
+        if dtype == "int8" and not have_int8(tvm.cuda(0).compute_version):
             print("skip because gpu does not support int8")
             return
         A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes))
@@ -46,7 +46,7 @@ def test_cuda_vectorize_add():
         s[B].bind(xo, bx)
         s[B].bind(xi, tx)
         fun = tvm.build(s, [A, B], "cuda")
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np.random.uniform(size=(n, lanes)))
         c = tvm.nd.empty((n,), B.dtype, dev)
         fun(a, c)
@@ -70,7 +70,7 @@ def test_cuda_vectorize_add():
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
 def test_cuda_bf16_vectorize_add():
-    if not have_bf16(tvm.gpu(0).compute_version):
+    if not have_bf16(tvm.cuda(0).compute_version):
         print("skip because gpu does not support bf16")
         return
     num_thread = 8
@@ -99,7 +99,7 @@ def test_cuda_bf16_vectorize_add():
             disabled_pass=["tir.BF16Promote", "tir.BF16CastElimination", "tir.BF16TypeLowering"]
         ):
             fun = tvm.build(s, [A, B], "cuda")
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         np_a = np.random.uniform(size=(n, lanes)).astype("float32")
         np_a = np_bf162np_float(np_float2np_bf16(np_a))
         a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np_float2np_bf16(np_a))
@@ -120,7 +120,7 @@ def test_cuda_multiply_add():
     num_thread = 8
 
     def check_cuda(dtype, n, lanes):
-        if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version):
+        if dtype == "int8" and not have_int8(tvm.cuda(0).compute_version):
             print("skip because gpu does not support int8")
             return
         A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes))
@@ -138,7 +138,7 @@ def test_cuda_multiply_add():
         np_b = np.random.randint(low=-128, high=127, size=(n, lanes))
         np_c = np.random.randint(low=0, high=127, size=(n,))
         np_d = [sum(x * y) + z for x, y, z in zip(np_a, np_b, np_c)]
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np_a)
         b = tvm.nd.empty((n,), B.dtype, dev).copyfrom(np_b)
         c = tvm.nd.empty((n,), C.dtype, dev).copyfrom(np_c)
@@ -155,7 +155,7 @@ def test_cuda_vectorize_load():
     num_thread = 8
 
     def check_cuda(dtype, n, lanes):
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes))
         B = te.compute((n,), lambda i: A[i], name="B")
         s = te.create_schedule(B.op)
@@ -181,7 +181,7 @@ def test_cuda_vectorize_load():
 def test_cuda_make_int8():
     def check_cuda(n, value, lanes):
         dtype = "int8"
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         A = te.compute((n, lanes), lambda i, j: tvm.tir.const(value, dtype=dtype))
         s = te.create_schedule(A.op)
         y, x = s[A].op.axis
@@ -209,7 +209,7 @@ def test_cuda_make_int8():
 def test_cuda_make_int4():
     def check_cuda(n, value, lanes):
         dtype = "int4"
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         A = te.compute((n, lanes), lambda i, j: tvm.tir.const(value, dtype=dtype))
         s = te.create_schedule(A.op)
         y, x = s[A].op.axis
@@ -300,7 +300,7 @@ def test_cuda_shuffle():
         b_ = np.array((list(range(4))[::-1]) * 16, dtype="int32")
         c_ = np.zeros((64,), dtype="int32")
         ref = a_ + np.array((list(range(4))) * 16, dtype="int32")
-        nda, ndb, ndc = [tvm.nd.array(i, tvm.gpu(0)) for i in [a_, b_, c_]]
+        nda, ndb, ndc = [tvm.nd.array(i, tvm.cuda(0)) for i in [a_, b_, c_]]
         module(nda, ndb, ndc)
         tvm.testing.assert_allclose(ndc.asnumpy(), ref)
 
@@ -440,7 +440,7 @@ def test_cuda_const_float_to_half():
     s[c].bind(tx, te.thread_axis("threadIdx.x"))
 
     func = tvm.build(s, [a, c], "cuda")
-    dev = tvm.gpu(0)
+    dev = tvm.cuda(0)
     a_np = np.random.uniform(size=shape).astype(a.dtype)
     c_np = np.zeros(shape=shape, dtype=c.dtype)
     a = tvm.nd.array(a_np, dev)
@@ -528,7 +528,7 @@ def test_cuda_floordiv_with_vectorization():
         s[B].bind(xio, tx)
         func = tvm.build(s, [A, B], "cuda")
 
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         a_np = np.random.uniform(size=(n,)).astype(A.dtype)
         b_np = np.array([a_np[i // k] for i in range(0, n)])
         a_nd = tvm.nd.array(a_np, dev)
@@ -554,7 +554,7 @@ def test_cuda_floormod_with_vectorization():
         s[B].bind(xio, tx)
         func = tvm.build(s, [A, B], "cuda")
 
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         a_np = np.random.uniform(size=(n,)).astype(A.dtype)
         b_np = np.array([a_np[i % k] for i in range(0, n)])
         a_nd = tvm.nd.array(a_np, dev)
@@ -567,7 +567,7 @@ def test_cuda_floormod_with_vectorization():
 @tvm.testing.requires_cuda
 def test_vectorized_casts():
     def check(t0, t1, factor):
-        if (t0 == "float16" or t1 == "float16") and not have_fp16(tvm.gpu(0).compute_version):
+        if (t0 == "float16" or t1 == "float16") and not have_fp16(tvm.cuda(0).compute_version):
             print("Skip because gpu does not have fp16 support")
             return
 
@@ -585,7 +585,7 @@ def test_vectorized_casts():
         func = tvm.build(s, [A, B, C], "cuda")
 
         # correctness
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         low, high = (0, 20) if t0.startswith("u") or t1.startswith("u") else (-10, 10)
         a_np = np.random.randint(low, high, size=n).astype(A.dtype)
         b_np = np.random.randint(low, high, size=n).astype(B.dtype)
@@ -664,7 +664,7 @@ def test_vectorized_intrin1():
     ]
 
     def run_test(tvm_intrin, np_func, dtype):
-        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+        if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
             print("Skip because gpu does not have fp16 support")
             return
         # set of intrinsics does not support fp16 yet.
@@ -686,7 +686,7 @@ def test_vectorized_intrin1():
         B = te.compute((n,), lambda *i: tvm_intrin(A(*i)), name="B")
         s = sched(B)
         f = tvm.build(s, [A, B], "cuda")
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), dev)
         b = tvm.nd.array(np.zeros(shape=(n,)).astype(A.dtype), dev)
         f(a, b)
@@ -712,7 +712,7 @@ def test_vectorized_intrin2(dtype="float32"):
         B = te.compute((n,), lambda i: tvm_intrin(A[i], c2), name="B")
         s = sched(B)
         f = tvm.build(s, [A, B], "cuda")
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), dev)
         b = tvm.nd.array(np.zeros(shape=(n,)).astype(A.dtype), dev)
         f(a, b)
@@ -738,7 +738,7 @@ def test_vectorized_popcount():
         B = te.compute((n,), lambda i: tvm.tir.popcount(A[i]), name="B")
         s = sched(B)
         f = tvm.build(s, [A, B], "cuda")
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         a = tvm.nd.array(np.random.randint(0, 100000, size=n).astype(A.dtype), dev)
         b = tvm.nd.array(np.zeros(shape=(n,)).astype(B.dtype), dev)
         f(a, b)
@@ -753,11 +753,11 @@ def test_vectorized_popcount():
 @tvm.testing.requires_cuda
 def test_cuda_vectorize_load_permute_pad():
     def check_cuda(dtype, n, l, padding, lanes):
-        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+        if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
             print("Skip because gpu does not have fp16 support")
             return
 
-        dev = tvm.gpu(0)
+        dev = tvm.cuda(0)
         A = tvm.te.placeholder((n, l), name="A", dtype=dtype)
         B = tvm.te.compute(
             (n // lanes, l + 2 * padding, lanes),
diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py
index 56a8514..96b67ea 100644
--- a/tests/python/unittest/test_target_codegen_llvm.py
+++ b/tests/python/unittest/test_target_codegen_llvm.py
@@ -810,7 +810,7 @@ def test_llvm_gpu_lower_atomic():
         s = tvm.te.create_schedule(C.op)
         f = tvm.build(s, [A], target="nvptx")
 
-        dev = tvm.gpu()
+        dev = tvm.cuda()
         a = tvm.nd.array(np.zeros((size,)).astype(A.dtype), dev)
         f(a)
         ref = np.zeros((size,)).astype(A.dtype)
diff --git a/tests/python/unittest/test_te_schedule_postproc_rewrite_for_tensor_core.py b/tests/python/unittest/test_te_schedule_postproc_rewrite_for_tensor_core.py
index e7a8469..0f97e49 100644
--- a/tests/python/unittest/test_te_schedule_postproc_rewrite_for_tensor_core.py
+++ b/tests/python/unittest/test_te_schedule_postproc_rewrite_for_tensor_core.py
@@ -100,7 +100,7 @@ def tensor_core_matmul(warp_tile_m=16, m=64, n=32, l=96):
 
     func = tvm.build(s, [A, B, C], "cuda")
 
-    dev = tvm.gpu(0)
+    dev = tvm.cuda(0)
     a_np = np.random.uniform(size=(n, l)).astype(A.dtype)
     b_np = np.random.uniform(size=(l, m)).astype(B.dtype)
     c_np = np.zeros((n, m), dtype=np.float32)
@@ -195,7 +195,7 @@ def tensor_core_batch_matmul(warp_tile_m=16, m=64, n=32, l=96, batch=2):
 
     func = tvm.build(s, [A, B, C], "cuda")
 
-    dev = tvm.gpu(0)
+    dev = tvm.cuda(0)
     a_np = np.random.uniform(size=(batch, n, l)).astype(A.dtype)
     b_np = np.random.uniform(size=(batch, l, m)).astype(B.dtype)
     c_np = np.zeros((batch, n, m), dtype=np.float32)
diff --git a/tests/python/unittest/test_te_schedule_tensor_core.py b/tests/python/unittest/test_te_schedule_tensor_core.py
index 9491425..e0cf583 100644
--- a/tests/python/unittest/test_te_schedule_tensor_core.py
+++ b/tests/python/unittest/test_te_schedule_tensor_core.py
@@ -256,7 +256,7 @@ def test_tensor_core_batch_matmal():
 
     func = tvm.build(s, [A, B, C], "cuda")
 
-    dev = tvm.gpu(0)
+    dev = tvm.cuda(0)
     a_np = np.random.uniform(size=(batch_size, nn, ll, 32, 16)).astype(A.dtype)
     b_np = np.random.uniform(size=(batch_size, ll, mm, 16, 8)).astype(B.dtype)
     a = tvm.nd.array(a_np, dev)
@@ -432,7 +432,7 @@ def test_tensor_core_batch_conv():
 
     func = tvm.build(s, [A, W, Conv], "cuda")
 
-    dev = tvm.gpu(0)
+    dev = tvm.cuda(0)
     a_np = np.random.uniform(size=data_shape).astype(A.dtype)
     w_np = np.random.uniform(size=kernel_shape).astype(W.dtype)
     a = tvm.nd.array(a_np, dev)
diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
index ac72043..2a84078 100644
--- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py
+++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
@@ -92,7 +92,7 @@ def test_lower_warp_memory_correct_indices():
 @tvm.testing.requires_cuda
 def test_lower_warp_memory_cuda_end_to_end():
     def check_cuda(dtype):
-        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+        if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
             print("Skip because gpu does not have fp16 support")
             return
 
@@ -114,7 +114,7 @@ def test_lower_warp_memory_cuda_end_to_end():
             xo, xi = s[AA].split(s[AA].op.axis[0], 32)
             s[AA].bind(xi, tx)
 
-            dev = tvm.gpu(0)
+            dev = tvm.cuda(0)
             func = tvm.build(s, [A, B], "cuda")
             A_np = np.array(list(range(m)), dtype=dtype)
             B_np = np.array(
@@ -141,7 +141,7 @@ def test_lower_warp_memory_cuda_end_to_end():
 @tvm.testing.requires_cuda
 def test_lower_warp_memory_cuda_half_a_warp():
     def check_cuda(dtype):
-        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+        if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
             print("Skip because gpu does not have fp16 support")
             return
 
@@ -181,7 +181,7 @@ def test_lower_warp_memory_cuda_half_a_warp():
             _, x = AA.op.axis
             s[AA].bind(x, tx)
 
-            dev = tvm.gpu(0)
+            dev = tvm.cuda(0)
             func = tvm.build(s, [A, B], "cuda")
             A_np = np.array([list(range(i, m + i)) for i in range(n)], dtype=dtype)
             B_np = np.array([list(range(1 + i, m + i)) + [i] for i in range(n)], dtype=dtype)
@@ -198,7 +198,7 @@ def test_lower_warp_memory_cuda_half_a_warp():
 @tvm.testing.requires_cuda
 def test_lower_warp_memory_cuda_2_buffers():
     def check_cuda(dtype):
-        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+        if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
             print("Skip because gpu does not have fp16 support")
             return
 
@@ -228,7 +228,7 @@ def test_lower_warp_memory_cuda_2_buffers():
             s[BB].bind(xo, bx)
             s[BB].bind(xi, tx)
 
-            dev = tvm.gpu(0)
+            dev = tvm.cuda(0)
             func = tvm.build(s, [A, B, C], "cuda")
             AB_np = np.array(list(range(m)), dtype=dtype)
             C_np = np.array(list(range(1, m)) + [0], dtype=dtype) * 2
diff --git a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
index 41fdcbb..8664c86 100644
--- a/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
+++ b/tutorials/auto_scheduler/tune_conv2d_layer_cuda.py
@@ -145,7 +145,7 @@ bias_np = np.random.uniform(size=(1, CO, 1, 1)).astype(np.float32)
 conv_np = conv2d_nchw_python(data_np, weight_np, strides, padding)
 out_np = np.maximum(conv_np + bias_np, 0.0)
 
-dev = tvm.gpu()
+dev = tvm.cuda()
 data_tvm = tvm.nd.array(data_np, device=dev)
 weight_tvm = tvm.nd.array(weight_np, device=dev)
 bias_tvm = tvm.nd.array(bias_np, device=dev)
diff --git a/tutorials/autotvm/tune_conv2d_cuda.py b/tutorials/autotvm/tune_conv2d_cuda.py
index d14f9c3..c46180d 100644
--- a/tutorials/autotvm/tune_conv2d_cuda.py
+++ b/tutorials/autotvm/tune_conv2d_cuda.py
@@ -230,7 +230,7 @@ a_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
 w_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
 c_np = conv2d_nchw_python(a_np, w_np, strides, padding)
 
-dev = tvm.gpu()
+dev = tvm.cuda()
 a_tvm = tvm.nd.array(a_np, device=dev)
 w_tvm = tvm.nd.array(w_np, device=dev)
 c_tvm = tvm.nd.empty(c_np.shape, device=dev)
diff --git a/tutorials/frontend/deploy_sparse.py b/tutorials/frontend/deploy_sparse.py
index 92f4511..eb9b4ee 100644
--- a/tutorials/frontend/deploy_sparse.py
+++ b/tutorials/frontend/deploy_sparse.py
@@ -105,7 +105,7 @@ seq_len = 128
 # TVM platform identifier. Note that best cpu performance can be achieved by setting -mcpu
 # appropriately for your specific machine. CUDA and ROCm are also supported.
 target = "llvm"
-# Which device to run on. Should be one of tvm.cpu() or tvm.gpu().
+# Which device to run on. Should be one of tvm.cpu() or tvm.cuda().
 dev = tvm.cpu()
 # If true, then a sparse variant of the network will be run and
 # benchmarked.
diff --git a/tutorials/frontend/from_caffe2.py b/tutorials/frontend/from_caffe2.py
index a3378de..1c00f92 100644
--- a/tutorials/frontend/from_caffe2.py
+++ b/tutorials/frontend/from_caffe2.py
@@ -107,7 +107,7 @@ import tvm
 from tvm import te
 from tvm.contrib import graph_executor
 
-# context x86 CPU, use tvm.gpu(0) if you run on GPU
+# context x86 CPU, use tvm.cuda(0) if you run on GPU
 dev = tvm.cpu(0)
 # create a runtime executor module
 m = graph_executor.GraphModule(lib["default"](dev))
diff --git a/tutorials/frontend/from_keras.py b/tutorials/frontend/from_keras.py
index 5f39a24..8625465 100644
--- a/tutorials/frontend/from_keras.py
+++ b/tutorials/frontend/from_keras.py
@@ -96,7 +96,7 @@ shape_dict = {"input_1": data.shape}
 mod, params = relay.frontend.from_keras(keras_resnet50, shape_dict)
 # compile the model
 target = "cuda"
-dev = tvm.gpu(0)
+dev = tvm.cuda(0)
 with tvm.transform.PassContext(opt_level=3):
     executor = relay.build_module.create_executor("graph", mod, dev, target)
 
diff --git a/tutorials/frontend/from_mxnet.py b/tutorials/frontend/from_mxnet.py
index bfaac2c..da1bf4e 100644
--- a/tutorials/frontend/from_mxnet.py
+++ b/tutorials/frontend/from_mxnet.py
@@ -106,7 +106,7 @@ with tvm.transform.PassContext(opt_level=3):
 # Now, we would like to reproduce the same forward computation using TVM.
 from tvm.contrib import graph_executor
 
-dev = tvm.gpu(0)
+dev = tvm.cuda(0)
 dtype = "float32"
 m = graph_executor.GraphModule(lib["default"](dev))
 # set inputs
diff --git a/tutorials/frontend/from_tensorflow.py b/tutorials/frontend/from_tensorflow.py
index 9c8d0f6..468caf5 100644
--- a/tutorials/frontend/from_tensorflow.py
+++ b/tutorials/frontend/from_tensorflow.py
@@ -72,7 +72,7 @@ label_map_url = os.path.join(repo_base, label_map)
 # Use these commented settings to build for cuda.
 # target = tvm.target.Target("cuda", host="llvm")
 # layout = "NCHW"
-# dev = tvm.gpu(0)
+# dev = tvm.cuda(0)
 target = tvm.target.Target("llvm", host="llvm")
 layout = None
 dev = tvm.cpu(0)
diff --git a/tutorials/get_started/relay_quick_start.py b/tutorials/get_started/relay_quick_start.py
index ffc9bbe..9bd3065 100644
--- a/tutorials/get_started/relay_quick_start.py
+++ b/tutorials/get_started/relay_quick_start.py
@@ -107,7 +107,7 @@ with tvm.transform.PassContext(opt_level=opt_level):
 # Now we can create graph executor and run the module on Nvidia GPU.
 
 # create random input
-dev = tvm.gpu()
+dev = tvm.cuda()
 data = np.random.uniform(-1, 1, size=data_shape).astype("float32")
 # create module
 module = graph_executor.GraphModule(lib["default"](dev))
diff --git a/tutorials/language/reduction.py b/tutorials/language/reduction.py
index f782ac6..206848c 100644
--- a/tutorials/language/reduction.py
+++ b/tutorials/language/reduction.py
@@ -137,7 +137,7 @@ print(fcuda.imported_modules[0].get_source())
 # Verify the correctness of result kernel by comparing it to numpy.
 #
 nn = 128
-dev = tvm.gpu(0)
+dev = tvm.cuda(0)
 a = tvm.nd.array(np.random.uniform(size=(nn, nn)).astype(A.dtype), dev)
 b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), dev)
 fcuda(a, b)
diff --git a/tutorials/language/scan.py b/tutorials/language/scan.py
index 8124b56..8876921 100644
--- a/tutorials/language/scan.py
+++ b/tutorials/language/scan.py
@@ -83,7 +83,7 @@ print(tvm.lower(s, [X, s_scan], simple_mode=True))
 # numpy to verify the correctness of the result.
 #
 fscan = tvm.build(s, [X, s_scan], "cuda", name="myscan")
-dev = tvm.gpu(0)
+dev = tvm.cuda(0)
 n = 1024
 m = 10
 a_np = np.random.uniform(size=(m, n)).astype(s_scan.dtype)
diff --git a/tutorials/optimize/opt_conv_cuda.py b/tutorials/optimize/opt_conv_cuda.py
index 0cecc82..0ac2c62 100644
--- a/tutorials/optimize/opt_conv_cuda.py
+++ b/tutorials/optimize/opt_conv_cuda.py
@@ -238,7 +238,7 @@ s[WW].vectorize(fi)  # vectorize memory load
 #
 
 func = tvm.build(s, [A, W, B], "cuda")
-dev = tvm.gpu(0)
+dev = tvm.cuda(0)
 a_np = np.random.uniform(size=(in_size, in_size, in_channel, batch)).astype(A.dtype)
 w_np = np.random.uniform(size=(kernel, kernel, in_channel, out_channel)).astype(W.dtype)
 a = tvm.nd.array(a_np, dev)
diff --git a/tutorials/optimize/opt_conv_tensorcore.py b/tutorials/optimize/opt_conv_tensorcore.py
index 0a7798d..702e4a7 100644
--- a/tutorials/optimize/opt_conv_tensorcore.py
+++ b/tutorials/optimize/opt_conv_tensorcore.py
@@ -392,7 +392,7 @@ print(tvm.lower(s, [A, W, Conv], simple_mode=True))
 # Since TensorCores are only supported in NVIDIA GPU with Compute Capability 7.0 or higher, it may not
 # be able to run on our build server
 
-dev = tvm.gpu(0)
+dev = tvm.cuda(0)
 if nvcc.have_tensorcore(dev.compute_version):
     with tvm.transform.PassContext(config={"tir.UnrollLoop": {"auto_max_step": 16}}):
         func = tvm.build(s, [A, W, Conv], "cuda")
diff --git a/tutorials/topi/intro_topi.py b/tutorials/topi/intro_topi.py
index 1fefae5..5ddb878 100644
--- a/tutorials/topi/intro_topi.py
+++ b/tutorials/topi/intro_topi.py
@@ -99,7 +99,7 @@ print(sg.stages)
 # We can test the correctness by comparing with :code:`numpy` result as follows
 #
 func = tvm.build(sg, [a, b, g], "cuda")
-dev = tvm.gpu(0)
+dev = tvm.cuda(0)
 a_np = np.random.uniform(size=(x, y, y)).astype(a.dtype)
 b_np = np.random.uniform(size=(y, y)).astype(b.dtype)
 g_np = np.sum(np.add(a_np + b_np, a_np * b_np) / 2.0)