You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/08/25 14:03:38 UTC

[tvm] branch unity updated: [Unity] UpdateVDevice pass and infer vdevice (#15570)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0443482f7e [Unity] UpdateVDevice pass and infer vdevice (#15570)
0443482f7e is described below

commit 0443482f7eb128d4daac9a56b0c1f48d59ac24fb
Author: Yong Wu <yo...@gmail.com>
AuthorDate: Fri Aug 25 07:03:30 2023 -0700

    [Unity] UpdateVDevice pass and infer vdevice (#15570)
---
 include/tvm/relax/struct_info.h                    |   8 +-
 include/tvm/relax/transform.h                      |   8 +
 python/tvm/relax/struct_info.py                    |   2 -
 python/tvm/relax/transform/transform.py            |  20 +-
 src/relax/analysis/struct_info_analysis.cc         |  41 +++--
 src/relax/ir/expr.cc                               |   3 +-
 src/relax/ir/struct_info_functor.cc                |   5 +-
 src/relax/op/image/resize.cc                       |   6 +
 src/relax/op/nn/attention.cc                       |   3 +
 src/relax/op/nn/convolution.cc                     |  28 +++
 src/relax/op/nn/nn.cc                              |  38 +++-
 src/relax/op/nn/pooling.cc                         |  12 ++
 src/relax/op/op_common.h                           |  26 +++
 src/relax/op/tensor/binary.cc                      |  16 ++
 src/relax/op/tensor/create.cc                      |   3 +
 src/relax/op/tensor/index.cc                       |  30 +++
 src/relax/op/tensor/linear_algebra.cc              |  46 +++++
 src/relax/op/tensor/manipulate.cc                  | 203 ++++++++++++++++++++-
 src/relax/op/tensor/search.cc                      |  43 +++++
 src/relax/op/tensor/set.cc                         |  38 +++-
 src/relax/op/tensor/statistical.cc                 |  27 +++
 src/relax/op/tensor/ternary.cc                     |  25 ++-
 src/relax/transform/alter_op_impl.cc               |   4 +
 src/relax/transform/convert_layout.cc              |   5 +-
 src/relax/transform/to_mixed_precision.cc          |   2 +-
 src/relax/transform/update_vdevice.cc              | 114 ++++++++++++
 src/script/ir_builder/ir/ir.cc                     |   2 +-
 .../relax/test_analysis_struct_info_analysis.py    |  60 +++++-
 tests/python/relax/test_op_binary.py               |  32 +++-
 tests/python/relax/test_op_create.py               |  17 +-
 tests/python/relax/test_op_image.py                |   9 +-
 tests/python/relax/test_op_index.py                |  17 +-
 tests/python/relax/test_op_linear_algebra.py       |  16 +-
 tests/python/relax/test_op_manipulate.py           | 105 ++++++++++-
 tests/python/relax/test_op_nn.py                   |  21 ++-
 tests/python/relax/test_op_nn_convolution.py       |  30 ++-
 tests/python/relax/test_op_nn_pooling.py           |  19 +-
 tests/python/relax/test_op_search.py               |  14 +-
 tests/python/relax/test_op_set.py                  |  11 +-
 tests/python/relax/test_op_statistical.py          |  12 +-
 tests/python/relax/test_op_ternary.py              |   9 +-
 tests/python/relax/test_op_unary.py                |   8 +-
 .../python/relax/test_transform_update_vdevice.py  | 128 +++++++++++++
 tests/python/relax/test_tvmscript_parser.py        |   2 +-
 44 files changed, 1202 insertions(+), 66 deletions(-)

diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h
index deda5d666e..d2bf525225 100644
--- a/include/tvm/relax/struct_info.h
+++ b/include/tvm/relax/struct_info.h
@@ -220,9 +220,7 @@ class TensorStructInfo : public StructInfo {
    *
    * \note shape must already be normalized.
    */
-  TVM_DLL TensorStructInfo(Expr shape, DataType dtype,
-                           VDevice vdevice = VDevice(/*tgt*/ {}, /*dev_id*/ 0,
-                                                     /*mem_scope*/ "global"),
+  TVM_DLL TensorStructInfo(Expr shape, DataType dtype, VDevice vdevice = VDevice(),
                            Span span = Span());
 
   /*!
@@ -232,9 +230,7 @@ class TensorStructInfo : public StructInfo {
    * \param vdevice The virtual device.
    * \param span The span of the AST.
    */
-  TVM_DLL TensorStructInfo(DataType dtype, int ndim,
-                           VDevice vdevice = VDevice(/*tgt*/ {}, /*dev_id*/ 0,
-                                                     /*mem_scope*/ "global"),
+  TVM_DLL TensorStructInfo(DataType dtype, int ndim, VDevice vdevice = VDevice(),
                            Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo, TensorStructInfoNode);
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 05b26f0242..a6ef9ad7ad 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -258,6 +258,14 @@ TVM_DLL Pass LegalizeOps(Optional<Map<String, PackedFunc>> cmap, bool enable_war
  */
 TVM_DLL Pass LiftTransformParams();
 
+/*!
+ * \brief Update virtual device.
+ * \param new_vdevice The new virtual device.
+ * \param index The device index indicates the device on which the update will be performed.
+ * \return The Pass.
+ */
+TVM_DLL Pass UpdateVDevice(VDevice new_vdevice, int64_t index);
+
 /*!
  * \brief Annotate Op Pattern Kind for TIR functions, which is used in FuseOps.
  * \note It is an auto-detect pass for "unscheduled prim_funcs", the op_pattern will be
diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py
index fe30c01d3a..e78e1cf69a 100644
--- a/python/tvm/relax/struct_info.py
+++ b/python/tvm/relax/struct_info.py
@@ -120,8 +120,6 @@ class TensorStructInfo(StructInfo):
     ) -> None:
         if isinstance(shape, (list, tuple, Array)):
             shape = ShapeExpr(shape)
-        if vdevice is None:
-            vdevice = VDevice(None, 0, "global")
         self.__init_handle_by_constructor__(
             _ffi_api.TensorStructInfo, shape, dtype, ndim, vdevice, span  # type: ignore
         )
diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py
index 438a6d1213..6c08a6fe68 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -249,7 +249,7 @@ def RemovePurityChecking() -> tvm.ir.transform.Pass:
     return _ffi_api.RemovePurityChecking()  # type: ignore
 
 
-def LambdaLift():
+def LambdaLift() -> tvm.ir.transform.Pass:
     """A pass that lifts local functions into global.
 
     Returns
@@ -312,6 +312,24 @@ def EliminateCommonSubexpr(call_only=False) -> FunctionPass:
     return _ffi_api.EliminateCommonSubexpr(call_only)  # type: ignore
 
 
+def UpdateVDevice(new_vdevice: tvm.ir.VDevice, index: int) -> tvm.ir.transform.Pass:
+    """Update virtual device.
+
+    Parameters
+    ----------
+    new_vdevice : tvm.ir.VDevice
+        The new virtual device.
+    index : int
+        The device index indicates the device on which the update will be performed.
+
+    Returns
+    -------
+    ret : tvm.ir.transform.Pass
+        The registered pass that modifies the virtual device.
+    """
+    return _ffi_api.UpdateVDevice(new_vdevice, index)  # type: ignore
+
+
 def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
     """Convert all reshape-like call_tir to VM reshape operator call.
     The VM reshape operator calls will be further lowered to a CreateView
diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc
index 9fae776279..4a633e9df4 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -149,10 +149,7 @@ class WellDefinedEraser : public StructInfoMutator,
       std::swap(has_undefined_, has_undefined);
     }
 
-    VDevice vdev = VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
-    if (op->vdevice.defined()) {
-      vdev = op->vdevice.value();
-    }
+    VDevice vdev = op->vdevice.value_or(VDevice());
 
     // erase symbolic shape if we have undefined.
     if (!has_undefined) {
@@ -346,6 +343,22 @@ class StructInfoBaseChecker
       if (rhs->IsUnknownNdim()) return BaseCheckResult::kFailL1;
       return BaseCheckResult::kFailL0;
     }
+
+    // vdevice mismatch
+    if (lhs->vdevice.defined() && !rhs->vdevice.defined()) return BaseCheckResult::kFailL1;
+    if (lhs->vdevice.defined() && rhs->vdevice.defined()) {
+      VDevice lhs_vdevice = lhs->vdevice.value();
+      VDevice rhs_vdevice = rhs->vdevice.value();
+      if (lhs_vdevice->target.defined() && !rhs_vdevice->target.defined())
+        return BaseCheckResult::kFailL1;
+      // mismatch in either the target, vdevice_id, or memory_scope
+      if ((lhs_vdevice->target.defined() && rhs_vdevice->target.defined()) &&
+          (lhs_vdevice->target != rhs_vdevice->target ||
+           lhs_vdevice->vdevice_id != rhs_vdevice->vdevice_id ||
+           lhs_vdevice->memory_scope != rhs_vdevice->memory_scope))
+        return BaseCheckResult::kFailL0;
+    }
+
     // lhs does not have defined shape and everything else matches
     if (!lhs->shape.defined()) return BaseCheckResult::kPass;
     // rhs does not have symbolic value but lhs don't
@@ -769,32 +782,28 @@ class StructInfoLCAFinder
     auto* rhs = other.as<TensorStructInfoNode>();
     if (rhs == nullptr) return ObjectStructInfo(lhs->span);
 
-    // find the target dtype and ndim.
+    // find the target dtype, ndim, and vdevice.
     DataType dtype = lhs->dtype == rhs->dtype ? lhs->dtype : DataType::Void();
     int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownNDim;
-    VDevice vdev = VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
-    if (lhs->vdevice.defined() && rhs->vdevice.defined()) {
-      if (lhs->vdevice.value().same_as(lhs->vdevice.value())) {
-        vdev = lhs->vdevice.value();
-      }
-    } else if (lhs->vdevice.defined()) {
+    VDevice vdev = VDevice();
+    if (lhs->vdevice.defined() && rhs->vdevice.defined() &&
+        lhs->vdevice.value() == rhs->vdevice.value()) {
       vdev = lhs->vdevice.value();
-    } else if (rhs->vdevice.defined()) {
-      vdev = rhs->vdevice.value();
     }
     // if ndim mismatch or one side of shape is missing
     // then we cannot keep in symbolic shape
     if (lhs->ndim != rhs->ndim || !lhs->shape.defined() || !rhs->shape.defined() ||
         !CanProveShapeEqual(lhs->shape.value(), rhs->shape.value(), analyzer_)) {
       // reuse lhs when possible
-      if (!lhs->shape.defined() && lhs->dtype == dtype && lhs->ndim == ndim) {
+      if (!lhs->shape.defined() && lhs->dtype == dtype && lhs->ndim == ndim &&
+          (!lhs->vdevice.defined() || vdev.defined())) {
         return GetRef<StructInfo>(lhs);
       } else {
         return TensorStructInfo(dtype, ndim, vdev, lhs->span);
       }
     }
-    // symbolic shape match but dtype mismatch
-    if (lhs->dtype != dtype) {
+    // symbolic shape and vdevice match but dtype mismatch
+    if (lhs->dtype != dtype || (lhs->vdevice.defined() && !vdev.defined())) {
       return TensorStructInfo(lhs->shape.value(), dtype, vdev, lhs->span);
     } else {
       return GetRef<StructInfo>(lhs);
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index 3808325670..ac04096aaf 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -292,8 +292,7 @@ Constant::Constant(runtime::NDArray data, Optional<StructInfo> struct_info_annot
     n->struct_info_ = struct_info_annotation.value();
     n->checked_type_ = GetStaticType(struct_info_annotation.value());
   } else {
-    TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(),
-                           VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global"), span);
+    TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(), VDevice(), span);
     n->struct_info_ = tinfo;
     n->checked_type_ = DynTensorType(tinfo->ndim, tinfo->dtype);
   }
diff --git a/src/relax/ir/struct_info_functor.cc b/src/relax/ir/struct_info_functor.cc
index 10babe4b06..d7929e0f1a 100644
--- a/src/relax/ir/struct_info_functor.cc
+++ b/src/relax/ir/struct_info_functor.cc
@@ -94,10 +94,7 @@ StructInfo StructInfoMutator::VisitStructInfo_(const TensorStructInfoNode* op) {
     shape = this->VisitStructInfoExprField(op->shape.value());
   }
 
-  VDevice vdev = VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
-  if (op->vdevice.defined()) {
-    vdev = op->vdevice.value();
-  }
+  VDevice vdev = op->vdevice.value_or(VDevice());
 
   if (shape.same_as(op->shape)) {
     return GetRef<StructInfo>(op);
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index 3c3bb15136..3a4cb26861 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -90,6 +90,9 @@ StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) {
   Optional<ShapeExpr> data_shape =
       CheckNdimPerLayoutAndGetShape(call, ctx, GetRef<TensorStructInfo>(data_sinfo), data_layout);
   if (!data_shape.defined() || size_value == nullptr) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(out_dtype, data_layout.ndim(), data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(out_dtype, data_layout.ndim());
   }
 
@@ -99,6 +102,9 @@ StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) {
   out_NCHW_shape.Set(3, size_value->values[1]);
 
   Array<PrimExpr> out_shape = data2NCHW.BackwardShape(out_NCHW_shape);
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(out_shape), out_dtype, data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
 }
 
diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc
index 55757552db..4f37e3a33c 100644
--- a/src/relax/op/nn/attention.cc
+++ b/src/relax/op/nn/attention.cc
@@ -108,6 +108,9 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) {
   }
 
   Array<PrimExpr> output_shape = {num_batches, num_queries, num_heads, head_dim_value};
+  if (q_sinfo->vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(output_shape), q_sinfo->dtype, q_sinfo->vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(output_shape), q_sinfo->dtype);
 }
 
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index 96fc7e8464..e8cb1916e8 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -77,7 +77,11 @@ StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) {
   DataType out_dtype = attrs->out_dtype.is_void()
                            ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo)
                            : attrs->out_dtype;
+  Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo);
   if (!data_shape.defined() || !weight_shape.defined()) {
+    if (vdevice.defined()) {
+      return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice.value());
+    }
     return TensorStructInfo(out_dtype, out_layout.ndim());
   }
 
@@ -121,6 +125,9 @@ StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) {
   out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[0]) + 1);
 
   Array<PrimExpr> out_shape = out2NCW.BackwardShape(out_NCW_shape);
+  if (vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
 }
 
@@ -239,7 +246,11 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) {
   DataType out_dtype = attrs->out_dtype.is_void()
                            ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo)
                            : attrs->out_dtype;
+  Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo);
   if (!data_shape.defined() || !weight_shape.defined()) {
+    if (vdevice.defined()) {
+      return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice.value());
+    }
     return TensorStructInfo(out_dtype, out_layout.ndim());
   }
 
@@ -288,6 +299,9 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) {
   out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[1]) + 1);
 
   Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
+  if (vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
 }
 
@@ -411,7 +425,11 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder&
   DataType out_dtype = attrs->out_dtype.is_void()
                            ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo)
                            : attrs->out_dtype;
+  Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo);
   if (!data_shape.defined() || !weight_shape.defined()) {
+    if (vdevice.defined()) {
+      return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice.value());
+    }
     return TensorStructInfo(out_dtype, out_layout.ndim());
   }
 
@@ -465,6 +483,9 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder&
   out_NCW_shape[2] = analyzer->Simplify(out_w);
 
   Array<PrimExpr> out_shape = out2NCW.BackwardShape(out_NCW_shape);
+  if (vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
 }
 
@@ -548,7 +569,11 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder&
   DataType out_dtype = attrs->out_dtype.is_void()
                            ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo)
                            : attrs->out_dtype;
+  Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo);
   if (!data_shape.defined() || !weight_shape.defined()) {
+    if (vdevice.defined()) {
+      return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice.value());
+    }
     return TensorStructInfo(out_dtype, out_layout.ndim());
   }
 
@@ -610,6 +635,9 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder&
   out_NCHW_shape[3] = analyzer->Simplify(out_w);
 
   Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
+  if (vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
 }
 
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index 2e73297e2d..f95cc9f4d6 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -266,6 +266,12 @@ StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) {
 
   DataType dtype = input_sinfo[0]->dtype;
   if (unknown_shape) {
+    if (input_sinfo[0]->vdevice.defined()) {
+      VDevice vdev = input_sinfo[0]->vdevice.value();
+      return TupleStructInfo({TensorStructInfo(dtype, input_sinfo[0]->ndim, vdev),
+                              TensorStructInfo(dtype, /*ndim=*/1, vdev),
+                              TensorStructInfo(dtype, /*ndim=*/1, vdev)});
+    }
     return TupleStructInfo({TensorStructInfo(dtype, input_sinfo[0]->ndim),
                             TensorStructInfo(dtype, /*ndim=*/1),
                             TensorStructInfo(dtype, /*ndim=*/1)});
@@ -331,6 +337,11 @@ StructInfo InferStructInfoLayerNorm(const Call& call, const BlockBuilder& ctx) {
   const auto* attrs = call->attrs.as<LayerNormAttrs>();
   bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, attrs->axes);
 
+  if (input_sinfo[0]->vdevice.defined()) {
+    return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, input_sinfo[0]->ndim,
+                                            input_sinfo[0]->vdevice.value())
+                         : input_sinfo[0];
+  }
   return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, input_sinfo[0]->ndim)
                        : input_sinfo[0];
 }
@@ -503,6 +514,11 @@ StructInfo InferStructInfoRMSNorm(const Call& call, const BlockBuilder& ctx) {
   const auto* attrs = call->attrs.as<RMSNormAttrs>();
   bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, attrs->axes);
 
+  if (input_sinfo[0]->vdevice.defined()) {
+    return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, input_sinfo[0]->ndim,
+                                            input_sinfo[0]->vdevice.value())
+                         : input_sinfo[0];
+  }
   return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, input_sinfo[0]->ndim)
                        : input_sinfo[0];
 }
@@ -578,6 +594,9 @@ StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx
   // infer dtype
   DataType dtype = InferBinaryArithOpOutDtype(call, ctx, pred_sinfo, label_sinfo);
 
+  // infer vdevice
+  Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx, pred_sinfo, label_sinfo);
+
   // infer ndim
   if (!pred_sinfo->IsUnknownNdim() && !label_sinfo->IsUnknownNdim() &&
       pred_sinfo->ndim != label_sinfo->ndim) {
@@ -610,6 +629,9 @@ StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx
       }
     }
   }
+  if (vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(Array<PrimExpr>()), dtype, vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(Array<PrimExpr>()), dtype);
 }
 
@@ -685,13 +707,17 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) {
         << call->args[1]->struct_info_->GetTypeKey());
   }
 
-  // infer dtype
+  // infer dtype, vdevice
   DataType output_dtype;
+  Optional<VDevice> vdevice;
   if (wgt_sinfo != nullptr) {
     output_dtype = InferBinaryArithOpOutDtype(call, ctx, GetRef<TensorStructInfo>(pred_sinfo),
                                               GetRef<TensorStructInfo>(wgt_sinfo));
+    vdevice = InferBinaryArithOpOutVDevice(call, ctx, GetRef<TensorStructInfo>(pred_sinfo),
+                                           GetRef<TensorStructInfo>(wgt_sinfo));
   } else {
     output_dtype = pred_sinfo->dtype;
+    vdevice = pred_sinfo->vdevice;
   }
 
   // the type of targets must be int/uint.
@@ -834,13 +860,23 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) {
   if (reduction == "none") {
     // () or (N,) or (N, d1, d2, ..., dk)
     if (pred_sinfo->shape.as<ShapeExprNode>()) {
+      if (vdevice.defined()) {
+        return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdevice.value());
+      }
       return TensorStructInfo(ShapeExpr(output_shape), output_dtype);
     } else {
       int output_ndim = pred_sinfo->ndim == kUnknownNDim ? kUnknownNDim : pred_sinfo->ndim - 1;
+      if (vdevice.defined()) {
+        return TensorStructInfo(output_dtype, /*ndim=*/output_ndim, vdevice.value());
+      }
       return TensorStructInfo(output_dtype, /*ndim=*/output_ndim);
     }
   } else {
     // sum or mean. output is scalar
+    if (vdevice.defined()) {
+      return TensorStructInfo(/*shape=*/ShapeExpr(Array<PrimExpr>()), output_dtype,
+                              vdevice.value());
+    }
     return TensorStructInfo(/*shape=*/ShapeExpr(Array<PrimExpr>()), output_dtype);
   }
 }
diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc
index bfbb4b4284..c26fae08c2 100644
--- a/src/relax/op/nn/pooling.cc
+++ b/src/relax/op/nn/pooling.cc
@@ -86,6 +86,9 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) {
   Optional<ShapeExpr> data_shape =
       CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
   if (!data_shape.defined()) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->dtype, out_layout.ndim());
   }
 
@@ -114,6 +117,9 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) {
   out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[1]) + 1);
 
   Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
 }
 
@@ -204,6 +210,9 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder
         !attrs->output_size.defined()) {
       return data_sinfo;
     } else {
+      if (data_sinfo->vdevice.defined()) {
+        return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), data_sinfo->vdevice.value());
+      }
       return TensorStructInfo(data_sinfo->dtype, out_layout.ndim());
     }
   }
@@ -216,6 +225,9 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder
   }
 
   Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
 }
 
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 0e9eee4f01..dd4c3ac173 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -201,6 +201,32 @@ inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder&
   return x1_sinfo->dtype;
 }
 
+/*!
+ * \brief Infer the output virtual device for binary arithmetic operators.
+ * \param call The context Call to the operator.
+ * \param ctx The error reporting context.
+ * \param x1_sinfo The struct info of the first operand
+ * \param x2_sinfo The struct info of the second operand
+ * \return The inferred output vdevice.
+ * \throw Throw exception if the vdevice of two input TensorStructInfo don’t match
+ */
+inline Optional<VDevice> InferBinaryArithOpOutVDevice(const Call& call, const BlockBuilder& ctx,
+                                                      const TensorStructInfo& x1_sinfo,
+                                                      const TensorStructInfo& x2_sinfo) {
+  if (!x1_sinfo->vdevice.defined() || !x1_sinfo->vdevice.value()->target.defined()) {
+    return x2_sinfo->vdevice;
+  }
+  if (!x2_sinfo->vdevice.defined() || !x2_sinfo->vdevice.value()->target.defined()) {
+    return x1_sinfo->vdevice;
+  }
+  if (x1_sinfo->vdevice.value() != x2_sinfo->vdevice.value()) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "VDevice " << x1_sinfo->vdevice.value() << " and "
+                     << x2_sinfo->vdevice.value() << " must be equal for binary operators");
+  }
+  return x1_sinfo->vdevice;
+}
+
 /*!
  * \brief Infer the output shape for binary broadcast operators.
  * \param call The context Call to the operator.
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index 6483806182..87afc24397 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -39,6 +39,9 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx,
   // DateType
   DataType output_dtype = f_compute_out_dtype(call, ctx, x1_sinfo, x2_sinfo);
 
+  // VDevice
+  Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx, x1_sinfo, x2_sinfo);
+
   // ndims
   int output_ndim;
   if (x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) {
@@ -55,14 +58,27 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx,
     Optional<Array<PrimExpr>> output_shape =
         InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values);
     if (!output_shape.defined()) {
+      if (vdevice.defined()) {
+        return TensorStructInfo(output_dtype, /*ndim=*/output_ndim, vdevice.value());
+      }
       return TensorStructInfo(output_dtype, /*ndim=*/output_ndim);
+
     } else {
       ICHECK_EQ(static_cast<int>(output_shape.value().size()), output_ndim);
+      if (vdevice.defined()) {
+        return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdevice.value());
+      }
       return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype);
     }
   } else if (x1_sinfo->shape.defined() && x1_sinfo->shape.same_as(x2_sinfo->shape)) {
+    if (vdevice.defined()) {
+      return TensorStructInfo(x1_sinfo->shape.value(), output_dtype, vdevice.value());
+    }
     return TensorStructInfo(x1_sinfo->shape.value(), output_dtype);
   } else {
+    if (vdevice.defined()) {
+      return TensorStructInfo(output_dtype, /*ndim=*/output_ndim, vdevice.value());
+    }
     return TensorStructInfo(output_dtype, /*ndim=*/output_ndim);
   }
 }
diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc
index dabf3155f0..3a4de79b11 100644
--- a/src/relax/op/tensor/create.cc
+++ b/src/relax/op/tensor/create.cc
@@ -77,6 +77,9 @@ StructInfo InferStructInfoFull(const Call& call, const BlockBuilder& ctx) {
 
   const auto* attrs = call->attrs.as<InitAttrs>();
   DataType out_dtype = attrs->dtype.is_void() ? fill_value_sinfo->dtype : attrs->dtype;
+  if (fill_value_sinfo->vdevice.defined()) {
+    return TensorStructInfo(/*shape=*/call->args[0], out_dtype, fill_value_sinfo->vdevice.value());
+  }
   return TensorStructInfo(/*shape=*/call->args[0], out_dtype);
 }
 
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index 5f1d5149b3..6d9dfc86ba 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -66,6 +66,9 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) {
                      << data_sinfo->ndim);
   }
   if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
   }
 
@@ -75,6 +78,10 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) {
   const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
   const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
   if (data_shape == nullptr || indices_shape == nullptr) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim + data_sinfo->ndim - 1,
+                              data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim + data_sinfo->ndim - 1);
   }
 
@@ -87,6 +94,10 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) {
       output_shape.push_back(data_shape->values[i]);
     }
   }
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype,
+                            data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
 }
 
@@ -180,12 +191,18 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx
   }
 
   if (data_sinfo->IsUnknownNdim()) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
   }
 
   std::vector<int> axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes);
   const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
   if (data_shape == nullptr) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
   }
 
@@ -199,6 +216,9 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx
   for (int i = 0; i < n_axis; ++i) {
     const auto* int_stride = strides[i].as<IntImmNode>();
     if (!int_stride) {
+      if (data_sinfo->vdevice.defined()) {
+        return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value());
+      }
       return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
     }
     int_strides.push_back(int_stride->value);
@@ -211,6 +231,10 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx
     output_shape.Set(axes[i], GetLength(attrs->begin[i], attrs->end[i], int_strides[i],
                                         data_shape->values[axes[i]], attrs->assume_inbound));
   }
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype,
+                            data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
 }
 
@@ -265,6 +289,9 @@ StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder&
     LOG(WARNING) << "When data rank is unknown, dynamic strided slice assumes begin/end/strides "
                     "tensors are well-formed. It could produce runtime error when this assumption "
                     "turns out to be wrong.";
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
   }
   if (data_sinfo->IsUnknownDtype()) {
@@ -305,6 +332,9 @@ StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder&
   // The output shape will depend on the runtime value in begin/end/strides tensors.
   // TODO(tvm-team): Currently, it is unable to express partially-static shape. Revisit when
   // PrimValue lands.
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(data_sinfo->dtype, n_axis, data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(data_sinfo->dtype, n_axis);
 }  // namespace relax
 
diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc
index b05fbaa5d3..62ff189577 100644
--- a/src/relax/op/tensor/linear_algebra.cc
+++ b/src/relax/op/tensor/linear_algebra.cc
@@ -51,12 +51,26 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) {
   TensorStructInfo x1_sinfo = input_sinfo[0];
   TensorStructInfo x2_sinfo = input_sinfo[1];
 
+  VDevice vdev = VDevice();
+  if (x1_sinfo->vdevice.defined() && x2_sinfo->vdevice.defined()) {
+    if (x1_sinfo->vdevice.value() == x2_sinfo->vdevice.value()) {
+      vdev = x1_sinfo->vdevice.value();
+    }
+  } else if (x1_sinfo->vdevice.defined()) {
+    vdev = x1_sinfo->vdevice.value();
+  } else if (x2_sinfo->vdevice.defined()) {
+    vdev = x2_sinfo->vdevice.value();
+  }
+
   const auto* attrs = call->attrs.as<MatmulAttrs>();
   DataType out_dtype = attrs->out_dtype.is_void()
                            ? InferBinaryArithOpOutDtype(call, ctx, x1_sinfo, x2_sinfo)
                            : attrs->out_dtype;
 
   if (x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) {
+    if (vdev.defined()) {
+      return TensorStructInfo(out_dtype, kUnknownNDim, vdev);
+    }
     return TensorStructInfo(out_dtype, kUnknownNDim);
   }
   int x1_ndim = x1_sinfo->ndim;
@@ -82,6 +96,9 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) {
   const auto* x1_shape = x1_sinfo->shape.as<ShapeExprNode>();
   const auto* x2_shape = x2_sinfo->shape.as<ShapeExprNode>();
   if (x1_shape == nullptr || x2_shape == nullptr) {
+    if (vdev.defined()) {
+      return TensorStructInfo(out_dtype, output_ndim, vdev);
+    }
     return TensorStructInfo(out_dtype, output_ndim);
   }
 
@@ -92,6 +109,9 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) {
   Optional<Array<PrimExpr>> output_shape_prefix =
       InferBinaryBroadcastShape(call, ctx, x1_shape_prefix, x2_shape_prefix);
   if (!output_shape_prefix.defined()) {
+    if (vdev.defined()) {
+      return TensorStructInfo(out_dtype, output_ndim, vdev);
+    }
     return TensorStructInfo(out_dtype, output_ndim);
   }
 
@@ -113,6 +133,9 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) {
     output_shape.push_back(x2_shape->values[x2_ndim - 1]);
   }
   ICHECK_EQ(static_cast<int>(output_shape.size()), output_ndim);
+  if (vdev.defined()) {
+    return TensorStructInfo(ShapeExpr(output_shape), out_dtype, vdev);
+  }
   return TensorStructInfo(ShapeExpr(output_shape), out_dtype);
 }
 
@@ -156,6 +179,23 @@ StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) {
 
   const auto* attrs = call->attrs.as<EinsumAttrs>();
 
+  bool vdevice_unknown = false;
+  VDevice vdev = VDevice();
+  for (TensorStructInfo sinfo : operands_tensor_sinfo) {
+    if (!vdevice_unknown) {
+      if (sinfo->vdevice.defined()) {
+        if (!vdev.defined()) {
+          vdev = sinfo->vdevice.value();
+        } else if (sinfo->vdevice.value()->target.defined()) {
+          // mismatch
+          if (sinfo->vdevice.value() != vdev) {
+            vdevice_unknown = true;
+          }
+        }
+      }
+    }
+  }
+
   String subscripts = attrs->subscripts;
 
   DataType operand_dtype = operands_tensor_sinfo[0]->dtype;
@@ -176,12 +216,18 @@ StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) {
     if (shape_expr != nullptr) {
       input_shapes.push_back(shape_expr->values);
     } else {
+      if (!vdevice_unknown) {
+        return TensorStructInfo(operand_dtype, tensor_sinfo->ndim, vdev);
+      }
       return TensorStructInfo(operand_dtype, tensor_sinfo->ndim);
     }
   }
   // Calculate output shape using InferEinsumShape in topi
   Array<PrimExpr> oshape = topi::InferEinsumShape(subscripts, input_shapes);
 
+  if (!vdevice_unknown) {
+    return TensorStructInfo(ShapeExpr(oshape), operand_dtype, vdev);
+  }
   return TensorStructInfo(ShapeExpr(oshape), operand_dtype);
 }
 
diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc
index 24d96af3bb..38b761d04f 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -71,10 +71,18 @@ StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx)
 
   // Trust the input target shape when there is no possibility to do any compile-time check.
   if (!data_sinfo->shape.defined()) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype,
+                              data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype);
   }
   ShapeStructInfo shape_sinfo = Downcast<ShapeStructInfo>(data_sinfo->shape.value()->struct_info_);
   if (!shape_sinfo->values.defined() || !tgt_shape_sinfo->values.defined()) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype,
+                              data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype);
   }
 
@@ -100,6 +108,10 @@ StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx)
     // Todo(relax-team): revisit here for better check on if the tensor length
     // is consistent with the length in the given shape.
   }
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype,
+                            data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype);
 }
 
@@ -190,8 +202,10 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) {
   const auto* attrs = call->attrs.as<ConcatAttrs>();
   int output_ndim = attrs->axis.defined() ? kUnknownNDim : 1;
   DataType output_dtype = DataType::Void();
+  VDevice vdev = VDevice();
   bool shape_unknown = false;
   bool is_void_dtype = false;
+  bool vdevice_unknown = false;
   std::vector<Array<PrimExpr>> shape_values;
   shape_values.reserve(tensor_sinfo.size());
 
@@ -220,6 +234,20 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) {
                        << output_ndim << " and " << sinfo->ndim);
     }
 
+    // Update the virtual device.
+    if (!vdevice_unknown) {
+      if (sinfo->vdevice.defined()) {
+        if (!vdev.defined()) {
+          vdev = sinfo->vdevice.value();
+        } else if (sinfo->vdevice.value()->target.defined()) {
+          // mismatch
+          if (sinfo->vdevice.value() != vdev) {
+            vdevice_unknown = true;
+          }
+        }
+      }
+    }
+
     // Update the shape values for best effort check.
     const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
     if (shape_expr != nullptr) {
@@ -242,6 +270,10 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) {
     output_dtype = DataType::Void();
   }
   if (output_ndim == kUnknownNDim) {
+    if (!vdevice_unknown) {
+      return tensor_sinfo.size() == 1 ? tensor_sinfo[0]
+                                      : TensorStructInfo(output_dtype, output_ndim, vdev);
+    }
     return tensor_sinfo.size() == 1 ? tensor_sinfo[0] : TensorStructInfo(output_dtype, output_ndim);
   }
 
@@ -252,6 +284,9 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) {
     return tensor_sinfo[0];
   }
   if (shape_values.empty()) {
+    if (!vdevice_unknown) {
+      return TensorStructInfo(output_dtype, output_ndim, vdev);
+    }
     return TensorStructInfo(output_dtype, output_ndim);
   }
 
@@ -259,8 +294,14 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) {
   Optional<Array<PrimExpr>> output_shape = CheckConcatOutputShape(call, ctx, shape_values, axis);
 
   if (shape_unknown || !output_shape.defined()) {
+    if (!vdevice_unknown) {
+      return TensorStructInfo(output_dtype, output_ndim, vdev);
+    }
     return TensorStructInfo(output_dtype, output_ndim);
   } else {
+    if (!vdevice_unknown) {
+      return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdev);
+    }
     return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype);
   }
 }
@@ -318,6 +359,9 @@ StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx)
   }
 
   if (data_sinfo->IsUnknownNdim()) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
   }
 
@@ -327,6 +371,9 @@ StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx)
 
   const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
   if (data_shape == nullptr) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->dtype, output_ndim, data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->dtype, output_ndim);
   }
 
@@ -346,6 +393,10 @@ StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx)
     ++i_data_shape;
   }
   ICHECK_EQ(i_data_shape, data_sinfo->ndim);
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype,
+                            data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
 }
 
@@ -415,8 +466,14 @@ TVM_REGISTER_GLOBAL("relax.op.flatten").set_body_typed(flatten);
 StructInfo InferStructInfoFlatten(const Call& call, const BlockBuilder& ctx) {
   TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
   if (data_sinfo->IsUnknownNdim()) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->dtype, /*ndim=*/1, data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->dtype, /*ndim=*/1);
   } else if (data_sinfo->ndim == 0) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(ShapeExpr({1}), data_sinfo->dtype, data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(ShapeExpr({1}), data_sinfo->dtype);
   } else if (data_sinfo->ndim == 1) {
     return data_sinfo;
@@ -424,9 +481,16 @@ StructInfo InferStructInfoFlatten(const Call& call, const BlockBuilder& ctx) {
 
   const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
   if (data_shape == nullptr) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->dtype, /*ndim=*/1, data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->dtype, /*ndim=*/1);
   }
   PrimExpr shape_prod = ComputeShapeProduct(data_shape->values);
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr({std::move(shape_prod)}), data_sinfo->dtype,
+                            data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr({std::move(shape_prod)}), data_sinfo->dtype);
 }
 
@@ -471,6 +535,10 @@ StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder&
 
   if (data_sinfo->IsUnknownNdim()) {
     // Todo(relax-team): revisit here for better check on if the input tensor has desired ndim.
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size(),
+                              data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size());
   }
 
@@ -483,16 +551,28 @@ StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder&
   }
 
   if (!data_sinfo->shape.defined()) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size(),
+                              data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size());
   }
 
   ShapeStructInfo shape_sinfo = Downcast<ShapeStructInfo>(data_sinfo->shape.value()->struct_info_);
   if (!shape_sinfo->values.defined()) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size(),
+                              data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size());
   }
 
   arith::Analyzer analyzer;
   Array<PrimExpr> output_shape = index_map->MapShape(shape_sinfo->values.value(), &analyzer);
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype,
+                            data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
 }
 
@@ -534,6 +614,9 @@ StructInfo InferStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx)
   // Todo(relax-team): revisit here for better check on if the input tensor has
   // ndim same as the number of input axes.
   if (!attrs->axes.defined() && data_sinfo->IsUnknownNdim()) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
   }
 
@@ -561,6 +644,9 @@ StructInfo InferStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx)
 
   const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
   if (data_shape == nullptr) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
   }
   std::vector<PrimExpr> new_shape;
@@ -568,6 +654,9 @@ StructInfo InferStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx)
   for (int i = 0; i < data_sinfo->ndim; ++i) {
     new_shape.push_back(data_shape->values[axes[i]]);
   }
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(new_shape), data_sinfo->dtype, data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(new_shape), data_sinfo->dtype);
 }
 
@@ -759,8 +848,15 @@ StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) {
   Expr target_shape = call->args[1];
   // If shape values are defined, use them
   if (target_shape->IsInstance<VarNode>() && new_shape_sinfo->values.defined()) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(ShapeExpr(new_shape_sinfo->values.value()), data_sinfo->dtype,
+                              data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(ShapeExpr(new_shape_sinfo->values.value()), data_sinfo->dtype);
   }
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(target_shape, data_sinfo->dtype, data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(target_shape, data_sinfo->dtype);
 }
 
@@ -818,6 +914,11 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) {
     }
     // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape.
     if (data_shape == nullptr) {
+      if (data_sinfo->vdevice.defined()) {
+        return TupleStructInfo(Array<StructInfo>(
+            p_indices->size() + 1,
+            TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value())));
+      }
       return TupleStructInfo(Array<StructInfo>(
           p_indices->size() + 1, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim)));
     }
@@ -826,6 +927,11 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) {
     const auto* axis_length = data_shape->values[axis].as<IntImmNode>();
     // Fall back to unknown shape when the input tensor shape at the given axis is symbolic.
     if (axis_length == nullptr) {
+      if (data_sinfo->vdevice.defined()) {
+        return TupleStructInfo(Array<StructInfo>(
+            p_indices->size() + 1,
+            TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value())));
+      }
       return TupleStructInfo(Array<StructInfo>(
           p_indices->size() + 1, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim)));
     }
@@ -844,7 +950,12 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) {
 
       Array<PrimExpr> shape = data_shape->values;
       shape.Set(axis, tvm::max(zero, r - l));
-      output_sinfo.push_back(TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype));
+      if (data_sinfo->vdevice.defined()) {
+        output_sinfo.push_back(
+            TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice.value()));
+      } else {
+        output_sinfo.push_back(TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype));
+      }
     }
     return TupleStructInfo(output_sinfo);
   } else if (const auto* p_n_section = attrs->indices_or_sections.as<IntImmNode>()) {
@@ -856,6 +967,11 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) {
     }
     // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape.
     if (data_shape == nullptr) {
+      if (data_sinfo->vdevice.defined()) {
+        return TupleStructInfo(Array<StructInfo>(
+            n_section,
+            TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value())));
+      }
       return TupleStructInfo(
           Array<StructInfo>(n_section, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim)));
     }
@@ -865,12 +981,22 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) {
     // Construct struct info for tensors except the last one.
     Array<PrimExpr> shape = data_shape->values;
     shape.Set(axis, split_len);
+    if (data_sinfo->vdevice.defined()) {
+      std::vector<StructInfo> output_sinfo(
+          n_section - 1,
+          TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice.value()));
+    }
     std::vector<StructInfo> output_sinfo(n_section - 1,
                                          TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype));
 
     // Construct struct info for the last tensor.
     shape.Set(axis, data_shape->values[axis] - split_len * (n_section - 1));
-    output_sinfo.push_back(TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype));
+    if (data_sinfo->vdevice.defined()) {
+      output_sinfo.push_back(
+          TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice.value()));
+    } else {
+      output_sinfo.push_back(TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype));
+    }
     return TupleStructInfo(output_sinfo);
   }
   ICHECK(false) << "Cannot reach here.";
@@ -928,6 +1054,9 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) {
   }
 
   if (data_sinfo->IsUnknownNdim()) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
   }
 
@@ -943,6 +1072,10 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) {
     std::vector<int> axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value());
 
     if (!shape_value.defined()) {
+      if (data_sinfo->vdevice.defined()) {
+        return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim - axes.size(),
+                                data_sinfo->vdevice.value());
+      }
       return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim - axes.size());
     }
     for (int i = 0; i < static_cast<int>(axes.size()); ++i) {
@@ -965,12 +1098,18 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) {
     // (https://data-apis.org/array-api/latest/API_specification/generated/array_api.squeeze.html).
     // Consider discourage usage later.
     if (!shape_value.defined()) {
+      if (data_sinfo->vdevice.defined()) {
+        return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+      }
       return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
     }
     for (int i = 0; i < data_sinfo->ndim; ++i) {
       // Whenever a dimension length is symbolic, fall back to unknown ndim.
       const auto* int_len = shape_value.value()[i].as<IntImmNode>();
       if (int_len == nullptr) {
+        if (data_sinfo->vdevice.defined()) {
+          return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+        }
         return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
       }
       if (int_len->value == 1) {
@@ -991,11 +1130,22 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) {
     if (static_cast<int>(output_shape.size()) == data_sinfo->ndim) {
       return data_sinfo;
     } else if (attrs->axis.defined()) {
+      if (data_sinfo->vdevice.defined()) {
+        return TensorStructInfo(data_sinfo->dtype, output_shape.size(),
+                                data_sinfo->vdevice.value());
+      }
       return TensorStructInfo(data_sinfo->dtype, output_shape.size());
     } else {
+      if (data_sinfo->vdevice.defined()) {
+        return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+      }
       return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
     }
   } else {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype,
+                              data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
   }
 }
@@ -1132,8 +1282,16 @@ StructInfo InferStructInfoCollapseSumLike(const Call& call, const BlockBuilder&
   }
 
   if (collapse_target_sinfo->shape.defined()) {
+    if (collapse_target_sinfo->vdevice.defined()) {
+      return TensorStructInfo(collapse_target_sinfo->shape.value(), output_dtype,
+                              collapse_target_sinfo->vdevice.value());
+    }
     return TensorStructInfo(collapse_target_sinfo->shape.value(), output_dtype);
   } else {
+    if (collapse_target_sinfo->vdevice.defined()) {
+      return TensorStructInfo(output_dtype, collapse_target_sinfo->ndim,
+                              collapse_target_sinfo->vdevice.value());
+    }
     return TensorStructInfo(output_dtype, collapse_target_sinfo->ndim);
   }
 }
@@ -1185,7 +1343,9 @@ StructInfo InferStructInfoCollapseSumTo(const Call& call, const BlockBuilder& ct
   if (data_shape_value.defined() && shape_sinfo->values.defined()) {
     CheckCollapseShape(call, ctx, data_shape_value.value(), shape_sinfo->values.value());
   }
-
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(/*shape=*/call->args[1], output_dtype, data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(/*shape=*/call->args[1], output_dtype);
 }
 
@@ -1234,9 +1394,15 @@ StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) {
         // the shape does not changes
         return data_sinfo;
       } else {
+        if (data_sinfo->vdevice.defined()) {
+          return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value());
+        }
         return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
       }
     } else {
+      if (data_sinfo->vdevice.defined()) {
+        return TensorStructInfo(data_sinfo->dtype, 1, data_sinfo->vdevice.value());
+      }
       return TensorStructInfo(data_sinfo->dtype, 1);
     }
   }
@@ -1244,12 +1410,19 @@ StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) {
   if (!attrs->axis.defined()) {
     PrimExpr new_shape =
         analyzer->Simplify(ComputeShapeProduct(data_shape->values) * attrs->repeats);
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(ShapeExpr(Array<PrimExpr>({new_shape})), data_sinfo->dtype,
+                              data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(ShapeExpr(Array<PrimExpr>({new_shape})), data_sinfo->dtype);
   }
 
   int axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()->value);
   auto shape_array = data_shape->values;
   shape_array.Set(axis, analyzer->Simplify(shape_array[axis] * attrs->repeats));
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(shape_array), data_sinfo->dtype, data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(shape_array), data_sinfo->dtype);
 }
 
@@ -1284,13 +1457,23 @@ StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) {
 
   if (data_shape == nullptr) {
     if (data_sinfo->IsUnknownNdim()) {
+      if (data_sinfo->vdevice.defined()) {
+        return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+      }
       return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
     }
     if (l > ndim) {
+      if (data_sinfo->vdevice.defined()) {
+        return TensorStructInfo(data_sinfo->dtype, l, data_sinfo->vdevice.value());
+      }
       return TensorStructInfo(data_sinfo->dtype, l);
     } else {
       for (auto i : attrs->repeats) {
         if (!analyzer->CanProveEqual(i, 1)) {
+          if (data_sinfo->vdevice.defined()) {
+            return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim,
+                                    data_sinfo->vdevice.value());
+          }
           return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
         }
       }
@@ -1313,6 +1496,10 @@ StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) {
           analyzer->Simplify(data_shape->values[i - ndim_delta] * attrs->repeats[i - l_delta]));
     }
   }
+
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
 }
 
@@ -1395,6 +1582,9 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder&
   if (data_sinfo->IsUnknownNdim()) {
     // When `data` has unknown rank, assume rest of arguments are correct and proceed.
     // If the assumption turns out to be wrong, runtime error will be triggered.
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
   }
 
@@ -1461,8 +1651,15 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder&
   }
   const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
   if (data_shape) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype,
+                              data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype);
   }
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
 }
 
diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc
index e1d684916c..14fa287494 100644
--- a/src/relax/op/tensor/search.cc
+++ b/src/relax/op/tensor/search.cc
@@ -44,6 +44,21 @@ StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) {
   TensorStructInfo x1_sinfo = input_sinfo[1];
   TensorStructInfo x2_sinfo = input_sinfo[2];
 
+  VDevice vdev = VDevice();
+  for (int i = 0; i < 3; ++i) {
+    if (input_sinfo[i]->vdevice.defined()) {
+      if (!vdev.defined()) {
+        vdev = input_sinfo[i]->vdevice.value();
+      } else if (input_sinfo[i]->vdevice.value()->target.defined()) {
+        // mismatch
+        if (input_sinfo[i]->vdevice.value() != vdev) {
+          vdev = VDevice();
+          break;
+        }
+      }
+    }
+  }
+
   if (!cond_sinfo->dtype.is_bool()) {
     ctx->ReportFatal(Diagnostic::Error(call)
                      << "Where requires the input condition tensor to have boolean dtype. However, "
@@ -67,23 +82,38 @@ StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) {
     Optional<Array<PrimExpr>> broadcasted_shape =
         InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values);
     if (!broadcasted_shape.defined()) {
+      if (vdev.defined()) {
+        return TensorStructInfo(output_dtype, output_ndim, vdev);
+      }
       return TensorStructInfo(output_dtype, output_ndim);
     }
     // Step 2. Compute the broadcasted shape of cond's and the previous broadcasted shape.
     broadcasted_shape =
         InferBinaryBroadcastShape(call, ctx, cond_shape->values, broadcasted_shape.value());
     if (!broadcasted_shape.defined()) {
+      if (vdev.defined()) {
+        return TensorStructInfo(output_dtype, output_ndim, vdev);
+      }
       return TensorStructInfo(output_dtype, output_ndim);
     }
     ICHECK_EQ(static_cast<int>(broadcasted_shape.value().size()), output_ndim);
+    if (vdev.defined()) {
+      return TensorStructInfo(ShapeExpr(broadcasted_shape.value()), output_dtype, vdev);
+    }
     return TensorStructInfo(ShapeExpr(broadcasted_shape.value()), output_dtype);
   } else if (cond_sinfo->shape.defined() &&                 //
              x1_sinfo->shape.defined() &&                   //
              x2_sinfo->shape.defined() &&                   //
              cond_sinfo->shape.same_as(x1_sinfo->shape) &&  //
              cond_sinfo->shape.same_as(x2_sinfo->shape)) {
+    if (vdev.defined()) {
+      return TensorStructInfo(cond_sinfo->shape.value(), output_dtype, vdev);
+    }
     return TensorStructInfo(cond_sinfo->shape.value(), output_dtype);
   } else {
+    if (vdev.defined()) {
+      return TensorStructInfo(output_dtype, output_ndim, vdev);
+    }
     return TensorStructInfo(output_dtype, output_ndim);
   }
 }
@@ -131,9 +161,19 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx
   const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
   if (data_shape == nullptr) {
     if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) {
+      if (data_sinfo->vdevice.defined()) {
+        return TensorStructInfo(
+            ShapeExpr(Array<PrimExpr>(out_ndim, IntImm(out_dtype, /*value=*/1))), out_dtype,
+            data_sinfo->vdevice.value());
+      }
       return TensorStructInfo(ShapeExpr(Array<PrimExpr>(out_ndim, IntImm(out_dtype, /*value=*/1))),
                               out_dtype);
     } else {
+      if (data_sinfo->vdevice.defined()) {
+        return out_ndim == 0 ? TensorStructInfo(ShapeExpr(Array<PrimExpr>()), out_dtype,
+                                                data_sinfo->vdevice.value())
+                             : TensorStructInfo(out_dtype, out_ndim, data_sinfo->vdevice.value());
+      }
       return out_ndim == 0 ? TensorStructInfo(ShapeExpr(Array<PrimExpr>()), out_dtype)
                            : TensorStructInfo(out_dtype, out_ndim);
     }
@@ -153,6 +193,9 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx
     }
   }
   ICHECK_EQ(static_cast<int>(out_shape.size()), out_ndim);
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(out_shape), out_dtype, data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
 }
 
diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc
index cb6a332d49..3920cccadd 100644
--- a/src/relax/op/tensor/set.cc
+++ b/src/relax/op/tensor/set.cc
@@ -86,21 +86,45 @@ StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) {
 
   // unique values
   if (data_sinfo->ndim == 0) {
-    output_sinfo.push_back(
-        TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), data_sinfo->dtype));
+    if (data_sinfo->vdevice.defined()) {
+      output_sinfo.push_back(TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}),
+                                              data_sinfo->dtype, data_sinfo->vdevice.value()));
+    } else {
+      output_sinfo.push_back(
+          TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), data_sinfo->dtype));
+    }
   } else if (axis.defined()) {
-    output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim));
+    if (data_sinfo->vdevice.defined()) {
+      output_sinfo.push_back(
+          TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value()));
+    } else {
+      output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim));
+    }
   } else {
-    output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, /*ndim=*/1));
+    if (data_sinfo->vdevice.defined()) {
+      output_sinfo.push_back(
+          TensorStructInfo(data_sinfo->dtype, /*ndim=*/1, data_sinfo->vdevice.value()));
+    } else {
+      output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, /*ndim=*/1));
+    }
   }
 
   // index, reverse and counts
   TensorStructInfo int_return{nullptr};
   if (data_sinfo->ndim == 0) {
-    int_return =
-        TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), DataType::Int(64));
+    if (data_sinfo->vdevice.defined()) {
+      int_return = TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}),
+                                    DataType::Int(64), data_sinfo->vdevice.value());
+    } else {
+      int_return =
+          TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), DataType::Int(64));
+    }
   } else {
-    int_return = TensorStructInfo(DataType::Int(64), /*ndim=*/1);
+    if (data_sinfo->vdevice.defined()) {
+      int_return = TensorStructInfo(DataType::Int(64), /*ndim=*/1, data_sinfo->vdevice.value());
+    } else {
+      int_return = TensorStructInfo(DataType::Int(64), /*ndim=*/1);
+    }
   }
   for (int i = 0; i < n_int_return; ++i) {
     output_sinfo.push_back(int_return);
diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc
index 6d1cc86f0a..c450738a1d 100644
--- a/src/relax/op/tensor/statistical.cc
+++ b/src/relax/op/tensor/statistical.cc
@@ -61,10 +61,21 @@ StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx)
   const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
   if (data_shape == nullptr) {
     if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) {
+      if (data_sinfo->vdevice.defined()) {
+        return TensorStructInfo(
+            ShapeExpr(Array<PrimExpr>(out_ndim, IntImm(DataType::Int(64), /*value=*/1))),
+            data_sinfo->dtype, data_sinfo->vdevice.value());
+      }
       return TensorStructInfo(
           ShapeExpr(Array<PrimExpr>(out_ndim, IntImm(DataType::Int(64), /*value=*/1))),
           data_sinfo->dtype);
     } else {
+      if (data_sinfo->vdevice.defined()) {
+        return out_ndim == 0
+                   ? TensorStructInfo(ShapeExpr(Array<PrimExpr>()), data_sinfo->dtype,
+                                      data_sinfo->vdevice.value())
+                   : TensorStructInfo(data_sinfo->dtype, out_ndim, data_sinfo->vdevice.value());
+      }
       return out_ndim == 0 ? TensorStructInfo(ShapeExpr(Array<PrimExpr>()), data_sinfo->dtype)
                            : TensorStructInfo(data_sinfo->dtype, out_ndim);
     }
@@ -80,6 +91,9 @@ StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx)
     }
   }
   ICHECK_EQ(static_cast<int>(out_shape.size()), out_ndim);
+  if (data_sinfo->vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice.value());
+  }
   return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
 }
 
@@ -158,19 +172,32 @@ StructInfo InferStructInfoCumsum(const Call& call, const BlockBuilder& ctx) {
     // flattened
     const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
     if (data_shape == nullptr) {
+      if (data_sinfo->vdevice.defined()) {
+        return TensorStructInfo(out_type, data_sinfo->ndim, data_sinfo->vdevice.value());
+      }
       return TensorStructInfo(out_type, data_sinfo->ndim);
     } else {
       PrimExpr flattened_d = 1;
       for (const auto v : data_shape->values) {
         flattened_d *= v;
       }
+      if (data_sinfo->vdevice.defined()) {
+        return TensorStructInfo(ShapeExpr(Array<PrimExpr>({flattened_d})), out_type,
+                                data_sinfo->vdevice.value());
+      }
       return TensorStructInfo(ShapeExpr(Array<PrimExpr>({flattened_d})), out_type);
     }
   }
 
   if (data_sinfo->shape.defined()) {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(data_sinfo->shape.value(), out_type, data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(data_sinfo->shape.value(), out_type);
   } else {
+    if (data_sinfo->vdevice.defined()) {
+      return TensorStructInfo(out_type, data_sinfo->ndim, data_sinfo->vdevice.value());
+    }
     return TensorStructInfo(out_type, data_sinfo->ndim);
   }
 }
diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc
index d1ff5b7863..cc265ad9a1 100644
--- a/src/relax/op/tensor/ternary.cc
+++ b/src/relax/op/tensor/ternary.cc
@@ -65,6 +65,21 @@ StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) {
     output_dtype = t1->dtype;
   }
 
+  VDevice vdev = VDevice();
+  for (int i = 0; i < 3; ++i) {
+    if (input_sinfo[i]->vdevice.defined()) {
+      if (!vdev.defined()) {
+        vdev = input_sinfo[i]->vdevice.value();
+      } else if (input_sinfo[i]->vdevice.value()->target.defined()) {
+        // mismatch
+        if (input_sinfo[i]->vdevice.value() != vdev) {
+          vdev = VDevice();
+          break;
+        }
+      }
+    }
+  }
+
   auto* s1 = t1->shape.as<ShapeExprNode>();
   auto* s2 = t2->shape.as<ShapeExprNode>();
   auto* s3 = t3->shape.as<ShapeExprNode>();
@@ -82,11 +97,19 @@ StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) {
                          << "The 3 arguments of EwiseFMA must have the same shape");
       }
     }
+    if (vdev.defined()) {
+      return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdev);
+    }
     return TensorStructInfo(ShapeExpr(output_shape), output_dtype);
   } else if (t1->shape.defined() && t1->shape.same_as(t2->shape) && t1->shape.same_as(t3->shape)) {
+    if (vdev.defined()) {
+      return TensorStructInfo(t1->shape.value(), output_dtype, vdev);
+    }
     return TensorStructInfo(t1->shape.value(), output_dtype);
   }
-
+  if (vdev.defined()) {
+    return TensorStructInfo(output_dtype, ndim, vdev);
+  }
   return TensorStructInfo(output_dtype, ndim);
 }
 
diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc
index c303a2c8f0..9813c4ed24 100644
--- a/src/relax/transform/alter_op_impl.cc
+++ b/src/relax/transform/alter_op_impl.cc
@@ -264,6 +264,10 @@ class AlterOpImplMutator : public ExprMutator {
     auto shape = GetShapeFromTensorStructInfo(tensor_sinfo);
     arith::Analyzer analyzer;
     auto new_shape = transform->MapShape(shape, &analyzer);
+    if (tensor_sinfo->vdevice.defined()) {
+      return TensorStructInfo(ShapeExpr(new_shape), tensor_sinfo->dtype,
+                              tensor_sinfo->vdevice.value());
+    }
     return TensorStructInfo(ShapeExpr(new_shape), tensor_sinfo->dtype);
   }
 
diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc
index 3c7be6959d..f91d221b40 100644
--- a/src/relax/transform/convert_layout.cc
+++ b/src/relax/transform/convert_layout.cc
@@ -267,10 +267,7 @@ class LayoutConvertMutator : public ExprMutator {
         new_shape.push_back(
             shape->values[from.LeafValue()->layout.IndexOf(to.LeafValue()->layout[i])]);
       }
-      VDevice vdev = VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
-      if (tsinfo->vdevice.defined()) {
-        vdev = tsinfo->vdevice.value();
-      }
+      VDevice vdev = tsinfo->vdevice.value_or(VDevice());
       return TensorStructInfo(ShapeExpr(new_shape), tsinfo->dtype, vdev, tsinfo->span);
     };
     StructInfo new_struct_info = TransformTupleLeaf<LayoutDecision>(
diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc
index b864a65969..d12d1080b9 100644
--- a/src/relax/transform/to_mixed_precision.cc
+++ b/src/relax/transform/to_mixed_precision.cc
@@ -289,7 +289,7 @@ class ToMixedPrecisionRewriter : public ExprMutator {
       if (fp16_input_names_.count(var->name_hint())) {
         auto sinfo = GetStructInfo(var);
         if (auto tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
-          VDevice vdev = VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
+          VDevice vdev = VDevice();
           if (tensor_sinfo->vdevice.defined()) {
             vdev = tensor_sinfo->vdevice.value();
           }
diff --git a/src/relax/transform/update_vdevice.cc b/src/relax/transform/update_vdevice.cc
new file mode 100644
index 0000000000..a964f80ea1
--- /dev/null
+++ b/src/relax/transform/update_vdevice.cc
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/update_vdevice.cc
+ * \brief Update Virtual Device pass.
+ */
+
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+
+class VDeviceMutator : public ExprMutator {
+ public:
+  VDeviceMutator(const IRModule& mod, VDevice new_vdevice, int64_t index)
+      : ExprMutator(mod), mod_(mod), new_vdevice_(new_vdevice) {
+    Array<GlobalInfo> vdevices = mod->global_infos["vdevice"];
+    old_vdevice_ = Downcast<VDevice>(vdevices[index]);
+  }
+
+  using ExprMutator::VisitExpr_;
+
+  Expr VisitExpr(const Expr& expr) final {
+    auto visited_expr = ExprMutator::VisitExpr(expr);
+    if (visited_expr->struct_info_.defined()) {
+      auto* tinfo = GetStructInfoAs<TensorStructInfoNode>(visited_expr);
+      bool unchanged = true;
+      if (tinfo != nullptr) {
+        if (tinfo->vdevice.defined()) {
+          VDevice cur_vdevice = tinfo->vdevice.value();
+          if (cur_vdevice == old_vdevice_) {
+            unchanged = false;
+          }
+        }
+      }
+      if (!unchanged) {
+        if (tinfo->shape.defined()) {
+          visited_expr->struct_info_ =
+              TensorStructInfo(tinfo->shape.value(), tinfo->dtype, new_vdevice_, tinfo->span);
+        } else {
+          visited_expr->struct_info_ =
+              TensorStructInfo(tinfo->dtype, tinfo->ndim, new_vdevice_, tinfo->span);
+        }
+      }
+    }
+    return visited_expr;
+  }
+
+  IRModule Run() {
+    for (const auto& [gv, func] : mod_->functions) {
+      if (func->IsInstance<relax::FunctionNode>()) {
+        relax::Function update_func = Downcast<Function>(VisitExpr(func));
+        builder_->UpdateFunction(gv, update_func);
+      }
+    }
+    Array<GlobalInfo> new_vdevices;
+    for (auto vdev : mod_->global_infos["vdevice"]) {
+      if (vdev == old_vdevice_) {
+        new_vdevices.push_back(new_vdevice_);
+      } else {
+        new_vdevices.push_back(vdev);
+      }
+    }
+    IRModule new_mod = builder_->GetContextIRModule();
+    new_mod->UpdateGlobalInfo("vdevice", new_vdevices);
+    return new_mod;
+  }
+
+ private:
+  /*! \brief Input IRModule */
+  IRModule mod_;
+  /*! \brief The new virtual device */
+  VDevice new_vdevice_;
+  /*! \brief The virtual device to be updated */
+  VDevice old_vdevice_;
+};
+
+namespace transform {
+
+Pass UpdateVDevice(VDevice new_vdevice, int64_t index) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule m,
+                                                                            PassContext pc) {
+    return relax::VDeviceMutator(m, new_vdevice, index).Run();
+  };
+  return CreateModulePass(/*pass_function=*/pass_func,
+                          /*opt_level=*/0,
+                          /*pass_name=*/"UpdateVDevice",
+                          /*required=*/{});
+}
+TVM_REGISTER_GLOBAL("relax.transform.UpdateVDevice").set_body_typed(UpdateVDevice);
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc
index 4cb19fb48c..2f2785ca44 100644
--- a/src/script/ir_builder/ir/ir.cc
+++ b/src/script/ir_builder/ir/ir.cc
@@ -136,7 +136,7 @@ VDevice LookupVDevice(String target_kind, int device_index) {
     }
   }
   LOG(WARNING) << "The annotated device was not found, please check your vdevice list.";
-  return VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
+  return VDevice();
 }
 
 TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule);
diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py
index d279b60b54..879194037c 100644
--- a/tests/python/relax/test_analysis_struct_info_analysis.py
+++ b/tests/python/relax/test_analysis_struct_info_analysis.py
@@ -23,7 +23,7 @@ import tvm
 import tvm.testing
 from tvm import TVMError
 from tvm import relax as rx
-from tvm import tir
+from tvm import tir, ir
 
 
 def test_get_static_type_basic():
@@ -218,6 +218,12 @@ def test_base_check():
     shape3 = rx.ShapeStructInfo([1, 2, 3])
     shape4 = rx.ShapeStructInfo([1, n, 3])
 
+    vdevice0 = ir.VDevice()
+    vdevice1 = ir.VDevice("llvm")
+    vdevice2 = ir.VDevice("cuda", 0)
+    vdevice3 = ir.VDevice("cuda", 2)
+    vdevice4 = ir.VDevice("cuda", 0, "")
+
     tensor0 = rx.TensorStructInfo(ndim=-1, dtype="int32")
     tensor1 = rx.TensorStructInfo(ndim=-1, dtype="float32")
     tensor2 = rx.TensorStructInfo(ndim=2, dtype="int32")
@@ -225,6 +231,16 @@ def test_base_check():
     tensor4 = rx.TensorStructInfo([n, m], "int32")
     tensor5 = rx.TensorStructInfo([n, m, 1], "int32")
     tensor6 = rx.TensorStructInfo([n, m, 2], "int32")
+    tensor7 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice0)
+    tensor8 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice1)
+    tensor9 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice2)
+    tensor10 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice3)
+    tensor11 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice4)
+    tensor12 = rx.TensorStructInfo([n, m, 2], "int32", vdevice0)
+    tensor13 = rx.TensorStructInfo([n, m, 2], "int32", vdevice1)
+    tensor14 = rx.TensorStructInfo([n, m, 2], "int32", vdevice2)
+    tensor15 = rx.TensorStructInfo([n, m, 2], "int32", vdevice3)
+    tensor16 = rx.TensorStructInfo([n, m, 2], "int32", vdevice4)
 
     # obj
     assert bcheck(obj0, prim0) == BR.PASS
@@ -271,6 +287,14 @@ def test_base_check():
     assert bcheck(tensor3, tensor4) == BR.FAIL_L0
     assert bcheck(tensor1, tensor2) == BR.FAIL_L0
 
+    # vdevice mismatch
+    assert bcheck(tensor8, tensor9) == BR.FAIL_L0
+    assert bcheck(tensor9, tensor10) == BR.FAIL_L0
+    assert bcheck(tensor10, tensor11) == BR.FAIL_L0
+    assert bcheck(tensor13, tensor14) == BR.FAIL_L0
+    assert bcheck(tensor14, tensor15) == BR.FAIL_L0
+    assert bcheck(tensor15, tensor16) == BR.FAIL_L0
+
     # ndim mismatch
     assert bcheck(tensor2, tensor5) == BR.FAIL_L0
 
@@ -284,6 +308,10 @@ def test_base_check():
     assert tensor0.is_base_of(tensor5)
     assert tensor0.is_base_of(tensor6)
     assert tensor2.is_base_of(tensor4)
+    assert tensor3.is_base_of(tensor7)
+    assert tensor3.is_base_of(tensor8)
+    assert tensor6.is_base_of(tensor12)
+    assert tensor6.is_base_of(tensor13)
     assert tensor4.is_base_of(rx.TensorStructInfo([n, m], dtype="int32"))
 
     # tuple
@@ -386,6 +414,22 @@ def test_derive_call_ret_struct_info():
         with pytest.raises(TVMError):
             _check_derive(bb, func0(2), [obj0], obj0)
 
+        # Tensor with vdevice
+        vdev = ir.VDevice("llvm")
+
+        def func1(c):
+            n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
+            x = rx.TensorStructInfo([n, m], "float32", vdev)
+            z = rx.TensorStructInfo([m + c, n], "float32", vdev)
+            return rx.FuncStructInfo([x], z)
+
+        _check_derive(
+            bb,
+            func1(1),
+            [rx.TensorStructInfo([10, 11], "float32", vdev)],
+            rx.TensorStructInfo([12, 10], "float32", vdev),
+        )
+
         # opaque derivation
         fopaque0 = lambda: rx.FuncStructInfo.opaque_func()
         fopaque1 = lambda: rx.FuncStructInfo.opaque_func(ret=prim0)
@@ -477,6 +521,9 @@ def test_struct_info_lca():
     prim0 = rx.PrimStructInfo("int32")
     prim1 = rx.PrimStructInfo("float32")
 
+    vdevice0 = ir.VDevice("llvm")
+    vdevice1 = ir.VDevice("cuda", 0)
+
     shape0 = rx.ShapeStructInfo(ndim=-1)
     shape1 = rx.ShapeStructInfo(ndim=2)
     shape2 = rx.ShapeStructInfo(ndim=3)
@@ -490,6 +537,10 @@ def test_struct_info_lca():
     tensor4 = rx.TensorStructInfo([n, m], "int32")
     tensor5 = rx.TensorStructInfo([n, m, 1], "int32")
     tensor6 = rx.TensorStructInfo([n, m, 2], "int32")
+    tensor7 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice0)
+    tensor8 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice1)
+    tensor9 = rx.TensorStructInfo([n, m, 2], "int32", vdevice0)
+    tensor10 = rx.TensorStructInfo([n, m, 2], "int32", vdevice1)
 
     # obj
     _check_lca(obj0, prim0, obj0)
@@ -510,6 +561,13 @@ def test_struct_info_lca():
     _check_lca(tensor0, tensor1, rx.TensorStructInfo(ndim=-1, dtype=None))
     _check_lca(tensor0, tensor2, tensor0)
     _check_lca(tensor0, tensor4, tensor0)
+    _check_lca(tensor0, tensor4, tensor0)
+    _check_lca(tensor1, tensor3, tensor1)
+    _check_lca(tensor3, tensor7, tensor3)
+    _check_lca(tensor3, tensor8, tensor3)
+    _check_lca(tensor1, tensor8, tensor1)
+    _check_lca(tensor6, tensor9, tensor6)
+    _check_lca(tensor6, tensor10, tensor6)
 
     _check_lca(tensor2, tensor4, tensor2)
     _check_lca(tensor5, tensor6, rx.TensorStructInfo(ndim=3, dtype="int32"))
diff --git a/tests/python/relax/test_op_binary.py b/tests/python/relax/test_op_binary.py
index ce9e5d507e..a0ec08f0ab 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -20,7 +20,7 @@ import tvm
 import tvm.testing
 from tvm import relax, tir
 from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
 from tvm.script import relax as R
 
 
@@ -73,16 +73,22 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
 
 def test_binary_arith_infer_struct_info(binary_arith_op: Callable):
     bb = relax.BlockBuilder()
+    vdevice0 = VDevice("llvm")
+    vdevice1 = VDevice("cuda", 0)
     x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
     x1 = relax.Var("x", R.Tensor((1, 3), "float32"))
     x2 = relax.Var("x", R.Tensor((3, 2, 3), "float32"))
     x3 = relax.Var("x", R.Tensor((3, 1, 3), "float32"))
     x4 = relax.Var("x", R.Tensor("float32", ndim=2))
     x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor("float32", ndim=2, vdevice=vdevice0))
+    x7 = relax.Var("x", R.Tensor((2, 3), "float32", vdevice0))
     y0 = relax.Var("y", R.Tensor((2, 3), "float32"))
     y1 = relax.Var("y", R.Tensor((4, 3, 2, 1), "float32"))
     y2 = relax.Var("y", R.Tensor("float32", ndim=2))
     y3 = relax.Var("y", R.Tensor("float32", ndim=-1))
+    y4 = relax.Var("y", R.Tensor((2, 3), "float32", vdevice0))
+    y5 = relax.Var("y", R.Tensor("float32", ndim=2, vdevice=vdevice0))
 
     _check_inference(bb, binary_arith_op(x0, y0), relax.TensorStructInfo((2, 3), "float32"))
     _check_inference(bb, binary_arith_op(x1, y0), relax.TensorStructInfo((2, 3), "float32"))
@@ -94,6 +100,19 @@ def test_binary_arith_infer_struct_info(binary_arith_op: Callable):
     _check_inference(bb, binary_arith_op(x4, y2), relax.TensorStructInfo(dtype="float32", ndim=2))
     _check_inference(bb, binary_arith_op(x4, y3), relax.TensorStructInfo(dtype="float32", ndim=-1))
     _check_inference(bb, binary_arith_op(x5, y0), relax.TensorStructInfo(dtype="", ndim=-1))
+    _check_inference(
+        bb,
+        binary_arith_op(x6, y5),
+        relax.TensorStructInfo(dtype="float32", ndim=2, vdevice=vdevice0),
+    )
+    _check_inference(
+        bb,
+        binary_arith_op(x6, y2),
+        relax.TensorStructInfo(dtype="float32", ndim=2, vdevice=vdevice0),
+    )
+    _check_inference(
+        bb, binary_arith_op(x7, y4), relax.TensorStructInfo((2, 3), "float32", vdevice0)
+    )
 
 
 (binary_cmp_op,) = tvm.testing.parameters(
@@ -108,15 +127,18 @@ def test_binary_arith_infer_struct_info(binary_arith_op: Callable):
 
 def test_binary_cmp_infer_struct_info(binary_cmp_op: Callable):
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x = relax.Var("x", R.Tensor((2, 3), "float32"))
     y0 = relax.Var("y", R.Tensor((2, 3), "float32"))
     y1 = relax.Var("y", R.Tensor((2, 3), "int32"))
+    y2 = relax.Var("y", R.Tensor((2, 3), "float32", vdev0))
     _check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool"))
     _check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool"))
     _check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool"))
     _check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool"))
     _check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool"))
     _check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool"))
+    _check_inference(bb, binary_cmp_op(x, y2), relax.TensorStructInfo((2, 3), "bool", vdev0))
 
 
 def test_binary_infer_struct_info_shape_symbolic(binary_arith_op: Callable):
@@ -198,6 +220,14 @@ def test_binary_arith_infer_struct_info_dtype_mismatch(binary_arith_op: Callable
         bb.normalize(binary_arith_op(x, y))
 
 
+def test_binary_arith_infer_struct_info_vdevice_mismatch(binary_arith_op: Callable):
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3), "float32", VDevice("llvm")))
+    y = relax.Var("y", R.Tensor((2, 3), "int32", VDevice("cuda")))
+    with pytest.raises(TVMError):
+        bb.normalize(binary_arith_op(x, y))
+
+
 def test_binary_wrong_input_number(binary_arith_op: Callable):
     x = relax.Var("x", R.Tensor((2, 3), "float32"))
 
diff --git a/tests/python/relax/test_op_create.py b/tests/python/relax/test_op_create.py
index 345e68b9b2..c0b4308529 100644
--- a/tests/python/relax/test_op_create.py
+++ b/tests/python/relax/test_op_create.py
@@ -19,7 +19,7 @@ import pytest
 import tvm
 import tvm.testing
 from tvm import TVMError, relax, tir
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
 from tvm.script import relax as R
 from tvm.script import tir as T
 
@@ -45,10 +45,12 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
 
 def test_full_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     v0 = relax.Var("v", R.Tensor((), "float32"))
     v1 = relax.Var("v", R.Tensor("float32", ndim=0))
     v2 = relax.Var("v", R.Tensor(()))
     v3 = relax.Var("v", R.Tensor(ndim=0))
+    v4 = relax.Var("v", R.Tensor((), "float32", vdev0))
     s0 = relax.ShapeExpr((2, 3))
     s1 = relax.Var("s", relax.ShapeStructInfo((2, 3)))
     s2 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
@@ -62,6 +64,7 @@ def test_full_infer_struct_info():
         bb, relax.op.full(s0, v0, "float16"), relax.TensorStructInfo((2, 3), "float16")
     )
     _check_inference(bb, relax.op.full(s0, v0), relax.TensorStructInfo((2, 3), "float32"))
+    _check_inference(bb, relax.op.full(s0, v4), relax.TensorStructInfo((2, 3), "float32", vdev0))
     _check_inference(bb, relax.op.full(s1, v0, "float16"), relax.TensorStructInfo(s1, "float16"))
     _check_inference(bb, relax.op.full(s1, v0), relax.TensorStructInfo(s1, "float32"))
     _check_inference(bb, relax.op.full(s2, v0, "float16"), relax.TensorStructInfo(s2, "float16"))
@@ -300,6 +303,7 @@ def test_full_like_infer_struct_info_shape_symbolic():
 
 def test_full_like_infer_struct_info_shape_var():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     s0 = relax.Var("s", relax.ShapeStructInfo((2, 3)))
     s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
     s2 = relax.Var("s", relax.ShapeStructInfo())
@@ -307,11 +311,13 @@ def test_full_like_infer_struct_info_shape_var():
     x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
     x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
     x3 = relax.Var("x", R.Tensor((2, 3), "float32"))
+    x4 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0))
     sv0 = relax.Var("sv", relax.ShapeStructInfo(()))
     sv1 = relax.Var("sv", relax.ShapeStructInfo(ndim=0))
     v0 = relax.Var("v", relax.TensorStructInfo(sv0, "float16"))
     v1 = relax.Var("v", relax.TensorStructInfo(sv1, "float16"))
     v2 = relax.Var("v", R.Tensor((), "float16"))
+    v3 = relax.Var("v", relax.TensorStructInfo(sv1, "float16", vdev0))
 
     _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorStructInfo(s0, "float32"))
     _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorStructInfo(s0, "float32"))
@@ -324,6 +330,9 @@ def test_full_like_infer_struct_info_shape_var():
     _check_inference(bb, relax.op.full_like(x2, v2), relax.TensorStructInfo(s2, "float32"))
     _check_inference(bb, relax.op.full_like(x3, v0), relax.TensorStructInfo((2, 3), "float32"))
     _check_inference(bb, relax.op.full_like(x3, v1), relax.TensorStructInfo((2, 3), "float32"))
+    _check_inference(
+        bb, relax.op.full_like(x4, v3), relax.TensorStructInfo((2, 3), "float32", vdev0)
+    )
 
 
 def test_full_like_infer_struct_info_more_input_dtype():
@@ -605,12 +614,14 @@ def test_arange_infer_struct_info_shape_var():
 
 def test_tril_triu_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=3))
     x2 = relax.Var("x", R.Tensor("float32"))
     x3 = relax.Var("x", R.Tensor((2, 3, 4)))
     x4 = relax.Var("x", R.Tensor(ndim=3))
     x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0))
 
     _check_inference(bb, relax.op.tril(x0, k=1), relax.TensorStructInfo((2, 3, 4), "float32"))
     _check_inference(bb, relax.op.triu(x0, k=0), relax.TensorStructInfo((2, 3, 4), "float32"))
@@ -620,18 +631,22 @@ def test_tril_triu_infer_struct_info():
     _check_inference(bb, relax.op.triu(x3), relax.TensorStructInfo((2, 3, 4), dtype=""))
     _check_inference(bb, relax.op.tril(x4), relax.TensorStructInfo(dtype="", ndim=3))
     _check_inference(bb, relax.op.triu(x5), relax.TensorStructInfo(dtype=""))
+    _check_inference(bb, relax.op.tril(x6), relax.TensorStructInfo((2, 3, 4), "float32", vdev0))
 
 
 def test_tril_triu_infer_struct_info_shape_symbolic():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     a = tir.Var("a", "int64")
     b = tir.Var("b", "int64")
     c = tir.Var("c", "int64")
     x0 = relax.Var("x", R.Tensor((a, b, c), "float32"))
     x1 = relax.Var("x", R.Tensor((a, b, c)))
+    x2 = relax.Var("x", R.Tensor((a, b, c), "float32", vdev0))
 
     _check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo((a, b, c), "float32"))
     _check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo((a, b, c), dtype=""))
+    _check_inference(bb, relax.op.tril(x2), relax.TensorStructInfo((a, b, c), "float32", vdev0))
 
 
 def test_tril_triu_infer_struct_info_shape_var():
diff --git a/tests/python/relax/test_op_image.py b/tests/python/relax/test_op_image.py
index b06b51a2a1..251a30139b 100644
--- a/tests/python/relax/test_op_image.py
+++ b/tests/python/relax/test_op_image.py
@@ -19,7 +19,7 @@ import tvm
 import tvm.testing
 from tvm import relax, tir
 from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
 from tvm.script import relax as R
 
 
@@ -35,6 +35,7 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
 
 def test_resize2d_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
     x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32"))
     x2 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32"))
@@ -43,10 +44,16 @@ def test_resize2d_infer_struct_info():
     x5 = relax.Var("x", R.Tensor("float32"))
     x6 = relax.Var("x", R.Tensor(ndim=4))
     x7 = relax.Var("x", R.Tensor())
+    x8 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32", vdev0))
 
     _check_inference(
         bb, relax.op.image.resize2d(x0, (28, 28)), relax.TensorStructInfo((2, 3, 28, 28), "float32")
     )
+    _check_inference(
+        bb,
+        relax.op.image.resize2d(x8, (28, 28)),
+        relax.TensorStructInfo((2, 3, 28, 28), "float32", vdev0),
+    )
     _check_inference(
         bb,
         relax.op.image.resize2d(x0, size=28),
diff --git a/tests/python/relax/test_op_index.py b/tests/python/relax/test_op_index.py
index fffd31f6f7..e3c9e4a596 100644
--- a/tests/python/relax/test_op_index.py
+++ b/tests/python/relax/test_op_index.py
@@ -19,7 +19,7 @@ import tvm
 import tvm.testing
 from tvm import relax, tir
 from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
 from tvm.script import ir as I, relax as R, tir as T
 
 
@@ -40,12 +40,14 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
 
 def test_take_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((4, 10), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=2))
     x2 = relax.Var("x", R.Tensor("float32"))
     x3 = relax.Var("x", R.Tensor((4, 10)))
     x4 = relax.Var("x", R.Tensor(ndim=2))
     x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor((4, 10), "float32", vdev0))
     y0 = relax.Var("y", R.Tensor((10,), "float32"))
     y1 = relax.Var("y", R.Tensor("float32", ndim=1))
     y2 = relax.Var("y", R.Tensor((10,)))
@@ -58,8 +60,12 @@ def test_take_infer_struct_info():
     idx5 = relax.Var("idx", R.Tensor("int64", ndim=2))
     idx6 = relax.Var("idx", R.Tensor((6, 4)))
     idx7 = relax.Var("idx", R.Tensor(ndim=2))
+    idx8 = relax.Var("idx", R.Tensor((6,), "int64", vdev0))
 
     _check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((4, 6), "float32"))
+    _check_inference(
+        bb, relax.op.take(x6, idx8, axis=1), relax.TensorStructInfo((4, 6), "float32", vdev0)
+    )
     _check_inference(
         bb, relax.op.take(x0, idx0, axis=-1), relax.TensorStructInfo((4, 6), "float32")
     )
@@ -355,12 +361,14 @@ def test_take_infer_struct_info_wrong_input_type():
 
 def test_strided_slice_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=4))
     x2 = relax.Var("x", R.Tensor("float32"))
     x3 = relax.Var("x", R.Tensor((8, 9, 10, 10)))
     x4 = relax.Var("x", R.Tensor(ndim=4))
     x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32", vdev0))
 
     _check_inference(
         bb,
@@ -369,6 +377,13 @@ def test_strided_slice_infer_struct_info():
         ),
         relax.TensorStructInfo((4, 9, 10, 3), "float32"),
     )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(
+            x6, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3]
+        ),
+        relax.TensorStructInfo((4, 9, 10, 3), "float32", vdev0),
+    )
     _check_inference(
         bb,
         relax.op.strided_slice(
diff --git a/tests/python/relax/test_op_linear_algebra.py b/tests/python/relax/test_op_linear_algebra.py
index ccea5f79eb..fc30ea8619 100644
--- a/tests/python/relax/test_op_linear_algebra.py
+++ b/tests/python/relax/test_op_linear_algebra.py
@@ -19,7 +19,7 @@ import tvm
 import tvm.testing
 from tvm import relax, tir
 from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
 from tvm.script import relax as R
 
 
@@ -36,6 +36,7 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
 
 def test_matmul_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((3, 4), "float32"))
     x1 = relax.Var("x", R.Tensor((4,), "float32"))
     x2 = relax.Var("x", R.Tensor((2, 3, 5, 4), "float32"))
@@ -43,14 +44,17 @@ def test_matmul_infer_struct_info():
     x4 = relax.Var("x", R.Tensor((2, 1, 4, 5)))
     x5 = relax.Var("x", R.Tensor("float32"))
     x6 = relax.Var("x", R.Tensor((2, 1, 4, 5), "float16"))
+    x7 = relax.Var("x", R.Tensor((3, 4), "float32", vdev0))
     y0 = relax.Var("y", R.Tensor((4, 5), "float32"))
     y1 = relax.Var("y", R.Tensor((4,), "float32"))
     y2 = relax.Var("y", R.Tensor((2, 3, 4, 5), "float32"))
     y3 = relax.Var("y", R.Tensor((6, 1, 3, 5, 7), "float32"))
     y4 = relax.Var("y", R.Tensor("float32", ndim=5))
     y5 = relax.Var("y", R.Tensor())
+    y6 = relax.Var("y", R.Tensor((4, 5), "float32", vdev0))
 
     _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo((3, 5), "float32"))
+    _check_inference(bb, relax.op.matmul(x7, y6), relax.TensorStructInfo((3, 5), "float32", vdev0))
     _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo((), "float32"))
     _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorStructInfo((2, 3, 5), "float32"))
     _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorStructInfo((2, 3, 5), "float32"))
@@ -208,19 +212,26 @@ def test_linear():
     # Since linear is only a sugar for transpose + matmul + add,
     # we only have brief tests here.
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x1 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
     x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0))
     w1 = relax.Var("w", R.Tensor((5, 4), "float32"))
     w2 = relax.Var("w", R.Tensor((4,), "float32"))
     w3 = relax.Var("w", R.Tensor("float32"))
+    w4 = relax.Var("w", R.Tensor((5, 4), "float32", vdev0))
     b1 = relax.Var("b", R.Tensor((5,), "float32"))
     b2 = relax.Var("b", R.Tensor((), "float32"))
+    b3 = relax.Var("b", R.Tensor((5,), "float32", vdev0))
 
     # Need a scope to normalize non-leaf nodes
     with bb.function("func", [x1]):
         _check_inference(
             bb, relax.op.linear(x1, w1, b1), relax.TensorStructInfo((2, 3, 5), "float32")
         )
+        _check_inference(
+            bb, relax.op.linear(x3, w4, b3), relax.TensorStructInfo((2, 3, 5), "float32", vdev0)
+        )
         _check_inference(
             bb, relax.op.linear(x1, w1, b2), relax.TensorStructInfo((2, 3, 5), "float32")
         )
@@ -242,6 +253,7 @@ def test_linear():
 
 def test_einsum_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x0", R.Tensor((), "float32"))
     x1 = relax.Var("x1", R.Tensor((5,), "int32"))
     x2 = relax.Var("x2", R.Tensor((5, 5), "int32"))
@@ -258,8 +270,10 @@ def test_einsum_infer_struct_info():
     x13 = relax.Var("x13", R.Tensor((1, 1, 1, 3), "float16"))
     x14 = relax.Var("x14", R.Tensor((1, 5, 3, 8, 4), "float32"))
     x15 = relax.Var("x15", R.Tensor((2, 5, 3, 6, 4), "float32"))
+    x16 = relax.Var("x16", R.Tensor((5, 5), "int32", vdev0))
 
     _check_inference(bb, relax.op.einsum((x2,), "ii"), relax.TensorStructInfo((), "int32"))
+    _check_inference(bb, relax.op.einsum((x16,), "ii"), relax.TensorStructInfo((), "int32", vdev0))
     _check_inference(bb, relax.op.einsum((x2,), "ii->i"), relax.TensorStructInfo((5,), "int32"))
     _check_inference(bb, relax.op.einsum([x2], "...j->..."), relax.TensorStructInfo((5,), "int32"))
     _check_inference(
diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py
index 40993b5da5..6c0fbcf227 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -19,7 +19,7 @@ import tvm
 import tvm.testing
 from tvm import relax, tir
 from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
 from tvm.script import relax as R
 
 
@@ -54,12 +54,14 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
 
 def test_reshape_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=4))
     x2 = relax.Var("x", R.Tensor("float32"))
     x3 = relax.Var("x", R.Tensor((2, 3, 4, 5)))
     x4 = relax.Var("x", R.Tensor(ndim=4))
     x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0))
     s0 = relax.Var("s", R.Shape((3, 8, 5)))
     s1 = relax.Var("s", R.Shape(ndim=3))
     s2 = relax.Var("s", R.Shape())
@@ -68,6 +70,9 @@ def test_reshape_infer_struct_info():
     _check_inference(
         bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32")
     )
+    _check_inference(
+        bb, relax.op.reshape(x6, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32", vdev0)
+    )
     _check_inference(
         bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32")
     )
@@ -319,6 +324,7 @@ def test_reshape_infer_struct_info_wrong_input_type():
 
 def test_permute_dims_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=4))
     x2 = relax.Var("x", R.Tensor("float32"))
@@ -327,10 +333,16 @@ def test_permute_dims_infer_struct_info():
     x5 = relax.Var("x", R.Tensor())
     x6 = relax.Var("x", R.Tensor((1,), "float32"))
     x7 = relax.Var("x", R.Tensor((), "float32"))
+    x8 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32", vdev0))
 
     _check_inference(
         bb, relax.op.permute_dims(x0, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "float32")
     )
+    _check_inference(
+        bb,
+        relax.op.permute_dims(x8, [2, 3, 1, 0]),
+        relax.TensorStructInfo((3, 4, 2, 1), "float32", vdev0),
+    )
     _check_inference(
         bb, relax.op.permute_dims(x0, axes=None), relax.TensorStructInfo((4, 3, 2, 1), "float32")
     )
@@ -529,16 +541,23 @@ def test_permute_dims_infer_struct_info_wrong_input_type():
 
 def test_expand_dims_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=3))
     x2 = relax.Var("x", R.Tensor("float32"))
     x3 = relax.Var("x", R.Tensor((2, 3, 4)))
     x4 = relax.Var("x", R.Tensor(ndim=3))
     x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0))
 
     _check_inference(
         bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "float32")
     )
+    _check_inference(
+        bb,
+        relax.op.expand_dims(x6, [1, 3]),
+        relax.TensorStructInfo((2, 1, 3, 1, 4), "float32", vdev0),
+    )
     _check_inference(
         bb,
         relax.op.expand_dims(x0, [-1, 1, -6, 3, 5]),
@@ -690,7 +709,9 @@ def test_expand_dims_infer_struct_info_wrong_input_type():
 
 def test_layout_transform_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x = relax.Var("x", R.Tensor((10, 20, 30), "float32"))
+    x1 = relax.Var("x", R.Tensor((10, 20, 30), "float32", vdev0))
 
     transpose_transform = lambda a, b, c: (a, c, b)
     _check_inference(
@@ -698,6 +719,11 @@ def test_layout_transform_infer_struct_info():
         relax.op.layout_transform(x, index_map=transpose_transform),
         relax.TensorStructInfo((10, 30, 20), "float32"),
     )
+    _check_inference(
+        bb,
+        relax.op.layout_transform(x1, index_map=transpose_transform),
+        relax.TensorStructInfo((10, 30, 20), "float32", vdev0),
+    )
 
     tiling_transform = lambda a, b, c: (a, b // 2, c, b % 2)
     _check_inference(
@@ -812,16 +838,21 @@ def test_layout_transform_infer_struct_info_invalid_index_map():
 
 def test_squeeze_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=6))
     x2 = relax.Var("x", R.Tensor("float32"))
     x3 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4)))
     x4 = relax.Var("x", R.Tensor(ndim=6))
     x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32", vdev0))
 
     _check_inference(
         bb, relax.op.squeeze(x0, [1, 4]), relax.TensorStructInfo((2, 3, 1, 4), "float32")
     )
+    _check_inference(
+        bb, relax.op.squeeze(x6, [1, 4]), relax.TensorStructInfo((2, 3, 1, 4), "float32", vdev0)
+    )
     _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo((2, 3, 4), "float32"))
     _check_inference(
         bb, relax.op.squeeze(x1, [1, 4]), relax.TensorStructInfo(dtype="float32", ndim=4)
@@ -983,6 +1014,7 @@ def test_squeeze_infer_struct_info_wrong_input_type():
 
 def test_flatten_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
     x1 = relax.Var("x", R.Tensor((3,), "float32"))
     x2 = relax.Var("x", R.Tensor((), "float32"))
@@ -997,8 +1029,10 @@ def test_flatten_infer_struct_info():
     x11 = relax.Var("x", R.Tensor(ndim=1))
     x12 = relax.Var("x", R.Tensor(ndim=0))
     x13 = relax.Var("x", R.Tensor())
+    x14 = relax.Var("x", R.Tensor((3, 4, 5), "float32", vdev0))
 
     _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo((60,), "float32"))
+    _check_inference(bb, relax.op.flatten(x14), relax.TensorStructInfo((60,), "float32", vdev0))
     _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo((3,), "float32"))
     _check_inference(bb, relax.op.flatten(x2), relax.TensorStructInfo((1,), "float32"))
     _check_inference(bb, relax.op.flatten(x3), relax.TensorStructInfo(dtype="float32", ndim=1))
@@ -1083,28 +1117,42 @@ def test_flatten_wrong_input_number():
 
 def test_concat_infer_struct_info_with_axis():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=3))
     x2 = relax.Var("x", R.Tensor("float32"))
     x3 = relax.Var("x", R.Tensor((2, 3, 4)))
     x4 = relax.Var("x", R.Tensor(ndim=3))
     x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0))
     y0 = relax.Var("y", R.Tensor((2, 4, 4), "float32"))
     y1 = relax.Var("y", R.Tensor("float32", ndim=3))
     y2 = relax.Var("y", R.Tensor("float32"))
     y3 = relax.Var("y", R.Tensor((2, 4, 4)))
     y4 = relax.Var("y", R.Tensor(ndim=3))
     y5 = relax.Var("y", R.Tensor())
+    y6 = relax.Var("y", R.Tensor((2, 4, 4), "float32", vdev0))
     z0 = relax.Var("z", R.Tensor((2, 5, 4), "float32"))
     z1 = relax.Var("z", R.Tensor("float32", ndim=3))
     z2 = relax.Var("z", R.Tensor("float32"))
     z3 = relax.Var("z", R.Tensor((2, 5, 4)))
     z4 = relax.Var("z", R.Tensor(ndim=3))
     z5 = relax.Var("z", R.Tensor())
+    z6 = relax.Var("z", R.Tensor((2, 5, 4), "float32", vdev0))
 
     _check_inference(
         bb, relax.op.concat([x0, y0, z0], axis=1), relax.TensorStructInfo((2, 12, 4), "float32")
     )
+    _check_inference(
+        bb,
+        relax.op.concat([x6, y6, z6], axis=1),
+        relax.TensorStructInfo((2, 12, 4), "float32", vdev0),
+    )
+    _check_inference(
+        bb,
+        relax.op.concat([x6, y0, z0], axis=1),
+        relax.TensorStructInfo((2, 12, 4), "float32", vdev0),
+    )
     _check_inference(
         bb, relax.op.concat([x0, y0, z0], axis=-2), relax.TensorStructInfo((2, 12, 4), "float32")
     )
@@ -1658,12 +1706,14 @@ def test_concat_infer_struct_info_input_tuple_field_not_tensor():
 
 def test_split_infer_struct_info_by_indices():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=3))
     x2 = relax.Var("x", R.Tensor("float32"))
     x3 = relax.Var("x", R.Tensor((2, 10, 4)))
     x4 = relax.Var("x", R.Tensor(ndim=3))
     x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0))
 
     _check_inference(
         bb,
@@ -1676,6 +1726,17 @@ def test_split_infer_struct_info_by_indices():
             ]
         ),
     )
+    _check_inference(
+        bb,
+        relax.op.split(x6, [3, 7], axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((2, 3, 4), "float32", vdev0),
+                relax.TensorStructInfo((2, 4, 4), "float32", vdev0),
+                relax.TensorStructInfo((2, 3, 4), "float32", vdev0),
+            ]
+        ),
+    )
     _check_inference(
         bb,
         relax.op.split(x0, [3, 7], axis=-2),
@@ -2176,16 +2237,23 @@ def test_split_infer_struct_info_wrong_input_type():
 
 def test_broadcast_to_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 1, 3), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=3))
     x2 = relax.Var("x", R.Tensor("float32"))
     x3 = relax.Var("x", R.Tensor((2, 1, 3)))
     x4 = relax.Var("x", R.Tensor(ndim=3))
     x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor((2, 1, 3), "float32", vdev0))
 
     _check_inference(
         bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32")
     )
+    _check_inference(
+        bb,
+        relax.op.broadcast_to(x6, (4, 2, 5, 3)),
+        relax.TensorStructInfo((4, 2, 5, 3), "float32", vdev0),
+    )
     _check_inference(
         bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32")
     )
@@ -2405,10 +2473,11 @@ def test_broadcast_to_infer_struct_info_wrong_input_type():
 
 def test_collapse_sum_like_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=3))
     x2 = relax.Var("x", R.Tensor("float32"))
-    x3 = relax.Var("x", R.Tensor((2, 3, 4)))
+    x3 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0))
     x4 = relax.Var("x", R.Tensor(ndim=3))
     x5 = relax.Var("x", R.Tensor())
     y0 = relax.Var("y", R.Tensor((3, 4), "float32"))
@@ -2417,10 +2486,14 @@ def test_collapse_sum_like_infer_struct_info():
     y3 = relax.Var("y", R.Tensor((3, 4)))
     y4 = relax.Var("y", R.Tensor(ndim=2))
     y5 = relax.Var("y", R.Tensor((1, 4)))
+    y6 = relax.Var("y", R.Tensor((3, 4), "float32", vdev0))
 
     _check_inference(
         bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((3, 4), "float32")
     )
+    _check_inference(
+        bb, relax.op.collapse_sum_like(x3, y6), relax.TensorStructInfo((3, 4), "float32", vdev0)
+    )
     _check_inference(
         bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo(dtype="float32", ndim=2)
     )
@@ -2728,18 +2801,25 @@ def test_collapse_sum_to_infer_struct_info_struct_info_tgt_shape_var():
 
 def test_repeat_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=3))
     x2 = relax.Var("x", R.Tensor("float32"))
     x3 = relax.Var("x", R.Tensor((2, 10, 4)))
     x4 = relax.Var("x", R.Tensor(ndim=3))
     x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0))
 
     _check_inference(
         bb,
         relax.op.repeat(x0, 2, axis=0),
         relax.TensorStructInfo((4, 10, 4), "float32"),
     )
+    _check_inference(
+        bb,
+        relax.op.repeat(x6, 2, axis=0),
+        relax.TensorStructInfo((4, 10, 4), "float32", vdev0),
+    )
     _check_inference(
         bb,
         relax.op.repeat(x0, 2, axis=-2),
@@ -2853,18 +2933,25 @@ def test_repeat_infer_struct_info_wrong_input_type():
 
 def test_tile_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=3))
     x2 = relax.Var("x", R.Tensor("float32"))
     x3 = relax.Var("x", R.Tensor((2, 10, 4)))
     x4 = relax.Var("x", R.Tensor(ndim=3))
     x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0))
 
     _check_inference(
         bb,
         relax.op.tile(x0, 2),
         relax.TensorStructInfo((2, 10, 8), "float32"),
     )
+    _check_inference(
+        bb,
+        relax.op.tile(x6, 2),
+        relax.TensorStructInfo((2, 10, 8), "float32", vdev0),
+    )
     _check_inference(
         bb,
         relax.op.tile(x0, (3, 2)),
@@ -2971,13 +3058,18 @@ def test_tile_infer_struct_info_wrong_input_type():
 
 def test_flip_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
     x1 = relax.Var("x", R.Tensor("float16", ndim=3))
     x2 = relax.Var("x", R.Tensor("int32"))
     x3 = relax.Var("x", R.Tensor((2, 10, 4)))
     x4 = relax.Var("x", R.Tensor(ndim=3))
+    x5 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0))
 
     _check_inference(bb, relax.op.flip(x0, axis=1), relax.TensorStructInfo((2, 10, 4), "float32"))
+    _check_inference(
+        bb, relax.op.flip(x5, axis=1), relax.TensorStructInfo((2, 10, 4), "float32", vdev0)
+    )
     _check_inference(bb, relax.op.flip(x1, axis=0), R.Tensor("float16", ndim=3))
     _check_inference(bb, relax.op.flip(x2, axis=0), R.Tensor("int32"))
     _check_inference(bb, relax.op.flip(x3, axis=2), R.Tensor((2, 10, 4)))
@@ -3003,19 +3095,28 @@ def test_flip_infer_struct_info_wrong_inputs():
 
 def test_scatter_elements_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     d0 = relax.Var("data", R.Tensor((4, 4), "float32"))
     d1 = relax.Var("data", R.Tensor(dtype="float32", ndim=2))
     d2 = relax.Var("data", R.Tensor("float32"))
+    d3 = relax.Var("data", R.Tensor((4, 4), "float32", vdev0))
     i0 = relax.Var("indices", R.Tensor((2, 2), "int64"))
     i1 = relax.Var("indices", R.Tensor((2, 2)))
     i2 = relax.Var("indices", R.Tensor(dtype="int64", ndim=2))
     i3 = relax.Var("indices", R.Tensor(ndim=2))
+    i4 = relax.Var("indices", R.Tensor((2, 2), "int64", vdev0))
     u0 = relax.Var("updates", R.Tensor((2, 2), "float32"))
+    u1 = relax.Var("updates", R.Tensor((2, 2), "float32", vdev0))
     _check_inference(
         bb,
         relax.op.scatter_elements(d0, i0, u0, 0, "updates"),
         relax.TensorStructInfo((4, 4), dtype="float32"),
     )
+    _check_inference(
+        bb,
+        relax.op.scatter_elements(d3, i4, u1, 0, "updates"),
+        relax.TensorStructInfo((4, 4), dtype="float32", vdevice=vdev0),
+    )
     _check_inference(
         bb,
         relax.op.scatter_elements(d1, i0, u0, 0, "updates"),
diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py
index de1cf079a5..7adfc84283 100644
--- a/tests/python/relax/test_op_nn.py
+++ b/tests/python/relax/test_op_nn.py
@@ -19,7 +19,7 @@ import tvm
 import tvm.testing
 from tvm import relax, tir
 from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
 from tvm.script import relax as R
 
 
@@ -58,14 +58,17 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
 
 def test_linear_unit_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=3))
     x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
     x3 = relax.Var("x", R.Tensor((2, 3)))
     x4 = relax.Var("x", R.Tensor())
     x5 = relax.Var("x", R.Tensor((3, 4)))
+    x6 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0))
 
     _check_inference(bb, relax.op.nn.relu(x0), relax.TensorStructInfo((2, 3), "float32"))
+    _check_inference(bb, relax.op.nn.relu(x6), relax.TensorStructInfo((2, 3), "float32", vdev0))
     _check_inference(bb, relax.op.nn.silu(x1), relax.TensorStructInfo(dtype="float32", ndim=3))
     _check_inference(bb, relax.op.nn.gelu(x2), relax.TensorStructInfo(dtype="float32"))
     _check_inference(bb, relax.op.nn.relu(x3), relax.TensorStructInfo((2, 3), dtype=""))
@@ -133,13 +136,16 @@ def test_linear_unit_infer_struct_info_wrong_input_type():
 
 def test_softmax_log_softmax_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=3))
     x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
     x3 = relax.Var("x", R.Tensor((2, 3)))
     x4 = relax.Var("x", R.Tensor())
+    x5 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0))
 
     _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2, 3), "float32"))
+    _check_inference(bb, relax.op.nn.softmax(x5), relax.TensorStructInfo((2, 3), "float32", vdev0))
     _check_inference(
         bb, relax.op.nn.softmax(x1, axis=0), relax.TensorStructInfo(dtype="float32", ndim=3)
     )
@@ -1096,11 +1102,13 @@ def test_group_norm_infer_struct_info_wrong_input_type():
 
 def test_dropout_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=3))
     x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
     x3 = relax.Var("x", R.Tensor((2, 3)))
     x4 = relax.Var("x", R.Tensor())
+    x5 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0))
 
     _check_inference(
         bb,
@@ -1109,6 +1117,16 @@ def test_dropout_infer_struct_info():
             [relax.TensorStructInfo((2, 3), "float32"), relax.TensorStructInfo((2, 3), "float32")]
         ),
     )
+    _check_inference(
+        bb,
+        relax.op.nn.dropout(x5),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((2, 3), "float32", vdev0),
+                relax.TensorStructInfo((2, 3), "float32", vdev0),
+            ]
+        ),
+    )
     _check_inference(
         bb,
         relax.op.nn.dropout(x1),
@@ -1220,6 +1238,7 @@ def test_dropout_infer_struct_info_wrong_input_type():
 
 def test_cross_entropy_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x = relax.Var("x", R.Tensor((2, 3), "float32"))
     y0 = relax.Var("y", R.Tensor((2, 3), "float32"))
     y1 = relax.Var("y", R.Tensor("float32", ndim=2))
diff --git a/tests/python/relax/test_op_nn_convolution.py b/tests/python/relax/test_op_nn_convolution.py
index 2ec451c132..6be1245fe2 100644
--- a/tests/python/relax/test_op_nn_convolution.py
+++ b/tests/python/relax/test_op_nn_convolution.py
@@ -19,7 +19,7 @@ import tvm
 import tvm.testing
 from tvm import relax, tir
 from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
 from tvm.script import relax as R
 
 
@@ -44,19 +44,25 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
 
 def test_conv1d_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
     x1 = relax.Var("x", R.Tensor((2, 28, 3), "float32"))
     x2 = relax.Var("x", R.Tensor("float32", ndim=3))
     x3 = relax.Var("x", R.Tensor("float32"))
     x4 = relax.Var("x", R.Tensor())
     x5 = relax.Var("x", R.Tensor((2, 4, 28, 16), "float32"))
+    x6 = relax.Var("x", R.Tensor((2, 3, 28), "float32", vdev0))
     w0 = relax.Var("w", R.Tensor((4, 3, 3), "float32"))
     w1 = relax.Var("w", R.Tensor((3, 4, 3), "float32"))
     w2 = relax.Var("w", R.Tensor("float32", ndim=3))
     w3 = relax.Var("w", R.Tensor("float32"))
     w4 = relax.Var("w", R.Tensor((48, 4, 3, 16), "float32"))
+    w5 = relax.Var("w", R.Tensor((4, 3, 3), "float32", vdev0))
 
     _check_inference(bb, relax.op.nn.conv1d(x0, w0), relax.TensorStructInfo((2, 4, 26), "float32"))
+    _check_inference(
+        bb, relax.op.nn.conv1d(x6, w5), relax.TensorStructInfo((2, 4, 26), "float32", vdev0)
+    )
     _check_inference(
         bb,
         relax.op.nn.conv1d(x0, w0, out_dtype="float16"),
@@ -414,21 +420,29 @@ def test_conv1d_infer_struct_info_wrong_input_type():
 
 def test_conv1d_transpose_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
     x1 = relax.Var("x", R.Tensor((2, 28, 3), "float32"))
     x2 = relax.Var("x", R.Tensor("float32", ndim=3))
     x3 = relax.Var("x", R.Tensor("float32"))
     x4 = relax.Var("x", R.Tensor())
     x5 = relax.Var("x", R.Tensor((2, 4, 28, 16), "float32"))
+    x6 = relax.Var("x", R.Tensor((2, 3, 28), "float32", vdev0))
     w0 = relax.Var("w", R.Tensor((3, 4, 3), "float32"))
     w1 = relax.Var("w", R.Tensor((4, 3, 3), "float32"))
     w2 = relax.Var("w", R.Tensor("float32", ndim=3))
     w3 = relax.Var("w", R.Tensor("float32"))
     w4 = relax.Var("w", R.Tensor((4, 48, 3, 16), "float32"))
+    w5 = relax.Var("w", R.Tensor((3, 4, 3), "float32", vdev0))
 
     _check_inference(
         bb, relax.op.nn.conv1d_transpose(x0, w0), relax.TensorStructInfo((2, 4, 30), "float32")
     )
+    _check_inference(
+        bb,
+        relax.op.nn.conv1d_transpose(x6, w5),
+        relax.TensorStructInfo((2, 4, 30), "float32", vdev0),
+    )
     _check_inference(
         bb,
         relax.op.nn.conv1d_transpose(x0, w0, out_dtype="float16"),
@@ -764,21 +778,27 @@ def test_conv1d_transpose_infer_struct_info_wrong_input_type():
 
 def test_conv2d_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
     x1 = relax.Var("x", R.Tensor((2, 28, 28, 3), "float32"))
     x2 = relax.Var("x", R.Tensor("float32", ndim=4))
     x3 = relax.Var("x", R.Tensor("float32"))
     x4 = relax.Var("x", R.Tensor())
     x5 = relax.Var("x", R.Tensor((2, 4, 28, 28, 16), "float32"))
+    x6 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32", vdev0))
     w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
     w1 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32"))
     w2 = relax.Var("w", R.Tensor("float32", ndim=4))
     w3 = relax.Var("w", R.Tensor("float32"))
     w4 = relax.Var("w", R.Tensor((48, 4, 3, 3, 16), "float32"))
+    w5 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32", vdev0))
 
     _check_inference(
         bb, relax.op.nn.conv2d(x0, w0), relax.TensorStructInfo((2, 4, 26, 26), "float32")
     )
+    _check_inference(
+        bb, relax.op.nn.conv2d(x6, w5), relax.TensorStructInfo((2, 4, 26, 26), "float32", vdev0)
+    )
     _check_inference(
         bb,
         relax.op.nn.conv2d(x0, w0, out_dtype="float16"),
@@ -1155,21 +1175,29 @@ def test_conv2d_infer_struct_info_wrong_input_type():
 
 def test_conv2d_transpose_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
     x1 = relax.Var("x", R.Tensor((2, 28, 28, 3), "float32"))
     x2 = relax.Var("x", R.Tensor("float32", ndim=4))
     x3 = relax.Var("x", R.Tensor("float32"))
     x4 = relax.Var("x", R.Tensor())
     x5 = relax.Var("x", R.Tensor((2, 4, 28, 28, 16), "float32"))
+    x6 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32", vdev0))
     w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32"))
     w1 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
     w2 = relax.Var("w", R.Tensor("float32", ndim=4))
     w3 = relax.Var("w", R.Tensor("float32"))
     w4 = relax.Var("w", R.Tensor((4, 48, 3, 3, 16), "float32"))
+    w5 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32", vdev0))
 
     _check_inference(
         bb, relax.op.nn.conv2d_transpose(x0, w0), relax.TensorStructInfo((2, 4, 30, 30), "float32")
     )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x6, w5),
+        relax.TensorStructInfo((2, 4, 30, 30), "float32", vdev0),
+    )
     _check_inference(
         bb,
         relax.op.nn.conv2d_transpose(x0, w0, out_dtype="float16"),
diff --git a/tests/python/relax/test_op_nn_pooling.py b/tests/python/relax/test_op_nn_pooling.py
index 2bd7747f31..2533a2fcad 100644
--- a/tests/python/relax/test_op_nn_pooling.py
+++ b/tests/python/relax/test_op_nn_pooling.py
@@ -19,7 +19,7 @@ import tvm
 import tvm.testing
 from tvm import relax, tir
 from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
 from tvm.script import relax as R
 
 
@@ -37,6 +37,7 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
 
 def test_max_pool2d_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
     x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32"))
     x2 = relax.Var("x", R.Tensor("float32", ndim=4))
@@ -44,10 +45,14 @@ def test_max_pool2d_infer_struct_info():
     x4 = relax.Var("x", R.Tensor(ndim=4))
     x5 = relax.Var("x", R.Tensor())
     x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32"))
+    x7 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32", vdev0))
 
     _check_inference(
         bb, relax.op.nn.max_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32")
     )
+    _check_inference(
+        bb, relax.op.nn.max_pool2d(x7), relax.TensorStructInfo((2, 3, 32, 32), "float32", vdev0)
+    )
     _check_inference(
         bb,
         relax.op.nn.max_pool2d(x0, pool_size=3),
@@ -262,6 +267,7 @@ def test_max_pool2d_infer_struct_info_wrong_input_type():
 
 def test_avg_pool2d_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
     x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32"))
     x2 = relax.Var("x", R.Tensor("float32", ndim=4))
@@ -269,10 +275,14 @@ def test_avg_pool2d_infer_struct_info():
     x4 = relax.Var("x", R.Tensor(ndim=4))
     x5 = relax.Var("x", R.Tensor())
     x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32"))
+    x7 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32", vdev0))
 
     _check_inference(
         bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32")
     )
+    _check_inference(
+        bb, relax.op.nn.avg_pool2d(x7), relax.TensorStructInfo((2, 3, 32, 32), "float32", vdev0)
+    )
     _check_inference(
         bb,
         relax.op.nn.avg_pool2d(x0, pool_size=3),
@@ -487,6 +497,7 @@ def test_avg_pool2d_infer_struct_info_wrong_input_type():
 
 def test_adaptive_avg_pool2d_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
     x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32"))
     x2 = relax.Var("x", R.Tensor("float32", ndim=4))
@@ -494,10 +505,16 @@ def test_adaptive_avg_pool2d_infer_struct_info():
     x4 = relax.Var("x", R.Tensor(ndim=4))
     x5 = relax.Var("x", R.Tensor())
     x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32"))
+    x7 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32", vdev0))
 
     _check_inference(
         bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32")
     )
+    _check_inference(
+        bb,
+        relax.op.nn.adaptive_avg_pool2d(x7),
+        relax.TensorStructInfo((2, 3, 32, 32), "float32", vdev0),
+    )
     _check_inference(
         bb,
         relax.op.nn.adaptive_avg_pool2d(x0, output_size=30),
diff --git a/tests/python/relax/test_op_search.py b/tests/python/relax/test_op_search.py
index ba78d11022..21f022d9eb 100644
--- a/tests/python/relax/test_op_search.py
+++ b/tests/python/relax/test_op_search.py
@@ -21,7 +21,7 @@ import tvm
 import tvm.testing
 from tvm import relax, tir
 from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
 from tvm.script import relax as R
 
 
@@ -41,25 +41,32 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
 
 def test_where_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     cond0 = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool"))
     cond1 = relax.Var("cond", R.Tensor("bool", ndim=5))
     cond2 = relax.Var("cond", R.Tensor("bool"))
+    cond3 = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool", vdev0))
     x0 = relax.Var("x", R.Tensor((5, 1, 3, 2), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=4))
     x2 = relax.Var("x", R.Tensor("float32"))
     x3 = relax.Var("x", R.Tensor((5, 1, 3, 2)))
     x4 = relax.Var("x", R.Tensor(ndim=4))
     x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor((5, 1, 3, 2), "float32", vdev0))
     y0 = relax.Var("y", R.Tensor((4, 3, 1), "float32"))
     y1 = relax.Var("y", R.Tensor("float32", ndim=3))
     y2 = relax.Var("y", R.Tensor("float32"))
     y3 = relax.Var("y", R.Tensor((4, 3, 1)))
     y4 = relax.Var("y", R.Tensor(ndim=3))
     y5 = relax.Var("y", R.Tensor())
+    y6 = relax.Var("y", R.Tensor((4, 3, 1), "float32", vdev0))
 
     _check_inference(
         bb, relax.op.where(cond0, x0, y0), relax.TensorStructInfo((6, 5, 4, 3, 2), "float32")
     )
+    _check_inference(
+        bb, relax.op.where(cond3, x6, y6), relax.TensorStructInfo((6, 5, 4, 3, 2), "float32", vdev0)
+    )
     _check_inference(
         bb, relax.op.where(cond0, x1, y0), relax.TensorStructInfo(dtype="float32", ndim=5)
     )
@@ -283,12 +290,17 @@ def test_where_infer_struct_info_wrong_input_type():
 
 def test_argmax_argmin_infer_struct_info(argmax_argmin_op: Callable):
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=4))
     x2 = relax.Var("x", R.Tensor("float32"))
     x3 = relax.Var("x", R.Tensor((2, 3, 4, 5)))
+    x4 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0))
 
     _check_inference(bb, argmax_argmin_op(x0, axis=1), relax.TensorStructInfo((2, 4, 5), "int64"))
+    _check_inference(
+        bb, argmax_argmin_op(x4, axis=1), relax.TensorStructInfo((2, 4, 5), "int64", vdev0)
+    )
     _check_inference(
         bb,
         argmax_argmin_op(x0, axis=1, keepdims=True),
diff --git a/tests/python/relax/test_op_set.py b/tests/python/relax/test_op_set.py
index 755d5e8f87..741d7869d5 100644
--- a/tests/python/relax/test_op_set.py
+++ b/tests/python/relax/test_op_set.py
@@ -19,7 +19,7 @@ import tvm
 import tvm.testing
 from tvm import relax, tir
 from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
 from tvm.script import relax as R
 
 
@@ -35,10 +35,12 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
 
 def test_unique_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=3))
     x2 = relax.Var("x", R.Tensor("float32"))
     x3 = relax.Var("x", R.Tensor((2, 3, 4)))
+    x4 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0))
 
     _check_inference(
         bb,
@@ -47,6 +49,13 @@ def test_unique_infer_struct_info():
         ),
         relax.TensorStructInfo(dtype="float32", ndim=1),
     )
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x4, return_index=False, return_inverse=False, return_counts=False, axis=None
+        ),
+        relax.TensorStructInfo(dtype="float32", ndim=1, vdevice=vdev0),
+    )
     _check_inference(
         bb,
         relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=False, axis=1),
diff --git a/tests/python/relax/test_op_statistical.py b/tests/python/relax/test_op_statistical.py
index 8c542b86a1..5c7d56556c 100644
--- a/tests/python/relax/test_op_statistical.py
+++ b/tests/python/relax/test_op_statistical.py
@@ -19,7 +19,7 @@ import tvm
 import tvm.testing
 from tvm import relax, tir
 from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
 from tvm.script import relax as R
 
 
@@ -41,12 +41,17 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
 
 def test_statistical_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=4))
     x2 = relax.Var("x", R.Tensor("float32"))
     x3 = relax.Var("x", R.Tensor((2, 3, 4, 5)))
+    x4 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0))
 
     _check_inference(bb, relax.op.sum(x0, axis=[1, 2]), relax.TensorStructInfo((2, 5), "float32"))
+    _check_inference(
+        bb, relax.op.sum(x4, axis=[1, 2]), relax.TensorStructInfo((2, 5), "float32", vdev0)
+    )
     _check_inference(
         bb,
         relax.op.sum(x0, axis=[1, 2], keepdims=True),
@@ -202,14 +207,19 @@ def test_statistical_infer_struct_info_wrong_input_type():
 
 def test_cumsum_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=3))
     x2 = relax.Var("x", R.Tensor("float32"))
     x3 = relax.Var("x", R.Tensor((2, 10, 4)))
     x4 = relax.Var("x", R.Tensor(ndim=3))
     x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0))
 
     _check_inference(bb, relax.op.cumsum(x0, axis=1), relax.TensorStructInfo((2, 10, 4), "float32"))
+    _check_inference(
+        bb, relax.op.cumsum(x6, axis=1), relax.TensorStructInfo((2, 10, 4), "float32", vdev0)
+    )
     _check_inference(
         bb, relax.op.cumsum(x1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)
     )
diff --git a/tests/python/relax/test_op_ternary.py b/tests/python/relax/test_op_ternary.py
index 5ea7a01da7..120cb47c70 100644
--- a/tests/python/relax/test_op_ternary.py
+++ b/tests/python/relax/test_op_ternary.py
@@ -19,7 +19,7 @@ import tvm
 import tvm.testing
 from tvm import relax, tir
 from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
 from tvm.script import relax as R
 
 
@@ -37,14 +37,21 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
 
 def test_ewise_fma_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
     x1 = relax.Var("x", R.Tensor((2, 3)))
+    x2 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0))
     y0 = relax.Var("y", R.Tensor((2, 3), "float32"))
     y1 = relax.Var("y", R.Tensor(dtype="float32", ndim=2))
+    y2 = relax.Var("y", R.Tensor((2, 3), "float32", vdev0))
     z0 = relax.Var("z", R.Tensor((2, 3), "float32"))
     z1 = relax.Var("z", R.Tensor("float32"))
+    z2 = relax.Var("z", R.Tensor((2, 3), "float32", vdev0))
 
     _check_inference(bb, relax.op.ewise_fma(x0, y0, z0), relax.TensorStructInfo((2, 3), "float32"))
+    _check_inference(
+        bb, relax.op.ewise_fma(x2, y2, z2), relax.TensorStructInfo((2, 3), "float32", vdev0)
+    )
     _check_inference(
         bb, relax.op.ewise_fma(x0, y1, z0), relax.TensorStructInfo(dtype="float32", ndim=2)
     )
diff --git a/tests/python/relax/test_op_unary.py b/tests/python/relax/test_op_unary.py
index 3a6c14fd66..c66f23ae5c 100644
--- a/tests/python/relax/test_op_unary.py
+++ b/tests/python/relax/test_op_unary.py
@@ -20,7 +20,7 @@ import tvm
 import tvm.testing
 from tvm import relax, tir
 from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
 from tvm.script import relax as R
 
 
@@ -97,13 +97,16 @@ unary_arith_op, require_float_dtype = tvm.testing.parameters(
 
 def test_unary_arith_infer_struct_info(unary_arith_op: Callable):
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=3))
     x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
     x3 = relax.Var("x", R.Tensor((2, 3)))
     x4 = relax.Var("x", R.Tensor())
+    x5 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0))
 
     _check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo((2, 3), "float32"))
+    _check_inference(bb, unary_arith_op(x5), relax.TensorStructInfo((2, 3), "float32", vdev0))
     _check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo(dtype="float32", ndim=3))
     _check_inference(bb, unary_arith_op(x2), relax.TensorStructInfo(dtype="float32"))
     _check_inference(bb, unary_arith_op(x3), relax.TensorStructInfo((2, 3), dtype=""))
@@ -186,13 +189,16 @@ def test_unary_arith_infer_struct_info_wrong_input_type(unary_arith_op: Callable
 
 def test_clip_infer_struct_info():
     bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
     x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
     x1 = relax.Var("x", R.Tensor("float32", ndim=3))
     x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
     x3 = relax.Var("x", R.Tensor((2, 3)))
     x4 = relax.Var("x", R.Tensor())
+    x5 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0))
 
     _check_inference(bb, relax.op.clip(x0, 0, 6), relax.TensorStructInfo((2, 3), "float32"))
+    _check_inference(bb, relax.op.clip(x5, 0, 6), relax.TensorStructInfo((2, 3), "float32", vdev0))
     _check_inference(bb, relax.op.clip(x1, 0, 6), relax.TensorStructInfo(dtype="float32", ndim=3))
     _check_inference(bb, relax.op.clip(x2, 0, 6), relax.TensorStructInfo(dtype="float32"))
     _check_inference(bb, relax.op.clip(x3, 0, 6), relax.TensorStructInfo((2, 3), dtype=""))
diff --git a/tests/python/relax/test_transform_update_vdevice.py b/tests/python/relax/test_transform_update_vdevice.py
new file mode 100644
index 0000000000..76caccb536
--- /dev/null
+++ b/tests/python/relax/test_transform_update_vdevice.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.ir import VDevice
+from tvm.relax.transform import UpdateVDevice
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+
+def verify(input, new_vdevice, vdevice_index, expected):
+    tvm.ir.assert_structural_equal(UpdateVDevice(new_vdevice, vdevice_index)(input), expected)
+
+
+def test_update():
+    vdevices = [
+        VDevice("llvm"),
+        VDevice("cuda", 0),
+        VDevice("metal", 0, "global"),
+        VDevice("cuda -arch=sm_80", 0),
+        VDevice("metal", 1, "global"),
+        VDevice("llvm", 1),
+    ]
+
+    @I.ir_module
+    class Input1:
+        I.module_attrs({"attr": 10})
+        I.module_global_infos(
+            {
+                "vdevice": [
+                    I.vdevice("llvm"),
+                    I.vdevice("cuda", 0),
+                    I.vdevice("metal", 0, "global"),
+                    I.vdevice("cuda -arch=sm_80", 0),
+                ]
+            }
+        )
+
+        @R.function
+        def main(
+            a: R.Tensor((128, 128), "float32", "cuda:1"),  # noqa: F722
+            c: R.Tensor((128, 128), "float32", "vdevice:3"),  # noqa: F722
+        ) -> R.Tensor((128, 128), "float32"):
+            s = R.add(a, c)
+            return s
+
+    @I.ir_module
+    class Expect1:
+        I.module_attrs({"attr": 10})
+        I.module_global_infos(
+            {
+                "vdevice": [
+                    I.vdevice("llvm"),
+                    I.vdevice("cuda", 0),
+                    I.vdevice("metal", 0, "global"),
+                    I.vdevice("metal", 1, "global"),
+                ]
+            }
+        )
+
+        @R.function
+        def main(
+            a: R.Tensor((128, 128), dtype="float32", vdevice="metal:1"),  # noqa: F722
+            c: R.Tensor((128, 128), dtype="float32", vdevice="metal:1"),  # noqa: F722
+        ) -> R.Tensor((128, 128), dtype="float32", vdevice="metal:1"):  # noqa: F722
+            s: R.Tensor((128, 128), dtype="float32", vdevice="metal:1") = R.add(a, c)  # noqa: F722
+            return s
+
+    @I.ir_module
+    class Input2:
+        I.module_attrs({"attr": 10})
+        I.module_global_infos(
+            {
+                "vdevice": [
+                    I.vdevice("llvm"),
+                    I.vdevice("cuda", 0),
+                ]
+            }
+        )
+
+        @R.function
+        def main(
+            a: R.Tensor((128, 128), "float32", "cuda:0"),  # noqa: F722
+            c: R.Tensor((128, 128), "float32", "cuda:0"),  # noqa: F722
+        ) -> R.Tensor((128, 128), "float32"):
+            s = R.add(a, c)
+            return s
+
+    @I.ir_module
+    class Expect2:
+        I.module_attrs({"attr": 10})
+        I.module_global_infos(
+            {
+                "vdevice": [
+                    I.vdevice("llvm"),
+                    I.vdevice("llvm", 1),
+                ]
+            }
+        )
+
+        @R.function
+        def main(
+            a: R.Tensor((128, 128), "float32", "llvm:1"),  # noqa: F722
+            c: R.Tensor((128, 128), "float32", "llvm:1"),  # noqa: F722
+        ) -> R.Tensor((128, 128), "float32", "llvm:1"):  # noqa: F722
+            s: R.Tensor((128, 128), "float32", "llvm:1") = R.add(a, c)  # noqa: F722
+            return s
+
+    verify(Input1, vdevices[4], 3, Expect1)
+    verify(Input2, vdevices[5], 1, Expect2)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py
index 39a4d33ca6..1d56bd29b3 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -791,7 +791,7 @@ def test_tensor_with_vdevice():
             a: R.Tensor((128, 128), "float32", "cuda:1"),  # noqa: F722
             b: R.Tensor((128, 128), "float32", "llvm"),
             c: R.Tensor((128, 128), "float32", "vdevice:3"),  # noqa: F722
-        ) -> R.Tensor((128, 128), "float32"):
+        ) -> R.Tensor((128, 128), "float32", "cuda:1"):  # noqa: F722
             s = R.add(a, c)
             return s