You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ib...@apache.org on 2021/04/22 22:06:31 UTC
[beam] branch master updated: [BEAM-12194] Implement CREATE
AGGREGATE FUNCTION.
This is an automated email from the ASF dual-hosted git repository.
ibzib pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 8c8c330 [BEAM-12194] Implement CREATE AGGREGATE FUNCTION.
new c4cea56 Merge pull request #14609 from ibzib/BEAM-12194
8c8c330 is described below
commit 8c8c330e6d0afbc6740563a53c884b90ee16a4ef
Author: Kyle Weaver <kc...@google.com>
AuthorDate: Tue Apr 20 14:27:46 2021 -0700
[BEAM-12194] Implement CREATE AGGREGATE FUNCTION.
---
.../sdk/extensions/sql/impl/JavaUdfLoader.java | 50 ++++++++++++++-
.../sql/impl/LazyAggregateCombineFn.java | 71 ++++++++++++++++++++++
.../sdk/extensions/sql/impl/JavaUdfLoaderTest.java | 16 +++++
.../extensions/sql/provider/UdfTestProvider.java | 32 ++++++++++
.../extensions/sql/zetasql/BeamZetaSqlCatalog.java | 10 +++
.../extensions/sql/zetasql/ZetaSqlJavaUdfTest.java | 38 ++++++++++++
6 files changed, 216 insertions(+), 1 deletion(-)
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/JavaUdfLoader.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/JavaUdfLoader.java
index 57791a7..429ee63 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/JavaUdfLoader.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/JavaUdfLoader.java
@@ -35,6 +35,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
+import org.apache.beam.sdk.extensions.sql.udf.AggregateFn;
import org.apache.beam.sdk.extensions.sql.udf.ScalarFn;
import org.apache.beam.sdk.extensions.sql.udf.UdfProvider;
import org.apache.beam.sdk.io.FileSystems;
@@ -91,6 +92,33 @@ public class JavaUdfLoader {
}
}
+ /** Load a user-defined aggregate function from the specified jar. */
+ public AggregateFn loadAggregateFunction(List<String> functionPath, String jarPath) {
+ String functionFullName = String.join(".", functionPath);
+ try {
+ FunctionDefinitions functionDefinitions = loadJar(jarPath);
+ if (!functionDefinitions.aggregateFunctions().containsKey(functionPath)) {
+ throw new IllegalArgumentException(
+ String.format(
+ "No implementation of aggregate function %s found in %s.%n"
+ + " 1. Create a class implementing %s and annotate it with @AutoService(%s.class).%n"
+ + " 2. Add function %s to the class's userDefinedAggregateFunctions implementation.",
+ functionFullName,
+ jarPath,
+ UdfProvider.class.getSimpleName(),
+ UdfProvider.class.getSimpleName(),
+ functionFullName));
+ }
+ return functionDefinitions.aggregateFunctions().get(functionPath);
+ } catch (IOException e) {
+ throw new RuntimeException(
+ String.format(
+ "Failed to load user-defined aggregate function %s from %s",
+ functionFullName, jarPath),
+ e);
+ }
+ }
+
/**
* Creates a temporary local copy of the file at {@code inputPath}, and returns a handle to the
* local copy.
@@ -167,6 +195,7 @@ public class JavaUdfLoader {
ClassLoader classLoader = createClassLoader(jarPath);
Map<List<String>, ScalarFn> scalarFunctions = new HashMap<>();
+ Map<List<String>, AggregateFn> aggregateFunctions = new HashMap<>();
Iterator<UdfProvider> providers = getUdfProviders(classLoader);
int providersCount = 0;
while (providers.hasNext()) {
@@ -185,6 +214,19 @@ public class JavaUdfLoader {
}
scalarFunctions.put(functionPath, implementation);
});
+ provider
+ .userDefinedAggregateFunctions()
+ .forEach(
+ (functionName, implementation) -> {
+ List<String> functionPath = ImmutableList.copyOf(functionName.split("\\."));
+ if (aggregateFunctions.containsKey(functionPath)) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Found multiple definitions of aggregate function %s in %s.",
+ functionName, jarPath));
+ }
+ aggregateFunctions.put(functionPath, implementation);
+ });
}
if (providersCount == 0) {
throw new ProviderNotFoundException(
@@ -204,6 +246,7 @@ public class JavaUdfLoader {
FunctionDefinitions userFunctionDefinitions =
FunctionDefinitions.newBuilder()
.setScalarFunctions(ImmutableMap.copyOf(scalarFunctions))
+ .setAggregateFunctions(ImmutableMap.copyOf(aggregateFunctions))
.build();
functionCache.put(jarPath, userFunctionDefinitions);
@@ -216,16 +259,21 @@ public class JavaUdfLoader {
abstract static class FunctionDefinitions {
abstract ImmutableMap<List<String>, ScalarFn> scalarFunctions();
+ abstract ImmutableMap<List<String>, AggregateFn> aggregateFunctions();
+
@AutoValue.Builder
abstract static class Builder {
abstract Builder setScalarFunctions(ImmutableMap<List<String>, ScalarFn> value);
+ abstract Builder setAggregateFunctions(ImmutableMap<List<String>, AggregateFn> value);
+
abstract FunctionDefinitions build();
}
static Builder newBuilder() {
return new AutoValue_JavaUdfLoader_FunctionDefinitions.Builder()
- .setScalarFunctions(ImmutableMap.of());
+ .setScalarFunctions(ImmutableMap.of())
+ .setAggregateFunctions(ImmutableMap.of());
}
}
}
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
new file mode 100644
index 0000000..90abc0c
--- /dev/null
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.impl;
+
+import edu.umd.cs.findbugs.annotations.Nullable;
+import java.util.Iterator;
+import java.util.List;
+import org.apache.beam.sdk.extensions.sql.udf.AggregateFn;
+import org.apache.beam.sdk.transforms.Combine;
+
+/**
+ * {@link org.apache.beam.sdk.transforms.Combine.CombineFn} that wraps an {@link AggregateFn}. The
+ * {@link AggregateFn} is lazily instantiated so it doesn't have to be serialized/deserialized.
+ */
+public class LazyAggregateCombineFn<InputT, AccumT, OutputT>
+ extends Combine.CombineFn<InputT, AccumT, OutputT> {
+ private final List<String> functionPath;
+ private final String jarPath;
+ private transient @Nullable AggregateFn<InputT, AccumT, OutputT> aggregateFn = null;
+
+ public LazyAggregateCombineFn(List<String> functionPath, String jarPath) {
+ this.functionPath = functionPath;
+ this.jarPath = jarPath;
+ }
+
+ private AggregateFn<InputT, AccumT, OutputT> getAggregateFn() {
+ if (aggregateFn == null) {
+ JavaUdfLoader loader = new JavaUdfLoader();
+ aggregateFn = loader.loadAggregateFunction(functionPath, jarPath);
+ }
+ return aggregateFn;
+ }
+
+ @Override
+ public AccumT createAccumulator() {
+ return getAggregateFn().createAccumulator();
+ }
+
+ @Override
+ public AccumT addInput(AccumT mutableAccumulator, InputT input) {
+ return getAggregateFn().addInput(mutableAccumulator, input);
+ }
+
+ @Override
+ public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
+ Iterator<AccumT> it = accumulators.iterator();
+ AccumT first = it.next();
+ it.remove();
+ return getAggregateFn().mergeAccumulators(first, accumulators);
+ }
+
+ @Override
+ public OutputT extractOutput(AccumT accumulator) {
+ return getAggregateFn().extractOutput(accumulator);
+ }
+}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/JavaUdfLoaderTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/JavaUdfLoaderTest.java
index d169d65e..d769600 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/JavaUdfLoaderTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/JavaUdfLoaderTest.java
@@ -69,6 +69,12 @@ public class JavaUdfLoaderTest {
}
@Test
+ public void testLoadAggregateFunction() {
+ JavaUdfLoader udfLoader = new JavaUdfLoader();
+ udfLoader.loadAggregateFunction(Collections.singletonList("my_sum"), jarPath);
+ }
+
+ @Test
public void testLoadUnregisteredScalarFunctionThrowsRuntimeException() {
JavaUdfLoader udfLoader = new JavaUdfLoader();
thrown.expect(RuntimeException.class);
@@ -78,6 +84,16 @@ public class JavaUdfLoaderTest {
}
@Test
+ public void testLoadUnregisteredAggregateFunctionThrowsRuntimeException() {
+ JavaUdfLoader udfLoader = new JavaUdfLoader();
+ thrown.expect(RuntimeException.class);
+ thrown.expectMessage(
+ String.format(
+ "No implementation of aggregate function notRegistered found in %s.", jarPath));
+ udfLoader.loadAggregateFunction(Collections.singletonList("notRegistered"), jarPath);
+ }
+
+ @Test
public void testJarMissingUdfProviderThrowsProviderNotFoundException() {
JavaUdfLoader udfLoader = new JavaUdfLoader();
thrown.expect(ProviderNotFoundException.class);
diff --git a/sdks/java/extensions/sql/udf-test-provider/src/main/java/org/apache/beam/sdk/extensions/sql/provider/UdfTestProvider.java b/sdks/java/extensions/sql/udf-test-provider/src/main/java/org/apache/beam/sdk/extensions/sql/provider/UdfTestProvider.java
index 1f574db..1d615a9 100644
--- a/sdks/java/extensions/sql/udf-test-provider/src/main/java/org/apache/beam/sdk/extensions/sql/provider/UdfTestProvider.java
+++ b/sdks/java/extensions/sql/udf-test-provider/src/main/java/org/apache/beam/sdk/extensions/sql/provider/UdfTestProvider.java
@@ -19,6 +19,7 @@ package org.apache.beam.sdk.extensions.sql.provider;
import com.google.auto.service.AutoService;
import java.util.Map;
+import org.apache.beam.sdk.extensions.sql.udf.AggregateFn;
import org.apache.beam.sdk.extensions.sql.udf.ScalarFn;
import org.apache.beam.sdk.extensions.sql.udf.UdfProvider;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
@@ -39,6 +40,11 @@ public class UdfTestProvider implements UdfProvider {
new IsNullFn());
}
+ @Override
+ public Map<String, AggregateFn<?, ?, ?>> userDefinedAggregateFunctions() {
+ return ImmutableMap.of("my_sum", new Sum());
+ }
+
public static class HelloWorldFn extends ScalarFn {
@ApplyMethod
public String helloWorld() {
@@ -73,4 +79,30 @@ public class UdfTestProvider implements UdfProvider {
return "This method is not registered as a UDF.";
}
}
+
+ public static class Sum implements AggregateFn<Long, Long, Long> {
+
+ @Override
+ public Long createAccumulator() {
+ return 0L;
+ }
+
+ @Override
+ public Long addInput(Long mutableAccumulator, Long input) {
+ return mutableAccumulator + input;
+ }
+
+ @Override
+ public Long mergeAccumulators(Long mutableAccumulator, Iterable<Long> immutableAccumulators) {
+ for (Long x : immutableAccumulators) {
+ mutableAccumulator += x;
+ }
+ return mutableAccumulator;
+ }
+
+ @Override
+ public Long extractOutput(Long mutableAccumulator) {
+ return mutableAccumulator;
+ }
+ }
}
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCatalog.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCatalog.java
index e38cb5e..8a8b47e 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCatalog.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCatalog.java
@@ -42,6 +42,7 @@ import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.beam.sdk.extensions.sql.impl.JavaUdfLoader;
+import org.apache.beam.sdk.extensions.sql.impl.LazyAggregateCombineFn;
import org.apache.beam.sdk.extensions.sql.impl.ScalarFnReflector;
import org.apache.beam.sdk.extensions.sql.impl.ScalarFunctionImpl;
import org.apache.beam.sdk.extensions.sql.impl.UdafImpl;
@@ -150,6 +151,15 @@ public class BeamZetaSqlCatalog {
createFunctionStmt.getNamePath(),
UserFunctionDefinitions.JavaScalarFunction.create(method, jarPath));
break;
+ case USER_DEFINED_JAVA_AGGREGATE_FUNCTIONS:
+ jarPath = getJarPath(createFunctionStmt);
+ // Try loading the aggregate function just to make sure it exists. LazyAggregateCombineFn
+ // will need to fetch it again at runtime.
+ javaUdfLoader.loadAggregateFunction(createFunctionStmt.getNamePath(), jarPath);
+ Combine.CombineFn<?, ?, ?> combineFn =
+ new LazyAggregateCombineFn<>(createFunctionStmt.getNamePath(), jarPath);
+ javaUdafs.put(createFunctionStmt.getNamePath(), combineFn);
+ break;
default:
throw new IllegalArgumentException(
String.format("Encountered unrecognized function group %s.", functionGroup));
diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java
index 22b18f9..5cd0826 100644
--- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java
+++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java
@@ -400,6 +400,44 @@ public class ZetaSqlJavaUdfTest extends ZetaSqlTestBase {
}
@Test
+ public void testUdaf() {
+ String sql =
+ String.format(
+ "CREATE AGGREGATE FUNCTION my_sum(f INT64) RETURNS INT64 LANGUAGE java OPTIONS (path='%s'); "
+ + "SELECT my_sum(f_int_1) from aggregate_test_table",
+ jarPath);
+ ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
+ BeamRelNode beamRelNode = zetaSQLQueryPlanner.convertToBeamRel(sql);
+ PCollection<Row> stream = BeamSqlRelUtils.toPCollection(pipeline, beamRelNode);
+
+ Schema singleField = Schema.builder().addInt64Field("field1").build();
+
+ PAssert.that(stream).containsInAnyOrder(Row.withSchema(singleField).addValues(28L).build());
+ pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
+ }
+
+ @Test
+ public void testUdafNotFoundFailsToParse() {
+ String sql =
+ String.format(
+ "CREATE AGGREGATE FUNCTION nonexistent(f INT64) RETURNS INT64 LANGUAGE java OPTIONS (path='%s'); "
+ + "SELECT nonexistent(f_int_1) from aggregate_test_table",
+ jarPath);
+ ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
+
+ thrown.expect(RuntimeException.class);
+ thrown.expectMessage("Failed to define function 'nonexistent'");
+ thrown.expectCause(
+ allOf(
+ isA(IllegalArgumentException.class),
+ hasProperty(
+ "message",
+ containsString("No implementation of aggregate function nonexistent found"))));
+
+ BeamRelNode beamRelNode = zetaSQLQueryPlanner.convertToBeamRel(sql);
+ }
+
+ @Test
public void testRegisterUdaf() {
String sql = "SELECT my_sum(k) FROM UNNEST([1, 2, 3]) k;";
PCollection<Row> stream =