You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/05/28 08:36:21 UTC

[tvm] branch main updated: [OpenCL] Avoid SelectNode ambiguous overloading (#11488)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dd2897cb69 [OpenCL] Avoid SelectNode ambiguous overloading (#11488)
dd2897cb69 is described below

commit dd2897cb69d56b36a2d15daf9a43cc22f116b4c7
Author: mhyang-pllab <75...@users.noreply.github.com>
AuthorDate: Sat May 28 16:36:16 2022 +0800

    [OpenCL] Avoid SelectNode ambiguous overloading (#11488)
    
    * [OpenCL] Avoid SelectNode ambiguous overloading
    
    * Revert "[OpenCL] Avoid SelectNode ambiguous overloading"
    
    This reverts commit 60f68d2e7f750a0f8e62536da7b3327d1f5f29c1.
    
    * [OpenCL] Avoid SelectNode ambiguous codegen
---
 src/target/source/codegen_opencl.cc | 20 +++++++++++++++++---
 1 file changed, 17 insertions(+), 3 deletions(-)

diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc
index 1fdf1e7bed..5d04d00339 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -541,12 +541,26 @@ void CodeGenOpenCL::VisitExpr_(const OrNode* op, std::ostream& os) {
 }
 
 void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) {
+  std::ostringstream oss;
   os << "select(";
-  PrintExpr(op->false_value, os);
+  PrintExpr(op->false_value, oss);
+  os << CastFromTo(oss.str(), op->false_value.dtype(), op->dtype);
+  oss.str("");
   os << ", ";
-  PrintExpr(op->true_value, os);
+  PrintExpr(op->true_value, oss);
+  os << CastFromTo(oss.str(), op->true_value.dtype(), op->dtype);
+  oss.str("");
   os << ", ";
-  PrintExpr(op->condition, os);
+  PrintExpr(op->condition, oss);
+  if (op->dtype.is_float()) {
+    if (op->condition.dtype().is_uint() || op->condition.dtype().is_int()) {
+      os << oss.str();
+    } else {
+      os << CastTo(oss.str(), DataType::Int(op->dtype.bits(), op->dtype.lanes()));
+    }
+  } else {
+    os << CastFromTo(oss.str(), op->condition.dtype(), op->dtype);
+  }
   os << ")";
 }