You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ne...@apache.org on 2020/07/25 02:38:49 UTC
[incubator-pinot] branch master updated: GROOVY transform function
UDF for queries (#5748)
This is an automated email from the ASF dual-hosted git repository.
nehapawar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 6911172 GROOVY transform function UDF for queries (#5748)
6911172 is described below
commit 69111727f6f011a7d864698cde292df0cd2e22d9
Author: Neha Pawar <ne...@gmail.com>
AuthorDate: Fri Jul 24 19:38:33 2020 -0700
GROOVY transform function UDF for queries (#5748)
---
.../common/function/TransformFunctionType.java | 2 +-
.../data/function/GroovyFunctionEvaluator.java | 19 +-
.../function/GroovyTransformFunction.java | 438 +++++++++++++++++++++
.../function/TransformFunctionFactory.java | 2 +
.../function/GroovyTransformFunctionTest.java | 292 ++++++++++++++
5 files changed, 749 insertions(+), 4 deletions(-)
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
index 583e678..65cf1df 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
@@ -55,7 +55,7 @@ public enum TransformFunctionType {
ARRAYLENGTH("arrayLength"),
VALUEIN("valueIn"),
MAPVALUE("mapValue"),
-
+ GROOVY("groovy"),
// Special type for annotation based scalar functions
SCALAR("scalar"),
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/function/GroovyFunctionEvaluator.java b/pinot-core/src/main/java/org/apache/pinot/core/data/function/GroovyFunctionEvaluator.java
index 64e7d9a..ee5422e 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/data/function/GroovyFunctionEvaluator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/data/function/GroovyFunctionEvaluator.java
@@ -55,18 +55,20 @@ public class GroovyFunctionEvaluator implements FunctionEvaluator {
private static final String ARGUMENTS_SEPARATOR = ",";
private final List<String> _arguments;
+ private final int _numArguments;
private final Binding _binding;
private final Script _script;
- public GroovyFunctionEvaluator(String transformExpression) {
- Matcher matcher = GROOVY_FUNCTION_PATTERN.matcher(transformExpression);
- Preconditions.checkState(matcher.matches(), "Invalid transform expression: %s", transformExpression);
+ public GroovyFunctionEvaluator(String closure) {
+ Matcher matcher = GROOVY_FUNCTION_PATTERN.matcher(closure);
+ Preconditions.checkState(matcher.matches(), "Invalid transform expression: %s", closure);
String arguments = matcher.group(ARGUMENTS_GROUP_NAME);
if (arguments != null) {
_arguments = Splitter.on(ARGUMENTS_SEPARATOR).trimResults().splitToList(arguments);
} else {
_arguments = Collections.emptyList();
}
+ _numArguments = _arguments.size();
_binding = new Binding();
_script = new GroovyShell(_binding).parse(matcher.group(SCRIPT_GROUP_NAME));
}
@@ -92,4 +94,15 @@ public class GroovyFunctionEvaluator implements FunctionEvaluator {
}
return _script.run();
}
+
+ /**
+ * Evaluate the Groovy function with bindings provided as an array of Object
+ * The number of elements in the values must match the numArguments
+ */
+ public Object evaluate(Object[] values) {
+ for (int i = 0; i < _numArguments; i++) {
+ _binding.setVariable(_arguments.get(i), values[i]);
+ }
+ return _script.run();
+ }
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/GroovyTransformFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/GroovyTransformFunction.java
new file mode 100644
index 0000000..504d226
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/GroovyTransformFunction.java
@@ -0,0 +1,438 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.google.common.base.Preconditions;
+import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
+import it.unimi.dsi.fastutil.floats.FloatArrayList;
+import it.unimi.dsi.fastutil.ints.IntArrayList;
+import it.unimi.dsi.fastutil.longs.LongArrayList;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.function.BiFunction;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import org.apache.commons.lang3.EnumUtils;
+import org.apache.pinot.core.common.DataSource;
+import org.apache.pinot.core.data.function.GroovyFunctionEvaluator;
+import org.apache.pinot.core.operator.blocks.ProjectionBlock;
+import org.apache.pinot.core.operator.transform.TransformResultMetadata;
+import org.apache.pinot.core.plan.DocIdSetPlanNode;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.apache.pinot.spi.utils.JsonUtils;
+
+
+/**
+ * The GroovyTransformFunction executes groovy expressions
+ * 1st argument - json string containing returnType and isSingleValue e.g. '{"returnType":"LONG", "isSingleValue":false}'
+ * 2nd argument - groovy script (string) using arg0, arg1, arg2... as arguments e.g. 'arg0 + " " + arg1', 'arg0 + arg1.toList().max() + arg2' etc
+ * rest of the arguments - identifiers/functions to the groovy script
+ *
+ * Sample queries:
+ * SELECT GROOVY('{"returnType":"LONG", "isSingleValue":false}', 'arg0.findIndexValues{it==1}', products) FROM myTable
+ * SELECT GROOVY('{"returnType":"INT", "isSingleValue":true}', 'arg0 * arg1 * 10', arraylength(units), columnB ) FROM bob
+ */
+public class GroovyTransformFunction extends BaseTransformFunction {
+ public static final String FUNCTION_NAME = "groovy";
+
+ private static final String RETURN_TYPE_KEY = "returnType";
+ private static final String IS_SINGLE_VALUE_KEY = "isSingleValue";
+ private static final String ARGUMENT_PREFIX = "arg";
+ private static final String GROOVY_TEMPLATE_WITH_ARGS = "Groovy({%s}, %s)";
+ private static final String GROOVY_TEMPLATE_WITHOUT_ARGS = "Groovy({%s})";
+ private static final String GROOVY_ARG_DELIMITER = ",";
+
+ private int[] _intResultSV;
+ private long[] _longResultSV;
+ private double[] _doubleResultSV;
+ private float[] _floatResultSV;
+ private String[] _stringResultSV;
+ private int[][] _intResultMV;
+ private long[][] _longResultMV;
+ private double[][] _doubleResultMV;
+ private float[][] _floatResultMV;
+ private String[][] _stringResultMV;
+ private TransformResultMetadata _resultMetadata;
+
+ private GroovyFunctionEvaluator _groovyFunctionEvaluator;
+ private int _numGroovyArgs;
+ private TransformFunction[] _groovyArguments;
+ private boolean[] _isSourceSingleValue;
+ private FieldSpec.DataType[] _sourceDataType;
+ private BiFunction<TransformFunction, ProjectionBlock, Object>[] _transformToValuesFunctions;
+ private BiFunction<Object, Integer, Object>[] _fetchElementFunctions;
+ private Object[] _sourceArrays;
+ private Object[] _bindingValues;
+
+ @Override
+ public String getName() {
+ return FUNCTION_NAME;
+ }
+
+ @Override
+ public void init(List<TransformFunction> arguments, Map<String, DataSource> dataSourceMap) {
+ int numArgs = arguments.size();
+ if (numArgs < 2) {
+ throw new IllegalArgumentException("GROOVY transform function requires at least 2 arguments");
+ }
+
+ // 1st argument is a json string
+ TransformFunction returnValueMetadata = arguments.get(0);
+ Preconditions.checkState(returnValueMetadata instanceof LiteralTransformFunction,
+ "First argument of GROOVY transform function must be a literal, representing a json string");
+ String returnValueMetadataStr = ((LiteralTransformFunction) returnValueMetadata).getLiteral();
+ try {
+ JsonNode returnValueMetadataJson = JsonUtils.stringToJsonNode(returnValueMetadataStr);
+ Preconditions.checkState(returnValueMetadataJson.hasNonNull(RETURN_TYPE_KEY),
+ "The json string in the first argument of GROOVY transform function must have non-null 'returnType'");
+ Preconditions.checkState(returnValueMetadataJson.hasNonNull(IS_SINGLE_VALUE_KEY),
+ "The json string in the first argument of GROOVY transform function must have non-null 'isSingleValue'");
+ String returnTypeStr = returnValueMetadataJson.get(RETURN_TYPE_KEY).asText();
+ Preconditions.checkState(EnumUtils.isValidEnum(FieldSpec.DataType.class, returnTypeStr),
+ "The 'returnType' in the json string which is the first argument of GROOVY transform function must be a valid FieldSpec.DataType enum value");
+ _resultMetadata = new TransformResultMetadata(FieldSpec.DataType.valueOf(returnTypeStr),
+ returnValueMetadataJson.get(IS_SINGLE_VALUE_KEY).asBoolean(true), false);
+ } catch (IOException e) {
+ throw new IllegalStateException(
+ "Caught exception when converting json string '" + returnValueMetadataStr + "' to JsonNode", e);
+ }
+
+ // 2nd argument is groovy expression string
+ TransformFunction groovyTransformFunction = arguments.get(1);
+ Preconditions.checkState(groovyTransformFunction instanceof LiteralTransformFunction,
+ "Second argument of GROOVY transform function must be a literal string, representing the groovy expression");
+
+ // 3rd argument onwards, all are arguments to the groovy function
+ _numGroovyArgs = numArgs - 2;
+ if (_numGroovyArgs > 0) {
+ _groovyArguments = new TransformFunction[_numGroovyArgs];
+ _isSourceSingleValue = new boolean[_numGroovyArgs];
+ _sourceDataType = new FieldSpec.DataType[_numGroovyArgs];
+ int idx = 0;
+ for (int i = 2; i < numArgs; i++) {
+ TransformFunction argument = arguments.get(i);
+ Preconditions.checkState(!(argument instanceof LiteralTransformFunction),
+ "Third argument onwards, all arguments must be a column or other transform function");
+ _groovyArguments[idx] = argument;
+ TransformResultMetadata resultMetadata = argument.getResultMetadata();
+ _isSourceSingleValue[idx] = resultMetadata.isSingleValue();
+ _sourceDataType[idx++] = resultMetadata.getDataType();
+ }
+ // construct arguments string for GroovyFunctionEvaluator
+ String argumentsStr = IntStream.range(0, _numGroovyArgs).mapToObj(i -> ARGUMENT_PREFIX + i)
+ .collect(Collectors.joining(GROOVY_ARG_DELIMITER));
+ _groovyFunctionEvaluator = new GroovyFunctionEvaluator(String
+ .format(GROOVY_TEMPLATE_WITH_ARGS, ((LiteralTransformFunction) groovyTransformFunction).getLiteral(),
+ argumentsStr));
+
+ _transformToValuesFunctions = new BiFunction[_numGroovyArgs];
+ _fetchElementFunctions = new BiFunction[_numGroovyArgs];
+ initFunctions();
+ } else {
+ _groovyFunctionEvaluator = new GroovyFunctionEvaluator(String
+ .format(GROOVY_TEMPLATE_WITHOUT_ARGS, ((LiteralTransformFunction) groovyTransformFunction).getLiteral()));
+ }
+ _sourceArrays = new Object[_numGroovyArgs];
+ _bindingValues = new Object[_numGroovyArgs];
+ }
+
+ @Override
+ public TransformResultMetadata getResultMetadata() {
+ return _resultMetadata;
+ }
+
+ private void initFunctions() {
+ for (int i = 0; i < _numGroovyArgs; i++) {
+ BiFunction<Object, Integer, Object> getElementFunction;
+ BiFunction<TransformFunction, ProjectionBlock, Object> transformToValuesFunction;
+ if (_isSourceSingleValue[i]) {
+ switch (_sourceDataType[i]) {
+ case INT:
+ transformToValuesFunction = TransformFunction::transformToIntValuesSV;
+ getElementFunction = (sourceArray, position) -> ((int[]) sourceArray)[position];
+ break;
+ case LONG:
+ transformToValuesFunction = TransformFunction::transformToLongValuesSV;
+ getElementFunction = (sourceArray, position) -> ((long[]) sourceArray)[position];
+ break;
+ case FLOAT:
+ transformToValuesFunction = TransformFunction::transformToFloatValuesSV;
+ getElementFunction = (sourceArray, position) -> ((float[]) sourceArray)[position];
+ break;
+ case DOUBLE:
+ transformToValuesFunction = TransformFunction::transformToDoubleValuesSV;
+ getElementFunction = (sourceArray, position) -> ((double[]) sourceArray)[position];
+ break;
+ case STRING:
+ transformToValuesFunction = TransformFunction::transformToStringValuesSV;
+ getElementFunction = (sourceArray, position) -> ((String[]) sourceArray)[position];
+ break;
+ default:
+ throw new IllegalStateException(
+ "Unsupported data type '" + _sourceDataType[i] + "' for GROOVY transform function");
+ }
+ } else {
+ switch (_sourceDataType[i]) {
+ case INT:
+ transformToValuesFunction = TransformFunction::transformToIntValuesMV;
+ getElementFunction = (sourceArray, position) -> ((int[][]) sourceArray)[position];
+ break;
+ case LONG:
+ transformToValuesFunction = TransformFunction::transformToLongValuesMV;
+ getElementFunction = (sourceArray, position) -> ((long[][]) sourceArray)[position];
+ break;
+ case FLOAT:
+ transformToValuesFunction = TransformFunction::transformToFloatValuesMV;
+ getElementFunction = (sourceArray, position) -> ((float[][]) sourceArray)[position];
+ break;
+ case DOUBLE:
+ transformToValuesFunction = TransformFunction::transformToDoubleValuesMV;
+ getElementFunction = (sourceArray, position) -> ((double[][]) sourceArray)[position];
+ break;
+ case STRING:
+ transformToValuesFunction = TransformFunction::transformToStringValuesMV;
+ getElementFunction = (sourceArray, position) -> ((String[][]) sourceArray)[position];
+ break;
+ default:
+ throw new IllegalStateException(
+ "Unsupported data type '" + _sourceDataType[i] + "' for GROOVY transform function");
+ }
+ }
+ _transformToValuesFunctions[i] = transformToValuesFunction;
+ _fetchElementFunctions[i] = getElementFunction;
+ }
+ }
+
+ @Override
+ public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) {
+ if (_intResultSV == null) {
+ _intResultSV = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+ }
+ for (int i = 0; i < _numGroovyArgs; i++) {
+ _sourceArrays[i] = _transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+ }
+ int length = projectionBlock.getNumDocs();
+ for (int i = 0; i < length; i++) {
+ for (int j = 0; j < _numGroovyArgs; j++) {
+ _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], i);
+ }
+ _intResultSV[i] = (int) _groovyFunctionEvaluator.evaluate(_bindingValues);
+ }
+ return _intResultSV;
+ }
+
+ @Override
+ public int[][] transformToIntValuesMV(ProjectionBlock projectionBlock) {
+ if (_intResultMV == null) {
+ _intResultMV = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL][];
+ }
+ for (int i = 0; i < _numGroovyArgs; i++) {
+ _sourceArrays[i] = _transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+ }
+ int length = projectionBlock.getNumDocs();
+ for (int i = 0; i < length; i++) {
+ for (int j = 0; j < _numGroovyArgs; j++) {
+ _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], i);
+ }
+ Object result = _groovyFunctionEvaluator.evaluate(_bindingValues);
+ if (result instanceof List) {
+ _intResultMV[i] = new IntArrayList((List<Integer>) result).toIntArray();
+ } else if (result instanceof int[]) {
+ _intResultMV[i] = (int[]) result;
+ } else {
+ throw new IllegalStateException("Unexpected result type '" + result.getClass() + "' for GROOVY function");
+ }
+ }
+ return _intResultMV;
+ }
+
+ @Override
+ public double[] transformToDoubleValuesSV(ProjectionBlock projectionBlock) {
+ if (_doubleResultSV == null) {
+ _doubleResultSV = new double[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+ }
+ for (int i = 0; i < _numGroovyArgs; i++) {
+ _sourceArrays[i] = _transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+ }
+ int length = projectionBlock.getNumDocs();
+ for (int i = 0; i < length; i++) {
+ for (int j = 0; j < _numGroovyArgs; j++) {
+ _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], i);
+ }
+ _doubleResultSV[i] = (double) _groovyFunctionEvaluator.evaluate(_bindingValues);
+ }
+ return _doubleResultSV;
+ }
+
+ @Override
+ public double[][] transformToDoubleValuesMV(ProjectionBlock projectionBlock) {
+ if (_doubleResultMV == null) {
+ _doubleResultMV = new double[DocIdSetPlanNode.MAX_DOC_PER_CALL][];
+ }
+ for (int i = 0; i < _numGroovyArgs; i++) {
+ _sourceArrays[i] = _transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+ }
+ int length = projectionBlock.getNumDocs();
+ for (int i = 0; i < length; i++) {
+ for (int j = 0; j < _numGroovyArgs; j++) {
+ _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], i);
+ }
+ Object result = _groovyFunctionEvaluator.evaluate(_bindingValues);
+ if (result instanceof List) {
+ _doubleResultMV[i] = new DoubleArrayList((List<Double>) result).toDoubleArray();
+ } else if (result instanceof double[]) {
+ _doubleResultMV[i] = (double[]) result;
+ } else {
+ throw new IllegalStateException("Unexpected result type '" + result.getClass() + "' for GROOVY function");
+ }
+ }
+ return _doubleResultMV;
+ }
+
+ @Override
+ public long[] transformToLongValuesSV(ProjectionBlock projectionBlock) {
+ if (_longResultSV == null) {
+ _longResultSV = new long[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+ }
+ for (int i = 0; i < _numGroovyArgs; i++) {
+ _sourceArrays[i] = _transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+ }
+ int length = projectionBlock.getNumDocs();
+ for (int i = 0; i < length; i++) {
+ for (int j = 0; j < _numGroovyArgs; j++) {
+ _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], i);
+ }
+ _longResultSV[i] = (long) _groovyFunctionEvaluator.evaluate(_bindingValues);
+ }
+ return _longResultSV;
+ }
+
+ @Override
+ public long[][] transformToLongValuesMV(ProjectionBlock projectionBlock) {
+ if (_longResultMV == null) {
+ _longResultMV = new long[DocIdSetPlanNode.MAX_DOC_PER_CALL][];
+ }
+ for (int i = 0; i < _numGroovyArgs; i++) {
+ _sourceArrays[i] = _transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+ }
+ int length = projectionBlock.getNumDocs();
+ for (int i = 0; i < length; i++) {
+ for (int j = 0; j < _numGroovyArgs; j++) {
+ _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], i);
+ }
+ Object result = _groovyFunctionEvaluator.evaluate(_bindingValues);
+ if (result instanceof List) {
+ _longResultMV[i] = new LongArrayList((List<Long>) result).toLongArray();
+ } else if (result instanceof long[]) {
+ _longResultMV[i] = (long[]) result;
+ } else {
+ throw new IllegalStateException("Unexpected result type '" + result.getClass() + "' for GROOVY function");
+ }
+ }
+ return _longResultMV;
+ }
+
+ @Override
+ public float[] transformToFloatValuesSV(ProjectionBlock projectionBlock) {
+ if (_floatResultSV == null) {
+ _floatResultSV = new float[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+ }
+ for (int i = 0; i < _numGroovyArgs; i++) {
+ _sourceArrays[i] = _transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+ }
+ int length = projectionBlock.getNumDocs();
+ for (int i = 0; i < length; i++) {
+ for (int j = 0; j < _numGroovyArgs; j++) {
+ _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], i);
+ }
+ _floatResultSV[i] = (float) _groovyFunctionEvaluator.evaluate(_bindingValues);
+ }
+ return _floatResultSV;
+ }
+
+ @Override
+ public float[][] transformToFloatValuesMV(ProjectionBlock projectionBlock) {
+ if (_floatResultMV == null) {
+ _floatResultMV = new float[DocIdSetPlanNode.MAX_DOC_PER_CALL][];
+ }
+ for (int i = 0; i < _numGroovyArgs; i++) {
+ _sourceArrays[i] = _transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+ }
+ int length = projectionBlock.getNumDocs();
+ for (int i = 0; i < length; i++) {
+ for (int j = 0; j < _numGroovyArgs; j++) {
+ _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], i);
+ }
+ Object result = _groovyFunctionEvaluator.evaluate(_bindingValues);
+ if (result instanceof List) {
+ _floatResultMV[i] = new FloatArrayList((List<Float>) result).toFloatArray();
+ } else if (result instanceof float[]) {
+ _floatResultMV[i] = (float[]) result;
+ } else {
+ throw new IllegalStateException("Unexpected result type '" + result.getClass() + "' for GROOVY function");
+ }
+ }
+ return _floatResultMV;
+ }
+
+ @Override
+ public String[] transformToStringValuesSV(ProjectionBlock projectionBlock) {
+ if (_stringResultSV == null) {
+ _stringResultSV = new String[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+ }
+ for (int i = 0; i < _numGroovyArgs; i++) {
+ _sourceArrays[i] = _transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+ }
+ int length = projectionBlock.getNumDocs();
+ for (int i = 0; i < length; i++) {
+ for (int j = 0; j < _numGroovyArgs; j++) {
+ _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], i);
+ }
+ _stringResultSV[i] = (String) _groovyFunctionEvaluator.evaluate(_bindingValues);
+ }
+ return _stringResultSV;
+ }
+
+ @Override
+ public String[][] transformToStringValuesMV(ProjectionBlock projectionBlock) {
+ if (_stringResultMV == null) {
+ _stringResultMV = new String[DocIdSetPlanNode.MAX_DOC_PER_CALL][];
+ }
+ for (int i = 0; i < _numGroovyArgs; i++) {
+ _sourceArrays[i] = _transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+ }
+ int length = projectionBlock.getNumDocs();
+ for (int i = 0; i < length; i++) {
+ for (int j = 0; j < _numGroovyArgs; j++) {
+ _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], i);
+ }
+ Object result = _groovyFunctionEvaluator.evaluate(_bindingValues);
+ if (result instanceof List) {
+ _stringResultMV[i] = ((List<String>) result).toArray(new String[0]);
+ } else if (result instanceof String[]) {
+ _stringResultMV[i] = (String[]) result;
+ } else {
+ throw new IllegalStateException("Unexpected result type '" + result.getClass() + "' for GROOVY function");
+ }
+ }
+ return _stringResultMV;
+ }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
index 670d8a1..a4f9548 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
@@ -90,6 +90,8 @@ public class TransformFunctionFactory {
put(TransformFunctionType.ARRAYLENGTH.getName().toLowerCase(), ArrayLengthTransformFunction.class);
put(TransformFunctionType.VALUEIN.getName().toLowerCase(), ValueInTransformFunction.class);
put(TransformFunctionType.MAPVALUE.getName().toLowerCase(), MapValueTransformFunction.class);
+
+ put(TransformFunctionType.GROOVY.getName().toLowerCase(), GroovyTransformFunction.class);
put(TransformFunctionType.CASE.getName().toLowerCase(), CaseTransformFunction.class);
put(TransformFunctionType.EQUALS.getName().toLowerCase(), EqualsTransformFunction.class);
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/GroovyTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/GroovyTransformFunctionTest.java
new file mode 100644
index 0000000..187fb61
--- /dev/null
+++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/GroovyTransformFunctionTest.java
@@ -0,0 +1,292 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import com.google.common.base.Joiner;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.IntSummaryStatistics;
+import java.util.List;
+import java.util.stream.IntStream;
+import org.apache.pinot.core.query.exception.BadQueryRequestException;
+import org.apache.pinot.core.query.request.context.ExpressionContext;
+import org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.testng.Assert;
+import org.testng.annotations.DataProvider;
+import org.testng.annotations.Test;
+
+
+/**
+ * Tests the GROOVY transform function
+ */
+public class GroovyTransformFunctionTest extends BaseTransformFunctionTest {
+
+ @DataProvider(name = "groovyFunctionDataProvider")
+ public Object[][] groovyFunctionDataProvider() {
+
+ String groovyTransformFunction;
+ List<Object[]> inputs = new ArrayList<>();
+
+ // max in array (returns SV INT)
+ groovyTransformFunction = String
+ .format("groovy('{\"returnType\":\"INT\", \"isSingleValue\":true}', "
+ + "'arg0.toList().max()', "
+ + "%s)", INT_MV_COLUMN);
+ int[] expectedResult1 = new int[NUM_ROWS];
+ for (int i = 0; i < NUM_ROWS; i++) {
+ expectedResult1[i] = Arrays.stream(_intMVValues[i]).max().getAsInt();
+ }
+ inputs.add(new Object[]{groovyTransformFunction, FieldSpec.DataType.INT, true, expectedResult1});
+
+ // simple addition (returns SV LONG)
+ groovyTransformFunction = String
+ .format("groovy('{\"returnType\":\"LONG\", \"isSingleValue\":true}', "
+ + "'arg0 + arg1', "
+ + "%s, %s)", INT_SV_COLUMN, LONG_SV_COLUMN);
+ long[] expectedResult2 = new long[NUM_ROWS];
+ for (int i = 0; i < NUM_ROWS; i++) {
+ expectedResult2[i] = _intSVValues[i] + _longSVValues[i];
+ }
+ inputs.add(new Object[]{groovyTransformFunction, FieldSpec.DataType.LONG, true, expectedResult2});
+
+ // minimum of 2 numbers (returns SV DOUBLE)
+ groovyTransformFunction = String
+ .format("groovy('{\"returnType\":\"DOUBLE\", \"isSingleValue\":true}', "
+ + "'Math.min(arg0, arg1)', "
+ + "%s, %s)", DOUBLE_SV_COLUMN, INT_SV_COLUMN);
+ double[] expectedResult3 = new double[NUM_ROWS];
+ for (int i = 0; i < NUM_ROWS; i++) {
+ expectedResult3[i] = Math.min(_intSVValues[i], _doubleSVValues[i]);
+ }
+ inputs.add(new Object[]{groovyTransformFunction, FieldSpec.DataType.DOUBLE, true, expectedResult3});
+
+ // (returns SV FLOAT)
+ groovyTransformFunction = String.format(
+ "groovy('{\"returnType\":\"FLOAT\", \"isSingleValue\":true}', "
+ + "'def result; switch(arg0.length()) { case 10: result = 1.1; break; case 20: result = 1.2; break; default: result = 1.3;}; return result.floatValue()', "
+ + "%s)", STRING_ALPHANUM_SV_COLUMN);
+ float[] expectedResult4 = new float[NUM_ROWS];
+ for (int i = 0; i < NUM_ROWS; i++) {
+ expectedResult4[i] =
+ _stringAlphaNumericSVValues.length == 10 ? 1.1f : (_stringAlphaNumericSVValues.length == 20 ? 1.2f : 1.3f);
+ }
+ inputs.add(new Object[]{groovyTransformFunction, FieldSpec.DataType.FLOAT, true, expectedResult4});
+
+ // string operations (returns SV STRING)
+ groovyTransformFunction = String.format(
+ "groovy('{\"returnType\":\"STRING\", \"isSingleValue\":true}', "
+ + "'[arg0, arg1, arg2].join(\"_\")', "
+ + "%s, %s, %s)", FLOAT_SV_COLUMN, STRING_SV_COLUMN, DOUBLE_SV_COLUMN);
+ String[] expectedResult5 = new String[NUM_ROWS];
+ for (int i = 0; i < NUM_ROWS; i++) {
+ expectedResult5[i] = Joiner.on("_").join(_floatSVValues[i], _stringSVValues[i], _doubleSVValues[i]);
+ }
+ inputs.add(new Object[]{groovyTransformFunction, FieldSpec.DataType.STRING, true, expectedResult5});
+
+ // find all in array that match (returns MV INT)
+ groovyTransformFunction = String
+ .format("groovy('{\"returnType\":\"INT\", \"isSingleValue\":false}', "
+ + "'arg0.findAll{it < 5}', "
+ + "%s)", INT_MV_COLUMN);
+ int[][] expectedResult6 = new int[NUM_ROWS][];
+ for (int i = 0; i < NUM_ROWS; i++) {
+ expectedResult6[i] = Arrays.stream(_intMVValues[i]).filter(e -> e < 5).toArray();
+ }
+ inputs.add(new Object[]{groovyTransformFunction, FieldSpec.DataType.INT, false, expectedResult6});
+
+ // (returns MV LONG)
+ groovyTransformFunction = String
+ .format("groovy('{\"returnType\":\"LONG\", \"isSingleValue\":false}', "
+ + "'arg0.findIndexValues{it == 5}', "
+ + "%s)", INT_MV_COLUMN);
+ long[][] expectedResult7 = new long[NUM_ROWS][];
+ for (int i = 0; i < NUM_ROWS; i++) {
+ int[] intMVValue = _intMVValues[i];
+ expectedResult7[i] =
+ IntStream.range(0, intMVValue.length).filter(e -> intMVValue[e] == 5).mapToLong(e -> (long) e).toArray();
+ }
+ inputs.add(new Object[]{groovyTransformFunction, FieldSpec.DataType.LONG, false, expectedResult7});
+
+ // no-args function (returns MV STRING)
+ groovyTransformFunction = "groovy('{\"returnType\":\"STRING\", \"isSingleValue\":false}', '[\"foo\", \"bar\"]')";
+ String[][] expectedResult8 = new String[NUM_ROWS][];
+ Arrays.fill(expectedResult8, new String[]{"foo", "bar"});
+ inputs.add(new Object[]{groovyTransformFunction, FieldSpec.DataType.STRING, false, expectedResult8});
+
+ // nested groovy functions
+ String groovy1 = String
+ .format("groovy('{\"returnType\":\"INT\", \"isSingleValue\":true}', 'arg0.toList().max()', %s)", INT_MV_COLUMN);
+ String groovy2 = String
+ .format("groovy('{\"returnType\":\"INT\", \"isSingleValue\":true}', 'arg0.toList().min()', %s)", INT_MV_COLUMN);
+ groovyTransformFunction = String
+ .format("groovy('{\"returnType\":\"INT\", \"isSingleValue\":false}', '[arg0, arg1, arg2.sum()]', %s, %s, %s)",
+ groovy1, groovy2, INT_MV_COLUMN);
+ int[][] expectedResult9 = new int[NUM_ROWS][];
+ for (int i = 0; i < NUM_ROWS; i++) {
+ IntSummaryStatistics stats = Arrays.stream(_intMVValues[i]).summaryStatistics();
+ expectedResult9[i] = new int[]{stats.getMax(), stats.getMin(), (int) stats.getSum()};
+ }
+ inputs.add(new Object[]{groovyTransformFunction, FieldSpec.DataType.INT, false, expectedResult9});
+
+ // nested with other functions
+ groovyTransformFunction = String
+ .format("groovy('{\"returnType\":\"INT\", \"isSingleValue\":true}', 'arg0 + arg1', %s, arraylength(%s))",
+ INT_SV_COLUMN, INT_MV_COLUMN);
+ int[] expectedResult10 = new int[NUM_ROWS];
+ for (int i = 0; i < NUM_ROWS; i++) {
+ expectedResult10[i] = _intSVValues[i] + _intMVValues[i].length;
+ }
+ inputs.add(new Object[]{groovyTransformFunction, FieldSpec.DataType.INT, true, expectedResult10});
+
+ return inputs.toArray(new Object[0][]);
+ }
+
+ @Test(dataProvider = "groovyFunctionDataProvider")
+ public void testGroovyTransformFunctions(String expressionStr, FieldSpec.DataType resultType,
+ boolean isResultSingleValue, Object expectedResult) {
+ ExpressionContext expression = QueryContextConverterUtils.getExpression(expressionStr);
+ TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
+ Assert.assertTrue(transformFunction instanceof GroovyTransformFunction);
+ Assert.assertEquals(transformFunction.getName(), GroovyTransformFunction.FUNCTION_NAME);
+ Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), resultType);
+ Assert.assertEquals(transformFunction.getResultMetadata().isSingleValue(), isResultSingleValue);
+ Assert.assertFalse(transformFunction.getResultMetadata().hasDictionary());
+
+ if (isResultSingleValue) {
+ switch (resultType) {
+
+ case INT:
+ int[] intResults = transformFunction.transformToIntValuesSV(_projectionBlock);
+ int[] expectedInts = (int[]) expectedResult;
+ for (int i = 0; i < NUM_ROWS; i++) {
+ Assert.assertEquals(intResults[i], expectedInts[i]);
+ }
+ break;
+ case LONG:
+ long[] longResults = transformFunction.transformToLongValuesSV(_projectionBlock);
+ long[] expectedLongs = (long[]) expectedResult;
+ for (int i = 0; i < NUM_ROWS; i++) {
+ Assert.assertEquals(longResults[i], expectedLongs[i]);
+ }
+ break;
+ case FLOAT:
+ float[] floatResults = transformFunction.transformToFloatValuesSV(_projectionBlock);
+ float[] expectedFloats = (float[]) expectedResult;
+ for (int i = 0; i < NUM_ROWS; i++) {
+ Assert.assertEquals(floatResults[i], expectedFloats[i]);
+ }
+ break;
+ case DOUBLE:
+ double[] doubleResults = transformFunction.transformToDoubleValuesSV(_projectionBlock);
+ double[] expectedDoubles = (double[]) expectedResult;
+ for (int i = 0; i < NUM_ROWS; i++) {
+ Assert.assertEquals(doubleResults[i], expectedDoubles[i]);
+ }
+ break;
+ case STRING:
+ String[] stringResults = transformFunction.transformToStringValuesSV(_projectionBlock);
+ String[] expectedStrings = (String[]) expectedResult;
+ for (int i = 0; i < NUM_ROWS; i++) {
+ Assert.assertEquals(stringResults[i], expectedStrings[i]);
+ }
+ break;
+ }
+ } else {
+ switch (resultType) {
+
+ case INT:
+ int[][] intResults = transformFunction.transformToIntValuesMV(_projectionBlock);
+ int[][] expectedInts = (int[][]) expectedResult;
+ for (int i = 0; i < NUM_ROWS; i++) {
+ Assert.assertEquals(intResults[i], expectedInts[i]);
+ }
+ break;
+ case LONG:
+ long[][] longResults = transformFunction.transformToLongValuesMV(_projectionBlock);
+ long[][] expectedLongs = (long[][]) expectedResult;
+ for (int i = 0; i < NUM_ROWS; i++) {
+ Assert.assertEquals(longResults[i], expectedLongs[i]);
+ }
+ break;
+ case FLOAT:
+ float[][] floatResults = transformFunction.transformToFloatValuesMV(_projectionBlock);
+ float[][] expectedFloats = (float[][]) expectedResult;
+ for (int i = 0; i < NUM_ROWS; i++) {
+ Assert.assertEquals(floatResults[i], expectedFloats[i]);
+ }
+ break;
+ case DOUBLE:
+ double[][] doubleResults = transformFunction.transformToDoubleValuesMV(_projectionBlock);
+ double[][] expectedDoubles = (double[][]) expectedResult;
+ for (int i = 0; i < NUM_ROWS; i++) {
+ Assert.assertEquals(doubleResults[i], expectedDoubles[i]);
+ }
+ break;
+ case STRING:
+ String[][] stringResults = transformFunction.transformToStringValuesMV(_projectionBlock);
+ String[][] expectedStrings = (String[][]) expectedResult;
+ for (int i = 0; i < NUM_ROWS; i++) {
+ Assert.assertEquals(stringResults[i], expectedStrings[i]);
+ }
+ break;
+ }
+ }
+ }
+
+ @Test(dataProvider = "testIllegalArguments", expectedExceptions = {BadQueryRequestException.class})
+ public void testIllegalArguments(String expressionStr) {
+ ExpressionContext expression = QueryContextConverterUtils.getExpression(expressionStr);
+ TransformFunctionFactory.get(expression, _dataSourceMap);
+ }
+
+ @DataProvider(name = "testIllegalArguments")
+ public Object[][] testIllegalArguments() {
+ List<Object[]> inputs = new ArrayList<>();
+ // incorrect number of arguments
+ inputs.add(new Object[]{String.format("groovy(%s)", STRING_SV_COLUMN)});
+ // first argument must be literal
+ inputs.add(new Object[]{String.format("groovy(%s, %s)", DOUBLE_SV_COLUMN, STRING_SV_COLUMN)});
+ // second argument must be a literal
+ inputs.add(new Object[]{String.format("groovy('arg0 + 10', %s)", STRING_SV_COLUMN)});
+ // first argument must be a valid json
+ inputs.add(new Object[]{String.format("groovy(']]', 'arg0 + 10', %s)", STRING_SV_COLUMN)});
+ // first argument json must contain non-null key returnType
+ inputs.add(new Object[]{String.format("groovy('{\"isSingleValue\":true}', 'arg0 + 10', %s)", INT_SV_COLUMN)});
+ inputs.add(new Object[]{String.format("groovy('{\"returnType\":null, \"isSingleValue\":true}', 'arg0 + 10', %s)",
+ INT_SV_COLUMN)});
+ // first argument json must contain non-null key isSingleValue
+ inputs.add(new Object[]{String.format("groovy('{\"returnType\":\"INT\"}', 'arg0 + 10', %s)", INT_SV_COLUMN)});
+ inputs.add(new Object[]{String.format("groovy('{\"returnType\":\"INT\", \"isSingleValue\":null}', 'arg0 + 10', %s)",
+ INT_SV_COLUMN)});
+ // return type must be valid DataType enum
+ inputs.add(new Object[]{String.format("groovy('{\"returnType\":\"foo\", \"isSingleValue\":true}', 'arg0 + 10', %s)",
+ INT_SV_COLUMN)});
+ // arguments must be columns/transform functions
+ inputs.add(new Object[]{"groovy('{\"returnType\":\"INT\", \"isSingleValue\":true}', 'arg0 + 10', 'foo')"});
+ inputs.add(new Object[]{String.format(
+ "groovy('{\"returnType\":\"INT\", \"isSingleValue\":true}', 'arg0 + arg1 + 10', 'arraylength(colB)', %s)",
+ INT_SV_COLUMN)});
+ // invalid groovy expression
+ inputs.add(new Object[]{"groovy('{\"returnType\":\"INT\"}', '+-+')"});
+ inputs.add(new Object[]{String.format("groovy('{\"returnType\":\"INT\"}', '+-+arg0 arg1', %s, %s)", INT_SV_COLUMN,
+ DOUBLE_SV_COLUMN)});
+ return inputs.toArray(new Object[0][]);
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org