You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/03/29 07:33:58 UTC

[tvm] branch unity updated: [Unity][Fix] Copy over module attrs in FuseTIR (#14418)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 51b29ef494 [Unity][Fix] Copy over module attrs in FuseTIR (#14418)
51b29ef494 is described below

commit 51b29ef494fc9339d316f274bd59cb88ba6c03d7
Author: Prakalp Srivastava <pr...@octoml.ai>
AuthorDate: Wed Mar 29 03:33:50 2023 -0400

    [Unity][Fix] Copy over module attrs in FuseTIR (#14418)
---
 src/relax/transform/fuse_tir.cc               | 6 +++++-
 tests/python/relax/test_transform_fuse_tir.py | 4 ++--
 2 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index e90d6e4bc1..f4a31853e3 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -604,7 +604,11 @@ class TIRFuseMutator : public ExprMutator {
         mutator.builder_->AddFunction(update_func, gv->name_hint);
       }
     }
-    return mutator.builder_->GetContextIRModule();
+
+    // Step 3. Copy over module attributes and return.
+    auto modified_mod = mutator.builder_->GetContextIRModule();
+    if (mod->attrs.defined()) modified_mod = WithAttrs(modified_mod, mod->attrs->dict);
+    return modified_mod;
   }
 
  private:
diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py
index 7a8aa4d39f..356e28d6e9 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -47,7 +47,7 @@ def test_simple():
                 gv = bb.emit_output(relax.Call(fused_add_exp_squeeze, [x, p0]))
             bb.emit_func_output(gv)
 
-        return bb.get()
+        return bb.get().with_attrs({"foo": "bar"})
 
     def expected():
         def fused_add_exp_squeeze(x, p0):
@@ -63,7 +63,7 @@ def test_simple():
             with bb.dataflow():
                 gv = bb.emit_output(bb.call_te(fused_add_exp_squeeze, x, p0))
             bb.emit_func_output(gv)
-        return bb.get()
+        return bb.get().with_attrs({"foo": "bar"})
 
     _check(before(), expected())