You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2018/05/21 14:14:13 UTC
spark git commit: [SPARK-23935][SQL] Adding map_entries function
Repository: spark
Updated Branches:
refs/heads/master e480eccd9 -> a6e883feb
[SPARK-23935][SQL] Adding map_entries function
## What changes were proposed in this pull request?
This PR adds `map_entries` function that returns an unordered array of all entries in the given map.
## How was this patch tested?
New tests added into:
- `CollectionExpressionSuite`
- `DataFrameFunctionsSuite`
## CodeGen examples
### Primitive types
```
val df = Seq(Map(1 -> 5, 2 -> 6)).toDF("m")
df.filter('m.isNotNull).select(map_entries('m)).debugCodegen
```
Result:
```
/* 042 */ boolean project_isNull_0 = false;
/* 043 */
/* 044 */ ArrayData project_value_0 = null;
/* 045 */
/* 046 */ final int project_numElements_0 = inputadapter_value_0.numElements();
/* 047 */ final ArrayData project_keys_0 = inputadapter_value_0.keyArray();
/* 048 */ final ArrayData project_values_0 = inputadapter_value_0.valueArray();
/* 049 */
/* 050 */ final long project_size_0 = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
/* 051 */ project_numElements_0,
/* 052 */ 32);
/* 053 */ if (project_size_0 > 2147483632) {
/* 054 */ final Object[] project_internalRowArray_0 = new Object[project_numElements_0];
/* 055 */ for (int z = 0; z < project_numElements_0; z++) {
/* 056 */ project_internalRowArray_0[z] = new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(new Object[]{project_keys_0.getInt(z), project_values_0.getInt(z)});
/* 057 */ }
/* 058 */ project_value_0 = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_internalRowArray_0);
/* 059 */
/* 060 */ } else {
/* 061 */ final byte[] project_arrayBytes_0 = new byte[(int)project_size_0];
/* 062 */ UnsafeArrayData project_unsafeArrayData_0 = new UnsafeArrayData();
/* 063 */ Platform.putLong(project_arrayBytes_0, 16, project_numElements_0);
/* 064 */ project_unsafeArrayData_0.pointTo(project_arrayBytes_0, 16, (int)project_size_0);
/* 065 */
/* 066 */ final int project_structsOffset_0 = UnsafeArrayData.calculateHeaderPortionInBytes(project_numElements_0) + project_numElements_0 * 8;
/* 067 */ UnsafeRow project_unsafeRow_0 = new UnsafeRow(2);
/* 068 */ for (int z = 0; z < project_numElements_0; z++) {
/* 069 */ long offset = project_structsOffset_0 + z * 24L;
/* 070 */ project_unsafeArrayData_0.setLong(z, (offset << 32) + 24L);
/* 071 */ project_unsafeRow_0.pointTo(project_arrayBytes_0, 16 + offset, 24);
/* 072 */ project_unsafeRow_0.setInt(0, project_keys_0.getInt(z));
/* 073 */ project_unsafeRow_0.setInt(1, project_values_0.getInt(z));
/* 074 */ }
/* 075 */ project_value_0 = project_unsafeArrayData_0;
/* 076 */
/* 077 */ }
```
### Non-primitive types
```
val df = Seq(Map("a" -> "foo", "b" -> null)).toDF("m")
df.filter('m.isNotNull).select(map_entries('m)).debugCodegen
```
Result:
```
/* 042 */ boolean project_isNull_0 = false;
/* 043 */
/* 044 */ ArrayData project_value_0 = null;
/* 045 */
/* 046 */ final int project_numElements_0 = inputadapter_value_0.numElements();
/* 047 */ final ArrayData project_keys_0 = inputadapter_value_0.keyArray();
/* 048 */ final ArrayData project_values_0 = inputadapter_value_0.valueArray();
/* 049 */
/* 050 */ final Object[] project_internalRowArray_0 = new Object[project_numElements_0];
/* 051 */ for (int z = 0; z < project_numElements_0; z++) {
/* 052 */ project_internalRowArray_0[z] = new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(new Object[]{project_keys_0.getUTF8String(z), project_values_0.getUTF8String(z)});
/* 053 */ }
/* 054 */ project_value_0 = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_internalRowArray_0);
```
Author: Marek Novotny <mn...@gmail.com>
Closes #21236 from mn-mikke/feature/array-api-map_entries-to-master.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a6e883fe
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a6e883fe
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a6e883fe
Branch: refs/heads/master
Commit: a6e883feb3b78232ad5cf636f7f7d5e825183041
Parents: e480ecc
Author: Marek Novotny <mn...@gmail.com>
Authored: Mon May 21 23:14:03 2018 +0900
Committer: Takuya UESHIN <ue...@databricks.com>
Committed: Mon May 21 23:14:03 2018 +0900
----------------------------------------------------------------------
python/pyspark/sql/functions.py | 20 +++
.../sql/catalyst/expressions/UnsafeRow.java | 2 +
.../catalyst/analysis/FunctionRegistry.scala | 1 +
.../expressions/codegen/CodeGenerator.scala | 34 +++++
.../expressions/collectionOperations.scala | 153 +++++++++++++++++++
.../CollectionExpressionsSuite.scala | 23 +++
.../expressions/ExpressionEvalHelper.scala | 3 +
.../scala/org/apache/spark/sql/functions.scala | 7 +
.../spark/sql/DataFrameFunctionsSuite.scala | 44 ++++++
9 files changed, 287 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 8490081..fbc8a2d 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2344,6 +2344,26 @@ def map_values(col):
return Column(sc._jvm.functions.map_values(_to_java_column(col)))
+@since(2.4)
+def map_entries(col):
+ """
+ Collection function: Returns an unordered array of all entries in the given map.
+
+ :param col: name of column or expression
+
+ >>> from pyspark.sql.functions import map_entries
+ >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")
+ >>> df.select(map_entries("data").alias("entries")).show()
+ +----------------+
+ | entries|
+ +----------------+
+ |[[1, a], [2, b]]|
+ +----------------+
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.map_entries(_to_java_column(col)))
+
+
@ignore_unicode_prefix
@since(2.4)
def array_repeat(col, count):
http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 29a1411..469b0e6 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -62,6 +62,8 @@ import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
*/
public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable {
+ public static final int WORD_SIZE = 8;
+
//////////////////////////////////////////////////////////////////////////////
// Static methods
//////////////////////////////////////////////////////////////////////////////
http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 867c2d5..1134a88 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -419,6 +419,7 @@ object FunctionRegistry {
expression[ElementAt]("element_at"),
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
+ expression[MapEntries]("map_entries"),
expression[Size]("size"),
expression[Slice]("slice"),
expression[Size]("cardinality"),
http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 4dda525..d382d9a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -765,6 +765,40 @@ class CodegenContext {
}
/**
+ * Generates code creating a [[UnsafeArrayData]]. The generated code executes
+ * a provided fallback when the size of backing array would exceed the array size limit.
+ * @param arrayName a name of the array to create
+ * @param numElements a piece of code representing the number of elements the array should contain
+ * @param elementSize a size of an element in bytes
+ * @param bodyCode a function generating code that fills up the [[UnsafeArrayData]]
+ * and getting the backing array as a parameter
+ * @param fallbackCode a piece of code executed when the array size limit is exceeded
+ */
+ def createUnsafeArrayWithFallback(
+ arrayName: String,
+ numElements: String,
+ elementSize: Int,
+ bodyCode: String => String,
+ fallbackCode: String): String = {
+ val arraySize = freshName("size")
+ val arrayBytes = freshName("arrayBytes")
+ s"""
+ |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
+ | $numElements,
+ | $elementSize);
+ |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+ | $fallbackCode
+ |} else {
+ | final byte[] $arrayBytes = new byte[(int)$arraySize];
+ | UnsafeArrayData $arrayName = new UnsafeArrayData();
+ | Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements);
+ | $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize);
+ | ${bodyCode(arrayBytes)}
+ |}
+ """.stripMargin
+ }
+
+ /**
* Generates code to do null safe execution, i.e. only execute the code when the input is not
* null by adding null check if necessary.
*
http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index c82db83..8d763dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@@ -155,6 +156,158 @@ case class MapValues(child: Expression)
}
/**
+ * Returns an unordered array of all entries in the given map.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(map(1, 'a', 2, 'b'));
+ [(1,"a"),(2,"b")]
+ """,
+ since = "2.4.0")
+case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
+
+ lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType]
+
+ override def dataType: DataType = {
+ ArrayType(
+ StructType(
+ StructField("key", childDataType.keyType, false) ::
+ StructField("value", childDataType.valueType, childDataType.valueContainsNull) ::
+ Nil),
+ false)
+ }
+
+ override protected def nullSafeEval(input: Any): Any = {
+ val childMap = input.asInstanceOf[MapData]
+ val keys = childMap.keyArray()
+ val values = childMap.valueArray()
+ val length = childMap.numElements()
+ val resultData = new Array[AnyRef](length)
+ var i = 0;
+ while (i < length) {
+ val key = keys.get(i, childDataType.keyType)
+ val value = values.get(i, childDataType.valueType)
+ val row = new GenericInternalRow(Array[Any](key, value))
+ resultData.update(i, row)
+ i += 1
+ }
+ new GenericArrayData(resultData)
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, c => {
+ val numElements = ctx.freshName("numElements")
+ val keys = ctx.freshName("keys")
+ val values = ctx.freshName("values")
+ val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType)
+ val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
+ val code = if (isKeyPrimitive && isValuePrimitive) {
+ genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements)
+ } else {
+ genCodeForAnyElements(ctx, keys, values, ev.value, numElements)
+ }
+ s"""
+ |final int $numElements = $c.numElements();
+ |final ArrayData $keys = $c.keyArray();
+ |final ArrayData $values = $c.valueArray();
+ |$code
+ """.stripMargin
+ })
+ }
+
+ private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z")
+
+ private def getValue(varName: String) = {
+ CodeGenerator.getValue(varName, childDataType.valueType, "z")
+ }
+
+ private def genCodeForPrimitiveElements(
+ ctx: CodegenContext,
+ keys: String,
+ values: String,
+ arrayData: String,
+ numElements: String): String = {
+ val unsafeRow = ctx.freshName("unsafeRow")
+ val unsafeArrayData = ctx.freshName("unsafeArrayData")
+ val structsOffset = ctx.freshName("structsOffset")
+ val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"
+
+ val baseOffset = Platform.BYTE_ARRAY_OFFSET
+ val wordSize = UnsafeRow.WORD_SIZE
+ val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2
+ val structSizeAsLong = structSize + "L"
+ val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)
+ val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)
+
+ val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});"
+ val valueAssignmentChecked = if (childDataType.valueContainsNull) {
+ s"""
+ |if ($values.isNullAt(z)) {
+ | $unsafeRow.setNullAt(1);
+ |} else {
+ | $valueAssignment
+ |}
+ """.stripMargin
+ } else {
+ valueAssignment
+ }
+
+ val assignmentLoop = (byteArray: String) =>
+ s"""
+ |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize;
+ |UnsafeRow $unsafeRow = new UnsafeRow(2);
+ |for (int z = 0; z < $numElements; z++) {
+ | long offset = $structsOffset + z * $structSizeAsLong;
+ | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong);
+ | $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize);
+ | $unsafeRow.set$keyTypeName(0, ${getKey(keys)});
+ | $valueAssignmentChecked
+ |}
+ |$arrayData = $unsafeArrayData;
+ """.stripMargin
+
+ ctx.createUnsafeArrayWithFallback(
+ unsafeArrayData,
+ numElements,
+ structSize + wordSize,
+ assignmentLoop,
+ genCodeForAnyElements(ctx, keys, values, arrayData, numElements))
+ }
+
+ private def genCodeForAnyElements(
+ ctx: CodegenContext,
+ keys: String,
+ values: String,
+ arrayData: String,
+ numElements: String): String = {
+ val genericArrayClass = classOf[GenericArrayData].getName
+ val rowClass = classOf[GenericInternalRow].getName
+ val data = ctx.freshName("internalRowArray")
+
+ val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
+ val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) {
+ s"$values.isNullAt(z) ? null : (Object)${getValue(values)}"
+ } else {
+ getValue(values)
+ }
+
+ s"""
+ |final Object[] $data = new Object[$numElements];
+ |for (int z = 0; z < $numElements; z++) {
+ | $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck});
+ |}
+ |$arrayData = new $genericArrayClass($data);
+ """.stripMargin
+ }
+
+ override def prettyName: String = "map_entries"
+}
+
+/**
* Common base class for [[SortArray]] and [[ArraySort]].
*/
trait ArraySortLike extends ExpectsInputTypes {
http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 6ae1ac1..71ff96b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -56,6 +57,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MapValues(m2), null)
}
+ test("MapEntries") {
+ def r(values: Any*): InternalRow = create_row(values: _*)
+
+ // Primitive-type keys/values
+ val mi0 = Literal.create(Map(1 -> 1, 2 -> null, 3 -> 2), MapType(IntegerType, IntegerType))
+ val mi1 = Literal.create(Map[Int, Int](), MapType(IntegerType, IntegerType))
+ val mi2 = Literal.create(null, MapType(IntegerType, IntegerType))
+
+ checkEvaluation(MapEntries(mi0), Seq(r(1, 1), r(2, null), r(3, 2)))
+ checkEvaluation(MapEntries(mi1), Seq.empty)
+ checkEvaluation(MapEntries(mi2), null)
+
+ // Non-primitive-type keys/values
+ val ms0 = Literal.create(Map("a" -> "c", "b" -> null), MapType(StringType, StringType))
+ val ms1 = Literal.create(Map[Int, Int](), MapType(StringType, StringType))
+ val ms2 = Literal.create(null, MapType(StringType, StringType))
+
+ checkEvaluation(MapEntries(ms0), Seq(r("a", "c"), r("b", null)))
+ checkEvaluation(MapEntries(ms1), Seq.empty)
+ checkEvaluation(MapEntries(ms2), null)
+ }
+
test("Sort Array") {
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index a22e9d4..c2a44e0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -98,6 +98,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
if (expected.isNaN) result.isNaN else expected == result
case (result: Float, expected: Float) =>
if (expected.isNaN) result.isNaN else expected == result
+ case (result: UnsafeRow, expected: GenericInternalRow) =>
+ val structType = exprDataType.asInstanceOf[StructType]
+ result.toSeq(structType) == expected.toSeq(structType)
case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema)
case _ =>
result == expected
http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 2a8fe58..5ab9cb3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3492,6 +3492,13 @@ object functions {
*/
def map_values(e: Column): Column = withExpr { MapValues(e.expr) }
+ /**
+ * Returns an unordered array of all entries in the given map.
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) }
+
// scalastyle:off line.size.limit
// scalastyle:off parameter.number
http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index d08982a..df23e07 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -405,6 +405,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}
+ test("map_entries") {
+ val dummyFilter = (c: Column) => c.isNotNull || c.isNull
+
+ // Primitive-type elements
+ val idf = Seq(
+ Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300),
+ Map[Int, Int](),
+ null
+ ).toDF("m")
+ val iExpected = Seq(
+ Row(Seq(Row(1, 100), Row(2, 200), Row(3, 300))),
+ Row(Seq.empty),
+ Row(null)
+ )
+
+ checkAnswer(idf.select(map_entries('m)), iExpected)
+ checkAnswer(idf.selectExpr("map_entries(m)"), iExpected)
+ checkAnswer(idf.filter(dummyFilter('m)).select(map_entries('m)), iExpected)
+ checkAnswer(
+ spark.range(1).selectExpr("map_entries(map(1, null, 2, null))"),
+ Seq(Row(Seq(Row(1, null), Row(2, null)))))
+ checkAnswer(
+ spark.range(1).filter(dummyFilter('id)).selectExpr("map_entries(map(1, null, 2, null))"),
+ Seq(Row(Seq(Row(1, null), Row(2, null)))))
+
+ // Non-primitive-type elements
+ val sdf = Seq(
+ Map[String, String]("a" -> "f", "b" -> "o", "c" -> "o"),
+ Map[String, String]("a" -> null, "b" -> null),
+ Map[String, String](),
+ null
+ ).toDF("m")
+ val sExpected = Seq(
+ Row(Seq(Row("a", "f"), Row("b", "o"), Row("c", "o"))),
+ Row(Seq(Row("a", null), Row("b", null))),
+ Row(Seq.empty),
+ Row(null)
+ )
+
+ checkAnswer(sdf.select(map_entries('m)), sExpected)
+ checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected)
+ checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected)
+ }
+
test("array contains function") {
val df = Seq(
(Seq[Int](1, 2), "x"),
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org