You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2023/09/07 15:04:54 UTC
[tvm] branch unity updated: [Unity] Added known tir.Expr to relax.PrimValue (#15577)
This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 755af1fec7 [Unity] Added known tir.Expr to relax.PrimValue (#15577)
755af1fec7 is described below
commit 755af1fec722d7efa0b3b00fbd07eaac168a16b2
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Thu Sep 7 08:04:46 2023 -0700
[Unity] Added known tir.Expr to relax.PrimValue (#15577)
Prior to this commit, a `relax.PrimValue` could have a datatype, but
couldn't have a corresponding `tir.PrimExpr`. As a result, it could
not be used to specify tensor shapes. This makes some expressions
require fallback to `R.Tensor(ndim=ndim)`, even though the shape could
still be inferred.
```python
@R.function
def func(
A: R.Tensor(16, 16),
first_n_rows: R.prim("int64"),
) -> R.Tensor([first_n_rows, 16]):
# ^^^^^^^^^^^^
# R.Tensor requires a PrimExpr, not relax.Expr
#
# Operations may require PrimExpr
# vvvvvvvvvvvv
out = R.op.strided_slice(axis=[0], begin=[0], end=[first_n_rows])
return out
```
This commit adds a `Optional<PrimExpr> value` field to the
`PrimStructInfo`. This field acts similarly to the `PrimExpr` fields
already used in `ShapeStructInfo`, and may contain symbolic variables.
```python
@R.function
def func(
A: R.Tensor(16, 16),
# TIR definitions in signature allow in-line definitions,
# similar to R.Tensor and R.Shape. R.Prim takes `dtype` or
# `value` kwarg to distinguish between in-line symbolic variable
# and string representation of dtype.
first_n_rows: R.prim(value="first_n_rows_tir"),
) -> R.Tensor(["first_n_rows_tir", 16]):
# Body contains a TIR variable definition, which may be used
# in function calls, inferred shape annotations.
first_n_rows_tir = T.int64()
out = R.op.strided_slice(axis=[0], begin=[0], end=[first_n_rows])
return out
```
Use distinct PrimStructInfo arguments for dtype/value
Update TVMScript printer
Parser updates, Support R.Prim(value=...) annotations in function signature
* Added unit tests for new functionality in API, parser, printer
* Add unit tests for bind_symbolic_vars
* Add test cases to valid bind_symbolic_vars
---
include/tvm/relax/struct_info.h | 15 +++-
python/tvm/relax/struct_info.py | 63 +++++++++++++++-
python/tvm/script/parser/relax/entry.py | 43 ++++++++---
src/relax/analysis/struct_info_analysis.cc | 3 +
src/relax/ir/expr.cc | 2 +-
src/relax/ir/expr_functor.cc | 3 +
src/relax/ir/struct_info.cc | 15 +++-
src/relax/ir/struct_info_functor.cc | 17 ++++-
src/script/printer/relax/struct_info.cc | 23 ++++--
.../relax/test_analysis_struct_info_analysis.py | 43 ++++++++++-
tests/python/relax/test_bind_symbolic_vars.py | 87 +++++++++++++++++++++-
tests/python/relax/test_expr.py | 17 +++++
tests/python/relax/test_struct_info.py | 18 ++++-
tests/python/relax/test_tvmscript_parser.py | 46 ++++++++++--
tests/python/relax/test_tvmscript_printer_relax.py | 4 +-
15 files changed, 365 insertions(+), 34 deletions(-)
diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h
index d2bf525225..2e224f1830 100644
--- a/include/tvm/relax/struct_info.h
+++ b/include/tvm/relax/struct_info.h
@@ -60,19 +60,26 @@ class ObjectStructInfo : public StructInfo {
*/
class PrimStructInfoNode : public StructInfoNode {
public:
+ /*! \brief Underlying primitive value, if known */
+ Optional<PrimExpr> value;
+
/*! \brief Underlying data type of the primitive value */
DataType dtype;
void VisitAttrs(AttrVisitor* v) {
+ v->Visit("value", &value);
v->Visit("dtype", &dtype);
v->Visit("span", &span);
}
bool SEqualReduce(const PrimStructInfoNode* other, SEqualReducer equal) const {
- return equal(dtype, other->dtype);
+ return equal(value, other->value) && equal(dtype, other->dtype);
}
- void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); }
+ void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce(value);
+ hash_reduce(dtype);
+ }
static constexpr const char* _type_key = "relax.PrimStructInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimStructInfoNode, StructInfoNode);
@@ -84,8 +91,12 @@ class PrimStructInfoNode : public StructInfoNode {
*/
class PrimStructInfo : public StructInfo {
public:
+ /* Construct a PrimStructInfo with a known dtype, but unknown value */
TVM_DLL PrimStructInfo(DataType dtype, Span span = Span());
+ /* Construct a PrimStructInfo with a known value */
+ TVM_DLL PrimStructInfo(PrimExpr value, Span span = Span());
+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimStructInfo, StructInfo, PrimStructInfoNode);
};
diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py
index e78e1cf69a..38a4ab7491 100644
--- a/python/tvm/relax/struct_info.py
+++ b/python/tvm/relax/struct_info.py
@@ -23,6 +23,7 @@ import tvm
from tvm.ir import Span, EnvFunc, Array, VDevice
from tvm.tir import PrimExpr
+from tvm.runtime import DataType
from .expr import StructInfo, Expr, ShapeExpr
from . import _ffi_api, ty, expr
@@ -42,14 +43,68 @@ class PrimStructInfo(StructInfo):
Parameters
----------
- dtype : str
- The data type of the prim value.
+ dtype_or_expr : Union[str, DataType, PrimExpr]
+
+ The data type of the prim value, or a known expression for the prim
+ value.
"""
+ value: Optional[PrimExpr]
dtype: str
- def __init__(self, dtype: str, span: Span = None) -> None:
- self.__init_handle_by_constructor__(_ffi_api.PrimStructInfo, dtype, span) # type: ignore
+ def __init__(
+ self,
+ dtype: Optional[Union[str, DataType]] = None,
+ value: Optional[Union[int, float, PrimExpr]] = None,
+ span: Span = None,
+ ) -> None:
+ # Guard against incorrect usage. For backwards compatibility,
+ # the dtype and value are in the opposite order from most
+ # usages. While PrimStructInfo could take a single positional
+ # argument and check the type, this would require an API
+ # difference from TVMScript's PrimProxy, which cannot.
+ # (PrimProxy uses string arguments for datatype, and also for
+ # inline variable definitions when used in a function
+ # signature, and requires separate arguments to distinguish
+ # the two cases.)
+ if isinstance(dtype, (PrimExpr, int, float)):
+ raise TypeError(
+ f"The first positional argument of PrimStructInfo must be the datatype, "
+ f", but received {type(dtype)}. "
+ f"The value can be specified as a keyword argument "
+ f"without needing specifying the dtype: "
+ f"PrimStructInfo(value=arg)."
+ )
+
+ if dtype is None and value is None:
+ raise TypeError(
+ "PrimStructInfo.__init__ missing required argument. "
+ "Must provide either 'dtype' or 'value'"
+ )
+
+ if dtype is not None:
+ if isinstance(value, PrimExpr):
+ assert value.dtype == dtype, (
+ "When providing both 'value' and 'dtype' to PrimStructInfo.__init__, "
+ "they must be consistent with each other. "
+ "However, the value {value} has dtype {value.dtype}, "
+ "but the specified dtype was {dtype}."
+ )
+ elif isinstance(value, (int, float)):
+ value = tvm.tir.const(value, dtype)
+
+ # Use relax's default integer type if not otherwise specified.
+ if isinstance(value, int):
+ value = tvm.tir.IntImm("int64", value)
+
+ if value is None:
+ self.__init_handle_by_constructor__(
+ _ffi_api.PrimStructInfoFromDtype, dtype, span
+ ) # type: ignore
+ else:
+ self.__init_handle_by_constructor__(
+ _ffi_api.PrimStructInfoFromValue, value, span
+ ) # type: ignore
@tvm._ffi.register_object("relax.ShapeStructInfo")
diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py
index 85f641e838..d3e2342750 100644
--- a/python/tvm/script/parser/relax/entry.py
+++ b/python/tvm/script/parser/relax/entry.py
@@ -359,7 +359,7 @@ class ShapeProxy(StructInfoProxy):
def get_symbolic_vars(self) -> Set[str]:
if self.values is None:
- return {}
+ return set()
else:
return {v for v in self.values if isinstance(v, str) and v.isidentifier()}
@@ -405,27 +405,52 @@ def Object() -> ObjectProxy:
class PrimProxy(StructInfoProxy):
- dtype: str
- """The type of shape values.
+ dtype: Optional[str]
+ value: Optional[Union[int, float, str, PrimExpr]]
+
+ """The type of TIR-representable values.
Parameters
----------
- dtype : str
+ dtype : Optional[str]
The data type.
+
+ value: Optional[Union[int, float, str, PrimExpr]]
+ The known value
"""
- def __init__(self, dtype: str) -> None:
+ def __init__(
+ self,
+ dtype: Optional[str] = None,
+ value: Optional[Union[int, float, str, PrimExpr]] = None,
+ ) -> None:
+ if dtype is None and value is None:
+ raise TypeError(
+ "R.Prim missing required argument. " "Must provide either 'dtype' or 'value'"
+ )
+
self.dtype = dtype
+ self.value = value
def get_symbolic_vars(self) -> Set[str]:
- return set()
+ if isinstance(self.value, str) and self.value.isidentifier():
+ return {self.value}
+ else:
+ return set()
def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo:
- return PrimStructInfo(self.dtype)
+ if self.value is None:
+ return PrimStructInfo(dtype=self.dtype)
+ else:
+ value = _eval_shape(self.value, dict_globals)
+ return PrimStructInfo(dtype=self.dtype, value=value)
-def Prim(dtype: str) -> PrimProxy:
- return PrimProxy(dtype)
+def Prim(
+ dtype: Optional[str] = None,
+ value: Optional[Union[int, float, str, PrimExpr]] = None,
+) -> PrimProxy:
+ return PrimProxy(dtype, value)
############################ R.match_cast #############################
diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc
index 4a633e9df4..ddb3fdb5c1 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -1062,6 +1062,9 @@ class SymbolicVarCollector : public relax::ExprVisitor,
this->VisitStructInfoExprField(val);
}
}
+ if (auto prim_value = expr.as<relax::PrimValue>()) {
+ this->VisitStructInfoExprField(prim_value.value()->value);
+ }
}
void VisitStructInfoExprField(const PrimExpr& expr) final {
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index ac04096aaf..ac3e532289 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -311,7 +311,7 @@ TVM_REGISTER_GLOBAL("relax.Constant")
PrimValue::PrimValue(PrimExpr value, Span span) {
ObjectPtr<PrimValueNode> n = make_object<PrimValueNode>();
n->checked_type_ = PrimType(value.dtype());
- n->struct_info_ = PrimStructInfo(value.dtype());
+ n->struct_info_ = PrimStructInfo(value);
n->value = std::move(value);
n->span = std::move(span);
data_ = std::move(n);
diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc
index f0f0d29b51..0174308802 100644
--- a/src/relax/ir/expr_functor.cc
+++ b/src/relax/ir/expr_functor.cc
@@ -225,6 +225,9 @@ void ExprVisitor::VisitExpr_(const SeqExprNode* op) {
void ExprVisitor::VisitExpr_(const PrimValueNode* op) {
this->VisitPrimExpr(op->value);
+ if (auto* sinfo = op->struct_info_.as<StructInfoNode>()) {
+ this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo));
+ }
this->VisitSpan(op->span);
}
diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc
index 31784af000..9b635bb479 100644
--- a/src/relax/ir/struct_info.cc
+++ b/src/relax/ir/struct_info.cc
@@ -42,19 +42,32 @@ TVM_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) {
});
// Prim
+PrimStructInfo::PrimStructInfo(PrimExpr value, Span span) {
+ ObjectPtr<PrimStructInfoNode> n = make_object<PrimStructInfoNode>();
+ n->dtype = value->dtype;
+ n->value = std::move(value);
+ n->span = span;
+ data_ = std::move(n);
+}
+
PrimStructInfo::PrimStructInfo(DataType dtype, Span span) {
ObjectPtr<PrimStructInfoNode> n = make_object<PrimStructInfoNode>();
n->dtype = dtype;
+ n->value = NullOpt;
n->span = span;
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(PrimStructInfoNode);
-TVM_REGISTER_GLOBAL("relax.PrimStructInfo").set_body_typed([](DataType dtype, Span span) {
+TVM_REGISTER_GLOBAL("relax.PrimStructInfoFromDtype").set_body_typed([](DataType dtype, Span span) {
return PrimStructInfo(dtype, span);
});
+TVM_REGISTER_GLOBAL("relax.PrimStructInfoFromValue").set_body_typed([](PrimExpr value, Span span) {
+ return PrimStructInfo(value, span);
+});
+
// Shape
ShapeStructInfo::ShapeStructInfo(Array<PrimExpr> values, Span span) {
ObjectPtr<ShapeStructInfoNode> n = make_object<ShapeStructInfoNode>();
diff --git a/src/relax/ir/struct_info_functor.cc b/src/relax/ir/struct_info_functor.cc
index d7929e0f1a..ea8f1da8f0 100644
--- a/src/relax/ir/struct_info_functor.cc
+++ b/src/relax/ir/struct_info_functor.cc
@@ -28,7 +28,11 @@ namespace relax {
void StructInfoVisitor::VisitStructInfo_(const ObjectStructInfoNode* op) {}
-void StructInfoVisitor::VisitStructInfo_(const PrimStructInfoNode* op) {}
+void StructInfoVisitor::VisitStructInfo_(const PrimStructInfoNode* op) {
+ if (op->value.defined()) {
+ this->VisitStructInfoExprField(op->value.value());
+ }
+}
void StructInfoVisitor::VisitStructInfo_(const ShapeStructInfoNode* op) {
if (op->values.defined()) {
@@ -68,7 +72,16 @@ StructInfo StructInfoMutator::VisitStructInfo_(const ObjectStructInfoNode* op) {
}
StructInfo StructInfoMutator::VisitStructInfo_(const PrimStructInfoNode* op) {
- return GetRef<StructInfo>(op);
+ if (!op->value.defined()) {
+ return GetRef<StructInfo>(op);
+ }
+
+ auto new_expr = VisitStructInfoExprField(op->value.value());
+ if (new_expr.same_as(op->value)) {
+ return GetRef<StructInfo>(op);
+ } else {
+ return PrimStructInfo(new_expr);
+ }
}
StructInfo StructInfoMutator::VisitStructInfo_(const ShapeStructInfoNode* op) {
diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc
index cccf9ed08b..11f987c368 100644
--- a/src/script/printer/relax/struct_info.cc
+++ b/src/script/printer/relax/struct_info.cc
@@ -30,12 +30,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return Relax(d, "Object");
});
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch<relax::PrimStructInfo>(
- "", [](relax::PrimStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc {
- return Relax(d, "Prim")->Call({LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))});
- });
-
ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const IRDocsifier& d) {
ExprDoc expr_doc = d->AsDoc<ExprDoc>(e, e_p);
// Step 1. Find if `func_vars` are being collected
@@ -66,6 +60,23 @@ ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const IRDocsifie
return expr_doc;
}
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::PrimStructInfo>(
+ "", [](relax::PrimStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ Array<ExprDoc, void> args;
+ Array<String> kwargs_keys;
+ Array<ExprDoc, void> kwargs_values;
+
+ if (n->value.defined()) {
+ kwargs_keys.push_back("value");
+ kwargs_values.push_back(PrintShapeVar(n->value.value(), n_p->Attr("value"), d));
+ } else {
+ args.push_back(LiteralDoc::DataType(n->dtype, n_p->Attr("dtype")));
+ }
+
+ return Relax(d, "Prim")->Call(args, kwargs_keys, kwargs_values);
+ });
+
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::ShapeStructInfo>(
"", [](relax::ShapeStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc {
diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py
index 879194037c..1b1ea2e53e 100644
--- a/tests/python/relax/test_analysis_struct_info_analysis.py
+++ b/tests/python/relax/test_analysis_struct_info_analysis.py
@@ -636,7 +636,7 @@ def test_tir_vars_in_struct_info():
tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(func), [n, m])
-def test_symbolic_var_collector():
+def test_collect_symbolic_var_from_tensor_shape():
n, m, k, q, p = (
tir.Var("n", "int64"),
tir.Var("m", "int64"),
@@ -658,5 +658,46 @@ def test_symbolic_var_collector():
assert free_vars == {n, p, q}
+param_type = tvm.testing.parameter("shape_expr", "prim_value")
+param_order = tvm.testing.parameter("definition_first", "usage_first")
+
+
+def test_collect_symbolic_var_from_non_tensor_params(param_type, param_order):
+ tir_n = tir.Var("n", "int64")
+ tir_m = tir.Var("m", "int64")
+
+ bb = rx.BlockBuilder()
+ arg = rx.Var("arg", rx.TensorStructInfo([tir_n * tir_m]))
+
+ if param_type == "shape_expr":
+ extra_params = [
+ rx.Var("shape_expr", rx.ShapeStructInfo([tir_n, tir_m])),
+ ]
+ elif param_type == "prim_value":
+ extra_params = [
+ rx.Var("n", rx.PrimStructInfo(value=tir_n)),
+ rx.Var("m", rx.PrimStructInfo(value=tir_m)),
+ ]
+ else:
+ raise ValueError(f"Unknown param_type: {param_type}")
+
+ if param_order == "definition_first":
+ params = [*extra_params, arg]
+ elif param_order == "usage_first":
+ params = [arg, *extra_params]
+ else:
+ raise ValueError(f"Unknown param_order: {param_order}")
+
+ with bb.function("main", params=params):
+ out = rx.op.reshape(arg, [tir_n, tir_m])
+ bb.emit_func_output(out)
+ func = bb.get()["main"]
+
+ defined_vars = set(rx.analysis.defined_symbolic_vars(func))
+ free_vars = set(rx.analysis.free_symbolic_vars(func))
+ assert defined_vars == {tir_n, tir_m}
+ assert free_vars == set()
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_bind_symbolic_vars.py b/tests/python/relax/test_bind_symbolic_vars.py
index 1dc1189a67..82798c56df 100644
--- a/tests/python/relax/test_bind_symbolic_vars.py
+++ b/tests/python/relax/test_bind_symbolic_vars.py
@@ -163,7 +163,7 @@ def test_replacements_may_produce_new_symbolic_vars():
tvm.ir.assert_structural_equal(expected, after)
-def test_bind_symbolic_vars_in_shape():
+def test_bind_symbolic_vars_in_tensor_shape():
"""The bound variable should be replaced when appearing in struct info"""
@R.function(private=True)
@@ -183,6 +183,91 @@ def test_bind_symbolic_vars_in_shape():
tvm.ir.assert_structural_equal(expected, after)
+def test_bind_symbolic_vars_in_shape_expr():
+ """The bound variable should be replaced when appearing in R.Shape"""
+
+ @R.function(private=True)
+ def before(A: R.Tensor(["M * N"]), x: R.Shape(["M", "N"])):
+ M = T.int64()
+ N = T.int64()
+ B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * N]))
+ return B
+
+ @R.function(private=True)
+ def expected(A: R.Tensor(["M * 16"]), x: R.Shape(["M", 16])):
+ B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([M * 32]))
+ return B
+
+ after = before.bind_symbolic_vars({"N": 16})
+ tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_bind_defining_of_symbolic_vars_in_prim_value():
+ """R.Prim may define symbolic variables
+
+ This case is a bit odd, because it always results in a
+ fully-constrained parameter at the relax level. After binding in
+ this test case, we have a function that accepts three parameters,
+ and the third parameter must always be the number 16.
+
+ However, this provides the most consistent behavior with other
+ uses of `relax.Function.bind_symbolic_vars`, which restricts the
+ allowed values for each parameter, but does not alter the number
+ of parameters. This is in contrast to the `BindParams` pass,
+ which provides a known value for relax parameters, removing them
+ from the function signature.
+
+ This convention also prevents surprise changes to the function
+ signature, such as shown in
+ `test_bind_symbolic_vars_with_expr_in_prim_value`.
+ """
+
+ @R.function(private=True)
+ def before(A: R.Tensor(["M * N"]), x: R.Prim(value="M"), y: R.Prim(value="N")):
+ M = T.int64()
+ N = T.int64()
+ B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * N]))
+ return B
+
+ @R.function(private=True)
+ def expected(A: R.Tensor(["M * 16"]), x: R.Prim(value="M"), y: R.Prim(value=16)):
+ B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([M * 32]))
+ return B
+
+ after = before.bind_symbolic_vars({"N": 16})
+ tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_bind_usage_of_symbolic_vars_in_prim_value():
+ """R.Prim may use symbolic variables defined by other parameters
+
+ Like test_bind_defining_of_symbolic_vars_in_prim_value, but with
+ R.Prim using a symbolic variable rather than defining it.
+
+ This also demonstrates why we should not remove fully-constrained
+ R.Prim function parameters. In this case, we have a function that
+ accepts two parameters, and we have specialized the shape of the
+ first parameter. It would be unexpected for specialization of the
+ first parameter to result in removal of a different parameter
+ altogether.
+ """
+
+ @R.function(private=True)
+ def before(A: R.Tensor(["M", "N"]), x: R.Prim(value="M*N")):
+ M = T.int64()
+ N = T.int64()
+ B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * N]))
+ return B
+
+ @R.function(private=True)
+ def expected(A: R.Tensor([16, 16]), x: R.Prim(value=256)):
+ B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([512]))
+ return B
+
+ after = before.bind_symbolic_vars({"M": 16, "N": 16})
+ tvm.ir.assert_structural_equal(expected, after)
+
+
def test_bind_strided_slice():
"""relax.op.strided_slice stores PrimExpr attributes"""
diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py
index 902c478561..fbd37b307e 100644
--- a/tests/python/relax/test_expr.py
+++ b/tests/python/relax/test_expr.py
@@ -238,6 +238,23 @@ def test_prim_value():
_check_json_roundtrip(pv)
+def test_prim_value_with_var():
+ n = tir.Var("n", "int64")
+ pv = rx.PrimValue(n)
+ assert pv.value.same_as(n)
+ tvm.ir.assert_structural_equal(pv.struct_info, rx.PrimStructInfo(value=n))
+ _check_equal(pv, rx.PrimValue(n))
+ _check_json_roundtrip(pv)
+
+
+def test_prim_value_with_expr():
+ n = tir.Var("n", "int64")
+ pv = rx.PrimValue(n + 1)
+ tvm.ir.assert_structural_equal(pv.struct_info, rx.PrimStructInfo(value=n + 1))
+ _check_equal(pv, rx.PrimValue(n + 1))
+ _check_json_roundtrip(pv)
+
+
def test_string_imm():
s0 = rx.StringImm("hello")
s1 = rx.StringImm("hello")
diff --git a/tests/python/relax/test_struct_info.py b/tests/python/relax/test_struct_info.py
index 80ebc3cb18..33dcd7e9d7 100644
--- a/tests/python/relax/test_struct_info.py
+++ b/tests/python/relax/test_struct_info.py
@@ -86,7 +86,23 @@ def test_prim_struct_info():
# wrong API constructors
with pytest.raises(TVMError):
- rx.PrimStructInfo(1)
+ rx.PrimStructInfo([1])
+
+
+def test_prim_struct_info_with_expr():
+ n = tir.Var("n", "int64")
+ sinfo = rx.PrimStructInfo(value=n + 1)
+
+ _check_equal(sinfo, rx.PrimStructInfo(value=n + 1))
+ assert not tvm.ir.structural_equal(sinfo, rx.PrimStructInfo(dtype=n.dtype))
+
+ # can turn into str
+ str(sinfo)
+
+ assert isinstance(sinfo, rx.PrimStructInfo)
+ _check_json_roundtrip(sinfo)
+
+ assert sinfo.dtype == "int64"
def test_shape_struct_info():
diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py
index 0295b6f1c2..d86b1e0108 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1190,8 +1190,9 @@ def test_empty_tuple():
_check(foo, bb.get()["foo"])
-def test_symbolic_shape_computing():
- # Tensor Case 1
+def test_symbolic_vars_in_tensor_shape_with_usage_first():
+ """First param may use symbolic variable defined in second param"""
+
@R.function
def foo(x: R.Tensor(("m + 1",), "float32"), y: R.Tensor(("m", 1), "float32")):
z = R.add(x, y)
@@ -1207,7 +1208,10 @@ def test_symbolic_shape_computing():
_check(foo, bb.get()["foo"])
- # Tensor Case 2
+
+def test_symbolic_vars_in_tensor_shape_with_definition_first():
+ """Second param may use symbolic variable defined in first param"""
+
@R.function
def bar(
x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32")
@@ -1230,7 +1234,10 @@ def test_symbolic_shape_computing():
_check(bar, bb.get()["bar"])
- # Shape Case
+
+def test_symbolic_vars_in_shape():
+ """Symbolic variable may be defined in R.Shape"""
+
@R.function
def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")):
m = T.int64()
@@ -1247,7 +1254,36 @@ def test_symbolic_shape_computing():
_check(baz, bb.get()["baz"])
- # Error Case
+
+def test_symbolic_vars_in_prim_value():
+ """Symbolic variable may be defined in R.Prim"""
+
+ @R.function
+ def baz(x: R.Prim(value="m"), y: R.Tensor(("m * 2",), "float32")):
+ m = T.int64()
+ z = R.call_dps_packed("test_intrin", y, R.Tensor((m * 2,), dtype="float32"))
+ return z
+
+ m = tir.Var("m", "int64")
+ x = relax.Var("x", relax.PrimStructInfo(value=m))
+ y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("baz", (x, y)):
+ z = bb.emit(relax.call_dps_packed("test_intrin", (y), R.Tensor((m * 2,), dtype="float32")))
+ bb.emit_func_output(z)
+
+ _check(baz, bb.get()["baz"])
+
+
+def test_undefined_symbolic_var_raises_error():
+ """An undefined symbolic variable in an error
+
+ A symbolic variables is defined at the first site where it appears
+ as a shape parameter without any modification. TVMScript does not
+ support solving for a symbolic variable in terms of the argument
+ shape. That is, this test case raises an error, and will not
+ attempt to define `m` as either `x.shape[0]-1` or `x.shape[1]//2`.
+ """
with pytest.raises(tvm.error.DiagnosticError):
@R.function
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py
index 9f4ffd9acd..2e4218b2ab 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -205,6 +205,7 @@ def test_func_struct_info():
relax.PrimStructInfo("float32"),
relax.ObjectStructInfo(),
relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3]),
+ relax.PrimStructInfo(value=tir.Var("b", "int64")),
],
ret=relax.TensorStructInfo(
shape=relax.ShapeExpr([1, 2, 3]),
@@ -214,7 +215,8 @@ def test_func_struct_info():
_assert_print(
obj,
"a = T.int64()\n"
- 'R.Callable((R.Prim("float32"), R.Object, R.Shape([1, a, 3])), '
+ "b = T.int64()\n"
+ 'R.Callable((R.Prim("float32"), R.Object, R.Shape([1, a, 3]), R.Prim(value=b)), '
'R.Tensor((1, 2, 3), dtype="float32"), True)',
)