You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2018/04/26 04:37:19 UTC

spark git commit: [SPARK-23916][SQL] Add array_join function

Repository: spark
Updated Branches:
  refs/heads/master 58c55cb4a -> cd10f9df8


[SPARK-23916][SQL] Add array_join function

## What changes were proposed in this pull request?

The PR adds the SQL function `array_join`. The behavior of the function is based on Presto's one.

The function accepts an `array` of `string` which is to be joined, a `string` which is the delimiter to use between the items of the first argument and optionally a `string` which is used to replace `null` values.

## How was this patch tested?

added UTs

Author: Marco Gaido <ma...@gmail.com>

Closes #21011 from mgaido91/SPARK-23916.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cd10f9df
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cd10f9df
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cd10f9df

Branch: refs/heads/master
Commit: cd10f9df8284ee8a5d287b2cd204c70b8ba87f5e
Parents: 58c55cb
Author: Marco Gaido <ma...@gmail.com>
Authored: Thu Apr 26 13:37:13 2018 +0900
Committer: Takuya UESHIN <ue...@databricks.com>
Committed: Thu Apr 26 13:37:13 2018 +0900

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 |  21 +++
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../expressions/collectionOperations.scala      | 169 +++++++++++++++++++
 .../CollectionExpressionsSuite.scala            |  35 ++++
 .../scala/org/apache/spark/sql/functions.scala  |  19 +++
 .../spark/sql/DataFrameFunctionsSuite.scala     |  23 +++
 6 files changed, 268 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cd10f9df/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 38ae41a..ad4bd6f 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1834,6 +1834,27 @@ def array_contains(col, value):
     return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
 
 
+@ignore_unicode_prefix
+@since(2.4)
+def array_join(col, delimiter, null_replacement=None):
+    """
+    Concatenates the elements of `column` using the `delimiter`. Null values are replaced with
+    `null_replacement` if set, otherwise they are ignored.
+
+    >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data'])
+    >>> df.select(array_join(df.data, ",").alias("joined")).collect()
+    [Row(joined=u'a,b,c'), Row(joined=u'a')]
+    >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect()
+    [Row(joined=u'a,b,c'), Row(joined=u'a,NULL')]
+    """
+    sc = SparkContext._active_spark_context
+    if null_replacement is None:
+        return Column(sc._jvm.functions.array_join(_to_java_column(col), delimiter))
+    else:
+        return Column(sc._jvm.functions.array_join(
+            _to_java_column(col), delimiter, null_replacement))
+
+
 @since(1.5)
 @ignore_unicode_prefix
 def concat(*cols):

http://git-wip-us.apache.org/repos/asf/spark/blob/cd10f9df/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 6afcf30..6bc7b4e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -401,6 +401,7 @@ object FunctionRegistry {
     // collection functions
     expression[CreateArray]("array"),
     expression[ArrayContains]("array_contains"),
+    expression[ArrayJoin]("array_join"),
     expression[ArrayPosition]("array_position"),
     expression[CreateMap]("map"),
     expression[CreateNamedStruct]("named_struct"),

http://git-wip-us.apache.org/repos/asf/spark/blob/cd10f9df/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index bc71b5f..90223b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -379,6 +379,175 @@ case class ArrayContains(left: Expression, right: Expression)
 }
 
 /**
+ * Creates a String containing all the elements of the input array separated by the delimiter.
+ */
+@ExpressionDescription(
+  usage = """
+    _FUNC_(array, delimiter[, nullReplacement]) - Concatenates the elements of the given array
+      using the delimiter and an optional string to replace nulls. If no value is set for
+      nullReplacement, any null value is filtered.""",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array('hello', 'world'), ' ');
+       hello world
+      > SELECT _FUNC_(array('hello', null ,'world'), ' ');
+       hello world
+      > SELECT _FUNC_(array('hello', null ,'world'), ' ', ',');
+       hello , world
+  """, since = "2.4.0")
+case class ArrayJoin(
+    array: Expression,
+    delimiter: Expression,
+    nullReplacement: Option[Expression]) extends Expression with ExpectsInputTypes {
+
+  def this(array: Expression, delimiter: Expression) = this(array, delimiter, None)
+
+  def this(array: Expression, delimiter: Expression, nullReplacement: Expression) =
+    this(array, delimiter, Some(nullReplacement))
+
+  override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) {
+    Seq(ArrayType(StringType), StringType, StringType)
+  } else {
+    Seq(ArrayType(StringType), StringType)
+  }
+
+  override def children: Seq[Expression] = if (nullReplacement.isDefined) {
+    Seq(array, delimiter, nullReplacement.get)
+  } else {
+    Seq(array, delimiter)
+  }
+
+  override def nullable: Boolean = children.exists(_.nullable)
+
+  override def foldable: Boolean = children.forall(_.foldable)
+
+  override def eval(input: InternalRow): Any = {
+    val arrayEval = array.eval(input)
+    if (arrayEval == null) return null
+    val delimiterEval = delimiter.eval(input)
+    if (delimiterEval == null) return null
+    val nullReplacementEval = nullReplacement.map(_.eval(input))
+    if (nullReplacementEval.contains(null)) return null
+
+    val buffer = new UTF8StringBuilder()
+    var firstItem = true
+    val nullHandling = nullReplacementEval match {
+      case Some(rep) => (prependDelimiter: Boolean) => {
+        if (!prependDelimiter) {
+          buffer.append(delimiterEval.asInstanceOf[UTF8String])
+        }
+        buffer.append(rep.asInstanceOf[UTF8String])
+        true
+      }
+      case None => (_: Boolean) => false
+    }
+    arrayEval.asInstanceOf[ArrayData].foreach(StringType, (_, item) => {
+      if (item == null) {
+        if (nullHandling(firstItem)) {
+          firstItem = false
+        }
+      } else {
+        if (!firstItem) {
+          buffer.append(delimiterEval.asInstanceOf[UTF8String])
+        }
+        buffer.append(item.asInstanceOf[UTF8String])
+        firstItem = false
+      }
+    })
+    buffer.build()
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val code = nullReplacement match {
+      case Some(replacement) =>
+        val replacementGen = replacement.genCode(ctx)
+        val nullHandling = (buffer: String, delimiter: String, firstItem: String) => {
+          s"""
+             |if (!$firstItem) {
+             |  $buffer.append($delimiter);
+             |}
+             |$buffer.append(${replacementGen.value});
+             |$firstItem = false;
+           """.stripMargin
+        }
+        val execCode = if (replacement.nullable) {
+          ctx.nullSafeExec(replacement.nullable, replacementGen.isNull) {
+            genCodeForArrayAndDelimiter(ctx, ev, nullHandling)
+          }
+        } else {
+          genCodeForArrayAndDelimiter(ctx, ev, nullHandling)
+        }
+        s"""
+           |${replacementGen.code}
+           |$execCode
+         """.stripMargin
+      case None => genCodeForArrayAndDelimiter(ctx, ev,
+        (_: String, _: String, _: String) => "// nulls are ignored")
+    }
+    if (nullable) {
+      ev.copy(
+        s"""
+           |boolean ${ev.isNull} = true;
+           |UTF8String ${ev.value} = null;
+           |$code
+         """.stripMargin)
+    } else {
+      ev.copy(
+        s"""
+           |UTF8String ${ev.value} = null;
+           |$code
+         """.stripMargin, FalseLiteral)
+    }
+  }
+
+  private def genCodeForArrayAndDelimiter(
+      ctx: CodegenContext,
+      ev: ExprCode,
+      nullEval: (String, String, String) => String): String = {
+    val arrayGen = array.genCode(ctx)
+    val delimiterGen = delimiter.genCode(ctx)
+    val buffer = ctx.freshName("buffer")
+    val bufferClass = classOf[UTF8StringBuilder].getName
+    val i = ctx.freshName("i")
+    val firstItem = ctx.freshName("firstItem")
+    val resultCode =
+      s"""
+         |$bufferClass $buffer = new $bufferClass();
+         |boolean $firstItem = true;
+         |for (int $i = 0; $i < ${arrayGen.value}.numElements(); $i ++) {
+         |  if (${arrayGen.value}.isNullAt($i)) {
+         |    ${nullEval(buffer, delimiterGen.value, firstItem)}
+         |  } else {
+         |    if (!$firstItem) {
+         |      $buffer.append(${delimiterGen.value});
+         |    }
+         |    $buffer.append(${CodeGenerator.getValue(arrayGen.value, StringType, i)});
+         |    $firstItem = false;
+         |  }
+         |}
+         |${ev.value} = $buffer.build();""".stripMargin
+
+    if (array.nullable || delimiter.nullable) {
+      arrayGen.code + ctx.nullSafeExec(array.nullable, arrayGen.isNull) {
+        delimiterGen.code + ctx.nullSafeExec(delimiter.nullable, delimiterGen.isNull) {
+          s"""
+             |${ev.isNull} = false;
+             |$resultCode""".stripMargin
+        }
+      }
+    } else {
+      s"""
+         |${arrayGen.code}
+         |${delimiterGen.code}
+         |$resultCode""".stripMargin
+    }
+  }
+
+  override def dataType: DataType = StringType
+
+}
+
+/**
  * Returns the minimum value in the array.
  */
 @ExpressionDescription(

http://git-wip-us.apache.org/repos/asf/spark/blob/cd10f9df/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index b49fa76..7048d93 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -106,6 +106,41 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
   }
 
+  test("ArrayJoin") {
+    def testArrays(
+        arrays: Seq[Expression],
+        nullReplacement: Option[Expression],
+        expected: Seq[String]): Unit = {
+      assert(arrays.length == expected.length)
+      arrays.zip(expected).foreach { case (arr, exp) =>
+        checkEvaluation(ArrayJoin(arr, Literal(","), nullReplacement), exp)
+      }
+    }
+
+    val arrays = Seq(Literal.create(Seq[String]("a", "b"), ArrayType(StringType)),
+      Literal.create(Seq[String]("a", null, "b"), ArrayType(StringType)),
+      Literal.create(Seq[String](null), ArrayType(StringType)),
+      Literal.create(Seq[String]("a", "b", null), ArrayType(StringType)),
+      Literal.create(Seq[String](null, "a", "b"), ArrayType(StringType)),
+      Literal.create(Seq[String]("a"), ArrayType(StringType)))
+
+    val withoutNullReplacement = Seq("a,b", "a,b", "", "a,b", "a,b", "a")
+    val withNullReplacement = Seq("a,b", "a,NULL,b", "NULL", "a,b,NULL", "NULL,a,b", "a")
+    testArrays(arrays, None, withoutNullReplacement)
+    testArrays(arrays, Some(Literal("NULL")), withNullReplacement)
+
+    checkEvaluation(ArrayJoin(
+      Literal.create(null, ArrayType(StringType)), Literal(","), None), null)
+    checkEvaluation(ArrayJoin(
+      Literal.create(Seq[String](null), ArrayType(StringType)),
+      Literal.create(null, StringType),
+      None), null)
+    checkEvaluation(ArrayJoin(
+      Literal.create(Seq[String](null), ArrayType(StringType)),
+      Literal(","),
+      Some(Literal.create(null, StringType))), null)
+  }
+
   test("Array Min") {
     checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11)
     checkEvaluation(

http://git-wip-us.apache.org/repos/asf/spark/blob/cd10f9df/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index f1587cd..25afaac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3040,6 +3040,25 @@ object functions {
   }
 
   /**
+   * Concatenates the elements of `column` using the `delimiter`. Null values are replaced with
+   * `nullReplacement`.
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def array_join(column: Column, delimiter: String, nullReplacement: String): Column = withExpr {
+    ArrayJoin(column.expr, Literal(delimiter), Some(Literal(nullReplacement)))
+  }
+
+  /**
+   * Concatenates the elements of `column` using the `delimiter`.
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def array_join(column: Column, delimiter: String): Column = withExpr {
+    ArrayJoin(column.expr, Literal(delimiter), None)
+  }
+
+  /**
    * Concatenates multiple input columns together into a single column.
    * The function works with strings, binary and compatible array columns.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/cd10f9df/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 03605c3..c216d13 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -413,6 +413,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
     )
   }
 
+  test("array_join function") {
+    val df = Seq(
+      (Seq[String]("a", "b"), ","),
+      (Seq[String]("a", null, "b"), ","),
+      (Seq.empty[String], ",")
+    ).toDF("x", "delimiter")
+
+    checkAnswer(
+      df.select(array_join(df("x"), ";")),
+      Seq(Row("a;b"), Row("a;b"), Row(""))
+    )
+    checkAnswer(
+      df.select(array_join(df("x"), ";", "NULL")),
+      Seq(Row("a;b"), Row("a;NULL;b"), Row(""))
+    )
+    checkAnswer(
+      df.selectExpr("array_join(x, delimiter)"),
+      Seq(Row("a,b"), Row("a,b"), Row("")))
+    checkAnswer(
+      df.selectExpr("array_join(x, delimiter, 'NULL')"),
+      Seq(Row("a,b"), Row("a,NULL,b"), Row("")))
+  }
+
   test("array_min function") {
     val df = Seq(
       Seq[Option[Int]](Some(1), Some(3), Some(2)),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org