You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/10/30 18:43:44 UTC
(tvm) branch unity updated: [Unity] Allow Pipeline Registration (#16008)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 53ccf18625 [Unity] Allow Pipeline Registration (#16008)
53ccf18625 is described below
commit 53ccf1862539fa7bef7d94560926ab7c9b2db47a
Author: Junru Shao <ju...@apache.org>
AuthorDate: Mon Oct 30 11:43:38 2023 -0700
[Unity] Allow Pipeline Registration (#16008)
This commit adds `tvm.relax.register_pipeline` which allows external
registration of compilation pipelines.
---
include/tvm/topi/transform.h | 2 --
python/tvm/relax/__init__.py | 1 +
python/tvm/relax/dpl/pattern.py | 3 +--
python/tvm/relax/pipeline.py | 24 +++++++++++++++++++++---
4 files changed, 23 insertions(+), 7 deletions(-)
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 15f755df59..009f8cdd30 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -1123,8 +1123,6 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int a
name, tag);
}
} else if (mode == "fast") {
- LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
- "Make sure input indices are in bound";
return compute(
out_shape,
[&](const Array<Var>& out_index) {
diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index 718e40eb7f..09b5b965ea 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -85,6 +85,7 @@ from .struct_info import (
# pipeline
from .pipeline import get_pipeline
+from .pipeline import register_pipeline
# Import submodules in the last to avoid dependency
from . import exec_builder
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index b72cb73b5f..4fb08c4635 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -466,7 +466,7 @@ class TuplePattern(DFPattern):
The fields in the tuple.
"""
- def __init__(self, fields: Array):
+ def __init__(self, fields: list):
self.__init_handle_by_constructor__(ffi.TuplePattern, fields) # type: ignore
def __getitem__(self, index: Optional[int]) -> "TupleGetItemPattern":
@@ -1041,7 +1041,6 @@ class PatternSeq(Node):
return _used_by(self, other, index)
def only_used_by(self, other: Union[DFPattern, "PatternSeq"], index=-1) -> "PatternSeq":
-
"""
Assuming the right-most pattern must be **ONLY** used by the `other` pattern as a producer
diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py
index 74ba7a5520..367c1ede0e 100644
--- a/python/tvm/relax/pipeline.py
+++ b/python/tvm/relax/pipeline.py
@@ -23,6 +23,7 @@ as it is or serves as a basis to do further composition.
# pylint: disable=unused-argument
import tvm
from tvm import meta_schedule as ms
+
from . import transform
@@ -93,9 +94,26 @@ def get_pipeline(name: str = "zero", **kwargs) -> tvm.transform.Pass:
The transformation pipeline.
"""
- if name in PIPELINE_MAP:
- return PIPELINE_MAP[name](**kwargs)
- else:
+ if name not in PIPELINE_MAP:
raise ValueError(
f"Unknown pre-built pipeline {name}," f"candidates are {list(PIPELINE_MAP.keys())}"
)
+ return PIPELINE_MAP[name](**kwargs)
+
+
+def register_pipeline(name: str):
+ """Register a new pipeline
+
+ Parameters
+ ----------
+ name : str
+ Name of the pipeline
+ """
+ if name in PIPELINE_MAP:
+ raise ValueError(f"Pipeline {name} has already been registered")
+
+ def _register(func):
+ PIPELINE_MAP[name] = func
+ return func
+
+ return _register