You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by dh...@apache.org on 2017/05/13 00:47:12 UTC
[1/2] beam git commit: Support common-used aggregation functions in
SQL, including: COUNT, SUM, AVG, MAX, MIN
Repository: beam
Updated Branches:
refs/heads/DSL_SQL 6729a027d -> 523482be0
Support common-used aggregation functions in SQL, including:
COUNT,SUM,AVG,MAX,MIN
rename BeamAggregationTransform to BeamAggregationTransforms
update comments
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/f728fbe5
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/f728fbe5
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/f728fbe5
Branch: refs/heads/DSL_SQL
Commit: f728fbe5c7153341ace046fa8b2465ef8844be1b
Parents: 6729a02
Author: mingmxu <mi...@ebay.com>
Authored: Wed May 10 16:38:13 2017 -0700
Committer: mingmxu <mi...@ebay.com>
Committed: Wed May 10 20:47:40 2017 -0700
----------------------------------------------------------------------
.../interpreter/operator/BeamSqlPrimitive.java | 35 +
.../beam/dsls/sql/rel/BeamAggregationRel.java | 40 +-
.../apache/beam/dsls/sql/schema/BeamSQLRow.java | 4 +
.../beam/dsls/sql/schema/BeamSqlRowCoder.java | 4 +-
.../sql/transform/BeamAggregationTransform.java | 120 ----
.../transform/BeamAggregationTransforms.java | 671 +++++++++++++++++++
.../transform/BeamAggregationTransformTest.java | 436 ++++++++++++
.../schema/transform/BeamTransformBaseTest.java | 96 +++
8 files changed, 1261 insertions(+), 145 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/interpreter/operator/BeamSqlPrimitive.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/interpreter/operator/BeamSqlPrimitive.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/interpreter/operator/BeamSqlPrimitive.java
index 3309577..a5938f3 100644
--- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/interpreter/operator/BeamSqlPrimitive.java
+++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/interpreter/operator/BeamSqlPrimitive.java
@@ -65,6 +65,41 @@ public class BeamSqlPrimitive<T> extends BeamSqlExpression{
return value;
}
+ public long getLong() {
+ return (Long) getValue();
+ }
+
+ public double getDouble() {
+ return (Double) getValue();
+ }
+
+ public float getFloat() {
+ return (Float) getValue();
+ }
+
+ public int getInteger() {
+ return (Integer) getValue();
+ }
+
+ public short getShort() {
+ return (Short) getValue();
+ }
+
+ public byte getByte() {
+ return (Byte) getValue();
+ }
+ public boolean getBoolean() {
+ return (Boolean) getValue();
+ }
+
+ public String getString() {
+ return (String) getValue();
+ }
+
+ public Date getDate() {
+ return (Date) getValue();
+ }
+
@Override
public boolean accept() {
if (value == null) {
http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java
index 2c7626d..ab98cc4 100644
--- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java
+++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java
@@ -18,15 +18,13 @@
package org.apache.beam.dsls.sql.rel;
import java.util.List;
-import org.apache.beam.dsls.sql.exception.BeamSqlUnsupportedException;
import org.apache.beam.dsls.sql.planner.BeamPipelineCreator;
import org.apache.beam.dsls.sql.planner.BeamSQLRelUtils;
import org.apache.beam.dsls.sql.schema.BeamSQLRecordType;
import org.apache.beam.dsls.sql.schema.BeamSQLRow;
-import org.apache.beam.dsls.sql.transform.BeamAggregationTransform;
+import org.apache.beam.dsls.sql.transform.BeamAggregationTransforms;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.transforms.Combine;
-import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.WithKeys;
@@ -79,7 +77,7 @@ public class BeamAggregationRel extends Aggregate implements BeamRelNode {
PCollection<BeamSQLRow> upstream = planCreator.popUpstream();
if (windowFieldIdx != -1) {
upstream = upstream.apply("assignEventTimestamp", WithTimestamps
- .<BeamSQLRow>of(new BeamAggregationTransform.WindowTimestampFn(windowFieldIdx)));
+ .<BeamSQLRow>of(new BeamAggregationTransforms.WindowTimestampFn(windowFieldIdx)));
}
PCollection<BeamSQLRow> windowStream = upstream.apply("window",
@@ -88,32 +86,26 @@ public class BeamAggregationRel extends Aggregate implements BeamRelNode {
.withAllowedLateness(allowedLatence)
.accumulatingFiredPanes());
+ //1. extract fields in group-by key part
PCollection<KV<BeamSQLRow, BeamSQLRow>> exGroupByStream = windowStream.apply("exGroupBy",
WithKeys
- .of(new BeamAggregationTransform.AggregationGroupByKeyFn(windowFieldIdx, groupSet)));
+ .of(new BeamAggregationTransforms.AggregationGroupByKeyFn(windowFieldIdx, groupSet)));
+ //2. apply a GroupByKey.
PCollection<KV<BeamSQLRow, Iterable<BeamSQLRow>>> groupedStream = exGroupByStream
.apply("groupBy", GroupByKey.<BeamSQLRow, BeamSQLRow>create());
- if (aggCalls.size() > 1) {
- throw new BeamSqlUnsupportedException("only single aggregation is supported now.");
- }
-
- AggregateCall aggCall = aggCalls.get(0);
- switch (aggCall.getAggregation().getName()) {
- case "COUNT":
- PCollection<KV<BeamSQLRow, Long>> aggregatedStream = groupedStream.apply("count",
- Combine.<BeamSQLRow, BeamSQLRow, Long>groupedValues(Count.combineFn()));
- PCollection<BeamSQLRow> mergedStream = aggregatedStream.apply("mergeRecord",
- ParDo.of(new BeamAggregationTransform.MergeAggregationRecord(
- BeamSQLRecordType.from(getRowType()), aggCall.getName())));
- planCreator.pushUpstream(mergedStream);
- break;
- default:
- //Only support COUNT now, more are added in BEAM-2008
- throw new BeamSqlUnsupportedException(
- String.format("Unsupported aggregation [%s]", aggCall.getAggregation().getName()));
- }
+ //3. run aggregation functions
+ PCollection<KV<BeamSQLRow, BeamSQLRow>> aggregatedStream = groupedStream.apply("aggregation",
+ Combine.<BeamSQLRow, BeamSQLRow, BeamSQLRow>groupedValues(
+ new BeamAggregationTransforms.AggregationCombineFn(getAggCallList(),
+ BeamSQLRecordType.from(input.getRowType()))));
+
+ //4. flat KV to a single record
+ PCollection<BeamSQLRow> mergedStream = aggregatedStream.apply("mergeRecord",
+ ParDo.of(new BeamAggregationTransforms.MergeAggregationRecord(
+ BeamSQLRecordType.from(getRowType()), getAggCallList())));
+ planCreator.pushUpstream(mergedStream);
return planCreator.getPipeline();
}
http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSQLRow.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSQLRow.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSQLRow.java
index 65f4a41..5bdd5d2 100644
--- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSQLRow.java
+++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSQLRow.java
@@ -144,6 +144,10 @@ public class BeamSQLRow implements Serializable {
dataValues.set(index, fieldValue);
}
+ public byte getByte(int idx) {
+ return (Byte) getFieldValue(idx);
+ }
+
public short getShort(int idx) {
return (Short) getFieldValue(idx);
}
http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlRowCoder.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlRowCoder.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlRowCoder.java
index 3100ba5..0accb9a 100644
--- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlRowCoder.java
+++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlRowCoder.java
@@ -70,9 +70,11 @@ public class BeamSqlRowCoder extends StandardCoder<BeamSQLRow>{
intCoder.encode(value.getInteger(idx), outStream, context.nested());
break;
case SMALLINT:
- case TINYINT:
intCoder.encode((int) value.getShort(idx), outStream, context.nested());
break;
+ case TINYINT:
+ intCoder.encode((int) value.getByte(idx), outStream, context.nested());
+ break;
case DOUBLE:
doubleCoder.encode(value.getDouble(idx), outStream, context.nested());
break;
http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransform.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransform.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransform.java
deleted file mode 100644
index f478363..0000000
--- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransform.java
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.dsls.sql.transform;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
-import org.apache.beam.dsls.sql.schema.BeamSQLRecordType;
-import org.apache.beam.dsls.sql.schema.BeamSQLRow;
-import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.SerializableFunction;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.values.KV;
-import org.apache.calcite.util.ImmutableBitSet;
-import org.joda.time.Instant;
-
-/**
- * Collections of {@code PTransform} and {@code DoFn} used to perform GROUP-BY operation.
- */
-public class BeamAggregationTransform implements Serializable{
- /**
- * Merge KV to single record.
- */
- public static class MergeAggregationRecord extends DoFn<KV<BeamSQLRow, Long>, BeamSQLRow> {
- private BeamSQLRecordType outRecordType;
- private String aggFieldName;
-
- public MergeAggregationRecord(BeamSQLRecordType outRecordType, String aggFieldName) {
- this.outRecordType = outRecordType;
- this.aggFieldName = aggFieldName;
- }
-
- @ProcessElement
- public void processElement(ProcessContext c, BoundedWindow window) {
- BeamSQLRow outRecord = new BeamSQLRow(outRecordType);
- outRecord.updateWindowRange(c.element().getKey(), window);
-
- KV<BeamSQLRow, Long> kvRecord = c.element();
- for (String f : kvRecord.getKey().getDataType().getFieldsName()) {
- outRecord.addField(f, kvRecord.getKey().getFieldValue(f));
- }
- outRecord.addField(aggFieldName, kvRecord.getValue());
-
-// if (c.pane().isLast()) {
- c.output(outRecord);
-// }
- }
- }
-
- /**
- * extract group-by fields.
- */
- public static class AggregationGroupByKeyFn
- implements SerializableFunction<BeamSQLRow, BeamSQLRow> {
- private List<Integer> groupByKeys;
-
- public AggregationGroupByKeyFn(int windowFieldIdx, ImmutableBitSet groupSet) {
- this.groupByKeys = new ArrayList<>();
- for (int i : groupSet.asList()) {
- if (i != windowFieldIdx) {
- groupByKeys.add(i);
- }
- }
- }
-
- @Override
- public BeamSQLRow apply(BeamSQLRow input) {
- BeamSQLRecordType typeOfKey = exTypeOfKeyRecord(input.getDataType());
- BeamSQLRow keyOfRecord = new BeamSQLRow(typeOfKey);
- keyOfRecord.updateWindowRange(input, null);
-
- for (int idx = 0; idx < groupByKeys.size(); ++idx) {
- keyOfRecord.addField(idx, input.getFieldValue(groupByKeys.get(idx)));
- }
- return keyOfRecord;
- }
-
- private BeamSQLRecordType exTypeOfKeyRecord(BeamSQLRecordType dataType) {
- BeamSQLRecordType typeOfKey = new BeamSQLRecordType();
- for (int idx : groupByKeys) {
- typeOfKey.addField(dataType.getFieldsName().get(idx), dataType.getFieldsType().get(idx));
- }
- return typeOfKey;
- }
-
- }
-
- /**
- * Assign event timestamp.
- */
- public static class WindowTimestampFn implements SerializableFunction<BeamSQLRow, Instant> {
- private int windowFieldIdx = -1;
-
- public WindowTimestampFn(int windowFieldIdx) {
- super();
- this.windowFieldIdx = windowFieldIdx;
- }
-
- @Override
- public Instant apply(BeamSQLRow input) {
- return new Instant(input.getDate(windowFieldIdx).getTime());
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java
new file mode 100644
index 0000000..943c897
--- /dev/null
+++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java
@@ -0,0 +1,671 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.dsls.sql.transform;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Date;
+import java.util.List;
+import org.apache.beam.dsls.sql.exception.BeamSqlUnsupportedException;
+import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlExpression;
+import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlInputRefExpression;
+import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlPrimitive;
+import org.apache.beam.dsls.sql.schema.BeamSQLRecordType;
+import org.apache.beam.dsls.sql.schema.BeamSQLRow;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.values.KV;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.joda.time.Instant;
+
+/**
+ * Collections of {@code PTransform} and {@code DoFn} used to perform GROUP-BY operation.
+ */
+public class BeamAggregationTransforms implements Serializable{
+ /**
+ * Merge KV to single record.
+ */
+ public static class MergeAggregationRecord extends DoFn<KV<BeamSQLRow, BeamSQLRow>, BeamSQLRow> {
+ private BeamSQLRecordType outRecordType;
+ private List<String> aggFieldNames;
+
+ public MergeAggregationRecord(BeamSQLRecordType outRecordType, List<AggregateCall> aggList) {
+ this.outRecordType = outRecordType;
+ this.aggFieldNames = new ArrayList<>();
+ for (AggregateCall ac : aggList) {
+ aggFieldNames.add(ac.getName());
+ }
+ }
+
+ @ProcessElement
+ public void processElement(ProcessContext c, BoundedWindow window) {
+ BeamSQLRow outRecord = new BeamSQLRow(outRecordType);
+ outRecord.updateWindowRange(c.element().getKey(), window);
+
+ KV<BeamSQLRow, BeamSQLRow> kvRecord = c.element();
+ for (String f : kvRecord.getKey().getDataType().getFieldsName()) {
+ outRecord.addField(f, kvRecord.getKey().getFieldValue(f));
+ }
+ for (int idx = 0; idx < aggFieldNames.size(); ++idx) {
+ outRecord.addField(aggFieldNames.get(idx), kvRecord.getValue().getFieldValue(idx));
+ }
+
+ // if (c.pane().isLast()) {
+ c.output(outRecord);
+ // }
+ }
+ }
+
+ /**
+ * extract group-by fields.
+ */
+ public static class AggregationGroupByKeyFn
+ implements SerializableFunction<BeamSQLRow, BeamSQLRow> {
+ private List<Integer> groupByKeys;
+
+ public AggregationGroupByKeyFn(int windowFieldIdx, ImmutableBitSet groupSet) {
+ this.groupByKeys = new ArrayList<>();
+ for (int i : groupSet.asList()) {
+ if (i != windowFieldIdx) {
+ groupByKeys.add(i);
+ }
+ }
+ }
+
+ @Override
+ public BeamSQLRow apply(BeamSQLRow input) {
+ BeamSQLRecordType typeOfKey = exTypeOfKeyRecord(input.getDataType());
+ BeamSQLRow keyOfRecord = new BeamSQLRow(typeOfKey);
+ keyOfRecord.updateWindowRange(input, null);
+
+ for (int idx = 0; idx < groupByKeys.size(); ++idx) {
+ keyOfRecord.addField(idx, input.getFieldValue(groupByKeys.get(idx)));
+ }
+ return keyOfRecord;
+ }
+
+ private BeamSQLRecordType exTypeOfKeyRecord(BeamSQLRecordType dataType) {
+ BeamSQLRecordType typeOfKey = new BeamSQLRecordType();
+ for (int idx : groupByKeys) {
+ typeOfKey.addField(dataType.getFieldsName().get(idx), dataType.getFieldsType().get(idx));
+ }
+ return typeOfKey;
+ }
+
+ }
+
+ /**
+ * Assign event timestamp.
+ */
+ public static class WindowTimestampFn implements SerializableFunction<BeamSQLRow, Instant> {
+ private int windowFieldIdx = -1;
+
+ public WindowTimestampFn(int windowFieldIdx) {
+ super();
+ this.windowFieldIdx = windowFieldIdx;
+ }
+
+ @Override
+ public Instant apply(BeamSQLRow input) {
+ return new Instant(input.getDate(windowFieldIdx).getTime());
+ }
+ }
+
+ /**
+ * Aggregation function which supports COUNT, MAX, MIN, SUM, AVG.
+ *
+ * <p>Multiple aggregation functions are combined together.
+ * For each aggregation function, it may accept part of all data types:<br>
+ * 1). COUNT works for any data type;<br>
+ * 2). MAX/MIN works for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, TINYINT, TIMESTAMP;<br>
+ * 3). SUM/AVG works for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, TINYINT;<br>
+ *
+ */
+ public static class AggregationCombineFn extends CombineFn<BeamSQLRow, BeamSQLRow, BeamSQLRow> {
+ private BeamSQLRecordType aggDataType;
+
+ private int countIndex = -1;
+
+ List<String> aggFunctions;
+ List<BeamSqlExpression> aggElementExpressions;
+
+ public AggregationCombineFn(List<AggregateCall> aggregationCalls,
+ BeamSQLRecordType sourceRowRecordType) {
+ this.aggDataType = new BeamSQLRecordType();
+ this.aggFunctions = new ArrayList<>();
+ this.aggElementExpressions = new ArrayList<>();
+
+ boolean hasAvg = false;
+ boolean hasCount = false;
+ int countIndex = -1;
+ for (int idx = 0; idx < aggregationCalls.size(); ++idx) {
+ AggregateCall ac = aggregationCalls.get(idx);
+ //verify it's supported.
+ verifySupportedAggregation(ac);
+
+ aggDataType.addField(ac.name, ac.type.getSqlTypeName());
+
+ SqlAggFunction aggFn = ac.getAggregation();
+ switch (aggFn.getName()) {
+ case "COUNT":
+ aggElementExpressions.add(BeamSqlPrimitive.<Long>of(SqlTypeName.BIGINT, 1L));
+ hasCount = true;
+ countIndex = idx;
+ break;
+ case "SUM":
+ case "MAX":
+ case "MIN":
+ case "AVG":
+ int refIndex = ac.getArgList().get(0);
+ aggElementExpressions.add(new BeamSqlInputRefExpression(
+ sourceRowRecordType.getFieldsType().get(refIndex), refIndex));
+ if ("AVG".equals(aggFn.getName())) {
+ hasAvg = true;
+ }
+ break;
+
+ default:
+ break;
+ }
+ aggFunctions.add(aggFn.getName());
+ }
+ // add a COUNT holder if only have AVG
+ if (hasAvg && !hasCount) {
+ aggDataType.addField("__COUNT", SqlTypeName.BIGINT);
+
+ aggFunctions.add("COUNT");
+ aggElementExpressions.add(BeamSqlPrimitive.<Long>of(SqlTypeName.BIGINT, 1L));
+
+ hasCount = true;
+ countIndex = aggDataType.size() - 1;
+ }
+
+ this.countIndex = countIndex;
+ }
+
+ private void verifySupportedAggregation(AggregateCall ac) {
+ //donot support DISTINCT
+ if (ac.isDistinct()) {
+ throw new BeamSqlUnsupportedException("DISTINCT is not supported yet.");
+ }
+ String aggFnName = ac.getAggregation().getName();
+ switch (aggFnName) {
+ case "COUNT":
+ //COUNT works for any data type;
+ break;
+ case "SUM":
+ // SUM only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT,
+ // TINYINT now
+ if (!Arrays
+ .asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT, SqlTypeName.DOUBLE,
+ SqlTypeName.SMALLINT, SqlTypeName.TINYINT)
+ .contains(ac.type.getSqlTypeName())) {
+ throw new BeamSqlUnsupportedException(
+ "SUM only support for INT, LONG, FLOAT, DOUBLE, SMALLINT, TINYINT");
+ }
+ break;
+ case "MAX":
+ case "MIN":
+ // MAX/MIN only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT,
+ // TINYINT, TIMESTAMP now
+ if (!Arrays.asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT,
+ SqlTypeName.DOUBLE, SqlTypeName.SMALLINT, SqlTypeName.TINYINT,
+ SqlTypeName.TIMESTAMP).contains(ac.type.getSqlTypeName())) {
+ throw new BeamSqlUnsupportedException("MAX/MIN only support for INT, LONG, FLOAT,"
+ + " DOUBLE, SMALLINT, TINYINT, TIMESTAMP");
+ }
+ break;
+ case "AVG":
+ // AVG only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT,
+ // TINYINT now
+ if (!Arrays
+ .asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT, SqlTypeName.DOUBLE,
+ SqlTypeName.SMALLINT, SqlTypeName.TINYINT)
+ .contains(ac.type.getSqlTypeName())) {
+ throw new BeamSqlUnsupportedException(
+ "AVG only support for INT, LONG, FLOAT, DOUBLE, SMALLINT, TINYINT");
+ }
+ break;
+ default:
+ throw new BeamSqlUnsupportedException(
+ String.format("[%s] is not supported.", aggFnName));
+ }
+ }
+
+ @Override
+ public BeamSQLRow createAccumulator() {
+ BeamSQLRow initialRecord = new BeamSQLRow(aggDataType);
+ for (int idx = 0; idx < aggElementExpressions.size(); ++idx) {
+ BeamSqlExpression ex = aggElementExpressions.get(idx);
+ String aggFnName = aggFunctions.get(idx);
+ switch (aggFnName) {
+ case "COUNT":
+ initialRecord.addField(idx, 0L);
+ break;
+ case "AVG":
+ case "SUM":
+ //for both AVG/SUM, a summary value is hold at first.
+ switch (ex.getOutputType()) {
+ case INTEGER:
+ initialRecord.addField(idx, 0);
+ break;
+ case BIGINT:
+ initialRecord.addField(idx, 0L);
+ break;
+ case SMALLINT:
+ initialRecord.addField(idx, (short) 0);
+ break;
+ case TINYINT:
+ initialRecord.addField(idx, (byte) 0);
+ break;
+ case FLOAT:
+ initialRecord.addField(idx, 0.0f);
+ break;
+ case DOUBLE:
+ initialRecord.addField(idx, 0.0);
+ break;
+ default:
+ break;
+ }
+ break;
+ case "MAX":
+ switch (ex.getOutputType()) {
+ case INTEGER:
+ initialRecord.addField(idx, Integer.MIN_VALUE);
+ break;
+ case BIGINT:
+ initialRecord.addField(idx, Long.MIN_VALUE);
+ break;
+ case SMALLINT:
+ initialRecord.addField(idx, Short.MIN_VALUE);
+ break;
+ case TINYINT:
+ initialRecord.addField(idx, Byte.MIN_VALUE);
+ break;
+ case FLOAT:
+ initialRecord.addField(idx, Float.MIN_VALUE);
+ break;
+ case DOUBLE:
+ initialRecord.addField(idx, Double.MIN_VALUE);
+ break;
+ case TIMESTAMP:
+ initialRecord.addField(idx, new Date(0));
+ break;
+ default:
+ break;
+ }
+ break;
+ case "MIN":
+ switch (ex.getOutputType()) {
+ case INTEGER:
+ initialRecord.addField(idx, Integer.MAX_VALUE);
+ break;
+ case BIGINT:
+ initialRecord.addField(idx, Long.MAX_VALUE);
+ break;
+ case SMALLINT:
+ initialRecord.addField(idx, Short.MAX_VALUE);
+ break;
+ case TINYINT:
+ initialRecord.addField(idx, Byte.MAX_VALUE);
+ break;
+ case FLOAT:
+ initialRecord.addField(idx, Float.MAX_VALUE);
+ break;
+ case DOUBLE:
+ initialRecord.addField(idx, Double.MAX_VALUE);
+ break;
+ case TIMESTAMP:
+ initialRecord.addField(idx, new Date(Long.MAX_VALUE));
+ break;
+ default:
+ break;
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ return initialRecord;
+ }
+
+ @Override
+ public BeamSQLRow addInput(BeamSQLRow accumulator, BeamSQLRow input) {
+ BeamSQLRow deltaRecord = new BeamSQLRow(aggDataType);
+ for (int idx = 0; idx < aggElementExpressions.size(); ++idx) {
+ BeamSqlExpression ex = aggElementExpressions.get(idx);
+ String aggFnName = aggFunctions.get(idx);
+ switch (aggFnName) {
+ case "COUNT":
+ deltaRecord.addField(idx, 1 + accumulator.getLong(idx));
+ break;
+ case "AVG":
+ case "SUM":
+ // for both AVG/SUM, a summary value is hold at first.
+ switch (ex.getOutputType()) {
+ case INTEGER:
+ deltaRecord.addField(idx,
+ ex.evaluate(input).getInteger() + accumulator.getInteger(idx));
+ break;
+ case BIGINT:
+ deltaRecord.addField(idx, ex.evaluate(input).getLong() + accumulator.getLong(idx));
+ break;
+ case SMALLINT:
+ deltaRecord.addField(idx,
+ (short) (ex.evaluate(input).getShort() + accumulator.getShort(idx)));
+ break;
+ case TINYINT:
+ deltaRecord.addField(idx,
+ (byte) (ex.evaluate(input).getByte() + accumulator.getByte(idx)));
+ break;
+ case FLOAT:
+ deltaRecord.addField(idx,
+ (float) (ex.evaluate(input).getFloat() + accumulator.getFloat(idx)));
+ break;
+ case DOUBLE:
+ deltaRecord.addField(idx, ex.evaluate(input).getDouble() + accumulator.getDouble(idx));
+ break;
+ default:
+ break;
+ }
+ break;
+ case "MAX":
+ switch (ex.getOutputType()) {
+ case INTEGER:
+ deltaRecord.addField(idx,
+ Math.max(ex.evaluate(input).getInteger(), accumulator.getInteger(idx)));
+ break;
+ case BIGINT:
+ deltaRecord.addField(idx,
+ Math.max(ex.evaluate(input).getLong(), accumulator.getLong(idx)));
+ break;
+ case SMALLINT:
+ deltaRecord.addField(idx,
+ (short) Math.max(ex.evaluate(input).getShort(), accumulator.getShort(idx)));
+ break;
+ case TINYINT:
+ deltaRecord.addField(idx,
+ (byte) Math.max(ex.evaluate(input).getByte(), accumulator.getByte(idx)));
+ break;
+ case FLOAT:
+ deltaRecord.addField(idx,
+ Math.max(ex.evaluate(input).getFloat(), accumulator.getFloat(idx)));
+ break;
+ case DOUBLE:
+ deltaRecord.addField(idx,
+ Math.max(ex.evaluate(input).getDouble(), accumulator.getDouble(idx)));
+ break;
+ case TIMESTAMP:
+ Date preDate = accumulator.getDate(idx);
+ Date nowDate = ex.evaluate(input).getDate();
+ deltaRecord.addField(idx, preDate.getTime() > nowDate.getTime() ? preDate : nowDate);
+ break;
+ default:
+ break;
+ }
+ break;
+ case "MIN":
+ switch (ex.getOutputType()) {
+ case INTEGER:
+ deltaRecord.addField(idx,
+ Math.min(ex.evaluate(input).getInteger(), accumulator.getInteger(idx)));
+ break;
+ case BIGINT:
+ deltaRecord.addField(idx,
+ Math.min(ex.evaluate(input).getLong(), accumulator.getLong(idx)));
+ break;
+ case SMALLINT:
+ deltaRecord.addField(idx,
+ (short) Math.min(ex.evaluate(input).getShort(), accumulator.getShort(idx)));
+ break;
+ case TINYINT:
+ deltaRecord.addField(idx,
+ (byte) Math.min(ex.evaluate(input).getByte(), accumulator.getByte(idx)));
+ break;
+ case FLOAT:
+ deltaRecord.addField(idx,
+ Math.min(ex.evaluate(input).getFloat(), accumulator.getFloat(idx)));
+ break;
+ case DOUBLE:
+ deltaRecord.addField(idx,
+ Math.min(ex.evaluate(input).getDouble(), accumulator.getDouble(idx)));
+ break;
+ case TIMESTAMP:
+ Date preDate = accumulator.getDate(idx);
+ Date nowDate = ex.evaluate(input).getDate();
+ deltaRecord.addField(idx, preDate.getTime() < nowDate.getTime() ? preDate : nowDate);
+ break;
+ default:
+ break;
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ return deltaRecord;
+ }
+
+ @Override
+ public BeamSQLRow mergeAccumulators(Iterable<BeamSQLRow> accumulators) {
+ BeamSQLRow deltaRecord = new BeamSQLRow(aggDataType);
+
+ while (accumulators.iterator().hasNext()) {
+ BeamSQLRow sa = accumulators.iterator().next();
+ for (int idx = 0; idx < aggElementExpressions.size(); ++idx) {
+ BeamSqlExpression ex = aggElementExpressions.get(idx);
+ String aggFnName = aggFunctions.get(idx);
+ switch (aggFnName) {
+ case "COUNT":
+ deltaRecord.addField(idx, deltaRecord.getLong(idx) + sa.getLong(idx));
+ break;
+ case "AVG":
+ case "SUM":
+ // for both AVG/SUM, a summary value is hold at first.
+ switch (ex.getOutputType()) {
+ case INTEGER:
+ deltaRecord.addField(idx, deltaRecord.getInteger(idx) + sa.getInteger(idx));
+ break;
+ case BIGINT:
+ deltaRecord.addField(idx, deltaRecord.getLong(idx) + sa.getLong(idx));
+ break;
+ case SMALLINT:
+ deltaRecord.addField(idx, (short) (deltaRecord.getShort(idx) + sa.getShort(idx)));
+ break;
+ case TINYINT:
+ deltaRecord.addField(idx, (byte) (deltaRecord.getByte(idx) + sa.getByte(idx)));
+ break;
+ case FLOAT:
+ deltaRecord.addField(idx, (float) (deltaRecord.getFloat(idx) + sa.getFloat(idx)));
+ break;
+ case DOUBLE:
+ deltaRecord.addField(idx, deltaRecord.getDouble(idx) + sa.getDouble(idx));
+ break;
+ default:
+ break;
+ }
+ break;
+ case "MAX":
+ switch (ex.getOutputType()) {
+ case INTEGER:
+ deltaRecord.addField(idx, Math.max(deltaRecord.getInteger(idx), sa.getInteger(idx)));
+ break;
+ case BIGINT:
+ deltaRecord.addField(idx, Math.max(deltaRecord.getLong(idx), sa.getLong(idx)));
+ break;
+ case SMALLINT:
+ deltaRecord.addField(idx,
+ (short) Math.max(deltaRecord.getShort(idx), sa.getShort(idx)));
+ break;
+ case TINYINT:
+ deltaRecord.addField(idx, (byte) Math.max(deltaRecord.getByte(idx), sa.getByte(idx)));
+ break;
+ case FLOAT:
+ deltaRecord.addField(idx, Math.max(deltaRecord.getFloat(idx), sa.getFloat(idx)));
+ break;
+ case DOUBLE:
+ deltaRecord.addField(idx, Math.max(deltaRecord.getDouble(idx), sa.getDouble(idx)));
+ break;
+ case TIMESTAMP:
+ Date preDate = deltaRecord.getDate(idx);
+ Date nowDate = sa.getDate(idx);
+ deltaRecord.addField(idx, preDate.getTime() > nowDate.getTime() ? preDate : nowDate);
+ break;
+ default:
+ break;
+ }
+ break;
+ case "MIN":
+ switch (ex.getOutputType()) {
+ case INTEGER:
+ deltaRecord.addField(idx, Math.min(deltaRecord.getInteger(idx), sa.getInteger(idx)));
+ break;
+ case BIGINT:
+ deltaRecord.addField(idx, Math.min(deltaRecord.getLong(idx), sa.getLong(idx)));
+ break;
+ case SMALLINT:
+ deltaRecord.addField(idx,
+ (short) Math.min(deltaRecord.getShort(idx), sa.getShort(idx)));
+ break;
+ case TINYINT:
+ deltaRecord.addField(idx, (byte) Math.min(deltaRecord.getByte(idx), sa.getByte(idx)));
+ break;
+ case FLOAT:
+ deltaRecord.addField(idx, Math.min(deltaRecord.getFloat(idx), sa.getFloat(idx)));
+ break;
+ case DOUBLE:
+ deltaRecord.addField(idx, Math.min(deltaRecord.getDouble(idx), sa.getDouble(idx)));
+ break;
+ case TIMESTAMP:
+ Date preDate = deltaRecord.getDate(idx);
+ Date nowDate = sa.getDate(idx);
+ deltaRecord.addField(idx, preDate.getTime() < nowDate.getTime() ? preDate : nowDate);
+ break;
+ default:
+ break;
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ return deltaRecord;
+ }
+
+ @Override
+ public BeamSQLRow extractOutput(BeamSQLRow accumulator) {
+ BeamSQLRow finalRecord = new BeamSQLRow(aggDataType);
+ for (int idx = 0; idx < aggElementExpressions.size(); ++idx) {
+ BeamSqlExpression ex = aggElementExpressions.get(idx);
+ String aggFnName = aggFunctions.get(idx);
+ switch (aggFnName) {
+ case "COUNT":
+ finalRecord.addField(idx, accumulator.getLong(idx));
+ break;
+ case "AVG":
+ long count = accumulator.getLong(countIndex);
+ switch (ex.getOutputType()) {
+ case INTEGER:
+ finalRecord.addField(idx, (int) (accumulator.getInteger(idx) / count));
+ break;
+ case BIGINT:
+ finalRecord.addField(idx, accumulator.getLong(idx) / count);
+ break;
+ case SMALLINT:
+ finalRecord.addField(idx, (short) (accumulator.getShort(idx) / count));
+ break;
+ case TINYINT:
+ finalRecord.addField(idx, (byte) (accumulator.getByte(idx) / count));
+ break;
+ case FLOAT:
+ finalRecord.addField(idx, (float) (accumulator.getFloat(idx) / count));
+ break;
+ case DOUBLE:
+ finalRecord.addField(idx, accumulator.getDouble(idx) / count);
+ break;
+ default:
+ break;
+ }
+ break;
+ case "SUM":
+ switch (ex.getOutputType()) {
+ case INTEGER:
+ finalRecord.addField(idx, accumulator.getInteger(idx));
+ break;
+ case BIGINT:
+ finalRecord.addField(idx, accumulator.getLong(idx));
+ break;
+ case SMALLINT:
+ finalRecord.addField(idx, accumulator.getShort(idx));
+ break;
+ case TINYINT:
+ finalRecord.addField(idx, accumulator.getByte(idx));
+ break;
+ case FLOAT:
+ finalRecord.addField(idx, accumulator.getFloat(idx));
+ break;
+ case DOUBLE:
+ finalRecord.addField(idx, accumulator.getDouble(idx));
+ break;
+ default:
+ break;
+ }
+ break;
+ case "MAX":
+ case "MIN":
+ switch (ex.getOutputType()) {
+ case INTEGER:
+ finalRecord.addField(idx, accumulator.getInteger(idx));
+ break;
+ case BIGINT:
+ finalRecord.addField(idx, accumulator.getLong(idx));
+ break;
+ case SMALLINT:
+ finalRecord.addField(idx, accumulator.getShort(idx));
+ break;
+ case TINYINT:
+ finalRecord.addField(idx, accumulator.getByte(idx));
+ break;
+ case FLOAT:
+ finalRecord.addField(idx, accumulator.getFloat(idx));
+ break;
+ case DOUBLE:
+ finalRecord.addField(idx, accumulator.getDouble(idx));
+ break;
+ case TIMESTAMP:
+ finalRecord.addField(idx, accumulator.getDate(idx));
+ break;
+ default:
+ break;
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ return finalRecord;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java
new file mode 100644
index 0000000..f174b9c
--- /dev/null
+++ b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java
@@ -0,0 +1,436 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.dsls.sql.schema.transform;
+
+import java.text.ParseException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.beam.dsls.sql.planner.BeamQueryPlanner;
+import org.apache.beam.dsls.sql.schema.BeamSQLRecordType;
+import org.apache.beam.dsls.sql.schema.BeamSQLRecordTypeCoder;
+import org.apache.beam.dsls.sql.schema.BeamSQLRow;
+import org.apache.beam.dsls.sql.schema.BeamSqlRowCoder;
+import org.apache.beam.dsls.sql.transform.BeamAggregationTransforms;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.WithKeys;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory.FieldInfoBuilder;
+import org.apache.calcite.rel.type.RelDataTypeSystem;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.fun.SqlAvgAggFunction;
+import org.apache.calcite.sql.fun.SqlCountAggFunction;
+import org.apache.calcite.sql.fun.SqlMinMaxAggFunction;
+import org.apache.calcite.sql.fun.SqlSumAggFunction;
+import org.apache.calcite.sql.type.BasicSqlType;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.junit.Rule;
+import org.junit.Test;
+
+/**
+ * Unit tests for {@link BeamAggregationTransforms}.
+ *
+ */
+public class BeamAggregationTransformTest extends BeamTransformBaseTest{
+
+ @Rule
+ public TestPipeline p = TestPipeline.create();
+
+ private List<AggregateCall> aggCalls;
+ private BeamSQLRecordType keyType = initTypeOfSqlRow(
+ Arrays.asList(KV.of("f_int", SqlTypeName.INTEGER)));
+
+ /**
+ * This step equals to below query.
+ * <pre>
+ * SELECT `f_int`
+ * , COUNT(*) AS `size`
+ * , SUM(`f_long`) AS `sum1`, AVG(`f_long`) AS `avg1`
+ * , MAX(`f_long`) AS `max1`, MIN(`f_long`) AS `min1`
+ * , SUM(`f_short`) AS `sum2`, AVG(`f_short`) AS `avg2`
+ * , MAX(`f_short`) AS `max2`, MIN(`f_short`) AS `min2`
+ * , SUM(`f_byte`) AS `sum3`, AVG(`f_byte`) AS `avg3`
+ * , MAX(`f_byte`) AS `max3`, MIN(`f_byte`) AS `min3`
+ * , SUM(`f_float`) AS `sum4`, AVG(`f_float`) AS `avg4`
+ * , MAX(`f_float`) AS `max4`, MIN(`f_float`) AS `min4`
+ * , SUM(`f_double`) AS `sum5`, AVG(`f_double`) AS `avg5`
+ * , MAX(`f_double`) AS `max5`, MIN(`f_double`) AS `min5`
+ * , MAX(`f_timestamp`) AS `max7`, MIN(`f_timestamp`) AS `min7`
+ * ,SUM(`f_int2`) AS `sum8`, AVG(`f_int2`) AS `avg8`
+ * , MAX(`f_int2`) AS `max8`, MIN(`f_int2`) AS `min8`
+ * FROM TABLE_NAME
+ * GROUP BY `f_int`
+ * </pre>
+ * @throws ParseException
+ */
+ @Test
+ public void testCountPerElementBasic() throws ParseException {
+ setupEnvironment();
+
+ PCollection<BeamSQLRow> input = p.apply(Create.of(inputRows));
+
+ //1. extract fields in group-by key part
+ PCollection<KV<BeamSQLRow, BeamSQLRow>> exGroupByStream = input.apply("exGroupBy",
+ WithKeys
+ .of(new BeamAggregationTransforms.AggregationGroupByKeyFn(-1, ImmutableBitSet.of(0))));
+
+ //2. apply a GroupByKey.
+ PCollection<KV<BeamSQLRow, Iterable<BeamSQLRow>>> groupedStream = exGroupByStream
+ .apply("groupBy", GroupByKey.<BeamSQLRow, BeamSQLRow>create());
+
+ //3. run aggregation functions
+ PCollection<KV<BeamSQLRow, BeamSQLRow>> aggregatedStream = groupedStream.apply("aggregation",
+ Combine.<BeamSQLRow, BeamSQLRow, BeamSQLRow>groupedValues(
+ new BeamAggregationTransforms.AggregationCombineFn(aggCalls, inputRowType)));
+
+ //4. flat KV to a single record
+ PCollection<BeamSQLRow> mergedStream = aggregatedStream.apply("mergeRecord",
+ ParDo.of(new BeamAggregationTransforms.MergeAggregationRecord(
+ BeamSQLRecordType.from(prepareFinalRowType()), aggCalls)));
+
+ //assert function BeamAggregationTransform.AggregationGroupByKeyFn
+ PAssert.that(exGroupByStream).containsInAnyOrder(prepareResultOfAggregationGroupByKeyFn());
+
+ //assert BeamAggregationTransform.AggregationCombineFn
+ PAssert.that(aggregatedStream).containsInAnyOrder(prepareResultOfAggregationCombineFn());
+
+ //assert BeamAggregationTransform.MergeAggregationRecord
+ PAssert.that(mergedStream).containsInAnyOrder(prepareResultOfMergeAggregationRecord());
+
+ p.run();
+}
+
+ private void setupEnvironment() {
+ regiesterCoder();
+ prepareAggregationCalls();
+ }
+
+ /**
+ * Add Coders in BeamSQL.
+ */
+ private void regiesterCoder() {
+ CoderRegistry cr = p.getCoderRegistry();
+ cr.registerCoder(BeamSQLRow.class, BeamSqlRowCoder.of());
+ cr.registerCoder(BeamSQLRecordType.class, BeamSQLRecordTypeCoder.of());
+ }
+
+ /**
+ * create list of all {@link AggregateCall}.
+ */
+ @SuppressWarnings("deprecation")
+ private void prepareAggregationCalls() {
+ //aggregations for all data type
+ aggCalls = new ArrayList<>();
+ aggCalls.add(
+ new AggregateCall(new SqlCountAggFunction(), false,
+ Arrays.<Integer>asList(),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT),
+ "count")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlSumAggFunction(
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT)), false,
+ Arrays.<Integer>asList(1),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT),
+ "sum1")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+ Arrays.<Integer>asList(1),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT),
+ "avg1")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+ Arrays.<Integer>asList(1),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT),
+ "max1")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+ Arrays.<Integer>asList(1),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT),
+ "min1")
+ );
+
+ aggCalls.add(
+ new AggregateCall(new SqlSumAggFunction(
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT)), false,
+ Arrays.<Integer>asList(2),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT),
+ "sum2")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+ Arrays.<Integer>asList(2),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT),
+ "avg2")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+ Arrays.<Integer>asList(2),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT),
+ "max2")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+ Arrays.<Integer>asList(2),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT),
+ "min2")
+ );
+
+ aggCalls.add(
+ new AggregateCall(
+ new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT)),
+ false,
+ Arrays.<Integer>asList(3),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT),
+ "sum3")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+ Arrays.<Integer>asList(3),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT),
+ "avg3")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+ Arrays.<Integer>asList(3),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT),
+ "max3")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+ Arrays.<Integer>asList(3),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT),
+ "min3")
+ );
+
+ aggCalls.add(
+ new AggregateCall(
+ new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT)),
+ false,
+ Arrays.<Integer>asList(4),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT),
+ "sum4")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+ Arrays.<Integer>asList(4),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT),
+ "avg4")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+ Arrays.<Integer>asList(4),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT),
+ "max4")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+ Arrays.<Integer>asList(4),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT),
+ "min4")
+ );
+
+ aggCalls.add(
+ new AggregateCall(
+ new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE)),
+ false,
+ Arrays.<Integer>asList(5),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE),
+ "sum5")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+ Arrays.<Integer>asList(5),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE),
+ "avg5")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+ Arrays.<Integer>asList(5),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE),
+ "max5")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+ Arrays.<Integer>asList(5),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE),
+ "min5")
+ );
+
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+ Arrays.<Integer>asList(7),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TIMESTAMP),
+ "max7")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+ Arrays.<Integer>asList(7),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TIMESTAMP),
+ "min7")
+ );
+
+ aggCalls.add(
+ new AggregateCall(
+ new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER)),
+ false,
+ Arrays.<Integer>asList(8),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER),
+ "sum8")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+ Arrays.<Integer>asList(8),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER),
+ "avg8")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+ Arrays.<Integer>asList(8),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER),
+ "max8")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+ Arrays.<Integer>asList(8),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER),
+ "min8")
+ );
+ }
+
+ /**
+ * expected results after {@link BeamAggregationTransforms.AggregationGroupByKeyFn}.
+ */
+ private List<KV<BeamSQLRow, BeamSQLRow>> prepareResultOfAggregationGroupByKeyFn() {
+ return Arrays.asList(
+ KV.of(new BeamSQLRow(keyType, Arrays.<Object>asList(inputRows.get(0).getInteger(0))),
+ inputRows.get(0)),
+ KV.of(new BeamSQLRow(keyType, Arrays.<Object>asList(inputRows.get(1).getInteger(0))),
+ inputRows.get(1)),
+ KV.of(new BeamSQLRow(keyType, Arrays.<Object>asList(inputRows.get(2).getInteger(0))),
+ inputRows.get(2)),
+ KV.of(new BeamSQLRow(keyType, Arrays.<Object>asList(inputRows.get(3).getInteger(0))),
+ inputRows.get(3)));
+ }
+
+ /**
+ * expected results after {@link BeamAggregationTransforms.AggregationCombineFn}.
+ */
+ private List<KV<BeamSQLRow, BeamSQLRow>> prepareResultOfAggregationCombineFn()
+ throws ParseException {
+ BeamSQLRecordType aggPartType = initTypeOfSqlRow(
+ Arrays.asList(KV.of("count", SqlTypeName.BIGINT),
+
+ KV.of("sum1", SqlTypeName.BIGINT), KV.of("avg1", SqlTypeName.BIGINT),
+ KV.of("max1", SqlTypeName.BIGINT), KV.of("min1", SqlTypeName.BIGINT),
+
+ KV.of("sum2", SqlTypeName.SMALLINT), KV.of("avg2", SqlTypeName.SMALLINT),
+ KV.of("max2", SqlTypeName.SMALLINT), KV.of("min2", SqlTypeName.SMALLINT),
+
+ KV.of("sum3", SqlTypeName.TINYINT), KV.of("avg3", SqlTypeName.TINYINT),
+ KV.of("max3", SqlTypeName.TINYINT), KV.of("min3", SqlTypeName.TINYINT),
+
+ KV.of("sum4", SqlTypeName.FLOAT), KV.of("avg4", SqlTypeName.FLOAT),
+ KV.of("max4", SqlTypeName.FLOAT), KV.of("min4", SqlTypeName.FLOAT),
+
+ KV.of("sum5", SqlTypeName.DOUBLE), KV.of("avg5", SqlTypeName.DOUBLE),
+ KV.of("max5", SqlTypeName.DOUBLE), KV.of("min5", SqlTypeName.DOUBLE),
+
+ KV.of("max7", SqlTypeName.TIMESTAMP), KV.of("min7", SqlTypeName.TIMESTAMP),
+
+ KV.of("sum8", SqlTypeName.INTEGER), KV.of("avg8", SqlTypeName.INTEGER),
+ KV.of("max8", SqlTypeName.INTEGER), KV.of("min8", SqlTypeName.INTEGER)
+ ));
+ return Arrays.asList(
+ KV.of(new BeamSQLRow(keyType, Arrays.<Object>asList(inputRows.get(0).getInteger(0))),
+ new BeamSQLRow(aggPartType, Arrays.<Object>asList(
+ 4L,
+ 10000L, 2500L, 4000L, 1000L,
+ (short) 10, (short) 2, (short) 4, (short) 1,
+ (byte) 10, (byte) 2, (byte) 4, (byte) 1,
+ 10.0F, 2.5F, 4.0F, 1.0F,
+ 10.0, 2.5, 4.0, 1.0,
+ format.parse("2017-01-01 02:04:03"), format.parse("2017-01-01 01:01:03"),
+ 10, 2, 4, 1
+ )))
+ );
+ }
+
+ /**
+ * Row type of final output row.
+ */
+ private RelDataType prepareFinalRowType() {
+ FieldInfoBuilder builder = BeamQueryPlanner.TYPE_FACTORY.builder();
+ List<KV<String, SqlTypeName>> columnMetadata =
+ Arrays.asList(KV.of("f_int", SqlTypeName.INTEGER), KV.of("count", SqlTypeName.BIGINT),
+
+ KV.of("sum1", SqlTypeName.BIGINT), KV.of("avg1", SqlTypeName.BIGINT),
+ KV.of("max1", SqlTypeName.BIGINT), KV.of("min1", SqlTypeName.BIGINT),
+
+ KV.of("sum2", SqlTypeName.SMALLINT), KV.of("avg2", SqlTypeName.SMALLINT),
+ KV.of("max2", SqlTypeName.SMALLINT), KV.of("min2", SqlTypeName.SMALLINT),
+
+ KV.of("sum3", SqlTypeName.TINYINT), KV.of("avg3", SqlTypeName.TINYINT),
+ KV.of("max3", SqlTypeName.TINYINT), KV.of("min3", SqlTypeName.TINYINT),
+
+ KV.of("sum4", SqlTypeName.FLOAT), KV.of("avg4", SqlTypeName.FLOAT),
+ KV.of("max4", SqlTypeName.FLOAT), KV.of("min4", SqlTypeName.FLOAT),
+
+ KV.of("sum5", SqlTypeName.DOUBLE), KV.of("avg5", SqlTypeName.DOUBLE),
+ KV.of("max5", SqlTypeName.DOUBLE), KV.of("min5", SqlTypeName.DOUBLE),
+
+ KV.of("max7", SqlTypeName.TIMESTAMP), KV.of("min7", SqlTypeName.TIMESTAMP),
+
+ KV.of("sum8", SqlTypeName.INTEGER), KV.of("avg8", SqlTypeName.INTEGER),
+ KV.of("max8", SqlTypeName.INTEGER), KV.of("min8", SqlTypeName.INTEGER)
+ );
+ for (KV<String, SqlTypeName> cm : columnMetadata) {
+ builder.add(cm.getKey(), cm.getValue());
+ }
+ return builder.build();
+ }
+
+ /**
+ * expected results after {@link BeamAggregationTransforms.MergeAggregationRecord}.
+ */
+ private BeamSQLRow prepareResultOfMergeAggregationRecord() throws ParseException {
+ return new BeamSQLRow(BeamSQLRecordType.from(prepareFinalRowType()), Arrays.<Object>asList(
+ 1, 4L,
+ 10000L, 2500L, 4000L, 1000L,
+ (short) 10, (short) 2, (short) 4, (short) 1,
+ (byte) 10, (byte) 2, (byte) 4, (byte) 1,
+ 10.0F, 2.5F, 4.0F, 1.0F,
+ 10.0, 2.5, 4.0, 1.0,
+ format.parse("2017-01-01 02:04:03"), format.parse("2017-01-01 01:01:03"),
+ 10, 2, 4, 1
+ ));
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamTransformBaseTest.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamTransformBaseTest.java b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamTransformBaseTest.java
new file mode 100644
index 0000000..820d7f5
--- /dev/null
+++ b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamTransformBaseTest.java
@@ -0,0 +1,96 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.dsls.sql.schema.transform;
+
+import java.text.DateFormat;
+import java.text.ParseException;
+import java.text.SimpleDateFormat;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.beam.dsls.sql.planner.BeamQueryPlanner;
+import org.apache.beam.dsls.sql.schema.BeamSQLRecordType;
+import org.apache.beam.dsls.sql.schema.BeamSQLRow;
+import org.apache.beam.sdk.values.KV;
+import org.apache.calcite.rel.type.RelDataTypeFactory.FieldInfoBuilder;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.junit.BeforeClass;
+
+/**
+ * shared methods to test PTransforms which execute Beam SQL steps.
+ *
+ */
+public class BeamTransformBaseTest {
+ public static DateFormat format = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
+
+ public static BeamSQLRecordType inputRowType;
+ public static List<BeamSQLRow> inputRows;
+
+ @BeforeClass
+ public static void prepareInput() throws NumberFormatException, ParseException{
+ List<KV<String, SqlTypeName>> columnMetadata = Arrays.asList(
+ KV.of("f_int", SqlTypeName.INTEGER), KV.of("f_long", SqlTypeName.BIGINT),
+ KV.of("f_short", SqlTypeName.SMALLINT), KV.of("f_byte", SqlTypeName.TINYINT),
+ KV.of("f_float", SqlTypeName.FLOAT), KV.of("f_double", SqlTypeName.DOUBLE),
+ KV.of("f_string", SqlTypeName.VARCHAR), KV.of("f_timestamp", SqlTypeName.TIMESTAMP),
+ KV.of("f_int2", SqlTypeName.INTEGER)
+ );
+ inputRowType = initTypeOfSqlRow(columnMetadata);
+ inputRows = Arrays.asList(
+ initBeamSqlRow(columnMetadata,
+ Arrays.<Object>asList(1, 1000L, Short.valueOf("1"), Byte.valueOf("1"), 1.0F, 1.0,
+ "string_row1", format.parse("2017-01-01 01:01:03"), 1)),
+ initBeamSqlRow(columnMetadata,
+ Arrays.<Object>asList(1, 2000L, Short.valueOf("2"), Byte.valueOf("2"), 2.0F, 2.0,
+ "string_row2", format.parse("2017-01-01 01:02:03"), 2)),
+ initBeamSqlRow(columnMetadata,
+ Arrays.<Object>asList(1, 3000L, Short.valueOf("3"), Byte.valueOf("3"), 3.0F, 3.0,
+ "string_row3", format.parse("2017-01-01 01:03:03"), 3)),
+ initBeamSqlRow(columnMetadata, Arrays.<Object>asList(1, 4000L, Short.valueOf("4"),
+ Byte.valueOf("4"), 4.0F, 4.0, "string_row4", format.parse("2017-01-01 02:04:03"), 4)));
+ }
+
+ /**
+ * create a {@code BeamSQLRecordType} for given column metadata.
+ */
+ public static BeamSQLRecordType initTypeOfSqlRow(List<KV<String, SqlTypeName>> columnMetadata){
+ FieldInfoBuilder builder = BeamQueryPlanner.TYPE_FACTORY.builder();
+ for (KV<String, SqlTypeName> cm : columnMetadata) {
+ builder.add(cm.getKey(), cm.getValue());
+ }
+ return BeamSQLRecordType.from(builder.build());
+ }
+
+ /**
+ * Create an empty row with given column metadata.
+ */
+ public static BeamSQLRow initBeamSqlRow(List<KV<String, SqlTypeName>> columnMetadata) {
+ return initBeamSqlRow(columnMetadata, Arrays.asList());
+ }
+
+ /**
+ * Create a row with given column metadata, and values for each column.
+ *
+ */
+ public static BeamSQLRow initBeamSqlRow(List<KV<String, SqlTypeName>> columnMetadata,
+ List<Object> rowValues){
+ BeamSQLRecordType rowType = initTypeOfSqlRow(columnMetadata);
+
+ return new BeamSQLRow(rowType, rowValues);
+ }
+
+}
[2/2] beam git commit: This closes #3067
Posted by dh...@apache.org.
This closes #3067
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/523482be
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/523482be
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/523482be
Branch: refs/heads/DSL_SQL
Commit: 523482be0501a7bce79087f47c7752b900178a00
Parents: 6729a02 f728fbe
Author: Dan Halperin <dh...@google.com>
Authored: Fri May 12 17:47:06 2017 -0700
Committer: Dan Halperin <dh...@google.com>
Committed: Fri May 12 17:47:06 2017 -0700
----------------------------------------------------------------------
.../interpreter/operator/BeamSqlPrimitive.java | 35 +
.../beam/dsls/sql/rel/BeamAggregationRel.java | 40 +-
.../apache/beam/dsls/sql/schema/BeamSQLRow.java | 4 +
.../beam/dsls/sql/schema/BeamSqlRowCoder.java | 4 +-
.../sql/transform/BeamAggregationTransform.java | 120 ----
.../transform/BeamAggregationTransforms.java | 671 +++++++++++++++++++
.../transform/BeamAggregationTransformTest.java | 436 ++++++++++++
.../schema/transform/BeamTransformBaseTest.java | 96 +++
8 files changed, 1261 insertions(+), 145 deletions(-)
----------------------------------------------------------------------