You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/03/16 09:51:56 UTC

[spark] branch master updated: [SPARK-41233][SQL][PYTHON] Add `array_prepend` function

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3dd629629ab [SPARK-41233][SQL][PYTHON] Add `array_prepend` function
3dd629629ab is described below

commit 3dd629629ab151688b82a3aa66e1b5fa568afbfa
Author: Navin Viswanath <na...@gmail.com>
AuthorDate: Thu Mar 16 17:51:33 2023 +0800

    [SPARK-41233][SQL][PYTHON] Add `array_prepend` function
    
    ### What changes were proposed in this pull request?
    Adds a new array function array_prepend to catalyst.
    
    ### Why are the changes needed?
    This adds a function that exists in many SQL implementations, specifically Snowflake: https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/api/snowflake.snowpark.functions.array_prepend.html
    
    ### Does this PR introduce _any_ user-facing change?
    Yes.
    
    ### How was this patch tested?
    Added unit tests.
    
    Closes #38947 from navinvishy/array-prepend.
    
    Lead-authored-by: Navin Viswanath <na...@gmail.com>
    Co-authored-by: navinvishy <na...@gmail.com>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../source/reference/pyspark.sql/functions.rst     |   1 +
 python/pyspark/sql/functions.py                    |  30 +++++
 .../sql/catalyst/analysis/FunctionRegistry.scala   |   1 +
 .../expressions/collectionOperations.scala         | 146 +++++++++++++++++++++
 .../expressions/CollectionExpressionsSuite.scala   |  44 +++++++
 .../scala/org/apache/spark/sql/functions.scala     |  10 ++
 .../sql-functions/sql-expression-schema.md         |   3 +-
 .../src/test/resources/sql-tests/inputs/array.sql  |  11 ++
 .../resources/sql-tests/results/ansi/array.sql.out |  72 ++++++++++
 .../test/resources/sql-tests/results/array.sql.out |  72 ++++++++++
 .../apache/spark/sql/DataFrameFunctionsSuite.scala |  68 ++++++++++
 11 files changed, 457 insertions(+), 1 deletion(-)

diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst
index 70fc04ef9cf..cbc46e1fae1 100644
--- a/python/docs/source/reference/pyspark.sql/functions.rst
+++ b/python/docs/source/reference/pyspark.sql/functions.rst
@@ -159,6 +159,7 @@ Collection Functions
     array_sort
     array_insert
     array_remove
+    array_prepend
     array_distinct
     array_intersect
     array_union
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 051fd52a13c..1f02be3ad21 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -7631,6 +7631,36 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column:
     return _invoke_function_over_columns("get", col, index)
 
 
+@try_remote_functions
+def array_prepend(col: "ColumnOrName", value: Any) -> Column:
+    """
+    Collection function: Returns an array containing element as
+    well as all elements from array. The new element is positioned
+    at the beginning of the array.
+
+    .. versionadded:: 3.5.0
+
+    Parameters
+    ----------
+    col : :class:`~pyspark.sql.Column` or str
+        name of column containing array
+    value :
+        a literal value, or a :class:`~pyspark.sql.Column` expression.
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        an array excluding given value.
+
+    Examples
+    --------
+    >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data'])
+    >>> df.select(array_prepend(df.data, 1)).collect()
+    [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])]
+    """
+    return _invoke_function_over_columns("array_prepend", col, lit(value))
+
+
 @try_remote_functions
 def array_remove(col: "ColumnOrName", element: Any) -> Column:
     """
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index ad82a836199..aca73741c63 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -697,6 +697,7 @@ object FunctionRegistry {
     expression[Sequence]("sequence"),
     expression[ArrayRepeat]("array_repeat"),
     expression[ArrayRemove]("array_remove"),
+    expression[ArrayPrepend]("array_prepend"),
     expression[ArrayDistinct]("array_distinct"),
     expression[ArrayTransform]("transform"),
     expression[MapFilter]("map_filter"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 289859d420b..2ccb3a6d0cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1399,6 +1399,152 @@ case class ArrayContains(left: Expression, right: Expression)
     copy(left = newLeft, right = newRight)
 }
 
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+      _FUNC_(array, element) - Add the element at the beginning of the array passed as first
+      argument. Type of element should be the same as the type of the elements of the array.
+      Null element is also prepended to the array. But if the array passed is NULL
+      output is NULL
+    """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd');
+       ["d","b","d","c","a"]
+      > SELECT _FUNC_(array(1, 2, 3, null), null);
+       [null,1,2,3,null]
+      > SELECT _FUNC_(CAST(null as Array<Int>), 2);
+       NULL
+  """,
+  group = "array_funcs",
+  since = "3.5.0")
+case class ArrayPrepend(left: Expression, right: Expression)
+  extends BinaryExpression
+    with ImplicitCastInputTypes
+    with ComplexTypeMergingExpression
+    with QueryErrorsBase {
+
+  override def nullable: Boolean = left.nullable
+
+  @transient protected lazy val elementType: DataType =
+    inputTypes.head.asInstanceOf[ArrayType].elementType
+
+  override def eval(input: InternalRow): Any = {
+    val value1 = left.eval(input)
+    if (value1 == null) {
+      null
+    } else {
+      val value2 = right.eval(input)
+      nullSafeEval(value1, value2)
+    }
+  }
+  override def nullSafeEval(arr: Any, elementData: Any): Any = {
+    val arrayData = arr.asInstanceOf[ArrayData]
+    val numberOfElements = arrayData.numElements() + 1
+    if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+      throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements)
+    }
+    val finalData = new Array[Any](numberOfElements)
+    finalData.update(0, elementData)
+    arrayData.foreach(elementType, (i: Int, v: Any) => finalData.update(i + 1, v))
+    new GenericArrayData(finalData)
+  }
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val leftGen = left.genCode(ctx)
+    val rightGen = right.genCode(ctx)
+    val f = (arr: String, value: String) => {
+      val newArraySize = s"$arr.numElements() + 1"
+      val newArray = ctx.freshName("newArray")
+      val i = ctx.freshName("i")
+      val iPlus1 = s"$i+1"
+      val zero = "0"
+      val allocation = CodeGenerator.createArrayData(
+        newArray,
+        elementType,
+        newArraySize,
+        s" $prettyName failed.")
+      val assignment =
+        CodeGenerator.createArrayAssignment(newArray, elementType, arr, iPlus1, i, false)
+      val newElemAssignment =
+        CodeGenerator.setArrayElement(newArray, elementType, zero, value, Some(rightGen.isNull))
+      s"""
+         |$allocation
+         |$newElemAssignment
+         |for (int $i = 0; $i < $arr.numElements(); $i ++) {
+         |  $assignment
+         |}
+         |${ev.value} = $newArray;
+         |""".stripMargin
+    }
+    val resultCode = f(leftGen.value, rightGen.value)
+    if(nullable) {
+      val nullSafeEval = leftGen.code + rightGen.code + ctx.nullSafeExec(nullable, leftGen.isNull) {
+        s"""
+           |${ev.isNull} = false;
+           |${resultCode}
+           |""".stripMargin
+      }
+      ev.copy(code =
+        code"""
+          |boolean ${ev.isNull} = true;
+          |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+          |$nullSafeEval
+        """.stripMargin
+      )
+    } else {
+      ev.copy(code =
+        code"""
+          |${leftGen.code}
+          |${rightGen.code}
+          |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+          |$resultCode
+        """.stripMargin, isNull = FalseLiteral)
+    }
+  }
+
+  override def prettyName: String = "array_prepend"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): ArrayPrepend =
+    copy(left = newLeft, right = newRight)
+
+  override def dataType: DataType = if (right.nullable) left.dataType.asNullable else left.dataType
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    (left.dataType, right.dataType) match {
+      case (ArrayType(e1, _), e2) if e1.sameType(e2) => TypeCheckResult.TypeCheckSuccess
+      case (ArrayType(e1, _), e2) => DataTypeMismatch(
+        errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
+        messageParameters = Map(
+          "functionName" -> toSQLId(prettyName),
+          "leftType" -> toSQLType(left.dataType),
+          "rightType" -> toSQLType(right.dataType),
+          "dataType" -> toSQLType(ArrayType)
+        ))
+      case _ =>
+        DataTypeMismatch(
+          errorSubClass = "UNEXPECTED_INPUT_TYPE",
+          messageParameters = Map(
+            "paramIndex" -> "0",
+            "requiredType" -> toSQLType(ArrayType),
+            "inputSql" -> toSQLExpr(left),
+            "inputType" -> toSQLType(left.dataType)
+          )
+        )
+    }
+  }
+  override def inputTypes: Seq[AbstractDataType] = {
+    (left.dataType, right.dataType) match {
+      case (ArrayType(e1, hasNull), e2) =>
+        TypeCoercion.findTightestCommonType(e1, e2) match {
+          case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
+          case _ => Seq.empty
+        }
+      case _ => Seq.empty
+    }
+  }
+}
+
 /**
  * Checks if the two arrays contain at least one common element.
  */
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 60300ba62f2..3abc70a3d55 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -1855,6 +1855,50 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null)
   }
 
+  test("SPARK-41233: ArrayPrepend") {
+    val a0 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType))
+    val a1 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType))
+    val a2 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
+    val a3 = Literal.create(null, ArrayType(StringType))
+
+    checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4))
+    checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c"))
+    checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1))
+    checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), Seq(null))
+    checkEvaluation(ArrayPrepend(a3, Literal("a")), null)
+    checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null)
+
+    // complex data types
+    val data = Seq[Array[Byte]](
+      Array[Byte](5, 6),
+      Array[Byte](1, 2),
+      Array[Byte](1, 2),
+      Array[Byte](5, 6))
+    val b0 = Literal.create(
+      data,
+      ArrayType(BinaryType))
+    val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), ArrayType(BinaryType))
+    val nullBinary = Literal.create(null, BinaryType)
+    // Calling ArrayPrepend with a null element should result in NULL being prepended to the array
+    val dataWithNullPrepended = null +: data
+    checkEvaluation(ArrayPrepend(b0, nullBinary), dataWithNullPrepended)
+    val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType)
+    checkEvaluation(
+      ArrayPrepend(b1, dataToPrepend1),
+      Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](2, 1), null))
+
+    val c0 = Literal.create(
+      Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
+      ArrayType(ArrayType(IntegerType)))
+    val dataToPrepend2 = Literal.create(Seq[Int](5, 6), ArrayType(IntegerType))
+    checkEvaluation(
+      ArrayPrepend(c0, dataToPrepend2),
+      Seq(Seq[Int](5, 6), Seq[Int](1, 2), Seq[Int](3, 4)))
+    checkEvaluation(
+      ArrayPrepend(c0, Literal.create(Seq.empty[Int], ArrayType(IntegerType))),
+      Seq(Seq.empty[Int], Seq[Int](1, 2), Seq[Int](3, 4)))
+  }
+
   test("Array remove") {
     val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType))
     val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index cb5c1ad5c49..d771367f318 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -4043,6 +4043,16 @@ object functions {
   def array_compact(column: Column): Column = withExpr {
     ArrayCompact(column.expr)
   }
+    /**
+     * Returns an array containing value as well as all elements from array. The new element is
+     * positioned at the beginning of the array.
+     *
+     * @group collection_funcs
+     * @since 3.5.0
+     */
+  def array_prepend(column: Column, element: Any): Column = withExpr {
+    ArrayPrepend(column.expr, lit(element).expr)
+  }
 
   /**
    * Removes duplicate values from the array.
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 0894d03f9d4..6b5b67f9849 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -26,6 +26,7 @@
 | org.apache.spark.sql.catalyst.expressions.ArrayMax | array_max | SELECT array_max(array(1, 20, null, 3)) | struct<array_max(array(1, 20, NULL, 3)):int> |
 | org.apache.spark.sql.catalyst.expressions.ArrayMin | array_min | SELECT array_min(array(1, 20, null, 3)) | struct<array_min(array(1, 20, NULL, 3)):int> |
 | org.apache.spark.sql.catalyst.expressions.ArrayPosition | array_position | SELECT array_position(array(3, 2, 1), 1) | struct<array_position(array(3, 2, 1), 1):bigint> |
+| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct<array_prepend(array(b, d, c, a), d):array<string>> |
 | org.apache.spark.sql.catalyst.expressions.ArrayRemove | array_remove | SELECT array_remove(array(1, 2, 3, null, 3), 3) | struct<array_remove(array(1, 2, 3, NULL, 3), 3):array<int>> |
 | org.apache.spark.sql.catalyst.expressions.ArrayRepeat | array_repeat | SELECT array_repeat('123', 2) | struct<array_repeat(123, 2):array<string>> |
 | org.apache.spark.sql.catalyst.expressions.ArraySize | array_size | SELECT array_size(array('b', 'd', 'c', 'a')) | struct<array_size(array(b, d, c, a)):int> |
@@ -421,4 +422,4 @@
 | org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>','a/b/text()') | struct<xpath(<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>, a/b/text()):array<string>> |
 | org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_long(<a><b>1</b><b>2</b></a>, sum(a/b)):bigint> |
 | org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_short(<a><b>1</b><b>2</b></a>, sum(a/b)):smallint> |
-| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> |
\ No newline at end of file
+| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> |
diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql
index 3d107cb6dfc..d3c36b79d1f 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/array.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql
@@ -160,3 +160,14 @@ select array_append(CAST(null AS ARRAY<String>), CAST(null as String));
 select array_append(array(), 1);
 select array_append(CAST(array() AS ARRAY<String>), CAST(NULL AS String));
 select array_append(array(CAST(NULL AS String)), CAST(NULL AS String));
+
+-- function array_prepend
+select array_prepend(array(1, 2, 3), 4);
+select array_prepend(array('a', 'b', 'c'), 'd');
+select array_prepend(array(1, 2, 3, NULL), NULL);
+select array_prepend(array('a', 'b', 'c', NULL), NULL);
+select array_prepend(CAST(null AS ARRAY<String>), 'a');
+select array_prepend(CAST(null AS ARRAY<String>), CAST(null as String));
+select array_prepend(array(), 1);
+select array_prepend(CAST(array() AS ARRAY<String>), CAST(NULL AS String));
+select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String));
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
index 0d8ef39ed60..d228c605705 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
@@ -784,3 +784,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String))
 struct<array_append(array(CAST(NULL AS STRING)), CAST(NULL AS STRING)):array<string>>
 -- !query output
 [null,null]
+
+
+-- !query
+select array_prepend(array(1, 2, 3), 4)
+-- !query schema
+struct<array_prepend(array(1, 2, 3), 4):array<int>>
+-- !query output
+[4,1,2,3]
+
+
+-- !query
+select array_prepend(array('a', 'b', 'c'), 'd')
+-- !query schema
+struct<array_prepend(array(a, b, c), d):array<string>>
+-- !query output
+["d","a","b","c"]
+
+
+-- !query
+select array_prepend(array(1, 2, 3, NULL), NULL)
+-- !query schema
+struct<array_prepend(array(1, 2, 3, NULL), NULL):array<int>>
+-- !query output
+[null,1,2,3,null]
+
+
+-- !query
+select array_prepend(array('a', 'b', 'c', NULL), NULL)
+-- !query schema
+struct<array_prepend(array(a, b, c, NULL), NULL):array<string>>
+-- !query output
+[null,"a","b","c",null]
+
+
+-- !query
+select array_prepend(CAST(null AS ARRAY<String>), 'a')
+-- !query schema
+struct<array_prepend(NULL, a):array<string>>
+-- !query output
+NULL
+
+
+-- !query
+select array_prepend(CAST(null AS ARRAY<String>), CAST(null as String))
+-- !query schema
+struct<array_prepend(NULL, CAST(NULL AS STRING)):array<string>>
+-- !query output
+NULL
+
+
+-- !query
+select array_prepend(array(), 1)
+-- !query schema
+struct<array_prepend(array(), 1):array<int>>
+-- !query output
+[1]
+
+
+-- !query
+select array_prepend(CAST(array() AS ARRAY<String>), CAST(NULL AS String))
+-- !query schema
+struct<array_prepend(array(), CAST(NULL AS STRING)):array<string>>
+-- !query output
+[null]
+
+
+-- !query
+select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String))
+-- !query schema
+struct<array_prepend(array(CAST(NULL AS STRING)), CAST(NULL AS STRING)):array<string>>
+-- !query output
+[null,null]
diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out
index 609122a23d3..029bd767f54 100644
--- a/sql/core/src/test/resources/sql-tests/results/array.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out
@@ -665,3 +665,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String))
 struct<array_append(array(CAST(NULL AS STRING)), CAST(NULL AS STRING)):array<string>>
 -- !query output
 [null,null]
+
+
+-- !query
+select array_prepend(array(1, 2, 3), 4)
+-- !query schema
+struct<array_prepend(array(1, 2, 3), 4):array<int>>
+-- !query output
+[4,1,2,3]
+
+
+-- !query
+select array_prepend(array('a', 'b', 'c'), 'd')
+-- !query schema
+struct<array_prepend(array(a, b, c), d):array<string>>
+-- !query output
+["d","a","b","c"]
+
+
+-- !query
+select array_prepend(array(1, 2, 3, NULL), NULL)
+-- !query schema
+struct<array_prepend(array(1, 2, 3, NULL), NULL):array<int>>
+-- !query output
+[null,1,2,3,null]
+
+
+-- !query
+select array_prepend(array('a', 'b', 'c', NULL), NULL)
+-- !query schema
+struct<array_prepend(array(a, b, c, NULL), NULL):array<string>>
+-- !query output
+[null,"a","b","c",null]
+
+
+-- !query
+select array_prepend(CAST(null AS ARRAY<String>), 'a')
+-- !query schema
+struct<array_prepend(NULL, a):array<string>>
+-- !query output
+NULL
+
+
+-- !query
+select array_prepend(CAST(null AS ARRAY<String>), CAST(null as String))
+-- !query schema
+struct<array_prepend(NULL, CAST(NULL AS STRING)):array<string>>
+-- !query output
+NULL
+
+
+-- !query
+select array_prepend(array(), 1)
+-- !query schema
+struct<array_prepend(array(), 1):array<int>>
+-- !query output
+[1]
+
+
+-- !query
+select array_prepend(CAST(array() AS ARRAY<String>), CAST(NULL AS String))
+-- !query schema
+struct<array_prepend(array(), CAST(NULL AS STRING)):array<string>>
+-- !query output
+[null]
+
+
+-- !query
+select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String))
+-- !query schema
+struct<array_prepend(array(CAST(NULL AS STRING)), CAST(NULL AS STRING)):array<string>>
+-- !query output
+[null,null]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index bd03d292820..355f2dfffb5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -2651,6 +2651,74 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
     )
   }
 
+  test("SPARK-41233: array prepend") {
+    val df = Seq(
+      (Array[Int](2, 3, 4), Array("b", "c", "d"), Array("", ""), 2),
+      (Array.empty[Int], Array.empty[String], Array.empty[String], 2),
+      (null, null, null, 2)).toDF("a", "b", "c", "d")
+    checkAnswer(
+      df.select(array_prepend($"a", 1), array_prepend($"b", "a"), array_prepend($"c", "")),
+      Seq(
+        Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")),
+        Row(Seq(1), Seq("a"), Seq("")),
+        Row(null, null, null)))
+    checkAnswer(
+      df.select(array_prepend($"a", $"d")),
+      Seq(
+        Row(Seq(2, 2, 3, 4)),
+        Row(Seq(2)),
+        Row(null)))
+    checkAnswer(
+      df.selectExpr("array_prepend(a, d)"),
+      Seq(
+        Row(Seq(2, 2, 3, 4)),
+        Row(Seq(2)),
+        Row(null)))
+    checkAnswer(
+      OneRowRelation().selectExpr("array_prepend(array(1, 2), 1.23D)"),
+      Seq(
+        Row(Seq(1.23, 1.0, 2.0))
+      )
+    )
+    checkAnswer(
+      df.selectExpr("array_prepend(a, 1)", "array_prepend(b, \"a\")", "array_prepend(c, \"\")"),
+      Seq(
+        Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")),
+        Row(Seq(1), Seq("a"), Seq("")),
+        Row(null, null, null)))
+    checkError(
+      exception = intercept[AnalysisException] {
+        Seq(("a string element", "a")).toDF().selectExpr("array_prepend(_1, _2)")
+      },
+      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      parameters = Map(
+        "paramIndex" -> "0",
+        "sqlExpr" -> "\"array_prepend(_1, _2)\"",
+        "inputSql" -> "\"_1\"",
+        "inputType" -> "\"STRING\"",
+        "requiredType" -> "\"ARRAY\""),
+      queryContext = Array(ExpectedContext("", "", 0, 20, "array_prepend(_1, _2)")))
+    checkError(
+      exception = intercept[AnalysisException] {
+        OneRowRelation().selectExpr("array_prepend(array(1, 2), '1')")
+      },
+      errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES",
+      parameters = Map(
+        "sqlExpr" -> "\"array_prepend(array(1, 2), 1)\"",
+        "functionName" -> "`array_prepend`",
+        "dataType" -> "\"ARRAY\"",
+        "leftType" -> "\"ARRAY<INT>\"",
+        "rightType" -> "\"STRING\""),
+      queryContext = Array(ExpectedContext("", "", 0, 30, "array_prepend(array(1, 2), '1')")))
+    val df2 = Seq((Array[String]("a", "b", "c"), "d"),
+      (null, "d"),
+      (Array[String]("x", "y", "z"), null),
+      (null, null)
+    ).toDF("a", "b")
+    checkAnswer(df2.selectExpr("array_prepend(a, b)"),
+      Seq(Row(Seq("d", "a", "b", "c")), Row(null), Row(Seq(null, "x", "y", "z")), Row(null)))
+  }
+
   test("array remove") {
     val df = Seq(
       (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org