You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/03/31 16:37:59 UTC

[incubator-tvm] branch master updated: rocm: fix dense_rocblas in strategy, topi (#5191)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 9cb9a51  rocm: fix dense_rocblas in strategy, topi (#5191)
9cb9a51 is described below

commit 9cb9a51f37eaa9c7692f15f8c5ae52fa70394209
Author: Thomas Viehmann <tv...@beamnet.de>
AuthorDate: Tue Mar 31 18:37:51 2020 +0200

    rocm: fix dense_rocblas in strategy, topi (#5191)
---
 python/tvm/relay/op/strategy/rocm.py | 2 +-
 topi/python/topi/rocm/dense.py       | 2 ++
 2 files changed, 3 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py
index 0486f71..6cda346 100644
--- a/python/tvm/relay/op/strategy/rocm.py
+++ b/python/tvm/relay/op/strategy/rocm.py
@@ -129,7 +129,7 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
         assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
         strategy.add_implementation(
             wrap_compute_dense(topi.rocm.dense_rocblas),
-            wrap_topi_schedule(topi.rocm.dense_rocblas),
+            wrap_topi_schedule(topi.rocm.schedule_dense_rocblas),
             name="dense_rocblas.rocm",
             plevel=15)
     return strategy
diff --git a/topi/python/topi/rocm/dense.py b/topi/python/topi/rocm/dense.py
index 097120d..989cc2a 100644
--- a/topi/python/topi/rocm/dense.py
+++ b/topi/python/topi/rocm/dense.py
@@ -123,6 +123,8 @@ def dense_rocblas(cfg, data, weight, bias=None, out_dtype=None):
     output : tvm.te.Tensor
         2-D with shape [batch, out_dim]
     """
+    if out_dtype is None:
+        out_dtype = data.dtype
     assert out_dtype == data.dtype, "Mixed precision not supported."
     matmul = rocblas.matmul(data, weight, False, True)
     batch, in_dim = data.shape