You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/09/08 04:22:16 UTC
[tvm] branch unity updated: [Unity][Analysis] Check for usage of DataflowVar in all_vars() (#15698)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 7df7b7b1b0 [Unity][Analysis] Check for usage of DataflowVar in all_vars() (#15698)
7df7b7b1b0 is described below
commit 7df7b7b1b0f0d6ffaeac5e0981355d3ea9bd478f
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Thu Sep 7 21:22:11 2023 -0700
[Unity][Analysis] Check for usage of DataflowVar in all_vars() (#15698)
Prior to this commit, the `VarVisitor` used in the implementation of
`all_vars` and `free_vars` only collected variable usage sites where
the variable was a `const VarNode*`, and ignored usage sites of a
`const DataflowVarNode*`. When analyzing an entire function, these
variables were found in the `const VisitVarBinding*`, and could pass
the existing tests. However, when analyzing a single expression,
these variables would be erroneously excluded.
This commit adds a `VisitExpr_(const DataflowVarNode*)` implementation
in `VarVisitor`, to collect variable usage regardless of the type of
variable.
---
src/relax/analysis/analysis.cc | 2 ++
tests/python/relax/test_analysis.py | 9 +++++++++
2 files changed, 11 insertions(+)
diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc
index 108fe69372..7875a517a1 100644
--- a/src/relax/analysis/analysis.cc
+++ b/src/relax/analysis/analysis.cc
@@ -94,6 +94,8 @@ class VarVisitor : protected ExprVisitor {
void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); }
+ void VisitExpr_(const DataflowVarNode* var) final { vars_.Insert(GetRef<Var>(var)); }
+
void VisitExpr_(const FunctionNode* op) final {
for (const auto& param : op->params) {
MarkBounded(param);
diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py
index 40bd5146ba..d5545a0a56 100644
--- a/tests/python/relax/test_analysis.py
+++ b/tests/python/relax/test_analysis.py
@@ -282,6 +282,15 @@ def test_all_vars():
assert var_names == {"_", "x", "y", "z", "p", "q", "r", "s"}
+def test_all_vars_from_expr_using_dataflow():
+ """all_vars() should return all Var, including DataflowVar"""
+ func = VarExample["main"]
+ cls_func_q = func.body.blocks[1].bindings[1].value
+
+ var_names = var_name_set(all_vars(cls_func_q))
+ assert var_names == {"q"}
+
+
def test_bound_vars():
vars = bound_vars(VarExample["func"])
assert len(vars) == 2