You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/12/06 01:59:28 UTC
[tvm] branch main updated: [ROCm][Auto scheduler] Support Auto
scheduler and NHWC convolution on ROCm (#7038)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3ec739c [ROCm][Auto scheduler] Support Auto scheduler and NHWC convolution on ROCm (#7038)
3ec739c is described below
commit 3ec739c650b2b49c15528f301242e98f00f0493e
Author: masahi <ma...@gmail.com>
AuthorDate: Sun Dec 6 10:59:09 2020 +0900
[ROCm][Auto scheduler] Support Auto scheduler and NHWC convolution on ROCm (#7038)
* add nhwc + winograd support to rocm strategy
* support rocm hw parameters in search task
* run analysis pass for rocm too
* run black
* pylint fix
* use IsGPUTask function
Co-authored-by: Masahiro Masuda <ma...@gmail.com>
---
python/tvm/relay/op/strategy/rocm.py | 52 +++++++++++++++++++++++++++++++-----
src/auto_scheduler/feature.cc | 3 ++-
src/auto_scheduler/search_task.cc | 13 +++++----
3 files changed, 55 insertions(+), 13 deletions(-)
diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py
index f52bbc3..c52da54 100644
--- a/python/tvm/relay/op/strategy/rocm.py
+++ b/python/tvm/relay/op/strategy/rocm.py
@@ -17,8 +17,10 @@
"""Definition of ROCm operator strategy."""
# pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import
from tvm import topi
+from tvm.auto_scheduler import is_auto_scheduler_enabled
from .generic import *
from .. import op as _op
+from .cuda import judge_winograd, naive_schedule
@schedule_lrn.register("rocm")
@@ -67,6 +69,49 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
name="conv2d_nchw_winograd.cuda",
plevel=5,
)
+ elif layout == "NHWC":
+ assert kernel_layout == "HWIO"
+ strategy.add_implementation(
+ wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
+ wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
+ name="conv2d_nhwc.cuda",
+ )
+ N, H, W, _ = get_const_tuple(data.shape)
+ KH, KW, CI, CO = get_const_tuple(kernel.shape)
+
+ (_, judge_winograd_autotvm, judge_winograd_auto_scheduler,) = judge_winograd(
+ N,
+ H,
+ W,
+ KH,
+ KW,
+ CI,
+ CO,
+ padding,
+ stride_h,
+ stride_w,
+ dilation_h,
+ dilation_w,
+ data.dtype,
+ kernel.dtype,
+ pre_flag=False,
+ )
+
+ if judge_winograd_autotvm:
+ strategy.add_implementation(
+ wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_direct),
+ wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_winograd_direct),
+ name="conv2d_nhwc_winograd_direct.cuda",
+ plevel=5,
+ )
+
+ if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler:
+ strategy.add_implementation(
+ wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
+ naive_schedule, # this implementation should never be picked by autotvm
+ name="conv2d_nhwc.winograd",
+ plevel=15,
+ )
elif layout == "HWCN":
assert kernel_layout == "HWIO"
strategy.add_implementation(
@@ -74,13 +119,6 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
name="conv2d_hwcn.cuda",
)
- # TODO(@alexgl-github): Re-enable this after fix the conv2d_nhwc for cuda
- # elif layout == "NHWC":
- # assert kernel_layout == "HWIO"
- # strategy.add_implementation(
- # wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
- # wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
- # name="conv2d_nhwc.cuda")
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
index 0a3d705..53287a0 100755
--- a/src/auto_scheduler/feature.cc
+++ b/src/auto_scheduler/feature.cc
@@ -41,6 +41,7 @@
#include <unordered_map>
#include <vector>
+#include "search_policy/utils.h"
#include "utils.h"
namespace tvm {
@@ -1296,7 +1297,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i
}
auto mod = IRModule(Map<GlobalVar, BaseFunc>({{global_var, f}}));
- if (task->target->kind->device_type == kDLGPU) {
+ if (IsGPUTask(task)) {
auto pass_list = Array<tvm::transform::Pass>();
// Phase 0
pass_list.push_back(tir::transform::InjectPrefetch());
diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc
index 4c8cc6d..5a34755 100755
--- a/src/auto_scheduler/search_task.cc
+++ b/src/auto_scheduler/search_task.cc
@@ -22,6 +22,7 @@
* \brief Meta information and hardware parameters for a search task.
*/
+#include <dlpack/dlpack.h>
#include <tvm/auto_scheduler/search_task.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
@@ -52,11 +53,13 @@ HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_l
HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target,
const Target& target_host) {
- if (target->kind->device_type == kDLCPU) {
+ const auto device_type = target->kind->device_type;
+ if (device_type == kDLCPU) {
return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64, 0, 0, 0, 0, 0);
- } else if (target->kind->device_type == kDLGPU) {
- auto ctx = TVMContext{kDLGPU, 0};
- auto func = tvm::runtime::Registry::Get("device_api.gpu");
+ } else if (device_type == kDLGPU || device_type == kDLROCM) {
+ auto ctx = TVMContext{static_cast<DLDeviceType>(device_type), 0};
+ auto device_name = device_type == kDLGPU ? "device_api.gpu" : "device_api.rocm";
+ auto func = tvm::runtime::Registry::Get(device_name);
ICHECK(func != nullptr) << "Cannot find GPU device_api in registry";
auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());
@@ -77,7 +80,7 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
int max_vthread_extent = warp_size / 4;
return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
max_threads_per_block, max_vthread_extent, warp_size);
- } else if (target->kind->device_type == kDLMetal) {
+ } else if (device_type == kDLMetal) {
// Reference: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
// This setting looks working for Metal GPUs later than A10
int max_shared_memory_per_block = 32 * 1024;