You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ho...@apache.org on 2023/04/05 02:18:20 UTC
[tvm] branch unity updated: [Unity] LiftTransformParams with symbolic shape robustness (#14500)
This is an automated email from the ASF dual-hosted git repository.
hongyij pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 1eea30d54a [Unity] LiftTransformParams with symbolic shape robustness (#14500)
1eea30d54a is described below
commit 1eea30d54ae6cedb4b7a562a1c51415927750e82
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Tue Apr 4 22:18:13 2023 -0400
[Unity] LiftTransformParams with symbolic shape robustness (#14500)
This PR ensures that the LiftTransformParams pass will not lift bindings
that involve symbolic shape, because symbolic shape computation cannot
be performed at compilation time.
---
src/relax/transform/lift_transform_params.cc | 53 +++++++++++-
.../relax/test_transform_lift_transform_params.py | 96 ++++++++++++++++++++++
2 files changed, 148 insertions(+), 1 deletion(-)
diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc
index 88939bd1f5..e0296e6ae5 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -132,6 +132,46 @@ class TransformParamsFuncBuilder : public ExprMutator {
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> outputs_;
};
+bool SInfoContainsSymVar(StructInfo sinfo) {
+ struct SymVarDetector : public StructInfoVisitor {
+ void VisitStructInfo(const StructInfo& sinfo) final {
+ if (contains_sym_var) {
+ return;
+ }
+ StructInfoVisitor::VisitStructInfo(sinfo);
+ }
+
+ bool CheckShape(Array<PrimExpr> shape) {
+ for (const PrimExpr& value : shape) {
+ const auto* int_imm = value.as<IntImmNode>();
+ if (int_imm == nullptr) {
+ contains_sym_var = true;
+ return false;
+ }
+ }
+ return true;
+ }
+
+ void VisitStructInfo_(const ShapeStructInfoNode* shape_sinfo) final {
+ if (shape_sinfo->values.defined()) {
+ CheckShape(shape_sinfo->values.value());
+ }
+ }
+
+ void VisitStructInfo_(const TensorStructInfoNode* tensor_sinfo) final {
+ if (tensor_sinfo->shape.defined()) {
+ VisitStructInfo(GetStructInfo(tensor_sinfo->shape.value()));
+ }
+ }
+
+ bool contains_sym_var = false;
+ };
+
+ SymVarDetector detector;
+ detector(sinfo);
+ return detector.contains_sym_var;
+}
+
/*!
* \brief Visitor that creates the plan of lifting transform params.
*
@@ -165,9 +205,13 @@ class LiftTransformParamsPlanner : public ExprVisitor {
void VisitBinding_(const VarBindingNode* binding) final {
std::vector<const VarNode*> producers;
bool can_lift = true;
+
+ // Cond 1. Do not lift bindings outside dataflow blocks.
if (!is_in_dataflow_block_) {
can_lift = false;
}
+
+ // Cond 2. Do not lift regarding the "builtin.stop_lift_params" op.
if (const auto* call = binding->value.as<CallNode>()) {
static const Op& stop_lift_params_op = Op::Get("relax.builtin.stop_lift_params");
if (call->op.same_as(stop_lift_params_op)) {
@@ -175,6 +219,7 @@ class LiftTransformParamsPlanner : public ExprVisitor {
}
}
+ // Cond 3. Do not lift when involving Vars that are not liftable.
PostOrderVisit(binding->value, [&](const ObjectRef& obj) {
if (const VarNode* var = obj.as<VarNode>()) {
producers.push_back(var);
@@ -183,6 +228,12 @@ class LiftTransformParamsPlanner : public ExprVisitor {
}
}
});
+
+ // Cond 4. Do not lift when its struct info contains symbolic variables.
+ if (SInfoContainsSymVar(GetStructInfo(binding->var))) {
+ can_lift = false;
+ }
+
if (can_lift) {
lifted_bindings_.insert(binding->var);
builder_.AddBinding(GetRef<VarBinding>(binding));
@@ -256,7 +307,7 @@ class TransformParamsLifter : public ExprMutator {
for (const auto& [var, index] : lift_plan_.output_to_index) {
param_remap_[var] = TupleGetItem(params, index);
}
- auto new_body = VisitExpr(func->body);
+ auto new_body = VisitWithNewScope(func->body, new_params);
// Step 3.3: Remove function attributes that are not needed
auto new_attrs = func->attrs;
diff --git a/tests/python/relax/test_transform_lift_transform_params.py b/tests/python/relax/test_transform_lift_transform_params.py
index b618948840..2a045e9acb 100644
--- a/tests/python/relax/test_transform_lift_transform_params.py
+++ b/tests/python/relax/test_transform_lift_transform_params.py
@@ -440,5 +440,101 @@ def test_stop_lifting():
tvm.ir.assert_structural_equal(after, Expected)
+def test_symbolic_var():
+ @tvm.script.ir_module
+ class Before1:
+ @R.function
+ def main(shape: R.Shape(["n"])):
+ R.func_attr({"num_input": 1})
+ n = T.int64()
+ with R.dataflow():
+ zeros = R.zeros((n, n), "float32")
+ return shape
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main_transform_params(params: R.Tuple) -> R.Tuple:
+ with R.dataflow():
+ gv: R.Tuple = R.tuple()
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(shape: R.Shape(["n"]), params: R.Tuple) -> R.Shape(["n"]):
+ n = T.int64()
+ with R.dataflow():
+ zeros: R.Tensor((n, n), dtype="float32") = R.zeros(R.shape([n, n]), dtype="float32")
+ R.output()
+ return shape
+
+ @I.ir_module
+ class Before2:
+ @T.prim_func
+ def zeros(var_T_full: T.handle):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ n = T.int64()
+ T_full = T.match_buffer(var_T_full, (n, n))
+ for ax0, ax1 in T.grid(n, n):
+ with T.block("T_full"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads()
+ T.writes(T_full[v_ax0, v_ax1])
+ T_full[v_ax0, v_ax1] = T.float32(0)
+
+ @R.function
+ def main(shape: R.Shape(["n"])) -> R.Shape(["n"]):
+ n = T.int64()
+ R.func_attr({"num_input": 1})
+ cls = Before2
+ with R.dataflow():
+ zeros = R.call_tir(
+ cls.zeros, R.tuple(), out_sinfo=R.Tensor((n, n), dtype="float32")
+ )
+ R.output()
+ return shape
+
+ @I.ir_module
+ class Expected2:
+ @T.prim_func
+ def zeros(var_T_full: T.handle):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ n = T.int64()
+ T_full = T.match_buffer(var_T_full, (n, n))
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(n, n):
+ with T.block("T_full"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads()
+ T.writes(T_full[v_ax0, v_ax1])
+ T_full[v_ax0, v_ax1] = T.float32(0)
+
+ @R.function
+ def main_transform_params(params: R.Tuple) -> R.Tuple:
+ with R.dataflow():
+ gv: R.Tuple = R.tuple()
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(shape: R.Shape(["n"]), params: R.Tuple) -> R.Shape(["n"]):
+ n = T.int64()
+ cls = Expected2
+ with R.dataflow():
+ zeros = R.call_tir(
+ cls.zeros, R.tuple(), out_sinfo=R.Tensor((n, n), dtype="float32")
+ )
+ R.output()
+ return shape
+
+ mod = Before1
+ after = relax.transform.LiftTransformParams()(mod)
+ tvm.ir.assert_structural_equal(after, Expected1)
+
+ mod = Before2
+ after = relax.transform.LiftTransformParams()(mod)
+ tvm.ir.assert_structural_equal(after, Expected2)
+
+
if __name__ == "__main__":
tvm.testing.main()