You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ke...@apache.org on 2020/04/30 17:00:35 UTC

[incubator-tvm] branch master updated: [Fix] Add ConstantNode to IsAtomic (#5457)

This is an automated email from the ASF dual-hosted git repository.

kevinthesun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new ae89afe  [Fix] Add ConstantNode to IsAtomic (#5457)
ae89afe is described below

commit ae89afe0f09db85d11d92d75e5a6ca34b22fb323
Author: Zhi <51...@users.noreply.github.com>
AuthorDate: Thu Apr 30 10:00:27 2020 -0700

    [Fix] Add ConstantNode to IsAtomic (#5457)
    
    * add constantnode to atomic
    
    * Add ToANormalForm to FoldConstant
---
 src/relay/transforms/fold_constant.cc         |  1 +
 tests/python/relay/test_pass_fold_constant.py | 19 +++++++++++++++++++
 2 files changed, 20 insertions(+)

diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc
index a52f420..fab184c 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -203,6 +203,7 @@ class ConstantFolder : public ExprMutator {
   // Constant evaluate a expression.
   Expr ConstEvaluate(Expr expr) {
     std::vector<transform::Pass> passes = {transform::FuseOps(0),
+                                           transform::ToANormalForm(),
                                            transform::InferType()};
     Function func;
     if (expr.as<FunctionNode>()) {
diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py
index b212b26..a981667 100644
--- a/tests/python/relay/test_pass_fold_constant.py
+++ b/tests/python/relay/test_pass_fold_constant.py
@@ -32,6 +32,25 @@ def run_opt_pass(expr, opt_pass):
     return entry if isinstance(expr, relay.Function) else entry.body
 
 
+def test_concatenate_const():
+    def before():
+        data = tvm.nd.array(np.array([1.0, 2.0, 3.0]))
+        const = relay.const(data)
+        concat = relay.op.concatenate([const, const], axis=0)
+        func = relay.Function([], concat)
+        return func
+
+    def expected():
+        data = tvm.nd.array(np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0]))
+        const = relay.const(data)
+        func = relay.Function([], const)
+        return func
+
+    zz = run_opt_pass(before(), transform.FoldConstant())
+    zexpected = run_opt_pass(expected(), transform.InferType())
+    assert tvm.ir.structural_equal(zz, zexpected)
+
+
 def test_fold_const():
     c_data = np.array([1, 2, 3]).astype("float32")
     t = relay.TensorType([1, 2, 3], "float32")