You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2022/11/15 01:45:46 UTC

[tvm] branch main updated: [Codegen] Fix CUDA codegen for int64 Ramp (#13382)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3aa16f72dd [Codegen] Fix CUDA codegen for int64 Ramp (#13382)
3aa16f72dd is described below

commit 3aa16f72dd3f1807b11ec61cf372af07d32099c4
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Mon Nov 14 17:45:39 2022 -0800

    [Codegen] Fix CUDA codegen for int64 Ramp (#13382)
---
 src/target/source/codegen_cuda.cc               | 4 +++-
 tests/python/topi/python/test_topi_transform.py | 1 +
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc
index d96e0cbc16..3ae74cc16d 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -1005,7 +1005,9 @@ void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) {
 
 void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
   CHECK_LE(op->lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed.";
-  os << "(make_int" << op->lanes << "(";
+  os << "(make_";
+  PrintType(op->dtype, os);
+  os << "(";
   for (int i = 0; i < op->lanes; i++) {
     os << "(" << PrintExpr(op->base) << ")"
        << "+(" << PrintExpr(op->stride) << "*" << i << ")";
diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py
index dd5ad1b119..0f64b486f3 100644
--- a/tests/python/topi/python/test_topi_transform.py
+++ b/tests/python/topi/python/test_topi_transform.py
@@ -1040,6 +1040,7 @@ def test_gather():
     verify_gather(np.random.randn(4, 7, 5), 1, np.random.randint(low=0, high=7, size=(4, 10, 5)))
     verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 2)))
     verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 10)))
+    verify_gather(np.random.randn(4, 7, 2), 0, np.random.randint(low=0, high=4, size=(4, 7, 2)))
 
 
 @tvm.testing.uses_gpu