You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/09/01 02:23:23 UTC
[tvm] branch main updated: [TIR] Allow string/buffer arguments to Schedule cache_read/write (#12661)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c6516a534f [TIR] Allow string/buffer arguments to Schedule cache_read/write (#12661)
c6516a534f is described below
commit c6516a534fded605ae24bf56e24ec871b68ca9e2
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Wed Aug 31 19:23:15 2022 -0700
[TIR] Allow string/buffer arguments to Schedule cache_read/write (#12661)
Previously, the argument needed to be an integer specifying the index
into the read/write regions of a block. Now, the argument can be a
string specifying the name of the buffer, or the Buffer object itself.
This is a follow-up from https://github.com/apache/tvm/pull/11624.
---
python/tvm/tir/schedule/schedule.py | 42 ++++++++++++++++++----
.../unittest/test_tir_schedule_cache_read_write.py | 8 +++--
2 files changed, 42 insertions(+), 8 deletions(-)
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index 04cc1bc26a..d1293371a0 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -1014,7 +1014,7 @@ class Schedule(Object):
def cache_read(
self,
block: Union[BlockRV, str],
- read_buffer_index: int,
+ read_buffer_index: Union[int, str, Buffer],
storage_scope: str,
consumer_blocks: Optional[List[Union[BlockRV, str]]] = None,
) -> BlockRV:
@@ -1029,8 +1029,10 @@ class Schedule(Object):
block : Union[BlockRV, str]
The consumer block of the target buffer.
- read_buffer_index: int
- The index of the buffer in block's read region.
+ buffer: Union[int, str, Buffer]
+ The index of the buffer in block's read region, the unique
+ name of a read buffer in the block, or a Buffer object
+ that is within the blocks read region.
storage_scope: str
The target storage scope.
@@ -1093,13 +1095,21 @@ class Schedule(Object):
# Convert any string block names into Block RVs.
consumer_blocks = [self._normalize_block_arg(b) for b in consumer_blocks]
block = self._normalize_block_arg(block)
+
+ if not isinstance(read_buffer_index, int):
+ _, read_buffer_index, _ = self._normalize_buffer_arg(
+ block, read_buffer_index, required_buffer_type="read"
+ )
return _ffi_api.ScheduleCacheRead( # type: ignore # pylint: disable=no-member
self, block, read_buffer_index, storage_scope, consumer_blocks
)
@type_checked
def cache_write(
- self, block: Union[BlockRV, str], write_buffer_index: int, storage_scope: str
+ self,
+ block: Union[BlockRV, str],
+ write_buffer_index: Union[int, str, Buffer],
+ storage_scope: str,
) -> BlockRV:
"""Create a block that reads a buffer region into a write cache. It requires:
@@ -1113,7 +1123,9 @@ class Schedule(Object):
The producer block of the target buffer.
write_buffer_index: int
- The index of the buffer in block's write region.
+ The index of the buffer in block's write region, the unique
+ name of a write buffer in the block, or a Buffer object
+ that is within the blocks write region.
storage_scope: str
The target storage scope.
@@ -1168,6 +1180,11 @@ class Schedule(Object):
"""
block = self._normalize_block_arg(block)
+
+ if not isinstance(write_buffer_index, int):
+ _, write_buffer_index, _ = self._normalize_buffer_arg(
+ block, write_buffer_index, required_buffer_type="write"
+ )
return _ffi_api.ScheduleCacheWrite( # type: ignore # pylint: disable=no-member
self, block, write_buffer_index, storage_scope
)
@@ -2352,7 +2369,10 @@ class Schedule(Object):
return block
def _normalize_buffer_arg(
- self, block: BlockRV, buffer: Union[Tuple[str, int], str, Buffer]
+ self,
+ block: BlockRV,
+ buffer: Union[Tuple[str, int], int, str, Buffer],
+ required_buffer_type=None,
) -> Tuple[str, int, Buffer]:
block_obj: Block = self.get(block)
@@ -2364,6 +2384,9 @@ class Schedule(Object):
for i, write in enumerate(block_obj.writes):
yield "write", i, write.buffer
+ if isinstance(buffer, int):
+ buffer = (required_buffer_type, buffer)
+
if isinstance(buffer, str):
possible_buffers = {}
# String lookup requires ensuring that the name is unique
@@ -2405,6 +2428,13 @@ class Schedule(Object):
else:
raise TypeError(f"Invalid type for argument 'buffer': {type(buffer)}")
+ if required_buffer_type is not None:
+ assert buffer_index_type == required_buffer_type, (
+ f"Expected buffer to be read buffer, "
+ f"but {buffer_obj.name} was a {buffer_index_type} buffer "
+ f"in the specified block"
+ )
+
return (buffer_index_type, buffer_index, buffer_obj)
@type_checked
diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py
index 255ca34118..cf4836e536 100644
--- a/tests/python/unittest/test_tir_schedule_cache_read_write.py
+++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py
@@ -774,8 +774,12 @@ def test_cache_read_elementwise(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
block_c = sch.get_block("C")
- cached_a = sch.cache_read("B" if use_block_name else block_b, 0, "global")
- cached_b = sch.cache_read("C" if use_block_name else block_c, 0, "local")
+ if use_block_name:
+ cached_a = sch.cache_read("B", "A", "global")
+ cached_b = sch.cache_read("C", "B", "local")
+ else:
+ cached_a = sch.cache_read(block_b, 0, "global")
+ cached_b = sch.cache_read(block_c, 0, "local")
assert sch.get(cached_a) == sch.get(sch.get_block("A_global"))
assert sch.get(cached_b) == sch.get(sch.get_block("B_local"))
assert sch.get(block_b) == sch.get(sch.get_block("B"))