You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/07/25 04:09:18 UTC
[incubator-tvm] branch master updated: [TOPI] Fix CUDA Library
Tuning (#6132)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 00578b7 [TOPI] Fix CUDA Library Tuning (#6132)
00578b7 is described below
commit 00578b77530d5272397f76c2966c125888c4ed94
Author: Cody Yu <co...@gmail.com>
AuthorDate: Fri Jul 24 21:09:06 2020 -0700
[TOPI] Fix CUDA Library Tuning (#6132)
---
python/tvm/autotvm/task/space.py | 7 +++++--
topi/python/topi/cuda/conv2d.py | 7 ++++++-
2 files changed, 11 insertions(+), 3 deletions(-)
diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py
index fbf474f..4937661 100644
--- a/python/tvm/autotvm/task/space.py
+++ b/python/tvm/autotvm/task/space.py
@@ -33,6 +33,7 @@ from collections import namedtuple, OrderedDict
import numpy as np
from tvm.te import schedule, thread_axis
+from tvm.tir import expr
from tvm.autotvm.util import get_const_int
Axis = namedtuple('Axis', ['space', 'index'])
@@ -733,10 +734,12 @@ class ConfigSpace(object):
Parameters
---------
- flop: int or float
+ flop: int or float or IntImm or FloatImm
number of float operations
"""
- self.flop += flop
+ if isinstance(flop, (expr.IntImm, expr.FloatImm)):
+ flop = flop.value
+ self.flop += float(flop)
def raise_error(self, msg):
"""register error in config
diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py
index d98d630..973c216 100644
--- a/topi/python/topi/cuda/conv2d.py
+++ b/topi/python/topi/cuda/conv2d.py
@@ -18,6 +18,7 @@
"""Compute definition for conv2d with cuda backend"""
from tvm import te
from tvm import autotvm
+from tvm.autotvm.task.space import OtherOptionEntity
from tvm.contrib import cudnn
from .. import nn, generic
@@ -99,6 +100,10 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, groups=1,
else:
dtype = data.dtype
+ cfg.define_knob('algo', range(8))
+ if cfg.is_fallback: # Let CUDNN choose the best algo
+ cfg['algo'] = OtherOptionEntity(-1)
+
return cudnn.conv_forward(data,
kernel,
[pt, pl], # cudnn padding pt, pl on both sides of input
@@ -106,7 +111,7 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, groups=1,
[dilation_h, dilation_w],
conv_mode=1,
tensor_format=tensor_format,
- algo=-1, # let CUDNN choose the best algo
+ algo=cfg['algo'].val,
conv_dtype=dtype,
groups=groups)