You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2016/03/25 17:50:11 UTC

spark git commit: [SPARK-14061][SQL] implement CreateMap

Repository: spark
Updated Branches:
  refs/heads/master 6603d9f7e -> 43b15e01c


[SPARK-14061][SQL] implement CreateMap

## What changes were proposed in this pull request?

As we have `CreateArray` and `CreateStruct`, we should also have `CreateMap`.  This PR adds the `CreateMap` expression, and the DataFrame API, and python API.

## How was this patch tested?

various new tests.

Author: Wenchen Fan <we...@databricks.com>

Closes #11879 from cloud-fan/create_map.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/43b15e01
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/43b15e01
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/43b15e01

Branch: refs/heads/master
Commit: 43b15e01c46ea1971569f74c9201a55de39e8917
Parents: 6603d9f
Author: Wenchen Fan <we...@databricks.com>
Authored: Fri Mar 25 09:50:06 2016 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Fri Mar 25 09:50:06 2016 -0700

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 | 20 +++++
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../catalyst/analysis/HiveTypeCoercion.scala    | 35 ++++++++-
 .../expressions/complexTypeCreator.scala        | 83 +++++++++++++++++++-
 .../sql/catalyst/util/ArrayBasedMapData.scala   |  5 +-
 .../analysis/ExpressionTypeCheckingSuite.scala  | 16 +++-
 .../analysis/HiveTypeCoercionSuite.scala        | 61 ++++++++++++++
 .../catalyst/expressions/ComplexTypeSuite.scala | 40 ++++++++++
 .../scala/org/apache/spark/sql/functions.scala  | 11 +++
 .../spark/sql/DataFrameComplexTypeSuite.scala   |  8 +-
 .../spark/sql/DataFrameFunctionsSuite.scala     | 15 ++--
 .../spark/sql/hive/ExpressionToSQLSuite.scala   |  1 +
 12 files changed, 277 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/43b15e01/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index dee3d53..f5d959e 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1498,6 +1498,26 @@ def translate(srcCol, matching, replace):
 
 # ---------------------- Collection functions ------------------------------
 
+@ignore_unicode_prefix
+@since(2.0)
+def create_map(*cols):
+    """Creates a new map column.
+
+    :param cols: list of column names (string) or list of :class:`Column` expressions that grouped
+        as key-value pairs, e.g. (key1, value1, key2, value2, ...).
+
+    >>> df.select(create_map('name', 'age').alias("map")).collect()
+    [Row(map={u'Alice': 2}), Row(map={u'Bob': 5})]
+    >>> df.select(create_map([df.name, df.age]).alias("map")).collect()
+    [Row(map={u'Alice': 2}), Row(map={u'Bob': 5})]
+    """
+    sc = SparkContext._active_spark_context
+    if len(cols) == 1 and isinstance(cols[0], (list, set)):
+        cols = cols[0]
+    jc = sc._jvm.functions.map(_to_seq(sc, cols, _to_java_column))
+    return Column(jc)
+
+
 @since(1.4)
 def array(*cols):
     """Creates a new array column.

http://git-wip-us.apache.org/repos/asf/spark/blob/43b15e01/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 26bb96e..f584a4b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -126,6 +126,7 @@ object FunctionRegistry {
     expression[IsNull]("isnull"),
     expression[IsNotNull]("isnotnull"),
     expression[Least]("least"),
+    expression[CreateMap]("map"),
     expression[CreateNamedStruct]("named_struct"),
     expression[NaNvl]("nanvl"),
     expression[Coalesce]("nvl"),

http://git-wip-us.apache.org/repos/asf/spark/blob/43b15e01/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 0f85f44..823d249 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -160,6 +160,9 @@ object HiveTypeCoercion {
     })
   }
 
+  private def haveSameType(exprs: Seq[Expression]): Boolean =
+    exprs.map(_.dataType).distinct.length == 1
+
   /**
    * Applies any changes to [[AttributeReference]] data types that are made by other rules to
    * instances higher in the query tree.
@@ -443,13 +446,37 @@ object HiveTypeCoercion {
       // Skip nodes who's children have not been resolved yet.
       case e if !e.childrenResolved => e
 
-      case a @ CreateArray(children) if children.map(_.dataType).distinct.size > 1 =>
+      case a @ CreateArray(children) if !haveSameType(children) =>
         val types = children.map(_.dataType)
         findTightestCommonTypeAndPromoteToString(types) match {
           case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
           case None => a
         }
 
+      case m @ CreateMap(children) if m.keys.length == m.values.length &&
+        (!haveSameType(m.keys) || !haveSameType(m.values)) =>
+        val newKeys = if (haveSameType(m.keys)) {
+          m.keys
+        } else {
+          val types = m.keys.map(_.dataType)
+          findTightestCommonTypeAndPromoteToString(types) match {
+            case Some(finalDataType) => m.keys.map(Cast(_, finalDataType))
+            case None => m.keys
+          }
+        }
+
+        val newValues = if (haveSameType(m.values)) {
+          m.values
+        } else {
+          val types = m.values.map(_.dataType)
+          findTightestCommonTypeAndPromoteToString(types) match {
+            case Some(finalDataType) => m.values.map(Cast(_, finalDataType))
+            case None => m.values
+          }
+        }
+
+        CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })
+
       // Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows.
       case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
       case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType))
@@ -468,21 +495,21 @@ object HiveTypeCoercion {
       // Coalesce should return the first non-null value, which could be any column
       // from the list. So we need to make sure the return type is deterministic and
       // compatible with every child column.
-      case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
+      case c @ Coalesce(es) if !haveSameType(es) =>
         val types = es.map(_.dataType)
         findWiderCommonType(types) match {
           case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
           case None => c
         }
 
-      case g @ Greatest(children) if children.map(_.dataType).distinct.size > 1 =>
+      case g @ Greatest(children) if !haveSameType(children) =>
         val types = children.map(_.dataType)
         findTightestCommonType(types) match {
           case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
           case None => g
         }
 
-      case l @ Least(children) if children.map(_.dataType).distinct.size > 1 =>
+      case l @ Least(children) if !haveSameType(children) =>
         val types = children.map(_.dataType)
         findTightestCommonType(types) match {
           case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))

http://git-wip-us.apache.org/repos/asf/spark/blob/43b15e01/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index efd7529..c299586 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -70,6 +70,87 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
 }
 
 /**
+ * Returns a catalyst Map containing the evaluation of all children expressions as keys and values.
+ * The children are a flatted sequence of kv pairs, e.g. (key1, value1, key2, value2, ...)
+ */
+case class CreateMap(children: Seq[Expression]) extends Expression {
+  private[sql] lazy val keys = children.indices.filter(_ % 2 == 0).map(children)
+  private[sql] lazy val values = children.indices.filter(_ % 2 != 0).map(children)
+
+  override def foldable: Boolean = children.forall(_.foldable)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (children.size % 2 != 0) {
+      TypeCheckResult.TypeCheckFailure(s"$prettyName expects an positive even number of arguments.")
+    } else if (keys.map(_.dataType).distinct.length > 1) {
+      TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " +
+        "type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]"))
+    } else if (values.map(_.dataType).distinct.length > 1) {
+      TypeCheckResult.TypeCheckFailure("The given values of function map should all be the same " +
+        "type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]"))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override def dataType: DataType = {
+    MapType(
+      keyType = keys.headOption.map(_.dataType).getOrElse(NullType),
+      valueType = values.headOption.map(_.dataType).getOrElse(NullType),
+      valueContainsNull = values.exists(_.nullable))
+  }
+
+  override def nullable: Boolean = false
+
+  override def eval(input: InternalRow): Any = {
+    val keyArray = keys.map(_.eval(input)).toArray
+    if (keyArray.contains(null)) {
+      throw new RuntimeException("Cannot use null as map key!")
+    }
+    val valueArray = values.map(_.eval(input)).toArray
+    new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray))
+  }
+
+  override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
+    val arrayClass = classOf[GenericArrayData].getName
+    val mapClass = classOf[ArrayBasedMapData].getName
+    val keyArray = ctx.freshName("keyArray")
+    val valueArray = ctx.freshName("valueArray")
+    val keyData = s"new $arrayClass($keyArray)"
+    val valueData = s"new $arrayClass($valueArray)"
+    s"""
+      final boolean ${ev.isNull} = false;
+      final Object[] $keyArray = new Object[${keys.size}];
+      final Object[] $valueArray = new Object[${values.size}];
+    """ + keys.zipWithIndex.map {
+      case (key, i) =>
+        val eval = key.gen(ctx)
+        s"""
+          ${eval.code}
+          if (${eval.isNull}) {
+            throw new RuntimeException("Cannot use null as map key!");
+          } else {
+            $keyArray[$i] = ${eval.value};
+          }
+        """
+    }.mkString("\n") + values.zipWithIndex.map {
+      case (value, i) =>
+        val eval = value.gen(ctx)
+        s"""
+          ${eval.code}
+          if (${eval.isNull}) {
+            $valueArray[$i] = null;
+          } else {
+            $valueArray[$i] = ${eval.value};
+          }
+        """
+    }.mkString("\n") + s"final MapData ${ev.value} = new $mapClass($keyData, $valueData);"
+  }
+
+  override def prettyName: String = "map"
+}
+
+/**
  * Returns a Row containing the evaluation of all children expressions.
  */
 case class CreateStruct(children: Seq[Expression]) extends Expression {

http://git-wip-us.apache.org/repos/asf/spark/blob/43b15e01/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala
index d85b72e..d46f03a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala
@@ -24,7 +24,6 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte
 
   override def copy(): MapData = new ArrayBasedMapData(keyArray.copy(), valueArray.copy())
 
-  // We need to check equality of map type in tests.
   override def equals(o: Any): Boolean = {
     if (!o.isInstanceOf[ArrayBasedMapData]) {
       return false
@@ -35,11 +34,11 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte
       return false
     }
 
-    ArrayBasedMapData.toScalaMap(this) == ArrayBasedMapData.toScalaMap(other)
+    this.keyArray == other.keyArray && this.valueArray == other.valueArray
   }
 
   override def hashCode: Int = {
-    ArrayBasedMapData.toScalaMap(this).hashCode()
+    keyArray.hashCode() * 37 + valueArray.hashCode()
   }
 
   override def toString: String = {

http://git-wip-us.apache.org/repos/asf/spark/blob/43b15e01/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 92c8496..ace6e10 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -173,13 +173,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
       CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments")
     assertError(
       CreateNamedStruct(Seq(1, "a", "b", 2.0)),
-        "Only foldable StringType expressions are allowed to appear at odd position")
+      "Only foldable StringType expressions are allowed to appear at odd position")
     assertError(
       CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)),
-        "Only foldable StringType expressions are allowed to appear at odd position")
+      "Only foldable StringType expressions are allowed to appear at odd position")
     assertError(
       CreateNamedStruct(Seq(Literal.create(null, StringType), "a")),
-        "Field name should not be null")
+      "Field name should not be null")
+  }
+
+  test("check types for CreateMap") {
+    assertError(CreateMap(Seq("a", "b", 2.0)), "even number of arguments")
+    assertError(
+      CreateMap(Seq('intField, 'stringField, 'booleanField, 'stringField)),
+      "keys of function map should all be the same type")
+    assertError(
+      CreateMap(Seq('stringField, 'intField, 'stringField, 'booleanField)),
+      "values of function map should all be the same type")
   }
 
   test("check types for ROUND") {

http://git-wip-us.apache.org/repos/asf/spark/blob/43b15e01/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 6f289dc..883ef48 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -250,6 +250,67 @@ class HiveTypeCoercionSuite extends PlanTest {
         :: Nil))
   }
 
+  test("CreateArray casts") {
+    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+      CreateArray(Literal(1.0)
+        :: Literal(1)
+        :: Literal.create(1.0, FloatType)
+        :: Nil),
+      CreateArray(Cast(Literal(1.0), DoubleType)
+        :: Cast(Literal(1), DoubleType)
+        :: Cast(Literal.create(1.0, FloatType), DoubleType)
+        :: Nil))
+
+    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+      CreateArray(Literal(1.0)
+        :: Literal(1)
+        :: Literal("a")
+        :: Nil),
+      CreateArray(Cast(Literal(1.0), StringType)
+        :: Cast(Literal(1), StringType)
+        :: Cast(Literal("a"), StringType)
+        :: Nil))
+  }
+
+  test("CreateMap casts") {
+    // type coercion for map keys
+    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+      CreateMap(Literal(1)
+        :: Literal("a")
+        :: Literal.create(2.0, FloatType)
+        :: Literal("b")
+        :: Nil),
+      CreateMap(Cast(Literal(1), FloatType)
+        :: Literal("a")
+        :: Cast(Literal.create(2.0, FloatType), FloatType)
+        :: Literal("b")
+        :: Nil))
+    // type coercion for map values
+    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+      CreateMap(Literal(1)
+        :: Literal("a")
+        :: Literal(2)
+        :: Literal(3.0)
+        :: Nil),
+      CreateMap(Literal(1)
+        :: Cast(Literal("a"), StringType)
+        :: Literal(2)
+        :: Cast(Literal(3.0), StringType)
+        :: Nil))
+    // type coercion for both map keys and values
+    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+      CreateMap(Literal(1)
+        :: Literal("a")
+        :: Literal(2.0)
+        :: Literal(3.0)
+        :: Nil),
+      CreateMap(Cast(Literal(1), DoubleType)
+        :: Cast(Literal("a"), StringType)
+        :: Cast(Literal(2.0), DoubleType)
+        :: Cast(Literal(3.0), StringType)
+        :: Nil))
+  }
+
   test("greatest/least cast") {
     for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
       ruleTest(HiveTypeCoercion.FunctionArgumentConversion,

http://git-wip-us.apache.org/repos/asf/spark/blob/43b15e01/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 9c1688b..7c009a7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -134,6 +134,46 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil)
   }
 
+  test("CreateMap") {
+    def interlace(keys: Seq[Literal], values: Seq[Literal]): Seq[Literal] = {
+      keys.zip(values).flatMap { case (k, v) => Seq(k, v) }
+    }
+
+    def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
+      // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order.
+      scala.collection.immutable.ListMap(keys.zip(values): _*)
+    }
+
+    val intSeq = Seq(5, 10, 15, 20, 25)
+    val longSeq = intSeq.map(_.toLong)
+    val strSeq = intSeq.map(_.toString)
+    checkEvaluation(CreateMap(Nil), Map.empty)
+    checkEvaluation(
+      CreateMap(interlace(intSeq.map(Literal(_)), longSeq.map(Literal(_)))),
+      createMap(intSeq, longSeq))
+    checkEvaluation(
+      CreateMap(interlace(strSeq.map(Literal(_)), longSeq.map(Literal(_)))),
+      createMap(strSeq, longSeq))
+    checkEvaluation(
+      CreateMap(interlace(longSeq.map(Literal(_)), strSeq.map(Literal(_)))),
+      createMap(longSeq, strSeq))
+
+    val strWithNull = strSeq.drop(1).map(Literal(_)) :+ Literal.create(null, StringType)
+    checkEvaluation(
+      CreateMap(interlace(intSeq.map(Literal(_)), strWithNull)),
+      createMap(intSeq, strWithNull.map(_.value)))
+    intercept[RuntimeException] {
+      checkEvaluationWithoutCodegen(
+        CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))),
+        null, null)
+    }
+    intercept[RuntimeException] {
+      checkEvalutionWithUnsafeProjection(
+        CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))),
+        null, null)
+    }
+  }
+
   test("CreateStruct") {
     val row = create_row(1, 2, 3)
     val c1 = 'a.int.at(0)

http://git-wip-us.apache.org/repos/asf/spark/blob/43b15e01/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 304d747..8abb9d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -905,6 +905,17 @@ object functions {
   }
 
   /**
+   * Creates a new map column. The input columns must be grouped as key-value pairs, e.g.
+   * (key1, value1, key2, value2, ...). The key columns must all have the same data type, and can't
+   * be null. The value columns must all have the same data type.
+   *
+   * @group normal_funcs
+   * @since 2.0
+   */
+  @scala.annotation.varargs
+  def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) }
+
+  /**
    * Marks a DataFrame as small enough for use in broadcast joins.
    *
    * The following example marks the right DataFrame for broadcast hash join using `joinKey`.

http://git-wip-us.apache.org/repos/asf/spark/blob/43b15e01/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
index b76fc73..72f676e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
@@ -41,7 +41,13 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext {
   test("UDF on array") {
     val f = udf((a: String) => a)
     val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
-    df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect()
+    df.select(array($"a").as("s")).select(f($"s".getItem(0))).collect()
+  }
+
+  test("UDF on map") {
+    val f = udf((a: String) => a)
+    val df = Seq("a" -> 1).toDF("a", "b")
+    df.select(map($"a", $"b").as("s")).select(f($"s".getItem("a"))).collect()
   }
 
   test("SPARK-12477 accessing null element in array field") {

http://git-wip-us.apache.org/repos/asf/spark/blob/43b15e01/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 2aa6f8d..746e25a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -44,15 +44,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
 
     val expectedType = ArrayType(IntegerType, containsNull = false)
     assert(row.schema(0).dataType === expectedType)
-    assert(row.getAs[Seq[Int]](0) === Seq(0, 2))
+    assert(row.getSeq[Int](0) === Seq(0, 2))
   }
 
-  // Turn this on once we add a rule to the analyzer to throw a friendly exception
-  ignore("array: throw exception if putting columns of different types into an array") {
-    val df = Seq((0, "str")).toDF("a", "b")
-    intercept[AnalysisException] {
-      df.select(array("a", "b"))
-    }
+  test("map with column expressions") {
+    val df = Seq(1 -> "a").toDF("a", "b")
+    val row = df.select(map($"a" + 1, $"b")).first()
+
+    val expectedType = MapType(IntegerType, StringType, valueContainsNull = true)
+    assert(row.schema(0).dataType === expectedType)
+    assert(row.getMap[Int, String](0) === Map(2 -> "a"))
   }
 
   test("struct with column name") {

http://git-wip-us.apache.org/repos/asf/spark/blob/43b15e01/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala
index 4c9c48a..7593008 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala
@@ -100,6 +100,7 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
     checkSqlGeneration("SELECT isnull(null), isnull('a')")
     checkSqlGeneration("SELECT isnotnull(null), isnotnull('a')")
     checkSqlGeneration("SELECT least(1,null,3)")
+    checkSqlGeneration("SELECT map(1, 'a', 2, 'b')")
     checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)")
     checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2")
     checkSqlGeneration("SELECT nvl(null, 1, 2)")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org