You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by am...@apache.org on 2020/07/26 19:50:32 UTC
[beam] branch master updated: Implement Numbering functions (#12375)
This is an automated email from the ASF dual-hosted git repository.
amaliujia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 8d460db Implement Numbering functions (#12375)
8d460db is described below
commit 8d460db620d2ff1257b0e092218294df15b409a1
Author: jhnmora000 <48...@users.noreply.github.com>
AuthorDate: Sun Jul 26 14:50:00 2020 -0500
Implement Numbering functions (#12375)
RANK, DENSE_RANK, PERCENT_RANK, ROW_NUMBER
---
.../sdk/extensions/sql/impl/rel/BeamWindowRel.java | 23 ++-
.../transform/BeamBuiltinAnalyticFunctions.java | 183 ++++++++++++++++++++-
.../transform/agg/AggregationCombineFnAdapter.java | 8 +-
.../extensions/sql/BeamAnalyticFunctionsTest.java | 115 +++++++++++++
4 files changed, 318 insertions(+), 11 deletions(-)
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java
index 09ca5b1..5d5da8d 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java
@@ -28,6 +28,7 @@ import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
+import org.apache.beam.sdk.extensions.sql.impl.transform.BeamBuiltinAnalyticFunctions;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.AggregationCombineFnAdapter;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.schemas.Schema;
@@ -291,10 +292,26 @@ public class BeamWindowRel extends Window implements BeamRelNode {
aggRange = getRange(indexRange, sortedRowsAsList.get(idx));
}
Object accumulator = fieldAgg.combineFn.createAccumulator();
- final int aggFieldIndex = fieldAgg.inputFields.get(0);
+ // if not inputs are needed, put a mock Field index
+ final int aggFieldIndex =
+ fieldAgg.inputFields.isEmpty() ? -1 : fieldAgg.inputFields.get(0);
+ long count = 0;
for (Row aggRow : aggRange) {
- accumulator =
- fieldAgg.combineFn.addInput(accumulator, aggRow.getBaseValue(aggFieldIndex));
+ if (fieldAgg.combineFn instanceof BeamBuiltinAnalyticFunctions.PositionAwareCombineFn) {
+ BeamBuiltinAnalyticFunctions.PositionAwareCombineFn fn =
+ (BeamBuiltinAnalyticFunctions.PositionAwareCombineFn) fieldAgg.combineFn;
+ accumulator =
+ fn.addInput(
+ accumulator,
+ getOrderByValue(aggRow),
+ count,
+ (long) idx,
+ (long) sortedRowsAsList.size());
+ } else {
+ accumulator =
+ fieldAgg.combineFn.addInput(accumulator, aggRow.getBaseValue(aggFieldIndex));
+ }
+ count++;
}
Object result = fieldAgg.combineFn.extractOutput(accumulator);
Row processingRow = sortedRowsAsList.get(idx);
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAnalyticFunctions.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAnalyticFunctions.java
index 14fe20d..457e0db 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAnalyticFunctions.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAnalyticFunctions.java
@@ -17,11 +17,13 @@
*/
package org.apache.beam.sdk.extensions.sql.impl.transform;
+import java.math.BigDecimal;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableMap;
/** Built-in Analytic Functions for the aggregation analytics functionality. */
@@ -29,10 +31,16 @@ public class BeamBuiltinAnalyticFunctions {
public static final Map<String, Function<Schema.FieldType, Combine.CombineFn<?, ?, ?>>>
BUILTIN_ANALYTIC_FACTORIES =
ImmutableMap.<String, Function<Schema.FieldType, Combine.CombineFn<?, ?, ?>>>builder()
+ // Aggregate Analytic Functions
.putAll(BeamBuiltinAggregations.BUILTIN_AGGREGATOR_FACTORIES)
+ // Navigation Functions
.put("FIRST_VALUE", typeName -> navigationFirstValue())
.put("LAST_VALUE", typeName -> navigationLastValue())
- // Pending Numbering functions
+ // Numbering Functions
+ .put("ROW_NUMBER", typeName -> numberingRowNumber())
+ .put("DENSE_RANK", typeName -> numberingDenseRank())
+ .put("RANK", typeName -> numberingRank())
+ .put("PERCENT_RANK", typeName -> numberingPercentRank())
.build();
public static Combine.CombineFn<?, ?, ?> create(String functionName, Schema.FieldType fieldType) {
@@ -45,6 +53,7 @@ public class BeamBuiltinAnalyticFunctions {
String.format("Analytics Function [%s] is not supported", functionName));
}
+ // Navigation functions
public static <T> Combine.CombineFn<T, ?, T> navigationFirstValue() {
return new FirstValueCombineFn();
}
@@ -105,4 +114,176 @@ public class BeamBuiltinAnalyticFunctions {
return accumulator.isPresent() ? accumulator.get() : null;
}
}
+
+ // Numbering functions
+ public static <T> Combine.CombineFn<T, ?, T> numberingRowNumber() {
+ return new RowNumberCombineFn();
+ }
+
+ public static <T> Combine.CombineFn<T, ?, T> numberingDenseRank() {
+ return new DenseRankCombineFn();
+ }
+
+ public static <T> Combine.CombineFn<T, ?, T> numberingRank() {
+ return new RankCombineFn();
+ }
+
+ public static <T> Combine.CombineFn<T, ?, T> numberingPercentRank() {
+ return new PercentRankCombineFn();
+ }
+
+ public abstract static class PositionAwareCombineFn<InputT, AccumT, OutputT>
+ extends Combine.CombineFn<InputT, AccumT, OutputT> {
+ public abstract AccumT addInput(
+ AccumT accumulator,
+ InputT input,
+ Long cursorWindow,
+ Long cursorPartition,
+ Long countPartition);
+
+ @Override
+ public AccumT addInput(AccumT mutableAccumulator, InputT input) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ private static class RowNumberCombineFn<T>
+ extends PositionAwareCombineFn<BigDecimal, Optional<Long>, Long> {
+
+ @Override
+ public Optional<Long> addInput(
+ Optional<Long> accumulator,
+ BigDecimal input,
+ Long cursorPosition,
+ Long cursorPartition,
+ Long countPartition) {
+ return Optional.of(cursorPartition);
+ }
+
+ @Override
+ public Optional<Long> createAccumulator() {
+ return Optional.empty();
+ }
+
+ @Override
+ public Long extractOutput(Optional<Long> accumulator) {
+ // 1-based result
+ return accumulator.isPresent() ? accumulator.get() + 1 : null;
+ }
+ }
+
+ private static class DenseRankCombineFn<T>
+ extends PositionAwareCombineFn<BigDecimal, KV<BigDecimal, Long>, Long> {
+
+ @Override
+ public KV<BigDecimal, Long> addInput(
+ KV<BigDecimal, Long> accumulator,
+ BigDecimal input,
+ Long cursorPosition,
+ Long cursorPartition,
+ Long countPartition) {
+ KV<BigDecimal, Long> r = null;
+ if (accumulator == null) {
+ r = KV.of(input, 0L);
+ } else {
+ if (accumulator.getKey().compareTo(input) == 0) {
+ r = KV.of(input, accumulator.getValue());
+ } else {
+ r = KV.of(input, accumulator.getValue() + 1);
+ }
+ }
+ return r;
+ }
+
+ @Override
+ public KV<BigDecimal, Long> createAccumulator() {
+ return null;
+ }
+
+ @Override
+ public Long extractOutput(KV<BigDecimal, Long> accumulator) {
+ // 1-based result
+ return accumulator != null ? accumulator.getValue() + 1 : null;
+ }
+ }
+
+ private static class RankCombineFn<T>
+ extends PositionAwareCombineFn<BigDecimal, KV<BigDecimal, Long>, Long> {
+
+ @Override
+ public KV<BigDecimal, Long> addInput(
+ KV<BigDecimal, Long> accumulator,
+ BigDecimal input,
+ Long cursorPosition,
+ Long cursorPartition,
+ Long countPartition) {
+ KV<BigDecimal, Long> r = null;
+ if (accumulator == null) {
+ r = KV.of(input, 0L);
+ } else {
+ if (accumulator.getKey().compareTo(input) == 0) {
+ r = KV.of(input, accumulator.getValue());
+ } else {
+ r = KV.of(input, cursorPosition);
+ }
+ }
+ return r;
+ }
+
+ @Override
+ public KV<BigDecimal, Long> createAccumulator() {
+ return null;
+ }
+
+ @Override
+ public Long extractOutput(KV<BigDecimal, Long> accumulator) {
+ // 1-based result
+ return accumulator != null ? accumulator.getValue() + 1 : null;
+ }
+ }
+
+ private static class PercentRankCombineFn<T>
+ extends PositionAwareCombineFn<BigDecimal, KV<Optional<Long>, KV<BigDecimal, Long>>, Double> {
+
+ RankCombineFn internalRank;
+
+ PercentRankCombineFn() {
+ internalRank = new RankCombineFn();
+ }
+
+ @Override
+ public KV<Optional<Long>, KV<BigDecimal, Long>> addInput(
+ KV<Optional<Long>, KV<BigDecimal, Long>> accumulator,
+ BigDecimal input,
+ Long cursorPosition,
+ Long cursorPartition,
+ Long countPartition) {
+ KV<BigDecimal, Long> ac1 =
+ internalRank.addInput(
+ accumulator.getValue(), input, cursorPosition, cursorPartition, countPartition);
+ Optional<Long> ac2 = Optional.of(countPartition);
+ return KV.of(ac2, ac1);
+ }
+
+ @Override
+ public KV<Optional<Long>, KV<BigDecimal, Long>> createAccumulator() {
+ return KV.of(Optional.empty(), internalRank.createAccumulator());
+ }
+
+ @Override
+ public Double extractOutput(KV<Optional<Long>, KV<BigDecimal, Long>> accumulator) {
+ Long nr = accumulator.getKey().orElse(null);
+ Long rk = internalRank.extractOutput(accumulator.getValue());
+ Double r = 0.0;
+ if (nr != null && rk != null && nr > 1L) {
+ r = (rk.doubleValue() - 1) / (nr.doubleValue() - 1);
+ }
+ return r;
+ }
+ }
}
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationCombineFnAdapter.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationCombineFnAdapter.java
index 3b5b267..b2110af 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationCombineFnAdapter.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationCombineFnAdapter.java
@@ -174,13 +174,7 @@ public class AggregationCombineFnAdapter<T> {
} else {
combineFn = BeamBuiltinAnalyticFunctions.create(functionName, field.getType());
}
- if (call.getArgList().isEmpty()) {
- return new SingleInputCombiner(combineFn);
- } else if (call.getArgList().size() == 1) {
- return new SingleInputCombiner(combineFn);
- } else {
- return new MultiInputCombiner(combineFn);
- }
+ return combineFn;
}
public static CombineFn<Row, ?, Row> createConstantCombineFn() {
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamAnalyticFunctionsTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamAnalyticFunctionsTest.java
index 2e06a0b..d66fdb5 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamAnalyticFunctionsTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamAnalyticFunctionsTest.java
@@ -67,6 +67,17 @@ public class BeamAnalyticFunctionsTest extends BeamSqlDslBase {
}
/**
+ * Table schema and data taken from.
+ * https://cloud.google.com/bigquery/docs/reference/standard-sql/analytic-function-concepts#numbering_function_concepts
+ */
+ private PCollection<Row> inputData2() {
+ Schema schema = Schema.builder().addInt32Field("x").build();
+ return pipeline
+ .apply(Create.of(TestUtils.rowsBuilderOf(schema).addRows(1, 2, 2, 5, 8, 10, 10).getRows()))
+ .setRowSchema(schema);
+ }
+
+ /**
* Compute a cumulative sum query taken from.
* https://cloud.google.com/bigquery/docs/reference/standard-sql/analytic-function-concepts#compute_a_cumulative_sum
*/
@@ -473,4 +484,108 @@ public class BeamAnalyticFunctionsTest extends BeamSqlDslBase {
pipeline.run();
}
+
+ @Test
+ public void testRowNumberFunction() throws Exception {
+ pipeline.enableAbandonedNodeEnforcement(false);
+ PCollection<Row> inputRows = inputData2();
+ String sql = "SELECT x, ROW_NUMBER() over (ORDER BY x ) as agg FROM PCOLLECTION";
+ PCollection<Row> result = inputRows.apply("sql", SqlTransform.query(sql));
+
+ Schema overResultSchema = Schema.builder().addInt32Field("x").addInt64Field("agg").build();
+
+ List<Row> overResult =
+ TestUtils.RowsBuilder.of(overResultSchema)
+ .addRows(
+ 1, 1L,
+ 2, 2L,
+ 2, 3L,
+ 5, 4L,
+ 8, 5L,
+ 10, 6L,
+ 10, 7L)
+ .getRows();
+
+ PAssert.that(result).containsInAnyOrder(overResult);
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testDenseRankFunction() throws Exception {
+ pipeline.enableAbandonedNodeEnforcement(false);
+ PCollection<Row> inputRows = inputData2();
+ String sql = "SELECT x, DENSE_RANK() over (ORDER BY x ) as agg FROM PCOLLECTION";
+ PCollection<Row> result = inputRows.apply("sql", SqlTransform.query(sql));
+
+ Schema overResultSchema = Schema.builder().addInt32Field("x").addInt64Field("agg").build();
+
+ List<Row> overResult =
+ TestUtils.RowsBuilder.of(overResultSchema)
+ .addRows(
+ 1, 1L,
+ 2, 2L,
+ 2, 2L,
+ 5, 3L,
+ 8, 4L,
+ 10, 5L,
+ 10, 5L)
+ .getRows();
+
+ PAssert.that(result).containsInAnyOrder(overResult);
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testRankFunction() throws Exception {
+ pipeline.enableAbandonedNodeEnforcement(false);
+ PCollection<Row> inputRows = inputData2();
+ String sql = "SELECT x, RANK() over (ORDER BY x ) as agg FROM PCOLLECTION";
+ PCollection<Row> result = inputRows.apply("sql", SqlTransform.query(sql));
+
+ Schema overResultSchema = Schema.builder().addInt32Field("x").addInt64Field("agg").build();
+
+ List<Row> overResult =
+ TestUtils.RowsBuilder.of(overResultSchema)
+ .addRows(
+ 1, 1L,
+ 2, 2L,
+ 2, 2L,
+ 5, 4L,
+ 8, 5L,
+ 10, 6L,
+ 10, 6L)
+ .getRows();
+
+ PAssert.that(result).containsInAnyOrder(overResult);
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testPercentRankFunction() throws Exception {
+ pipeline.enableAbandonedNodeEnforcement(false);
+ PCollection<Row> inputRows = inputData2();
+ String sql = "SELECT x, PERCENT_RANK() over (ORDER BY x ) as agg FROM PCOLLECTION";
+ PCollection<Row> result = inputRows.apply("sql", SqlTransform.query(sql));
+
+ Schema overResultSchema = Schema.builder().addInt32Field("x").addDoubleField("agg").build();
+
+ List<Row> overResult =
+ TestUtils.RowsBuilder.of(overResultSchema)
+ .addRows(
+ 1, 0.0 / 6.0,
+ 2, 1.0 / 6.0,
+ 2, 1.0 / 6.0,
+ 5, 3.0 / 6.0,
+ 8, 4.0 / 6.0,
+ 10, 5.0 / 6.0,
+ 10, 5.0 / 6.0)
+ .getRows();
+
+ PAssert.that(result).containsInAnyOrder(overResult);
+
+ pipeline.run();
+ }
}