You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ro...@apache.org on 2022/11/05 15:48:14 UTC

[pinot] branch master updated: [bugfix] fix case-when issue (#9702)

This is an automated email from the ASF dual-hosted git repository.

rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 4b36685daf [bugfix] fix case-when issue (#9702)
4b36685daf is described below

commit 4b36685daf9fcdffd09b332d8134b0c8d9cde5c0
Author: Rong Rong <ro...@apache.org>
AuthorDate: Sat Nov 5 08:48:08 2022 -0700

    [bugfix] fix case-when issue (#9702)
    
    * fix case-when issue
    * fix backward compatibility issue
    * address scalar function
    
    Co-authored-by: Rong Rong <ro...@startree.ai>
---
 .../function/scalar/ComparisonFunctions.java       |  6 ---
 .../common/function/scalar/ObjectFunctions.java    | 41 ++++++++++++++++---
 .../apache/pinot/sql/parsers/CalciteSqlParser.java |  8 ++--
 .../pinot/sql/parsers/CalciteSqlCompilerTest.java  | 12 +++---
 .../transform/function/CaseTransformFunction.java  | 46 +++++++++++++++++++---
 .../apache/pinot/query/QueryCompilationTest.java   |  3 ++
 .../java/org/apache/pinot/query/QueryTestSet.java  |  6 +++
 .../query/runtime/QueryRunnerExceptionTest.java    |  3 ++
 8 files changed, 99 insertions(+), 26 deletions(-)

diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ComparisonFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ComparisonFunctions.java
index 643e90b7ac..e27ff13d6f 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ComparisonFunctions.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ComparisonFunctions.java
@@ -19,7 +19,6 @@
 package org.apache.pinot.common.function.scalar;
 
 import org.apache.pinot.spi.annotations.ScalarFunction;
-import org.apache.pinot.spi.utils.BooleanUtils;
 
 
 public class ComparisonFunctions {
@@ -64,9 +63,4 @@ public class ComparisonFunctions {
   public static boolean between(double val, double a, double b) {
     return val > a && val < b;
   }
-
-  @ScalarFunction
-  public static Object caseWhen(Object comparisonResult, Object left, Object right) {
-    return BooleanUtils.toBoolean(comparisonResult) ? left : right;
-  }
 }
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java
index 88250df4fc..944684ae1d 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java
@@ -20,6 +20,8 @@ package org.apache.pinot.common.function.scalar;
 
 import javax.annotation.Nullable;
 import org.apache.pinot.spi.annotations.ScalarFunction;
+import org.apache.pinot.spi.utils.BooleanUtils;
+
 
 public class ObjectFunctions {
   private ObjectFunctions() {
@@ -92,13 +94,40 @@ public class ObjectFunctions {
     return null;
   }
 
-  @Nullable
-  private static Object coalesce(Object... objects) {
-    for (Object o : objects) {
-      if (o != null) {
-        return o;
+  @ScalarFunction
+  public static Object caseWhen(boolean c1, Object o1, Object oe) {
+    return caseWhenVar(c1, o1, oe);
+  }
+
+  @ScalarFunction
+  public static Object caseWhen(boolean c1, Object o1, boolean c2, Object o2, Object oe) {
+    return caseWhenVar(c1, o1, c2, o2, oe);
+  }
+
+  @ScalarFunction
+  public static Object caseWhen(boolean c1, Object o1, boolean c2, Object o2, boolean c3, Object o3, Object oe) {
+    return caseWhenVar(c1, o1, c2, o2, c3, o3, oe);
+  }
+
+  @ScalarFunction
+  public static Object caseWhen(boolean c1, Object o1, boolean c2, Object o2, boolean c3, Object o3, boolean c4,
+      Object o4, Object oe) {
+    return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, oe);
+  }
+
+  @ScalarFunction
+  public static Object caseWhen(boolean c1, Object o1, boolean c2, Object o2, boolean c3, Object o3, boolean c4,
+      Object o4, boolean c5, Object o5, Object oe) {
+    return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, c5, o5, oe);
+  }
+
+  private static Object caseWhenVar(Object... objs) {
+    for (int i = 0; i < objs.length - 1; i += 2) {
+      if (BooleanUtils.toBoolean(objs[i])) {
+        return objs[i + 1];
       }
     }
-    return null;
+    // with or without else statement.
+    return objs.length % 2 == 0 ? null : objs[objs.length - 1];
   }
 }
diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
index 18cbca3f41..d77461e978 100644
--- a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
+++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.sql.parsers;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
 import java.io.StringReader;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -717,15 +718,16 @@ public class CalciteSqlParser {
         SqlNodeList thenOperands = caseSqlNode.getThenOperands();
         SqlNode elseOperand = caseSqlNode.getElseOperand();
         Expression caseFuncExpr = RequestUtils.getFunctionExpression("case");
-        for (SqlNode whenSqlNode : whenOperands.getList()) {
+        Preconditions.checkState(whenOperands.size() == thenOperands.size());
+        for (int i = 0; i < whenOperands.size(); i++) {
+          SqlNode whenSqlNode = whenOperands.get(i);
           Expression whenExpression = toExpression(whenSqlNode);
           if (isAggregateExpression(whenExpression)) {
             throw new SqlCompilationException(
                 "Aggregation functions inside WHEN Clause is not supported - " + whenSqlNode);
           }
           caseFuncExpr.getFunctionCall().addToOperands(whenExpression);
-        }
-        for (SqlNode thenSqlNode : thenOperands.getList()) {
+          SqlNode thenSqlNode = thenOperands.get(i);
           Expression thenExpression = toExpression(thenSqlNode);
           if (isAggregateExpression(thenExpression)) {
             throw new SqlCompilationException(
diff --git a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
index 2c1e9d26e3..d4afdc71b3 100644
--- a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
+++ b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
@@ -93,11 +93,11 @@ public class CalciteSqlCompilerTest {
     Assert.assertEquals(greatThanFunc.getOperator(), FilterKind.GREATER_THAN.name());
     Assert.assertEquals(greatThanFunc.getOperands().get(0).getIdentifier().getName(), "Quantity");
     Assert.assertEquals(greatThanFunc.getOperands().get(1).getLiteral().getFieldValue(), 30L);
-    Function equalsFunc = caseFunc.getOperands().get(1).getFunctionCall();
+    Assert.assertEquals(caseFunc.getOperands().get(1).getLiteral().getFieldValue(), "The quantity is greater than 30");
+    Function equalsFunc = caseFunc.getOperands().get(2).getFunctionCall();
     Assert.assertEquals(equalsFunc.getOperator(), FilterKind.EQUALS.name());
     Assert.assertEquals(equalsFunc.getOperands().get(0).getIdentifier().getName(), "Quantity");
     Assert.assertEquals(equalsFunc.getOperands().get(1).getLiteral().getFieldValue(), 30L);
-    Assert.assertEquals(caseFunc.getOperands().get(2).getLiteral().getFieldValue(), "The quantity is greater than 30");
     Assert.assertEquals(caseFunc.getOperands().get(3).getLiteral().getFieldValue(), "The quantity is 30");
     Assert.assertEquals(caseFunc.getOperands().get(4).getLiteral().getFieldValue(), "The quantity is under 30");
 
@@ -124,16 +124,16 @@ public class CalciteSqlCompilerTest {
     Assert.assertEquals(greatThanFunc.getOperator(), FilterKind.GREATER_THAN.name());
     Assert.assertEquals(greatThanFunc.getOperands().get(0).getIdentifier().getName(), "Quantity");
     Assert.assertEquals(greatThanFunc.getOperands().get(1).getLiteral().getFieldValue(), 30L);
-    greatThanFunc = caseFunc.getOperands().get(1).getFunctionCall();
+    Assert.assertEquals(caseFunc.getOperands().get(1).getLiteral().getFieldValue(), 3L);
+    greatThanFunc = caseFunc.getOperands().get(2).getFunctionCall();
     Assert.assertEquals(greatThanFunc.getOperator(), FilterKind.GREATER_THAN.name());
     Assert.assertEquals(greatThanFunc.getOperands().get(0).getIdentifier().getName(), "Quantity");
     Assert.assertEquals(greatThanFunc.getOperands().get(1).getLiteral().getFieldValue(), 20L);
-    greatThanFunc = caseFunc.getOperands().get(2).getFunctionCall();
+    Assert.assertEquals(caseFunc.getOperands().get(3).getLiteral().getFieldValue(), 2L);
+    greatThanFunc = caseFunc.getOperands().get(4).getFunctionCall();
     Assert.assertEquals(greatThanFunc.getOperator(), FilterKind.GREATER_THAN.name());
     Assert.assertEquals(greatThanFunc.getOperands().get(0).getIdentifier().getName(), "Quantity");
     Assert.assertEquals(greatThanFunc.getOperands().get(1).getLiteral().getFieldValue(), 10L);
-    Assert.assertEquals(caseFunc.getOperands().get(3).getLiteral().getFieldValue(), 3L);
-    Assert.assertEquals(caseFunc.getOperands().get(4).getLiteral().getFieldValue(), 2L);
     Assert.assertEquals(caseFunc.getOperands().get(5).getLiteral().getFieldValue(), 1L);
     Assert.assertEquals(caseFunc.getOperands().get(6).getLiteral().getFieldValue(), 0L);
   }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java
index ddb951eacf..15279ea86d 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java
@@ -84,17 +84,53 @@ public class CaseTransformFunction extends BaseTransformFunction {
     }
     int numWhenStatements = arguments.size() / 2;
     _whenStatements = new ArrayList<>(numWhenStatements);
+    _elseThenStatements = new ArrayList<>(numWhenStatements + 1);
+    constructStatementList(arguments);
+    _selections = new boolean[_elseThenStatements.size()];
+    Collections.reverse(_elseThenStatements);
+    Collections.reverse(_whenStatements);
+    _resultMetadata = calculateResultMetadata();
+  }
+
+  private void constructStatementList(List<TransformFunction> arguments) {
+    int numWhenStatements = arguments.size() / 2;
+    boolean allBooleanFirstHalf = true;
+    boolean notAllBooleanOddHalf = false;
+    for (int i = 0; i < numWhenStatements; i++) {
+      if (arguments.get(i).getResultMetadata().getDataType() != DataType.BOOLEAN) {
+        allBooleanFirstHalf = false;
+      }
+      if (arguments.get(i * 2).getResultMetadata().getDataType() != DataType.BOOLEAN) {
+        notAllBooleanOddHalf = true;
+      }
+    }
+    if (allBooleanFirstHalf && notAllBooleanOddHalf) {
+      constructStatementListLegacy(arguments);
+    } else {
+      constructStatementListCalcite(arguments);
+    }
+  }
+
+  private void constructStatementListCalcite(List<TransformFunction> arguments) {
+    int numWhenStatements = arguments.size() / 2;
+    // alternating WHEN and THEN clause, last one ELSE
+    for (int i = 0; i < numWhenStatements; i++) {
+      _whenStatements.add(arguments.get(i * 2));
+      _elseThenStatements.add(arguments.get(i * 2 + 1));
+    }
+    _elseThenStatements.add(arguments.get(arguments.size() - 1));
+  }
+
+  // TODO: Legacy format, this is here for backward compatibility support, remove after release 0.12
+  private void constructStatementListLegacy(List<TransformFunction> arguments) {
+    int numWhenStatements = arguments.size() / 2;
+    // first half WHEN, second half THEN, last one ELSE
     for (int i = 0; i < numWhenStatements; i++) {
       _whenStatements.add(arguments.get(i));
     }
-    _elseThenStatements = new ArrayList<>(numWhenStatements + 1);
     for (int i = numWhenStatements; i < numWhenStatements * 2 + 1; i++) {
       _elseThenStatements.add(arguments.get(i));
     }
-    _selections = new boolean[_elseThenStatements.size()];
-    Collections.reverse(_elseThenStatements);
-    Collections.reverse(_whenStatements);
-    _resultMetadata = calculateResultMetadata();
   }
 
   private TransformResultMetadata calculateResultMetadata() {
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
index 9c220b4438..87b3c940f2 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
@@ -259,6 +259,9 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
         new Object[]{"SELECT a.col1 FROM a WHERE a.col1 IN ()", "Encountered \"\" at line"},
         // AT TIME ZONE should fail
         new Object[]{"SELECT a.col1 AT TIME ZONE 'PST' FROM a", "No match found for function signature AT_TIME_ZONE"},
+        // CASE WHEN with non-consolidated result type at compile time.
+        new Object[]{"SELECT SUM(CASE WHEN col3 > 10 THEN 1 WHEN col3 > 20 THEN 2 WHEN col3 > 30 THEN 3 "
+            + "WHEN col3 > 40 THEN 4 WHEN col3 > 50 THEN '5' ELSE 0 END) FROM a", "while converting CASE WHEN"},
     };
   }
 
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryTestSet.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryTestSet.java
index f8f90a7f2c..575efcc903 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryTestSet.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryTestSet.java
@@ -217,6 +217,12 @@ public class QueryTestSet {
         // COALESCE function
         new Object[]{"SELECT a.col1, COALESCE(b.col3, 0) FROM a LEFT JOIN b ON a.col1 = b.col2"},
         new Object[]{"SELECT a.col1, COALESCE(a.col3, 0) FROM a WHERE COALESCE(a.col2, 'bar') = 'bar'"},
+
+        // CASE WHEN function
+        new Object[]{"SELECT MAX(CAST((CASE WHEN col3 > 0 THEN 1 WHEN col3 > 10 then 2 ELSE 0 END) AS INTEGER)) "
+            + " FROM a"},
+        new Object[]{"SELECT col2, CASE WHEN SUM(col3) > 0 THEN 1 WHEN SUM(col3) > 10 then 2 WHEN SUM(col3) > 100 "
+            + " THEN 3 ELSE 0 END FROM a GROUP BY col2"}
     };
   }
 }
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerExceptionTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerExceptionTest.java
index 5b5c084ce5..15c981e206 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerExceptionTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerExceptionTest.java
@@ -73,6 +73,9 @@ public class QueryRunnerExceptionTest extends QueryRunnerTestBase {
             "ArithmeticFunctions.least(double,double) with arguments"},
         // Function that tries to cast String to Number should throw runtime exception
         new Object[]{"SELECT a.col2, b.col1 FROM a JOIN b ON a.col1 = b.col3", "transform function: cast"},
+        // standard SqlOpTable function that runs out of signature list in actual impl throws not found exception
+        new Object[]{"SELECT CASE WHEN col3 > 10 THEN 1 WHEN col3 > 20 THEN 2 WHEN col3 > 30 THEN 3 "
+            + "WHEN col3 > 40 THEN 4 WHEN col3 > 50 THEN 5 WHEN col3 > 60 THEN '6' ELSE 0 END FROM a", "caseWhen"},
     };
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org