You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by yo...@apache.org on 2024/03/07 01:13:46 UTC

(tvm) branch main updated: [Relax] Remove the legalization of cumsum/cumprob (#16676)

This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6ca2341460 [Relax] Remove the legalization of cumsum/cumprob (#16676)
6ca2341460 is described below

commit 6ca234146024f370e7713a2835dde8fe8f459da2
Author: Yong Wu <yo...@gmail.com>
AuthorDate: Wed Mar 6 17:13:39 2024 -0800

    [Relax] Remove the legalization of cumsum/cumprob (#16676)
    
    * [Relax] Remove the legalization of cumsum/cumprob
    
    * remove related tests
---
 .../relax/transform/legalize_ops/statistical.py    | 14 -----
 tests/python/relax/test_frontend_nn_op.py          |  1 -
 ...st_transform_legalize_ops_search_statistical.py | 69 ----------------------
 3 files changed, 84 deletions(-)

diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py b/python/tvm/relax/transform/legalize_ops/statistical.py
index 1181b3b2a7..bdb79126f0 100644
--- a/python/tvm/relax/transform/legalize_ops/statistical.py
+++ b/python/tvm/relax/transform/legalize_ops/statistical.py
@@ -85,17 +85,3 @@ register_legalize("relax.max", _statistical(topi.max))
 register_legalize("relax.min", _statistical(topi.min))
 register_legalize("relax.prod", _statistical(topi.prod))
 register_legalize("relax.sum", _statistical(topi.sum))
-
-
-@register_legalize("relax.cumsum")
-def _cumsum(bb: BlockBuilder, call: Call) -> Expr:
-    return bb.call_te(
-        topi.cumsum, call.args[0], call.attrs.axis, call.attrs.dtype, call.attrs.exclusive
-    )
-
-
-@register_legalize("relax.cumprod")
-def _cumprod(bb: BlockBuilder, call: Call) -> Expr:
-    return bb.call_te(
-        topi.cumprod, call.args[0], call.attrs.axis, call.attrs.dtype, call.attrs.exclusive
-    )
diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py
index 0d579163cd..eb1df67a8f 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -1161,7 +1161,6 @@ def test_renormalize_top_p_top_k_prob():
 
     target = tvm.target.Target("cuda -libs=thrust", host="llvm")
     with target:
-        mod = relax.backend.DispatchSortScan()(mod)
         mod = relax.transform.LegalizeOps()(mod)
         mod = tir.transform.DefaultGPUSchedule()(mod)
 
diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
index c6c53ff0b9..2a28151dbe 100644
--- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
+++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
@@ -1066,74 +1066,5 @@ def test_variance_no_keepdims():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
-def test_cumsum():
-    # fmt: off
-    @I.ir_module
-    class Cumsum:
-        @R.function
-        def main(x: R.Tensor((3, 2, 3), "float32")):
-            gv = R.cumsum(x, axis=1, dtype="int32")
-            return gv
-
-    @I.ir_module
-    class Expected:
-        @T.prim_func(private=True)
-        def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "int32")):
-            T.func_attr({"tir.noalias": True})
-            rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(3), T.int64(2), T.int64(3)), offset_factor=1)
-            with T.block("cumsum_generic"):
-                for fused in T.parallel(T.int64(9)):
-                    out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) % T.int64(3)] = T.Cast("int32", rxplaceholder[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % [...]
-                    for _k in range(T.int64(1)):
-                        out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) % T.int64(3)] = out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k [...]
-
-        @R.function
-        def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((3, 2, 3), dtype="int32"):
-            cls = Expected
-            gv = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((3, 2, 3), dtype="int32"))
-            return gv
-    # fmt: on
-
-    mod = LegalizeOps()(Cumsum)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_cumsum_symbolic():
-    # fmt: off
-    @I.ir_module
-    class Cumsum:
-        @R.function
-        def main(x: R.Tensor(("a", "b", "c"), "float32")):
-            gv = R.cumsum(x, axis=1, dtype="int32")
-            return gv
-
-    @I.ir_module
-    class Expected:
-        @T.prim_func(private=True)
-        def cumsum(var_rxplaceholder: T.handle, var_cumsum_generic: T.handle):
-            T.func_attr({"tir.noalias": True})
-            a, b, c = T.int64(), T.int64(), T.int64()
-            rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c), offset_factor=1)
-            out_buf = T.match_buffer(var_cumsum_generic, (a, b, c), "int32")
-            with T.block("cumsum_generic"):
-                for fused in T.parallel(a * c):
-                    out_buf[(fused // c * b * c + fused % c) // c // b, (fused // c * b * c + fused % c) // c % b, (fused // c * b * c + fused % c) % c] = T.Cast("int32", rxplaceholder[(fused // c * b * c + fused % c) // c // b, (fused // c * b * c + fused % c) // c % b, (fused // c * b * c + fused % c) % c])
-                    for _k in range(b - T.int64(1)):
-                        out_buf[(fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) % c] = out_buf[(fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) % c] + T.Cast("int32", [...]
-
-        @R.function
-        def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor(("a", "b", "c"), dtype="int32"):
-            a = T.int64()
-            b = T.int64()
-            c = T.int64()
-            cls = Expected
-            gv = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((a, b, c), dtype="int32"))
-            return gv
-    # fmt: on
-
-    mod = LegalizeOps()(Cumsum)
-    tvm.ir.assert_structural_equal(mod, Expected)
-
-
 if __name__ == "__main__":
     tvm.testing.main()