You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by xi...@apache.org on 2020/09/30 21:45:21 UTC

[incubator-pinot] 01/01: Adding array transform functions: array_average, array_max, array_min, array_sum

This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch support_array_transform_function_phase1
in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git

commit c5e290be29335dd763ccee11f019546247e5d8f2
Author: Xiang Fu <fx...@gmail.com>
AuthorDate: Wed Sep 30 14:44:28 2020 -0700

    Adding array transform functions: array_average, array_max, array_min, array_sum
---
 .../common/function/TransformFunctionType.java     |   4 +
 .../function/ArrayAverageTransformFunction.java    | 127 +++++++++++++++
 .../function/ArrayMaxTransformFunction.java        | 179 +++++++++++++++++++++
 .../function/ArrayMinTransformFunction.java        | 178 ++++++++++++++++++++
 .../function/ArraySumTransformFunction.java        | 148 +++++++++++++++++
 .../function/TransformFunctionFactory.java         |   6 +
 .../ArrayAverageTransformFunctionTest.java         |  49 ++++++
 .../function/ArrayBaseTransformFunctionTest.java   |  97 +++++++++++
 .../function/ArrayLengthTransformFunctionTest.java |  38 ++---
 .../function/ArrayMaxTransformFunctionTest.java    |  49 ++++++
 .../function/ArrayMinTransformFunctionTest.java    |  49 ++++++
 .../function/ArraySumTransformFunctionTest.java    |  57 +++++++
 12 files changed, 957 insertions(+), 24 deletions(-)

diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
index 1347f1d..0dea6d2 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
@@ -53,6 +53,10 @@ public enum TransformFunctionType {
   DATETIMECONVERT("dateTimeConvert"),
   DATETRUNC("dateTrunc"),
   ARRAYLENGTH("arrayLength"),
+  ARRAY_AVERAGE("array_average"),
+  ARRAY_MIN("array_min"),
+  ARRAY_MAX("array_max"),
+  ARRAY_SUM("array_sum"),
   VALUEIN("valueIn"),
   MAPVALUE("mapValue"),
   INIDSET("inIdSet"),
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayAverageTransformFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayAverageTransformFunction.java
new file mode 100644
index 0000000..e0610ed
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayAverageTransformFunction.java
@@ -0,0 +1,127 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.core.common.DataSource;
+import org.apache.pinot.core.operator.blocks.ProjectionBlock;
+import org.apache.pinot.core.operator.transform.TransformResultMetadata;
+import org.apache.pinot.core.plan.DocIdSetPlanNode;
+
+
+/**
+ * The ArrayAverageTransformFunction class implements array_average function for multi-valued columns
+ *
+ * Sample queries:
+ * SELECT COUNT(*) FROM table WHERE array_average(mvColumn) > 2
+ * SELECT COUNT(*) FROM table GROUP BY array_average(mvColumn)
+ * SELECT SUM(array_average(mvColumn)) FROM table
+ */
+public class ArrayAverageTransformFunction extends BaseTransformFunction {
+  public static final String FUNCTION_NAME = "array_average";
+
+  private double[] _results;
+  private TransformFunction _argument;
+
+  @Override
+  public String getName() {
+    return FUNCTION_NAME;
+  }
+
+  @Override
+  public void init(List<TransformFunction> arguments, Map<String, DataSource> dataSourceMap) {
+    // Check that there is only 1 argument
+    if (arguments.size() != 1) {
+      throw new IllegalArgumentException("Exactly 1 argument is required for ARRAY_AVERAGE transform function");
+    }
+
+    // Check that the argument is a multi-valued column or transform function
+    TransformFunction firstArgument = arguments.get(0);
+    if (firstArgument instanceof LiteralTransformFunction || firstArgument.getResultMetadata().isSingleValue()) {
+      throw new IllegalArgumentException(
+          "The argument of ARRAY_AVERAGE transform function must be a multi-valued column or a transform function");
+    }
+    if (!firstArgument.getResultMetadata().getDataType().isNumeric()) {
+      throw new IllegalArgumentException(
+          "The argument of ARRAY_AVERAGE transform function must be numeric");
+    }
+    _argument = firstArgument;
+  }
+
+  @Override
+  public TransformResultMetadata getResultMetadata() {
+    return DOUBLE_SV_NO_DICTIONARY_METADATA;
+  }
+
+  @Override
+  public double[] transformToDoubleValuesSV(ProjectionBlock projectionBlock) {
+    if (_results == null) {
+      _results = new double[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+
+    int numDocs = projectionBlock.getNumDocs();
+    double sumRes;
+    switch (_argument.getResultMetadata().getDataType()) {
+      case INT:
+        int[][] intValuesMV = _argument.transformToIntValuesMV(projectionBlock);
+        for (int i = 0; i < numDocs; i++) {
+          sumRes = 0;
+          for (int j = 0; j < intValuesMV[i].length; j++) {
+            sumRes += intValuesMV[i][j];
+          }
+          _results[i] = sumRes / intValuesMV[i].length;
+        }
+        break;
+      case LONG:
+        long[][] longValuesMV = _argument.transformToLongValuesMV(projectionBlock);
+        for (int i = 0; i < numDocs; i++) {
+          sumRes = 0;
+          for (int j = 0; j < longValuesMV[i].length; j++) {
+            sumRes += longValuesMV[i][j];
+          }
+          _results[i] = sumRes / longValuesMV[i].length;
+        }
+        break;
+      case FLOAT:
+        float[][] floatValuesMV = _argument.transformToFloatValuesMV(projectionBlock);
+        for (int i = 0; i < numDocs; i++) {
+          sumRes = 0;
+          for (int j = 0; j < floatValuesMV[i].length; j++) {
+            sumRes += floatValuesMV[i][j];
+          }
+          _results[i] = sumRes / floatValuesMV[i].length;
+        }
+        break;
+      case DOUBLE:
+        double[][] doubleValuesMV = _argument.transformToDoubleValuesMV(projectionBlock);
+        for (int i = 0; i < numDocs; i++) {
+          sumRes = 0;
+          for (int j = 0; j < doubleValuesMV[i].length; j++) {
+            sumRes += doubleValuesMV[i][j];
+          }
+          _results[i] = sumRes / doubleValuesMV[i].length;
+        }
+        break;
+      default:
+        throw new IllegalStateException();
+    }
+    return _results;
+  }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayMaxTransformFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayMaxTransformFunction.java
new file mode 100644
index 0000000..f877244
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayMaxTransformFunction.java
@@ -0,0 +1,179 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.pinot.core.common.DataSource;
+import org.apache.pinot.core.operator.blocks.ProjectionBlock;
+import org.apache.pinot.core.operator.transform.TransformResultMetadata;
+import org.apache.pinot.core.plan.DocIdSetPlanNode;
+import org.apache.pinot.core.util.ArrayCopyUtils;
+import org.apache.pinot.spi.data.FieldSpec;
+
+
+/**
+ * The ArrayMaxTransformFunction class implements array_max function for multi-valued columns
+ *
+ * Sample queries:
+ * SELECT COUNT(*) FROM table WHERE array_max(mvColumn) > 2
+ * SELECT COUNT(*) FROM table GROUP BY array_max(mvColumn)
+ * SELECT SUM(array_max(mvColumn)) FROM table
+ */
+public class ArrayMaxTransformFunction extends BaseTransformFunction {
+  public static final String FUNCTION_NAME = "array_max";
+
+  private int[] _intValuesSV;
+  private long[] _longValuesSV;
+  private float[] _floatValuesSV;
+  private double[] _doubleValuesSV;
+  private String[] _stringValuesSV;
+  private TransformFunction _argument;
+  private TransformResultMetadata _resultMetadata;
+
+  @Override
+  public String getName() {
+    return FUNCTION_NAME;
+  }
+
+  @Override
+  public void init(List<TransformFunction> arguments, Map<String, DataSource> dataSourceMap) {
+    // Check that there is only 1 argument
+    if (arguments.size() != 1) {
+      throw new IllegalArgumentException("Exactly 1 argument is required for ARRAY_MAX transform function");
+    }
+
+    // Check that the argument is a multi-valued column or transform function
+    TransformFunction firstArgument = arguments.get(0);
+    if (firstArgument instanceof LiteralTransformFunction || firstArgument.getResultMetadata().isSingleValue()) {
+      throw new IllegalArgumentException(
+          "The argument of ARRAY_MAX transform function must be a multi-valued column or a transform function");
+    }
+    _resultMetadata = new TransformResultMetadata(firstArgument.getResultMetadata().getDataType(), true, false);
+    _argument = firstArgument;
+  }
+
+  @Override
+  public TransformResultMetadata getResultMetadata() {
+    return _resultMetadata;
+  }
+
+  @Override
+  public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) {
+    if (_argument.getResultMetadata().getDataType() != FieldSpec.DataType.INT) {
+      return super.transformToIntValuesSV(projectionBlock);
+    }
+    if (_intValuesSV == null) {
+      _intValuesSV = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+    int length = projectionBlock.getNumDocs();
+    int[][] intValuesMV = _argument.transformToIntValuesMV(projectionBlock);
+    for (int i = 0; i < length; i++) {
+      int maxRes = Integer.MIN_VALUE;
+      for (int j = 0; j < intValuesMV[i].length; j++) {
+        maxRes = Math.max(maxRes, intValuesMV[i][j]);
+      }
+      _intValuesSV[i] = maxRes;
+    }
+    return _intValuesSV;
+  }
+
+  @Override
+  public long[] transformToLongValuesSV(ProjectionBlock projectionBlock) {
+    if (_argument.getResultMetadata().getDataType() != FieldSpec.DataType.LONG) {
+      return super.transformToLongValuesSV(projectionBlock);
+    }
+    if (_longValuesSV == null) {
+      _longValuesSV = new long[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+    int length = projectionBlock.getNumDocs();
+    long[][] longValuesMV = _argument.transformToLongValuesMV(projectionBlock);
+    for (int i = 0; i < length; i++) {
+      long maxRes = Long.MIN_VALUE;
+      for (int j = 0; j < longValuesMV[i].length; j++) {
+        maxRes = Math.max(maxRes, longValuesMV[i][j]);
+      }
+      _longValuesSV[i] = maxRes;
+    }
+    return _longValuesSV;
+  }
+
+  @Override
+  public float[] transformToFloatValuesSV(ProjectionBlock projectionBlock) {
+    if (_argument.getResultMetadata().getDataType() != FieldSpec.DataType.FLOAT) {
+      return super.transformToFloatValuesSV(projectionBlock);
+    }
+    if (_floatValuesSV == null) {
+      _floatValuesSV = new float[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+    int length = projectionBlock.getNumDocs();
+    float[][] floatValuesMV = _argument.transformToFloatValuesMV(projectionBlock);
+    for (int i = 0; i < length; i++) {
+      float maxRes = Float.NEGATIVE_INFINITY;
+      for (int j = 0; j < floatValuesMV[i].length; j++) {
+        maxRes = Math.max(maxRes, floatValuesMV[i][j]);
+      }
+      _floatValuesSV[i] = maxRes;
+    }
+    return _floatValuesSV;
+  }
+
+  @Override
+  public double[] transformToDoubleValuesSV(ProjectionBlock projectionBlock) {
+    if (_argument.getResultMetadata().getDataType() != FieldSpec.DataType.DOUBLE) {
+      return super.transformToDoubleValuesSV(projectionBlock);
+    }
+    if (_doubleValuesSV == null) {
+      _doubleValuesSV = new double[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+    int length = projectionBlock.getNumDocs();
+    double[][] doubleValuesMV = _argument.transformToDoubleValuesMV(projectionBlock);
+    for (int i = 0; i < length; i++) {
+      double maxRes = Double.NEGATIVE_INFINITY;
+      for (int j = 0; j < doubleValuesMV[i].length; j++) {
+        maxRes = Math.max(maxRes, doubleValuesMV[i][j]);
+      }
+      _doubleValuesSV[i] = maxRes;
+    }
+    return _doubleValuesSV;
+  }
+
+  @Override
+  public String[] transformToStringValuesSV(ProjectionBlock projectionBlock) {
+    if (_argument.getResultMetadata().getDataType() != FieldSpec.DataType.STRING) {
+      return super.transformToStringValuesSV(projectionBlock);
+    }
+    if (_stringValuesSV == null) {
+      _stringValuesSV = new String[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+    int length = projectionBlock.getNumDocs();
+    String[][] stringValuesMV = _argument.transformToStringValuesMV(projectionBlock);
+    for (int i = 0; i < length; i++) {
+      String maxRes = null;
+      for (int j = 0; j < stringValuesMV[i].length; j++) {
+        if (StringUtils.compare(maxRes, stringValuesMV[i][j]) < 0) {
+          maxRes = stringValuesMV[i][j];
+        }
+      }
+      _stringValuesSV[i] = maxRes;
+    }
+    return _stringValuesSV;
+  }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayMinTransformFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayMinTransformFunction.java
new file mode 100644
index 0000000..86abf33
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayMinTransformFunction.java
@@ -0,0 +1,178 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.pinot.core.common.DataSource;
+import org.apache.pinot.core.operator.blocks.ProjectionBlock;
+import org.apache.pinot.core.operator.transform.TransformResultMetadata;
+import org.apache.pinot.core.plan.DocIdSetPlanNode;
+import org.apache.pinot.spi.data.FieldSpec;
+
+
+/**
+ * The ArrayMinTransformFunction class implements array_min function for multi-valued columns
+ *
+ * Sample queries:
+ * SELECT COUNT(*) FROM table WHERE array_min(mvColumn) > 2
+ * SELECT COUNT(*) FROM table GROUP BY array_min(mvColumn)
+ * SELECT SUM(array_min(mvColumn)) FROM table
+ */
+public class ArrayMinTransformFunction extends BaseTransformFunction {
+  public static final String FUNCTION_NAME = "array_min";
+
+  private int[] _intValuesSV;
+  private long[] _longValuesSV;
+  private float[] _floatValuesSV;
+  private double[] _doubleValuesSV;
+  private String[] _stringValuesSV;
+  private TransformFunction _argument;
+  private TransformResultMetadata _resultMetadata;
+
+  @Override
+  public String getName() {
+    return FUNCTION_NAME;
+  }
+
+  @Override
+  public void init(List<TransformFunction> arguments, Map<String, DataSource> dataSourceMap) {
+    // Check that there is only 1 argument
+    if (arguments.size() != 1) {
+      throw new IllegalArgumentException("Exactly 1 argument is required for ARRAY_MIN transform function");
+    }
+
+    // Check that the argument is a multi-valued column or transform function
+    TransformFunction firstArgument = arguments.get(0);
+    if (firstArgument instanceof LiteralTransformFunction || firstArgument.getResultMetadata().isSingleValue()) {
+      throw new IllegalArgumentException(
+          "The argument of ARRAY_MIN transform function must be a multi-valued column or a transform function");
+    }
+    _resultMetadata = new TransformResultMetadata(firstArgument.getResultMetadata().getDataType(), true, false);
+    _argument = firstArgument;
+  }
+
+  @Override
+  public TransformResultMetadata getResultMetadata() {
+    return _resultMetadata;
+  }
+
+  @Override
+  public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) {
+    if (_argument.getResultMetadata().getDataType() != FieldSpec.DataType.INT) {
+      return super.transformToIntValuesSV(projectionBlock);
+    }
+    if (_intValuesSV == null) {
+      _intValuesSV = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+    int length = projectionBlock.getNumDocs();
+    int[][] intValuesMV = _argument.transformToIntValuesMV(projectionBlock);
+    for (int i = 0; i < length; i++) {
+      int minRes = Integer.MAX_VALUE;
+      for (int j = 0; j < intValuesMV[i].length; j++) {
+        minRes = Math.min(minRes, intValuesMV[i][j]);
+      }
+      _intValuesSV[i] = minRes;
+    }
+    return _intValuesSV;
+  }
+
+  @Override
+  public long[] transformToLongValuesSV(ProjectionBlock projectionBlock) {
+    if (_argument.getResultMetadata().getDataType() != FieldSpec.DataType.LONG) {
+      return super.transformToLongValuesSV(projectionBlock);
+    }
+    if (_longValuesSV == null) {
+      _longValuesSV = new long[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+    int length = projectionBlock.getNumDocs();
+    long[][] longValuesMV = _argument.transformToLongValuesMV(projectionBlock);
+    for (int i = 0; i < length; i++) {
+      long minRes = Long.MAX_VALUE;
+      for (int j = 0; j < longValuesMV[i].length; j++) {
+        minRes = Math.min(minRes, longValuesMV[i][j]);
+      }
+      _longValuesSV[i] = minRes;
+    }
+    return _longValuesSV;
+  }
+
+  @Override
+  public float[] transformToFloatValuesSV(ProjectionBlock projectionBlock) {
+    if (_argument.getResultMetadata().getDataType() != FieldSpec.DataType.FLOAT) {
+      return super.transformToFloatValuesSV(projectionBlock);
+    }
+    if (_floatValuesSV == null) {
+      _floatValuesSV = new float[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+    int length = projectionBlock.getNumDocs();
+    float[][] floatValuesMV = _argument.transformToFloatValuesMV(projectionBlock);
+    for (int i = 0; i < length; i++) {
+      float minRes = Float.POSITIVE_INFINITY;
+      for (int j = 0; j < floatValuesMV[i].length; j++) {
+        minRes = Math.min(minRes, floatValuesMV[i][j]);
+      }
+      _floatValuesSV[i] = minRes;
+    }
+    return _floatValuesSV;
+  }
+
+  @Override
+  public double[] transformToDoubleValuesSV(ProjectionBlock projectionBlock) {
+    if (_argument.getResultMetadata().getDataType() != FieldSpec.DataType.DOUBLE) {
+      return super.transformToDoubleValuesSV(projectionBlock);
+    }
+    if (_doubleValuesSV == null) {
+      _doubleValuesSV = new double[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+    int length = projectionBlock.getNumDocs();
+    double[][] doubleValuesMV = _argument.transformToDoubleValuesMV(projectionBlock);
+    for (int i = 0; i < length; i++) {
+      double minRes = Double.POSITIVE_INFINITY;
+      for (int j = 0; j < doubleValuesMV[i].length; j++) {
+        minRes = Math.min(minRes, doubleValuesMV[i][j]);
+      }
+      _doubleValuesSV[i] = minRes;
+    }
+    return _doubleValuesSV;
+  }
+
+  @Override
+  public String[] transformToStringValuesSV(ProjectionBlock projectionBlock) {
+    if (_argument.getResultMetadata().getDataType() != FieldSpec.DataType.STRING) {
+      return super.transformToStringValuesSV(projectionBlock);
+    }
+    if (_stringValuesSV == null) {
+      _stringValuesSV = new String[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+    int length = projectionBlock.getNumDocs();
+    String[][] stringValuesMV = _argument.transformToStringValuesMV(projectionBlock);
+    for (int i = 0; i < length; i++) {
+      String minRes = null;
+      for (int j = 0; j < stringValuesMV[i].length; j++) {
+        if (StringUtils.compare(minRes, stringValuesMV[i][j]) > 0) {
+          minRes = stringValuesMV[i][j];
+        }
+      }
+      _stringValuesSV[i] = minRes;
+    }
+    return _stringValuesSV;
+  }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArraySumTransformFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArraySumTransformFunction.java
new file mode 100644
index 0000000..ae5f6e2
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArraySumTransformFunction.java
@@ -0,0 +1,148 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.core.common.DataSource;
+import org.apache.pinot.core.operator.blocks.ProjectionBlock;
+import org.apache.pinot.core.operator.transform.TransformResultMetadata;
+import org.apache.pinot.core.plan.DocIdSetPlanNode;
+import org.apache.pinot.core.util.ArrayCopyUtils;
+import org.apache.pinot.spi.data.FieldSpec;
+
+
+/**
+ * The ArraySumTransformFunction class implements array_sum function for multi-valued columns
+ *
+ * Sample queries:
+ * SELECT COUNT(*) FROM table WHERE array_sum(mvColumn) > 2
+ * SELECT COUNT(*) FROM table GROUP BY array_sum(mvColumn)
+ * SELECT SUM(array_sum(mvColumn)) FROM table
+ */
+public class ArraySumTransformFunction extends BaseTransformFunction {
+  public static final String FUNCTION_NAME = "array_sum";
+
+  private long[] _longResults;
+  private double[] _doubleResults;
+  private TransformFunction _argument;
+  private TransformResultMetadata _resultMetadata;
+
+  @Override
+  public String getName() {
+    return FUNCTION_NAME;
+  }
+
+  @Override
+  public void init(List<TransformFunction> arguments, Map<String, DataSource> dataSourceMap) {
+    // Check that there is only 1 argument
+    if (arguments.size() != 1) {
+      throw new IllegalArgumentException("Exactly 1 argument is required for ARRAY_AVERAGE transform function");
+    }
+
+    // Check that the argument is a multi-valued column or transform function
+    TransformFunction firstArgument = arguments.get(0);
+    if (firstArgument instanceof LiteralTransformFunction || firstArgument.getResultMetadata().isSingleValue()) {
+      throw new IllegalArgumentException(
+          "The argument of ARRAY_AVERAGE transform function must be a multi-valued column or a transform function");
+    }
+    FieldSpec.DataType resultDataType;
+    switch (firstArgument.getResultMetadata().getDataType()) {
+      case INT:
+      case LONG:
+        resultDataType = FieldSpec.DataType.LONG;
+        break;
+      case FLOAT:
+      case DOUBLE:
+        resultDataType = FieldSpec.DataType.DOUBLE;
+        break;
+      default:
+        throw new IllegalArgumentException(
+            "The argument of ARRAY_AVERAGE transform function must be numeric");
+    }
+    _resultMetadata = new TransformResultMetadata(resultDataType, true, false);
+    _argument = firstArgument;
+  }
+
+  @Override
+  public TransformResultMetadata getResultMetadata() {
+    return _resultMetadata;
+  }
+
+  @Override
+  public long[] transformToLongValuesSV(ProjectionBlock projectionBlock) {
+    if (_longResults == null) {
+      _longResults = new long[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+    int length = projectionBlock.getNumDocs();
+    long sumRes;
+    switch (_argument.getResultMetadata().getDataType()) {
+      case INT:
+      case LONG:
+        long[][] longValuesMV = _argument.transformToLongValuesMV(projectionBlock);
+        for (int i = 0; i < length; i++) {
+          sumRes = 0;
+          for (int j = 0; j < longValuesMV[i].length; j++) {
+            sumRes += longValuesMV[i][j];
+          }
+          _longResults[i] = sumRes;
+        }
+        break;
+      case FLOAT:
+      case DOUBLE:
+        double[] doubleValues = transformToDoubleValuesSV(projectionBlock);
+        ArrayCopyUtils.copy(doubleValues, _longResults, length);
+        break;
+      default:
+        throw new IllegalStateException();
+    }
+    return _longResults;
+  }
+
+  @Override
+  public double[] transformToDoubleValuesSV(ProjectionBlock projectionBlock) {
+    if (_doubleResults == null) {
+      _doubleResults = new double[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+
+    int length = projectionBlock.getNumDocs();
+    double sumRes;
+    switch (_argument.getResultMetadata().getDataType()) {
+      case INT:
+      case LONG:
+        long[] longValues = transformToLongValuesSV(projectionBlock);
+        ArrayCopyUtils.copy(longValues, _doubleResults, length);
+        break;
+      case FLOAT:
+      case DOUBLE:
+        double[][] doubleValuesMV = _argument.transformToDoubleValuesMV(projectionBlock);
+        for (int i = 0; i < length; i++) {
+          sumRes = 0;
+          for (int j = 0; j < doubleValuesMV[i].length; j++) {
+            sumRes += doubleValuesMV[i][j];
+          }
+          _doubleResults[i] = sumRes;
+        }
+        break;
+      default:
+        throw new IllegalStateException();
+    }
+    return _doubleResults;
+  }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
index e70e5ed..4917708 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
@@ -92,6 +92,12 @@ public class TransformFunctionFactory {
           put(TransformFunctionType.MAPVALUE.getName().toLowerCase(), MapValueTransformFunction.class);
           put(TransformFunctionType.INIDSET.getName().toLowerCase(), InIdSetTransformFunction.class);
 
+          // Array functions
+          put(TransformFunctionType.ARRAY_AVERAGE.getName().toLowerCase(), ArrayAverageTransformFunction.class);
+          put(TransformFunctionType.ARRAY_MAX.getName().toLowerCase(), ArrayMaxTransformFunction.class);
+          put(TransformFunctionType.ARRAY_MIN.getName().toLowerCase(), ArrayMinTransformFunction.class);
+          put(TransformFunctionType.ARRAY_SUM.getName().toLowerCase(), ArraySumTransformFunction.class);
+
           put(TransformFunctionType.GROOVY.getName().toLowerCase(), GroovyTransformFunction.class);
           put(TransformFunctionType.CASE.getName().toLowerCase(), CaseTransformFunction.class);
 
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayAverageTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayAverageTransformFunctionTest.java
new file mode 100644
index 0000000..0971adb
--- /dev/null
+++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayAverageTransformFunctionTest.java
@@ -0,0 +1,49 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import org.apache.pinot.spi.data.FieldSpec;
+
+
+public class ArrayAverageTransformFunctionTest extends ArrayBaseTransformFunctionTest {
+
+  @Override
+  String getFunctionName() {
+    return ArrayAverageTransformFunction.FUNCTION_NAME;
+  }
+
+  @Override
+  Object getExpectResult(int[] intArrary) {
+    double sumRes = 0;
+    for (int v : intArrary) {
+      sumRes += v;
+    }
+    return sumRes / intArrary.length;
+  }
+
+  @Override
+  Class getArrayFunctionClass() {
+    return ArrayAverageTransformFunction.class;
+  }
+
+  @Override
+  FieldSpec.DataType getResultDataType(FieldSpec.DataType inputDataType) {
+    return FieldSpec.DataType.DOUBLE;
+  }
+}
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayBaseTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayBaseTransformFunctionTest.java
new file mode 100644
index 0000000..3f5f02c
--- /dev/null
+++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayBaseTransformFunctionTest.java
@@ -0,0 +1,97 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import org.apache.pinot.core.query.exception.BadQueryRequestException;
+import org.apache.pinot.core.query.request.context.ExpressionContext;
+import org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.testng.Assert;
+import org.testng.annotations.DataProvider;
+import org.testng.annotations.Test;
+
+
+public abstract class ArrayBaseTransformFunctionTest extends BaseTransformFunctionTest {
+
+  @Test
+  public void testArrayTransformFunction() {
+    ExpressionContext expression =
+        QueryContextConverterUtils.getExpression(String.format("%s(%s)", getFunctionName(), INT_MV_COLUMN));
+    TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
+    Assert.assertEquals(transformFunction.getClass().getName(), getArrayFunctionClass().getName());
+    Assert.assertEquals(transformFunction.getName(), getFunctionName());
+    Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), getResultDataType(FieldSpec.DataType.INT));
+    Assert.assertTrue(transformFunction.getResultMetadata().isSingleValue());
+    Assert.assertFalse(transformFunction.getResultMetadata().hasDictionary());
+
+    switch (getResultDataType(FieldSpec.DataType.INT)) {
+      case INT:
+        int[] intResults = transformFunction.transformToIntValuesSV(_projectionBlock);
+        for (int i = 0; i < NUM_ROWS; i++) {
+          Assert.assertEquals(intResults[i], getExpectResult(_intMVValues[i]));
+        }
+        break;
+      case LONG:
+        long[] longResults = transformFunction.transformToLongValuesSV(_projectionBlock);
+        for (int i = 0; i < NUM_ROWS; i++) {
+          Assert.assertEquals(longResults[i], getExpectResult(_intMVValues[i]));
+        }
+        break;
+      case FLOAT:
+        float[] floatResults = transformFunction.transformToFloatValuesSV(_projectionBlock);
+        for (int i = 0; i < NUM_ROWS; i++) {
+          Assert.assertEquals(floatResults[i], getExpectResult(_intMVValues[i]));
+        }
+        break;
+      case DOUBLE:
+        double[] doubleResults = transformFunction.transformToDoubleValuesSV(_projectionBlock);
+        for (int i = 0; i < NUM_ROWS; i++) {
+          Assert.assertEquals(doubleResults[i], getExpectResult(_intMVValues[i]));
+        }
+        break;
+      case STRING:
+        String[] stringResults = transformFunction.transformToStringValuesSV(_projectionBlock);
+        for (int i = 0; i < NUM_ROWS; i++) {
+          Assert.assertEquals(stringResults[i], getExpectResult(_intMVValues[i]));
+        }
+        break;
+    }
+  }
+
+  @Test(dataProvider = "testIllegalArguments", expectedExceptions = {BadQueryRequestException.class})
+  public void testIllegalArguments(String expressionStr) {
+    ExpressionContext expression = QueryContextConverterUtils.getExpression(expressionStr);
+    TransformFunctionFactory.get(expression, _dataSourceMap);
+  }
+
+  @DataProvider(name = "testIllegalArguments")
+  public Object[][] testIllegalArguments() {
+    return new Object[][]{new Object[]{String.format("%s(%s,1)", getFunctionName(),
+        INT_MV_COLUMN)}, new Object[]{String.format("%s(2)", getFunctionName())}, new Object[]{String.format("%s(%s)",
+        getFunctionName(), LONG_SV_COLUMN)}};
+  }
+
+  abstract String getFunctionName();
+
+  abstract Object getExpectResult(int[] intArray);
+
+  abstract Class getArrayFunctionClass();
+
+  abstract FieldSpec.DataType getResultDataType(FieldSpec.DataType inputDataType);
+}
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLengthTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLengthTransformFunctionTest.java
index 81bcdcf..049e573 100644
--- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLengthTransformFunctionTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLengthTransformFunctionTest.java
@@ -27,35 +27,25 @@ import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
 
-public class ArrayLengthTransformFunctionTest extends BaseTransformFunctionTest {
+public class ArrayLengthTransformFunctionTest extends ArrayBaseTransformFunctionTest {
 
-  @Test
-  public void testLengthTransformFunction() {
-    ExpressionContext expression =
-        QueryContextConverterUtils.getExpression(String.format("arrayLength(%s)", INT_MV_COLUMN));
-    TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
-    Assert.assertTrue(transformFunction instanceof ArrayLengthTransformFunction);
-    Assert.assertEquals(transformFunction.getName(), ArrayLengthTransformFunction.FUNCTION_NAME);
-    Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), FieldSpec.DataType.INT);
-    Assert.assertTrue(transformFunction.getResultMetadata().isSingleValue());
-    Assert.assertFalse(transformFunction.getResultMetadata().hasDictionary());
+  @Override
+  String getFunctionName() {
+    return ArrayLengthTransformFunction.FUNCTION_NAME;
+  }
 
-    int[] results = transformFunction.transformToIntValuesSV(_projectionBlock);
-    for (int i = 0; i < NUM_ROWS; i++) {
-      Assert.assertEquals(results[i], _intMVValues[i].length);
-    }
+  @Override
+  Object getExpectResult(int[] intArrary) {
+    return intArrary.length;
   }
 
-  @Test(dataProvider = "testIllegalArguments", expectedExceptions = {BadQueryRequestException.class})
-  public void testIllegalArguments(String expressionStr) {
-    ExpressionContext expression = QueryContextConverterUtils.getExpression(expressionStr);
-    TransformFunctionFactory.get(expression, _dataSourceMap);
+  @Override
+  Class getArrayFunctionClass() {
+    return ArrayLengthTransformFunction.class;
   }
 
-  @DataProvider(name = "testIllegalArguments")
-  public Object[][] testIllegalArguments() {
-    return new Object[][]{new Object[]{String.format("arrayLength(%s,1)",
-        INT_MV_COLUMN)}, new Object[]{"arrayLength(2)"}, new Object[]{String.format("arrayLength(%s)",
-        LONG_SV_COLUMN)}};
+  @Override
+  FieldSpec.DataType getResultDataType(FieldSpec.DataType inputDataType) {
+    return FieldSpec.DataType.INT;
   }
 }
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayMaxTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayMaxTransformFunctionTest.java
new file mode 100644
index 0000000..1c98462
--- /dev/null
+++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayMaxTransformFunctionTest.java
@@ -0,0 +1,49 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import org.apache.pinot.spi.data.FieldSpec;
+
+
+public class ArrayMaxTransformFunctionTest extends ArrayBaseTransformFunctionTest {
+
+  @Override
+  String getFunctionName() {
+    return ArrayMaxTransformFunction.FUNCTION_NAME;
+  }
+
+  @Override
+  Object getExpectResult(int[] intArrary) {
+    int maxValue = Integer.MIN_VALUE;
+    for (int v : intArrary) {
+      maxValue = Math.max(maxValue, v);
+    }
+    return maxValue;
+  }
+
+  @Override
+  Class getArrayFunctionClass() {
+    return ArrayMaxTransformFunction.class;
+  }
+
+  @Override
+  FieldSpec.DataType getResultDataType(FieldSpec.DataType inputDataType) {
+    return inputDataType;
+  }
+}
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayMinTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayMinTransformFunctionTest.java
new file mode 100644
index 0000000..48da84b
--- /dev/null
+++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayMinTransformFunctionTest.java
@@ -0,0 +1,49 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import org.apache.pinot.spi.data.FieldSpec;
+
+
+public class ArrayMinTransformFunctionTest extends ArrayBaseTransformFunctionTest {
+
+  @Override
+  String getFunctionName() {
+    return ArrayMinTransformFunction.FUNCTION_NAME;
+  }
+
+  @Override
+  Object getExpectResult(int[] intArrary) {
+    int minValue = Integer.MAX_VALUE;
+    for (int v : intArrary) {
+      minValue = Math.min(minValue, v);
+    }
+    return minValue;
+  }
+
+  @Override
+  Class getArrayFunctionClass() {
+    return ArrayMinTransformFunction.class;
+  }
+
+  @Override
+  FieldSpec.DataType getResultDataType(FieldSpec.DataType inputDataType) {
+    return inputDataType;
+  }
+}
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArraySumTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArraySumTransformFunctionTest.java
new file mode 100644
index 0000000..15e455b
--- /dev/null
+++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArraySumTransformFunctionTest.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import org.apache.pinot.spi.data.FieldSpec;
+
+
+public class ArraySumTransformFunctionTest extends ArrayBaseTransformFunctionTest {
+
+  @Override
+  String getFunctionName() {
+    return ArraySumTransformFunction.FUNCTION_NAME;
+  }
+
+  @Override
+  Object getExpectResult(int[] intArrary) {
+    long sumRes = 0;
+    for (int v : intArrary) {
+      sumRes += v;
+    }
+    return sumRes;
+  }
+
+  @Override
+  Class getArrayFunctionClass() {
+    return ArraySumTransformFunction.class;
+  }
+
+  @Override
+  FieldSpec.DataType getResultDataType(FieldSpec.DataType inputDataType) {
+    switch (inputDataType) {
+      case INT:
+      case LONG:
+        return FieldSpec.DataType.LONG;
+      case FLOAT:
+      case DOUBLE:
+        return FieldSpec.DataType.DOUBLE;
+    }
+    throw new IllegalArgumentException("Unsupported input data type - " + inputDataType);
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org