You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2015/06/16 18:49:48 UTC

flink git commit: [FLINK-2210] Table API support for aggregation on columns with null values

Repository: flink
Updated Branches:
  refs/heads/master 46ad40588 -> b59c81bc4


[FLINK-2210] Table API support for aggregation on columns with null values


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/b59c81bc
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/b59c81bc
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/b59c81bc

Branch: refs/heads/master
Commit: b59c81bc41f0fc4ade5359dfdf42549a76d412fa
Parents: 46ad405
Author: Shiti <ss...@gmail.com>
Authored: Mon Jun 15 00:29:02 2015 +0530
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Tue Jun 16 18:38:48 2015 +0200

----------------------------------------------------------------------
 .../table/codegen/ExpressionCodeGenerator.scala | 19 +++++++
 .../api/table/expressions/aggregations.scala    |  2 +-
 .../api/table/expressions/comparison.scala      |  8 +++
 .../runtime/ExpressionAggregateFunction.scala   |  5 +-
 .../scala/table/test/AggregationsITCase.scala   | 58 +++++++++++++++++++-
 5 files changed, 88 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/b59c81bc/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala
index 49f7600..e109574 100644
--- a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala
+++ b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala
@@ -489,6 +489,25 @@ abstract class ExpressionCodeGenerator[R](
             """.stripMargin
         }
 
+      case NumericIsNotNull(child) =>
+        val childCode = generateExpression(child)
+        if (nullCheck) {
+          childCode.code +
+            s"""
+               |boolean $nullTerm = ${childCode.nullTerm};
+               |if ($nullTerm) {
+               |  0;
+               |} else {
+               |  $resultTpe $resultTerm = ${childCode.resultTerm} != null ? 1 : 0;
+               |}
+            """.stripMargin
+        } else {
+          childCode.code +
+            s"""
+               |$resultTpe $resultTerm = ${childCode.resultTerm} != null ? 1 : 0;
+            """.stripMargin
+        }
+
       case _ => throw new ExpressionException("Could not generate code for expression " + expr)
     }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b59c81bc/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala
index 08e319d..a762f66 100644
--- a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala
+++ b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala
@@ -89,7 +89,7 @@ case class Count(child: Expression) extends Aggregation {
 case class Avg(child: Expression) extends Aggregation {
   override def toString = s"($child).avg"
 
-  override def getIntermediateFields: Seq[Expression] = Seq(child, Literal(1))
+  override def getIntermediateFields: Seq[Expression] = Seq(child, NumericIsNotNull(child))
   // This is just sweet. Use our own AST representation and let the code generator do
   // our dirty work.
   override def getFinalField(inputs: Seq[Expression]): Expression =

http://git-wip-us.apache.org/repos/asf/flink/blob/b59c81bc/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala
index 687ea7a..c60acf9 100644
--- a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala
+++ b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala
@@ -91,3 +91,11 @@ case class IsNotNull(child: Expression) extends UnaryExpression {
 
   override def toString = s"($child).isNotNull"
 }
+
+case class NumericIsNotNull(child: Expression) extends UnaryExpression {
+  def typeInfo = {
+    BasicTypeInfo.INT_TYPE_INFO
+  }
+
+  override def toString = s"($child).numericIsNotNull"
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/b59c81bc/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/runtime/ExpressionAggregateFunction.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/runtime/ExpressionAggregateFunction.scala b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/runtime/ExpressionAggregateFunction.scala
index 7e9bc0d..7d7dc1c 100644
--- a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/runtime/ExpressionAggregateFunction.scala
+++ b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/runtime/ExpressionAggregateFunction.scala
@@ -53,7 +53,10 @@ class ExpressionAggregateFunction(
       var i = 0
       val len = functions.length
       while (i < len) {
-        functions(i).aggregate(current.productElement(fieldPositions(i)))
+        val element: Any = current.productElement(fieldPositions(i))
+        if (element != null){
+          functions(i).aggregate(element)
+        }
         i += 1
       }
     }

http://git-wip-us.apache.org/repos/asf/flink/blob/b59c81bc/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala b/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
index 3b7ab8d..62ac345 100644
--- a/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
+++ b/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
@@ -18,13 +18,16 @@
 
 package org.apache.flink.api.scala.table.test
 
-import org.apache.flink.api.table.ExpressionException
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
 import org.apache.flink.api.scala._
 import org.apache.flink.api.scala.table._
 import org.apache.flink.api.scala.util.CollectionDataSets
+import org.apache.flink.api.table.typeinfo.RowTypeInfo
+import org.apache.flink.api.table.{ExpressionException, Row}
 import org.apache.flink.core.fs.FileSystem.WriteMode
-import org.apache.flink.test.util.{TestBaseUtils, MultipleProgramsTestBase}
 import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
+import org.apache.flink.test.util.{MultipleProgramsTestBase, TestBaseUtils}
+import org.junit.Assert._
 import org.junit._
 import org.junit.rules.TemporaryFolder
 import org.junit.runner.RunWith
@@ -123,5 +126,56 @@ class AggregationsITCase(mode: TestExecutionMode) extends MultipleProgramsTestBa
     expected = ""
   }
 
+  @Test
+  def testAggregationWithNullValues(): Unit = {
+
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val dataSet = env.fromElements[(Integer, String)](
+      (123, "a"), (234, "b"), (345, "c"), (0, "d"))
+
+    implicit val rowInfo: TypeInformation[Row] = new RowTypeInfo(
+      Seq(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO), Seq("id", "name"))
+
+    val rowDataSet = dataSet.map {
+      entry =>
+        val row = new Row(2)
+        val amount = if (entry._1 > 200) entry._1 else null
+        row.setField(0, amount)
+        row.setField(1, entry._2)
+        row
+    }
+
+    val entries = rowDataSet.toTable.select('id.avg, 'id.sum, 'id.count).collect().head
+    val mean = entries.productElement(0).toString.toInt
+    val sum = entries.productElement(1).toString.toInt
+    val count = entries.productElement(2).toString.toInt
+
+    assertEquals(4,count)
+
+    val computedMean = sum / 2
+    assertEquals(computedMean, mean)
+  }
+
+  @Test
+  def testAggregationWhenAllValuesAreNull(): Unit = {
+
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val dataSet = env.fromElements[(Integer, String)](
+      (123, "a"), (234, "b"), (345, "c"), (0, "d"))
+
+    implicit val rowInfo: TypeInformation[Row] = new RowTypeInfo(
+      Seq(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO), Seq("id", "name"))
+
+    val rowDataSet = dataSet.map {
+      entry =>
+        val row = new Row(2)
+        row.setField(0, null)
+        row.setField(1, entry._2)
+        row
+    }
+
+    val entries = rowDataSet.toTable.select('id.max).collect().head.productElement(0)
+    assertEquals(entries, null)
+  }
 
 }