You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/05/17 02:59:31 UTC

[GitHub] [tvm] vinx13 commented on a diff in pull request #11269: [Draft][TIR][Schedule] Transform layout quality of life

vinx13 commented on code in PR #11269:
URL: https://github.com/apache/tvm/pull/11269#discussion_r874317075


##########
python/tvm/tir/schedule/schedule.py:
##########
@@ -2114,25 +2114,111 @@ def after_unannotate(a: T.handle, b: T.handle) -> None:
 
     ########## Schedule: Layout transformation ##########
 
-    @type_checked
+    def _normalize_block_arg(self, block: Union[BlockRV, str]) -> BlockRV:
+        if isinstance(block, str):
+            return self.get_block(block)
+
+        return block
+
+    def _normalize_buffer_arg(
+        self, block: BlockRV, buffer: Union[Tuple[str, int], str, Buffer]
+    ) -> Tuple[str, int, Buffer]:
+
+        block_name = self.get(block).name_hint
+
+        def iter_buffers():
+            block_obj = self.get(block)
+            for i, read in enumerate(block_obj.reads):
+                yield "read", i, read.buffer
+            for i, write in enumerate(block_obj.writes):
+                yield "write", i, write.buffer
+
+        if isinstance(buffer, str):
+            possible_buffers = {}
+            # String lookup requires ensuring that the name is unique
+            for buffer_index, buffer_index_type, buf in iter_buffers():
+                if buf.name == buffer:
+                    possible_buffers[buf] = (buffer_index_type, buffer_index)
+
+            assert possible_buffers, f"Could not find buffer '{buffer}' in block '{block_name}'"
+            assert (
+                len(possible_buffers) == 1
+            ), f"Multiple buffers named '{buffer}' in block '{block_name}'"
+            buffer_obj, (buffer_index, buffer_index_type) = next(iter(possible_buffers.items()))
+
+        elif isinstance(buffer, Buffer):
+            # Buffer lookup has unique id, can break out early
+            found = False
+            for buffer_index, buffer_index_type, buffer_obj in iter_buffers():
+                if buffer_obj.same_as(buffer):
+                    found = True
+                    break
+
+            assert found, "Could not find buffer '{buffer.name}' in block '{block_name}'"
+
+        elif isinstance(buffer, tuple):
+            buffer_index_type, buffer_index = buffer
+            assert buffer_index_type in ["read", "write",], (
+                f"Invalid buffer_index_type.  "
+                f"Expected 'read' or 'write', "
+                f"but received {buffer_index_type}"
+            )
+            buffer_list = (
+                self.get(block).reads if buffer_index_type == "read" else self.get(block).writes
+            )
+            assert 0 <= buffer_index < len(buffer_list), (
+                f"Invalid buffer_index {buffer_index}.  "
+                f"Block {block_name} has only "
+                f"{len(buffer_list)} {buffer_index_type} buffers."
+            )
+            buffer_obj = buffer_list[buffer_index].buffer
+
+        else:
+            raise TypeError(f"Invalid type for argument 'buffer': {type(buffer)}")
+
+        return (buffer_index_type, buffer_index, buffer_obj)
+
+    # @type_checked
     def transform_layout(
         self,
-        block: BlockRV,
-        buffer_index: int,
-        buffer_index_type: str,
+        block: Union[BlockRV, str],
+        buffer: Union[Tuple[str, int], str, Buffer],

Review Comment:
   this will break trace to python conversion. It is implemented here https://github.com/apache/tvm/blob/main/src/tir/schedule/primitive/layout_transformation.cc#L288.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org