You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2023/09/07 15:04:54 UTC

[tvm] branch unity updated: [Unity] Added known tir.Expr to relax.PrimValue (#15577)

This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 755af1fec7 [Unity] Added known tir.Expr to relax.PrimValue (#15577)
755af1fec7 is described below

commit 755af1fec722d7efa0b3b00fbd07eaac168a16b2
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Thu Sep 7 08:04:46 2023 -0700

    [Unity] Added known tir.Expr to relax.PrimValue (#15577)
    
    Prior to this commit, a `relax.PrimValue` could have a datatype, but
    couldn't have a corresponding `tir.PrimExpr`.  As a result, it could
    not be used to specify tensor shapes.  This makes some expressions
    require fallback to `R.Tensor(ndim=ndim)`, even though the shape could
    still be inferred.
    
    ```python
    @R.function
    def func(
        A: R.Tensor(16, 16),
        first_n_rows: R.prim("int64"),
    ) -> R.Tensor([first_n_rows, 16]):
        #          ^^^^^^^^^^^^
        #          R.Tensor requires a PrimExpr, not relax.Expr
        #
        #                               Operations may require PrimExpr
        #                                                  vvvvvvvvvvvv
        out = R.op.strided_slice(axis=[0], begin=[0], end=[first_n_rows])
        return out
    ```
    
    This commit adds a `Optional<PrimExpr> value` field to the
    `PrimStructInfo`.  This field acts similarly to the `PrimExpr` fields
    already used in `ShapeStructInfo`, and may contain symbolic variables.
    
    ```python
    @R.function
    def func(
        A: R.Tensor(16, 16),
    
        # TIR definitions in signature allow in-line definitions,
        # similar to R.Tensor and R.Shape.  R.Prim takes `dtype` or
        # `value` kwarg to distinguish between in-line symbolic variable
        # and string representation of dtype.
        first_n_rows: R.prim(value="first_n_rows_tir"),
    ) -> R.Tensor(["first_n_rows_tir", 16]):
    
        # Body contains a TIR variable definition, which may be used
        # in function calls, inferred shape annotations.
        first_n_rows_tir = T.int64()
        out = R.op.strided_slice(axis=[0], begin=[0], end=[first_n_rows])
        return out
    ```
    
    Use distinct PrimStructInfo arguments for dtype/value
    
    Update TVMScript printer
    
    Parser updates, Support R.Prim(value=...) annotations in function signature
    
    * Added unit tests for new functionality in API, parser, printer
    
    * Add unit tests for bind_symbolic_vars
    
    * Add test cases to valid bind_symbolic_vars
---
 include/tvm/relax/struct_info.h                    | 15 +++-
 python/tvm/relax/struct_info.py                    | 63 +++++++++++++++-
 python/tvm/script/parser/relax/entry.py            | 43 ++++++++---
 src/relax/analysis/struct_info_analysis.cc         |  3 +
 src/relax/ir/expr.cc                               |  2 +-
 src/relax/ir/expr_functor.cc                       |  3 +
 src/relax/ir/struct_info.cc                        | 15 +++-
 src/relax/ir/struct_info_functor.cc                | 17 ++++-
 src/script/printer/relax/struct_info.cc            | 23 ++++--
 .../relax/test_analysis_struct_info_analysis.py    | 43 ++++++++++-
 tests/python/relax/test_bind_symbolic_vars.py      | 87 +++++++++++++++++++++-
 tests/python/relax/test_expr.py                    | 17 +++++
 tests/python/relax/test_struct_info.py             | 18 ++++-
 tests/python/relax/test_tvmscript_parser.py        | 46 ++++++++++--
 tests/python/relax/test_tvmscript_printer_relax.py |  4 +-
 15 files changed, 365 insertions(+), 34 deletions(-)

diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h
index d2bf525225..2e224f1830 100644
--- a/include/tvm/relax/struct_info.h
+++ b/include/tvm/relax/struct_info.h
@@ -60,19 +60,26 @@ class ObjectStructInfo : public StructInfo {
  */
 class PrimStructInfoNode : public StructInfoNode {
  public:
+  /*! \brief Underlying primitive value, if known */
+  Optional<PrimExpr> value;
+
   /*! \brief Underlying data type of the primitive value */
   DataType dtype;
 
   void VisitAttrs(AttrVisitor* v) {
+    v->Visit("value", &value);
     v->Visit("dtype", &dtype);
     v->Visit("span", &span);
   }
 
   bool SEqualReduce(const PrimStructInfoNode* other, SEqualReducer equal) const {
-    return equal(dtype, other->dtype);
+    return equal(value, other->value) && equal(dtype, other->dtype);
   }
 
-  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); }
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(value);
+    hash_reduce(dtype);
+  }
 
   static constexpr const char* _type_key = "relax.PrimStructInfo";
   TVM_DECLARE_FINAL_OBJECT_INFO(PrimStructInfoNode, StructInfoNode);
@@ -84,8 +91,12 @@ class PrimStructInfoNode : public StructInfoNode {
  */
 class PrimStructInfo : public StructInfo {
  public:
+  /* Construct a PrimStructInfo with a known dtype, but unknown value */
   TVM_DLL PrimStructInfo(DataType dtype, Span span = Span());
 
+  /* Construct a PrimStructInfo with a known value */
+  TVM_DLL PrimStructInfo(PrimExpr value, Span span = Span());
+
   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimStructInfo, StructInfo, PrimStructInfoNode);
 };
 
diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py
index e78e1cf69a..38a4ab7491 100644
--- a/python/tvm/relax/struct_info.py
+++ b/python/tvm/relax/struct_info.py
@@ -23,6 +23,7 @@ import tvm
 
 from tvm.ir import Span, EnvFunc, Array, VDevice
 from tvm.tir import PrimExpr
+from tvm.runtime import DataType
 from .expr import StructInfo, Expr, ShapeExpr
 
 from . import _ffi_api, ty, expr
@@ -42,14 +43,68 @@ class PrimStructInfo(StructInfo):
 
     Parameters
     ----------
-    dtype : str
-       The data type of the prim value.
+    dtype_or_expr : Union[str, DataType, PrimExpr]
+
+       The data type of the prim value, or a known expression for the prim
+       value.
     """
 
+    value: Optional[PrimExpr]
     dtype: str
 
-    def __init__(self, dtype: str, span: Span = None) -> None:
-        self.__init_handle_by_constructor__(_ffi_api.PrimStructInfo, dtype, span)  # type: ignore
+    def __init__(
+        self,
+        dtype: Optional[Union[str, DataType]] = None,
+        value: Optional[Union[int, float, PrimExpr]] = None,
+        span: Span = None,
+    ) -> None:
+        # Guard against incorrect usage.  For backwards compatibility,
+        # the dtype and value are in the opposite order from most
+        # usages.  While PrimStructInfo could take a single positional
+        # argument and check the type, this would require an API
+        # difference from TVMScript's PrimProxy, which cannot.
+        # (PrimProxy uses string arguments for datatype, and also for
+        # inline variable definitions when used in a function
+        # signature, and requires separate arguments to distinguish
+        # the two cases.)
+        if isinstance(dtype, (PrimExpr, int, float)):
+            raise TypeError(
+                f"The first positional argument of PrimStructInfo must be the datatype, "
+                f", but received {type(dtype)}.  "
+                f"The value can be specified as a keyword argument "
+                f"without needing specifying the dtype: "
+                f"PrimStructInfo(value=arg)."
+            )
+
+        if dtype is None and value is None:
+            raise TypeError(
+                "PrimStructInfo.__init__ missing required argument.  "
+                "Must provide either 'dtype' or 'value'"
+            )
+
+        if dtype is not None:
+            if isinstance(value, PrimExpr):
+                assert value.dtype == dtype, (
+                    "When providing both 'value' and 'dtype' to PrimStructInfo.__init__, "
+                    "they must be consistent with each other.  "
+                    "However, the value {value} has dtype {value.dtype}, "
+                    "but the specified dtype was {dtype}."
+                )
+            elif isinstance(value, (int, float)):
+                value = tvm.tir.const(value, dtype)
+
+        # Use relax's default integer type if not otherwise specified.
+        if isinstance(value, int):
+            value = tvm.tir.IntImm("int64", value)
+
+        if value is None:
+            self.__init_handle_by_constructor__(
+                _ffi_api.PrimStructInfoFromDtype, dtype, span
+            )  # type: ignore
+        else:
+            self.__init_handle_by_constructor__(
+                _ffi_api.PrimStructInfoFromValue, value, span
+            )  # type: ignore
 
 
 @tvm._ffi.register_object("relax.ShapeStructInfo")
diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py
index 85f641e838..d3e2342750 100644
--- a/python/tvm/script/parser/relax/entry.py
+++ b/python/tvm/script/parser/relax/entry.py
@@ -359,7 +359,7 @@ class ShapeProxy(StructInfoProxy):
 
     def get_symbolic_vars(self) -> Set[str]:
         if self.values is None:
-            return {}
+            return set()
         else:
             return {v for v in self.values if isinstance(v, str) and v.isidentifier()}
 
@@ -405,27 +405,52 @@ def Object() -> ObjectProxy:
 
 
 class PrimProxy(StructInfoProxy):
-    dtype: str
-    """The type of shape values.
+    dtype: Optional[str]
+    value: Optional[Union[int, float, str, PrimExpr]]
+
+    """The type of TIR-representable values.
 
     Parameters
     ----------
-    dtype : str
+    dtype : Optional[str]
        The data type.
+
+    value: Optional[Union[int, float, str, PrimExpr]]
+       The known value
     """
 
-    def __init__(self, dtype: str) -> None:
+    def __init__(
+        self,
+        dtype: Optional[str] = None,
+        value: Optional[Union[int, float, str, PrimExpr]] = None,
+    ) -> None:
+        if dtype is None and value is None:
+            raise TypeError(
+                "R.Prim missing required argument.  " "Must provide either 'dtype' or 'value'"
+            )
+
         self.dtype = dtype
+        self.value = value
 
     def get_symbolic_vars(self) -> Set[str]:
-        return set()
+        if isinstance(self.value, str) and self.value.isidentifier():
+            return {self.value}
+        else:
+            return set()
 
     def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo:
-        return PrimStructInfo(self.dtype)
+        if self.value is None:
+            return PrimStructInfo(dtype=self.dtype)
+        else:
+            value = _eval_shape(self.value, dict_globals)
+            return PrimStructInfo(dtype=self.dtype, value=value)
 
 
-def Prim(dtype: str) -> PrimProxy:
-    return PrimProxy(dtype)
+def Prim(
+    dtype: Optional[str] = None,
+    value: Optional[Union[int, float, str, PrimExpr]] = None,
+) -> PrimProxy:
+    return PrimProxy(dtype, value)
 
 
 ############################ R.match_cast #############################
diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc
index 4a633e9df4..ddb3fdb5c1 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -1062,6 +1062,9 @@ class SymbolicVarCollector : public relax::ExprVisitor,
         this->VisitStructInfoExprField(val);
       }
     }
+    if (auto prim_value = expr.as<relax::PrimValue>()) {
+      this->VisitStructInfoExprField(prim_value.value()->value);
+    }
   }
 
   void VisitStructInfoExprField(const PrimExpr& expr) final {
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index ac04096aaf..ac3e532289 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -311,7 +311,7 @@ TVM_REGISTER_GLOBAL("relax.Constant")
 PrimValue::PrimValue(PrimExpr value, Span span) {
   ObjectPtr<PrimValueNode> n = make_object<PrimValueNode>();
   n->checked_type_ = PrimType(value.dtype());
-  n->struct_info_ = PrimStructInfo(value.dtype());
+  n->struct_info_ = PrimStructInfo(value);
   n->value = std::move(value);
   n->span = std::move(span);
   data_ = std::move(n);
diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc
index f0f0d29b51..0174308802 100644
--- a/src/relax/ir/expr_functor.cc
+++ b/src/relax/ir/expr_functor.cc
@@ -225,6 +225,9 @@ void ExprVisitor::VisitExpr_(const SeqExprNode* op) {
 
 void ExprVisitor::VisitExpr_(const PrimValueNode* op) {
   this->VisitPrimExpr(op->value);
+  if (auto* sinfo = op->struct_info_.as<StructInfoNode>()) {
+    this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo));
+  }
   this->VisitSpan(op->span);
 }
 
diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc
index 31784af000..9b635bb479 100644
--- a/src/relax/ir/struct_info.cc
+++ b/src/relax/ir/struct_info.cc
@@ -42,19 +42,32 @@ TVM_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) {
 });
 
 // Prim
+PrimStructInfo::PrimStructInfo(PrimExpr value, Span span) {
+  ObjectPtr<PrimStructInfoNode> n = make_object<PrimStructInfoNode>();
+  n->dtype = value->dtype;
+  n->value = std::move(value);
+  n->span = span;
+  data_ = std::move(n);
+}
+
 PrimStructInfo::PrimStructInfo(DataType dtype, Span span) {
   ObjectPtr<PrimStructInfoNode> n = make_object<PrimStructInfoNode>();
   n->dtype = dtype;
+  n->value = NullOpt;
   n->span = span;
   data_ = std::move(n);
 }
 
 TVM_REGISTER_NODE_TYPE(PrimStructInfoNode);
 
-TVM_REGISTER_GLOBAL("relax.PrimStructInfo").set_body_typed([](DataType dtype, Span span) {
+TVM_REGISTER_GLOBAL("relax.PrimStructInfoFromDtype").set_body_typed([](DataType dtype, Span span) {
   return PrimStructInfo(dtype, span);
 });
 
+TVM_REGISTER_GLOBAL("relax.PrimStructInfoFromValue").set_body_typed([](PrimExpr value, Span span) {
+  return PrimStructInfo(value, span);
+});
+
 // Shape
 ShapeStructInfo::ShapeStructInfo(Array<PrimExpr> values, Span span) {
   ObjectPtr<ShapeStructInfoNode> n = make_object<ShapeStructInfoNode>();
diff --git a/src/relax/ir/struct_info_functor.cc b/src/relax/ir/struct_info_functor.cc
index d7929e0f1a..ea8f1da8f0 100644
--- a/src/relax/ir/struct_info_functor.cc
+++ b/src/relax/ir/struct_info_functor.cc
@@ -28,7 +28,11 @@ namespace relax {
 
 void StructInfoVisitor::VisitStructInfo_(const ObjectStructInfoNode* op) {}
 
-void StructInfoVisitor::VisitStructInfo_(const PrimStructInfoNode* op) {}
+void StructInfoVisitor::VisitStructInfo_(const PrimStructInfoNode* op) {
+  if (op->value.defined()) {
+    this->VisitStructInfoExprField(op->value.value());
+  }
+}
 
 void StructInfoVisitor::VisitStructInfo_(const ShapeStructInfoNode* op) {
   if (op->values.defined()) {
@@ -68,7 +72,16 @@ StructInfo StructInfoMutator::VisitStructInfo_(const ObjectStructInfoNode* op) {
 }
 
 StructInfo StructInfoMutator::VisitStructInfo_(const PrimStructInfoNode* op) {
-  return GetRef<StructInfo>(op);
+  if (!op->value.defined()) {
+    return GetRef<StructInfo>(op);
+  }
+
+  auto new_expr = VisitStructInfoExprField(op->value.value());
+  if (new_expr.same_as(op->value)) {
+    return GetRef<StructInfo>(op);
+  } else {
+    return PrimStructInfo(new_expr);
+  }
 }
 
 StructInfo StructInfoMutator::VisitStructInfo_(const ShapeStructInfoNode* op) {
diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc
index cccf9ed08b..11f987c368 100644
--- a/src/script/printer/relax/struct_info.cc
+++ b/src/script/printer/relax/struct_info.cc
@@ -30,12 +30,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
           return Relax(d, "Object");
         });
 
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
-    .set_dispatch<relax::PrimStructInfo>(
-        "", [](relax::PrimStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc {
-          return Relax(d, "Prim")->Call({LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))});
-        });
-
 ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const IRDocsifier& d) {
   ExprDoc expr_doc = d->AsDoc<ExprDoc>(e, e_p);
   // Step 1. Find if `func_vars` are being collected
@@ -66,6 +60,23 @@ ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const IRDocsifie
   return expr_doc;
 }
 
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+    .set_dispatch<relax::PrimStructInfo>(
+        "", [](relax::PrimStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc {
+          Array<ExprDoc, void> args;
+          Array<String> kwargs_keys;
+          Array<ExprDoc, void> kwargs_values;
+
+          if (n->value.defined()) {
+            kwargs_keys.push_back("value");
+            kwargs_values.push_back(PrintShapeVar(n->value.value(), n_p->Attr("value"), d));
+          } else {
+            args.push_back(LiteralDoc::DataType(n->dtype, n_p->Attr("dtype")));
+          }
+
+          return Relax(d, "Prim")->Call(args, kwargs_keys, kwargs_values);
+        });
+
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<relax::ShapeStructInfo>(
         "", [](relax::ShapeStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc {
diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py
index 879194037c..1b1ea2e53e 100644
--- a/tests/python/relax/test_analysis_struct_info_analysis.py
+++ b/tests/python/relax/test_analysis_struct_info_analysis.py
@@ -636,7 +636,7 @@ def test_tir_vars_in_struct_info():
     tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(func), [n, m])
 
 
-def test_symbolic_var_collector():
+def test_collect_symbolic_var_from_tensor_shape():
     n, m, k, q, p = (
         tir.Var("n", "int64"),
         tir.Var("m", "int64"),
@@ -658,5 +658,46 @@ def test_symbolic_var_collector():
     assert free_vars == {n, p, q}
 
 
+param_type = tvm.testing.parameter("shape_expr", "prim_value")
+param_order = tvm.testing.parameter("definition_first", "usage_first")
+
+
+def test_collect_symbolic_var_from_non_tensor_params(param_type, param_order):
+    tir_n = tir.Var("n", "int64")
+    tir_m = tir.Var("m", "int64")
+
+    bb = rx.BlockBuilder()
+    arg = rx.Var("arg", rx.TensorStructInfo([tir_n * tir_m]))
+
+    if param_type == "shape_expr":
+        extra_params = [
+            rx.Var("shape_expr", rx.ShapeStructInfo([tir_n, tir_m])),
+        ]
+    elif param_type == "prim_value":
+        extra_params = [
+            rx.Var("n", rx.PrimStructInfo(value=tir_n)),
+            rx.Var("m", rx.PrimStructInfo(value=tir_m)),
+        ]
+    else:
+        raise ValueError(f"Unknown param_type: {param_type}")
+
+    if param_order == "definition_first":
+        params = [*extra_params, arg]
+    elif param_order == "usage_first":
+        params = [arg, *extra_params]
+    else:
+        raise ValueError(f"Unknown param_order: {param_order}")
+
+    with bb.function("main", params=params):
+        out = rx.op.reshape(arg, [tir_n, tir_m])
+        bb.emit_func_output(out)
+    func = bb.get()["main"]
+
+    defined_vars = set(rx.analysis.defined_symbolic_vars(func))
+    free_vars = set(rx.analysis.free_symbolic_vars(func))
+    assert defined_vars == {tir_n, tir_m}
+    assert free_vars == set()
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_bind_symbolic_vars.py b/tests/python/relax/test_bind_symbolic_vars.py
index 1dc1189a67..82798c56df 100644
--- a/tests/python/relax/test_bind_symbolic_vars.py
+++ b/tests/python/relax/test_bind_symbolic_vars.py
@@ -163,7 +163,7 @@ def test_replacements_may_produce_new_symbolic_vars():
     tvm.ir.assert_structural_equal(expected, after)
 
 
-def test_bind_symbolic_vars_in_shape():
+def test_bind_symbolic_vars_in_tensor_shape():
     """The bound variable should be replaced when appearing in struct info"""
 
     @R.function(private=True)
@@ -183,6 +183,91 @@ def test_bind_symbolic_vars_in_shape():
     tvm.ir.assert_structural_equal(expected, after)
 
 
+def test_bind_symbolic_vars_in_shape_expr():
+    """The bound variable should be replaced when appearing in R.Shape"""
+
+    @R.function(private=True)
+    def before(A: R.Tensor(["M * N"]), x: R.Shape(["M", "N"])):
+        M = T.int64()
+        N = T.int64()
+        B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * N]))
+        return B
+
+    @R.function(private=True)
+    def expected(A: R.Tensor(["M * 16"]), x: R.Shape(["M", 16])):
+        B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([M * 32]))
+        return B
+
+    after = before.bind_symbolic_vars({"N": 16})
+    tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_bind_defining_of_symbolic_vars_in_prim_value():
+    """R.Prim may define symbolic variables
+
+    This case is a bit odd, because it always results in a
+    fully-constrained parameter at the relax level.  After binding in
+    this test case, we have a function that accepts three parameters,
+    and the third parameter must always be the number 16.
+
+    However, this provides the most consistent behavior with other
+    uses of `relax.Function.bind_symbolic_vars`, which restricts the
+    allowed values for each parameter, but does not alter the number
+    of parameters.  This is in contrast to the `BindParams` pass,
+    which provides a known value for relax parameters, removing them
+    from the function signature.
+
+    This convention also prevents surprise changes to the function
+    signature, such as shown in
+    `test_bind_symbolic_vars_with_expr_in_prim_value`.
+    """
+
+    @R.function(private=True)
+    def before(A: R.Tensor(["M * N"]), x: R.Prim(value="M"), y: R.Prim(value="N")):
+        M = T.int64()
+        N = T.int64()
+        B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * N]))
+        return B
+
+    @R.function(private=True)
+    def expected(A: R.Tensor(["M * 16"]), x: R.Prim(value="M"), y: R.Prim(value=16)):
+        B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([M * 32]))
+        return B
+
+    after = before.bind_symbolic_vars({"N": 16})
+    tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_bind_usage_of_symbolic_vars_in_prim_value():
+    """R.Prim may use symbolic variables defined by other parameters
+
+    Like test_bind_defining_of_symbolic_vars_in_prim_value, but with
+    R.Prim using a symbolic variable rather than defining it.
+
+    This also demonstrates why we should not remove fully-constrained
+    R.Prim function parameters.  In this case, we have a function that
+    accepts two parameters, and we have specialized the shape of the
+    first parameter.  It would be unexpected for specialization of the
+    first parameter to result in removal of a different parameter
+    altogether.
+    """
+
+    @R.function(private=True)
+    def before(A: R.Tensor(["M", "N"]), x: R.Prim(value="M*N")):
+        M = T.int64()
+        N = T.int64()
+        B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * N]))
+        return B
+
+    @R.function(private=True)
+    def expected(A: R.Tensor([16, 16]), x: R.Prim(value=256)):
+        B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([512]))
+        return B
+
+    after = before.bind_symbolic_vars({"M": 16, "N": 16})
+    tvm.ir.assert_structural_equal(expected, after)
+
+
 def test_bind_strided_slice():
     """relax.op.strided_slice stores PrimExpr attributes"""
 
diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py
index 902c478561..fbd37b307e 100644
--- a/tests/python/relax/test_expr.py
+++ b/tests/python/relax/test_expr.py
@@ -238,6 +238,23 @@ def test_prim_value():
     _check_json_roundtrip(pv)
 
 
+def test_prim_value_with_var():
+    n = tir.Var("n", "int64")
+    pv = rx.PrimValue(n)
+    assert pv.value.same_as(n)
+    tvm.ir.assert_structural_equal(pv.struct_info, rx.PrimStructInfo(value=n))
+    _check_equal(pv, rx.PrimValue(n))
+    _check_json_roundtrip(pv)
+
+
+def test_prim_value_with_expr():
+    n = tir.Var("n", "int64")
+    pv = rx.PrimValue(n + 1)
+    tvm.ir.assert_structural_equal(pv.struct_info, rx.PrimStructInfo(value=n + 1))
+    _check_equal(pv, rx.PrimValue(n + 1))
+    _check_json_roundtrip(pv)
+
+
 def test_string_imm():
     s0 = rx.StringImm("hello")
     s1 = rx.StringImm("hello")
diff --git a/tests/python/relax/test_struct_info.py b/tests/python/relax/test_struct_info.py
index 80ebc3cb18..33dcd7e9d7 100644
--- a/tests/python/relax/test_struct_info.py
+++ b/tests/python/relax/test_struct_info.py
@@ -86,7 +86,23 @@ def test_prim_struct_info():
 
     # wrong API constructors
     with pytest.raises(TVMError):
-        rx.PrimStructInfo(1)
+        rx.PrimStructInfo([1])
+
+
+def test_prim_struct_info_with_expr():
+    n = tir.Var("n", "int64")
+    sinfo = rx.PrimStructInfo(value=n + 1)
+
+    _check_equal(sinfo, rx.PrimStructInfo(value=n + 1))
+    assert not tvm.ir.structural_equal(sinfo, rx.PrimStructInfo(dtype=n.dtype))
+
+    # can turn into str
+    str(sinfo)
+
+    assert isinstance(sinfo, rx.PrimStructInfo)
+    _check_json_roundtrip(sinfo)
+
+    assert sinfo.dtype == "int64"
 
 
 def test_shape_struct_info():
diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py
index 0295b6f1c2..d86b1e0108 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1190,8 +1190,9 @@ def test_empty_tuple():
     _check(foo, bb.get()["foo"])
 
 
-def test_symbolic_shape_computing():
-    # Tensor Case 1
+def test_symbolic_vars_in_tensor_shape_with_usage_first():
+    """First param may use symbolic variable defined in second param"""
+
     @R.function
     def foo(x: R.Tensor(("m + 1",), "float32"), y: R.Tensor(("m", 1), "float32")):
         z = R.add(x, y)
@@ -1207,7 +1208,10 @@ def test_symbolic_shape_computing():
 
     _check(foo, bb.get()["foo"])
 
-    # Tensor Case 2
+
+def test_symbolic_vars_in_tensor_shape_with_definition_first():
+    """Second param may use symbolic variable defined in first param"""
+
     @R.function
     def bar(
         x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32")
@@ -1230,7 +1234,10 @@ def test_symbolic_shape_computing():
 
     _check(bar, bb.get()["bar"])
 
-    # Shape Case
+
+def test_symbolic_vars_in_shape():
+    """Symbolic variable may be defined in R.Shape"""
+
     @R.function
     def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")):
         m = T.int64()
@@ -1247,7 +1254,36 @@ def test_symbolic_shape_computing():
 
     _check(baz, bb.get()["baz"])
 
-    # Error Case
+
+def test_symbolic_vars_in_prim_value():
+    """Symbolic variable may be defined in R.Prim"""
+
+    @R.function
+    def baz(x: R.Prim(value="m"), y: R.Tensor(("m * 2",), "float32")):
+        m = T.int64()
+        z = R.call_dps_packed("test_intrin", y, R.Tensor((m * 2,), dtype="float32"))
+        return z
+
+    m = tir.Var("m", "int64")
+    x = relax.Var("x", relax.PrimStructInfo(value=m))
+    y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("baz", (x, y)):
+        z = bb.emit(relax.call_dps_packed("test_intrin", (y), R.Tensor((m * 2,), dtype="float32")))
+        bb.emit_func_output(z)
+
+    _check(baz, bb.get()["baz"])
+
+
+def test_undefined_symbolic_var_raises_error():
+    """An undefined symbolic variable in an error
+
+    A symbolic variables is defined at the first site where it appears
+    as a shape parameter without any modification.  TVMScript does not
+    support solving for a symbolic variable in terms of the argument
+    shape.  That is, this test case raises an error, and will not
+    attempt to define `m` as either `x.shape[0]-1` or `x.shape[1]//2`.
+    """
     with pytest.raises(tvm.error.DiagnosticError):
 
         @R.function
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py
index 9f4ffd9acd..2e4218b2ab 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -205,6 +205,7 @@ def test_func_struct_info():
             relax.PrimStructInfo("float32"),
             relax.ObjectStructInfo(),
             relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3]),
+            relax.PrimStructInfo(value=tir.Var("b", "int64")),
         ],
         ret=relax.TensorStructInfo(
             shape=relax.ShapeExpr([1, 2, 3]),
@@ -214,7 +215,8 @@ def test_func_struct_info():
     _assert_print(
         obj,
         "a = T.int64()\n"
-        'R.Callable((R.Prim("float32"), R.Object, R.Shape([1, a, 3])), '
+        "b = T.int64()\n"
+        'R.Callable((R.Prim("float32"), R.Object, R.Shape([1, a, 3]), R.Prim(value=b)), '
         'R.Tensor((1, 2, 3), dtype="float32"), True)',
     )