You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2018/05/21 14:14:13 UTC

spark git commit: [SPARK-23935][SQL] Adding map_entries function

Repository: spark
Updated Branches:
  refs/heads/master e480eccd9 -> a6e883feb


[SPARK-23935][SQL] Adding map_entries function

## What changes were proposed in this pull request?

This PR adds `map_entries` function that returns an unordered array of all entries in the given map.

## How was this patch tested?

New tests added into:
- `CollectionExpressionSuite`
- `DataFrameFunctionsSuite`

## CodeGen examples
### Primitive types
```
val df = Seq(Map(1 -> 5, 2 -> 6)).toDF("m")
df.filter('m.isNotNull).select(map_entries('m)).debugCodegen
```
Result:
```
/* 042 */         boolean project_isNull_0 = false;
/* 043 */
/* 044 */         ArrayData project_value_0 = null;
/* 045 */
/* 046 */         final int project_numElements_0 = inputadapter_value_0.numElements();
/* 047 */         final ArrayData project_keys_0 = inputadapter_value_0.keyArray();
/* 048 */         final ArrayData project_values_0 = inputadapter_value_0.valueArray();
/* 049 */
/* 050 */         final long project_size_0 = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
/* 051 */           project_numElements_0,
/* 052 */           32);
/* 053 */         if (project_size_0 > 2147483632) {
/* 054 */           final Object[] project_internalRowArray_0 = new Object[project_numElements_0];
/* 055 */           for (int z = 0; z < project_numElements_0; z++) {
/* 056 */             project_internalRowArray_0[z] = new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(new Object[]{project_keys_0.getInt(z), project_values_0.getInt(z)});
/* 057 */           }
/* 058 */           project_value_0 = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_internalRowArray_0);
/* 059 */
/* 060 */         } else {
/* 061 */           final byte[] project_arrayBytes_0 = new byte[(int)project_size_0];
/* 062 */           UnsafeArrayData project_unsafeArrayData_0 = new UnsafeArrayData();
/* 063 */           Platform.putLong(project_arrayBytes_0, 16, project_numElements_0);
/* 064 */           project_unsafeArrayData_0.pointTo(project_arrayBytes_0, 16, (int)project_size_0);
/* 065 */
/* 066 */           final int project_structsOffset_0 = UnsafeArrayData.calculateHeaderPortionInBytes(project_numElements_0) + project_numElements_0 * 8;
/* 067 */           UnsafeRow project_unsafeRow_0 = new UnsafeRow(2);
/* 068 */           for (int z = 0; z < project_numElements_0; z++) {
/* 069 */             long offset = project_structsOffset_0 + z * 24L;
/* 070 */             project_unsafeArrayData_0.setLong(z, (offset << 32) + 24L);
/* 071 */             project_unsafeRow_0.pointTo(project_arrayBytes_0, 16 + offset, 24);
/* 072 */             project_unsafeRow_0.setInt(0, project_keys_0.getInt(z));
/* 073 */             project_unsafeRow_0.setInt(1, project_values_0.getInt(z));
/* 074 */           }
/* 075 */           project_value_0 = project_unsafeArrayData_0;
/* 076 */
/* 077 */         }
```
### Non-primitive types
```
val df = Seq(Map("a" -> "foo", "b" -> null)).toDF("m")
df.filter('m.isNotNull).select(map_entries('m)).debugCodegen
```
Result:
```
/* 042 */         boolean project_isNull_0 = false;
/* 043 */
/* 044 */         ArrayData project_value_0 = null;
/* 045 */
/* 046 */         final int project_numElements_0 = inputadapter_value_0.numElements();
/* 047 */         final ArrayData project_keys_0 = inputadapter_value_0.keyArray();
/* 048 */         final ArrayData project_values_0 = inputadapter_value_0.valueArray();
/* 049 */
/* 050 */         final Object[] project_internalRowArray_0 = new Object[project_numElements_0];
/* 051 */         for (int z = 0; z < project_numElements_0; z++) {
/* 052 */           project_internalRowArray_0[z] = new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(new Object[]{project_keys_0.getUTF8String(z), project_values_0.getUTF8String(z)});
/* 053 */         }
/* 054 */         project_value_0 = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_internalRowArray_0);
```

Author: Marek Novotny <mn...@gmail.com>

Closes #21236 from mn-mikke/feature/array-api-map_entries-to-master.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a6e883fe
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a6e883fe
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a6e883fe

Branch: refs/heads/master
Commit: a6e883feb3b78232ad5cf636f7f7d5e825183041
Parents: e480ecc
Author: Marek Novotny <mn...@gmail.com>
Authored: Mon May 21 23:14:03 2018 +0900
Committer: Takuya UESHIN <ue...@databricks.com>
Committed: Mon May 21 23:14:03 2018 +0900

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 |  20 +++
 .../sql/catalyst/expressions/UnsafeRow.java     |   2 +
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../expressions/codegen/CodeGenerator.scala     |  34 +++++
 .../expressions/collectionOperations.scala      | 153 +++++++++++++++++++
 .../CollectionExpressionsSuite.scala            |  23 +++
 .../expressions/ExpressionEvalHelper.scala      |   3 +
 .../scala/org/apache/spark/sql/functions.scala  |   7 +
 .../spark/sql/DataFrameFunctionsSuite.scala     |  44 ++++++
 9 files changed, 287 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 8490081..fbc8a2d 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2344,6 +2344,26 @@ def map_values(col):
     return Column(sc._jvm.functions.map_values(_to_java_column(col)))
 
 
+@since(2.4)
+def map_entries(col):
+    """
+    Collection function: Returns an unordered array of all entries in the given map.
+
+    :param col: name of column or expression
+
+    >>> from pyspark.sql.functions import map_entries
+    >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")
+    >>> df.select(map_entries("data").alias("entries")).show()
+    +----------------+
+    |         entries|
+    +----------------+
+    |[[1, a], [2, b]]|
+    +----------------+
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.map_entries(_to_java_column(col)))
+
+
 @ignore_unicode_prefix
 @since(2.4)
 def array_repeat(col, count):

http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 29a1411..469b0e6 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -62,6 +62,8 @@ import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
  */
 public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable {
 
+  public static final int WORD_SIZE = 8;
+
   //////////////////////////////////////////////////////////////////////////////
   // Static methods
   //////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 867c2d5..1134a88 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -419,6 +419,7 @@ object FunctionRegistry {
     expression[ElementAt]("element_at"),
     expression[MapKeys]("map_keys"),
     expression[MapValues]("map_values"),
+    expression[MapEntries]("map_entries"),
     expression[Size]("size"),
     expression[Slice]("slice"),
     expression[Size]("cardinality"),

http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 4dda525..d382d9a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -765,6 +765,40 @@ class CodegenContext {
   }
 
   /**
+   * Generates code creating a [[UnsafeArrayData]]. The generated code executes
+   * a provided fallback when the size of backing array would exceed the array size limit.
+   * @param arrayName a name of the array to create
+   * @param numElements a piece of code representing the number of elements the array should contain
+   * @param elementSize a size of an element in bytes
+   * @param bodyCode a function generating code that fills up the [[UnsafeArrayData]]
+   *                 and getting the backing array as a parameter
+   * @param fallbackCode a piece of code executed when the array size limit is exceeded
+   */
+  def createUnsafeArrayWithFallback(
+      arrayName: String,
+      numElements: String,
+      elementSize: Int,
+      bodyCode: String => String,
+      fallbackCode: String): String = {
+    val arraySize = freshName("size")
+    val arrayBytes = freshName("arrayBytes")
+    s"""
+       |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
+       |  $numElements,
+       |  $elementSize);
+       |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+       |  $fallbackCode
+       |} else {
+       |  final byte[] $arrayBytes = new byte[(int)$arraySize];
+       |  UnsafeArrayData $arrayName = new UnsafeArrayData();
+       |  Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements);
+       |  $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize);
+       |  ${bodyCode(arrayBytes)}
+       |}
+     """.stripMargin
+  }
+
+  /**
    * Generates code to do null safe execution, i.e. only execute the code when the input is not
    * null by adding null check if necessary.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index c82db83..8d763dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.array.ByteArrayMethods
 import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
 
@@ -155,6 +156,158 @@ case class MapValues(child: Expression)
 }
 
 /**
+ * Returns an unordered array of all entries in the given map.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(map(1, 'a', 2, 'b'));
+       [(1,"a"),(2,"b")]
+  """,
+  since = "2.4.0")
+case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
+
+  lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType]
+
+  override def dataType: DataType = {
+    ArrayType(
+      StructType(
+        StructField("key", childDataType.keyType, false) ::
+        StructField("value", childDataType.valueType, childDataType.valueContainsNull) ::
+        Nil),
+      false)
+  }
+
+  override protected def nullSafeEval(input: Any): Any = {
+    val childMap = input.asInstanceOf[MapData]
+    val keys = childMap.keyArray()
+    val values = childMap.valueArray()
+    val length = childMap.numElements()
+    val resultData = new Array[AnyRef](length)
+    var i = 0;
+    while (i < length) {
+      val key = keys.get(i, childDataType.keyType)
+      val value = values.get(i, childDataType.valueType)
+      val row = new GenericInternalRow(Array[Any](key, value))
+      resultData.update(i, row)
+      i += 1
+    }
+    new GenericArrayData(resultData)
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    nullSafeCodeGen(ctx, ev, c => {
+      val numElements = ctx.freshName("numElements")
+      val keys = ctx.freshName("keys")
+      val values = ctx.freshName("values")
+      val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType)
+      val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
+      val code = if (isKeyPrimitive && isValuePrimitive) {
+        genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements)
+      } else {
+        genCodeForAnyElements(ctx, keys, values, ev.value, numElements)
+      }
+      s"""
+         |final int $numElements = $c.numElements();
+         |final ArrayData $keys = $c.keyArray();
+         |final ArrayData $values = $c.valueArray();
+         |$code
+       """.stripMargin
+    })
+  }
+
+  private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z")
+
+  private def getValue(varName: String) = {
+    CodeGenerator.getValue(varName, childDataType.valueType, "z")
+  }
+
+  private def genCodeForPrimitiveElements(
+      ctx: CodegenContext,
+      keys: String,
+      values: String,
+      arrayData: String,
+      numElements: String): String = {
+    val unsafeRow = ctx.freshName("unsafeRow")
+    val unsafeArrayData = ctx.freshName("unsafeArrayData")
+    val structsOffset = ctx.freshName("structsOffset")
+    val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"
+
+    val baseOffset = Platform.BYTE_ARRAY_OFFSET
+    val wordSize = UnsafeRow.WORD_SIZE
+    val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2
+    val structSizeAsLong = structSize + "L"
+    val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)
+    val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)
+
+    val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});"
+    val valueAssignmentChecked = if (childDataType.valueContainsNull) {
+      s"""
+         |if ($values.isNullAt(z)) {
+         |  $unsafeRow.setNullAt(1);
+         |} else {
+         |  $valueAssignment
+         |}
+       """.stripMargin
+    } else {
+      valueAssignment
+    }
+
+    val assignmentLoop = (byteArray: String) =>
+      s"""
+         |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize;
+         |UnsafeRow $unsafeRow = new UnsafeRow(2);
+         |for (int z = 0; z < $numElements; z++) {
+         |  long offset = $structsOffset + z * $structSizeAsLong;
+         |  $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong);
+         |  $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize);
+         |  $unsafeRow.set$keyTypeName(0, ${getKey(keys)});
+         |  $valueAssignmentChecked
+         |}
+         |$arrayData = $unsafeArrayData;
+       """.stripMargin
+
+    ctx.createUnsafeArrayWithFallback(
+      unsafeArrayData,
+      numElements,
+      structSize + wordSize,
+      assignmentLoop,
+      genCodeForAnyElements(ctx, keys, values, arrayData, numElements))
+  }
+
+  private def genCodeForAnyElements(
+      ctx: CodegenContext,
+      keys: String,
+      values: String,
+      arrayData: String,
+      numElements: String): String = {
+    val genericArrayClass = classOf[GenericArrayData].getName
+    val rowClass = classOf[GenericInternalRow].getName
+    val data = ctx.freshName("internalRowArray")
+
+    val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
+    val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) {
+      s"$values.isNullAt(z) ? null : (Object)${getValue(values)}"
+    } else {
+      getValue(values)
+    }
+
+    s"""
+       |final Object[] $data = new Object[$numElements];
+       |for (int z = 0; z < $numElements; z++) {
+       |  $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck});
+       |}
+       |$arrayData = new $genericArrayClass($data);
+     """.stripMargin
+  }
+
+  override def prettyName: String = "map_entries"
+}
+
+/**
  * Common base class for [[SortArray]] and [[ArraySort]].
  */
 trait ArraySortLike extends ExpectsInputTypes {

http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 6ae1ac1..71ff96b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
 
 class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -56,6 +57,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(MapValues(m2), null)
   }
 
+  test("MapEntries") {
+    def r(values: Any*): InternalRow = create_row(values: _*)
+
+    // Primitive-type keys/values
+    val mi0 = Literal.create(Map(1 -> 1, 2 -> null, 3 -> 2), MapType(IntegerType, IntegerType))
+    val mi1 = Literal.create(Map[Int, Int](), MapType(IntegerType, IntegerType))
+    val mi2 = Literal.create(null, MapType(IntegerType, IntegerType))
+
+    checkEvaluation(MapEntries(mi0), Seq(r(1, 1), r(2, null), r(3, 2)))
+    checkEvaluation(MapEntries(mi1), Seq.empty)
+    checkEvaluation(MapEntries(mi2), null)
+
+    // Non-primitive-type keys/values
+    val ms0 = Literal.create(Map("a" -> "c", "b" -> null), MapType(StringType, StringType))
+    val ms1 = Literal.create(Map[Int, Int](), MapType(StringType, StringType))
+    val ms2 = Literal.create(null, MapType(StringType, StringType))
+
+    checkEvaluation(MapEntries(ms0), Seq(r("a", "c"), r("b", null)))
+    checkEvaluation(MapEntries(ms1), Seq.empty)
+    checkEvaluation(MapEntries(ms2), null)
+  }
+
   test("Sort Array") {
     val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
     val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))

http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index a22e9d4..c2a44e0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -98,6 +98,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
         if (expected.isNaN) result.isNaN else expected == result
       case (result: Float, expected: Float) =>
         if (expected.isNaN) result.isNaN else expected == result
+      case (result: UnsafeRow, expected: GenericInternalRow) =>
+        val structType = exprDataType.asInstanceOf[StructType]
+        result.toSeq(structType) == expected.toSeq(structType)
       case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema)
       case _ =>
         result == expected

http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 2a8fe58..5ab9cb3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3492,6 +3492,13 @@ object functions {
    */
   def map_values(e: Column): Column = withExpr { MapValues(e.expr) }
 
+  /**
+   * Returns an unordered array of all entries in the given map.
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) }
+
   // scalastyle:off line.size.limit
   // scalastyle:off parameter.number
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a6e883fe/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index d08982a..df23e07 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -405,6 +405,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
     )
   }
 
+  test("map_entries") {
+    val dummyFilter = (c: Column) => c.isNotNull || c.isNull
+
+    // Primitive-type elements
+    val idf = Seq(
+      Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300),
+      Map[Int, Int](),
+      null
+    ).toDF("m")
+    val iExpected = Seq(
+      Row(Seq(Row(1, 100), Row(2, 200), Row(3, 300))),
+      Row(Seq.empty),
+      Row(null)
+    )
+
+    checkAnswer(idf.select(map_entries('m)), iExpected)
+    checkAnswer(idf.selectExpr("map_entries(m)"), iExpected)
+    checkAnswer(idf.filter(dummyFilter('m)).select(map_entries('m)), iExpected)
+    checkAnswer(
+      spark.range(1).selectExpr("map_entries(map(1, null, 2, null))"),
+      Seq(Row(Seq(Row(1, null), Row(2, null)))))
+    checkAnswer(
+      spark.range(1).filter(dummyFilter('id)).selectExpr("map_entries(map(1, null, 2, null))"),
+      Seq(Row(Seq(Row(1, null), Row(2, null)))))
+
+    // Non-primitive-type elements
+    val sdf = Seq(
+      Map[String, String]("a" -> "f", "b" -> "o", "c" -> "o"),
+      Map[String, String]("a" -> null, "b" -> null),
+      Map[String, String](),
+      null
+    ).toDF("m")
+    val sExpected = Seq(
+      Row(Seq(Row("a", "f"), Row("b", "o"), Row("c", "o"))),
+      Row(Seq(Row("a", null), Row("b", null))),
+      Row(Seq.empty),
+      Row(null)
+    )
+
+    checkAnswer(sdf.select(map_entries('m)), sExpected)
+    checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected)
+    checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected)
+  }
+
   test("array contains function") {
     val df = Seq(
       (Seq[Int](1, 2), "x"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org