You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ja...@apache.org on 2020/08/03 22:24:29 UTC

[incubator-pinot] branch master updated: Add SegmentPartitionedDistinctCount aggregation function (#5786)

This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 657e245  Add SegmentPartitionedDistinctCount aggregation function (#5786)
657e245 is described below

commit 657e2452176c45edfe71a3c69d60ce3b7cec6982
Author: Xiaotian (Jackie) Jiang <17...@users.noreply.github.com>
AuthorDate: Mon Aug 3 15:24:16 2020 -0700

    Add SegmentPartitionedDistinctCount aggregation function (#5786)
    
    Add a new `SegmentPartitionedDistinctCountAggregationFunction` to calculate the number of distinct values when values are partitioned for each segment.
    This function calculates the exact number of distinct values (using raw value instead of hash code) within the segment, then simply sums up the results from different segments to get the final result.
---
 .../common/function/AggregationFunctionType.java   |   1 +
 .../query/DictionaryBasedAggregationOperator.java  |   3 +
 .../core/plan/maker/InstancePlanMakerImplV2.java   |   8 +-
 .../function/AggregationFunctionFactory.java       |   2 +
 .../function/AggregationFunctionUtils.java         |  14 +-
 .../function/AggregationFunctionVisitorBase.java   |   4 +-
 ...artitionedDistinctCountAggregationFunction.java | 425 +++++++++++++++++++++
 ...SegmentPartitionedDistinctCountQueriesTest.java | 253 ++++++++++++
 8 files changed, 699 insertions(+), 11 deletions(-)

diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java b/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
index cf48ea6..62704db 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
@@ -28,6 +28,7 @@ public enum AggregationFunctionType {
   MINMAXRANGE("minMaxRange"),
   DISTINCTCOUNT("distinctCount"),
   DISTINCTCOUNTBITMAP("distinctCountBitmap"),
+  SEGMENTPARTITIONEDDISTINCTCOUNT("segmentPartitionedDistinctCount"),
   DISTINCTCOUNTHLL("distinctCountHLL"),
   DISTINCTCOUNTRAWHLL("distinctCountRawHLL"),
   FASTHLL("fastHLL"),
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java
index 35abe86..7fa6798 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java
@@ -114,6 +114,9 @@ public class DictionaryBasedAggregationOperator extends BaseOperator<Intermediat
           }
           aggregationResults.add(set);
           break;
+        case SEGMENTPARTITIONEDDISTINCTCOUNT:
+          aggregationResults.add((long) dictionarySize);
+          break;
         default:
           throw new IllegalStateException(
               "Dictionary based aggregation operator does not support function type: " + aggregationFunction.getType());
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/maker/InstancePlanMakerImplV2.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/maker/InstancePlanMakerImplV2.java
index d7f8bad..caae551 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/maker/InstancePlanMakerImplV2.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/maker/InstancePlanMakerImplV2.java
@@ -158,7 +158,7 @@ public class InstancePlanMakerImplV2 implements PlanMaker {
   /**
    * Returns {@code true} if the given aggregation-only without filter QueryContext can be solved with dictionary,
    * {@code false} otherwise.
-   * <p>Aggregations supported: MIN, MAX, MINMAXRANGE, DISTINCTCOUNT
+   * <p>Aggregations supported: MIN, MAX, MIN_MAX_RANGE, DISTINCT_COUNT, SEGMENT_PARTITIONED_DISTINCT_COUNT
    */
   @VisibleForTesting
   static boolean isFitForDictionaryBasedPlan(QueryContext queryContext, IndexSegment indexSegment) {
@@ -179,8 +179,10 @@ public class InstancePlanMakerImplV2 implements PlanMaker {
       if (dictionary == null) {
         return false;
       }
-      // NOTE: DISTINCTCOUNT does not require sorted dictionary
-      if (!dictionary.isSorted() && !functionName.equalsIgnoreCase(AggregationFunctionType.DISTINCTCOUNT.name())) {
+      // TODO: Remove this check because MutableDictionary maintains min/max value
+      // NOTE: DISTINCT_COUNT and SEGMENT_PARTITIONED_DISTINCT_COUNT does not require sorted dictionary
+      if (!dictionary.isSorted() && !functionName.equalsIgnoreCase(AggregationFunctionType.DISTINCTCOUNT.name())
+          && !functionName.equalsIgnoreCase(AggregationFunctionType.SEGMENTPARTITIONEDDISTINCTCOUNT.name())) {
         return false;
       }
     }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index 3a1bb01..08eed30 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -123,6 +123,8 @@ public class AggregationFunctionFactory {
             return new DistinctCountAggregationFunction(firstArgument);
           case DISTINCTCOUNTBITMAP:
             return new DistinctCountBitmapAggregationFunction(firstArgument);
+          case SEGMENTPARTITIONEDDISTINCTCOUNT:
+            return new SegmentPartitionedDistinctCountAggregationFunction(firstArgument);
           case DISTINCTCOUNTHLL:
             return new DistinctCountHLLAggregationFunction(arguments);
           case DISTINCTCOUNTRAWHLL:
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
index 3822373..b5784fb 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
@@ -209,12 +209,12 @@ public class AggregationFunctionUtils {
   }
 
   public static boolean isFitForDictionaryBasedComputation(String functionName) {
-    if (functionName.equalsIgnoreCase(AggregationFunctionType.MIN.name()) ||  //
-        functionName.equalsIgnoreCase(AggregationFunctionType.MAX.name()) || //
-        functionName.equalsIgnoreCase(AggregationFunctionType.MINMAXRANGE.name()) || //
-        functionName.equalsIgnoreCase(AggregationFunctionType.DISTINCTCOUNT.name())) {
-      return true;
-    }
-    return false;
+    //@formatter:off
+    return functionName.equalsIgnoreCase(AggregationFunctionType.MIN.name())
+        || functionName.equalsIgnoreCase(AggregationFunctionType.MAX.name())
+        || functionName.equalsIgnoreCase(AggregationFunctionType.MINMAXRANGE.name())
+        || functionName.equalsIgnoreCase(AggregationFunctionType.DISTINCTCOUNT.name())
+        || functionName.equalsIgnoreCase(AggregationFunctionType.SEGMENTPARTITIONEDDISTINCTCOUNT.name());
+    //@formatter:on
   }
 }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionVisitorBase.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionVisitorBase.java
index 72c8d4c..2710ba7 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionVisitorBase.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionVisitorBase.java
@@ -51,6 +51,9 @@ public class AggregationFunctionVisitorBase {
   public void visit(DistinctCountBitmapMVAggregationFunction function) {
   }
 
+  public void visit(SegmentPartitionedDistinctCountAggregationFunction function) {
+  }
+
   public void visit(DistinctCountHLLAggregationFunction function) {
   }
 
@@ -106,7 +109,6 @@ public class AggregationFunctionVisitorBase {
   }
 
   public void visit(StUnionAggregationFunction function) {
-
   }
 }
 
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SegmentPartitionedDistinctCountAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SegmentPartitionedDistinctCountAggregationFunction.java
new file mode 100644
index 0000000..221969d
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SegmentPartitionedDistinctCountAggregationFunction.java
@@ -0,0 +1,425 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function;
+
+import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet;
+import it.unimi.dsi.fastutil.floats.FloatOpenHashSet;
+import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
+import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet;
+import java.util.Collection;
+import java.util.Map;
+import javax.annotation.Nullable;
+import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
+import org.apache.pinot.core.query.request.context.ExpressionContext;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.apache.pinot.spi.utils.ByteArray;
+import org.roaringbitmap.RoaringBitmap;
+
+
+/**
+ * The {@code SegmentPartitionedDistinctCountAggregationFunction} calculates the number of distinct values for a given
+ * single-value expression.
+ * <p>IMPORTANT: This function relies on the expression values being partitioned for each segment, where there is no
+ * common values within different segments.
+ * <p>This function calculates the exact number of distinct values within the segment, then simply sums up the results
+ * from different segments to get the final result.
+ */
+public class SegmentPartitionedDistinctCountAggregationFunction extends BaseSingleInputAggregationFunction<Long, Long> {
+
+  public SegmentPartitionedDistinctCountAggregationFunction(ExpressionContext expression) {
+    super(expression);
+  }
+
+  @Override
+  public AggregationFunctionType getType() {
+    return AggregationFunctionType.SEGMENTPARTITIONEDDISTINCTCOUNT;
+  }
+
+  @Override
+  public void accept(AggregationFunctionVisitorBase visitor) {
+    visitor.visit(this);
+  }
+
+  @Override
+  public AggregationResultHolder createAggregationResultHolder() {
+    return new ObjectAggregationResultHolder();
+  }
+
+  @Override
+  public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) {
+    return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
+  }
+
+  @Override
+  public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+      Map<ExpressionContext, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
+
+    // For dictionary-encoded expression, store dictionary ids into a RoaringBitmap
+    if (blockValSet.getDictionary() != null) {
+      int[] dictIds = blockValSet.getDictionaryIdsSV();
+      RoaringBitmap bitmap = aggregationResultHolder.getResult();
+      if (bitmap == null) {
+        bitmap = new RoaringBitmap();
+        aggregationResultHolder.setValue(bitmap);
+      }
+      bitmap.addN(dictIds, 0, length);
+      return;
+    }
+
+    // For non-dictionary-encoded expression, store INT values into a RoaringBitmap, other types into an OpenHashSet
+    DataType valueType = blockValSet.getValueType();
+    switch (valueType) {
+      case INT:
+        int[] intValues = blockValSet.getIntValuesSV();
+        RoaringBitmap bitmap = aggregationResultHolder.getResult();
+        if (bitmap == null) {
+          bitmap = new RoaringBitmap();
+          aggregationResultHolder.setValue(bitmap);
+        }
+        bitmap.addN(intValues, 0, length);
+        break;
+      case LONG:
+        long[] longValues = blockValSet.getLongValuesSV();
+        LongOpenHashSet longSet = aggregationResultHolder.getResult();
+        if (longSet == null) {
+          longSet = new LongOpenHashSet();
+          aggregationResultHolder.setValue(longSet);
+        }
+        for (int i = 0; i < length; i++) {
+          longSet.add(longValues[i]);
+        }
+        break;
+      case FLOAT:
+        float[] floatValues = blockValSet.getFloatValuesSV();
+        FloatOpenHashSet floatSet = aggregationResultHolder.getResult();
+        if (floatSet == null) {
+          floatSet = new FloatOpenHashSet();
+          aggregationResultHolder.setValue(floatSet);
+        }
+        for (int i = 0; i < length; i++) {
+          floatSet.add(floatValues[i]);
+        }
+        break;
+      case DOUBLE:
+        double[] doubleValues = blockValSet.getDoubleValuesSV();
+        DoubleOpenHashSet doubleSet = aggregationResultHolder.getResult();
+        if (doubleSet == null) {
+          doubleSet = new DoubleOpenHashSet();
+          aggregationResultHolder.setValue(doubleSet);
+        }
+        for (int i = 0; i < length; i++) {
+          doubleSet.add(doubleValues[i]);
+        }
+        break;
+      case STRING:
+        String[] stringValues = blockValSet.getStringValuesSV();
+        ObjectOpenHashSet<String> stringSet = aggregationResultHolder.getResult();
+        if (stringSet == null) {
+          stringSet = new ObjectOpenHashSet<>();
+          aggregationResultHolder.setValue(stringSet);
+        }
+        //noinspection ManualArrayToCollectionCopy
+        for (int i = 0; i < length; i++) {
+          stringSet.add(stringValues[i]);
+        }
+        break;
+      case BYTES:
+        byte[][] bytesValues = blockValSet.getBytesValuesSV();
+        ObjectOpenHashSet<ByteArray> bytesSet = aggregationResultHolder.getResult();
+        if (bytesSet == null) {
+          bytesSet = new ObjectOpenHashSet<>();
+          aggregationResultHolder.setValue(bytesSet);
+        }
+        for (int i = 0; i < length; i++) {
+          bytesSet.add(new ByteArray(bytesValues[i]));
+        }
+        break;
+      default:
+        throw new IllegalStateException(
+            "Illegal data type for PARTITIONED_DISTINCT_COUNT aggregation function: " + valueType);
+    }
+  }
+
+  @Override
+  public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
+      Map<ExpressionContext, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
+
+    // For dictionary-encoded expression, store dictionary ids into a RoaringBitmap
+    if (blockValSet.getDictionary() != null) {
+      int[] dictIds = blockValSet.getDictionaryIdsSV();
+      for (int i = 0; i < length; i++) {
+        setIntValueForGroup(groupByResultHolder, groupKeyArray[i], dictIds[i]);
+      }
+      return;
+    }
+
+    // For non-dictionary-encoded expression, store INT values into a RoaringBitmap, other types into an OpenHashSet
+    DataType valueType = blockValSet.getValueType();
+    switch (valueType) {
+      case INT:
+        int[] intValues = blockValSet.getIntValuesSV();
+        for (int i = 0; i < length; i++) {
+          setIntValueForGroup(groupByResultHolder, groupKeyArray[i], intValues[i]);
+        }
+        break;
+      case LONG:
+        long[] longValues = blockValSet.getLongValuesSV();
+        for (int i = 0; i < length; i++) {
+          setLongValueForGroup(groupByResultHolder, groupKeyArray[i], longValues[i]);
+        }
+        break;
+      case FLOAT:
+        float[] floatValues = blockValSet.getFloatValuesSV();
+        for (int i = 0; i < length; i++) {
+          setFloatValueForGroup(groupByResultHolder, groupKeyArray[i], floatValues[i]);
+        }
+        break;
+      case DOUBLE:
+        double[] doubleValues = blockValSet.getDoubleValuesSV();
+        for (int i = 0; i < length; i++) {
+          setDoubleValueForGroup(groupByResultHolder, groupKeyArray[i], doubleValues[i]);
+        }
+        break;
+      case STRING:
+        String[] stringValues = blockValSet.getStringValuesSV();
+        for (int i = 0; i < length; i++) {
+          setStringValueForGroup(groupByResultHolder, groupKeyArray[i], stringValues[i]);
+        }
+        break;
+      case BYTES:
+        byte[][] bytesValues = blockValSet.getBytesValuesSV();
+        for (int i = 0; i < length; i++) {
+          setBytesValueForGroup(groupByResultHolder, groupKeyArray[i], new ByteArray(bytesValues[i]));
+        }
+        break;
+      default:
+        throw new IllegalStateException(
+            "Illegal data type for PARTITIONED_DISTINCT_COUNT aggregation function: " + valueType);
+    }
+  }
+
+  @Override
+  public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
+      Map<ExpressionContext, BlockValSet> blockValSetMap) {
+    BlockValSet blockValSet = blockValSetMap.get(_expression);
+
+    // For dictionary-encoded expression, store dictionary ids into a RoaringBitmap
+    if (blockValSet.getDictionary() != null) {
+      int[] dictIds = blockValSet.getDictionaryIdsSV();
+      for (int i = 0; i < length; i++) {
+        int dictId = dictIds[i];
+        for (int groupKey : groupKeysArray[i]) {
+          setIntValueForGroup(groupByResultHolder, groupKey, dictId);
+        }
+      }
+      return;
+    }
+
+    // For non-dictionary-encoded expression, store INT values into a RoaringBitmap, other types into an OpenHashSet
+    DataType valueType = blockValSet.getValueType();
+    switch (valueType) {
+      case INT:
+        int[] intValues = blockValSet.getIntValuesSV();
+        for (int i = 0; i < length; i++) {
+          int value = intValues[i];
+          for (int groupKey : groupKeysArray[i]) {
+            setIntValueForGroup(groupByResultHolder, groupKey, value);
+          }
+        }
+        break;
+      case LONG:
+        long[] longValues = blockValSet.getLongValuesSV();
+        for (int i = 0; i < length; i++) {
+          long value = longValues[i];
+          for (int groupKey : groupKeysArray[i]) {
+            setLongValueForGroup(groupByResultHolder, groupKey, value);
+          }
+        }
+        break;
+      case FLOAT:
+        float[] floatValues = blockValSet.getFloatValuesSV();
+        for (int i = 0; i < length; i++) {
+          float value = floatValues[i];
+          for (int groupKey : groupKeysArray[i]) {
+            setFloatValueForGroup(groupByResultHolder, groupKey, value);
+          }
+        }
+        break;
+      case DOUBLE:
+        double[] doubleValues = blockValSet.getDoubleValuesSV();
+        for (int i = 0; i < length; i++) {
+          double value = doubleValues[i];
+          for (int groupKey : groupKeysArray[i]) {
+            setDoubleValueForGroup(groupByResultHolder, groupKey, value);
+          }
+        }
+        break;
+      case STRING:
+        String[] stringValues = blockValSet.getStringValuesSV();
+        for (int i = 0; i < length; i++) {
+          String value = stringValues[i];
+          for (int groupKey : groupKeysArray[i]) {
+            setStringValueForGroup(groupByResultHolder, groupKey, value);
+          }
+        }
+        break;
+      case BYTES:
+        byte[][] bytesValues = blockValSet.getBytesValuesSV();
+        for (int i = 0; i < length; i++) {
+          ByteArray value = new ByteArray(bytesValues[i]);
+          for (int groupKey : groupKeysArray[i]) {
+            setBytesValueForGroup(groupByResultHolder, groupKey, value);
+          }
+        }
+        break;
+      default:
+        throw new IllegalStateException(
+            "Illegal data type for PARTITIONED_DISTINCT_COUNT aggregation function: " + valueType);
+    }
+  }
+
+  @Override
+  public Long extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
+    return extractIntermediateResult(aggregationResultHolder.getResult());
+  }
+
+  @Override
+  public Long extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) {
+    return extractIntermediateResult(groupByResultHolder.getResult(groupKey));
+  }
+
+  @Override
+  public Long merge(Long intermediateResult1, Long intermediateResult2) {
+    return intermediateResult1 + intermediateResult2;
+  }
+
+  @Override
+  public boolean isIntermediateResultComparable() {
+    return true;
+  }
+
+  @Override
+  public ColumnDataType getIntermediateResultColumnType() {
+    return ColumnDataType.LONG;
+  }
+
+  @Override
+  public ColumnDataType getFinalResultColumnType() {
+    return ColumnDataType.LONG;
+  }
+
+  @Override
+  public Long extractFinalResult(Long intermediateResult) {
+    return intermediateResult;
+  }
+
+  /**
+   * Helper method to set an INT value for the given group key into the result holder.
+   */
+  private static void setIntValueForGroup(GroupByResultHolder groupByResultHolder, int groupKey, int value) {
+    RoaringBitmap bitmap = groupByResultHolder.getResult(groupKey);
+    if (bitmap == null) {
+      bitmap = new RoaringBitmap();
+      groupByResultHolder.setValueForKey(groupKey, bitmap);
+    }
+    bitmap.add(value);
+  }
+
+  /**
+   * Helper method to set an LONG value for the given group key into the result holder.
+   */
+  private static void setLongValueForGroup(GroupByResultHolder groupByResultHolder, int groupKey, long value) {
+    LongOpenHashSet longSet = groupByResultHolder.getResult(groupKey);
+    if (longSet == null) {
+      longSet = new LongOpenHashSet();
+      groupByResultHolder.setValueForKey(groupKey, longSet);
+    }
+    longSet.add(value);
+  }
+
+  /**
+   * Helper method to set an FLOAT value for the given group key into the result holder.
+   */
+  private static void setFloatValueForGroup(GroupByResultHolder groupByResultHolder, int groupKey, float value) {
+    FloatOpenHashSet floatSet = groupByResultHolder.getResult(groupKey);
+    if (floatSet == null) {
+      floatSet = new FloatOpenHashSet();
+      groupByResultHolder.setValueForKey(groupKey, floatSet);
+    }
+    floatSet.add(value);
+  }
+
+  /**
+   * Helper method to set an DOUBLE value for the given group key into the result holder.
+   */
+  private static void setDoubleValueForGroup(GroupByResultHolder groupByResultHolder, int groupKey, double value) {
+    DoubleOpenHashSet doubleSet = groupByResultHolder.getResult(groupKey);
+    if (doubleSet == null) {
+      doubleSet = new DoubleOpenHashSet();
+      groupByResultHolder.setValueForKey(groupKey, doubleSet);
+    }
+    doubleSet.add(value);
+  }
+
+  /**
+   * Helper method to set an STRING value for the given group key into the result holder.
+   */
+  private static void setStringValueForGroup(GroupByResultHolder groupByResultHolder, int groupKey, String value) {
+    ObjectOpenHashSet<String> stringSet = groupByResultHolder.getResult(groupKey);
+    if (stringSet == null) {
+      stringSet = new ObjectOpenHashSet<>();
+      groupByResultHolder.setValueForKey(groupKey, stringSet);
+    }
+    stringSet.add(value);
+  }
+
+  /**
+   * Helper method to set an BYTES value for the given group key into the result holder.
+   */
+  private static void setBytesValueForGroup(GroupByResultHolder groupByResultHolder, int groupKey, ByteArray value) {
+    ObjectOpenHashSet<ByteArray> bytesSet = groupByResultHolder.getResult(groupKey);
+    if (bytesSet == null) {
+      bytesSet = new ObjectOpenHashSet<>();
+      groupByResultHolder.setValueForKey(groupKey, bytesSet);
+    }
+    bytesSet.add(value);
+  }
+
+  /**
+   * Helper method to extract segment level intermediate result from the inner segment result.
+   */
+  private static long extractIntermediateResult(@Nullable Object result) {
+    if (result == null) {
+      return 0L;
+    }
+    if (result instanceof RoaringBitmap) {
+      return ((RoaringBitmap) result).getLongCardinality();
+    }
+    assert result instanceof Collection;
+    return ((Collection<?>) result).size();
+  }
+}
diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/SegmentPartitionedDistinctCountQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/SegmentPartitionedDistinctCountQueriesTest.java
new file mode 100644
index 0000000..bef3e57
--- /dev/null
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/SegmentPartitionedDistinctCountQueriesTest.java
@@ -0,0 +1,253 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.queries;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import org.apache.commons.io.FileUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.pinot.common.response.broker.AggregationResult;
+import org.apache.pinot.common.response.broker.BrokerResponseNative;
+import org.apache.pinot.common.response.broker.GroupByResult;
+import org.apache.pinot.common.segment.ReadMode;
+import org.apache.pinot.common.utils.HashUtil;
+import org.apache.pinot.common.utils.StringUtil;
+import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.core.data.readers.GenericRowRecordReader;
+import org.apache.pinot.core.indexsegment.IndexSegment;
+import org.apache.pinot.core.indexsegment.generator.SegmentGeneratorConfig;
+import org.apache.pinot.core.indexsegment.immutable.ImmutableSegment;
+import org.apache.pinot.core.indexsegment.immutable.ImmutableSegmentLoader;
+import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
+import org.apache.pinot.core.operator.query.AggregationGroupByOperator;
+import org.apache.pinot.core.operator.query.AggregationOperator;
+import org.apache.pinot.core.operator.query.DictionaryBasedAggregationOperator;
+import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
+import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
+import org.apache.pinot.core.segment.creator.impl.SegmentIndexCreationDriverImpl;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.apache.pinot.spi.data.Schema;
+import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
+import org.testng.Assert;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertTrue;
+
+
+/**
+ * Queries test for SEGMENT_PARTITIONED_DISTINCT_COUNT queries.
+ */
+@SuppressWarnings("rawtypes")
+public class SegmentPartitionedDistinctCountQueriesTest extends BaseQueriesTest {
+  private static final File INDEX_DIR =
+      new File(FileUtils.getTempDirectory(), "SegmentPartitionedDistinctCountQueriesTest");
+  private static final String RAW_TABLE_NAME = "testTable";
+  private static final String SEGMENT_NAME = "testSegment";
+  private static final Random RANDOM = new Random();
+
+  private static final int NUM_RECORDS = 2000;
+  private static final int MAX_VALUE = 1000;
+
+  private static final String INT_COLUMN = "intColumn";
+  private static final String LONG_COLUMN = "longColumn";
+  private static final String FLOAT_COLUMN = "floatColumn";
+  private static final String DOUBLE_COLUMN = "doubleColumn";
+  private static final String STRING_COLUMN = "stringColumn";
+  private static final String BYTES_COLUMN = "bytesColumn";
+  private static final Schema SCHEMA = new Schema.SchemaBuilder().addSingleValueDimension(INT_COLUMN, DataType.INT)
+      .addSingleValueDimension(LONG_COLUMN, DataType.LONG).addSingleValueDimension(FLOAT_COLUMN, DataType.FLOAT)
+      .addSingleValueDimension(DOUBLE_COLUMN, DataType.DOUBLE).addSingleValueDimension(STRING_COLUMN, DataType.STRING)
+      .addSingleValueDimension(BYTES_COLUMN, DataType.BYTES).build();
+  private static final TableConfig TABLE_CONFIG =
+      new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+
+  private Set<Integer> _values;
+  private long _expectedResult;
+  private IndexSegment _indexSegment;
+  private List<IndexSegment> _indexSegments;
+
+  @Override
+  protected String getFilter() {
+    // NOTE: Use a match all filter to switch between DictionaryBasedAggregationOperator and AggregationOperator
+    return " WHERE intColumn >= 0";
+  }
+
+  @Override
+  protected IndexSegment getIndexSegment() {
+    return _indexSegment;
+  }
+
+  @Override
+  protected List<IndexSegment> getIndexSegments() {
+    return _indexSegments;
+  }
+
+  @BeforeClass
+  public void setUp()
+      throws Exception {
+    FileUtils.deleteDirectory(INDEX_DIR);
+
+    List<GenericRow> records = new ArrayList<>(NUM_RECORDS);
+    int hashMapCapacity = HashUtil.getHashMapCapacity(MAX_VALUE);
+    _values = new HashSet<>(hashMapCapacity);
+    for (int i = 0; i < NUM_RECORDS; i++) {
+      int value = RANDOM.nextInt(MAX_VALUE);
+      GenericRow record = new GenericRow();
+      record.putValue(INT_COLUMN, value);
+      _values.add(Integer.hashCode(value));
+      record.putValue(LONG_COLUMN, (long) value);
+      record.putValue(FLOAT_COLUMN, (float) value);
+      record.putValue(DOUBLE_COLUMN, (double) value);
+      String stringValue = Integer.toString(value);
+      record.putValue(STRING_COLUMN, stringValue);
+      // NOTE: Create fixed-length bytes so that dictionary can be generated
+      byte[] bytesValue = StringUtil.encodeUtf8(StringUtils.leftPad(stringValue, 3, '0'));
+      record.putValue(BYTES_COLUMN, bytesValue);
+      records.add(record);
+    }
+    _expectedResult = _values.size();
+
+    SegmentGeneratorConfig segmentGeneratorConfig = new SegmentGeneratorConfig(TABLE_CONFIG, SCHEMA);
+    segmentGeneratorConfig.setTableName(RAW_TABLE_NAME);
+    segmentGeneratorConfig.setSegmentName(SEGMENT_NAME);
+    segmentGeneratorConfig.setOutDir(INDEX_DIR.getPath());
+
+    SegmentIndexCreationDriverImpl driver = new SegmentIndexCreationDriverImpl();
+    driver.init(segmentGeneratorConfig, new GenericRowRecordReader(records));
+    driver.build();
+
+    ImmutableSegment immutableSegment = ImmutableSegmentLoader.load(new File(INDEX_DIR, SEGMENT_NAME), ReadMode.mmap);
+    _indexSegment = immutableSegment;
+    _indexSegments = Arrays.asList(immutableSegment, immutableSegment);
+  }
+
+  @Test
+  public void testAggregationOnly() {
+    String query =
+        "SELECT SEGMENTPARTITIONEDDISTINCTCOUNT(intColumn), SEGMENTPARTITIONEDDISTINCTCOUNT(longColumn), SEGMENTPARTITIONEDDISTINCTCOUNT(floatColumn), SEGMENTPARTITIONEDDISTINCTCOUNT(doubleColumn), SEGMENTPARTITIONEDDISTINCTCOUNT(stringColumn), SEGMENTPARTITIONEDDISTINCTCOUNT(bytesColumn) FROM testTable";
+
+    // Inner segment
+    Operator operator = getOperatorForPqlQuery(query);
+    assertTrue(operator instanceof DictionaryBasedAggregationOperator);
+    IntermediateResultsBlock resultsBlock = ((DictionaryBasedAggregationOperator) operator).nextBlock();
+    QueriesTestUtils
+        .testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, 0, NUM_RECORDS);
+    List<Object> aggregationResult = resultsBlock.getAggregationResult();
+
+    operator = getOperatorForPqlQueryWithFilter(query);
+    assertTrue(operator instanceof AggregationOperator);
+    IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator) operator).nextBlock();
+    QueriesTestUtils
+        .testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, 6 * NUM_RECORDS,
+            NUM_RECORDS);
+    List<Object> aggregationResultWithFilter = resultsBlockWithFilter.getAggregationResult();
+
+    assertNotNull(aggregationResult);
+    assertNotNull(aggregationResultWithFilter);
+    assertEquals(aggregationResult, aggregationResultWithFilter);
+    for (int i = 0; i < 6; i++) {
+      assertEquals((long) aggregationResult.get(i), _expectedResult);
+    }
+
+    // Inter segments (expect 4 * inner segment result)
+    String[] expectedResults = new String[6];
+    for (int i = 0; i < 6; i++) {
+      expectedResults[i] = Long.toString(4 * _expectedResult);
+    }
+    BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+    QueriesTestUtils
+        .testInterSegmentAggregationResult(brokerResponse, 4 * NUM_RECORDS, 0, 0, 4 * NUM_RECORDS, expectedResults);
+    brokerResponse = getBrokerResponseForPqlQueryWithFilter(query);
+    QueriesTestUtils
+        .testInterSegmentAggregationResult(brokerResponse, 4 * NUM_RECORDS, 0, 4 * 6 * NUM_RECORDS, 4 * NUM_RECORDS,
+            expectedResults);
+  }
+
+  @Test
+  public void testAggregationGroupBy() {
+    String query =
+        "SELECT SEGMENTPARTITIONEDDISTINCTCOUNT(intColumn), SEGMENTPARTITIONEDDISTINCTCOUNT(longColumn), SEGMENTPARTITIONEDDISTINCTCOUNT(floatColumn), SEGMENTPARTITIONEDDISTINCTCOUNT(doubleColumn), SEGMENTPARTITIONEDDISTINCTCOUNT(stringColumn), SEGMENTPARTITIONEDDISTINCTCOUNT(bytesColumn) FROM testTable GROUP BY intColumn";
+
+    // Inner segment
+    Operator operator = getOperatorForPqlQuery(query);
+    assertTrue(operator instanceof AggregationGroupByOperator);
+    IntermediateResultsBlock resultsBlock = ((AggregationGroupByOperator) operator).nextBlock();
+    QueriesTestUtils
+        .testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, 6 * NUM_RECORDS,
+            NUM_RECORDS);
+    AggregationGroupByResult aggregationGroupByResult = resultsBlock.getAggregationGroupByResult();
+    assertNotNull(aggregationGroupByResult);
+    int numGroups = 0;
+    Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator = aggregationGroupByResult.getGroupKeyIterator();
+    while (groupKeyIterator.hasNext()) {
+      numGroups++;
+      GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next();
+      assertTrue(_values.contains(Integer.parseInt(groupKey._stringKey)));
+      for (int i = 0; i < 6; i++) {
+        assertEquals((long) aggregationGroupByResult.getResultForKey(groupKey, i), 1);
+      }
+    }
+    assertEquals(numGroups, _values.size());
+
+    // Inter segments (expect 4 * inner segment result)
+    BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+    Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS);
+    Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0);
+    Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 6 * NUM_RECORDS);
+    Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS);
+    // size of this array will be equal to number of aggregation functions since
+    // we return each aggregation function separately
+    List<AggregationResult> aggregationResults = brokerResponse.getAggregationResults();
+    int numAggregationColumns = aggregationResults.size();
+    Assert.assertEquals(numAggregationColumns, 6);
+    for (AggregationResult aggregationResult : aggregationResults) {
+      Assert.assertNull(aggregationResult.getValue());
+      List<GroupByResult> groupByResults = aggregationResult.getGroupByResult();
+      numGroups = groupByResults.size();
+      for (int i = 0; i < numGroups; i++) {
+        GroupByResult groupByResult = groupByResults.get(i);
+        List<String> group = groupByResult.getGroup();
+        assertEquals(group.size(), 1);
+        assertTrue(_values.contains(Integer.parseInt(group.get(0))));
+        assertEquals(groupByResult.getValue(), Long.toString(4));
+      }
+    }
+  }
+
+  @AfterClass
+  public void tearDown()
+      throws IOException {
+    _indexSegment.destroy();
+    FileUtils.deleteDirectory(INDEX_DIR);
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org