You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2018/07/09 12:21:45 UTC

spark git commit: [SPARK-23936][SQL] Implement map_concat

Repository: spark
Updated Branches:
  refs/heads/master e2c7e09f7 -> 034913b62


[SPARK-23936][SQL] Implement map_concat

## What changes were proposed in this pull request?

Implement map_concat high order function.

This implementation does not pick a winner when the specified maps have overlapping keys. Therefore, this implementation preserves existing duplicate keys in the maps and potentially introduces new duplicates (After discussion with ueshin, we settled on option 1 from [here](https://issues.apache.org/jira/browse/SPARK-23936?focusedCommentId=16464245&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-16464245)).

## How was this patch tested?

New tests
Manual tests
Run all sbt SQL tests
Run all pyspark sql tests

Author: Bruce Robbins <be...@gmail.com>

Closes #21073 from bersprockets/SPARK-23936.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/034913b6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/034913b6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/034913b6

Branch: refs/heads/master
Commit: 034913b62b579ae003431231c0272513de8f496c
Parents: e2c7e09
Author: Bruce Robbins <be...@gmail.com>
Authored: Mon Jul 9 21:21:38 2018 +0900
Committer: Takuya UESHIN <ue...@databricks.com>
Committed: Mon Jul 9 21:21:38 2018 +0900

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 |  22 ++
 .../sql/catalyst/CatalystTypeConverters.scala   |   6 +
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../sql/catalyst/analysis/TypeCoercion.scala    |   8 +
 .../expressions/collectionOperations.scala      | 231 +++++++++++++++++++
 .../CollectionExpressionsSuite.scala            | 126 ++++++++++
 .../scala/org/apache/spark/sql/functions.scala  |   8 +
 .../inputs/typeCoercion/native/mapconcat.sql    |  94 ++++++++
 .../typeCoercion/native/mapconcat.sql.out       | 143 ++++++++++++
 .../spark/sql/DataFrameFunctionsSuite.scala     |  78 +++++++
 10 files changed, 717 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/034913b6/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 55e7d57..9f61e29 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2510,6 +2510,28 @@ def arrays_zip(*cols):
     return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column)))
 
 
+@since(2.4)
+def map_concat(*cols):
+    """Returns the union of all the given maps.
+
+    :param cols: list of column names (string) or list of :class:`Column` expressions
+
+    >>> from pyspark.sql.functions import map_concat
+    >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2")
+    >>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False)
+    +--------------------------------+
+    |map3                            |
+    +--------------------------------+
+    |[1 -> a, 2 -> b, 3 -> c, 1 -> d]|
+    +--------------------------------+
+    """
+    sc = SparkContext._active_spark_context
+    if len(cols) == 1 and isinstance(cols[0], (list, set)):
+        cols = cols[0]
+    jc = sc._jvm.functions.map_concat(_to_seq(sc, cols, _to_java_column))
+    return Column(jc)
+
+
 # ---------------------------- User Defined Function ----------------------------------
 
 class PandasUDFType(object):

http://git-wip-us.apache.org/repos/asf/spark/blob/034913b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 93df73a..6f5fbdd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -431,6 +431,12 @@ object CatalystTypeConverters {
         map,
         (key: Any) => convertToCatalyst(key),
         (value: Any) => convertToCatalyst(value))
+    case (keys: Array[_], values: Array[_]) =>
+      // case for mapdata with duplicate keys
+      new ArrayBasedMapData(
+        new GenericArrayData(keys.map(convertToCatalyst)),
+        new GenericArrayData(values.map(convertToCatalyst))
+      )
     case other => other
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/034913b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 80a0af6..e7517e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -422,6 +422,7 @@ object FunctionRegistry {
     expression[MapValues]("map_values"),
     expression[MapEntries]("map_entries"),
     expression[MapFromEntries]("map_from_entries"),
+    expression[MapConcat]("map_concat"),
     expression[Size]("size"),
     expression[Slice]("slice"),
     expression[Size]("cardinality"),

http://git-wip-us.apache.org/repos/asf/spark/blob/034913b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index b6ca30c..72908c1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -548,6 +548,14 @@ object TypeCoercion {
           case None => s
         }
 
+      case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) &&
+        !haveSameType(children) =>
+        val types = children.map(_.dataType)
+        findWiderCommonType(types) match {
+          case Some(finalDataType) => MapConcat(children.map(Cast(_, finalDataType)))
+          case None => m
+        }
+
       case m @ CreateMap(children) if m.keys.length == m.values.length &&
         (!haveSameType(m.keys) || !haveSameType(m.values)) =>
         val newKeys = if (haveSameType(m.keys)) {

http://git-wip-us.apache.org/repos/asf/spark/blob/034913b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index fcac3a5..879603b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -504,6 +504,237 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
 }
 
 /**
+ * Returns the union of all the given maps.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(map, ...) - Returns the union of all the given maps",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd'));
+       [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]]
+  """, since = "2.4.0")
+case class MapConcat(children: Seq[Expression]) extends Expression {
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    var funcName = s"function $prettyName"
+    if (children.exists(!_.dataType.isInstanceOf[MapType])) {
+      TypeCheckResult.TypeCheckFailure(
+        s"input to $funcName should all be of type map, but it's " +
+          children.map(_.dataType.simpleString).mkString("[", ", ", "]"))
+    } else {
+      TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName)
+    }
+  }
+
+  override def dataType: MapType = {
+    val dt = children.map(_.dataType.asInstanceOf[MapType]).headOption
+      .getOrElse(MapType(StringType, StringType))
+    val valueContainsNull = children.map(_.dataType.asInstanceOf[MapType])
+      .exists(_.valueContainsNull)
+    if (dt.valueContainsNull != valueContainsNull) {
+      dt.copy(valueContainsNull = valueContainsNull)
+    } else {
+      dt
+    }
+  }
+
+  override def nullable: Boolean = children.exists(_.nullable)
+
+  override def eval(input: InternalRow): Any = {
+    val maps = children.map(_.eval(input))
+    if (maps.contains(null)) {
+      return null
+    }
+    val keyArrayDatas = maps.map(_.asInstanceOf[MapData].keyArray())
+    val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray())
+
+    val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum + ad.numElements())
+    if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+      throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " +
+        s"elements due to exceeding the map size limit " +
+        s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
+    }
+    val finalKeyArray = new Array[AnyRef](numElements.toInt)
+    val finalValueArray = new Array[AnyRef](numElements.toInt)
+    var position = 0
+    for (i <- keyArrayDatas.indices) {
+      val keyArray = keyArrayDatas(i).toObjectArray(dataType.keyType)
+      val valueArray = valueArrayDatas(i).toObjectArray(dataType.valueType)
+      Array.copy(keyArray, 0, finalKeyArray, position, keyArray.length)
+      Array.copy(valueArray, 0, finalValueArray, position, valueArray.length)
+      position += keyArray.length
+    }
+
+    new ArrayBasedMapData(new GenericArrayData(finalKeyArray),
+      new GenericArrayData(finalValueArray))
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val mapCodes = children.map(_.genCode(ctx))
+    val keyType = dataType.keyType
+    val valueType = dataType.valueType
+    val argsName = ctx.freshName("args")
+    val hasNullName = ctx.freshName("hasNull")
+    val mapDataClass = classOf[MapData].getName
+    val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName
+    val arrayDataClass = classOf[ArrayData].getName
+
+    val init =
+      s"""
+        |$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}];
+        |boolean ${ev.isNull}, $hasNullName = false;
+        |$mapDataClass ${ev.value} = null;
+      """.stripMargin
+
+    val assignments = mapCodes.zipWithIndex.map { case (m, i) =>
+      s"""
+         |if (!$hasNullName) {
+         |  ${m.code}
+         |  $argsName[$i] = ${m.value};
+         |  if (${m.isNull}) {
+         |    $hasNullName = true;
+         |  }
+         |}
+       """.stripMargin
+    }
+
+    val codes = ctx.splitExpressionsWithCurrentInputs(
+      expressions = assignments,
+      funcName = "getMapConcatInputs",
+      extraArguments = (s"$mapDataClass[]", argsName) :: ("boolean", hasNullName) :: Nil,
+      returnType = "boolean",
+      makeSplitFunction = body =>
+        s"""
+           |$body
+           |return $hasNullName;
+        """.stripMargin,
+      foldFunctions = _.map(funcCall => s"$hasNullName = $funcCall;").mkString("\n")
+    )
+
+    val idxName = ctx.freshName("idx")
+    val numElementsName = ctx.freshName("numElems")
+    val finKeysName = ctx.freshName("finalKeys")
+    val finValsName = ctx.freshName("finalValues")
+
+    val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) {
+      genCodeForPrimitiveArrays(ctx, keyType, false)
+    } else {
+      genCodeForNonPrimitiveArrays(ctx, keyType)
+    }
+
+    val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) {
+      genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull)
+    } else {
+      genCodeForNonPrimitiveArrays(ctx, valueType)
+    }
+
+    val keyArgsName = ctx.freshName("keyArgs")
+    val valArgsName = ctx.freshName("valArgs")
+
+    val mapMerge =
+      s"""
+        |${ev.isNull} = $hasNullName;
+        |if (!${ev.isNull}) {
+        |  $arrayDataClass[] $keyArgsName = new $arrayDataClass[${mapCodes.size}];
+        |  $arrayDataClass[] $valArgsName = new $arrayDataClass[${mapCodes.size}];
+        |  long $numElementsName = 0;
+        |  for (int $idxName = 0; $idxName < $argsName.length; $idxName++) {
+        |    $keyArgsName[$idxName] = $argsName[$idxName].keyArray();
+        |    $valArgsName[$idxName] = $argsName[$idxName].valueArray();
+        |    $numElementsName += $argsName[$idxName].numElements();
+        |  }
+        |  if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+        |    throw new RuntimeException("Unsuccessful attempt to concat maps with " +
+        |       $numElementsName + " elements due to exceeding the map size limit " +
+        |       "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
+        |  }
+        |  $arrayDataClass $finKeysName = $keyConcatenator.concat($keyArgsName,
+        |    (int) $numElementsName);
+        |  $arrayDataClass $finValsName = $valueConcatenator.concat($valArgsName,
+        |    (int) $numElementsName);
+        |  ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName);
+        |}
+      """.stripMargin
+
+    ev.copy(
+      code = code"""
+        |$init
+        |$codes
+        |$mapMerge
+      """.stripMargin)
+  }
+
+  private def genCodeForPrimitiveArrays(
+      ctx: CodegenContext,
+      elementType: DataType,
+      checkForNull: Boolean): String = {
+    val counter = ctx.freshName("counter")
+    val arrayData = ctx.freshName("arrayData")
+    val argsName = ctx.freshName("args")
+    val numElemName = ctx.freshName("numElements")
+    val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+
+    val setterCode1 =
+      s"""
+         |$arrayData.set$primitiveValueTypeName(
+         |  $counter,
+         |  ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}
+         |);""".stripMargin
+
+    val setterCode = if (checkForNull) {
+      s"""
+         |if ($argsName[y].isNullAt(z)) {
+         |  $arrayData.setNullAt($counter);
+         |} else {
+         |  $setterCode1
+         |}""".stripMargin
+    } else {
+      setterCode1
+    }
+
+    s"""
+       |new Object() {
+       |  public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {
+       |    ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
+       |    int $counter = 0;
+       |    for (int y = 0; y < ${children.length}; y++) {
+       |      for (int z = 0; z < $argsName[y].numElements(); z++) {
+       |        $setterCode
+       |        $counter++;
+       |      }
+       |    }
+       |    return $arrayData;
+       |  }
+       |}""".stripMargin.stripPrefix("\n")
+  }
+
+  private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
+    val genericArrayClass = classOf[GenericArrayData].getName
+    val arrayData = ctx.freshName("arrayObjects")
+    val counter = ctx.freshName("counter")
+    val argsName = ctx.freshName("args")
+    val numElemName = ctx.freshName("numElements")
+
+    s"""
+       |new Object() {
+       |  public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {;
+       |    Object[] $arrayData = new Object[$numElemName];
+       |    int $counter = 0;
+       |    for (int y = 0; y < ${children.length}; y++) {
+       |      for (int z = 0; z < $argsName[y].numElements(); z++) {
+       |        $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")};
+       |        $counter++;
+       |      }
+       |    }
+       |    return new $genericArrayClass($arrayData);
+       |  }
+       |}""".stripMargin.stripPrefix("\n")
+  }
+
+  override def prettyName: String = "map_concat"
+}
+
+/**
  * Returns a map created from the given array of entries.
  */
 @ExpressionDescription(

http://git-wip-us.apache.org/repos/asf/spark/blob/034913b6/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 496ee1d..173c98a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -98,6 +98,132 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(MapEntries(ms2), null)
   }
 
+  test("Map Concat") {
+    val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType,
+      valueContainsNull = false))
+    val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType,
+      valueContainsNull = false))
+    val m2 = Literal.create(Map("d" -> "4", "e" -> "5"), MapType(StringType, StringType))
+    val m3 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType))
+    val m4 = Literal.create(Map("a" -> null, "c" -> "3"), MapType(StringType, StringType))
+    val m5 = Literal.create(Map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType))
+    val m6 = Literal.create(Map("a" -> null, "c" -> 3), MapType(StringType, IntegerType))
+    val m7 = Literal.create(Map(List(1, 2) -> 1, List(3, 4) -> 2),
+      MapType(ArrayType(IntegerType), IntegerType))
+    val m8 = Literal.create(Map(List(5, 6) -> 3, List(1, 2) -> 4),
+      MapType(ArrayType(IntegerType), IntegerType))
+    val m9 = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2),
+      MapType(MapType(IntegerType, IntegerType), IntegerType))
+    val m10 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4),
+      MapType(MapType(IntegerType, IntegerType), IntegerType))
+    val m11 = Literal.create(Map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType,
+      valueContainsNull = false))
+    val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType,
+      valueContainsNull = false))
+    val mNull = Literal.create(null, MapType(StringType, StringType))
+
+    // overlapping maps
+    checkEvaluation(MapConcat(Seq(m0, m1)),
+      (
+        Array("a", "b", "c", "a"), // keys
+        Array("1", "2", "3", "4") // values
+      )
+    )
+
+    // maps with no overlap
+    checkEvaluation(MapConcat(Seq(m0, m2)),
+      Map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5"))
+
+    // 3 maps
+    checkEvaluation(MapConcat(Seq(m0, m1, m2)),
+      (
+        Array("a", "b", "c", "a", "d", "e"), // keys
+        Array("1", "2", "3", "4", "4", "5") // values
+      )
+    )
+
+    // null reference values
+    checkEvaluation(MapConcat(Seq(m3, m4)),
+      (
+        Array("a", "b", "a", "c"), // keys
+        Array("1", "2", null, "3") // values
+      )
+    )
+
+    // null primitive values
+    checkEvaluation(MapConcat(Seq(m5, m6)),
+      (
+        Array("a", "b", "a", "c"), // keys
+        Array(1, 2, null, 3) // values
+      )
+    )
+
+    // keys that are primitive
+    checkEvaluation(MapConcat(Seq(m11, m12)),
+      (
+        Array(1, 2, 3, 4), // keys
+        Array("1", "2", "3", "4") // values
+      )
+    )
+
+    // keys that are arrays, with overlap
+    checkEvaluation(MapConcat(Seq(m7, m8)),
+      (
+        Array(List(1, 2), List(3, 4), List(5, 6), List(1, 2)), // keys
+        Array(1, 2, 3, 4) // values
+      )
+    )
+
+    // keys that are maps, with overlap
+    checkEvaluation(MapConcat(Seq(m9, m10)),
+      (
+        Array(Map(1 -> 2, 3 -> 4), Map(5 -> 6, 7 -> 8), Map(9 -> 10, 11 -> 12),
+          Map(1 -> 2, 3 -> 4)), // keys
+        Array(1, 2, 3, 4) // values
+      )
+    )
+
+    // null map
+    checkEvaluation(MapConcat(Seq(m0, mNull)), null)
+    checkEvaluation(MapConcat(Seq(mNull, m0)), null)
+    checkEvaluation(MapConcat(Seq(mNull, mNull)), null)
+    checkEvaluation(MapConcat(Seq(mNull)), null)
+
+    // single map
+    checkEvaluation(MapConcat(Seq(m0)), Map("a" -> "1", "b" -> "2"))
+
+    // no map
+    checkEvaluation(MapConcat(Seq.empty), Map.empty)
+
+    // force split expressions for input in generated code
+    val expectedKeys = Array.fill(65)(Seq("a", "b")).flatten ++ Array("d", "e")
+    val expectedValues = Array.fill(65)(Seq("1", "2")).flatten ++ Array("4", "5")
+    checkEvaluation(MapConcat(
+      Seq(
+        m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0,
+        m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0,
+        m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m2
+      )),
+      (expectedKeys, expectedValues))
+
+    // argument checking
+    assert(MapConcat(Seq(m0, m1)).checkInputDataTypes().isSuccess)
+    assert(MapConcat(Seq(m5, m6)).checkInputDataTypes().isSuccess)
+    assert(MapConcat(Seq(m0, m5)).checkInputDataTypes().isFailure)
+    assert(MapConcat(Seq(m0, Literal(12))).checkInputDataTypes().isFailure)
+    assert(MapConcat(Seq(m0, m1)).dataType.keyType == StringType)
+    assert(MapConcat(Seq(m0, m1)).dataType.valueType == StringType)
+    assert(!MapConcat(Seq(m0, m1)).dataType.valueContainsNull)
+    assert(MapConcat(Seq(m5, m6)).dataType.keyType == StringType)
+    assert(MapConcat(Seq(m5, m6)).dataType.valueType == IntegerType)
+    assert(MapConcat(Seq.empty).dataType.keyType == StringType)
+    assert(MapConcat(Seq.empty).dataType.valueType == StringType)
+    assert(MapConcat(Seq(m5, m6)).dataType.valueContainsNull)
+    assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull)
+    assert(!MapConcat(Seq(m1, m2)).nullable)
+    assert(MapConcat(Seq(m1, mNull)).nullable)
+  }
+
   test("MapFromEntries") {
     def arrayType(keyType: DataType, valueType: DataType) : DataType = {
       ArrayType(

http://git-wip-us.apache.org/repos/asf/spark/blob/034913b6/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index f2627e6..89dbba1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3627,6 +3627,14 @@ object functions {
   @scala.annotation.varargs
   def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) }
 
+  /**
+   * Returns the union of all the given maps.
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  @scala.annotation.varargs
+  def map_concat(cols: Column*): Column = withExpr { MapConcat(cols.map(_.expr)) }
+
   //////////////////////////////////////////////////////////////////////////////////////////////
   // Mask functions
   //////////////////////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/034913b6/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql
new file mode 100644
index 0000000..fc26397
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql
@@ -0,0 +1,94 @@
+CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES (
+  map(true, false), map(false, true),
+  map(1Y, 2Y), map(3Y, 4Y),
+  map(1S, 2S), map(3S, 4S),
+  map(4, 6), map(7, 8),
+  map(6L, 7L), map(8L, 9L),
+  map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809),  
+  map(1.0D, 2.0D), map(3.0D, 4.0D),
+  map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)),
+  map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'),
+  map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'),
+  map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'),
+  map('a', 'b'), map('c', 'd'),
+  map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')),
+  map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)),
+  map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)),
+  map('a', 1), map('c', 2),
+  map(1, 'a'), map(2, 'c')
+) AS various_maps (
+  boolean_map1, boolean_map2,
+  tinyint_map1, tinyint_map2,
+  smallint_map1, smallint_map2,
+  int_map1, int_map2,
+  bigint_map1, bigint_map2,
+  decimal_map1, decimal_map2,
+  double_map1, double_map2,
+  float_map1, float_map2,
+  date_map1, date_map2,
+  timestamp_map1,
+  timestamp_map2,
+  string_map1, string_map2,
+  array_map1, array_map2,
+  struct_map1, struct_map2,
+  map_map1, map_map2,
+  string_int_map1, string_int_map2,
+  int_string_map1, int_string_map2
+);
+
+-- Concatenate maps of the same type
+SELECT
+    map_concat(boolean_map1, boolean_map2) boolean_map,
+    map_concat(tinyint_map1, tinyint_map2) tinyint_map,
+    map_concat(smallint_map1, smallint_map2) smallint_map,
+    map_concat(int_map1, int_map2) int_map,
+    map_concat(bigint_map1, bigint_map2) bigint_map,
+    map_concat(decimal_map1, decimal_map2) decimal_map,
+    map_concat(float_map1, float_map2) float_map,
+    map_concat(double_map1, double_map2) double_map,
+    map_concat(date_map1, date_map2) date_map,
+    map_concat(timestamp_map1, timestamp_map2) timestamp_map,
+    map_concat(string_map1, string_map2) string_map,
+    map_concat(array_map1, array_map2) array_map,
+    map_concat(struct_map1, struct_map2) struct_map,
+    map_concat(map_map1, map_map2) map_map,
+    map_concat(string_int_map1, string_int_map2) string_int_map,
+    map_concat(int_string_map1, int_string_map2) int_string_map
+FROM various_maps;
+
+-- Concatenate maps of different types
+SELECT
+    map_concat(tinyint_map1, smallint_map2) ts_map,
+    map_concat(smallint_map1, int_map2) si_map,
+    map_concat(int_map1, bigint_map2) ib_map,
+    map_concat(decimal_map1, float_map2) df_map,
+    map_concat(string_map1, date_map2) std_map,
+    map_concat(timestamp_map1, string_map2) tst_map,
+    map_concat(string_map1, int_map2) sti_map,
+    map_concat(int_string_map1, tinyint_map2) istt_map
+FROM various_maps;
+
+-- Concatenate map of incompatible types 1
+SELECT
+    map_concat(tinyint_map1, map_map2) tm_map
+FROM various_maps;
+
+-- Concatenate map of incompatible types 2
+SELECT
+    map_concat(boolean_map1, int_map2) bi_map
+FROM various_maps;
+
+-- Concatenate map of incompatible types 3
+SELECT
+    map_concat(int_map1, struct_map2) is_map
+FROM various_maps;
+
+-- Concatenate map of incompatible types 4
+SELECT
+    map_concat(map_map1, array_map2) ma_map
+FROM various_maps;
+
+-- Concatenate map of incompatible types 5
+SELECT
+    map_concat(map_map1, struct_map2) ms_map
+FROM various_maps;

http://git-wip-us.apache.org/repos/asf/spark/blob/034913b6/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out
new file mode 100644
index 0000000..d352b72
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out
@@ -0,0 +1,143 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 8
+
+
+-- !query 0
+CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES (
+  map(true, false), map(false, true),
+  map(1Y, 2Y), map(3Y, 4Y),
+  map(1S, 2S), map(3S, 4S),
+  map(4, 6), map(7, 8),
+  map(6L, 7L), map(8L, 9L),
+  map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809),  
+  map(1.0D, 2.0D), map(3.0D, 4.0D),
+  map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)),
+  map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'),
+  map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'),
+  map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'),
+  map('a', 'b'), map('c', 'd'),
+  map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')),
+  map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)),
+  map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)),
+  map('a', 1), map('c', 2),
+  map(1, 'a'), map(2, 'c')
+) AS various_maps (
+  boolean_map1, boolean_map2,
+  tinyint_map1, tinyint_map2,
+  smallint_map1, smallint_map2,
+  int_map1, int_map2,
+  bigint_map1, bigint_map2,
+  decimal_map1, decimal_map2,
+  double_map1, double_map2,
+  float_map1, float_map2,
+  date_map1, date_map2,
+  timestamp_map1,
+  timestamp_map2,
+  string_map1, string_map2,
+  array_map1, array_map2,
+  struct_map1, struct_map2,
+  map_map1, map_map2,
+  string_int_map1, string_int_map2,
+  int_string_map1, int_string_map2
+)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+SELECT
+    map_concat(boolean_map1, boolean_map2) boolean_map,
+    map_concat(tinyint_map1, tinyint_map2) tinyint_map,
+    map_concat(smallint_map1, smallint_map2) smallint_map,
+    map_concat(int_map1, int_map2) int_map,
+    map_concat(bigint_map1, bigint_map2) bigint_map,
+    map_concat(decimal_map1, decimal_map2) decimal_map,
+    map_concat(float_map1, float_map2) float_map,
+    map_concat(double_map1, double_map2) double_map,
+    map_concat(date_map1, date_map2) date_map,
+    map_concat(timestamp_map1, timestamp_map2) timestamp_map,
+    map_concat(string_map1, string_map2) string_map,
+    map_concat(array_map1, array_map2) array_map,
+    map_concat(struct_map1, struct_map2) struct_map,
+    map_concat(map_map1, map_map2) map_map,
+    map_concat(string_int_map1, string_int_map2) string_int_map,
+    map_concat(int_string_map1, int_string_map2) int_string_map
+FROM various_maps
+-- !query 1 schema
+struct<boolean_map:map<boolean,boolean>,tinyint_map:map<tinyint,tinyint>,smallint_map:map<smallint,smallint>,int_map:map<int,int>,bigint_map:map<bigint,bigint>,decimal_map:map<decimal(19,0),decimal(19,0)>,float_map:map<float,float>,double_map:map<double,double>,date_map:map<date,date>,timestamp_map:map<timestamp,timestamp>,string_map:map<string,string>,array_map:map<array<string>,array<string>>,struct_map:map<struct<col1:string,col2:int>,struct<col1:string,col2:int>>,map_map:map<map<string,int>,map<string,int>>,string_int_map:map<string,int>,int_string_map:map<int,string>>
+-- !query 1 output
+{false:true,true:false}	{1:2,3:4}	{1:2,3:4}	{4:6,7:8}	{6:7,8:9}	{9223372036854775808:9223372036854775809,9223372036854775809:9223372036854775808}	{1.0:2.0,3.0:4.0}	{1.0:2.0,3.0:4.0}	{2016-03-12:2016-03-11,2016-03-14:2016-03-13}	{2016-11-11 20:54:00.0:2016-11-09 20:54:00.0,2016-11-15 20:54:00.0:2016-11-12 20:54:00.0}	{"a":"b","c":"d"}	{["a","b"]:["c","d"],["e"]:["f"]}	{{"col1":"a","col2":1}:{"col1":"b","col2":2},{"col1":"c","col2":3}:{"col1":"d","col2":4}}	{{"a":1}:{"b":2},{"c":3}:{"d":4}}	{"a":1,"c":2}	{1:"a",2:"c"}
+
+
+-- !query 2
+SELECT
+    map_concat(tinyint_map1, smallint_map2) ts_map,
+    map_concat(smallint_map1, int_map2) si_map,
+    map_concat(int_map1, bigint_map2) ib_map,
+    map_concat(decimal_map1, float_map2) df_map,
+    map_concat(string_map1, date_map2) std_map,
+    map_concat(timestamp_map1, string_map2) tst_map,
+    map_concat(string_map1, int_map2) sti_map,
+    map_concat(int_string_map1, tinyint_map2) istt_map
+FROM various_maps
+-- !query 2 schema
+struct<ts_map:map<smallint,smallint>,si_map:map<int,int>,ib_map:map<bigint,bigint>,df_map:map<double,double>,std_map:map<string,string>,tst_map:map<string,string>,sti_map:map<string,string>,istt_map:map<int,string>>
+-- !query 2 output
+{1:2,3:4}	{1:2,7:8}	{4:6,8:9}	{3.0:4.0,9.223372036854776E18:9.223372036854776E18}	{"2016-03-12":"2016-03-11","a":"b"}	{"2016-11-15 20:54:00":"2016-11-12 20:54:00","c":"d"}	{"7":"8","a":"b"}	{1:"a",3:"4"}
+
+
+-- !query 3
+SELECT
+    map_concat(tinyint_map1, map_map2) tm_map
+FROM various_maps
+-- !query 3 schema
+struct<>
+-- !query 3 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'map_concat(various_maps.`tinyint_map1`, various_maps.`map_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map<tinyint,tinyint>, map<map<string,int>,map<string,int>>]; line 2 pos 4
+
+
+-- !query 4
+SELECT
+    map_concat(boolean_map1, int_map2) bi_map
+FROM various_maps
+-- !query 4 schema
+struct<>
+-- !query 4 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'map_concat(various_maps.`boolean_map1`, various_maps.`int_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map<boolean,boolean>, map<int,int>]; line 2 pos 4
+
+
+-- !query 5
+SELECT
+    map_concat(int_map1, struct_map2) is_map
+FROM various_maps
+-- !query 5 schema
+struct<>
+-- !query 5 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'map_concat(various_maps.`int_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map<int,int>, map<struct<col1:string,col2:int>,struct<col1:string,col2:int>>]; line 2 pos 4
+
+
+-- !query 6
+SELECT
+    map_concat(map_map1, array_map2) ma_map
+FROM various_maps
+-- !query 6 schema
+struct<>
+-- !query 6 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map<map<string,int>,map<string,int>>, map<array<string>,array<string>>]; line 2 pos 4
+
+
+-- !query 7
+SELECT
+    map_concat(map_map1, struct_map2) ms_map
+FROM various_maps
+-- !query 7 schema
+struct<>
+-- !query 7 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map<map<string,int>,map<string,int>>, map<struct<col1:string,col2:int>,struct<col1:string,col2:int>>]; line 2 pos 4

http://git-wip-us.apache.org/repos/asf/spark/blob/034913b6/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 4c28e2f..d60ed7a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -657,6 +657,84 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
     checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected)
   }
 
+  test("map_concat function") {
+    val df1 = Seq(
+      (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 4 -> 400)),
+      (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 1 -> 400)),
+      (null, Map[Int, Int](3 -> 300, 4 -> 400))
+    ).toDF("map1", "map2")
+
+    val expected1a = Seq(
+      Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)),
+      Row(Map(1 -> 400, 2 -> 200, 3 -> 300)),
+      Row(null)
+    )
+
+    checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a)
+    checkAnswer(df1.select(map_concat('map1, 'map2)), expected1a)
+
+    val expected1b = Seq(
+      Row(Map(1 -> 100, 2 -> 200)),
+      Row(Map(1 -> 100, 2 -> 200)),
+      Row(null)
+    )
+
+    checkAnswer(df1.selectExpr("map_concat(map1)"), expected1b)
+    checkAnswer(df1.select(map_concat('map1)), expected1b)
+
+    val df2 = Seq(
+      (
+        Map[Array[Int], Int](Array(1) -> 100, Array(2) -> 200),
+        Map[String, Int]("3" -> 300, "4" -> 400)
+      )
+    ).toDF("map1", "map2")
+
+    val expected2 = Seq(Row(Map()))
+
+    checkAnswer(df2.selectExpr("map_concat()"), expected2)
+    checkAnswer(df2.select(map_concat()), expected2)
+
+    val df3 = {
+      val schema = StructType(
+        StructField("map1", MapType(StringType, IntegerType, true), false)  ::
+        StructField("map2", MapType(StringType, IntegerType, false), false) :: Nil
+      )
+      val data = Seq(
+        Row(Map[String, Any]("a" -> 1, "b" -> null), Map[String, Any]("c" -> 3, "d" -> 4)),
+        Row(Map[String, Any]("a" -> 1, "b" -> 2), Map[String, Any]("c" -> 3, "d" -> 4))
+      )
+      spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
+    }
+
+    val expected3 = Seq(
+      Row(Map[String, Any]("a" -> 1, "b" -> null, "c" -> 3, "d" -> 4)),
+      Row(Map[String, Any]("a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4))
+    )
+
+    checkAnswer(df3.selectExpr("map_concat(map1, map2)"), expected3)
+    checkAnswer(df3.select(map_concat('map1, 'map2)), expected3)
+
+    val expectedMessage1 = "input to function map_concat should all be the same type"
+
+    assert(intercept[AnalysisException] {
+      df2.selectExpr("map_concat(map1, map2)").collect()
+    }.getMessage().contains(expectedMessage1))
+
+    assert(intercept[AnalysisException] {
+      df2.select(map_concat('map1, 'map2)).collect()
+    }.getMessage().contains(expectedMessage1))
+
+    val expectedMessage2 = "input to function map_concat should all be of type map"
+
+    assert(intercept[AnalysisException] {
+      df2.selectExpr("map_concat(map1, 12)").collect()
+    }.getMessage().contains(expectedMessage2))
+
+    assert(intercept[AnalysisException] {
+      df2.select(map_concat('map1, lit(12))).collect()
+    }.getMessage().contains(expectedMessage2))
+  }
+
   test("map_from_entries function") {
     def dummyFilter(c: Column): Column = c.isNull || c.isNotNull
     val oneRowDF = Seq(3215).toDF("i")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org