You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2015/06/16 18:49:48 UTC
flink git commit: [FLINK-2210] Table API support for aggregation on
columns with null values
Repository: flink
Updated Branches:
refs/heads/master 46ad40588 -> b59c81bc4
[FLINK-2210] Table API support for aggregation on columns with null values
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/b59c81bc
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/b59c81bc
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/b59c81bc
Branch: refs/heads/master
Commit: b59c81bc41f0fc4ade5359dfdf42549a76d412fa
Parents: 46ad405
Author: Shiti <ss...@gmail.com>
Authored: Mon Jun 15 00:29:02 2015 +0530
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Tue Jun 16 18:38:48 2015 +0200
----------------------------------------------------------------------
.../table/codegen/ExpressionCodeGenerator.scala | 19 +++++++
.../api/table/expressions/aggregations.scala | 2 +-
.../api/table/expressions/comparison.scala | 8 +++
.../runtime/ExpressionAggregateFunction.scala | 5 +-
.../scala/table/test/AggregationsITCase.scala | 58 +++++++++++++++++++-
5 files changed, 88 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/b59c81bc/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala
index 49f7600..e109574 100644
--- a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala
+++ b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala
@@ -489,6 +489,25 @@ abstract class ExpressionCodeGenerator[R](
""".stripMargin
}
+ case NumericIsNotNull(child) =>
+ val childCode = generateExpression(child)
+ if (nullCheck) {
+ childCode.code +
+ s"""
+ |boolean $nullTerm = ${childCode.nullTerm};
+ |if ($nullTerm) {
+ | 0;
+ |} else {
+ | $resultTpe $resultTerm = ${childCode.resultTerm} != null ? 1 : 0;
+ |}
+ """.stripMargin
+ } else {
+ childCode.code +
+ s"""
+ |$resultTpe $resultTerm = ${childCode.resultTerm} != null ? 1 : 0;
+ """.stripMargin
+ }
+
case _ => throw new ExpressionException("Could not generate code for expression " + expr)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/b59c81bc/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala
index 08e319d..a762f66 100644
--- a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala
+++ b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala
@@ -89,7 +89,7 @@ case class Count(child: Expression) extends Aggregation {
case class Avg(child: Expression) extends Aggregation {
override def toString = s"($child).avg"
- override def getIntermediateFields: Seq[Expression] = Seq(child, Literal(1))
+ override def getIntermediateFields: Seq[Expression] = Seq(child, NumericIsNotNull(child))
// This is just sweet. Use our own AST representation and let the code generator do
// our dirty work.
override def getFinalField(inputs: Seq[Expression]): Expression =
http://git-wip-us.apache.org/repos/asf/flink/blob/b59c81bc/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala
index 687ea7a..c60acf9 100644
--- a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala
+++ b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala
@@ -91,3 +91,11 @@ case class IsNotNull(child: Expression) extends UnaryExpression {
override def toString = s"($child).isNotNull"
}
+
+case class NumericIsNotNull(child: Expression) extends UnaryExpression {
+ def typeInfo = {
+ BasicTypeInfo.INT_TYPE_INFO
+ }
+
+ override def toString = s"($child).numericIsNotNull"
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/b59c81bc/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/runtime/ExpressionAggregateFunction.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/runtime/ExpressionAggregateFunction.scala b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/runtime/ExpressionAggregateFunction.scala
index 7e9bc0d..7d7dc1c 100644
--- a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/runtime/ExpressionAggregateFunction.scala
+++ b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/runtime/ExpressionAggregateFunction.scala
@@ -53,7 +53,10 @@ class ExpressionAggregateFunction(
var i = 0
val len = functions.length
while (i < len) {
- functions(i).aggregate(current.productElement(fieldPositions(i)))
+ val element: Any = current.productElement(fieldPositions(i))
+ if (element != null){
+ functions(i).aggregate(element)
+ }
i += 1
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/b59c81bc/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala b/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
index 3b7ab8d..62ac345 100644
--- a/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
+++ b/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
@@ -18,13 +18,16 @@
package org.apache.flink.api.scala.table.test
-import org.apache.flink.api.table.ExpressionException
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.table._
import org.apache.flink.api.scala.util.CollectionDataSets
+import org.apache.flink.api.table.typeinfo.RowTypeInfo
+import org.apache.flink.api.table.{ExpressionException, Row}
import org.apache.flink.core.fs.FileSystem.WriteMode
-import org.apache.flink.test.util.{TestBaseUtils, MultipleProgramsTestBase}
import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
+import org.apache.flink.test.util.{MultipleProgramsTestBase, TestBaseUtils}
+import org.junit.Assert._
import org.junit._
import org.junit.rules.TemporaryFolder
import org.junit.runner.RunWith
@@ -123,5 +126,56 @@ class AggregationsITCase(mode: TestExecutionMode) extends MultipleProgramsTestBa
expected = ""
}
+ @Test
+ def testAggregationWithNullValues(): Unit = {
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val dataSet = env.fromElements[(Integer, String)](
+ (123, "a"), (234, "b"), (345, "c"), (0, "d"))
+
+ implicit val rowInfo: TypeInformation[Row] = new RowTypeInfo(
+ Seq(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO), Seq("id", "name"))
+
+ val rowDataSet = dataSet.map {
+ entry =>
+ val row = new Row(2)
+ val amount = if (entry._1 > 200) entry._1 else null
+ row.setField(0, amount)
+ row.setField(1, entry._2)
+ row
+ }
+
+ val entries = rowDataSet.toTable.select('id.avg, 'id.sum, 'id.count).collect().head
+ val mean = entries.productElement(0).toString.toInt
+ val sum = entries.productElement(1).toString.toInt
+ val count = entries.productElement(2).toString.toInt
+
+ assertEquals(4,count)
+
+ val computedMean = sum / 2
+ assertEquals(computedMean, mean)
+ }
+
+ @Test
+ def testAggregationWhenAllValuesAreNull(): Unit = {
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val dataSet = env.fromElements[(Integer, String)](
+ (123, "a"), (234, "b"), (345, "c"), (0, "d"))
+
+ implicit val rowInfo: TypeInformation[Row] = new RowTypeInfo(
+ Seq(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO), Seq("id", "name"))
+
+ val rowDataSet = dataSet.map {
+ entry =>
+ val row = new Row(2)
+ row.setField(0, null)
+ row.setField(1, entry._2)
+ row
+ }
+
+ val entries = rowDataSet.toTable.select('id.max).collect().head.productElement(0)
+ assertEquals(entries, null)
+ }
}