You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/01 21:49:10 UTC

spark git commit: [SPARK-7274] [SQL] Create Column expression for array/struct creation.

Repository: spark
Updated Branches:
  refs/heads/master 168603272 -> 37537760d


[SPARK-7274] [SQL] Create Column expression for array/struct creation.

Author: Reynold Xin <rx...@databricks.com>

Closes #5802 from rxin/SPARK-7274 and squashes the following commits:

19aecaa [Reynold Xin] Fixed unicode tests.
bfc1538 [Reynold Xin] Export all Python functions.
2517b8c [Reynold Xin] Code review.
23da335 [Reynold Xin] Fixed Python bug.
132002e [Reynold Xin] Fixed tests.
56fce26 [Reynold Xin] Added Python support.
b0d591a [Reynold Xin] Fixed debug error.
86926a6 [Reynold Xin] Added test suite.
7dbb9ab [Reynold Xin] Ok one more.
470e2f5 [Reynold Xin] One more MLlib ...
e2d14f0 [Reynold Xin] [SPARK-7274][SQL] Create Column expression for array/struct creation.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/37537760
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/37537760
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/37537760

Branch: refs/heads/master
Commit: 37537760d19eab878a5e1a48641cc49e6cb4b989
Parents: 1686032
Author: Reynold Xin <rx...@databricks.com>
Authored: Fri May 1 12:49:02 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Fri May 1 12:49:02 2015 -0700

----------------------------------------------------------------------
 .../spark/ml/feature/VectorAssembler.scala      | 13 ++-
 python/pyspark/sql/functions.py                 | 80 ++++++++++++++-----
 .../catalyst/expressions/BoundAttribute.scala   | 10 ++-
 .../scala/org/apache/spark/sql/functions.scala  | 41 +++++++++-
 .../spark/sql/DataFrameFunctionsSuite.scala     | 84 ++++++++++++++++++++
 5 files changed, 199 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/37537760/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 7b2a451..5e781a3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -25,9 +25,7 @@ import org.apache.spark.ml.Transformer
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
-import org.apache.spark.sql.{Column, DataFrame, Row}
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateStruct}
+import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 
@@ -53,13 +51,12 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
     val inputColNames = map(inputCols)
     val args = inputColNames.map { c =>
       schema(c).dataType match {
-        case DoubleType => UnresolvedAttribute(c)
-        case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c)
-        case _: NumericType | BooleanType =>
-          Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")()
+        case DoubleType => dataset(c)
+        case _: VectorUDT => dataset(c)
+        case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid")
       }
     }
-    dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol)))
+    dataset.select(col("*"), assembleFunc(struct(args : _*)).as(map(outputCol)))
   }
 
   override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {

http://git-wip-us.apache.org/repos/asf/spark/blob/37537760/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 241f821..641220a 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -24,13 +24,20 @@ if sys.version < "3":
     from itertools import imap as map
 
 from pyspark import SparkContext
-from pyspark.rdd import _prepare_for_python_RDD
+from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
 from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
 from pyspark.sql.types import StringType
 from pyspark.sql.dataframe import Column, _to_java_column, _to_seq
 
 
-__all__ = ['countDistinct', 'approxCountDistinct', 'udf']
+__all__ = [
+    'approxCountDistinct',
+    'countDistinct',
+    'monotonicallyIncreasingId',
+    'rand',
+    'randn',
+    'sparkPartitionId',
+    'udf']
 
 
 def _create_function(name, doc=""):
@@ -74,27 +81,21 @@ __all__ += _functions.keys()
 __all__.sort()
 
 
-def rand(seed=None):
-    """
-    Generate a random column with i.i.d. samples from U[0.0, 1.0].
-    """
-    sc = SparkContext._active_spark_context
-    if seed:
-        jc = sc._jvm.functions.rand(seed)
-    else:
-        jc = sc._jvm.functions.rand()
-    return Column(jc)
+def array(*cols):
+    """Creates a new array column.
 
+    :param cols: list of column names (string) or list of :class:`Column` expressions that have
+        the same data type.
 
-def randn(seed=None):
-    """
-    Generate a column with i.i.d. samples from the standard normal distribution.
+    >>> df.select(array('age', 'age').alias("arr")).collect()
+    [Row(arr=[2, 2]), Row(arr=[5, 5])]
+    >>> df.select(array([df.age, df.age]).alias("arr")).collect()
+    [Row(arr=[2, 2]), Row(arr=[5, 5])]
     """
     sc = SparkContext._active_spark_context
-    if seed:
-        jc = sc._jvm.functions.randn(seed)
-    else:
-        jc = sc._jvm.functions.randn()
+    if len(cols) == 1 and isinstance(cols[0], (list, set)):
+        cols = cols[0]
+    jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column))
     return Column(jc)
 
 
@@ -146,6 +147,28 @@ def monotonicallyIncreasingId():
     return Column(sc._jvm.functions.monotonicallyIncreasingId())
 
 
+def rand(seed=None):
+    """Generates a random column with i.i.d. samples from U[0.0, 1.0].
+    """
+    sc = SparkContext._active_spark_context
+    if seed:
+        jc = sc._jvm.functions.rand(seed)
+    else:
+        jc = sc._jvm.functions.rand()
+    return Column(jc)
+
+
+def randn(seed=None):
+    """Generates a column with i.i.d. samples from the standard normal distribution.
+    """
+    sc = SparkContext._active_spark_context
+    if seed:
+        jc = sc._jvm.functions.randn(seed)
+    else:
+        jc = sc._jvm.functions.randn()
+    return Column(jc)
+
+
 def sparkPartitionId():
     """A column for partition ID of the Spark task.
 
@@ -158,6 +181,25 @@ def sparkPartitionId():
     return Column(sc._jvm.functions.sparkPartitionId())
 
 
+@ignore_unicode_prefix
+def struct(*cols):
+    """Creates a new struct column.
+
+    :param cols: list of column names (string) or list of :class:`Column` expressions
+        that are named or aliased.
+
+    >>> df.select(struct('age', 'name').alias("struct")).collect()
+    [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))]
+    >>> df.select(struct([df.age, df.name]).alias("struct")).collect()
+    [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))]
+    """
+    sc = SparkContext._active_spark_context
+    if len(cols) == 1 and isinstance(cols[0], (list, set)):
+        cols = cols[0]
+    jc = sc._jvm.functions.struct(_to_seq(sc, cols, _to_java_column))
+    return Column(jc)
+
+
 class UserDefinedFunction(object):
     """
     User defined function in Python

http://git-wip-us.apache.org/repos/asf/spark/blob/37537760/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 2225621..c6217f0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -28,13 +28,21 @@ import org.apache.spark.sql.catalyst.trees
  * the layout of intermediate tuples, BindReferences should be run after all such transformations.
  */
 case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
-  extends Expression with trees.LeafNode[Expression] {
+  extends NamedExpression with trees.LeafNode[Expression] {
 
   type EvaluatedType = Any
 
   override def toString: String = s"input[$ordinal]"
 
   override def eval(input: Row): Any = input(ordinal)
+
+  override def name: String = s"i[$ordinal]"
+
+  override def toAttribute: Attribute = throw new UnsupportedOperationException
+
+  override def qualifiers: Seq[String] = throw new UnsupportedOperationException
+
+  override def exprId: ExprId = throw new UnsupportedOperationException
 }
 
 object BindReferences extends Logging {

http://git-wip-us.apache.org/repos/asf/spark/blob/37537760/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 242e64d..7e28339 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag}
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, Star}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
@@ -284,6 +284,23 @@ object functions {
   def abs(e: Column): Column = Abs(e.expr)
 
   /**
+   * Creates a new array column. The input columns must all have the same data type.
+   *
+   * @group normal_funcs
+   */
+  @scala.annotation.varargs
+  def array(cols: Column*): Column = CreateArray(cols.map(_.expr))
+
+  /**
+   * Creates a new array column. The input columns must all have the same data type.
+   *
+   * @group normal_funcs
+   */
+  def array(colName: String, colNames: String*): Column = {
+    array((colName +: colNames).map(col) : _*)
+  }
+
+  /**
    * Returns the first column that is not null.
    * {{{
    *   df.select(coalesce(df("a"), df("b")))
@@ -391,6 +408,28 @@ object functions {
   def sqrt(e: Column): Column = Sqrt(e.expr)
 
   /**
+   * Creates a new struct column. The input column must be a column in a [[DataFrame]], or
+   * a derived column expression that is named (i.e. aliased).
+   *
+   * @group normal_funcs
+   */
+  @scala.annotation.varargs
+  def struct(cols: Column*): Column = {
+    require(cols.forall(_.expr.isInstanceOf[NamedExpression]),
+      s"struct input columns must all be named or aliased ($cols)")
+    CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression]))
+  }
+
+  /**
+   * Creates a new struct column that composes multiple input columns.
+   *
+   * @group normal_funcs
+   */
+  def struct(colName: String, colNames: String*): Column = {
+    struct((colName +: colNames).map(col) : _*)
+  }
+
+  /**
    * Converts a string expression to upper case.
    *
    * @group normal_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/37537760/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
new file mode 100644
index 0000000..ca03713
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
+import org.apache.spark.sql.types._
+
+/**
+ * Test suite for functions in [[org.apache.spark.sql.functions]].
+ */
+class DataFrameFunctionsSuite extends QueryTest {
+
+  test("array with column name") {
+    val df = Seq((0, 1)).toDF("a", "b")
+    val row = df.select(array("a", "b")).first()
+
+    val expectedType = ArrayType(IntegerType, containsNull = false)
+    assert(row.schema(0).dataType === expectedType)
+    assert(row.getAs[Seq[Int]](0) === Seq(0, 1))
+  }
+
+  test("array with column expression") {
+    val df = Seq((0, 1)).toDF("a", "b")
+    val row = df.select(array(col("a"), col("b") + col("b"))).first()
+
+    val expectedType = ArrayType(IntegerType, containsNull = false)
+    assert(row.schema(0).dataType === expectedType)
+    assert(row.getAs[Seq[Int]](0) === Seq(0, 2))
+  }
+
+  // Turn this on once we add a rule to the analyzer to throw a friendly exception
+  ignore("array: throw exception if putting columns of different types into an array") {
+    val df = Seq((0, "str")).toDF("a", "b")
+    intercept[AnalysisException] {
+      df.select(array("a", "b"))
+    }
+  }
+
+  test("struct with column name") {
+    val df = Seq((1, "str")).toDF("a", "b")
+    val row = df.select(struct("a", "b")).first()
+
+    val expectedType = StructType(Seq(
+      StructField("a", IntegerType, nullable = false),
+      StructField("b", StringType)
+    ))
+    assert(row.schema(0).dataType === expectedType)
+    assert(row.getAs[Row](0) === Row(1, "str"))
+  }
+
+  test("struct with column expression") {
+    val df = Seq((1, "str")).toDF("a", "b")
+    val row = df.select(struct((col("a") * 2).as("c"), col("b"))).first()
+
+    val expectedType = StructType(Seq(
+      StructField("c", IntegerType, nullable = false),
+      StructField("b", StringType)
+    ))
+    assert(row.schema(0).dataType === expectedType)
+    assert(row.getAs[Row](0) === Row(2, "str"))
+  }
+
+  test("struct: must use named column expression") {
+    intercept[IllegalArgumentException] {
+      struct(col("a") * 2)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org