You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by yo...@apache.org on 2024/03/07 01:13:46 UTC
(tvm) branch main updated: [Relax] Remove the legalization of cumsum/cumprob (#16676)
This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6ca2341460 [Relax] Remove the legalization of cumsum/cumprob (#16676)
6ca2341460 is described below
commit 6ca234146024f370e7713a2835dde8fe8f459da2
Author: Yong Wu <yo...@gmail.com>
AuthorDate: Wed Mar 6 17:13:39 2024 -0800
[Relax] Remove the legalization of cumsum/cumprob (#16676)
* [Relax] Remove the legalization of cumsum/cumprob
* remove related tests
---
.../relax/transform/legalize_ops/statistical.py | 14 -----
tests/python/relax/test_frontend_nn_op.py | 1 -
...st_transform_legalize_ops_search_statistical.py | 69 ----------------------
3 files changed, 84 deletions(-)
diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py b/python/tvm/relax/transform/legalize_ops/statistical.py
index 1181b3b2a7..bdb79126f0 100644
--- a/python/tvm/relax/transform/legalize_ops/statistical.py
+++ b/python/tvm/relax/transform/legalize_ops/statistical.py
@@ -85,17 +85,3 @@ register_legalize("relax.max", _statistical(topi.max))
register_legalize("relax.min", _statistical(topi.min))
register_legalize("relax.prod", _statistical(topi.prod))
register_legalize("relax.sum", _statistical(topi.sum))
-
-
-@register_legalize("relax.cumsum")
-def _cumsum(bb: BlockBuilder, call: Call) -> Expr:
- return bb.call_te(
- topi.cumsum, call.args[0], call.attrs.axis, call.attrs.dtype, call.attrs.exclusive
- )
-
-
-@register_legalize("relax.cumprod")
-def _cumprod(bb: BlockBuilder, call: Call) -> Expr:
- return bb.call_te(
- topi.cumprod, call.args[0], call.attrs.axis, call.attrs.dtype, call.attrs.exclusive
- )
diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py
index 0d579163cd..eb1df67a8f 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -1161,7 +1161,6 @@ def test_renormalize_top_p_top_k_prob():
target = tvm.target.Target("cuda -libs=thrust", host="llvm")
with target:
- mod = relax.backend.DispatchSortScan()(mod)
mod = relax.transform.LegalizeOps()(mod)
mod = tir.transform.DefaultGPUSchedule()(mod)
diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
index c6c53ff0b9..2a28151dbe 100644
--- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
+++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
@@ -1066,74 +1066,5 @@ def test_variance_no_keepdims():
tvm.ir.assert_structural_equal(mod, Expected)
-def test_cumsum():
- # fmt: off
- @I.ir_module
- class Cumsum:
- @R.function
- def main(x: R.Tensor((3, 2, 3), "float32")):
- gv = R.cumsum(x, axis=1, dtype="int32")
- return gv
-
- @I.ir_module
- class Expected:
- @T.prim_func(private=True)
- def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "int32")):
- T.func_attr({"tir.noalias": True})
- rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(3), T.int64(2), T.int64(3)), offset_factor=1)
- with T.block("cumsum_generic"):
- for fused in T.parallel(T.int64(9)):
- out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) % T.int64(3)] = T.Cast("int32", rxplaceholder[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % [...]
- for _k in range(T.int64(1)):
- out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) % T.int64(3)] = out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k [...]
-
- @R.function
- def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((3, 2, 3), dtype="int32"):
- cls = Expected
- gv = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((3, 2, 3), dtype="int32"))
- return gv
- # fmt: on
-
- mod = LegalizeOps()(Cumsum)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
-def test_cumsum_symbolic():
- # fmt: off
- @I.ir_module
- class Cumsum:
- @R.function
- def main(x: R.Tensor(("a", "b", "c"), "float32")):
- gv = R.cumsum(x, axis=1, dtype="int32")
- return gv
-
- @I.ir_module
- class Expected:
- @T.prim_func(private=True)
- def cumsum(var_rxplaceholder: T.handle, var_cumsum_generic: T.handle):
- T.func_attr({"tir.noalias": True})
- a, b, c = T.int64(), T.int64(), T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c), offset_factor=1)
- out_buf = T.match_buffer(var_cumsum_generic, (a, b, c), "int32")
- with T.block("cumsum_generic"):
- for fused in T.parallel(a * c):
- out_buf[(fused // c * b * c + fused % c) // c // b, (fused // c * b * c + fused % c) // c % b, (fused // c * b * c + fused % c) % c] = T.Cast("int32", rxplaceholder[(fused // c * b * c + fused % c) // c // b, (fused // c * b * c + fused % c) // c % b, (fused // c * b * c + fused % c) % c])
- for _k in range(b - T.int64(1)):
- out_buf[(fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) % c] = out_buf[(fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) % c] + T.Cast("int32", [...]
-
- @R.function
- def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor(("a", "b", "c"), dtype="int32"):
- a = T.int64()
- b = T.int64()
- c = T.int64()
- cls = Expected
- gv = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((a, b, c), dtype="int32"))
- return gv
- # fmt: on
-
- mod = LegalizeOps()(Cumsum)
- tvm.ir.assert_structural_equal(mod, Expected)
-
-
if __name__ == "__main__":
tvm.testing.main()