You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ja...@apache.org on 2021/10/25 07:54:17 UTC

[pinot] branch master updated: Parsing Support for FILTER Clauses (#7566)

This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new c3d57c1  Parsing Support for FILTER Clauses (#7566)
c3d57c1 is described below

commit c3d57c117647e380710ea1c50af0266ff99e91eb
Author: Atri Sharma <at...@gmail.com>
AuthorDate: Mon Oct 25 13:23:54 2021 +0530

    Parsing Support for FILTER Clauses (#7566)
    
    Introduce the parsing and context construction logic for FILTER clauses in aggregates.
---
 .../request/context/RequestContextUtils.java       | 93 ++++++++++++++++++++++
 .../core/query/request/context/QueryContext.java   | 57 +++++++++++--
 .../BrokerRequestToQueryContextConverterTest.java  | 46 +++++++----
 3 files changed, 174 insertions(+), 22 deletions(-)

diff --git a/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java
index 3d8f185..86c5c30 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java
@@ -260,6 +260,99 @@ public class RequestContextUtils {
   }
 
   /**
+   * Converts the given filter {@link ExpressionContext} into a {@link FilterContext}.
+   * <p>NOTE: Currently the query engine only accepts string literals as the right-hand side of the predicate, so we
+   *          always convert the right-hand side expressions into strings.
+   */
+  public static FilterContext getFilter(ExpressionContext filterExpression) {
+    FunctionContext filterFunction = filterExpression.getFunction();
+    FilterKind filterKind = FilterKind.valueOf(filterFunction.getFunctionName().toUpperCase());
+    List<ExpressionContext> operands = filterFunction.getArguments();
+    int numOperands = operands.size();
+    switch (filterKind) {
+      case AND:
+        List<FilterContext> children = new ArrayList<>(numOperands);
+        for (ExpressionContext operand : operands) {
+          children.add(getFilter(operand));
+        }
+        return new FilterContext(FilterContext.Type.AND, children, null);
+      case OR:
+        children = new ArrayList<>(numOperands);
+        for (ExpressionContext operand : operands) {
+          children.add(getFilter(operand));
+        }
+        return new FilterContext(FilterContext.Type.OR, children, null);
+      case EQUALS:
+        return new FilterContext(FilterContext.Type.PREDICATE, null,
+            new EqPredicate(operands.get(0), getStringValue(operands.get(1))));
+      case NOT_EQUALS:
+        return new FilterContext(FilterContext.Type.PREDICATE, null,
+            new NotEqPredicate(operands.get(0), getStringValue(operands.get(1))));
+      case IN:
+        List<String> values = new ArrayList<>(numOperands - 1);
+        for (int i = 1; i < numOperands; i++) {
+          values.add(getStringValue(operands.get(i)));
+        }
+        return new FilterContext(FilterContext.Type.PREDICATE, null, new InPredicate(operands.get(0), values));
+      case NOT_IN:
+        values = new ArrayList<>(numOperands - 1);
+        for (int i = 1; i < numOperands; i++) {
+          values.add(getStringValue(operands.get(i)));
+        }
+        return new FilterContext(FilterContext.Type.PREDICATE, null, new NotInPredicate(operands.get(0), values));
+      case GREATER_THAN:
+        return new FilterContext(FilterContext.Type.PREDICATE, null,
+            new RangePredicate(operands.get(0), false, getStringValue(operands.get(1)), false,
+                RangePredicate.UNBOUNDED));
+      case GREATER_THAN_OR_EQUAL:
+        return new FilterContext(FilterContext.Type.PREDICATE, null,
+            new RangePredicate(operands.get(0), true, getStringValue(operands.get(1)), false,
+                RangePredicate.UNBOUNDED));
+      case LESS_THAN:
+        return new FilterContext(FilterContext.Type.PREDICATE, null,
+            new RangePredicate(operands.get(0), false, RangePredicate.UNBOUNDED, false,
+                getStringValue(operands.get(1))));
+      case LESS_THAN_OR_EQUAL:
+        return new FilterContext(FilterContext.Type.PREDICATE, null,
+            new RangePredicate(operands.get(0), false, RangePredicate.UNBOUNDED, true,
+                getStringValue(operands.get(1))));
+      case BETWEEN:
+        return new FilterContext(FilterContext.Type.PREDICATE, null,
+            new RangePredicate(operands.get(0), true, getStringValue(operands.get(1)), true,
+                getStringValue(operands.get(2))));
+      case RANGE:
+        return new FilterContext(FilterContext.Type.PREDICATE, null,
+            new RangePredicate(operands.get(0), getStringValue(operands.get(1))));
+      case REGEXP_LIKE:
+        return new FilterContext(FilterContext.Type.PREDICATE, null,
+            new RegexpLikePredicate(operands.get(0), getStringValue(operands.get(1))));
+      case LIKE:
+        return new FilterContext(FilterContext.Type.PREDICATE, null, new RegexpLikePredicate(operands.get(0),
+            LikeToRegexpLikePatternConverterUtils.processValue(getStringValue(operands.get(1)))));
+      case TEXT_MATCH:
+        return new FilterContext(FilterContext.Type.PREDICATE, null,
+            new TextMatchPredicate(operands.get(0), getStringValue(operands.get(1))));
+      case JSON_MATCH:
+        return new FilterContext(FilterContext.Type.PREDICATE, null,
+            new JsonMatchPredicate(operands.get(0), getStringValue(operands.get(1))));
+      case IS_NULL:
+        return new FilterContext(FilterContext.Type.PREDICATE, null, new IsNullPredicate(operands.get(0)));
+      case IS_NOT_NULL:
+        return new FilterContext(FilterContext.Type.PREDICATE, null, new IsNotNullPredicate(operands.get(0)));
+      default:
+        throw new IllegalStateException();
+    }
+  }
+
+  private static String getStringValue(ExpressionContext expressionContext) {
+    if (expressionContext.getType() != ExpressionContext.Type.LITERAL) {
+      throw new BadQueryRequestException(
+          "Pinot does not support column or function on the right-hand side of the predicate");
+    }
+    return expressionContext.getLiteral();
+  }
+
+  /**
    * Converts the given {@link FilterQueryTree} into a {@link FilterContext}.
    */
   public static FilterContext getFilter(FilterQueryTree node) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
index a1f68fe..f857bcd 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java
@@ -18,6 +18,7 @@
  */
 package org.apache.pinot.core.query.request.context;
 
+import com.google.common.base.Preconditions;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
@@ -26,11 +27,13 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import javax.annotation.Nullable;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.request.BrokerRequest;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.request.context.FilterContext;
 import org.apache.pinot.common.request.context.FunctionContext;
 import org.apache.pinot.common.request.context.OrderByExpressionContext;
+import org.apache.pinot.common.request.context.RequestContextUtils;
 import org.apache.pinot.core.plan.maker.InstancePlanMakerImplV2;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory;
@@ -82,6 +85,9 @@ public class QueryContext {
 
   // Pre-calculate the aggregation functions and columns for the query so that it can be shared across all the segments
   private AggregationFunction[] _aggregationFunctions;
+  private List<Pair<AggregationFunction, FilterContext>> _filteredAggregationFunctions;
+  // TODO: Use Pair<FunctionContext, FilterContext> as key to support filtered aggregations in order-by and post
+  //       aggregation
   private Map<FunctionContext, Integer> _aggregationFunctionIndexMap;
   private Set<String> _columns;
 
@@ -221,6 +227,14 @@ public class QueryContext {
   }
 
   /**
+   * Returns the filtered aggregation expressions for the query.
+   */
+  @Nullable
+  public List<Pair<AggregationFunction, FilterContext>> getFilteredAggregationFunctions() {
+    return _filteredAggregationFunctions;
+  }
+
+  /**
    * Returns a map from the AGGREGATION FunctionContext to the index of the corresponding AggregationFunction in the
    * aggregation functions array.
    */
@@ -408,19 +422,28 @@ public class QueryContext {
      */
     private void generateAggregationFunctions(QueryContext queryContext) {
       List<AggregationFunction> aggregationFunctions = new ArrayList<>();
+      List<Pair<AggregationFunction, FilterContext>> filteredAggregationFunctions = new ArrayList<>();
       Map<FunctionContext, Integer> aggregationFunctionIndexMap = new HashMap<>();
 
       // Add aggregation functions in the SELECT clause
       // NOTE: DO NOT deduplicate the aggregation functions in the SELECT clause because that involves protocol change.
       List<FunctionContext> aggregationsInSelect = new ArrayList<>();
+      List<Pair<FunctionContext, FilterContext>> filteredAggregations = new ArrayList<>();
       for (ExpressionContext selectExpression : queryContext._selectExpressions) {
-        getAggregations(selectExpression, aggregationsInSelect);
+        getAggregations(selectExpression, aggregationsInSelect, filteredAggregations);
       }
       for (FunctionContext function : aggregationsInSelect) {
         int functionIndex = aggregationFunctions.size();
-        aggregationFunctions.add(AggregationFunctionFactory.getAggregationFunction(function, queryContext));
+        AggregationFunction aggregationFunction =
+            AggregationFunctionFactory.getAggregationFunction(function, queryContext);
+        aggregationFunctions.add(aggregationFunction);
         aggregationFunctionIndexMap.put(function, functionIndex);
       }
+      for (Pair<FunctionContext, FilterContext> pair : filteredAggregations) {
+        AggregationFunction aggregationFunction =
+            aggregationFunctions.get(aggregationFunctionIndexMap.get(pair.getLeft()));
+        filteredAggregationFunctions.add(Pair.of(aggregationFunction, pair.getRight()));
+      }
 
       // Add aggregation functions in the HAVING clause but not in the SELECT clause
       if (queryContext._havingFilter != null) {
@@ -439,7 +462,7 @@ public class QueryContext {
       if (queryContext._orderByExpressions != null) {
         List<FunctionContext> aggregationsInOrderBy = new ArrayList<>();
         for (OrderByExpressionContext orderByExpression : queryContext._orderByExpressions) {
-          getAggregations(orderByExpression.getExpression(), aggregationsInOrderBy);
+          getAggregations(orderByExpression.getExpression(), aggregationsInOrderBy, null);
         }
         for (FunctionContext function : aggregationsInOrderBy) {
           if (!aggregationFunctionIndexMap.containsKey(function)) {
@@ -452,6 +475,7 @@ public class QueryContext {
 
       if (!aggregationFunctions.isEmpty()) {
         queryContext._aggregationFunctions = aggregationFunctions.toArray(new AggregationFunction[0]);
+        queryContext._filteredAggregationFunctions = filteredAggregationFunctions;
         queryContext._aggregationFunctionIndexMap = aggregationFunctionIndexMap;
       }
     }
@@ -459,7 +483,8 @@ public class QueryContext {
     /**
      * Helper method to extract AGGREGATION FunctionContexts from the given expression.
      */
-    private static void getAggregations(ExpressionContext expression, List<FunctionContext> aggregations) {
+    private static void getAggregations(ExpressionContext expression, List<FunctionContext> aggregations,
+        List<Pair<FunctionContext, FilterContext>> filteredAggregations) {
       FunctionContext function = expression.getFunction();
       if (function == null) {
         return;
@@ -468,9 +493,25 @@ public class QueryContext {
         // Aggregation
         aggregations.add(function);
       } else {
-        // Transform
-        for (ExpressionContext argument : function.getArguments()) {
-          getAggregations(argument, aggregations);
+        List<ExpressionContext> arguments = function.getArguments();
+        if (function.getFunctionName().equalsIgnoreCase("filter")) {
+          // Filtered aggregation
+          Preconditions.checkState(arguments.size() == 2, "FILTER must contain 2 arguments");
+          FunctionContext aggregation = arguments.get(0).getFunction();
+          Preconditions.checkState(aggregation != null && aggregation.getType() == FunctionContext.Type.AGGREGATION,
+              "First argument of FILTER must be an aggregation function");
+          ExpressionContext filterExpression = arguments.get(1);
+          Preconditions.checkState(filterExpression.getFunction() != null
+                  && filterExpression.getFunction().getType() == FunctionContext.Type.TRANSFORM,
+              "Second argument of FILTER must be a filter expression");
+          FilterContext filter = RequestContextUtils.getFilter(filterExpression);
+          aggregations.add(aggregation);
+          filteredAggregations.add(Pair.of(aggregation, filter));
+        } else {
+          // Transform
+          for (ExpressionContext argument : arguments) {
+            getAggregations(argument, aggregations, filteredAggregations);
+          }
         }
       }
     }
@@ -485,7 +526,7 @@ public class QueryContext {
           getAggregations(child, aggregations);
         }
       } else {
-        getAggregations(filter.getPredicate().getLhs(), aggregations);
+        getAggregations(filter.getPredicate().getLhs(), aggregations, null);
       }
     }
 
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java
index f312104..1f690c1 100644
--- a/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java
@@ -27,6 +27,7 @@ import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.request.BrokerRequest;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.request.context.FilterContext;
@@ -37,6 +38,7 @@ import org.apache.pinot.common.request.context.predicate.Predicate;
 import org.apache.pinot.common.request.context.predicate.RangePredicate;
 import org.apache.pinot.common.request.context.predicate.TextMatchPredicate;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.CountAggregationFunction;
 import org.apache.pinot.core.query.request.context.QueryContext;
 import org.apache.pinot.pql.parsers.Pql2Compiler;
 import org.testng.annotations.Test;
@@ -44,6 +46,7 @@ import org.testng.annotations.Test;
 import static org.testng.Assert.*;
 
 
+@SuppressWarnings("rawtypes")
 public class BrokerRequestToQueryContextConverterTest {
 
   private int getAliasCount(List<String> aliasList) {
@@ -149,8 +152,8 @@ public class BrokerRequestToQueryContextConverterTest {
         List<ExpressionContext> selectExpressions = queryContext.getSelectExpressions();
         assertEquals(selectExpressions.size(), 1);
         assertEquals(selectExpressions.get(0), ExpressionContext.forFunction(
-            new FunctionContext(FunctionContext.Type.AGGREGATION, "distinct", Arrays
-                .asList(ExpressionContext.forIdentifier("foo"), ExpressionContext.forIdentifier("bar"),
+            new FunctionContext(FunctionContext.Type.AGGREGATION, "distinct",
+                Arrays.asList(ExpressionContext.forIdentifier("foo"), ExpressionContext.forIdentifier("bar"),
                     ExpressionContext.forIdentifier("foobar")))));
         assertEquals(selectExpressions.get(0).toString(), "distinct(foo,bar,foobar)");
         assertEquals(getAliasCount(queryContext.getAliasList()), 0);
@@ -185,10 +188,11 @@ public class BrokerRequestToQueryContextConverterTest {
         List<ExpressionContext> selectExpressions = queryContext.getSelectExpressions();
         assertEquals(selectExpressions.size(), 2);
         assertEquals(selectExpressions.get(0), ExpressionContext.forFunction(
-            new FunctionContext(FunctionContext.Type.TRANSFORM, "add", Arrays
-                .asList(ExpressionContext.forIdentifier("foo"), ExpressionContext.forFunction(
-                    new FunctionContext(FunctionContext.Type.TRANSFORM, "add", Arrays
-                        .asList(ExpressionContext.forIdentifier("bar"), ExpressionContext.forLiteral("123"))))))));
+            new FunctionContext(FunctionContext.Type.TRANSFORM, "add",
+                Arrays.asList(ExpressionContext.forIdentifier("foo"), ExpressionContext.forFunction(
+                    new FunctionContext(FunctionContext.Type.TRANSFORM, "add",
+                        Arrays.asList(ExpressionContext.forIdentifier("bar"),
+                            ExpressionContext.forLiteral("123"))))))));
         assertEquals(selectExpressions.get(0).toString(), "add(foo,add(bar,'123'))");
         assertEquals(selectExpressions.get(1), ExpressionContext.forFunction(
             new FunctionContext(FunctionContext.Type.TRANSFORM, "sub",
@@ -230,8 +234,8 @@ public class BrokerRequestToQueryContextConverterTest {
         assertTrue(numSelectExpressions == 1 || numSelectExpressions == 3);
         ExpressionContext aggregationExpression = selectExpressions.get(numSelectExpressions - 1);
         assertEquals(aggregationExpression, ExpressionContext.forFunction(
-            new FunctionContext(FunctionContext.Type.AGGREGATION, "sum", Collections.singletonList(ExpressionContext
-                .forFunction(new FunctionContext(FunctionContext.Type.TRANSFORM, "add",
+            new FunctionContext(FunctionContext.Type.AGGREGATION, "sum", Collections.singletonList(
+                ExpressionContext.forFunction(new FunctionContext(FunctionContext.Type.TRANSFORM, "add",
                     Arrays.asList(ExpressionContext.forIdentifier("foo"), ExpressionContext.forIdentifier("bar"))))))));
         assertEquals(aggregationExpression.toString(), "sum(add(foo,bar))");
         if (numSelectExpressions == 3) {
@@ -257,8 +261,8 @@ public class BrokerRequestToQueryContextConverterTest {
         assertNotNull(orderByExpressions);
         assertEquals(orderByExpressions.size(), 2);
         assertEquals(orderByExpressions.get(0), new OrderByExpressionContext(ExpressionContext.forFunction(
-            new FunctionContext(FunctionContext.Type.AGGREGATION, "sum", Collections.singletonList(ExpressionContext
-                .forFunction(new FunctionContext(FunctionContext.Type.TRANSFORM, "add",
+            new FunctionContext(FunctionContext.Type.AGGREGATION, "sum", Collections.singletonList(
+                ExpressionContext.forFunction(new FunctionContext(FunctionContext.Type.TRANSFORM, "add",
                     Arrays.asList(ExpressionContext.forIdentifier("foo"), ExpressionContext.forIdentifier("bar"))))))),
             true));
         assertEquals(orderByExpressions.get(0).toString(), "sum(add(foo,bar)) ASC");
@@ -388,8 +392,8 @@ public class BrokerRequestToQueryContextConverterTest {
       assertNull(queryContext.getOrderByExpressions());
       FilterContext havingFilter = queryContext.getHavingFilter();
       assertNotNull(havingFilter);
-      assertEquals(havingFilter, new FilterContext(FilterContext.Type.PREDICATE, null, new InPredicate(ExpressionContext
-          .forFunction(new FunctionContext(FunctionContext.Type.AGGREGATION, "sum",
+      assertEquals(havingFilter, new FilterContext(FilterContext.Type.PREDICATE, null, new InPredicate(
+          ExpressionContext.forFunction(new FunctionContext(FunctionContext.Type.AGGREGATION, "sum",
               Collections.singletonList(ExpressionContext.forIdentifier("foo")))), Arrays.asList("5", "10", "15"))));
       assertEquals(havingFilter.toString(), "sum(foo) IN ('5','10','15')");
       assertEquals(queryContext.getLimit(), 10);
@@ -568,12 +572,26 @@ public class BrokerRequestToQueryContextConverterTest {
 
   private QueryContext[] getQueryContexts(String pqlQuery, String sqlQuery) {
     return new QueryContext[]{
-        QueryContextConverterUtils.getQueryContextFromPQL(pqlQuery),
-        QueryContextConverterUtils.getQueryContextFromSQL(sqlQuery)
+        QueryContextConverterUtils.getQueryContextFromPQL(pqlQuery), QueryContextConverterUtils.getQueryContextFromSQL(
+        sqlQuery)
     };
   }
 
   @Test
+  public void testFilteredAggregations() {
+    String query = "SELECT COUNT(*) FILTER(WHERE foo > 5), COUNT(*) FILTER(WHERE foo < 6) FROM testTable WHERE bar > 0";
+    QueryContext queryContext = QueryContextConverterUtils.getQueryContextFromSQL(query);
+    List<Pair<AggregationFunction, FilterContext>> filteredAggregationList =
+        queryContext.getFilteredAggregationFunctions();
+    assertNotNull(filteredAggregationList);
+    assertEquals(filteredAggregationList.size(), 2);
+    assertTrue(filteredAggregationList.get(0).getLeft() instanceof CountAggregationFunction);
+    assertEquals(filteredAggregationList.get(0).getRight().toString(), "foo > '5'");
+    assertTrue(filteredAggregationList.get(1).getLeft() instanceof CountAggregationFunction);
+    assertEquals(filteredAggregationList.get(1).getRight().toString(), "foo < '6'");
+  }
+
+  @Test
   public void testServerQueryBackwardCompatible() {
     // Backward compatible: Select query with LIMIT set only in Select
     // Presto may send a BrokerRequest with LIMIT only set in side brokerRequest.getSelections().getSize()

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org