You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/01/28 04:31:28 UTC

[tvm] branch main updated: Fold If when the condition is Constant (#7354)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8b84e33  Fold If when the condition is Constant (#7354)
8b84e33 is described below

commit 8b84e33679585082fd1817821eac8a7eae5830c6
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Wed Jan 27 21:31:18 2021 -0700

    Fold If when the condition is Constant (#7354)
---
 src/relay/transforms/fold_constant.cc         | 12 +++++++++
 tests/python/relay/test_pass_fold_constant.py | 39 +++++++++++++++++++++++++++
 2 files changed, 51 insertions(+)

diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc
index 48af31f..66f233b 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -120,6 +120,18 @@ class ConstantFolder : public MixedModeMutator {
     }
   }
 
+  Expr VisitExpr_(const IfNode* op) final {
+    auto new_cond = ExprMutator::VisitExpr(op->cond);
+    if (auto const_cond = new_cond.as<ConstantNode>()) {
+      if (reinterpret_cast<uint8_t*>(const_cond->data->data)[0]) {
+        return ExprMutator::VisitExpr(op->true_branch);
+      } else {
+        return ExprMutator::VisitExpr(op->false_branch);
+      }
+    }
+    return ExprMutator::VisitExpr_(op);
+  }
+
   Expr Rewrite_(const CallNode* call, const Expr& post) final {
     if (inside_primitive) {
       return GetRef<Expr>(call);
diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py
index 549596d..76182d2 100644
--- a/tests/python/relay/test_pass_fold_constant.py
+++ b/tests/python/relay/test_pass_fold_constant.py
@@ -147,6 +147,45 @@ def test_fold_concat():
     assert tvm.ir.structural_equal(zz, zexpected)
 
 
+def test_fold_if():
+    cond_data = np.array(1).astype("bool")
+    x_data = np.array([[1, 2, 3]]).astype("float32")
+
+    def before():
+        a = relay.const(cond_data)
+        x = relay.const(x_data)
+        y = relay.const(x_data)
+        iff = relay.If(a, x + y, x - y)
+        return relay.Function([], iff)
+
+    def expected():
+        y_data = x_data + x_data
+        y = relay.const(y_data)
+        return relay.Function([], y)
+
+    zz = run_opt_pass(before(), transform.FoldConstant())
+    zexpected = run_opt_pass(expected(), transform.InferType())
+    assert tvm.ir.structural_equal(zz, zexpected)
+
+    cond_data = np.array(0).astype("bool")
+
+    def before():
+        a = relay.const(cond_data)
+        x = relay.const(x_data)
+        y = relay.const(x_data)
+        iff = relay.If(a, x + y, x - y)
+        return relay.Function([], iff)
+
+    def expected():
+        y_data = x_data - x_data
+        y = relay.const(y_data)
+        return relay.Function([], y)
+
+    zz = run_opt_pass(before(), transform.FoldConstant())
+    zexpected = run_opt_pass(expected(), transform.InferType())
+    assert tvm.ir.structural_equal(zz, zexpected)
+
+
 def test_fold_shape_of():
     c_shape = (8, 9, 10)