You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/04/17 21:11:53 UTC

[tvm] branch unity updated: [Unity] `enable_warning` option for LegalizeOps and MSApplyDatabase (#14634)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new ff2f3c861a [Unity] `enable_warning` option for LegalizeOps and MSApplyDatabase (#14634)
ff2f3c861a is described below

commit ff2f3c861a90016c0770f808e4f69ece41d2e1ad
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Mon Apr 17 17:11:45 2023 -0400

    [Unity] `enable_warning` option for LegalizeOps and MSApplyDatabase (#14634)
    
    This PR introduces the `enable_warning` argument for the LegalizeOps
    pass and the MetaScheduleApplyDatabase pass. These two passes now will
    emit warnings only when the `enable_warning` argument is true.
    
    We introduce this because in our recent development of [Web LLM](https://github.com/mlc-ai/web-llm),
    we leverage three passes to generate the GPU code for the TIR functions
    in a given IRModule. The passes are MetaScheduleDatabaseApply,
    DispatchTIROperator (one pass that substitutes a TIR func with another
    handmade one or hand-scheduled one), and DefaultGPUSchedule, in order.
    
    In this case (and most of the cases), we always have the
    DefaultGPUSchedule pass as a safe guard, and it is common and often that
    one TIR function does not exist in the MetaSchedule database. For this
    reason, always printing warning is not ideal and can be too much and
    annoying for both users and developers.
    
    The same reason applies to LegalizeOps as well. In our practice of
    WebLLM, we bring in some high-level operator that is not intended to be
    legalized (of course, it does not have a legalization function). Always
    printing the warnings is also not expected as well.
    
    According to these experiences, we try to introduce a new argument for
    configuring the behavior of if emitting warnings or not.
---
 include/tvm/relax/transform.h                 |  4 +-
 python/tvm/meta_schedule/relax_integration.py |  6 ++-
 python/tvm/relax/pipeline.py                  | 71 +++++++++++++++++----------
 python/tvm/relax/transform/transform.py       | 22 +++++++--
 src/relax/transform/legalize_ops.cc           | 23 ++++++---
 src/relax/transform/meta_schedule.cc          |  4 +-
 6 files changed, 89 insertions(+), 41 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index fec2ef0a04..9a9d1eb54e 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -195,9 +195,11 @@ TVM_DLL Pass FoldConstant();
  *
  * \param cmap The customized operator legalization function map. The customized function
  * will override the default one.
+ * \param enable_warning A boolean value indicating if to print warnings for TIR functions not
+ * showing up in the database.
  * \return The Pass.
  */
-TVM_DLL Pass LegalizeOps(Optional<Map<String, PackedFunc>> cmap);
+TVM_DLL Pass LegalizeOps(Optional<Map<String, PackedFunc>> cmap, bool enable_warning = false);
 
 /*!
  * \brief Lift transformation of the parameters of a function.
diff --git a/python/tvm/meta_schedule/relax_integration.py b/python/tvm/meta_schedule/relax_integration.py
index c776d64763..55de6cf6d5 100644
--- a/python/tvm/meta_schedule/relax_integration.py
+++ b/python/tvm/meta_schedule/relax_integration.py
@@ -358,6 +358,7 @@ def compile_relax(
     mod: IRModule,
     target: Union[Target, str],
     params: Optional[Dict[str, NDArray]],
+    enable_warning: bool = False,
 ) -> "relax.Executable":
     """Compile a relax program with a MetaSchedule database.
 
@@ -371,6 +372,9 @@ def compile_relax(
         The compilation target
     params : Optional[Dict[str, tvm.runtime.NDArray]]
         The associated parameters of the program
+    enable_warning : bool
+        A boolean value indicating if to print warnings for TIR functions not
+        showing up in the database. By default we don't print warning.
 
     Returns
     -------
@@ -388,6 +392,6 @@ def compile_relax(
         mod = BindParams("main", params)(mod)
 
     with target, database, PassContext(opt_level=3):
-        relax_mod = MetaScheduleApplyDatabase()(mod)
+        relax_mod = MetaScheduleApplyDatabase(enable_warning=enable_warning)(mod)
         relax_ex = relax_build(relax_mod, target=target)
     return relax_ex
diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py
index a5da15b76d..74ba7a5520 100644
--- a/python/tvm/relax/pipeline.py
+++ b/python/tvm/relax/pipeline.py
@@ -26,49 +26,66 @@ from tvm import meta_schedule as ms
 from . import transform
 
 
-@tvm.transform.module_pass(opt_level=0)
-def zero_pipeline(mod: tvm.ir.IRModule, ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:
-    """Pipeline that applies pre-tuned logs.
+def zero_pipeline(*, enable_warning: bool = False):
+    """Wrapper function that returns the zero pipeline.
 
     Parameters
     ----------
-    mod : tvm.ir.IRModule
-        Input IRModule.
+    enable_warning : bool
+        A boolean value indicating if to print warnings
+        * in LegalizeOps pass, for CallNode whose op's legalization function is
+        not registered,
+        * in MetaScheduleApplyDatabase pass, for TIR functions now showing up in
+        the database. By default we don't print warning.
+    """
 
-    ctx : tvm.transform.PassContext
-        The pass context
+    @tvm.transform.module_pass(opt_level=0)
+    def f_zero_pipeline(mod: tvm.ir.IRModule, ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:
+        """Pipeline that applies pre-tuned logs.
 
-    Returns
-    -------
-    mod: tvm.ir.IRModule
-        The result transformed module.
-    """
-    seq = tvm.transform.Sequential(
-        [
-            transform.LegalizeOps(),
-            transform.AnnotateTIROpPattern(),
-            transform.FoldConstant(),
-            transform.FuseOps(),
-            transform.FuseTIR(),
-        ]
-    )
-    mod = seq(mod)
-    if ms.Database.current():
-        mod = transform.MetaScheduleApplyDatabase()(mod)
-    return mod
+        Parameters
+        ----------
+        mod : tvm.ir.IRModule
+            Input IRModule.
+
+        ctx : tvm.transform.PassContext
+            The pass context
+
+        Returns
+        -------
+        mod: tvm.ir.IRModule
+            The result transformed module.
+        """
+        seq = tvm.transform.Sequential(
+            [
+                transform.LegalizeOps(enable_warning=enable_warning),
+                transform.AnnotateTIROpPattern(),
+                transform.FoldConstant(),
+                transform.FuseOps(),
+                transform.FuseTIR(),
+            ]
+        )
+        mod = seq(mod)
+        if ms.Database.current():
+            mod = transform.MetaScheduleApplyDatabase(enable_warning=enable_warning)(mod)
+        return mod
+
+    return f_zero_pipeline
 
 
 # global map of pre-built pipelines
 PIPELINE_MAP = {"zero": zero_pipeline}
 
 
-def get_pipeline(name: str = "zero") -> tvm.transform.Pass:
+def get_pipeline(name: str = "zero", **kwargs) -> tvm.transform.Pass:
     """Get pre-build pipeline by name
 
     Parameters
     ----------
     name : Optional[str]
         Name of the pipeline
+    kwargs : Dict[str, object]
+        Keyword args for configuring the pipeline.
 
     Returns
     -------
@@ -77,7 +94,7 @@ def get_pipeline(name: str = "zero") -> tvm.transform.Pass:
     """
 
     if name in PIPELINE_MAP:
-        return PIPELINE_MAP[name]
+        return PIPELINE_MAP[name](**kwargs)
     else:
         raise ValueError(
             f"Unknown pre-built pipeline {name}," f"candidates are {list(PIPELINE_MAP.keys())}"
diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py
index f0277151bb..b17f2fe62b 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -631,7 +631,9 @@ def LiftTransformParams() -> tvm.ir.transform.Pass:
     return _ffi_api.LiftTransformParams()  # type: ignore
 
 
-def LegalizeOps(customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None):
+def LegalizeOps(
+    customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None, enable_warning: bool = False
+):
     """Legalize high-level operator calls in Relax functions to call_tir
     with corresponding low-level TIR PrimFuncs.
 
@@ -656,6 +658,11 @@ def LegalizeOps(customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None
         The customized operator legalization function map. The customized function will override
         the default one.
 
+    enable_warning : bool
+        A boolean value indicating if to print warnings for CallNode whose op's
+        legalization function is not registered. By default we don't print
+        warnings.
+
     Returns
     -------
     ret : tvm.transform.Pass
@@ -730,22 +737,29 @@ def LegalizeOps(customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None
                         T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * B[v_ax0, v_ax1]
     """
 
-    return _ffi_api.LegalizeOps(customize_legalize_map)  # type: ignore
+    return _ffi_api.LegalizeOps(customize_legalize_map, enable_warning)  # type: ignore
 
 
 def MetaScheduleApplyDatabase(
-    work_dir: Optional[str] = None,
+    work_dir: Optional[str] = None, enable_warning: bool = False
 ) -> tvm.ir.transform.Pass:
     """Apply the best schedule from tuning database.
+
+    Parameters
+    ----------
     work_dir : Optional[str]
        work directory to deduce default database if database is not provided
        (it will be ignored when an user passes database)
+    enable_warning : bool
+        A boolean value indicating if to print warnings for TIR functions not
+        showing up in the database. By default we don't print warning.
+
     Returns
     -------
     ret : tvm.transform.Pass
         The registered pass
     """
-    return _ffi_api.MetaScheduleApplyDatabase(work_dir)  # type: ignore
+    return _ffi_api.MetaScheduleApplyDatabase(work_dir, enable_warning)  # type: ignore
 
 
 def MetaScheduleTuneTIR(
diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc
index 7c5393c6ca..0953a8dacf 100644
--- a/src/relax/transform/legalize_ops.cc
+++ b/src/relax/transform/legalize_ops.cc
@@ -55,8 +55,12 @@ bool KnowAllShapeValues(const StructInfo& sinfo) {
 
 class LegalizeMutator : public ExprMutator {
  public:
-  explicit LegalizeMutator(const IRModule& mod, const Optional<Map<String, PackedFunc>>& cmap)
-      : ExprMutator(mod), mod_(std::move(mod)), cmap_(std::move(cmap)) {}
+  explicit LegalizeMutator(const IRModule& mod, const Optional<Map<String, PackedFunc>>& cmap,
+                           bool enable_warning)
+      : ExprMutator(mod),
+        mod_(std::move(mod)),
+        cmap_(std::move(cmap)),
+        enable_warning_(enable_warning) {}
 
   IRModule Transform() {
     for (const auto& [gv, func] : mod_->functions) {
@@ -107,7 +111,7 @@ class LegalizeMutator : public ExprMutator {
     }
 
     // No legalization.
-    if (op != call_tir_op && op != call_dps_packed_op) {
+    if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op) {
       LOG(WARNING) << "No legalization func for " << op->name << " is found.";
     }
     return visited_call;
@@ -117,13 +121,20 @@ class LegalizeMutator : public ExprMutator {
   IRModule mod_;
   /*! \brief The customized legalization function map. */
   Optional<Map<String, PackedFunc>> cmap_;
+  /*!
+   * \brief A boolean value indicating if to print warnings for CallNode whose op's
+   * legalization function is not registered.
+   */
+  bool enable_warning_;
 };
 
 namespace transform {
 
-Pass LegalizeOps(Optional<Map<String, PackedFunc>> cmap) {
-  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
-      [=](IRModule mod, PassContext pc) { return LegalizeMutator(mod, cmap).Transform(); };
+Pass LegalizeOps(Optional<Map<String, PackedFunc>> cmap, bool enable_warning) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
+                                                                            PassContext pc) {
+    return LegalizeMutator(mod, cmap, enable_warning).Transform();
+  };
   return CreateModulePass(/*pass_function=*/pass_func,
                           /*opt_level=*/0,
                           /*pass_name=*/"LegalizeOps",
diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc
index c33a90ccb1..bb9c3579e7 100644
--- a/src/relax/transform/meta_schedule.cc
+++ b/src/relax/transform/meta_schedule.cc
@@ -86,7 +86,7 @@ class MetaScheduleTuner {
   const runtime::PackedFunc* normalize_mod_func_;
 };
 
-Pass MetaScheduleApplyDatabase(Optional<String> work_dir) {
+Pass MetaScheduleApplyDatabase(Optional<String> work_dir, bool enable_warning = false) {
   using tvm::meta_schedule::Database;
   Target target = Target::Current(false);
   const runtime::PackedFunc* normalize_mod_func_ =
@@ -140,7 +140,7 @@ Pass MetaScheduleApplyDatabase(Optional<String> work_dir) {
           new_prim_func = WithAttr(std::move(new_prim_func), tir::attr::kIsScheduled, Bool(true));
           result.Set(gv, new_prim_func);
           continue;
-        } else {
+        } else if (enable_warning) {
           LOG(WARNING) << "Tuning record is not found for primfunc: " << gv->name_hint;
         }
       }