You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ke...@apache.org on 2020/04/30 17:00:35 UTC
[incubator-tvm] branch master updated: [Fix] Add ConstantNode to
IsAtomic (#5457)
This is an automated email from the ASF dual-hosted git repository.
kevinthesun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new ae89afe [Fix] Add ConstantNode to IsAtomic (#5457)
ae89afe is described below
commit ae89afe0f09db85d11d92d75e5a6ca34b22fb323
Author: Zhi <51...@users.noreply.github.com>
AuthorDate: Thu Apr 30 10:00:27 2020 -0700
[Fix] Add ConstantNode to IsAtomic (#5457)
* add constantnode to atomic
* Add ToANormalForm to FoldConstant
---
src/relay/transforms/fold_constant.cc | 1 +
tests/python/relay/test_pass_fold_constant.py | 19 +++++++++++++++++++
2 files changed, 20 insertions(+)
diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc
index a52f420..fab184c 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -203,6 +203,7 @@ class ConstantFolder : public ExprMutator {
// Constant evaluate a expression.
Expr ConstEvaluate(Expr expr) {
std::vector<transform::Pass> passes = {transform::FuseOps(0),
+ transform::ToANormalForm(),
transform::InferType()};
Function func;
if (expr.as<FunctionNode>()) {
diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py
index b212b26..a981667 100644
--- a/tests/python/relay/test_pass_fold_constant.py
+++ b/tests/python/relay/test_pass_fold_constant.py
@@ -32,6 +32,25 @@ def run_opt_pass(expr, opt_pass):
return entry if isinstance(expr, relay.Function) else entry.body
+def test_concatenate_const():
+ def before():
+ data = tvm.nd.array(np.array([1.0, 2.0, 3.0]))
+ const = relay.const(data)
+ concat = relay.op.concatenate([const, const], axis=0)
+ func = relay.Function([], concat)
+ return func
+
+ def expected():
+ data = tvm.nd.array(np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0]))
+ const = relay.const(data)
+ func = relay.Function([], const)
+ return func
+
+ zz = run_opt_pass(before(), transform.FoldConstant())
+ zexpected = run_opt_pass(expected(), transform.InferType())
+ assert tvm.ir.structural_equal(zz, zexpected)
+
+
def test_fold_const():
c_data = np.array([1, 2, 3]).astype("float32")
t = relay.TensorType([1, 2, 3], "float32")