You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/09/08 04:22:16 UTC

[tvm] branch unity updated: [Unity][Analysis] Check for usage of DataflowVar in all_vars() (#15698)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 7df7b7b1b0 [Unity][Analysis] Check for usage of DataflowVar in all_vars() (#15698)
7df7b7b1b0 is described below

commit 7df7b7b1b0f0d6ffaeac5e0981355d3ea9bd478f
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Thu Sep 7 21:22:11 2023 -0700

    [Unity][Analysis] Check for usage of DataflowVar in all_vars() (#15698)
    
    Prior to this commit, the `VarVisitor` used in the implementation of
    `all_vars` and `free_vars` only collected variable usage sites where
    the variable was a `const VarNode*`, and ignored usage sites of a
    `const DataflowVarNode*`.  When analyzing an entire function, these
    variables were found in the `const VisitVarBinding*`, and could pass
    the existing tests.  However, when analyzing a single expression,
    these variables would be erroneously excluded.
    
    This commit adds a `VisitExpr_(const DataflowVarNode*)` implementation
    in `VarVisitor`, to collect variable usage regardless of the type of
    variable.
---
 src/relax/analysis/analysis.cc      | 2 ++
 tests/python/relax/test_analysis.py | 9 +++++++++
 2 files changed, 11 insertions(+)

diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc
index 108fe69372..7875a517a1 100644
--- a/src/relax/analysis/analysis.cc
+++ b/src/relax/analysis/analysis.cc
@@ -94,6 +94,8 @@ class VarVisitor : protected ExprVisitor {
 
   void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); }
 
+  void VisitExpr_(const DataflowVarNode* var) final { vars_.Insert(GetRef<Var>(var)); }
+
   void VisitExpr_(const FunctionNode* op) final {
     for (const auto& param : op->params) {
       MarkBounded(param);
diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py
index 40bd5146ba..d5545a0a56 100644
--- a/tests/python/relax/test_analysis.py
+++ b/tests/python/relax/test_analysis.py
@@ -282,6 +282,15 @@ def test_all_vars():
     assert var_names == {"_", "x", "y", "z", "p", "q", "r", "s"}
 
 
+def test_all_vars_from_expr_using_dataflow():
+    """all_vars() should return all Var, including DataflowVar"""
+    func = VarExample["main"]
+    cls_func_q = func.body.blocks[1].bindings[1].value
+
+    var_names = var_name_set(all_vars(cls_func_q))
+    assert var_names == {"q"}
+
+
 def test_bound_vars():
     vars = bound_vars(VarExample["func"])
     assert len(vars) == 2