You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2021/02/05 22:01:37 UTC

[tvm] 02/02: update

This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch custom_tile_size
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit b1fcd45d2f9c04d2ff544c390a4f0bb7dd43c692
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Fri Feb 5 22:01:11 2021 +0000

    update
---
 tutorials/autotvm/tune_conv2d_cuda.py | 14 +++++++++++---
 1 file changed, 11 insertions(+), 3 deletions(-)

diff --git a/tutorials/autotvm/tune_conv2d_cuda.py b/tutorials/autotvm/tune_conv2d_cuda.py
index a00fe5f..aa7449e 100644
--- a/tutorials/autotvm/tune_conv2d_cuda.py
+++ b/tutorials/autotvm/tune_conv2d_cuda.py
@@ -100,9 +100,9 @@ def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
     rc, ry, rx = s[conv].op.reduce_axis
 
     cfg = autotvm.get_config()
-    cfg.define_split("tile_f", f, num_outputs=4)          # filter / output channel
-    cfg.define_split("tile_y", y, num_outputs=4)          # height
-    cfg.define_split("tile_x", x, num_outputs=4)          # width
+    cfg.define_split("tile_f", f, num_outputs=4)          # filter / output channel -> blockIdx.z, vthread, threadIdx.z, thread_inner
+    cfg.define_split("tile_y", y, num_outputs=4)          # height                  -> blockIdx.y, vthread, threadIdx.y, thread_inner
+    cfg.define_split("tile_x", x, num_outputs=4)          # width                   -> blockIdx.x, vthread, threadIdx.x, thread_inner
     cfg.define_split("tile_rc", rc, num_outputs=3)        # input channel
     cfg.define_split("tile_ry", ry, num_outputs=3)        # kernel width
     cfg.define_split("tile_rx", rx, num_outputs=3)        # kernel height
@@ -110,6 +110,13 @@ def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
     cfg.define_knob("unroll_explicit", [0])               # disable auto unroll
     ##### space definition end #####
 
+    # Constraints  (read this from deviceQuery)
+    #  blockIdx.z <= 2^31, blockIdx.y < 2^16, blockIdx.x < 2^16        (Max dimension size of a grid size)
+    #  threadIdx.z <= 1024, threadIdx.y <= 1024 , threadIdx.z <=1024   (Max dimension size of a thread block)
+    #  threadIdx.z * threadIdx.y * threadIdx.z <= 1024                 (Maximum number of threads per block)
+    #
+    #  input buffer + weight buffer in each block < 49152 bytes        (Total amount of shared memory per block)
+
     # inline padding
     pad_data = s[conv].op.input_tensors[0]
     s[pad_data].compute_inline()
@@ -221,6 +228,7 @@ best_config = dispatch_context.query(task.target, task.workload)
 
 # Plug your own tile sizes
 #best_config._entity_map['tile_f'] = SplitEntity([-1, 2, 8, 8])
+#                                [-1, 2, 8, 8] will be mapped to [blockId, vthread, threadId, local_id]
 
 print("\nBest config:")
 print(best_config)