You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ta...@apache.org on 2017/08/12 00:14:11 UTC
[1/2] beam git commit: take CombineFn as UDAF.
Repository: beam
Updated Branches:
refs/heads/DSL_SQL f37a7a19c -> 9eec6a030
take CombineFn as UDAF.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/1770c861
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/1770c861
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/1770c861
Branch: refs/heads/DSL_SQL
Commit: 1770c86121d7edc388cadc0e2791c19b027cc50f
Parents: f37a7a1
Author: mingmxu <mi...@ebay.com>
Authored: Thu Aug 10 17:42:29 2017 -0700
Committer: Tyler Akidau <ta...@apache.org>
Committed: Fri Aug 11 17:11:23 2017 -0700
----------------------------------------------------------------------
.../apache/beam/sdk/coders/BeamRecordCoder.java | 16 +-
.../apache/beam/sdk/extensions/sql/BeamSql.java | 22 +-
.../beam/sdk/extensions/sql/BeamSqlEnv.java | 11 +-
.../operator/BeamSqlInputRefExpression.java | 4 +
.../sql/impl/interpreter/operator/UdafImpl.java | 87 ++++
.../transform/BeamAggregationTransforms.java | 44 +-
.../impl/transform/BeamBuiltinAggregations.java | 504 +++++++------------
.../sdk/extensions/sql/schema/BeamSqlUdaf.java | 72 ---
.../extensions/sql/BeamSqlDslUdfUdafTest.java | 22 +-
9 files changed, 344 insertions(+), 438 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/BeamRecordCoder.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/BeamRecordCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/BeamRecordCoder.java
index cbed87d..7b1b681 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/BeamRecordCoder.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/BeamRecordCoder.java
@@ -35,11 +35,11 @@ public class BeamRecordCoder extends CustomCoder<BeamRecord> {
private static final BitSetCoder nullListCoder = BitSetCoder.of();
private BeamRecordType recordType;
- private List<Coder> coderArray;
+ private List<Coder> coders;
- private BeamRecordCoder(BeamRecordType recordType, List<Coder> coderArray) {
+ private BeamRecordCoder(BeamRecordType recordType, List<Coder> coders) {
this.recordType = recordType;
- this.coderArray = coderArray;
+ this.coders = coders;
}
public static BeamRecordCoder of(BeamRecordType recordType, List<Coder> coderArray){
@@ -62,7 +62,7 @@ public class BeamRecordCoder extends CustomCoder<BeamRecord> {
continue;
}
- coderArray.get(idx).encode(value.getFieldValue(idx), outStream);
+ coders.get(idx).encode(value.getFieldValue(idx), outStream);
}
}
@@ -75,7 +75,7 @@ public class BeamRecordCoder extends CustomCoder<BeamRecord> {
if (nullFields.get(idx)) {
fieldValues.add(null);
} else {
- fieldValues.add(coderArray.get(idx).decode(inStream));
+ fieldValues.add(coders.get(idx).decode(inStream));
}
}
BeamRecord record = new BeamRecord(recordType, fieldValues);
@@ -99,8 +99,12 @@ public class BeamRecordCoder extends CustomCoder<BeamRecord> {
@Override
public void verifyDeterministic()
throws org.apache.beam.sdk.coders.Coder.NonDeterministicException {
- for (Coder c : coderArray) {
+ for (Coder c : coders) {
c.verifyDeterministic();
}
}
+
+ public List<Coder> getCoders() {
+ return coders;
+ }
}
http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSql.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSql.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSql.java
index a1e9877..bf6a9c0 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSql.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSql.java
@@ -23,8 +23,8 @@ import org.apache.beam.sdk.coders.BeamRecordCoder;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode;
import org.apache.beam.sdk.extensions.sql.schema.BeamPCollectionTable;
import org.apache.beam.sdk.extensions.sql.schema.BeamRecordSqlType;
-import org.apache.beam.sdk.extensions.sql.schema.BeamSqlUdaf;
import org.apache.beam.sdk.extensions.sql.schema.BeamSqlUdf;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.BeamRecord;
@@ -155,10 +155,10 @@ public class BeamSql {
}
/**
- * register a UDAF function used in this query.
+ * register a {@link CombineFn} as UDAF function used in this query.
*/
- public QueryTransform withUdaf(String functionName, Class<? extends BeamSqlUdaf> clazz){
- getSqlEnv().registerUdaf(functionName, clazz);
+ public QueryTransform withUdaf(String functionName, CombineFn combineFn){
+ getSqlEnv().registerUdaf(functionName, combineFn);
return this;
}
@@ -231,13 +231,13 @@ public class BeamSql {
return this;
}
- /**
- * register a UDAF function used in this query.
- */
- public SimpleQueryTransform withUdaf(String functionName, Class<? extends BeamSqlUdaf> clazz){
- getSqlEnv().registerUdaf(functionName, clazz);
- return this;
- }
+ /**
+ * register a {@link CombineFn} as UDAF function used in this query.
+ */
+ public SimpleQueryTransform withUdaf(String functionName, CombineFn combineFn){
+ getSqlEnv().registerUdaf(functionName, combineFn);
+ return this;
+ }
private void validateQuery() {
SqlNode sqlNode;
http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlEnv.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlEnv.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlEnv.java
index 0737c49..79f2b32 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlEnv.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlEnv.java
@@ -18,12 +18,13 @@
package org.apache.beam.sdk.extensions.sql;
import java.io.Serializable;
+import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.UdafImpl;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamQueryPlanner;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.extensions.sql.schema.BaseBeamTable;
import org.apache.beam.sdk.extensions.sql.schema.BeamRecordSqlType;
-import org.apache.beam.sdk.extensions.sql.schema.BeamSqlUdaf;
import org.apache.beam.sdk.extensions.sql.schema.BeamSqlUdf;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.calcite.DataContext;
import org.apache.calcite.linq4j.Enumerable;
@@ -34,7 +35,6 @@ import org.apache.calcite.schema.Schema;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.schema.Statistic;
import org.apache.calcite.schema.Statistics;
-import org.apache.calcite.schema.impl.AggregateFunctionImpl;
import org.apache.calcite.schema.impl.ScalarFunctionImpl;
import org.apache.calcite.tools.Frameworks;
@@ -69,11 +69,10 @@ public class BeamSqlEnv implements Serializable{
}
/**
- * Register a UDAF function which can be used in GROUP-BY expression.
- * See {@link BeamSqlUdaf} on how to implement a UDAF.
+ * Register a {@link CombineFn} as UDAF function which can be used in GROUP-BY expression.
*/
- public void registerUdaf(String functionName, Class<? extends BeamSqlUdaf> clazz) {
- schema.add(functionName, AggregateFunctionImpl.create(clazz));
+ public void registerUdaf(String functionName, CombineFn combineFn) {
+ schema.add(functionName, new UdafImpl(combineFn));
}
/**
http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpression.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpression.java
index a2d1624..2c321f7 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpression.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpression.java
@@ -41,4 +41,8 @@ public class BeamSqlInputRefExpression extends BeamSqlExpression {
public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) {
return BeamSqlPrimitive.of(outputType, inputRow.getFieldValue(inputRef));
}
+
+ public int getInputRef() {
+ return inputRef;
+ }
}
http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/UdafImpl.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/UdafImpl.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/UdafImpl.java
new file mode 100644
index 0000000..83ed7f8
--- /dev/null
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/UdafImpl.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.impl.interpreter.operator;
+
+import java.io.Serializable;
+import java.lang.reflect.ParameterizedType;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.calcite.adapter.enumerable.AggImplementor;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.schema.AggregateFunction;
+import org.apache.calcite.schema.FunctionParameter;
+import org.apache.calcite.schema.ImplementableAggFunction;
+
+/**
+ * Implement {@link AggregateFunction} to take a {@link CombineFn} as UDAF.
+ */
+public final class UdafImpl<InputT, AccumT, OutputT>
+ implements AggregateFunction, ImplementableAggFunction, Serializable{
+ private CombineFn<InputT, AccumT, OutputT> combineFn;
+
+ public UdafImpl(CombineFn<InputT, AccumT, OutputT> combineFn) {
+ this.combineFn = combineFn;
+ }
+
+ public CombineFn<InputT, AccumT, OutputT> getCombineFn() {
+ return combineFn;
+ }
+
+ @Override
+ public List<FunctionParameter> getParameters() {
+ List<FunctionParameter> para = new ArrayList<>();
+ para.add(new FunctionParameter() {
+ public int getOrdinal() {
+ return 0; //up to one parameter is supported in UDAF.
+ }
+
+ public String getName() {
+ // not used as Beam SQL uses its own execution engine
+ return null;
+ }
+
+ public RelDataType getType(RelDataTypeFactory typeFactory) {
+ //the first generic type of CombineFn is the input type.
+ ParameterizedType parameterizedType = (ParameterizedType) combineFn.getClass()
+ .getGenericSuperclass();
+ return typeFactory.createJavaType(
+ (Class) parameterizedType.getActualTypeArguments()[0]);
+ }
+
+ public boolean isOptional() {
+ // not used as Beam SQL uses its own execution engine
+ return false;
+ }
+ });
+ return para;
+ }
+
+ @Override
+ public AggImplementor getImplementor(boolean windowContext) {
+ // not used as Beam SQL uses its own execution engine
+ return null;
+ }
+
+ @Override
+ public RelDataType getReturnType(RelDataTypeFactory typeFactory) {
+ return typeFactory.createJavaType((Class) combineFn.getOutputType().getType());
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
index 0f90bee..40b7b58 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
@@ -25,6 +25,7 @@ import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
+import org.apache.beam.sdk.coders.BeamRecordCoder;
import org.apache.beam.sdk.coders.BigDecimalCoder;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
@@ -32,13 +33,13 @@ import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.CustomCoder;
import org.apache.beam.sdk.coders.VarIntCoder;
-import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression;
import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlInputRefExpression;
+import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.UdafImpl;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.extensions.sql.schema.BeamRecordSqlType;
import org.apache.beam.sdk.extensions.sql.schema.BeamSqlRecordHelper;
-import org.apache.beam.sdk.extensions.sql.schema.BeamSqlUdaf;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -46,7 +47,6 @@ import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.values.BeamRecord;
import org.apache.beam.sdk.values.KV;
import org.apache.calcite.rel.core.AggregateCall;
-import org.apache.calcite.schema.impl.AggregateFunctionImpl;
import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
import org.apache.calcite.util.ImmutableBitSet;
import org.joda.time.Instant;
@@ -151,8 +151,8 @@ public class BeamAggregationTransforms implements Serializable{
*/
public static class AggregationAdaptor
extends CombineFn<BeamRecord, AggregationAccumulator, BeamRecord> {
- private List<BeamSqlUdaf> aggregators;
- private List<BeamSqlExpression> sourceFieldExps;
+ private List<CombineFn> aggregators;
+ private List<BeamSqlInputRefExpression> sourceFieldExps;
private BeamRecordSqlType finalRowType;
public AggregationAdaptor(List<AggregateCall> aggregationCalls,
@@ -163,7 +163,7 @@ public class BeamAggregationTransforms implements Serializable{
List<Integer> outFieldsType = new ArrayList<>();
for (AggregateCall call : aggregationCalls) {
int refIndex = call.getArgList().size() > 0 ? call.getArgList().get(0) : 0;
- BeamSqlExpression sourceExp = new BeamSqlInputRefExpression(
+ BeamSqlInputRefExpression sourceExp = new BeamSqlInputRefExpression(
CalciteUtils.getFieldType(sourceRowType, refIndex), refIndex);
sourceFieldExps.add(sourceExp);
@@ -173,27 +173,27 @@ public class BeamAggregationTransforms implements Serializable{
switch (call.getAggregation().getName()) {
case "COUNT":
- aggregators.add(new BeamBuiltinAggregations.Count());
+ aggregators.add(Count.combineFn());
break;
case "MAX":
- aggregators.add(BeamBuiltinAggregations.Max.create(call.type.getSqlTypeName()));
+ aggregators.add(BeamBuiltinAggregations.createMax(call.type.getSqlTypeName()));
break;
case "MIN":
- aggregators.add(BeamBuiltinAggregations.Min.create(call.type.getSqlTypeName()));
+ aggregators.add(BeamBuiltinAggregations.createMin(call.type.getSqlTypeName()));
break;
case "SUM":
- aggregators.add(BeamBuiltinAggregations.Sum.create(call.type.getSqlTypeName()));
+ aggregators.add(BeamBuiltinAggregations.createSum(call.type.getSqlTypeName()));
break;
case "AVG":
- aggregators.add(BeamBuiltinAggregations.Avg.create(call.type.getSqlTypeName()));
+ aggregators.add(BeamBuiltinAggregations.createAvg(call.type.getSqlTypeName()));
break;
default:
if (call.getAggregation() instanceof SqlUserDefinedAggFunction) {
// handle UDAF.
SqlUserDefinedAggFunction udaf = (SqlUserDefinedAggFunction) call.getAggregation();
- AggregateFunctionImpl fn = (AggregateFunctionImpl) udaf.function;
+ UdafImpl fn = (UdafImpl) udaf.function;
try {
- aggregators.add((BeamSqlUdaf) fn.declaringClass.newInstance());
+ aggregators.add(fn.getCombineFn());
} catch (Exception e) {
throw new IllegalStateException(e);
}
@@ -210,8 +210,8 @@ public class BeamAggregationTransforms implements Serializable{
@Override
public AggregationAccumulator createAccumulator() {
AggregationAccumulator initialAccu = new AggregationAccumulator();
- for (BeamSqlUdaf agg : aggregators) {
- initialAccu.accumulatorElements.add(agg.init());
+ for (CombineFn agg : aggregators) {
+ initialAccu.accumulatorElements.add(agg.createAccumulator());
}
return initialAccu;
}
@@ -220,7 +220,7 @@ public class BeamAggregationTransforms implements Serializable{
AggregationAccumulator deltaAcc = new AggregationAccumulator();
for (int idx = 0; idx < aggregators.size(); ++idx) {
deltaAcc.accumulatorElements.add(
- aggregators.get(idx).add(accumulator.accumulatorElements.get(idx),
+ aggregators.get(idx).addInput(accumulator.accumulatorElements.get(idx),
sourceFieldExps.get(idx).evaluate(input, null).getValue()));
}
return deltaAcc;
@@ -234,7 +234,7 @@ public class BeamAggregationTransforms implements Serializable{
while (ite.hasNext()) {
accs.add(ite.next().accumulatorElements.get(idx));
}
- deltaAcc.accumulatorElements.add(aggregators.get(idx).merge(accs));
+ deltaAcc.accumulatorElements.add(aggregators.get(idx).mergeAccumulators(accs));
}
return deltaAcc;
}
@@ -242,7 +242,8 @@ public class BeamAggregationTransforms implements Serializable{
public BeamRecord extractOutput(AggregationAccumulator accumulator) {
List<Object> fieldValues = new ArrayList<>(aggregators.size());
for (int idx = 0; idx < aggregators.size(); ++idx) {
- fieldValues.add(aggregators.get(idx).result(accumulator.accumulatorElements.get(idx)));
+ fieldValues
+ .add(aggregators.get(idx).extractOutput(accumulator.accumulatorElements.get(idx)));
}
return new BeamRecord(finalRowType, fieldValues);
}
@@ -250,10 +251,13 @@ public class BeamAggregationTransforms implements Serializable{
public Coder<AggregationAccumulator> getAccumulatorCoder(
CoderRegistry registry, Coder<BeamRecord> inputCoder)
throws CannotProvideCoderException {
+ BeamRecordCoder beamRecordCoder = (BeamRecordCoder) inputCoder;
registry.registerCoderForClass(BigDecimal.class, BigDecimalCoder.of());
List<Coder> aggAccuCoderList = new ArrayList<>();
- for (BeamSqlUdaf udaf : aggregators) {
- aggAccuCoderList.add(udaf.getAccumulatorCoder(registry));
+ for (int idx = 0; idx < aggregators.size(); ++idx) {
+ int srcFieldIndex = sourceFieldExps.get(idx).getInputRef();
+ Coder srcFieldCoder = beamRecordCoder.getCoders().get(srcFieldIndex);
+ aggAccuCoderList.add(aggregators.get(idx).getAccumulatorCoder(registry, srcFieldCoder));
}
return new AggregationAccumulatorCoder(aggAccuCoderList);
}
http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
index 1fc8cf6..03edf13 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
@@ -21,16 +21,16 @@ import java.math.BigDecimal;
import java.util.Date;
import java.util.Iterator;
import org.apache.beam.sdk.coders.BigDecimalCoder;
-import org.apache.beam.sdk.coders.ByteCoder;
+import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
-import org.apache.beam.sdk.coders.DoubleCoder;
import org.apache.beam.sdk.coders.KvCoder;
-import org.apache.beam.sdk.coders.SerializableCoder;
-import org.apache.beam.sdk.coders.VarIntCoder;
-import org.apache.beam.sdk.coders.VarLongCoder;
-import org.apache.beam.sdk.extensions.sql.schema.BeamSqlUdaf;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.Max;
+import org.apache.beam.sdk.transforms.Min;
+import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.values.KV;
import org.apache.calcite.sql.type.SqlTypeName;
@@ -39,374 +39,258 @@ import org.apache.calcite.sql.type.SqlTypeName;
*/
class BeamBuiltinAggregations {
/**
- * Built-in aggregation for COUNT.
+ * {@link CombineFn} for MAX based on {@link Max} and {@link Combine.BinaryCombineFn}.
*/
- public static final class Count<T> extends BeamSqlUdaf<T, Long, Long> {
- public Count() {}
-
- @Override
- public Long init() {
- return 0L;
- }
-
- @Override
- public Long add(Long accumulator, T input) {
- return accumulator + 1;
- }
+ public static CombineFn createMax(SqlTypeName fieldType) {
+ switch (fieldType) {
+ case INTEGER:
+ return Max.ofIntegers();
+ case SMALLINT:
+ return new CustMax<Short>();
+ case TINYINT:
+ return new CustMax<Byte>();
+ case BIGINT:
+ return Max.ofLongs();
+ case FLOAT:
+ return new CustMax<Float>();
+ case DOUBLE:
+ return Max.ofDoubles();
+ case TIMESTAMP:
+ return new CustMax<Date>();
+ case DECIMAL:
+ return new CustMax<BigDecimal>();
+ default:
+ throw new UnsupportedOperationException(
+ String.format("[%s] is not support in MAX", fieldType));
+ }
+ }
- @Override
- public Long merge(Iterable<Long> accumulators) {
- long v = 0L;
- Iterator<Long> ite = accumulators.iterator();
- while (ite.hasNext()) {
- v += ite.next();
- }
- return v;
- }
+ /**
+ * {@link CombineFn} for MAX based on {@link Min} and {@link Combine.BinaryCombineFn}.
+ */
+ public static CombineFn createMin(SqlTypeName fieldType) {
+ switch (fieldType) {
+ case INTEGER:
+ return Min.ofIntegers();
+ case SMALLINT:
+ return new CustMin<Short>();
+ case TINYINT:
+ return new CustMin<Byte>();
+ case BIGINT:
+ return Min.ofLongs();
+ case FLOAT:
+ return new CustMin<Float>();
+ case DOUBLE:
+ return Min.ofDoubles();
+ case TIMESTAMP:
+ return new CustMin<Date>();
+ case DECIMAL:
+ return new CustMin<BigDecimal>();
+ default:
+ throw new UnsupportedOperationException(
+ String.format("[%s] is not support in MIN", fieldType));
+ }
+ }
- @Override
- public Long result(Long accumulator) {
- return accumulator;
- }
+ /**
+ * {@link CombineFn} for MAX based on {@link Sum} and {@link Combine.BinaryCombineFn}.
+ */
+ public static CombineFn createSum(SqlTypeName fieldType) {
+ switch (fieldType) {
+ case INTEGER:
+ return Sum.ofIntegers();
+ case SMALLINT:
+ return new ShortSum();
+ case TINYINT:
+ return new ByteSum();
+ case BIGINT:
+ return Sum.ofLongs();
+ case FLOAT:
+ return new FloatSum();
+ case DOUBLE:
+ return Sum.ofDoubles();
+ case DECIMAL:
+ return new BigDecimalSum();
+ default:
+ throw new UnsupportedOperationException(
+ String.format("[%s] is not support in SUM", fieldType));
+ }
}
/**
- * Built-in aggregation for MAX.
+ * {@link CombineFn} for AVG.
*/
- public static final class Max<T extends Comparable<T>> extends BeamSqlUdaf<T, T, T> {
- public static Max create(SqlTypeName fieldType) {
- switch (fieldType) {
- case INTEGER:
- return new BeamBuiltinAggregations.Max<Integer>(fieldType);
- case SMALLINT:
- return new BeamBuiltinAggregations.Max<Short>(fieldType);
- case TINYINT:
- return new BeamBuiltinAggregations.Max<Byte>(fieldType);
- case BIGINT:
- return new BeamBuiltinAggregations.Max<Long>(fieldType);
- case FLOAT:
- return new BeamBuiltinAggregations.Max<Float>(fieldType);
- case DOUBLE:
- return new BeamBuiltinAggregations.Max<Double>(fieldType);
- case TIMESTAMP:
- return new BeamBuiltinAggregations.Max<Date>(fieldType);
- case DECIMAL:
- return new BeamBuiltinAggregations.Max<BigDecimal>(fieldType);
- default:
- throw new UnsupportedOperationException(
- String.format("[%s] is not support in MAX", fieldType));
- }
- }
+ public static CombineFn createAvg(SqlTypeName fieldType) {
+ switch (fieldType) {
+ case INTEGER:
+ return new IntegerAvg();
+ case SMALLINT:
+ return new ShortAvg();
+ case TINYINT:
+ return new ByteAvg();
+ case BIGINT:
+ return new LongAvg();
+ case FLOAT:
+ return new FloatAvg();
+ case DOUBLE:
+ return new DoubleAvg();
+ case DECIMAL:
+ return new BigDecimalAvg();
+ default:
+ throw new UnsupportedOperationException(
+ String.format("[%s] is not support in AVG", fieldType));
+ }
+ }
- private final SqlTypeName fieldType;
- private Max(SqlTypeName fieldType) {
- this.fieldType = fieldType;
+ static class CustMax<T extends Comparable<T>> extends Combine.BinaryCombineFn<T> {
+ public T apply(T left, T right) {
+ return (right == null || right.compareTo(left) < 0) ? left : right;
}
+ }
- @Override
- public T init() {
- return null;
+ static class CustMin<T extends Comparable<T>> extends Combine.BinaryCombineFn<T> {
+ public T apply(T left, T right) {
+ return (left == null || left.compareTo(right) < 0) ? left : right;
}
+ }
- @Override
- public T add(T accumulator, T input) {
- return (accumulator == null || accumulator.compareTo(input) < 0) ? input : accumulator;
+ static class ShortSum extends Combine.BinaryCombineFn<Short> {
+ public Short apply(Short left, Short right) {
+ return (short) (left + right);
}
+ }
- @Override
- public T merge(Iterable<T> accumulators) {
- Iterator<T> ite = accumulators.iterator();
- T mergedV = ite.next();
- while (ite.hasNext()) {
- T v = ite.next();
- mergedV = mergedV.compareTo(v) > 0 ? mergedV : v;
- }
- return mergedV;
+ static class ByteSum extends Combine.BinaryCombineFn<Byte> {
+ public Byte apply(Byte left, Byte right) {
+ return (byte) (left + right);
}
+ }
- @Override
- public T result(T accumulator) {
- return accumulator;
+ static class FloatSum extends Combine.BinaryCombineFn<Float> {
+ public Float apply(Float left, Float right) {
+ return left + right;
}
+ }
- @Override
- public Coder<T> getAccumulatorCoder(CoderRegistry registry) throws CannotProvideCoderException {
- return BeamBuiltinAggregations.getSqlTypeCoder(fieldType);
+ static class BigDecimalSum extends Combine.BinaryCombineFn<BigDecimal> {
+ public BigDecimal apply(BigDecimal left, BigDecimal right) {
+ return left.add(right);
}
}
/**
- * Built-in aggregation for MIN.
+ * {@link CombineFn} for <em>AVG</em> on {@link Number} types.
*/
- public static final class Min<T extends Comparable<T>> extends BeamSqlUdaf<T, T, T> {
- public static Min create(SqlTypeName fieldType) {
- switch (fieldType) {
- case INTEGER:
- return new BeamBuiltinAggregations.Min<Integer>(fieldType);
- case SMALLINT:
- return new BeamBuiltinAggregations.Min<Short>(fieldType);
- case TINYINT:
- return new BeamBuiltinAggregations.Min<Byte>(fieldType);
- case BIGINT:
- return new BeamBuiltinAggregations.Min<Long>(fieldType);
- case FLOAT:
- return new BeamBuiltinAggregations.Min<Float>(fieldType);
- case DOUBLE:
- return new BeamBuiltinAggregations.Min<Double>(fieldType);
- case TIMESTAMP:
- return new BeamBuiltinAggregations.Min<Date>(fieldType);
- case DECIMAL:
- return new BeamBuiltinAggregations.Min<BigDecimal>(fieldType);
- default:
- throw new UnsupportedOperationException(
- String.format("[%s] is not support in MIN", fieldType));
- }
- }
-
- private final SqlTypeName fieldType;
- private Min(SqlTypeName fieldType) {
- this.fieldType = fieldType;
- }
-
+ abstract static class Avg<T extends Number>
+ extends CombineFn<T, KV<Integer, BigDecimal>, T> {
@Override
- public T init() {
- return null;
+ public KV<Integer, BigDecimal> createAccumulator() {
+ return KV.of(0, new BigDecimal(0));
}
@Override
- public T add(T accumulator, T input) {
- return (accumulator == null || accumulator.compareTo(input) > 0) ? input : accumulator;
+ public KV<Integer, BigDecimal> addInput(KV<Integer, BigDecimal> accumulator, T input) {
+ return KV.of(accumulator.getKey() + 1, accumulator.getValue().add(toBigDecimal(input)));
}
@Override
- public T merge(Iterable<T> accumulators) {
- Iterator<T> ite = accumulators.iterator();
- T mergedV = ite.next();
+ public KV<Integer, BigDecimal> mergeAccumulators(
+ Iterable<KV<Integer, BigDecimal>> accumulators) {
+ int size = 0;
+ BigDecimal acc = new BigDecimal(0);
+ Iterator<KV<Integer, BigDecimal>> ite = accumulators.iterator();
while (ite.hasNext()) {
- T v = ite.next();
- mergedV = mergedV.compareTo(v) < 0 ? mergedV : v;
+ KV<Integer, BigDecimal> ele = ite.next();
+ size += ele.getKey();
+ acc = acc.add(ele.getValue());
}
- return mergedV;
+ return KV.of(size, acc);
}
@Override
- public T result(T accumulator) {
- return accumulator;
+ public Coder<KV<Integer, BigDecimal>> getAccumulatorCoder(CoderRegistry registry,
+ Coder<T> inputCoder) throws CannotProvideCoderException {
+ return KvCoder.of(BigEndianIntegerCoder.of(), BigDecimalCoder.of());
}
- @Override
- public Coder<T> getAccumulatorCoder(CoderRegistry registry) throws CannotProvideCoderException {
- return BeamBuiltinAggregations.getSqlTypeCoder(fieldType);
- }
+ public abstract T extractOutput(KV<Integer, BigDecimal> accumulator);
+ public abstract BigDecimal toBigDecimal(T record);
}
- /**
- * Built-in aggregation for SUM.
- */
- public static final class Sum<T> extends BeamSqlUdaf<T, BigDecimal, T> {
- public static Sum create(SqlTypeName fieldType) {
- switch (fieldType) {
- case INTEGER:
- return new BeamBuiltinAggregations.Sum<Integer>(fieldType);
- case SMALLINT:
- return new BeamBuiltinAggregations.Sum<Short>(fieldType);
- case TINYINT:
- return new BeamBuiltinAggregations.Sum<Byte>(fieldType);
- case BIGINT:
- return new BeamBuiltinAggregations.Sum<Long>(fieldType);
- case FLOAT:
- return new BeamBuiltinAggregations.Sum<Float>(fieldType);
- case DOUBLE:
- return new BeamBuiltinAggregations.Sum<Double>(fieldType);
- case TIMESTAMP:
- return new BeamBuiltinAggregations.Sum<Date>(fieldType);
- case DECIMAL:
- return new BeamBuiltinAggregations.Sum<BigDecimal>(fieldType);
- default:
- throw new UnsupportedOperationException(
- String.format("[%s] is not support in SUM", fieldType));
- }
+ static class IntegerAvg extends Avg<Integer>{
+ public Integer extractOutput(KV<Integer, BigDecimal> accumulator) {
+ return accumulator.getKey() == 0 ? null
+ : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).intValue();
}
- private SqlTypeName fieldType;
- private Sum(SqlTypeName fieldType) {
- this.fieldType = fieldType;
- }
+ public BigDecimal toBigDecimal(Integer record) {
+ return new BigDecimal(record);
+ }
+ }
- @Override
- public BigDecimal init() {
- return new BigDecimal(0);
+ static class LongAvg extends Avg<Long>{
+ public Long extractOutput(KV<Integer, BigDecimal> accumulator) {
+ return accumulator.getKey() == 0 ? null
+ : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).longValue();
}
- @Override
- public BigDecimal add(BigDecimal accumulator, T input) {
- return accumulator.add(new BigDecimal(input.toString()));
+ public BigDecimal toBigDecimal(Long record) {
+ return new BigDecimal(record);
}
+ }
- @Override
- public BigDecimal merge(Iterable<BigDecimal> accumulators) {
- BigDecimal v = new BigDecimal(0);
- Iterator<BigDecimal> ite = accumulators.iterator();
- while (ite.hasNext()) {
- v = v.add(ite.next());
- }
- return v;
+ static class ShortAvg extends Avg<Short>{
+ public Short extractOutput(KV<Integer, BigDecimal> accumulator) {
+ return accumulator.getKey() == 0 ? null
+ : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).shortValue();
}
- @Override
- public T result(BigDecimal accumulator) {
- Object result = null;
- switch (fieldType) {
- case INTEGER:
- result = accumulator.intValue();
- break;
- case BIGINT:
- result = accumulator.longValue();
- break;
- case SMALLINT:
- result = accumulator.shortValue();
- break;
- case TINYINT:
- result = accumulator.byteValue();
- break;
- case DOUBLE:
- result = accumulator.doubleValue();
- break;
- case FLOAT:
- result = accumulator.floatValue();
- break;
- case DECIMAL:
- result = accumulator;
- break;
- default:
- break;
- }
- return (T) result;
+ public BigDecimal toBigDecimal(Short record) {
+ return new BigDecimal(record);
}
}
- /**
- * Built-in aggregation for AVG.
- */
- public static final class Avg<T> extends BeamSqlUdaf<T, KV<BigDecimal, Long>, T> {
- public static Avg create(SqlTypeName fieldType) {
- switch (fieldType) {
- case INTEGER:
- return new BeamBuiltinAggregations.Avg<Integer>(fieldType);
- case SMALLINT:
- return new BeamBuiltinAggregations.Avg<Short>(fieldType);
- case TINYINT:
- return new BeamBuiltinAggregations.Avg<Byte>(fieldType);
- case BIGINT:
- return new BeamBuiltinAggregations.Avg<Long>(fieldType);
- case FLOAT:
- return new BeamBuiltinAggregations.Avg<Float>(fieldType);
- case DOUBLE:
- return new BeamBuiltinAggregations.Avg<Double>(fieldType);
- case TIMESTAMP:
- return new BeamBuiltinAggregations.Avg<Date>(fieldType);
- case DECIMAL:
- return new BeamBuiltinAggregations.Avg<BigDecimal>(fieldType);
- default:
- throw new UnsupportedOperationException(
- String.format("[%s] is not support in AVG", fieldType));
- }
+ static class ByteAvg extends Avg<Byte>{
+ public Byte extractOutput(KV<Integer, BigDecimal> accumulator) {
+ return accumulator.getKey() == 0 ? null
+ : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).byteValue();
}
- private SqlTypeName fieldType;
- private Avg(SqlTypeName fieldType) {
- this.fieldType = fieldType;
- }
-
- @Override
- public KV<BigDecimal, Long> init() {
- return KV.of(new BigDecimal(0), 0L);
+ public BigDecimal toBigDecimal(Byte record) {
+ return new BigDecimal(record);
}
+ }
- @Override
- public KV<BigDecimal, Long> add(KV<BigDecimal, Long> accumulator, T input) {
- return KV.of(
- accumulator.getKey().add(new BigDecimal(input.toString())),
- accumulator.getValue() + 1);
+ static class FloatAvg extends Avg<Float>{
+ public Float extractOutput(KV<Integer, BigDecimal> accumulator) {
+ return accumulator.getKey() == 0 ? null
+ : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).floatValue();
}
- @Override
- public KV<BigDecimal, Long> merge(Iterable<KV<BigDecimal, Long>> accumulators) {
- BigDecimal v = new BigDecimal(0);
- long s = 0;
- Iterator<KV<BigDecimal, Long>> ite = accumulators.iterator();
- while (ite.hasNext()) {
- KV<BigDecimal, Long> r = ite.next();
- v = v.add(r.getKey());
- s += r.getValue();
- }
- return KV.of(v, s);
+ public BigDecimal toBigDecimal(Float record) {
+ return new BigDecimal(record);
}
+ }
- @Override
- public T result(KV<BigDecimal, Long> accumulator) {
- BigDecimal decimalAvg = accumulator.getKey().divide(
- new BigDecimal(accumulator.getValue()));
- Object result = null;
- switch (fieldType) {
- case INTEGER:
- result = decimalAvg.intValue();
- break;
- case BIGINT:
- result = decimalAvg.longValue();
- break;
- case SMALLINT:
- result = decimalAvg.shortValue();
- break;
- case TINYINT:
- result = decimalAvg.byteValue();
- break;
- case DOUBLE:
- result = decimalAvg.doubleValue();
- break;
- case FLOAT:
- result = decimalAvg.floatValue();
- break;
- case DECIMAL:
- result = decimalAvg;
- break;
- default:
- break;
- }
- return (T) result;
+ static class DoubleAvg extends Avg<Double>{
+ public Double extractOutput(KV<Integer, BigDecimal> accumulator) {
+ return accumulator.getKey() == 0 ? null
+ : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).doubleValue();
}
- @Override
- public Coder<KV<BigDecimal, Long>> getAccumulatorCoder(CoderRegistry registry)
- throws CannotProvideCoderException {
- return KvCoder.of(BigDecimalCoder.of(), VarLongCoder.of());
+ public BigDecimal toBigDecimal(Double record) {
+ return new BigDecimal(record);
}
}
- /**
- * Find {@link Coder} for Beam SQL field types.
- */
- private static Coder getSqlTypeCoder(SqlTypeName sqlType) {
- switch (sqlType) {
- case INTEGER:
- return VarIntCoder.of();
- case SMALLINT:
- return SerializableCoder.of(Short.class);
- case TINYINT:
- return ByteCoder.of();
- case BIGINT:
- return VarLongCoder.of();
- case FLOAT:
- return SerializableCoder.of(Float.class);
- case DOUBLE:
- return DoubleCoder.of();
- case TIMESTAMP:
- return SerializableCoder.of(Date.class);
- case DECIMAL:
- return BigDecimalCoder.of();
- default:
- throw new UnsupportedOperationException(
- String.format("Cannot find a Coder for data type [%s]", sqlType));
+ static class BigDecimalAvg extends Avg<BigDecimal>{
+ public BigDecimal extractOutput(KV<Integer, BigDecimal> accumulator) {
+ return accumulator.getKey() == 0 ? null
+ : accumulator.getValue().divide(new BigDecimal(accumulator.getKey()));
+ }
+
+ public BigDecimal toBigDecimal(BigDecimal record) {
+ return record;
}
}
}
http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/schema/BeamSqlUdaf.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/schema/BeamSqlUdaf.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/schema/BeamSqlUdaf.java
deleted file mode 100644
index 2f78586..0000000
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/schema/BeamSqlUdaf.java
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.extensions.sql.schema;
-
-import java.io.Serializable;
-import java.lang.reflect.ParameterizedType;
-import org.apache.beam.sdk.coders.CannotProvideCoderException;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.CoderRegistry;
-import org.apache.beam.sdk.transforms.Combine.CombineFn;
-
-/**
- * abstract class of aggregation functions in Beam SQL.
- *
- * <p>There're several constrains for a UDAF:<br>
- * 1. A constructor with an empty argument list is required;<br>
- * 2. The type of {@code InputT} and {@code OutputT} can only be Interger/Long/Short/Byte/Double
- * /Float/Date/BigDecimal, mapping as SQL type INTEGER/BIGINT/SMALLINT/TINYINE/DOUBLE/FLOAT
- * /TIMESTAMP/DECIMAL;<br>
- * 3. Keep intermediate data in {@code AccumT}, and do not rely on elements in class;<br>
- */
-public abstract class BeamSqlUdaf<InputT, AccumT, OutputT> implements Serializable {
- public BeamSqlUdaf(){}
-
- /**
- * create an initial aggregation object, equals to {@link CombineFn#createAccumulator()}.
- */
- public abstract AccumT init();
-
- /**
- * add an input value, equals to {@link CombineFn#addInput(Object, Object)}.
- */
- public abstract AccumT add(AccumT accumulator, InputT input);
-
- /**
- * merge aggregation objects from parallel tasks, equals to
- * {@link CombineFn#mergeAccumulators(Iterable)}.
- */
- public abstract AccumT merge(Iterable<AccumT> accumulators);
-
- /**
- * extract output value from aggregation object, equals to
- * {@link CombineFn#extractOutput(Object)}.
- */
- public abstract OutputT result(AccumT accumulator);
-
- /**
- * get the coder for AccumT which stores the intermediate result.
- * By default it's fetched from {@link CoderRegistry}.
- */
- public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry)
- throws CannotProvideCoderException {
- return registry.getCoder(
- (Class<AccumT>) ((ParameterizedType) getClass()
- .getGenericSuperclass()).getActualTypeArguments()[1]);
- }
-}
http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
index 0552cbf..1541123 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
@@ -21,9 +21,9 @@ import java.sql.Types;
import java.util.Arrays;
import java.util.Iterator;
import org.apache.beam.sdk.extensions.sql.schema.BeamRecordSqlType;
-import org.apache.beam.sdk.extensions.sql.schema.BeamSqlUdaf;
import org.apache.beam.sdk.extensions.sql.schema.BeamSqlUdf;
import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.BeamRecord;
import org.apache.beam.sdk.values.PCollection;
@@ -49,7 +49,7 @@ public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase {
+ " FROM PCOLLECTION GROUP BY f_int2";
PCollection<BeamRecord> result1 =
boundedInput1.apply("testUdaf1",
- BeamSql.simpleQuery(sql1).withUdaf("squaresum1", SquareSum.class));
+ BeamSql.simpleQuery(sql1).withUdaf("squaresum1", new SquareSum()));
PAssert.that(result1).containsInAnyOrder(record);
String sql2 = "SELECT f_int2, squaresum2(f_int) AS `squaresum`"
@@ -57,7 +57,7 @@ public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase {
PCollection<BeamRecord> result2 =
PCollectionTuple.of(new TupleTag<BeamRecord>("PCOLLECTION"), boundedInput1)
.apply("testUdaf2",
- BeamSql.query(sql2).withUdaf("squaresum2", SquareSum.class));
+ BeamSql.query(sql2).withUdaf("squaresum2", new SquareSum()));
PAssert.that(result2).containsInAnyOrder(record);
pipeline.run().waitUntilFinish();
@@ -90,25 +90,21 @@ public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase {
}
/**
- * UDAF for test, which returns the sum of square.
+ * UDAF(CombineFn) for test, which returns the sum of square.
*/
- public static class SquareSum extends BeamSqlUdaf<Integer, Integer, Integer> {
-
- public SquareSum() {
- }
-
+ public static class SquareSum extends CombineFn<Integer, Integer, Integer> {
@Override
- public Integer init() {
+ public Integer createAccumulator() {
return 0;
}
@Override
- public Integer add(Integer accumulator, Integer input) {
+ public Integer addInput(Integer accumulator, Integer input) {
return accumulator + input * input;
}
@Override
- public Integer merge(Iterable<Integer> accumulators) {
+ public Integer mergeAccumulators(Iterable<Integer> accumulators) {
int v = 0;
Iterator<Integer> ite = accumulators.iterator();
while (ite.hasNext()) {
@@ -118,7 +114,7 @@ public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase {
}
@Override
- public Integer result(Integer accumulator) {
+ public Integer extractOutput(Integer accumulator) {
return accumulator;
}
[2/2] beam git commit: [BEAM-2747] This closes #3716
Posted by ta...@apache.org.
[BEAM-2747] This closes #3716
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/9eec6a03
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/9eec6a03
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/9eec6a03
Branch: refs/heads/DSL_SQL
Commit: 9eec6a030a4d14b504730b57a8d863a5bd97468e
Parents: f37a7a1 1770c86
Author: Tyler Akidau <ta...@apache.org>
Authored: Fri Aug 11 17:12:29 2017 -0700
Committer: Tyler Akidau <ta...@apache.org>
Committed: Fri Aug 11 17:12:29 2017 -0700
----------------------------------------------------------------------
.../apache/beam/sdk/coders/BeamRecordCoder.java | 16 +-
.../apache/beam/sdk/extensions/sql/BeamSql.java | 22 +-
.../beam/sdk/extensions/sql/BeamSqlEnv.java | 11 +-
.../operator/BeamSqlInputRefExpression.java | 4 +
.../sql/impl/interpreter/operator/UdafImpl.java | 87 ++++
.../transform/BeamAggregationTransforms.java | 44 +-
.../impl/transform/BeamBuiltinAggregations.java | 504 +++++++------------
.../sdk/extensions/sql/schema/BeamSqlUdaf.java | 72 ---
.../extensions/sql/BeamSqlDslUdfUdafTest.java | 22 +-
9 files changed, 344 insertions(+), 438 deletions(-)
----------------------------------------------------------------------