You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/05/28 08:36:21 UTC
[tvm] branch main updated: [OpenCL] Avoid SelectNode ambiguous overloading (#11488)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dd2897cb69 [OpenCL] Avoid SelectNode ambiguous overloading (#11488)
dd2897cb69 is described below
commit dd2897cb69d56b36a2d15daf9a43cc22f116b4c7
Author: mhyang-pllab <75...@users.noreply.github.com>
AuthorDate: Sat May 28 16:36:16 2022 +0800
[OpenCL] Avoid SelectNode ambiguous overloading (#11488)
* [OpenCL] Avoid SelectNode ambiguous overloading
* Revert "[OpenCL] Avoid SelectNode ambiguous overloading"
This reverts commit 60f68d2e7f750a0f8e62536da7b3327d1f5f29c1.
* [OpenCL] Avoid SelectNode ambiguous codegen
---
src/target/source/codegen_opencl.cc | 20 +++++++++++++++++---
1 file changed, 17 insertions(+), 3 deletions(-)
diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc
index 1fdf1e7bed..5d04d00339 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -541,12 +541,26 @@ void CodeGenOpenCL::VisitExpr_(const OrNode* op, std::ostream& os) {
}
void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) {
+ std::ostringstream oss;
os << "select(";
- PrintExpr(op->false_value, os);
+ PrintExpr(op->false_value, oss);
+ os << CastFromTo(oss.str(), op->false_value.dtype(), op->dtype);
+ oss.str("");
os << ", ";
- PrintExpr(op->true_value, os);
+ PrintExpr(op->true_value, oss);
+ os << CastFromTo(oss.str(), op->true_value.dtype(), op->dtype);
+ oss.str("");
os << ", ";
- PrintExpr(op->condition, os);
+ PrintExpr(op->condition, oss);
+ if (op->dtype.is_float()) {
+ if (op->condition.dtype().is_uint() || op->condition.dtype().is_int()) {
+ os << oss.str();
+ } else {
+ os << CastTo(oss.str(), DataType::Int(op->dtype.bits(), op->dtype.lanes()));
+ }
+ } else {
+ os << CastFromTo(oss.str(), op->condition.dtype(), op->dtype);
+ }
os << ")";
}