You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2018/04/16 04:45:59 UTC
spark git commit: [SPARK-23917][SQL] Add array_max function
Repository: spark
Updated Branches:
refs/heads/master c0964935d -> 693102203
[SPARK-23917][SQL] Add array_max function
## What changes were proposed in this pull request?
The PR adds the SQL function `array_max`. It takes an array as argument and returns the maximum value in it.
## How was this patch tested?
added UTs
Author: Marco Gaido <ma...@gmail.com>
Closes #21024 from mgaido91/SPARK-23917.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/69310220
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/69310220
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/69310220
Branch: refs/heads/master
Commit: 69310220319163bac18c9ee69d7da6d92227253b
Parents: c096493
Author: Marco Gaido <ma...@gmail.com>
Authored: Sun Apr 15 21:45:55 2018 -0700
Committer: gatorsmile <ga...@gmail.com>
Committed: Sun Apr 15 21:45:55 2018 -0700
----------------------------------------------------------------------
python/pyspark/sql/functions.py | 15 +++++
.../catalyst/analysis/FunctionRegistry.scala | 1 +
.../sql/catalyst/expressions/arithmetic.scala | 6 +-
.../expressions/codegen/CodeGenerator.scala | 17 +++++
.../expressions/collectionOperations.scala | 68 +++++++++++++++++++-
.../CollectionExpressionsSuite.scala | 10 +++
.../scala/org/apache/spark/sql/functions.scala | 8 +++
.../spark/sql/DataFrameFunctionsSuite.scala | 14 ++++
8 files changed, 133 insertions(+), 6 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/69310220/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 1b19268..f3492ae 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2080,6 +2080,21 @@ def size(col):
return Column(sc._jvm.functions.size(_to_java_column(col)))
+@since(2.4)
+def array_max(col):
+ """
+ Collection function: returns the maximum value of the array.
+
+ :param col: name of column or expression
+
+ >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])
+ >>> df.select(array_max(df.data).alias('max')).collect()
+ [Row(max=3), Row(max=10)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.array_max(_to_java_column(col)))
+
+
@since(1.5)
def sort_array(col, asc=True):
"""
http://git-wip-us.apache.org/repos/asf/spark/blob/69310220/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 131b958..05bfa2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -409,6 +409,7 @@ object FunctionRegistry {
expression[MapValues]("map_values"),
expression[Size]("size"),
expression[SortArray]("sort_array"),
+ expression[ArrayMax]("array_max"),
CreateStruct.registryEntry,
// misc functions
http://git-wip-us.apache.org/repos/asf/spark/blob/69310220/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 9212c3d..942dfd4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -674,11 +674,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
- |if (!${eval.isNull} && (${ev.isNull} ||
- | ${ctx.genGreater(dataType, eval.value, ev.value)})) {
- | ${ev.isNull} = false;
- | ${ev.value} = ${eval.value};
- |}
+ |${ctx.reassignIfGreater(dataType, ev, eval)}
""".stripMargin
)
http://git-wip-us.apache.org/repos/asf/spark/blob/69310220/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 0abfc9f..c86c5be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -700,6 +700,23 @@ class CodegenContext {
}
/**
+ * Generates code for updating `partialResult` if `item` is greater than it.
+ *
+ * @param dataType data type of the expressions
+ * @param partialResult `ExprCode` representing the partial result which has to be updated
+ * @param item `ExprCode` representing the new expression to evaluate for the result
+ */
+ def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = {
+ s"""
+ |if (!${item.isNull} && (${partialResult.isNull} ||
+ | ${genGreater(dataType, item.value, partialResult.value)})) {
+ | ${partialResult.isNull} = false;
+ | ${partialResult.value} = ${item.value};
+ |}
+ """.stripMargin
+ }
+
+ /**
* Generates code to do null safe execution, i.e. only execute the code when the input is not
* null by adding null check if necessary.
*
http://git-wip-us.apache.org/repos/asf/spark/blob/69310220/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 91188da..e2614a1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -21,7 +21,7 @@ import java.util.Comparator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
/**
@@ -287,3 +287,69 @@ case class ArrayContains(left: Expression, right: Expression)
override def prettyName: String = "array_contains"
}
+
+
+/**
+ * Returns the maximum value in the array.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 20, null, 3));
+ 20
+ """, since = "2.4.0")
+case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+
+ override def nullable: Boolean = true
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
+
+ private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val typeCheckResult = super.checkInputDataTypes()
+ if (typeCheckResult.isSuccess) {
+ TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
+ } else {
+ typeCheckResult
+ }
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val childGen = child.genCode(ctx)
+ val javaType = CodeGenerator.javaType(dataType)
+ val i = ctx.freshName("i")
+ val item = ExprCode("",
+ isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
+ value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
+ ev.copy(code =
+ s"""
+ |${childGen.code}
+ |boolean ${ev.isNull} = true;
+ |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+ |if (!${childGen.isNull}) {
+ | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) {
+ | ${ctx.reassignIfGreater(dataType, ev, item)}
+ | }
+ |}
+ """.stripMargin)
+ }
+
+ override protected def nullSafeEval(input: Any): Any = {
+ var max: Any = null
+ input.asInstanceOf[ArrayData].foreach(dataType, (_, item) =>
+ if (item != null && (max == null || ordering.gt(item, max))) {
+ max = item
+ }
+ )
+ max
+ }
+
+ override def dataType: DataType = child.dataType match {
+ case ArrayType(dt, _) => dt
+ case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
+ }
+
+ override def prettyName: String = "array_max"
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/69310220/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 020687e..a238401 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -105,4 +105,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayContains(a3, Literal("")), null)
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
}
+
+ test("Array max") {
+ checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10)
+ checkEvaluation(
+ ArrayMax(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "abc")
+ checkEvaluation(ArrayMax(Literal.create(Seq(null), ArrayType(LongType))), null)
+ checkEvaluation(ArrayMax(Literal.create(null, ArrayType(StringType))), null)
+ checkEvaluation(
+ ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123)
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/69310220/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index c658f25..daf4079 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3301,6 +3301,14 @@ object functions {
def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) }
/**
+ * Returns the maximum value in the array.
+ *
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) }
+
+ /**
* Returns an unordered array containing the keys of the map.
* @group collection_funcs
* @since 2.3.0
http://git-wip-us.apache.org/repos/asf/spark/blob/69310220/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 50e4759..5d5d92c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}
+ test("array_max function") {
+ val df = Seq(
+ Seq[Option[Int]](Some(1), Some(3), Some(2)),
+ Seq.empty[Option[Int]],
+ Seq[Option[Int]](None),
+ Seq[Option[Int]](None, Some(1), Some(-100))
+ ).toDF("a")
+
+ val answer = Seq(Row(3), Row(null), Row(null), Row(1))
+
+ checkAnswer(df.select(array_max(df("a"))), answer)
+ checkAnswer(df.selectExpr("array_max(a)"), answer)
+ }
+
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org