You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/06/09 15:54:22 UTC

[GitHub] [incubator-tvm] t-vi opened a new pull request #5752: Make batch matrix multiplication on GPU tunable

t-vi opened a new pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752


   Here is a PR making the hardcoded splits in the GPU batch matrix
   multiplication tunable through autotvm.
   
   (This is primarily aimed at the AMD GPU backend and done as part
   of a project for AMD, but should work for all users of the GPU
   schedule like the CUDA/nvptx backend as well.)
   
   I get reasonably good results for the workloads I tried - on par to 1.5x
   slower for batch matmul, but quite a speedup when integrated into
   BERT.
   
   This is my first PR adding autotvm tunings to the public repo, so please bear with me.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] t-vi commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
t-vi commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437993129



##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:
+            # llvm-based backends cannot do non-explicit unrolling
+            cfg.define_knob("unroll_explicit", [1])
+        else:
+            cfg.define_knob("unroll_explicit", [0, 1])
+
+        if cfg.is_fallback:

Review comment:
       What else can I do to make it prettier?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] t-vi commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
t-vi commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437701346



##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:
+            # llvm-based backends cannot do non-explicit unrolling
+            cfg.define_knob("unroll_explicit", [1])
+        else:
+            cfg.define_knob("unroll_explicit", [0, 1])
+
+        if cfg.is_fallback:

Review comment:
       As much fun as geeking out over this is
   I don't think I need the fallback because
   - This schedule works without,
   - other schedules work without,
   - define_knob sets the fallback to the first option by setting `cfg._entity_map` which is what is queried by `__getitem__`.
   
   I cannot move the definition of `tile_k` to before k is defined.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437627640



##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:
+            # llvm-based backends cannot do non-explicit unrolling
+            cfg.define_knob("unroll_explicit", [1])
+        else:
+            cfg.define_knob("unroll_explicit", [0, 1])
+
+        if cfg.is_fallback:

Review comment:
       Hmm it's interesting. It means this parameter is never effective in those schedules...others like conv2d_direct uses it like:
   
   `s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)`
   
   I think you can also fix them if you prefer to.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#issuecomment-642797574


   Thanks @t-vi  for the contribution, and @comaniac for reviewing!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] t-vi commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
t-vi commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437993129



##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:
+            # llvm-based backends cannot do non-explicit unrolling
+            cfg.define_knob("unroll_explicit", [1])
+        else:
+            cfg.define_knob("unroll_explicit", [0, 1])
+
+        if cfg.is_fallback:

Review comment:
       What else do you think you need to happen?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] t-vi commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
t-vi commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437701346



##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:
+            # llvm-based backends cannot do non-explicit unrolling
+            cfg.define_knob("unroll_explicit", [1])
+        else:
+            cfg.define_knob("unroll_explicit", [0, 1])
+
+        if cfg.is_fallback:

Review comment:
       As much fun as geeking out over this is
   - I don't think I need the fallback because
     - This schedule works without,
     - other schedules work without,
     - define_knob sets the fallback to the first option by setting `cfg._entity_map` which is what is queried by `__getitem__`.
   
   I cannot move the definition of `tile_k` to before k is defined.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] t-vi commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
t-vi commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437650025



##########
File path: python/tvm/relay/op/strategy/cuda.py
##########
@@ -463,9 +463,9 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
     """batch_matmul cuda strategy"""
     strategy = _op.OpStrategy()
     strategy.add_implementation(
-        wrap_compute_batch_matmul(topi.nn.batch_matmul),
+        wrap_compute_batch_matmul(topi.cuda.batch_matmul),
         wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
-        name="batch_matmul.cuda",
+        name="batch_matmul.gpu",

Review comment:
       Yeah, I changed it to CUDA for now.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] t-vi commented on pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
t-vi commented on pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#issuecomment-641488953


   > Could you provide some evaluation results for this tunable template? For example, the tuning space size, the best performance after 3,000 trials (or less if the tuning space is smaller than 3,000), and the performance of fallback config.
   
   After 2000 trials I get to speed roughly on par with PyTorch (using rocblas) on a Radeon VII for both configurations mentioned in the thread linked about. On a GTX 1080TI, I'm getting to 1.5x-ish results (so 1.5x slower than PyTorch),  down from "a lot more", even if I didn't measure as much.
   For the end-to-end bert that I'm running with a few upcoming PRs, I'm going from not competitive performance to significantly faster than vanilla PyTorch on the first example from the PyTorch JIT tutorial.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] t-vi commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
t-vi commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437616265



##########
File path: python/tvm/relay/op/strategy/cuda.py
##########
@@ -463,9 +463,9 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
     """batch_matmul cuda strategy"""
     strategy = _op.OpStrategy()
     strategy.add_implementation(
-        wrap_compute_batch_matmul(topi.nn.batch_matmul),
+        wrap_compute_batch_matmul(topi.cuda.batch_matmul),
         wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
-        name="batch_matmul.cuda",
+        name="batch_matmul.gpu",

Review comment:
       I don't insist on it, but we recently (with the great(!) new softmax schedule) had some confusion about cuda schedules being used on non-cuda.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] t-vi commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
t-vi commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437621243



##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:

Review comment:
       See, one might easily think so. But I know that ROCm uses the schedule. My impression is that calling the shared schedules GPU would make that more clear.
   

##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:
+            # llvm-based backends cannot do non-explicit unrolling
+            cfg.define_knob("unroll_explicit", [1])
+        else:
+            cfg.define_knob("unroll_explicit", [0, 1])
+
+        if cfg.is_fallback:

Review comment:
       Good catch, thank you!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437624392



##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:

Review comment:
       Yeah I totally agree. That's why I suggested having another RFC/PR to fix all cases. On the other hand, simply specify `gpu` in `topi/cuda` and cover both CUDA and ROCM are also confusion to me.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] t-vi commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
t-vi commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437650755



##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:
+            # llvm-based backends cannot do non-explicit unrolling
+            cfg.define_knob("unroll_explicit", [1])
+        else:
+            cfg.define_knob("unroll_explicit", [0, 1])
+
+        if cfg.is_fallback:

Review comment:
       I think it should be OK now.
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] t-vi commented on pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
t-vi commented on pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#issuecomment-641237043


   @tqchen @comaniac This is related the [(batch) matmul discussion on the forums](https://discuss.tvm.ai/t/optimizing-matrix-multiplication-for-gpu/4212/18) you participated in.
   @masahi 
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] t-vi commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
t-vi commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437718505



##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:
+            # llvm-based backends cannot do non-explicit unrolling
+            cfg.define_knob("unroll_explicit", [1])
+        else:
+            cfg.define_knob("unroll_explicit", [0, 1])
+
+        if cfg.is_fallback:

Review comment:
       I moved the k and `tile_k` it appears to work and looks much prettier indeed. Thank you for insisting on it.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437696668



##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:
+            # llvm-based backends cannot do non-explicit unrolling
+            cfg.define_knob("unroll_explicit", [1])
+        else:
+            cfg.define_knob("unroll_explicit", [0, 1])
+
+        if cfg.is_fallback:

Review comment:
       Sorry just found that the `unroll_explicit` is still missing in fallback. Also would you mind moving L103 (define `tile_k`) up together with other tuning parameters?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437706120



##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:
+            # llvm-based backends cannot do non-explicit unrolling
+            cfg.define_knob("unroll_explicit", [1])
+        else:
+            cfg.define_knob("unroll_explicit", [0, 1])
+
+        if cfg.is_fallback:

Review comment:
       If the first two points matter, we should just remove that parameter to make the tuning space more efficient. Meanwhile, I appreciate the third point that I didn't realize before.
   
   For `tile_k`, I think you should be able to move the definition of `k` up as well. It should be safe because the tuning parameter must be static so it won't depend on other parameters. In this way, we can also put the fallback configs together to make it clearer.
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] t-vi edited a comment on pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
t-vi edited a comment on pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#issuecomment-641488953


   Thank you for the feedback @comaniac . I changed all as to per your suggestion except for the fallback config in unroll explcit which seems to be implicit (ha!). 
   
   > Could you provide some evaluation results for this tunable template? For example, the tuning space size, the best performance after 3,000 trials (or less if the tuning space is smaller than 3,000), and the performance of fallback config.
   
   After 2000 trials I get to speed roughly on par with PyTorch (using rocblas) on a Radeon VII for both configurations mentioned in the thread linked about. On a GTX 1080TI, I'm getting to 1.5x-ish results (so 1.5x slower than PyTorch),  down from "a lot more", even if I didn't measure as much.
   For the end-to-end bert that I'm running with a few upcoming PRs, I'm going from not competitive performance to significantly faster than vanilla PyTorch on the first example from the PyTorch JIT tutorial.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] t-vi commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
t-vi commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437624803



##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:
+            # llvm-based backends cannot do non-explicit unrolling
+            cfg.define_knob("unroll_explicit", [1])
+        else:
+            cfg.define_knob("unroll_explicit", [0, 1])
+
+        if cfg.is_fallback:

Review comment:
       Upon looking at this, I should use unroll explicit, not just define it.
   What I'm unsure about is whether I need to define it in the fallback - conv2d_direct doesn't define it and it seems to work well.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437618804



##########
File path: python/tvm/relay/op/strategy/cuda.py
##########
@@ -463,9 +463,9 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
     """batch_matmul cuda strategy"""
     strategy = _op.OpStrategy()
     strategy.add_implementation(
-        wrap_compute_batch_matmul(topi.nn.batch_matmul),
+        wrap_compute_batch_matmul(topi.cuda.batch_matmul),
         wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
-        name="batch_matmul.cuda",
+        name="batch_matmul.gpu",

Review comment:
       Yeah I realized that ROCM uses many CUDA schedules and it results in confusion. We can find lots of CUDA schedules have `if target == 'rocm'`. However, it would be better to have a separate PR to correct this if everone agrees to do so.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] t-vi commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
t-vi commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437632153



##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:

Review comment:
       Yeah, I'm sticking to cuda for now to not create two confusions where there was only one. :slightly_smiling_face: 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437695904



##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:
+            # llvm-based backends cannot do non-explicit unrolling
+            cfg.define_knob("unroll_explicit", [1])
+        else:
+            cfg.define_knob("unroll_explicit", [0, 1])
+
+        if cfg.is_fallback:

Review comment:
       I think you still need to add `"unroll_explicit"` in the fallback cfg.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437597643



##########
File path: python/tvm/relay/op/strategy/cuda.py
##########
@@ -463,9 +463,9 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
     """batch_matmul cuda strategy"""
     strategy = _op.OpStrategy()
     strategy.add_implementation(
-        wrap_compute_batch_matmul(topi.nn.batch_matmul),
+        wrap_compute_batch_matmul(topi.cuda.batch_matmul),
         wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
-        name="batch_matmul.cuda",
+        name="batch_matmul.gpu",

Review comment:
       Why change to `gpu`?

##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:
+            # llvm-based backends cannot do non-explicit unrolling
+            cfg.define_knob("unroll_explicit", [1])
+        else:
+            cfg.define_knob("unroll_explicit", [0, 1])
+
+        if cfg.is_fallback:

Review comment:
       Missing `unroll_explicit` in fallback schedule. Based on your comment, this might result in a problem at LLVM-based backends.

##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -14,13 +14,24 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name,too-many-locals,unused-variable
+# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument
 """cuda batch_matmul operators"""
+import tvm
+from tvm import autotvm
 from tvm import te
 from tvm.contrib import cublas
+from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
+from .. import nn
 from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
 
-def schedule_batch_matmul(outs):
+@autotvm.register_topi_compute("batch_matmul.gpu")
+def batch_matmul(cfg, x, y):
+    """Compute conv2d with NCHW layout"""
+    return nn.batch_matmul(x, y)
+
+
+@autotvm.register_topi_schedule("batch_matmul.gpu")

Review comment:
       ditto.

##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:

Review comment:
       It reminds me that you may also need to add batch_matmul op strategy to ROCM; otherwise ROCM will never use this schedule.

##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -14,13 +14,24 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name,too-many-locals,unused-variable
+# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument
 """cuda batch_matmul operators"""
+import tvm
+from tvm import autotvm
 from tvm import te
 from tvm.contrib import cublas
+from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
+from .. import nn
 from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
 
-def schedule_batch_matmul(outs):
+@autotvm.register_topi_compute("batch_matmul.gpu")

Review comment:
       We use `XXX.cuda` in all other places so I'd suggest changing to `batch_matmul.cuda`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen merged pull request #5752: Make batch matrix multiplication on GPU tunable

Posted by GitBox <gi...@apache.org>.
tqchen merged pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org