You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/09/19 19:38:14 UTC

[tvm] 02/28: fix new imports

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch aluo/rebase-09192022-autotensorization
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 90b986cc4ee347c3d48e0c3063be009c0208be34
Author: Andrew Luo <an...@gmail.com>
AuthorDate: Wed Aug 17 21:05:08 2022 -0700

    fix new imports
---
 python/tvm/meta_schedule/default_config.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py
index 8907d0bc9d..eaa026e3b4 100644
--- a/python/tvm/meta_schedule/default_config.py
+++ b/python/tvm/meta_schedule/default_config.py
@@ -24,7 +24,7 @@ from tvm._ffi.registry import register_func
 from tvm.contrib import nvcc
 from tvm.ir import IRModule
 from tvm.target import Target
-from tvm.tir import PrimFunc, tensor_intrin
+from tvm.tir import PrimFunc
 
 from .builder import Builder, LocalBuilder
 from .cost_model import CostModel, XGBModel
@@ -311,6 +311,7 @@ class _DefaultLLVMVNNI:
     @staticmethod
     def schedule_rules() -> List[ScheduleRule]:
         from tvm.meta_schedule import schedule_rule as M
+        from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN
 
         logger.info("Using schedule rule: LLVM VNNI")
 
@@ -326,7 +327,7 @@ class _DefaultLLVMVNNI:
             ),
             M.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64),
             M.MultiLevelTilingWithIntrin(
-                tensor_intrin.VNNI_DOT_16x4_INTRIN,
+                VNNI_DOT_16x4_INTRIN,
                 structure="SSRSRS",
                 tile_binds=None,
                 max_innermost_factor=64,
@@ -459,7 +460,7 @@ class _DefaultCUDATensorCore:
         return [
             M.MultiLevelTilingTensorCore(
                 intrin_groups=[
-                    tensor_intrin.get_wmma_intrin_group(
+                    get_wmma_intrin_group(
                         store_scope="shared",
                         in_dtype=in_dtype,
                         out_dtype=out_dtype,