You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/04/05 07:30:58 UTC
spark git commit: [SPARK-14359] Create built-in functions for typed
aggregates in Java
Repository: spark
Updated Branches:
refs/heads/master 7db56244f -> 064623014
[SPARK-14359] Create built-in functions for typed aggregates in Java
## What changes were proposed in this pull request?
This adds the corresponding Java static functions for built-in typed aggregates already exposed in Scala.
## How was this patch tested?
Unit tests.
rxin
Author: Eric Liang <ek...@databricks.com>
Closes #12168 from ericl/sc-2794.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/06462301
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/06462301
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/06462301
Branch: refs/heads/master
Commit: 064623014e0d6dfb0376722f24e81027fde649de
Parents: 7db5624
Author: Eric Liang <ek...@databricks.com>
Authored: Tue Apr 5 00:30:55 2016 -0500
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Apr 5 00:30:55 2016 -0500
----------------------------------------------------------------------
.../execution/aggregate/typedaggregators.scala | 33 +++++++++++++
.../spark/sql/expressions/java/typed.java | 42 +++++++++++++++++
.../sql/sources/JavaDatasetAggregatorSuite.java | 49 ++++++++++++++++++++
3 files changed, 124 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/06462301/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
index 9afc290..7a18d0a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.execution.aggregate
+import org.apache.spark.api.java.function.MapFunction
+import org.apache.spark.sql.TypedColumn
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -30,6 +33,8 @@ class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT]
override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a))
override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2)
override def finish(reduction: OUT): OUT = reduction
+
+ // TODO(ekl) java api support once this is exposed in scala
}
@@ -38,6 +43,13 @@ class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double]
override def reduce(b: Double, a: IN): Double = b + f(a)
override def merge(b1: Double, b2: Double): Double = b1 + b2
override def finish(reduction: Double): Double = reduction
+
+ // Java api support
+ def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
+ def toColumnJava(): TypedColumn[IN, java.lang.Double] = {
+ toColumn(ExpressionEncoder(), ExpressionEncoder())
+ .asInstanceOf[TypedColumn[IN, java.lang.Double]]
+ }
}
@@ -46,6 +58,13 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
override def reduce(b: Long, a: IN): Long = b + f(a)
override def merge(b1: Long, b2: Long): Long = b1 + b2
override def finish(reduction: Long): Long = reduction
+
+ // Java api support
+ def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long])
+ def toColumnJava(): TypedColumn[IN, java.lang.Long] = {
+ toColumn(ExpressionEncoder(), ExpressionEncoder())
+ .asInstanceOf[TypedColumn[IN, java.lang.Long]]
+ }
}
@@ -56,6 +75,13 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
}
override def merge(b1: Long, b2: Long): Long = b1 + b2
override def finish(reduction: Long): Long = reduction
+
+ // Java api support
+ def this(f: MapFunction[IN, Object]) = this(x => f.call(x))
+ def toColumnJava(): TypedColumn[IN, java.lang.Long] = {
+ toColumn(ExpressionEncoder(), ExpressionEncoder())
+ .asInstanceOf[TypedColumn[IN, java.lang.Long]]
+ }
}
@@ -66,4 +92,11 @@ class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), D
override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = {
(b1._1 + b2._1, b1._2 + b2._2)
}
+
+ // Java api support
+ def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
+ def toColumnJava(): TypedColumn[IN, java.lang.Double] = {
+ toColumn(ExpressionEncoder(), ExpressionEncoder())
+ .asInstanceOf[TypedColumn[IN, java.lang.Double]]
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/06462301/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
index cdba970..8ff7b65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
@@ -18,7 +18,13 @@
package org.apache.spark.sql.expressions.java;
import org.apache.spark.annotation.Experimental;
+import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.TypedColumn;
+import org.apache.spark.sql.execution.aggregate.TypedAverage;
+import org.apache.spark.sql.execution.aggregate.TypedCount;
+import org.apache.spark.sql.execution.aggregate.TypedSumDouble;
+import org.apache.spark.sql.execution.aggregate.TypedSumLong;
/**
* :: Experimental ::
@@ -30,5 +36,41 @@ import org.apache.spark.sql.Dataset;
*/
@Experimental
public class typed {
+ // Note: make sure to keep in sync with typed.scala
+ /**
+ * Average aggregate function.
+ *
+ * @since 2.0.0
+ */
+ public static<T> TypedColumn<T, Double> avg(MapFunction<T, Double> f) {
+ return new TypedAverage<T>(f).toColumnJava();
+ }
+
+ /**
+ * Count aggregate function.
+ *
+ * @since 2.0.0
+ */
+ public static<T> TypedColumn<T, Long> count(MapFunction<T, Object> f) {
+ return new TypedCount<T>(f).toColumnJava();
+ }
+
+ /**
+ * Sum aggregate function for floating point (double) type.
+ *
+ * @since 2.0.0
+ */
+ public static<T> TypedColumn<T, Double> sum(MapFunction<T, Double> f) {
+ return new TypedSumDouble<T>(f).toColumnJava();
+ }
+
+ /**
+ * Sum aggregate function for integral (long, i.e. 64 bit integer) type.
+ *
+ * @since 2.0.0
+ */
+ public static<T> TypedColumn<T, Long> sumLong(MapFunction<T, Long> f) {
+ return new TypedSumLong<T>(f).toColumnJava();
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/06462301/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
index c4c455b..c8d0eec 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
@@ -35,6 +35,7 @@ import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.KeyValueGroupedDataset;
import org.apache.spark.sql.expressions.Aggregator;
+import org.apache.spark.sql.expressions.java.typed;
import org.apache.spark.sql.test.TestSQLContext;
/**
@@ -120,4 +121,52 @@ public class JavaDatasetAggregatorSuite implements Serializable {
return reduction;
}
}
+
+ @Test
+ public void testTypedAggregationAverage() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.avg(
+ new MapFunction<Tuple2<String, Integer>, Double>() {
+ public Double call(Tuple2<String, Integer> value) throws Exception {
+ return (double)(value._2() * 2);
+ }
+ }));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList());
+ }
+
+ @Test
+ public void testTypedAggregationCount() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.count(
+ new MapFunction<Tuple2<String, Integer>, Object>() {
+ public Object call(Tuple2<String, Integer> value) throws Exception {
+ return value;
+ }
+ }));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList());
+ }
+
+ @Test
+ public void testTypedAggregationSumDouble() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.sum(
+ new MapFunction<Tuple2<String, Integer>, Double>() {
+ public Double call(Tuple2<String, Integer> value) throws Exception {
+ return (double)value._2();
+ }
+ }));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList());
+ }
+
+ @Test
+ public void testTypedAggregationSumLong() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.sumLong(
+ new MapFunction<Tuple2<String, Integer>, Long>() {
+ public Long call(Tuple2<String, Integer> value) throws Exception {
+ return (long)value._2();
+ }
+ }));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org