You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2018/04/18 09:42:25 UTC
spark git commit: [SPARK-23926][SQL] Extending reverse function to
support ArrayType arguments
Repository: spark
Updated Branches:
refs/heads/master cce469435 -> f81fa478f
[SPARK-23926][SQL] Extending reverse function to support ArrayType arguments
## What changes were proposed in this pull request?
This PR extends `reverse` functions to be able to operate over array columns and covers:
- Introduction of `Reverse` expression that represents logic for reversing arrays and also strings
- Removal of `StringReverse` expression
- A wrapper for PySpark
## How was this patch tested?
New tests added into:
- CollectionExpressionsSuite
- DataFrameFunctionsSuite
## Codegen examples
### Primitive type
```
val df = Seq(
Seq(1, 3, 4, 2),
null
).toDF("i")
df.filter($"i".isNotNull || $"i".isNull).select(reverse($"i")).debugCodegen
```
Result:
```
/* 032 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 033 */ ArrayData inputadapter_value = inputadapter_isNull ?
/* 034 */ null : (inputadapter_row.getArray(0));
/* 035 */
/* 036 */ boolean filter_value = true;
/* 037 */
/* 038 */ if (!(!inputadapter_isNull)) {
/* 039 */ filter_value = inputadapter_isNull;
/* 040 */ }
/* 041 */ if (!filter_value) continue;
/* 042 */
/* 043 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 044 */
/* 045 */ boolean project_isNull = inputadapter_isNull;
/* 046 */ ArrayData project_value = null;
/* 047 */
/* 048 */ if (!inputadapter_isNull) {
/* 049 */ final int project_length = inputadapter_value.numElements();
/* 050 */ project_value = inputadapter_value.copy();
/* 051 */ for(int k = 0; k < project_length / 2; k++) {
/* 052 */ int l = project_length - k - 1;
/* 053 */ boolean isNullAtK = project_value.isNullAt(k);
/* 054 */ boolean isNullAtL = project_value.isNullAt(l);
/* 055 */ if(!isNullAtK) {
/* 056 */ int el = project_value.getInt(k);
/* 057 */ if(!isNullAtL) {
/* 058 */ project_value.setInt(k, project_value.getInt(l));
/* 059 */ } else {
/* 060 */ project_value.setNullAt(k);
/* 061 */ }
/* 062 */ project_value.setInt(l, el);
/* 063 */ } else if (!isNullAtL) {
/* 064 */ project_value.setInt(k, project_value.getInt(l));
/* 065 */ project_value.setNullAt(l);
/* 066 */ }
/* 067 */ }
/* 068 */
/* 069 */ }
```
### Non-primitive type
```
val df = Seq(
Seq("a", "c", "d", "b"),
null
).toDF("s")
df.filter($"s".isNotNull || $"s".isNull).select(reverse($"s")).debugCodegen
```
Result:
```
/* 032 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 033 */ ArrayData inputadapter_value = inputadapter_isNull ?
/* 034 */ null : (inputadapter_row.getArray(0));
/* 035 */
/* 036 */ boolean filter_value = true;
/* 037 */
/* 038 */ if (!(!inputadapter_isNull)) {
/* 039 */ filter_value = inputadapter_isNull;
/* 040 */ }
/* 041 */ if (!filter_value) continue;
/* 042 */
/* 043 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 044 */
/* 045 */ boolean project_isNull = inputadapter_isNull;
/* 046 */ ArrayData project_value = null;
/* 047 */
/* 048 */ if (!inputadapter_isNull) {
/* 049 */ final int project_length = inputadapter_value.numElements();
/* 050 */ project_value = new org.apache.spark.sql.catalyst.util.GenericArrayData(new Object[project_length]);
/* 051 */ for(int k = 0; k < project_length; k++) {
/* 052 */ int l = project_length - k - 1;
/* 053 */ project_value.update(k, inputadapter_value.getUTF8String(l));
/* 054 */ }
/* 055 */
/* 056 */ }
```
Author: mn-mikke <mrkAha12346github>
Closes #21034 from mn-mikke/feature/array-api-reverse-to-master.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f81fa478
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f81fa478
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f81fa478
Branch: refs/heads/master
Commit: f81fa478ff990146e2a8e463ac252271448d96f5
Parents: cce4694
Author: mn-mikke <mrkAha12346github>
Authored: Wed Apr 18 18:41:55 2018 +0900
Committer: Takuya UESHIN <ue...@databricks.com>
Committed: Wed Apr 18 18:41:55 2018 +0900
----------------------------------------------------------------------
python/pyspark/sql/functions.py | 20 ++++-
.../catalyst/analysis/FunctionRegistry.scala | 2 +-
.../expressions/collectionOperations.scala | 88 ++++++++++++++++++
.../expressions/stringExpressions.scala | 20 -----
.../CollectionExpressionsSuite.scala | 44 +++++++++
.../expressions/StringExpressionsSuite.scala | 6 +-
.../scala/org/apache/spark/sql/functions.scala | 15 ++--
.../spark/sql/DataFrameFunctionsSuite.scala | 94 ++++++++++++++++++++
8 files changed, 256 insertions(+), 33 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/f81fa478/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 6ca22b6..d3bb0a5 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1414,7 +1414,6 @@ _string_functions = {
'uppercase. Words are delimited by whitespace.',
'lower': 'Converts a string column to lower case.',
'upper': 'Converts a string column to upper case.',
- 'reverse': 'Reverses the string column and returns it as a new string column.',
'ltrim': 'Trim the spaces from left end for the specified string value.',
'rtrim': 'Trim the spaces from right end for the specified string value.',
'trim': 'Trim the spaces from both ends for the specified string column.',
@@ -2128,6 +2127,25 @@ def sort_array(col, asc=True):
return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc))
+@since(1.5)
+@ignore_unicode_prefix
+def reverse(col):
+ """
+ Collection function: returns a reversed string or an array with reverse order of elements.
+
+ :param col: name of column or expression
+
+ >>> df = spark.createDataFrame([('Spark SQL',)], ['data'])
+ >>> df.select(reverse(df.data).alias('s')).collect()
+ [Row(s=u'LQS krapS')]
+ >>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data'])
+ >>> df.select(reverse(df.data).alias('r')).collect()
+ [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.reverse(_to_java_column(col)))
+
+
@since(2.3)
def map_keys(col):
"""
http://git-wip-us.apache.org/repos/asf/spark/blob/f81fa478/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 4dd1ca5..38c874a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -336,7 +336,6 @@ object FunctionRegistry {
expression[RegExpReplace]("regexp_replace"),
expression[StringRepeat]("repeat"),
expression[StringReplace]("replace"),
- expression[StringReverse]("reverse"),
expression[RLike]("rlike"),
expression[StringRPad]("rpad"),
expression[StringTrimRight]("rtrim"),
@@ -411,6 +410,7 @@ object FunctionRegistry {
expression[SortArray]("sort_array"),
expression[ArrayMin]("array_min"),
expression[ArrayMax]("array_max"),
+ expression[Reverse]("reverse"),
CreateStruct.registryEntry,
// misc functions
http://git-wip-us.apache.org/repos/asf/spark/blob/f81fa478/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 7c87777..76b71f5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
/**
* Given an array or map, returns its size. Returns -1 if null.
@@ -213,6 +214,93 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
}
/**
+ * Returns a reversed string or an array with reverse order of elements.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_('Spark SQL');
+ LQS krapS
+ > SELECT _FUNC_(array(2, 1, 4, 3));
+ [3, 4, 1, 2]
+ """,
+ since = "1.5.0",
+ note = "Reverse logic for arrays is available since 2.4.0."
+)
+case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+
+ // Input types are utilized by type coercion in ImplicitTypeCasts.
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType))
+
+ override def dataType: DataType = child.dataType
+
+ lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
+
+ override def nullSafeEval(input: Any): Any = input match {
+ case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse)
+ case s: UTF8String => s.reverse()
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, c => dataType match {
+ case _: StringType => stringCodeGen(ev, c)
+ case _: ArrayType => arrayCodeGen(ctx, ev, c)
+ })
+ }
+
+ private def stringCodeGen(ev: ExprCode, childName: String): String = {
+ s"${ev.value} = ($childName).reverse();"
+ }
+
+ private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = {
+ val length = ctx.freshName("length")
+ val javaElementType = CodeGenerator.javaType(elementType)
+ val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType)
+
+ val initialization = if (isPrimitiveType) {
+ s"$childName.copy()"
+ } else {
+ s"new ${classOf[GenericArrayData].getName()}(new Object[$length])"
+ }
+
+ val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length
+
+ val swapAssigments = if (isPrimitiveType) {
+ val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType)
+ val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index)
+ s"""|boolean isNullAtK = ${ev.value}.isNullAt(k);
+ |boolean isNullAtL = ${ev.value}.isNullAt(l);
+ |if(!isNullAtK) {
+ | $javaElementType el = ${getCall("k")};
+ | if(!isNullAtL) {
+ | ${ev.value}.$setFunc(k, ${getCall("l")});
+ | } else {
+ | ${ev.value}.setNullAt(k);
+ | }
+ | ${ev.value}.$setFunc(l, el);
+ |} else if (!isNullAtL) {
+ | ${ev.value}.$setFunc(k, ${getCall("l")});
+ | ${ev.value}.setNullAt(l);
+ |}""".stripMargin
+ } else {
+ s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});"
+ }
+
+ s"""
+ |final int $length = $childName.numElements();
+ |${ev.value} = $initialization;
+ |for(int k = 0; k < $numberOfIterations; k++) {
+ | int l = $length - k - 1;
+ | $swapAssigments
+ |}
+ """.stripMargin
+ }
+
+ override def prettyName: String = "reverse"
+}
+
+/**
* Checks if the array (left) has the element (right)
*/
@ExpressionDescription(
http://git-wip-us.apache.org/repos/asf/spark/blob/f81fa478/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 22fbb89..5a02ca0 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -1505,26 +1505,6 @@ case class StringRepeat(str: Expression, times: Expression)
}
/**
- * Returns the reversed given string.
- */
-@ExpressionDescription(
- usage = "_FUNC_(str) - Returns the reversed given string.",
- examples = """
- Examples:
- > SELECT _FUNC_('Spark SQL');
- LQS krapS
- """)
-case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression {
- override def convert(v: UTF8String): UTF8String = v.reverse()
-
- override def prettyName: String = "reverse"
-
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- defineCodeGen(ctx, ev, c => s"($c).reverse()")
- }
-}
-
-/**
* Returns a string consisting of n spaces.
*/
@ExpressionDescription(
http://git-wip-us.apache.org/repos/asf/spark/blob/f81fa478/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 5a31e3a..517639d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -125,4 +125,48 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(
ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123)
}
+
+ test("Reverse") {
+ // Primitive-type elements
+ val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType))
+ val ai1 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
+ val ai2 = Literal.create(Seq(null, 1, null, 3), ArrayType(IntegerType))
+ val ai3 = Literal.create(Seq(2, null, 4, null), ArrayType(IntegerType))
+ val ai4 = Literal.create(Seq(null, null, null), ArrayType(IntegerType))
+ val ai5 = Literal.create(Seq(1), ArrayType(IntegerType))
+ val ai6 = Literal.create(Seq.empty, ArrayType(IntegerType))
+ val ai7 = Literal.create(null, ArrayType(IntegerType))
+
+ checkEvaluation(Reverse(ai0), Seq(3, 4, 1, 2))
+ checkEvaluation(Reverse(ai1), Seq(3, 1, 2))
+ checkEvaluation(Reverse(ai2), Seq(3, null, 1, null))
+ checkEvaluation(Reverse(ai3), Seq(null, 4, null, 2))
+ checkEvaluation(Reverse(ai4), Seq(null, null, null))
+ checkEvaluation(Reverse(ai5), Seq(1))
+ checkEvaluation(Reverse(ai6), Seq.empty)
+ checkEvaluation(Reverse(ai7), null)
+
+ // Non-primitive-type elements
+ val as0 = Literal.create(Seq("b", "a", "d", "c"), ArrayType(StringType))
+ val as1 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType))
+ val as2 = Literal.create(Seq(null, "a", null, "c"), ArrayType(StringType))
+ val as3 = Literal.create(Seq("b", null, "d", null), ArrayType(StringType))
+ val as4 = Literal.create(Seq(null, null, null), ArrayType(StringType))
+ val as5 = Literal.create(Seq("a"), ArrayType(StringType))
+ val as6 = Literal.create(Seq.empty, ArrayType(StringType))
+ val as7 = Literal.create(null, ArrayType(StringType))
+ val aa = Literal.create(
+ Seq(Seq("a", "b"), Seq("c", "d"), Seq("e")),
+ ArrayType(ArrayType(StringType)))
+
+ checkEvaluation(Reverse(as0), Seq("c", "d", "a", "b"))
+ checkEvaluation(Reverse(as1), Seq("c", "a", "b"))
+ checkEvaluation(Reverse(as2), Seq("c", null, "a", null))
+ checkEvaluation(Reverse(as3), Seq(null, "d", null, "b"))
+ checkEvaluation(Reverse(as4), Seq(null, null, null))
+ checkEvaluation(Reverse(as5), Seq("a"))
+ checkEvaluation(Reverse(as6), Seq.empty)
+ checkEvaluation(Reverse(as7), null)
+ checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b")))
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f81fa478/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 9a1a4da..f1a6f9b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -629,9 +629,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("REVERSE") {
val s = 'a.string.at(0)
val row1 = create_row("abccc")
- checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1)
- checkEvaluation(StringReverse(s), "cccba", row1)
- checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1)
+ checkEvaluation(Reverse(Literal("abccc")), "cccba", row1)
+ checkEvaluation(Reverse(s), "cccba", row1)
+ checkEvaluation(Reverse(Literal.create(null, StringType)), null, row1)
}
test("SPACE") {
http://git-wip-us.apache.org/repos/asf/spark/blob/f81fa478/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 642ac05..a55a800 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2465,14 +2465,6 @@ object functions {
}
/**
- * Reverses the string column and returns it as a new string column.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def reverse(str: Column): Column = withExpr { StringReverse(str.expr) }
-
- /**
* Trim the spaces from right end for the specified string value.
*
* @group string_funcs
@@ -3317,6 +3309,13 @@ object functions {
def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) }
/**
+ * Returns a reversed string or an array with reverse order of elements.
+ * @group collection_funcs
+ * @since 1.5.0
+ */
+ def reverse(e: Column): Column = withExpr { Reverse(e.expr) }
+
+ /**
* Returns an unordered array containing the keys of the map.
* @group collection_funcs
* @since 2.3.0
http://git-wip-us.apache.org/repos/asf/spark/blob/f81fa478/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 636e86b..74c42f2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -441,6 +441,100 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.selectExpr("array_max(a)"), answer)
}
+ test("reverse function") {
+ val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on
+
+ // String test cases
+ val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i")
+
+ checkAnswer(
+ oneRowDF.select(reverse('s)),
+ Seq(Row("krapS"))
+ )
+ checkAnswer(
+ oneRowDF.selectExpr("reverse(s)"),
+ Seq(Row("krapS"))
+ )
+ checkAnswer(
+ oneRowDF.select(reverse('i)),
+ Seq(Row("5123"))
+ )
+ checkAnswer(
+ oneRowDF.selectExpr("reverse(i)"),
+ Seq(Row("5123"))
+ )
+ checkAnswer(
+ oneRowDF.selectExpr("reverse(null)"),
+ Seq(Row(null))
+ )
+
+ // Array test cases (primitive-type elements)
+ val idf = Seq(
+ Seq(1, 9, 8, 7),
+ Seq(5, 8, 9, 7, 2),
+ Seq.empty,
+ null
+ ).toDF("i")
+
+ checkAnswer(
+ idf.select(reverse('i)),
+ Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
+ )
+ checkAnswer(
+ idf.filter(dummyFilter('i)).select(reverse('i)),
+ Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
+ )
+ checkAnswer(
+ idf.selectExpr("reverse(i)"),
+ Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
+ )
+ checkAnswer(
+ oneRowDF.selectExpr("reverse(array(1, null, 2, null))"),
+ Seq(Row(Seq(null, 2, null, 1)))
+ )
+ checkAnswer(
+ oneRowDF.filter(dummyFilter('i)).selectExpr("reverse(array(1, null, 2, null))"),
+ Seq(Row(Seq(null, 2, null, 1)))
+ )
+
+ // Array test cases (non-primitive-type elements)
+ val sdf = Seq(
+ Seq("c", "a", "b"),
+ Seq("b", null, "c", null),
+ Seq.empty,
+ null
+ ).toDF("s")
+
+ checkAnswer(
+ sdf.select(reverse('s)),
+ Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
+ )
+ checkAnswer(
+ sdf.filter(dummyFilter('s)).select(reverse('s)),
+ Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
+ )
+ checkAnswer(
+ sdf.selectExpr("reverse(s)"),
+ Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
+ )
+ checkAnswer(
+ oneRowDF.selectExpr("reverse(array(array(1, 2), array(3, 4)))"),
+ Seq(Row(Seq(Seq(3, 4), Seq(1, 2))))
+ )
+ checkAnswer(
+ oneRowDF.filter(dummyFilter('s)).selectExpr("reverse(array(array(1, 2), array(3, 4)))"),
+ Seq(Row(Seq(Seq(3, 4), Seq(1, 2))))
+ )
+
+ // Error test cases
+ intercept[AnalysisException] {
+ oneRowDF.selectExpr("reverse(struct(1, 'a'))")
+ }
+ intercept[AnalysisException] {
+ oneRowDF.selectExpr("reverse(map(1, 'a'))")
+ }
+ }
+
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org