You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/03/23 07:43:12 UTC

spark git commit: [SPARK-14088][SQL] Some Dataset API touch-up

Repository: spark
Updated Branches:
  refs/heads/master 1a22cf1e9 -> 926a93e54


[SPARK-14088][SQL] Some Dataset API touch-up

## What changes were proposed in this pull request?
1. Deprecated unionAll. It is pretty confusing to have both "union" and "unionAll" when the two do the same thing in Spark but are different in SQL.
2. Rename reduce in KeyValueGroupedDataset to reduceGroups so it is more consistent with rest of the functions in KeyValueGroupedDataset. Also makes it more obvious what "reduce" and "reduceGroups" mean. Previously it was confusing because it could be reducing a Dataset, or just reducing groups.
3. Added a "name" function, which is more natural to name columns than "as" for non-SQL users.
4. Remove "subtract" function since it is just an alias for "except".

## How was this patch tested?
All changes should be covered by existing tests. Also added couple test cases to cover "name".

Author: Reynold Xin <rx...@databricks.com>

Closes #11908 from rxin/SPARK-14088.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/926a93e5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/926a93e5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/926a93e5

Branch: refs/heads/master
Commit: 926a93e54b83f1ee596096f3301fef015705b627
Parents: 1a22cf1
Author: Reynold Xin <rx...@databricks.com>
Authored: Tue Mar 22 23:43:09 2016 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Mar 22 23:43:09 2016 -0700

----------------------------------------------------------------------
 project/MimaExcludes.scala                      |  1 +
 python/pyspark/sql/column.py                    |  2 ++
 python/pyspark/sql/dataframe.py                 | 14 +++++++--
 .../scala/org/apache/spark/sql/Column.scala     | 29 ++++++++++++++-----
 .../scala/org/apache/spark/sql/Dataset.scala    | 30 +++++++-------------
 .../spark/sql/KeyValueGroupedDataset.scala      | 11 ++-----
 .../org/apache/spark/sql/JavaDatasetSuite.java  |  4 +--
 .../spark/sql/ColumnExpressionSuite.scala       |  3 +-
 .../org/apache/spark/sql/DatasetSuite.scala     |  2 +-
 9 files changed, 56 insertions(+), 40 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/926a93e5/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 68e9c50..42eafcb 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -317,6 +317,7 @@ object MimaExcludes {
         ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.longRddToDataFrameHolder"),
         ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.intRddToDataFrameHolder"),
         ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"),
+        ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.subtract"),
 
         ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"),
         ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"),

http://git-wip-us.apache.org/repos/asf/spark/blob/926a93e5/python/pyspark/sql/column.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 19ec6fc..43e9bae 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -315,6 +315,8 @@ class Column(object):
             sc = SparkContext._active_spark_context
             return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias))))
 
+    name = copy_func(alias, sinceversion=2.0, doc=":func:`name` is an alias for :func:`alias`.")
+
     @ignore_unicode_prefix
     @since(1.3)
     def cast(self, dataType):

http://git-wip-us.apache.org/repos/asf/spark/blob/926a93e5/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 7e1854c..5cfc348 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -911,14 +911,24 @@ class DataFrame(object):
         """
         return self.groupBy().agg(*exprs)
 
+    @since(2.0)
+    def union(self, other):
+        """ Return a new :class:`DataFrame` containing union of rows in this
+        frame and another frame.
+
+        This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union
+        (that does deduplication of elements), use this function followed by a distinct.
+        """
+        return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
+
     @since(1.3)
     def unionAll(self, other):
         """ Return a new :class:`DataFrame` containing union of rows in this
         frame and another frame.
 
-        This is equivalent to `UNION ALL` in SQL.
+        .. note:: Deprecated in 2.0, use union instead.
         """
-        return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
+        return self.union(other)
 
     @since(1.3)
     def intersect(self, other):

http://git-wip-us.apache.org/repos/asf/spark/blob/926a93e5/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 622a62a..d64736e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -856,7 +856,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
    * @group expr_ops
    * @since 1.4.0
    */
-  def alias(alias: String): Column = as(alias)
+  def alias(alias: String): Column = name(alias)
 
   /**
    * Gives the column an alias.
@@ -871,12 +871,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
    * @group expr_ops
    * @since 1.3.0
    */
-  def as(alias: String): Column = withExpr {
-    expr match {
-      case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata))
-      case other => Alias(other, alias)()
-    }
-  }
+  def as(alias: String): Column = name(alias)
 
   /**
    * (Scala-specific) Assigns the given aliases to the results of a table generating function.
@@ -937,6 +932,26 @@ class Column(protected[sql] val expr: Expression) extends Logging {
   }
 
   /**
+   * Gives the column a name (alias).
+   * {{{
+   *   // Renames colA to colB in select output.
+   *   df.select($"colA".name("colB"))
+   * }}}
+   *
+   * If the current column has metadata associated with it, this metadata will be propagated
+   * to the new column.  If this not desired, use `as` with explicitly empty metadata.
+   *
+   * @group expr_ops
+   * @since 2.0.0
+   */
+  def name(alias: String): Column = withExpr {
+    expr match {
+      case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata))
+      case other => Alias(other, alias)()
+    }
+  }
+
+  /**
    * Casts the column to a different data type.
    * {{{
    *   // Casts colA to IntegerType.

http://git-wip-us.apache.org/repos/asf/spark/blob/926a93e5/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index be0dfe7..31864d6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1350,20 +1350,24 @@ class Dataset[T] private[sql](
    * @group typedrel
    * @since 2.0.0
    */
-  def unionAll(other: Dataset[T]): Dataset[T] = withTypedPlan {
-    // This breaks caching, but it's usually ok because it addresses a very specific use case:
-    // using union to union many files or partitions.
-    CombineUnions(Union(logicalPlan, other.logicalPlan))
-  }
+  @deprecated("use union()", "2.0.0")
+  def unionAll(other: Dataset[T]): Dataset[T] = union(other)
 
   /**
    * Returns a new [[Dataset]] containing union of rows in this Dataset and another Dataset.
    * This is equivalent to `UNION ALL` in SQL.
    *
+   * To do a SQL-style set union (that does deduplication of elements), use this function followed
+   * by a [[distinct]].
+   *
    * @group typedrel
    * @since 2.0.0
    */
-  def union(other: Dataset[T]): Dataset[T] = unionAll(other)
+  def union(other: Dataset[T]): Dataset[T] = withTypedPlan {
+    // This breaks caching, but it's usually ok because it addresses a very specific use case:
+    // using union to union many files or partitions.
+    CombineUnions(Union(logicalPlan, other.logicalPlan))
+  }
 
   /**
    * Returns a new [[Dataset]] containing rows only in both this Dataset and another Dataset.
@@ -1394,18 +1398,6 @@ class Dataset[T] private[sql](
   }
 
   /**
-   * Returns a new [[Dataset]] containing rows in this Dataset but not in another Dataset.
-   * This is equivalent to `EXCEPT` in SQL.
-   *
-   * Note that, equality checking is performed directly on the encoded representation of the data
-   * and thus is not affected by a custom `equals` function defined on `T`.
-   *
-   * @group typedrel
-   * @since 2.0.0
-   */
-  def subtract(other: Dataset[T]): Dataset[T] = except(other)
-
-  /**
    * Returns a new [[Dataset]] by sampling a fraction of rows.
    *
    * @param withReplacement Sample with replacement or not.
@@ -1756,7 +1748,7 @@ class Dataset[T] private[sql](
         outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c))
       }
 
-      val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
+      val row = groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
 
       // Pivot the data so each summary is one row
       row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/926a93e5/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index f0f9682..8bb75bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -190,7 +190,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
    *
    * @since 1.6.0
    */
-  def reduce(f: (V, V) => V): Dataset[(K, V)] = {
+  def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
     val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f)))
 
     implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder)
@@ -203,15 +203,10 @@ class KeyValueGroupedDataset[K, V] private[sql](
    *
    * @since 1.6.0
    */
-  def reduce(f: ReduceFunction[V]): Dataset[(K, V)] = {
-    reduce(f.call _)
+  def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = {
+    reduceGroups(f.call _)
   }
 
-  // This is here to prevent us from adding overloads that would be ambiguous.
-  @scala.annotation.varargs
-  private def agg(exprs: Column*): DataFrame =
-    groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*)
-
   private def withEncoder(c: Column): Column = c match {
     case tc: TypedColumn[_, _] =>
       tc.withInputType(resolvedVEncoder.bind(dataAttributes), dataAttributes)

http://git-wip-us.apache.org/repos/asf/spark/blob/926a93e5/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 3bff129..18f17a8 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -204,7 +204,7 @@ public class JavaDatasetSuite implements Serializable {
 
     Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList()));
 
-    Dataset<Tuple2<Integer, String>> reduced = grouped.reduce(new ReduceFunction<String>() {
+    Dataset<Tuple2<Integer, String>> reduced = grouped.reduceGroups(new ReduceFunction<String>() {
       @Override
       public String call(String v1, String v2) throws Exception {
         return v1 + v2;
@@ -300,7 +300,7 @@ public class JavaDatasetSuite implements Serializable {
       Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo", "abc", "abc", "xyz"),
       unioned.collectAsList());
 
-    Dataset<String> subtracted = ds.subtract(ds2);
+    Dataset<String> subtracted = ds.except(ds2);
     Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList());
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/926a93e5/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index c2434e4..351b03b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -105,10 +105,11 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
       Row("a") :: Nil)
   }
 
-  test("alias") {
+  test("alias and name") {
     val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
     assert(df.select(df("a").as("b")).columns.head === "b")
     assert(df.select(df("a").alias("b")).columns.head === "b")
+    assert(df.select(df("a").name("b")).columns.head === "b")
   }
 
   test("as propagates metadata") {

http://git-wip-us.apache.org/repos/asf/spark/blob/926a93e5/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 677f84e..0bcc512 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -305,7 +305,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
 
   test("groupBy function, reduce") {
     val ds = Seq("abc", "xyz", "hello").toDS()
-    val agged = ds.groupByKey(_.length).reduce(_ + _)
+    val agged = ds.groupByKey(_.length).reduceGroups(_ + _)
 
     checkDataset(
       agged,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org