You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/11/16 04:04:06 UTC

[tvm] branch main updated: [TIR] Fix extern_primfunc buffer order bug (#13347)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6401d0ef62 [TIR] Fix extern_primfunc buffer order bug  (#13347)
6401d0ef62 is described below

commit 6401d0ef62c8f22f187dfb9fde3dfe07f55df206
Author: Gavin Uberti <gu...@users.noreply.github.com>
AuthorDate: Wed Nov 16 13:03:58 2022 +0900

    [TIR] Fix extern_primfunc buffer order bug  (#13347)
    
    This fixes #13330, which was blocking my work to write TIR schedules for microTVM.
    
    I originally thought I'd have to change the function signature of `DomainTouchedAccessMap`, but I couldn't think of a way to do that cleanly. Instead, I changed `extern_primfunc` to use `primfunc.params` to create the buffer lists in the right order.
    
    #13330 should have been caught by `test_tir_te_extern_primfunc.py`, but one of that test's helper functions had the same bug as `extern_primfunc`. I've thus modified `test_tir_te_extern_primfunc.py` to instantiate the input tensors a different way, allowing it to catch regressions of this issue.
---
 python/tvm/te/operation.py                         | 11 ++---
 .../python/unittest/test_tir_te_extern_primfunc.py | 50 ++++------------------
 2 files changed, 15 insertions(+), 46 deletions(-)

diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py
index 5279c46aeb..846f88d389 100644
--- a/python/tvm/te/operation.py
+++ b/python/tvm/te/operation.py
@@ -398,11 +398,12 @@ def extern_primfunc(input_tensors: List[_tensor.Tensor], primfunc: tvm.tir.PrimF
 
         C = te.extern_primfunc([A, B], func)
     """
-    access_map = {
-        k: tuple(v) for k, v in tvm.arith._ffi_api.DomainTouchedAccessMap(primfunc).items()
-    }
-    in_buffers = [buf for buf, access in access_map.items() if len(access[0])]
-    out_buffers = [buf for buf, access in access_map.items() if len(access[1])]
+
+    # dt_access_map and primfunc.buffer_map are unordered, so use order from primfunc.params
+    dt_access_map = tvm.arith._ffi_api.DomainTouchedAccessMap(primfunc)
+    ordered_buffers = [primfunc.buffer_map[param] for param in primfunc.params]
+    in_buffers = [buf for buf in ordered_buffers if len(dt_access_map[buf][0])]
+    out_buffers = [buf for buf in ordered_buffers if len(dt_access_map[buf][1])]
     assert in_buffers, "PrimFunc has no input buffers"
     assert out_buffers, "PrimFunc has no output buffers"
 
diff --git a/tests/python/unittest/test_tir_te_extern_primfunc.py b/tests/python/unittest/test_tir_te_extern_primfunc.py
index 2675214562..a622f77cc7 100644
--- a/tests/python/unittest/test_tir_te_extern_primfunc.py
+++ b/tests/python/unittest/test_tir_te_extern_primfunc.py
@@ -21,10 +21,8 @@ import numpy as np
 
 import tvm
 import tvm.testing
-from tvm import tir, te, TVMError
+from tvm import te
 from tvm.script import tir as T
-from tvm.arith import _ffi_api as _ffi_arith_api
-from tvm.tir.schedule import _ffi_api as _ffi_schedule_api
 
 
 # TODO(csullivan): Additional tests cases needed:
@@ -174,11 +172,11 @@ def verify_func_4(module):
 
 
 class TestPrimFuncs:
-    func, verify = tvm.testing.parameters(
-        [func_1, verify_func_1],
-        [func_2, verify_func_2],
-        [func_3, verify_func_3],
-        [func_4, verify_func_4],
+    func, params, verify = tvm.testing.parameters(
+        [func_1, ("A"), verify_func_1],
+        [func_2, ("C", "D"), verify_func_2],
+        [func_3, ("C", "A", "D", "E"), verify_func_3],
+        [func_4, ("C", "A", "D", "E"), verify_func_4],
     )
 
     def test_primfunc_call(self, func, verify):
@@ -186,11 +184,12 @@ class TestPrimFuncs:
         func = tvm.build(func, target=target)
         verify(func)
 
-    def test_te_extern_call(self, func, verify):
+    def test_te_extern_call(self, func, params, verify):
         ir_mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
         prim_func = ir_mod["main"]
 
-        input_tensors = create_input_tensors_for_primfunc(prim_func)
+        buf_name_map = {buf.name: buf for buf in func.buffer_map.values()}
+        input_tensors = [te.placeholder(buf_name_map[name].shape) for name in params]
         output = te.extern_primfunc(input_tensors, prim_func)
         rt_prim_func = te.create_prim_func(tensors_from_extern_op(output, prim_func))
         tvm.ir.assert_structural_equal(tvm.lower(prim_func), tvm.lower(rt_prim_func))
@@ -222,36 +221,5 @@ def tensors_from_extern_op(extern, func):
     return ordered_tensors
 
 
-def create_input_tensors_for_primfunc(primfunc):
-    access_map = {k: tuple(v) for k, v in _ffi_arith_api.DomainTouchedAccessMap(primfunc).items()}
-    in_buffers = [buf for buf, access in access_map.items() if len(access[0])]
-    out_buffers = [buf for buf, access in access_map.items() if len(access[1])]
-    assert in_buffers, "PrimFunc has no input buffers"
-    assert out_buffers, "PrimFunc has no output buffers"
-
-    outputs = []
-    inplace = []
-    inputs = in_buffers
-    for obuf in out_buffers:
-        if obuf in in_buffers:
-            inplace.append(obuf)
-        else:
-            outputs.append(obuf)
-
-    if not outputs:
-        iobuf = inplace.pop()
-        inputs.remove(iobuf)
-        outputs = [iobuf]
-
-    def create_tensors(input_buffers):
-        tensors = []
-        for buf in input_buffers:
-            t = te.placeholder(buf.shape, dtype=buf.dtype, name=buf.name + "_placeholder")
-            tensors.append(t)
-        return tensors
-
-    return create_tensors(inputs)
-
-
 if __name__ == "__main__":
     sys.exit(pytest.main(sys.argv))