You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2014/09/29 18:49:10 UTC

[3/4] git commit: [FLINK-1120] Add Explicit Partition/Rebalance Operator for Scala API

[FLINK-1120] Add Explicit Partition/Rebalance Operator for Scala API


Project: http://git-wip-us.apache.org/repos/asf/incubator-flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-flink/commit/66f236d2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-flink/tree/66f236d2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-flink/diff/66f236d2

Branch: refs/heads/master
Commit: 66f236d25003094911dcbf21b68f3ddc8a77f707
Parents: e2c0b9d
Author: Aljoscha Krettek <al...@gmail.com>
Authored: Mon Sep 29 15:09:28 2014 +0200
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Mon Sep 29 18:16:36 2014 +0200

----------------------------------------------------------------------
 .../org/apache/flink/api/scala/DataSet.scala    |  71 +++++++
 .../api/scala/ScalaAPICompletenessTest.scala    |  11 +-
 .../scala/operators/FirstNOperatorTest.scala    |   2 +-
 .../api/scala/operators/PartitionITCase.scala   | 213 +++++++++++++++++++
 .../test/javaApiOperators/PartitionITCase.java  |   2 +-
 5 files changed, 287 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/66f236d2/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
index d520f1f..2e15625 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
@@ -22,12 +22,14 @@ import org.apache.flink.api.common.InvalidProgramException
 import org.apache.flink.api.common.aggregators.Aggregator
 import org.apache.flink.api.common.functions._
 import org.apache.flink.api.common.io.{FileOutputFormat, OutputFormat}
+import org.apache.flink.api.common.operators.base.PartitionOperatorBase.PartitionMethod
 import org.apache.flink.api.java.aggregation.Aggregations
 import org.apache.flink.api.java.functions.{FirstReducer, KeySelector}
 import org.apache.flink.api.java.io.{PrintingOutputFormat, TextOutputFormat}
 import org.apache.flink.api.java.operators.JoinOperator.JoinHint
 import org.apache.flink.api.java.operators.Keys.FieldPositionKeys
 import org.apache.flink.api.java.operators._
+import org.apache.flink.api.java.typeutils.TypeExtractor
 import org.apache.flink.api.java.{DataSet => JavaDataSet}
 import org.apache.flink.api.scala.operators.{ScalaCsvOutputFormat, ScalaAggregateOperator}
 import org.apache.flink.core.fs.FileSystem.WriteMode
@@ -873,6 +875,75 @@ class DataSet[T: ClassTag](private[flink] val set: JavaDataSet[T]) {
   def union(other: DataSet[T]): DataSet[T] = wrap(new UnionOperator[T](set, other.set))
 
   // --------------------------------------------------------------------------------------------
+  //  Partitioning
+  // --------------------------------------------------------------------------------------------
+
+  /**
+   * Hash-partitions a DataSet on the specified tuple field positions.
+   *
+   * '''important:''' This operation shuffles the whole DataSet over the network and can take
+   * significant amount of time.
+   */
+  def partitionByHash(fields: Int*): DataSet[T] = {
+    val op = new PartitionOperator[T](
+      set,
+      PartitionMethod.HASH,
+      new Keys.FieldPositionKeys[T](fields.toArray, set.getType, false))
+    wrap(op)
+  }
+
+  /**
+   * Hash-partitions a DataSet on the specified fields.
+   *
+   * '''important:''' This operation shuffles the whole DataSet over the network and can take
+   * significant amount of time.
+   */
+  def partitionByHash(firstField: String, otherFields: String*): DataSet[T] = {
+    val fieldIndices = fieldNames2Indices(set.getType, firstField +: otherFields.toArray)
+
+    val op = new PartitionOperator[T](
+      set,
+      PartitionMethod.HASH,
+      new Keys.FieldPositionKeys[T](fieldIndices, set.getType, false))
+    wrap(op)
+  }
+
+  /**
+   * Partitions a DataSet using the specified key selector function.
+   *
+   * '''Important:'''This operation shuffles the whole DataSet over the network and can take
+   * significant amount of time.
+   */
+  def partitionByHash[K: TypeInformation](fun: (T) => K): DataSet[T] = {
+    val keyExtractor = new KeySelector[T, K] {
+      def getKey(in: T) = fun(in)
+    }
+    val op = new PartitionOperator[T](
+      set,
+      PartitionMethod.HASH,
+      new Keys.SelectorFunctionKeys[T, K](
+        keyExtractor,
+        set.getType,
+        implicitly[TypeInformation[K]]))
+    wrap(op)
+  }
+
+  /**
+   * Enforces a rebalancing of the DataSet, i.e., the DataSet is evenly distributed over all
+   * parallel instances of the
+   * following task. This can help to improve performance in case of heavy data skew and compute
+   * intensive operations.
+   *
+   * '''Important:''' This operation shuffles the whole DataSet over the network and can take
+   * significant amount of time.
+   *
+   * @return The rebalanced DataSet.
+   */
+  def rebalance(): DataSet[T] = {
+    wrap(new PartitionOperator[T](set, PartitionMethod.REBALANCE))
+  }
+
+  // --------------------------------------------------------------------------------------------
   //  Result writing
   // --------------------------------------------------------------------------------------------
 

http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/66f236d2/flink-scala/src/test/scala/org/apache/flink/api/scala/ScalaAPICompletenessTest.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/ScalaAPICompletenessTest.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/ScalaAPICompletenessTest.scala
index ba0f6f1..1aad4d1 100644
--- a/flink-scala/src/test/scala/org/apache/flink/api/scala/ScalaAPICompletenessTest.scala
+++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/ScalaAPICompletenessTest.scala
@@ -65,17 +65,8 @@ class ScalaAPICompletenessTest {
       "org.apache.flink.api.java.DataSet.minBy",
       "org.apache.flink.api.java.DataSet.maxBy",
       "org.apache.flink.api.java.operators.UnsortedGrouping.minBy",
-      "org.apache.flink.api.java.operators.UnsortedGrouping.maxBy",
+      "org.apache.flink.api.java.operators.UnsortedGrouping.maxBy"
       
-      // Exclude first operator for now
-      "org.apache.flink.api.java.DataSet.first",
-      "org.apache.flink.api.java.operators.SortedGrouping.first",
-      "org.apache.flink.api.java.operators.UnsortedGrouping.first",
-      
-      // Exclude explicit rebalance and hashPartitionBy for now
-      "org.apache.flink.api.java.DataSet.partitionByHash",
-      "org.apache.flink.api.java.DataSet.rebalance"
-
     )
     val excludedPatterns = Seq(
       // We don't have project on tuples in the Scala API

http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/66f236d2/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/FirstNOperatorTest.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/FirstNOperatorTest.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/FirstNOperatorTest.scala
index 0f4e776..7c259b8 100644
--- a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/FirstNOperatorTest.scala
+++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/FirstNOperatorTest.scala
@@ -7,7 +7,7 @@
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
  *
- *     http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,

http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/66f236d2/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/PartitionITCase.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/PartitionITCase.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/PartitionITCase.scala
new file mode 100644
index 0000000..4c212f5
--- /dev/null
+++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/PartitionITCase.scala
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.operators
+
+import org.apache.flink.api.common.functions.{RichFilterFunction, RichMapFunction}
+import org.apache.flink.api.scala.ExecutionEnvironment
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.test.util.JavaProgramTestBase
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+import org.junit.runners.Parameterized.Parameters
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.flink.api.scala._
+
+
+object PartitionProgs {
+  var NUM_PROGRAMS: Int = 6
+
+  val tupleInput = Array(
+    (1, "Foo"),
+    (1, "Foo"),
+    (1, "Foo"),
+    (2, "Foo"),
+    (2, "Foo"),
+    (2, "Foo"),
+    (2, "Foo"),
+    (2, "Foo"),
+    (3, "Foo"),
+    (3, "Foo"),
+    (3, "Foo"),
+    (4, "Foo"),
+    (4, "Foo"),
+    (4, "Foo"),
+    (4, "Foo"),
+    (5, "Foo"),
+    (5, "Foo"),
+    (6, "Foo"),
+    (6, "Foo"),
+    (6, "Foo"),
+    (6, "Foo")
+  )
+
+
+  def runProgram(progId: Int, resultPath: String): String = {
+    progId match {
+      case 1 =>
+        val env = ExecutionEnvironment.getExecutionEnvironment
+        val ds = env.fromCollection(tupleInput)
+
+        val unique = ds.partitionByHash(0).mapPartition( _.map(_._1).toSet )
+
+        unique.writeAsText(resultPath)
+        env.execute()
+
+        "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n"
+
+      case 2 =>
+        val env = ExecutionEnvironment.getExecutionEnvironment
+        val ds = env.fromCollection(tupleInput)
+        val unique = ds.partitionByHash( _._1 ).mapPartition( _.map(_._1).toSet )
+
+        unique.writeAsText(resultPath)
+        env.execute()
+        "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n"
+
+      case 3 =>
+        val env = ExecutionEnvironment.getExecutionEnvironment
+        val ds = env.generateSequence(1, 3000)
+
+        val skewed = ds.filter(_ > 780)
+        val rebalanced = skewed.rebalance()
+
+        val countsInPartition = rebalanced.map( new RichMapFunction[Long, (Int, Long)] {
+          def map(in: Long) = {
+            (getRuntimeContext.getIndexOfThisSubtask, 1)
+          }
+        })
+          .groupBy(0)
+          .reduce { (v1, v2) => (v1._1, v1._2 + v2._2) }
+          // round counts to mitigate runtime scheduling effects (lazy split assignment)
+          .map { in => (in._1, in._2 / 10) }
+
+        countsInPartition.writeAsText(resultPath)
+        env.execute()
+
+        "(0,55)\n" + "(1,55)\n" + "(2,55)\n" + "(3,55)\n"
+
+      case 4 =>
+        // Verify that mapPartition operation after repartition picks up correct
+        // DOP
+        val env = ExecutionEnvironment.getExecutionEnvironment
+        val ds = env.fromCollection(tupleInput)
+        env.setDegreeOfParallelism(1)
+
+        val unique = ds.partitionByHash(0).setParallelism(4).mapPartition( _.map(_._1).toSet )
+
+        unique.writeAsText(resultPath)
+        env.execute()
+
+        "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n"
+
+      case 5 =>
+        // Verify that map operation after repartition picks up correct
+        // DOP
+        val env = ExecutionEnvironment.getExecutionEnvironment
+        val ds = env.fromCollection(tupleInput)
+        env.setDegreeOfParallelism(1)
+
+        val count = ds.partitionByHash(0).setParallelism(4).map(
+          new RichMapFunction[(Int, String), Tuple1[Int]] {
+            var first = true
+            override def map(in: (Int, String)): Tuple1[Int] = {
+              // only output one value with count 1
+              if (first) {
+                first = false
+                Tuple1(1)
+              } else {
+                Tuple1(0)
+              }
+            }
+          }).sum(0)
+
+        count.writeAsText(resultPath)
+        env.execute()
+
+        "(4)\n"
+
+      case 6 =>
+        // Verify that filter operation after repartition picks up correct
+        // DOP
+        val env = ExecutionEnvironment.getExecutionEnvironment
+        val ds = env.fromCollection(tupleInput)
+        env.setDegreeOfParallelism(1)
+
+        val count = ds.partitionByHash(0).setParallelism(4).filter(
+          new RichFilterFunction[(Int, String)] {
+            var first = true
+            override def filter(in: (Int, String)): Boolean = {
+              // only output one value with count 1
+              if (first) {
+                first = false
+                true
+              } else {
+                false
+              }
+            }
+        })
+          .map( _ => Tuple1(1)).sum(0)
+
+        count.writeAsText(resultPath)
+        env.execute()
+
+        "(4)\n"
+
+      case _ =>
+        throw new IllegalArgumentException("Invalid program id")
+    }
+  }
+}
+
+
+@RunWith(classOf[Parameterized])
+class PartitionITCase(config: Configuration) extends JavaProgramTestBase(config) {
+
+  private var curProgId: Int = config.getInteger("ProgramId", -1)
+  private var resultPath: String = null
+  private var expectedResult: String = null
+
+  protected override def preSubmit(): Unit = {
+    resultPath = getTempDirPath("result")
+  }
+
+  protected def testProgram(): Unit = {
+    expectedResult = PartitionProgs.runProgram(curProgId, resultPath)
+  }
+
+  protected override def postSubmit(): Unit = {
+    compareResultsByLinesInMemory(expectedResult, resultPath)
+  }
+}
+
+object PartitionITCase {
+  @Parameters
+  def getConfigurations: java.util.Collection[Array[AnyRef]] = {
+    val configs = mutable.MutableList[Array[AnyRef]]()
+    for (i <- 1 to PartitionProgs.NUM_PROGRAMS) {
+      val config = new Configuration()
+      config.setInteger("ProgramId", i)
+      configs += Array(config)
+    }
+
+    configs.asJavaCollection
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/66f236d2/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/PartitionITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/PartitionITCase.java b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/PartitionITCase.java
index b44c450..c937d9d 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/PartitionITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/PartitionITCase.java
@@ -209,7 +209,7 @@ public class PartitionITCase extends JavaProgramTestBase {
 				DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
 				DataSet<Long> uniqLongs = ds
 						.partitionByHash(1).setParallelism(4)
-						.mapPartition(new UniqueLongMapper()).setParallelism(4);
+						.mapPartition(new UniqueLongMapper());
 				uniqLongs.writeAsText(resultPath);
 				
 				env.execute();