You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/04/05 14:09:11 UTC

[tvm] branch unity updated: [Unity][Op] vm.alloc_tensor infer struct info (#14503)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e93ee9fc56 [Unity][Op] vm.alloc_tensor infer struct info (#14503)
e93ee9fc56 is described below

commit e93ee9fc562852bf8e00577ad64a54b56493e9b8
Author: Bohan Hou <32...@users.noreply.github.com>
AuthorDate: Wed Apr 5 07:09:00 2023 -0700

    [Unity][Op] vm.alloc_tensor infer struct info (#14503)
    
    ,
---
 src/relax/op/op.cc                 | 6 ++++++
 tests/python/relax/test_op_misc.py | 9 +++++++++
 2 files changed, 15 insertions(+)

diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index c641c45922..9d331e41dd 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -501,6 +501,12 @@ StructInfo InferStructInfoVMAllocTensor(const Call& call, const BlockBuilder& ct
   }
   if (const auto* output_shape = call->args[2].as<ShapeExprNode>()) {
     return TensorStructInfo(GetRef<Expr>(output_shape), out_dtype);
+  } else if (const auto* shape_sinfo = GetStructInfoAs<ShapeStructInfoNode>(call->args[2])) {
+    if (shape_sinfo->values.defined()) {
+      return TensorStructInfo(ShapeExpr(shape_sinfo->values.value()), out_dtype);
+    } else {
+      return TensorStructInfo(out_dtype, shape_sinfo->ndim);
+    }
   }
   return TensorStructInfo(out_dtype, kUnknownNDim);
 }
diff --git a/tests/python/relax/test_op_misc.py b/tests/python/relax/test_op_misc.py
index d596c60196..87c2a58a82 100644
--- a/tests/python/relax/test_op_misc.py
+++ b/tests/python/relax/test_op_misc.py
@@ -103,6 +103,15 @@ def test_vm_alloc_tensor():
     tvm.ir.assert_structural_equal(alloc.struct_info, R.Tensor([4, 5], "float32"))
 
 
+def test_vm_alloc_tensor_infer_struct_info():
+    bb = rx.BlockBuilder()
+    s1 = rx.Var("s", R.Shape(ndim=3))
+    storage = rx.Var("storage", rx.TensorStructInfo(dtype="float32"))
+    alloc = rx.op.vm.alloc_tensor(storage, offset=0, shape=s1, dtype="float32")
+    ret = bb.normalize(alloc)
+    tvm.ir.assert_structural_equal(ret.struct_info, R.Tensor(dtype="float32", ndim=3))
+
+
 def test_builtin_stop_lift_params():
     bb = rx.BlockBuilder()
     x = rx.Var("x", rx.TensorStructInfo(shape=[4, 5], dtype="float32"))