You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/06/15 13:13:09 UTC

[tvm] branch main updated: [TIR] Update primfunc host attachment to include host (#15102)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dd6fcccc45 [TIR] Update primfunc host attachment to include host (#15102)
dd6fcccc45 is described below

commit dd6fcccc459364562b468c3f17f3b53185094618
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Thu Jun 15 09:12:58 2023 -0400

    [TIR] Update primfunc host attachment to include host (#15102)
    
    This PR updates the host function attachment to include
    host attribute so it can be lowered through MakePackedAPI.
---
 src/tir/transforms/primfunc_utils.cc        | 3 ++-
 tests/python/unittest/test_tir_host_func.py | 3 ++-
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc
index f844b51f53..8a5317a3c8 100644
--- a/src/tir/transforms/primfunc_utils.cc
+++ b/src/tir/transforms/primfunc_utils.cc
@@ -46,7 +46,8 @@ transform::Pass BindTarget(Target target) {
         func = WithAttr(std::move(func), tvm::attr::kTarget, new_target);
       }
     } else if (func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) {
-      func = WithAttr(std::move(func), tvm::attr::kTarget, target_host);
+      func =
+          WithAttr(std::move(func), tvm::attr::kTarget, Target::WithHost(target_host, target_host));
     } else if (is_externally_exposed) {
       func = WithAttr(std::move(func), tvm::attr::kTarget, target);
     } else {
diff --git a/tests/python/unittest/test_tir_host_func.py b/tests/python/unittest/test_tir_host_func.py
index ea0ad7ba4a..ed04985bdd 100644
--- a/tests/python/unittest/test_tir_host_func.py
+++ b/tests/python/unittest/test_tir_host_func.py
@@ -22,6 +22,7 @@ from tvm.meta_schedule.testing import te_workload
 # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,missing-class-docstring,missing-function-docstring
 # fmt: off
 
+
 @I.ir_module
 class Module:
     @T.prim_func
@@ -33,7 +34,7 @@ class Module:
         T.func_attr(
             {
                 "global_symbol": "test",
-                "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}),
+                "target": tvm.target.Target("llvm", host="llvm"),
                 "tir.noalias": True,
             }
         )