You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2018/04/26 04:37:19 UTC
spark git commit: [SPARK-23916][SQL] Add array_join function
Repository: spark
Updated Branches:
refs/heads/master 58c55cb4a -> cd10f9df8
[SPARK-23916][SQL] Add array_join function
## What changes were proposed in this pull request?
The PR adds the SQL function `array_join`. The behavior of the function is based on Presto's one.
The function accepts an `array` of `string` which is to be joined, a `string` which is the delimiter to use between the items of the first argument and optionally a `string` which is used to replace `null` values.
## How was this patch tested?
added UTs
Author: Marco Gaido <ma...@gmail.com>
Closes #21011 from mgaido91/SPARK-23916.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cd10f9df
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cd10f9df
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cd10f9df
Branch: refs/heads/master
Commit: cd10f9df8284ee8a5d287b2cd204c70b8ba87f5e
Parents: 58c55cb
Author: Marco Gaido <ma...@gmail.com>
Authored: Thu Apr 26 13:37:13 2018 +0900
Committer: Takuya UESHIN <ue...@databricks.com>
Committed: Thu Apr 26 13:37:13 2018 +0900
----------------------------------------------------------------------
python/pyspark/sql/functions.py | 21 +++
.../catalyst/analysis/FunctionRegistry.scala | 1 +
.../expressions/collectionOperations.scala | 169 +++++++++++++++++++
.../CollectionExpressionsSuite.scala | 35 ++++
.../scala/org/apache/spark/sql/functions.scala | 19 +++
.../spark/sql/DataFrameFunctionsSuite.scala | 23 +++
6 files changed, 268 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/cd10f9df/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 38ae41a..ad4bd6f 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1834,6 +1834,27 @@ def array_contains(col, value):
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
+@ignore_unicode_prefix
+@since(2.4)
+def array_join(col, delimiter, null_replacement=None):
+ """
+ Concatenates the elements of `column` using the `delimiter`. Null values are replaced with
+ `null_replacement` if set, otherwise they are ignored.
+
+ >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data'])
+ >>> df.select(array_join(df.data, ",").alias("joined")).collect()
+ [Row(joined=u'a,b,c'), Row(joined=u'a')]
+ >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect()
+ [Row(joined=u'a,b,c'), Row(joined=u'a,NULL')]
+ """
+ sc = SparkContext._active_spark_context
+ if null_replacement is None:
+ return Column(sc._jvm.functions.array_join(_to_java_column(col), delimiter))
+ else:
+ return Column(sc._jvm.functions.array_join(
+ _to_java_column(col), delimiter, null_replacement))
+
+
@since(1.5)
@ignore_unicode_prefix
def concat(*cols):
http://git-wip-us.apache.org/repos/asf/spark/blob/cd10f9df/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 6afcf30..6bc7b4e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -401,6 +401,7 @@ object FunctionRegistry {
// collection functions
expression[CreateArray]("array"),
expression[ArrayContains]("array_contains"),
+ expression[ArrayJoin]("array_join"),
expression[ArrayPosition]("array_position"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
http://git-wip-us.apache.org/repos/asf/spark/blob/cd10f9df/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index bc71b5f..90223b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -379,6 +379,175 @@ case class ArrayContains(left: Expression, right: Expression)
}
/**
+ * Creates a String containing all the elements of the input array separated by the delimiter.
+ */
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array, delimiter[, nullReplacement]) - Concatenates the elements of the given array
+ using the delimiter and an optional string to replace nulls. If no value is set for
+ nullReplacement, any null value is filtered.""",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array('hello', 'world'), ' ');
+ hello world
+ > SELECT _FUNC_(array('hello', null ,'world'), ' ');
+ hello world
+ > SELECT _FUNC_(array('hello', null ,'world'), ' ', ',');
+ hello , world
+ """, since = "2.4.0")
+case class ArrayJoin(
+ array: Expression,
+ delimiter: Expression,
+ nullReplacement: Option[Expression]) extends Expression with ExpectsInputTypes {
+
+ def this(array: Expression, delimiter: Expression) = this(array, delimiter, None)
+
+ def this(array: Expression, delimiter: Expression, nullReplacement: Expression) =
+ this(array, delimiter, Some(nullReplacement))
+
+ override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) {
+ Seq(ArrayType(StringType), StringType, StringType)
+ } else {
+ Seq(ArrayType(StringType), StringType)
+ }
+
+ override def children: Seq[Expression] = if (nullReplacement.isDefined) {
+ Seq(array, delimiter, nullReplacement.get)
+ } else {
+ Seq(array, delimiter)
+ }
+
+ override def nullable: Boolean = children.exists(_.nullable)
+
+ override def foldable: Boolean = children.forall(_.foldable)
+
+ override def eval(input: InternalRow): Any = {
+ val arrayEval = array.eval(input)
+ if (arrayEval == null) return null
+ val delimiterEval = delimiter.eval(input)
+ if (delimiterEval == null) return null
+ val nullReplacementEval = nullReplacement.map(_.eval(input))
+ if (nullReplacementEval.contains(null)) return null
+
+ val buffer = new UTF8StringBuilder()
+ var firstItem = true
+ val nullHandling = nullReplacementEval match {
+ case Some(rep) => (prependDelimiter: Boolean) => {
+ if (!prependDelimiter) {
+ buffer.append(delimiterEval.asInstanceOf[UTF8String])
+ }
+ buffer.append(rep.asInstanceOf[UTF8String])
+ true
+ }
+ case None => (_: Boolean) => false
+ }
+ arrayEval.asInstanceOf[ArrayData].foreach(StringType, (_, item) => {
+ if (item == null) {
+ if (nullHandling(firstItem)) {
+ firstItem = false
+ }
+ } else {
+ if (!firstItem) {
+ buffer.append(delimiterEval.asInstanceOf[UTF8String])
+ }
+ buffer.append(item.asInstanceOf[UTF8String])
+ firstItem = false
+ }
+ })
+ buffer.build()
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val code = nullReplacement match {
+ case Some(replacement) =>
+ val replacementGen = replacement.genCode(ctx)
+ val nullHandling = (buffer: String, delimiter: String, firstItem: String) => {
+ s"""
+ |if (!$firstItem) {
+ | $buffer.append($delimiter);
+ |}
+ |$buffer.append(${replacementGen.value});
+ |$firstItem = false;
+ """.stripMargin
+ }
+ val execCode = if (replacement.nullable) {
+ ctx.nullSafeExec(replacement.nullable, replacementGen.isNull) {
+ genCodeForArrayAndDelimiter(ctx, ev, nullHandling)
+ }
+ } else {
+ genCodeForArrayAndDelimiter(ctx, ev, nullHandling)
+ }
+ s"""
+ |${replacementGen.code}
+ |$execCode
+ """.stripMargin
+ case None => genCodeForArrayAndDelimiter(ctx, ev,
+ (_: String, _: String, _: String) => "// nulls are ignored")
+ }
+ if (nullable) {
+ ev.copy(
+ s"""
+ |boolean ${ev.isNull} = true;
+ |UTF8String ${ev.value} = null;
+ |$code
+ """.stripMargin)
+ } else {
+ ev.copy(
+ s"""
+ |UTF8String ${ev.value} = null;
+ |$code
+ """.stripMargin, FalseLiteral)
+ }
+ }
+
+ private def genCodeForArrayAndDelimiter(
+ ctx: CodegenContext,
+ ev: ExprCode,
+ nullEval: (String, String, String) => String): String = {
+ val arrayGen = array.genCode(ctx)
+ val delimiterGen = delimiter.genCode(ctx)
+ val buffer = ctx.freshName("buffer")
+ val bufferClass = classOf[UTF8StringBuilder].getName
+ val i = ctx.freshName("i")
+ val firstItem = ctx.freshName("firstItem")
+ val resultCode =
+ s"""
+ |$bufferClass $buffer = new $bufferClass();
+ |boolean $firstItem = true;
+ |for (int $i = 0; $i < ${arrayGen.value}.numElements(); $i ++) {
+ | if (${arrayGen.value}.isNullAt($i)) {
+ | ${nullEval(buffer, delimiterGen.value, firstItem)}
+ | } else {
+ | if (!$firstItem) {
+ | $buffer.append(${delimiterGen.value});
+ | }
+ | $buffer.append(${CodeGenerator.getValue(arrayGen.value, StringType, i)});
+ | $firstItem = false;
+ | }
+ |}
+ |${ev.value} = $buffer.build();""".stripMargin
+
+ if (array.nullable || delimiter.nullable) {
+ arrayGen.code + ctx.nullSafeExec(array.nullable, arrayGen.isNull) {
+ delimiterGen.code + ctx.nullSafeExec(delimiter.nullable, delimiterGen.isNull) {
+ s"""
+ |${ev.isNull} = false;
+ |$resultCode""".stripMargin
+ }
+ }
+ } else {
+ s"""
+ |${arrayGen.code}
+ |${delimiterGen.code}
+ |$resultCode""".stripMargin
+ }
+ }
+
+ override def dataType: DataType = StringType
+
+}
+
+/**
* Returns the minimum value in the array.
*/
@ExpressionDescription(
http://git-wip-us.apache.org/repos/asf/spark/blob/cd10f9df/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index b49fa76..7048d93 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -106,6 +106,41 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
}
+ test("ArrayJoin") {
+ def testArrays(
+ arrays: Seq[Expression],
+ nullReplacement: Option[Expression],
+ expected: Seq[String]): Unit = {
+ assert(arrays.length == expected.length)
+ arrays.zip(expected).foreach { case (arr, exp) =>
+ checkEvaluation(ArrayJoin(arr, Literal(","), nullReplacement), exp)
+ }
+ }
+
+ val arrays = Seq(Literal.create(Seq[String]("a", "b"), ArrayType(StringType)),
+ Literal.create(Seq[String]("a", null, "b"), ArrayType(StringType)),
+ Literal.create(Seq[String](null), ArrayType(StringType)),
+ Literal.create(Seq[String]("a", "b", null), ArrayType(StringType)),
+ Literal.create(Seq[String](null, "a", "b"), ArrayType(StringType)),
+ Literal.create(Seq[String]("a"), ArrayType(StringType)))
+
+ val withoutNullReplacement = Seq("a,b", "a,b", "", "a,b", "a,b", "a")
+ val withNullReplacement = Seq("a,b", "a,NULL,b", "NULL", "a,b,NULL", "NULL,a,b", "a")
+ testArrays(arrays, None, withoutNullReplacement)
+ testArrays(arrays, Some(Literal("NULL")), withNullReplacement)
+
+ checkEvaluation(ArrayJoin(
+ Literal.create(null, ArrayType(StringType)), Literal(","), None), null)
+ checkEvaluation(ArrayJoin(
+ Literal.create(Seq[String](null), ArrayType(StringType)),
+ Literal.create(null, StringType),
+ None), null)
+ checkEvaluation(ArrayJoin(
+ Literal.create(Seq[String](null), ArrayType(StringType)),
+ Literal(","),
+ Some(Literal.create(null, StringType))), null)
+ }
+
test("Array Min") {
checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11)
checkEvaluation(
http://git-wip-us.apache.org/repos/asf/spark/blob/cd10f9df/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index f1587cd..25afaac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3040,6 +3040,25 @@ object functions {
}
/**
+ * Concatenates the elements of `column` using the `delimiter`. Null values are replaced with
+ * `nullReplacement`.
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def array_join(column: Column, delimiter: String, nullReplacement: String): Column = withExpr {
+ ArrayJoin(column.expr, Literal(delimiter), Some(Literal(nullReplacement)))
+ }
+
+ /**
+ * Concatenates the elements of `column` using the `delimiter`.
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def array_join(column: Column, delimiter: String): Column = withExpr {
+ ArrayJoin(column.expr, Literal(delimiter), None)
+ }
+
+ /**
* Concatenates multiple input columns together into a single column.
* The function works with strings, binary and compatible array columns.
*
http://git-wip-us.apache.org/repos/asf/spark/blob/cd10f9df/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 03605c3..c216d13 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -413,6 +413,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}
+ test("array_join function") {
+ val df = Seq(
+ (Seq[String]("a", "b"), ","),
+ (Seq[String]("a", null, "b"), ","),
+ (Seq.empty[String], ",")
+ ).toDF("x", "delimiter")
+
+ checkAnswer(
+ df.select(array_join(df("x"), ";")),
+ Seq(Row("a;b"), Row("a;b"), Row(""))
+ )
+ checkAnswer(
+ df.select(array_join(df("x"), ";", "NULL")),
+ Seq(Row("a;b"), Row("a;NULL;b"), Row(""))
+ )
+ checkAnswer(
+ df.selectExpr("array_join(x, delimiter)"),
+ Seq(Row("a,b"), Row("a,b"), Row("")))
+ checkAnswer(
+ df.selectExpr("array_join(x, delimiter, 'NULL')"),
+ Seq(Row("a,b"), Row("a,NULL,b"), Row("")))
+ }
+
test("array_min function") {
val df = Seq(
Seq[Option[Int]](Some(1), Some(3), Some(2)),
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org