You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by dh...@apache.org on 2017/05/13 00:47:12 UTC

[1/2] beam git commit: Support common-used aggregation functions in SQL, including: COUNT, SUM, AVG, MAX, MIN

Repository: beam
Updated Branches:
  refs/heads/DSL_SQL 6729a027d -> 523482be0


Support common-used aggregation functions in SQL, including:
  COUNT,SUM,AVG,MAX,MIN

rename BeamAggregationTransform to BeamAggregationTransforms

update comments


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/f728fbe5
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/f728fbe5
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/f728fbe5

Branch: refs/heads/DSL_SQL
Commit: f728fbe5c7153341ace046fa8b2465ef8844be1b
Parents: 6729a02
Author: mingmxu <mi...@ebay.com>
Authored: Wed May 10 16:38:13 2017 -0700
Committer: mingmxu <mi...@ebay.com>
Committed: Wed May 10 20:47:40 2017 -0700

----------------------------------------------------------------------
 .../interpreter/operator/BeamSqlPrimitive.java  |  35 +
 .../beam/dsls/sql/rel/BeamAggregationRel.java   |  40 +-
 .../apache/beam/dsls/sql/schema/BeamSQLRow.java |   4 +
 .../beam/dsls/sql/schema/BeamSqlRowCoder.java   |   4 +-
 .../sql/transform/BeamAggregationTransform.java | 120 ----
 .../transform/BeamAggregationTransforms.java    | 671 +++++++++++++++++++
 .../transform/BeamAggregationTransformTest.java | 436 ++++++++++++
 .../schema/transform/BeamTransformBaseTest.java |  96 +++
 8 files changed, 1261 insertions(+), 145 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/interpreter/operator/BeamSqlPrimitive.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/interpreter/operator/BeamSqlPrimitive.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/interpreter/operator/BeamSqlPrimitive.java
index 3309577..a5938f3 100644
--- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/interpreter/operator/BeamSqlPrimitive.java
+++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/interpreter/operator/BeamSqlPrimitive.java
@@ -65,6 +65,41 @@ public class BeamSqlPrimitive<T> extends BeamSqlExpression{
     return value;
   }
 
+  public long getLong() {
+    return (Long) getValue();
+  }
+
+  public double getDouble() {
+    return (Double) getValue();
+  }
+
+  public float getFloat() {
+    return (Float) getValue();
+  }
+
+  public int getInteger() {
+    return (Integer) getValue();
+  }
+
+  public short getShort() {
+    return (Short) getValue();
+  }
+
+  public byte getByte() {
+    return (Byte) getValue();
+  }
+  public boolean getBoolean() {
+    return (Boolean) getValue();
+  }
+
+  public String getString() {
+    return (String) getValue();
+  }
+
+  public Date getDate() {
+    return (Date) getValue();
+  }
+
   @Override
   public boolean accept() {
     if (value == null) {

http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java
index 2c7626d..ab98cc4 100644
--- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java
+++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java
@@ -18,15 +18,13 @@
 package org.apache.beam.dsls.sql.rel;
 
 import java.util.List;
-import org.apache.beam.dsls.sql.exception.BeamSqlUnsupportedException;
 import org.apache.beam.dsls.sql.planner.BeamPipelineCreator;
 import org.apache.beam.dsls.sql.planner.BeamSQLRelUtils;
 import org.apache.beam.dsls.sql.schema.BeamSQLRecordType;
 import org.apache.beam.dsls.sql.schema.BeamSQLRow;
-import org.apache.beam.dsls.sql.transform.BeamAggregationTransform;
+import org.apache.beam.dsls.sql.transform.BeamAggregationTransforms;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.transforms.Combine;
-import org.apache.beam.sdk.transforms.Count;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.WithKeys;
@@ -79,7 +77,7 @@ public class BeamAggregationRel extends Aggregate implements BeamRelNode {
     PCollection<BeamSQLRow> upstream = planCreator.popUpstream();
     if (windowFieldIdx != -1) {
       upstream = upstream.apply("assignEventTimestamp", WithTimestamps
-          .<BeamSQLRow>of(new BeamAggregationTransform.WindowTimestampFn(windowFieldIdx)));
+          .<BeamSQLRow>of(new BeamAggregationTransforms.WindowTimestampFn(windowFieldIdx)));
     }
 
     PCollection<BeamSQLRow> windowStream = upstream.apply("window",
@@ -88,32 +86,26 @@ public class BeamAggregationRel extends Aggregate implements BeamRelNode {
         .withAllowedLateness(allowedLatence)
         .accumulatingFiredPanes());
 
+    //1. extract fields in group-by key part
     PCollection<KV<BeamSQLRow, BeamSQLRow>> exGroupByStream = windowStream.apply("exGroupBy",
         WithKeys
-            .of(new BeamAggregationTransform.AggregationGroupByKeyFn(windowFieldIdx, groupSet)));
+            .of(new BeamAggregationTransforms.AggregationGroupByKeyFn(windowFieldIdx, groupSet)));
 
+    //2. apply a GroupByKey.
     PCollection<KV<BeamSQLRow, Iterable<BeamSQLRow>>> groupedStream = exGroupByStream
         .apply("groupBy", GroupByKey.<BeamSQLRow, BeamSQLRow>create());
 
-    if (aggCalls.size() > 1) {
-      throw new BeamSqlUnsupportedException("only single aggregation is supported now.");
-    }
-
-    AggregateCall aggCall = aggCalls.get(0);
-    switch (aggCall.getAggregation().getName()) {
-    case "COUNT":
-      PCollection<KV<BeamSQLRow, Long>> aggregatedStream = groupedStream.apply("count",
-          Combine.<BeamSQLRow, BeamSQLRow, Long>groupedValues(Count.combineFn()));
-      PCollection<BeamSQLRow> mergedStream = aggregatedStream.apply("mergeRecord",
-          ParDo.of(new BeamAggregationTransform.MergeAggregationRecord(
-              BeamSQLRecordType.from(getRowType()), aggCall.getName())));
-      planCreator.pushUpstream(mergedStream);
-      break;
-    default:
-      //Only support COUNT now, more are added in BEAM-2008
-      throw new BeamSqlUnsupportedException(
-          String.format("Unsupported aggregation [%s]", aggCall.getAggregation().getName()));
-    }
+    //3. run aggregation functions
+    PCollection<KV<BeamSQLRow, BeamSQLRow>> aggregatedStream = groupedStream.apply("aggregation",
+        Combine.<BeamSQLRow, BeamSQLRow, BeamSQLRow>groupedValues(
+            new BeamAggregationTransforms.AggregationCombineFn(getAggCallList(),
+                BeamSQLRecordType.from(input.getRowType()))));
+
+    //4. flat KV to a single record
+    PCollection<BeamSQLRow> mergedStream = aggregatedStream.apply("mergeRecord",
+        ParDo.of(new BeamAggregationTransforms.MergeAggregationRecord(
+            BeamSQLRecordType.from(getRowType()), getAggCallList())));
+    planCreator.pushUpstream(mergedStream);
 
     return planCreator.getPipeline();
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSQLRow.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSQLRow.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSQLRow.java
index 65f4a41..5bdd5d2 100644
--- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSQLRow.java
+++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSQLRow.java
@@ -144,6 +144,10 @@ public class BeamSQLRow implements Serializable {
     dataValues.set(index, fieldValue);
   }
 
+  public byte getByte(int idx) {
+    return (Byte) getFieldValue(idx);
+  }
+
   public short getShort(int idx) {
     return (Short) getFieldValue(idx);
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlRowCoder.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlRowCoder.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlRowCoder.java
index 3100ba5..0accb9a 100644
--- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlRowCoder.java
+++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlRowCoder.java
@@ -70,9 +70,11 @@ public class BeamSqlRowCoder extends StandardCoder<BeamSQLRow>{
           intCoder.encode(value.getInteger(idx), outStream, context.nested());
           break;
         case SMALLINT:
-        case TINYINT:
           intCoder.encode((int) value.getShort(idx), outStream, context.nested());
           break;
+        case TINYINT:
+          intCoder.encode((int) value.getByte(idx), outStream, context.nested());
+          break;
         case DOUBLE:
           doubleCoder.encode(value.getDouble(idx), outStream, context.nested());
           break;

http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransform.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransform.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransform.java
deleted file mode 100644
index f478363..0000000
--- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransform.java
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.dsls.sql.transform;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
-import org.apache.beam.dsls.sql.schema.BeamSQLRecordType;
-import org.apache.beam.dsls.sql.schema.BeamSQLRow;
-import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.SerializableFunction;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.values.KV;
-import org.apache.calcite.util.ImmutableBitSet;
-import org.joda.time.Instant;
-
-/**
- * Collections of {@code PTransform} and {@code DoFn} used to perform GROUP-BY operation.
- */
-public class BeamAggregationTransform implements Serializable{
-  /**
-   * Merge KV to single record.
-   */
-  public static class MergeAggregationRecord extends DoFn<KV<BeamSQLRow, Long>, BeamSQLRow> {
-    private BeamSQLRecordType outRecordType;
-    private String aggFieldName;
-
-    public MergeAggregationRecord(BeamSQLRecordType outRecordType, String aggFieldName) {
-      this.outRecordType = outRecordType;
-      this.aggFieldName = aggFieldName;
-    }
-
-    @ProcessElement
-    public void processElement(ProcessContext c, BoundedWindow window) {
-      BeamSQLRow outRecord = new BeamSQLRow(outRecordType);
-      outRecord.updateWindowRange(c.element().getKey(), window);
-
-      KV<BeamSQLRow, Long> kvRecord = c.element();
-      for (String f : kvRecord.getKey().getDataType().getFieldsName()) {
-        outRecord.addField(f, kvRecord.getKey().getFieldValue(f));
-      }
-      outRecord.addField(aggFieldName, kvRecord.getValue());
-
-//      if (c.pane().isLast()) {
-      c.output(outRecord);
-//      }
-    }
-  }
-
-  /**
-   * extract group-by fields.
-   */
-  public static class AggregationGroupByKeyFn
-      implements SerializableFunction<BeamSQLRow, BeamSQLRow> {
-    private List<Integer> groupByKeys;
-
-    public AggregationGroupByKeyFn(int windowFieldIdx, ImmutableBitSet groupSet) {
-      this.groupByKeys = new ArrayList<>();
-      for (int i : groupSet.asList()) {
-        if (i != windowFieldIdx) {
-          groupByKeys.add(i);
-        }
-      }
-    }
-
-    @Override
-    public BeamSQLRow apply(BeamSQLRow input) {
-      BeamSQLRecordType typeOfKey = exTypeOfKeyRecord(input.getDataType());
-      BeamSQLRow keyOfRecord = new BeamSQLRow(typeOfKey);
-      keyOfRecord.updateWindowRange(input, null);
-
-      for (int idx = 0; idx < groupByKeys.size(); ++idx) {
-        keyOfRecord.addField(idx, input.getFieldValue(groupByKeys.get(idx)));
-      }
-      return keyOfRecord;
-    }
-
-    private BeamSQLRecordType exTypeOfKeyRecord(BeamSQLRecordType dataType) {
-      BeamSQLRecordType typeOfKey = new BeamSQLRecordType();
-      for (int idx : groupByKeys) {
-        typeOfKey.addField(dataType.getFieldsName().get(idx), dataType.getFieldsType().get(idx));
-      }
-      return typeOfKey;
-    }
-
-  }
-
-  /**
-   * Assign event timestamp.
-   */
-  public static class WindowTimestampFn implements SerializableFunction<BeamSQLRow, Instant> {
-    private int windowFieldIdx = -1;
-
-    public WindowTimestampFn(int windowFieldIdx) {
-      super();
-      this.windowFieldIdx = windowFieldIdx;
-    }
-
-    @Override
-    public Instant apply(BeamSQLRow input) {
-      return new Instant(input.getDate(windowFieldIdx).getTime());
-    }
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java
new file mode 100644
index 0000000..943c897
--- /dev/null
+++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java
@@ -0,0 +1,671 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.dsls.sql.transform;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Date;
+import java.util.List;
+import org.apache.beam.dsls.sql.exception.BeamSqlUnsupportedException;
+import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlExpression;
+import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlInputRefExpression;
+import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlPrimitive;
+import org.apache.beam.dsls.sql.schema.BeamSQLRecordType;
+import org.apache.beam.dsls.sql.schema.BeamSQLRow;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.values.KV;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.joda.time.Instant;
+
+/**
+ * Collections of {@code PTransform} and {@code DoFn} used to perform GROUP-BY operation.
+ */
+public class BeamAggregationTransforms implements Serializable{
+  /**
+   * Merge KV to single record.
+   */
+  public static class MergeAggregationRecord extends DoFn<KV<BeamSQLRow, BeamSQLRow>, BeamSQLRow> {
+    private BeamSQLRecordType outRecordType;
+    private List<String> aggFieldNames;
+
+    public MergeAggregationRecord(BeamSQLRecordType outRecordType, List<AggregateCall> aggList) {
+      this.outRecordType = outRecordType;
+      this.aggFieldNames = new ArrayList<>();
+      for (AggregateCall ac : aggList) {
+        aggFieldNames.add(ac.getName());
+      }
+    }
+
+    @ProcessElement
+    public void processElement(ProcessContext c, BoundedWindow window) {
+      BeamSQLRow outRecord = new BeamSQLRow(outRecordType);
+      outRecord.updateWindowRange(c.element().getKey(), window);
+
+      KV<BeamSQLRow, BeamSQLRow> kvRecord = c.element();
+      for (String f : kvRecord.getKey().getDataType().getFieldsName()) {
+        outRecord.addField(f, kvRecord.getKey().getFieldValue(f));
+      }
+      for (int idx = 0; idx < aggFieldNames.size(); ++idx) {
+        outRecord.addField(aggFieldNames.get(idx), kvRecord.getValue().getFieldValue(idx));
+      }
+
+      // if (c.pane().isLast()) {
+      c.output(outRecord);
+      // }
+    }
+  }
+
+  /**
+   * extract group-by fields.
+   */
+  public static class AggregationGroupByKeyFn
+      implements SerializableFunction<BeamSQLRow, BeamSQLRow> {
+    private List<Integer> groupByKeys;
+
+    public AggregationGroupByKeyFn(int windowFieldIdx, ImmutableBitSet groupSet) {
+      this.groupByKeys = new ArrayList<>();
+      for (int i : groupSet.asList()) {
+        if (i != windowFieldIdx) {
+          groupByKeys.add(i);
+        }
+      }
+    }
+
+    @Override
+    public BeamSQLRow apply(BeamSQLRow input) {
+      BeamSQLRecordType typeOfKey = exTypeOfKeyRecord(input.getDataType());
+      BeamSQLRow keyOfRecord = new BeamSQLRow(typeOfKey);
+      keyOfRecord.updateWindowRange(input, null);
+
+      for (int idx = 0; idx < groupByKeys.size(); ++idx) {
+        keyOfRecord.addField(idx, input.getFieldValue(groupByKeys.get(idx)));
+      }
+      return keyOfRecord;
+    }
+
+    private BeamSQLRecordType exTypeOfKeyRecord(BeamSQLRecordType dataType) {
+      BeamSQLRecordType typeOfKey = new BeamSQLRecordType();
+      for (int idx : groupByKeys) {
+        typeOfKey.addField(dataType.getFieldsName().get(idx), dataType.getFieldsType().get(idx));
+      }
+      return typeOfKey;
+    }
+
+  }
+
+  /**
+   * Assign event timestamp.
+   */
+  public static class WindowTimestampFn implements SerializableFunction<BeamSQLRow, Instant> {
+    private int windowFieldIdx = -1;
+
+    public WindowTimestampFn(int windowFieldIdx) {
+      super();
+      this.windowFieldIdx = windowFieldIdx;
+    }
+
+    @Override
+    public Instant apply(BeamSQLRow input) {
+      return new Instant(input.getDate(windowFieldIdx).getTime());
+    }
+  }
+
+  /**
+   * Aggregation function which supports COUNT, MAX, MIN, SUM, AVG.
+   *
+   * <p>Multiple aggregation functions are combined together.
+   * For each aggregation function, it may accept part of all data types:<br>
+   *   1). COUNT works for any data type;<br>
+   *   2). MAX/MIN works for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, TINYINT, TIMESTAMP;<br>
+   *   3). SUM/AVG works for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, TINYINT;<br>
+   *
+   */
+  public static class AggregationCombineFn extends CombineFn<BeamSQLRow, BeamSQLRow, BeamSQLRow> {
+    private BeamSQLRecordType aggDataType;
+
+    private int countIndex = -1;
+
+    List<String> aggFunctions;
+    List<BeamSqlExpression> aggElementExpressions;
+
+    public AggregationCombineFn(List<AggregateCall> aggregationCalls,
+        BeamSQLRecordType sourceRowRecordType) {
+      this.aggDataType = new BeamSQLRecordType();
+      this.aggFunctions = new ArrayList<>();
+      this.aggElementExpressions = new ArrayList<>();
+
+      boolean hasAvg = false;
+      boolean hasCount = false;
+      int countIndex = -1;
+      for (int idx = 0; idx < aggregationCalls.size(); ++idx) {
+        AggregateCall ac = aggregationCalls.get(idx);
+        //verify it's supported.
+        verifySupportedAggregation(ac);
+
+        aggDataType.addField(ac.name, ac.type.getSqlTypeName());
+
+        SqlAggFunction aggFn = ac.getAggregation();
+        switch (aggFn.getName()) {
+        case "COUNT":
+          aggElementExpressions.add(BeamSqlPrimitive.<Long>of(SqlTypeName.BIGINT, 1L));
+          hasCount = true;
+          countIndex = idx;
+          break;
+        case "SUM":
+        case "MAX":
+        case "MIN":
+        case "AVG":
+          int refIndex = ac.getArgList().get(0);
+          aggElementExpressions.add(new BeamSqlInputRefExpression(
+              sourceRowRecordType.getFieldsType().get(refIndex), refIndex));
+          if ("AVG".equals(aggFn.getName())) {
+            hasAvg = true;
+          }
+          break;
+
+        default:
+          break;
+        }
+        aggFunctions.add(aggFn.getName());
+      }
+      // add a COUNT holder if only have AVG
+      if (hasAvg && !hasCount) {
+        aggDataType.addField("__COUNT", SqlTypeName.BIGINT);
+
+        aggFunctions.add("COUNT");
+        aggElementExpressions.add(BeamSqlPrimitive.<Long>of(SqlTypeName.BIGINT, 1L));
+
+        hasCount = true;
+        countIndex = aggDataType.size() - 1;
+      }
+
+      this.countIndex = countIndex;
+    }
+
+    private void verifySupportedAggregation(AggregateCall ac) {
+      //donot support DISTINCT
+      if (ac.isDistinct()) {
+        throw new BeamSqlUnsupportedException("DISTINCT is not supported yet.");
+      }
+      String aggFnName = ac.getAggregation().getName();
+      switch (aggFnName) {
+      case "COUNT":
+        //COUNT works for any data type;
+        break;
+      case "SUM":
+        // SUM only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT,
+        // TINYINT now
+        if (!Arrays
+            .asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT, SqlTypeName.DOUBLE,
+                SqlTypeName.SMALLINT, SqlTypeName.TINYINT)
+            .contains(ac.type.getSqlTypeName())) {
+          throw new BeamSqlUnsupportedException(
+              "SUM only support for INT, LONG, FLOAT, DOUBLE, SMALLINT, TINYINT");
+        }
+        break;
+      case "MAX":
+      case "MIN":
+        // MAX/MIN only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT,
+        // TINYINT, TIMESTAMP now
+        if (!Arrays.asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT,
+            SqlTypeName.DOUBLE, SqlTypeName.SMALLINT, SqlTypeName.TINYINT,
+            SqlTypeName.TIMESTAMP).contains(ac.type.getSqlTypeName())) {
+          throw new BeamSqlUnsupportedException("MAX/MIN only support for INT, LONG, FLOAT,"
+              + " DOUBLE, SMALLINT, TINYINT, TIMESTAMP");
+        }
+        break;
+      case "AVG":
+        // AVG only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT,
+        // TINYINT now
+        if (!Arrays
+            .asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT, SqlTypeName.DOUBLE,
+                SqlTypeName.SMALLINT, SqlTypeName.TINYINT)
+            .contains(ac.type.getSqlTypeName())) {
+          throw new BeamSqlUnsupportedException(
+              "AVG only support for INT, LONG, FLOAT, DOUBLE, SMALLINT, TINYINT");
+        }
+        break;
+      default:
+        throw new BeamSqlUnsupportedException(
+            String.format("[%s] is not supported.", aggFnName));
+      }
+    }
+
+    @Override
+    public BeamSQLRow createAccumulator() {
+      BeamSQLRow initialRecord = new BeamSQLRow(aggDataType);
+      for (int idx = 0; idx < aggElementExpressions.size(); ++idx) {
+        BeamSqlExpression ex = aggElementExpressions.get(idx);
+        String aggFnName = aggFunctions.get(idx);
+        switch (aggFnName) {
+        case "COUNT":
+          initialRecord.addField(idx, 0L);
+          break;
+        case "AVG":
+        case "SUM":
+          //for both AVG/SUM, a summary value is hold at first.
+          switch (ex.getOutputType()) {
+          case INTEGER:
+            initialRecord.addField(idx, 0);
+            break;
+          case BIGINT:
+            initialRecord.addField(idx, 0L);
+            break;
+          case SMALLINT:
+            initialRecord.addField(idx, (short) 0);
+            break;
+          case TINYINT:
+            initialRecord.addField(idx, (byte) 0);
+            break;
+          case FLOAT:
+            initialRecord.addField(idx, 0.0f);
+            break;
+          case DOUBLE:
+            initialRecord.addField(idx, 0.0);
+            break;
+          default:
+            break;
+          }
+          break;
+        case "MAX":
+          switch (ex.getOutputType()) {
+          case INTEGER:
+            initialRecord.addField(idx, Integer.MIN_VALUE);
+            break;
+          case BIGINT:
+            initialRecord.addField(idx, Long.MIN_VALUE);
+            break;
+          case SMALLINT:
+            initialRecord.addField(idx, Short.MIN_VALUE);
+            break;
+          case TINYINT:
+            initialRecord.addField(idx, Byte.MIN_VALUE);
+            break;
+          case FLOAT:
+            initialRecord.addField(idx, Float.MIN_VALUE);
+            break;
+          case DOUBLE:
+            initialRecord.addField(idx, Double.MIN_VALUE);
+            break;
+          case TIMESTAMP:
+            initialRecord.addField(idx, new Date(0));
+            break;
+          default:
+            break;
+          }
+          break;
+        case "MIN":
+          switch (ex.getOutputType()) {
+          case INTEGER:
+            initialRecord.addField(idx, Integer.MAX_VALUE);
+            break;
+          case BIGINT:
+            initialRecord.addField(idx, Long.MAX_VALUE);
+            break;
+          case SMALLINT:
+            initialRecord.addField(idx, Short.MAX_VALUE);
+            break;
+          case TINYINT:
+            initialRecord.addField(idx, Byte.MAX_VALUE);
+            break;
+          case FLOAT:
+            initialRecord.addField(idx, Float.MAX_VALUE);
+            break;
+          case DOUBLE:
+            initialRecord.addField(idx, Double.MAX_VALUE);
+            break;
+          case TIMESTAMP:
+            initialRecord.addField(idx, new Date(Long.MAX_VALUE));
+            break;
+          default:
+            break;
+          }
+          break;
+        default:
+          break;
+        }
+      }
+      return initialRecord;
+    }
+
+    @Override
+    public BeamSQLRow addInput(BeamSQLRow accumulator, BeamSQLRow input) {
+      BeamSQLRow deltaRecord = new BeamSQLRow(aggDataType);
+      for (int idx = 0; idx < aggElementExpressions.size(); ++idx) {
+        BeamSqlExpression ex = aggElementExpressions.get(idx);
+        String aggFnName = aggFunctions.get(idx);
+        switch (aggFnName) {
+        case "COUNT":
+          deltaRecord.addField(idx, 1 + accumulator.getLong(idx));
+          break;
+        case "AVG":
+        case "SUM":
+          // for both AVG/SUM, a summary value is hold at first.
+          switch (ex.getOutputType()) {
+          case INTEGER:
+            deltaRecord.addField(idx,
+                ex.evaluate(input).getInteger() + accumulator.getInteger(idx));
+            break;
+          case BIGINT:
+            deltaRecord.addField(idx, ex.evaluate(input).getLong() + accumulator.getLong(idx));
+            break;
+          case SMALLINT:
+            deltaRecord.addField(idx,
+                (short) (ex.evaluate(input).getShort() + accumulator.getShort(idx)));
+            break;
+          case TINYINT:
+            deltaRecord.addField(idx,
+                (byte) (ex.evaluate(input).getByte() + accumulator.getByte(idx)));
+            break;
+          case FLOAT:
+            deltaRecord.addField(idx,
+                (float) (ex.evaluate(input).getFloat() + accumulator.getFloat(idx)));
+            break;
+          case DOUBLE:
+            deltaRecord.addField(idx, ex.evaluate(input).getDouble() + accumulator.getDouble(idx));
+            break;
+          default:
+            break;
+          }
+          break;
+        case "MAX":
+          switch (ex.getOutputType()) {
+          case INTEGER:
+            deltaRecord.addField(idx,
+                Math.max(ex.evaluate(input).getInteger(), accumulator.getInteger(idx)));
+            break;
+          case BIGINT:
+            deltaRecord.addField(idx,
+                Math.max(ex.evaluate(input).getLong(), accumulator.getLong(idx)));
+            break;
+          case SMALLINT:
+            deltaRecord.addField(idx,
+                (short) Math.max(ex.evaluate(input).getShort(), accumulator.getShort(idx)));
+            break;
+          case TINYINT:
+            deltaRecord.addField(idx,
+                (byte) Math.max(ex.evaluate(input).getByte(), accumulator.getByte(idx)));
+            break;
+          case FLOAT:
+            deltaRecord.addField(idx,
+                Math.max(ex.evaluate(input).getFloat(), accumulator.getFloat(idx)));
+            break;
+          case DOUBLE:
+            deltaRecord.addField(idx,
+                Math.max(ex.evaluate(input).getDouble(), accumulator.getDouble(idx)));
+            break;
+          case TIMESTAMP:
+            Date preDate = accumulator.getDate(idx);
+            Date nowDate = ex.evaluate(input).getDate();
+            deltaRecord.addField(idx, preDate.getTime() > nowDate.getTime() ? preDate : nowDate);
+            break;
+          default:
+            break;
+          }
+          break;
+        case "MIN":
+          switch (ex.getOutputType()) {
+          case INTEGER:
+            deltaRecord.addField(idx,
+                Math.min(ex.evaluate(input).getInteger(), accumulator.getInteger(idx)));
+            break;
+          case BIGINT:
+            deltaRecord.addField(idx,
+                Math.min(ex.evaluate(input).getLong(), accumulator.getLong(idx)));
+            break;
+          case SMALLINT:
+            deltaRecord.addField(idx,
+                (short) Math.min(ex.evaluate(input).getShort(), accumulator.getShort(idx)));
+            break;
+          case TINYINT:
+            deltaRecord.addField(idx,
+                (byte) Math.min(ex.evaluate(input).getByte(), accumulator.getByte(idx)));
+            break;
+          case FLOAT:
+            deltaRecord.addField(idx,
+                Math.min(ex.evaluate(input).getFloat(), accumulator.getFloat(idx)));
+            break;
+          case DOUBLE:
+            deltaRecord.addField(idx,
+                Math.min(ex.evaluate(input).getDouble(), accumulator.getDouble(idx)));
+            break;
+          case TIMESTAMP:
+            Date preDate = accumulator.getDate(idx);
+            Date nowDate = ex.evaluate(input).getDate();
+            deltaRecord.addField(idx, preDate.getTime() < nowDate.getTime() ? preDate : nowDate);
+            break;
+          default:
+            break;
+          }
+          break;
+        default:
+          break;
+        }
+      }
+      return deltaRecord;
+    }
+
+    @Override
+    public BeamSQLRow mergeAccumulators(Iterable<BeamSQLRow> accumulators) {
+      BeamSQLRow deltaRecord = new BeamSQLRow(aggDataType);
+
+      while (accumulators.iterator().hasNext()) {
+        BeamSQLRow sa = accumulators.iterator().next();
+        for (int idx = 0; idx < aggElementExpressions.size(); ++idx) {
+          BeamSqlExpression ex = aggElementExpressions.get(idx);
+          String aggFnName = aggFunctions.get(idx);
+          switch (aggFnName) {
+          case "COUNT":
+            deltaRecord.addField(idx, deltaRecord.getLong(idx) + sa.getLong(idx));
+            break;
+          case "AVG":
+          case "SUM":
+            // for both AVG/SUM, a summary value is hold at first.
+            switch (ex.getOutputType()) {
+            case INTEGER:
+              deltaRecord.addField(idx, deltaRecord.getInteger(idx) + sa.getInteger(idx));
+              break;
+            case BIGINT:
+              deltaRecord.addField(idx, deltaRecord.getLong(idx) + sa.getLong(idx));
+              break;
+            case SMALLINT:
+              deltaRecord.addField(idx, (short) (deltaRecord.getShort(idx) + sa.getShort(idx)));
+              break;
+            case TINYINT:
+              deltaRecord.addField(idx, (byte) (deltaRecord.getByte(idx) + sa.getByte(idx)));
+              break;
+            case FLOAT:
+              deltaRecord.addField(idx, (float) (deltaRecord.getFloat(idx) + sa.getFloat(idx)));
+              break;
+            case DOUBLE:
+              deltaRecord.addField(idx, deltaRecord.getDouble(idx) + sa.getDouble(idx));
+              break;
+            default:
+              break;
+            }
+            break;
+          case "MAX":
+            switch (ex.getOutputType()) {
+            case INTEGER:
+              deltaRecord.addField(idx, Math.max(deltaRecord.getInteger(idx), sa.getInteger(idx)));
+              break;
+            case BIGINT:
+              deltaRecord.addField(idx, Math.max(deltaRecord.getLong(idx), sa.getLong(idx)));
+              break;
+            case SMALLINT:
+              deltaRecord.addField(idx,
+                  (short) Math.max(deltaRecord.getShort(idx), sa.getShort(idx)));
+              break;
+            case TINYINT:
+              deltaRecord.addField(idx, (byte) Math.max(deltaRecord.getByte(idx), sa.getByte(idx)));
+              break;
+            case FLOAT:
+              deltaRecord.addField(idx, Math.max(deltaRecord.getFloat(idx), sa.getFloat(idx)));
+              break;
+            case DOUBLE:
+              deltaRecord.addField(idx, Math.max(deltaRecord.getDouble(idx), sa.getDouble(idx)));
+              break;
+            case TIMESTAMP:
+              Date preDate = deltaRecord.getDate(idx);
+              Date nowDate = sa.getDate(idx);
+              deltaRecord.addField(idx, preDate.getTime() > nowDate.getTime() ? preDate : nowDate);
+              break;
+            default:
+              break;
+            }
+            break;
+          case "MIN":
+            switch (ex.getOutputType()) {
+            case INTEGER:
+              deltaRecord.addField(idx, Math.min(deltaRecord.getInteger(idx), sa.getInteger(idx)));
+              break;
+            case BIGINT:
+              deltaRecord.addField(idx, Math.min(deltaRecord.getLong(idx), sa.getLong(idx)));
+              break;
+            case SMALLINT:
+              deltaRecord.addField(idx,
+                  (short) Math.min(deltaRecord.getShort(idx), sa.getShort(idx)));
+              break;
+            case TINYINT:
+              deltaRecord.addField(idx, (byte) Math.min(deltaRecord.getByte(idx), sa.getByte(idx)));
+              break;
+            case FLOAT:
+              deltaRecord.addField(idx, Math.min(deltaRecord.getFloat(idx), sa.getFloat(idx)));
+              break;
+            case DOUBLE:
+              deltaRecord.addField(idx, Math.min(deltaRecord.getDouble(idx), sa.getDouble(idx)));
+              break;
+            case TIMESTAMP:
+              Date preDate = deltaRecord.getDate(idx);
+              Date nowDate = sa.getDate(idx);
+              deltaRecord.addField(idx, preDate.getTime() < nowDate.getTime() ? preDate : nowDate);
+              break;
+            default:
+              break;
+            }
+            break;
+          default:
+            break;
+          }
+        }
+      }
+      return deltaRecord;
+    }
+
+    @Override
+    public BeamSQLRow extractOutput(BeamSQLRow accumulator) {
+      BeamSQLRow finalRecord = new BeamSQLRow(aggDataType);
+      for (int idx = 0; idx < aggElementExpressions.size(); ++idx) {
+        BeamSqlExpression ex = aggElementExpressions.get(idx);
+        String aggFnName = aggFunctions.get(idx);
+        switch (aggFnName) {
+        case "COUNT":
+          finalRecord.addField(idx, accumulator.getLong(idx));
+          break;
+        case "AVG":
+          long count = accumulator.getLong(countIndex);
+          switch (ex.getOutputType()) {
+          case INTEGER:
+            finalRecord.addField(idx, (int) (accumulator.getInteger(idx) / count));
+            break;
+          case BIGINT:
+            finalRecord.addField(idx, accumulator.getLong(idx) / count);
+            break;
+          case SMALLINT:
+            finalRecord.addField(idx, (short) (accumulator.getShort(idx) / count));
+            break;
+          case TINYINT:
+            finalRecord.addField(idx, (byte) (accumulator.getByte(idx) / count));
+            break;
+          case FLOAT:
+            finalRecord.addField(idx, (float) (accumulator.getFloat(idx) / count));
+            break;
+          case DOUBLE:
+            finalRecord.addField(idx, accumulator.getDouble(idx) / count);
+            break;
+          default:
+            break;
+          }
+          break;
+        case "SUM":
+          switch (ex.getOutputType()) {
+          case INTEGER:
+            finalRecord.addField(idx, accumulator.getInteger(idx));
+            break;
+          case BIGINT:
+            finalRecord.addField(idx, accumulator.getLong(idx));
+            break;
+          case SMALLINT:
+            finalRecord.addField(idx, accumulator.getShort(idx));
+            break;
+          case TINYINT:
+            finalRecord.addField(idx, accumulator.getByte(idx));
+            break;
+          case FLOAT:
+            finalRecord.addField(idx, accumulator.getFloat(idx));
+            break;
+          case DOUBLE:
+            finalRecord.addField(idx, accumulator.getDouble(idx));
+            break;
+          default:
+            break;
+          }
+          break;
+        case "MAX":
+        case "MIN":
+          switch (ex.getOutputType()) {
+          case INTEGER:
+            finalRecord.addField(idx, accumulator.getInteger(idx));
+            break;
+          case BIGINT:
+            finalRecord.addField(idx, accumulator.getLong(idx));
+            break;
+          case SMALLINT:
+            finalRecord.addField(idx, accumulator.getShort(idx));
+            break;
+          case TINYINT:
+            finalRecord.addField(idx, accumulator.getByte(idx));
+            break;
+          case FLOAT:
+            finalRecord.addField(idx, accumulator.getFloat(idx));
+            break;
+          case DOUBLE:
+            finalRecord.addField(idx, accumulator.getDouble(idx));
+            break;
+          case TIMESTAMP:
+            finalRecord.addField(idx, accumulator.getDate(idx));
+            break;
+          default:
+            break;
+          }
+          break;
+        default:
+          break;
+        }
+      }
+      return finalRecord;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java
new file mode 100644
index 0000000..f174b9c
--- /dev/null
+++ b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java
@@ -0,0 +1,436 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.dsls.sql.schema.transform;
+
+import java.text.ParseException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.beam.dsls.sql.planner.BeamQueryPlanner;
+import org.apache.beam.dsls.sql.schema.BeamSQLRecordType;
+import org.apache.beam.dsls.sql.schema.BeamSQLRecordTypeCoder;
+import org.apache.beam.dsls.sql.schema.BeamSQLRow;
+import org.apache.beam.dsls.sql.schema.BeamSqlRowCoder;
+import org.apache.beam.dsls.sql.transform.BeamAggregationTransforms;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.WithKeys;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory.FieldInfoBuilder;
+import org.apache.calcite.rel.type.RelDataTypeSystem;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.fun.SqlAvgAggFunction;
+import org.apache.calcite.sql.fun.SqlCountAggFunction;
+import org.apache.calcite.sql.fun.SqlMinMaxAggFunction;
+import org.apache.calcite.sql.fun.SqlSumAggFunction;
+import org.apache.calcite.sql.type.BasicSqlType;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.junit.Rule;
+import org.junit.Test;
+
+/**
+ * Unit tests for {@link BeamAggregationTransforms}.
+ *
+ */
+public class BeamAggregationTransformTest extends BeamTransformBaseTest{
+
+  @Rule
+  public TestPipeline p = TestPipeline.create();
+
+  private List<AggregateCall> aggCalls;
+  private BeamSQLRecordType keyType = initTypeOfSqlRow(
+      Arrays.asList(KV.of("f_int", SqlTypeName.INTEGER)));
+
+  /**
+   * This step equals to below query.
+   * <pre>
+   * SELECT `f_int`
+   * , COUNT(*) AS `size`
+   * , SUM(`f_long`) AS `sum1`, AVG(`f_long`) AS `avg1`
+   * , MAX(`f_long`) AS `max1`, MIN(`f_long`) AS `min1`
+   * , SUM(`f_short`) AS `sum2`, AVG(`f_short`) AS `avg2`
+   * , MAX(`f_short`) AS `max2`, MIN(`f_short`) AS `min2`
+   * , SUM(`f_byte`) AS `sum3`, AVG(`f_byte`) AS `avg3`
+   * , MAX(`f_byte`) AS `max3`, MIN(`f_byte`) AS `min3`
+   * , SUM(`f_float`) AS `sum4`, AVG(`f_float`) AS `avg4`
+   * , MAX(`f_float`) AS `max4`, MIN(`f_float`) AS `min4`
+   * , SUM(`f_double`) AS `sum5`, AVG(`f_double`) AS `avg5`
+   * , MAX(`f_double`) AS `max5`, MIN(`f_double`) AS `min5`
+   * , MAX(`f_timestamp`) AS `max7`, MIN(`f_timestamp`) AS `min7`
+   * ,SUM(`f_int2`) AS `sum8`, AVG(`f_int2`) AS `avg8`
+   * , MAX(`f_int2`) AS `max8`, MIN(`f_int2`) AS `min8`
+   * FROM TABLE_NAME
+   * GROUP BY `f_int`
+   * </pre>
+   * @throws ParseException
+   */
+  @Test
+  public void testCountPerElementBasic() throws ParseException {
+    setupEnvironment();
+
+    PCollection<BeamSQLRow> input = p.apply(Create.of(inputRows));
+
+    //1. extract fields in group-by key part
+    PCollection<KV<BeamSQLRow, BeamSQLRow>> exGroupByStream = input.apply("exGroupBy",
+        WithKeys
+            .of(new BeamAggregationTransforms.AggregationGroupByKeyFn(-1, ImmutableBitSet.of(0))));
+
+    //2. apply a GroupByKey.
+    PCollection<KV<BeamSQLRow, Iterable<BeamSQLRow>>> groupedStream = exGroupByStream
+        .apply("groupBy", GroupByKey.<BeamSQLRow, BeamSQLRow>create());
+
+    //3. run aggregation functions
+    PCollection<KV<BeamSQLRow, BeamSQLRow>> aggregatedStream = groupedStream.apply("aggregation",
+        Combine.<BeamSQLRow, BeamSQLRow, BeamSQLRow>groupedValues(
+            new BeamAggregationTransforms.AggregationCombineFn(aggCalls, inputRowType)));
+
+    //4. flat KV to a single record
+    PCollection<BeamSQLRow> mergedStream = aggregatedStream.apply("mergeRecord",
+        ParDo.of(new BeamAggregationTransforms.MergeAggregationRecord(
+            BeamSQLRecordType.from(prepareFinalRowType()), aggCalls)));
+
+    //assert function BeamAggregationTransform.AggregationGroupByKeyFn
+    PAssert.that(exGroupByStream).containsInAnyOrder(prepareResultOfAggregationGroupByKeyFn());
+
+    //assert BeamAggregationTransform.AggregationCombineFn
+    PAssert.that(aggregatedStream).containsInAnyOrder(prepareResultOfAggregationCombineFn());
+
+  //assert BeamAggregationTransform.MergeAggregationRecord
+    PAssert.that(mergedStream).containsInAnyOrder(prepareResultOfMergeAggregationRecord());
+
+    p.run();
+}
+
+  private void setupEnvironment() {
+    regiesterCoder();
+    prepareAggregationCalls();
+  }
+
+  /**
+   * Add Coders in BeamSQL.
+   */
+  private void regiesterCoder() {
+    CoderRegistry cr = p.getCoderRegistry();
+    cr.registerCoder(BeamSQLRow.class, BeamSqlRowCoder.of());
+    cr.registerCoder(BeamSQLRecordType.class, BeamSQLRecordTypeCoder.of());
+  }
+
+  /**
+   * create list of all {@link AggregateCall}.
+   */
+  @SuppressWarnings("deprecation")
+  private void prepareAggregationCalls() {
+    //aggregations for all data type
+    aggCalls = new ArrayList<>();
+    aggCalls.add(
+        new AggregateCall(new SqlCountAggFunction(), false,
+            Arrays.<Integer>asList(),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT),
+            "count")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlSumAggFunction(
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT)), false,
+            Arrays.<Integer>asList(1),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT),
+            "sum1")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+            Arrays.<Integer>asList(1),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT),
+            "avg1")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+            Arrays.<Integer>asList(1),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT),
+            "max1")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+            Arrays.<Integer>asList(1),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT),
+            "min1")
+        );
+
+    aggCalls.add(
+        new AggregateCall(new SqlSumAggFunction(
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT)), false,
+            Arrays.<Integer>asList(2),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT),
+            "sum2")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+            Arrays.<Integer>asList(2),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT),
+            "avg2")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+            Arrays.<Integer>asList(2),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT),
+            "max2")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+            Arrays.<Integer>asList(2),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT),
+            "min2")
+        );
+
+    aggCalls.add(
+        new AggregateCall(
+            new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT)),
+            false,
+            Arrays.<Integer>asList(3),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT),
+            "sum3")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+            Arrays.<Integer>asList(3),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT),
+            "avg3")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+            Arrays.<Integer>asList(3),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT),
+            "max3")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+            Arrays.<Integer>asList(3),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT),
+            "min3")
+        );
+
+    aggCalls.add(
+        new AggregateCall(
+            new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT)),
+            false,
+            Arrays.<Integer>asList(4),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT),
+            "sum4")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+            Arrays.<Integer>asList(4),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT),
+            "avg4")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+            Arrays.<Integer>asList(4),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT),
+            "max4")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+            Arrays.<Integer>asList(4),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT),
+            "min4")
+        );
+
+    aggCalls.add(
+        new AggregateCall(
+            new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE)),
+            false,
+            Arrays.<Integer>asList(5),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE),
+            "sum5")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+            Arrays.<Integer>asList(5),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE),
+            "avg5")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+            Arrays.<Integer>asList(5),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE),
+            "max5")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+            Arrays.<Integer>asList(5),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE),
+            "min5")
+        );
+
+    aggCalls.add(
+        new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+            Arrays.<Integer>asList(7),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TIMESTAMP),
+            "max7")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+            Arrays.<Integer>asList(7),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TIMESTAMP),
+            "min7")
+        );
+
+    aggCalls.add(
+        new AggregateCall(
+            new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER)),
+            false,
+            Arrays.<Integer>asList(8),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER),
+            "sum8")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+            Arrays.<Integer>asList(8),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER),
+            "avg8")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+            Arrays.<Integer>asList(8),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER),
+            "max8")
+        );
+    aggCalls.add(
+        new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+            Arrays.<Integer>asList(8),
+            new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER),
+            "min8")
+        );
+  }
+
+  /**
+   * expected results after {@link BeamAggregationTransforms.AggregationGroupByKeyFn}.
+   */
+  private List<KV<BeamSQLRow, BeamSQLRow>> prepareResultOfAggregationGroupByKeyFn() {
+    return Arrays.asList(
+        KV.of(new BeamSQLRow(keyType, Arrays.<Object>asList(inputRows.get(0).getInteger(0))),
+            inputRows.get(0)),
+        KV.of(new BeamSQLRow(keyType, Arrays.<Object>asList(inputRows.get(1).getInteger(0))),
+            inputRows.get(1)),
+        KV.of(new BeamSQLRow(keyType, Arrays.<Object>asList(inputRows.get(2).getInteger(0))),
+            inputRows.get(2)),
+        KV.of(new BeamSQLRow(keyType, Arrays.<Object>asList(inputRows.get(3).getInteger(0))),
+            inputRows.get(3)));
+  }
+
+  /**
+   * expected results after {@link BeamAggregationTransforms.AggregationCombineFn}.
+   */
+  private List<KV<BeamSQLRow, BeamSQLRow>> prepareResultOfAggregationCombineFn()
+      throws ParseException {
+    BeamSQLRecordType aggPartType = initTypeOfSqlRow(
+        Arrays.asList(KV.of("count", SqlTypeName.BIGINT),
+
+            KV.of("sum1", SqlTypeName.BIGINT), KV.of("avg1", SqlTypeName.BIGINT),
+            KV.of("max1", SqlTypeName.BIGINT), KV.of("min1", SqlTypeName.BIGINT),
+
+            KV.of("sum2", SqlTypeName.SMALLINT), KV.of("avg2", SqlTypeName.SMALLINT),
+            KV.of("max2", SqlTypeName.SMALLINT), KV.of("min2", SqlTypeName.SMALLINT),
+
+            KV.of("sum3", SqlTypeName.TINYINT), KV.of("avg3", SqlTypeName.TINYINT),
+            KV.of("max3", SqlTypeName.TINYINT), KV.of("min3", SqlTypeName.TINYINT),
+
+            KV.of("sum4", SqlTypeName.FLOAT), KV.of("avg4", SqlTypeName.FLOAT),
+            KV.of("max4", SqlTypeName.FLOAT), KV.of("min4", SqlTypeName.FLOAT),
+
+            KV.of("sum5", SqlTypeName.DOUBLE), KV.of("avg5", SqlTypeName.DOUBLE),
+            KV.of("max5", SqlTypeName.DOUBLE), KV.of("min5", SqlTypeName.DOUBLE),
+
+            KV.of("max7", SqlTypeName.TIMESTAMP), KV.of("min7", SqlTypeName.TIMESTAMP),
+
+            KV.of("sum8", SqlTypeName.INTEGER), KV.of("avg8", SqlTypeName.INTEGER),
+            KV.of("max8", SqlTypeName.INTEGER), KV.of("min8", SqlTypeName.INTEGER)
+            ));
+    return Arrays.asList(
+            KV.of(new BeamSQLRow(keyType, Arrays.<Object>asList(inputRows.get(0).getInteger(0))),
+                new BeamSQLRow(aggPartType, Arrays.<Object>asList(
+                    4L,
+                    10000L, 2500L, 4000L, 1000L,
+                    (short) 10, (short) 2, (short) 4, (short) 1,
+                    (byte) 10, (byte) 2, (byte) 4, (byte) 1,
+                    10.0F, 2.5F, 4.0F, 1.0F,
+                    10.0, 2.5, 4.0, 1.0,
+                    format.parse("2017-01-01 02:04:03"), format.parse("2017-01-01 01:01:03"),
+                    10, 2, 4, 1
+                    )))
+            );
+  }
+
+  /**
+   * Row type of final output row.
+   */
+  private RelDataType prepareFinalRowType() {
+    FieldInfoBuilder builder = BeamQueryPlanner.TYPE_FACTORY.builder();
+    List<KV<String, SqlTypeName>> columnMetadata =
+        Arrays.asList(KV.of("f_int", SqlTypeName.INTEGER), KV.of("count", SqlTypeName.BIGINT),
+
+        KV.of("sum1", SqlTypeName.BIGINT), KV.of("avg1", SqlTypeName.BIGINT),
+        KV.of("max1", SqlTypeName.BIGINT), KV.of("min1", SqlTypeName.BIGINT),
+
+        KV.of("sum2", SqlTypeName.SMALLINT), KV.of("avg2", SqlTypeName.SMALLINT),
+        KV.of("max2", SqlTypeName.SMALLINT), KV.of("min2", SqlTypeName.SMALLINT),
+
+        KV.of("sum3", SqlTypeName.TINYINT), KV.of("avg3", SqlTypeName.TINYINT),
+        KV.of("max3", SqlTypeName.TINYINT), KV.of("min3", SqlTypeName.TINYINT),
+
+        KV.of("sum4", SqlTypeName.FLOAT), KV.of("avg4", SqlTypeName.FLOAT),
+        KV.of("max4", SqlTypeName.FLOAT), KV.of("min4", SqlTypeName.FLOAT),
+
+        KV.of("sum5", SqlTypeName.DOUBLE), KV.of("avg5", SqlTypeName.DOUBLE),
+        KV.of("max5", SqlTypeName.DOUBLE), KV.of("min5", SqlTypeName.DOUBLE),
+
+        KV.of("max7", SqlTypeName.TIMESTAMP), KV.of("min7", SqlTypeName.TIMESTAMP),
+
+        KV.of("sum8", SqlTypeName.INTEGER), KV.of("avg8", SqlTypeName.INTEGER),
+        KV.of("max8", SqlTypeName.INTEGER), KV.of("min8", SqlTypeName.INTEGER)
+        );
+    for (KV<String, SqlTypeName> cm : columnMetadata) {
+      builder.add(cm.getKey(), cm.getValue());
+    }
+    return builder.build();
+  }
+
+  /**
+   * expected results after {@link BeamAggregationTransforms.MergeAggregationRecord}.
+   */
+  private BeamSQLRow prepareResultOfMergeAggregationRecord() throws ParseException {
+    return new BeamSQLRow(BeamSQLRecordType.from(prepareFinalRowType()), Arrays.<Object>asList(
+        1, 4L,
+        10000L, 2500L, 4000L, 1000L,
+        (short) 10, (short) 2, (short) 4, (short) 1,
+        (byte) 10, (byte) 2, (byte) 4, (byte) 1,
+        10.0F, 2.5F, 4.0F, 1.0F,
+        10.0, 2.5, 4.0, 1.0,
+        format.parse("2017-01-01 02:04:03"), format.parse("2017-01-01 01:01:03"),
+        10, 2, 4, 1
+        ));
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamTransformBaseTest.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamTransformBaseTest.java b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamTransformBaseTest.java
new file mode 100644
index 0000000..820d7f5
--- /dev/null
+++ b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamTransformBaseTest.java
@@ -0,0 +1,96 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.dsls.sql.schema.transform;
+
+import java.text.DateFormat;
+import java.text.ParseException;
+import java.text.SimpleDateFormat;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.beam.dsls.sql.planner.BeamQueryPlanner;
+import org.apache.beam.dsls.sql.schema.BeamSQLRecordType;
+import org.apache.beam.dsls.sql.schema.BeamSQLRow;
+import org.apache.beam.sdk.values.KV;
+import org.apache.calcite.rel.type.RelDataTypeFactory.FieldInfoBuilder;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.junit.BeforeClass;
+
+/**
+ * shared methods to test PTransforms which execute Beam SQL steps.
+ *
+ */
+public class BeamTransformBaseTest {
+  public static DateFormat format = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
+
+  public static  BeamSQLRecordType inputRowType;
+  public static List<BeamSQLRow> inputRows;
+
+  @BeforeClass
+  public static void prepareInput() throws NumberFormatException, ParseException{
+    List<KV<String, SqlTypeName>> columnMetadata = Arrays.asList(
+        KV.of("f_int", SqlTypeName.INTEGER), KV.of("f_long", SqlTypeName.BIGINT),
+        KV.of("f_short", SqlTypeName.SMALLINT), KV.of("f_byte", SqlTypeName.TINYINT),
+        KV.of("f_float", SqlTypeName.FLOAT), KV.of("f_double", SqlTypeName.DOUBLE),
+        KV.of("f_string", SqlTypeName.VARCHAR), KV.of("f_timestamp", SqlTypeName.TIMESTAMP),
+        KV.of("f_int2", SqlTypeName.INTEGER)
+        );
+    inputRowType = initTypeOfSqlRow(columnMetadata);
+    inputRows = Arrays.asList(
+        initBeamSqlRow(columnMetadata,
+            Arrays.<Object>asList(1, 1000L, Short.valueOf("1"), Byte.valueOf("1"), 1.0F, 1.0,
+                "string_row1", format.parse("2017-01-01 01:01:03"), 1)),
+        initBeamSqlRow(columnMetadata,
+            Arrays.<Object>asList(1, 2000L, Short.valueOf("2"), Byte.valueOf("2"), 2.0F, 2.0,
+                "string_row2", format.parse("2017-01-01 01:02:03"), 2)),
+        initBeamSqlRow(columnMetadata,
+            Arrays.<Object>asList(1, 3000L, Short.valueOf("3"), Byte.valueOf("3"), 3.0F, 3.0,
+                "string_row3", format.parse("2017-01-01 01:03:03"), 3)),
+        initBeamSqlRow(columnMetadata, Arrays.<Object>asList(1, 4000L, Short.valueOf("4"),
+            Byte.valueOf("4"), 4.0F, 4.0, "string_row4", format.parse("2017-01-01 02:04:03"), 4)));
+  }
+
+  /**
+   * create a {@code BeamSQLRecordType} for given column metadata.
+   */
+  public static BeamSQLRecordType initTypeOfSqlRow(List<KV<String, SqlTypeName>> columnMetadata){
+    FieldInfoBuilder builder = BeamQueryPlanner.TYPE_FACTORY.builder();
+    for (KV<String, SqlTypeName> cm : columnMetadata) {
+      builder.add(cm.getKey(), cm.getValue());
+    }
+    return BeamSQLRecordType.from(builder.build());
+  }
+
+  /**
+   * Create an empty row with given column metadata.
+   */
+  public static BeamSQLRow initBeamSqlRow(List<KV<String, SqlTypeName>> columnMetadata) {
+    return initBeamSqlRow(columnMetadata, Arrays.asList());
+  }
+
+  /**
+   * Create a row with given column metadata, and values for each column.
+   *
+   */
+  public static BeamSQLRow initBeamSqlRow(List<KV<String, SqlTypeName>> columnMetadata,
+      List<Object> rowValues){
+    BeamSQLRecordType rowType = initTypeOfSqlRow(columnMetadata);
+
+    return new BeamSQLRow(rowType, rowValues);
+  }
+
+}


[2/2] beam git commit: This closes #3067

Posted by dh...@apache.org.
This closes #3067


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/523482be
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/523482be
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/523482be

Branch: refs/heads/DSL_SQL
Commit: 523482be0501a7bce79087f47c7752b900178a00
Parents: 6729a02 f728fbe
Author: Dan Halperin <dh...@google.com>
Authored: Fri May 12 17:47:06 2017 -0700
Committer: Dan Halperin <dh...@google.com>
Committed: Fri May 12 17:47:06 2017 -0700

----------------------------------------------------------------------
 .../interpreter/operator/BeamSqlPrimitive.java  |  35 +
 .../beam/dsls/sql/rel/BeamAggregationRel.java   |  40 +-
 .../apache/beam/dsls/sql/schema/BeamSQLRow.java |   4 +
 .../beam/dsls/sql/schema/BeamSqlRowCoder.java   |   4 +-
 .../sql/transform/BeamAggregationTransform.java | 120 ----
 .../transform/BeamAggregationTransforms.java    | 671 +++++++++++++++++++
 .../transform/BeamAggregationTransformTest.java | 436 ++++++++++++
 .../schema/transform/BeamTransformBaseTest.java |  96 +++
 8 files changed, 1261 insertions(+), 145 deletions(-)
----------------------------------------------------------------------