You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/07/02 19:12:29 UTC

spark git commit: [SPARK-8407] [SQL] complex type constructors: struct and named_struct

Repository: spark
Updated Branches:
  refs/heads/master afa021e03 -> 52302a803


[SPARK-8407] [SQL] complex type constructors: struct and named_struct

This is a follow up of [SPARK-8283](https://issues.apache.org/jira/browse/SPARK-8283) ([PR-6828](https://github.com/apache/spark/pull/6828)), to support both `struct` and `named_struct` in Spark SQL.

After [#6725](https://github.com/apache/spark/pull/6828), the semantic of [`CreateStruct`](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala#L56) methods have changed a little and do not limited to cols of `NamedExpressions`, it will name non-NamedExpression fields following the hive convention, col1, col2 ...

This PR would both loosen [`struct`](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/functions.scala#L723) to take children of `Expression` type and add `named_struct` support.

Author: Yijie Shen <he...@gmail.com>

Closes #6874 from yijieshen/SPARK-8283 and squashes the following commits:

4cd3375ac [Yijie Shen] change struct documentation
d599d0b [Yijie Shen] rebase code
9a7039e [Yijie Shen] fix reviews and regenerate golden answers
b487354 [Yijie Shen] replace assert using checkAnswer
f07e114 [Yijie Shen] tiny fix
9613be9 [Yijie Shen] review fix
7fef712 [Yijie Shen] Fix checkInputTypes' implementation using foldable and nullable
60812a7 [Yijie Shen] Fix type check
828d694 [Yijie Shen] remove unnecessary resolved assertion inside dataType method
fd3cd8e [Yijie Shen] remove type check from eval
7a71255 [Yijie Shen] tiny fix
ccbbd86 [Yijie Shen] Fix reviews
47da332 [Yijie Shen] remove nameStruct API from DataFrame
917e680 [Yijie Shen] Fix reviews
4bd75ad [Yijie Shen] loosen struct method in functions.scala to take Expression children
0acb7be [Yijie Shen] Add CreateNamedStruct in both DataFrame function API and FunctionRegistery


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/52302a80
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/52302a80
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/52302a80

Branch: refs/heads/master
Commit: 52302a803967114b29a8bf6b74459477364c5b88
Parents: afa021e
Author: Yijie Shen <he...@gmail.com>
Authored: Thu Jul 2 10:12:25 2015 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Thu Jul 2 10:12:25 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 |  1 -
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../expressions/complexTypeCreator.scala        | 49 ++++++++++++++++++++
 .../analysis/ExpressionTypeCheckingSuite.scala  | 11 +++++
 .../catalyst/expressions/ComplexTypeSuite.scala | 24 +++++++++-
 .../scala/org/apache/spark/sql/functions.scala  | 11 +++--
 .../spark/sql/DataFrameFunctionsSuite.scala     | 40 ++++++++++++++--
 ...neric udf-0-638f81ad9077c7d0c5c735c6e73742ad |  1 +
 ...neric udf-0-cc120a2331158f570a073599985d3f55 |  1 -
 .../sql/hive/execution/HiveQuerySuite.scala     |  2 +-
 10 files changed, 127 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/52302a80/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index bccde60..12263e6 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -467,7 +467,6 @@ def struct(*cols):
     """Creates a new struct column.
 
     :param cols: list of column names (string) or list of :class:`Column` expressions
-        that are named or aliased.
 
     >>> df.select(struct('age', 'name').alias("struct")).collect()
     [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))]

http://git-wip-us.apache.org/repos/asf/spark/blob/52302a80/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index aa051b1..e7e4d1c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -96,6 +96,7 @@ object FunctionRegistry {
     expression[Rand]("rand"),
     expression[Randn]("randn"),
     expression[CreateStruct]("struct"),
+    expression[CreateNamedStruct]("named_struct"),
     expression[Sqrt]("sqrt"),
 
     // math functions

http://git-wip-us.apache.org/repos/asf/spark/blob/52302a80/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 67e7dc4..fa70409 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -17,9 +17,12 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import org.apache.spark.sql.catalyst
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * Returns an Array containing the evaluation of all children expressions.
@@ -54,6 +57,8 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
 
   override def foldable: Boolean = children.forall(_.foldable)
 
+  override lazy val resolved: Boolean = childrenResolved
+
   override lazy val dataType: StructType = {
     val fields = children.zipWithIndex.map { case (child, idx) =>
       child match {
@@ -74,3 +79,47 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
 
   override def prettyName: String = "struct"
 }
+
+/**
+ * Creates a struct with the given field names and values
+ *
+ * @param children Seq(name1, val1, name2, val2, ...)
+ */
+case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
+
+  private lazy val (nameExprs, valExprs) =
+    children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip
+
+  private lazy val names = nameExprs.map(_.eval(EmptyRow).toString)
+
+  override lazy val dataType: StructType = {
+    val fields = names.zip(valExprs).map { case (name, valExpr) =>
+      StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty)
+    }
+    StructType(fields)
+  }
+
+  override def foldable: Boolean = valExprs.forall(_.foldable)
+
+  override def nullable: Boolean = false
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (children.size % 2 != 0) {
+      TypeCheckResult.TypeCheckFailure("CreateNamedStruct expects an even number of arguments.")
+    } else {
+      val invalidNames =
+        nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable)
+      if (invalidNames.size != 0) {
+        TypeCheckResult.TypeCheckFailure(
+          s"Odd position only allow foldable and not-null StringType expressions, got :" +
+            s" ${invalidNames.mkString(",")}")
+      } else {
+        TypeCheckResult.TypeCheckSuccess
+      }
+    }
+  }
+
+  override def eval(input: InternalRow): Any = {
+    InternalRow(valExprs.map(_.eval(input)): _*)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/52302a80/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index bc1537b..8e0551b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -160,4 +160,15 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertError(Explode('intField),
       "input to function explode should be array or map type")
   }
+
+  test("check types for CreateNamedStruct") {
+    assertError(
+      CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments")
+    assertError(
+      CreateNamedStruct(Seq(1, "a", "b", 2.0)),
+        "Odd position only allow foldable and not-null StringType expressions")
+    assertError(
+      CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)),
+        "Odd position only allow foldable and not-null StringType expressions")
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/52302a80/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 3515d04..a09014e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import org.scalatest.exceptions.TestFailedException
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
 import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -119,11 +121,29 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
 
   test("CreateStruct") {
     val row = create_row(1, 2, 3)
-    val c1 = 'a.int.at(0).as("a")
-    val c3 = 'c.int.at(2).as("c")
+    val c1 = 'a.int.at(0)
+    val c3 = 'c.int.at(2)
     checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row)
   }
 
+  test("CreateNamedStruct") {
+    val row = InternalRow(1, 2, 3)
+    val c1 = 'a.int.at(0)
+    val c3 = 'c.int.at(2)
+    checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), InternalRow(1, 3), row)
+  }
+
+  test("CreateNamedStruct with literal field") {
+    val row = InternalRow(1, 2, 3)
+    val c1 = 'a.int.at(0)
+    checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), InternalRow(1, "y"), row)
+  }
+
+  test("CreateNamedStruct from all literal fields") {
+    checkEvaluation(
+      CreateNamedStruct(Seq("a", "x", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty)
+  }
+
   test("test dsl for complex type") {
     def quickResolve(u: UnresolvedExtractValue): Expression = {
       ExtractValue(u.child, u.extraction, _ == _)

http://git-wip-us.apache.org/repos/asf/spark/blob/52302a80/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index a5b6828..4ee1fb8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -739,17 +739,18 @@ object functions {
   def sqrt(colName: String): Column = sqrt(Column(colName))
 
   /**
-   * Creates a new struct column. The input column must be a column in a [[DataFrame]], or
-   * a derived column expression that is named (i.e. aliased).
+   * Creates a new struct column.
+   * If the input column is a column in a [[DataFrame]], or a derived column expression
+   * that is named (i.e. aliased), its name would be remained as the StructField's name,
+   * otherwise, the newly generated StructField's name would be auto generated as col${index + 1},
+   * i.e. col1, col2, col3, ...
    *
    * @group normal_funcs
    * @since 1.4.0
    */
   @scala.annotation.varargs
   def struct(cols: Column*): Column = {
-    require(cols.forall(_.expr.isInstanceOf[NamedExpression]),
-      s"struct input columns must all be named or aliased ($cols)")
-    CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression]))
+    CreateStruct(cols.map(_.expr))
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/52302a80/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 7ae89bc..0d43aca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -79,10 +79,42 @@ class DataFrameFunctionsSuite extends QueryTest {
     assert(row.getAs[Row](0) === Row(2, "str"))
   }
 
-  test("struct: must use named column expression") {
-    intercept[IllegalArgumentException] {
-      struct(col("a") * 2)
-    }
+  test("struct with column expression to be automatically named") {
+    val df = Seq((1, "str")).toDF("a", "b")
+    val result = df.select(struct((col("a") * 2), col("b")))
+
+    val expectedType = StructType(Seq(
+      StructField("col1", IntegerType, nullable = false),
+      StructField("b", StringType)
+    ))
+    assert(result.first.schema(0).dataType === expectedType)
+    checkAnswer(result, Row(Row(2, "str")))
+  }
+
+  test("struct with literal columns") {
+    val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b")
+    val result = df.select(struct((col("a") * 2), lit(5.0)))
+
+    val expectedType = StructType(Seq(
+      StructField("col1", IntegerType, nullable = false),
+      StructField("col2", DoubleType, nullable = false)
+    ))
+
+    assert(result.first.schema(0).dataType === expectedType)
+    checkAnswer(result, Seq(Row(Row(2, 5.0)), Row(Row(4, 5.0))))
+  }
+
+  test("struct with all literal columns") {
+    val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b")
+    val result = df.select(struct(lit("v"), lit(5.0)))
+
+    val expectedType = StructType(Seq(
+      StructField("col1", StringType, nullable = false),
+      StructField("col2", DoubleType, nullable = false)
+    ))
+
+    assert(result.first.schema(0).dataType === expectedType)
+    checkAnswer(result, Seq(Row(Row("v", 5.0)), Row(Row("v", 5.0))))
   }
 
   test("constant functions") {

http://git-wip-us.apache.org/repos/asf/spark/blob/52302a80/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad
new file mode 100644
index 0000000..7bc77e7
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad	
@@ -0,0 +1 @@
+{"aa":"10","aaaaaa":"11","aaaaaa":"12","bb12":"13","s14s14":"14"}

http://git-wip-us.apache.org/repos/asf/spark/blob/52302a80/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55
deleted file mode 100644
index 7bc77e7..0000000
--- a/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55	
+++ /dev/null
@@ -1 +0,0 @@
-{"aa":"10","aaaaaa":"11","aaaaaa":"12","bb12":"13","s14s14":"14"}

http://git-wip-us.apache.org/repos/asf/spark/blob/52302a80/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 4cdba03..991da2f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -132,7 +132,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
       lower("AA"), "10",
       repeat(lower("AA"), 3), "11",
       lower(repeat("AA", 3)), "12",
-      printf("Bb%d", 12), "13",
+      printf("bb%d", 12), "13",
       repeat(printf("s%d", 14), 2), "14") FROM src LIMIT 1""")
 
   createQueryTest("NaN to Decimal",


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org