You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/10/30 18:43:44 UTC

(tvm) branch unity updated: [Unity] Allow Pipeline Registration (#16008)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 53ccf18625 [Unity] Allow Pipeline Registration (#16008)
53ccf18625 is described below

commit 53ccf1862539fa7bef7d94560926ab7c9b2db47a
Author: Junru Shao <ju...@apache.org>
AuthorDate: Mon Oct 30 11:43:38 2023 -0700

    [Unity] Allow Pipeline Registration (#16008)
    
    This commit adds `tvm.relax.register_pipeline` which allows external
    registration of compilation pipelines.
---
 include/tvm/topi/transform.h    |  2 --
 python/tvm/relax/__init__.py    |  1 +
 python/tvm/relax/dpl/pattern.py |  3 +--
 python/tvm/relax/pipeline.py    | 24 +++++++++++++++++++++---
 4 files changed, 23 insertions(+), 7 deletions(-)

diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 15f755df59..009f8cdd30 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -1123,8 +1123,6 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int a
           name, tag);
     }
   } else if (mode == "fast") {
-    LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
-                    "Make sure input indices are in bound";
     return compute(
         out_shape,
         [&](const Array<Var>& out_index) {
diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index 718e40eb7f..09b5b965ea 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -85,6 +85,7 @@ from .struct_info import (
 
 # pipeline
 from .pipeline import get_pipeline
+from .pipeline import register_pipeline
 
 # Import submodules in the last to avoid dependency
 from . import exec_builder
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index b72cb73b5f..4fb08c4635 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -466,7 +466,7 @@ class TuplePattern(DFPattern):
         The fields in the tuple.
     """
 
-    def __init__(self, fields: Array):
+    def __init__(self, fields: list):
         self.__init_handle_by_constructor__(ffi.TuplePattern, fields)  # type: ignore
 
     def __getitem__(self, index: Optional[int]) -> "TupleGetItemPattern":
@@ -1041,7 +1041,6 @@ class PatternSeq(Node):
         return _used_by(self, other, index)
 
     def only_used_by(self, other: Union[DFPattern, "PatternSeq"], index=-1) -> "PatternSeq":
-
         """
         Assuming the right-most pattern must be **ONLY** used by the `other` pattern as a producer
 
diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py
index 74ba7a5520..367c1ede0e 100644
--- a/python/tvm/relax/pipeline.py
+++ b/python/tvm/relax/pipeline.py
@@ -23,6 +23,7 @@ as it is or serves as a basis to do further composition.
 # pylint: disable=unused-argument
 import tvm
 from tvm import meta_schedule as ms
+
 from . import transform
 
 
@@ -93,9 +94,26 @@ def get_pipeline(name: str = "zero", **kwargs) -> tvm.transform.Pass:
        The transformation pipeline.
     """
 
-    if name in PIPELINE_MAP:
-        return PIPELINE_MAP[name](**kwargs)
-    else:
+    if name not in PIPELINE_MAP:
         raise ValueError(
             f"Unknown pre-built pipeline {name}," f"candidates are {list(PIPELINE_MAP.keys())}"
         )
+    return PIPELINE_MAP[name](**kwargs)
+
+
+def register_pipeline(name: str):
+    """Register a new pipeline
+
+    Parameters
+    ----------
+    name : str
+        Name of the pipeline
+    """
+    if name in PIPELINE_MAP:
+        raise ValueError(f"Pipeline {name} has already been registered")
+
+    def _register(func):
+        PIPELINE_MAP[name] = func
+        return func
+
+    return _register