You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by am...@apache.org on 2020/05/04 18:32:42 UTC
[beam] branch master updated: [BEAM-9418] Support ANY_VALUE
aggregation functions
This is an automated email from the ASF dual-hosted git repository.
amaliujia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 27876f0 [BEAM-9418] Support ANY_VALUE aggregation functions
new 6453e85 Merge pull request #11333 from jhnmora000/master
27876f0 is described below
commit 27876f035e3afdd478ef5e7f21d8524d17b59e13
Author: John Mora <jh...@gmail.com>
AuthorDate: Tue Apr 28 22:11:51 2020 -0500
[BEAM-9418] Support ANY_VALUE aggregation functions
The implementation is based on the function Sample#anyCombineFn(int)
of the Java SDK core.Also, the support for ZetaSQL was enabled.
---
.../org/apache/beam/sdk/transforms/Sample.java | 38 +++++++++++++++++
.../impl/transform/BeamBuiltinAggregations.java | 2 +
.../extensions/sql/BeamSqlDslAggregationTest.java | 49 ++++++++++++++++++++++
.../sql/zetasql/SqlStdOperatorMappingTable.java | 3 +-
.../sql/zetasql/ZetaSQLDialectSpecTest.java | 33 +++++++++++++++
5 files changed, 124 insertions(+), 1 deletion(-)
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sample.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sample.java
index 2594dee..4b93596 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sample.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sample.java
@@ -59,6 +59,14 @@ public class Sample {
}
/**
+ * Returns a {@link CombineFn} that computes a single and potentially non-uniform sample value of
+ * its inputs.
+ */
+ public static <T> CombineFn<T, ?, T> anyValueCombineFn() {
+ return new AnyValueCombineFn<>();
+ }
+
+ /**
* {@code Sample#any(long)} takes a {@code PCollection<T>} and a limit, and produces a new {@code
* PCollection<T>} containing up to limit elements of the input {@code PCollection}.
*
@@ -246,6 +254,36 @@ public class Sample {
}
}
+ /** A {@link CombineFn} that combines into a single element. */
+ private static class AnyValueCombineFn<T> extends CombineFn<T, List<T>, T> {
+ private SampleAnyCombineFn internal;
+
+ private AnyValueCombineFn() {
+ internal = new SampleAnyCombineFn<>(1);
+ }
+
+ @Override
+ public List<T> createAccumulator() {
+ return internal.createAccumulator();
+ }
+
+ @Override
+ public List<T> addInput(List<T> accumulator, T input) {
+ return internal.addInput(accumulator, input);
+ }
+
+ @Override
+ public List<T> mergeAccumulators(Iterable<List<T>> accumulators) {
+ return internal.mergeAccumulators(accumulators);
+ }
+
+ @Override
+ public T extractOutput(List<T> accumulator) {
+ Iterator<T> it = internal.extractOutput(accumulator).iterator();
+ return it.hasNext() ? it.next() : null;
+ }
+ }
+
/**
* {@code CombineFn} that computes a fixed-size sample of a collection of values.
*
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
index ad99c28..106e609 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
@@ -38,6 +38,7 @@ import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Max;
import org.apache.beam.sdk.transforms.Min;
+import org.apache.beam.sdk.transforms.Sample;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableMap;
@@ -48,6 +49,7 @@ public class BeamBuiltinAggregations {
public static final Map<String, Function<Schema.FieldType, CombineFn<?, ?, ?>>>
BUILTIN_AGGREGATOR_FACTORIES =
ImmutableMap.<String, Function<Schema.FieldType, CombineFn<?, ?, ?>>>builder()
+ .put("ANY_VALUE", typeName -> Sample.anyValueCombineFn())
.put("COUNT", typeName -> Count.combineFn())
.put("MAX", BeamBuiltinAggregations::createMax)
.put("MIN", BeamBuiltinAggregations::createMin)
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java
index d350062..80964f5 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java
@@ -27,8 +27,10 @@ import static org.junit.internal.matchers.ThrowableMessageMatcher.hasMessage;
import java.math.BigDecimal;
import java.util.Arrays;
+import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
+import java.util.Map;
import org.apache.beam.sdk.extensions.sql.impl.ParseException;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.testing.PAssert;
@@ -241,6 +243,53 @@ public class BeamSqlDslAggregationTest extends BeamSqlDslBase {
pipeline.run().waitUntilFinish();
}
+ /** GROUP-BY with the any_value aggregation function. */
+ @Test
+ public void testAnyValueFunction() throws Exception {
+ pipeline.enableAbandonedNodeEnforcement(false);
+
+ Schema schema = Schema.builder().addInt32Field("key").addInt32Field("col").build();
+
+ PCollection<Row> inputRows =
+ pipeline
+ .apply(
+ Create.of(
+ TestUtils.rowsBuilderOf(schema)
+ .addRows(
+ 0, 1,
+ 0, 2,
+ 1, 3,
+ 2, 4,
+ 2, 5)
+ .getRows()))
+ .setRowSchema(schema);
+
+ String sql = "SELECT key, any_value(col) as any_value FROM PCOLLECTION GROUP BY key";
+
+ PCollection<Row> result = inputRows.apply("sql", SqlTransform.query(sql));
+
+ Map<Integer, List<Integer>> allowedTuples = new HashMap<>();
+ allowedTuples.put(0, Arrays.asList(1, 2));
+ allowedTuples.put(1, Arrays.asList(3));
+ allowedTuples.put(2, Arrays.asList(4, 5));
+
+ PAssert.that(result)
+ .satisfies(
+ input -> {
+ Iterator<Row> iter = input.iterator();
+ while (iter.hasNext()) {
+ Row row = iter.next();
+ List<Integer> values = allowedTuples.remove(row.getInt32("key"));
+ assertTrue(values != null);
+ assertTrue(values.contains(row.getInt32("any_value")));
+ }
+ assertTrue(allowedTuples.isEmpty());
+ return null;
+ });
+
+ pipeline.run();
+ }
+
private static class CheckerBigDecimalDivide
implements SerializableFunction<Iterable<Row>, Void> {
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlStdOperatorMappingTable.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlStdOperatorMappingTable.java
index 0f31fa6..22b2de9 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlStdOperatorMappingTable.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlStdOperatorMappingTable.java
@@ -32,6 +32,7 @@ public class SqlStdOperatorMappingTable {
static final List<FunctionSignatureId> ZETASQL_BUILTIN_FUNCTION_WHITELIST =
ImmutableList.of(
FunctionSignatureId.FN_AND,
+ FunctionSignatureId.FN_ANY_VALUE,
FunctionSignatureId.FN_OR,
FunctionSignatureId.FN_NOT,
FunctionSignatureId.FN_MULTIPLY_DOUBLE,
@@ -204,7 +205,7 @@ public class SqlStdOperatorMappingTable {
.put("min", SqlStdOperatorTable.MIN)
.put("avg", SqlStdOperatorTable.AVG)
.put("sum", SqlStdOperatorTable.SUM)
- // .put("any_value", SqlStdOperatorTable.ANY_VALUE)
+ .put("any_value", SqlStdOperatorTable.ANY_VALUE)
.put("count", SqlStdOperatorTable.COUNT)
// aggregate UDF
diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLDialectSpecTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLDialectSpecTest.java
index 26e34ca..960a89e 100644
--- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLDialectSpecTest.java
+++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLDialectSpecTest.java
@@ -44,6 +44,7 @@ import static org.apache.beam.sdk.extensions.sql.zetasql.TestInput.TIMESTAMP_TAB
import static org.apache.beam.sdk.extensions.sql.zetasql.TestInput.TIMESTAMP_TABLE_TWO;
import static org.apache.beam.sdk.extensions.sql.zetasql.TestInput.TIME_TABLE;
import static org.apache.beam.sdk.schemas.Schema.FieldType.DATETIME;
+import static org.junit.Assert.assertTrue;
import com.google.protobuf.ByteString;
import com.google.zetasql.SqlException;
@@ -56,6 +57,8 @@ import com.google.zetasql.ZetaSQLValue.ValueProto;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
import java.util.Map;
import org.apache.beam.sdk.extensions.sql.impl.JdbcConnection;
import org.apache.beam.sdk.extensions.sql.impl.JdbcDriver;
@@ -1497,6 +1500,36 @@ public class ZetaSQLDialectSpecTest {
}
@Test
+ public void testZetaSQLAnyValueInGroupBy() {
+ String sql =
+ "SELECT rowCol.row_id as key, ANY_VALUE(rowCol.data) as any_value FROM table_with_struct_two GROUP BY rowCol.row_id";
+
+ ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
+ BeamRelNode beamRelNode = zetaSQLQueryPlanner.convertToBeamRel(sql);
+ PCollection<Row> stream = BeamSqlRelUtils.toPCollection(pipeline, beamRelNode);
+ Map<Long, List<String>> allowedTuples = new HashMap<>();
+ allowedTuples.put(1L, Arrays.asList("data1"));
+ allowedTuples.put(2L, Arrays.asList("data2"));
+ allowedTuples.put(3L, Arrays.asList("data2", "data3"));
+
+ PAssert.that(stream)
+ .satisfies(
+ input -> {
+ Iterator<Row> iter = input.iterator();
+ while (iter.hasNext()) {
+ Row row = iter.next();
+ List<String> values = allowedTuples.remove(row.getInt64("key"));
+ assertTrue(values != null);
+ assertTrue(values.contains(row.getString("any_value")));
+ }
+ assertTrue(allowedTuples.isEmpty());
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
+ }
+
+ @Test
public void testZetaSQLStructFieldAccessInGroupBy2() {
String sql =
"SELECT rowCol.data, MAX(rowCol.row_id), MIN(rowCol.row_id) FROM table_with_struct_two"