You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2023/10/26 15:16:05 UTC

[tvm] branch unity updated: [Unity] Include LegalizeOps in the default relax.build lowering flow (#15864)

This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new ebbe38f328 [Unity] Include LegalizeOps in the default relax.build lowering flow (#15864)
ebbe38f328 is described below

commit ebbe38f3281776cfda4fce0b188892ab1c5c7572
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Thu Oct 26 10:15:59 2023 -0500

    [Unity] Include LegalizeOps in the default relax.build lowering flow (#15864)
    
    Prior to this commit, `relax.transform.LegalizeOps` needed to be
    called prior to `relax.build`.  This commit adds `LegalizeOps` to the
    lowering flow, to simplify the calling steps for an end-user.  If the
    `IRModule` contains no legalizable functions, a second legalization
    pass has no effect.
    
    Some test cases relied on this behavior as an implicit assertion that
    operator fusion patterns applied.  That is, by omitting `LegalizeOps`,
    a successful compilation `relax.build` would only occur if all
    legalizable operators have already been removed, and so an incorrect
    fusion pattern would result in a failure to build the module.  While
    these tests would be better expressed by comparing against an expected
    fused pattern, updating the tests is outside the scope of this PR.  To
    allow these tests to keep their implicit assertions, a
    `"relax.transform.apply_legalize_ops"` config can be used to disable
    the `LegalizeOps` pass.
---
 python/tvm/relax/vm_build.py                          |  1 +
 src/relax/transform/legalize_ops.cc                   |  9 ++++++++-
 tests/python/relax/test_codegen_cublas.py             | 10 ++++++----
 tests/python/relax/test_codegen_cudnn.py              | 10 ++++++----
 tests/python/relax/test_codegen_cutlass.py            | 10 ++++++----
 tests/python/relax/test_codegen_dnnl.py               |  6 ++----
 tests/python/relax/test_codegen_tensorrt.py           |  6 ++----
 tests/python/relax/test_codegen_tir_cutlass.py        |  1 -
 tests/python/relax/test_dataflow_pattern.py           |  2 --
 tests/python/relax/test_e2e_op_dynamic.py             |  6 ++----
 tests/python/relax/test_frontend_stablehlo.py         |  4 ----
 tests/python/relax/test_op_gradient_numeric.py        |  6 ++----
 tests/python/relax/test_training_optimizer_numeric.py |  4 +---
 tests/python/relax/test_transform_gradient_numeric.py |  4 +---
 tests/python/relax/test_vm_execbuilder.py             |  3 +--
 15 files changed, 38 insertions(+), 44 deletions(-)

diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py
index 8b33379957..a54c0154fc 100644
--- a/python/tvm/relax/vm_build.py
+++ b/python/tvm/relax/vm_build.py
@@ -307,6 +307,7 @@ def build(
 
     lowering_passes = tvm.transform.Sequential(
         [
+            relax.transform.LegalizeOps(),
             relax.transform.RewriteDataflowReshape(),
             relax.transform.ToNonDataflow(),
             relax.transform.RemovePurityChecking(),
diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc
index 170967d282..a557a41f8e 100644
--- a/src/relax/transform/legalize_ops.cc
+++ b/src/relax/transform/legalize_ops.cc
@@ -31,6 +31,8 @@
 namespace tvm {
 namespace relax {
 
+TVM_REGISTER_PASS_CONFIG_OPTION("relax.transform.apply_legalize_ops", Bool);
+
 /*!
  * \brief Check if a given Tensor/Shape/TupleStructInfo contains shapes whose
  * values are all known.
@@ -206,7 +208,12 @@ namespace transform {
 Pass LegalizeOps(Optional<Map<String, PackedFunc>> cmap, bool enable_warning) {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
                                                                             PassContext pc) {
-    return LegalizeMutator(mod, cmap, enable_warning).Transform();
+    bool apply_legalize_ops =
+        pc->GetConfig<Bool>("relax.transform.apply_legalize_ops").value_or(Bool(true))->value;
+    if (apply_legalize_ops) {
+      mod = LegalizeMutator(mod, cmap, enable_warning).Transform();
+    }
+    return mod;
   };
   return CreateModulePass(/*pass_function=*/pass_func,
                           /*opt_level=*/0,
diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py
index fc2256531e..6c8f6bc335 100644
--- a/tests/python/relax/test_codegen_cublas.py
+++ b/tests/python/relax/test_codegen_cublas.py
@@ -42,11 +42,13 @@ pytestmark = [cublas_enabled]
 
 
 def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False):
-    if legalize:
-        mod = relax.transform.LegalizeOps()(mod)
-
     dev = tvm.device(target, 0)
-    with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": cuda_graph}):
+    with tvm.transform.PassContext(
+        config={
+            "relax.backend.use_cuda_graph": cuda_graph,
+            "relax.transform.apply_legalize_ops": legalize,
+        }
+    ):
         ex = relax.build(mod, target)
     vm = relax.VirtualMachine(ex, dev)
     f = vm["main"]
diff --git a/tests/python/relax/test_codegen_cudnn.py b/tests/python/relax/test_codegen_cudnn.py
index 5ba638c11c..c913559232 100644
--- a/tests/python/relax/test_codegen_cudnn.py
+++ b/tests/python/relax/test_codegen_cudnn.py
@@ -110,11 +110,13 @@ def get_result_with_relax_cudnn_offload(mod, np_inputs, cuda_graph=False):
 
 
 def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False):
-    if legalize:
-        mod = relax.transform.LegalizeOps()(mod)
-
     dev = tvm.device(target, 0)
-    with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": cuda_graph}):
+    with tvm.transform.PassContext(
+        config={
+            "relax.backend.use_cuda_graph": cuda_graph,
+            "relax.transform.apply_legalize_ops": legalize,
+        }
+    ):
         ex = relax.build(mod, target)
     vm = relax.VirtualMachine(ex, dev)
     f = vm["main"]
diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py
index f07a0dfcbb..bd647486d6 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -86,10 +86,12 @@ pytestmark = [cutlass_enabled]
 
 
 def build_and_run(mod, inputs_np, target, legalize=True, cuda_graph=False):
-    if legalize:
-        mod = relax.transform.LegalizeOps()(mod)  # For cpu reference, nop for cutlass.
-
-    with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": cuda_graph}):
+    with tvm.transform.PassContext(
+        config={
+            "relax.backend.use_cuda_graph": cuda_graph,
+            "relax.transform.apply_legalize_ops": legalize,
+        }
+    ):
         ex = relax.build(mod, target)
 
     dev = tvm.device(target, 0)
diff --git a/tests/python/relax/test_codegen_dnnl.py b/tests/python/relax/test_codegen_dnnl.py
index 66f442f165..fe4590f85a 100644
--- a/tests/python/relax/test_codegen_dnnl.py
+++ b/tests/python/relax/test_codegen_dnnl.py
@@ -52,14 +52,12 @@ pytestmark = [dnnl_enabled]
 
 
 def build_and_run(mod, inputs, legalize=False):
-    if legalize:
-        mod = relax.transform.LegalizeOps()(mod)
-
     target = tvm.target.Target("llvm")
     dev = tvm.cpu()
     inputs = [tvm.nd.array(inp, dev) for inp in inputs]
 
-    ex = relax.build(mod, target)
+    with tvm.transform.PassContext(config={"relax.transform.apply_legalize_ops": legalize}):
+        ex = relax.build(mod, target)
     vm = relax.VirtualMachine(ex, dev)
     f = vm["main"]
     return f(*inputs).numpy()
diff --git a/tests/python/relax/test_codegen_tensorrt.py b/tests/python/relax/test_codegen_tensorrt.py
index 595103bc5f..23dc7d887f 100644
--- a/tests/python/relax/test_codegen_tensorrt.py
+++ b/tests/python/relax/test_codegen_tensorrt.py
@@ -53,11 +53,9 @@ pytestmark = [tensorrt_enabled]
 
 
 def build_and_run(mod, inputs_np, target, legalize=False):
-    if legalize:
-        mod = relax.transform.LegalizeOps()(mod)
-
     dev = tvm.device(target, 0)
-    ex = relax.build(mod, target)
+    with tvm.transform.PassContext(config={"relax.transform.apply_legalize_ops": legalize}):
+        ex = relax.build(mod, target)
     vm = relax.VirtualMachine(ex, dev)
     f = vm["main"]
     inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]
diff --git a/tests/python/relax/test_codegen_tir_cutlass.py b/tests/python/relax/test_codegen_tir_cutlass.py
index 9c960ed355..a14ca7ac36 100644
--- a/tests/python/relax/test_codegen_tir_cutlass.py
+++ b/tests/python/relax/test_codegen_tir_cutlass.py
@@ -65,7 +65,6 @@ def build(mod):
 
 
 def build_and_run_reference(mod, inputs_np):
-    mod = relax.transform.LegalizeOps()(mod)
     dev = tvm.device("llvm", 0)
     ex = relax.build(mod, "llvm")
     vm = relax.VirtualMachine(ex, dev)
diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py
index a8b71aa5eb..520fb87322 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1181,7 +1181,6 @@ def test_combine_matmul_emit_order():
         # make sure it builds
         mod = tvm.IRModule()
         mod["main"] = rewritten
-        mod = rx.transform.LegalizeOps()(mod)
 
         rx.build(mod, target="llvm")
 
@@ -1279,7 +1278,6 @@ def test_combine_transposed_matmul_twice():
         # make sure it builds
         mod = tvm.IRModule()
         mod["main"] = rewritten
-        mod = rx.transform.LegalizeOps()(mod)
 
         rx.build(mod, target="llvm")
 
diff --git a/tests/python/relax/test_e2e_op_dynamic.py b/tests/python/relax/test_e2e_op_dynamic.py
index 63c71b7979..641469172f 100644
--- a/tests/python/relax/test_e2e_op_dynamic.py
+++ b/tests/python/relax/test_e2e_op_dynamic.py
@@ -49,8 +49,7 @@ def test_dynamic_strided_slice(begin, end, strides):
             gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, begin, end, strides)
             return gv
     # fmt: on
-    mod = LegalizeOps()(DynamicStridedSlice)
-    vm = build(mod)
+    vm = build(DynamicStridedSlice)
 
     x_np = np.random.rand(8, 9, 10, 10).astype(np.float32)
     data_nd = tvm.nd.array(x_np, dev)
@@ -83,8 +82,7 @@ def test_dynamic_strided_slice_symbolic(begin, end, strides):
             gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, begin, end, strides)
             return gv
     # fmt: on
-    mod = LegalizeOps()(DynamicStridedSlice)
-    vm = build(mod)
+    vm = build(DynamicStridedSlice)
 
     x_np = np.random.rand(8, 9, 10, 10).astype(np.float32)
     data_nd = tvm.nd.array(x_np, dev)
diff --git a/tests/python/relax/test_frontend_stablehlo.py b/tests/python/relax/test_frontend_stablehlo.py
index 4152d50b8d..d3068f29c7 100644
--- a/tests/python/relax/test_frontend_stablehlo.py
+++ b/tests/python/relax/test_frontend_stablehlo.py
@@ -115,9 +115,6 @@ def check_correctness(
     # Run the jax jitted model with the input jax numpy data
     jax_output = jax_jit_mod(*inputs_jnp)
 
-    # Legalize the Relax Operators into TensorIR
-    # TODO (relax-team): add LegalizeOps in default seq in vm_build
-    ir_mod = relax.transform.LegalizeOps()(ir_mod)
     # TODO (yongwww): support multiple targets,
     # "llvm" should be good for this check
     target = tvm.target.Target("llvm", host="llvm")
@@ -157,7 +154,6 @@ def get_vm_res(
     out: Union[tvm.nd.NDArray, List[tvm.nd.NDArray]]
         inference result
     """
-    ir_mod = relax.transform.LegalizeOps()(ir_mod)
     target = tvm.target.Target("llvm", host="llvm")
     # Compile and run
     ex = relax.build(ir_mod, target)
diff --git a/tests/python/relax/test_op_gradient_numeric.py b/tests/python/relax/test_op_gradient_numeric.py
index 4b4c5cabc4..bc5cb0f5be 100644
--- a/tests/python/relax/test_op_gradient_numeric.py
+++ b/tests/python/relax/test_op_gradient_numeric.py
@@ -122,8 +122,7 @@ def relax_check_gradients(
             out = forward_bb.emit_output(call)
         forward_bb.emit_func_output(out)
     forward_mod = forward_bb.get()
-    forward_lower_mod = LegalizeOps()(forward_mod)
-    forward_ex = relax.build(forward_lower_mod, target)
+    forward_ex = relax.build(forward_mod, target)
     forward_vm = relax.VirtualMachine(forward_ex, dev)
 
     # Generate weights
@@ -187,8 +186,7 @@ def relax_check_gradients(
         grad_bb.emit_func_output(out)
 
     grad_mod = grad_bb.get()
-    grad_lower_mod = LegalizeOps()(grad_mod)
-    grad_ex = relax.build(grad_lower_mod, target)
+    grad_ex = relax.build(grad_mod, target)
     grad_vm = relax.VirtualMachine(grad_ex, dev)
 
     # tvm.runtime.NDArray inputs
diff --git a/tests/python/relax/test_training_optimizer_numeric.py b/tests/python/relax/test_training_optimizer_numeric.py
index 23db8987f1..8acf7ad66b 100644
--- a/tests/python/relax/test_training_optimizer_numeric.py
+++ b/tests/python/relax/test_training_optimizer_numeric.py
@@ -23,15 +23,13 @@ import tvm.testing
 from tvm import relax
 from tvm import IRModule
 from tvm.relax.training.optimizer import Adam, SGD, MomentumSGD
-from tvm.relax.transform import LegalizeOps
 from tvm.script.parser import relax as R
 from tvm.runtime.relax_vm import VirtualMachine
 from tvm.testing import assert_allclose
 
 
 def _legalize_and_build(mod: IRModule, target, dev):
-    lowered_mod = LegalizeOps()(mod)
-    ex = relax.build(lowered_mod, target)
+    ex = relax.build(mod, target)
     vm = VirtualMachine(ex, dev)
     return vm
 
diff --git a/tests/python/relax/test_transform_gradient_numeric.py b/tests/python/relax/test_transform_gradient_numeric.py
index 7585ecf1f6..38a63406e8 100644
--- a/tests/python/relax/test_transform_gradient_numeric.py
+++ b/tests/python/relax/test_transform_gradient_numeric.py
@@ -22,12 +22,10 @@ from tvm.relay.testing import rand
 from tvm.testing import assert_allclose
 from tvm.testing.utils import check_numerical_grads
 from tvm.script.parser import ir as I, relax as R
-from tvm.relax.transform import LegalizeOps
 
 
 def _legalize_and_build(mod, target, dev):
-    lowered_mod = LegalizeOps()(mod)
-    ex = relax.build(lowered_mod, target)
+    ex = relax.build(mod, target)
     vm = relax.VirtualMachine(ex, dev)
     return vm
 
diff --git a/tests/python/relax/test_vm_execbuilder.py b/tests/python/relax/test_vm_execbuilder.py
index 4c15d8013b..b2d9edd346 100644
--- a/tests/python/relax/test_vm_execbuilder.py
+++ b/tests/python/relax/test_vm_execbuilder.py
@@ -277,8 +277,7 @@ def test_vm_stack_restore_after_failure():
                 R.output(gv)
             return gv
 
-    mod = relax.transform.LegalizeOps()(Module)
-    ex = relax.build(mod, "llvm")
+    ex = relax.build(Module, "llvm")
     vm = relax.VirtualMachine(ex, tvm.cpu())
 
     correct_input = tvm.nd.array(np.random.normal(size=(10, 10)).astype("float32"))