You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2018/12/14 13:11:03 UTC

[GitHub] srowen closed pull request #18113: [SPARK-20890][SQL] Added min and max typed aggregation functions

srowen closed pull request #18113: [SPARK-20890][SQL] Added min and max typed aggregation functions
URL: https://github.com/apache/spark/pull/18113
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java b/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java
index ec9c107b1c119..f426dd95cec27 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java
@@ -21,10 +21,7 @@
 import org.apache.spark.annotation.InterfaceStability;
 import org.apache.spark.api.java.function.MapFunction;
 import org.apache.spark.sql.TypedColumn;
-import org.apache.spark.sql.execution.aggregate.TypedAverage;
-import org.apache.spark.sql.execution.aggregate.TypedCount;
-import org.apache.spark.sql.execution.aggregate.TypedSumDouble;
-import org.apache.spark.sql.execution.aggregate.TypedSumLong;
+import org.apache.spark.sql.execution.aggregate.*;
 
 /**
  * :: Experimental ::
@@ -74,4 +71,40 @@
   public static <T> TypedColumn<T, Long> sumLong(MapFunction<T, Long> f) {
     return new TypedSumLong<T>(f).toColumnJava();
   }
+
+  /**
+   * Min aggregate function for floating point (double) type.
+   *
+   * @since 2.3.0
+   */
+  public static <T> TypedColumn<T, Double> min(MapFunction<T, Double> f) {
+    return new JavaTypedMinDouble<T>(f).toColumn();
+  }
+
+  /**
+   * Min aggregate function for integral (long, i.e. 64 bit integer) type.
+   *
+   * @since 2.3.0
+   */
+  public static <T> TypedColumn<T, Long> minLong(MapFunction<T, Long> f) {
+    return new JavaTypedMinLong<T>(f).toColumn();
+  }
+
+  /**
+   * Min aggregate function for floating point (double) type.
+   *
+   * @since 2.3.0
+   */
+  public static <T> TypedColumn<T, Double> max(MapFunction<T, Double> f) {
+    return new JavaTypedMaxDouble<T>(f).toColumn();
+  }
+
+  /**
+   * Min aggregate function for integral (long, i.e. 64 bit integer) type.
+   *
+   * @since 2.3.0
+   */
+  public static <T> TypedColumn<T, Long> maxLong(MapFunction<T, Long> f) {
+    return new JavaTypedMaxLong<T>(f).toColumn();
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
index b6550bf3e4aac..1d019e4f78f6d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution.aggregate
 
 import org.apache.spark.api.java.function.MapFunction
-import org.apache.spark.sql.{Encoder, TypedColumn}
+import org.apache.spark.sql.{Encoder, Encoders, TypedColumn}
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.expressions.Aggregator
 
@@ -38,13 +38,11 @@ class TypedSumDouble[IN](val f: IN => Double) extends Aggregator[IN, Double, Dou
 
   // Java api support
   def this(f: MapFunction[IN, java.lang.Double]) = this((x: IN) => f.call(x).asInstanceOf[Double])
-
   def toColumnJava: TypedColumn[IN, java.lang.Double] = {
     toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
   }
 }
 
-
 class TypedSumLong[IN](val f: IN => Long) extends Aggregator[IN, Long, Long] {
   override def zero: Long = 0L
   override def reduce(b: Long, a: IN): Long = b + f(a)
@@ -56,13 +54,11 @@ class TypedSumLong[IN](val f: IN => Long) extends Aggregator[IN, Long, Long] {
 
   // Java api support
   def this(f: MapFunction[IN, java.lang.Long]) = this((x: IN) => f.call(x).asInstanceOf[Long])
-
   def toColumnJava: TypedColumn[IN, java.lang.Long] = {
     toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]]
   }
 }
 
-
 class TypedCount[IN](val f: IN => Any) extends Aggregator[IN, Long, Long] {
   override def zero: Long = 0
   override def reduce(b: Long, a: IN): Long = {
@@ -81,14 +77,13 @@ class TypedCount[IN](val f: IN => Any) extends Aggregator[IN, Long, Long] {
   }
 }
 
-
 class TypedAverage[IN](val f: IN => Double) extends Aggregator[IN, (Double, Long), Double] {
   override def zero: (Double, Long) = (0.0, 0L)
   override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2)
-  override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2
   override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = {
     (b1._1 + b2._1, b1._2 + b2._2)
   }
+  override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2
 
   override def bufferEncoder: Encoder[(Double, Long)] = ExpressionEncoder[(Double, Long)]()
   override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]()
@@ -99,3 +94,179 @@ class TypedAverage[IN](val f: IN => Double) extends Aggregator[IN, (Double, Long
     toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
   }
 }
+
+trait TypedMinDouble[IN, OUT] extends Aggregator[IN, MutableDouble, OUT] {
+  val f: IN => Double
+  override def zero: MutableDouble = null
+  override def reduce(b: MutableDouble, a: IN): MutableDouble = {
+    if (b == null) {
+      new MutableDouble(f(a))
+    } else {
+      b.value = math.min(b.value, f(a))
+      b
+    }
+  }
+  override def merge(b1: MutableDouble, b2: MutableDouble): MutableDouble = {
+    if (b1 == null) {
+      b2
+    } else if (b2 == null) {
+      b1
+    } else {
+      b1.value = math.min(b1.value, b2.value)
+      b1
+    }
+  }
+
+  override def bufferEncoder: Encoder[MutableDouble] = Encoders.kryo[MutableDouble]
+}
+
+class JavaTypedMinDouble[IN](override val f: IN => Double)
+  extends TypedMinDouble[IN, java.lang.Double] {
+  override def outputEncoder: Encoder[java.lang.Double] = ExpressionEncoder[java.lang.Double]()
+  override def finish(reduction: MutableDouble): java.lang.Double = reduction.value
+  def this(f: MapFunction[IN, java.lang.Double]) = this((x: IN) => f.call(x))
+}
+
+class ScalaTypedMinDouble[IN](override val f: IN => Double)
+  extends TypedMinDouble[IN, Option[Double]] {
+  override def outputEncoder: Encoder[Option[Double]] = ExpressionEncoder[Option[Double]]()
+  override def finish(reduction: MutableDouble): Option[Double] = {
+    if (reduction != null) {
+      Some(reduction.value)
+    } else {
+      None
+    }
+  }
+}
+
+trait TypedMaxDouble[IN, OUT] extends Aggregator[IN, MutableDouble, OUT] {
+  val f: IN => Double
+  override def zero: MutableDouble = null
+  override def reduce(b: MutableDouble, a: IN): MutableDouble = {
+    if (b == null) {
+      new MutableDouble(f(a))
+    } else {
+      b.value = math.max(b.value, f(a))
+      b
+    }
+  }
+  override def merge(b1: MutableDouble, b2: MutableDouble): MutableDouble = {
+    if (b1 == null) {
+      b2
+    } else if (b2 == null) {
+      b1
+    } else {
+      b1.value = math.max(b1.value, b2.value)
+      b1
+    }
+  }
+
+  override def bufferEncoder: Encoder[MutableDouble] = Encoders.kryo[MutableDouble]
+}
+
+class JavaTypedMaxDouble[IN](override val f: IN => Double)
+  extends TypedMaxDouble[IN, java.lang.Double] {
+  override def outputEncoder: Encoder[java.lang.Double] = ExpressionEncoder[java.lang.Double]()
+  override def finish(reduction: MutableDouble): java.lang.Double = reduction.value
+  def this(f: MapFunction[IN, java.lang.Double]) = this((x: IN) => f.call(x))
+}
+
+class ScalaTypedMaxDouble[IN](override val f: IN => Double)
+  extends TypedMaxDouble[IN, Option[Double]] {
+  override def outputEncoder: Encoder[Option[Double]] = ExpressionEncoder[Option[Double]]()
+  override def finish(reduction: MutableDouble): Option[Double] = {
+    if (reduction != null) {
+      Some(reduction.value)
+    } else {
+      None
+    }
+  }
+}
+
+trait TypedMinLong[IN, OUT] extends Aggregator[IN, MutableLong, OUT] {
+  val f: IN => Long
+  override def zero: MutableLong = null
+  override def reduce(b: MutableLong, a: IN): MutableLong = {
+    if (b == null) {
+      new MutableLong(f(a))
+    } else {
+      b.value = math.min(b.value, f(a))
+      b
+    }
+  }
+  override def merge(b1: MutableLong, b2: MutableLong): MutableLong = {
+    if (b1 == null) {
+      b2
+    } else if (b2 == null) {
+      b1
+    } else {
+      b1.value = math.min(b1.value, b2.value)
+      b1
+    }
+  }
+
+  override def bufferEncoder: Encoder[MutableLong] = Encoders.kryo[MutableLong]
+}
+
+class JavaTypedMinLong[IN](override val f: IN => Long) extends TypedMinLong[IN, java.lang.Long] {
+  override def outputEncoder: Encoder[java.lang.Long] = ExpressionEncoder[java.lang.Long]()
+  override def finish(reduction: MutableLong): java.lang.Long = reduction.value
+  def this(f: MapFunction[IN, java.lang.Long]) = this((x: IN) => f.call(x))
+}
+
+class ScalaTypedMinLong[IN](override val f: IN => Long) extends TypedMinLong[IN, Option[Long]] {
+  override def outputEncoder: Encoder[Option[Long]] = ExpressionEncoder[Option[Long]]()
+  override def finish(reduction: MutableLong): Option[Long] = {
+    if (reduction != null) {
+      Some(reduction.value)
+    } else {
+      None
+    }
+  }
+}
+
+trait TypedMaxLong[IN, OUT] extends Aggregator[IN, MutableLong, OUT] {
+  val f: IN => Long
+  override def zero: MutableLong = null
+  override def reduce(b: MutableLong, a: IN): MutableLong = {
+    if (b == null) {
+      new MutableLong(f(a))
+    } else {
+      b.value = math.max(b.value, f(a))
+      b
+    }
+  }
+  override def merge(b1: MutableLong, b2: MutableLong): MutableLong = {
+    if (b1 == null) {
+      b2
+    } else if (b2 == null) {
+      b1
+    } else {
+      b1.value = math.max(b1.value, b2.value)
+      b1
+    }
+  }
+
+  override def bufferEncoder: Encoder[MutableLong] = Encoders.kryo[MutableLong]
+}
+
+class JavaTypedMaxLong[IN](override val f: IN => Long) extends TypedMaxLong[IN, java.lang.Long] {
+  override def outputEncoder: Encoder[java.lang.Long] = ExpressionEncoder[java.lang.Long]()
+  override def finish(reduction: MutableLong): java.lang.Long = reduction.value
+  def this(f: MapFunction[IN, java.lang.Long]) = this((x: IN) => f.call(x))
+}
+
+class ScalaTypedMaxLong[IN](override val f: IN => Long) extends TypedMaxLong[IN, Option[Long]] {
+  override def outputEncoder: Encoder[Option[Long]] = ExpressionEncoder[Option[Long]]()
+  override def finish(reduction: MutableLong): Option[Long] = {
+    if (reduction != null) {
+      Some(reduction.value)
+    } else {
+      None
+    }
+  }
+}
+
+private class MutableLong(var value: Long) extends Serializable
+
+private class MutableDouble(var value: Double) extends Serializable
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 058c38c8cb8f4..b6e3e86922e94 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.expressions
 
 import org.apache.spark.annotation.{Experimental, InterfaceStability}
-import org.apache.spark.sql.{Dataset, Encoder, TypedColumn}
+import org.apache.spark.sql.{Encoder, TypedColumn}
 import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala
index 650ffd4586592..36bc78ede9d25 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala
@@ -47,8 +47,6 @@ object typed {
     override protected def _sqlContext: SQLContext = null
   }
 
-  import implicits._
-
   /**
    * Average aggregate function.
    *
@@ -77,14 +75,43 @@ object typed {
    */
   def sumLong[IN](f: IN => Long): TypedColumn[IN, Long] = new TypedSumLong[IN](f).toColumn
 
+  /**
+   * Min aggregate function for floating point (double) type.
+   *
+   * @since 2.3.0
+   */
+  def min[IN](f: IN => Double): TypedColumn[IN, Option[Double]] =
+    new ScalaTypedMinDouble[IN](f).toColumn
+
+  /**
+   * Min aggregate function for integral (long, i.e. 64 bit integer) type.
+   *
+   * @since 2.3.0
+   */
+  def minLong[IN](f: IN => Long): TypedColumn[IN, Option[Long]] =
+    new ScalaTypedMinLong[IN](f).toColumn
+
+  /**
+   * Max aggregate function for floating point (double) type.
+   *
+   * @since 2.3.0
+   */
+  def max[IN](f: IN => Double): TypedColumn[IN, Option[Double]] =
+    new ScalaTypedMaxDouble[IN](f).toColumn
+
+  /**
+   * Max aggregate function for integral (long, i.e. 64 bit integer) type.
+   *
+   * @since 2.3.0
+   */
+  def maxLong[IN](f: IN => Long): TypedColumn[IN, Option[Long]] =
+    new ScalaTypedMaxLong[IN](f).toColumn
+
   // TODO:
   // stddevOf: Double
   // varianceOf: Double
   // approxCountDistinct: Long
 
-  // minOf: T
-  // maxOf: T
-
   // firstOf: T
   // lastOf: T
 }
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/Java8DatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/Java8DatasetAggregatorSuite.java
index 6ffccee52c0fe..5e5cf8c79592f 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/Java8DatasetAggregatorSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/Java8DatasetAggregatorSuite.java
@@ -66,4 +66,40 @@ public void testTypedAggregationSumLong() {
         Arrays.asList(new Tuple2<>("a", 3L), new Tuple2<>("b", 3L)),
         agged.collectAsList());
   }
+
+  @Test
+  public void testTypedAggregationMinDouble() {
+    KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+    Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.min(v -> (double)v._2()));
+    Assert.assertEquals(
+            Arrays.asList(new Tuple2<>("a", 1.0), new Tuple2<>("b", 3.0)),
+            agged.collectAsList());
+  }
+
+  @Test
+  public void testTypedAggregationMinLong() {
+    KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+    Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.minLong(v -> (long) v._2()));
+    Assert.assertEquals(
+            Arrays.asList(new Tuple2<>("a", 1L), new Tuple2<>("b", 3L)),
+            agged.collectAsList());
+  }
+
+  @Test
+  public void testTypedAggregationMaxDouble() {
+    KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+    Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.max(v -> (double)v._2()));
+    Assert.assertEquals(
+            Arrays.asList(new Tuple2<>("a", 2.0), new Tuple2<>("b", 3.0)),
+            agged.collectAsList());
+  }
+
+  @Test
+  public void testTypedAggregationMaxLong() {
+    KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+    Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.maxLong(v -> (long)v._2()));
+    Assert.assertEquals(
+            Arrays.asList(new Tuple2<>("a", 2L), new Tuple2<>("b", 3L)),
+            agged.collectAsList());
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 0e7eaa9e88d57..6bbee401ae5d8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -22,7 +22,6 @@ import org.apache.spark.sql.expressions.Aggregator
 import org.apache.spark.sql.expressions.scalalang.typed
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.StringType
 
 
 object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] {
@@ -263,6 +262,25 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
       ("a", 4), ("b", 3))
   }
 
+  test("typed aggregate: min, max") {
+    val ds = Seq("a" -> 1, "a" -> 3, "b" -> 4, "b" -> -4, "b" -> 0).toDS()
+    checkDataset(
+      ds.groupByKey(_._1).agg(
+        typed.min(_._2), typed.minLong(_._2), typed.max(_._2), typed.maxLong(_._2)),
+      ("a", Some(1.0), Some(1L), Some(3.0), Some(3L)),
+      ("b", Some(-4.0), Some(-4L), Some(4.0), Some(4L)))
+  }
+
+  test("typed aggregate: empty") {
+    val empty = Seq.empty[(Double, Double)].toDS
+    val f = (x: (Double, Double)) => x._2
+    val g = (x: (Long, Long)) => x._2
+    checkDataset(
+      empty.agg(typed.sum(f), typed.sumLong(g), typed.avg(f),
+        typed.min(f), typed.minLong(g), typed.max(f), typed.maxLong(g)),
+      Row(0.0, 0L, Double.NaN, None, None, None, None))
+  }
+
   test("SPARK-12555 - result should not be corrupted after input columns are reordered") {
     val ds = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData]
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org