You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/04/28 13:00:39 UTC

[tvm] branch main updated: [TIR] Get read/write access precisely for opaque access. (#11110)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7710dfd557 [TIR] Get read/write access precisely for opaque access. (#11110)
7710dfd557 is described below

commit 7710dfd557752e2a980b90d1b19a66d8dcefb929
Author: albert qing <26...@qq.com>
AuthorDate: Thu Apr 28 21:00:29 2022 +0800

    [TIR] Get read/write access precisely for opaque access. (#11110)
    
    * [TIR] Get read/write access precisely for opaque access.
    
    When the opaque access is wrapped with tvm_access_ptr, we can get the access_mask
    from tvm_access_ptr in BlockReadWriteDetector and put this opaque access to read_regions
    or write_regions according to access_mask.
    
    * [TIR] Add parameter extent for access_ptr.
    
    Co-authored-by: sqing <qi...@intellif.com>
---
 include/tvm/tir/buffer.h                           |  5 +-
 python/tvm/tir/buffer.py                           | 10 ++-
 src/tir/analysis/block_access_region_detector.cc   | 28 +++++++
 src/tir/ir/buffer.cc                               |  8 +-
 .../test_tir_analysis_get_block_access_region.py   | 29 +++++++
 tests/python/unittest/test_tir_buffer.py           |  6 ++
 .../unittest/test_tir_schedule_compute_inline.py   | 97 ++++++++++++++--------
 7 files changed, 142 insertions(+), 41 deletions(-)

diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h
index aef82ae368..ca7faf1cde 100644
--- a/include/tvm/tir/buffer.h
+++ b/include/tvm/tir/buffer.h
@@ -186,10 +186,11 @@ class Buffer : public ObjectRef {
    * \param ptr_type The type of the pointer.
    * \param content_lanes The number of lanes for the (data) type.
    * \param offset The offset of ptr.
+   * \param input_extent The extent of ptr.
    */
   TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(),
-                              int content_lanes = 1,
-                              PrimExpr offset = IntImm(DataType::Int(32), 0)) const;
+                              int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0),
+                              Optional<PrimExpr> input_extent = NullOpt) const;
   /*!
    * \brief Create an Expr that does a vector load at begin index.
    * \param begin The beginning index
diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py
index e36a99339e..d9b0aec76a 100644
--- a/python/tvm/tir/buffer.py
+++ b/python/tvm/tir/buffer.py
@@ -42,7 +42,7 @@ class Buffer(Object):
     READ = 1
     WRITE = 2
 
-    def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
+    def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0, extent=None):
         """Get an access pointer to the head of buffer.
 
         This is the recommended method to get buffer data
@@ -66,6 +66,9 @@ class Buffer(Object):
             The offset of pointer. We can use it to offset by
             the number of elements from the address of ptr.
 
+        extent: Expr, optional
+            The extent of pointer.
+
         Examples
         --------
         .. code-block:: python
@@ -78,6 +81,8 @@ class Buffer(Object):
           buffer.access_ptr("rw")
           # Get access ptr for read with offset
           buffer.access_ptr("r", offset = 100)
+          # Get access ptr for read with extent
+          buffer.access_ptr("r", extent = 100)
         """
         if isinstance(access_mask, string_types):
             mask = 0
@@ -90,8 +95,9 @@ class Buffer(Object):
                     raise ValueError("Unknown access_mask %s" % access_mask)
             access_mask = mask
         offset = convert(offset)
+        extent = convert(extent)
         return _ffi_api.BufferAccessPtr(
-            self, access_mask, ptr_type, content_lanes, offset  # type: ignore
+            self, access_mask, ptr_type, content_lanes, offset, extent  # type: ignore
         )
 
     def vload(self, begin, dtype=None):
diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc
index ffe0c75294..c65a422ed3 100644
--- a/src/tir/analysis/block_access_region_detector.cc
+++ b/src/tir/analysis/block_access_region_detector.cc
@@ -181,6 +181,34 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) {
 }
 
 void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
+  if (op->op.same_as(builtin::tvm_access_ptr())) {
+    const VarNode* buffer_var = op->args[1].as<VarNode>();
+    const IntImmNode* access_mask = op->args[4].as<IntImmNode>();
+    if (buffer_var && access_mask) {
+      auto it = buffer_var_map_.find(GetRef<Var>(buffer_var));
+      if (it != buffer_var_map_.end()) {
+        const Buffer& buffer = (*it).second;
+        const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
+        const Region& region = buffer_region->region;
+        std::vector<arith::IntSet> int_set;
+        int_set.reserve(region.size());
+        for (const Range& range : region) {
+          int_set.push_back(arith::EvalSet(range, dom_map_));
+        }
+        // read access, write access or opaque access
+        if ((access_mask->value & 1) && (access_mask->value & 2)) {
+          Update(&opaque_buffers_, &opaque_regions_, buffer, int_set);
+        } else if (access_mask->value & 1) {
+          Update(&read_buffers_, &read_regions_, buffer, int_set);
+        } else if (access_mask->value & 2) {
+          Update(&writes_buffers_, &write_regions_, buffer, int_set);
+        }
+      }
+    } else {
+      StmtExprVisitor::VisitExpr_(op);
+    }
+    return;
+  }
   if (op->op.same_as(builtin::if_then_else())) {
     VisitExpr(op->args[0]);
     {
diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc
index 9cc92bd17e..ccf186634b 100644
--- a/src/tir/ir/buffer.cc
+++ b/src/tir/ir/buffer.cc
@@ -495,8 +495,8 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
   return slice;
 }
 
-PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes,
-                            PrimExpr offset) const {
+PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, PrimExpr offset,
+                            Optional<PrimExpr> input_extent) const {
   const BufferNode* self = operator->();
   ICHECK(self != nullptr);
   PrimExpr e_dtype;
@@ -519,6 +519,10 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
   } else {
     e_dtype = tir::TypeAnnotation(self->dtype);
   }
+
+  if (input_extent.defined()) {
+    extent = input_extent.value();
+  }
   Array<PrimExpr> acc_args{e_dtype, self->data, elem_offset, extent,
                            make_const(DataType::Int(32), access_mask)};
   return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args);
diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
index 463f2a7f0e..8a10cbd072 100644
--- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py
+++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
@@ -105,6 +105,19 @@ def opaque_access_func() -> None:
             )
 
 
+@T.prim_func
+def opaque_access_with_tvm_access_ptr_func() -> None:
+    A = T.alloc_buffer([1024])
+    B = T.alloc_buffer([1024])
+    C = T.alloc_buffer([1024])
+    with T.block("opaque"):
+        T.reads(A[0:1024], C[0:1024])
+        T.writes(B[0:1024], C[0:1024])
+        T.evaluate(A.access_ptr("r"))
+        T.evaluate(B.access_ptr("w"))
+        T.evaluate(C.access_ptr("rw"))
+
+
 @T.prim_func
 def access_in_if_then_else_func() -> None:
     A = T.alloc_buffer([8])
@@ -235,6 +248,21 @@ def test_opaque_access():
         tvm.ir.assert_structural_equal(ret0[1], ret1[1])
 
 
+def test_opaque_access_with_tvm_access_ptr():
+    block = opaque_access_with_tvm_access_ptr_func.body.block.body.block
+    alloc_buffers = opaque_access_with_tvm_access_ptr_func.body.block.alloc_buffers
+    buffer_var_map = {buf.data: buf for buf in alloc_buffers}
+
+    ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map)
+    ret1 = tir.analysis.get_block_access_region(block, buffer_var_map)
+    tvm.ir.assert_structural_equal(block.reads, ret0[0])
+    tvm.ir.assert_structural_equal(block.writes, ret0[1])
+    with pytest.raises(ValueError):
+        tvm.ir.assert_structural_equal(ret0[0], ret1[0])
+    with pytest.raises(ValueError):
+        tvm.ir.assert_structural_equal(ret0[1], ret1[1])
+
+
 def test_match_buffer():
     root_block = match_buffer_func.body.block
     block = root_block.body.body.body.block
@@ -333,6 +361,7 @@ if __name__ == "__main__":
     test_block_access_region_detector()
     test_opaque_block()
     test_opaque_access()
+    test_opaque_access_with_tvm_access_ptr()
     test_match_buffer()
     test_access_in_if_then_else_func()
     test_access_in_branch_func()
diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py
index 990d0a22c8..337f9cbc07 100644
--- a/tests/python/unittest/test_tir_buffer.py
+++ b/tests/python/unittest/test_tir_buffer.py
@@ -76,6 +76,12 @@ def test_buffer_access_ptr_extent():
     aptr = Ab.access_ptr("rw", offset=100)
     assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m - 100)
 
+    # Test extent from input params
+    aptr = Ab.access_ptr("rw", extent=200)
+    assert tvm.ir.structural_equal(aptr.args[3], 200)
+    aptr = Ab.access_ptr("rw", offset=100, extent=100)
+    assert tvm.ir.structural_equal(aptr.args[3], 100)
+
 
 def test_buffer_vload():
     m = te.size_var("m")
diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py
index 1259219a39..8894cd4d9f 100644
--- a/tests/python/unittest/test_tir_schedule_compute_inline.py
+++ b/tests/python/unittest/test_tir_schedule_compute_inline.py
@@ -183,11 +183,7 @@ def opaque_access_load(a: T.handle, c: T.handle) -> None:
             vi, vj = T.axis.remap("SS", [i, j])
             T.reads(B[0:128, 0:128])
             T.writes(C[0:128, 0:128])
-            T.evaluate(
-                T.tvm_access_ptr(
-                    T.type_annotation(dtype="float32"), B.data, 0, 128, "r", dtype="handle"
-                )
-            )
+            T.evaluate(B.access_ptr("r", extent=128))
             C[vi, vj] = B[vi, vj] + 1.0
 
 
@@ -205,16 +201,8 @@ def opaque_access_store(a: T.handle, c: T.handle) -> None:
             vi, vj = T.axis.remap("SS", [i, j])
             T.reads(B[0:128, 0:128])
             T.writes(C[0:128, 0:128])
-            T.evaluate(
-                T.tvm_access_ptr(
-                    T.type_annotation(dtype="float32"), B.data, 0, 128, "r", dtype="handle"
-                )
-            )
-            T.evaluate(
-                T.tvm_access_ptr(
-                    T.type_annotation(dtype="float32"), C.data, 0, 128, "w", dtype="handle"
-                )
-            )
+            T.evaluate(B.access_ptr("r", extent=128))
+            T.evaluate(C.access_ptr("w", extent=128))
             C[vi, vj] = B[vi, vj] + 1.0
 
 
@@ -296,16 +284,8 @@ def access_opaque_ptr_then_elemwise(a: T.handle, b: T.handle) -> None:
         # annotated opaque partial access
         T.reads(A[0:512])
         T.writes(A_cache[0:512])
-        T.evaluate(
-            T.tvm_access_ptr(
-                T.type_annotation(dtype="float32"), A.data, 0, 512, "r", dtype="handle"
-            )
-        )
-        T.evaluate(
-            T.tvm_access_ptr(
-                T.type_annotation(dtype="float32"), A_cache.data, 0, 512, "w", dtype="handle"
-            )
-        )
+        T.evaluate(A.access_ptr("r", extent=512))
+        T.evaluate(A_cache.access_ptr("w", extent=512))
     for i in range(512):
         with T.block("BB"):
             vi = T.axis.remap("S", [i])
@@ -325,16 +305,8 @@ def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None:
         # annotated opaque partial access should be kept
         T.reads(A[0:512])
         T.writes([A_cache[0:512]])
-        T.evaluate(
-            T.tvm_access_ptr(
-                T.type_annotation(dtype="float32"), A.data, 0, 512, "r", dtype="handle"
-            )
-        )
-        T.evaluate(
-            T.tvm_access_ptr(
-                T.type_annotation(dtype="float32"), A_cache.data, 0, 512, "w", dtype="handle"
-            )
-        )
+        T.evaluate(A.access_ptr("r", extent=512))
+        T.evaluate(A_cache.access_ptr("w", extent=512))
     for i in T.serial(0, 512):
         with T.block("B"):
             vi = T.axis.spatial(512, i)
@@ -402,6 +374,51 @@ def inline_block_with_init(
                 )
 
 
+@T.prim_func
+def exp_exp_opaque_access_with_tvm_access_ptr(
+    lookup_table: T.Buffer[(1024,), "int8"],
+    x: T.Buffer[(16,), "float16"],
+    compute: T.Buffer[(16,), "float16"],
+) -> None:
+    compute_1 = T.alloc_buffer([16], dtype="float16")
+    for i0 in T.serial(16):
+        with T.block("compute"):
+            i0_1 = T.axis.spatial(16, i0)
+            T.reads(x[i0_1])
+            T.writes(compute_1[i0_1])
+            compute_1[i0_1] = T.exp(x[i0_1], dtype="float16")
+    for i0 in T.serial(16):
+        with T.block("compute_1"):
+            i0_2 = T.axis.spatial(16, i0)
+            T.reads(compute_1[i0_2], lookup_table[0:1024])
+            T.writes(compute[i0_2])
+            compute[i0_2] = T.exp(
+                compute_1[i0_2],
+                lookup_table.access_ptr("r"),
+                dtype="float16",
+            )
+
+
+@T.prim_func
+def exp_exp_opaque_access_with_tvm_access_ptr_inlined(
+    lookup_table: T.Buffer[(1024,), "int8"],
+    x: T.Buffer[(16,), "float16"],
+    compute: T.Buffer[(16,), "float16"],
+) -> None:
+    for i0 in T.serial(16):
+        with T.block("compute_1"):
+            i0_1 = T.axis.spatial(16, i0)
+            # Do not put the opaque access to new write region when opaque access
+            # wrapped with a tvm_access_ptr and the access mask set to "read only"
+            T.reads(x[i0_1], lookup_table[0:1024])
+            T.writes(compute[i0_1])
+            compute[i0_1] = T.exp(
+                T.exp(x[i0_1], dtype="float16"),
+                lookup_table.access_ptr("r"),
+                dtype="float16",
+            )
+
+
 # pylint: enable=no-member,invalid-name,unused-variable
 
 
@@ -569,5 +586,15 @@ def test_inline_block_with_init():
         sch.compute_inline(block=block)
 
 
+def test_compute_inline_opaque_access_with_tvm_access_ptr():
+    """Test opaque access with tvm_access_ptr after compute inline"""
+    sch = tir.Schedule(exp_exp_opaque_access_with_tvm_access_ptr, debug_mask="all")
+    compute = sch.get_block("compute")
+    sch.compute_inline(compute)
+    tvm.ir.assert_structural_equal(
+        exp_exp_opaque_access_with_tvm_access_ptr_inlined, sch.mod["main"]
+    )
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))