You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/10/28 13:59:00 UTC

spark git commit: [SPARK-11313][SQL] implement cogroup on DataSets (support 2 datasets)

Repository: spark
Updated Branches:
  refs/heads/master 5f1cee6f1 -> 075ce4914


[SPARK-11313][SQL] implement cogroup on DataSets (support 2 datasets)

A simpler version of https://github.com/apache/spark/pull/9279, only support 2 datasets.

Author: Wenchen Fan <we...@databricks.com>

Closes #9324 from cloud-fan/cogroup2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/075ce491
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/075ce491
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/075ce491

Branch: refs/heads/master
Commit: 075ce4914fdcbbcc7286c3c30cb940ed28d474d2
Parents: 5f1cee6
Author: Wenchen Fan <we...@databricks.com>
Authored: Wed Oct 28 13:58:52 2015 +0100
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Wed Oct 28 13:58:52 2015 +0100

----------------------------------------------------------------------
 .../sql/catalyst/expressions/UnsafeRow.java     |  1 +
 .../catalyst/plans/logical/basicOperators.scala | 39 +++++++++
 .../org/apache/spark/sql/GroupedDataset.scala   | 20 +++++
 .../spark/sql/execution/CoGroupedIterator.scala | 89 ++++++++++++++++++++
 .../spark/sql/execution/SparkStrategies.scala   |  4 +
 .../spark/sql/execution/basicOperators.scala    | 41 +++++++++
 .../org/apache/spark/sql/DatasetSuite.scala     | 12 +++
 .../sql/execution/CoGroupedIteratorSuite.scala  | 51 +++++++++++
 8 files changed, 257 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/075ce491/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 850838a..5ba14eb 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -591,6 +591,7 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
       build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i)));
       build.append(',');
     }
+    build.deleteCharAt(build.length() - 1);
     build.append(']');
     return build.toString();
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/075ce491/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index d2d3db0..4cb67aa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -513,3 +513,42 @@ case class MapGroups[K, T, U](
   override def missingInput: AttributeSet = AttributeSet.empty
 }
 
+/** Factory for constructing new `CoGroup` nodes. */
+object CoGroup {
+  def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder](
+      func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
+      leftGroup: Seq[Attribute],
+      rightGroup: Seq[Attribute],
+      left: LogicalPlan,
+      right: LogicalPlan): CoGroup[K, Left, Right, R] = {
+    CoGroup(
+      func,
+      encoderFor[K],
+      encoderFor[Left],
+      encoderFor[Right],
+      encoderFor[R],
+      encoderFor[R].schema.toAttributes,
+      leftGroup,
+      rightGroup,
+      left,
+      right)
+  }
+}
+
+/**
+ * A relation produced by applying `func` to each grouping key and associated values from left and
+ * right children.
+ */
+case class CoGroup[K, Left, Right, R](
+    func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
+    kEncoder: ExpressionEncoder[K],
+    leftEnc: ExpressionEncoder[Left],
+    rightEnc: ExpressionEncoder[Right],
+    rEncoder: ExpressionEncoder[R],
+    output: Seq[Attribute],
+    leftGroup: Seq[Attribute],
+    rightGroup: Seq[Attribute],
+    left: LogicalPlan,
+    right: LogicalPlan) extends BinaryNode {
+  override def missingInput: AttributeSet = AttributeSet.empty
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/075ce491/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 89a16dd..612f2b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -65,4 +65,24 @@ class GroupedDataset[K, T] private[sql](
       sqlContext,
       MapGroups(f, groupingAttributes, logicalPlan))
   }
+
+  /**
+   * Applies the given function to each cogrouped data.  For each unique group, the function will
+   * be passed the grouping key and 2 iterators containing all elements in the group from
+   * [[Dataset]] `this` and `other`.  The function can return an iterator containing elements of an
+   * arbitrary type which will be returned as a new [[Dataset]].
+   */
+  def cogroup[U, R : Encoder](
+      other: GroupedDataset[K, U])(
+      f: (K, Iterator[T], Iterator[U]) => Iterator[R]): Dataset[R] = {
+    implicit def uEnc: Encoder[U] = other.tEncoder
+    new Dataset[R](
+      sqlContext,
+      CoGroup(
+        f,
+        this.groupingAttributes,
+        other.groupingAttributes,
+        this.logicalPlan,
+        other.logicalPlan))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/075ce491/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala
new file mode 100644
index 0000000..ce58278
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder, Attribute}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
+
+/**
+ * Iterates over [[GroupedIterator]]s and returns the cogrouped data, i.e. each record is a
+ * grouping key with its associated values from all [[GroupedIterator]]s.
+ * Note: we assume the output of each [[GroupedIterator]] is ordered by the grouping key.
+ */
+class CoGroupedIterator(
+    left: Iterator[(InternalRow, Iterator[InternalRow])],
+    right: Iterator[(InternalRow, Iterator[InternalRow])],
+    groupingSchema: Seq[Attribute])
+  extends Iterator[(InternalRow, Iterator[InternalRow], Iterator[InternalRow])] {
+
+  private val keyOrdering =
+    GenerateOrdering.generate(groupingSchema.map(SortOrder(_, Ascending)), groupingSchema)
+
+  private var currentLeftData: (InternalRow, Iterator[InternalRow]) = _
+  private var currentRightData: (InternalRow, Iterator[InternalRow]) = _
+
+  override def hasNext: Boolean = left.hasNext || right.hasNext
+
+  override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
+    if (currentLeftData.eq(null) && left.hasNext) {
+      currentLeftData = left.next()
+    }
+    if (currentRightData.eq(null) && right.hasNext) {
+      currentRightData = right.next()
+    }
+
+    assert(currentLeftData.ne(null) || currentRightData.ne(null))
+
+    if (currentLeftData.eq(null)) {
+      // left is null, right is not null, consume the right data.
+      rightOnly()
+    } else if (currentRightData.eq(null)) {
+      // left is not null, right is null, consume the left data.
+      leftOnly()
+    } else if (currentLeftData._1 == currentRightData._1) {
+      // left and right have the same grouping key, consume both of them.
+      val result = (currentLeftData._1, currentLeftData._2, currentRightData._2)
+      currentLeftData = null
+      currentRightData = null
+      result
+    } else {
+      val compare = keyOrdering.compare(currentLeftData._1, currentRightData._1)
+      assert(compare != 0)
+      if (compare < 0) {
+        // the grouping key of left is smaller, consume the left data.
+        leftOnly()
+      } else {
+        // the grouping key of right is smaller, consume the right data.
+        rightOnly()
+      }
+    }
+  }
+
+  private def leftOnly(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
+    val result = (currentLeftData._1, currentLeftData._2, Iterator.empty)
+    currentLeftData = null
+    result
+  }
+
+  private def rightOnly(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
+    val result = (currentRightData._1, Iterator.empty, currentRightData._2)
+    currentRightData = null
+    result
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/075ce491/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index ee97162..3206726 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -393,6 +393,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil
       case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) =>
         execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil
+      case logical.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output,
+        leftGroup, rightGroup, left, right) =>
+        execution.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, leftGroup, rightGroup,
+          planLater(left), planLater(right)) :: Nil
 
       case logical.Repartition(numPartitions, shuffle, child) =>
         if (shuffle) {

http://git-wip-us.apache.org/repos/asf/spark/blob/075ce491/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 8993847..d5a803f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -390,3 +390,44 @@ case class MapGroups[K, T, U](
     }
   }
 }
+
+/**
+ * Co-groups the data from left and right children, and calls the function with each group and 2
+ * iterators containing all elements in the group from left and right side.
+ * The result of this function is encoded and flattened before being output.
+ */
+case class CoGroup[K, Left, Right, R](
+    func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
+    kEncoder: ExpressionEncoder[K],
+    leftEnc: ExpressionEncoder[Left],
+    rightEnc: ExpressionEncoder[Right],
+    rEncoder: ExpressionEncoder[R],
+    output: Seq[Attribute],
+    leftGroup: Seq[Attribute],
+    rightGroup: Seq[Attribute],
+    left: SparkPlan,
+    right: SparkPlan) extends BinaryNode {
+
+  override def requiredChildDistribution: Seq[Distribution] =
+    ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+    leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
+      val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
+      val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
+      val groupKeyEncoder = kEncoder.bind(leftGroup)
+
+      new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
+        case (key, leftResult, rightResult) =>
+          val result = func(
+            groupKeyEncoder.fromRow(key),
+            leftResult.map(leftEnc.fromRow),
+            rightResult.map(rightEnc.fromRow))
+          result.map(rEncoder.toRow)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/075ce491/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index aebb390..993e6d2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -202,4 +202,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       agged,
       ("a", 30), ("b", 3), ("c", 1))
   }
+
+  test("cogroup") {
+    val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS()
+    val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS()
+    val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) =>
+      Iterator(key -> (data1.map(_._2).mkString + "#" + data2.map(_._2).mkString))
+    }
+
+    checkAnswer(
+      cogrouped,
+      1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er")
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/075ce491/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala
new file mode 100644
index 0000000..d1fe819
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper
+
+class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+  test("basic") {
+    val leftInput = Seq(create_row(1, "a"), create_row(1, "b"), create_row(2, "c")).iterator
+    val rightInput = Seq(create_row(1, 2L), create_row(2, 3L), create_row(3, 4L)).iterator
+    val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string))
+    val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long))
+    val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int))
+
+    val result = cogrouped.map {
+      case (key, leftData, rightData) =>
+        assert(key.numFields == 1)
+        (key.getInt(0), leftData.toSeq, rightData.toSeq)
+    }.toSeq
+    assert(result ==
+      (1,
+        Seq(create_row(1, "a"), create_row(1, "b")),
+        Seq(create_row(1, 2L))) ::
+      (2,
+        Seq(create_row(2, "c")),
+        Seq(create_row(2, 3L))) ::
+      (3,
+        Seq.empty,
+        Seq(create_row(3, 4L))) ::
+      Nil
+    )
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org