You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/04/05 07:30:58 UTC

spark git commit: [SPARK-14359] Create built-in functions for typed aggregates in Java

Repository: spark
Updated Branches:
  refs/heads/master 7db56244f -> 064623014


[SPARK-14359] Create built-in functions for typed aggregates in Java

## What changes were proposed in this pull request?

This adds the corresponding Java static functions for built-in typed aggregates already exposed in Scala.

## How was this patch tested?

Unit tests.

rxin

Author: Eric Liang <ek...@databricks.com>

Closes #12168 from ericl/sc-2794.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/06462301
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/06462301
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/06462301

Branch: refs/heads/master
Commit: 064623014e0d6dfb0376722f24e81027fde649de
Parents: 7db5624
Author: Eric Liang <ek...@databricks.com>
Authored: Tue Apr 5 00:30:55 2016 -0500
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Apr 5 00:30:55 2016 -0500

----------------------------------------------------------------------
 .../execution/aggregate/typedaggregators.scala  | 33 +++++++++++++
 .../spark/sql/expressions/java/typed.java       | 42 +++++++++++++++++
 .../sql/sources/JavaDatasetAggregatorSuite.java | 49 ++++++++++++++++++++
 3 files changed, 124 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/06462301/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
index 9afc290..7a18d0a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
@@ -17,6 +17,9 @@
 
 package org.apache.spark.sql.execution.aggregate
 
+import org.apache.spark.api.java.function.MapFunction
+import org.apache.spark.sql.TypedColumn
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.expressions.Aggregator
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -30,6 +33,8 @@ class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT]
   override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a))
   override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2)
   override def finish(reduction: OUT): OUT = reduction
+
+  // TODO(ekl) java api support once this is exposed in scala
 }
 
 
@@ -38,6 +43,13 @@ class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double]
   override def reduce(b: Double, a: IN): Double = b + f(a)
   override def merge(b1: Double, b2: Double): Double = b1 + b2
   override def finish(reduction: Double): Double = reduction
+
+  // Java api support
+  def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
+  def toColumnJava(): TypedColumn[IN, java.lang.Double] = {
+    toColumn(ExpressionEncoder(), ExpressionEncoder())
+      .asInstanceOf[TypedColumn[IN, java.lang.Double]]
+  }
 }
 
 
@@ -46,6 +58,13 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
   override def reduce(b: Long, a: IN): Long = b + f(a)
   override def merge(b1: Long, b2: Long): Long = b1 + b2
   override def finish(reduction: Long): Long = reduction
+
+  // Java api support
+  def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long])
+  def toColumnJava(): TypedColumn[IN, java.lang.Long] = {
+    toColumn(ExpressionEncoder(), ExpressionEncoder())
+      .asInstanceOf[TypedColumn[IN, java.lang.Long]]
+  }
 }
 
 
@@ -56,6 +75,13 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
   }
   override def merge(b1: Long, b2: Long): Long = b1 + b2
   override def finish(reduction: Long): Long = reduction
+
+  // Java api support
+  def this(f: MapFunction[IN, Object]) = this(x => f.call(x))
+  def toColumnJava(): TypedColumn[IN, java.lang.Long] = {
+    toColumn(ExpressionEncoder(), ExpressionEncoder())
+      .asInstanceOf[TypedColumn[IN, java.lang.Long]]
+  }
 }
 
 
@@ -66,4 +92,11 @@ class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), D
   override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = {
     (b1._1 + b2._1, b1._2 + b2._2)
   }
+
+  // Java api support
+  def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
+  def toColumnJava(): TypedColumn[IN, java.lang.Double] = {
+    toColumn(ExpressionEncoder(), ExpressionEncoder())
+      .asInstanceOf[TypedColumn[IN, java.lang.Double]]
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/06462301/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
index cdba970..8ff7b65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
@@ -18,7 +18,13 @@
 package org.apache.spark.sql.expressions.java;
 
 import org.apache.spark.annotation.Experimental;
+import org.apache.spark.api.java.function.MapFunction;
 import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.TypedColumn;
+import org.apache.spark.sql.execution.aggregate.TypedAverage;
+import org.apache.spark.sql.execution.aggregate.TypedCount;
+import org.apache.spark.sql.execution.aggregate.TypedSumDouble;
+import org.apache.spark.sql.execution.aggregate.TypedSumLong;
 
 /**
  * :: Experimental ::
@@ -30,5 +36,41 @@ import org.apache.spark.sql.Dataset;
  */
 @Experimental
 public class typed {
+  // Note: make sure to keep in sync with typed.scala
 
+  /**
+   * Average aggregate function.
+   *
+   * @since 2.0.0
+   */
+  public static<T> TypedColumn<T, Double> avg(MapFunction<T, Double> f) {
+    return new TypedAverage<T>(f).toColumnJava();
+  }
+
+  /**
+   * Count aggregate function.
+   *
+   * @since 2.0.0
+   */
+  public static<T> TypedColumn<T, Long> count(MapFunction<T, Object> f) {
+    return new TypedCount<T>(f).toColumnJava();
+  }
+
+  /**
+   * Sum aggregate function for floating point (double) type.
+   *
+   * @since 2.0.0
+   */
+  public static<T> TypedColumn<T, Double> sum(MapFunction<T, Double> f) {
+    return new TypedSumDouble<T>(f).toColumnJava();
+  }
+
+  /**
+   * Sum aggregate function for integral (long, i.e. 64 bit integer) type.
+   *
+   * @since 2.0.0
+   */
+  public static<T> TypedColumn<T, Long> sumLong(MapFunction<T, Long> f) {
+    return new TypedSumLong<T>(f).toColumnJava();
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/06462301/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
index c4c455b..c8d0eec 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
@@ -35,6 +35,7 @@ import org.apache.spark.sql.Encoder;
 import org.apache.spark.sql.Encoders;
 import org.apache.spark.sql.KeyValueGroupedDataset;
 import org.apache.spark.sql.expressions.Aggregator;
+import org.apache.spark.sql.expressions.java.typed;
 import org.apache.spark.sql.test.TestSQLContext;
 
 /**
@@ -120,4 +121,52 @@ public class JavaDatasetAggregatorSuite implements Serializable {
       return reduction;
     }
   }
+
+  @Test
+  public void testTypedAggregationAverage() {
+    KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+    Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.avg(
+      new MapFunction<Tuple2<String, Integer>, Double>() {
+        public Double call(Tuple2<String, Integer> value) throws Exception {
+          return (double)(value._2() * 2);
+        }
+      }));
+    Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList());
+  }
+
+  @Test
+  public void testTypedAggregationCount() {
+    KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+    Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.count(
+      new MapFunction<Tuple2<String, Integer>, Object>() {
+        public Object call(Tuple2<String, Integer> value) throws Exception {
+          return value;
+        }
+      }));
+    Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList());
+  }
+
+  @Test
+  public void testTypedAggregationSumDouble() {
+    KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+    Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.sum(
+      new MapFunction<Tuple2<String, Integer>, Double>() {
+        public Double call(Tuple2<String, Integer> value) throws Exception {
+          return (double)value._2();
+        }
+      }));
+    Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList());
+  }
+
+  @Test
+  public void testTypedAggregationSumLong() {
+    KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+    Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.sumLong(
+      new MapFunction<Tuple2<String, Integer>, Long>() {
+        public Long call(Tuple2<String, Integer> value) throws Exception {
+          return (long)value._2();
+        }
+      }));
+    Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org