You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/08/29 20:11:11 UTC
[4/5] flink git commit: [FLINK-7206] [table] Add DataView to support
direct state access in AggregateFunction accumulators.
[FLINK-7206] [table] Add DataView to support direct state access in AggregateFunction accumulators.
This closes #4355.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/1fc0b641
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/1fc0b641
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/1fc0b641
Branch: refs/heads/master
Commit: 1fc0b6413c74eff0ace25f4329451e35e84849b5
Parents: 88848e7
Author: 宝牛 <ba...@taobao.com>
Authored: Wed Aug 23 17:45:05 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Tue Aug 29 22:10:17 2017 +0200
----------------------------------------------------------------------
.../flink/table/api/dataview/DataView.scala | 35 +++
.../flink/table/api/dataview/DataViewSpec.scala | 55 +++++
.../flink/table/api/dataview/ListView.scala | 142 ++++++++++++
.../flink/table/api/dataview/MapView.scala | 198 ++++++++++++++++
.../codegen/AggregationCodeGenerator.scala | 231 ++++++++++++++++---
.../flink/table/codegen/CodeGenerator.scala | 27 ++-
.../table/dataview/ListViewSerializer.scala | 111 +++++++++
.../flink/table/dataview/ListViewTypeInfo.scala | 66 ++++++
.../dataview/ListViewTypeInfoFactory.scala | 43 ++++
.../table/dataview/MapViewSerializer.scala | 121 ++++++++++
.../flink/table/dataview/MapViewTypeInfo.scala | 72 ++++++
.../table/dataview/MapViewTypeInfoFactory.scala | 51 ++++
.../flink/table/dataview/StateListView.scala | 47 ++++
.../flink/table/dataview/StateMapView.scala | 54 +++++
.../utils/UserDefinedFunctionUtils.scala | 113 ++++++++-
.../table/runtime/aggregate/AggregateUtil.scala | 95 +++++---
.../aggregate/GeneratedAggregations.scala | 21 +-
.../aggregate/GroupAggProcessFunction.scala | 8 +-
.../aggregate/ProcTimeBoundedRangeOver.scala | 11 +-
.../aggregate/ProcTimeBoundedRowsOver.scala | 11 +-
.../aggregate/ProcTimeUnboundedOver.scala | 9 +-
.../aggregate/RowTimeBoundedRangeOver.scala | 6 +
.../aggregate/RowTimeBoundedRowsOver.scala | 6 +
.../aggregate/RowTimeUnboundedOver.scala | 5 +
.../utils/JavaUserDefinedAggFunctions.java | 199 ++++++++++++++++
.../table/dataview/ListViewSerializerTest.scala | 62 +++++
.../table/dataview/MapViewSerializerTest.scala | 68 ++++++
.../runtime/batch/table/AggregateITCase.scala | 9 +-
.../table/runtime/harness/HarnessTestBase.scala | 17 ++
.../runtime/stream/table/AggregateITCase.scala | 41 +++-
.../stream/table/GroupWindowITCase.scala | 40 ++--
.../runtime/stream/table/OverWindowITCase.scala | 118 +++++-----
32 files changed, 1933 insertions(+), 159 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/DataView.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/DataView.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/DataView.scala
new file mode 100644
index 0000000..2214086
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/DataView.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.api.dataview
+
+/**
+ * A [[DataView]] is a collection type that can be used in the accumulator of an
+ * [[org.apache.flink.table.functions.AggregateFunction]].
+ *
+ * Depending on the context in which the [[org.apache.flink.table.functions.AggregateFunction]] is
+ * used, a [[DataView]] can be backed by a Java heap collection or a state backend.
+ */
+trait DataView extends Serializable {
+
+ /**
+ * Clears the [[DataView]] and removes all data.
+ */
+ def clear(): Unit
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/DataViewSpec.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/DataViewSpec.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/DataViewSpec.scala
new file mode 100644
index 0000000..943fe03
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/DataViewSpec.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.api.dataview
+
+import java.lang.reflect.Field
+
+import org.apache.flink.api.common.state.{ListStateDescriptor, MapStateDescriptor, State, StateDescriptor}
+import org.apache.flink.table.dataview.{ListViewTypeInfo, MapViewTypeInfo}
+
+/**
+ * Data view specification.
+ *
+ * @tparam ACC type extends [[DataView]]
+ */
+trait DataViewSpec[ACC <: DataView] {
+ def stateId: String
+ def field: Field
+ def toStateDescriptor: StateDescriptor[_ <: State, _]
+}
+
+case class ListViewSpec[T](
+ stateId: String,
+ field: Field,
+ listViewTypeInfo: ListViewTypeInfo[T])
+ extends DataViewSpec[ListView[T]] {
+
+ override def toStateDescriptor: StateDescriptor[_ <: State, _] =
+ new ListStateDescriptor[T](stateId, listViewTypeInfo.elementType)
+}
+
+case class MapViewSpec[K, V](
+ stateId: String,
+ field: Field,
+ mapViewTypeInfo: MapViewTypeInfo[K, V])
+ extends DataViewSpec[MapView[K, V]] {
+
+ override def toStateDescriptor: StateDescriptor[_ <: State, _] =
+ new MapStateDescriptor[K, V](stateId, mapViewTypeInfo.keyType, mapViewTypeInfo.valueType)
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/ListView.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/ListView.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/ListView.scala
new file mode 100644
index 0000000..59b2426
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/ListView.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.api.dataview
+
+import java.lang.{Iterable => JIterable}
+import java.util
+
+import org.apache.flink.api.common.typeinfo.{TypeInfo, TypeInformation}
+import org.apache.flink.table.dataview.ListViewTypeInfoFactory
+
+/**
+ * A [[ListView]] provides List functionality for accumulators used by user-defined aggregate
+ * functions [[org.apache.flink.api.common.functions.AggregateFunction]].
+ *
+ * A [[ListView]] can be backed by a Java ArrayList or a state backend, depending on the context in
+ * which the aggregate function is used.
+ *
+ * At runtime [[ListView]] will be replaced by a [[org.apache.flink.table.dataview.StateListView]]
+ * if it is backed by a state backend.
+ *
+ * Example of an accumulator type with a [[ListView]] and an aggregate function that uses it:
+ * {{{
+ *
+ * public class MyAccum {
+ * public ListView<String> list;
+ * public long count;
+ * }
+ *
+ * public class MyAgg extends AggregateFunction<Long, MyAccum> {
+ *
+ * @Override
+ * public MyAccum createAccumulator() {
+ * MyAccum accum = new MyAccum();
+ * accum.list = new ListView<>(Types.STRING);
+ * accum.count = 0L;
+ * return accum;
+ * }
+ *
+ * public void accumulate(MyAccum accumulator, String id) {
+ * accumulator.list.add(id);
+ * ... ...
+ * accumulator.get()
+ * ... ...
+ * }
+ *
+ * @Override
+ * public Long getValue(MyAccum accumulator) {
+ * accumulator.list.add(id);
+ * ... ...
+ * accumulator.get()
+ * ... ...
+ * return accumulator.count;
+ * }
+ * }
+ *
+ * }}}
+ *
+ * @param elementTypeInfo element type information
+ * @tparam T element type
+ */
+@TypeInfo(classOf[ListViewTypeInfoFactory[_]])
+class ListView[T](
+ @transient private[flink] val elementTypeInfo: TypeInformation[T],
+ private[flink] val list: util.List[T])
+ extends DataView {
+
+ /**
+ * Creates a list view for elements of the specified type.
+ *
+ * @param elementTypeInfo The type of the list view elements.
+ */
+ def this(elementTypeInfo: TypeInformation[T]) {
+ this(elementTypeInfo, new util.ArrayList[T]())
+ }
+
+ /**
+ * Creates a list view.
+ */
+ def this() = this(null)
+
+ /**
+ * Returns an iterable of the list view.
+ *
+ * @throws Exception Thrown if the system cannot get data.
+ * @return The iterable of the list or { @code null} if the list is empty.
+ */
+ @throws[Exception]
+ def get: JIterable[T] = {
+ if (!list.isEmpty) {
+ list
+ } else {
+ null
+ }
+ }
+
+ /**
+ * Adds the given value to the list.
+ *
+ * @throws Exception Thrown if the system cannot add data.
+ * @param value The element to be appended to this list view.
+ */
+ @throws[Exception]
+ def add(value: T): Unit = list.add(value)
+
+ /**
+ * Adds all of the elements of the specified list to this list view.
+ *
+ * @throws Exception Thrown if the system cannot add all data.
+ * @param list The list with the elements that will be stored in this list view.
+ */
+ @throws[Exception]
+ def addAll(list: util.List[T]): Unit = this.list.addAll(list)
+
+ /**
+ * Removes all of the elements from this list view.
+ */
+ override def clear(): Unit = list.clear()
+
+ override def equals(other: Any): Boolean = other match {
+ case that: ListView[T] =>
+ list.equals(that.list)
+ case _ => false
+ }
+
+ override def hashCode(): Int = list.hashCode()
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/MapView.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/MapView.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/MapView.scala
new file mode 100644
index 0000000..9206d6a
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/dataview/MapView.scala
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.api.dataview
+
+import java.lang.{Iterable => JIterable}
+import java.util
+
+import org.apache.flink.api.common.typeinfo.{TypeInfo, TypeInformation}
+import org.apache.flink.table.dataview.MapViewTypeInfoFactory
+
+/**
+ * A [[MapView]] provides Map functionality for accumulators used by user-defined aggregate
+ * functions [[org.apache.flink.table.functions.AggregateFunction]].
+ *
+ * A [[MapView]] can be backed by a Java HashMap or a state backend, depending on the context in
+ * which the aggregation function is used.
+ *
+ * At runtime [[MapView]] will be replaced by a [[org.apache.flink.table.dataview.StateMapView]]
+ * if it is backed by a state backend.
+ *
+ * Example of an accumulator type with a [[MapView]] and an aggregate function that uses it:
+ * {{{
+ *
+ * public class MyAccum {
+ * public MapView<String, Integer> map;
+ * public long count;
+ * }
+ *
+ * public class MyAgg extends AggregateFunction<Long, MyAccum> {
+ *
+ * @Override
+ * public MyAccum createAccumulator() {
+ * MyAccum accum = new MyAccum();
+ * accum.map = new MapView<>(Types.STRING, Types.INT);
+ * accum.count = 0L;
+ * return accum;
+ * }
+ *
+ * public void accumulate(MyAccum accumulator, String id) {
+ * try {
+ * if (!accumulator.map.contains(id)) {
+ * accumulator.map.put(id, 1);
+ * accumulator.count++;
+ * }
+ * } catch (Exception e) {
+ * e.printStackTrace();
+ * }
+ * }
+ *
+ * @Override
+ * public Long getValue(MyAccum accumulator) {
+ * return accumulator.count;
+ * }
+ * }
+ *
+ * }}}
+ *
+ * @param keyTypeInfo key type information
+ * @param valueTypeInfo value type information
+ * @tparam K key type
+ * @tparam V value type
+ */
+@TypeInfo(classOf[MapViewTypeInfoFactory[_, _]])
+class MapView[K, V](
+ @transient private[flink] val keyTypeInfo: TypeInformation[K],
+ @transient private[flink] val valueTypeInfo: TypeInformation[V],
+ private[flink] val map: util.Map[K, V])
+ extends DataView {
+
+ /**
+ * Creates a MapView with the specified key and value types.
+ *
+ * @param keyTypeInfo The type of keys of the MapView.
+ * @param valueTypeInfo The type of the values of the MapView.
+ */
+ def this(keyTypeInfo: TypeInformation[K], valueTypeInfo: TypeInformation[V]) {
+ this(keyTypeInfo, valueTypeInfo, new util.HashMap[K, V]())
+ }
+
+ /**
+ * Creates a MapView.
+ */
+ def this() = this(null, null)
+
+ /**
+ * Return the value for the specified key or { @code null } if the key is not in the map view.
+ *
+ * @param key The look up key.
+ * @return The value for the specified key.
+ * @throws Exception Thrown if the system cannot get data.
+ */
+ @throws[Exception]
+ def get(key: K): V = map.get(key)
+
+ /**
+ * Inserts a value for the given key into the map view.
+ * If the map view already contains a value for the key, the existing value is overwritten.
+ *
+ * @param key The key for which the value is inserted.
+ * @param value The value that is inserted for the key.
+ * @throws Exception Thrown if the system cannot put data.
+ */
+ @throws[Exception]
+ def put(key: K, value: V): Unit = map.put(key, value)
+
+ /**
+ * Inserts all mappings from the specified map to this map view.
+ *
+ * @param map The map whose entries are inserted into this map view.
+ * @throws Exception Thrown if the system cannot access the map.
+ */
+ @throws[Exception]
+ def putAll(map: util.Map[K, V]): Unit = this.map.putAll(map)
+
+ /**
+ * Deletes the value for the given key.
+ *
+ * @param key The key for which the value is deleted.
+ * @throws Exception Thrown if the system cannot access the map.
+ */
+ @throws[Exception]
+ def remove(key: K): Unit = map.remove(key)
+
+ /**
+ * Checks if the map view contains a value for a given key.
+ *
+ * @param key The key to check.
+ * @return True if there exists a value for the given key, false otherwise.
+ * @throws Exception Thrown if the system cannot access the map.
+ */
+ @throws[Exception]
+ def contains(key: K): Boolean = map.containsKey(key)
+
+ /**
+ * Returns all entries of the map view.
+ *
+ * @return An iterable of all the key-value pairs in the map view.
+ * @throws Exception Thrown if the system cannot access the map.
+ */
+ @throws[Exception]
+ def entries: JIterable[util.Map.Entry[K, V]] = map.entrySet()
+
+ /**
+ * Returns all the keys in the map view.
+ *
+ * @return An iterable of all the keys in the map.
+ * @throws Exception Thrown if the system cannot access the map.
+ */
+ @throws[Exception]
+ def keys: JIterable[K] = map.keySet()
+
+ /**
+ * Returns all the values in the map view.
+ *
+ * @return An iterable of all the values in the map.
+ * @throws Exception Thrown if the system cannot access the map.
+ */
+ @throws[Exception]
+ def values: JIterable[V] = map.values()
+
+ /**
+ * Returns an iterator over all entries of the map view.
+ *
+ * @return An iterator over all the mappings in the map.
+ * @throws Exception Thrown if the system cannot access the map.
+ */
+ @throws[Exception]
+ def iterator: util.Iterator[util.Map.Entry[K, V]] = map.entrySet().iterator()
+
+ /**
+ * Removes all entries of this map.
+ */
+ override def clear(): Unit = map.clear()
+
+ override def equals(other: Any): Boolean = other match {
+ case that: MapView[K, V] =>
+ map.equals(that.map)
+ case _ => false
+ }
+
+ override def hashCode(): Int = map.hashCode()
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
index 25527cc..22ce5ba 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
@@ -17,17 +17,23 @@
*/
package org.apache.flink.table.codegen
-import java.lang.reflect.ParameterizedType
+import java.lang.reflect.{Modifier, ParameterizedType}
import java.lang.{Iterable => JIterable}
+import org.apache.commons.codec.binary.Base64
+import org.apache.flink.api.common.state.{State, StateDescriptor}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.table.api.TableConfig
+import org.apache.flink.table.api.dataview._
import org.apache.flink.table.codegen.Indenter.toISC
-import org.apache.flink.table.codegen.CodeGenUtils.newName
+import org.apache.flink.table.codegen.CodeGenUtils.{newName, reflectiveFieldWriteAccess}
import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getUserDefinedMethod, signatureToString}
import org.apache.flink.table.runtime.aggregate.{GeneratedAggregations, SingleElementIterable}
+import org.apache.flink.util.InstantiationUtil
+
+import scala.collection.mutable
/**
* A code generator for generating [[GeneratedAggregations]].
@@ -42,6 +48,18 @@ class AggregationCodeGenerator(
input: TypeInformation[_ <: Any])
extends CodeGenerator(config, nullableInput, input) {
+ // set of statements for cleanup dataview that will be added only once
+ // we use a LinkedHashSet to keep the insertion order
+ private val reusableCleanupStatements = mutable.LinkedHashSet[String]()
+
+ /**
+ * @return code block of statements that need to be placed in the cleanup() method of
+ * [[GeneratedAggregations]]
+ */
+ def reuseCleanupCode(): String = {
+ reusableCleanupStatements.mkString("", "\n", "\n")
+ }
+
/**
* Generates a [[org.apache.flink.table.runtime.aggregate.GeneratedAggregations]] that can be
* passed to a Java compiler.
@@ -79,13 +97,15 @@ class AggregationCodeGenerator(
outputArity: Int,
needRetract: Boolean,
needMerge: Boolean,
- needReset: Boolean)
+ needReset: Boolean,
+ accConfig: Option[Array[Seq[DataViewSpec[_]]]])
: GeneratedAggregationsFunction = {
// get unique function name
val funcName = newName(name)
// register UDAGGs
- val aggs = aggregates.map(a => addReusableFunction(a))
+ val aggs = aggregates.map(a => addReusableFunction(a, contextTerm))
+
// get java types of accumulators
val accTypeClasses = aggregates.map { a =>
a.getClass.getMethod("createAccumulator").getReturnType
@@ -105,6 +125,9 @@ class AggregationCodeGenerator(
inFields.map(classes(_))
}
+ // initialize and create data views
+ addReusableDataViews()
+
// check and validate the needed methods
aggregates.zipWithIndex.map {
case (a, i) =>
@@ -161,13 +184,121 @@ class AggregationCodeGenerator(
}
}
+ /**
+ * Create DataView Term, for example, acc1_map_dataview.
+ *
+ * @param aggIndex index of aggregate function
+ * @param fieldName field name of DataView
+ * @return term to access [[MapView]] or [[ListView]]
+ */
+ def createDataViewTerm(aggIndex: Int, fieldName: String): String = {
+ s"acc${aggIndex}_${fieldName}_dataview"
+ }
+
+ /**
+ * Adds a reusable [[org.apache.flink.table.api.dataview.DataView]] to the open, cleanup,
+ * close and member area of the generated function.
+ *
+ */
+ def addReusableDataViews(): Unit = {
+ if (accConfig.isDefined) {
+ val descMapping: Map[String, StateDescriptor[_, _]] = accConfig.get
+ .flatMap(specs => specs.map(s => (s.stateId, s.toStateDescriptor)))
+ .toMap[String, StateDescriptor[_ <: State, _]]
+
+ for (i <- aggs.indices) yield {
+ for (spec <- accConfig.get(i)) yield {
+ val dataViewField = spec.field
+ val dataViewTypeTerm = dataViewField.getType.getCanonicalName
+ val desc = descMapping.getOrElse(spec.stateId,
+ throw new CodeGenException(
+ s"Can not find DataView in accumulator by id: ${spec.stateId}"))
+
+ // define the DataView variables
+ val serializedData = serializeStateDescriptor(desc)
+ val dataViewFieldTerm = createDataViewTerm(i, dataViewField.getName)
+ val field =
+ s"""
+ | transient $dataViewTypeTerm $dataViewFieldTerm = null;
+ |""".stripMargin
+ reusableMemberStatements.add(field)
+
+ // create DataViews
+ val descFieldTerm = s"${dataViewFieldTerm}_desc"
+ val descClassQualifier = classOf[StateDescriptor[_, _]].getCanonicalName
+ val descDeserializeCode =
+ s"""
+ | $descClassQualifier $descFieldTerm = ($descClassQualifier)
+ | org.apache.flink.util.InstantiationUtil.deserializeObject(
+ | org.apache.commons.codec.binary.Base64.decodeBase64("$serializedData"),
+ | $contextTerm.getUserCodeClassLoader());
+ |""".stripMargin
+ val createDataView = if (dataViewField.getType == classOf[MapView[_, _]]) {
+ s"""
+ | $descDeserializeCode
+ | $dataViewFieldTerm = new org.apache.flink.table.dataview.StateMapView(
+ | $contextTerm.getMapState((
+ | org.apache.flink.api.common.state.MapStateDescriptor)$descFieldTerm));
+ |""".stripMargin
+ } else if (dataViewField.getType == classOf[ListView[_]]) {
+ s"""
+ | $descDeserializeCode
+ | $dataViewFieldTerm = new org.apache.flink.table.dataview.StateListView(
+ | $contextTerm.getListState((
+ | org.apache.flink.api.common.state.ListStateDescriptor)$descFieldTerm));
+ |""".stripMargin
+ } else {
+ throw new CodeGenException(s"Unsupported dataview type: $dataViewTypeTerm")
+ }
+ reusableOpenStatements.add(createDataView)
+
+ // cleanup DataViews
+ val cleanup =
+ s"""
+ | $dataViewFieldTerm.clear();
+ |""".stripMargin
+ reusableCleanupStatements.add(cleanup)
+ }
+ }
+ }
+ }
+
+ /**
+ * Generate statements to set data view field when use state backend.
+ *
+ * @param accTerm aggregation term
+ * @param aggIndex index of aggregation
+ * @return data view field set statements
+ */
+ def genDataViewFieldSetter(accTerm: String, aggIndex: Int): String = {
+ if (accConfig.isDefined) {
+ val setters = for (spec <- accConfig.get(aggIndex)) yield {
+ val field = spec.field
+ val dataViewTerm = createDataViewTerm(aggIndex, field.getName)
+ val fieldSetter = if (Modifier.isPublic(field.getModifiers)) {
+ s"$accTerm.${field.getName} = $dataViewTerm;"
+ } else {
+ val fieldTerm = addReusablePrivateFieldAccess(field.getDeclaringClass, field.getName)
+ s"${reflectiveFieldWriteAccess(fieldTerm, field, accTerm, dataViewTerm)};"
+ }
+
+ s"""
+ | $fieldSetter
+ """.stripMargin
+ }
+ setters.mkString("\n")
+ } else {
+ ""
+ }
+ }
+
def genSetAggregationResults: String = {
val sig: String =
j"""
| public final void setAggregationResults(
| org.apache.flink.types.Row accs,
- | org.apache.flink.types.Row output)""".stripMargin
+ | org.apache.flink.types.Row output) throws Exception """.stripMargin
val setAggs: String = {
for (i <- aggs.indices) yield
@@ -181,10 +312,11 @@ class AggregationCodeGenerator(
j"""
| org.apache.flink.table.functions.AggregateFunction baseClass$i =
| (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)};
- |
+ | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+ | ${genDataViewFieldSetter(s"acc$i", i)}
| output.setField(
| ${aggMapping(i)},
- | baseClass$i.getValue((${accTypes(i)}) accs.getField($i)));""".stripMargin
+ | baseClass$i.getValue(acc$i));""".stripMargin
}
}.mkString("\n")
@@ -200,14 +332,17 @@ class AggregationCodeGenerator(
j"""
| public final void accumulate(
| org.apache.flink.types.Row accs,
- | org.apache.flink.types.Row input)""".stripMargin
+ | org.apache.flink.types.Row input) throws Exception """.stripMargin
val accumulate: String = {
- for (i <- aggs.indices) yield
+ for (i <- aggs.indices) yield {
j"""
+ | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+ | ${genDataViewFieldSetter(s"acc$i", i)}
| ${aggs(i)}.accumulate(
- | ((${accTypes(i)}) accs.getField($i)),
+ | acc$i,
| ${parametersCode(i)});""".stripMargin
+ }
}.mkString("\n")
j"""$sig {
@@ -221,14 +356,17 @@ class AggregationCodeGenerator(
j"""
| public final void retract(
| org.apache.flink.types.Row accs,
- | org.apache.flink.types.Row input)""".stripMargin
+ | org.apache.flink.types.Row input) throws Exception """.stripMargin
val retract: String = {
- for (i <- aggs.indices) yield
+ for (i <- aggs.indices) yield {
j"""
+ | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+ | ${genDataViewFieldSetter(s"acc$i", i)}
| ${aggs(i)}.retract(
- | ((${accTypes(i)}) accs.getField($i)),
+ | acc$i,
| ${parametersCode(i)});""".stripMargin
+ }
}.mkString("\n")
if (needRetract) {
@@ -247,7 +385,7 @@ class AggregationCodeGenerator(
val sig: String =
j"""
- | public final org.apache.flink.types.Row createAccumulators()
+ | public final org.apache.flink.types.Row createAccumulators() throws Exception
| """.stripMargin
val init: String =
j"""
@@ -255,12 +393,15 @@ class AggregationCodeGenerator(
| new org.apache.flink.types.Row(${aggs.length});"""
.stripMargin
val create: String = {
- for (i <- aggs.indices) yield
+ for (i <- aggs.indices) yield {
j"""
+ | ${accTypes(i)} acc$i = (${accTypes(i)}) ${aggs(i)}.createAccumulator();
+ | ${genDataViewFieldSetter(s"acc$i", i)}
| accs.setField(
| $i,
- | ${aggs(i)}.createAccumulator());"""
- .stripMargin
+ | acc$i);"""
+ .stripMargin
+ }
}.mkString("\n")
val ret: String =
j"""
@@ -356,6 +497,10 @@ class AggregationCodeGenerator(
""".stripMargin
if (needMerge) {
+ if (accConfig.isDefined) {
+ throw new CodeGenException("DataView doesn't support merge when the backend uses " +
+ s"state when generate aggregation for $funcName.")
+ }
j"""
|$sig {
|$merge
@@ -385,13 +530,15 @@ class AggregationCodeGenerator(
val sig: String =
j"""
| public final void resetAccumulator(
- | org.apache.flink.types.Row accs)""".stripMargin
+ | org.apache.flink.types.Row accs) throws Exception """.stripMargin
val reset: String = {
- for (i <- aggs.indices) yield
+ for (i <- aggs.indices) yield {
j"""
- | ${aggs(i)}.resetAccumulator(
- | ((${accTypes(i)}) accs.getField($i)));""".stripMargin
+ | ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
+ | ${genDataViewFieldSetter(s"acc$i", i)}
+ | ${aggs(i)}.resetAccumulator(acc$i);""".stripMargin
+ }
}.mkString("\n")
if (needReset) {
@@ -404,6 +551,17 @@ class AggregationCodeGenerator(
}
}
+ val aggFuncCode = Seq(
+ genSetAggregationResults,
+ genAccumulate,
+ genRetract,
+ genCreateAccumulators,
+ genSetForwardedFields,
+ genSetConstantFlags,
+ genCreateOutputRow,
+ genMergeAccumulatorsPair,
+ genResetAccumulator).mkString("\n")
+
val generatedAggregationsClass = classOf[GeneratedAggregations].getCanonicalName
var funcCode =
j"""
@@ -416,20 +574,29 @@ class AggregationCodeGenerator(
| }
| ${reuseConstructorCode(funcName)}
|
+ | public final void open(
+ | org.apache.flink.api.common.functions.RuntimeContext $contextTerm) throws Exception {
+ | ${reuseOpenCode()}
+ | }
+ |
+ | $aggFuncCode
+ |
+ | public final void cleanup() throws Exception {
+ | ${reuseCleanupCode()}
+ | }
+ |
+ | public final void close() throws Exception {
+ | ${reuseCloseCode()}
+ | }
+ |}
""".stripMargin
- funcCode += genSetAggregationResults + "\n"
- funcCode += genAccumulate + "\n"
- funcCode += genRetract + "\n"
- funcCode += genCreateAccumulators + "\n"
- funcCode += genSetForwardedFields + "\n"
- funcCode += genSetConstantFlags + "\n"
- funcCode += genCreateOutputRow + "\n"
- funcCode += genMergeAccumulatorsPair + "\n"
- funcCode += genResetAccumulator + "\n"
- funcCode += "}"
-
GeneratedAggregationsFunction(funcName, funcCode)
}
+ @throws[Exception]
+ def serializeStateDescriptor(stateDescriptor: StateDescriptor[_, _]): String = {
+ val byteArray = InstantiationUtil.serializeObject(stateDescriptor)
+ Base64.encodeBase64URLSafeString(byteArray)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
index 946c6cd..154e8ad 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
@@ -41,8 +41,9 @@ import org.apache.flink.table.codegen.calls.FunctionGenerator
import org.apache.flink.table.codegen.calls.ScalarOperators._
import org.apache.flink.table.functions.sql.{ProctimeSqlFunction, ScalarSqlFunctions}
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
-import org.apache.flink.table.functions.{FunctionContext, UserDefinedFunction}
+
import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo
+import org.apache.flink.table.functions.{FunctionContext, UserDefinedFunction}
import org.apache.flink.table.typeutils.TypeCheckUtils._
import scala.collection.JavaConversions._
@@ -108,31 +109,31 @@ abstract class CodeGenerator(
// set of member statements that will be added only once
// we use a LinkedHashSet to keep the insertion order
- private val reusableMemberStatements = mutable.LinkedHashSet[String]()
+ protected val reusableMemberStatements = mutable.LinkedHashSet[String]()
// set of constructor statements that will be added only once
// we use a LinkedHashSet to keep the insertion order
- private val reusableInitStatements = mutable.LinkedHashSet[String]()
+ protected val reusableInitStatements = mutable.LinkedHashSet[String]()
// set of open statements for RichFunction that will be added only once
// we use a LinkedHashSet to keep the insertion order
- private val reusableOpenStatements = mutable.LinkedHashSet[String]()
+ protected val reusableOpenStatements = mutable.LinkedHashSet[String]()
// set of close statements for RichFunction that will be added only once
// we use a LinkedHashSet to keep the insertion order
- private val reusableCloseStatements = mutable.LinkedHashSet[String]()
+ protected val reusableCloseStatements = mutable.LinkedHashSet[String]()
// set of statements that will be added only once per record
// we use a LinkedHashSet to keep the insertion order
- private val reusablePerRecordStatements = mutable.LinkedHashSet[String]()
+ protected val reusablePerRecordStatements = mutable.LinkedHashSet[String]()
// map of initial input unboxing expressions that will be added only once
// (inputTerm, index) -> expr
- private val reusableInputUnboxingExprs = mutable.Map[(String, Int), GeneratedExpression]()
+ protected val reusableInputUnboxingExprs = mutable.Map[(String, Int), GeneratedExpression]()
// set of constructor statements that will be added only once
// we use a LinkedHashSet to keep the insertion order
- private val reusableConstructorStatements = mutable.LinkedHashSet[(String, String)]()
+ protected val reusableConstructorStatements = mutable.LinkedHashSet[(String, String)]()
/**
* @return code block of statements that need to be placed in the member area of the Function
@@ -1458,9 +1459,10 @@ abstract class CodeGenerator(
* Adds a reusable [[UserDefinedFunction]] to the member area of the generated [[Function]].
*
* @param function [[UserDefinedFunction]] object to be instantiated during runtime
+ * @param contextTerm [[RuntimeContext]] term to access the [[RuntimeContext]]
* @return member variable term
*/
- def addReusableFunction(function: UserDefinedFunction): String = {
+ def addReusableFunction(function: UserDefinedFunction, contextTerm: String = null): String = {
val classQualifier = function.getClass.getCanonicalName
val functionSerializedData = UserDefinedFunctionUtils.serialize(function)
val fieldTerm = s"function_${function.functionIdentifier}"
@@ -1480,10 +1482,15 @@ abstract class CodeGenerator(
reusableInitStatements.add(functionDeserialization)
- val openFunction =
+ val openFunction = if (contextTerm != null) {
+ s"""
+ |$fieldTerm.open(new ${classOf[FunctionContext].getCanonicalName}($contextTerm));
+ """.stripMargin
+ } else {
s"""
|$fieldTerm.open(new ${classOf[FunctionContext].getCanonicalName}(getRuntimeContext()));
""".stripMargin
+ }
reusableOpenStatements.add(openFunction)
val closeFunction =
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewSerializer.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewSerializer.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewSerializer.scala
new file mode 100644
index 0000000..a450c2c
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewSerializer.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.dataview
+
+import org.apache.flink.api.common.typeutils._
+import org.apache.flink.api.common.typeutils.base.{CollectionSerializerConfigSnapshot, ListSerializer}
+import org.apache.flink.core.memory.{DataInputView, DataOutputView}
+import org.apache.flink.table.api.dataview.ListView
+
+/**
+ * A serializer for [[ListView]]. The serializer relies on an element
+ * serializer for the serialization of the list's elements.
+ *
+ * The serialization format for the list is as follows: four bytes for the length of the list,
+ * followed by the serialized representation of each element.
+ *
+ * @param listSerializer List serializer.
+ * @tparam T The type of element in the list.
+ */
+class ListViewSerializer[T](val listSerializer: ListSerializer[T])
+ extends TypeSerializer[ListView[T]] {
+
+ override def isImmutableType: Boolean = false
+
+ override def duplicate(): TypeSerializer[ListView[T]] = {
+ new ListViewSerializer[T](listSerializer.duplicate().asInstanceOf[ListSerializer[T]])
+ }
+
+ override def createInstance(): ListView[T] = {
+ new ListView[T]
+ }
+
+ override def copy(from: ListView[T]): ListView[T] = {
+ new ListView[T](null, listSerializer.copy(from.list))
+ }
+
+ override def copy(from: ListView[T], reuse: ListView[T]): ListView[T] = copy(from)
+
+ override def getLength: Int = -1
+
+ override def serialize(record: ListView[T], target: DataOutputView): Unit = {
+ listSerializer.serialize(record.list, target)
+ }
+
+ override def deserialize(source: DataInputView): ListView[T] = {
+ new ListView[T](null, listSerializer.deserialize(source))
+ }
+
+ override def deserialize(reuse: ListView[T], source: DataInputView): ListView[T] =
+ deserialize(source)
+
+ override def copy(source: DataInputView, target: DataOutputView): Unit =
+ listSerializer.copy(source, target)
+
+ override def canEqual(obj: scala.Any): Boolean = obj != null && obj.getClass == getClass
+
+ override def hashCode(): Int = listSerializer.hashCode()
+
+ override def equals(obj: Any): Boolean = canEqual(this) &&
+ listSerializer.equals(obj.asInstanceOf[ListViewSerializer[_]].listSerializer)
+
+ override def snapshotConfiguration(): TypeSerializerConfigSnapshot =
+ listSerializer.snapshotConfiguration()
+
+ // copy and modified from ListSerializer.ensureCompatibility
+ override def ensureCompatibility(
+ configSnapshot: TypeSerializerConfigSnapshot): CompatibilityResult[ListView[T]] = {
+
+ configSnapshot match {
+ case snapshot: CollectionSerializerConfigSnapshot[_] =>
+ val previousListSerializerAndConfig = snapshot.getSingleNestedSerializerAndConfig
+
+ val compatResult = CompatibilityUtil.resolveCompatibilityResult(
+ previousListSerializerAndConfig.f0,
+ classOf[UnloadableDummyTypeSerializer[_]],
+ previousListSerializerAndConfig.f1,
+ listSerializer.getElementSerializer)
+
+ if (!compatResult.isRequiresMigration) {
+ CompatibilityResult.compatible[ListView[T]]
+ } else if (compatResult.getConvertDeserializer != null) {
+ CompatibilityResult.requiresMigration(
+ new ListViewSerializer[T](
+ new ListSerializer[T](
+ new TypeDeserializerAdapter[T](compatResult.getConvertDeserializer))
+ )
+ )
+ } else {
+ CompatibilityResult.requiresMigration[ListView[T]]
+ }
+
+ case _ => CompatibilityResult.requiresMigration[ListView[T]]
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewTypeInfo.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewTypeInfo.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewTypeInfo.scala
new file mode 100644
index 0000000..a10b675
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewTypeInfo.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.dataview
+
+import org.apache.flink.api.common.ExecutionConfig
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.common.typeutils.TypeSerializer
+import org.apache.flink.api.common.typeutils.base.ListSerializer
+import org.apache.flink.table.api.dataview.ListView
+
+/**
+ * [[ListView]] type information.
+ *
+ * @param elementType element type information
+ * @tparam T element type
+ */
+class ListViewTypeInfo[T](val elementType: TypeInformation[T])
+ extends TypeInformation[ListView[T]] {
+
+ override def isBasicType: Boolean = false
+
+ override def isTupleType: Boolean = false
+
+ override def getArity: Int = 1
+
+ override def getTotalFields: Int = 1
+
+ override def getTypeClass: Class[ListView[T]] = classOf[ListView[T]]
+
+ override def isKeyType: Boolean = false
+
+ override def createSerializer(config: ExecutionConfig): TypeSerializer[ListView[T]] = {
+ val typeSer = elementType.createSerializer(config)
+ new ListViewSerializer[T](new ListSerializer[T](typeSer))
+ }
+
+ override def canEqual(obj: scala.Any): Boolean = obj != null && obj.getClass == getClass
+
+ override def hashCode(): Int = 31 * elementType.hashCode
+
+ override def equals(obj: Any): Boolean = canEqual(obj) && {
+ obj match {
+ case other: ListViewTypeInfo[T] =>
+ elementType.equals(other.elementType)
+ case _ => false
+ }
+ }
+
+ override def toString: String = s"ListView<$elementType>"
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewTypeInfoFactory.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewTypeInfoFactory.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewTypeInfoFactory.scala
new file mode 100644
index 0000000..eda6cb9
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/ListViewTypeInfoFactory.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.dataview
+
+import java.lang.reflect.Type
+import java.util
+
+import org.apache.flink.api.common.typeinfo.{TypeInfoFactory, TypeInformation}
+import org.apache.flink.api.java.typeutils.GenericTypeInfo
+import org.apache.flink.table.api.dataview.ListView
+
+class ListViewTypeInfoFactory[T] extends TypeInfoFactory[ListView[T]] {
+
+ override def createTypeInfo(
+ t: Type,
+ genericParameters: util.Map[String, TypeInformation[_]]): TypeInformation[ListView[T]] = {
+
+ var elementType = genericParameters.get("T")
+
+ if (elementType == null) {
+ // we might can get the elementType later from the ListView constructor
+ elementType = new GenericTypeInfo(classOf[Any])
+ }
+
+ new ListViewTypeInfo[T](elementType.asInstanceOf[TypeInformation[T]])
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewSerializer.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewSerializer.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewSerializer.scala
new file mode 100644
index 0000000..c53f10c
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewSerializer.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.dataview
+
+import org.apache.flink.api.common.typeutils._
+import org.apache.flink.api.common.typeutils.base.{MapSerializer, MapSerializerConfigSnapshot}
+import org.apache.flink.core.memory.{DataInputView, DataOutputView}
+import org.apache.flink.table.api.dataview.MapView
+
+/**
+ * A serializer for [[MapView]]. The serializer relies on a key serializer and a value
+ * serializer for the serialization of the map's key-value pairs.
+ *
+ * The serialization format for the map is as follows: four bytes for the length of the map,
+ * followed by the serialized representation of each key-value pair. To allow null values,
+ * each value is prefixed by a null marker.
+ *
+ * @param mapSerializer Map serializer.
+ * @tparam K The type of the keys in the map.
+ * @tparam V The type of the values in the map.
+ */
+class MapViewSerializer[K, V](val mapSerializer: MapSerializer[K, V])
+ extends TypeSerializer[MapView[K, V]] {
+
+ override def isImmutableType: Boolean = false
+
+ override def duplicate(): TypeSerializer[MapView[K, V]] =
+ new MapViewSerializer[K, V](
+ mapSerializer.duplicate().asInstanceOf[MapSerializer[K, V]])
+
+ override def createInstance(): MapView[K, V] = {
+ new MapView[K, V]()
+ }
+
+ override def copy(from: MapView[K, V]): MapView[K, V] = {
+ new MapView[K, V](null, null, mapSerializer.copy(from.map))
+ }
+
+ override def copy(from: MapView[K, V], reuse: MapView[K, V]): MapView[K, V] = copy(from)
+
+ override def getLength: Int = -1 // var length
+
+ override def serialize(record: MapView[K, V], target: DataOutputView): Unit = {
+ mapSerializer.serialize(record.map, target)
+ }
+
+ override def deserialize(source: DataInputView): MapView[K, V] = {
+ new MapView[K, V](null, null, mapSerializer.deserialize(source))
+ }
+
+ override def deserialize(reuse: MapView[K, V], source: DataInputView): MapView[K, V] =
+ deserialize(source)
+
+ override def copy(source: DataInputView, target: DataOutputView): Unit =
+ mapSerializer.copy(source, target)
+
+ override def canEqual(obj: Any): Boolean = obj != null && obj.getClass == getClass
+
+ override def hashCode(): Int = mapSerializer.hashCode()
+
+ override def equals(obj: Any): Boolean = canEqual(this) &&
+ mapSerializer.equals(obj.asInstanceOf[MapViewSerializer[_, _]].mapSerializer)
+
+ override def snapshotConfiguration(): TypeSerializerConfigSnapshot =
+ mapSerializer.snapshotConfiguration()
+
+ // copy and modified from MapSerializer.ensureCompatibility
+ override def ensureCompatibility(configSnapshot: TypeSerializerConfigSnapshot)
+ : CompatibilityResult[MapView[K, V]] = {
+
+ configSnapshot match {
+ case snapshot: MapSerializerConfigSnapshot[_, _] =>
+ val previousKvSerializersAndConfigs = snapshot.getNestedSerializersAndConfigs
+
+ val keyCompatResult = CompatibilityUtil.resolveCompatibilityResult(
+ previousKvSerializersAndConfigs.get(0).f0,
+ classOf[UnloadableDummyTypeSerializer[_]],
+ previousKvSerializersAndConfigs.get(0).f1,
+ mapSerializer.getKeySerializer)
+
+ val valueCompatResult = CompatibilityUtil.resolveCompatibilityResult(
+ previousKvSerializersAndConfigs.get(1).f0,
+ classOf[UnloadableDummyTypeSerializer[_]],
+ previousKvSerializersAndConfigs.get(1).f1,
+ mapSerializer.getValueSerializer)
+
+ if (!keyCompatResult.isRequiresMigration && !valueCompatResult.isRequiresMigration) {
+ CompatibilityResult.compatible[MapView[K, V]]
+ } else if (keyCompatResult.getConvertDeserializer != null
+ && valueCompatResult.getConvertDeserializer != null) {
+ CompatibilityResult.requiresMigration(
+ new MapViewSerializer[K, V](
+ new MapSerializer[K, V](
+ new TypeDeserializerAdapter[K](keyCompatResult.getConvertDeserializer),
+ new TypeDeserializerAdapter[V](valueCompatResult.getConvertDeserializer))
+ )
+ )
+ } else {
+ CompatibilityResult.requiresMigration[MapView[K, V]]
+ }
+
+ case _ => CompatibilityResult.requiresMigration[MapView[K, V]]
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewTypeInfo.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewTypeInfo.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewTypeInfo.scala
new file mode 100644
index 0000000..ec5c222
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewTypeInfo.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.dataview
+
+import org.apache.flink.api.common.ExecutionConfig
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.common.typeutils.TypeSerializer
+import org.apache.flink.api.common.typeutils.base.MapSerializer
+import org.apache.flink.table.api.dataview.MapView
+
+/**
+ * [[MapView]] type information.
+ *
+ * @param keyType key type information
+ * @param valueType value type information
+ * @tparam K key type
+ * @tparam V value type
+ */
+class MapViewTypeInfo[K, V](
+ val keyType: TypeInformation[K],
+ val valueType: TypeInformation[V])
+ extends TypeInformation[MapView[K, V]] {
+
+ override def isBasicType = false
+
+ override def isTupleType = false
+
+ override def getArity = 1
+
+ override def getTotalFields = 1
+
+ override def getTypeClass: Class[MapView[K, V]] = classOf[MapView[K, V]]
+
+ override def isKeyType: Boolean = false
+
+ override def createSerializer(config: ExecutionConfig): TypeSerializer[MapView[K, V]] = {
+ val keySer = keyType.createSerializer(config)
+ val valueSer = valueType.createSerializer(config)
+ new MapViewSerializer[K, V](new MapSerializer[K, V](keySer, valueSer))
+ }
+
+ override def canEqual(obj: scala.Any): Boolean = obj != null && obj.getClass == getClass
+
+ override def hashCode(): Int = 31 * keyType.hashCode + valueType.hashCode
+
+ override def equals(obj: Any): Boolean = canEqual(obj) && {
+ obj match {
+ case other: MapViewTypeInfo[_, _] =>
+ keyType.equals(other.keyType) &&
+ valueType.equals(other.valueType)
+ case _ => false
+ }
+ }
+
+ override def toString: String = s"MapView<$keyType, $valueType>"
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewTypeInfoFactory.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewTypeInfoFactory.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewTypeInfoFactory.scala
new file mode 100644
index 0000000..33c3ffe
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/MapViewTypeInfoFactory.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.dataview
+
+import java.lang.reflect.Type
+import java.util
+
+import org.apache.flink.api.common.typeinfo.{TypeInfoFactory, TypeInformation}
+import org.apache.flink.api.java.typeutils.GenericTypeInfo
+import org.apache.flink.table.api.dataview.MapView
+
+class MapViewTypeInfoFactory[K, V] extends TypeInfoFactory[MapView[K, V]] {
+
+ override def createTypeInfo(
+ t: Type,
+ genericParameters: util.Map[String, TypeInformation[_]]): TypeInformation[MapView[K, V]] = {
+
+ var keyType = genericParameters.get("K")
+ var valueType = genericParameters.get("V")
+
+ if (keyType == null) {
+ // we might can get the keyType later from the MapView constructor
+ keyType = new GenericTypeInfo(classOf[Any])
+ }
+
+ if (valueType == null) {
+ // we might can get the valueType later from the MapView constructor
+ valueType = new GenericTypeInfo(classOf[Any])
+ }
+
+ new MapViewTypeInfo[K, V](
+ keyType.asInstanceOf[TypeInformation[K]],
+ valueType.asInstanceOf[TypeInformation[V]])
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/StateListView.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/StateListView.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/StateListView.scala
new file mode 100644
index 0000000..70756ca
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/StateListView.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.dataview
+
+import java.util
+import java.lang.{Iterable => JIterable}
+
+import org.apache.flink.api.common.state._
+import org.apache.flink.table.api.dataview.ListView
+
+/**
+ * [[ListView]] use state backend.
+ *
+ * @param state list state
+ * @tparam T element type
+ */
+class StateListView[T](state: ListState[T]) extends ListView[T] {
+
+ override def get: JIterable[T] = state.get()
+
+ override def add(value: T): Unit = state.add(value)
+
+ override def addAll(list: util.List[T]): Unit = {
+ val iterator = list.iterator()
+ while (iterator.hasNext) {
+ state.add(iterator.next())
+ }
+ }
+
+ override def clear(): Unit = state.clear()
+}
+
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/StateMapView.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/StateMapView.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/StateMapView.scala
new file mode 100644
index 0000000..22f5f0b
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/dataview/StateMapView.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.dataview
+
+import java.util
+import java.lang.{Iterable => JIterable}
+
+import org.apache.flink.api.common.state.MapState
+import org.apache.flink.table.api.dataview.MapView
+
+/**
+ * [[MapView]] use state backend.
+ *
+ * @param state map state
+ * @tparam K key type
+ * @tparam V value type
+ */
+class StateMapView[K, V](state: MapState[K, V]) extends MapView[K, V] {
+
+ override def get(key: K): V = state.get(key)
+
+ override def put(key: K, value: V): Unit = state.put(key, value)
+
+ override def putAll(map: util.Map[K, V]): Unit = state.putAll(map)
+
+ override def remove(key: K): Unit = state.remove(key)
+
+ override def contains(key: K): Boolean = state.contains(key)
+
+ override def entries: JIterable[util.Map.Entry[K, V]] = state.entries()
+
+ override def keys: JIterable[K] = state.keys()
+
+ override def values: JIterable[V] = state.values()
+
+ override def iterator: util.Iterator[util.Map.Entry[K, V]] = state.iterator()
+
+ override def clear(): Unit = state.clear()
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index b44c28e..f53bcde 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -19,6 +19,7 @@
package org.apache.flink.table.functions.utils
+import java.util
import java.lang.{Integer => JInt, Long => JLong}
import java.lang.reflect.{Method, Modifier}
import java.sql.{Date, Time, Timestamp}
@@ -29,7 +30,10 @@ import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.{SqlCallBinding, SqlFunction}
import org.apache.flink.api.common.functions.InvalidTypesException
import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.api.java.typeutils.TypeExtractor
+import org.apache.flink.api.common.typeutils.CompositeType
+import org.apache.flink.api.java.typeutils.{PojoField, PojoTypeInfo, TypeExtractor}
+import org.apache.flink.table.api.dataview._
+import org.apache.flink.table.dataview._
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException}
import org.apache.flink.table.expressions._
@@ -38,6 +42,8 @@ import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, Tabl
import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl
import org.apache.flink.util.InstantiationUtil
+import scala.collection.mutable
+
object UserDefinedFunctionUtils {
/**
@@ -307,6 +313,111 @@ object UserDefinedFunctionUtils {
// ----------------------------------------------------------------------------------------------
/**
+ * Remove StateView fields from accumulator type information.
+ *
+ * @param index index of aggregate function
+ * @param aggFun aggregate function
+ * @param accType accumulator type information, only support pojo type
+ * @param isStateBackedDataViews is data views use state backend
+ * @return mapping of accumulator type information and data view config which contains id,
+ * field name and state descriptor
+ */
+ def removeStateViewFieldsFromAccTypeInfo(
+ index: Int,
+ aggFun: AggregateFunction[_, _],
+ accType: TypeInformation[_],
+ isStateBackedDataViews: Boolean)
+ : (TypeInformation[_], Option[Seq[DataViewSpec[_]]]) = {
+
+ /** Recursively checks if composite type includes a data view type. */
+ def includesDataView(ct: CompositeType[_]): Boolean = {
+ (0 until ct.getArity).exists(i =>
+ ct.getTypeAt(i) match {
+ case nestedCT: CompositeType[_] => includesDataView(nestedCT)
+ case t: TypeInformation[_] if t.getTypeClass == classOf[ListView[_]] => true
+ case t: TypeInformation[_] if t.getTypeClass == classOf[MapView[_, _]] => true
+ case _ => false
+ }
+ )
+ }
+
+ val acc = aggFun.createAccumulator()
+ accType match {
+ case pojoType: PojoTypeInfo[_] if pojoType.getArity > 0 =>
+ val arity = pojoType.getArity
+ val newPojoFields = new util.ArrayList[PojoField]()
+ val accumulatorSpecs = new mutable.ArrayBuffer[DataViewSpec[_]]
+ for (i <- 0 until arity) {
+ val pojoField = pojoType.getPojoFieldAt(i)
+ val field = pojoField.getField
+ val fieldName = field.getName
+ field.setAccessible(true)
+
+ pojoField.getTypeInformation match {
+ case ct: CompositeType[_] if includesDataView(ct) =>
+ throw new TableException(
+ "MapView and ListView only supported at first level of accumulators of Pojo type.")
+ case map: MapViewTypeInfo[_, _] =>
+ val mapView = field.get(acc).asInstanceOf[MapView[_, _]]
+ if (mapView != null) {
+ val keyTypeInfo = mapView.keyTypeInfo
+ val valueTypeInfo = mapView.valueTypeInfo
+ val newTypeInfo = if (keyTypeInfo != null && valueTypeInfo != null) {
+ new MapViewTypeInfo(keyTypeInfo, valueTypeInfo)
+ } else {
+ map
+ }
+
+ // create map view specs with unique id (used as state name)
+ var spec = MapViewSpec(
+ "agg" + index + "$" + fieldName,
+ field,
+ newTypeInfo)
+
+ accumulatorSpecs += spec
+ if (!isStateBackedDataViews) {
+ // add data view field if it is not backed by a state backend.
+ // data view fields which are backed by state backend are not serialized.
+ newPojoFields.add(new PojoField(field, newTypeInfo))
+ }
+ }
+
+ case list: ListViewTypeInfo[_] =>
+ val listView = field.get(acc).asInstanceOf[ListView[_]]
+ if (listView != null) {
+ val elementTypeInfo = listView.elementTypeInfo
+ val newTypeInfo = if (elementTypeInfo != null) {
+ new ListViewTypeInfo(elementTypeInfo)
+ } else {
+ list
+ }
+
+ // create list view specs with unique is (used as state name)
+ var spec = ListViewSpec(
+ "agg" + index + "$" + fieldName,
+ field,
+ newTypeInfo)
+
+ accumulatorSpecs += spec
+ if (!isStateBackedDataViews) {
+ // add data view field if it is not backed by a state backend.
+ // data view fields which are backed by state backend are not serialized.
+ newPojoFields.add(new PojoField(field, newTypeInfo))
+ }
+ }
+
+ case _ => newPojoFields.add(pojoField)
+ }
+ }
+ (new PojoTypeInfo(accType.getTypeClass, newPojoFields), Some(accumulatorSpecs))
+ case ct: CompositeType[_] if includesDataView(ct) =>
+ throw new TableException(
+ "MapView and ListView only supported in accumulators of POJO type.")
+ case _ => (accType, None)
+ }
+ }
+
+ /**
* Tries to infer the TypeInformation of an AggregateFunction's return type.
*
* @param aggregateFunction The AggregateFunction for which the return type is inferred.
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index 6304dc4..58940d0 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -32,6 +32,7 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, WindowFunction}
import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow}
+import org.apache.flink.table.api.dataview.DataViewSpec
import org.apache.flink.table.api.{StreamQueryConfig, TableException}
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.calcite.FlinkTypeFactory
@@ -82,11 +83,12 @@ object AggregateUtil {
isRowsClause: Boolean)
: ProcessFunction[CRow, CRow] = {
- val (aggFields, aggregates, accTypes) =
+ val (aggFields, aggregates, accTypes, accSpecs) =
transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
- needRetraction = false)
+ needRetraction = false,
+ isStateBackedDataViews = true)
val aggregationStateType: RowTypeInfo = new RowTypeInfo(accTypes: _*)
@@ -107,7 +109,8 @@ object AggregateUtil {
outputArity,
needRetract = false,
needMerge = false,
- needReset = false
+ needReset = false,
+ accConfig = Some(accSpecs)
)
if (rowTimeIdx.isDefined) {
@@ -160,11 +163,12 @@ object AggregateUtil {
generateRetraction: Boolean,
consumeRetraction: Boolean): ProcessFunction[CRow, CRow] = {
- val (aggFields, aggregates, accTypes) =
+ val (aggFields, aggregates, accTypes, accSpecs) =
transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputRowType,
- consumeRetraction)
+ consumeRetraction,
+ isStateBackedDataViews = true)
val aggMapping = aggregates.indices.map(_ + groupings.length).toArray
@@ -185,7 +189,8 @@ object AggregateUtil {
outputArity,
consumeRetraction,
needMerge = false,
- needReset = false
+ needReset = false,
+ accConfig = Some(accSpecs)
)
new GroupAggProcessFunction(
@@ -223,11 +228,12 @@ object AggregateUtil {
: ProcessFunction[CRow, CRow] = {
val needRetract = true
- val (aggFields, aggregates, accTypes) =
+ val (aggFields, aggregates, accTypes, accSpecs) =
transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
- needRetract)
+ needRetract,
+ isStateBackedDataViews = true)
val aggregationStateType: RowTypeInfo = new RowTypeInfo(accTypes: _*)
val inputRowType = CRowTypeInfo(inputTypeInfo)
@@ -249,7 +255,8 @@ object AggregateUtil {
outputArity,
needRetract,
needMerge = false,
- needReset = true
+ needReset = false,
+ accConfig = Some(accSpecs)
)
if (rowTimeIdx.isDefined) {
@@ -323,7 +330,7 @@ object AggregateUtil {
: MapFunction[Row, Row] = {
val needRetract = false
- val (aggFieldIndexes, aggregates, accTypes) = transformToAggregateFunctions(
+ val (aggFieldIndexes, aggregates, accTypes, _) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
needRetract)
@@ -380,7 +387,8 @@ object AggregateUtil {
outputArity,
needRetract,
needMerge = false,
- needReset = true
+ needReset = true,
+ None
)
new DataSetWindowAggMapFunction(
@@ -428,7 +436,7 @@ object AggregateUtil {
: RichGroupReduceFunction[Row, Row] = {
val needRetract = false
- val (aggFieldIndexes, aggregates, accTypes) = transformToAggregateFunctions(
+ val (aggFieldIndexes, aggregates, accTypes, _) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
physicalInputRowType,
needRetract)
@@ -458,7 +466,8 @@ object AggregateUtil {
keysAndAggregatesArity + 1,
needRetract,
needMerge = true,
- needReset = true
+ needReset = true,
+ None
)
new DataSetSlideTimeWindowAggReduceGroupFunction(
genFunction,
@@ -541,7 +550,7 @@ object AggregateUtil {
: RichGroupReduceFunction[Row, Row] = {
val needRetract = false
- val (aggFieldIndexes, aggregates, _) = transformToAggregateFunctions(
+ val (aggFieldIndexes, aggregates, _, _) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
physicalInputRowType,
needRetract)
@@ -561,7 +570,8 @@ object AggregateUtil {
outputType.getFieldCount,
needRetract,
needMerge = true,
- needReset = true
+ needReset = true,
+ None
)
val genFinalAggFunction = generator.generateAggregations(
@@ -577,7 +587,8 @@ object AggregateUtil {
outputType.getFieldCount,
needRetract,
needMerge = true,
- needReset = true
+ needReset = true,
+ None
)
val keysAndAggregatesArity = groupings.length + namedAggregates.length
@@ -686,7 +697,7 @@ object AggregateUtil {
groupings: Array[Int]): MapPartitionFunction[Row, Row] = {
val needRetract = false
- val (aggFieldIndexes, aggregates, accTypes) = transformToAggregateFunctions(
+ val (aggFieldIndexes, aggregates, accTypes, _) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
physicalInputRowType,
needRetract)
@@ -718,7 +729,8 @@ object AggregateUtil {
groupings.length + aggregates.length + 2,
needRetract,
needMerge = true,
- needReset = true
+ needReset = true,
+ None
)
new DataSetSessionWindowAggregatePreProcessor(
@@ -759,7 +771,7 @@ object AggregateUtil {
: GroupCombineFunction[Row, Row] = {
val needRetract = false
- val (aggFieldIndexes, aggregates, accTypes) = transformToAggregateFunctions(
+ val (aggFieldIndexes, aggregates, accTypes, _) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
physicalInputRowType,
needRetract)
@@ -792,7 +804,8 @@ object AggregateUtil {
groupings.length + aggregates.length + 2,
needRetract,
needMerge = true,
- needReset = true
+ needReset = true,
+ None
)
new DataSetSessionWindowAggregatePreProcessor(
@@ -825,7 +838,7 @@ object AggregateUtil {
RichGroupReduceFunction[Row, Row]) = {
val needRetract = false
- val (aggInFields, aggregates, accTypes) = transformToAggregateFunctions(
+ val (aggInFields, aggregates, accTypes, _) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
needRetract)
@@ -872,7 +885,8 @@ object AggregateUtil {
groupings.length + aggregates.length,
needRetract,
needMerge = false,
- needReset = true
+ needReset = true,
+ None
)
// compute mapping of forwarded grouping keys
@@ -898,7 +912,8 @@ object AggregateUtil {
outputType.getFieldCount,
needRetract,
needMerge = true,
- needReset = true
+ needReset = true,
+ None
)
(
@@ -921,7 +936,8 @@ object AggregateUtil {
outputType.getFieldCount,
needRetract,
needMerge = false,
- needReset = true
+ needReset = true,
+ None
)
(
@@ -996,7 +1012,7 @@ object AggregateUtil {
: (DataStreamAggFunction[CRow, Row, Row], RowTypeInfo, RowTypeInfo) = {
val needRetract = false
- val (aggFields, aggregates, accTypes) =
+ val (aggFields, aggregates, accTypes, _) =
transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
@@ -1018,7 +1034,8 @@ object AggregateUtil {
outputArity,
needRetract,
needMerge,
- needReset = false
+ needReset = false,
+ None
)
val aggResultTypes = namedAggregates.map(a => FlinkTypeFactory.toTypeInfo(a.left.getType))
@@ -1159,8 +1176,12 @@ object AggregateUtil {
private def transformToAggregateFunctions(
aggregateCalls: Seq[AggregateCall],
inputType: RelDataType,
- needRetraction: Boolean)
- : (Array[Array[Int]], Array[TableAggregateFunction[_, _]], Array[TypeInformation[_]]) = {
+ needRetraction: Boolean,
+ isStateBackedDataViews: Boolean = false)
+ : (Array[Array[Int]],
+ Array[TableAggregateFunction[_, _]],
+ Array[TypeInformation[_]],
+ Array[Seq[DataViewSpec[_]]]) = {
// store the aggregate fields of each aggregate function, by the same order of aggregates.
val aggFieldIndexes = new Array[Array[Int]](aggregateCalls.size)
@@ -1398,14 +1419,28 @@ object AggregateUtil {
}
}
+ val accSpecs = new Array[Seq[DataViewSpec[_]]](aggregateCalls.size)
+
// create accumulator type information for every aggregate function
aggregates.zipWithIndex.foreach { case (agg, index) =>
- if (null == accTypes(index)) {
+ if (accTypes(index) != null) {
+ val (accType, specs) = removeStateViewFieldsFromAccTypeInfo(index,
+ agg,
+ accTypes(index),
+ isStateBackedDataViews)
+ if (specs.isDefined) {
+ accSpecs(index) = specs.get
+ accTypes(index) = accType
+ } else {
+ accSpecs(index) = Seq()
+ }
+ } else {
+ accSpecs(index) = Seq()
accTypes(index) = getAccumulatorTypeOfAggregateFunction(agg)
}
}
- (aggFieldIndexes, aggregates, accTypes)
+ (aggFieldIndexes, aggregates, accTypes, accSpecs)
}
private def createRowTypeForKeysAndAggregates(
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala
index 5f48e09..7b20114 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala
@@ -18,7 +18,7 @@
package org.apache.flink.table.runtime.aggregate
-import org.apache.flink.api.common.functions.Function
+import org.apache.flink.api.common.functions.{Function, RuntimeContext}
import org.apache.flink.types.Row
/**
@@ -27,6 +27,14 @@ import org.apache.flink.types.Row
abstract class GeneratedAggregations extends Function {
/**
+ * Setup method for [[org.apache.flink.table.functions.AggregateFunction]].
+ * It can be used for initialization work. By default, this method does nothing.
+ *
+ * @param ctx The runtime context.
+ */
+ def open(ctx: RuntimeContext)
+
+ /**
* Sets the results of the aggregations (partial or final) to the output row.
* Final results are computed with the aggregation function.
* Partial results are the accumulators themselves.
@@ -100,6 +108,17 @@ abstract class GeneratedAggregations extends Function {
* aggregated results
*/
def resetAccumulator(accumulators: Row)
+
+ /**
+ * Cleanup for the accumulators.
+ */
+ def cleanup()
+
+ /**
+ * Tear-down method for [[org.apache.flink.table.functions.AggregateFunction]].
+ * It can be used for clean up work. By default, this method does nothing.
+ */
+ def close()
}
class SingleElementIterable[T] extends java.lang.Iterable[T] {
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala
index 690a7c3..a476987 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala
@@ -23,9 +23,8 @@ import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.types.Row
import org.apache.flink.util.Collector
-import org.apache.flink.api.common.state.ValueStateDescriptor
+import org.apache.flink.api.common.state.{StateDescriptor, ValueState, ValueStateDescriptor}
import org.apache.flink.api.java.typeutils.RowTypeInfo
-import org.apache.flink.api.common.state.ValueState
import org.apache.flink.table.api.{StreamQueryConfig, Types}
import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
import org.slf4j.{Logger, LoggerFactory}
@@ -65,6 +64,7 @@ class GroupAggProcessFunction(
genAggregations.code)
LOG.debug("Instantiating AggregateHelper.")
function = clazz.newInstance()
+ function.open(getRuntimeContext)
newRow = new CRow(function.createOutputRow(), true)
prevRow = new CRow(function.createOutputRow(), false)
@@ -162,7 +162,11 @@ class GroupAggProcessFunction(
if (needToCleanupState(timestamp)) {
cleanupState(state, cntState)
+ function.cleanup()
}
}
+ override def close(): Unit = {
+ function.close()
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/1fc0b641/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala
index ab3dc1d..5c28519 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala
@@ -22,10 +22,7 @@ import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.types.Row
import org.apache.flink.util.Collector
-import org.apache.flink.api.common.state.ValueState
-import org.apache.flink.api.common.state.ValueStateDescriptor
-import org.apache.flink.api.common.state.MapState
-import org.apache.flink.api.common.state.MapStateDescriptor
+import org.apache.flink.api.common.state._
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.ListTypeInfo
import java.util.{ArrayList, List => JList}
@@ -71,6 +68,8 @@ class ProcTimeBoundedRangeOver(
genAggregations.code)
LOG.debug("Instantiating AggregateHelper.")
function = clazz.newInstance()
+ function.open(getRuntimeContext)
+
output = new CRow(function.createOutputRow(), true)
// We keep the elements received in a MapState indexed based on their ingestion time
@@ -121,6 +120,7 @@ class ProcTimeBoundedRangeOver(
if (needToCleanupState(timestamp)) {
// clean up and return
cleanupState(rowMapState, accumulatorState)
+ function.cleanup()
return
}
@@ -201,4 +201,7 @@ class ProcTimeBoundedRangeOver(
accumulatorState.update(accumulators)
}
+ override def close(): Unit = {
+ function.close()
+ }
}