You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ib...@apache.org on 2021/04/22 22:06:31 UTC

[beam] branch master updated: [BEAM-12194] Implement CREATE AGGREGATE FUNCTION.

This is an automated email from the ASF dual-hosted git repository.

ibzib pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 8c8c330  [BEAM-12194] Implement CREATE AGGREGATE FUNCTION.
     new c4cea56  Merge pull request #14609 from ibzib/BEAM-12194
8c8c330 is described below

commit 8c8c330e6d0afbc6740563a53c884b90ee16a4ef
Author: Kyle Weaver <kc...@google.com>
AuthorDate: Tue Apr 20 14:27:46 2021 -0700

    [BEAM-12194] Implement CREATE AGGREGATE FUNCTION.
---
 .../sdk/extensions/sql/impl/JavaUdfLoader.java     | 50 ++++++++++++++-
 .../sql/impl/LazyAggregateCombineFn.java           | 71 ++++++++++++++++++++++
 .../sdk/extensions/sql/impl/JavaUdfLoaderTest.java | 16 +++++
 .../extensions/sql/provider/UdfTestProvider.java   | 32 ++++++++++
 .../extensions/sql/zetasql/BeamZetaSqlCatalog.java | 10 +++
 .../extensions/sql/zetasql/ZetaSqlJavaUdfTest.java | 38 ++++++++++++
 6 files changed, 216 insertions(+), 1 deletion(-)

diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/JavaUdfLoader.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/JavaUdfLoader.java
index 57791a7..429ee63 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/JavaUdfLoader.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/JavaUdfLoader.java
@@ -35,6 +35,7 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.ServiceLoader;
+import org.apache.beam.sdk.extensions.sql.udf.AggregateFn;
 import org.apache.beam.sdk.extensions.sql.udf.ScalarFn;
 import org.apache.beam.sdk.extensions.sql.udf.UdfProvider;
 import org.apache.beam.sdk.io.FileSystems;
@@ -91,6 +92,33 @@ public class JavaUdfLoader {
     }
   }
 
+  /** Load a user-defined aggregate function from the specified jar. */
+  public AggregateFn loadAggregateFunction(List<String> functionPath, String jarPath) {
+    String functionFullName = String.join(".", functionPath);
+    try {
+      FunctionDefinitions functionDefinitions = loadJar(jarPath);
+      if (!functionDefinitions.aggregateFunctions().containsKey(functionPath)) {
+        throw new IllegalArgumentException(
+            String.format(
+                "No implementation of aggregate function %s found in %s.%n"
+                    + " 1. Create a class implementing %s and annotate it with @AutoService(%s.class).%n"
+                    + " 2. Add function %s to the class's userDefinedAggregateFunctions implementation.",
+                functionFullName,
+                jarPath,
+                UdfProvider.class.getSimpleName(),
+                UdfProvider.class.getSimpleName(),
+                functionFullName));
+      }
+      return functionDefinitions.aggregateFunctions().get(functionPath);
+    } catch (IOException e) {
+      throw new RuntimeException(
+          String.format(
+              "Failed to load user-defined aggregate function %s from %s",
+              functionFullName, jarPath),
+          e);
+    }
+  }
+
   /**
    * Creates a temporary local copy of the file at {@code inputPath}, and returns a handle to the
    * local copy.
@@ -167,6 +195,7 @@ public class JavaUdfLoader {
 
     ClassLoader classLoader = createClassLoader(jarPath);
     Map<List<String>, ScalarFn> scalarFunctions = new HashMap<>();
+    Map<List<String>, AggregateFn> aggregateFunctions = new HashMap<>();
     Iterator<UdfProvider> providers = getUdfProviders(classLoader);
     int providersCount = 0;
     while (providers.hasNext()) {
@@ -185,6 +214,19 @@ public class JavaUdfLoader {
                 }
                 scalarFunctions.put(functionPath, implementation);
               });
+      provider
+          .userDefinedAggregateFunctions()
+          .forEach(
+              (functionName, implementation) -> {
+                List<String> functionPath = ImmutableList.copyOf(functionName.split("\\."));
+                if (aggregateFunctions.containsKey(functionPath)) {
+                  throw new IllegalArgumentException(
+                      String.format(
+                          "Found multiple definitions of aggregate function %s in %s.",
+                          functionName, jarPath));
+                }
+                aggregateFunctions.put(functionPath, implementation);
+              });
     }
     if (providersCount == 0) {
       throw new ProviderNotFoundException(
@@ -204,6 +246,7 @@ public class JavaUdfLoader {
     FunctionDefinitions userFunctionDefinitions =
         FunctionDefinitions.newBuilder()
             .setScalarFunctions(ImmutableMap.copyOf(scalarFunctions))
+            .setAggregateFunctions(ImmutableMap.copyOf(aggregateFunctions))
             .build();
 
     functionCache.put(jarPath, userFunctionDefinitions);
@@ -216,16 +259,21 @@ public class JavaUdfLoader {
   abstract static class FunctionDefinitions {
     abstract ImmutableMap<List<String>, ScalarFn> scalarFunctions();
 
+    abstract ImmutableMap<List<String>, AggregateFn> aggregateFunctions();
+
     @AutoValue.Builder
     abstract static class Builder {
       abstract Builder setScalarFunctions(ImmutableMap<List<String>, ScalarFn> value);
 
+      abstract Builder setAggregateFunctions(ImmutableMap<List<String>, AggregateFn> value);
+
       abstract FunctionDefinitions build();
     }
 
     static Builder newBuilder() {
       return new AutoValue_JavaUdfLoader_FunctionDefinitions.Builder()
-          .setScalarFunctions(ImmutableMap.of());
+          .setScalarFunctions(ImmutableMap.of())
+          .setAggregateFunctions(ImmutableMap.of());
     }
   }
 }
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
new file mode 100644
index 0000000..90abc0c
--- /dev/null
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.impl;
+
+import edu.umd.cs.findbugs.annotations.Nullable;
+import java.util.Iterator;
+import java.util.List;
+import org.apache.beam.sdk.extensions.sql.udf.AggregateFn;
+import org.apache.beam.sdk.transforms.Combine;
+
+/**
+ * {@link org.apache.beam.sdk.transforms.Combine.CombineFn} that wraps an {@link AggregateFn}. The
+ * {@link AggregateFn} is lazily instantiated so it doesn't have to be serialized/deserialized.
+ */
+public class LazyAggregateCombineFn<InputT, AccumT, OutputT>
+    extends Combine.CombineFn<InputT, AccumT, OutputT> {
+  private final List<String> functionPath;
+  private final String jarPath;
+  private transient @Nullable AggregateFn<InputT, AccumT, OutputT> aggregateFn = null;
+
+  public LazyAggregateCombineFn(List<String> functionPath, String jarPath) {
+    this.functionPath = functionPath;
+    this.jarPath = jarPath;
+  }
+
+  private AggregateFn<InputT, AccumT, OutputT> getAggregateFn() {
+    if (aggregateFn == null) {
+      JavaUdfLoader loader = new JavaUdfLoader();
+      aggregateFn = loader.loadAggregateFunction(functionPath, jarPath);
+    }
+    return aggregateFn;
+  }
+
+  @Override
+  public AccumT createAccumulator() {
+    return getAggregateFn().createAccumulator();
+  }
+
+  @Override
+  public AccumT addInput(AccumT mutableAccumulator, InputT input) {
+    return getAggregateFn().addInput(mutableAccumulator, input);
+  }
+
+  @Override
+  public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
+    Iterator<AccumT> it = accumulators.iterator();
+    AccumT first = it.next();
+    it.remove();
+    return getAggregateFn().mergeAccumulators(first, accumulators);
+  }
+
+  @Override
+  public OutputT extractOutput(AccumT accumulator) {
+    return getAggregateFn().extractOutput(accumulator);
+  }
+}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/JavaUdfLoaderTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/JavaUdfLoaderTest.java
index d169d65e..d769600 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/JavaUdfLoaderTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/JavaUdfLoaderTest.java
@@ -69,6 +69,12 @@ public class JavaUdfLoaderTest {
   }
 
   @Test
+  public void testLoadAggregateFunction() {
+    JavaUdfLoader udfLoader = new JavaUdfLoader();
+    udfLoader.loadAggregateFunction(Collections.singletonList("my_sum"), jarPath);
+  }
+
+  @Test
   public void testLoadUnregisteredScalarFunctionThrowsRuntimeException() {
     JavaUdfLoader udfLoader = new JavaUdfLoader();
     thrown.expect(RuntimeException.class);
@@ -78,6 +84,16 @@ public class JavaUdfLoaderTest {
   }
 
   @Test
+  public void testLoadUnregisteredAggregateFunctionThrowsRuntimeException() {
+    JavaUdfLoader udfLoader = new JavaUdfLoader();
+    thrown.expect(RuntimeException.class);
+    thrown.expectMessage(
+        String.format(
+            "No implementation of aggregate function notRegistered found in %s.", jarPath));
+    udfLoader.loadAggregateFunction(Collections.singletonList("notRegistered"), jarPath);
+  }
+
+  @Test
   public void testJarMissingUdfProviderThrowsProviderNotFoundException() {
     JavaUdfLoader udfLoader = new JavaUdfLoader();
     thrown.expect(ProviderNotFoundException.class);
diff --git a/sdks/java/extensions/sql/udf-test-provider/src/main/java/org/apache/beam/sdk/extensions/sql/provider/UdfTestProvider.java b/sdks/java/extensions/sql/udf-test-provider/src/main/java/org/apache/beam/sdk/extensions/sql/provider/UdfTestProvider.java
index 1f574db..1d615a9 100644
--- a/sdks/java/extensions/sql/udf-test-provider/src/main/java/org/apache/beam/sdk/extensions/sql/provider/UdfTestProvider.java
+++ b/sdks/java/extensions/sql/udf-test-provider/src/main/java/org/apache/beam/sdk/extensions/sql/provider/UdfTestProvider.java
@@ -19,6 +19,7 @@ package org.apache.beam.sdk.extensions.sql.provider;
 
 import com.google.auto.service.AutoService;
 import java.util.Map;
+import org.apache.beam.sdk.extensions.sql.udf.AggregateFn;
 import org.apache.beam.sdk.extensions.sql.udf.ScalarFn;
 import org.apache.beam.sdk.extensions.sql.udf.UdfProvider;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
@@ -39,6 +40,11 @@ public class UdfTestProvider implements UdfProvider {
         new IsNullFn());
   }
 
+  @Override
+  public Map<String, AggregateFn<?, ?, ?>> userDefinedAggregateFunctions() {
+    return ImmutableMap.of("my_sum", new Sum());
+  }
+
   public static class HelloWorldFn extends ScalarFn {
     @ApplyMethod
     public String helloWorld() {
@@ -73,4 +79,30 @@ public class UdfTestProvider implements UdfProvider {
       return "This method is not registered as a UDF.";
     }
   }
+
+  public static class Sum implements AggregateFn<Long, Long, Long> {
+
+    @Override
+    public Long createAccumulator() {
+      return 0L;
+    }
+
+    @Override
+    public Long addInput(Long mutableAccumulator, Long input) {
+      return mutableAccumulator + input;
+    }
+
+    @Override
+    public Long mergeAccumulators(Long mutableAccumulator, Iterable<Long> immutableAccumulators) {
+      for (Long x : immutableAccumulators) {
+        mutableAccumulator += x;
+      }
+      return mutableAccumulator;
+    }
+
+    @Override
+    public Long extractOutput(Long mutableAccumulator) {
+      return mutableAccumulator;
+    }
+  }
 }
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCatalog.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCatalog.java
index e38cb5e..8a8b47e 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCatalog.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCatalog.java
@@ -42,6 +42,7 @@ import java.util.Map;
 import java.util.Optional;
 import java.util.stream.Collectors;
 import org.apache.beam.sdk.extensions.sql.impl.JavaUdfLoader;
+import org.apache.beam.sdk.extensions.sql.impl.LazyAggregateCombineFn;
 import org.apache.beam.sdk.extensions.sql.impl.ScalarFnReflector;
 import org.apache.beam.sdk.extensions.sql.impl.ScalarFunctionImpl;
 import org.apache.beam.sdk.extensions.sql.impl.UdafImpl;
@@ -150,6 +151,15 @@ public class BeamZetaSqlCatalog {
             createFunctionStmt.getNamePath(),
             UserFunctionDefinitions.JavaScalarFunction.create(method, jarPath));
         break;
+      case USER_DEFINED_JAVA_AGGREGATE_FUNCTIONS:
+        jarPath = getJarPath(createFunctionStmt);
+        // Try loading the aggregate function just to make sure it exists. LazyAggregateCombineFn
+        // will need to fetch it again at runtime.
+        javaUdfLoader.loadAggregateFunction(createFunctionStmt.getNamePath(), jarPath);
+        Combine.CombineFn<?, ?, ?> combineFn =
+            new LazyAggregateCombineFn<>(createFunctionStmt.getNamePath(), jarPath);
+        javaUdafs.put(createFunctionStmt.getNamePath(), combineFn);
+        break;
       default:
         throw new IllegalArgumentException(
             String.format("Encountered unrecognized function group %s.", functionGroup));
diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java
index 22b18f9..5cd0826 100644
--- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java
+++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java
@@ -400,6 +400,44 @@ public class ZetaSqlJavaUdfTest extends ZetaSqlTestBase {
   }
 
   @Test
+  public void testUdaf() {
+    String sql =
+        String.format(
+            "CREATE AGGREGATE FUNCTION my_sum(f INT64) RETURNS INT64 LANGUAGE java OPTIONS (path='%s'); "
+                + "SELECT my_sum(f_int_1) from aggregate_test_table",
+            jarPath);
+    ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
+    BeamRelNode beamRelNode = zetaSQLQueryPlanner.convertToBeamRel(sql);
+    PCollection<Row> stream = BeamSqlRelUtils.toPCollection(pipeline, beamRelNode);
+
+    Schema singleField = Schema.builder().addInt64Field("field1").build();
+
+    PAssert.that(stream).containsInAnyOrder(Row.withSchema(singleField).addValues(28L).build());
+    pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
+  }
+
+  @Test
+  public void testUdafNotFoundFailsToParse() {
+    String sql =
+        String.format(
+            "CREATE AGGREGATE FUNCTION nonexistent(f INT64) RETURNS INT64 LANGUAGE java OPTIONS (path='%s'); "
+                + "SELECT nonexistent(f_int_1) from aggregate_test_table",
+            jarPath);
+    ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
+
+    thrown.expect(RuntimeException.class);
+    thrown.expectMessage("Failed to define function 'nonexistent'");
+    thrown.expectCause(
+        allOf(
+            isA(IllegalArgumentException.class),
+            hasProperty(
+                "message",
+                containsString("No implementation of aggregate function nonexistent found"))));
+
+    BeamRelNode beamRelNode = zetaSQLQueryPlanner.convertToBeamRel(sql);
+  }
+
+  @Test
   public void testRegisterUdaf() {
     String sql = "SELECT my_sum(k) FROM UNNEST([1, 2, 3]) k;";
     PCollection<Row> stream =