You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/03/16 09:51:56 UTC
[spark] branch master updated: [SPARK-41233][SQL][PYTHON] Add `array_prepend` function
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3dd629629ab [SPARK-41233][SQL][PYTHON] Add `array_prepend` function
3dd629629ab is described below
commit 3dd629629ab151688b82a3aa66e1b5fa568afbfa
Author: Navin Viswanath <na...@gmail.com>
AuthorDate: Thu Mar 16 17:51:33 2023 +0800
[SPARK-41233][SQL][PYTHON] Add `array_prepend` function
### What changes were proposed in this pull request?
Adds a new array function array_prepend to catalyst.
### Why are the changes needed?
This adds a function that exists in many SQL implementations, specifically Snowflake: https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/api/snowflake.snowpark.functions.array_prepend.html
### Does this PR introduce _any_ user-facing change?
Yes.
### How was this patch tested?
Added unit tests.
Closes #38947 from navinvishy/array-prepend.
Lead-authored-by: Navin Viswanath <na...@gmail.com>
Co-authored-by: navinvishy <na...@gmail.com>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../source/reference/pyspark.sql/functions.rst | 1 +
python/pyspark/sql/functions.py | 30 +++++
.../sql/catalyst/analysis/FunctionRegistry.scala | 1 +
.../expressions/collectionOperations.scala | 146 +++++++++++++++++++++
.../expressions/CollectionExpressionsSuite.scala | 44 +++++++
.../scala/org/apache/spark/sql/functions.scala | 10 ++
.../sql-functions/sql-expression-schema.md | 3 +-
.../src/test/resources/sql-tests/inputs/array.sql | 11 ++
.../resources/sql-tests/results/ansi/array.sql.out | 72 ++++++++++
.../test/resources/sql-tests/results/array.sql.out | 72 ++++++++++
.../apache/spark/sql/DataFrameFunctionsSuite.scala | 68 ++++++++++
11 files changed, 457 insertions(+), 1 deletion(-)
diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst
index 70fc04ef9cf..cbc46e1fae1 100644
--- a/python/docs/source/reference/pyspark.sql/functions.rst
+++ b/python/docs/source/reference/pyspark.sql/functions.rst
@@ -159,6 +159,7 @@ Collection Functions
array_sort
array_insert
array_remove
+ array_prepend
array_distinct
array_intersect
array_union
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 051fd52a13c..1f02be3ad21 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -7631,6 +7631,36 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column:
return _invoke_function_over_columns("get", col, index)
+@try_remote_functions
+def array_prepend(col: "ColumnOrName", value: Any) -> Column:
+ """
+ Collection function: Returns an array containing element as
+ well as all elements from array. The new element is positioned
+ at the beginning of the array.
+
+ .. versionadded:: 3.5.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ name of column containing array
+ value :
+ a literal value, or a :class:`~pyspark.sql.Column` expression.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ an array excluding given value.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data'])
+ >>> df.select(array_prepend(df.data, 1)).collect()
+ [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])]
+ """
+ return _invoke_function_over_columns("array_prepend", col, lit(value))
+
+
@try_remote_functions
def array_remove(col: "ColumnOrName", element: Any) -> Column:
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index ad82a836199..aca73741c63 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -697,6 +697,7 @@ object FunctionRegistry {
expression[Sequence]("sequence"),
expression[ArrayRepeat]("array_repeat"),
expression[ArrayRemove]("array_remove"),
+ expression[ArrayPrepend]("array_prepend"),
expression[ArrayDistinct]("array_distinct"),
expression[ArrayTransform]("transform"),
expression[MapFilter]("map_filter"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 289859d420b..2ccb3a6d0cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1399,6 +1399,152 @@ case class ArrayContains(left: Expression, right: Expression)
copy(left = newLeft, right = newRight)
}
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array, element) - Add the element at the beginning of the array passed as first
+ argument. Type of element should be the same as the type of the elements of the array.
+ Null element is also prepended to the array. But if the array passed is NULL
+ output is NULL
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd');
+ ["d","b","d","c","a"]
+ > SELECT _FUNC_(array(1, 2, 3, null), null);
+ [null,1,2,3,null]
+ > SELECT _FUNC_(CAST(null as Array<Int>), 2);
+ NULL
+ """,
+ group = "array_funcs",
+ since = "3.5.0")
+case class ArrayPrepend(left: Expression, right: Expression)
+ extends BinaryExpression
+ with ImplicitCastInputTypes
+ with ComplexTypeMergingExpression
+ with QueryErrorsBase {
+
+ override def nullable: Boolean = left.nullable
+
+ @transient protected lazy val elementType: DataType =
+ inputTypes.head.asInstanceOf[ArrayType].elementType
+
+ override def eval(input: InternalRow): Any = {
+ val value1 = left.eval(input)
+ if (value1 == null) {
+ null
+ } else {
+ val value2 = right.eval(input)
+ nullSafeEval(value1, value2)
+ }
+ }
+ override def nullSafeEval(arr: Any, elementData: Any): Any = {
+ val arrayData = arr.asInstanceOf[ArrayData]
+ val numberOfElements = arrayData.numElements() + 1
+ if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+ throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements)
+ }
+ val finalData = new Array[Any](numberOfElements)
+ finalData.update(0, elementData)
+ arrayData.foreach(elementType, (i: Int, v: Any) => finalData.update(i + 1, v))
+ new GenericArrayData(finalData)
+ }
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val leftGen = left.genCode(ctx)
+ val rightGen = right.genCode(ctx)
+ val f = (arr: String, value: String) => {
+ val newArraySize = s"$arr.numElements() + 1"
+ val newArray = ctx.freshName("newArray")
+ val i = ctx.freshName("i")
+ val iPlus1 = s"$i+1"
+ val zero = "0"
+ val allocation = CodeGenerator.createArrayData(
+ newArray,
+ elementType,
+ newArraySize,
+ s" $prettyName failed.")
+ val assignment =
+ CodeGenerator.createArrayAssignment(newArray, elementType, arr, iPlus1, i, false)
+ val newElemAssignment =
+ CodeGenerator.setArrayElement(newArray, elementType, zero, value, Some(rightGen.isNull))
+ s"""
+ |$allocation
+ |$newElemAssignment
+ |for (int $i = 0; $i < $arr.numElements(); $i ++) {
+ | $assignment
+ |}
+ |${ev.value} = $newArray;
+ |""".stripMargin
+ }
+ val resultCode = f(leftGen.value, rightGen.value)
+ if(nullable) {
+ val nullSafeEval = leftGen.code + rightGen.code + ctx.nullSafeExec(nullable, leftGen.isNull) {
+ s"""
+ |${ev.isNull} = false;
+ |${resultCode}
+ |""".stripMargin
+ }
+ ev.copy(code =
+ code"""
+ |boolean ${ev.isNull} = true;
+ |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+ |$nullSafeEval
+ """.stripMargin
+ )
+ } else {
+ ev.copy(code =
+ code"""
+ |${leftGen.code}
+ |${rightGen.code}
+ |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+ |$resultCode
+ """.stripMargin, isNull = FalseLiteral)
+ }
+ }
+
+ override def prettyName: String = "array_prepend"
+
+ override protected def withNewChildrenInternal(
+ newLeft: Expression, newRight: Expression): ArrayPrepend =
+ copy(left = newLeft, right = newRight)
+
+ override def dataType: DataType = if (right.nullable) left.dataType.asNullable else left.dataType
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (left.dataType, right.dataType) match {
+ case (ArrayType(e1, _), e2) if e1.sameType(e2) => TypeCheckResult.TypeCheckSuccess
+ case (ArrayType(e1, _), e2) => DataTypeMismatch(
+ errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
+ messageParameters = Map(
+ "functionName" -> toSQLId(prettyName),
+ "leftType" -> toSQLType(left.dataType),
+ "rightType" -> toSQLType(right.dataType),
+ "dataType" -> toSQLType(ArrayType)
+ ))
+ case _ =>
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> "0",
+ "requiredType" -> toSQLType(ArrayType),
+ "inputSql" -> toSQLExpr(left),
+ "inputType" -> toSQLType(left.dataType)
+ )
+ )
+ }
+ }
+ override def inputTypes: Seq[AbstractDataType] = {
+ (left.dataType, right.dataType) match {
+ case (ArrayType(e1, hasNull), e2) =>
+ TypeCoercion.findTightestCommonType(e1, e2) match {
+ case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
+ case _ => Seq.empty
+ }
+ case _ => Seq.empty
+ }
+ }
+}
+
/**
* Checks if the two arrays contain at least one common element.
*/
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 60300ba62f2..3abc70a3d55 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -1855,6 +1855,50 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null)
}
+ test("SPARK-41233: ArrayPrepend") {
+ val a0 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType))
+ val a1 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType))
+ val a2 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
+ val a3 = Literal.create(null, ArrayType(StringType))
+
+ checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4))
+ checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c"))
+ checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1))
+ checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), Seq(null))
+ checkEvaluation(ArrayPrepend(a3, Literal("a")), null)
+ checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null)
+
+ // complex data types
+ val data = Seq[Array[Byte]](
+ Array[Byte](5, 6),
+ Array[Byte](1, 2),
+ Array[Byte](1, 2),
+ Array[Byte](5, 6))
+ val b0 = Literal.create(
+ data,
+ ArrayType(BinaryType))
+ val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), ArrayType(BinaryType))
+ val nullBinary = Literal.create(null, BinaryType)
+ // Calling ArrayPrepend with a null element should result in NULL being prepended to the array
+ val dataWithNullPrepended = null +: data
+ checkEvaluation(ArrayPrepend(b0, nullBinary), dataWithNullPrepended)
+ val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType)
+ checkEvaluation(
+ ArrayPrepend(b1, dataToPrepend1),
+ Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](2, 1), null))
+
+ val c0 = Literal.create(
+ Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
+ ArrayType(ArrayType(IntegerType)))
+ val dataToPrepend2 = Literal.create(Seq[Int](5, 6), ArrayType(IntegerType))
+ checkEvaluation(
+ ArrayPrepend(c0, dataToPrepend2),
+ Seq(Seq[Int](5, 6), Seq[Int](1, 2), Seq[Int](3, 4)))
+ checkEvaluation(
+ ArrayPrepend(c0, Literal.create(Seq.empty[Int], ArrayType(IntegerType))),
+ Seq(Seq.empty[Int], Seq[Int](1, 2), Seq[Int](3, 4)))
+ }
+
test("Array remove") {
val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType))
val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index cb5c1ad5c49..d771367f318 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -4043,6 +4043,16 @@ object functions {
def array_compact(column: Column): Column = withExpr {
ArrayCompact(column.expr)
}
+ /**
+ * Returns an array containing value as well as all elements from array. The new element is
+ * positioned at the beginning of the array.
+ *
+ * @group collection_funcs
+ * @since 3.5.0
+ */
+ def array_prepend(column: Column, element: Any): Column = withExpr {
+ ArrayPrepend(column.expr, lit(element).expr)
+ }
/**
* Removes duplicate values from the array.
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 0894d03f9d4..6b5b67f9849 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -26,6 +26,7 @@
| org.apache.spark.sql.catalyst.expressions.ArrayMax | array_max | SELECT array_max(array(1, 20, null, 3)) | struct<array_max(array(1, 20, NULL, 3)):int> |
| org.apache.spark.sql.catalyst.expressions.ArrayMin | array_min | SELECT array_min(array(1, 20, null, 3)) | struct<array_min(array(1, 20, NULL, 3)):int> |
| org.apache.spark.sql.catalyst.expressions.ArrayPosition | array_position | SELECT array_position(array(3, 2, 1), 1) | struct<array_position(array(3, 2, 1), 1):bigint> |
+| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct<array_prepend(array(b, d, c, a), d):array<string>> |
| org.apache.spark.sql.catalyst.expressions.ArrayRemove | array_remove | SELECT array_remove(array(1, 2, 3, null, 3), 3) | struct<array_remove(array(1, 2, 3, NULL, 3), 3):array<int>> |
| org.apache.spark.sql.catalyst.expressions.ArrayRepeat | array_repeat | SELECT array_repeat('123', 2) | struct<array_repeat(123, 2):array<string>> |
| org.apache.spark.sql.catalyst.expressions.ArraySize | array_size | SELECT array_size(array('b', 'd', 'c', 'a')) | struct<array_size(array(b, d, c, a)):int> |
@@ -421,4 +422,4 @@
| org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>','a/b/text()') | struct<xpath(<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>, a/b/text()):array<string>> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_long(<a><b>1</b><b>2</b></a>, sum(a/b)):bigint> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_short(<a><b>1</b><b>2</b></a>, sum(a/b)):smallint> |
-| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> |
\ No newline at end of file
+| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> |
diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql
index 3d107cb6dfc..d3c36b79d1f 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/array.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql
@@ -160,3 +160,14 @@ select array_append(CAST(null AS ARRAY<String>), CAST(null as String));
select array_append(array(), 1);
select array_append(CAST(array() AS ARRAY<String>), CAST(NULL AS String));
select array_append(array(CAST(NULL AS String)), CAST(NULL AS String));
+
+-- function array_prepend
+select array_prepend(array(1, 2, 3), 4);
+select array_prepend(array('a', 'b', 'c'), 'd');
+select array_prepend(array(1, 2, 3, NULL), NULL);
+select array_prepend(array('a', 'b', 'c', NULL), NULL);
+select array_prepend(CAST(null AS ARRAY<String>), 'a');
+select array_prepend(CAST(null AS ARRAY<String>), CAST(null as String));
+select array_prepend(array(), 1);
+select array_prepend(CAST(array() AS ARRAY<String>), CAST(NULL AS String));
+select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String));
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
index 0d8ef39ed60..d228c605705 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
@@ -784,3 +784,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String))
struct<array_append(array(CAST(NULL AS STRING)), CAST(NULL AS STRING)):array<string>>
-- !query output
[null,null]
+
+
+-- !query
+select array_prepend(array(1, 2, 3), 4)
+-- !query schema
+struct<array_prepend(array(1, 2, 3), 4):array<int>>
+-- !query output
+[4,1,2,3]
+
+
+-- !query
+select array_prepend(array('a', 'b', 'c'), 'd')
+-- !query schema
+struct<array_prepend(array(a, b, c), d):array<string>>
+-- !query output
+["d","a","b","c"]
+
+
+-- !query
+select array_prepend(array(1, 2, 3, NULL), NULL)
+-- !query schema
+struct<array_prepend(array(1, 2, 3, NULL), NULL):array<int>>
+-- !query output
+[null,1,2,3,null]
+
+
+-- !query
+select array_prepend(array('a', 'b', 'c', NULL), NULL)
+-- !query schema
+struct<array_prepend(array(a, b, c, NULL), NULL):array<string>>
+-- !query output
+[null,"a","b","c",null]
+
+
+-- !query
+select array_prepend(CAST(null AS ARRAY<String>), 'a')
+-- !query schema
+struct<array_prepend(NULL, a):array<string>>
+-- !query output
+NULL
+
+
+-- !query
+select array_prepend(CAST(null AS ARRAY<String>), CAST(null as String))
+-- !query schema
+struct<array_prepend(NULL, CAST(NULL AS STRING)):array<string>>
+-- !query output
+NULL
+
+
+-- !query
+select array_prepend(array(), 1)
+-- !query schema
+struct<array_prepend(array(), 1):array<int>>
+-- !query output
+[1]
+
+
+-- !query
+select array_prepend(CAST(array() AS ARRAY<String>), CAST(NULL AS String))
+-- !query schema
+struct<array_prepend(array(), CAST(NULL AS STRING)):array<string>>
+-- !query output
+[null]
+
+
+-- !query
+select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String))
+-- !query schema
+struct<array_prepend(array(CAST(NULL AS STRING)), CAST(NULL AS STRING)):array<string>>
+-- !query output
+[null,null]
diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out
index 609122a23d3..029bd767f54 100644
--- a/sql/core/src/test/resources/sql-tests/results/array.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out
@@ -665,3 +665,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String))
struct<array_append(array(CAST(NULL AS STRING)), CAST(NULL AS STRING)):array<string>>
-- !query output
[null,null]
+
+
+-- !query
+select array_prepend(array(1, 2, 3), 4)
+-- !query schema
+struct<array_prepend(array(1, 2, 3), 4):array<int>>
+-- !query output
+[4,1,2,3]
+
+
+-- !query
+select array_prepend(array('a', 'b', 'c'), 'd')
+-- !query schema
+struct<array_prepend(array(a, b, c), d):array<string>>
+-- !query output
+["d","a","b","c"]
+
+
+-- !query
+select array_prepend(array(1, 2, 3, NULL), NULL)
+-- !query schema
+struct<array_prepend(array(1, 2, 3, NULL), NULL):array<int>>
+-- !query output
+[null,1,2,3,null]
+
+
+-- !query
+select array_prepend(array('a', 'b', 'c', NULL), NULL)
+-- !query schema
+struct<array_prepend(array(a, b, c, NULL), NULL):array<string>>
+-- !query output
+[null,"a","b","c",null]
+
+
+-- !query
+select array_prepend(CAST(null AS ARRAY<String>), 'a')
+-- !query schema
+struct<array_prepend(NULL, a):array<string>>
+-- !query output
+NULL
+
+
+-- !query
+select array_prepend(CAST(null AS ARRAY<String>), CAST(null as String))
+-- !query schema
+struct<array_prepend(NULL, CAST(NULL AS STRING)):array<string>>
+-- !query output
+NULL
+
+
+-- !query
+select array_prepend(array(), 1)
+-- !query schema
+struct<array_prepend(array(), 1):array<int>>
+-- !query output
+[1]
+
+
+-- !query
+select array_prepend(CAST(array() AS ARRAY<String>), CAST(NULL AS String))
+-- !query schema
+struct<array_prepend(array(), CAST(NULL AS STRING)):array<string>>
+-- !query output
+[null]
+
+
+-- !query
+select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String))
+-- !query schema
+struct<array_prepend(array(CAST(NULL AS STRING)), CAST(NULL AS STRING)):array<string>>
+-- !query output
+[null,null]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index bd03d292820..355f2dfffb5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -2651,6 +2651,74 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
)
}
+ test("SPARK-41233: array prepend") {
+ val df = Seq(
+ (Array[Int](2, 3, 4), Array("b", "c", "d"), Array("", ""), 2),
+ (Array.empty[Int], Array.empty[String], Array.empty[String], 2),
+ (null, null, null, 2)).toDF("a", "b", "c", "d")
+ checkAnswer(
+ df.select(array_prepend($"a", 1), array_prepend($"b", "a"), array_prepend($"c", "")),
+ Seq(
+ Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")),
+ Row(Seq(1), Seq("a"), Seq("")),
+ Row(null, null, null)))
+ checkAnswer(
+ df.select(array_prepend($"a", $"d")),
+ Seq(
+ Row(Seq(2, 2, 3, 4)),
+ Row(Seq(2)),
+ Row(null)))
+ checkAnswer(
+ df.selectExpr("array_prepend(a, d)"),
+ Seq(
+ Row(Seq(2, 2, 3, 4)),
+ Row(Seq(2)),
+ Row(null)))
+ checkAnswer(
+ OneRowRelation().selectExpr("array_prepend(array(1, 2), 1.23D)"),
+ Seq(
+ Row(Seq(1.23, 1.0, 2.0))
+ )
+ )
+ checkAnswer(
+ df.selectExpr("array_prepend(a, 1)", "array_prepend(b, \"a\")", "array_prepend(c, \"\")"),
+ Seq(
+ Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")),
+ Row(Seq(1), Seq("a"), Seq("")),
+ Row(null, null, null)))
+ checkError(
+ exception = intercept[AnalysisException] {
+ Seq(("a string element", "a")).toDF().selectExpr("array_prepend(_1, _2)")
+ },
+ errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ parameters = Map(
+ "paramIndex" -> "0",
+ "sqlExpr" -> "\"array_prepend(_1, _2)\"",
+ "inputSql" -> "\"_1\"",
+ "inputType" -> "\"STRING\"",
+ "requiredType" -> "\"ARRAY\""),
+ queryContext = Array(ExpectedContext("", "", 0, 20, "array_prepend(_1, _2)")))
+ checkError(
+ exception = intercept[AnalysisException] {
+ OneRowRelation().selectExpr("array_prepend(array(1, 2), '1')")
+ },
+ errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES",
+ parameters = Map(
+ "sqlExpr" -> "\"array_prepend(array(1, 2), 1)\"",
+ "functionName" -> "`array_prepend`",
+ "dataType" -> "\"ARRAY\"",
+ "leftType" -> "\"ARRAY<INT>\"",
+ "rightType" -> "\"STRING\""),
+ queryContext = Array(ExpectedContext("", "", 0, 30, "array_prepend(array(1, 2), '1')")))
+ val df2 = Seq((Array[String]("a", "b", "c"), "d"),
+ (null, "d"),
+ (Array[String]("x", "y", "z"), null),
+ (null, null)
+ ).toDF("a", "b")
+ checkAnswer(df2.selectExpr("array_prepend(a, b)"),
+ Seq(Row(Seq("d", "a", "b", "c")), Row(null), Row(Seq(null, "x", "y", "z")), Row(null)))
+ }
+
test("array remove") {
val df = Seq(
(Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2),
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org