You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/09/05 12:22:42 UTC

[3/5] flink git commit: [FLINK-6751] [docs] Add documentation for user-defined AggregateFunction.

[FLINK-6751] [docs] Add documentation for user-defined AggregateFunction.

This closes #4546.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/09344aa2
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/09344aa2
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/09344aa2

Branch: refs/heads/master
Commit: 09344aa2dc36b9b3ea4c5b7573ff532e26f9b0dd
Parents: ba03b78
Author: shaoxuan-wang <sh...@apache.org>
Authored: Wed Aug 16 00:00:12 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Tue Sep 5 13:53:59 2017 +0200

----------------------------------------------------------------------
 docs/dev/table/udfs.md                          | 408 ++++++++++++++++++-
 docs/fig/udagg-mechanism.png                    | Bin 0 -> 201262 bytes
 .../table/functions/AggregateFunction.scala     |   2 +-
 3 files changed, 403 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/09344aa2/docs/dev/table/udfs.md
----------------------------------------------------------------------
diff --git a/docs/dev/table/udfs.md b/docs/dev/table/udfs.md
index 55f58b6..6c9bc1a 100644
--- a/docs/dev/table/udfs.md
+++ b/docs/dev/table/udfs.md
@@ -24,15 +24,18 @@ under the License.
 
 User-defined functions are an important feature, because they significantly extend the expressiveness of queries.
 
-**TODO**
-
 * This will be replaced by the TOC
 {:toc}
 
 Register User-Defined Functions
 -------------------------------
+In most cases, a user-defined function must be registered before it can be used in an query. It is not necessary to register functions for the Scala Table API. 
+
+Functions are registered at the `TableEnvironment` by calling a `registerFunction()` method. When a user-defined function is registered, it is inserted into the function catalog of the `TableEnvironment` such that the Table API or SQL parser can recognize and properly translate it. 
+
+Please find detailed examples of how to register and how to call each type of user-defined function 
+(`ScalarFunction`, `TableFunction`, and `AggregateFunction`) in the following sub-sessions.
 
-**TODO**
 
 {% top %}
 
@@ -97,8 +100,6 @@ tableEnv.sql("SELECT string, HASHCODE(string) FROM MyTable");
 
 By default the result type of an evaluation method is determined by Flink's type extraction facilities. This is sufficient for basic types or simple POJOs but might be wrong for more complex, custom, or composite types. In these cases `TypeInformation` of the result type can be manually defined by overriding `ScalarFunction#getResultType()`.
 
-Internally, the Table API and SQL code generation works with primitive values as much as possible. If a user-defined scalar function should not introduce much overhead through object creation/casting during runtime, it is recommended to declare parameters and result types as primitive types instead of their boxed classes. `Types.DATE` and `Types.TIME` can also be represented as `int`. `Types.TIMESTAMP` can be represented as `long`.
-
 The following example shows an advanced example which takes the internal timestamp representation and also returns the internal timestamp representation as a long value. By overriding `ScalarFunction#getResultType()` we define that the returned long value should be interpreted as a `Types.TIMESTAMP` by the code generation.
 
 <div class="codetabs" markdown="1">
@@ -264,10 +265,405 @@ class CustomTypeSplit extends TableFunction[Row] {
 
 {% top %}
 
+
 Aggregation Functions
 ---------------------
 
-**TODO**
+User-Defined Aggregate Functions (UDAGGs) aggregate a table (one ore more rows with one or more attributes) to a scalar value. 
+
+<center>
+<img alt="UDAGG mechanism" src="{{ site.baseurl }}/fig/udagg-mechanism.png" width="80%">
+</center>
+
+The above figure shows an example of an aggregation. Assume you have a table that contains data about beverages. The table consists of three columns, `id`, `name` and `price` and 5 rows. Imagine you need to find the highest price of all beverages in the table, i.e., perform a `max()` aggregation. You would need to check each of the 5 rows and the result would be a single numeric value.
+
+User-defined aggregation functions are implemented by extending the `AggregateFunction` class. An `AggregateFunction` works as follows. First, it needs an `accumulator`, which is the data structure that holds the intermediate result of the aggregation. An empty accumulator is created by calling the `createAccumulator()` method of the `AggregateFunction`. Subsequently, the `accumulate()` method of the function is called for each input row to update the accumulator. Once all rows have been processed, the `getValue()` method of the function is called to compute and return the final result. 
+
+**The following methods are mandatory for each `AggregateFunction`:**
+
+- `createAccumulator()`
+- `accumulate()` 
+- `getValue()`
+
+Flinkā€™s type extraction facilities can fail to identify complex data types, e.g., if they are not basic types or simple POJOs. So similar to `ScalarFunction` and `TableFunction`, `AggregateFunction` provides methods to specify the `TypeInformation` of the result type (through 
+ `AggregateFunction#getResultType()`) and the type of the accumulator (through `AggregateFunction#getAccumulatorType()`).
+ 
+Besides the above methods, there are a few contracted methods that can be 
+optionally implemented. While some of these methods allow the system more efficient query execution, others are mandatory for certain use cases. For instance, the `merge()` method is mandatory if the aggregation function should be applied in the context of a session group window (the accumulators of two session windows need to be joined when a row is observed that "connects" them). 
+
+**The following methods of `AggregateFunction` are required depending on the use case:**
+
+- `retract()` is required for aggregations on bounded `OVER` windows.
+- `merge()` is required for many batch aggreagtions and session window aggregations.
+- `resetAccumulator()` is required for many batch aggregations.
+
+All methods of `AggregateFunction` must be declared as `public`, not `static` and named exactly as the names mentioned above. The methods `createAccumulator`, `getValue`, `getResultType`, and `getAccumulatorType` are defined in the `AggregateFunction` abstract class, while others are contracted methods. In order to define a table function, one has to extend the base class `org.apache.flink.table.functions.AggregateFunction` and implement one (or more) `accumulate` methods. 
+
+Detailed documentation for all methods of `AggregateFunction` is given below. 
+
+<div class="codetabs" markdown="1">
+<div data-lang="java" markdown="1">
+{% highlight java %}
+/**
+  * Base class for aggregation functions. 
+  *
+  * @param <T>   the type of the aggregation result
+  * @param <ACC> the type of the aggregation accumulator. The accumulator is used to keep the
+  *             aggregated values which are needed to compute an aggregation result.
+  *             AggregateFunction represents its state using accumulator, thereby the state of the
+  *             AggregateFunction must be put into the accumulator.
+  */
+public abstract class AggregateFunction<T, ACC> extends UserDefinedFunction {
+
+  /**
+    * Creates and init the Accumulator for this [[AggregateFunction]].
+    *
+    * @return the accumulator with the initial value
+    */
+  public ACC createAccumulator(); // MANDATORY
+
+  /** Processes the input values and update the provided accumulator instance. The method
+    * accumulate can be overloaded with different custom types and arguments. An AggregateFunction
+    * requires at least one accumulate() method.
+    *
+    * @param accumulator           the accumulator which contains the current aggregated results
+    * @param [user defined inputs] the input value (usually obtained from a new arrived data).
+    */
+  public void accumulate(ACC accumulator, [user defined inputs]); // MANDATORY
+
+  /**
+    * Retracts the input values from the accumulator instance. The current design assumes the
+    * inputs are the values that have been previously accumulated. The method retract can be
+    * overloaded with different custom types and arguments. This function must be implemented for
+    * datastream bounded over aggregate.
+    *
+    * @param accumulator           the accumulator which contains the current aggregated results
+    * @param [user defined inputs] the input value (usually obtained from a new arrived data).
+    */
+  public void retract(ACC accumulator, [user defined inputs]); // OPTIONAL
+
+  /**
+    * Merges a group of accumulator instances into one accumulator instance. This function must be
+    * implemented for datastream session window grouping aggregate and dataset grouping aggregate.
+    *
+    * @param accumulator  the accumulator which will keep the merged aggregate results. It should
+    *                     be noted that the accumulator may contain the previous aggregated
+    *                     results. Therefore user should not replace or clean this instance in the
+    *                     custom merge method.
+    * @param its          an [[java.lang.Iterable]] pointed to a group of accumulators that will be
+    *                     merged.
+    */
+  public void merge(ACC accumulator, java.lang.Iterable<ACC> its); // OPTIONAL
+
+  /**
+    * Called every time when an aggregation result should be materialized.
+    * The returned value could be either an early and incomplete result
+    * (periodically emitted as data arrive) or the final result of the
+    * aggregation.
+    *
+    * @param accumulator the accumulator which contains the current
+    *                    aggregated results
+    * @return the aggregation result
+    */
+  public T getValue(ACC accumulator); // MANDATORY
+
+  /**
+    * Resets the accumulator for this [[AggregateFunction]]. This function must be implemented for
+    * dataset grouping aggregate.
+    *
+    * @param accumulator  the accumulator which needs to be reset
+    */
+  public void resetAccumulator(ACC accumulator); // OPTIONAL
+
+  /**
+    * Returns true if this AggregateFunction can only be applied in an OVER window.
+    *
+    * @return true if the AggregateFunction requires an OVER window, false otherwise.
+    */
+  public Boolean requiresOver = false; // PRE-DEFINED
+
+  /**
+    * Returns the TypeInformation of the AggregateFunction's result.
+    *
+    * @return The TypeInformation of the AggregateFunction's result or null if the result type
+    *         should be automatically inferred.
+    */
+  public TypeInformation<T> getResultType = null; // PRE-DEFINED
+
+  /**
+    * Returns the TypeInformation of the AggregateFunction's accumulator.
+    *
+    * @return The TypeInformation of the AggregateFunction's accumulator or null if the
+    *         accumulator type should be automatically inferred.
+    */
+  public TypeInformation<T> getAccumulatorType = null; // PRE-DEFINED
+}
+{% endhighlight %}
+</div>
+
+<div data-lang="scala" markdown="1">
+{% highlight scala %}
+/**
+  * Base class for aggregation functions. 
+  *
+  * @tparam T   the type of the aggregation result
+  * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the
+  *             aggregated values which are needed to compute an aggregation result.
+  *             AggregateFunction represents its state using accumulator, thereby the state of the
+  *             AggregateFunction must be put into the accumulator.
+  */
+abstract class AggregateFunction[T, ACC] extends UserDefinedFunction {
+  /**
+    * Creates and init the Accumulator for this [[AggregateFunction]].
+    *
+    * @return the accumulator with the initial value
+    */
+  def createAccumulator(): ACC // MANDATORY
+
+  /**
+    * Processes the input values and update the provided accumulator instance. The method
+    * accumulate can be overloaded with different custom types and arguments. An AggregateFunction
+    * requires at least one accumulate() method.
+    *
+    * @param accumulator           the accumulator which contains the current aggregated results
+    * @param [user defined inputs] the input value (usually obtained from a new arrived data).
+    */
+  def accumulate(accumulator: ACC, [user defined inputs]): Unit // MANDATORY
+
+  /**
+    * Retracts the input values from the accumulator instance. The current design assumes the
+    * inputs are the values that have been previously accumulated. The method retract can be
+    * overloaded with different custom types and arguments. This function must be implemented for
+    * datastream bounded over aggregate.
+    *
+    * @param accumulator           the accumulator which contains the current aggregated results
+    * @param [user defined inputs] the input value (usually obtained from a new arrived data).
+    */
+  def retract(accumulator: ACC, [user defined inputs]): Unit // OPTIONAL
+
+  /**
+    * Merges a group of accumulator instances into one accumulator instance. This function must be
+    * implemented for datastream session window grouping aggregate and dataset grouping aggregate.
+    *
+    * @param accumulator  the accumulator which will keep the merged aggregate results. It should
+    *                     be noted that the accumulator may contain the previous aggregated
+    *                     results. Therefore user should not replace or clean this instance in the
+    *                     custom merge method.
+    * @param its          an [[java.lang.Iterable]] pointed to a group of accumulators that will be
+    *                     merged.
+    */
+  def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit // OPTIONAL
+  
+  /**
+    * Called every time when an aggregation result should be materialized.
+    * The returned value could be either an early and incomplete result
+    * (periodically emitted as data arrive) or the final result of the
+    * aggregation.
+    *
+    * @param accumulator the accumulator which contains the current
+    *                    aggregated results
+    * @return the aggregation result
+    */
+  def getValue(accumulator: ACC): T // MANDATORY
+
+  h/**
+    * Resets the accumulator for this [[AggregateFunction]]. This function must be implemented for
+    * dataset grouping aggregate.
+    *
+    * @param accumulator  the accumulator which needs to be reset
+    */
+  def resetAccumulator(accumulator: ACC): Unit // OPTIONAL
+
+  /**
+    * Returns true if this AggregateFunction can only be applied in an OVER window.
+    *
+    * @return true if the AggregateFunction requires an OVER window, false otherwise.
+    */
+  def requiresOver: Boolean = false // PRE-DEFINED
+
+  /**
+    * Returns the TypeInformation of the AggregateFunction's result.
+    *
+    * @return The TypeInformation of the AggregateFunction's result or null if the result type
+    *         should be automatically inferred.
+    */
+  def getResultType: TypeInformation[T] = null // PRE-DEFINED
+
+  /**
+    * Returns the TypeInformation of the AggregateFunction's accumulator.
+    *
+    * @return The TypeInformation of the AggregateFunction's accumulator or null if the
+    *         accumulator type should be automatically inferred.
+    */
+  def getAccumulatorType: TypeInformation[ACC] = null // PRE-DEFINED
+}
+{% endhighlight %}
+</div>
+</div>
+
+
+The following example shows how to
+
+- define an `AggregateFunction` that calculates the weighted average on a given column, 
+- register the function in the `TableEnvironment`, and 
+- use the function in a query.  
+
+To calculate an weighted average value, the accumulator needs to store the weighted sum and count of all the data that has been accumulated. In our example we define a class `WeightedAvgAccum` to be the accumulator. Accumulators are automatically backup-ed by Flink's checkpointing mechanism and restored in case of a failure to ensure exactly-once semantics.
+
+The `accumulate()` method of our `WeightedAvg` `AggregateFunction` has three inputs. The first one is the `WeightedAvgAccum` accumulator, the other two are user-defined inputs: input value `ivalue` and weight of the input `iweight`. Although the `retract()`, `merge()`, and `resetAccumulator()` methods are not mandatory for most aggregation types, we provide them below as examples. Please note that we used Java primitive types and defined `getResultType()` and `getAccumulatorType()` methods in the Scala example because Flink type extraction does not work very well for Scala types.
+
+<div class="codetabs" markdown="1">
+<div data-lang="java" markdown="1">
+{% highlight java %}
+/**
+ * Accumulator for WeightedAvg.
+ */
+public static class WeightedAvgAccum {
+    public long sum = 0;
+    public int count = 0;
+}
+
+/**
+ * Weighted Average user-defined aggregate function.
+ */
+public static class WeightedAvg extends AggregateFunction<long, WeightedAvgAccum> {
+
+    @Override
+    public WeightedAvgAccum createAccumulator() {
+        return new WeightedAvgAccum();
+    }
+
+    @Override
+    public long getValue(WeightedAvgAccum acc) {
+        if (acc.count == 0) {
+            return null;
+        } else {
+            return acc.sum / acc.count;
+        }
+    }
+
+    public void accumulate(WeightedAvgAccum acc, long iValue, int iWeight) {
+        acc.sum += iValue * iWeight;
+        acc.count += iWeight;
+    }
+
+    public void retract(WeightedAvgAccum acc, long iValue, int iWeight) {
+        acc.sum -= iValue * iWeight;
+        acc.count -= iWeight;
+    }
+    
+    public void merge(WeightedAvgAccum acc, Iterable<WeightedAvgAccum> it) {
+        Iterator<WeightedAvgAccum> iter = it.iterator();
+        while (iter.hasNext()) {
+            WeightedAvgAccum a = iter.next();
+            acc.count += a.count;
+            acc.sum += a.sum;
+        }
+    }
+    
+    public void resetAccumulator(WeightedAvgAccum acc) {
+        acc.count = 0;
+        acc.sum = 0L;
+    }
+}
+
+// register function
+StreamTableEnvironment tEnv = ...
+tEnv.registerFunction("wAvg", new WeightedAvg());
+
+// use function
+tEnv.sql("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user");
+
+{% endhighlight %}
+</div>
+
+<div data-lang="scala" markdown="1">
+{% highlight scala %}
+import java.lang.{Long => JLong, Integer => JInteger}
+import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1}
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.java.typeutils.TupleTypeInfo
+import org.apache.flink.table.functions.AggregateFunction
+
+/**
+ * Accumulator for WeightedAvg.
+ */
+class WeightedAvgAccum extends JTuple1[JLong, JInteger] {
+  sum = 0L
+  count = 0
+}
+
+/**
+ * Weighted Average user-defined aggregate function.
+ */
+class WeightedAvg extends AggregateFunction[JLong, CountAccumulator] {
+
+  override def createAccumulator(): WeightedAvgAccum = {
+    new WeightedAvgAccum
+  }
+  
+  override def getValue(acc: WeightedAvgAccum): JLong = {
+    if (acc.count == 0) {
+        null
+    } else {
+        acc.sum / acc.count
+    }
+  }
+  
+  def accumulate(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = {
+    acc.sum += iValue * iWeight
+    acc.count += iWeight
+  }
+
+  def retract(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = {
+    acc.sum -= iValue * iWeight
+    acc.count -= iWeight
+  }
+    
+  def merge(acc: WeightedAvgAccum, it: java.lang.Iterable[WeightedAvgAccum]): Unit = {
+    val iter = it.iterator()
+    while (iter.hasNext) {
+      val a = iter.next()
+      acc.count += a.count
+      acc.sum += a.sum
+    }
+  }
+
+  def resetAccumulator(acc: WeightedAvgAccum): Unit = {
+    acc.count = 0
+    acc.sum = 0L
+  }
+
+  override def getAccumulatorType: TypeInformation[WeightedAvgAccum] = {
+    new TupleTypeInfo(classOf[WeightedAvgAccum], 
+                      BasicTypeInfo.LONG_TYPE_INFO,
+                      BasicTypeInfo.INT_TYPE_INFO)
+  }
+
+  override def getResultType: TypeInformation[JLong] =
+    BasicTypeInfo.LONG_TYPE_INFO
+}
+
+// register function
+val tEnv: StreamTableEnvironment = ???
+tEnv.registerFunction("wAvg", new WeightedAvg())
+
+// use function
+tEnv.sql("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user")
+
+{% endhighlight %}
+</div>
+</div>
+
+
+{% top %}
+
+Best Practices for Implementing UDFs
+------------------------------------
+
+The Table API and SQL code generation internally tries to work with primitive values as much as possible. A user-defined function can introduce much overhead through object creation, casting, and (un)boxing. Therefore, it is highly recommended to declare parameters and result types as primitive types instead of their boxed classes. `Types.DATE` and `Types.TIME` can also be represented as `int`. `Types.TIMESTAMP` can be represented as `long`. 
+
+We recommended that user-defined functions should be written by Java instead of Scala as Scala types pose a challenge for Flink's type extractor.
 
 {% top %}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/09344aa2/docs/fig/udagg-mechanism.png
----------------------------------------------------------------------
diff --git a/docs/fig/udagg-mechanism.png b/docs/fig/udagg-mechanism.png
new file mode 100644
index 0000000..043196f
Binary files /dev/null and b/docs/fig/udagg-mechanism.png differ

http://git-wip-us.apache.org/repos/asf/flink/blob/09344aa2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
index 8f50971..d3f9497 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
@@ -33,7 +33,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
   *  - merge, and
   *  - resetAccumulator
   *
-  * All these methods muse be declared publicly, not static and named exactly as the names
+  * All these methods must be declared publicly, not static and named exactly as the names
   * mentioned above. The methods createAccumulator and getValue are defined in the
   * [[AggregateFunction]] functions, while other methods are explained below.
   *