You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/03/29 07:33:58 UTC
[tvm] branch unity updated: [Unity][Fix] Copy over module attrs in FuseTIR (#14418)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 51b29ef494 [Unity][Fix] Copy over module attrs in FuseTIR (#14418)
51b29ef494 is described below
commit 51b29ef494fc9339d316f274bd59cb88ba6c03d7
Author: Prakalp Srivastava <pr...@octoml.ai>
AuthorDate: Wed Mar 29 03:33:50 2023 -0400
[Unity][Fix] Copy over module attrs in FuseTIR (#14418)
---
src/relax/transform/fuse_tir.cc | 6 +++++-
tests/python/relax/test_transform_fuse_tir.py | 4 ++--
2 files changed, 7 insertions(+), 3 deletions(-)
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index e90d6e4bc1..f4a31853e3 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -604,7 +604,11 @@ class TIRFuseMutator : public ExprMutator {
mutator.builder_->AddFunction(update_func, gv->name_hint);
}
}
- return mutator.builder_->GetContextIRModule();
+
+ // Step 3. Copy over module attributes and return.
+ auto modified_mod = mutator.builder_->GetContextIRModule();
+ if (mod->attrs.defined()) modified_mod = WithAttrs(modified_mod, mod->attrs->dict);
+ return modified_mod;
}
private:
diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py
index 7a8aa4d39f..356e28d6e9 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -47,7 +47,7 @@ def test_simple():
gv = bb.emit_output(relax.Call(fused_add_exp_squeeze, [x, p0]))
bb.emit_func_output(gv)
- return bb.get()
+ return bb.get().with_attrs({"foo": "bar"})
def expected():
def fused_add_exp_squeeze(x, p0):
@@ -63,7 +63,7 @@ def test_simple():
with bb.dataflow():
gv = bb.emit_output(bb.call_te(fused_add_exp_squeeze, x, p0))
bb.emit_func_output(gv)
- return bb.get()
+ return bb.get().with_attrs({"foo": "bar"})
_check(before(), expected())