You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/01/28 04:31:28 UTC
[tvm] branch main updated: Fold If when the condition is Constant
(#7354)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8b84e33 Fold If when the condition is Constant (#7354)
8b84e33 is described below
commit 8b84e33679585082fd1817821eac8a7eae5830c6
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Wed Jan 27 21:31:18 2021 -0700
Fold If when the condition is Constant (#7354)
---
src/relay/transforms/fold_constant.cc | 12 +++++++++
tests/python/relay/test_pass_fold_constant.py | 39 +++++++++++++++++++++++++++
2 files changed, 51 insertions(+)
diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc
index 48af31f..66f233b 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -120,6 +120,18 @@ class ConstantFolder : public MixedModeMutator {
}
}
+ Expr VisitExpr_(const IfNode* op) final {
+ auto new_cond = ExprMutator::VisitExpr(op->cond);
+ if (auto const_cond = new_cond.as<ConstantNode>()) {
+ if (reinterpret_cast<uint8_t*>(const_cond->data->data)[0]) {
+ return ExprMutator::VisitExpr(op->true_branch);
+ } else {
+ return ExprMutator::VisitExpr(op->false_branch);
+ }
+ }
+ return ExprMutator::VisitExpr_(op);
+ }
+
Expr Rewrite_(const CallNode* call, const Expr& post) final {
if (inside_primitive) {
return GetRef<Expr>(call);
diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py
index 549596d..76182d2 100644
--- a/tests/python/relay/test_pass_fold_constant.py
+++ b/tests/python/relay/test_pass_fold_constant.py
@@ -147,6 +147,45 @@ def test_fold_concat():
assert tvm.ir.structural_equal(zz, zexpected)
+def test_fold_if():
+ cond_data = np.array(1).astype("bool")
+ x_data = np.array([[1, 2, 3]]).astype("float32")
+
+ def before():
+ a = relay.const(cond_data)
+ x = relay.const(x_data)
+ y = relay.const(x_data)
+ iff = relay.If(a, x + y, x - y)
+ return relay.Function([], iff)
+
+ def expected():
+ y_data = x_data + x_data
+ y = relay.const(y_data)
+ return relay.Function([], y)
+
+ zz = run_opt_pass(before(), transform.FoldConstant())
+ zexpected = run_opt_pass(expected(), transform.InferType())
+ assert tvm.ir.structural_equal(zz, zexpected)
+
+ cond_data = np.array(0).astype("bool")
+
+ def before():
+ a = relay.const(cond_data)
+ x = relay.const(x_data)
+ y = relay.const(x_data)
+ iff = relay.If(a, x + y, x - y)
+ return relay.Function([], iff)
+
+ def expected():
+ y_data = x_data - x_data
+ y = relay.const(y_data)
+ return relay.Function([], y)
+
+ zz = run_opt_pass(before(), transform.FoldConstant())
+ zexpected = run_opt_pass(expected(), transform.InferType())
+ assert tvm.ir.structural_equal(zz, zexpected)
+
+
def test_fold_shape_of():
c_shape = (8, 9, 10)