You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/04/28 13:00:39 UTC
[tvm] branch main updated: [TIR] Get read/write access precisely for opaque access. (#11110)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7710dfd557 [TIR] Get read/write access precisely for opaque access. (#11110)
7710dfd557 is described below
commit 7710dfd557752e2a980b90d1b19a66d8dcefb929
Author: albert qing <26...@qq.com>
AuthorDate: Thu Apr 28 21:00:29 2022 +0800
[TIR] Get read/write access precisely for opaque access. (#11110)
* [TIR] Get read/write access precisely for opaque access.
When the opaque access is wrapped with tvm_access_ptr, we can get the access_mask
from tvm_access_ptr in BlockReadWriteDetector and put this opaque access to read_regions
or write_regions according to access_mask.
* [TIR] Add parameter extent for access_ptr.
Co-authored-by: sqing <qi...@intellif.com>
---
include/tvm/tir/buffer.h | 5 +-
python/tvm/tir/buffer.py | 10 ++-
src/tir/analysis/block_access_region_detector.cc | 28 +++++++
src/tir/ir/buffer.cc | 8 +-
.../test_tir_analysis_get_block_access_region.py | 29 +++++++
tests/python/unittest/test_tir_buffer.py | 6 ++
.../unittest/test_tir_schedule_compute_inline.py | 97 ++++++++++++++--------
7 files changed, 142 insertions(+), 41 deletions(-)
diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h
index aef82ae368..ca7faf1cde 100644
--- a/include/tvm/tir/buffer.h
+++ b/include/tvm/tir/buffer.h
@@ -186,10 +186,11 @@ class Buffer : public ObjectRef {
* \param ptr_type The type of the pointer.
* \param content_lanes The number of lanes for the (data) type.
* \param offset The offset of ptr.
+ * \param input_extent The extent of ptr.
*/
TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(),
- int content_lanes = 1,
- PrimExpr offset = IntImm(DataType::Int(32), 0)) const;
+ int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0),
+ Optional<PrimExpr> input_extent = NullOpt) const;
/*!
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py
index e36a99339e..d9b0aec76a 100644
--- a/python/tvm/tir/buffer.py
+++ b/python/tvm/tir/buffer.py
@@ -42,7 +42,7 @@ class Buffer(Object):
READ = 1
WRITE = 2
- def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
+ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0, extent=None):
"""Get an access pointer to the head of buffer.
This is the recommended method to get buffer data
@@ -66,6 +66,9 @@ class Buffer(Object):
The offset of pointer. We can use it to offset by
the number of elements from the address of ptr.
+ extent: Expr, optional
+ The extent of pointer.
+
Examples
--------
.. code-block:: python
@@ -78,6 +81,8 @@ class Buffer(Object):
buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
+ # Get access ptr for read with extent
+ buffer.access_ptr("r", extent = 100)
"""
if isinstance(access_mask, string_types):
mask = 0
@@ -90,8 +95,9 @@ class Buffer(Object):
raise ValueError("Unknown access_mask %s" % access_mask)
access_mask = mask
offset = convert(offset)
+ extent = convert(extent)
return _ffi_api.BufferAccessPtr(
- self, access_mask, ptr_type, content_lanes, offset # type: ignore
+ self, access_mask, ptr_type, content_lanes, offset, extent # type: ignore
)
def vload(self, begin, dtype=None):
diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc
index ffe0c75294..c65a422ed3 100644
--- a/src/tir/analysis/block_access_region_detector.cc
+++ b/src/tir/analysis/block_access_region_detector.cc
@@ -181,6 +181,34 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) {
}
void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
+ if (op->op.same_as(builtin::tvm_access_ptr())) {
+ const VarNode* buffer_var = op->args[1].as<VarNode>();
+ const IntImmNode* access_mask = op->args[4].as<IntImmNode>();
+ if (buffer_var && access_mask) {
+ auto it = buffer_var_map_.find(GetRef<Var>(buffer_var));
+ if (it != buffer_var_map_.end()) {
+ const Buffer& buffer = (*it).second;
+ const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
+ const Region& region = buffer_region->region;
+ std::vector<arith::IntSet> int_set;
+ int_set.reserve(region.size());
+ for (const Range& range : region) {
+ int_set.push_back(arith::EvalSet(range, dom_map_));
+ }
+ // read access, write access or opaque access
+ if ((access_mask->value & 1) && (access_mask->value & 2)) {
+ Update(&opaque_buffers_, &opaque_regions_, buffer, int_set);
+ } else if (access_mask->value & 1) {
+ Update(&read_buffers_, &read_regions_, buffer, int_set);
+ } else if (access_mask->value & 2) {
+ Update(&writes_buffers_, &write_regions_, buffer, int_set);
+ }
+ }
+ } else {
+ StmtExprVisitor::VisitExpr_(op);
+ }
+ return;
+ }
if (op->op.same_as(builtin::if_then_else())) {
VisitExpr(op->args[0]);
{
diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc
index 9cc92bd17e..ccf186634b 100644
--- a/src/tir/ir/buffer.cc
+++ b/src/tir/ir/buffer.cc
@@ -495,8 +495,8 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
return slice;
}
-PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes,
- PrimExpr offset) const {
+PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, PrimExpr offset,
+ Optional<PrimExpr> input_extent) const {
const BufferNode* self = operator->();
ICHECK(self != nullptr);
PrimExpr e_dtype;
@@ -519,6 +519,10 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
} else {
e_dtype = tir::TypeAnnotation(self->dtype);
}
+
+ if (input_extent.defined()) {
+ extent = input_extent.value();
+ }
Array<PrimExpr> acc_args{e_dtype, self->data, elem_offset, extent,
make_const(DataType::Int(32), access_mask)};
return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args);
diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
index 463f2a7f0e..8a10cbd072 100644
--- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py
+++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
@@ -105,6 +105,19 @@ def opaque_access_func() -> None:
)
+@T.prim_func
+def opaque_access_with_tvm_access_ptr_func() -> None:
+ A = T.alloc_buffer([1024])
+ B = T.alloc_buffer([1024])
+ C = T.alloc_buffer([1024])
+ with T.block("opaque"):
+ T.reads(A[0:1024], C[0:1024])
+ T.writes(B[0:1024], C[0:1024])
+ T.evaluate(A.access_ptr("r"))
+ T.evaluate(B.access_ptr("w"))
+ T.evaluate(C.access_ptr("rw"))
+
+
@T.prim_func
def access_in_if_then_else_func() -> None:
A = T.alloc_buffer([8])
@@ -235,6 +248,21 @@ def test_opaque_access():
tvm.ir.assert_structural_equal(ret0[1], ret1[1])
+def test_opaque_access_with_tvm_access_ptr():
+ block = opaque_access_with_tvm_access_ptr_func.body.block.body.block
+ alloc_buffers = opaque_access_with_tvm_access_ptr_func.body.block.alloc_buffers
+ buffer_var_map = {buf.data: buf for buf in alloc_buffers}
+
+ ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map)
+ ret1 = tir.analysis.get_block_access_region(block, buffer_var_map)
+ tvm.ir.assert_structural_equal(block.reads, ret0[0])
+ tvm.ir.assert_structural_equal(block.writes, ret0[1])
+ with pytest.raises(ValueError):
+ tvm.ir.assert_structural_equal(ret0[0], ret1[0])
+ with pytest.raises(ValueError):
+ tvm.ir.assert_structural_equal(ret0[1], ret1[1])
+
+
def test_match_buffer():
root_block = match_buffer_func.body.block
block = root_block.body.body.body.block
@@ -333,6 +361,7 @@ if __name__ == "__main__":
test_block_access_region_detector()
test_opaque_block()
test_opaque_access()
+ test_opaque_access_with_tvm_access_ptr()
test_match_buffer()
test_access_in_if_then_else_func()
test_access_in_branch_func()
diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py
index 990d0a22c8..337f9cbc07 100644
--- a/tests/python/unittest/test_tir_buffer.py
+++ b/tests/python/unittest/test_tir_buffer.py
@@ -76,6 +76,12 @@ def test_buffer_access_ptr_extent():
aptr = Ab.access_ptr("rw", offset=100)
assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m - 100)
+ # Test extent from input params
+ aptr = Ab.access_ptr("rw", extent=200)
+ assert tvm.ir.structural_equal(aptr.args[3], 200)
+ aptr = Ab.access_ptr("rw", offset=100, extent=100)
+ assert tvm.ir.structural_equal(aptr.args[3], 100)
+
def test_buffer_vload():
m = te.size_var("m")
diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py
index 1259219a39..8894cd4d9f 100644
--- a/tests/python/unittest/test_tir_schedule_compute_inline.py
+++ b/tests/python/unittest/test_tir_schedule_compute_inline.py
@@ -183,11 +183,7 @@ def opaque_access_load(a: T.handle, c: T.handle) -> None:
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[0:128, 0:128])
T.writes(C[0:128, 0:128])
- T.evaluate(
- T.tvm_access_ptr(
- T.type_annotation(dtype="float32"), B.data, 0, 128, "r", dtype="handle"
- )
- )
+ T.evaluate(B.access_ptr("r", extent=128))
C[vi, vj] = B[vi, vj] + 1.0
@@ -205,16 +201,8 @@ def opaque_access_store(a: T.handle, c: T.handle) -> None:
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[0:128, 0:128])
T.writes(C[0:128, 0:128])
- T.evaluate(
- T.tvm_access_ptr(
- T.type_annotation(dtype="float32"), B.data, 0, 128, "r", dtype="handle"
- )
- )
- T.evaluate(
- T.tvm_access_ptr(
- T.type_annotation(dtype="float32"), C.data, 0, 128, "w", dtype="handle"
- )
- )
+ T.evaluate(B.access_ptr("r", extent=128))
+ T.evaluate(C.access_ptr("w", extent=128))
C[vi, vj] = B[vi, vj] + 1.0
@@ -296,16 +284,8 @@ def access_opaque_ptr_then_elemwise(a: T.handle, b: T.handle) -> None:
# annotated opaque partial access
T.reads(A[0:512])
T.writes(A_cache[0:512])
- T.evaluate(
- T.tvm_access_ptr(
- T.type_annotation(dtype="float32"), A.data, 0, 512, "r", dtype="handle"
- )
- )
- T.evaluate(
- T.tvm_access_ptr(
- T.type_annotation(dtype="float32"), A_cache.data, 0, 512, "w", dtype="handle"
- )
- )
+ T.evaluate(A.access_ptr("r", extent=512))
+ T.evaluate(A_cache.access_ptr("w", extent=512))
for i in range(512):
with T.block("BB"):
vi = T.axis.remap("S", [i])
@@ -325,16 +305,8 @@ def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None:
# annotated opaque partial access should be kept
T.reads(A[0:512])
T.writes([A_cache[0:512]])
- T.evaluate(
- T.tvm_access_ptr(
- T.type_annotation(dtype="float32"), A.data, 0, 512, "r", dtype="handle"
- )
- )
- T.evaluate(
- T.tvm_access_ptr(
- T.type_annotation(dtype="float32"), A_cache.data, 0, 512, "w", dtype="handle"
- )
- )
+ T.evaluate(A.access_ptr("r", extent=512))
+ T.evaluate(A_cache.access_ptr("w", extent=512))
for i in T.serial(0, 512):
with T.block("B"):
vi = T.axis.spatial(512, i)
@@ -402,6 +374,51 @@ def inline_block_with_init(
)
+@T.prim_func
+def exp_exp_opaque_access_with_tvm_access_ptr(
+ lookup_table: T.Buffer[(1024,), "int8"],
+ x: T.Buffer[(16,), "float16"],
+ compute: T.Buffer[(16,), "float16"],
+) -> None:
+ compute_1 = T.alloc_buffer([16], dtype="float16")
+ for i0 in T.serial(16):
+ with T.block("compute"):
+ i0_1 = T.axis.spatial(16, i0)
+ T.reads(x[i0_1])
+ T.writes(compute_1[i0_1])
+ compute_1[i0_1] = T.exp(x[i0_1], dtype="float16")
+ for i0 in T.serial(16):
+ with T.block("compute_1"):
+ i0_2 = T.axis.spatial(16, i0)
+ T.reads(compute_1[i0_2], lookup_table[0:1024])
+ T.writes(compute[i0_2])
+ compute[i0_2] = T.exp(
+ compute_1[i0_2],
+ lookup_table.access_ptr("r"),
+ dtype="float16",
+ )
+
+
+@T.prim_func
+def exp_exp_opaque_access_with_tvm_access_ptr_inlined(
+ lookup_table: T.Buffer[(1024,), "int8"],
+ x: T.Buffer[(16,), "float16"],
+ compute: T.Buffer[(16,), "float16"],
+) -> None:
+ for i0 in T.serial(16):
+ with T.block("compute_1"):
+ i0_1 = T.axis.spatial(16, i0)
+ # Do not put the opaque access to new write region when opaque access
+ # wrapped with a tvm_access_ptr and the access mask set to "read only"
+ T.reads(x[i0_1], lookup_table[0:1024])
+ T.writes(compute[i0_1])
+ compute[i0_1] = T.exp(
+ T.exp(x[i0_1], dtype="float16"),
+ lookup_table.access_ptr("r"),
+ dtype="float16",
+ )
+
+
# pylint: enable=no-member,invalid-name,unused-variable
@@ -569,5 +586,15 @@ def test_inline_block_with_init():
sch.compute_inline(block=block)
+def test_compute_inline_opaque_access_with_tvm_access_ptr():
+ """Test opaque access with tvm_access_ptr after compute inline"""
+ sch = tir.Schedule(exp_exp_opaque_access_with_tvm_access_ptr, debug_mask="all")
+ compute = sch.get_block("compute")
+ sch.compute_inline(compute)
+ tvm.ir.assert_structural_equal(
+ exp_exp_opaque_access_with_tvm_access_ptr_inlined, sch.mod["main"]
+ )
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))