You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/08/25 14:03:38 UTC
[tvm] branch unity updated: [Unity] UpdateVDevice pass and infer vdevice (#15570)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0443482f7e [Unity] UpdateVDevice pass and infer vdevice (#15570)
0443482f7e is described below
commit 0443482f7eb128d4daac9a56b0c1f48d59ac24fb
Author: Yong Wu <yo...@gmail.com>
AuthorDate: Fri Aug 25 07:03:30 2023 -0700
[Unity] UpdateVDevice pass and infer vdevice (#15570)
---
include/tvm/relax/struct_info.h | 8 +-
include/tvm/relax/transform.h | 8 +
python/tvm/relax/struct_info.py | 2 -
python/tvm/relax/transform/transform.py | 20 +-
src/relax/analysis/struct_info_analysis.cc | 41 +++--
src/relax/ir/expr.cc | 3 +-
src/relax/ir/struct_info_functor.cc | 5 +-
src/relax/op/image/resize.cc | 6 +
src/relax/op/nn/attention.cc | 3 +
src/relax/op/nn/convolution.cc | 28 +++
src/relax/op/nn/nn.cc | 38 +++-
src/relax/op/nn/pooling.cc | 12 ++
src/relax/op/op_common.h | 26 +++
src/relax/op/tensor/binary.cc | 16 ++
src/relax/op/tensor/create.cc | 3 +
src/relax/op/tensor/index.cc | 30 +++
src/relax/op/tensor/linear_algebra.cc | 46 +++++
src/relax/op/tensor/manipulate.cc | 203 ++++++++++++++++++++-
src/relax/op/tensor/search.cc | 43 +++++
src/relax/op/tensor/set.cc | 38 +++-
src/relax/op/tensor/statistical.cc | 27 +++
src/relax/op/tensor/ternary.cc | 25 ++-
src/relax/transform/alter_op_impl.cc | 4 +
src/relax/transform/convert_layout.cc | 5 +-
src/relax/transform/to_mixed_precision.cc | 2 +-
src/relax/transform/update_vdevice.cc | 114 ++++++++++++
src/script/ir_builder/ir/ir.cc | 2 +-
.../relax/test_analysis_struct_info_analysis.py | 60 +++++-
tests/python/relax/test_op_binary.py | 32 +++-
tests/python/relax/test_op_create.py | 17 +-
tests/python/relax/test_op_image.py | 9 +-
tests/python/relax/test_op_index.py | 17 +-
tests/python/relax/test_op_linear_algebra.py | 16 +-
tests/python/relax/test_op_manipulate.py | 105 ++++++++++-
tests/python/relax/test_op_nn.py | 21 ++-
tests/python/relax/test_op_nn_convolution.py | 30 ++-
tests/python/relax/test_op_nn_pooling.py | 19 +-
tests/python/relax/test_op_search.py | 14 +-
tests/python/relax/test_op_set.py | 11 +-
tests/python/relax/test_op_statistical.py | 12 +-
tests/python/relax/test_op_ternary.py | 9 +-
tests/python/relax/test_op_unary.py | 8 +-
.../python/relax/test_transform_update_vdevice.py | 128 +++++++++++++
tests/python/relax/test_tvmscript_parser.py | 2 +-
44 files changed, 1202 insertions(+), 66 deletions(-)
diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h
index deda5d666e..d2bf525225 100644
--- a/include/tvm/relax/struct_info.h
+++ b/include/tvm/relax/struct_info.h
@@ -220,9 +220,7 @@ class TensorStructInfo : public StructInfo {
*
* \note shape must already be normalized.
*/
- TVM_DLL TensorStructInfo(Expr shape, DataType dtype,
- VDevice vdevice = VDevice(/*tgt*/ {}, /*dev_id*/ 0,
- /*mem_scope*/ "global"),
+ TVM_DLL TensorStructInfo(Expr shape, DataType dtype, VDevice vdevice = VDevice(),
Span span = Span());
/*!
@@ -232,9 +230,7 @@ class TensorStructInfo : public StructInfo {
* \param vdevice The virtual device.
* \param span The span of the AST.
*/
- TVM_DLL TensorStructInfo(DataType dtype, int ndim,
- VDevice vdevice = VDevice(/*tgt*/ {}, /*dev_id*/ 0,
- /*mem_scope*/ "global"),
+ TVM_DLL TensorStructInfo(DataType dtype, int ndim, VDevice vdevice = VDevice(),
Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo, TensorStructInfoNode);
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 05b26f0242..a6ef9ad7ad 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -258,6 +258,14 @@ TVM_DLL Pass LegalizeOps(Optional<Map<String, PackedFunc>> cmap, bool enable_war
*/
TVM_DLL Pass LiftTransformParams();
+/*!
+ * \brief Update virtual device.
+ * \param new_vdevice The new virtual device.
+ * \param index The device index indicates the device on which the update will be performed.
+ * \return The Pass.
+ */
+TVM_DLL Pass UpdateVDevice(VDevice new_vdevice, int64_t index);
+
/*!
* \brief Annotate Op Pattern Kind for TIR functions, which is used in FuseOps.
* \note It is an auto-detect pass for "unscheduled prim_funcs", the op_pattern will be
diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py
index fe30c01d3a..e78e1cf69a 100644
--- a/python/tvm/relax/struct_info.py
+++ b/python/tvm/relax/struct_info.py
@@ -120,8 +120,6 @@ class TensorStructInfo(StructInfo):
) -> None:
if isinstance(shape, (list, tuple, Array)):
shape = ShapeExpr(shape)
- if vdevice is None:
- vdevice = VDevice(None, 0, "global")
self.__init_handle_by_constructor__(
_ffi_api.TensorStructInfo, shape, dtype, ndim, vdevice, span # type: ignore
)
diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py
index 438a6d1213..6c08a6fe68 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -249,7 +249,7 @@ def RemovePurityChecking() -> tvm.ir.transform.Pass:
return _ffi_api.RemovePurityChecking() # type: ignore
-def LambdaLift():
+def LambdaLift() -> tvm.ir.transform.Pass:
"""A pass that lifts local functions into global.
Returns
@@ -312,6 +312,24 @@ def EliminateCommonSubexpr(call_only=False) -> FunctionPass:
return _ffi_api.EliminateCommonSubexpr(call_only) # type: ignore
+def UpdateVDevice(new_vdevice: tvm.ir.VDevice, index: int) -> tvm.ir.transform.Pass:
+ """Update virtual device.
+
+ Parameters
+ ----------
+ new_vdevice : tvm.ir.VDevice
+ The new virtual device.
+ index : int
+ The device index indicates the device on which the update will be performed.
+
+ Returns
+ -------
+ ret : tvm.ir.transform.Pass
+ The registered pass that modifies the virtual device.
+ """
+ return _ffi_api.UpdateVDevice(new_vdevice, index) # type: ignore
+
+
def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
"""Convert all reshape-like call_tir to VM reshape operator call.
The VM reshape operator calls will be further lowered to a CreateView
diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc
index 9fae776279..4a633e9df4 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -149,10 +149,7 @@ class WellDefinedEraser : public StructInfoMutator,
std::swap(has_undefined_, has_undefined);
}
- VDevice vdev = VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
- if (op->vdevice.defined()) {
- vdev = op->vdevice.value();
- }
+ VDevice vdev = op->vdevice.value_or(VDevice());
// erase symbolic shape if we have undefined.
if (!has_undefined) {
@@ -346,6 +343,22 @@ class StructInfoBaseChecker
if (rhs->IsUnknownNdim()) return BaseCheckResult::kFailL1;
return BaseCheckResult::kFailL0;
}
+
+ // vdevice mismatch
+ if (lhs->vdevice.defined() && !rhs->vdevice.defined()) return BaseCheckResult::kFailL1;
+ if (lhs->vdevice.defined() && rhs->vdevice.defined()) {
+ VDevice lhs_vdevice = lhs->vdevice.value();
+ VDevice rhs_vdevice = rhs->vdevice.value();
+ if (lhs_vdevice->target.defined() && !rhs_vdevice->target.defined())
+ return BaseCheckResult::kFailL1;
+ // mismatch in either the target, vdevice_id, or memory_scope
+ if ((lhs_vdevice->target.defined() && rhs_vdevice->target.defined()) &&
+ (lhs_vdevice->target != rhs_vdevice->target ||
+ lhs_vdevice->vdevice_id != rhs_vdevice->vdevice_id ||
+ lhs_vdevice->memory_scope != rhs_vdevice->memory_scope))
+ return BaseCheckResult::kFailL0;
+ }
+
// lhs does not have defined shape and everything else matches
if (!lhs->shape.defined()) return BaseCheckResult::kPass;
// rhs does not have symbolic value but lhs don't
@@ -769,32 +782,28 @@ class StructInfoLCAFinder
auto* rhs = other.as<TensorStructInfoNode>();
if (rhs == nullptr) return ObjectStructInfo(lhs->span);
- // find the target dtype and ndim.
+ // find the target dtype, ndim, and vdevice.
DataType dtype = lhs->dtype == rhs->dtype ? lhs->dtype : DataType::Void();
int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownNDim;
- VDevice vdev = VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
- if (lhs->vdevice.defined() && rhs->vdevice.defined()) {
- if (lhs->vdevice.value().same_as(lhs->vdevice.value())) {
- vdev = lhs->vdevice.value();
- }
- } else if (lhs->vdevice.defined()) {
+ VDevice vdev = VDevice();
+ if (lhs->vdevice.defined() && rhs->vdevice.defined() &&
+ lhs->vdevice.value() == rhs->vdevice.value()) {
vdev = lhs->vdevice.value();
- } else if (rhs->vdevice.defined()) {
- vdev = rhs->vdevice.value();
}
// if ndim mismatch or one side of shape is missing
// then we cannot keep in symbolic shape
if (lhs->ndim != rhs->ndim || !lhs->shape.defined() || !rhs->shape.defined() ||
!CanProveShapeEqual(lhs->shape.value(), rhs->shape.value(), analyzer_)) {
// reuse lhs when possible
- if (!lhs->shape.defined() && lhs->dtype == dtype && lhs->ndim == ndim) {
+ if (!lhs->shape.defined() && lhs->dtype == dtype && lhs->ndim == ndim &&
+ (!lhs->vdevice.defined() || vdev.defined())) {
return GetRef<StructInfo>(lhs);
} else {
return TensorStructInfo(dtype, ndim, vdev, lhs->span);
}
}
- // symbolic shape match but dtype mismatch
- if (lhs->dtype != dtype) {
+ // symbolic shape and vdevice match but dtype mismatch
+ if (lhs->dtype != dtype || (lhs->vdevice.defined() && !vdev.defined())) {
return TensorStructInfo(lhs->shape.value(), dtype, vdev, lhs->span);
} else {
return GetRef<StructInfo>(lhs);
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index 3808325670..ac04096aaf 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -292,8 +292,7 @@ Constant::Constant(runtime::NDArray data, Optional<StructInfo> struct_info_annot
n->struct_info_ = struct_info_annotation.value();
n->checked_type_ = GetStaticType(struct_info_annotation.value());
} else {
- TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(),
- VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global"), span);
+ TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(), VDevice(), span);
n->struct_info_ = tinfo;
n->checked_type_ = DynTensorType(tinfo->ndim, tinfo->dtype);
}
diff --git a/src/relax/ir/struct_info_functor.cc b/src/relax/ir/struct_info_functor.cc
index 10babe4b06..d7929e0f1a 100644
--- a/src/relax/ir/struct_info_functor.cc
+++ b/src/relax/ir/struct_info_functor.cc
@@ -94,10 +94,7 @@ StructInfo StructInfoMutator::VisitStructInfo_(const TensorStructInfoNode* op) {
shape = this->VisitStructInfoExprField(op->shape.value());
}
- VDevice vdev = VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
- if (op->vdevice.defined()) {
- vdev = op->vdevice.value();
- }
+ VDevice vdev = op->vdevice.value_or(VDevice());
if (shape.same_as(op->shape)) {
return GetRef<StructInfo>(op);
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index 3c3bb15136..3a4cb26861 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -90,6 +90,9 @@ StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) {
Optional<ShapeExpr> data_shape =
CheckNdimPerLayoutAndGetShape(call, ctx, GetRef<TensorStructInfo>(data_sinfo), data_layout);
if (!data_shape.defined() || size_value == nullptr) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(out_dtype, data_layout.ndim(), data_sinfo->vdevice.value());
+ }
return TensorStructInfo(out_dtype, data_layout.ndim());
}
@@ -99,6 +102,9 @@ StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) {
out_NCHW_shape.Set(3, size_value->values[1]);
Array<PrimExpr> out_shape = data2NCHW.BackwardShape(out_NCHW_shape);
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(out_shape), out_dtype, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
}
diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc
index 55757552db..4f37e3a33c 100644
--- a/src/relax/op/nn/attention.cc
+++ b/src/relax/op/nn/attention.cc
@@ -108,6 +108,9 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) {
}
Array<PrimExpr> output_shape = {num_batches, num_queries, num_heads, head_dim_value};
+ if (q_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(output_shape), q_sinfo->dtype, q_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(output_shape), q_sinfo->dtype);
}
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index 96fc7e8464..e8cb1916e8 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -77,7 +77,11 @@ StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) {
DataType out_dtype = attrs->out_dtype.is_void()
? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo)
: attrs->out_dtype;
+ Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo);
if (!data_shape.defined() || !weight_shape.defined()) {
+ if (vdevice.defined()) {
+ return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice.value());
+ }
return TensorStructInfo(out_dtype, out_layout.ndim());
}
@@ -121,6 +125,9 @@ StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) {
out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[0]) + 1);
Array<PrimExpr> out_shape = out2NCW.BackwardShape(out_NCW_shape);
+ if (vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
}
@@ -239,7 +246,11 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) {
DataType out_dtype = attrs->out_dtype.is_void()
? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo)
: attrs->out_dtype;
+ Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo);
if (!data_shape.defined() || !weight_shape.defined()) {
+ if (vdevice.defined()) {
+ return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice.value());
+ }
return TensorStructInfo(out_dtype, out_layout.ndim());
}
@@ -288,6 +299,9 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) {
out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[1]) + 1);
Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
+ if (vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
}
@@ -411,7 +425,11 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder&
DataType out_dtype = attrs->out_dtype.is_void()
? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo)
: attrs->out_dtype;
+ Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo);
if (!data_shape.defined() || !weight_shape.defined()) {
+ if (vdevice.defined()) {
+ return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice.value());
+ }
return TensorStructInfo(out_dtype, out_layout.ndim());
}
@@ -465,6 +483,9 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder&
out_NCW_shape[2] = analyzer->Simplify(out_w);
Array<PrimExpr> out_shape = out2NCW.BackwardShape(out_NCW_shape);
+ if (vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
}
@@ -548,7 +569,11 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder&
DataType out_dtype = attrs->out_dtype.is_void()
? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo)
: attrs->out_dtype;
+ Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo);
if (!data_shape.defined() || !weight_shape.defined()) {
+ if (vdevice.defined()) {
+ return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice.value());
+ }
return TensorStructInfo(out_dtype, out_layout.ndim());
}
@@ -610,6 +635,9 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder&
out_NCHW_shape[3] = analyzer->Simplify(out_w);
Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
+ if (vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
}
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index 2e73297e2d..f95cc9f4d6 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -266,6 +266,12 @@ StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) {
DataType dtype = input_sinfo[0]->dtype;
if (unknown_shape) {
+ if (input_sinfo[0]->vdevice.defined()) {
+ VDevice vdev = input_sinfo[0]->vdevice.value();
+ return TupleStructInfo({TensorStructInfo(dtype, input_sinfo[0]->ndim, vdev),
+ TensorStructInfo(dtype, /*ndim=*/1, vdev),
+ TensorStructInfo(dtype, /*ndim=*/1, vdev)});
+ }
return TupleStructInfo({TensorStructInfo(dtype, input_sinfo[0]->ndim),
TensorStructInfo(dtype, /*ndim=*/1),
TensorStructInfo(dtype, /*ndim=*/1)});
@@ -331,6 +337,11 @@ StructInfo InferStructInfoLayerNorm(const Call& call, const BlockBuilder& ctx) {
const auto* attrs = call->attrs.as<LayerNormAttrs>();
bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, attrs->axes);
+ if (input_sinfo[0]->vdevice.defined()) {
+ return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, input_sinfo[0]->ndim,
+ input_sinfo[0]->vdevice.value())
+ : input_sinfo[0];
+ }
return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, input_sinfo[0]->ndim)
: input_sinfo[0];
}
@@ -503,6 +514,11 @@ StructInfo InferStructInfoRMSNorm(const Call& call, const BlockBuilder& ctx) {
const auto* attrs = call->attrs.as<RMSNormAttrs>();
bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, attrs->axes);
+ if (input_sinfo[0]->vdevice.defined()) {
+ return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, input_sinfo[0]->ndim,
+ input_sinfo[0]->vdevice.value())
+ : input_sinfo[0];
+ }
return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, input_sinfo[0]->ndim)
: input_sinfo[0];
}
@@ -578,6 +594,9 @@ StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx
// infer dtype
DataType dtype = InferBinaryArithOpOutDtype(call, ctx, pred_sinfo, label_sinfo);
+ // infer vdevice
+ Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx, pred_sinfo, label_sinfo);
+
// infer ndim
if (!pred_sinfo->IsUnknownNdim() && !label_sinfo->IsUnknownNdim() &&
pred_sinfo->ndim != label_sinfo->ndim) {
@@ -610,6 +629,9 @@ StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx
}
}
}
+ if (vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(Array<PrimExpr>()), dtype, vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(Array<PrimExpr>()), dtype);
}
@@ -685,13 +707,17 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) {
<< call->args[1]->struct_info_->GetTypeKey());
}
- // infer dtype
+ // infer dtype, vdevice
DataType output_dtype;
+ Optional<VDevice> vdevice;
if (wgt_sinfo != nullptr) {
output_dtype = InferBinaryArithOpOutDtype(call, ctx, GetRef<TensorStructInfo>(pred_sinfo),
GetRef<TensorStructInfo>(wgt_sinfo));
+ vdevice = InferBinaryArithOpOutVDevice(call, ctx, GetRef<TensorStructInfo>(pred_sinfo),
+ GetRef<TensorStructInfo>(wgt_sinfo));
} else {
output_dtype = pred_sinfo->dtype;
+ vdevice = pred_sinfo->vdevice;
}
// the type of targets must be int/uint.
@@ -834,13 +860,23 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) {
if (reduction == "none") {
// () or (N,) or (N, d1, d2, ..., dk)
if (pred_sinfo->shape.as<ShapeExprNode>()) {
+ if (vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(output_shape), output_dtype);
} else {
int output_ndim = pred_sinfo->ndim == kUnknownNDim ? kUnknownNDim : pred_sinfo->ndim - 1;
+ if (vdevice.defined()) {
+ return TensorStructInfo(output_dtype, /*ndim=*/output_ndim, vdevice.value());
+ }
return TensorStructInfo(output_dtype, /*ndim=*/output_ndim);
}
} else {
// sum or mean. output is scalar
+ if (vdevice.defined()) {
+ return TensorStructInfo(/*shape=*/ShapeExpr(Array<PrimExpr>()), output_dtype,
+ vdevice.value());
+ }
return TensorStructInfo(/*shape=*/ShapeExpr(Array<PrimExpr>()), output_dtype);
}
}
diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc
index bfbb4b4284..c26fae08c2 100644
--- a/src/relax/op/nn/pooling.cc
+++ b/src/relax/op/nn/pooling.cc
@@ -86,6 +86,9 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) {
Optional<ShapeExpr> data_shape =
CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
if (!data_shape.defined()) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, out_layout.ndim());
}
@@ -114,6 +117,9 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) {
out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[1]) + 1);
Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
}
@@ -204,6 +210,9 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder
!attrs->output_size.defined()) {
return data_sinfo;
} else {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, out_layout.ndim());
}
}
@@ -216,6 +225,9 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder
}
Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
}
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 0e9eee4f01..dd4c3ac173 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -201,6 +201,32 @@ inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder&
return x1_sinfo->dtype;
}
+/*!
+ * \brief Infer the output virtual device for binary arithmetic operators.
+ * \param call The context Call to the operator.
+ * \param ctx The error reporting context.
+ * \param x1_sinfo The struct info of the first operand
+ * \param x2_sinfo The struct info of the second operand
+ * \return The inferred output vdevice.
+ * \throw Throw exception if the vdevice of two input TensorStructInfo don’t match
+ */
+inline Optional<VDevice> InferBinaryArithOpOutVDevice(const Call& call, const BlockBuilder& ctx,
+ const TensorStructInfo& x1_sinfo,
+ const TensorStructInfo& x2_sinfo) {
+ if (!x1_sinfo->vdevice.defined() || !x1_sinfo->vdevice.value()->target.defined()) {
+ return x2_sinfo->vdevice;
+ }
+ if (!x2_sinfo->vdevice.defined() || !x2_sinfo->vdevice.value()->target.defined()) {
+ return x1_sinfo->vdevice;
+ }
+ if (x1_sinfo->vdevice.value() != x2_sinfo->vdevice.value()) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "VDevice " << x1_sinfo->vdevice.value() << " and "
+ << x2_sinfo->vdevice.value() << " must be equal for binary operators");
+ }
+ return x1_sinfo->vdevice;
+}
+
/*!
* \brief Infer the output shape for binary broadcast operators.
* \param call The context Call to the operator.
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index 6483806182..87afc24397 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -39,6 +39,9 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx,
// DateType
DataType output_dtype = f_compute_out_dtype(call, ctx, x1_sinfo, x2_sinfo);
+ // VDevice
+ Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx, x1_sinfo, x2_sinfo);
+
// ndims
int output_ndim;
if (x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) {
@@ -55,14 +58,27 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx,
Optional<Array<PrimExpr>> output_shape =
InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values);
if (!output_shape.defined()) {
+ if (vdevice.defined()) {
+ return TensorStructInfo(output_dtype, /*ndim=*/output_ndim, vdevice.value());
+ }
return TensorStructInfo(output_dtype, /*ndim=*/output_ndim);
+
} else {
ICHECK_EQ(static_cast<int>(output_shape.value().size()), output_ndim);
+ if (vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype);
}
} else if (x1_sinfo->shape.defined() && x1_sinfo->shape.same_as(x2_sinfo->shape)) {
+ if (vdevice.defined()) {
+ return TensorStructInfo(x1_sinfo->shape.value(), output_dtype, vdevice.value());
+ }
return TensorStructInfo(x1_sinfo->shape.value(), output_dtype);
} else {
+ if (vdevice.defined()) {
+ return TensorStructInfo(output_dtype, /*ndim=*/output_ndim, vdevice.value());
+ }
return TensorStructInfo(output_dtype, /*ndim=*/output_ndim);
}
}
diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc
index dabf3155f0..3a4de79b11 100644
--- a/src/relax/op/tensor/create.cc
+++ b/src/relax/op/tensor/create.cc
@@ -77,6 +77,9 @@ StructInfo InferStructInfoFull(const Call& call, const BlockBuilder& ctx) {
const auto* attrs = call->attrs.as<InitAttrs>();
DataType out_dtype = attrs->dtype.is_void() ? fill_value_sinfo->dtype : attrs->dtype;
+ if (fill_value_sinfo->vdevice.defined()) {
+ return TensorStructInfo(/*shape=*/call->args[0], out_dtype, fill_value_sinfo->vdevice.value());
+ }
return TensorStructInfo(/*shape=*/call->args[0], out_dtype);
}
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index 5f1d5149b3..6d9dfc86ba 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -66,6 +66,9 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) {
<< data_sinfo->ndim);
}
if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
}
@@ -75,6 +78,10 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) {
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
if (data_shape == nullptr || indices_shape == nullptr) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim + data_sinfo->ndim - 1,
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim + data_sinfo->ndim - 1);
}
@@ -87,6 +94,10 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) {
output_shape.push_back(data_shape->values[i]);
}
}
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype,
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
}
@@ -180,12 +191,18 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx
}
if (data_sinfo->IsUnknownNdim()) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
}
std::vector<int> axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes);
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
if (data_shape == nullptr) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
}
@@ -199,6 +216,9 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx
for (int i = 0; i < n_axis; ++i) {
const auto* int_stride = strides[i].as<IntImmNode>();
if (!int_stride) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
}
int_strides.push_back(int_stride->value);
@@ -211,6 +231,10 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx
output_shape.Set(axes[i], GetLength(attrs->begin[i], attrs->end[i], int_strides[i],
data_shape->values[axes[i]], attrs->assume_inbound));
}
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype,
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
}
@@ -265,6 +289,9 @@ StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder&
LOG(WARNING) << "When data rank is unknown, dynamic strided slice assumes begin/end/strides "
"tensors are well-formed. It could produce runtime error when this assumption "
"turns out to be wrong.";
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
}
if (data_sinfo->IsUnknownDtype()) {
@@ -305,6 +332,9 @@ StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder&
// The output shape will depend on the runtime value in begin/end/strides tensors.
// TODO(tvm-team): Currently, it is unable to express partially-static shape. Revisit when
// PrimValue lands.
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, n_axis, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, n_axis);
} // namespace relax
diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc
index b05fbaa5d3..62ff189577 100644
--- a/src/relax/op/tensor/linear_algebra.cc
+++ b/src/relax/op/tensor/linear_algebra.cc
@@ -51,12 +51,26 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) {
TensorStructInfo x1_sinfo = input_sinfo[0];
TensorStructInfo x2_sinfo = input_sinfo[1];
+ VDevice vdev = VDevice();
+ if (x1_sinfo->vdevice.defined() && x2_sinfo->vdevice.defined()) {
+ if (x1_sinfo->vdevice.value() == x2_sinfo->vdevice.value()) {
+ vdev = x1_sinfo->vdevice.value();
+ }
+ } else if (x1_sinfo->vdevice.defined()) {
+ vdev = x1_sinfo->vdevice.value();
+ } else if (x2_sinfo->vdevice.defined()) {
+ vdev = x2_sinfo->vdevice.value();
+ }
+
const auto* attrs = call->attrs.as<MatmulAttrs>();
DataType out_dtype = attrs->out_dtype.is_void()
? InferBinaryArithOpOutDtype(call, ctx, x1_sinfo, x2_sinfo)
: attrs->out_dtype;
if (x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) {
+ if (vdev.defined()) {
+ return TensorStructInfo(out_dtype, kUnknownNDim, vdev);
+ }
return TensorStructInfo(out_dtype, kUnknownNDim);
}
int x1_ndim = x1_sinfo->ndim;
@@ -82,6 +96,9 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) {
const auto* x1_shape = x1_sinfo->shape.as<ShapeExprNode>();
const auto* x2_shape = x2_sinfo->shape.as<ShapeExprNode>();
if (x1_shape == nullptr || x2_shape == nullptr) {
+ if (vdev.defined()) {
+ return TensorStructInfo(out_dtype, output_ndim, vdev);
+ }
return TensorStructInfo(out_dtype, output_ndim);
}
@@ -92,6 +109,9 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) {
Optional<Array<PrimExpr>> output_shape_prefix =
InferBinaryBroadcastShape(call, ctx, x1_shape_prefix, x2_shape_prefix);
if (!output_shape_prefix.defined()) {
+ if (vdev.defined()) {
+ return TensorStructInfo(out_dtype, output_ndim, vdev);
+ }
return TensorStructInfo(out_dtype, output_ndim);
}
@@ -113,6 +133,9 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) {
output_shape.push_back(x2_shape->values[x2_ndim - 1]);
}
ICHECK_EQ(static_cast<int>(output_shape.size()), output_ndim);
+ if (vdev.defined()) {
+ return TensorStructInfo(ShapeExpr(output_shape), out_dtype, vdev);
+ }
return TensorStructInfo(ShapeExpr(output_shape), out_dtype);
}
@@ -156,6 +179,23 @@ StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) {
const auto* attrs = call->attrs.as<EinsumAttrs>();
+ bool vdevice_unknown = false;
+ VDevice vdev = VDevice();
+ for (TensorStructInfo sinfo : operands_tensor_sinfo) {
+ if (!vdevice_unknown) {
+ if (sinfo->vdevice.defined()) {
+ if (!vdev.defined()) {
+ vdev = sinfo->vdevice.value();
+ } else if (sinfo->vdevice.value()->target.defined()) {
+ // mismatch
+ if (sinfo->vdevice.value() != vdev) {
+ vdevice_unknown = true;
+ }
+ }
+ }
+ }
+ }
+
String subscripts = attrs->subscripts;
DataType operand_dtype = operands_tensor_sinfo[0]->dtype;
@@ -176,12 +216,18 @@ StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) {
if (shape_expr != nullptr) {
input_shapes.push_back(shape_expr->values);
} else {
+ if (!vdevice_unknown) {
+ return TensorStructInfo(operand_dtype, tensor_sinfo->ndim, vdev);
+ }
return TensorStructInfo(operand_dtype, tensor_sinfo->ndim);
}
}
// Calculate output shape using InferEinsumShape in topi
Array<PrimExpr> oshape = topi::InferEinsumShape(subscripts, input_shapes);
+ if (!vdevice_unknown) {
+ return TensorStructInfo(ShapeExpr(oshape), operand_dtype, vdev);
+ }
return TensorStructInfo(ShapeExpr(oshape), operand_dtype);
}
diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc
index 24d96af3bb..38b761d04f 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -71,10 +71,18 @@ StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx)
// Trust the input target shape when there is no possibility to do any compile-time check.
if (!data_sinfo->shape.defined()) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype,
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype);
}
ShapeStructInfo shape_sinfo = Downcast<ShapeStructInfo>(data_sinfo->shape.value()->struct_info_);
if (!shape_sinfo->values.defined() || !tgt_shape_sinfo->values.defined()) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype,
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype);
}
@@ -100,6 +108,10 @@ StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx)
// Todo(relax-team): revisit here for better check on if the tensor length
// is consistent with the length in the given shape.
}
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype,
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype);
}
@@ -190,8 +202,10 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) {
const auto* attrs = call->attrs.as<ConcatAttrs>();
int output_ndim = attrs->axis.defined() ? kUnknownNDim : 1;
DataType output_dtype = DataType::Void();
+ VDevice vdev = VDevice();
bool shape_unknown = false;
bool is_void_dtype = false;
+ bool vdevice_unknown = false;
std::vector<Array<PrimExpr>> shape_values;
shape_values.reserve(tensor_sinfo.size());
@@ -220,6 +234,20 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) {
<< output_ndim << " and " << sinfo->ndim);
}
+ // Update the virtual device.
+ if (!vdevice_unknown) {
+ if (sinfo->vdevice.defined()) {
+ if (!vdev.defined()) {
+ vdev = sinfo->vdevice.value();
+ } else if (sinfo->vdevice.value()->target.defined()) {
+ // mismatch
+ if (sinfo->vdevice.value() != vdev) {
+ vdevice_unknown = true;
+ }
+ }
+ }
+ }
+
// Update the shape values for best effort check.
const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
if (shape_expr != nullptr) {
@@ -242,6 +270,10 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) {
output_dtype = DataType::Void();
}
if (output_ndim == kUnknownNDim) {
+ if (!vdevice_unknown) {
+ return tensor_sinfo.size() == 1 ? tensor_sinfo[0]
+ : TensorStructInfo(output_dtype, output_ndim, vdev);
+ }
return tensor_sinfo.size() == 1 ? tensor_sinfo[0] : TensorStructInfo(output_dtype, output_ndim);
}
@@ -252,6 +284,9 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) {
return tensor_sinfo[0];
}
if (shape_values.empty()) {
+ if (!vdevice_unknown) {
+ return TensorStructInfo(output_dtype, output_ndim, vdev);
+ }
return TensorStructInfo(output_dtype, output_ndim);
}
@@ -259,8 +294,14 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) {
Optional<Array<PrimExpr>> output_shape = CheckConcatOutputShape(call, ctx, shape_values, axis);
if (shape_unknown || !output_shape.defined()) {
+ if (!vdevice_unknown) {
+ return TensorStructInfo(output_dtype, output_ndim, vdev);
+ }
return TensorStructInfo(output_dtype, output_ndim);
} else {
+ if (!vdevice_unknown) {
+ return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdev);
+ }
return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype);
}
}
@@ -318,6 +359,9 @@ StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx)
}
if (data_sinfo->IsUnknownNdim()) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
}
@@ -327,6 +371,9 @@ StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx)
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
if (data_shape == nullptr) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, output_ndim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, output_ndim);
}
@@ -346,6 +393,10 @@ StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx)
++i_data_shape;
}
ICHECK_EQ(i_data_shape, data_sinfo->ndim);
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype,
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
}
@@ -415,8 +466,14 @@ TVM_REGISTER_GLOBAL("relax.op.flatten").set_body_typed(flatten);
StructInfo InferStructInfoFlatten(const Call& call, const BlockBuilder& ctx) {
TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
if (data_sinfo->IsUnknownNdim()) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, /*ndim=*/1, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, /*ndim=*/1);
} else if (data_sinfo->ndim == 0) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr({1}), data_sinfo->dtype, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr({1}), data_sinfo->dtype);
} else if (data_sinfo->ndim == 1) {
return data_sinfo;
@@ -424,9 +481,16 @@ StructInfo InferStructInfoFlatten(const Call& call, const BlockBuilder& ctx) {
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
if (data_shape == nullptr) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, /*ndim=*/1, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, /*ndim=*/1);
}
PrimExpr shape_prod = ComputeShapeProduct(data_shape->values);
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr({std::move(shape_prod)}), data_sinfo->dtype,
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr({std::move(shape_prod)}), data_sinfo->dtype);
}
@@ -471,6 +535,10 @@ StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder&
if (data_sinfo->IsUnknownNdim()) {
// Todo(relax-team): revisit here for better check on if the input tensor has desired ndim.
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size(),
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size());
}
@@ -483,16 +551,28 @@ StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder&
}
if (!data_sinfo->shape.defined()) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size(),
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size());
}
ShapeStructInfo shape_sinfo = Downcast<ShapeStructInfo>(data_sinfo->shape.value()->struct_info_);
if (!shape_sinfo->values.defined()) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size(),
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size());
}
arith::Analyzer analyzer;
Array<PrimExpr> output_shape = index_map->MapShape(shape_sinfo->values.value(), &analyzer);
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype,
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
}
@@ -534,6 +614,9 @@ StructInfo InferStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx)
// Todo(relax-team): revisit here for better check on if the input tensor has
// ndim same as the number of input axes.
if (!attrs->axes.defined() && data_sinfo->IsUnknownNdim()) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
}
@@ -561,6 +644,9 @@ StructInfo InferStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx)
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
if (data_shape == nullptr) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
}
std::vector<PrimExpr> new_shape;
@@ -568,6 +654,9 @@ StructInfo InferStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx)
for (int i = 0; i < data_sinfo->ndim; ++i) {
new_shape.push_back(data_shape->values[axes[i]]);
}
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(new_shape), data_sinfo->dtype, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(new_shape), data_sinfo->dtype);
}
@@ -759,8 +848,15 @@ StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) {
Expr target_shape = call->args[1];
// If shape values are defined, use them
if (target_shape->IsInstance<VarNode>() && new_shape_sinfo->values.defined()) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(new_shape_sinfo->values.value()), data_sinfo->dtype,
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(new_shape_sinfo->values.value()), data_sinfo->dtype);
}
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(target_shape, data_sinfo->dtype, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(target_shape, data_sinfo->dtype);
}
@@ -818,6 +914,11 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) {
}
// Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape.
if (data_shape == nullptr) {
+ if (data_sinfo->vdevice.defined()) {
+ return TupleStructInfo(Array<StructInfo>(
+ p_indices->size() + 1,
+ TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value())));
+ }
return TupleStructInfo(Array<StructInfo>(
p_indices->size() + 1, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim)));
}
@@ -826,6 +927,11 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) {
const auto* axis_length = data_shape->values[axis].as<IntImmNode>();
// Fall back to unknown shape when the input tensor shape at the given axis is symbolic.
if (axis_length == nullptr) {
+ if (data_sinfo->vdevice.defined()) {
+ return TupleStructInfo(Array<StructInfo>(
+ p_indices->size() + 1,
+ TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value())));
+ }
return TupleStructInfo(Array<StructInfo>(
p_indices->size() + 1, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim)));
}
@@ -844,7 +950,12 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) {
Array<PrimExpr> shape = data_shape->values;
shape.Set(axis, tvm::max(zero, r - l));
- output_sinfo.push_back(TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype));
+ if (data_sinfo->vdevice.defined()) {
+ output_sinfo.push_back(
+ TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice.value()));
+ } else {
+ output_sinfo.push_back(TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype));
+ }
}
return TupleStructInfo(output_sinfo);
} else if (const auto* p_n_section = attrs->indices_or_sections.as<IntImmNode>()) {
@@ -856,6 +967,11 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) {
}
// Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape.
if (data_shape == nullptr) {
+ if (data_sinfo->vdevice.defined()) {
+ return TupleStructInfo(Array<StructInfo>(
+ n_section,
+ TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value())));
+ }
return TupleStructInfo(
Array<StructInfo>(n_section, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim)));
}
@@ -865,12 +981,22 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) {
// Construct struct info for tensors except the last one.
Array<PrimExpr> shape = data_shape->values;
shape.Set(axis, split_len);
+ if (data_sinfo->vdevice.defined()) {
+ std::vector<StructInfo> output_sinfo(
+ n_section - 1,
+ TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice.value()));
+ }
std::vector<StructInfo> output_sinfo(n_section - 1,
TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype));
// Construct struct info for the last tensor.
shape.Set(axis, data_shape->values[axis] - split_len * (n_section - 1));
- output_sinfo.push_back(TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype));
+ if (data_sinfo->vdevice.defined()) {
+ output_sinfo.push_back(
+ TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice.value()));
+ } else {
+ output_sinfo.push_back(TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype));
+ }
return TupleStructInfo(output_sinfo);
}
ICHECK(false) << "Cannot reach here.";
@@ -928,6 +1054,9 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) {
}
if (data_sinfo->IsUnknownNdim()) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
}
@@ -943,6 +1072,10 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) {
std::vector<int> axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value());
if (!shape_value.defined()) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim - axes.size(),
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim - axes.size());
}
for (int i = 0; i < static_cast<int>(axes.size()); ++i) {
@@ -965,12 +1098,18 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) {
// (https://data-apis.org/array-api/latest/API_specification/generated/array_api.squeeze.html).
// Consider discourage usage later.
if (!shape_value.defined()) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
}
for (int i = 0; i < data_sinfo->ndim; ++i) {
// Whenever a dimension length is symbolic, fall back to unknown ndim.
const auto* int_len = shape_value.value()[i].as<IntImmNode>();
if (int_len == nullptr) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
}
if (int_len->value == 1) {
@@ -991,11 +1130,22 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) {
if (static_cast<int>(output_shape.size()) == data_sinfo->ndim) {
return data_sinfo;
} else if (attrs->axis.defined()) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, output_shape.size(),
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, output_shape.size());
} else {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
}
} else {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype,
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
}
}
@@ -1132,8 +1282,16 @@ StructInfo InferStructInfoCollapseSumLike(const Call& call, const BlockBuilder&
}
if (collapse_target_sinfo->shape.defined()) {
+ if (collapse_target_sinfo->vdevice.defined()) {
+ return TensorStructInfo(collapse_target_sinfo->shape.value(), output_dtype,
+ collapse_target_sinfo->vdevice.value());
+ }
return TensorStructInfo(collapse_target_sinfo->shape.value(), output_dtype);
} else {
+ if (collapse_target_sinfo->vdevice.defined()) {
+ return TensorStructInfo(output_dtype, collapse_target_sinfo->ndim,
+ collapse_target_sinfo->vdevice.value());
+ }
return TensorStructInfo(output_dtype, collapse_target_sinfo->ndim);
}
}
@@ -1185,7 +1343,9 @@ StructInfo InferStructInfoCollapseSumTo(const Call& call, const BlockBuilder& ct
if (data_shape_value.defined() && shape_sinfo->values.defined()) {
CheckCollapseShape(call, ctx, data_shape_value.value(), shape_sinfo->values.value());
}
-
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(/*shape=*/call->args[1], output_dtype, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(/*shape=*/call->args[1], output_dtype);
}
@@ -1234,9 +1394,15 @@ StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) {
// the shape does not changes
return data_sinfo;
} else {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
}
} else {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, 1, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, 1);
}
}
@@ -1244,12 +1410,19 @@ StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) {
if (!attrs->axis.defined()) {
PrimExpr new_shape =
analyzer->Simplify(ComputeShapeProduct(data_shape->values) * attrs->repeats);
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(Array<PrimExpr>({new_shape})), data_sinfo->dtype,
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(Array<PrimExpr>({new_shape})), data_sinfo->dtype);
}
int axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()->value);
auto shape_array = data_shape->values;
shape_array.Set(axis, analyzer->Simplify(shape_array[axis] * attrs->repeats));
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(shape_array), data_sinfo->dtype, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(shape_array), data_sinfo->dtype);
}
@@ -1284,13 +1457,23 @@ StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) {
if (data_shape == nullptr) {
if (data_sinfo->IsUnknownNdim()) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
}
if (l > ndim) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, l, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, l);
} else {
for (auto i : attrs->repeats) {
if (!analyzer->CanProveEqual(i, 1)) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim,
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
}
}
@@ -1313,6 +1496,10 @@ StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) {
analyzer->Simplify(data_shape->values[i - ndim_delta] * attrs->repeats[i - l_delta]));
}
}
+
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
}
@@ -1395,6 +1582,9 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder&
if (data_sinfo->IsUnknownNdim()) {
// When `data` has unknown rank, assume rest of arguments are correct and proceed.
// If the assumption turns out to be wrong, runtime error will be triggered.
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
}
@@ -1461,8 +1651,15 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder&
}
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
if (data_shape) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype,
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype);
}
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
}
diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc
index e1d684916c..14fa287494 100644
--- a/src/relax/op/tensor/search.cc
+++ b/src/relax/op/tensor/search.cc
@@ -44,6 +44,21 @@ StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) {
TensorStructInfo x1_sinfo = input_sinfo[1];
TensorStructInfo x2_sinfo = input_sinfo[2];
+ VDevice vdev = VDevice();
+ for (int i = 0; i < 3; ++i) {
+ if (input_sinfo[i]->vdevice.defined()) {
+ if (!vdev.defined()) {
+ vdev = input_sinfo[i]->vdevice.value();
+ } else if (input_sinfo[i]->vdevice.value()->target.defined()) {
+ // mismatch
+ if (input_sinfo[i]->vdevice.value() != vdev) {
+ vdev = VDevice();
+ break;
+ }
+ }
+ }
+ }
+
if (!cond_sinfo->dtype.is_bool()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Where requires the input condition tensor to have boolean dtype. However, "
@@ -67,23 +82,38 @@ StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) {
Optional<Array<PrimExpr>> broadcasted_shape =
InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values);
if (!broadcasted_shape.defined()) {
+ if (vdev.defined()) {
+ return TensorStructInfo(output_dtype, output_ndim, vdev);
+ }
return TensorStructInfo(output_dtype, output_ndim);
}
// Step 2. Compute the broadcasted shape of cond's and the previous broadcasted shape.
broadcasted_shape =
InferBinaryBroadcastShape(call, ctx, cond_shape->values, broadcasted_shape.value());
if (!broadcasted_shape.defined()) {
+ if (vdev.defined()) {
+ return TensorStructInfo(output_dtype, output_ndim, vdev);
+ }
return TensorStructInfo(output_dtype, output_ndim);
}
ICHECK_EQ(static_cast<int>(broadcasted_shape.value().size()), output_ndim);
+ if (vdev.defined()) {
+ return TensorStructInfo(ShapeExpr(broadcasted_shape.value()), output_dtype, vdev);
+ }
return TensorStructInfo(ShapeExpr(broadcasted_shape.value()), output_dtype);
} else if (cond_sinfo->shape.defined() && //
x1_sinfo->shape.defined() && //
x2_sinfo->shape.defined() && //
cond_sinfo->shape.same_as(x1_sinfo->shape) && //
cond_sinfo->shape.same_as(x2_sinfo->shape)) {
+ if (vdev.defined()) {
+ return TensorStructInfo(cond_sinfo->shape.value(), output_dtype, vdev);
+ }
return TensorStructInfo(cond_sinfo->shape.value(), output_dtype);
} else {
+ if (vdev.defined()) {
+ return TensorStructInfo(output_dtype, output_ndim, vdev);
+ }
return TensorStructInfo(output_dtype, output_ndim);
}
}
@@ -131,9 +161,19 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
if (data_shape == nullptr) {
if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(
+ ShapeExpr(Array<PrimExpr>(out_ndim, IntImm(out_dtype, /*value=*/1))), out_dtype,
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(Array<PrimExpr>(out_ndim, IntImm(out_dtype, /*value=*/1))),
out_dtype);
} else {
+ if (data_sinfo->vdevice.defined()) {
+ return out_ndim == 0 ? TensorStructInfo(ShapeExpr(Array<PrimExpr>()), out_dtype,
+ data_sinfo->vdevice.value())
+ : TensorStructInfo(out_dtype, out_ndim, data_sinfo->vdevice.value());
+ }
return out_ndim == 0 ? TensorStructInfo(ShapeExpr(Array<PrimExpr>()), out_dtype)
: TensorStructInfo(out_dtype, out_ndim);
}
@@ -153,6 +193,9 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx
}
}
ICHECK_EQ(static_cast<int>(out_shape.size()), out_ndim);
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(out_shape), out_dtype, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
}
diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc
index cb6a332d49..3920cccadd 100644
--- a/src/relax/op/tensor/set.cc
+++ b/src/relax/op/tensor/set.cc
@@ -86,21 +86,45 @@ StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) {
// unique values
if (data_sinfo->ndim == 0) {
- output_sinfo.push_back(
- TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), data_sinfo->dtype));
+ if (data_sinfo->vdevice.defined()) {
+ output_sinfo.push_back(TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}),
+ data_sinfo->dtype, data_sinfo->vdevice.value()));
+ } else {
+ output_sinfo.push_back(
+ TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), data_sinfo->dtype));
+ }
} else if (axis.defined()) {
- output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim));
+ if (data_sinfo->vdevice.defined()) {
+ output_sinfo.push_back(
+ TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice.value()));
+ } else {
+ output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim));
+ }
} else {
- output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, /*ndim=*/1));
+ if (data_sinfo->vdevice.defined()) {
+ output_sinfo.push_back(
+ TensorStructInfo(data_sinfo->dtype, /*ndim=*/1, data_sinfo->vdevice.value()));
+ } else {
+ output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, /*ndim=*/1));
+ }
}
// index, reverse and counts
TensorStructInfo int_return{nullptr};
if (data_sinfo->ndim == 0) {
- int_return =
- TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), DataType::Int(64));
+ if (data_sinfo->vdevice.defined()) {
+ int_return = TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}),
+ DataType::Int(64), data_sinfo->vdevice.value());
+ } else {
+ int_return =
+ TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), DataType::Int(64));
+ }
} else {
- int_return = TensorStructInfo(DataType::Int(64), /*ndim=*/1);
+ if (data_sinfo->vdevice.defined()) {
+ int_return = TensorStructInfo(DataType::Int(64), /*ndim=*/1, data_sinfo->vdevice.value());
+ } else {
+ int_return = TensorStructInfo(DataType::Int(64), /*ndim=*/1);
+ }
}
for (int i = 0; i < n_int_return; ++i) {
output_sinfo.push_back(int_return);
diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc
index 6d1cc86f0a..c450738a1d 100644
--- a/src/relax/op/tensor/statistical.cc
+++ b/src/relax/op/tensor/statistical.cc
@@ -61,10 +61,21 @@ StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx)
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
if (data_shape == nullptr) {
if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(
+ ShapeExpr(Array<PrimExpr>(out_ndim, IntImm(DataType::Int(64), /*value=*/1))),
+ data_sinfo->dtype, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(
ShapeExpr(Array<PrimExpr>(out_ndim, IntImm(DataType::Int(64), /*value=*/1))),
data_sinfo->dtype);
} else {
+ if (data_sinfo->vdevice.defined()) {
+ return out_ndim == 0
+ ? TensorStructInfo(ShapeExpr(Array<PrimExpr>()), data_sinfo->dtype,
+ data_sinfo->vdevice.value())
+ : TensorStructInfo(data_sinfo->dtype, out_ndim, data_sinfo->vdevice.value());
+ }
return out_ndim == 0 ? TensorStructInfo(ShapeExpr(Array<PrimExpr>()), data_sinfo->dtype)
: TensorStructInfo(data_sinfo->dtype, out_ndim);
}
@@ -80,6 +91,9 @@ StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx)
}
}
ICHECK_EQ(static_cast<int>(out_shape.size()), out_ndim);
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
}
@@ -158,19 +172,32 @@ StructInfo InferStructInfoCumsum(const Call& call, const BlockBuilder& ctx) {
// flattened
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
if (data_shape == nullptr) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(out_type, data_sinfo->ndim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(out_type, data_sinfo->ndim);
} else {
PrimExpr flattened_d = 1;
for (const auto v : data_shape->values) {
flattened_d *= v;
}
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(Array<PrimExpr>({flattened_d})), out_type,
+ data_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(Array<PrimExpr>({flattened_d})), out_type);
}
}
if (data_sinfo->shape.defined()) {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(data_sinfo->shape.value(), out_type, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(data_sinfo->shape.value(), out_type);
} else {
+ if (data_sinfo->vdevice.defined()) {
+ return TensorStructInfo(out_type, data_sinfo->ndim, data_sinfo->vdevice.value());
+ }
return TensorStructInfo(out_type, data_sinfo->ndim);
}
}
diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc
index d1ff5b7863..cc265ad9a1 100644
--- a/src/relax/op/tensor/ternary.cc
+++ b/src/relax/op/tensor/ternary.cc
@@ -65,6 +65,21 @@ StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) {
output_dtype = t1->dtype;
}
+ VDevice vdev = VDevice();
+ for (int i = 0; i < 3; ++i) {
+ if (input_sinfo[i]->vdevice.defined()) {
+ if (!vdev.defined()) {
+ vdev = input_sinfo[i]->vdevice.value();
+ } else if (input_sinfo[i]->vdevice.value()->target.defined()) {
+ // mismatch
+ if (input_sinfo[i]->vdevice.value() != vdev) {
+ vdev = VDevice();
+ break;
+ }
+ }
+ }
+ }
+
auto* s1 = t1->shape.as<ShapeExprNode>();
auto* s2 = t2->shape.as<ShapeExprNode>();
auto* s3 = t3->shape.as<ShapeExprNode>();
@@ -82,11 +97,19 @@ StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) {
<< "The 3 arguments of EwiseFMA must have the same shape");
}
}
+ if (vdev.defined()) {
+ return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdev);
+ }
return TensorStructInfo(ShapeExpr(output_shape), output_dtype);
} else if (t1->shape.defined() && t1->shape.same_as(t2->shape) && t1->shape.same_as(t3->shape)) {
+ if (vdev.defined()) {
+ return TensorStructInfo(t1->shape.value(), output_dtype, vdev);
+ }
return TensorStructInfo(t1->shape.value(), output_dtype);
}
-
+ if (vdev.defined()) {
+ return TensorStructInfo(output_dtype, ndim, vdev);
+ }
return TensorStructInfo(output_dtype, ndim);
}
diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc
index c303a2c8f0..9813c4ed24 100644
--- a/src/relax/transform/alter_op_impl.cc
+++ b/src/relax/transform/alter_op_impl.cc
@@ -264,6 +264,10 @@ class AlterOpImplMutator : public ExprMutator {
auto shape = GetShapeFromTensorStructInfo(tensor_sinfo);
arith::Analyzer analyzer;
auto new_shape = transform->MapShape(shape, &analyzer);
+ if (tensor_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(new_shape), tensor_sinfo->dtype,
+ tensor_sinfo->vdevice.value());
+ }
return TensorStructInfo(ShapeExpr(new_shape), tensor_sinfo->dtype);
}
diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc
index 3c7be6959d..f91d221b40 100644
--- a/src/relax/transform/convert_layout.cc
+++ b/src/relax/transform/convert_layout.cc
@@ -267,10 +267,7 @@ class LayoutConvertMutator : public ExprMutator {
new_shape.push_back(
shape->values[from.LeafValue()->layout.IndexOf(to.LeafValue()->layout[i])]);
}
- VDevice vdev = VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
- if (tsinfo->vdevice.defined()) {
- vdev = tsinfo->vdevice.value();
- }
+ VDevice vdev = tsinfo->vdevice.value_or(VDevice());
return TensorStructInfo(ShapeExpr(new_shape), tsinfo->dtype, vdev, tsinfo->span);
};
StructInfo new_struct_info = TransformTupleLeaf<LayoutDecision>(
diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc
index b864a65969..d12d1080b9 100644
--- a/src/relax/transform/to_mixed_precision.cc
+++ b/src/relax/transform/to_mixed_precision.cc
@@ -289,7 +289,7 @@ class ToMixedPrecisionRewriter : public ExprMutator {
if (fp16_input_names_.count(var->name_hint())) {
auto sinfo = GetStructInfo(var);
if (auto tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
- VDevice vdev = VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
+ VDevice vdev = VDevice();
if (tensor_sinfo->vdevice.defined()) {
vdev = tensor_sinfo->vdevice.value();
}
diff --git a/src/relax/transform/update_vdevice.cc b/src/relax/transform/update_vdevice.cc
new file mode 100644
index 0000000000..a964f80ea1
--- /dev/null
+++ b/src/relax/transform/update_vdevice.cc
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/update_vdevice.cc
+ * \brief Update Virtual Device pass.
+ */
+
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+
+class VDeviceMutator : public ExprMutator {
+ public:
+ VDeviceMutator(const IRModule& mod, VDevice new_vdevice, int64_t index)
+ : ExprMutator(mod), mod_(mod), new_vdevice_(new_vdevice) {
+ Array<GlobalInfo> vdevices = mod->global_infos["vdevice"];
+ old_vdevice_ = Downcast<VDevice>(vdevices[index]);
+ }
+
+ using ExprMutator::VisitExpr_;
+
+ Expr VisitExpr(const Expr& expr) final {
+ auto visited_expr = ExprMutator::VisitExpr(expr);
+ if (visited_expr->struct_info_.defined()) {
+ auto* tinfo = GetStructInfoAs<TensorStructInfoNode>(visited_expr);
+ bool unchanged = true;
+ if (tinfo != nullptr) {
+ if (tinfo->vdevice.defined()) {
+ VDevice cur_vdevice = tinfo->vdevice.value();
+ if (cur_vdevice == old_vdevice_) {
+ unchanged = false;
+ }
+ }
+ }
+ if (!unchanged) {
+ if (tinfo->shape.defined()) {
+ visited_expr->struct_info_ =
+ TensorStructInfo(tinfo->shape.value(), tinfo->dtype, new_vdevice_, tinfo->span);
+ } else {
+ visited_expr->struct_info_ =
+ TensorStructInfo(tinfo->dtype, tinfo->ndim, new_vdevice_, tinfo->span);
+ }
+ }
+ }
+ return visited_expr;
+ }
+
+ IRModule Run() {
+ for (const auto& [gv, func] : mod_->functions) {
+ if (func->IsInstance<relax::FunctionNode>()) {
+ relax::Function update_func = Downcast<Function>(VisitExpr(func));
+ builder_->UpdateFunction(gv, update_func);
+ }
+ }
+ Array<GlobalInfo> new_vdevices;
+ for (auto vdev : mod_->global_infos["vdevice"]) {
+ if (vdev == old_vdevice_) {
+ new_vdevices.push_back(new_vdevice_);
+ } else {
+ new_vdevices.push_back(vdev);
+ }
+ }
+ IRModule new_mod = builder_->GetContextIRModule();
+ new_mod->UpdateGlobalInfo("vdevice", new_vdevices);
+ return new_mod;
+ }
+
+ private:
+ /*! \brief Input IRModule */
+ IRModule mod_;
+ /*! \brief The new virtual device */
+ VDevice new_vdevice_;
+ /*! \brief The virtual device to be updated */
+ VDevice old_vdevice_;
+};
+
+namespace transform {
+
+Pass UpdateVDevice(VDevice new_vdevice, int64_t index) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule m,
+ PassContext pc) {
+ return relax::VDeviceMutator(m, new_vdevice, index).Run();
+ };
+ return CreateModulePass(/*pass_function=*/pass_func,
+ /*opt_level=*/0,
+ /*pass_name=*/"UpdateVDevice",
+ /*required=*/{});
+}
+TVM_REGISTER_GLOBAL("relax.transform.UpdateVDevice").set_body_typed(UpdateVDevice);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc
index 4cb19fb48c..2f2785ca44 100644
--- a/src/script/ir_builder/ir/ir.cc
+++ b/src/script/ir_builder/ir/ir.cc
@@ -136,7 +136,7 @@ VDevice LookupVDevice(String target_kind, int device_index) {
}
}
LOG(WARNING) << "The annotated device was not found, please check your vdevice list.";
- return VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
+ return VDevice();
}
TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule);
diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py
index d279b60b54..879194037c 100644
--- a/tests/python/relax/test_analysis_struct_info_analysis.py
+++ b/tests/python/relax/test_analysis_struct_info_analysis.py
@@ -23,7 +23,7 @@ import tvm
import tvm.testing
from tvm import TVMError
from tvm import relax as rx
-from tvm import tir
+from tvm import tir, ir
def test_get_static_type_basic():
@@ -218,6 +218,12 @@ def test_base_check():
shape3 = rx.ShapeStructInfo([1, 2, 3])
shape4 = rx.ShapeStructInfo([1, n, 3])
+ vdevice0 = ir.VDevice()
+ vdevice1 = ir.VDevice("llvm")
+ vdevice2 = ir.VDevice("cuda", 0)
+ vdevice3 = ir.VDevice("cuda", 2)
+ vdevice4 = ir.VDevice("cuda", 0, "")
+
tensor0 = rx.TensorStructInfo(ndim=-1, dtype="int32")
tensor1 = rx.TensorStructInfo(ndim=-1, dtype="float32")
tensor2 = rx.TensorStructInfo(ndim=2, dtype="int32")
@@ -225,6 +231,16 @@ def test_base_check():
tensor4 = rx.TensorStructInfo([n, m], "int32")
tensor5 = rx.TensorStructInfo([n, m, 1], "int32")
tensor6 = rx.TensorStructInfo([n, m, 2], "int32")
+ tensor7 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice0)
+ tensor8 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice1)
+ tensor9 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice2)
+ tensor10 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice3)
+ tensor11 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice4)
+ tensor12 = rx.TensorStructInfo([n, m, 2], "int32", vdevice0)
+ tensor13 = rx.TensorStructInfo([n, m, 2], "int32", vdevice1)
+ tensor14 = rx.TensorStructInfo([n, m, 2], "int32", vdevice2)
+ tensor15 = rx.TensorStructInfo([n, m, 2], "int32", vdevice3)
+ tensor16 = rx.TensorStructInfo([n, m, 2], "int32", vdevice4)
# obj
assert bcheck(obj0, prim0) == BR.PASS
@@ -271,6 +287,14 @@ def test_base_check():
assert bcheck(tensor3, tensor4) == BR.FAIL_L0
assert bcheck(tensor1, tensor2) == BR.FAIL_L0
+ # vdevice mismatch
+ assert bcheck(tensor8, tensor9) == BR.FAIL_L0
+ assert bcheck(tensor9, tensor10) == BR.FAIL_L0
+ assert bcheck(tensor10, tensor11) == BR.FAIL_L0
+ assert bcheck(tensor13, tensor14) == BR.FAIL_L0
+ assert bcheck(tensor14, tensor15) == BR.FAIL_L0
+ assert bcheck(tensor15, tensor16) == BR.FAIL_L0
+
# ndim mismatch
assert bcheck(tensor2, tensor5) == BR.FAIL_L0
@@ -284,6 +308,10 @@ def test_base_check():
assert tensor0.is_base_of(tensor5)
assert tensor0.is_base_of(tensor6)
assert tensor2.is_base_of(tensor4)
+ assert tensor3.is_base_of(tensor7)
+ assert tensor3.is_base_of(tensor8)
+ assert tensor6.is_base_of(tensor12)
+ assert tensor6.is_base_of(tensor13)
assert tensor4.is_base_of(rx.TensorStructInfo([n, m], dtype="int32"))
# tuple
@@ -386,6 +414,22 @@ def test_derive_call_ret_struct_info():
with pytest.raises(TVMError):
_check_derive(bb, func0(2), [obj0], obj0)
+ # Tensor with vdevice
+ vdev = ir.VDevice("llvm")
+
+ def func1(c):
+ n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
+ x = rx.TensorStructInfo([n, m], "float32", vdev)
+ z = rx.TensorStructInfo([m + c, n], "float32", vdev)
+ return rx.FuncStructInfo([x], z)
+
+ _check_derive(
+ bb,
+ func1(1),
+ [rx.TensorStructInfo([10, 11], "float32", vdev)],
+ rx.TensorStructInfo([12, 10], "float32", vdev),
+ )
+
# opaque derivation
fopaque0 = lambda: rx.FuncStructInfo.opaque_func()
fopaque1 = lambda: rx.FuncStructInfo.opaque_func(ret=prim0)
@@ -477,6 +521,9 @@ def test_struct_info_lca():
prim0 = rx.PrimStructInfo("int32")
prim1 = rx.PrimStructInfo("float32")
+ vdevice0 = ir.VDevice("llvm")
+ vdevice1 = ir.VDevice("cuda", 0)
+
shape0 = rx.ShapeStructInfo(ndim=-1)
shape1 = rx.ShapeStructInfo(ndim=2)
shape2 = rx.ShapeStructInfo(ndim=3)
@@ -490,6 +537,10 @@ def test_struct_info_lca():
tensor4 = rx.TensorStructInfo([n, m], "int32")
tensor5 = rx.TensorStructInfo([n, m, 1], "int32")
tensor6 = rx.TensorStructInfo([n, m, 2], "int32")
+ tensor7 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice0)
+ tensor8 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice1)
+ tensor9 = rx.TensorStructInfo([n, m, 2], "int32", vdevice0)
+ tensor10 = rx.TensorStructInfo([n, m, 2], "int32", vdevice1)
# obj
_check_lca(obj0, prim0, obj0)
@@ -510,6 +561,13 @@ def test_struct_info_lca():
_check_lca(tensor0, tensor1, rx.TensorStructInfo(ndim=-1, dtype=None))
_check_lca(tensor0, tensor2, tensor0)
_check_lca(tensor0, tensor4, tensor0)
+ _check_lca(tensor0, tensor4, tensor0)
+ _check_lca(tensor1, tensor3, tensor1)
+ _check_lca(tensor3, tensor7, tensor3)
+ _check_lca(tensor3, tensor8, tensor3)
+ _check_lca(tensor1, tensor8, tensor1)
+ _check_lca(tensor6, tensor9, tensor6)
+ _check_lca(tensor6, tensor10, tensor6)
_check_lca(tensor2, tensor4, tensor2)
_check_lca(tensor5, tensor6, rx.TensorStructInfo(ndim=3, dtype="int32"))
diff --git a/tests/python/relax/test_op_binary.py b/tests/python/relax/test_op_binary.py
index ce9e5d507e..a0ec08f0ab 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -20,7 +20,7 @@ import tvm
import tvm.testing
from tvm import relax, tir
from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
from tvm.script import relax as R
@@ -73,16 +73,22 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
def test_binary_arith_infer_struct_info(binary_arith_op: Callable):
bb = relax.BlockBuilder()
+ vdevice0 = VDevice("llvm")
+ vdevice1 = VDevice("cuda", 0)
x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
x1 = relax.Var("x", R.Tensor((1, 3), "float32"))
x2 = relax.Var("x", R.Tensor((3, 2, 3), "float32"))
x3 = relax.Var("x", R.Tensor((3, 1, 3), "float32"))
x4 = relax.Var("x", R.Tensor("float32", ndim=2))
x5 = relax.Var("x", R.Tensor())
+ x6 = relax.Var("x", R.Tensor("float32", ndim=2, vdevice=vdevice0))
+ x7 = relax.Var("x", R.Tensor((2, 3), "float32", vdevice0))
y0 = relax.Var("y", R.Tensor((2, 3), "float32"))
y1 = relax.Var("y", R.Tensor((4, 3, 2, 1), "float32"))
y2 = relax.Var("y", R.Tensor("float32", ndim=2))
y3 = relax.Var("y", R.Tensor("float32", ndim=-1))
+ y4 = relax.Var("y", R.Tensor((2, 3), "float32", vdevice0))
+ y5 = relax.Var("y", R.Tensor("float32", ndim=2, vdevice=vdevice0))
_check_inference(bb, binary_arith_op(x0, y0), relax.TensorStructInfo((2, 3), "float32"))
_check_inference(bb, binary_arith_op(x1, y0), relax.TensorStructInfo((2, 3), "float32"))
@@ -94,6 +100,19 @@ def test_binary_arith_infer_struct_info(binary_arith_op: Callable):
_check_inference(bb, binary_arith_op(x4, y2), relax.TensorStructInfo(dtype="float32", ndim=2))
_check_inference(bb, binary_arith_op(x4, y3), relax.TensorStructInfo(dtype="float32", ndim=-1))
_check_inference(bb, binary_arith_op(x5, y0), relax.TensorStructInfo(dtype="", ndim=-1))
+ _check_inference(
+ bb,
+ binary_arith_op(x6, y5),
+ relax.TensorStructInfo(dtype="float32", ndim=2, vdevice=vdevice0),
+ )
+ _check_inference(
+ bb,
+ binary_arith_op(x6, y2),
+ relax.TensorStructInfo(dtype="float32", ndim=2, vdevice=vdevice0),
+ )
+ _check_inference(
+ bb, binary_arith_op(x7, y4), relax.TensorStructInfo((2, 3), "float32", vdevice0)
+ )
(binary_cmp_op,) = tvm.testing.parameters(
@@ -108,15 +127,18 @@ def test_binary_arith_infer_struct_info(binary_arith_op: Callable):
def test_binary_cmp_infer_struct_info(binary_cmp_op: Callable):
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x = relax.Var("x", R.Tensor((2, 3), "float32"))
y0 = relax.Var("y", R.Tensor((2, 3), "float32"))
y1 = relax.Var("y", R.Tensor((2, 3), "int32"))
+ y2 = relax.Var("y", R.Tensor((2, 3), "float32", vdev0))
_check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool"))
_check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool"))
_check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool"))
_check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool"))
_check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool"))
_check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool"))
+ _check_inference(bb, binary_cmp_op(x, y2), relax.TensorStructInfo((2, 3), "bool", vdev0))
def test_binary_infer_struct_info_shape_symbolic(binary_arith_op: Callable):
@@ -198,6 +220,14 @@ def test_binary_arith_infer_struct_info_dtype_mismatch(binary_arith_op: Callable
bb.normalize(binary_arith_op(x, y))
+def test_binary_arith_infer_struct_info_vdevice_mismatch(binary_arith_op: Callable):
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3), "float32", VDevice("llvm")))
+ y = relax.Var("y", R.Tensor((2, 3), "int32", VDevice("cuda")))
+ with pytest.raises(TVMError):
+ bb.normalize(binary_arith_op(x, y))
+
+
def test_binary_wrong_input_number(binary_arith_op: Callable):
x = relax.Var("x", R.Tensor((2, 3), "float32"))
diff --git a/tests/python/relax/test_op_create.py b/tests/python/relax/test_op_create.py
index 345e68b9b2..c0b4308529 100644
--- a/tests/python/relax/test_op_create.py
+++ b/tests/python/relax/test_op_create.py
@@ -19,7 +19,7 @@ import pytest
import tvm
import tvm.testing
from tvm import TVMError, relax, tir
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
from tvm.script import relax as R
from tvm.script import tir as T
@@ -45,10 +45,12 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
def test_full_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
v0 = relax.Var("v", R.Tensor((), "float32"))
v1 = relax.Var("v", R.Tensor("float32", ndim=0))
v2 = relax.Var("v", R.Tensor(()))
v3 = relax.Var("v", R.Tensor(ndim=0))
+ v4 = relax.Var("v", R.Tensor((), "float32", vdev0))
s0 = relax.ShapeExpr((2, 3))
s1 = relax.Var("s", relax.ShapeStructInfo((2, 3)))
s2 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
@@ -62,6 +64,7 @@ def test_full_infer_struct_info():
bb, relax.op.full(s0, v0, "float16"), relax.TensorStructInfo((2, 3), "float16")
)
_check_inference(bb, relax.op.full(s0, v0), relax.TensorStructInfo((2, 3), "float32"))
+ _check_inference(bb, relax.op.full(s0, v4), relax.TensorStructInfo((2, 3), "float32", vdev0))
_check_inference(bb, relax.op.full(s1, v0, "float16"), relax.TensorStructInfo(s1, "float16"))
_check_inference(bb, relax.op.full(s1, v0), relax.TensorStructInfo(s1, "float32"))
_check_inference(bb, relax.op.full(s2, v0, "float16"), relax.TensorStructInfo(s2, "float16"))
@@ -300,6 +303,7 @@ def test_full_like_infer_struct_info_shape_symbolic():
def test_full_like_infer_struct_info_shape_var():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
s0 = relax.Var("s", relax.ShapeStructInfo((2, 3)))
s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
s2 = relax.Var("s", relax.ShapeStructInfo())
@@ -307,11 +311,13 @@ def test_full_like_infer_struct_info_shape_var():
x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
x3 = relax.Var("x", R.Tensor((2, 3), "float32"))
+ x4 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0))
sv0 = relax.Var("sv", relax.ShapeStructInfo(()))
sv1 = relax.Var("sv", relax.ShapeStructInfo(ndim=0))
v0 = relax.Var("v", relax.TensorStructInfo(sv0, "float16"))
v1 = relax.Var("v", relax.TensorStructInfo(sv1, "float16"))
v2 = relax.Var("v", R.Tensor((), "float16"))
+ v3 = relax.Var("v", relax.TensorStructInfo(sv1, "float16", vdev0))
_check_inference(bb, relax.op.full_like(x0, v0), relax.TensorStructInfo(s0, "float32"))
_check_inference(bb, relax.op.full_like(x0, v1), relax.TensorStructInfo(s0, "float32"))
@@ -324,6 +330,9 @@ def test_full_like_infer_struct_info_shape_var():
_check_inference(bb, relax.op.full_like(x2, v2), relax.TensorStructInfo(s2, "float32"))
_check_inference(bb, relax.op.full_like(x3, v0), relax.TensorStructInfo((2, 3), "float32"))
_check_inference(bb, relax.op.full_like(x3, v1), relax.TensorStructInfo((2, 3), "float32"))
+ _check_inference(
+ bb, relax.op.full_like(x4, v3), relax.TensorStructInfo((2, 3), "float32", vdev0)
+ )
def test_full_like_infer_struct_info_more_input_dtype():
@@ -605,12 +614,14 @@ def test_arange_infer_struct_info_shape_var():
def test_tril_triu_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=3))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((2, 3, 4)))
x4 = relax.Var("x", R.Tensor(ndim=3))
x5 = relax.Var("x", R.Tensor())
+ x6 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0))
_check_inference(bb, relax.op.tril(x0, k=1), relax.TensorStructInfo((2, 3, 4), "float32"))
_check_inference(bb, relax.op.triu(x0, k=0), relax.TensorStructInfo((2, 3, 4), "float32"))
@@ -620,18 +631,22 @@ def test_tril_triu_infer_struct_info():
_check_inference(bb, relax.op.triu(x3), relax.TensorStructInfo((2, 3, 4), dtype=""))
_check_inference(bb, relax.op.tril(x4), relax.TensorStructInfo(dtype="", ndim=3))
_check_inference(bb, relax.op.triu(x5), relax.TensorStructInfo(dtype=""))
+ _check_inference(bb, relax.op.tril(x6), relax.TensorStructInfo((2, 3, 4), "float32", vdev0))
def test_tril_triu_infer_struct_info_shape_symbolic():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
a = tir.Var("a", "int64")
b = tir.Var("b", "int64")
c = tir.Var("c", "int64")
x0 = relax.Var("x", R.Tensor((a, b, c), "float32"))
x1 = relax.Var("x", R.Tensor((a, b, c)))
+ x2 = relax.Var("x", R.Tensor((a, b, c), "float32", vdev0))
_check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo((a, b, c), "float32"))
_check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo((a, b, c), dtype=""))
+ _check_inference(bb, relax.op.tril(x2), relax.TensorStructInfo((a, b, c), "float32", vdev0))
def test_tril_triu_infer_struct_info_shape_var():
diff --git a/tests/python/relax/test_op_image.py b/tests/python/relax/test_op_image.py
index b06b51a2a1..251a30139b 100644
--- a/tests/python/relax/test_op_image.py
+++ b/tests/python/relax/test_op_image.py
@@ -19,7 +19,7 @@ import tvm
import tvm.testing
from tvm import relax, tir
from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
from tvm.script import relax as R
@@ -35,6 +35,7 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
def test_resize2d_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32"))
x2 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32"))
@@ -43,10 +44,16 @@ def test_resize2d_infer_struct_info():
x5 = relax.Var("x", R.Tensor("float32"))
x6 = relax.Var("x", R.Tensor(ndim=4))
x7 = relax.Var("x", R.Tensor())
+ x8 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32", vdev0))
_check_inference(
bb, relax.op.image.resize2d(x0, (28, 28)), relax.TensorStructInfo((2, 3, 28, 28), "float32")
)
+ _check_inference(
+ bb,
+ relax.op.image.resize2d(x8, (28, 28)),
+ relax.TensorStructInfo((2, 3, 28, 28), "float32", vdev0),
+ )
_check_inference(
bb,
relax.op.image.resize2d(x0, size=28),
diff --git a/tests/python/relax/test_op_index.py b/tests/python/relax/test_op_index.py
index fffd31f6f7..e3c9e4a596 100644
--- a/tests/python/relax/test_op_index.py
+++ b/tests/python/relax/test_op_index.py
@@ -19,7 +19,7 @@ import tvm
import tvm.testing
from tvm import relax, tir
from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
from tvm.script import ir as I, relax as R, tir as T
@@ -40,12 +40,14 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
def test_take_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((4, 10), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=2))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((4, 10)))
x4 = relax.Var("x", R.Tensor(ndim=2))
x5 = relax.Var("x", R.Tensor())
+ x6 = relax.Var("x", R.Tensor((4, 10), "float32", vdev0))
y0 = relax.Var("y", R.Tensor((10,), "float32"))
y1 = relax.Var("y", R.Tensor("float32", ndim=1))
y2 = relax.Var("y", R.Tensor((10,)))
@@ -58,8 +60,12 @@ def test_take_infer_struct_info():
idx5 = relax.Var("idx", R.Tensor("int64", ndim=2))
idx6 = relax.Var("idx", R.Tensor((6, 4)))
idx7 = relax.Var("idx", R.Tensor(ndim=2))
+ idx8 = relax.Var("idx", R.Tensor((6,), "int64", vdev0))
_check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((4, 6), "float32"))
+ _check_inference(
+ bb, relax.op.take(x6, idx8, axis=1), relax.TensorStructInfo((4, 6), "float32", vdev0)
+ )
_check_inference(
bb, relax.op.take(x0, idx0, axis=-1), relax.TensorStructInfo((4, 6), "float32")
)
@@ -355,12 +361,14 @@ def test_take_infer_struct_info_wrong_input_type():
def test_strided_slice_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=4))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((8, 9, 10, 10)))
x4 = relax.Var("x", R.Tensor(ndim=4))
x5 = relax.Var("x", R.Tensor())
+ x6 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32", vdev0))
_check_inference(
bb,
@@ -369,6 +377,13 @@ def test_strided_slice_infer_struct_info():
),
relax.TensorStructInfo((4, 9, 10, 3), "float32"),
)
+ _check_inference(
+ bb,
+ relax.op.strided_slice(
+ x6, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3]
+ ),
+ relax.TensorStructInfo((4, 9, 10, 3), "float32", vdev0),
+ )
_check_inference(
bb,
relax.op.strided_slice(
diff --git a/tests/python/relax/test_op_linear_algebra.py b/tests/python/relax/test_op_linear_algebra.py
index ccea5f79eb..fc30ea8619 100644
--- a/tests/python/relax/test_op_linear_algebra.py
+++ b/tests/python/relax/test_op_linear_algebra.py
@@ -19,7 +19,7 @@ import tvm
import tvm.testing
from tvm import relax, tir
from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
from tvm.script import relax as R
@@ -36,6 +36,7 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
def test_matmul_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((3, 4), "float32"))
x1 = relax.Var("x", R.Tensor((4,), "float32"))
x2 = relax.Var("x", R.Tensor((2, 3, 5, 4), "float32"))
@@ -43,14 +44,17 @@ def test_matmul_infer_struct_info():
x4 = relax.Var("x", R.Tensor((2, 1, 4, 5)))
x5 = relax.Var("x", R.Tensor("float32"))
x6 = relax.Var("x", R.Tensor((2, 1, 4, 5), "float16"))
+ x7 = relax.Var("x", R.Tensor((3, 4), "float32", vdev0))
y0 = relax.Var("y", R.Tensor((4, 5), "float32"))
y1 = relax.Var("y", R.Tensor((4,), "float32"))
y2 = relax.Var("y", R.Tensor((2, 3, 4, 5), "float32"))
y3 = relax.Var("y", R.Tensor((6, 1, 3, 5, 7), "float32"))
y4 = relax.Var("y", R.Tensor("float32", ndim=5))
y5 = relax.Var("y", R.Tensor())
+ y6 = relax.Var("y", R.Tensor((4, 5), "float32", vdev0))
_check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo((3, 5), "float32"))
+ _check_inference(bb, relax.op.matmul(x7, y6), relax.TensorStructInfo((3, 5), "float32", vdev0))
_check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo((), "float32"))
_check_inference(bb, relax.op.matmul(x1, y2), relax.TensorStructInfo((2, 3, 5), "float32"))
_check_inference(bb, relax.op.matmul(x2, y1), relax.TensorStructInfo((2, 3, 5), "float32"))
@@ -208,19 +212,26 @@ def test_linear():
# Since linear is only a sugar for transpose + matmul + add,
# we only have brief tests here.
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x1 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
x2 = relax.Var("x", R.Tensor("float32"))
+ x3 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0))
w1 = relax.Var("w", R.Tensor((5, 4), "float32"))
w2 = relax.Var("w", R.Tensor((4,), "float32"))
w3 = relax.Var("w", R.Tensor("float32"))
+ w4 = relax.Var("w", R.Tensor((5, 4), "float32", vdev0))
b1 = relax.Var("b", R.Tensor((5,), "float32"))
b2 = relax.Var("b", R.Tensor((), "float32"))
+ b3 = relax.Var("b", R.Tensor((5,), "float32", vdev0))
# Need a scope to normalize non-leaf nodes
with bb.function("func", [x1]):
_check_inference(
bb, relax.op.linear(x1, w1, b1), relax.TensorStructInfo((2, 3, 5), "float32")
)
+ _check_inference(
+ bb, relax.op.linear(x3, w4, b3), relax.TensorStructInfo((2, 3, 5), "float32", vdev0)
+ )
_check_inference(
bb, relax.op.linear(x1, w1, b2), relax.TensorStructInfo((2, 3, 5), "float32")
)
@@ -242,6 +253,7 @@ def test_linear():
def test_einsum_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x0", R.Tensor((), "float32"))
x1 = relax.Var("x1", R.Tensor((5,), "int32"))
x2 = relax.Var("x2", R.Tensor((5, 5), "int32"))
@@ -258,8 +270,10 @@ def test_einsum_infer_struct_info():
x13 = relax.Var("x13", R.Tensor((1, 1, 1, 3), "float16"))
x14 = relax.Var("x14", R.Tensor((1, 5, 3, 8, 4), "float32"))
x15 = relax.Var("x15", R.Tensor((2, 5, 3, 6, 4), "float32"))
+ x16 = relax.Var("x16", R.Tensor((5, 5), "int32", vdev0))
_check_inference(bb, relax.op.einsum((x2,), "ii"), relax.TensorStructInfo((), "int32"))
+ _check_inference(bb, relax.op.einsum((x16,), "ii"), relax.TensorStructInfo((), "int32", vdev0))
_check_inference(bb, relax.op.einsum((x2,), "ii->i"), relax.TensorStructInfo((5,), "int32"))
_check_inference(bb, relax.op.einsum([x2], "...j->..."), relax.TensorStructInfo((5,), "int32"))
_check_inference(
diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py
index 40993b5da5..6c0fbcf227 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -19,7 +19,7 @@ import tvm
import tvm.testing
from tvm import relax, tir
from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
from tvm.script import relax as R
@@ -54,12 +54,14 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
def test_reshape_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=4))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((2, 3, 4, 5)))
x4 = relax.Var("x", R.Tensor(ndim=4))
x5 = relax.Var("x", R.Tensor())
+ x6 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0))
s0 = relax.Var("s", R.Shape((3, 8, 5)))
s1 = relax.Var("s", R.Shape(ndim=3))
s2 = relax.Var("s", R.Shape())
@@ -68,6 +70,9 @@ def test_reshape_infer_struct_info():
_check_inference(
bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32")
)
+ _check_inference(
+ bb, relax.op.reshape(x6, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32", vdev0)
+ )
_check_inference(
bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32")
)
@@ -319,6 +324,7 @@ def test_reshape_infer_struct_info_wrong_input_type():
def test_permute_dims_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=4))
x2 = relax.Var("x", R.Tensor("float32"))
@@ -327,10 +333,16 @@ def test_permute_dims_infer_struct_info():
x5 = relax.Var("x", R.Tensor())
x6 = relax.Var("x", R.Tensor((1,), "float32"))
x7 = relax.Var("x", R.Tensor((), "float32"))
+ x8 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32", vdev0))
_check_inference(
bb, relax.op.permute_dims(x0, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "float32")
)
+ _check_inference(
+ bb,
+ relax.op.permute_dims(x8, [2, 3, 1, 0]),
+ relax.TensorStructInfo((3, 4, 2, 1), "float32", vdev0),
+ )
_check_inference(
bb, relax.op.permute_dims(x0, axes=None), relax.TensorStructInfo((4, 3, 2, 1), "float32")
)
@@ -529,16 +541,23 @@ def test_permute_dims_infer_struct_info_wrong_input_type():
def test_expand_dims_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=3))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((2, 3, 4)))
x4 = relax.Var("x", R.Tensor(ndim=3))
x5 = relax.Var("x", R.Tensor())
+ x6 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0))
_check_inference(
bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "float32")
)
+ _check_inference(
+ bb,
+ relax.op.expand_dims(x6, [1, 3]),
+ relax.TensorStructInfo((2, 1, 3, 1, 4), "float32", vdev0),
+ )
_check_inference(
bb,
relax.op.expand_dims(x0, [-1, 1, -6, 3, 5]),
@@ -690,7 +709,9 @@ def test_expand_dims_infer_struct_info_wrong_input_type():
def test_layout_transform_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x = relax.Var("x", R.Tensor((10, 20, 30), "float32"))
+ x1 = relax.Var("x", R.Tensor((10, 20, 30), "float32", vdev0))
transpose_transform = lambda a, b, c: (a, c, b)
_check_inference(
@@ -698,6 +719,11 @@ def test_layout_transform_infer_struct_info():
relax.op.layout_transform(x, index_map=transpose_transform),
relax.TensorStructInfo((10, 30, 20), "float32"),
)
+ _check_inference(
+ bb,
+ relax.op.layout_transform(x1, index_map=transpose_transform),
+ relax.TensorStructInfo((10, 30, 20), "float32", vdev0),
+ )
tiling_transform = lambda a, b, c: (a, b // 2, c, b % 2)
_check_inference(
@@ -812,16 +838,21 @@ def test_layout_transform_infer_struct_info_invalid_index_map():
def test_squeeze_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=6))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4)))
x4 = relax.Var("x", R.Tensor(ndim=6))
x5 = relax.Var("x", R.Tensor())
+ x6 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32", vdev0))
_check_inference(
bb, relax.op.squeeze(x0, [1, 4]), relax.TensorStructInfo((2, 3, 1, 4), "float32")
)
+ _check_inference(
+ bb, relax.op.squeeze(x6, [1, 4]), relax.TensorStructInfo((2, 3, 1, 4), "float32", vdev0)
+ )
_check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo((2, 3, 4), "float32"))
_check_inference(
bb, relax.op.squeeze(x1, [1, 4]), relax.TensorStructInfo(dtype="float32", ndim=4)
@@ -983,6 +1014,7 @@ def test_squeeze_infer_struct_info_wrong_input_type():
def test_flatten_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
x1 = relax.Var("x", R.Tensor((3,), "float32"))
x2 = relax.Var("x", R.Tensor((), "float32"))
@@ -997,8 +1029,10 @@ def test_flatten_infer_struct_info():
x11 = relax.Var("x", R.Tensor(ndim=1))
x12 = relax.Var("x", R.Tensor(ndim=0))
x13 = relax.Var("x", R.Tensor())
+ x14 = relax.Var("x", R.Tensor((3, 4, 5), "float32", vdev0))
_check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo((60,), "float32"))
+ _check_inference(bb, relax.op.flatten(x14), relax.TensorStructInfo((60,), "float32", vdev0))
_check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo((3,), "float32"))
_check_inference(bb, relax.op.flatten(x2), relax.TensorStructInfo((1,), "float32"))
_check_inference(bb, relax.op.flatten(x3), relax.TensorStructInfo(dtype="float32", ndim=1))
@@ -1083,28 +1117,42 @@ def test_flatten_wrong_input_number():
def test_concat_infer_struct_info_with_axis():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=3))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((2, 3, 4)))
x4 = relax.Var("x", R.Tensor(ndim=3))
x5 = relax.Var("x", R.Tensor())
+ x6 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0))
y0 = relax.Var("y", R.Tensor((2, 4, 4), "float32"))
y1 = relax.Var("y", R.Tensor("float32", ndim=3))
y2 = relax.Var("y", R.Tensor("float32"))
y3 = relax.Var("y", R.Tensor((2, 4, 4)))
y4 = relax.Var("y", R.Tensor(ndim=3))
y5 = relax.Var("y", R.Tensor())
+ y6 = relax.Var("y", R.Tensor((2, 4, 4), "float32", vdev0))
z0 = relax.Var("z", R.Tensor((2, 5, 4), "float32"))
z1 = relax.Var("z", R.Tensor("float32", ndim=3))
z2 = relax.Var("z", R.Tensor("float32"))
z3 = relax.Var("z", R.Tensor((2, 5, 4)))
z4 = relax.Var("z", R.Tensor(ndim=3))
z5 = relax.Var("z", R.Tensor())
+ z6 = relax.Var("z", R.Tensor((2, 5, 4), "float32", vdev0))
_check_inference(
bb, relax.op.concat([x0, y0, z0], axis=1), relax.TensorStructInfo((2, 12, 4), "float32")
)
+ _check_inference(
+ bb,
+ relax.op.concat([x6, y6, z6], axis=1),
+ relax.TensorStructInfo((2, 12, 4), "float32", vdev0),
+ )
+ _check_inference(
+ bb,
+ relax.op.concat([x6, y0, z0], axis=1),
+ relax.TensorStructInfo((2, 12, 4), "float32", vdev0),
+ )
_check_inference(
bb, relax.op.concat([x0, y0, z0], axis=-2), relax.TensorStructInfo((2, 12, 4), "float32")
)
@@ -1658,12 +1706,14 @@ def test_concat_infer_struct_info_input_tuple_field_not_tensor():
def test_split_infer_struct_info_by_indices():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=3))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((2, 10, 4)))
x4 = relax.Var("x", R.Tensor(ndim=3))
x5 = relax.Var("x", R.Tensor())
+ x6 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0))
_check_inference(
bb,
@@ -1676,6 +1726,17 @@ def test_split_infer_struct_info_by_indices():
]
),
)
+ _check_inference(
+ bb,
+ relax.op.split(x6, [3, 7], axis=1),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((2, 3, 4), "float32", vdev0),
+ relax.TensorStructInfo((2, 4, 4), "float32", vdev0),
+ relax.TensorStructInfo((2, 3, 4), "float32", vdev0),
+ ]
+ ),
+ )
_check_inference(
bb,
relax.op.split(x0, [3, 7], axis=-2),
@@ -2176,16 +2237,23 @@ def test_split_infer_struct_info_wrong_input_type():
def test_broadcast_to_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 1, 3), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=3))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((2, 1, 3)))
x4 = relax.Var("x", R.Tensor(ndim=3))
x5 = relax.Var("x", R.Tensor())
+ x6 = relax.Var("x", R.Tensor((2, 1, 3), "float32", vdev0))
_check_inference(
bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32")
)
+ _check_inference(
+ bb,
+ relax.op.broadcast_to(x6, (4, 2, 5, 3)),
+ relax.TensorStructInfo((4, 2, 5, 3), "float32", vdev0),
+ )
_check_inference(
bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32")
)
@@ -2405,10 +2473,11 @@ def test_broadcast_to_infer_struct_info_wrong_input_type():
def test_collapse_sum_like_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=3))
x2 = relax.Var("x", R.Tensor("float32"))
- x3 = relax.Var("x", R.Tensor((2, 3, 4)))
+ x3 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0))
x4 = relax.Var("x", R.Tensor(ndim=3))
x5 = relax.Var("x", R.Tensor())
y0 = relax.Var("y", R.Tensor((3, 4), "float32"))
@@ -2417,10 +2486,14 @@ def test_collapse_sum_like_infer_struct_info():
y3 = relax.Var("y", R.Tensor((3, 4)))
y4 = relax.Var("y", R.Tensor(ndim=2))
y5 = relax.Var("y", R.Tensor((1, 4)))
+ y6 = relax.Var("y", R.Tensor((3, 4), "float32", vdev0))
_check_inference(
bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((3, 4), "float32")
)
+ _check_inference(
+ bb, relax.op.collapse_sum_like(x3, y6), relax.TensorStructInfo((3, 4), "float32", vdev0)
+ )
_check_inference(
bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo(dtype="float32", ndim=2)
)
@@ -2728,18 +2801,25 @@ def test_collapse_sum_to_infer_struct_info_struct_info_tgt_shape_var():
def test_repeat_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=3))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((2, 10, 4)))
x4 = relax.Var("x", R.Tensor(ndim=3))
x5 = relax.Var("x", R.Tensor())
+ x6 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0))
_check_inference(
bb,
relax.op.repeat(x0, 2, axis=0),
relax.TensorStructInfo((4, 10, 4), "float32"),
)
+ _check_inference(
+ bb,
+ relax.op.repeat(x6, 2, axis=0),
+ relax.TensorStructInfo((4, 10, 4), "float32", vdev0),
+ )
_check_inference(
bb,
relax.op.repeat(x0, 2, axis=-2),
@@ -2853,18 +2933,25 @@ def test_repeat_infer_struct_info_wrong_input_type():
def test_tile_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=3))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((2, 10, 4)))
x4 = relax.Var("x", R.Tensor(ndim=3))
x5 = relax.Var("x", R.Tensor())
+ x6 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0))
_check_inference(
bb,
relax.op.tile(x0, 2),
relax.TensorStructInfo((2, 10, 8), "float32"),
)
+ _check_inference(
+ bb,
+ relax.op.tile(x6, 2),
+ relax.TensorStructInfo((2, 10, 8), "float32", vdev0),
+ )
_check_inference(
bb,
relax.op.tile(x0, (3, 2)),
@@ -2971,13 +3058,18 @@ def test_tile_infer_struct_info_wrong_input_type():
def test_flip_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
x1 = relax.Var("x", R.Tensor("float16", ndim=3))
x2 = relax.Var("x", R.Tensor("int32"))
x3 = relax.Var("x", R.Tensor((2, 10, 4)))
x4 = relax.Var("x", R.Tensor(ndim=3))
+ x5 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0))
_check_inference(bb, relax.op.flip(x0, axis=1), relax.TensorStructInfo((2, 10, 4), "float32"))
+ _check_inference(
+ bb, relax.op.flip(x5, axis=1), relax.TensorStructInfo((2, 10, 4), "float32", vdev0)
+ )
_check_inference(bb, relax.op.flip(x1, axis=0), R.Tensor("float16", ndim=3))
_check_inference(bb, relax.op.flip(x2, axis=0), R.Tensor("int32"))
_check_inference(bb, relax.op.flip(x3, axis=2), R.Tensor((2, 10, 4)))
@@ -3003,19 +3095,28 @@ def test_flip_infer_struct_info_wrong_inputs():
def test_scatter_elements_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
d0 = relax.Var("data", R.Tensor((4, 4), "float32"))
d1 = relax.Var("data", R.Tensor(dtype="float32", ndim=2))
d2 = relax.Var("data", R.Tensor("float32"))
+ d3 = relax.Var("data", R.Tensor((4, 4), "float32", vdev0))
i0 = relax.Var("indices", R.Tensor((2, 2), "int64"))
i1 = relax.Var("indices", R.Tensor((2, 2)))
i2 = relax.Var("indices", R.Tensor(dtype="int64", ndim=2))
i3 = relax.Var("indices", R.Tensor(ndim=2))
+ i4 = relax.Var("indices", R.Tensor((2, 2), "int64", vdev0))
u0 = relax.Var("updates", R.Tensor((2, 2), "float32"))
+ u1 = relax.Var("updates", R.Tensor((2, 2), "float32", vdev0))
_check_inference(
bb,
relax.op.scatter_elements(d0, i0, u0, 0, "updates"),
relax.TensorStructInfo((4, 4), dtype="float32"),
)
+ _check_inference(
+ bb,
+ relax.op.scatter_elements(d3, i4, u1, 0, "updates"),
+ relax.TensorStructInfo((4, 4), dtype="float32", vdevice=vdev0),
+ )
_check_inference(
bb,
relax.op.scatter_elements(d1, i0, u0, 0, "updates"),
diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py
index de1cf079a5..7adfc84283 100644
--- a/tests/python/relax/test_op_nn.py
+++ b/tests/python/relax/test_op_nn.py
@@ -19,7 +19,7 @@ import tvm
import tvm.testing
from tvm import relax, tir
from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
from tvm.script import relax as R
@@ -58,14 +58,17 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
def test_linear_unit_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=3))
x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
x3 = relax.Var("x", R.Tensor((2, 3)))
x4 = relax.Var("x", R.Tensor())
x5 = relax.Var("x", R.Tensor((3, 4)))
+ x6 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0))
_check_inference(bb, relax.op.nn.relu(x0), relax.TensorStructInfo((2, 3), "float32"))
+ _check_inference(bb, relax.op.nn.relu(x6), relax.TensorStructInfo((2, 3), "float32", vdev0))
_check_inference(bb, relax.op.nn.silu(x1), relax.TensorStructInfo(dtype="float32", ndim=3))
_check_inference(bb, relax.op.nn.gelu(x2), relax.TensorStructInfo(dtype="float32"))
_check_inference(bb, relax.op.nn.relu(x3), relax.TensorStructInfo((2, 3), dtype=""))
@@ -133,13 +136,16 @@ def test_linear_unit_infer_struct_info_wrong_input_type():
def test_softmax_log_softmax_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=3))
x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
x3 = relax.Var("x", R.Tensor((2, 3)))
x4 = relax.Var("x", R.Tensor())
+ x5 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0))
_check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2, 3), "float32"))
+ _check_inference(bb, relax.op.nn.softmax(x5), relax.TensorStructInfo((2, 3), "float32", vdev0))
_check_inference(
bb, relax.op.nn.softmax(x1, axis=0), relax.TensorStructInfo(dtype="float32", ndim=3)
)
@@ -1096,11 +1102,13 @@ def test_group_norm_infer_struct_info_wrong_input_type():
def test_dropout_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=3))
x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
x3 = relax.Var("x", R.Tensor((2, 3)))
x4 = relax.Var("x", R.Tensor())
+ x5 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0))
_check_inference(
bb,
@@ -1109,6 +1117,16 @@ def test_dropout_infer_struct_info():
[relax.TensorStructInfo((2, 3), "float32"), relax.TensorStructInfo((2, 3), "float32")]
),
)
+ _check_inference(
+ bb,
+ relax.op.nn.dropout(x5),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((2, 3), "float32", vdev0),
+ relax.TensorStructInfo((2, 3), "float32", vdev0),
+ ]
+ ),
+ )
_check_inference(
bb,
relax.op.nn.dropout(x1),
@@ -1220,6 +1238,7 @@ def test_dropout_infer_struct_info_wrong_input_type():
def test_cross_entropy_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x = relax.Var("x", R.Tensor((2, 3), "float32"))
y0 = relax.Var("y", R.Tensor((2, 3), "float32"))
y1 = relax.Var("y", R.Tensor("float32", ndim=2))
diff --git a/tests/python/relax/test_op_nn_convolution.py b/tests/python/relax/test_op_nn_convolution.py
index 2ec451c132..6be1245fe2 100644
--- a/tests/python/relax/test_op_nn_convolution.py
+++ b/tests/python/relax/test_op_nn_convolution.py
@@ -19,7 +19,7 @@ import tvm
import tvm.testing
from tvm import relax, tir
from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
from tvm.script import relax as R
@@ -44,19 +44,25 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
def test_conv1d_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
x1 = relax.Var("x", R.Tensor((2, 28, 3), "float32"))
x2 = relax.Var("x", R.Tensor("float32", ndim=3))
x3 = relax.Var("x", R.Tensor("float32"))
x4 = relax.Var("x", R.Tensor())
x5 = relax.Var("x", R.Tensor((2, 4, 28, 16), "float32"))
+ x6 = relax.Var("x", R.Tensor((2, 3, 28), "float32", vdev0))
w0 = relax.Var("w", R.Tensor((4, 3, 3), "float32"))
w1 = relax.Var("w", R.Tensor((3, 4, 3), "float32"))
w2 = relax.Var("w", R.Tensor("float32", ndim=3))
w3 = relax.Var("w", R.Tensor("float32"))
w4 = relax.Var("w", R.Tensor((48, 4, 3, 16), "float32"))
+ w5 = relax.Var("w", R.Tensor((4, 3, 3), "float32", vdev0))
_check_inference(bb, relax.op.nn.conv1d(x0, w0), relax.TensorStructInfo((2, 4, 26), "float32"))
+ _check_inference(
+ bb, relax.op.nn.conv1d(x6, w5), relax.TensorStructInfo((2, 4, 26), "float32", vdev0)
+ )
_check_inference(
bb,
relax.op.nn.conv1d(x0, w0, out_dtype="float16"),
@@ -414,21 +420,29 @@ def test_conv1d_infer_struct_info_wrong_input_type():
def test_conv1d_transpose_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
x1 = relax.Var("x", R.Tensor((2, 28, 3), "float32"))
x2 = relax.Var("x", R.Tensor("float32", ndim=3))
x3 = relax.Var("x", R.Tensor("float32"))
x4 = relax.Var("x", R.Tensor())
x5 = relax.Var("x", R.Tensor((2, 4, 28, 16), "float32"))
+ x6 = relax.Var("x", R.Tensor((2, 3, 28), "float32", vdev0))
w0 = relax.Var("w", R.Tensor((3, 4, 3), "float32"))
w1 = relax.Var("w", R.Tensor((4, 3, 3), "float32"))
w2 = relax.Var("w", R.Tensor("float32", ndim=3))
w3 = relax.Var("w", R.Tensor("float32"))
w4 = relax.Var("w", R.Tensor((4, 48, 3, 16), "float32"))
+ w5 = relax.Var("w", R.Tensor((3, 4, 3), "float32", vdev0))
_check_inference(
bb, relax.op.nn.conv1d_transpose(x0, w0), relax.TensorStructInfo((2, 4, 30), "float32")
)
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x6, w5),
+ relax.TensorStructInfo((2, 4, 30), "float32", vdev0),
+ )
_check_inference(
bb,
relax.op.nn.conv1d_transpose(x0, w0, out_dtype="float16"),
@@ -764,21 +778,27 @@ def test_conv1d_transpose_infer_struct_info_wrong_input_type():
def test_conv2d_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
x1 = relax.Var("x", R.Tensor((2, 28, 28, 3), "float32"))
x2 = relax.Var("x", R.Tensor("float32", ndim=4))
x3 = relax.Var("x", R.Tensor("float32"))
x4 = relax.Var("x", R.Tensor())
x5 = relax.Var("x", R.Tensor((2, 4, 28, 28, 16), "float32"))
+ x6 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32", vdev0))
w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
w1 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32"))
w2 = relax.Var("w", R.Tensor("float32", ndim=4))
w3 = relax.Var("w", R.Tensor("float32"))
w4 = relax.Var("w", R.Tensor((48, 4, 3, 3, 16), "float32"))
+ w5 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32", vdev0))
_check_inference(
bb, relax.op.nn.conv2d(x0, w0), relax.TensorStructInfo((2, 4, 26, 26), "float32")
)
+ _check_inference(
+ bb, relax.op.nn.conv2d(x6, w5), relax.TensorStructInfo((2, 4, 26, 26), "float32", vdev0)
+ )
_check_inference(
bb,
relax.op.nn.conv2d(x0, w0, out_dtype="float16"),
@@ -1155,21 +1175,29 @@ def test_conv2d_infer_struct_info_wrong_input_type():
def test_conv2d_transpose_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
x1 = relax.Var("x", R.Tensor((2, 28, 28, 3), "float32"))
x2 = relax.Var("x", R.Tensor("float32", ndim=4))
x3 = relax.Var("x", R.Tensor("float32"))
x4 = relax.Var("x", R.Tensor())
x5 = relax.Var("x", R.Tensor((2, 4, 28, 28, 16), "float32"))
+ x6 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32", vdev0))
w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32"))
w1 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
w2 = relax.Var("w", R.Tensor("float32", ndim=4))
w3 = relax.Var("w", R.Tensor("float32"))
w4 = relax.Var("w", R.Tensor((4, 48, 3, 3, 16), "float32"))
+ w5 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32", vdev0))
_check_inference(
bb, relax.op.nn.conv2d_transpose(x0, w0), relax.TensorStructInfo((2, 4, 30, 30), "float32")
)
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x6, w5),
+ relax.TensorStructInfo((2, 4, 30, 30), "float32", vdev0),
+ )
_check_inference(
bb,
relax.op.nn.conv2d_transpose(x0, w0, out_dtype="float16"),
diff --git a/tests/python/relax/test_op_nn_pooling.py b/tests/python/relax/test_op_nn_pooling.py
index 2bd7747f31..2533a2fcad 100644
--- a/tests/python/relax/test_op_nn_pooling.py
+++ b/tests/python/relax/test_op_nn_pooling.py
@@ -19,7 +19,7 @@ import tvm
import tvm.testing
from tvm import relax, tir
from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
from tvm.script import relax as R
@@ -37,6 +37,7 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
def test_max_pool2d_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32"))
x2 = relax.Var("x", R.Tensor("float32", ndim=4))
@@ -44,10 +45,14 @@ def test_max_pool2d_infer_struct_info():
x4 = relax.Var("x", R.Tensor(ndim=4))
x5 = relax.Var("x", R.Tensor())
x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32"))
+ x7 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32", vdev0))
_check_inference(
bb, relax.op.nn.max_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32")
)
+ _check_inference(
+ bb, relax.op.nn.max_pool2d(x7), relax.TensorStructInfo((2, 3, 32, 32), "float32", vdev0)
+ )
_check_inference(
bb,
relax.op.nn.max_pool2d(x0, pool_size=3),
@@ -262,6 +267,7 @@ def test_max_pool2d_infer_struct_info_wrong_input_type():
def test_avg_pool2d_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32"))
x2 = relax.Var("x", R.Tensor("float32", ndim=4))
@@ -269,10 +275,14 @@ def test_avg_pool2d_infer_struct_info():
x4 = relax.Var("x", R.Tensor(ndim=4))
x5 = relax.Var("x", R.Tensor())
x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32"))
+ x7 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32", vdev0))
_check_inference(
bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32")
)
+ _check_inference(
+ bb, relax.op.nn.avg_pool2d(x7), relax.TensorStructInfo((2, 3, 32, 32), "float32", vdev0)
+ )
_check_inference(
bb,
relax.op.nn.avg_pool2d(x0, pool_size=3),
@@ -487,6 +497,7 @@ def test_avg_pool2d_infer_struct_info_wrong_input_type():
def test_adaptive_avg_pool2d_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32"))
x2 = relax.Var("x", R.Tensor("float32", ndim=4))
@@ -494,10 +505,16 @@ def test_adaptive_avg_pool2d_infer_struct_info():
x4 = relax.Var("x", R.Tensor(ndim=4))
x5 = relax.Var("x", R.Tensor())
x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32"))
+ x7 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32", vdev0))
_check_inference(
bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32")
)
+ _check_inference(
+ bb,
+ relax.op.nn.adaptive_avg_pool2d(x7),
+ relax.TensorStructInfo((2, 3, 32, 32), "float32", vdev0),
+ )
_check_inference(
bb,
relax.op.nn.adaptive_avg_pool2d(x0, output_size=30),
diff --git a/tests/python/relax/test_op_search.py b/tests/python/relax/test_op_search.py
index ba78d11022..21f022d9eb 100644
--- a/tests/python/relax/test_op_search.py
+++ b/tests/python/relax/test_op_search.py
@@ -21,7 +21,7 @@ import tvm
import tvm.testing
from tvm import relax, tir
from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
from tvm.script import relax as R
@@ -41,25 +41,32 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
def test_where_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
cond0 = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool"))
cond1 = relax.Var("cond", R.Tensor("bool", ndim=5))
cond2 = relax.Var("cond", R.Tensor("bool"))
+ cond3 = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool", vdev0))
x0 = relax.Var("x", R.Tensor((5, 1, 3, 2), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=4))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((5, 1, 3, 2)))
x4 = relax.Var("x", R.Tensor(ndim=4))
x5 = relax.Var("x", R.Tensor())
+ x6 = relax.Var("x", R.Tensor((5, 1, 3, 2), "float32", vdev0))
y0 = relax.Var("y", R.Tensor((4, 3, 1), "float32"))
y1 = relax.Var("y", R.Tensor("float32", ndim=3))
y2 = relax.Var("y", R.Tensor("float32"))
y3 = relax.Var("y", R.Tensor((4, 3, 1)))
y4 = relax.Var("y", R.Tensor(ndim=3))
y5 = relax.Var("y", R.Tensor())
+ y6 = relax.Var("y", R.Tensor((4, 3, 1), "float32", vdev0))
_check_inference(
bb, relax.op.where(cond0, x0, y0), relax.TensorStructInfo((6, 5, 4, 3, 2), "float32")
)
+ _check_inference(
+ bb, relax.op.where(cond3, x6, y6), relax.TensorStructInfo((6, 5, 4, 3, 2), "float32", vdev0)
+ )
_check_inference(
bb, relax.op.where(cond0, x1, y0), relax.TensorStructInfo(dtype="float32", ndim=5)
)
@@ -283,12 +290,17 @@ def test_where_infer_struct_info_wrong_input_type():
def test_argmax_argmin_infer_struct_info(argmax_argmin_op: Callable):
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=4))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((2, 3, 4, 5)))
+ x4 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0))
_check_inference(bb, argmax_argmin_op(x0, axis=1), relax.TensorStructInfo((2, 4, 5), "int64"))
+ _check_inference(
+ bb, argmax_argmin_op(x4, axis=1), relax.TensorStructInfo((2, 4, 5), "int64", vdev0)
+ )
_check_inference(
bb,
argmax_argmin_op(x0, axis=1, keepdims=True),
diff --git a/tests/python/relax/test_op_set.py b/tests/python/relax/test_op_set.py
index 755d5e8f87..741d7869d5 100644
--- a/tests/python/relax/test_op_set.py
+++ b/tests/python/relax/test_op_set.py
@@ -19,7 +19,7 @@ import tvm
import tvm.testing
from tvm import relax, tir
from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
from tvm.script import relax as R
@@ -35,10 +35,12 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
def test_unique_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=3))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((2, 3, 4)))
+ x4 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0))
_check_inference(
bb,
@@ -47,6 +49,13 @@ def test_unique_infer_struct_info():
),
relax.TensorStructInfo(dtype="float32", ndim=1),
)
+ _check_inference(
+ bb,
+ relax.op.unique(
+ x4, return_index=False, return_inverse=False, return_counts=False, axis=None
+ ),
+ relax.TensorStructInfo(dtype="float32", ndim=1, vdevice=vdev0),
+ )
_check_inference(
bb,
relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=False, axis=1),
diff --git a/tests/python/relax/test_op_statistical.py b/tests/python/relax/test_op_statistical.py
index 8c542b86a1..5c7d56556c 100644
--- a/tests/python/relax/test_op_statistical.py
+++ b/tests/python/relax/test_op_statistical.py
@@ -19,7 +19,7 @@ import tvm
import tvm.testing
from tvm import relax, tir
from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
from tvm.script import relax as R
@@ -41,12 +41,17 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
def test_statistical_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=4))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((2, 3, 4, 5)))
+ x4 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0))
_check_inference(bb, relax.op.sum(x0, axis=[1, 2]), relax.TensorStructInfo((2, 5), "float32"))
+ _check_inference(
+ bb, relax.op.sum(x4, axis=[1, 2]), relax.TensorStructInfo((2, 5), "float32", vdev0)
+ )
_check_inference(
bb,
relax.op.sum(x0, axis=[1, 2], keepdims=True),
@@ -202,14 +207,19 @@ def test_statistical_infer_struct_info_wrong_input_type():
def test_cumsum_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=3))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((2, 10, 4)))
x4 = relax.Var("x", R.Tensor(ndim=3))
x5 = relax.Var("x", R.Tensor())
+ x6 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0))
_check_inference(bb, relax.op.cumsum(x0, axis=1), relax.TensorStructInfo((2, 10, 4), "float32"))
+ _check_inference(
+ bb, relax.op.cumsum(x6, axis=1), relax.TensorStructInfo((2, 10, 4), "float32", vdev0)
+ )
_check_inference(
bb, relax.op.cumsum(x1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)
)
diff --git a/tests/python/relax/test_op_ternary.py b/tests/python/relax/test_op_ternary.py
index 5ea7a01da7..120cb47c70 100644
--- a/tests/python/relax/test_op_ternary.py
+++ b/tests/python/relax/test_op_ternary.py
@@ -19,7 +19,7 @@ import tvm
import tvm.testing
from tvm import relax, tir
from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
from tvm.script import relax as R
@@ -37,14 +37,21 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
def test_ewise_fma_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
x1 = relax.Var("x", R.Tensor((2, 3)))
+ x2 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0))
y0 = relax.Var("y", R.Tensor((2, 3), "float32"))
y1 = relax.Var("y", R.Tensor(dtype="float32", ndim=2))
+ y2 = relax.Var("y", R.Tensor((2, 3), "float32", vdev0))
z0 = relax.Var("z", R.Tensor((2, 3), "float32"))
z1 = relax.Var("z", R.Tensor("float32"))
+ z2 = relax.Var("z", R.Tensor((2, 3), "float32", vdev0))
_check_inference(bb, relax.op.ewise_fma(x0, y0, z0), relax.TensorStructInfo((2, 3), "float32"))
+ _check_inference(
+ bb, relax.op.ewise_fma(x2, y2, z2), relax.TensorStructInfo((2, 3), "float32", vdev0)
+ )
_check_inference(
bb, relax.op.ewise_fma(x0, y1, z0), relax.TensorStructInfo(dtype="float32", ndim=2)
)
diff --git a/tests/python/relax/test_op_unary.py b/tests/python/relax/test_op_unary.py
index 3a6c14fd66..c66f23ae5c 100644
--- a/tests/python/relax/test_op_unary.py
+++ b/tests/python/relax/test_op_unary.py
@@ -20,7 +20,7 @@ import tvm
import tvm.testing
from tvm import relax, tir
from tvm import TVMError
-from tvm.ir import Op
+from tvm.ir import Op, VDevice
from tvm.script import relax as R
@@ -97,13 +97,16 @@ unary_arith_op, require_float_dtype = tvm.testing.parameters(
def test_unary_arith_infer_struct_info(unary_arith_op: Callable):
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=3))
x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
x3 = relax.Var("x", R.Tensor((2, 3)))
x4 = relax.Var("x", R.Tensor())
+ x5 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0))
_check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo((2, 3), "float32"))
+ _check_inference(bb, unary_arith_op(x5), relax.TensorStructInfo((2, 3), "float32", vdev0))
_check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo(dtype="float32", ndim=3))
_check_inference(bb, unary_arith_op(x2), relax.TensorStructInfo(dtype="float32"))
_check_inference(bb, unary_arith_op(x3), relax.TensorStructInfo((2, 3), dtype=""))
@@ -186,13 +189,16 @@ def test_unary_arith_infer_struct_info_wrong_input_type(unary_arith_op: Callable
def test_clip_infer_struct_info():
bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=3))
x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
x3 = relax.Var("x", R.Tensor((2, 3)))
x4 = relax.Var("x", R.Tensor())
+ x5 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0))
_check_inference(bb, relax.op.clip(x0, 0, 6), relax.TensorStructInfo((2, 3), "float32"))
+ _check_inference(bb, relax.op.clip(x5, 0, 6), relax.TensorStructInfo((2, 3), "float32", vdev0))
_check_inference(bb, relax.op.clip(x1, 0, 6), relax.TensorStructInfo(dtype="float32", ndim=3))
_check_inference(bb, relax.op.clip(x2, 0, 6), relax.TensorStructInfo(dtype="float32"))
_check_inference(bb, relax.op.clip(x3, 0, 6), relax.TensorStructInfo((2, 3), dtype=""))
diff --git a/tests/python/relax/test_transform_update_vdevice.py b/tests/python/relax/test_transform_update_vdevice.py
new file mode 100644
index 0000000000..76caccb536
--- /dev/null
+++ b/tests/python/relax/test_transform_update_vdevice.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.ir import VDevice
+from tvm.relax.transform import UpdateVDevice
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+
+def verify(input, new_vdevice, vdevice_index, expected):
+ tvm.ir.assert_structural_equal(UpdateVDevice(new_vdevice, vdevice_index)(input), expected)
+
+
+def test_update():
+ vdevices = [
+ VDevice("llvm"),
+ VDevice("cuda", 0),
+ VDevice("metal", 0, "global"),
+ VDevice("cuda -arch=sm_80", 0),
+ VDevice("metal", 1, "global"),
+ VDevice("llvm", 1),
+ ]
+
+ @I.ir_module
+ class Input1:
+ I.module_attrs({"attr": 10})
+ I.module_global_infos(
+ {
+ "vdevice": [
+ I.vdevice("llvm"),
+ I.vdevice("cuda", 0),
+ I.vdevice("metal", 0, "global"),
+ I.vdevice("cuda -arch=sm_80", 0),
+ ]
+ }
+ )
+
+ @R.function
+ def main(
+ a: R.Tensor((128, 128), "float32", "cuda:1"), # noqa: F722
+ c: R.Tensor((128, 128), "float32", "vdevice:3"), # noqa: F722
+ ) -> R.Tensor((128, 128), "float32"):
+ s = R.add(a, c)
+ return s
+
+ @I.ir_module
+ class Expect1:
+ I.module_attrs({"attr": 10})
+ I.module_global_infos(
+ {
+ "vdevice": [
+ I.vdevice("llvm"),
+ I.vdevice("cuda", 0),
+ I.vdevice("metal", 0, "global"),
+ I.vdevice("metal", 1, "global"),
+ ]
+ }
+ )
+
+ @R.function
+ def main(
+ a: R.Tensor((128, 128), dtype="float32", vdevice="metal:1"), # noqa: F722
+ c: R.Tensor((128, 128), dtype="float32", vdevice="metal:1"), # noqa: F722
+ ) -> R.Tensor((128, 128), dtype="float32", vdevice="metal:1"): # noqa: F722
+ s: R.Tensor((128, 128), dtype="float32", vdevice="metal:1") = R.add(a, c) # noqa: F722
+ return s
+
+ @I.ir_module
+ class Input2:
+ I.module_attrs({"attr": 10})
+ I.module_global_infos(
+ {
+ "vdevice": [
+ I.vdevice("llvm"),
+ I.vdevice("cuda", 0),
+ ]
+ }
+ )
+
+ @R.function
+ def main(
+ a: R.Tensor((128, 128), "float32", "cuda:0"), # noqa: F722
+ c: R.Tensor((128, 128), "float32", "cuda:0"), # noqa: F722
+ ) -> R.Tensor((128, 128), "float32"):
+ s = R.add(a, c)
+ return s
+
+ @I.ir_module
+ class Expect2:
+ I.module_attrs({"attr": 10})
+ I.module_global_infos(
+ {
+ "vdevice": [
+ I.vdevice("llvm"),
+ I.vdevice("llvm", 1),
+ ]
+ }
+ )
+
+ @R.function
+ def main(
+ a: R.Tensor((128, 128), "float32", "llvm:1"), # noqa: F722
+ c: R.Tensor((128, 128), "float32", "llvm:1"), # noqa: F722
+ ) -> R.Tensor((128, 128), "float32", "llvm:1"): # noqa: F722
+ s: R.Tensor((128, 128), "float32", "llvm:1") = R.add(a, c) # noqa: F722
+ return s
+
+ verify(Input1, vdevices[4], 3, Expect1)
+ verify(Input2, vdevices[5], 1, Expect2)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py
index 39a4d33ca6..1d56bd29b3 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -791,7 +791,7 @@ def test_tensor_with_vdevice():
a: R.Tensor((128, 128), "float32", "cuda:1"), # noqa: F722
b: R.Tensor((128, 128), "float32", "llvm"),
c: R.Tensor((128, 128), "float32", "vdevice:3"), # noqa: F722
- ) -> R.Tensor((128, 128), "float32"):
+ ) -> R.Tensor((128, 128), "float32", "cuda:1"): # noqa: F722
s = R.add(a, c)
return s