You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "AndrewZhaoLuo (via GitHub)" <gi...@apache.org> on 2023/03/01 23:48:24 UTC

[GitHub] [tvm] AndrewZhaoLuo opened a new pull request, #14167: [CUDA][Schedule] Better Layout Transform Schedules

AndrewZhaoLuo opened a new pull request, #14167:
URL: https://github.com/apache/tvm/pull/14167

   CUDA: Improved Layout Transform Schedule
   
   ALL numbers are on an RTX 3070 unless stated otherwise.
   
   ## Motivation
   The default implementation of layout transform has poor performance in some scenarios. A motivating factor behind all of this were layout transform operators taking up significant time in some stable diffusion models we tried. NHWC layout is required for tensor-core use so we had to convert the layout of conv. operators. These introduced operators and their fused and unfused versions were extremely slow. I ended up improving latency by at least 2x for all these operators:
   
   *** TABLE HERE ***
   
   ## Algorithm 
   Currently, layout transform relies on the AutoBind schedule rule which seems to guarantee contiguous writes to the output buffer when assigning thread and block indices to loops. However in the case of layout transforms where the inner dimensions of `src_layout` and `dst_layout` do not match, it is impossible to have contiguous writes and contiguous reads in the same operation. 
   
   We ignore the case where the inner dimension matches (though I believe the new schedule might be faster in some scenarios).
   
   ### Simple Case: Transpose
   The way around this is to use a small amount of shared memory and tile loads so that reads from global memory -> shared can be coalesced. We carefully load elements so writing from shared —> global memory also have coalesced access. An example is taking the transpose of the matrix [2048, 1024] —> [1024, 2048]. We can read 32 x 32 tiles of the src matrix. We might make rows of our tile of shared memory correspond to coalesced accesses of the src matrix, then columns of our shared memory tile correspond to coalesced accesses of the dst matrix. By doing this, we can maximize memory throughput of our operations.
   
   ### General Case
   While a simple transpose is easy to read about, how do we guarantee this behavior for general layout transforms where dimensions can be arbitrarily thrown around? 
   
   The answer is we can make use of analysis tools in TVM scheduling! Specifically if we read from global memory adn write into our shared memory correctly, setting up the loop structure so that the outer loop will eventually be bound to blockIdx’s, we can set up our read loop and then use `compute_at` or `reverse_compute_at` to automatically generate the proper extants for writing loops!
   
   We therefore only have to care about getting the loop structure for the read blocks correctly so that the eventual write block can also have coalesced memory access. We have a constraints to think about here. 
   
   Reads from global memory can be coalesced as much as possible up to the tile_size
   Writes to global memory can be coalesced as much as possible up to the tile_size
   Each thread will read up to tile_size elements from global memory
   Each thread will write up to tile_size elements into global memory
   
   We can guarantee the proper number of reads and writes by tiling our loops like so:
   
   ```
   for block_idx, thread_idx, inner_loop in T.grid(remainder, tile_size, tile_size):
       # each combo of all indices reads one element
   ```
   
   Guaranteeing coalesced reads is also simple. Essentially the inner most loops must access the inner dimensions of src_layout. For example, let’s say we have a matrix 
   [2_d, 128, 64, 32, 8, 4, 2, 2_s] (_s and _d refer to the inner most dimension of src and dst layout respectively) and we want to do a layout transform from “ABCDEFGH” —> HBCDFEGA” (A and I swapped, E and F swapped). Then for writing we want our inner loop of extant tile_size to correspond to the inner most 4 dimensions of the source layout. E.g. we divide the dimension of 8 into one of 4 x 2 and combine the loops of extant 2, 4, 2, 2_s into one of 32. 
   
   Guaranteeing coalesced writes is done similarly, except we must consider things from the point of view of the dst_layout:
   [2_s, 128, 64, 32, 4, 8, 2, 2_d] (first and last dimension swapped, as are the 4 and the 8). 
   
   We can use similar logic, and get combining the dimensions [8, 2, 2_d] to get our tile_size. However, notice we already utilized some of these loops in guaranteeing coalesced reads! We cannot reuse the same factors or else the loop will read and write a different number of elements. So instead we use the remaining factors left until we also hit our tile_size. 
   
   We now have our read and write dimension tiled to get those constraints as close as possible. To handle weirder shapes which don’t divide nicely into tile_size, we pad some dimensions until it divides into tile_size. Incidentally, the above example with a dype of “float16” and a tile_size of 32 has a runtime of 667 us vs. the default 993 us and a 90% memory throughput reported from nvidia compute tool!
   
   ## Choice of search space
   — The most natural tile sizes are those aligned with coalesced transacation limits in global memory of 32, 64, and 128 bytes. A block size of 128 threads is too high (as each thread would also be required for 128 elements of work) though would hit the transaction limit with 1-byte datatypes. 64 is a natural upper limit to tile size. Due to the nature of factoring, other tile size between 1 and 64 might be better for various shapes. It’s relatively small search space so I have elected to try all tile sizes from 1 to 64 inclusive + default autobind implementation. #search
   
   ## Known defects:
   — High Memory use from common factors:
    
   Analysis when using `compute_at` seems to fail (or I am missing something) if we sample factors from the same dimensions for both dim0 and dim1 tiling. This leads to excessive shared memory use *sometimes* which can cause failure or lead to performance issues. A lot of times these are still faster than the default schedules still but occasionally they are much higher than I expect. These are a very tiny amount of tested cases however and we always try both the new schedule and AutoBind so it should be ok for now.
   
   I am not sure why this happens and would need more investigation. An example is transposing [1209, 9] of any type and 32 tile size. 
   
   — Shared memory — Bank Conflicts:
   
   Shared memory bank conflicts exist and are common for the strategy used. Consider a simple transpose of a [2048, 1024] matrix to [1024, 2048]. Then with a tile size of 32, the shared memory buffer will be of shape [32, 32]. We might read from global memory into rows of shared memory, and then write columns of shared memory to global memory for contiguous access. However, note the columns of this shared memory buffer lie on the same memory bank! 
   
   A common solution is to simply pad the innermost dimension of shared memory. E.g. [32, 33]. Which now makes accesses along columns be bank-conflct free.
   
   This is planned to be done in a future PR via a new scheduling rule and is a general problem throughout all CUDA generated schedules. To give an idea of impact a [1024, 2048] transpose went from 14.34us —> 12.35us after this change basing off the optimized layout transform described in this PR.
   
   — Non-aligned inner dimensions:
   
   This is an issue I did not think of when writing this schedule. The schedule is done from the viewpoint of trying to maximize coalesced memory access in global memory. However, one small detail is coalesced memory access must be aligned to the size of the transcation. That is, if we have a coalesced access of 64 bytes (e.g. 32 float16’s), then the each address accessesd must be on the same 64 byte line (e.g. only the last 6 bits of address may be different). 
   
   Consider a layout transform where dimensions are prime numbers. E.g. [3, 1024, 1024, 7] -> [7, 1024, 1024, 3]. Then the current strategy will read 7 element-wide chunks at a time. However, most accesses will occur across coalesced memory boundaries, resulting in two coalesced memory requests instead of just 1. E.g. let’s say coalesced memory must be 8 byte aligned and we are dealing with one-byte datatype. The first read of 7 elements might be 0x00, 0x01 … 0x07 and the next will be 0x08, 0x09 … 0x0E. For the second accesss, 0x08 belongs to the first 8-byte line, while 0x09…0x0E belong to the second 8-byte line, requiring two memory transactions.
   
   One possible way to get around this is to treat the array as flattened and just access stuff coalesced, though I am not sure about the details, to guarantee good access for src and dst will require some thinkinging, though it might be possible. 
   
   E.g.
   An interesting thing in this case is if we do the no-op reshape into [3, 1024, 32, 32, 7] and then into [3, 1024, 32, 32 * 7], then [3, 1024, 32, 7, 32]. Then things become obvious. However, trying something like this initially leads to weird calculated bounds in the compute_at step and excessive shared memory usage as we must also consider the dst_layout.
   
   ## Results:
   Good stuff 
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on pull request #14167: [CUDA][Schedule] Better Layout Transform Schedules

Posted by "AndrewZhaoLuo (via GitHub)" <gi...@apache.org>.
AndrewZhaoLuo commented on PR #14167:
URL: https://github.com/apache/tvm/pull/14167#issuecomment-1472819633

   cc @tkonolige @vinx13 
   
   This is now ready for re-review.
   
   There is a lint error because our version of `black` is out of date which might take a while to fix unfortunately (since have to update CI). 
   
   Main thing was handling implicit reshapes in the layout transform (e.g. NCHW --> NCHW4c) and adding tests.  Tests are composed off some manual cases + some autogenerated cases.  The autogenerated cases also try to fuse compatible ops into layout transform task and mainly checks things for correctness.
   
   I offline tested ~8000 autogenerated cases for correctness of the schedule, for normal runs it tests ~9 autogenerated cases which takes about a minute on my computer.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on a diff in pull request #14167: [CUDA][Schedule] Better Layout Transform Schedules

Posted by "vinx13 (via GitHub)" <gi...@apache.org>.
vinx13 commented on code in PR #14167:
URL: https://github.com/apache/tvm/pull/14167#discussion_r1125019185


##########
include/tvm/topi/transform.h:
##########
@@ -1596,6 +1596,7 @@ inline Array<Tensor> meshgrid(const Array<Tensor>& inputs, const std::string& in
  */
 inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
                                const std::string& dst_layout,
+                               const std::string schedule_rule = "None",

Review Comment:
   update the document above to include this param



##########
python/tvm/meta_schedule/schedule/cuda/layout_transform.py:
##########
@@ -0,0 +1,321 @@
+import tvm
+from tvm import topi
+import math
+from typing import List, Sequence, Tuple
+
+from tvm.tir.schedule import BlockRV, ExprRV, LoopRV
+from collections import deque
+
+
+def tile_layout_transform(
+    sch: tvm.tir.Schedule,
+    block_write: BlockRV,
+    src_layout: str,
+    dst_layout: str,
+    input_shape: List[int],
+    tile_size: ExprRV,
+):
+    """
+    High level tiling for layout transform block.
+    """
+    
+    ## Tiling layout transforms:
+    # Assume we have an input shape of [A, B, C, D] and want to layout transform
+    # ABCD --> DBAC so the output shape would be [D, B, A, C].
+    #
+    # Consider reading from the input buffer in a cache-friendly fashion on CPU. We would
+    # expect a loop structure like:
+    # lAr, lBr, lCr, lDr = T.grid(A, B, C, D)
+    #
+    # Meanwhile consider writing to the output buffer in a cache-friendly fashion on CPU:
+    # lDw, lBw, lAw, lCw = T.grid(D, B, A, C)
+    #
+    # Clearly in many scenarios it is impossible to guarantee contiguous writes and reads
+    # within a single loop. Due to non-adjacent dimensions. Instead we work on transposing some 
+    # small sub-tensor of our input writing and then reading from shared memory. We must now 
+    # construct our submatrix so that reading and writing can both be done with some contiguous  
+    # access in global memory.
+    #
+    # Consider the case of a 2D transpose. For example [1024, 2048] -> [2048, 1024].
+    # We note that if we deal with a submatrix of shape [32, 32] which corresponds 
+    # to the dimension of our input tensor, then rows of the submatrix are contiguous
+    # in the input tensor. Meanwhile, columns of our submatrix are contiguous in our 
+    # output vector. Therefore, with this tile shape we have opportunity to read
+    # contiguously in our input tensor and write to shared memory, and write contiguously
+    # to our output tensor.
+    #
+    # The multiple dimensional case has a similar analogue. We want to allocate shared
+    # memory per block of [`tile_size`, `tile_size`]. We want the inner most dimension
+    # of our shared memory to correspond to contiguous reads from the input tensor and
+    # the outer dimension to correspond to contiguous writes into the output tensor.
+    # 
+    # In terms of the loop structure reading from the input tensor, the inner most loops
+    # of our tile must correspond to the inner most dimensions of the input shape,  
+    # while the outer dimensions correspond to the inner most dimensions of the output shape.
+    # To obtain an inner tile with this loop structure we factor out a contiguous `tile_size`
+    # chunk of our loop in the shape of interest. 
+    # 
+    # An example is probably best to show this idea:
+    # Let's say we want a layout transform of ABCD --> DCAB. With shape
+    # [1024_a, 2_b, 32_c, 8_d] --> [8_d, 32_c, 1024_a, 2_b]
+    #
+    # And tile size 32.
+    #
+    # Then we initially have a coalesced-read loop pattern of:
+    # T.grid(1024_a, 2_b, 32_c, 8_d)
+    # 
+    # To obtain an inner tile of 32, we factor 4 from 32_c and 8 from 8_d:
+    # T.grid(1024_a, 2_b, 8_c1, 1_d1, 4_c2t, 8_d2t)
+    # T.grid(1024_a, 2_b, 8_cr, 1_dr, 32_dim1)
+    #
+    # To obtain an outer tile of 32, we factor from B then A to follow contiguous write 
+    # pattern:
+    #
+    # T.grid(64_a1, 1_b1, 8_cr, 1_dr, 16_a2t, 2_b2t, 32_dim1)
+    # T.grid(64_ar, 1_br, 8_cr, 1_dr, 32_dim0, 32_dim1)
+    #
+    # Which allows us to read a tile with our wanted properties.
+    # For writing we use the existing analysis infrastructure to generate the proper structure for writing.
+
+    def pad_dimension_to_at_least_number(loop: LoopRV, requested_size: int):
+        """E.g. if loop has extant of 8 but we want 10, returns size 10 loop with padding."""
+        l1, l2 = sch.split(loop, [None, requested_size])
+        return sch.fuse(l1, l2)
+
+    def pad_dimension_to_factor_of_tile_size(
+        loop: LoopRV, initial_size: int, tile_size: int = tile_size
+    ) -> Tuple[LoopRV, int]:
+        """
+        Pads loop of given size until it is divisble into tile_size.
+        If the given size of the loop is greater than tile size. Do not pad.
+
+        example, loop_size = 5, tile_size = 32. loop_size --> 8
+                loop_size = 5, tile_size = 36. loop_size --> 6
+                loop_size = 8, tile_size = 32. loop_size --> 8
+                loop_size = 33, tile_size = 32. loop_size --> 33
+
+        Returns padded loopRV and the new size
+        """
+        if tile_size % initial_size == 0:
+            return loop, int(initial_size)
+
+        if initial_size > tile_size or initial_size == tile_size:
+            return loop, int(initial_size)
+
+        # if initial_size > tile_size return without change, factor = 1
+        size = initial_size
+        while (tile_size % size) % tile_size > 0:
+            size += 1
+
+        return pad_dimension_to_at_least_number(loop, size), int(size)
+
+    def spin_out_factor(
+        loops: List[LoopRV], loop_extants: List[int], index: int, factor_needed: int
+    ) -> Tuple[List[LoopRV], List[int], int]:
+        """
+        Factor out loop dimensions to reach the requested factor. Updates the schedule in-place.
+
+        E.g. say we want to factors which eventually multiply to 32 (factor_needed).
+
+        Say we have the index we chose is a loop with an extant of 8.
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed = 32, index = 3
+            - 8 divides into 32 so we just split up the loop into two loops with extants 1 and 8.
+            - we then keep the 1-loop in place and move the new 8-loop to back of the list of loops
+            - ending loops / loop_extants = [3, 32, 6, 1, 8], remaining_factor_needed = 32 / 8 = 4
+
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed=32, index = 0
+            - 3 does not divide 32, so we pad until the extant divides 32, e.g. 4
+            - we then split up the loop into extants 1 and 4, moving the 4 to the back
+            - ending loops / loop_extants = [1, 32, 6, 8, 4], remaining_factor_needed = 32 / 4 = 8
+
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed=5, index = 3
+            - 8 is larger than 5 so we immediately do the splitting routine.
+            - the 8 extant loop becomes loops with extants 2 and 5
+            - ending loops / loop_extants = [1, 32, 6, 2, 5], remaining_factor_needed = 5 / 5 = 1
+
+        After updating loop ordering in place, returns the new list of loops, extants, and the
+        remaining factor needed.
+        """
+        cur_loop = loops[index]
+        cur_extant = loop_extants[index]
+
+        # Pad loops to divide evenly for factors needed, and split
+        new_loop, new_size = pad_dimension_to_factor_of_tile_size(
+            cur_loop, cur_extant, tile_size=factor_needed
+        )
+
+        split_factor = min(new_size, factor_needed)
+        new_loop_split, factored_loop = sch.split(new_loop, [None, split_factor])
+        factor_needed = factor_needed // split_factor
+
+        # update caching
+        loops[index] = new_loop_split
+        loops.append(factored_loop)
+
+        loop_extants[index] = math.ceil(new_size / split_factor)
+        loop_extants.append(split_factor)
+
+        sch.reorder(*loops)
+        return loops, loop_extants, factor_needed
+
+    def factor_dim_in_order(
+        indices: Sequence[int],
+        loops: List[LoopRV],
+        cur_loop_extants: List[int],
+        work_needed_inner_loop: int = tile_size,
+    ):
+        """Factors out the loops in the order of indices until we reach needed work.
+        
+        Adds new loop factors to the back in reverse order of access.
+        """
+        for i in indices:
+            loops, cur_loop_extants, work_needed_inner_loop = spin_out_factor(
+                loops, cur_loop_extants, i, work_needed_inner_loop
+            )
+            if work_needed_inner_loop == 1:
+                break
+        return loops, cur_loop_extants
+
+    def get_high_level_loop_structure(block):
+        """Runs the factorization described above."""
+        # index 0 ... rank - 1 will always correspond to original loops
+        # perhaps after they have been factored.
+        loops = sch.get_loops(block)
+        cur_loop_extants = list(input_shape)
+
+        # Factor dim0 tile size and fuse things together
+        loops, cur_loop_extants = factor_dim_in_order(
+            range(rank - 1, -1, -1),
+            loops,
+            cur_loop_extants,
+            work_needed_inner_loop=tile_size,
+        )
+        # The factors which multiply to tile_size are now in back of our
+        # list of loops. However because we added them by traversing the inner
+        # dimensions, they are actually reversed order to guarantee the best access
+        # so reorder so reorder before fusing.
+        loops = loops[:rank] + loops[rank:][::-1]
+        cur_loop_extants = cur_loop_extants[:rank] + cur_loop_extants[rank::-1]
+        sch.reorder(*loops)
+        dim0_loop_tiled = sch.fuse(*loops[rank:])
+        loops = loops[:rank]
+        loops.append(dim0_loop_tiled)
+        cur_loop_extants = cur_loop_extants[:rank]
+        cur_loop_extants.append(tile_size)
+
+        # Same thing with dim1
+        # [:rank + 1], since we placed dim0_loop_tiled in the end which we want to keep
+        loops, cur_loop_extants = factor_dim_in_order(
+            (
+                src_layout.index(dst_layout[loop_index_dst])
+                for loop_index_dst in range(rank - 1, -1, -1)
+            ),
+            loops,
+            cur_loop_extants,
+            work_needed_inner_loop=tile_size,
+        )
+        loops = loops[: rank + 1] + loops[rank + 1 :][::-1]
+        cur_loop_extants = cur_loop_extants[: rank + 1] + cur_loop_extants[rank + 1 :: -1]
+        sch.reorder(*loops)
+        dim1_loop_tiled = sch.fuse(*loops[rank + 1 :])
+        loops = loops[: rank + 1]
+        loops.append(dim1_loop_tiled)
+        cur_loop_extants = cur_loop_extants[: rank + 1]
+        cur_loop_extants.append(tile_size)
+
+    rank = len(src_layout)
+
+    # Outer loop structure of read block matches that of src_layout
+    # E.g. if input_shape is [4, 6, 8]. Loops for read block will be
+    # for i, j, k in T.grid(4, 6, 8):
+    #     ...
+    # Read block will read from global memory coalesced at the start
+    # Assume write to output global memory is coalesced in block_write
+    block_read = sch.cache_read(block_write, 0, "shared")
+
+    # Here we have [loop1, loop2, loop3 ... dim0_tiled, dim1_tiled]
+    get_high_level_loop_structure(block_read)
+    loops = sch.get_loops(block_read)
+
+    # If there are insufficient elements, than dim1_tiled or dim0_tiled might be too small
+    # In all likelihood you should use a smaller tile, but I don't want things to crash.
+    loops[-1] = pad_dimension_to_at_least_number(loops[-1], tile_size)
+    loops[-2] = pad_dimension_to_at_least_number(loops[-2], tile_size)
+
+    # We want the dim0 and dim1 parent loops to be the inner most. Right now dim1 is inner-msot
+    # and we just need to move dim0 in (last dimension of dst).
+    # Recall right now structure is at least [l1 l2 ... ln, dim0_tiled, dim1_tiled]
+    # where n >= 2.
+    dim0_loop_index = src_layout.index(dst_layout[-1])
+    dim0_loop = loops.pop(dim0_loop_index)
+    loops = loops[:-3] + [dim0_loop, loops[-3]] + loops[-2:]
+    sch.reorder(*loops)
+
+    # After this: [outer_loop (block binding), dim0_tiled, dim1_tiled]
+    outer_loop = sch.fuse(*loops[:-2])
+
+    # Now that we have the high level loop structure, we can use reverse_compute_at magic
+    # To get the proper loop structure for writing! This is also as coalesced as possible
+    # already.
+    sch.reverse_compute_at(block_write, outer_loop)
+
+    # Fuse all inner loops for the write into 2 loops, grab inner loops for both read
+    # and write block which have locality (we will bind these to threadIdx)
+    fused_write_loop = sch.fuse(*sch.get_loops(block_write)[1:])
+    _, inner_write_loop = sch.split(fused_write_loop, [None, tile_size])
+    inner_read_loop = sch.get_loops(block_read)[-2]
+
+    sch.bind(loop=outer_loop, thread_axis="blockIdx.x")
+    sch.bind(loop=inner_write_loop, thread_axis="threadIdx.x")
+    sch.bind(loop=inner_read_loop, thread_axis="threadIdx.x")
+
+def auto_inline(sch, start_block):
+    # Autoinlines given block into consumers, and repeats process for consumer of block
+    # Done by default for injective schedules.
+    fringe = deque([start_block])
+    visited = set()
+    while len(fringe) > 0:
+        cur_block = fringe.popleft()
+        if cur_block in visited:
+            continue
+        else:
+            visited.add(cur_block)
+
+        consumer_blocks = sch.get_consumers(cur_block)
+        if len(consumer_blocks) >= 1:
+            fringe.extend(consumer_blocks)
+            sch.compute_inline(cur_block)
+        else:
+            # Found output block, no more inlining needed
+            return cur_block
+
+
+@tvm.register_func("meta_schedule.cuda.layout_transform")
+def cuda_layout_transform_schedule_rule(sch, block):
+    # params: input_buffer, output_buffer
+    params = sch.mod["main"].params
+    input_buffer = sch.mod["main"].buffer_map[params[0]]
+    output_buffer = sch.mod["main"].buffer_map[params[1]]
+    
+    # Info needed for tiling
+    input_shape = [int(dim) for dim in input_buffer.shape]
+    output_shape = [int(dim) for dim in output_buffer.shape]
+    src_layout = sch.get_sref(block).stmt.annotations["src_layout"]
+    dst_layout = sch.get_sref(block).stmt.annotations["dst_layout"]
+
+    # For each schedule we also want to inline each stage as would be done in normal circumstances
+    # to prevent extraneous memory access.
+    block = auto_inline(sch, block)

Review Comment:
   is it for the consumer blocks other than `layout_transform` itself? will `AutoInline` meta schedule rule be applied automatically without calling this?  cc @zxybazh 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #14167: [CUDA][Schedule] Better Layout Transform Schedules

Posted by "AndrewZhaoLuo (via GitHub)" <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #14167:
URL: https://github.com/apache/tvm/pull/14167#discussion_r1136006711


##########
python/tvm/meta_schedule/schedule/cuda/layout_transform.py:
##########
@@ -0,0 +1,321 @@
+import tvm
+from tvm import topi
+import math
+from typing import List, Sequence, Tuple
+
+from tvm.tir.schedule import BlockRV, ExprRV, LoopRV
+from collections import deque
+
+
+def tile_layout_transform(
+    sch: tvm.tir.Schedule,
+    block_write: BlockRV,
+    src_layout: str,
+    dst_layout: str,
+    input_shape: List[int],
+    tile_size: ExprRV,
+):
+    """
+    High level tiling for layout transform block.
+    """
+    
+    ## Tiling layout transforms:
+    # Assume we have an input shape of [A, B, C, D] and want to layout transform
+    # ABCD --> DBAC so the output shape would be [D, B, A, C].
+    #
+    # Consider reading from the input buffer in a cache-friendly fashion on CPU. We would
+    # expect a loop structure like:
+    # lAr, lBr, lCr, lDr = T.grid(A, B, C, D)
+    #
+    # Meanwhile consider writing to the output buffer in a cache-friendly fashion on CPU:
+    # lDw, lBw, lAw, lCw = T.grid(D, B, A, C)
+    #
+    # Clearly in many scenarios it is impossible to guarantee contiguous writes and reads
+    # within a single loop. Due to non-adjacent dimensions. Instead we work on transposing some 
+    # small sub-tensor of our input writing and then reading from shared memory. We must now 
+    # construct our submatrix so that reading and writing can both be done with some contiguous  
+    # access in global memory.
+    #
+    # Consider the case of a 2D transpose. For example [1024, 2048] -> [2048, 1024].
+    # We note that if we deal with a submatrix of shape [32, 32] which corresponds 
+    # to the dimension of our input tensor, then rows of the submatrix are contiguous
+    # in the input tensor. Meanwhile, columns of our submatrix are contiguous in our 
+    # output vector. Therefore, with this tile shape we have opportunity to read
+    # contiguously in our input tensor and write to shared memory, and write contiguously
+    # to our output tensor.
+    #
+    # The multiple dimensional case has a similar analogue. We want to allocate shared
+    # memory per block of [`tile_size`, `tile_size`]. We want the inner most dimension
+    # of our shared memory to correspond to contiguous reads from the input tensor and
+    # the outer dimension to correspond to contiguous writes into the output tensor.
+    # 
+    # In terms of the loop structure reading from the input tensor, the inner most loops
+    # of our tile must correspond to the inner most dimensions of the input shape,  
+    # while the outer dimensions correspond to the inner most dimensions of the output shape.
+    # To obtain an inner tile with this loop structure we factor out a contiguous `tile_size`
+    # chunk of our loop in the shape of interest. 
+    # 
+    # An example is probably best to show this idea:
+    # Let's say we want a layout transform of ABCD --> DCAB. With shape
+    # [1024_a, 2_b, 32_c, 8_d] --> [8_d, 32_c, 1024_a, 2_b]
+    #
+    # And tile size 32.
+    #
+    # Then we initially have a coalesced-read loop pattern of:
+    # T.grid(1024_a, 2_b, 32_c, 8_d)
+    # 
+    # To obtain an inner tile of 32, we factor 4 from 32_c and 8 from 8_d:
+    # T.grid(1024_a, 2_b, 8_c1, 1_d1, 4_c2t, 8_d2t)
+    # T.grid(1024_a, 2_b, 8_cr, 1_dr, 32_dim1)
+    #
+    # To obtain an outer tile of 32, we factor from B then A to follow contiguous write 
+    # pattern:
+    #
+    # T.grid(64_a1, 1_b1, 8_cr, 1_dr, 16_a2t, 2_b2t, 32_dim1)
+    # T.grid(64_ar, 1_br, 8_cr, 1_dr, 32_dim0, 32_dim1)
+    #
+    # Which allows us to read a tile with our wanted properties.
+    # For writing we use the existing analysis infrastructure to generate the proper structure for writing.
+
+    def pad_dimension_to_at_least_number(loop: LoopRV, requested_size: int):
+        """E.g. if loop has extant of 8 but we want 10, returns size 10 loop with padding."""
+        l1, l2 = sch.split(loop, [None, requested_size])
+        return sch.fuse(l1, l2)
+
+    def pad_dimension_to_factor_of_tile_size(
+        loop: LoopRV, initial_size: int, tile_size: int = tile_size
+    ) -> Tuple[LoopRV, int]:
+        """
+        Pads loop of given size until it is divisble into tile_size.
+        If the given size of the loop is greater than tile size. Do not pad.
+
+        example, loop_size = 5, tile_size = 32. loop_size --> 8
+                loop_size = 5, tile_size = 36. loop_size --> 6
+                loop_size = 8, tile_size = 32. loop_size --> 8
+                loop_size = 33, tile_size = 32. loop_size --> 33
+
+        Returns padded loopRV and the new size
+        """
+        if tile_size % initial_size == 0:
+            return loop, int(initial_size)
+
+        if initial_size > tile_size or initial_size == tile_size:
+            return loop, int(initial_size)
+
+        # if initial_size > tile_size return without change, factor = 1
+        size = initial_size
+        while (tile_size % size) % tile_size > 0:
+            size += 1
+
+        return pad_dimension_to_at_least_number(loop, size), int(size)
+
+    def spin_out_factor(
+        loops: List[LoopRV], loop_extants: List[int], index: int, factor_needed: int
+    ) -> Tuple[List[LoopRV], List[int], int]:
+        """
+        Factor out loop dimensions to reach the requested factor. Updates the schedule in-place.
+
+        E.g. say we want to factors which eventually multiply to 32 (factor_needed).
+
+        Say we have the index we chose is a loop with an extant of 8.
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed = 32, index = 3
+            - 8 divides into 32 so we just split up the loop into two loops with extants 1 and 8.
+            - we then keep the 1-loop in place and move the new 8-loop to back of the list of loops
+            - ending loops / loop_extants = [3, 32, 6, 1, 8], remaining_factor_needed = 32 / 8 = 4
+
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed=32, index = 0
+            - 3 does not divide 32, so we pad until the extant divides 32, e.g. 4
+            - we then split up the loop into extants 1 and 4, moving the 4 to the back
+            - ending loops / loop_extants = [1, 32, 6, 8, 4], remaining_factor_needed = 32 / 4 = 8
+
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed=5, index = 3
+            - 8 is larger than 5 so we immediately do the splitting routine.
+            - the 8 extant loop becomes loops with extants 2 and 5
+            - ending loops / loop_extants = [1, 32, 6, 2, 5], remaining_factor_needed = 5 / 5 = 1
+
+        After updating loop ordering in place, returns the new list of loops, extants, and the
+        remaining factor needed.
+        """
+        cur_loop = loops[index]
+        cur_extant = loop_extants[index]
+
+        # Pad loops to divide evenly for factors needed, and split
+        new_loop, new_size = pad_dimension_to_factor_of_tile_size(
+            cur_loop, cur_extant, tile_size=factor_needed
+        )
+
+        split_factor = min(new_size, factor_needed)
+        new_loop_split, factored_loop = sch.split(new_loop, [None, split_factor])
+        factor_needed = factor_needed // split_factor
+
+        # update caching
+        loops[index] = new_loop_split
+        loops.append(factored_loop)
+
+        loop_extants[index] = math.ceil(new_size / split_factor)
+        loop_extants.append(split_factor)
+
+        sch.reorder(*loops)
+        return loops, loop_extants, factor_needed
+
+    def factor_dim_in_order(
+        indices: Sequence[int],
+        loops: List[LoopRV],
+        cur_loop_extants: List[int],
+        work_needed_inner_loop: int = tile_size,
+    ):
+        """Factors out the loops in the order of indices until we reach needed work.
+        
+        Adds new loop factors to the back in reverse order of access.
+        """
+        for i in indices:
+            loops, cur_loop_extants, work_needed_inner_loop = spin_out_factor(
+                loops, cur_loop_extants, i, work_needed_inner_loop
+            )
+            if work_needed_inner_loop == 1:
+                break
+        return loops, cur_loop_extants
+
+    def get_high_level_loop_structure(block):
+        """Runs the factorization described above."""
+        # index 0 ... rank - 1 will always correspond to original loops
+        # perhaps after they have been factored.
+        loops = sch.get_loops(block)
+        cur_loop_extants = list(input_shape)
+
+        # Factor dim0 tile size and fuse things together
+        loops, cur_loop_extants = factor_dim_in_order(
+            range(rank - 1, -1, -1),
+            loops,
+            cur_loop_extants,
+            work_needed_inner_loop=tile_size,
+        )
+        # The factors which multiply to tile_size are now in back of our
+        # list of loops. However because we added them by traversing the inner
+        # dimensions, they are actually reversed order to guarantee the best access
+        # so reorder so reorder before fusing.
+        loops = loops[:rank] + loops[rank:][::-1]
+        cur_loop_extants = cur_loop_extants[:rank] + cur_loop_extants[rank::-1]
+        sch.reorder(*loops)
+        dim0_loop_tiled = sch.fuse(*loops[rank:])
+        loops = loops[:rank]
+        loops.append(dim0_loop_tiled)
+        cur_loop_extants = cur_loop_extants[:rank]
+        cur_loop_extants.append(tile_size)
+
+        # Same thing with dim1
+        # [:rank + 1], since we placed dim0_loop_tiled in the end which we want to keep
+        loops, cur_loop_extants = factor_dim_in_order(
+            (
+                src_layout.index(dst_layout[loop_index_dst])
+                for loop_index_dst in range(rank - 1, -1, -1)
+            ),
+            loops,
+            cur_loop_extants,
+            work_needed_inner_loop=tile_size,
+        )
+        loops = loops[: rank + 1] + loops[rank + 1 :][::-1]
+        cur_loop_extants = cur_loop_extants[: rank + 1] + cur_loop_extants[rank + 1 :: -1]
+        sch.reorder(*loops)
+        dim1_loop_tiled = sch.fuse(*loops[rank + 1 :])
+        loops = loops[: rank + 1]
+        loops.append(dim1_loop_tiled)
+        cur_loop_extants = cur_loop_extants[: rank + 1]
+        cur_loop_extants.append(tile_size)
+
+    rank = len(src_layout)
+
+    # Outer loop structure of read block matches that of src_layout
+    # E.g. if input_shape is [4, 6, 8]. Loops for read block will be
+    # for i, j, k in T.grid(4, 6, 8):
+    #     ...
+    # Read block will read from global memory coalesced at the start
+    # Assume write to output global memory is coalesced in block_write
+    block_read = sch.cache_read(block_write, 0, "shared")
+
+    # Here we have [loop1, loop2, loop3 ... dim0_tiled, dim1_tiled]
+    get_high_level_loop_structure(block_read)
+    loops = sch.get_loops(block_read)
+
+    # If there are insufficient elements, than dim1_tiled or dim0_tiled might be too small
+    # In all likelihood you should use a smaller tile, but I don't want things to crash.
+    loops[-1] = pad_dimension_to_at_least_number(loops[-1], tile_size)
+    loops[-2] = pad_dimension_to_at_least_number(loops[-2], tile_size)
+
+    # We want the dim0 and dim1 parent loops to be the inner most. Right now dim1 is inner-msot
+    # and we just need to move dim0 in (last dimension of dst).
+    # Recall right now structure is at least [l1 l2 ... ln, dim0_tiled, dim1_tiled]
+    # where n >= 2.
+    dim0_loop_index = src_layout.index(dst_layout[-1])
+    dim0_loop = loops.pop(dim0_loop_index)
+    loops = loops[:-3] + [dim0_loop, loops[-3]] + loops[-2:]
+    sch.reorder(*loops)
+
+    # After this: [outer_loop (block binding), dim0_tiled, dim1_tiled]
+    outer_loop = sch.fuse(*loops[:-2])
+
+    # Now that we have the high level loop structure, we can use reverse_compute_at magic
+    # To get the proper loop structure for writing! This is also as coalesced as possible
+    # already.
+    sch.reverse_compute_at(block_write, outer_loop)
+
+    # Fuse all inner loops for the write into 2 loops, grab inner loops for both read
+    # and write block which have locality (we will bind these to threadIdx)
+    fused_write_loop = sch.fuse(*sch.get_loops(block_write)[1:])
+    _, inner_write_loop = sch.split(fused_write_loop, [None, tile_size])
+    inner_read_loop = sch.get_loops(block_read)[-2]
+
+    sch.bind(loop=outer_loop, thread_axis="blockIdx.x")
+    sch.bind(loop=inner_write_loop, thread_axis="threadIdx.x")
+    sch.bind(loop=inner_read_loop, thread_axis="threadIdx.x")
+
+def auto_inline(sch, start_block):
+    # Autoinlines given block into consumers, and repeats process for consumer of block
+    # Done by default for injective schedules.
+    fringe = deque([start_block])
+    visited = set()
+    while len(fringe) > 0:
+        cur_block = fringe.popleft()
+        if cur_block in visited:
+            continue
+        else:
+            visited.add(cur_block)
+
+        consumer_blocks = sch.get_consumers(cur_block)
+        if len(consumer_blocks) >= 1:
+            fringe.extend(consumer_blocks)
+            sch.compute_inline(cur_block)
+        else:
+            # Found output block, no more inlining needed
+            return cur_block
+
+
+@tvm.register_func("meta_schedule.cuda.layout_transform")
+def cuda_layout_transform_schedule_rule(sch, block):
+    # params: input_buffer, output_buffer
+    params = sch.mod["main"].params
+    input_buffer = sch.mod["main"].buffer_map[params[0]]
+    output_buffer = sch.mod["main"].buffer_map[params[1]]
+    
+    # Info needed for tiling
+    input_shape = [int(dim) for dim in input_buffer.shape]
+    output_shape = [int(dim) for dim in output_buffer.shape]
+    src_layout = sch.get_sref(block).stmt.annotations["src_layout"]
+    dst_layout = sch.get_sref(block).stmt.annotations["dst_layout"]
+
+    # For each schedule we also want to inline each stage as would be done in normal circumstances
+    # to prevent extraneous memory access.
+    block = auto_inline(sch, block)

Review Comment:
   This appears to be perhaps when generating the design space, the creation of new blocks to the schedule will not have rules applied. Not sure how to handle this.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #14167: [CUDA][Schedule] Better Layout Transform Schedules

Posted by "AndrewZhaoLuo (via GitHub)" <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #14167:
URL: https://github.com/apache/tvm/pull/14167#discussion_r1136006711


##########
python/tvm/meta_schedule/schedule/cuda/layout_transform.py:
##########
@@ -0,0 +1,321 @@
+import tvm
+from tvm import topi
+import math
+from typing import List, Sequence, Tuple
+
+from tvm.tir.schedule import BlockRV, ExprRV, LoopRV
+from collections import deque
+
+
+def tile_layout_transform(
+    sch: tvm.tir.Schedule,
+    block_write: BlockRV,
+    src_layout: str,
+    dst_layout: str,
+    input_shape: List[int],
+    tile_size: ExprRV,
+):
+    """
+    High level tiling for layout transform block.
+    """
+    
+    ## Tiling layout transforms:
+    # Assume we have an input shape of [A, B, C, D] and want to layout transform
+    # ABCD --> DBAC so the output shape would be [D, B, A, C].
+    #
+    # Consider reading from the input buffer in a cache-friendly fashion on CPU. We would
+    # expect a loop structure like:
+    # lAr, lBr, lCr, lDr = T.grid(A, B, C, D)
+    #
+    # Meanwhile consider writing to the output buffer in a cache-friendly fashion on CPU:
+    # lDw, lBw, lAw, lCw = T.grid(D, B, A, C)
+    #
+    # Clearly in many scenarios it is impossible to guarantee contiguous writes and reads
+    # within a single loop. Due to non-adjacent dimensions. Instead we work on transposing some 
+    # small sub-tensor of our input writing and then reading from shared memory. We must now 
+    # construct our submatrix so that reading and writing can both be done with some contiguous  
+    # access in global memory.
+    #
+    # Consider the case of a 2D transpose. For example [1024, 2048] -> [2048, 1024].
+    # We note that if we deal with a submatrix of shape [32, 32] which corresponds 
+    # to the dimension of our input tensor, then rows of the submatrix are contiguous
+    # in the input tensor. Meanwhile, columns of our submatrix are contiguous in our 
+    # output vector. Therefore, with this tile shape we have opportunity to read
+    # contiguously in our input tensor and write to shared memory, and write contiguously
+    # to our output tensor.
+    #
+    # The multiple dimensional case has a similar analogue. We want to allocate shared
+    # memory per block of [`tile_size`, `tile_size`]. We want the inner most dimension
+    # of our shared memory to correspond to contiguous reads from the input tensor and
+    # the outer dimension to correspond to contiguous writes into the output tensor.
+    # 
+    # In terms of the loop structure reading from the input tensor, the inner most loops
+    # of our tile must correspond to the inner most dimensions of the input shape,  
+    # while the outer dimensions correspond to the inner most dimensions of the output shape.
+    # To obtain an inner tile with this loop structure we factor out a contiguous `tile_size`
+    # chunk of our loop in the shape of interest. 
+    # 
+    # An example is probably best to show this idea:
+    # Let's say we want a layout transform of ABCD --> DCAB. With shape
+    # [1024_a, 2_b, 32_c, 8_d] --> [8_d, 32_c, 1024_a, 2_b]
+    #
+    # And tile size 32.
+    #
+    # Then we initially have a coalesced-read loop pattern of:
+    # T.grid(1024_a, 2_b, 32_c, 8_d)
+    # 
+    # To obtain an inner tile of 32, we factor 4 from 32_c and 8 from 8_d:
+    # T.grid(1024_a, 2_b, 8_c1, 1_d1, 4_c2t, 8_d2t)
+    # T.grid(1024_a, 2_b, 8_cr, 1_dr, 32_dim1)
+    #
+    # To obtain an outer tile of 32, we factor from B then A to follow contiguous write 
+    # pattern:
+    #
+    # T.grid(64_a1, 1_b1, 8_cr, 1_dr, 16_a2t, 2_b2t, 32_dim1)
+    # T.grid(64_ar, 1_br, 8_cr, 1_dr, 32_dim0, 32_dim1)
+    #
+    # Which allows us to read a tile with our wanted properties.
+    # For writing we use the existing analysis infrastructure to generate the proper structure for writing.
+
+    def pad_dimension_to_at_least_number(loop: LoopRV, requested_size: int):
+        """E.g. if loop has extant of 8 but we want 10, returns size 10 loop with padding."""
+        l1, l2 = sch.split(loop, [None, requested_size])
+        return sch.fuse(l1, l2)
+
+    def pad_dimension_to_factor_of_tile_size(
+        loop: LoopRV, initial_size: int, tile_size: int = tile_size
+    ) -> Tuple[LoopRV, int]:
+        """
+        Pads loop of given size until it is divisble into tile_size.
+        If the given size of the loop is greater than tile size. Do not pad.
+
+        example, loop_size = 5, tile_size = 32. loop_size --> 8
+                loop_size = 5, tile_size = 36. loop_size --> 6
+                loop_size = 8, tile_size = 32. loop_size --> 8
+                loop_size = 33, tile_size = 32. loop_size --> 33
+
+        Returns padded loopRV and the new size
+        """
+        if tile_size % initial_size == 0:
+            return loop, int(initial_size)
+
+        if initial_size > tile_size or initial_size == tile_size:
+            return loop, int(initial_size)
+
+        # if initial_size > tile_size return without change, factor = 1
+        size = initial_size
+        while (tile_size % size) % tile_size > 0:
+            size += 1
+
+        return pad_dimension_to_at_least_number(loop, size), int(size)
+
+    def spin_out_factor(
+        loops: List[LoopRV], loop_extants: List[int], index: int, factor_needed: int
+    ) -> Tuple[List[LoopRV], List[int], int]:
+        """
+        Factor out loop dimensions to reach the requested factor. Updates the schedule in-place.
+
+        E.g. say we want to factors which eventually multiply to 32 (factor_needed).
+
+        Say we have the index we chose is a loop with an extant of 8.
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed = 32, index = 3
+            - 8 divides into 32 so we just split up the loop into two loops with extants 1 and 8.
+            - we then keep the 1-loop in place and move the new 8-loop to back of the list of loops
+            - ending loops / loop_extants = [3, 32, 6, 1, 8], remaining_factor_needed = 32 / 8 = 4
+
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed=32, index = 0
+            - 3 does not divide 32, so we pad until the extant divides 32, e.g. 4
+            - we then split up the loop into extants 1 and 4, moving the 4 to the back
+            - ending loops / loop_extants = [1, 32, 6, 8, 4], remaining_factor_needed = 32 / 4 = 8
+
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed=5, index = 3
+            - 8 is larger than 5 so we immediately do the splitting routine.
+            - the 8 extant loop becomes loops with extants 2 and 5
+            - ending loops / loop_extants = [1, 32, 6, 2, 5], remaining_factor_needed = 5 / 5 = 1
+
+        After updating loop ordering in place, returns the new list of loops, extants, and the
+        remaining factor needed.
+        """
+        cur_loop = loops[index]
+        cur_extant = loop_extants[index]
+
+        # Pad loops to divide evenly for factors needed, and split
+        new_loop, new_size = pad_dimension_to_factor_of_tile_size(
+            cur_loop, cur_extant, tile_size=factor_needed
+        )
+
+        split_factor = min(new_size, factor_needed)
+        new_loop_split, factored_loop = sch.split(new_loop, [None, split_factor])
+        factor_needed = factor_needed // split_factor
+
+        # update caching
+        loops[index] = new_loop_split
+        loops.append(factored_loop)
+
+        loop_extants[index] = math.ceil(new_size / split_factor)
+        loop_extants.append(split_factor)
+
+        sch.reorder(*loops)
+        return loops, loop_extants, factor_needed
+
+    def factor_dim_in_order(
+        indices: Sequence[int],
+        loops: List[LoopRV],
+        cur_loop_extants: List[int],
+        work_needed_inner_loop: int = tile_size,
+    ):
+        """Factors out the loops in the order of indices until we reach needed work.
+        
+        Adds new loop factors to the back in reverse order of access.
+        """
+        for i in indices:
+            loops, cur_loop_extants, work_needed_inner_loop = spin_out_factor(
+                loops, cur_loop_extants, i, work_needed_inner_loop
+            )
+            if work_needed_inner_loop == 1:
+                break
+        return loops, cur_loop_extants
+
+    def get_high_level_loop_structure(block):
+        """Runs the factorization described above."""
+        # index 0 ... rank - 1 will always correspond to original loops
+        # perhaps after they have been factored.
+        loops = sch.get_loops(block)
+        cur_loop_extants = list(input_shape)
+
+        # Factor dim0 tile size and fuse things together
+        loops, cur_loop_extants = factor_dim_in_order(
+            range(rank - 1, -1, -1),
+            loops,
+            cur_loop_extants,
+            work_needed_inner_loop=tile_size,
+        )
+        # The factors which multiply to tile_size are now in back of our
+        # list of loops. However because we added them by traversing the inner
+        # dimensions, they are actually reversed order to guarantee the best access
+        # so reorder so reorder before fusing.
+        loops = loops[:rank] + loops[rank:][::-1]
+        cur_loop_extants = cur_loop_extants[:rank] + cur_loop_extants[rank::-1]
+        sch.reorder(*loops)
+        dim0_loop_tiled = sch.fuse(*loops[rank:])
+        loops = loops[:rank]
+        loops.append(dim0_loop_tiled)
+        cur_loop_extants = cur_loop_extants[:rank]
+        cur_loop_extants.append(tile_size)
+
+        # Same thing with dim1
+        # [:rank + 1], since we placed dim0_loop_tiled in the end which we want to keep
+        loops, cur_loop_extants = factor_dim_in_order(
+            (
+                src_layout.index(dst_layout[loop_index_dst])
+                for loop_index_dst in range(rank - 1, -1, -1)
+            ),
+            loops,
+            cur_loop_extants,
+            work_needed_inner_loop=tile_size,
+        )
+        loops = loops[: rank + 1] + loops[rank + 1 :][::-1]
+        cur_loop_extants = cur_loop_extants[: rank + 1] + cur_loop_extants[rank + 1 :: -1]
+        sch.reorder(*loops)
+        dim1_loop_tiled = sch.fuse(*loops[rank + 1 :])
+        loops = loops[: rank + 1]
+        loops.append(dim1_loop_tiled)
+        cur_loop_extants = cur_loop_extants[: rank + 1]
+        cur_loop_extants.append(tile_size)
+
+    rank = len(src_layout)
+
+    # Outer loop structure of read block matches that of src_layout
+    # E.g. if input_shape is [4, 6, 8]. Loops for read block will be
+    # for i, j, k in T.grid(4, 6, 8):
+    #     ...
+    # Read block will read from global memory coalesced at the start
+    # Assume write to output global memory is coalesced in block_write
+    block_read = sch.cache_read(block_write, 0, "shared")
+
+    # Here we have [loop1, loop2, loop3 ... dim0_tiled, dim1_tiled]
+    get_high_level_loop_structure(block_read)
+    loops = sch.get_loops(block_read)
+
+    # If there are insufficient elements, than dim1_tiled or dim0_tiled might be too small
+    # In all likelihood you should use a smaller tile, but I don't want things to crash.
+    loops[-1] = pad_dimension_to_at_least_number(loops[-1], tile_size)
+    loops[-2] = pad_dimension_to_at_least_number(loops[-2], tile_size)
+
+    # We want the dim0 and dim1 parent loops to be the inner most. Right now dim1 is inner-msot
+    # and we just need to move dim0 in (last dimension of dst).
+    # Recall right now structure is at least [l1 l2 ... ln, dim0_tiled, dim1_tiled]
+    # where n >= 2.
+    dim0_loop_index = src_layout.index(dst_layout[-1])
+    dim0_loop = loops.pop(dim0_loop_index)
+    loops = loops[:-3] + [dim0_loop, loops[-3]] + loops[-2:]
+    sch.reorder(*loops)
+
+    # After this: [outer_loop (block binding), dim0_tiled, dim1_tiled]
+    outer_loop = sch.fuse(*loops[:-2])
+
+    # Now that we have the high level loop structure, we can use reverse_compute_at magic
+    # To get the proper loop structure for writing! This is also as coalesced as possible
+    # already.
+    sch.reverse_compute_at(block_write, outer_loop)
+
+    # Fuse all inner loops for the write into 2 loops, grab inner loops for both read
+    # and write block which have locality (we will bind these to threadIdx)
+    fused_write_loop = sch.fuse(*sch.get_loops(block_write)[1:])
+    _, inner_write_loop = sch.split(fused_write_loop, [None, tile_size])
+    inner_read_loop = sch.get_loops(block_read)[-2]
+
+    sch.bind(loop=outer_loop, thread_axis="blockIdx.x")
+    sch.bind(loop=inner_write_loop, thread_axis="threadIdx.x")
+    sch.bind(loop=inner_read_loop, thread_axis="threadIdx.x")
+
+def auto_inline(sch, start_block):
+    # Autoinlines given block into consumers, and repeats process for consumer of block
+    # Done by default for injective schedules.
+    fringe = deque([start_block])
+    visited = set()
+    while len(fringe) > 0:
+        cur_block = fringe.popleft()
+        if cur_block in visited:
+            continue
+        else:
+            visited.add(cur_block)
+
+        consumer_blocks = sch.get_consumers(cur_block)
+        if len(consumer_blocks) >= 1:
+            fringe.extend(consumer_blocks)
+            sch.compute_inline(cur_block)
+        else:
+            # Found output block, no more inlining needed
+            return cur_block
+
+
+@tvm.register_func("meta_schedule.cuda.layout_transform")
+def cuda_layout_transform_schedule_rule(sch, block):
+    # params: input_buffer, output_buffer
+    params = sch.mod["main"].params
+    input_buffer = sch.mod["main"].buffer_map[params[0]]
+    output_buffer = sch.mod["main"].buffer_map[params[1]]
+    
+    # Info needed for tiling
+    input_shape = [int(dim) for dim in input_buffer.shape]
+    output_shape = [int(dim) for dim in output_buffer.shape]
+    src_layout = sch.get_sref(block).stmt.annotations["src_layout"]
+    dst_layout = sch.get_sref(block).stmt.annotations["dst_layout"]
+
+    # For each schedule we also want to inline each stage as would be done in normal circumstances
+    # to prevent extraneous memory access.
+    block = auto_inline(sch, block)

Review Comment:
   This appears to be perhaps when generating the design space, the creation of new blocks to the schedule will not have rules applied. I'm investigating.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #14167: [CUDA][Schedule] Better Layout Transform Schedules

Posted by "tkonolige (via GitHub)" <gi...@apache.org>.
tkonolige commented on code in PR #14167:
URL: https://github.com/apache/tvm/pull/14167#discussion_r1139431118


##########
src/meta_schedule/postproc/rewrite_cooperative_fetch.cc:
##########
@@ -39,7 +39,17 @@ Optional<Integer> ParseThreadBinding(const Schedule& sch, const Instruction& ins
   if (thread_axis != axis) {
     return NullOpt;
   }
-  return Downcast<Integer>(sch->Get(Downcast<LoopRV>(inst->inputs[0]))->extent);
+
+  try {
+    return Downcast<Integer>(sch->Get(Downcast<LoopRV>(inst->inputs[0]))->extent);
+  } catch (const std::exception& e) {

Review Comment:
   Can you be a little more specific on which exception you are catching here? Should probably be at least a `TVMError`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #14167: [CUDA][Schedule] Better Layout Transform Schedules

Posted by "tkonolige (via GitHub)" <gi...@apache.org>.
tkonolige commented on code in PR #14167:
URL: https://github.com/apache/tvm/pull/14167#discussion_r1126835722


##########
python/tvm/meta_schedule/schedule/cuda/layout_transform.py:
##########
@@ -0,0 +1,322 @@
+import tvm
+from tvm import topi
+import math
+from typing import List, Sequence, Tuple
+
+from tvm.tir.schedule import BlockRV, ExprRV, LoopRV
+from collections import deque
+
+
+def tile_layout_transform(
+    sch: tvm.tir.Schedule,
+    block_write: BlockRV,
+    src_layout: str,
+    dst_layout: str,
+    input_shape: List[int],
+    tile_size: ExprRV,
+):
+    """
+    High level tiling for layout transform block.
+    """
+
+    ## Tiling layout transforms:
+    # Assume we have an input shape of [A, B, C, D] and want to layout transform
+    # ABCD --> DBAC so the output shape would be [D, B, A, C].
+    #
+    # Consider reading from the input buffer in a cache-friendly fashion on CPU. We would
+    # expect a loop structure like:
+    # lAr, lBr, lCr, lDr = T.grid(A, B, C, D)
+    #
+    # Meanwhile consider writing to the output buffer in a cache-friendly fashion on CPU:
+    # lDw, lBw, lAw, lCw = T.grid(D, B, A, C)
+    #
+    # Clearly in many scenarios it is impossible to guarantee contiguous writes and reads
+    # within a single loop. Due to non-adjacent dimensions. Instead we work on transposing some
+    # small sub-tensor of our input writing and then reading from shared memory. We must now
+    # construct our submatrix so that reading and writing can both be done with some contiguous
+    # access in global memory.
+    #
+    # Consider the case of a 2D transpose. For example [1024, 2048] -> [2048, 1024].
+    # We note that if we deal with a submatrix of shape [32, 32] which corresponds
+    # to the dimension of our input tensor, then rows of the submatrix are contiguous
+    # in the input tensor. Meanwhile, columns of our submatrix are contiguous in our
+    # output vector. Therefore, with this tile shape we have opportunity to read
+    # contiguously in our input tensor and write to shared memory, and write contiguously
+    # to our output tensor.
+    #
+    # The multiple dimensional case has a similar analogue. We want to allocate shared
+    # memory per block of [`tile_size`, `tile_size`]. We want the inner most dimension
+    # of our shared memory to correspond to contiguous reads from the input tensor and
+    # the outer dimension to correspond to contiguous writes into the output tensor.
+    #
+    # In terms of the loop structure reading from the input tensor, the inner most loops
+    # of our tile must correspond to the inner most dimensions of the input shape,
+    # while the outer dimensions correspond to the inner most dimensions of the output shape.
+    # To obtain an inner tile with this loop structure we factor out a contiguous `tile_size`
+    # chunk of our loop in the shape of interest.
+    #
+    # An example is probably best to show this idea:
+    # Let's say we want a layout transform of ABCD --> DCAB. With shape
+    # [1024_a, 2_b, 32_c, 8_d] --> [8_d, 32_c, 1024_a, 2_b]
+    #
+    # And tile size 32.
+    #
+    # Then we initially have a coalesced-read loop pattern of:
+    # T.grid(1024_a, 2_b, 32_c, 8_d)
+    #
+    # To obtain an inner tile of 32, we factor 4 from 32_c and 8 from 8_d:
+    # T.grid(1024_a, 2_b, 8_c1, 1_d1, 4_c2t, 8_d2t)
+    # T.grid(1024_a, 2_b, 8_cr, 1_dr, 32_dim1)
+    #
+    # To obtain an outer tile of 32, we factor from B then A to follow contiguous write
+    # pattern:
+    #
+    # T.grid(64_a1, 1_b1, 8_cr, 1_dr, 16_a2t, 2_b2t, 32_dim1)
+    # T.grid(64_ar, 1_br, 8_cr, 1_dr, 32_dim0, 32_dim1)
+    #
+    # Which allows us to read a tile with our wanted properties.
+    # For writing we use the existing analysis infrastructure to generate the proper structure for writing.
+
+    def pad_dimension_to_at_least_number(loop: LoopRV, requested_size: int):
+        """E.g. if loop has extant of 8 but we want 10, returns size 10 loop with padding."""
+        l1, l2 = sch.split(loop, [None, requested_size])
+        return sch.fuse(l1, l2)
+
+    def pad_dimension_to_factor_of_tile_size(
+        loop: LoopRV, initial_size: int, tile_size: int = tile_size
+    ) -> Tuple[LoopRV, int]:
+        """
+        Pads loop of given size until it is divisble into tile_size.
+        If the given size of the loop is greater than tile size. Do not pad.
+
+        example, loop_size = 5, tile_size = 32. loop_size --> 8
+                loop_size = 5, tile_size = 36. loop_size --> 6
+                loop_size = 8, tile_size = 32. loop_size --> 8
+                loop_size = 33, tile_size = 32. loop_size --> 33
+
+        Returns padded loopRV and the new size
+        """
+        if tile_size % initial_size == 0:
+            return loop, int(initial_size)
+
+        if initial_size > tile_size or initial_size == tile_size:
+            return loop, int(initial_size)
+
+        # if initial_size > tile_size return without change, factor = 1
+        size = initial_size
+        while (tile_size % size) % tile_size > 0:
+            size += 1
+
+        return pad_dimension_to_at_least_number(loop, size), int(size)
+
+    def spin_out_factor(
+        loops: List[LoopRV], loop_extants: List[int], index: int, factor_needed: int
+    ) -> Tuple[List[LoopRV], List[int], int]:
+        """
+        Factor out loop dimensions to reach the requested factor. Updates the schedule in-place.
+
+        E.g. say we want to factors which eventually multiply to 32 (factor_needed).
+
+        Say we have the index we chose is a loop with an extant of 8.
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed = 32, index = 3
+            - 8 divides into 32 so we just split up the loop into two loops with extants 1 and 8.
+            - we then keep the 1-loop in place and move the new 8-loop to back of the list of loops
+            - ending loops / loop_extants = [3, 32, 6, 1, 8], remaining_factor_needed = 32 / 8 = 4
+
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed=32, index = 0
+            - 3 does not divide 32, so we pad until the extant divides 32, e.g. 4
+            - we then split up the loop into extants 1 and 4, moving the 4 to the back
+            - ending loops / loop_extants = [1, 32, 6, 8, 4], remaining_factor_needed = 32 / 4 = 8
+
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed=5, index = 3
+            - 8 is larger than 5 so we immediately do the splitting routine.
+            - the 8 extant loop becomes loops with extants 2 and 5
+            - ending loops / loop_extants = [1, 32, 6, 2, 5], remaining_factor_needed = 5 / 5 = 1
+
+        After updating loop ordering in place, returns the new list of loops, extants, and the
+        remaining factor needed.
+        """
+        cur_loop = loops[index]
+        cur_extant = loop_extants[index]
+
+        # Pad loops to divide evenly for factors needed, and split
+        new_loop, new_size = pad_dimension_to_factor_of_tile_size(
+            cur_loop, cur_extant, tile_size=factor_needed
+        )
+
+        split_factor = min(new_size, factor_needed)
+        new_loop_split, factored_loop = sch.split(new_loop, [None, split_factor])
+        factor_needed = factor_needed // split_factor
+
+        # update caching
+        loops[index] = new_loop_split
+        loops.append(factored_loop)
+
+        loop_extants[index] = math.ceil(new_size / split_factor)
+        loop_extants.append(split_factor)
+
+        sch.reorder(*loops)
+        return loops, loop_extants, factor_needed
+
+    def factor_dim_in_order(
+        indices: Sequence[int],
+        loops: List[LoopRV],
+        cur_loop_extants: List[int],
+        work_needed_inner_loop: int = tile_size,
+    ):
+        """Factors out the loops in the order of indices until we reach needed work.
+
+        Adds new loop factors to the back in reverse order of access.
+        """
+        for i in indices:
+            loops, cur_loop_extants, work_needed_inner_loop = spin_out_factor(
+                loops, cur_loop_extants, i, work_needed_inner_loop
+            )
+            if work_needed_inner_loop == 1:
+                break
+        return loops, cur_loop_extants
+
+    def get_high_level_loop_structure(block):
+        """Runs the factorization described above."""
+        # index 0 ... rank - 1 will always correspond to original loops
+        # perhaps after they have been factored.
+        loops = sch.get_loops(block)
+        cur_loop_extants = list(input_shape)
+
+        # Factor dim0 tile size and fuse things together
+        loops, cur_loop_extants = factor_dim_in_order(
+            range(rank - 1, -1, -1),
+            loops,
+            cur_loop_extants,
+            work_needed_inner_loop=tile_size,
+        )
+        # The factors which multiply to tile_size are now in back of our
+        # list of loops. However because we added them by traversing the inner
+        # dimensions, they are actually reversed order to guarantee the best access
+        # so reorder so reorder before fusing.
+        loops = loops[:rank] + loops[rank:][::-1]
+        cur_loop_extants = cur_loop_extants[:rank] + cur_loop_extants[rank::-1]
+        sch.reorder(*loops)
+        dim0_loop_tiled = sch.fuse(*loops[rank:])
+        loops = loops[:rank]
+        loops.append(dim0_loop_tiled)
+        cur_loop_extants = cur_loop_extants[:rank]
+        cur_loop_extants.append(tile_size)
+
+        # Same thing with dim1
+        # [:rank + 1], since we placed dim0_loop_tiled in the end which we want to keep
+        loops, cur_loop_extants = factor_dim_in_order(
+            (
+                src_layout.index(dst_layout[loop_index_dst])
+                for loop_index_dst in range(rank - 1, -1, -1)
+            ),
+            loops,
+            cur_loop_extants,
+            work_needed_inner_loop=tile_size,
+        )
+        loops = loops[: rank + 1] + loops[rank + 1 :][::-1]
+        cur_loop_extants = cur_loop_extants[: rank + 1] + cur_loop_extants[rank + 1 :: -1]
+        sch.reorder(*loops)
+        dim1_loop_tiled = sch.fuse(*loops[rank + 1 :])
+        loops = loops[: rank + 1]
+        loops.append(dim1_loop_tiled)
+        cur_loop_extants = cur_loop_extants[: rank + 1]
+        cur_loop_extants.append(tile_size)
+
+    rank = len(src_layout)
+
+    # Outer loop structure of read block matches that of src_layout
+    # E.g. if input_shape is [4, 6, 8]. Loops for read block will be
+    # for i, j, k in T.grid(4, 6, 8):
+    #     ...
+    # Read block will read from global memory coalesced at the start
+    # Assume write to output global memory is coalesced in block_write
+    block_read = sch.cache_read(block_write, 0, "shared")
+
+    # Here we have [loop1, loop2, loop3 ... dim0_tiled, dim1_tiled]
+    get_high_level_loop_structure(block_read)
+    loops = sch.get_loops(block_read)
+
+    # If there are insufficient elements, than dim1_tiled or dim0_tiled might be too small
+    # In all likelihood you should use a smaller tile, but I don't want things to crash.
+    loops[-1] = pad_dimension_to_at_least_number(loops[-1], tile_size)
+    loops[-2] = pad_dimension_to_at_least_number(loops[-2], tile_size)
+
+    # We want the dim0 and dim1 parent loops to be the inner most. Right now dim1 is inner-msot
+    # and we just need to move dim0 in (last dimension of dst).
+    # Recall right now structure is at least [l1 l2 ... ln, dim0_tiled, dim1_tiled]
+    # where n >= 2.
+    dim0_loop_index = src_layout.index(dst_layout[-1])
+    dim0_loop = loops.pop(dim0_loop_index)
+    loops = loops[:-3] + [dim0_loop, loops[-3]] + loops[-2:]
+    sch.reorder(*loops)
+
+    # After this: [outer_loop (block binding), dim0_tiled, dim1_tiled]
+    outer_loop = sch.fuse(*loops[:-2])
+
+    # Now that we have the high level loop structure, we can use reverse_compute_at magic
+    # To get the proper loop structure for writing! This is also as coalesced as possible
+    # already.
+    sch.reverse_compute_at(block_write, outer_loop)
+
+    # Fuse all inner loops for the write into 2 loops, grab inner loops for both read
+    # and write block which have locality (we will bind these to threadIdx)
+    fused_write_loop = sch.fuse(*sch.get_loops(block_write)[1:])
+    _, inner_write_loop = sch.split(fused_write_loop, [None, tile_size])
+    inner_read_loop = sch.get_loops(block_read)[-2]
+
+    sch.bind(loop=outer_loop, thread_axis="blockIdx.x")
+    sch.bind(loop=inner_write_loop, thread_axis="threadIdx.x")
+    sch.bind(loop=inner_read_loop, thread_axis="threadIdx.x")
+
+
+def auto_inline(sch, start_block):
+    # Autoinlines given block into consumers, and repeats process for consumer of block
+    # Done by default for injective schedules.
+    fringe = deque([start_block])
+    visited = set()
+    while len(fringe) > 0:
+        cur_block = fringe.popleft()
+        if cur_block in visited:
+            continue
+        else:
+            visited.add(cur_block)
+
+        consumer_blocks = sch.get_consumers(cur_block)
+        if len(consumer_blocks) >= 1:
+            fringe.extend(consumer_blocks)
+            sch.compute_inline(cur_block)
+        else:
+            # Found output block, no more inlining needed
+            return cur_block
+
+
+@tvm.register_func("meta_schedule.cuda.layout_transform")
+def cuda_layout_transform_schedule_rule(sch, block):
+    # params: input_buffer, output_buffer
+    params = sch.mod["main"].params
+    input_buffer = sch.mod["main"].buffer_map[params[0]]
+    output_buffer = sch.mod["main"].buffer_map[params[1]]
+
+    # Info needed for tiling
+    input_shape = [int(dim) for dim in input_buffer.shape]
+    output_shape = [int(dim) for dim in output_buffer.shape]
+    src_layout = sch.get_sref(block).stmt.annotations["src_layout"]
+    dst_layout = sch.get_sref(block).stmt.annotations["dst_layout"]
+
+    # For each schedule we also want to inline each stage as would be done in normal circumstances
+    # to prevent extraneous memory access.
+    block = auto_inline(sch, block)
+
+    schedules = []
+
+    # Always include the default schedules which will be handled via AutoBind schedule rule
+    schedules.append(sch)
+
+    # Tile size 2,3,4...32 as tile size of 1 has no coaslescing.
+    for tile_size in range(2, 33):

Review Comment:
   Maybe make this dependent on warp size?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] elvin-n commented on a diff in pull request #14167: [DRAFT][CUDA][Schedule] Better Layout Transform Schedules

Posted by "elvin-n (via GitHub)" <gi...@apache.org>.
elvin-n commented on code in PR #14167:
URL: https://github.com/apache/tvm/pull/14167#discussion_r1122713183


##########
python/tvm/relay/op/_transform.py:
##########
@@ -94,7 +94,7 @@ def compute_strided_set(attrs, inputs, output_type):
 _reg.register_injective_schedule("strided_set")
 
 # layout_transform
-_reg.register_injective_schedule("layout_transform")
+_reg.register_strategy("layout_transform", strategy.layout_transform_strategy)

Review Comment:
   Does this mean that layout_transform for all platforms stop to be scheduled through injective schedule?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #14167: [CUDA][Schedule] Better Layout Transform Schedules

Posted by "tvm-bot (via GitHub)" <gi...@apache.org>.
tvm-bot commented on PR #14167:
URL: https://github.com/apache/tvm/pull/14167#issuecomment-1451040767

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #14167: [DRAFT][CUDA][Schedule] Better Layout Transform Schedules

Posted by "AndrewZhaoLuo (via GitHub)" <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #14167:
URL: https://github.com/apache/tvm/pull/14167#discussion_r1122796544


##########
python/tvm/relay/op/_transform.py:
##########
@@ -94,7 +94,7 @@ def compute_strided_set(attrs, inputs, output_type):
 _reg.register_injective_schedule("strided_set")
 
 # layout_transform
-_reg.register_injective_schedule("layout_transform")
+_reg.register_strategy("layout_transform", strategy.layout_transform_strategy)

Review Comment:
   Hey @elvin-n. The plan is to make it so all implementations will have an injective like schedule.
   
   For non-cuda targets it will still use an injective schedule via the generic fallback strategy. 
   
   For cuda it will use the new schedule in the PR after using the autoinline injective schedule (I still need to figure out this part for cuda).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo merged pull request #14167: [CUDA][Schedule] Better Layout Transform Schedules

Posted by "AndrewZhaoLuo (via GitHub)" <gi...@apache.org>.
AndrewZhaoLuo merged PR #14167:
URL: https://github.com/apache/tvm/pull/14167


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on pull request #14167: [CUDA][Schedule] Better Layout Transform Schedules

Posted by "AndrewZhaoLuo (via GitHub)" <gi...@apache.org>.
AndrewZhaoLuo commented on PR #14167:
URL: https://github.com/apache/tvm/pull/14167#issuecomment-1461054673

   > > Shared memory — Bank Conflicts
   > 
   > Have you tried using `storage_align` sch primitive? It achieves similar things as shmem padding for power of two size shmem.
   
   Ah yes, this might be exactly what I am looking for, thanks for bringing it to my attention.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on pull request #14167: [CUDA][Schedule] Better Layout Transform Schedules

Posted by "AndrewZhaoLuo (via GitHub)" <gi...@apache.org>.
AndrewZhaoLuo commented on PR #14167:
URL: https://github.com/apache/tvm/pull/14167#issuecomment-1476758980

   Will depend on https://github.com/apache/tvm/pull/14346 before merge. 
   
   Will keep it open a few more days for additional comments.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #14167: [CUDA][Schedule] Better Layout Transform Schedules

Posted by "AndrewZhaoLuo (via GitHub)" <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #14167:
URL: https://github.com/apache/tvm/pull/14167#discussion_r1125027414


##########
python/tvm/relay/op/_transform.py:
##########
@@ -94,7 +94,7 @@ def compute_strided_set(attrs, inputs, output_type):
 _reg.register_injective_schedule("strided_set")
 
 # layout_transform
-_reg.register_injective_schedule("layout_transform")
+_reg.register_strategy("layout_transform", strategy.layout_transform_strategy)

Review Comment:
   It should now autoline and then apply new layout transform schedule rule. Though would probably need some additional testing



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #14167: [CUDA][Schedule] Better Layout Transform Schedules

Posted by "AndrewZhaoLuo (via GitHub)" <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #14167:
URL: https://github.com/apache/tvm/pull/14167#discussion_r1134674873


##########
python/tvm/meta_schedule/schedule/cuda/layout_transform.py:
##########
@@ -0,0 +1,322 @@
+import tvm
+from tvm import topi
+import math
+from typing import List, Sequence, Tuple
+
+from tvm.tir.schedule import BlockRV, ExprRV, LoopRV
+from collections import deque
+
+
+def tile_layout_transform(
+    sch: tvm.tir.Schedule,
+    block_write: BlockRV,
+    src_layout: str,
+    dst_layout: str,
+    input_shape: List[int],
+    tile_size: ExprRV,
+):
+    """
+    High level tiling for layout transform block.
+    """
+
+    ## Tiling layout transforms:
+    # Assume we have an input shape of [A, B, C, D] and want to layout transform
+    # ABCD --> DBAC so the output shape would be [D, B, A, C].
+    #
+    # Consider reading from the input buffer in a cache-friendly fashion on CPU. We would
+    # expect a loop structure like:
+    # lAr, lBr, lCr, lDr = T.grid(A, B, C, D)
+    #
+    # Meanwhile consider writing to the output buffer in a cache-friendly fashion on CPU:
+    # lDw, lBw, lAw, lCw = T.grid(D, B, A, C)
+    #
+    # Clearly in many scenarios it is impossible to guarantee contiguous writes and reads
+    # within a single loop. Due to non-adjacent dimensions. Instead we work on transposing some
+    # small sub-tensor of our input writing and then reading from shared memory. We must now
+    # construct our submatrix so that reading and writing can both be done with some contiguous
+    # access in global memory.
+    #
+    # Consider the case of a 2D transpose. For example [1024, 2048] -> [2048, 1024].
+    # We note that if we deal with a submatrix of shape [32, 32] which corresponds
+    # to the dimension of our input tensor, then rows of the submatrix are contiguous
+    # in the input tensor. Meanwhile, columns of our submatrix are contiguous in our
+    # output vector. Therefore, with this tile shape we have opportunity to read
+    # contiguously in our input tensor and write to shared memory, and write contiguously
+    # to our output tensor.
+    #
+    # The multiple dimensional case has a similar analogue. We want to allocate shared
+    # memory per block of [`tile_size`, `tile_size`]. We want the inner most dimension
+    # of our shared memory to correspond to contiguous reads from the input tensor and
+    # the outer dimension to correspond to contiguous writes into the output tensor.
+    #
+    # In terms of the loop structure reading from the input tensor, the inner most loops
+    # of our tile must correspond to the inner most dimensions of the input shape,
+    # while the outer dimensions correspond to the inner most dimensions of the output shape.
+    # To obtain an inner tile with this loop structure we factor out a contiguous `tile_size`
+    # chunk of our loop in the shape of interest.
+    #
+    # An example is probably best to show this idea:
+    # Let's say we want a layout transform of ABCD --> DCAB. With shape
+    # [1024_a, 2_b, 32_c, 8_d] --> [8_d, 32_c, 1024_a, 2_b]
+    #
+    # And tile size 32.
+    #
+    # Then we initially have a coalesced-read loop pattern of:
+    # T.grid(1024_a, 2_b, 32_c, 8_d)
+    #
+    # To obtain an inner tile of 32, we factor 4 from 32_c and 8 from 8_d:
+    # T.grid(1024_a, 2_b, 8_c1, 1_d1, 4_c2t, 8_d2t)
+    # T.grid(1024_a, 2_b, 8_cr, 1_dr, 32_dim1)
+    #
+    # To obtain an outer tile of 32, we factor from B then A to follow contiguous write
+    # pattern:
+    #
+    # T.grid(64_a1, 1_b1, 8_cr, 1_dr, 16_a2t, 2_b2t, 32_dim1)
+    # T.grid(64_ar, 1_br, 8_cr, 1_dr, 32_dim0, 32_dim1)
+    #
+    # Which allows us to read a tile with our wanted properties.
+    # For writing we use the existing analysis infrastructure to generate the proper structure for writing.
+
+    def pad_dimension_to_at_least_number(loop: LoopRV, requested_size: int):
+        """E.g. if loop has extant of 8 but we want 10, returns size 10 loop with padding."""
+        l1, l2 = sch.split(loop, [None, requested_size])
+        return sch.fuse(l1, l2)
+
+    def pad_dimension_to_factor_of_tile_size(
+        loop: LoopRV, initial_size: int, tile_size: int = tile_size
+    ) -> Tuple[LoopRV, int]:
+        """
+        Pads loop of given size until it is divisble into tile_size.
+        If the given size of the loop is greater than tile size. Do not pad.
+
+        example, loop_size = 5, tile_size = 32. loop_size --> 8
+                loop_size = 5, tile_size = 36. loop_size --> 6
+                loop_size = 8, tile_size = 32. loop_size --> 8
+                loop_size = 33, tile_size = 32. loop_size --> 33
+
+        Returns padded loopRV and the new size
+        """
+        if tile_size % initial_size == 0:
+            return loop, int(initial_size)
+
+        if initial_size > tile_size or initial_size == tile_size:
+            return loop, int(initial_size)
+
+        # if initial_size > tile_size return without change, factor = 1
+        size = initial_size
+        while (tile_size % size) % tile_size > 0:
+            size += 1
+
+        return pad_dimension_to_at_least_number(loop, size), int(size)
+
+    def spin_out_factor(
+        loops: List[LoopRV], loop_extants: List[int], index: int, factor_needed: int
+    ) -> Tuple[List[LoopRV], List[int], int]:
+        """
+        Factor out loop dimensions to reach the requested factor. Updates the schedule in-place.
+
+        E.g. say we want to factors which eventually multiply to 32 (factor_needed).
+
+        Say we have the index we chose is a loop with an extant of 8.
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed = 32, index = 3
+            - 8 divides into 32 so we just split up the loop into two loops with extants 1 and 8.
+            - we then keep the 1-loop in place and move the new 8-loop to back of the list of loops
+            - ending loops / loop_extants = [3, 32, 6, 1, 8], remaining_factor_needed = 32 / 8 = 4
+
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed=32, index = 0
+            - 3 does not divide 32, so we pad until the extant divides 32, e.g. 4
+            - we then split up the loop into extants 1 and 4, moving the 4 to the back
+            - ending loops / loop_extants = [1, 32, 6, 8, 4], remaining_factor_needed = 32 / 4 = 8
+
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed=5, index = 3
+            - 8 is larger than 5 so we immediately do the splitting routine.
+            - the 8 extant loop becomes loops with extants 2 and 5
+            - ending loops / loop_extants = [1, 32, 6, 2, 5], remaining_factor_needed = 5 / 5 = 1
+
+        After updating loop ordering in place, returns the new list of loops, extants, and the
+        remaining factor needed.
+        """
+        cur_loop = loops[index]
+        cur_extant = loop_extants[index]
+
+        # Pad loops to divide evenly for factors needed, and split
+        new_loop, new_size = pad_dimension_to_factor_of_tile_size(
+            cur_loop, cur_extant, tile_size=factor_needed
+        )
+
+        split_factor = min(new_size, factor_needed)
+        new_loop_split, factored_loop = sch.split(new_loop, [None, split_factor])
+        factor_needed = factor_needed // split_factor
+
+        # update caching
+        loops[index] = new_loop_split
+        loops.append(factored_loop)
+
+        loop_extants[index] = math.ceil(new_size / split_factor)
+        loop_extants.append(split_factor)
+
+        sch.reorder(*loops)
+        return loops, loop_extants, factor_needed
+
+    def factor_dim_in_order(
+        indices: Sequence[int],
+        loops: List[LoopRV],
+        cur_loop_extants: List[int],
+        work_needed_inner_loop: int = tile_size,
+    ):
+        """Factors out the loops in the order of indices until we reach needed work.
+
+        Adds new loop factors to the back in reverse order of access.
+        """
+        for i in indices:
+            loops, cur_loop_extants, work_needed_inner_loop = spin_out_factor(
+                loops, cur_loop_extants, i, work_needed_inner_loop
+            )
+            if work_needed_inner_loop == 1:
+                break
+        return loops, cur_loop_extants
+
+    def get_high_level_loop_structure(block):
+        """Runs the factorization described above."""
+        # index 0 ... rank - 1 will always correspond to original loops
+        # perhaps after they have been factored.
+        loops = sch.get_loops(block)
+        cur_loop_extants = list(input_shape)
+
+        # Factor dim0 tile size and fuse things together
+        loops, cur_loop_extants = factor_dim_in_order(
+            range(rank - 1, -1, -1),
+            loops,
+            cur_loop_extants,
+            work_needed_inner_loop=tile_size,
+        )
+        # The factors which multiply to tile_size are now in back of our
+        # list of loops. However because we added them by traversing the inner
+        # dimensions, they are actually reversed order to guarantee the best access
+        # so reorder so reorder before fusing.
+        loops = loops[:rank] + loops[rank:][::-1]
+        cur_loop_extants = cur_loop_extants[:rank] + cur_loop_extants[rank::-1]
+        sch.reorder(*loops)
+        dim0_loop_tiled = sch.fuse(*loops[rank:])
+        loops = loops[:rank]
+        loops.append(dim0_loop_tiled)
+        cur_loop_extants = cur_loop_extants[:rank]
+        cur_loop_extants.append(tile_size)
+
+        # Same thing with dim1
+        # [:rank + 1], since we placed dim0_loop_tiled in the end which we want to keep
+        loops, cur_loop_extants = factor_dim_in_order(
+            (
+                src_layout.index(dst_layout[loop_index_dst])
+                for loop_index_dst in range(rank - 1, -1, -1)
+            ),
+            loops,
+            cur_loop_extants,
+            work_needed_inner_loop=tile_size,
+        )
+        loops = loops[: rank + 1] + loops[rank + 1 :][::-1]
+        cur_loop_extants = cur_loop_extants[: rank + 1] + cur_loop_extants[rank + 1 :: -1]
+        sch.reorder(*loops)
+        dim1_loop_tiled = sch.fuse(*loops[rank + 1 :])
+        loops = loops[: rank + 1]
+        loops.append(dim1_loop_tiled)
+        cur_loop_extants = cur_loop_extants[: rank + 1]
+        cur_loop_extants.append(tile_size)
+
+    rank = len(src_layout)
+
+    # Outer loop structure of read block matches that of src_layout
+    # E.g. if input_shape is [4, 6, 8]. Loops for read block will be
+    # for i, j, k in T.grid(4, 6, 8):
+    #     ...
+    # Read block will read from global memory coalesced at the start
+    # Assume write to output global memory is coalesced in block_write
+    block_read = sch.cache_read(block_write, 0, "shared")
+
+    # Here we have [loop1, loop2, loop3 ... dim0_tiled, dim1_tiled]
+    get_high_level_loop_structure(block_read)
+    loops = sch.get_loops(block_read)
+
+    # If there are insufficient elements, than dim1_tiled or dim0_tiled might be too small
+    # In all likelihood you should use a smaller tile, but I don't want things to crash.
+    loops[-1] = pad_dimension_to_at_least_number(loops[-1], tile_size)
+    loops[-2] = pad_dimension_to_at_least_number(loops[-2], tile_size)
+
+    # We want the dim0 and dim1 parent loops to be the inner most. Right now dim1 is inner-msot
+    # and we just need to move dim0 in (last dimension of dst).
+    # Recall right now structure is at least [l1 l2 ... ln, dim0_tiled, dim1_tiled]
+    # where n >= 2.
+    dim0_loop_index = src_layout.index(dst_layout[-1])
+    dim0_loop = loops.pop(dim0_loop_index)
+    loops = loops[:-3] + [dim0_loop, loops[-3]] + loops[-2:]
+    sch.reorder(*loops)
+
+    # After this: [outer_loop (block binding), dim0_tiled, dim1_tiled]
+    outer_loop = sch.fuse(*loops[:-2])
+
+    # Now that we have the high level loop structure, we can use reverse_compute_at magic
+    # To get the proper loop structure for writing! This is also as coalesced as possible
+    # already.
+    sch.reverse_compute_at(block_write, outer_loop)
+
+    # Fuse all inner loops for the write into 2 loops, grab inner loops for both read
+    # and write block which have locality (we will bind these to threadIdx)
+    fused_write_loop = sch.fuse(*sch.get_loops(block_write)[1:])
+    _, inner_write_loop = sch.split(fused_write_loop, [None, tile_size])
+    inner_read_loop = sch.get_loops(block_read)[-2]
+
+    sch.bind(loop=outer_loop, thread_axis="blockIdx.x")
+    sch.bind(loop=inner_write_loop, thread_axis="threadIdx.x")
+    sch.bind(loop=inner_read_loop, thread_axis="threadIdx.x")
+
+
+def auto_inline(sch, start_block):
+    # Autoinlines given block into consumers, and repeats process for consumer of block
+    # Done by default for injective schedules.
+    fringe = deque([start_block])
+    visited = set()
+    while len(fringe) > 0:
+        cur_block = fringe.popleft()
+        if cur_block in visited:
+            continue
+        else:
+            visited.add(cur_block)
+
+        consumer_blocks = sch.get_consumers(cur_block)
+        if len(consumer_blocks) >= 1:
+            fringe.extend(consumer_blocks)
+            sch.compute_inline(cur_block)
+        else:
+            # Found output block, no more inlining needed
+            return cur_block
+
+
+@tvm.register_func("meta_schedule.cuda.layout_transform")
+def cuda_layout_transform_schedule_rule(sch, block):
+    # params: input_buffer, output_buffer
+    params = sch.mod["main"].params
+    input_buffer = sch.mod["main"].buffer_map[params[0]]
+    output_buffer = sch.mod["main"].buffer_map[params[1]]
+
+    # Info needed for tiling
+    input_shape = [int(dim) for dim in input_buffer.shape]
+    output_shape = [int(dim) for dim in output_buffer.shape]
+    src_layout = sch.get_sref(block).stmt.annotations["src_layout"]
+    dst_layout = sch.get_sref(block).stmt.annotations["dst_layout"]
+
+    # For each schedule we also want to inline each stage as would be done in normal circumstances
+    # to prevent extraneous memory access.
+    block = auto_inline(sch, block)
+
+    schedules = []
+
+    # Always include the default schedules which will be handled via AutoBind schedule rule
+    schedules.append(sch)
+
+    # Tile size 2,3,4...32 as tile size of 1 has no coaslescing.
+    for tile_size in range(2, 33):

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on pull request #14167: [CUDA][Schedule] Better Layout Transform Schedules

Posted by "masahi (via GitHub)" <gi...@apache.org>.
masahi commented on PR #14167:
URL: https://github.com/apache/tvm/pull/14167#issuecomment-1454171002

   > Shared memory — Bank Conflicts
   
   Have you tried using `storage_align` sch primitive? It achieves similar things for shmem padding for power of two size shmem. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on pull request #14167: [DRAFT][CUDA][Schedule] Better Layout Transform Schedules

Posted by "AndrewZhaoLuo (via GitHub)" <gi...@apache.org>.
AndrewZhaoLuo commented on PR #14167:
URL: https://github.com/apache/tvm/pull/14167#issuecomment-1454134573

   This is now ready for review. cc @tkonolige @masahi 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #14167: [CUDA][Schedule] Better Layout Transform Schedules

Posted by "AndrewZhaoLuo (via GitHub)" <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #14167:
URL: https://github.com/apache/tvm/pull/14167#discussion_r1130178154


##########
python/tvm/meta_schedule/schedule/cuda/layout_transform.py:
##########
@@ -0,0 +1,321 @@
+import tvm
+from tvm import topi
+import math
+from typing import List, Sequence, Tuple
+
+from tvm.tir.schedule import BlockRV, ExprRV, LoopRV
+from collections import deque
+
+
+def tile_layout_transform(
+    sch: tvm.tir.Schedule,
+    block_write: BlockRV,
+    src_layout: str,
+    dst_layout: str,
+    input_shape: List[int],
+    tile_size: ExprRV,
+):
+    """
+    High level tiling for layout transform block.
+    """
+    
+    ## Tiling layout transforms:
+    # Assume we have an input shape of [A, B, C, D] and want to layout transform
+    # ABCD --> DBAC so the output shape would be [D, B, A, C].
+    #
+    # Consider reading from the input buffer in a cache-friendly fashion on CPU. We would
+    # expect a loop structure like:
+    # lAr, lBr, lCr, lDr = T.grid(A, B, C, D)
+    #
+    # Meanwhile consider writing to the output buffer in a cache-friendly fashion on CPU:
+    # lDw, lBw, lAw, lCw = T.grid(D, B, A, C)
+    #
+    # Clearly in many scenarios it is impossible to guarantee contiguous writes and reads
+    # within a single loop. Due to non-adjacent dimensions. Instead we work on transposing some 
+    # small sub-tensor of our input writing and then reading from shared memory. We must now 
+    # construct our submatrix so that reading and writing can both be done with some contiguous  
+    # access in global memory.
+    #
+    # Consider the case of a 2D transpose. For example [1024, 2048] -> [2048, 1024].
+    # We note that if we deal with a submatrix of shape [32, 32] which corresponds 
+    # to the dimension of our input tensor, then rows of the submatrix are contiguous
+    # in the input tensor. Meanwhile, columns of our submatrix are contiguous in our 
+    # output vector. Therefore, with this tile shape we have opportunity to read
+    # contiguously in our input tensor and write to shared memory, and write contiguously
+    # to our output tensor.
+    #
+    # The multiple dimensional case has a similar analogue. We want to allocate shared
+    # memory per block of [`tile_size`, `tile_size`]. We want the inner most dimension
+    # of our shared memory to correspond to contiguous reads from the input tensor and
+    # the outer dimension to correspond to contiguous writes into the output tensor.
+    # 
+    # In terms of the loop structure reading from the input tensor, the inner most loops
+    # of our tile must correspond to the inner most dimensions of the input shape,  
+    # while the outer dimensions correspond to the inner most dimensions of the output shape.
+    # To obtain an inner tile with this loop structure we factor out a contiguous `tile_size`
+    # chunk of our loop in the shape of interest. 
+    # 
+    # An example is probably best to show this idea:
+    # Let's say we want a layout transform of ABCD --> DCAB. With shape
+    # [1024_a, 2_b, 32_c, 8_d] --> [8_d, 32_c, 1024_a, 2_b]
+    #
+    # And tile size 32.
+    #
+    # Then we initially have a coalesced-read loop pattern of:
+    # T.grid(1024_a, 2_b, 32_c, 8_d)
+    # 
+    # To obtain an inner tile of 32, we factor 4 from 32_c and 8 from 8_d:
+    # T.grid(1024_a, 2_b, 8_c1, 1_d1, 4_c2t, 8_d2t)
+    # T.grid(1024_a, 2_b, 8_cr, 1_dr, 32_dim1)
+    #
+    # To obtain an outer tile of 32, we factor from B then A to follow contiguous write 
+    # pattern:
+    #
+    # T.grid(64_a1, 1_b1, 8_cr, 1_dr, 16_a2t, 2_b2t, 32_dim1)
+    # T.grid(64_ar, 1_br, 8_cr, 1_dr, 32_dim0, 32_dim1)
+    #
+    # Which allows us to read a tile with our wanted properties.
+    # For writing we use the existing analysis infrastructure to generate the proper structure for writing.
+
+    def pad_dimension_to_at_least_number(loop: LoopRV, requested_size: int):
+        """E.g. if loop has extant of 8 but we want 10, returns size 10 loop with padding."""
+        l1, l2 = sch.split(loop, [None, requested_size])
+        return sch.fuse(l1, l2)
+
+    def pad_dimension_to_factor_of_tile_size(
+        loop: LoopRV, initial_size: int, tile_size: int = tile_size
+    ) -> Tuple[LoopRV, int]:
+        """
+        Pads loop of given size until it is divisble into tile_size.
+        If the given size of the loop is greater than tile size. Do not pad.
+
+        example, loop_size = 5, tile_size = 32. loop_size --> 8
+                loop_size = 5, tile_size = 36. loop_size --> 6
+                loop_size = 8, tile_size = 32. loop_size --> 8
+                loop_size = 33, tile_size = 32. loop_size --> 33
+
+        Returns padded loopRV and the new size
+        """
+        if tile_size % initial_size == 0:
+            return loop, int(initial_size)
+
+        if initial_size > tile_size or initial_size == tile_size:
+            return loop, int(initial_size)
+
+        # if initial_size > tile_size return without change, factor = 1
+        size = initial_size
+        while (tile_size % size) % tile_size > 0:
+            size += 1
+
+        return pad_dimension_to_at_least_number(loop, size), int(size)
+
+    def spin_out_factor(
+        loops: List[LoopRV], loop_extants: List[int], index: int, factor_needed: int
+    ) -> Tuple[List[LoopRV], List[int], int]:
+        """
+        Factor out loop dimensions to reach the requested factor. Updates the schedule in-place.
+
+        E.g. say we want to factors which eventually multiply to 32 (factor_needed).
+
+        Say we have the index we chose is a loop with an extant of 8.
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed = 32, index = 3
+            - 8 divides into 32 so we just split up the loop into two loops with extants 1 and 8.
+            - we then keep the 1-loop in place and move the new 8-loop to back of the list of loops
+            - ending loops / loop_extants = [3, 32, 6, 1, 8], remaining_factor_needed = 32 / 8 = 4
+
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed=32, index = 0
+            - 3 does not divide 32, so we pad until the extant divides 32, e.g. 4
+            - we then split up the loop into extants 1 and 4, moving the 4 to the back
+            - ending loops / loop_extants = [1, 32, 6, 8, 4], remaining_factor_needed = 32 / 4 = 8
+
+        E.g. loops / loop_extants = [3, 32, 6, 8], factor_needed=5, index = 3
+            - 8 is larger than 5 so we immediately do the splitting routine.
+            - the 8 extant loop becomes loops with extants 2 and 5
+            - ending loops / loop_extants = [1, 32, 6, 2, 5], remaining_factor_needed = 5 / 5 = 1
+
+        After updating loop ordering in place, returns the new list of loops, extants, and the
+        remaining factor needed.
+        """
+        cur_loop = loops[index]
+        cur_extant = loop_extants[index]
+
+        # Pad loops to divide evenly for factors needed, and split
+        new_loop, new_size = pad_dimension_to_factor_of_tile_size(
+            cur_loop, cur_extant, tile_size=factor_needed
+        )
+
+        split_factor = min(new_size, factor_needed)
+        new_loop_split, factored_loop = sch.split(new_loop, [None, split_factor])
+        factor_needed = factor_needed // split_factor
+
+        # update caching
+        loops[index] = new_loop_split
+        loops.append(factored_loop)
+
+        loop_extants[index] = math.ceil(new_size / split_factor)
+        loop_extants.append(split_factor)
+
+        sch.reorder(*loops)
+        return loops, loop_extants, factor_needed
+
+    def factor_dim_in_order(
+        indices: Sequence[int],
+        loops: List[LoopRV],
+        cur_loop_extants: List[int],
+        work_needed_inner_loop: int = tile_size,
+    ):
+        """Factors out the loops in the order of indices until we reach needed work.
+        
+        Adds new loop factors to the back in reverse order of access.
+        """
+        for i in indices:
+            loops, cur_loop_extants, work_needed_inner_loop = spin_out_factor(
+                loops, cur_loop_extants, i, work_needed_inner_loop
+            )
+            if work_needed_inner_loop == 1:
+                break
+        return loops, cur_loop_extants
+
+    def get_high_level_loop_structure(block):
+        """Runs the factorization described above."""
+        # index 0 ... rank - 1 will always correspond to original loops
+        # perhaps after they have been factored.
+        loops = sch.get_loops(block)
+        cur_loop_extants = list(input_shape)
+
+        # Factor dim0 tile size and fuse things together
+        loops, cur_loop_extants = factor_dim_in_order(
+            range(rank - 1, -1, -1),
+            loops,
+            cur_loop_extants,
+            work_needed_inner_loop=tile_size,
+        )
+        # The factors which multiply to tile_size are now in back of our
+        # list of loops. However because we added them by traversing the inner
+        # dimensions, they are actually reversed order to guarantee the best access
+        # so reorder so reorder before fusing.
+        loops = loops[:rank] + loops[rank:][::-1]
+        cur_loop_extants = cur_loop_extants[:rank] + cur_loop_extants[rank::-1]
+        sch.reorder(*loops)
+        dim0_loop_tiled = sch.fuse(*loops[rank:])
+        loops = loops[:rank]
+        loops.append(dim0_loop_tiled)
+        cur_loop_extants = cur_loop_extants[:rank]
+        cur_loop_extants.append(tile_size)
+
+        # Same thing with dim1
+        # [:rank + 1], since we placed dim0_loop_tiled in the end which we want to keep
+        loops, cur_loop_extants = factor_dim_in_order(
+            (
+                src_layout.index(dst_layout[loop_index_dst])
+                for loop_index_dst in range(rank - 1, -1, -1)
+            ),
+            loops,
+            cur_loop_extants,
+            work_needed_inner_loop=tile_size,
+        )
+        loops = loops[: rank + 1] + loops[rank + 1 :][::-1]
+        cur_loop_extants = cur_loop_extants[: rank + 1] + cur_loop_extants[rank + 1 :: -1]
+        sch.reorder(*loops)
+        dim1_loop_tiled = sch.fuse(*loops[rank + 1 :])
+        loops = loops[: rank + 1]
+        loops.append(dim1_loop_tiled)
+        cur_loop_extants = cur_loop_extants[: rank + 1]
+        cur_loop_extants.append(tile_size)
+
+    rank = len(src_layout)
+
+    # Outer loop structure of read block matches that of src_layout
+    # E.g. if input_shape is [4, 6, 8]. Loops for read block will be
+    # for i, j, k in T.grid(4, 6, 8):
+    #     ...
+    # Read block will read from global memory coalesced at the start
+    # Assume write to output global memory is coalesced in block_write
+    block_read = sch.cache_read(block_write, 0, "shared")
+
+    # Here we have [loop1, loop2, loop3 ... dim0_tiled, dim1_tiled]
+    get_high_level_loop_structure(block_read)
+    loops = sch.get_loops(block_read)
+
+    # If there are insufficient elements, than dim1_tiled or dim0_tiled might be too small
+    # In all likelihood you should use a smaller tile, but I don't want things to crash.
+    loops[-1] = pad_dimension_to_at_least_number(loops[-1], tile_size)
+    loops[-2] = pad_dimension_to_at_least_number(loops[-2], tile_size)
+
+    # We want the dim0 and dim1 parent loops to be the inner most. Right now dim1 is inner-msot
+    # and we just need to move dim0 in (last dimension of dst).
+    # Recall right now structure is at least [l1 l2 ... ln, dim0_tiled, dim1_tiled]
+    # where n >= 2.
+    dim0_loop_index = src_layout.index(dst_layout[-1])
+    dim0_loop = loops.pop(dim0_loop_index)
+    loops = loops[:-3] + [dim0_loop, loops[-3]] + loops[-2:]
+    sch.reorder(*loops)
+
+    # After this: [outer_loop (block binding), dim0_tiled, dim1_tiled]
+    outer_loop = sch.fuse(*loops[:-2])
+
+    # Now that we have the high level loop structure, we can use reverse_compute_at magic
+    # To get the proper loop structure for writing! This is also as coalesced as possible
+    # already.
+    sch.reverse_compute_at(block_write, outer_loop)
+
+    # Fuse all inner loops for the write into 2 loops, grab inner loops for both read
+    # and write block which have locality (we will bind these to threadIdx)
+    fused_write_loop = sch.fuse(*sch.get_loops(block_write)[1:])
+    _, inner_write_loop = sch.split(fused_write_loop, [None, tile_size])
+    inner_read_loop = sch.get_loops(block_read)[-2]
+
+    sch.bind(loop=outer_loop, thread_axis="blockIdx.x")
+    sch.bind(loop=inner_write_loop, thread_axis="threadIdx.x")
+    sch.bind(loop=inner_read_loop, thread_axis="threadIdx.x")
+
+def auto_inline(sch, start_block):
+    # Autoinlines given block into consumers, and repeats process for consumer of block
+    # Done by default for injective schedules.
+    fringe = deque([start_block])
+    visited = set()
+    while len(fringe) > 0:
+        cur_block = fringe.popleft()
+        if cur_block in visited:
+            continue
+        else:
+            visited.add(cur_block)
+
+        consumer_blocks = sch.get_consumers(cur_block)
+        if len(consumer_blocks) >= 1:
+            fringe.extend(consumer_blocks)
+            sch.compute_inline(cur_block)
+        else:
+            # Found output block, no more inlining needed
+            return cur_block
+
+
+@tvm.register_func("meta_schedule.cuda.layout_transform")
+def cuda_layout_transform_schedule_rule(sch, block):
+    # params: input_buffer, output_buffer
+    params = sch.mod["main"].params
+    input_buffer = sch.mod["main"].buffer_map[params[0]]
+    output_buffer = sch.mod["main"].buffer_map[params[1]]
+    
+    # Info needed for tiling
+    input_shape = [int(dim) for dim in input_buffer.shape]
+    output_shape = [int(dim) for dim in output_buffer.shape]
+    src_layout = sch.get_sref(block).stmt.annotations["src_layout"]
+    dst_layout = sch.get_sref(block).stmt.annotations["dst_layout"]
+
+    # For each schedule we also want to inline each stage as would be done in normal circumstances
+    # to prevent extraneous memory access.
+    block = auto_inline(sch, block)

Review Comment:
   It does not appear AutoInline meta schedule rule get's applied automatically so I manually did it. It does seem to apply to fusion before the layout transform if that makes sense 
   
   That is if you have 
   
   a = x + y * z
   b = layout_transform(a)
   c = c * c + c 
   
   Then a's operations will be fused into b but c will not be fused into c.
   
   Upon closer examination, it appears some PostProcs (RewriteCooperativeFetching) expect thread binding to be the last terms in the trace (otherwise it may fetch a loop which does not exist in the final schedule) so I am unsure of the best thing to do here. I would expect Autoinling to be automatic though.
   
   For now I have relaxed RewriteCooperativeFetching behavior though and will add more tests to make sure fusion is working as intended for now.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #14167: [CUDA][Schedule] Better Layout Transform Schedules

Posted by "AndrewZhaoLuo (via GitHub)" <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #14167:
URL: https://github.com/apache/tvm/pull/14167#discussion_r1130176681


##########
include/tvm/topi/transform.h:
##########
@@ -1596,6 +1596,7 @@ inline Array<Tensor> meshgrid(const Array<Tensor>& inputs, const std::string& in
  */
 inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
                                const std::string& dst_layout,
+                               const std::string schedule_rule = "None",

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #14167: [CUDA][Schedule] Better Layout Transform Schedules

Posted by "AndrewZhaoLuo (via GitHub)" <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #14167:
URL: https://github.com/apache/tvm/pull/14167#discussion_r1142502071


##########
src/meta_schedule/postproc/rewrite_cooperative_fetch.cc:
##########
@@ -39,7 +39,17 @@ Optional<Integer> ParseThreadBinding(const Schedule& sch, const Instruction& ins
   if (thread_axis != axis) {
     return NullOpt;
   }
-  return Downcast<Integer>(sch->Get(Downcast<LoopRV>(inst->inputs[0]))->extent);
+
+  try {
+    return Downcast<Integer>(sch->Get(Downcast<LoopRV>(inst->inputs[0]))->extent);
+  } catch (const std::exception& e) {

Review Comment:
   I have removed this. 
   
   Before the issue was when tiling, I bind loops. Then I apply fusion which possibly destroyed the loops I originally bound to, which the pass does not expect.
   
   Now I apply fusion then I tile so the bound loops will always exist in the schedule so this is no longer needed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org