You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ib...@apache.org on 2021/04/30 16:42:22 UTC

[beam] branch master updated: [BEAM-12257] Infer accumulator coder from underlying AggregateFn in LazyAggregateCombineFn.

This is an automated email from the ASF dual-hosted git repository.

ibzib pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 24d6cda  [BEAM-12257] Infer accumulator coder from underlying AggregateFn in LazyAggregateCombineFn.
     new 7e1aa0d  Merge pull request #14693 from ibzib/accum-coder
24d6cda is described below

commit 24d6cda57476bd9c878dbad40daba5453cb5e1e5
Author: Kyle Weaver <kc...@google.com>
AuthorDate: Thu Apr 29 17:43:31 2021 -0700

    [BEAM-12257] Infer accumulator coder from underlying AggregateFn in LazyAggregateCombineFn.
---
 .../sql/impl/LazyAggregateCombineFn.java           | 31 ++++++++++
 .../sql/impl/LazyAggregateCombineFnTest.java       | 69 ++++++++++++++++++++++
 2 files changed, 100 insertions(+)

diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
index 90abc0c..3b782d9 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
@@ -18,10 +18,18 @@
 package org.apache.beam.sdk.extensions.sql.impl;
 
 import edu.umd.cs.findbugs.annotations.Nullable;
+import java.lang.reflect.Type;
+import java.lang.reflect.TypeVariable;
 import java.util.Iterator;
 import java.util.List;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
 import org.apache.beam.sdk.extensions.sql.udf.AggregateFn;
 import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 
 /**
  * {@link org.apache.beam.sdk.transforms.Combine.CombineFn} that wraps an {@link AggregateFn}. The
@@ -38,6 +46,13 @@ public class LazyAggregateCombineFn<InputT, AccumT, OutputT>
     this.jarPath = jarPath;
   }
 
+  @VisibleForTesting
+  LazyAggregateCombineFn(AggregateFn aggregateFn) {
+    this.functionPath = ImmutableList.of();
+    this.jarPath = "";
+    this.aggregateFn = aggregateFn;
+  }
+
   private AggregateFn<InputT, AccumT, OutputT> getAggregateFn() {
     if (aggregateFn == null) {
       JavaUdfLoader loader = new JavaUdfLoader();
@@ -68,4 +83,20 @@ public class LazyAggregateCombineFn<InputT, AccumT, OutputT>
   public OutputT extractOutput(AccumT accumulator) {
     return getAggregateFn().extractOutput(accumulator);
   }
+
+  @Override
+  public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry, Coder<InputT> inputCoder)
+      throws CannotProvideCoderException {
+    // Infer coder based on underlying AggregateFn instance.
+    return registry.getCoder(
+        getAggregateFn().getClass(),
+        AggregateFn.class,
+        ImmutableMap.<Type, Coder<?>>of(getInputTVariable(), inputCoder),
+        getAccumTVariable());
+  }
+
+  @Override
+  public TypeVariable<?> getAccumTVariable() {
+    return AggregateFn.class.getTypeParameters()[1];
+  }
 }
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java
new file mode 100644
index 0000000..21ab8d0
--- /dev/null
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.impl;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.instanceOf;
+
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.extensions.sql.udf.AggregateFn;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link LazyAggregateCombineFn}. */
+@RunWith(JUnit4.class)
+public class LazyAggregateCombineFnTest {
+
+  @Test
+  public void getAccumulatorCoderInfersCoderForWildcardTypeParameter()
+      throws CannotProvideCoderException {
+    LazyAggregateCombineFn<Long, ?, ?> combiner = new LazyAggregateCombineFn<>(new Sum());
+    Coder<?> coder = combiner.getAccumulatorCoder(CoderRegistry.createDefault(), VarLongCoder.of());
+    assertThat(coder, instanceOf(VarLongCoder.class));
+  }
+
+  public static class Sum implements AggregateFn<Long, Long, Long> {
+
+    @Override
+    public Long createAccumulator() {
+      return 0L;
+    }
+
+    @Override
+    public Long addInput(Long mutableAccumulator, Long input) {
+      return mutableAccumulator + input;
+    }
+
+    @Override
+    public Long mergeAccumulators(Long mutableAccumulator, Iterable<Long> immutableAccumulators) {
+      for (Long x : immutableAccumulators) {
+        mutableAccumulator += x;
+      }
+      return mutableAccumulator;
+    }
+
+    @Override
+    public Long extractOutput(Long mutableAccumulator) {
+      return mutableAccumulator;
+    }
+  }
+}