You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ro...@apache.org on 2022/11/05 15:48:14 UTC
[pinot] branch master updated: [bugfix] fix case-when issue (#9702)
This is an automated email from the ASF dual-hosted git repository.
rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 4b36685daf [bugfix] fix case-when issue (#9702)
4b36685daf is described below
commit 4b36685daf9fcdffd09b332d8134b0c8d9cde5c0
Author: Rong Rong <ro...@apache.org>
AuthorDate: Sat Nov 5 08:48:08 2022 -0700
[bugfix] fix case-when issue (#9702)
* fix case-when issue
* fix backward compatibility issue
* address scalar function
Co-authored-by: Rong Rong <ro...@startree.ai>
---
.../function/scalar/ComparisonFunctions.java | 6 ---
.../common/function/scalar/ObjectFunctions.java | 41 ++++++++++++++++---
.../apache/pinot/sql/parsers/CalciteSqlParser.java | 8 ++--
.../pinot/sql/parsers/CalciteSqlCompilerTest.java | 12 +++---
.../transform/function/CaseTransformFunction.java | 46 +++++++++++++++++++---
.../apache/pinot/query/QueryCompilationTest.java | 3 ++
.../java/org/apache/pinot/query/QueryTestSet.java | 6 +++
.../query/runtime/QueryRunnerExceptionTest.java | 3 ++
8 files changed, 99 insertions(+), 26 deletions(-)
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ComparisonFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ComparisonFunctions.java
index 643e90b7ac..e27ff13d6f 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ComparisonFunctions.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ComparisonFunctions.java
@@ -19,7 +19,6 @@
package org.apache.pinot.common.function.scalar;
import org.apache.pinot.spi.annotations.ScalarFunction;
-import org.apache.pinot.spi.utils.BooleanUtils;
public class ComparisonFunctions {
@@ -64,9 +63,4 @@ public class ComparisonFunctions {
public static boolean between(double val, double a, double b) {
return val > a && val < b;
}
-
- @ScalarFunction
- public static Object caseWhen(Object comparisonResult, Object left, Object right) {
- return BooleanUtils.toBoolean(comparisonResult) ? left : right;
- }
}
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java
index 88250df4fc..944684ae1d 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java
@@ -20,6 +20,8 @@ package org.apache.pinot.common.function.scalar;
import javax.annotation.Nullable;
import org.apache.pinot.spi.annotations.ScalarFunction;
+import org.apache.pinot.spi.utils.BooleanUtils;
+
public class ObjectFunctions {
private ObjectFunctions() {
@@ -92,13 +94,40 @@ public class ObjectFunctions {
return null;
}
- @Nullable
- private static Object coalesce(Object... objects) {
- for (Object o : objects) {
- if (o != null) {
- return o;
+ @ScalarFunction
+ public static Object caseWhen(boolean c1, Object o1, Object oe) {
+ return caseWhenVar(c1, o1, oe);
+ }
+
+ @ScalarFunction
+ public static Object caseWhen(boolean c1, Object o1, boolean c2, Object o2, Object oe) {
+ return caseWhenVar(c1, o1, c2, o2, oe);
+ }
+
+ @ScalarFunction
+ public static Object caseWhen(boolean c1, Object o1, boolean c2, Object o2, boolean c3, Object o3, Object oe) {
+ return caseWhenVar(c1, o1, c2, o2, c3, o3, oe);
+ }
+
+ @ScalarFunction
+ public static Object caseWhen(boolean c1, Object o1, boolean c2, Object o2, boolean c3, Object o3, boolean c4,
+ Object o4, Object oe) {
+ return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, oe);
+ }
+
+ @ScalarFunction
+ public static Object caseWhen(boolean c1, Object o1, boolean c2, Object o2, boolean c3, Object o3, boolean c4,
+ Object o4, boolean c5, Object o5, Object oe) {
+ return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, oe);
+ }
+
+ private static Object caseWhenVar(Object... objs) {
+ for (int i = 0; i < objs.length - 1; i += 2) {
+ if (BooleanUtils.toBoolean(objs[i])) {
+ return objs[i + 1];
}
}
- return null;
+ // with or without else statement.
+ return objs.length % 2 == 0 ? null : objs[objs.length - 1];
}
}
diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
index 18cbca3f41..d77461e978 100644
--- a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
+++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
@@ -19,6 +19,7 @@
package org.apache.pinot.sql.parsers;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Collections;
@@ -717,15 +718,16 @@ public class CalciteSqlParser {
SqlNodeList thenOperands = caseSqlNode.getThenOperands();
SqlNode elseOperand = caseSqlNode.getElseOperand();
Expression caseFuncExpr = RequestUtils.getFunctionExpression("case");
- for (SqlNode whenSqlNode : whenOperands.getList()) {
+ Preconditions.checkState(whenOperands.size() == thenOperands.size());
+ for (int i = 0; i < whenOperands.size(); i++) {
+ SqlNode whenSqlNode = whenOperands.get(i);
Expression whenExpression = toExpression(whenSqlNode);
if (isAggregateExpression(whenExpression)) {
throw new SqlCompilationException(
"Aggregation functions inside WHEN Clause is not supported - " + whenSqlNode);
}
caseFuncExpr.getFunctionCall().addToOperands(whenExpression);
- }
- for (SqlNode thenSqlNode : thenOperands.getList()) {
+ SqlNode thenSqlNode = thenOperands.get(i);
Expression thenExpression = toExpression(thenSqlNode);
if (isAggregateExpression(thenExpression)) {
throw new SqlCompilationException(
diff --git a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
index 2c1e9d26e3..d4afdc71b3 100644
--- a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
+++ b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
@@ -93,11 +93,11 @@ public class CalciteSqlCompilerTest {
Assert.assertEquals(greatThanFunc.getOperator(), FilterKind.GREATER_THAN.name());
Assert.assertEquals(greatThanFunc.getOperands().get(0).getIdentifier().getName(), "Quantity");
Assert.assertEquals(greatThanFunc.getOperands().get(1).getLiteral().getFieldValue(), 30L);
- Function equalsFunc = caseFunc.getOperands().get(1).getFunctionCall();
+ Assert.assertEquals(caseFunc.getOperands().get(1).getLiteral().getFieldValue(), "The quantity is greater than 30");
+ Function equalsFunc = caseFunc.getOperands().get(2).getFunctionCall();
Assert.assertEquals(equalsFunc.getOperator(), FilterKind.EQUALS.name());
Assert.assertEquals(equalsFunc.getOperands().get(0).getIdentifier().getName(), "Quantity");
Assert.assertEquals(equalsFunc.getOperands().get(1).getLiteral().getFieldValue(), 30L);
- Assert.assertEquals(caseFunc.getOperands().get(2).getLiteral().getFieldValue(), "The quantity is greater than 30");
Assert.assertEquals(caseFunc.getOperands().get(3).getLiteral().getFieldValue(), "The quantity is 30");
Assert.assertEquals(caseFunc.getOperands().get(4).getLiteral().getFieldValue(), "The quantity is under 30");
@@ -124,16 +124,16 @@ public class CalciteSqlCompilerTest {
Assert.assertEquals(greatThanFunc.getOperator(), FilterKind.GREATER_THAN.name());
Assert.assertEquals(greatThanFunc.getOperands().get(0).getIdentifier().getName(), "Quantity");
Assert.assertEquals(greatThanFunc.getOperands().get(1).getLiteral().getFieldValue(), 30L);
- greatThanFunc = caseFunc.getOperands().get(1).getFunctionCall();
+ Assert.assertEquals(caseFunc.getOperands().get(1).getLiteral().getFieldValue(), 3L);
+ greatThanFunc = caseFunc.getOperands().get(2).getFunctionCall();
Assert.assertEquals(greatThanFunc.getOperator(), FilterKind.GREATER_THAN.name());
Assert.assertEquals(greatThanFunc.getOperands().get(0).getIdentifier().getName(), "Quantity");
Assert.assertEquals(greatThanFunc.getOperands().get(1).getLiteral().getFieldValue(), 20L);
- greatThanFunc = caseFunc.getOperands().get(2).getFunctionCall();
+ Assert.assertEquals(caseFunc.getOperands().get(3).getLiteral().getFieldValue(), 2L);
+ greatThanFunc = caseFunc.getOperands().get(4).getFunctionCall();
Assert.assertEquals(greatThanFunc.getOperator(), FilterKind.GREATER_THAN.name());
Assert.assertEquals(greatThanFunc.getOperands().get(0).getIdentifier().getName(), "Quantity");
Assert.assertEquals(greatThanFunc.getOperands().get(1).getLiteral().getFieldValue(), 10L);
- Assert.assertEquals(caseFunc.getOperands().get(3).getLiteral().getFieldValue(), 3L);
- Assert.assertEquals(caseFunc.getOperands().get(4).getLiteral().getFieldValue(), 2L);
Assert.assertEquals(caseFunc.getOperands().get(5).getLiteral().getFieldValue(), 1L);
Assert.assertEquals(caseFunc.getOperands().get(6).getLiteral().getFieldValue(), 0L);
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java
index ddb951eacf..15279ea86d 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java
@@ -84,17 +84,53 @@ public class CaseTransformFunction extends BaseTransformFunction {
}
int numWhenStatements = arguments.size() / 2;
_whenStatements = new ArrayList<>(numWhenStatements);
+ _elseThenStatements = new ArrayList<>(numWhenStatements + 1);
+ constructStatementList(arguments);
+ _selections = new boolean[_elseThenStatements.size()];
+ Collections.reverse(_elseThenStatements);
+ Collections.reverse(_whenStatements);
+ _resultMetadata = calculateResultMetadata();
+ }
+
+ private void constructStatementList(List<TransformFunction> arguments) {
+ int numWhenStatements = arguments.size() / 2;
+ boolean allBooleanFirstHalf = true;
+ boolean notAllBooleanOddHalf = false;
+ for (int i = 0; i < numWhenStatements; i++) {
+ if (arguments.get(i).getResultMetadata().getDataType() != DataType.BOOLEAN) {
+ allBooleanFirstHalf = false;
+ }
+ if (arguments.get(i * 2).getResultMetadata().getDataType() != DataType.BOOLEAN) {
+ notAllBooleanOddHalf = true;
+ }
+ }
+ if (allBooleanFirstHalf && notAllBooleanOddHalf) {
+ constructStatementListLegacy(arguments);
+ } else {
+ constructStatementListCalcite(arguments);
+ }
+ }
+
+ private void constructStatementListCalcite(List<TransformFunction> arguments) {
+ int numWhenStatements = arguments.size() / 2;
+ // alternating WHEN and THEN clause, last one ELSE
+ for (int i = 0; i < numWhenStatements; i++) {
+ _whenStatements.add(arguments.get(i * 2));
+ _elseThenStatements.add(arguments.get(i * 2 + 1));
+ }
+ _elseThenStatements.add(arguments.get(arguments.size() - 1));
+ }
+
+ // TODO: Legacy format, this is here for backward compatibility support, remove after release 0.12
+ private void constructStatementListLegacy(List<TransformFunction> arguments) {
+ int numWhenStatements = arguments.size() / 2;
+ // first half WHEN, second half THEN, last one ELSE
for (int i = 0; i < numWhenStatements; i++) {
_whenStatements.add(arguments.get(i));
}
- _elseThenStatements = new ArrayList<>(numWhenStatements + 1);
for (int i = numWhenStatements; i < numWhenStatements * 2 + 1; i++) {
_elseThenStatements.add(arguments.get(i));
}
- _selections = new boolean[_elseThenStatements.size()];
- Collections.reverse(_elseThenStatements);
- Collections.reverse(_whenStatements);
- _resultMetadata = calculateResultMetadata();
}
private TransformResultMetadata calculateResultMetadata() {
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
index 9c220b4438..87b3c940f2 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
@@ -259,6 +259,9 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
new Object[]{"SELECT a.col1 FROM a WHERE a.col1 IN ()", "Encountered \"\" at line"},
// AT TIME ZONE should fail
new Object[]{"SELECT a.col1 AT TIME ZONE 'PST' FROM a", "No match found for function signature AT_TIME_ZONE"},
+ // CASE WHEN with non-consolidated result type at compile time.
+ new Object[]{"SELECT SUM(CASE WHEN col3 > 10 THEN 1 WHEN col3 > 20 THEN 2 WHEN col3 > 30 THEN 3 "
+ + "WHEN col3 > 40 THEN 4 WHEN col3 > 50 THEN '5' ELSE 0 END) FROM a", "while converting CASE WHEN"},
};
}
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryTestSet.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryTestSet.java
index f8f90a7f2c..575efcc903 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryTestSet.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryTestSet.java
@@ -217,6 +217,12 @@ public class QueryTestSet {
// COALESCE function
new Object[]{"SELECT a.col1, COALESCE(b.col3, 0) FROM a LEFT JOIN b ON a.col1 = b.col2"},
new Object[]{"SELECT a.col1, COALESCE(a.col3, 0) FROM a WHERE COALESCE(a.col2, 'bar') = 'bar'"},
+
+ // CASE WHEN function
+ new Object[]{"SELECT MAX(CAST((CASE WHEN col3 > 0 THEN 1 WHEN col3 > 10 then 2 ELSE 0 END) AS INTEGER)) "
+ + " FROM a"},
+ new Object[]{"SELECT col2, CASE WHEN SUM(col3) > 0 THEN 1 WHEN SUM(col3) > 10 then 2 WHEN SUM(col3) > 100 "
+ + " THEN 3 ELSE 0 END FROM a GROUP BY col2"}
};
}
}
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerExceptionTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerExceptionTest.java
index 5b5c084ce5..15c981e206 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerExceptionTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerExceptionTest.java
@@ -73,6 +73,9 @@ public class QueryRunnerExceptionTest extends QueryRunnerTestBase {
"ArithmeticFunctions.least(double,double) with arguments"},
// Function that tries to cast String to Number should throw runtime exception
new Object[]{"SELECT a.col2, b.col1 FROM a JOIN b ON a.col1 = b.col3", "transform function: cast"},
+ // standard SqlOpTable function that runs out of signature list in actual impl throws not found exception
+ new Object[]{"SELECT CASE WHEN col3 > 10 THEN 1 WHEN col3 > 20 THEN 2 WHEN col3 > 30 THEN 3 "
+ + "WHEN col3 > 40 THEN 4 WHEN col3 > 50 THEN 5 WHEN col3 > 60 THEN '6' ELSE 0 END FROM a", "caseWhen"},
};
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org