You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by am...@apache.org on 2020/07/26 19:50:32 UTC

[beam] branch master updated: Implement Numbering functions (#12375)

This is an automated email from the ASF dual-hosted git repository.

amaliujia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 8d460db  Implement Numbering functions (#12375)
8d460db is described below

commit 8d460db620d2ff1257b0e092218294df15b409a1
Author: jhnmora000 <48...@users.noreply.github.com>
AuthorDate: Sun Jul 26 14:50:00 2020 -0500

    Implement Numbering functions (#12375)
    
    RANK, DENSE_RANK, PERCENT_RANK, ROW_NUMBER
---
 .../sdk/extensions/sql/impl/rel/BeamWindowRel.java |  23 ++-
 .../transform/BeamBuiltinAnalyticFunctions.java    | 183 ++++++++++++++++++++-
 .../transform/agg/AggregationCombineFnAdapter.java |   8 +-
 .../extensions/sql/BeamAnalyticFunctionsTest.java  | 115 +++++++++++++
 4 files changed, 318 insertions(+), 11 deletions(-)

diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java
index 09ca5b1..5d5da8d 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java
@@ -28,6 +28,7 @@ import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.ListCoder;
 import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
 import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
+import org.apache.beam.sdk.extensions.sql.impl.transform.BeamBuiltinAnalyticFunctions;
 import org.apache.beam.sdk.extensions.sql.impl.transform.agg.AggregationCombineFnAdapter;
 import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
 import org.apache.beam.sdk.schemas.Schema;
@@ -291,10 +292,26 @@ public class BeamWindowRel extends Window implements BeamRelNode {
             aggRange = getRange(indexRange, sortedRowsAsList.get(idx));
           }
           Object accumulator = fieldAgg.combineFn.createAccumulator();
-          final int aggFieldIndex = fieldAgg.inputFields.get(0);
+          // if not inputs are needed, put a mock Field index
+          final int aggFieldIndex =
+              fieldAgg.inputFields.isEmpty() ? -1 : fieldAgg.inputFields.get(0);
+          long count = 0;
           for (Row aggRow : aggRange) {
-            accumulator =
-                fieldAgg.combineFn.addInput(accumulator, aggRow.getBaseValue(aggFieldIndex));
+            if (fieldAgg.combineFn instanceof BeamBuiltinAnalyticFunctions.PositionAwareCombineFn) {
+              BeamBuiltinAnalyticFunctions.PositionAwareCombineFn fn =
+                  (BeamBuiltinAnalyticFunctions.PositionAwareCombineFn) fieldAgg.combineFn;
+              accumulator =
+                  fn.addInput(
+                      accumulator,
+                      getOrderByValue(aggRow),
+                      count,
+                      (long) idx,
+                      (long) sortedRowsAsList.size());
+            } else {
+              accumulator =
+                  fieldAgg.combineFn.addInput(accumulator, aggRow.getBaseValue(aggFieldIndex));
+            }
+            count++;
           }
           Object result = fieldAgg.combineFn.extractOutput(accumulator);
           Row processingRow = sortedRowsAsList.get(idx);
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAnalyticFunctions.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAnalyticFunctions.java
index 14fe20d..457e0db 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAnalyticFunctions.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAnalyticFunctions.java
@@ -17,11 +17,13 @@
  */
 package org.apache.beam.sdk.extensions.sql.impl.transform;
 
+import java.math.BigDecimal;
 import java.util.Map;
 import java.util.Optional;
 import java.util.function.Function;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableMap;
 
 /** Built-in Analytic Functions for the aggregation analytics functionality. */
@@ -29,10 +31,16 @@ public class BeamBuiltinAnalyticFunctions {
   public static final Map<String, Function<Schema.FieldType, Combine.CombineFn<?, ?, ?>>>
       BUILTIN_ANALYTIC_FACTORIES =
           ImmutableMap.<String, Function<Schema.FieldType, Combine.CombineFn<?, ?, ?>>>builder()
+              // Aggregate Analytic Functions
               .putAll(BeamBuiltinAggregations.BUILTIN_AGGREGATOR_FACTORIES)
+              // Navigation Functions
               .put("FIRST_VALUE", typeName -> navigationFirstValue())
               .put("LAST_VALUE", typeName -> navigationLastValue())
-              // Pending Numbering functions
+              // Numbering Functions
+              .put("ROW_NUMBER", typeName -> numberingRowNumber())
+              .put("DENSE_RANK", typeName -> numberingDenseRank())
+              .put("RANK", typeName -> numberingRank())
+              .put("PERCENT_RANK", typeName -> numberingPercentRank())
               .build();
 
   public static Combine.CombineFn<?, ?, ?> create(String functionName, Schema.FieldType fieldType) {
@@ -45,6 +53,7 @@ public class BeamBuiltinAnalyticFunctions {
         String.format("Analytics Function [%s] is not supported", functionName));
   }
 
+  // Navigation functions
   public static <T> Combine.CombineFn<T, ?, T> navigationFirstValue() {
     return new FirstValueCombineFn();
   }
@@ -105,4 +114,176 @@ public class BeamBuiltinAnalyticFunctions {
       return accumulator.isPresent() ? accumulator.get() : null;
     }
   }
+
+  // Numbering functions
+  public static <T> Combine.CombineFn<T, ?, T> numberingRowNumber() {
+    return new RowNumberCombineFn();
+  }
+
+  public static <T> Combine.CombineFn<T, ?, T> numberingDenseRank() {
+    return new DenseRankCombineFn();
+  }
+
+  public static <T> Combine.CombineFn<T, ?, T> numberingRank() {
+    return new RankCombineFn();
+  }
+
+  public static <T> Combine.CombineFn<T, ?, T> numberingPercentRank() {
+    return new PercentRankCombineFn();
+  }
+
+  public abstract static class PositionAwareCombineFn<InputT, AccumT, OutputT>
+      extends Combine.CombineFn<InputT, AccumT, OutputT> {
+    public abstract AccumT addInput(
+        AccumT accumulator,
+        InputT input,
+        Long cursorWindow,
+        Long cursorPartition,
+        Long countPartition);
+
+    @Override
+    public AccumT addInput(AccumT mutableAccumulator, InputT input) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
+      throw new UnsupportedOperationException();
+    }
+  }
+
+  private static class RowNumberCombineFn<T>
+      extends PositionAwareCombineFn<BigDecimal, Optional<Long>, Long> {
+
+    @Override
+    public Optional<Long> addInput(
+        Optional<Long> accumulator,
+        BigDecimal input,
+        Long cursorPosition,
+        Long cursorPartition,
+        Long countPartition) {
+      return Optional.of(cursorPartition);
+    }
+
+    @Override
+    public Optional<Long> createAccumulator() {
+      return Optional.empty();
+    }
+
+    @Override
+    public Long extractOutput(Optional<Long> accumulator) {
+      // 1-based result
+      return accumulator.isPresent() ? accumulator.get() + 1 : null;
+    }
+  }
+
+  private static class DenseRankCombineFn<T>
+      extends PositionAwareCombineFn<BigDecimal, KV<BigDecimal, Long>, Long> {
+
+    @Override
+    public KV<BigDecimal, Long> addInput(
+        KV<BigDecimal, Long> accumulator,
+        BigDecimal input,
+        Long cursorPosition,
+        Long cursorPartition,
+        Long countPartition) {
+      KV<BigDecimal, Long> r = null;
+      if (accumulator == null) {
+        r = KV.of(input, 0L);
+      } else {
+        if (accumulator.getKey().compareTo(input) == 0) {
+          r = KV.of(input, accumulator.getValue());
+        } else {
+          r = KV.of(input, accumulator.getValue() + 1);
+        }
+      }
+      return r;
+    }
+
+    @Override
+    public KV<BigDecimal, Long> createAccumulator() {
+      return null;
+    }
+
+    @Override
+    public Long extractOutput(KV<BigDecimal, Long> accumulator) {
+      // 1-based result
+      return accumulator != null ? accumulator.getValue() + 1 : null;
+    }
+  }
+
+  private static class RankCombineFn<T>
+      extends PositionAwareCombineFn<BigDecimal, KV<BigDecimal, Long>, Long> {
+
+    @Override
+    public KV<BigDecimal, Long> addInput(
+        KV<BigDecimal, Long> accumulator,
+        BigDecimal input,
+        Long cursorPosition,
+        Long cursorPartition,
+        Long countPartition) {
+      KV<BigDecimal, Long> r = null;
+      if (accumulator == null) {
+        r = KV.of(input, 0L);
+      } else {
+        if (accumulator.getKey().compareTo(input) == 0) {
+          r = KV.of(input, accumulator.getValue());
+        } else {
+          r = KV.of(input, cursorPosition);
+        }
+      }
+      return r;
+    }
+
+    @Override
+    public KV<BigDecimal, Long> createAccumulator() {
+      return null;
+    }
+
+    @Override
+    public Long extractOutput(KV<BigDecimal, Long> accumulator) {
+      // 1-based result
+      return accumulator != null ? accumulator.getValue() + 1 : null;
+    }
+  }
+
+  private static class PercentRankCombineFn<T>
+      extends PositionAwareCombineFn<BigDecimal, KV<Optional<Long>, KV<BigDecimal, Long>>, Double> {
+
+    RankCombineFn internalRank;
+
+    PercentRankCombineFn() {
+      internalRank = new RankCombineFn();
+    }
+
+    @Override
+    public KV<Optional<Long>, KV<BigDecimal, Long>> addInput(
+        KV<Optional<Long>, KV<BigDecimal, Long>> accumulator,
+        BigDecimal input,
+        Long cursorPosition,
+        Long cursorPartition,
+        Long countPartition) {
+      KV<BigDecimal, Long> ac1 =
+          internalRank.addInput(
+              accumulator.getValue(), input, cursorPosition, cursorPartition, countPartition);
+      Optional<Long> ac2 = Optional.of(countPartition);
+      return KV.of(ac2, ac1);
+    }
+
+    @Override
+    public KV<Optional<Long>, KV<BigDecimal, Long>> createAccumulator() {
+      return KV.of(Optional.empty(), internalRank.createAccumulator());
+    }
+
+    @Override
+    public Double extractOutput(KV<Optional<Long>, KV<BigDecimal, Long>> accumulator) {
+      Long nr = accumulator.getKey().orElse(null);
+      Long rk = internalRank.extractOutput(accumulator.getValue());
+      Double r = 0.0;
+      if (nr != null && rk != null && nr > 1L) {
+        r = (rk.doubleValue() - 1) / (nr.doubleValue() - 1);
+      }
+      return r;
+    }
+  }
 }
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationCombineFnAdapter.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationCombineFnAdapter.java
index 3b5b267..b2110af 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationCombineFnAdapter.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationCombineFnAdapter.java
@@ -174,13 +174,7 @@ public class AggregationCombineFnAdapter<T> {
     } else {
       combineFn = BeamBuiltinAnalyticFunctions.create(functionName, field.getType());
     }
-    if (call.getArgList().isEmpty()) {
-      return new SingleInputCombiner(combineFn);
-    } else if (call.getArgList().size() == 1) {
-      return new SingleInputCombiner(combineFn);
-    } else {
-      return new MultiInputCombiner(combineFn);
-    }
+    return combineFn;
   }
 
   public static CombineFn<Row, ?, Row> createConstantCombineFn() {
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamAnalyticFunctionsTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamAnalyticFunctionsTest.java
index 2e06a0b..d66fdb5 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamAnalyticFunctionsTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamAnalyticFunctionsTest.java
@@ -67,6 +67,17 @@ public class BeamAnalyticFunctionsTest extends BeamSqlDslBase {
   }
 
   /**
+   * Table schema and data taken from.
+   * https://cloud.google.com/bigquery/docs/reference/standard-sql/analytic-function-concepts#numbering_function_concepts
+   */
+  private PCollection<Row> inputData2() {
+    Schema schema = Schema.builder().addInt32Field("x").build();
+    return pipeline
+        .apply(Create.of(TestUtils.rowsBuilderOf(schema).addRows(1, 2, 2, 5, 8, 10, 10).getRows()))
+        .setRowSchema(schema);
+  }
+
+  /**
    * Compute a cumulative sum query taken from.
    * https://cloud.google.com/bigquery/docs/reference/standard-sql/analytic-function-concepts#compute_a_cumulative_sum
    */
@@ -473,4 +484,108 @@ public class BeamAnalyticFunctionsTest extends BeamSqlDslBase {
 
     pipeline.run();
   }
+
+  @Test
+  public void testRowNumberFunction() throws Exception {
+    pipeline.enableAbandonedNodeEnforcement(false);
+    PCollection<Row> inputRows = inputData2();
+    String sql = "SELECT x, ROW_NUMBER() over (ORDER BY x ) as agg  FROM PCOLLECTION";
+    PCollection<Row> result = inputRows.apply("sql", SqlTransform.query(sql));
+
+    Schema overResultSchema = Schema.builder().addInt32Field("x").addInt64Field("agg").build();
+
+    List<Row> overResult =
+        TestUtils.RowsBuilder.of(overResultSchema)
+            .addRows(
+                1, 1L,
+                2, 2L,
+                2, 3L,
+                5, 4L,
+                8, 5L,
+                10, 6L,
+                10, 7L)
+            .getRows();
+
+    PAssert.that(result).containsInAnyOrder(overResult);
+
+    pipeline.run();
+  }
+
+  @Test
+  public void testDenseRankFunction() throws Exception {
+    pipeline.enableAbandonedNodeEnforcement(false);
+    PCollection<Row> inputRows = inputData2();
+    String sql = "SELECT x, DENSE_RANK() over (ORDER BY x ) as agg  FROM PCOLLECTION";
+    PCollection<Row> result = inputRows.apply("sql", SqlTransform.query(sql));
+
+    Schema overResultSchema = Schema.builder().addInt32Field("x").addInt64Field("agg").build();
+
+    List<Row> overResult =
+        TestUtils.RowsBuilder.of(overResultSchema)
+            .addRows(
+                1, 1L,
+                2, 2L,
+                2, 2L,
+                5, 3L,
+                8, 4L,
+                10, 5L,
+                10, 5L)
+            .getRows();
+
+    PAssert.that(result).containsInAnyOrder(overResult);
+
+    pipeline.run();
+  }
+
+  @Test
+  public void testRankFunction() throws Exception {
+    pipeline.enableAbandonedNodeEnforcement(false);
+    PCollection<Row> inputRows = inputData2();
+    String sql = "SELECT x, RANK() over (ORDER BY x ) as agg  FROM PCOLLECTION";
+    PCollection<Row> result = inputRows.apply("sql", SqlTransform.query(sql));
+
+    Schema overResultSchema = Schema.builder().addInt32Field("x").addInt64Field("agg").build();
+
+    List<Row> overResult =
+        TestUtils.RowsBuilder.of(overResultSchema)
+            .addRows(
+                1, 1L,
+                2, 2L,
+                2, 2L,
+                5, 4L,
+                8, 5L,
+                10, 6L,
+                10, 6L)
+            .getRows();
+
+    PAssert.that(result).containsInAnyOrder(overResult);
+
+    pipeline.run();
+  }
+
+  @Test
+  public void testPercentRankFunction() throws Exception {
+    pipeline.enableAbandonedNodeEnforcement(false);
+    PCollection<Row> inputRows = inputData2();
+    String sql = "SELECT x, PERCENT_RANK() over (ORDER BY x ) as agg  FROM PCOLLECTION";
+    PCollection<Row> result = inputRows.apply("sql", SqlTransform.query(sql));
+
+    Schema overResultSchema = Schema.builder().addInt32Field("x").addDoubleField("agg").build();
+
+    List<Row> overResult =
+        TestUtils.RowsBuilder.of(overResultSchema)
+            .addRows(
+                1, 0.0 / 6.0,
+                2, 1.0 / 6.0,
+                2, 1.0 / 6.0,
+                5, 3.0 / 6.0,
+                8, 4.0 / 6.0,
+                10, 5.0 / 6.0,
+                10, 5.0 / 6.0)
+            .getRows();
+
+    PAssert.that(result).containsInAnyOrder(overResult);
+
+    pipeline.run();
+  }
 }