You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/10/20 17:08:16 UTC

spark git commit: [SPARK-15780][SQL] Support mapValues on KeyValueGroupedDataset

Repository: spark
Updated Branches:
  refs/heads/master fb0894b3a -> 84b245f2d


[SPARK-15780][SQL] Support mapValues on KeyValueGroupedDataset

## What changes were proposed in this pull request?

Add mapValues to KeyValueGroupedDataset

## How was this patch tested?

New test in DatasetSuite for groupBy function, mapValues, flatMap

Author: Koert Kuipers <ko...@tresata.com>

Closes #13526 from koertkuipers/feat-keyvaluegroupeddataset-mapvalues.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/84b245f2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/84b245f2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/84b245f2

Branch: refs/heads/master
Commit: 84b245f2dd31c1cebbf12458bf11f67e287e93f4
Parents: fb0894b
Author: Koert Kuipers <ko...@tresata.com>
Authored: Thu Oct 20 10:08:12 2016 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Oct 20 10:08:12 2016 -0700

----------------------------------------------------------------------
 .../sql/catalyst/plans/logical/object.scala     | 13 ++++++
 .../spark/sql/KeyValueGroupedDataset.scala      | 42 ++++++++++++++++++++
 .../org/apache/spark/sql/DatasetSuite.scala     | 11 +++++
 3 files changed, 66 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/84b245f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index fefe5a3..0ab4c90 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -230,6 +230,19 @@ object AppendColumns {
       encoderFor[U].namedExpressions,
       child)
   }
+
+  def apply[T : Encoder, U : Encoder](
+      func: T => U,
+      inputAttributes: Seq[Attribute],
+      child: LogicalPlan): AppendColumns = {
+    new AppendColumns(
+      func.asInstanceOf[Any => Any],
+      implicitly[Encoder[T]].clsTag.runtimeClass,
+      implicitly[Encoder[T]].schema,
+      UnresolvedDeserializer(encoderFor[T].deserializer, inputAttributes),
+      encoderFor[U].namedExpressions,
+      child)
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/84b245f2/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 828eb94..4cb0313 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -67,6 +67,48 @@ class KeyValueGroupedDataset[K, V] private[sql](
       groupingAttributes)
 
   /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied
+   * to the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create values grouped by key from a Dataset[(K, V)]
+   *   ds.groupByKey(_._1).mapValues(_._2) // Scala
+   * }}}
+   *
+   * @since 2.1.0
+   */
+  def mapValues[W : Encoder](func: V => W): KeyValueGroupedDataset[K, W] = {
+    val withNewData = AppendColumns(func, dataAttributes, logicalPlan)
+    val projected = Project(withNewData.newColumns ++ groupingAttributes, withNewData)
+    val executed = sparkSession.sessionState.executePlan(projected)
+
+    new KeyValueGroupedDataset(
+      encoderFor[K],
+      encoderFor[W],
+      executed,
+      withNewData.newColumns,
+      groupingAttributes)
+  }
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied
+   * to the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create Integer values grouped by String key from a Dataset<Tuple2<String, Integer>>
+   *   Dataset<Tuple2<String, Integer>> ds = ...;
+   *   KeyValueGroupedDataset<String, Integer> grouped =
+   *     ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT()); // Java 8
+   * }}}
+   *
+   * @since 2.1.0
+   */
+  def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = {
+    implicit val uEnc = encoder
+    mapValues { (v: V) => func.call(v) }
+  }
+
+  /**
    * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping
    * over the Dataset to extract the keys and then running a distinct operation on those.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/84b245f2/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 5fce9b4..cc367ac 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -336,6 +336,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       "a", "30", "b", "3", "c", "1")
   }
 
+  test("groupBy function, mapValues, flatMap") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    val keyValue = ds.groupByKey(_._1).mapValues(_._2)
+    val agged = keyValue.mapGroups { case (g, iter) => (g, iter.sum) }
+    checkDataset(agged, ("a", 30), ("b", 3), ("c", 1))
+
+    val keyValue1 = ds.groupByKey(t => (t._1, "key")).mapValues(t => (t._2, "value"))
+    val agged1 = keyValue1.mapGroups { case (g, iter) => (g._1, iter.map(_._1).sum) }
+    checkDataset(agged, ("a", 30), ("b", 3), ("c", 1))
+  }
+
   test("groupBy function, reduce") {
     val ds = Seq("abc", "xyz", "hello").toDS()
     val agged = ds.groupByKey(_.length).reduceGroups(_ + _)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org