You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ho...@apache.org on 2023/04/05 02:18:20 UTC

[tvm] branch unity updated: [Unity] LiftTransformParams with symbolic shape robustness (#14500)

This is an automated email from the ASF dual-hosted git repository.

hongyij pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 1eea30d54a [Unity] LiftTransformParams with symbolic shape robustness (#14500)
1eea30d54a is described below

commit 1eea30d54ae6cedb4b7a562a1c51415927750e82
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Tue Apr 4 22:18:13 2023 -0400

    [Unity] LiftTransformParams with symbolic shape robustness (#14500)
    
    This PR ensures that the LiftTransformParams pass will not lift bindings
    that involve symbolic shape, because symbolic shape computation cannot
    be performed at compilation time.
---
 src/relax/transform/lift_transform_params.cc       | 53 +++++++++++-
 .../relax/test_transform_lift_transform_params.py  | 96 ++++++++++++++++++++++
 2 files changed, 148 insertions(+), 1 deletion(-)

diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc
index 88939bd1f5..e0296e6ae5 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -132,6 +132,46 @@ class TransformParamsFuncBuilder : public ExprMutator {
   std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> outputs_;
 };
 
+bool SInfoContainsSymVar(StructInfo sinfo) {
+  struct SymVarDetector : public StructInfoVisitor {
+    void VisitStructInfo(const StructInfo& sinfo) final {
+      if (contains_sym_var) {
+        return;
+      }
+      StructInfoVisitor::VisitStructInfo(sinfo);
+    }
+
+    bool CheckShape(Array<PrimExpr> shape) {
+      for (const PrimExpr& value : shape) {
+        const auto* int_imm = value.as<IntImmNode>();
+        if (int_imm == nullptr) {
+          contains_sym_var = true;
+          return false;
+        }
+      }
+      return true;
+    }
+
+    void VisitStructInfo_(const ShapeStructInfoNode* shape_sinfo) final {
+      if (shape_sinfo->values.defined()) {
+        CheckShape(shape_sinfo->values.value());
+      }
+    }
+
+    void VisitStructInfo_(const TensorStructInfoNode* tensor_sinfo) final {
+      if (tensor_sinfo->shape.defined()) {
+        VisitStructInfo(GetStructInfo(tensor_sinfo->shape.value()));
+      }
+    }
+
+    bool contains_sym_var = false;
+  };
+
+  SymVarDetector detector;
+  detector(sinfo);
+  return detector.contains_sym_var;
+}
+
 /*!
  * \brief Visitor that creates the plan of lifting transform params.
  *
@@ -165,9 +205,13 @@ class LiftTransformParamsPlanner : public ExprVisitor {
   void VisitBinding_(const VarBindingNode* binding) final {
     std::vector<const VarNode*> producers;
     bool can_lift = true;
+
+    // Cond 1. Do not lift bindings outside dataflow blocks.
     if (!is_in_dataflow_block_) {
       can_lift = false;
     }
+
+    // Cond 2. Do not lift regarding the "builtin.stop_lift_params" op.
     if (const auto* call = binding->value.as<CallNode>()) {
       static const Op& stop_lift_params_op = Op::Get("relax.builtin.stop_lift_params");
       if (call->op.same_as(stop_lift_params_op)) {
@@ -175,6 +219,7 @@ class LiftTransformParamsPlanner : public ExprVisitor {
       }
     }
 
+    // Cond 3. Do not lift when involving Vars that are not liftable.
     PostOrderVisit(binding->value, [&](const ObjectRef& obj) {
       if (const VarNode* var = obj.as<VarNode>()) {
         producers.push_back(var);
@@ -183,6 +228,12 @@ class LiftTransformParamsPlanner : public ExprVisitor {
         }
       }
     });
+
+    // Cond 4. Do not lift when its struct info contains symbolic variables.
+    if (SInfoContainsSymVar(GetStructInfo(binding->var))) {
+      can_lift = false;
+    }
+
     if (can_lift) {
       lifted_bindings_.insert(binding->var);
       builder_.AddBinding(GetRef<VarBinding>(binding));
@@ -256,7 +307,7 @@ class TransformParamsLifter : public ExprMutator {
     for (const auto& [var, index] : lift_plan_.output_to_index) {
       param_remap_[var] = TupleGetItem(params, index);
     }
-    auto new_body = VisitExpr(func->body);
+    auto new_body = VisitWithNewScope(func->body, new_params);
 
     // Step 3.3: Remove function attributes that are not needed
     auto new_attrs = func->attrs;
diff --git a/tests/python/relax/test_transform_lift_transform_params.py b/tests/python/relax/test_transform_lift_transform_params.py
index b618948840..2a045e9acb 100644
--- a/tests/python/relax/test_transform_lift_transform_params.py
+++ b/tests/python/relax/test_transform_lift_transform_params.py
@@ -440,5 +440,101 @@ def test_stop_lifting():
     tvm.ir.assert_structural_equal(after, Expected)
 
 
+def test_symbolic_var():
+    @tvm.script.ir_module
+    class Before1:
+        @R.function
+        def main(shape: R.Shape(["n"])):
+            R.func_attr({"num_input": 1})
+            n = T.int64()
+            with R.dataflow():
+                zeros = R.zeros((n, n), "float32")
+            return shape
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main_transform_params(params: R.Tuple) -> R.Tuple:
+            with R.dataflow():
+                gv: R.Tuple = R.tuple()
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(shape: R.Shape(["n"]), params: R.Tuple) -> R.Shape(["n"]):
+            n = T.int64()
+            with R.dataflow():
+                zeros: R.Tensor((n, n), dtype="float32") = R.zeros(R.shape([n, n]), dtype="float32")
+                R.output()
+            return shape
+
+    @I.ir_module
+    class Before2:
+        @T.prim_func
+        def zeros(var_T_full: T.handle):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            n = T.int64()
+            T_full = T.match_buffer(var_T_full, (n, n))
+            for ax0, ax1 in T.grid(n, n):
+                with T.block("T_full"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads()
+                    T.writes(T_full[v_ax0, v_ax1])
+                    T_full[v_ax0, v_ax1] = T.float32(0)
+
+        @R.function
+        def main(shape: R.Shape(["n"])) -> R.Shape(["n"]):
+            n = T.int64()
+            R.func_attr({"num_input": 1})
+            cls = Before2
+            with R.dataflow():
+                zeros = R.call_tir(
+                    cls.zeros, R.tuple(), out_sinfo=R.Tensor((n, n), dtype="float32")
+                )
+                R.output()
+            return shape
+
+    @I.ir_module
+    class Expected2:
+        @T.prim_func
+        def zeros(var_T_full: T.handle):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            n = T.int64()
+            T_full = T.match_buffer(var_T_full, (n, n))
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(n, n):
+                with T.block("T_full"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads()
+                    T.writes(T_full[v_ax0, v_ax1])
+                    T_full[v_ax0, v_ax1] = T.float32(0)
+
+        @R.function
+        def main_transform_params(params: R.Tuple) -> R.Tuple:
+            with R.dataflow():
+                gv: R.Tuple = R.tuple()
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(shape: R.Shape(["n"]), params: R.Tuple) -> R.Shape(["n"]):
+            n = T.int64()
+            cls = Expected2
+            with R.dataflow():
+                zeros = R.call_tir(
+                    cls.zeros, R.tuple(), out_sinfo=R.Tensor((n, n), dtype="float32")
+                )
+                R.output()
+            return shape
+
+    mod = Before1
+    after = relax.transform.LiftTransformParams()(mod)
+    tvm.ir.assert_structural_equal(after, Expected1)
+
+    mod = Before2
+    after = relax.transform.LiftTransformParams()(mod)
+    tvm.ir.assert_structural_equal(after, Expected2)
+
+
 if __name__ == "__main__":
     tvm.testing.main()