You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/09/01 02:23:23 UTC

[tvm] branch main updated: [TIR] Allow string/buffer arguments to Schedule cache_read/write (#12661)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c6516a534f [TIR] Allow string/buffer arguments to Schedule cache_read/write (#12661)
c6516a534f is described below

commit c6516a534fded605ae24bf56e24ec871b68ca9e2
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Wed Aug 31 19:23:15 2022 -0700

    [TIR] Allow string/buffer arguments to Schedule cache_read/write (#12661)
    
    Previously, the argument needed to be an integer specifying the index
    into the read/write regions of a block.  Now, the argument can be a
    string specifying the name of the buffer, or the Buffer object itself.
    This is a follow-up from https://github.com/apache/tvm/pull/11624.
---
 python/tvm/tir/schedule/schedule.py                | 42 ++++++++++++++++++----
 .../unittest/test_tir_schedule_cache_read_write.py |  8 +++--
 2 files changed, 42 insertions(+), 8 deletions(-)

diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index 04cc1bc26a..d1293371a0 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -1014,7 +1014,7 @@ class Schedule(Object):
     def cache_read(
         self,
         block: Union[BlockRV, str],
-        read_buffer_index: int,
+        read_buffer_index: Union[int, str, Buffer],
         storage_scope: str,
         consumer_blocks: Optional[List[Union[BlockRV, str]]] = None,
     ) -> BlockRV:
@@ -1029,8 +1029,10 @@ class Schedule(Object):
         block : Union[BlockRV, str]
             The consumer block of the target buffer.
 
-        read_buffer_index: int
-            The index of the buffer in block's read region.
+        buffer: Union[int, str, Buffer]
+            The index of the buffer in block's read region, the unique
+            name of a read buffer in the block, or a Buffer object
+            that is within the blocks read region.
 
         storage_scope: str
             The target storage scope.
@@ -1093,13 +1095,21 @@ class Schedule(Object):
         # Convert any string block names into Block RVs.
         consumer_blocks = [self._normalize_block_arg(b) for b in consumer_blocks]
         block = self._normalize_block_arg(block)
+
+        if not isinstance(read_buffer_index, int):
+            _, read_buffer_index, _ = self._normalize_buffer_arg(
+                block, read_buffer_index, required_buffer_type="read"
+            )
         return _ffi_api.ScheduleCacheRead(  # type: ignore # pylint: disable=no-member
             self, block, read_buffer_index, storage_scope, consumer_blocks
         )
 
     @type_checked
     def cache_write(
-        self, block: Union[BlockRV, str], write_buffer_index: int, storage_scope: str
+        self,
+        block: Union[BlockRV, str],
+        write_buffer_index: Union[int, str, Buffer],
+        storage_scope: str,
     ) -> BlockRV:
         """Create a block that reads a buffer region into a write cache. It requires:
 
@@ -1113,7 +1123,9 @@ class Schedule(Object):
             The producer block of the target buffer.
 
         write_buffer_index: int
-            The index of the buffer in block's write region.
+            The index of the buffer in block's write region, the unique
+            name of a write buffer in the block, or a Buffer object
+            that is within the blocks write region.
 
         storage_scope: str
             The target storage scope.
@@ -1168,6 +1180,11 @@ class Schedule(Object):
 
         """
         block = self._normalize_block_arg(block)
+
+        if not isinstance(write_buffer_index, int):
+            _, write_buffer_index, _ = self._normalize_buffer_arg(
+                block, write_buffer_index, required_buffer_type="write"
+            )
         return _ffi_api.ScheduleCacheWrite(  # type: ignore # pylint: disable=no-member
             self, block, write_buffer_index, storage_scope
         )
@@ -2352,7 +2369,10 @@ class Schedule(Object):
         return block
 
     def _normalize_buffer_arg(
-        self, block: BlockRV, buffer: Union[Tuple[str, int], str, Buffer]
+        self,
+        block: BlockRV,
+        buffer: Union[Tuple[str, int], int, str, Buffer],
+        required_buffer_type=None,
     ) -> Tuple[str, int, Buffer]:
 
         block_obj: Block = self.get(block)
@@ -2364,6 +2384,9 @@ class Schedule(Object):
             for i, write in enumerate(block_obj.writes):
                 yield "write", i, write.buffer
 
+        if isinstance(buffer, int):
+            buffer = (required_buffer_type, buffer)
+
         if isinstance(buffer, str):
             possible_buffers = {}
             # String lookup requires ensuring that the name is unique
@@ -2405,6 +2428,13 @@ class Schedule(Object):
         else:
             raise TypeError(f"Invalid type for argument 'buffer': {type(buffer)}")
 
+        if required_buffer_type is not None:
+            assert buffer_index_type == required_buffer_type, (
+                f"Expected buffer to be read buffer, "
+                f"but {buffer_obj.name} was a {buffer_index_type} buffer "
+                f"in the specified block"
+            )
+
         return (buffer_index_type, buffer_index, buffer_obj)
 
     @type_checked
diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py
index 255ca34118..cf4836e536 100644
--- a/tests/python/unittest/test_tir_schedule_cache_read_write.py
+++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py
@@ -774,8 +774,12 @@ def test_cache_read_elementwise(use_block_name):
     sch = tir.Schedule(elementwise, debug_mask="all")
     block_b = sch.get_block("B")
     block_c = sch.get_block("C")
-    cached_a = sch.cache_read("B" if use_block_name else block_b, 0, "global")
-    cached_b = sch.cache_read("C" if use_block_name else block_c, 0, "local")
+    if use_block_name:
+        cached_a = sch.cache_read("B", "A", "global")
+        cached_b = sch.cache_read("C", "B", "local")
+    else:
+        cached_a = sch.cache_read(block_b, 0, "global")
+        cached_b = sch.cache_read(block_c, 0, "local")
     assert sch.get(cached_a) == sch.get(sch.get_block("A_global"))
     assert sch.get(cached_b) == sch.get(sch.get_block("B_local"))
     assert sch.get(block_b) == sch.get(sch.get_block("B"))