You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2018/04/16 04:45:59 UTC

spark git commit: [SPARK-23917][SQL] Add array_max function

Repository: spark
Updated Branches:
  refs/heads/master c0964935d -> 693102203


[SPARK-23917][SQL] Add array_max function

## What changes were proposed in this pull request?

The PR adds the SQL function `array_max`. It takes an array as argument and returns the maximum value in it.

## How was this patch tested?

added UTs

Author: Marco Gaido <ma...@gmail.com>

Closes #21024 from mgaido91/SPARK-23917.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/69310220
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/69310220
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/69310220

Branch: refs/heads/master
Commit: 69310220319163bac18c9ee69d7da6d92227253b
Parents: c096493
Author: Marco Gaido <ma...@gmail.com>
Authored: Sun Apr 15 21:45:55 2018 -0700
Committer: gatorsmile <ga...@gmail.com>
Committed: Sun Apr 15 21:45:55 2018 -0700

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 | 15 +++++
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../sql/catalyst/expressions/arithmetic.scala   |  6 +-
 .../expressions/codegen/CodeGenerator.scala     | 17 +++++
 .../expressions/collectionOperations.scala      | 68 +++++++++++++++++++-
 .../CollectionExpressionsSuite.scala            | 10 +++
 .../scala/org/apache/spark/sql/functions.scala  |  8 +++
 .../spark/sql/DataFrameFunctionsSuite.scala     | 14 ++++
 8 files changed, 133 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/69310220/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 1b19268..f3492ae 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2080,6 +2080,21 @@ def size(col):
     return Column(sc._jvm.functions.size(_to_java_column(col)))
 
 
+@since(2.4)
+def array_max(col):
+    """
+    Collection function: returns the maximum value of the array.
+
+    :param col: name of column or expression
+
+    >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])
+    >>> df.select(array_max(df.data).alias('max')).collect()
+    [Row(max=3), Row(max=10)]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.array_max(_to_java_column(col)))
+
+
 @since(1.5)
 def sort_array(col, asc=True):
     """

http://git-wip-us.apache.org/repos/asf/spark/blob/69310220/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 131b958..05bfa2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -409,6 +409,7 @@ object FunctionRegistry {
     expression[MapValues]("map_values"),
     expression[Size]("size"),
     expression[SortArray]("sort_array"),
+    expression[ArrayMax]("array_max"),
     CreateStruct.registryEntry,
 
     // misc functions

http://git-wip-us.apache.org/repos/asf/spark/blob/69310220/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 9212c3d..942dfd4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -674,11 +674,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
     val evals = evalChildren.map(eval =>
       s"""
          |${eval.code}
-         |if (!${eval.isNull} && (${ev.isNull} ||
-         |  ${ctx.genGreater(dataType, eval.value, ev.value)})) {
-         |  ${ev.isNull} = false;
-         |  ${ev.value} = ${eval.value};
-         |}
+         |${ctx.reassignIfGreater(dataType, ev, eval)}
       """.stripMargin
     )
 

http://git-wip-us.apache.org/repos/asf/spark/blob/69310220/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 0abfc9f..c86c5be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -700,6 +700,23 @@ class CodegenContext {
   }
 
   /**
+   * Generates code for updating `partialResult` if `item` is greater than it.
+   *
+   * @param dataType data type of the expressions
+   * @param partialResult `ExprCode` representing the partial result which has to be updated
+   * @param item `ExprCode` representing the new expression to evaluate for the result
+   */
+  def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = {
+    s"""
+       |if (!${item.isNull} && (${partialResult.isNull} ||
+       |  ${genGreater(dataType, item.value, partialResult.value)})) {
+       |  ${partialResult.isNull} = false;
+       |  ${partialResult.value} = ${item.value};
+       |}
+      """.stripMargin
+  }
+
+  /**
    * Generates code to do null safe execution, i.e. only execute the code when the input is not
    * null by adding null check if necessary.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/69310220/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 91188da..e2614a1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -21,7 +21,7 @@ import java.util.Comparator
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
 import org.apache.spark.sql.types._
 
 /**
@@ -287,3 +287,69 @@ case class ArrayContains(left: Expression, right: Expression)
 
   override def prettyName: String = "array_contains"
 }
+
+
+/**
+ * Returns the maximum value in the array.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(1, 20, null, 3));
+       20
+  """, since = "2.4.0")
+case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+
+  override def nullable: Boolean = true
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
+
+  private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val typeCheckResult = super.checkInputDataTypes()
+    if (typeCheckResult.isSuccess) {
+      TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
+    } else {
+      typeCheckResult
+    }
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val childGen = child.genCode(ctx)
+    val javaType = CodeGenerator.javaType(dataType)
+    val i = ctx.freshName("i")
+    val item = ExprCode("",
+      isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
+      value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
+    ev.copy(code =
+      s"""
+         |${childGen.code}
+         |boolean ${ev.isNull} = true;
+         |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+         |if (!${childGen.isNull}) {
+         |  for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) {
+         |    ${ctx.reassignIfGreater(dataType, ev, item)}
+         |  }
+         |}
+      """.stripMargin)
+  }
+
+  override protected def nullSafeEval(input: Any): Any = {
+    var max: Any = null
+    input.asInstanceOf[ArrayData].foreach(dataType, (_, item) =>
+      if (item != null && (max == null || ordering.gt(item, max))) {
+        max = item
+      }
+    )
+    max
+  }
+
+  override def dataType: DataType = child.dataType match {
+    case ArrayType(dt, _) => dt
+    case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
+  }
+
+  override def prettyName: String = "array_max"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/69310220/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 020687e..a238401 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -105,4 +105,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(ArrayContains(a3, Literal("")), null)
     checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
   }
+
+  test("Array max") {
+    checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10)
+    checkEvaluation(
+      ArrayMax(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "abc")
+    checkEvaluation(ArrayMax(Literal.create(Seq(null), ArrayType(LongType))), null)
+    checkEvaluation(ArrayMax(Literal.create(null, ArrayType(StringType))), null)
+    checkEvaluation(
+      ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/69310220/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index c658f25..daf4079 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3301,6 +3301,14 @@ object functions {
   def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) }
 
   /**
+   * Returns the maximum value in the array.
+   *
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) }
+
+  /**
    * Returns an unordered array containing the keys of the map.
    * @group collection_funcs
    * @since 2.3.0

http://git-wip-us.apache.org/repos/asf/spark/blob/69310220/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 50e4759..5d5d92c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
     )
   }
 
+  test("array_max function") {
+    val df = Seq(
+      Seq[Option[Int]](Some(1), Some(3), Some(2)),
+      Seq.empty[Option[Int]],
+      Seq[Option[Int]](None),
+      Seq[Option[Int]](None, Some(1), Some(-100))
+    ).toDF("a")
+
+    val answer = Seq(Row(3), Row(null), Row(null), Row(1))
+
+    checkAnswer(df.select(array_max(df("a"))), answer)
+    checkAnswer(df.selectExpr("array_max(a)"), answer)
+  }
+
   private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
     import DataFrameFunctionsSuite.CodegenFallbackExpr
     for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org