You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by am...@apache.org on 2020/05/04 18:32:42 UTC

[beam] branch master updated: [BEAM-9418] Support ANY_VALUE aggregation functions

This is an automated email from the ASF dual-hosted git repository.

amaliujia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 27876f0  [BEAM-9418] Support ANY_VALUE aggregation functions
     new 6453e85  Merge pull request #11333 from jhnmora000/master
27876f0 is described below

commit 27876f035e3afdd478ef5e7f21d8524d17b59e13
Author: John Mora <jh...@gmail.com>
AuthorDate: Tue Apr 28 22:11:51 2020 -0500

    [BEAM-9418] Support ANY_VALUE aggregation functions
    
    The implementation is based on the function Sample#anyCombineFn(int)
    of the Java SDK core.Also, the support for ZetaSQL was enabled.
---
 .../org/apache/beam/sdk/transforms/Sample.java     | 38 +++++++++++++++++
 .../impl/transform/BeamBuiltinAggregations.java    |  2 +
 .../extensions/sql/BeamSqlDslAggregationTest.java  | 49 ++++++++++++++++++++++
 .../sql/zetasql/SqlStdOperatorMappingTable.java    |  3 +-
 .../sql/zetasql/ZetaSQLDialectSpecTest.java        | 33 +++++++++++++++
 5 files changed, 124 insertions(+), 1 deletion(-)

diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sample.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sample.java
index 2594dee..4b93596 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sample.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sample.java
@@ -59,6 +59,14 @@ public class Sample {
   }
 
   /**
+   * Returns a {@link CombineFn} that computes a single and potentially non-uniform sample value of
+   * its inputs.
+   */
+  public static <T> CombineFn<T, ?, T> anyValueCombineFn() {
+    return new AnyValueCombineFn<>();
+  }
+
+  /**
    * {@code Sample#any(long)} takes a {@code PCollection<T>} and a limit, and produces a new {@code
    * PCollection<T>} containing up to limit elements of the input {@code PCollection}.
    *
@@ -246,6 +254,36 @@ public class Sample {
     }
   }
 
+  /** A {@link CombineFn} that combines into a single element. */
+  private static class AnyValueCombineFn<T> extends CombineFn<T, List<T>, T> {
+    private SampleAnyCombineFn internal;
+
+    private AnyValueCombineFn() {
+      internal = new SampleAnyCombineFn<>(1);
+    }
+
+    @Override
+    public List<T> createAccumulator() {
+      return internal.createAccumulator();
+    }
+
+    @Override
+    public List<T> addInput(List<T> accumulator, T input) {
+      return internal.addInput(accumulator, input);
+    }
+
+    @Override
+    public List<T> mergeAccumulators(Iterable<List<T>> accumulators) {
+      return internal.mergeAccumulators(accumulators);
+    }
+
+    @Override
+    public T extractOutput(List<T> accumulator) {
+      Iterator<T> it = internal.extractOutput(accumulator).iterator();
+      return it.hasNext() ? it.next() : null;
+    }
+  }
+
   /**
    * {@code CombineFn} that computes a fixed-size sample of a collection of values.
    *
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
index ad99c28..106e609 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
@@ -38,6 +38,7 @@ import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.Count;
 import org.apache.beam.sdk.transforms.Max;
 import org.apache.beam.sdk.transforms.Min;
+import org.apache.beam.sdk.transforms.Sample;
 import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableMap;
@@ -48,6 +49,7 @@ public class BeamBuiltinAggregations {
   public static final Map<String, Function<Schema.FieldType, CombineFn<?, ?, ?>>>
       BUILTIN_AGGREGATOR_FACTORIES =
           ImmutableMap.<String, Function<Schema.FieldType, CombineFn<?, ?, ?>>>builder()
+              .put("ANY_VALUE", typeName -> Sample.anyValueCombineFn())
               .put("COUNT", typeName -> Count.combineFn())
               .put("MAX", BeamBuiltinAggregations::createMax)
               .put("MIN", BeamBuiltinAggregations::createMin)
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java
index d350062..80964f5 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java
@@ -27,8 +27,10 @@ import static org.junit.internal.matchers.ThrowableMessageMatcher.hasMessage;
 
 import java.math.BigDecimal;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
+import java.util.Map;
 import org.apache.beam.sdk.extensions.sql.impl.ParseException;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.testing.PAssert;
@@ -241,6 +243,53 @@ public class BeamSqlDslAggregationTest extends BeamSqlDslBase {
     pipeline.run().waitUntilFinish();
   }
 
+  /** GROUP-BY with the any_value aggregation function. */
+  @Test
+  public void testAnyValueFunction() throws Exception {
+    pipeline.enableAbandonedNodeEnforcement(false);
+
+    Schema schema = Schema.builder().addInt32Field("key").addInt32Field("col").build();
+
+    PCollection<Row> inputRows =
+        pipeline
+            .apply(
+                Create.of(
+                    TestUtils.rowsBuilderOf(schema)
+                        .addRows(
+                            0, 1,
+                            0, 2,
+                            1, 3,
+                            2, 4,
+                            2, 5)
+                        .getRows()))
+            .setRowSchema(schema);
+
+    String sql = "SELECT key, any_value(col) as any_value FROM PCOLLECTION GROUP BY key";
+
+    PCollection<Row> result = inputRows.apply("sql", SqlTransform.query(sql));
+
+    Map<Integer, List<Integer>> allowedTuples = new HashMap<>();
+    allowedTuples.put(0, Arrays.asList(1, 2));
+    allowedTuples.put(1, Arrays.asList(3));
+    allowedTuples.put(2, Arrays.asList(4, 5));
+
+    PAssert.that(result)
+        .satisfies(
+            input -> {
+              Iterator<Row> iter = input.iterator();
+              while (iter.hasNext()) {
+                Row row = iter.next();
+                List<Integer> values = allowedTuples.remove(row.getInt32("key"));
+                assertTrue(values != null);
+                assertTrue(values.contains(row.getInt32("any_value")));
+              }
+              assertTrue(allowedTuples.isEmpty());
+              return null;
+            });
+
+    pipeline.run();
+  }
+
   private static class CheckerBigDecimalDivide
       implements SerializableFunction<Iterable<Row>, Void> {
 
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlStdOperatorMappingTable.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlStdOperatorMappingTable.java
index 0f31fa6..22b2de9 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlStdOperatorMappingTable.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlStdOperatorMappingTable.java
@@ -32,6 +32,7 @@ public class SqlStdOperatorMappingTable {
   static final List<FunctionSignatureId> ZETASQL_BUILTIN_FUNCTION_WHITELIST =
       ImmutableList.of(
           FunctionSignatureId.FN_AND,
+          FunctionSignatureId.FN_ANY_VALUE,
           FunctionSignatureId.FN_OR,
           FunctionSignatureId.FN_NOT,
           FunctionSignatureId.FN_MULTIPLY_DOUBLE,
@@ -204,7 +205,7 @@ public class SqlStdOperatorMappingTable {
           .put("min", SqlStdOperatorTable.MIN)
           .put("avg", SqlStdOperatorTable.AVG)
           .put("sum", SqlStdOperatorTable.SUM)
-          // .put("any_value", SqlStdOperatorTable.ANY_VALUE)
+          .put("any_value", SqlStdOperatorTable.ANY_VALUE)
           .put("count", SqlStdOperatorTable.COUNT)
 
           // aggregate UDF
diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLDialectSpecTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLDialectSpecTest.java
index 26e34ca..960a89e 100644
--- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLDialectSpecTest.java
+++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLDialectSpecTest.java
@@ -44,6 +44,7 @@ import static org.apache.beam.sdk.extensions.sql.zetasql.TestInput.TIMESTAMP_TAB
 import static org.apache.beam.sdk.extensions.sql.zetasql.TestInput.TIMESTAMP_TABLE_TWO;
 import static org.apache.beam.sdk.extensions.sql.zetasql.TestInput.TIME_TABLE;
 import static org.apache.beam.sdk.schemas.Schema.FieldType.DATETIME;
+import static org.junit.Assert.assertTrue;
 
 import com.google.protobuf.ByteString;
 import com.google.zetasql.SqlException;
@@ -56,6 +57,8 @@ import com.google.zetasql.ZetaSQLValue.ValueProto;
 import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
 import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
 import java.util.Map;
 import org.apache.beam.sdk.extensions.sql.impl.JdbcConnection;
 import org.apache.beam.sdk.extensions.sql.impl.JdbcDriver;
@@ -1497,6 +1500,36 @@ public class ZetaSQLDialectSpecTest {
   }
 
   @Test
+  public void testZetaSQLAnyValueInGroupBy() {
+    String sql =
+        "SELECT rowCol.row_id as key, ANY_VALUE(rowCol.data) as any_value FROM table_with_struct_two GROUP BY rowCol.row_id";
+
+    ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
+    BeamRelNode beamRelNode = zetaSQLQueryPlanner.convertToBeamRel(sql);
+    PCollection<Row> stream = BeamSqlRelUtils.toPCollection(pipeline, beamRelNode);
+    Map<Long, List<String>> allowedTuples = new HashMap<>();
+    allowedTuples.put(1L, Arrays.asList("data1"));
+    allowedTuples.put(2L, Arrays.asList("data2"));
+    allowedTuples.put(3L, Arrays.asList("data2", "data3"));
+
+    PAssert.that(stream)
+        .satisfies(
+            input -> {
+              Iterator<Row> iter = input.iterator();
+              while (iter.hasNext()) {
+                Row row = iter.next();
+                List<String> values = allowedTuples.remove(row.getInt64("key"));
+                assertTrue(values != null);
+                assertTrue(values.contains(row.getString("any_value")));
+              }
+              assertTrue(allowedTuples.isEmpty());
+              return null;
+            });
+
+    pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
+  }
+
+  @Test
   public void testZetaSQLStructFieldAccessInGroupBy2() {
     String sql =
         "SELECT rowCol.data, MAX(rowCol.row_id), MIN(rowCol.row_id) FROM table_with_struct_two"