You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/05/04 22:51:41 UTC
[2/2] flink git commit: [FLINK-5906] [table] Add support to register
UDAGGs for Table API and SQL.
[FLINK-5906] [table] Add support to register UDAGGs for Table API and SQL.
This closes #3809.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/981dea41
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/981dea41
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/981dea41
Branch: refs/heads/master
Commit: 981dea41e593f3db763af3d0366bf7adbdd1d3bf
Parents: d6435e8
Author: shaoxuan-wang <ws...@gmail.com>
Authored: Tue May 2 23:00:51 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Thu May 4 23:40:31 2017 +0200
----------------------------------------------------------------------
.../flink/table/api/TableEnvironment.scala | 25 ++-
.../table/api/java/BatchTableEnvironment.scala | 22 ++-
.../table/api/java/StreamTableEnvironment.scala | 22 ++-
.../table/api/scala/BatchTableEnvironment.scala | 18 +-
.../api/scala/StreamTableEnvironment.scala | 18 +-
.../flink/table/api/scala/expressionDsl.scala | 3 +
.../org/apache/flink/table/api/table.scala | 21 +-
.../flink/table/codegen/CodeGenerator.scala | 97 +++++++--
.../codegen/calls/ScalarFunctionCallGen.scala | 4 +-
.../codegen/calls/TableFunctionCallGen.scala | 2 +-
.../table/expressions/UDAGGExpression.scala | 36 ++++
.../flink/table/expressions/aggregations.scala | 75 ++++++-
.../apache/flink/table/expressions/call.scala | 13 +-
.../table/functions/AggregateFunction.scala | 14 +-
.../aggfunctions/CountAggFunction.scala | 3 +
.../table/functions/utils/AggSqlFunction.scala | 179 +++++++++++++++++
.../functions/utils/ScalarSqlFunction.scala | 33 +---
.../utils/UserDefinedFunctionUtils.scala | 198 ++++++++++++++-----
.../flink/table/plan/ProjectionTranslator.scala | 57 ++++++
.../flink/table/plan/logical/operators.scala | 4 +-
.../table/plan/nodes/CommonAggregate.scala | 2 +-
.../flink/table/plan/nodes/OverAggregate.scala | 2 +-
.../nodes/datastream/DataStreamAggregate.scala | 12 +-
.../table/runtime/aggregate/AggregateUtil.scala | 48 +++--
.../flink/table/validate/FunctionCatalog.scala | 13 +-
.../api/java/utils/UserDefinedAggFunctions.java | 95 +++++++++
.../scala/batch/sql/AggregationsITCase.scala | 43 ++--
.../scala/batch/sql/WindowAggregateTest.scala | 43 +++-
.../scala/batch/table/AggregationsITCase.scala | 9 +-
.../api/scala/batch/table/GroupWindowTest.scala | 155 +++++++++++++++
.../AggregationsStringExpressionTest.scala | 58 ++++++
.../validation/AggregationsValidationTest.scala | 99 ++++++++++
.../scala/stream/sql/WindowAggregateTest.scala | 74 ++++---
.../scala/stream/table/AggregationsITCase.scala | 43 ++--
.../scala/stream/table/GroupWindowTest.scala | 120 ++++++++++-
.../scala/stream/table/OverWindowITCase.scala | 45 +++--
.../api/scala/stream/table/OverWindowTest.scala | 88 ++++++---
.../GroupWindowStringExpressionTest.scala | 65 ++++++
38 files changed, 1617 insertions(+), 241 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
index 45267d2..06c405e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
@@ -50,8 +50,8 @@ import org.apache.flink.table.calcite.{FlinkPlannerImpl, FlinkRelBuilder, FlinkT
import org.apache.flink.table.catalog.{ExternalCatalog, ExternalCatalogSchema}
import org.apache.flink.table.codegen.{CodeGenerator, ExpressionReducer}
import org.apache.flink.table.expressions.{Alias, Expression, UnresolvedFieldReference}
-import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{checkForInstantiation, checkNotSingleton, createScalarSqlFunction, createTableSqlFunctions}
-import org.apache.flink.table.functions.{ScalarFunction, TableFunction}
+import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
+import org.apache.flink.table.functions.{ScalarFunction, TableFunction, AggregateFunction}
import org.apache.flink.table.plan.cost.DataSetCostFactory
import org.apache.flink.table.plan.logical.{CatalogNode, LogicalRelNode}
import org.apache.flink.table.plan.rules.FlinkRuleSets
@@ -352,6 +352,27 @@ abstract class TableEnvironment(val config: TableConfig) {
}
/**
+ * Registers an [[AggregateFunction]] under a unique name. Replaces already existing
+ * user-defined functions under this name.
+ */
+ private[flink] def registerAggregateFunctionInternal[T: TypeInformation, ACC](
+ name: String, function: AggregateFunction[T, ACC]): Unit = {
+ // check if class not Scala object
+ checkNotSingleton(function.getClass)
+ // check if class could be instantiated
+ checkForInstantiation(function.getClass)
+
+ val typeInfo: TypeInformation[_] = implicitly[TypeInformation[T]]
+
+ // register in Table API
+ functionCatalog.registerFunction(name, function.getClass)
+
+ // register in SQL API
+ val sqlFunctions = createAggregateSqlFunction(name, function, typeInfo, typeFactory)
+ functionCatalog.registerSqlFunction(sqlFunctions)
+ }
+
+ /**
* Registers a [[Table]] under a unique name in the TableEnvironment's catalog.
* Registered tables can be referenced in SQL queries.
*
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala
index de5f789..03fb77e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala
@@ -22,7 +22,7 @@ import org.apache.flink.api.java.typeutils.TypeExtractor
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
import org.apache.flink.table.expressions.ExpressionParser
import org.apache.flink.table.api._
-import org.apache.flink.table.functions.TableFunction
+import org.apache.flink.table.functions.{AggregateFunction, TableFunction}
/**
* The [[TableEnvironment]] for a Java batch [[DataSet]]
@@ -178,4 +178,24 @@ class BatchTableEnvironment(
registerTableFunctionInternal[T](name, tf)
}
+
+ /**
+ * Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog.
+ * Registered functions can be referenced in Table API and SQL queries.
+ *
+ * @param name The name under which the function is registered.
+ * @param f The AggregateFunction to register.
+ * @tparam T The type of the output value.
+ * @tparam ACC The type of aggregate accumulator.
+ */
+ def registerFunction[T, ACC](
+ name: String,
+ f: AggregateFunction[T, ACC])
+ : Unit = {
+ implicit val typeInfo: TypeInformation[T] = TypeExtractor
+ .createTypeInfo(f, classOf[AggregateFunction[T, ACC]], f.getClass, 0)
+ .asInstanceOf[TypeInformation[T]]
+
+ registerAggregateFunctionInternal[T, ACC](name, f)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvironment.scala
index 4d9f1e1..a649584 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvironment.scala
@@ -20,7 +20,7 @@ package org.apache.flink.table.api.java
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.TypeExtractor
import org.apache.flink.table.api._
-import org.apache.flink.table.functions.TableFunction
+import org.apache.flink.table.functions.{AggregateFunction, TableFunction}
import org.apache.flink.table.expressions.ExpressionParser
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
@@ -180,4 +180,24 @@ class StreamTableEnvironment(
registerTableFunctionInternal[T](name, tf)
}
+
+ /**
+ * Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog.
+ * Registered functions can be referenced in Table API and SQL queries.
+ *
+ * @param name The name under which the function is registered.
+ * @param f The AggregateFunction to register.
+ * @tparam T The type of the output value.
+ * @tparam ACC The type of aggregate accumulator.
+ */
+ def registerFunction[T, ACC](
+ name: String,
+ f: AggregateFunction[T, ACC])
+ : Unit = {
+ implicit val typeInfo: TypeInformation[T] = TypeExtractor
+ .createTypeInfo(f, classOf[AggregateFunction[T, ACC]], f.getClass, 0)
+ .asInstanceOf[TypeInformation[T]]
+
+ registerAggregateFunctionInternal[T, ACC](name, f)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/BatchTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/BatchTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/BatchTableEnvironment.scala
index 3ae8c31..0dd7ca0 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/BatchTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/BatchTableEnvironment.scala
@@ -21,7 +21,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.scala._
import org.apache.flink.table.api._
import org.apache.flink.table.expressions.Expression
-import org.apache.flink.table.functions.TableFunction
+import org.apache.flink.table.functions.{AggregateFunction, TableFunction}
import _root_.scala.reflect.ClassTag
@@ -151,4 +151,20 @@ class BatchTableEnvironment(
def registerFunction[T: TypeInformation](name: String, tf: TableFunction[T]): Unit = {
registerTableFunctionInternal(name, tf)
}
+
+ /**
+ * Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog.
+ * Registered functions can be referenced in Table API and SQL queries.
+ *
+ * @param name The name under which the function is registered.
+ * @param f The AggregateFunction to register.
+ * @tparam T The type of the output value.
+ * @tparam ACC The type of aggregate accumulator.
+ */
+ def registerFunction[T: TypeInformation, ACC](
+ name: String,
+ f: AggregateFunction[T, ACC])
+ : Unit = {
+ registerAggregateFunctionInternal[T, ACC](name, f)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvironment.scala
index 0113146..0552d7c 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvironment.scala
@@ -19,7 +19,7 @@ package org.apache.flink.table.api.scala
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.table.api.{TableEnvironment, Table, TableConfig}
-import org.apache.flink.table.functions.TableFunction
+import org.apache.flink.table.functions.{AggregateFunction, TableFunction}
import org.apache.flink.table.expressions.Expression
import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
import org.apache.flink.streaming.api.scala.asScalaStream
@@ -152,4 +152,20 @@ class StreamTableEnvironment(
def registerFunction[T: TypeInformation](name: String, tf: TableFunction[T]): Unit = {
registerTableFunctionInternal(name, tf)
}
+
+ /**
+ * Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog.
+ * Registered functions can be referenced in Table API and SQL queries.
+ *
+ * @param name The name under which the function is registered.
+ * @param f The AggregateFunction to register.
+ * @tparam T The type of the output value.
+ * @tparam ACC The type of aggregate accumulator.
+ */
+ def registerFunction[T: TypeInformation, ACC](
+ name: String,
+ f: AggregateFunction[T, ACC])
+ : Unit = {
+ registerAggregateFunctionInternal[T, ACC](name, f)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
index b3de4a4..cc58ff5 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
@@ -26,6 +26,7 @@ import org.apache.flink.table.api.{TableException, CurrentRow, CurrentRange, Unb
import org.apache.flink.table.expressions.ExpressionUtils.{convertArray, toMilliInterval, toMonthInterval, toRowInterval}
import org.apache.flink.table.expressions.TimeIntervalUnit.TimeIntervalUnit
import org.apache.flink.table.expressions._
+import org.apache.flink.table.functions.AggregateFunction
import scala.language.implicitConversions
@@ -773,6 +774,8 @@ trait ImplicitExpressionConversions {
implicit def sqlTimestamp2Literal(sqlTimestamp: Timestamp): Expression =
Literal(sqlTimestamp)
implicit def array2ArrayConstructor(array: Array[_]): Expression = convertArray(array)
+ implicit def userDefinedAggFunctionConstructor[T: TypeInformation, ACC]
+ (udagg: AggregateFunction[T, ACC]): UDAGGExpression[T, ACC] = UDAGGExpression(udagg)
}
// ------------------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala
index 9606979..87dde0a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala
@@ -151,7 +151,9 @@ class Table(
*/
def select(fields: String): Table = {
val fieldExprs = ExpressionParser.parseExpressionList(fields)
- select(fieldExprs: _*)
+ //get the correct expression for AggFunctionCall
+ val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, tableEnv))
+ select(withResolvedAggFunctionCall: _*)
}
/**
@@ -167,7 +169,7 @@ class Table(
def as(fields: Expression*): Table = {
logicalPlan match {
- case functionCall: LogicalTableFunctionCall if functionCall.child == null => {
+ case functionCall: LogicalTableFunctionCall if functionCall.child == null =>
// If the logical plan is a TableFunctionCall, we replace its field names to avoid special
// cases during the validation.
if (fields.length != functionCall.output.length) {
@@ -181,7 +183,7 @@ class Table(
}
new Table(
tableEnv,
- new LogicalTableFunctionCall(
+ LogicalTableFunctionCall(
functionCall.functionName,
functionCall.tableFunction,
functionCall.parameters,
@@ -189,7 +191,6 @@ class Table(
fields.map(_.asInstanceOf[UnresolvedFieldReference].name).toArray,
functionCall.child)
)
- }
case _ =>
// prepend an AliasNode
new Table(tableEnv, AliasNode(fields, logicalPlan).validate(tableEnv))
@@ -908,7 +909,9 @@ class GroupedTable(
*/
def select(fields: String): Table = {
val fieldExprs = ExpressionParser.parseExpressionList(fields)
- select(fieldExprs: _*)
+ //get the correct expression for AggFunctionCall
+ val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, table.tableEnv))
+ select(withResolvedAggFunctionCall: _*)
}
}
@@ -983,7 +986,9 @@ class OverWindowedTable(
def select(fields: String): Table = {
val fieldExprs = ExpressionParser.parseExpressionList(fields)
- select(fieldExprs: _*)
+ //get the correct expression for AggFunctionCall
+ val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, table.tableEnv))
+ select(withResolvedAggFunctionCall: _*)
}
}
@@ -1043,7 +1048,9 @@ class WindowGroupedTable(
*/
def select(fields: String): Table = {
val fieldExprs = ExpressionParser.parseExpressionList(fields)
- select(fieldExprs: _*)
+ //get the correct expression for AggFunctionCall
+ val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, table.tableEnv))
+ select(withResolvedAggFunctionCall: _*)
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
index 648efe6..5bb3b0e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
@@ -18,6 +18,8 @@
package org.apache.flink.table.codegen
+import java.lang.reflect.ParameterizedType
+import java.lang.{Iterable => JIterable}
import java.math.{BigDecimal => JBigDecimal}
import org.apache.calcite.avatica.util.DateTimeUtils
@@ -45,6 +47,7 @@ import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
import org.apache.flink.table.runtime.TableFunctionCollector
import org.apache.flink.table.typeutils.TypeCheckUtils._
import org.apache.flink.types.Row
+import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getUserDefinedMethod, signatureToString}
import scala.collection.JavaConversions._
import scala.collection.mutable
@@ -258,6 +261,9 @@ class CodeGenerator(
* @param constantFlags An optional parameter to define where to set constant boolean flags in
* the output row.
* @param outputArity The number of fields in the output row.
+ * @param needRetract a flag to indicate if the aggregate needs the retract method
+ * @param needMerge a flag to indicate if the aggregate needs the merge method
+ * @param needReset a flag to indicate if the aggregate needs the resetAccumulator method
*
* @return A GeneratedAggregationsFunction
*/
@@ -274,7 +280,8 @@ class CodeGenerator(
constantFlags: Option[Array[(Int, Boolean)]],
outputArity: Int,
needRetract: Boolean,
- needMerge: Boolean)
+ needMerge: Boolean,
+ needReset: Boolean)
: GeneratedAggregationsFunction = {
// get unique function name
@@ -282,19 +289,80 @@ class CodeGenerator(
// register UDAGGs
val aggs = aggregates.map(a => generator.addReusableFunction(a))
// get java types of accumulators
- val accTypes = aggregates.map { a =>
- a.getClass.getMethod("createAccumulator").getReturnType.getCanonicalName
+ val accTypeClasses = aggregates.map { a =>
+ a.getClass.getMethod("createAccumulator").getReturnType
}
+ val accTypes = accTypeClasses.map(_.getCanonicalName)
- // get java types of input fields
- val javaTypes = inputType.getFieldList
- .map(f => FlinkTypeFactory.toTypeInfo(f.getType))
- .map(t => t.getTypeClass.getCanonicalName)
+ // get java classes of input fields
+ val javaClasses = inputType.getFieldList
+ .map(f => FlinkTypeFactory.toTypeInfo(f.getType).getTypeClass)
// get parameter lists for aggregation functions
- val parameters = aggFields.map {inFields =>
- val fields = for (f <- inFields) yield s"(${javaTypes(f)}) input.getField($f)"
+ val parameters = aggFields.map { inFields =>
+ val fields = for (f <- inFields) yield
+ s"(${javaClasses(f).getCanonicalName}) input.getField($f)"
fields.mkString(", ")
}
+ val methodSignaturesList = aggFields.map {
+ inFields => for (f <- inFields) yield javaClasses(f)
+ }
+
+ // check and validate the needed methods
+ aggregates.zipWithIndex.map {
+ case (a, i) => {
+ getUserDefinedMethod(a, "accumulate", Array(accTypeClasses(i)) ++ methodSignaturesList(i))
+ .getOrElse(
+ throw new CodeGenException(
+ s"No matching accumulate method found for AggregateFunction " +
+ s"'${a.getClass.getCanonicalName}'" +
+ s"with parameters '${signatureToString(methodSignaturesList(i))}'.")
+ )
+
+ if (needRetract) {
+ getUserDefinedMethod(a, "retract", Array(accTypeClasses(i)) ++ methodSignaturesList(i))
+ .getOrElse(
+ throw new CodeGenException(
+ s"No matching retract method found for AggregateFunction " +
+ s"'${a.getClass.getCanonicalName}'" +
+ s"with parameters '${signatureToString(methodSignaturesList(i))}'.")
+ )
+ }
+
+ if (needMerge) {
+ val methods =
+ getUserDefinedMethod(a, "merge", Array(accTypeClasses(i), classOf[JIterable[Any]]))
+ .getOrElse(
+ throw new CodeGenException(
+ s"No matching merge method found for AggregateFunction " +
+ s"${a.getClass.getCanonicalName}'.")
+ )
+
+ var iterableTypeClass = methods.getGenericParameterTypes.apply(1)
+ .asInstanceOf[ParameterizedType].getActualTypeArguments.apply(0)
+ // further extract iterableTypeClass if the accumulator has generic type
+ iterableTypeClass match {
+ case impl: ParameterizedType => iterableTypeClass = impl.getRawType
+ case _ =>
+ }
+
+ if (iterableTypeClass != accTypeClasses(i)) {
+ throw new CodeGenException(
+ s"merge method in AggregateFunction ${a.getClass.getCanonicalName} does not have " +
+ s"the correct Iterable type. Actually: ${iterableTypeClass.toString}. " +
+ s"Expected: ${accTypeClasses(i).toString}")
+ }
+ }
+
+ if (needReset) {
+ getUserDefinedMethod(a, "resetAccumulator", Array(accTypeClasses(i)))
+ .getOrElse(
+ throw new CodeGenException(
+ s"No matching resetAccumulator method found for " +
+ s"aggregate ${a.getClass.getCanonicalName}'.")
+ )
+ }
+ }
+ }
def genSetAggregationResults: String = {
@@ -529,9 +597,14 @@ class CodeGenerator(
| ((${accTypes(i)}) accs.getField($i)));""".stripMargin
}.mkString("\n")
- j"""$sig {
- |$reset
- | }""".stripMargin
+ if (needReset) {
+ j"""$sig {
+ |$reset
+ | }""".stripMargin
+ } else {
+ j"""$sig {
+ | }""".stripMargin
+ }
}
var funcCode =
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
index b0b4e09..07a8708 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
@@ -44,10 +44,10 @@ class ScalarFunctionCallGen(
operands: Seq[GeneratedExpression])
: GeneratedExpression = {
// determine function method and result class
- val matchingMethod = getEvalMethod(scalarFunction, signature)
+ val matchingMethod = getUserDefinedMethod(scalarFunction, "eval", typeInfoToClass(signature))
.getOrElse(throw new CodeGenException("No matching signature found."))
val matchingSignature = matchingMethod.getParameterTypes
- val resultClass = getResultTypeClass(scalarFunction, matchingSignature)
+ val resultClass = getResultTypeClassOfScalarFunction(scalarFunction, matchingSignature)
// zip for variable signatures
var paramToOperands = matchingSignature.zip(operands)
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala
index ba90292..a3609c1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala
@@ -45,7 +45,7 @@ class TableFunctionCallGen(
operands: Seq[GeneratedExpression])
: GeneratedExpression = {
// determine function method
- val matchingMethod = getEvalMethod(tableFunction, signature)
+ val matchingMethod = getUserDefinedMethod(tableFunction, "eval", typeInfoToClass(signature))
.getOrElse(throw new CodeGenException("No matching signature found."))
val matchingSignature = matchingMethod.getParameterTypes
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/UDAGGExpression.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/UDAGGExpression.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/UDAGGExpression.scala
new file mode 100644
index 0000000..c0e213d
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/UDAGGExpression.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.expressions
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.table.functions.AggregateFunction
+
+/**
+ * A class which creates a call to an aggregateFunction
+ */
+case class UDAGGExpression[T: TypeInformation, ACC](aggregateFunction: AggregateFunction[T, ACC]) {
+
+ /**
+ * Creates a call to an [[AggregateFunction]].
+ *
+ * @param params actual parameters of function
+ * @return a [[AggFunctionCall]]
+ */
+ def apply(params: Expression*): AggFunctionCall =
+ AggFunctionCall(aggregateFunction, params)
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
index 72e7e4b..7b180ae 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
@@ -23,13 +23,18 @@ import org.apache.calcite.sql.SqlKind._
import org.apache.calcite.sql.fun._
import org.apache.calcite.tools.RelBuilder
import org.apache.calcite.tools.RelBuilder.AggCall
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.functions.utils.AggSqlFunction
+import org.apache.flink.table.typeutils.TypeCheckUtils
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.table.calcite.FlinkTypeFactory
-import org.apache.flink.table.typeutils.TypeCheckUtils
+import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
+import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess}
-abstract sealed class Aggregation extends UnaryExpression {
+abstract sealed class Aggregation extends Expression {
- override def toString = s"Aggregate($child)"
+ override def toString = s"Aggregate"
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode =
throw new UnsupportedOperationException("Aggregate cannot be transformed to RexNode")
@@ -47,6 +52,7 @@ abstract sealed class Aggregation extends UnaryExpression {
}
case class Sum(child: Expression) extends Aggregation {
+ override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"sum($child)"
override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
@@ -67,6 +73,7 @@ case class Sum(child: Expression) extends Aggregation {
}
case class Sum0(child: Expression) extends Aggregation {
+ override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"sum0($child)"
override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
@@ -83,6 +90,7 @@ case class Sum0(child: Expression) extends Aggregation {
}
case class Min(child: Expression) extends Aggregation {
+ override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"min($child)"
override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
@@ -100,6 +108,7 @@ case class Min(child: Expression) extends Aggregation {
}
case class Max(child: Expression) extends Aggregation {
+ override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"max($child)"
override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
@@ -117,6 +126,7 @@ case class Max(child: Expression) extends Aggregation {
}
case class Count(child: Expression) extends Aggregation {
+ override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"count($child)"
override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
@@ -131,6 +141,7 @@ case class Count(child: Expression) extends Aggregation {
}
case class Avg(child: Expression) extends Aggregation {
+ override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"avg($child)"
override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
@@ -148,6 +159,7 @@ case class Avg(child: Expression) extends Aggregation {
}
case class StddevPop(child: Expression) extends Aggregation {
+ override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"stddev_pop($child)"
override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
@@ -164,6 +176,7 @@ case class StddevPop(child: Expression) extends Aggregation {
}
case class StddevSamp(child: Expression) extends Aggregation {
+ override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"stddev_samp($child)"
override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
@@ -180,6 +193,7 @@ case class StddevSamp(child: Expression) extends Aggregation {
}
case class VarPop(child: Expression) extends Aggregation {
+ override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"var_pop($child)"
override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
@@ -196,6 +210,7 @@ case class VarPop(child: Expression) extends Aggregation {
}
case class VarSamp(child: Expression) extends Aggregation {
+ override private[flink] def children: Seq[Expression] = Seq(child)
override def toString = s"var_samp($child)"
override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
@@ -210,3 +225,57 @@ case class VarSamp(child: Expression) extends Aggregation {
override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
new SqlAvgAggFunction(VAR_SAMP)
}
+
+case class AggFunctionCall(
+ aggregateFunction: AggregateFunction[_, _],
+ args: Seq[Expression])
+ extends Aggregation {
+
+ override private[flink] def children: Seq[Expression] = args
+
+ override def resultType: TypeInformation[_] = getResultTypeOfAggregateFunction(aggregateFunction)
+
+ override def validateInput(): ValidationResult = {
+ val signature = children.map(_.resultType)
+ // look for a signature that matches the input types
+ val foundSignature = getAccumulateMethodSignature(aggregateFunction, signature)
+ if (foundSignature.isEmpty) {
+ ValidationFailure(s"Given parameters do not match any signature. \n" +
+ s"Actual: ${signatureToString(signature)} \n" +
+ s"Expected: ${
+ getMethodSignatures(aggregateFunction, "accumulate").drop(1)
+ .map(signatureToString).mkString(", ")}")
+ } else {
+ ValidationSuccess
+ }
+ }
+
+ override def toString(): String = s"${aggregateFunction.getClass.getSimpleName}($args)"
+
+ override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
+ val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
+ val sqlFunction = AggSqlFunction(aggregateFunction.getClass.getSimpleName,
+ aggregateFunction,
+ resultType,
+ typeFactory)
+ relBuilder.aggregateCall(sqlFunction, false, null, name, args.map(_.toRexNode): _*)
+ }
+
+ override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
+ val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
+ AggSqlFunction(aggregateFunction.getClass.getSimpleName,
+ aggregateFunction,
+ resultType,
+ typeFactory)
+ }
+
+ override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
+ val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
+ relBuilder.call(
+ AggSqlFunction(aggregateFunction.getClass.getSimpleName,
+ aggregateFunction,
+ resultType,
+ typeFactory),
+ args.map(_.toRexNode): _*)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala
index 68ed688..5f7204a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala
@@ -106,8 +106,8 @@ case class OverCall(
.getTypeFactory.asInstanceOf[FlinkTypeFactory]
.createTypeFromTypeInfo(agg.resultType)
- val aggChildName = agg.asInstanceOf[Aggregation].child.asInstanceOf[ResolvedFieldReference].name
- val aggExprs = List(relBuilder.field(aggChildName).asInstanceOf[RexNode]).asJava
+ // assemble exprs by agg children
+ val aggExprs = agg.asInstanceOf[Aggregation].children.map(_.toRexNode(relBuilder)).asJava
// assemble order by key
val orderKey = orderBy match {
@@ -281,16 +281,19 @@ case class ScalarFunctionCall(
override def toString =
s"${scalarFunction.getClass.getCanonicalName}(${parameters.mkString(", ")})"
- override private[flink] def resultType = getResultType(scalarFunction, foundSignature.get)
+ override private[flink] def resultType =
+ getResultTypeOfScalarFunction(
+ scalarFunction,
+ foundSignature.get)
override private[flink] def validateInput(): ValidationResult = {
val signature = children.map(_.resultType)
// look for a signature that matches the input types
- foundSignature = getSignature(scalarFunction, signature)
+ foundSignature = getEvalMethodSignature(scalarFunction, signature)
if (foundSignature.isEmpty) {
ValidationFailure(s"Given parameters do not match any signature. \n" +
s"Actual: ${signatureToString(signature)} \n" +
- s"Expected: ${signaturesToString(scalarFunction)}")
+ s"Expected: ${signaturesToString(scalarFunction, "eval")}")
} else {
ValidationSuccess
}
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
index 7a74112..9c79439 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
@@ -89,7 +89,7 @@ package org.apache.flink.table.functions
*
* {{{
* Returns the [[org.apache.flink.api.common.typeinfo.TypeInformation]] of the accumulator. This
- * function is optional and can be implemented if the accumulator type cannot automatically
+ * function is optional and can be implemented if the accumulator type cannot be automatically
* inferred from the instance returned by createAccumulator method.
*
* @return the type information for the accumulator.
@@ -98,6 +98,18 @@ package org.apache.flink.table.functions
* }}}
*
*
+ * {{{
+ * Returns the [[org.apache.flink.api.common.typeinfo.TypeInformation]] of the return value. This
+ * function is optional and needed in case Flink's type extraction facilities are not sufficient
+ * to extract the TypeInformation. Flink's type extraction facilities can handle basic types or
+ * simple POJOs but might be wrong for more complex, custom, or composite types.
+ *
+ * @return the type information for the return value.
+ *
+ * def getResultType: TypeInformation[_]
+ * }}}
+ *
+ *
* @tparam T the type of the aggregation result
* @tparam ACC base class for aggregate Accumulator. The accumulator is used to keep the aggregated
* values which are needed to compute an aggregation result. AggregateFunction
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
index 77341cd..2b8ec14 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
@@ -68,4 +68,7 @@ class CountAggFunction extends AggregateFunction[Long, CountAccumulator] {
def getAccumulatorType(): TypeInformation[_] = {
new TupleTypeInfo((new CountAccumulator).getClass, BasicTypeInfo.LONG_TYPE_INFO)
}
+
+ def getResultType(): TypeInformation[_] =
+ BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[_]]
}
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala
new file mode 100644
index 0000000..c3f6c4c
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions.utils
+
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.sql._
+import org.apache.calcite.sql.`type`._
+import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency
+import org.apache.calcite.sql.parser.SqlParserPos
+import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction
+import org.apache.flink.api.common.typeinfo._
+import org.apache.flink.table.api.ValidationException
+import org.apache.flink.table.calcite.FlinkTypeFactory
+import org.apache.flink.table.functions.AggregateFunction
+import org.apache.flink.table.functions.utils.AggSqlFunction.{createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference}
+import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
+
+/**
+ * Calcite wrapper for user-defined aggregate functions.
+ *
+ * @param name function name (used by SQL parser)
+ * @param aggregateFunction aggregate function to be called
+ * @param returnType the type information of returned value
+ * @param typeFactory type factory for converting Flink's between Calcite's types
+ */
+class AggSqlFunction(
+ name: String,
+ aggregateFunction: AggregateFunction[_, _],
+ returnType: TypeInformation[_],
+ typeFactory: FlinkTypeFactory)
+ extends SqlUserDefinedAggFunction(
+ new SqlIdentifier(name, SqlParserPos.ZERO),
+ createReturnTypeInference(returnType, typeFactory),
+ createOperandTypeInference(aggregateFunction, typeFactory),
+ createOperandTypeChecker(aggregateFunction),
+ // Do not need to provide a calcite aggregateFunction here. Flink aggregateion function
+ // will be generated when translating the calcite relnode to flink runtime execution plan
+ null
+ ) {
+
+ def getFunction: AggregateFunction[_, _] = aggregateFunction
+}
+
+object AggSqlFunction {
+
+ def apply(
+ name: String,
+ aggregateFunction: AggregateFunction[_, _],
+ returnType: TypeInformation[_],
+ typeFactory: FlinkTypeFactory): AggSqlFunction = {
+
+ new AggSqlFunction(name, aggregateFunction, returnType, typeFactory)
+ }
+
+ private[flink] def createOperandTypeInference(
+ aggregateFunction: AggregateFunction[_, _],
+ typeFactory: FlinkTypeFactory)
+ : SqlOperandTypeInference = {
+ /**
+ * Operand type inference based on [[AggregateFunction]] given information.
+ */
+ new SqlOperandTypeInference {
+ override def inferOperandTypes(
+ callBinding: SqlCallBinding,
+ returnType: RelDataType,
+ operandTypes: Array[RelDataType]): Unit = {
+
+ val operandTypeInfo = getOperandTypeInfo(callBinding)
+
+ val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo)
+ .getOrElse(
+ throw new ValidationException(
+ s"Operand types of ${signatureToString(operandTypeInfo)} could not be inferred."))
+
+ val inferredTypes = getParameterTypes(aggregateFunction, foundSignature.drop(1))
+ .map(typeFactory.createTypeFromTypeInfo)
+
+ for (i <- operandTypes.indices) {
+ if (i < inferredTypes.length - 1) {
+ operandTypes(i) = inferredTypes(i)
+ } else if (null != inferredTypes.last.getComponentType) {
+ // last argument is a collection, the array type
+ operandTypes(i) = inferredTypes.last.getComponentType
+ } else {
+ operandTypes(i) = inferredTypes.last
+ }
+ }
+ }
+ }
+ }
+
+ private[flink] def createReturnTypeInference(
+ resultType: TypeInformation[_],
+ typeFactory: FlinkTypeFactory)
+ : SqlReturnTypeInference = {
+
+ new SqlReturnTypeInference {
+ override def inferReturnType(opBinding: SqlOperatorBinding): RelDataType = {
+ typeFactory.createTypeFromTypeInfo(resultType)
+ }
+ }
+ }
+
+ private[flink] def createOperandTypeChecker(aggregateFunction: AggregateFunction[_, _])
+ : SqlOperandTypeChecker = {
+
+ val signatures = getMethodSignatures(aggregateFunction, "accumulate")
+
+ /**
+ * Operand type checker based on [[AggregateFunction]] given information.
+ */
+ new SqlOperandTypeChecker {
+ override def getAllowedSignatures(op: SqlOperator, opName: String): String = {
+ s"$opName[${signaturesToString(aggregateFunction, "accumulate")}]"
+ }
+
+ override def getOperandCountRange: SqlOperandCountRange = {
+ var min = 255
+ var max = -1
+ signatures.foreach(
+ sig => {
+ //do not count accumulator as input
+ val inputSig = sig.drop(1)
+ var len = inputSig.length
+ if (len > 0 && inputSig.last.isArray) {
+ max = 253 // according to JVM spec 4.3.3
+ len = sig.length - 1
+ }
+ max = Math.max(len, max)
+ min = Math.min(len, min)
+ })
+ SqlOperandCountRanges.between(min, max)
+ }
+
+ override def checkOperandTypes(
+ callBinding: SqlCallBinding,
+ throwOnFailure: Boolean)
+ : Boolean = {
+ val operandTypeInfo = getOperandTypeInfo(callBinding)
+
+ val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo)
+
+ if (foundSignature.isEmpty) {
+ if (throwOnFailure) {
+ throw new ValidationException(
+ s"Given parameters of function do not match any signature. \n" +
+ s"Actual: ${signatureToString(operandTypeInfo)} \n" +
+ s"Expected: ${signaturesToString(aggregateFunction, "accumulate")}")
+ } else {
+ false
+ }
+ } else {
+ true
+ }
+ }
+
+ override def isOptional(i: Int): Boolean = false
+
+ override def getConsistency: Consistency = Consistency.NONE
+
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
index e2cd272..bbfa3aa 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
@@ -23,12 +23,11 @@ import org.apache.calcite.sql._
import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency
import org.apache.calcite.sql.`type`._
import org.apache.calcite.sql.parser.SqlParserPos
-import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.functions.ScalarFunction
-import org.apache.flink.table.functions.utils.ScalarSqlFunction.{createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference}
-import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getResultType, getSignature, getSignatures, signatureToString, signaturesToString}
+import org.apache.flink.table.functions.utils.ScalarSqlFunction._
+import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
import scala.collection.JavaConverters._
@@ -77,14 +76,14 @@ object ScalarSqlFunction {
FlinkTypeFactory.toTypeInfo(operandType)
}
}
- val foundSignature = getSignature(scalarFunction, parameters)
+ val foundSignature = getEvalMethodSignature(scalarFunction, parameters)
if (foundSignature.isEmpty) {
throw new ValidationException(
s"Given parameters of function '$name' do not match any signature. \n" +
s"Actual: ${signatureToString(parameters)} \n" +
- s"Expected: ${signaturesToString(scalarFunction)}")
+ s"Expected: ${signaturesToString(scalarFunction, "eval")}")
}
- val resultType = getResultType(scalarFunction, foundSignature.get)
+ val resultType = getResultTypeOfScalarFunction(scalarFunction, foundSignature.get)
val t = typeFactory.createTypeFromTypeInfo(resultType)
typeFactory.createTypeWithNullability(t, nullable = true)
}
@@ -106,7 +105,7 @@ object ScalarSqlFunction {
val operandTypeInfo = getOperandTypeInfo(callBinding)
- val foundSignature = getSignature(scalarFunction, operandTypeInfo)
+ val foundSignature = getEvalMethodSignature(scalarFunction, operandTypeInfo)
.getOrElse(throw new ValidationException(s"Operand types of could not be inferred."))
val inferredTypes = scalarFunction
@@ -132,14 +131,14 @@ object ScalarSqlFunction {
scalarFunction: ScalarFunction)
: SqlOperandTypeChecker = {
- val signatures = getSignatures(scalarFunction)
+ val signatures = getMethodSignatures(scalarFunction, "eval")
/**
* Operand type checker based on [[ScalarFunction]] given information.
*/
new SqlOperandTypeChecker {
override def getAllowedSignatures(op: SqlOperator, opName: String): String = {
- s"$opName[${signaturesToString(scalarFunction)}]"
+ s"$opName[${signaturesToString(scalarFunction, "eval")}]"
}
override def getOperandCountRange: SqlOperandCountRange = {
@@ -163,14 +162,14 @@ object ScalarSqlFunction {
: Boolean = {
val operandTypeInfo = getOperandTypeInfo(callBinding)
- val foundSignature = getSignature(scalarFunction, operandTypeInfo)
+ val foundSignature = getEvalMethodSignature(scalarFunction, operandTypeInfo)
if (foundSignature.isEmpty) {
if (throwOnFailure) {
throw new ValidationException(
s"Given parameters of function '$name' do not match any signature. \n" +
s"Actual: ${signatureToString(operandTypeInfo)} \n" +
- s"Expected: ${signaturesToString(scalarFunction)}")
+ s"Expected: ${signaturesToString(scalarFunction, "eval")}")
} else {
false
}
@@ -185,16 +184,4 @@ object ScalarSqlFunction {
}
}
-
- private[flink] def getOperandTypeInfo(callBinding: SqlCallBinding): Seq[TypeInformation[_]] = {
- val operandTypes = for (i <- 0 until callBinding.getOperandCount)
- yield callBinding.getOperandType(i)
- operandTypes.map { operandType =>
- if (operandType.getSqlTypeName == SqlTypeName.NULL) {
- null
- } else {
- FlinkTypeFactory.toTypeInfo(operandType)
- }
- }
- }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index d108e31..689bf0e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -25,15 +25,16 @@ import java.sql.{Date, Time, Timestamp}
import org.apache.commons.codec.binary.Base64
import com.google.common.primitives.Primitives
-import org.apache.calcite.sql.SqlFunction
+import org.apache.calcite.sql.`type`.SqlTypeName
+import org.apache.calcite.sql.{SqlCallBinding, SqlFunction}
import org.apache.flink.api.common.functions.InvalidTypesException
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.TypeExtractor
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException}
import org.apache.flink.table.expressions._
-import org.apache.flink.table.functions.{ScalarFunction, TableFunction, UserDefinedFunction}
import org.apache.flink.table.plan.logical._
+import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, TableFunction, UserDefinedFunction}
import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl
import org.apache.flink.util.InstantiationUtil
@@ -69,52 +70,89 @@ object UserDefinedFunctionUtils {
}
// ----------------------------------------------------------------------------------------------
- // Utilities for eval methods
+ // Utilities for user-defined methods
// ----------------------------------------------------------------------------------------------
/**
- * Returns signatures matching the given signature of [[TypeInformation]].
+ * Returns signatures of eval methods matching the given signature of [[TypeInformation]].
* Elements of the signature can be null (act as a wildcard).
*/
- def getSignature(
- function: UserDefinedFunction,
- signature: Seq[TypeInformation[_]])
+ def getEvalMethodSignature(
+ function: UserDefinedFunction,
+ signature: Seq[TypeInformation[_]])
: Option[Array[Class[_]]] = {
- getEvalMethod(function, signature).map(_.getParameterTypes)
+ getUserDefinedMethod(function, "eval", typeInfoToClass(signature)).map(_.getParameterTypes)
}
/**
- * Returns eval method matching the given signature of [[TypeInformation]].
+ * Returns signatures of accumulate methods matching the given signature of [[TypeInformation]].
+ * Elements of the signature can be null (act as a wildcard).
*/
- def getEvalMethod(
- function: UserDefinedFunction,
+ def getAccumulateMethodSignature(
+ function: AggregateFunction[_, _],
signature: Seq[TypeInformation[_]])
+ : Option[Array[Class[_]]] = {
+ val accType = TypeExtractor.createTypeInfo(
+ function, classOf[AggregateFunction[_, _]], function.getClass, 1)
+ val input = (Array(accType) ++ signature).toSeq
+ getUserDefinedMethod(
+ function,
+ "accumulate",
+ typeInfoToClass(input)).map(_.getParameterTypes)
+ }
+
+ def getParameterTypes(
+ function: UserDefinedFunction,
+ signature: Array[Class[_]]): Array[TypeInformation[_]] = {
+ signature.map { c =>
+ try {
+ TypeExtractor.getForClass(c)
+ } catch {
+ case ite: InvalidTypesException =>
+ throw new ValidationException(
+ s"Parameter types of function '${function.getClass.getCanonicalName}' cannot be " +
+ s"automatically determined. Please provide type information manually.")
+ }
+ }
+ }
+
+ /**
+ * Returns user defined method matching the given name and signature.
+ *
+ * @param function function instance
+ * @param methodName method name
+ * @param methodSignature an array of raw Java classes. We compare the raw Java classes not the
+ * TypeInformation. TypeInformation does not matter during runtime (e.g.
+ * within a MapFunction)
+ */
+ def getUserDefinedMethod(
+ function: UserDefinedFunction,
+ methodName: String,
+ methodSignature: Array[Class[_]])
: Option[Method] = {
- // We compare the raw Java classes not the TypeInformation.
- // TypeInformation does not matter during runtime (e.g. within a MapFunction).
- val actualSignature = typeInfoToClass(signature)
- val evalMethods = checkAndExtractEvalMethods(function)
- val filtered = evalMethods
- // go over all eval methods and filter out matching methods
+ val methods = checkAndExtractMethods(function, methodName)
+
+ val filtered = methods
+ // go over all the methods and filter out matching methods
.filter {
case cur if !cur.isVarArgs =>
val signatures = cur.getParameterTypes
// match parameters of signature to actual parameters
- actualSignature.length == signatures.length &&
+ methodSignature.length == signatures.length &&
signatures.zipWithIndex.forall { case (clazz, i) =>
- parameterTypeEquals(actualSignature(i), clazz)
+ parameterTypeEquals(methodSignature(i), clazz)
}
case cur if cur.isVarArgs =>
val signatures = cur.getParameterTypes
- actualSignature.zipWithIndex.forall {
+ methodSignature.zipWithIndex.forall {
// non-varargs
case (clazz, i) if i < signatures.length - 1 =>
parameterTypeEquals(clazz, signatures(i))
// varargs
case (clazz, i) if i >= signatures.length - 1 =>
parameterTypeEquals(clazz, signatures.last.getComponentType)
- } || (actualSignature.isEmpty && signatures.length == 1) // empty varargs
+ } || (methodSignature.isEmpty && signatures.length == 1) // empty varargs
}
// if there is a fixed method, compiler will call this method preferentially
@@ -126,19 +164,21 @@ object UserDefinedFunctionUtils {
// check if there is a Scala varargs annotation
if (found.isEmpty &&
- evalMethods.exists { evalMethod =>
- val signatures = evalMethod.getParameterTypes
+ methods.exists { method =>
+ val signatures = method.getParameterTypes
signatures.zipWithIndex.forall {
case (clazz, i) if i < signatures.length - 1 =>
- parameterTypeEquals(actualSignature(i), clazz)
+ parameterTypeEquals(methodSignature(i), clazz)
case (clazz, i) if i == signatures.length - 1 =>
clazz.getName.equals("scala.collection.Seq")
}
}) {
- throw new ValidationException("Scala-style variable arguments in 'eval' methods are not " +
- "supported. Please add a @scala.annotation.varargs annotation.")
+ throw new ValidationException(
+ s"Scala-style variable arguments in '${methodName}' methods are not supported. Please " +
+ s"add a @scala.annotation.varargs annotation.")
} else if (found.length > 1) {
- throw new ValidationException("Found multiple 'eval' methods which match the signature.")
+ throw new ValidationException(
+ "Found multiple '${methodName}' methods which match the signature.")
}
found.headOption
}
@@ -157,16 +197,18 @@ object UserDefinedFunctionUtils {
}
/**
- * Extracts "eval" methods and throws a [[ValidationException]] if no implementation
+ * Extracts methods and throws a [[ValidationException]] if no implementation
* can be found, or implementation does not match the requirements.
*/
- def checkAndExtractEvalMethods(function: UserDefinedFunction): Array[Method] = {
+ def checkAndExtractMethods(
+ function: UserDefinedFunction,
+ methodName: String): Array[Method] = {
val methods = function
.getClass
- .getDeclaredMethods
+ .getMethods
.filter { m =>
val modifiers = m.getModifiers
- m.getName == "eval" &&
+ m.getName == methodName &&
Modifier.isPublic(modifiers) &&
!Modifier.isAbstract(modifiers) &&
!(function.isInstanceOf[TableFunction[_]] && Modifier.isStatic(modifiers))
@@ -175,15 +217,17 @@ object UserDefinedFunctionUtils {
if (methods.isEmpty) {
throw new ValidationException(
s"Function class '${function.getClass.getCanonicalName}' does not implement at least " +
- s"one method named 'eval' which is public, not abstract and " +
+ s"one method named '${methodName}' which is public, not abstract and " +
s"(in case of table functions) not static.")
}
methods
}
- def getSignatures(function: UserDefinedFunction): Array[Array[Class[_]]] = {
- checkAndExtractEvalMethods(function).map(_.getParameterTypes)
+ def getMethodSignatures(
+ function: UserDefinedFunction,
+ methodName: String): Array[Array[Class[_]]] = {
+ checkAndExtractMethods(function, methodName).map(_.getParameterTypes)
}
// ----------------------------------------------------------------------------------------------
@@ -222,7 +266,7 @@ object UserDefinedFunctionUtils {
typeFactory: FlinkTypeFactory)
: Seq[SqlFunction] = {
val (fieldNames, fieldIndexes, _) = UserDefinedFunctionUtils.getFieldInfo(resultType)
- val evalMethods = checkAndExtractEvalMethods(tableFunction)
+ val evalMethods = checkAndExtractMethods(tableFunction, "eval")
evalMethods.map { method =>
val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, method)
@@ -230,29 +274,75 @@ object UserDefinedFunctionUtils {
}
}
+ /**
+ * Create [[SqlFunction]] for an [[AggregateFunction]]
+ *
+ * @param name function name
+ * @param aggFunction aggregate function
+ * @param typeFactory type factory
+ * @return the TableSqlFunction
+ */
+ def createAggregateSqlFunction(
+ name: String,
+ aggFunction: AggregateFunction[_, _],
+ typeInfo: TypeInformation[_],
+ typeFactory: FlinkTypeFactory)
+ : SqlFunction = {
+ //check if a qualified accumulate method exists before create Sql function
+ checkAndExtractMethods(aggFunction, "accumulate")
+ val resultType: TypeInformation[_] = getResultTypeOfAggregateFunction(aggFunction, typeInfo)
+ AggSqlFunction(name, aggFunction, resultType, typeFactory)
+ }
+
// ----------------------------------------------------------------------------------------------
- // Utilities for scalar functions
+ // Utilities for user-defined functions
// ----------------------------------------------------------------------------------------------
/**
+ * Internal method of AggregateFunction#getResultType() that does some pre-checking and uses
+ * [[TypeExtractor]] as default return type inference.
+ */
+ def getResultTypeOfAggregateFunction(
+ aggregateFunction: AggregateFunction[_, _],
+ extractedType: TypeInformation[_] = null)
+ : TypeInformation[_] = {
+
+ val resultType = try {
+ val method: Method = aggregateFunction.getClass.getMethod("getResultType")
+ method.invoke(aggregateFunction).asInstanceOf[TypeInformation[_]]
+ } catch {
+ case _: NoSuchMethodException => null
+ case ite: Throwable => throw new TableException("Unexpected exception:", ite)
+ }
+ if (resultType != null) {
+ resultType
+ } else if(extractedType != null) {
+ extractedType
+ } else {
+ TypeExtractor
+ .createTypeInfo(aggregateFunction,
+ classOf[AggregateFunction[_, _]],
+ aggregateFunction.getClass,
+ 0)
+ .asInstanceOf[TypeInformation[_]]
+ }
+ }
+
+ /**
* Internal method of [[ScalarFunction#getResultType()]] that does some pre-checking and uses
* [[TypeExtractor]] as default return type inference.
*/
- def getResultType(
+ def getResultTypeOfScalarFunction(
function: ScalarFunction,
signature: Array[Class[_]])
: TypeInformation[_] = {
- // find method for signature
- val evalMethod = checkAndExtractEvalMethods(function)
- .find(m => signature.sameElements(m.getParameterTypes))
- .getOrElse(throw new ValidationException("Given signature is invalid."))
val userDefinedTypeInfo = function.getResultType(signature)
if (userDefinedTypeInfo != null) {
userDefinedTypeInfo
} else {
try {
- TypeExtractor.getForClass(evalMethod.getReturnType)
+ TypeExtractor.getForClass(getResultTypeClassOfScalarFunction(function, signature))
} catch {
case ite: InvalidTypesException =>
throw new ValidationException(
@@ -265,12 +355,12 @@ object UserDefinedFunctionUtils {
/**
* Returns the return type of the evaluation method matching the given signature.
*/
- def getResultTypeClass(
+ def getResultTypeClassOfScalarFunction(
function: ScalarFunction,
signature: Array[Class[_]])
: Class[_] = {
// find method for signature
- val evalMethod = checkAndExtractEvalMethods(function)
+ val evalMethod = checkAndExtractMethods(function, "eval")
.find(m => signature.sameElements(m.getParameterTypes))
.getOrElse(throw new IllegalArgumentException("Given signature is invalid."))
evalMethod.getReturnType
@@ -317,16 +407,16 @@ object UserDefinedFunctionUtils {
}
/**
- * Prints all eval methods signatures of a class.
+ * Prints all signatures of methods with given name in a class.
*/
- def signaturesToString(function: UserDefinedFunction): String = {
- getSignatures(function).map(signatureToString).mkString(", ")
+ def signaturesToString(function: UserDefinedFunction, name: String): String = {
+ getMethodSignatures(function, name).map(signatureToString).mkString(", ")
}
/**
* Extracts type classes of [[TypeInformation]] in a null-aware way.
*/
- private def typeInfoToClass(typeInfos: Seq[TypeInformation[_]]): Array[Class[_]] =
+ def typeInfoToClass(typeInfos: Seq[TypeInformation[_]]): Array[Class[_]] =
typeInfos.map { typeInfo =>
if (typeInfo == null) {
null
@@ -393,4 +483,16 @@ object UserDefinedFunctionUtils {
.as(alias).toLogicalTableFunctionCall(child = null)
functionCall
}
+
+ def getOperandTypeInfo(callBinding: SqlCallBinding): Seq[TypeInformation[_]] = {
+ val operandTypes = for (i <- 0 until callBinding.getOperandCount)
+ yield callBinding.getOperandType(i)
+ operandTypes.map { operandType =>
+ if (operandType.getSqlTypeName == SqlTypeName.NULL) {
+ null
+ } else {
+ FlinkTypeFactory.toTypeInfo(operandType)
+ }
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala
index 0d45a37..d26cdcf 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala
@@ -303,6 +303,11 @@ object ProjectionTranslator {
(fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
}
+ case aggfc @ AggFunctionCall(clazz, args) =>
+ args.foldLeft(fieldReferences) {
+ (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
+ }
+
// array constructor
case c @ ArrayConstructor(args) =>
args.foldLeft(fieldReferences) {
@@ -327,4 +332,56 @@ object ProjectionTranslator {
}
}
+ /**
+ * Find and replace UDAGG function Call to AggFunctionCall
+ *
+ * @param field the expression to check
+ * @param tableEnv the TableEnvironment
+ * @return an expression with correct AggFunctionCall type for UDAGG functions
+ */
+ def replaceAggFunctionCall(field: Expression, tableEnv: TableEnvironment): Expression = {
+ field match {
+ case l: LeafExpression => l
+
+ case u: UnaryExpression =>
+ val c = replaceAggFunctionCall(u.child, tableEnv)
+ u.makeCopy(Array(c))
+
+ case b: BinaryExpression =>
+ val l = replaceAggFunctionCall(b.left, tableEnv)
+ val r = replaceAggFunctionCall(b.right, tableEnv)
+ b.makeCopy(Array(l, r))
+
+ // Functions calls
+ case c @ Call(name, args) =>
+ val function = tableEnv.getFunctionCatalog.lookupFunction(name, args)
+ if (function.isInstanceOf[AggFunctionCall]) {
+ function
+ } else {
+ val newArgs =
+ args.map(
+ (exp: Expression) =>
+ replaceAggFunctionCall(exp, tableEnv))
+ c.makeCopy(Array(name, newArgs))
+ }
+
+ // Scala functions
+ case sfc @ ScalarFunctionCall(clazz, args) =>
+ val newArgs: Seq[Expression] =
+ args.map(
+ (exp: Expression) =>
+ replaceAggFunctionCall(exp, tableEnv))
+ sfc.makeCopy(Array(clazz, newArgs))
+
+ // Array constructor
+ case c @ ArrayConstructor(args) =>
+ val newArgs =
+ c.elements
+ .map((exp: Expression) => replaceAggFunctionCall(exp, tableEnv))
+ c.makeCopy(Array(newArgs))
+
+ // Other expressions
+ case e: Expression => e
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
index c67bfd1..5f2394c 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala
@@ -685,12 +685,12 @@ case class LogicalTableFunctionCall(
checkForInstantiation(tableFunction.getClass)
// look for a signature that matches the input types
val signature = node.parameters.map(_.resultType)
- val foundMethod = getEvalMethod(tableFunction, signature)
+ val foundMethod = getUserDefinedMethod(tableFunction, "eval", typeInfoToClass(signature))
if (foundMethod.isEmpty) {
failValidation(
s"Given parameters of function '$functionName' do not match any signature. \n" +
s"Actual: ${signatureToString(signature)} \n" +
- s"Expected: ${signaturesToString(tableFunction)}")
+ s"Expected: ${signaturesToString(tableFunction, "eval")}")
} else {
node.evalMethod = foundMethod.get
}
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonAggregate.scala
index 3883b14..e95747c 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonAggregate.scala
@@ -50,7 +50,7 @@ trait CommonAggregate {
val aggs = namedAggregates.map(_.getKey)
val aggStrings = aggs.map( a => s"${a.getAggregation}(${
if (a.getArgList.size() > 0) {
- inFields(a.getArgList.get(0))
+ a.getArgList.asScala.map(inFields(_)).mkString(", ")
} else {
"*"
}
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
index 91c8cef..6878473 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
@@ -88,7 +88,7 @@ trait OverAggregate {
val aggStrings = namedAggregates.map(_.getKey).map(
a => s"${a.getAggregation}(${
if (a.getArgList.size() > 0) {
- inFields(a.getArgList.get(0))
+ a.getArgList.asScala.map(inFields(_)).mkString(", ")
} else {
"*"
}
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
index 187773d..c232a71 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
@@ -122,6 +122,12 @@ class DataStreamAggregate(
false,
inputDS.getType)
+ val needMerge = window match {
+ case ProcessingTimeSessionGroupWindow(_, _) => true
+ case EventTimeSessionGroupWindow(_, _, _) => true
+ case _ => false
+ }
+
// grouped / keyed aggregation
if (grouping.length > 0) {
val windowFunction = AggregateUtil.createAggregationGroupWindowFunction(
@@ -141,7 +147,8 @@ class DataStreamAggregate(
generator,
namedAggregates,
inputType,
- rowRelDataType)
+ rowRelDataType,
+ needMerge)
windowedStream
.aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo)
@@ -163,7 +170,8 @@ class DataStreamAggregate(
generator,
namedAggregates,
inputType,
- rowRelDataType)
+ rowRelDataType,
+ needMerge)
windowedStream
.aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo)
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index 2950a78..e38207d 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -39,6 +39,7 @@ import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.CodeGenerator
import org.apache.flink.table.expressions._
import org.apache.flink.table.functions.aggfunctions._
+import org.apache.flink.table.functions.utils.AggSqlFunction
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
import org.apache.flink.table.functions.{AggregateFunction => TableAggregateFunction}
import org.apache.flink.table.plan.logical._
@@ -101,7 +102,8 @@ object AggregateUtil {
None,
outputArity,
needRetract,
- needMerge = false
+ needMerge = false,
+ needReset = false
)
if (isRowTimeType) {
@@ -178,7 +180,8 @@ object AggregateUtil {
None,
outputArity,
needRetract,
- needMerge = false
+ needMerge = false,
+ needReset = true
)
if (isRowTimeType) {
@@ -303,7 +306,8 @@ object AggregateUtil {
None,
outputArity,
needRetract,
- needMerge = false
+ needMerge = false,
+ needReset = true
)
new DataSetWindowAggMapFunction(
@@ -374,12 +378,13 @@ object AggregateUtil {
aggFieldIndexes,
aggregates.indices.map(_ + groupings.length).toArray,
partialResults = true,
- groupings,
+ groupings.indices.toArray,
Some(aggregates.indices.map(_ + groupings.length).toArray),
None,
keysAndAggregatesArity + 1,
needRetract,
- needMerge = true
+ needMerge = true,
+ needReset = true
)
new DataSetSlideTimeWindowAggReduceGroupFunction(
genFunction,
@@ -481,7 +486,8 @@ object AggregateUtil {
None,
outputType.getFieldCount,
needRetract,
- needMerge = true
+ needMerge = true,
+ needReset = true
)
val genFinalAggFunction = generator.generateAggregations(
@@ -497,7 +503,8 @@ object AggregateUtil {
None,
outputType.getFieldCount,
needRetract,
- needMerge = true
+ needMerge = true,
+ needReset = true
)
val keysAndAggregatesArity = groupings.length + namedAggregates.length
@@ -636,7 +643,8 @@ object AggregateUtil {
None,
groupings.length + aggregates.length + 2,
needRetract,
- needMerge = true
+ needMerge = true,
+ needReset = true
)
new DataSetSessionWindowAggregatePreProcessor(
@@ -708,7 +716,8 @@ object AggregateUtil {
None,
groupings.length + aggregates.length + 2,
needRetract,
- needMerge = true
+ needMerge = true,
+ needReset = true
)
new DataSetSessionWindowAggregatePreProcessor(
@@ -787,7 +796,8 @@ object AggregateUtil {
None,
groupings.length + aggregates.length,
needRetract,
- needMerge = false
+ needMerge = false,
+ needReset = true
)
// compute mapping of forwarded grouping keys
@@ -813,7 +823,8 @@ object AggregateUtil {
constantFlags,
outputType.getFieldCount,
needRetract,
- needMerge = true
+ needMerge = true,
+ needReset = true
)
(
@@ -836,7 +847,8 @@ object AggregateUtil {
constantFlags,
outputType.getFieldCount,
needRetract,
- needMerge = false
+ needMerge = false,
+ needReset = true
)
(
@@ -902,7 +914,8 @@ object AggregateUtil {
generator: CodeGenerator,
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
- outputType: RelDataType)
+ outputType: RelDataType,
+ needMerge: Boolean)
: (DataStreamAggFunction[Row, Row, Row], RowTypeInfo, RowTypeInfo) = {
val needRetract = false
@@ -928,7 +941,8 @@ object AggregateUtil {
None,
outputArity,
needRetract,
- needMerge = true
+ needMerge,
+ needReset = false
)
val aggResultTypes = namedAggregates.map(a => FlinkTypeFactory.toTypeInfo(a.left.getType))
@@ -1083,9 +1097,6 @@ object AggregateUtil {
throw new TableException("Aggregate fields should not be empty.")
}
} else {
- if (argList.size() > 1) {
- throw new TableException("Currently, do not support aggregate on multi fields.")
- }
aggFieldIndexes(index) = argList.asScala.map(i => i.intValue).toArray
}
val sqlTypeName = inputType.getFieldList.get(aggFieldIndexes(index)(0)).getType
@@ -1298,6 +1309,9 @@ object AggregateUtil {
case _: SqlCountAggFunction =>
aggregates(index) = new CountAggFunction
+ case udagg: AggSqlFunction =>
+ aggregates(index) = udagg.getFunction
+
case unSupported: SqlAggFunction =>
throw new TableException("unsupported Function: " + unSupported.getName)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/981dea41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
index 1022e4d..63dc1ae 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
@@ -23,8 +23,8 @@ import org.apache.calcite.sql.util.{ChainedSqlOperatorTable, ListSqlOperatorTabl
import org.apache.calcite.sql.{SqlFunction, SqlOperator, SqlOperatorTable}
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.expressions._
-import org.apache.flink.table.functions.utils.{ScalarSqlFunction, TableSqlFunction}
-import org.apache.flink.table.functions.{EventTimeExtractor, RowTime, ScalarFunction, TableFunction, _}
+import org.apache.flink.table.functions.utils.{AggSqlFunction, ScalarSqlFunction, TableSqlFunction}
+import org.apache.flink.table.functions.{AggregateFunction, EventTimeExtractor, RowTime, ScalarFunction, TableFunction, _}
import scala.collection.JavaConversions._
import scala.collection.mutable
@@ -97,6 +97,15 @@ class FunctionCatalog {
val function = tableSqlFunction.getTableFunction
TableFunctionCall(name, function, children, typeInfo)
+ // user-defined aggregate function call
+ case af if classOf[AggregateFunction[_, _]].isAssignableFrom(af) =>
+ val aggregateFunction = sqlFunctions
+ .find(f => f.getName.equalsIgnoreCase(name) && f.isInstanceOf[AggSqlFunction])
+ .getOrElse(throw ValidationException(s"Undefined table function: $name"))
+ .asInstanceOf[AggSqlFunction]
+ val function = aggregateFunction.getFunction
+ AggFunctionCall(function, children)
+
// general expression call
case expression if classOf[Expression].isAssignableFrom(expression) =>
// try to find a constructor accepts `Seq[Expression]`