You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/11/16 04:04:06 UTC
[tvm] branch main updated: [TIR] Fix extern_primfunc buffer order bug (#13347)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6401d0ef62 [TIR] Fix extern_primfunc buffer order bug (#13347)
6401d0ef62 is described below
commit 6401d0ef62c8f22f187dfb9fde3dfe07f55df206
Author: Gavin Uberti <gu...@users.noreply.github.com>
AuthorDate: Wed Nov 16 13:03:58 2022 +0900
[TIR] Fix extern_primfunc buffer order bug (#13347)
This fixes #13330, which was blocking my work to write TIR schedules for microTVM.
I originally thought I'd have to change the function signature of `DomainTouchedAccessMap`, but I couldn't think of a way to do that cleanly. Instead, I changed `extern_primfunc` to use `primfunc.params` to create the buffer lists in the right order.
#13330 should have been caught by `test_tir_te_extern_primfunc.py`, but one of that test's helper functions had the same bug as `extern_primfunc`. I've thus modified `test_tir_te_extern_primfunc.py` to instantiate the input tensors a different way, allowing it to catch regressions of this issue.
---
python/tvm/te/operation.py | 11 ++---
.../python/unittest/test_tir_te_extern_primfunc.py | 50 ++++------------------
2 files changed, 15 insertions(+), 46 deletions(-)
diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py
index 5279c46aeb..846f88d389 100644
--- a/python/tvm/te/operation.py
+++ b/python/tvm/te/operation.py
@@ -398,11 +398,12 @@ def extern_primfunc(input_tensors: List[_tensor.Tensor], primfunc: tvm.tir.PrimF
C = te.extern_primfunc([A, B], func)
"""
- access_map = {
- k: tuple(v) for k, v in tvm.arith._ffi_api.DomainTouchedAccessMap(primfunc).items()
- }
- in_buffers = [buf for buf, access in access_map.items() if len(access[0])]
- out_buffers = [buf for buf, access in access_map.items() if len(access[1])]
+
+ # dt_access_map and primfunc.buffer_map are unordered, so use order from primfunc.params
+ dt_access_map = tvm.arith._ffi_api.DomainTouchedAccessMap(primfunc)
+ ordered_buffers = [primfunc.buffer_map[param] for param in primfunc.params]
+ in_buffers = [buf for buf in ordered_buffers if len(dt_access_map[buf][0])]
+ out_buffers = [buf for buf in ordered_buffers if len(dt_access_map[buf][1])]
assert in_buffers, "PrimFunc has no input buffers"
assert out_buffers, "PrimFunc has no output buffers"
diff --git a/tests/python/unittest/test_tir_te_extern_primfunc.py b/tests/python/unittest/test_tir_te_extern_primfunc.py
index 2675214562..a622f77cc7 100644
--- a/tests/python/unittest/test_tir_te_extern_primfunc.py
+++ b/tests/python/unittest/test_tir_te_extern_primfunc.py
@@ -21,10 +21,8 @@ import numpy as np
import tvm
import tvm.testing
-from tvm import tir, te, TVMError
+from tvm import te
from tvm.script import tir as T
-from tvm.arith import _ffi_api as _ffi_arith_api
-from tvm.tir.schedule import _ffi_api as _ffi_schedule_api
# TODO(csullivan): Additional tests cases needed:
@@ -174,11 +172,11 @@ def verify_func_4(module):
class TestPrimFuncs:
- func, verify = tvm.testing.parameters(
- [func_1, verify_func_1],
- [func_2, verify_func_2],
- [func_3, verify_func_3],
- [func_4, verify_func_4],
+ func, params, verify = tvm.testing.parameters(
+ [func_1, ("A"), verify_func_1],
+ [func_2, ("C", "D"), verify_func_2],
+ [func_3, ("C", "A", "D", "E"), verify_func_3],
+ [func_4, ("C", "A", "D", "E"), verify_func_4],
)
def test_primfunc_call(self, func, verify):
@@ -186,11 +184,12 @@ class TestPrimFuncs:
func = tvm.build(func, target=target)
verify(func)
- def test_te_extern_call(self, func, verify):
+ def test_te_extern_call(self, func, params, verify):
ir_mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
prim_func = ir_mod["main"]
- input_tensors = create_input_tensors_for_primfunc(prim_func)
+ buf_name_map = {buf.name: buf for buf in func.buffer_map.values()}
+ input_tensors = [te.placeholder(buf_name_map[name].shape) for name in params]
output = te.extern_primfunc(input_tensors, prim_func)
rt_prim_func = te.create_prim_func(tensors_from_extern_op(output, prim_func))
tvm.ir.assert_structural_equal(tvm.lower(prim_func), tvm.lower(rt_prim_func))
@@ -222,36 +221,5 @@ def tensors_from_extern_op(extern, func):
return ordered_tensors
-def create_input_tensors_for_primfunc(primfunc):
- access_map = {k: tuple(v) for k, v in _ffi_arith_api.DomainTouchedAccessMap(primfunc).items()}
- in_buffers = [buf for buf, access in access_map.items() if len(access[0])]
- out_buffers = [buf for buf, access in access_map.items() if len(access[1])]
- assert in_buffers, "PrimFunc has no input buffers"
- assert out_buffers, "PrimFunc has no output buffers"
-
- outputs = []
- inplace = []
- inputs = in_buffers
- for obuf in out_buffers:
- if obuf in in_buffers:
- inplace.append(obuf)
- else:
- outputs.append(obuf)
-
- if not outputs:
- iobuf = inplace.pop()
- inputs.remove(iobuf)
- outputs = [iobuf]
-
- def create_tensors(input_buffers):
- tensors = []
- for buf in input_buffers:
- t = te.placeholder(buf.shape, dtype=buf.dtype, name=buf.name + "_placeholder")
- tensors.append(t)
- return tensors
-
- return create_tensors(inputs)
-
-
if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))