You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2014/11/18 12:22:49 UTC
[1/4] incubator-flink git commit: [FLINK-1237] Add support for custom
partitioners - Functions: GroupReduce, Reduce, Aggregate on UnsortedGrouping,
SortedGrouping,
Join (Java API & Scala API) - Manual partition on DataSet (Java API & S
Repository: incubator-flink
Updated Branches:
refs/heads/master 83d02563e -> 2000b45ce
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningGroupingPojoTest.scala
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningGroupingPojoTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningGroupingPojoTest.scala
new file mode 100644
index 0000000..8ffba8e
--- /dev/null
+++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningGroupingPojoTest.scala
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.scala.operators.translation
+
+import org.junit.Assert._
+import org.junit.Test
+import org.apache.flink.api.scala._
+import org.apache.flink.api.common.functions.Partitioner
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType
+import org.apache.flink.compiler.plan.SingleInputPlanNode
+import org.apache.flink.test.compiler.util.CompilerTestBase
+import scala.collection.immutable.Seq
+import org.apache.flink.api.common.operators.Order
+import org.apache.flink.api.common.InvalidProgramException
+
+
+class CustomPartitioningGroupingPojoTest extends CompilerTestBase {
+
+ @Test
+ def testCustomPartitioningTupleReduce() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val data = env.fromElements(new Pojo2()).rebalance().setParallelism(4)
+
+ data
+ .groupBy("a").withPartitioner(new TestPartitionerInt())
+ .reduce( (a,b) => a )
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningTupleGroupReduce() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements(new Pojo2()).rebalance().setParallelism(4)
+
+ data
+ .groupBy("a").withPartitioner(new TestPartitionerInt())
+ .reduceGroup( iter => Seq(iter.next) )
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningTupleGroupReduceSorted() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements(new Pojo3()).rebalance().setParallelism(4)
+
+ data
+ .groupBy("a").withPartitioner(new TestPartitionerInt())
+ .sortGroup("b", Order.ASCENDING)
+ .reduceGroup( iter => Seq(iter.next) )
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningTupleGroupReduceSorted2() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements(new Pojo4()).rebalance().setParallelism(4)
+
+ data
+ .groupBy("a").withPartitioner(new TestPartitionerInt())
+ .sortGroup("b", Order.ASCENDING)
+ .sortGroup("c", Order.DESCENDING)
+ .reduceGroup( iter => Seq(iter.next) )
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningTupleInvalidType() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements(new Pojo2()).rebalance().setParallelism(4)
+
+ try {
+ data.groupBy("a").withPartitioner(new TestPartitionerLong())
+ fail("Should throw an exception")
+ }
+ catch {
+ case e: InvalidProgramException =>
+ }
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningTupleInvalidTypeSorted() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements(new Pojo3()).rebalance().setParallelism(4)
+
+ try {
+ data
+ .groupBy("a")
+ .sortGroup("b", Order.ASCENDING)
+ .withPartitioner(new TestPartitionerLong())
+ fail("Should throw an exception")
+ }
+ catch {
+ case e: InvalidProgramException =>
+ }
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningTupleRejectCompositeKey() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val data = env.fromElements(new Pojo2()).rebalance().setParallelism(4)
+ try {
+ data.groupBy("a", "b").withPartitioner(new TestPartitionerInt())
+ fail("Should throw an exception")
+ } catch {
+ case e: InvalidProgramException =>
+ }
+ } catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ //-----------------------------------------------------------------------------------------------
+
+ class Pojo2 {
+
+ var a: Int = _
+ var b: Int = _
+ }
+
+ class Pojo3 {
+
+ var a: Int = _
+ var b: Int = _
+ var c: Int = _
+ }
+
+ class Pojo4 {
+
+ var a: Int = _
+ var b: Int = _
+ var c: Int = _
+ var d: Int = _
+ }
+
+ private class TestPartitionerInt extends Partitioner[Int] {
+
+ override def partition(key: Int, numPartitions: Int): Int = 0
+ }
+
+ private class TestPartitionerLong extends Partitioner[Long] {
+
+ override def partition(key: Long, numPartitions: Int): Int = 0
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningGroupingTupleTest.scala
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningGroupingTupleTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningGroupingTupleTest.scala
new file mode 100644
index 0000000..b5f266f
--- /dev/null
+++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningGroupingTupleTest.scala
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.scala.operators.translation
+
+import org.junit.Assert._
+import org.junit.Test
+
+import org.apache.flink.api.common.functions.Partitioner
+import org.apache.flink.api.scala._
+import org.apache.flink.test.compiler.util.CompilerTestBase
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType
+import org.apache.flink.compiler.plan.SingleInputPlanNode
+import org.apache.flink.api.common.operators.Order
+import org.apache.flink.api.common.InvalidProgramException
+
+
+class CustomPartitioningGroupingTupleTest extends CompilerTestBase {
+
+ @Test
+ def testCustomPartitioningTupleAgg() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements((0, 0)).rebalance()
+
+ data.groupBy(0)
+ .withPartitioner(new TestPartitionerInt())
+ .sum(1)
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningTupleReduce() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0, 0) ).rebalance().setParallelism(4)
+
+ data
+ .groupBy(0).withPartitioner(new TestPartitionerInt())
+ .reduce( (a,b) => a )
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningTupleGroupReduce() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0, 0) ).rebalance().setParallelism(4)
+
+ data
+ .groupBy(0).withPartitioner(new TestPartitionerInt())
+ .reduceGroup( iter => Seq(iter.next) )
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningTupleGroupReduceSorted() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0, 0, 0) ).rebalance().setParallelism(4)
+
+ data
+ .groupBy(0).withPartitioner(new TestPartitionerInt())
+ .sortGroup(1, Order.ASCENDING)
+ .reduceGroup( iter => Seq(iter.next) )
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningTupleGroupReduceSorted2() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0, 0, 0, 0) ).rebalance().setParallelism(4)
+
+ data
+ .groupBy(0).withPartitioner(new TestPartitionerInt())
+ .sortGroup(1, Order.ASCENDING)
+ .sortGroup(2, Order.DESCENDING)
+ .reduceGroup( iter => Seq(iter.next) )
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningTupleInvalidType() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0, 0) ).rebalance().setParallelism(4)
+
+ try {
+ data.groupBy(0).withPartitioner(new TestPartitionerLong())
+ fail("Should throw an exception")
+ }
+ catch {
+ case e: InvalidProgramException =>
+ }
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningTupleInvalidTypeSorted() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val data = env.fromElements( (0, 0, 0) ).rebalance().setParallelism(4)
+ try {
+ data.groupBy(0).sortGroup(1, Order.ASCENDING).withPartitioner(new TestPartitionerLong())
+ fail("Should throw an exception")
+ } catch {
+ case e: InvalidProgramException =>
+ }
+ } catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningTupleRejectCompositeKey() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0, 0, 0) ).rebalance().setParallelism(4)
+ try {
+ data.groupBy(0, 1).withPartitioner(new TestPartitionerInt())
+ fail("Should throw an exception")
+ }
+ catch {
+ case e: InvalidProgramException =>
+ }
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------------
+
+ class TestPartitionerInt extends Partitioner[Int] {
+
+ override def partition(key: Int, numPartitions: Int): Int = 0
+ }
+
+ class TestPartitionerLong extends Partitioner[Long] {
+
+ override def partition(key: Long, numPartitions: Int): Int = 0
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningTest.scala
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningTest.scala
new file mode 100644
index 0000000..d4e438f
--- /dev/null
+++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningTest.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.scala.operators.translation
+
+import org.apache.flink.api.scala._
+import org.junit.Test
+import org.junit.Assert._
+import org.apache.flink.api.common.functions.Partitioner
+import org.apache.flink.test.compiler.util.CompilerTestBase
+import org.apache.flink.compiler.plan.SingleInputPlanNode
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType
+import org.apache.flink.api.common.InvalidProgramException
+
+
+class CustomPartitioningTest extends CompilerTestBase {
+
+ @Test
+ def testPartitionTuples() {
+ try {
+ val part = new TestPartitionerInt()
+ val parallelism = 4
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ env.setDegreeOfParallelism(parallelism)
+
+ val data = env.fromElements( (0,0) ).rebalance()
+
+ data.partitionCustom(part, 0)
+ .mapPartition( x => x )
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val mapper = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val partitioner = mapper.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val balancer = partitioner.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(parallelism, sink.getDegreeOfParallelism)
+
+ assertEquals(ShipStrategyType.FORWARD, mapper.getInput.getShipStrategy)
+ assertEquals(parallelism, mapper.getDegreeOfParallelism)
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner.getInput.getShipStrategy)
+ assertEquals(part, partitioner.getInput.getPartitioner)
+ assertEquals(parallelism, partitioner.getDegreeOfParallelism)
+
+ assertEquals(ShipStrategyType.PARTITION_FORCED_REBALANCE, balancer.getInput.getShipStrategy)
+ assertEquals(parallelism, balancer.getDegreeOfParallelism)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testPartitionTuplesInvalidType() {
+ try {
+ val parallelism = 4
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ env.setDegreeOfParallelism(parallelism)
+
+ val data = env.fromElements( (0,0) ).rebalance()
+ try {
+ data.partitionCustom(new TestPartitionerLong(), 0)
+ fail("Should throw an exception")
+ }
+ catch {
+ case e: InvalidProgramException =>
+ }
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testPartitionPojo() {
+ try {
+ val part = new TestPartitionerInt()
+ val parallelism = 4
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ env.setDegreeOfParallelism(parallelism)
+
+ val data = env.fromElements(new Pojo()).rebalance()
+
+ data
+ .partitionCustom(part, "a")
+ .mapPartition( x => x)
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val mapper = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val partitioner = mapper.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val balancer = partitioner.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(parallelism, sink.getDegreeOfParallelism)
+
+ assertEquals(ShipStrategyType.FORWARD, mapper.getInput.getShipStrategy)
+ assertEquals(parallelism, mapper.getDegreeOfParallelism)
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner.getInput.getShipStrategy)
+ assertEquals(part, partitioner.getInput.getPartitioner)
+ assertEquals(parallelism, partitioner.getDegreeOfParallelism)
+
+ assertEquals(ShipStrategyType.PARTITION_FORCED_REBALANCE, balancer.getInput.getShipStrategy)
+ assertEquals(parallelism, balancer.getDegreeOfParallelism)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testPartitionPojoInvalidType() {
+ try {
+ val parallelism = 4
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ env.setDegreeOfParallelism(parallelism)
+
+ val data = env.fromElements(new Pojo()).rebalance()
+
+ try {
+ data.partitionCustom(new TestPartitionerLong(), "a")
+ fail("Should throw an exception")
+ }
+ catch {
+ case e: InvalidProgramException =>
+ }
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testPartitionKeySelector() {
+ try {
+ val part = new TestPartitionerInt()
+ val parallelism = 4
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ env.setDegreeOfParallelism(parallelism)
+
+ val data = env.fromElements(new Pojo()).rebalance()
+
+ data
+ .partitionCustom(part, pojo => pojo.a)
+ .mapPartition( x => x)
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val mapper = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val keyRemover = mapper.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val partitioner = keyRemover.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val keyExtractor = partitioner.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val balancer = keyExtractor.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(parallelism, sink.getDegreeOfParallelism)
+
+ assertEquals(ShipStrategyType.FORWARD, mapper.getInput.getShipStrategy)
+ assertEquals(parallelism, mapper.getDegreeOfParallelism)
+
+ assertEquals(ShipStrategyType.FORWARD, keyRemover.getInput.getShipStrategy)
+ assertEquals(parallelism, keyRemover.getDegreeOfParallelism)
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner.getInput.getShipStrategy)
+ assertEquals(part, partitioner.getInput.getPartitioner)
+ assertEquals(parallelism, partitioner.getDegreeOfParallelism)
+
+ assertEquals(ShipStrategyType.FORWARD, keyExtractor.getInput.getShipStrategy)
+ assertEquals(parallelism, keyExtractor.getDegreeOfParallelism)
+
+ assertEquals(ShipStrategyType.PARTITION_FORCED_REBALANCE, balancer.getInput.getShipStrategy)
+ assertEquals(parallelism, balancer.getDegreeOfParallelism)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------------
+
+ class Pojo {
+
+ var a: Int = _
+ var b: Long = _
+ }
+
+ class TestPartitionerInt extends Partitioner[Int] {
+
+ override def partition(key: Int, numPartitions: Int): Int = 0
+ }
+
+ class TestPartitionerLong extends Partitioner[Long] {
+
+ override def partition(key: Long, numPartitions: Int): Int = 0
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/JoinCustomPartitioningTest.scala
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/JoinCustomPartitioningTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/JoinCustomPartitioningTest.scala
new file mode 100644
index 0000000..debd48d
--- /dev/null
+++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/JoinCustomPartitioningTest.scala
@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.scala.operators.translation
+
+import org.junit.Assert._
+import org.junit.Test
+import org.apache.flink.api.common.functions.Partitioner
+import org.apache.flink.api.scala._
+import org.apache.flink.test.compiler.util.CompilerTestBase
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType
+import org.apache.flink.compiler.plan.SingleInputPlanNode
+import org.apache.flink.api.common.operators.Order
+import org.apache.flink.api.common.InvalidProgramException
+import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint
+import org.apache.flink.compiler.plan.DualInputPlanNode
+
+class JoinCustomPartitioningTest extends CompilerTestBase {
+
+ @Test
+ def testJoinWithTuples() {
+ try {
+ val partitioner = new TestPartitionerLong()
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val input1 = env.fromElements( (0L, 0L) )
+ val input2 = env.fromElements( (0L, 0L, 0L) )
+
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST)
+ .where(1).equalTo(0)
+ .withPartitioner(partitioner)
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val join = sink.getInput.getSource.asInstanceOf[DualInputPlanNode]
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2.getShipStrategy)
+ assertEquals(partitioner, join.getInput1.getPartitioner)
+ assertEquals(partitioner, join.getInput2.getPartitioner)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testJoinWithTuplesWrongType() {
+ try {
+ val partitioner = new TestPartitionerInt()
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val input1 = env.fromElements( (0L, 0L) )
+ val input2 = env.fromElements( (0L, 0L, 0L) )
+
+ try {
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST)
+ .where(1).equalTo(0)
+ .withPartitioner(partitioner)
+ fail("should throw an exception")
+ }
+ catch {
+ case e: InvalidProgramException =>
+ }
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testJoinWithPojos() {
+ try {
+ val partitioner = new TestPartitionerInt()
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val input1 = env.fromElements(new Pojo2())
+ val input2 = env.fromElements(new Pojo3())
+
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST)
+ .where("b").equalTo("a")
+ .withPartitioner(partitioner)
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val join = sink.getInput.getSource.asInstanceOf[DualInputPlanNode]
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2.getShipStrategy)
+ assertEquals(partitioner, join.getInput1.getPartitioner)
+ assertEquals(partitioner, join.getInput2.getPartitioner)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testJoinWithPojosWrongType() {
+ try {
+ val partitioner = new TestPartitionerLong()
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val input1 = env.fromElements(new Pojo2())
+ val input2 = env.fromElements(new Pojo3())
+
+ try {
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST)
+ .where("a").equalTo("b")
+ .withPartitioner(partitioner)
+ fail("should throw an exception")
+ }
+ catch {
+ case e: InvalidProgramException =>
+ }
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testJoinWithKeySelectors() {
+ try {
+ val partitioner = new TestPartitionerInt()
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val input1 = env.fromElements(new Pojo2())
+ val input2 = env.fromElements(new Pojo3())
+
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST)
+ .where( _.a ).equalTo( _.b )
+ .withPartitioner(partitioner)
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val join = sink.getInput.getSource.asInstanceOf[DualInputPlanNode]
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2.getShipStrategy)
+ assertEquals(partitioner, join.getInput1.getPartitioner)
+ assertEquals(partitioner, join.getInput2.getPartitioner)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testJoinWithKeySelectorsWrongType() {
+ try {
+ val partitioner = new TestPartitionerLong()
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val input1 = env.fromElements(new Pojo2())
+ val input2 = env.fromElements(new Pojo3())
+
+ try {
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST)
+ .where( _.a ).equalTo( _.b )
+ .withPartitioner(partitioner)
+ fail("should throw an exception")
+ }
+ catch {
+ case e: InvalidProgramException =>
+ }
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+
+ // ----------------------------------------------------------------------------------------------
+
+ private class TestPartitionerInt extends Partitioner[Int] {
+
+ override def partition(key: Int, numPartitions: Int): Int = 0
+ }
+
+ private class TestPartitionerLong extends Partitioner[Long] {
+
+ override def partition(key: Long, numPartitions: Int): Int = 0
+ }
+
+ class Pojo2 {
+
+ var a: Int = _
+ var b: Int = _
+ }
+
+ class Pojo3 {
+
+ var a: Int = _
+ var b: Int = _
+ var c: Int = _
+ }
+}
[2/4] incubator-flink git commit: [FLINK-1237] Add support for custom
partitioners - Functions: GroupReduce, Reduce, Aggregate on UnsortedGrouping,
SortedGrouping,
Join (Java API & Scala API) - Manual partition on DataSet (Java API & S
Posted by se...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/AggregateOperator.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/AggregateOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/AggregateOperator.java
index e906232..66821ae 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/AggregateOperator.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/AggregateOperator.java
@@ -208,10 +208,9 @@ public class AggregateOperator<IN> extends SingleInputOperator<IN, IN, Aggregate
po.setCombinable(true);
- // set input
po.setInput(input);
- // set dop
po.setDegreeOfParallelism(this.getParallelism());
+ po.setCustomPartitioner(grouping.getCustomPartitioner());
SingleInputSemanticProperties props = new SingleInputSemanticProperties();
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/DistinctOperator.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/DistinctOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/DistinctOperator.java
index 126949c..e60c7de 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/DistinctOperator.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/DistinctOperator.java
@@ -22,9 +22,11 @@ import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.operators.Operator;
+import org.apache.flink.api.common.operators.SingleInputSemanticProperties;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase;
import org.apache.flink.api.common.operators.base.MapOperatorBase;
+import org.apache.flink.api.common.operators.util.FieldSet;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
@@ -71,7 +73,7 @@ public class DistinctOperator<T> extends SingleInputOperator<T, T, DistinctOpera
}
- // FieldPositionKeys can only be applied on Tuples
+ // FieldPositionKeys can only be applied on Tuples and POJOs
if (keys instanceof Keys.ExpressionKeys && !(input.getType() instanceof CompositeType)) {
throw new InvalidProgramException("Distinction on field positions is only possible on composite type DataSets.");
}
@@ -84,7 +86,7 @@ public class DistinctOperator<T> extends SingleInputOperator<T, T, DistinctOpera
final RichGroupReduceFunction<T, T> function = new DistinctFunction<T>();
- String name = "Distinct at "+distinctLocationName;
+ String name = "Distinct at " + distinctLocationName;
if (keys instanceof Keys.ExpressionKeys) {
@@ -95,7 +97,19 @@ public class DistinctOperator<T> extends SingleInputOperator<T, T, DistinctOpera
po.setCombinable(true);
po.setInput(input);
- po.setDegreeOfParallelism(this.getParallelism());
+ po.setDegreeOfParallelism(getParallelism());
+
+ // make sure that distinct preserves the partitioning for the fields on which they operate
+ if (getType().isTupleType()) {
+ SingleInputSemanticProperties sProps = new SingleInputSemanticProperties();
+
+ for (int field : keys.computeLogicalKeyPositions()) {
+ sProps.setForwardedField(field, new FieldSet(field));
+ }
+
+ po.setSemanticProperties(sProps);
+ }
+
return po;
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java
index 327d12f..bef91ed 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java
@@ -113,10 +113,14 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
return this;
}
+ // --------------------------------------------------------------------------------------------
+ // Translation
+ // --------------------------------------------------------------------------------------------
+
@Override
- protected org.apache.flink.api.common.operators.base.GroupReduceOperatorBase<?, OUT, ?> translateToDataFlow(Operator<IN> input) {
+ protected GroupReduceOperatorBase<?, OUT, ?> translateToDataFlow(Operator<IN> input) {
- String name = getName() != null ? getName() : "GroupReduce at "+defaultName;
+ String name = getName() != null ? getName() : "GroupReduce at " + defaultName;
// distinguish between grouped reduce and non-grouped reduce
if (grouper == null) {
@@ -124,9 +128,8 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
UnaryOperatorInformation<IN, OUT> operatorInfo = new UnaryOperatorInformation<IN, OUT>(getInputType(), getResultType());
GroupReduceOperatorBase<IN, OUT, GroupReduceFunction<IN, OUT>> po =
new GroupReduceOperatorBase<IN, OUT, GroupReduceFunction<IN, OUT>>(function, operatorInfo, new int[0], name);
-
+
po.setCombinable(combinable);
- // set input
po.setInput(input);
// the degree of parallelism for a non grouped reduce can only be 1
po.setDegreeOfParallelism(1);
@@ -141,7 +144,8 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
PlanUnwrappingReduceGroupOperator<IN, OUT, ?> po = translateSelectorFunctionReducer(
selectorKeys, function, getInputType(), getResultType(), name, input, isCombinable());
- po.setDegreeOfParallelism(this.getParallelism());
+ po.setDegreeOfParallelism(getParallelism());
+ po.setCustomPartitioner(grouper.getCustomPartitioner());
return po;
}
@@ -154,7 +158,8 @@ public class GroupReduceOperator<IN, OUT> extends SingleInputUdfOperator<IN, OUT
po.setCombinable(combinable);
po.setInput(input);
- po.setDegreeOfParallelism(this.getParallelism());
+ po.setDegreeOfParallelism(getParallelism());
+ po.setCustomPartitioner(grouper.getCustomPartitioner());
// set group order
if (grouper instanceof SortedGrouping) {
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/Grouping.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/Grouping.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/Grouping.java
index 36a364e..3c0d07f 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/Grouping.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/Grouping.java
@@ -19,7 +19,7 @@
package org.apache.flink.api.java.operators;
import org.apache.flink.api.common.InvalidProgramException;
-
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.java.DataSet;
/**
@@ -40,7 +40,10 @@ public abstract class Grouping<T> {
protected final DataSet<T> dataSet;
protected final Keys<T> keys;
+
+ protected Partitioner<?> customPartitioner;
+
public Grouping(DataSet<T> set, Keys<T> keys) {
if (set == null || keys == null) {
throw new NullPointerException();
@@ -62,5 +65,14 @@ public abstract class Grouping<T> {
public Keys<T> getKeys() {
return this.keys;
}
-
+
+ /**
+ * Gets the custom partitioner to be used for this grouping, or {@code null}, if
+ * none was defined.
+ *
+ * @return The custom partitioner to be used for this grouping.
+ */
+ public Partitioner<?> getCustomPartitioner() {
+ return this.customPartitioner;
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java
index 93e0371..21534f1 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java
@@ -25,6 +25,7 @@ import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichFlatJoinFunction;
import org.apache.flink.api.common.operators.BinaryOperatorInformation;
import org.apache.flink.api.common.operators.DualInputSemanticProperties;
@@ -46,13 +47,16 @@ import org.apache.flink.api.java.operators.translation.PlanBothUnwrappingJoinOpe
import org.apache.flink.api.java.operators.translation.PlanLeftUnwrappingJoinOperator;
import org.apache.flink.api.java.operators.translation.PlanRightUnwrappingJoinOperator;
import org.apache.flink.api.java.operators.translation.WrappingFunction;
-//CHECKSTYLE.OFF: AvoidStarImport - Needed for TupleGenerator
-import org.apache.flink.api.java.tuple.*;
-//CHECKSTYLE.ON: AvoidStarImport
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.util.Collector;
+//CHECKSTYLE.OFF: AvoidStarImport - Needed for TupleGenerator
+import org.apache.flink.api.java.tuple.*;
+
+import com.google.common.base.Preconditions;
+//CHECKSTYLE.ON: AvoidStarImport
+
/**
* A {@link DataSet} that is the result of a Join transformation.
*
@@ -69,14 +73,25 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
private final JoinHint joinHint;
+ private Partitioner<?> customPartitioner;
+
+
protected JoinOperator(DataSet<I1> input1, DataSet<I2> input2,
Keys<I1> keys1, Keys<I2> keys2,
TypeInformation<OUT> returnType, JoinHint hint)
{
super(input1, input2, returnType);
- if (keys1 == null || keys2 == null) {
- throw new NullPointerException();
+ Preconditions.checkNotNull(keys1);
+ Preconditions.checkNotNull(keys2);
+
+ try {
+ if (!keys1.areCompatible(keys2)) {
+ throw new InvalidProgramException("The types of the key fields do not match.");
+ }
+ }
+ catch (IncompatibleKeysException ike) {
+ throw new InvalidProgramException("The types of the key fields do not match: " + ike.getMessage(), ike);
}
// sanity check solution set key mismatches
@@ -110,10 +125,43 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
return this.keys2;
}
+ /**
+ * Gets the JoinHint that describes how the join is executed.
+ *
+ * @return The JoinHint.
+ */
public JoinHint getJoinHint() {
return this.joinHint;
}
+ /**
+ * Sets a custom partitioner for this join. The partitioner will be called on the join keys to determine
+ * the partition a key should be assigned to. The partitioner is evaluated on both join inputs in the
+ * same way.
+ * <p>
+ * NOTE: A custom partitioner can only be used with single-field join keys, not with composite join keys.
+ *
+ * @param partitioner The custom partitioner to be used.
+ * @return This join operator, to allow for function chaining.
+ */
+ public JoinOperator<I1, I2, OUT> withPartitioner(Partitioner<?> partitioner) {
+ if (partitioner != null) {
+ keys1.validateCustomPartitioner(partitioner, null);
+ keys2.validateCustomPartitioner(partitioner, null);
+ }
+ this.customPartitioner = partitioner;
+ return this;
+ }
+
+ /**
+ * Gets the custom partitioner used by this join, or {@code null}, if none is set.
+ *
+ * @return The custom partitioner used by this join;
+ */
+ public Partitioner<?> getPartitioner() {
+ return customPartitioner;
+ }
+
// --------------------------------------------------------------------------------------------
// special join types
// --------------------------------------------------------------------------------------------
@@ -206,30 +254,20 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
// }
@Override
- protected JoinOperatorBase<?, ?, OUT, ?> translateToDataFlow(
- Operator<I1> input1,
- Operator<I2> input2) {
+ protected JoinOperatorBase<?, ?, OUT, ?> translateToDataFlow(Operator<I1> input1, Operator<I2> input2) {
String name = getName() != null ? getName() : "Join at "+joinLocationName;
- try {
- keys1.areCompatible(super.keys2);
- } catch(IncompatibleKeysException ike) {
- throw new InvalidProgramException("The types of the key fields do not match.", ike);
- }
final JoinOperatorBase<?, ?, OUT, ?> translated;
- if (keys1 instanceof Keys.SelectorFunctionKeys
- && keys2 instanceof Keys.SelectorFunctionKeys) {
+ if (keys1 instanceof Keys.SelectorFunctionKeys && keys2 instanceof Keys.SelectorFunctionKeys) {
// Both join sides have a key selector function, so we need to do the
// tuple wrapping/unwrapping on both sides.
@SuppressWarnings("unchecked")
- Keys.SelectorFunctionKeys<I1, ?> selectorKeys1 =
- (Keys.SelectorFunctionKeys<I1, ?>) keys1;
+ Keys.SelectorFunctionKeys<I1, ?> selectorKeys1 = (Keys.SelectorFunctionKeys<I1, ?>) keys1;
@SuppressWarnings("unchecked")
- Keys.SelectorFunctionKeys<I2, ?> selectorKeys2 =
- (Keys.SelectorFunctionKeys<I2, ?>) keys2;
+ Keys.SelectorFunctionKeys<I2, ?> selectorKeys2 = (Keys.SelectorFunctionKeys<I2, ?>) keys2;
PlanBothUnwrappingJoinOperator<I1, I2, OUT, ?> po =
translateSelectorFunctionJoin(selectorKeys1, selectorKeys2, function,
@@ -304,6 +342,7 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
}
translated.setJoinHint(getJoinHint());
+ translated.setCustomPartitioner(getPartitioner());
return translated;
}
@@ -506,22 +545,6 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
out.collect (this.wrappedFunction.join(left, right));
}
}
-
- /*
- private static class GeneratedFlatJoinFunction<IN1, IN2, OUT> extends FlatJoinFunction<IN1, IN2, OUT> {
-
- private Joinable<IN1,IN2,OUT> function;
-
- private GeneratedFlatJoinFunction(Joinable<IN1, IN2, OUT> function) {
- this.function = function;
- }
-
- @Override
- public void join(IN1 first, IN2 second, Collector<OUT> out) throws Exception {
- out.collect(function.join(first, second));
- }
- }
- */
/**
* Initiates a ProjectJoin transformation and projects the first join input<br/>
@@ -933,32 +956,6 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
}
}
- public static final class LeftSemiFlatJoinFunction<T1, T2> extends RichFlatJoinFunction<T1, T2, T1> {
-
- private static final long serialVersionUID = 1L;
-
- @Override
- //public T1 join(T1 left, T2 right) throws Exception {
- // return left;
- //}
- public void join (T1 left, T2 right, Collector<T1> out) {
- out.collect(left);
- }
- }
-
- public static final class RightSemiFlatJoinFunction<T1, T2> extends RichFlatJoinFunction<T1, T2, T2> {
-
- private static final long serialVersionUID = 1L;
-
- @Override
- //public T2 join(T1 left, T2 right) throws Exception {
- // return right;
- //}
- public void join (T1 left, T2 right, Collector<T2> out) {
- out.collect(right);
- }
- }
-
public static final class JoinProjection<I1, I2> {
private final DataSet<I1> ds1;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java
index 46bbfab..c2a2a8e 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java
@@ -24,13 +24,17 @@ import java.util.LinkedList;
import java.util.List;
import com.google.common.base.Joiner;
+
import org.apache.flink.api.common.InvalidProgramException;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.typeinfo.AtomicType;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.CompositeType.FlatFieldDescriptor;
import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.typeutils.GenericTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -54,6 +58,8 @@ public abstract class Keys<T> {
public abstract int[] computeLogicalKeyPositions();
+ public abstract <E> void validateCustomPartitioner(Partitioner<E> partitioner, TypeInformation<E> typeInfo);
+
// --------------------------------------------------------------------------------------------
// Specializations for expression-based / extractor-based grouping
@@ -146,6 +152,27 @@ public abstract class Keys<T> {
public int[] computeLogicalKeyPositions() {
return logicalKeyFields;
}
+
+ @Override
+ public <E> void validateCustomPartitioner(Partitioner<E> partitioner, TypeInformation<E> typeInfo) {
+ if (logicalKeyFields.length != 1) {
+ throw new InvalidProgramException("Custom partitioners can only be used with keys that have one key field.");
+ }
+
+ if (typeInfo == null) {
+ try {
+ typeInfo = TypeExtractor.getPartitionerTypes(partitioner);
+ }
+ catch (Throwable t) {
+ // best effort check, so we ignore exceptions
+ }
+ }
+
+ if (typeInfo != null && !(typeInfo instanceof GenericTypeInfo) && (!keyType.equals(typeInfo))) {
+ throw new InvalidProgramException("The partitioner is imcompatible with the key type. "
+ + "Partitioner type: " + typeInfo + " , key type: " + keyType);
+ }
+ }
@Override
public String toString() {
@@ -299,12 +326,36 @@ public abstract class Keys<T> {
@Override
public int[] computeLogicalKeyPositions() {
- List<Integer> logicalKeys = new LinkedList<Integer>();
- for(FlatFieldDescriptor kd : keyFields) {
- logicalKeys.addAll( Ints.asList(kd.getPosition()));
+ List<Integer> logicalKeys = new ArrayList<Integer>();
+ for (FlatFieldDescriptor kd : keyFields) {
+ logicalKeys.add(kd.getPosition());
}
return Ints.toArray(logicalKeys);
}
+
+ @Override
+ public <E> void validateCustomPartitioner(Partitioner<E> partitioner, TypeInformation<E> typeInfo) {
+ if (keyFields.size() != 1) {
+ throw new InvalidProgramException("Custom partitioners can only be used with keys that have one key field.");
+ }
+
+ if (typeInfo == null) {
+ try {
+ typeInfo = TypeExtractor.getPartitionerTypes(partitioner);
+ }
+ catch (Throwable t) {
+ // best effort check, so we ignore exceptions
+ }
+ }
+
+ if (typeInfo != null && !(typeInfo instanceof GenericTypeInfo)) {
+ TypeInformation<?> keyType = keyFields.get(0).getType();
+ if (!keyType.equals(typeInfo)) {
+ throw new InvalidProgramException("The partitioner is incompatible with the key type. "
+ + "Partitioner type: " + typeInfo + " , key type: " + keyType);
+ }
+ }
+ }
@Override
public String toString() {
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java
index 77d5681..22d4d44 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java
@@ -19,6 +19,7 @@
package org.apache.flink.api.java.operators;
import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.operators.Operator;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.base.MapOperatorBase;
@@ -32,6 +33,8 @@ import org.apache.flink.api.java.operators.translation.KeyRemovingMapper;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import com.google.common.base.Preconditions;
+
/**
* This operator represents a partitioning.
*
@@ -42,66 +45,102 @@ public class PartitionOperator<T> extends SingleInputUdfOperator<T, T, Partition
private final Keys<T> pKeys;
private final PartitionMethod pMethod;
private final String partitionLocationName;
+ private final Partitioner<?> customPartitioner;
+
public PartitionOperator(DataSet<T> input, PartitionMethod pMethod, Keys<T> pKeys, String partitionLocationName) {
+ this(input, pMethod, pKeys, null, null, partitionLocationName);
+ }
+
+ public PartitionOperator(DataSet<T> input, PartitionMethod pMethod, String partitionLocationName) {
+ this(input, pMethod, null, null, null, partitionLocationName);
+ }
+
+ public PartitionOperator(DataSet<T> input, Keys<T> pKeys, Partitioner<?> customPartitioner, String partitionLocationName) {
+ this(input, PartitionMethod.CUSTOM, pKeys, customPartitioner, null, partitionLocationName);
+ }
+
+ public <P> PartitionOperator(DataSet<T> input, Keys<T> pKeys, Partitioner<P> customPartitioner,
+ TypeInformation<P> partitionerTypeInfo, String partitionLocationName)
+ {
+ this(input, PartitionMethod.CUSTOM, pKeys, customPartitioner, partitionerTypeInfo, partitionLocationName);
+ }
+
+ private <P> PartitionOperator(DataSet<T> input, PartitionMethod pMethod, Keys<T> pKeys, Partitioner<P> customPartitioner,
+ TypeInformation<P> partitionerTypeInfo, String partitionLocationName)
+ {
super(input, input.getType());
- this.partitionLocationName = partitionLocationName;
-
- if(pMethod == PartitionMethod.HASH && pKeys == null) {
- throw new IllegalArgumentException("Hash Partitioning requires keys");
- } else if(pMethod == PartitionMethod.RANGE) {
- throw new UnsupportedOperationException("Range Partitioning not yet supported");
+
+ Preconditions.checkNotNull(pMethod);
+ Preconditions.checkArgument(pKeys != null || pMethod == PartitionMethod.REBALANCE, "Partitioning requires keys");
+ Preconditions.checkArgument(pMethod != PartitionMethod.CUSTOM || customPartitioner != null, "Custom partioning requires a partitioner.");
+ Preconditions.checkArgument(pMethod != PartitionMethod.RANGE, "Range partitioning is not yet supported");
+
+ if (pKeys instanceof Keys.ExpressionKeys<?> && !(input.getType() instanceof CompositeType) ) {
+ throw new IllegalArgumentException("Hash Partitioning with key fields only possible on Tuple or POJO DataSets");
}
- if(pKeys instanceof Keys.ExpressionKeys<?> && !(input.getType() instanceof CompositeType) ) {
- throw new IllegalArgumentException("Hash Partitioning with key fields only possible on Composite-type DataSets");
+ if (customPartitioner != null) {
+ pKeys.validateCustomPartitioner(customPartitioner, partitionerTypeInfo);
}
this.pMethod = pMethod;
this.pKeys = pKeys;
+ this.partitionLocationName = partitionLocationName;
+ this.customPartitioner = customPartitioner;
}
- public PartitionOperator(DataSet<T> input, PartitionMethod pMethod, String partitionLocationName) {
- this(input, pMethod, null, partitionLocationName);
- }
+ // --------------------------------------------------------------------------------------------
+ // Properties
+ // --------------------------------------------------------------------------------------------
- /*
- * Translation of partitioning
+ /**
+ * Gets the custom partitioner from this partitioning.
+ *
+ * @return The custom partitioner.
*/
+ public Partitioner<?> getCustomPartitioner() {
+ return customPartitioner;
+ }
+
+ // --------------------------------------------------------------------------------------------
+ // Translation
+ // --------------------------------------------------------------------------------------------
+
protected org.apache.flink.api.common.operators.SingleInputOperator<?, T, ?> translateToDataFlow(Operator<T> input) {
- String name = "Partition at "+partitionLocationName;
+ String name = "Partition at " + partitionLocationName;
// distinguish between partition types
if (pMethod == PartitionMethod.REBALANCE) {
UnaryOperatorInformation<T, T> operatorInfo = new UnaryOperatorInformation<T, T>(getType(), getType());
PartitionOperatorBase<T> noop = new PartitionOperatorBase<T>(operatorInfo, pMethod, name);
- // set input
+
noop.setInput(input);
- // set DOP
noop.setDegreeOfParallelism(getParallelism());
return noop;
}
- else if (pMethod == PartitionMethod.HASH) {
+ else if (pMethod == PartitionMethod.HASH || pMethod == PartitionMethod.CUSTOM) {
if (pKeys instanceof Keys.ExpressionKeys) {
int[] logicalKeyPositions = pKeys.computeLogicalKeyPositions();
UnaryOperatorInformation<T, T> operatorInfo = new UnaryOperatorInformation<T, T>(getType(), getType());
PartitionOperatorBase<T> noop = new PartitionOperatorBase<T>(operatorInfo, pMethod, logicalKeyPositions, name);
- // set input
+
noop.setInput(input);
- // set DOP
noop.setDegreeOfParallelism(getParallelism());
+ noop.setCustomPartitioner(customPartitioner);
return noop;
- } else if (pKeys instanceof Keys.SelectorFunctionKeys) {
+ }
+ else if (pKeys instanceof Keys.SelectorFunctionKeys) {
@SuppressWarnings("unchecked")
Keys.SelectorFunctionKeys<T, ?> selectorKeys = (Keys.SelectorFunctionKeys<T, ?>) pKeys;
- MapOperatorBase<?, T, ?> po = translateSelectorFunctionReducer(selectorKeys, pMethod, getType(), name, input, getParallelism());
+ MapOperatorBase<?, T, ?> po = translateSelectorFunctionPartitioner(selectorKeys, pMethod, getType(), name, input, getParallelism(), customPartitioner);
return po;
}
else {
@@ -112,14 +151,13 @@ public class PartitionOperator<T> extends SingleInputUdfOperator<T, T, Partition
else if (pMethod == PartitionMethod.RANGE) {
throw new UnsupportedOperationException("Range partitioning not yet supported");
}
-
- return null;
+ else {
+ throw new UnsupportedOperationException("Unsupported partitioning method: " + pMethod.name());
+ }
}
-
- // --------------------------------------------------------------------------------------------
- private static <T, K> MapOperatorBase<Tuple2<K, T>, T, ?> translateSelectorFunctionReducer(Keys.SelectorFunctionKeys<T, ?> rawKeys,
- PartitionMethod pMethod, TypeInformation<T> inputType, String name, Operator<T> input, int partitionDop)
+ private static <T, K> MapOperatorBase<Tuple2<K, T>, T, ?> translateSelectorFunctionPartitioner(Keys.SelectorFunctionKeys<T, ?> rawKeys,
+ PartitionMethod pMethod, TypeInformation<T> inputType, String name, Operator<T> input, int partitionDop, Partitioner<?> customPartitioner)
{
@SuppressWarnings("unchecked")
final Keys.SelectorFunctionKeys<T, K> keys = (Keys.SelectorFunctionKeys<T, K>) rawKeys;
@@ -137,6 +175,8 @@ public class PartitionOperator<T> extends SingleInputUdfOperator<T, T, Partition
noop.setInput(keyExtractingMap);
keyRemovingMap.setInput(noop);
+ noop.setCustomPartitioner(customPartitioner);
+
// set dop
keyExtractingMap.setDegreeOfParallelism(input.getDegreeOfParallelism());
noop.setDegreeOfParallelism(partitionDop);
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/ReduceOperator.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/ReduceOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/ReduceOperator.java
index 7089cf6..02b0ede 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/ReduceOperator.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/ReduceOperator.java
@@ -87,9 +87,8 @@ public class ReduceOperator<IN> extends SingleInputUdfOperator<IN, IN, ReduceOpe
UnaryOperatorInformation<IN, IN> operatorInfo = new UnaryOperatorInformation<IN, IN>(getInputType(), getInputType());
ReduceOperatorBase<IN, ReduceFunction<IN>> po =
new ReduceOperatorBase<IN, ReduceFunction<IN>>(function, operatorInfo, new int[0], name);
- // set input
- po.setInput(input);
+ po.setInput(input);
// the degree of parallelism for a non grouped reduce can only be 1
po.setDegreeOfParallelism(1);
@@ -102,7 +101,9 @@ public class ReduceOperator<IN> extends SingleInputUdfOperator<IN, IN, ReduceOpe
@SuppressWarnings("unchecked")
Keys.SelectorFunctionKeys<IN, ?> selectorKeys = (Keys.SelectorFunctionKeys<IN, ?>) grouper.getKeys();
- MapOperatorBase<?, IN, ?> po = translateSelectorFunctionReducer(selectorKeys, function, getInputType(), name, input, this.getParallelism());
+ MapOperatorBase<?, IN, ?> po = translateSelectorFunctionReducer(selectorKeys, function, getInputType(), name, input, getParallelism());
+ ((PlanUnwrappingReduceOperator<?, ?>) po.getInput()).setCustomPartitioner(grouper.getCustomPartitioner());
+
return po;
}
else if (grouper.getKeys() instanceof Keys.ExpressionKeys) {
@@ -113,17 +114,16 @@ public class ReduceOperator<IN> extends SingleInputUdfOperator<IN, IN, ReduceOpe
ReduceOperatorBase<IN, ReduceFunction<IN>> po =
new ReduceOperatorBase<IN, ReduceFunction<IN>>(function, operatorInfo, logicalKeyPositions, name);
- // set input
+ po.setCustomPartitioner(grouper.getCustomPartitioner());
+
po.setInput(input);
- // set dop
- po.setDegreeOfParallelism(this.getParallelism());
+ po.setDegreeOfParallelism(getParallelism());
return po;
}
else {
throw new UnsupportedOperationException("Unrecognized key type.");
}
-
}
// --------------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/SortedGrouping.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/SortedGrouping.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/SortedGrouping.java
index 36d14ee..63e5a19 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/SortedGrouping.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/SortedGrouping.java
@@ -27,6 +27,7 @@ import java.util.Arrays;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.GroupReduceFunction;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.Keys.ExpressionKeys;
@@ -34,7 +35,6 @@ import org.apache.flink.api.java.typeutils.TypeExtractor;
import com.google.common.base.Preconditions;
-
/**
* SortedGrouping is an intermediate step for a transformation on a grouped and sorted DataSet.<br/>
* The following transformation can be applied on sorted groups:
@@ -84,6 +84,8 @@ public class SortedGrouping<T> extends Grouping<T> {
Arrays.fill(this.groupSortOrders, order); // if field == "*"
}
+ // --------------------------------------------------------------------------------------------
+
protected int[] getGroupSortKeyPositions() {
return this.groupSortKeyPositions;
}
@@ -91,6 +93,21 @@ public class SortedGrouping<T> extends Grouping<T> {
protected Order[] getGroupSortOrders() {
return this.groupSortOrders;
}
+
+ /**
+ * Uses a custom partitioner for the grouping.
+ *
+ * @param partitioner The custom partitioner.
+ * @return The grouping object itself, to allow for method chaining.
+ */
+ public SortedGrouping<T> withPartitioner(Partitioner<?> partitioner) {
+ Preconditions.checkNotNull(partitioner);
+
+ getKeys().validateCustomPartitioner(partitioner, null);
+
+ this.customPartitioner = partitioner;
+ return this;
+ }
/**
* Applies a GroupReduce transformation on a grouped and sorted {@link DataSet}.<br/>
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/operators/UnsortedGrouping.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/UnsortedGrouping.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/UnsortedGrouping.java
index b504e37..d323eae 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/UnsortedGrouping.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/UnsortedGrouping.java
@@ -20,6 +20,7 @@ package org.apache.flink.api.java.operators;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.GroupReduceFunction;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.typeinfo.TypeInformation;
@@ -32,11 +33,27 @@ import org.apache.flink.api.java.functions.SelectByMinFunction;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
+import com.google.common.base.Preconditions;
+
public class UnsortedGrouping<T> extends Grouping<T> {
public UnsortedGrouping(DataSet<T> set, Keys<T> keys) {
super(set, keys);
}
+
+ /**
+ * Uses a custom partitioner for the grouping.
+ *
+ * @param partitioner The custom partitioner.
+ * @return The grouping object itself, to allow for method chaining.
+ */
+ public UnsortedGrouping<T> withPartitioner(Partitioner<?> partitioner) {
+ Preconditions.checkNotNull(partitioner);
+ getKeys().validateCustomPartitioner(partitioner, null);
+
+ this.customPartitioner = partitioner;
+ return this;
+ }
// --------------------------------------------------------------------------------------------
// Operations / Transformations
@@ -213,7 +230,9 @@ public class UnsortedGrouping<T> extends Grouping<T> {
* @see Order
*/
public SortedGrouping<T> sortGroup(int field, Order order) {
- return new SortedGrouping<T>(this.dataSet, this.keys, field, order);
+ SortedGrouping<T> sg = new SortedGrouping<T>(this.dataSet, this.keys, field, order);
+ sg.customPartitioner = getCustomPartitioner();
+ return sg;
}
/**
@@ -228,7 +247,9 @@ public class UnsortedGrouping<T> extends Grouping<T> {
* @see Order
*/
public SortedGrouping<T> sortGroup(String field, Order order) {
- return new SortedGrouping<T>(this.dataSet, this.keys, field, order);
+ SortedGrouping<T> sg = new SortedGrouping<T>(this.dataSet, this.keys, field, order);
+ sg.customPartitioner = getCustomPartitioner();
+ return sg;
}
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvInputFormat.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvInputFormat.java b/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvInputFormat.java
index 15333e8..e3ad06f 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvInputFormat.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvInputFormat.java
@@ -20,6 +20,7 @@
package org.apache.flink.api.java.record.io;
import com.google.common.base.Preconditions;
+
import org.apache.flink.api.common.io.GenericCsvInputFormat;
import org.apache.flink.api.common.io.ParseException;
import org.apache.flink.api.common.operators.CompilerHints;
@@ -54,6 +55,7 @@ import java.io.IOException;
* @see Configuration
* @see Record
*/
+@SuppressWarnings("deprecation")
public class CsvInputFormat extends GenericCsvInputFormat<Record> {
private static final long serialVersionUID = 1L;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvOutputFormat.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvOutputFormat.java b/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvOutputFormat.java
index 2c514fe..a5d83c3 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvOutputFormat.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/record/io/CsvOutputFormat.java
@@ -16,7 +16,6 @@
* limitations under the License.
*/
-
package org.apache.flink.api.java.record.io;
import java.io.BufferedOutputStream;
@@ -52,6 +51,7 @@ import org.apache.flink.types.Value;
* @see Configuration
* @see Record
*/
+@SuppressWarnings("deprecation")
public class CsvOutputFormat extends FileOutputFormat {
private static final long serialVersionUID = 1L;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/record/io/DelimitedOutputFormat.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/record/io/DelimitedOutputFormat.java b/flink-java/src/main/java/org/apache/flink/api/java/record/io/DelimitedOutputFormat.java
index 0818f45..49d9a2a 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/record/io/DelimitedOutputFormat.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/record/io/DelimitedOutputFormat.java
@@ -16,7 +16,6 @@
* limitations under the License.
*/
-
package org.apache.flink.api.java.record.io;
@@ -27,10 +26,10 @@ import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.types.Record;
-
/**
* The base class for output formats that serialize their records into a delimited sequence.
*/
+@SuppressWarnings("deprecation")
public abstract class DelimitedOutputFormat extends FileOutputFormat {
private static final long serialVersionUID = 1L;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/record/operators/ReduceOperator.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/record/operators/ReduceOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/record/operators/ReduceOperator.java
index 5329a69..d09c2dc 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/record/operators/ReduceOperator.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/record/operators/ReduceOperator.java
@@ -53,6 +53,7 @@ import org.apache.flink.util.InstantiationUtil;
*
* @see ReduceFunction
*/
+@SuppressWarnings("deprecation")
public class ReduceOperator extends GroupReduceOperatorBase<Record, Record, GroupReduceFunction<Record, Record>> implements RecordOperator {
private static final String DEFAULT_NAME = "<Unnamed Reducer>"; // the default name for contracts
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java
index d52e1b0..33750b5 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java
@@ -42,6 +42,7 @@ import org.apache.flink.api.common.functions.InvalidTypesException;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.io.InputFormat;
import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
@@ -115,6 +116,10 @@ public class TypeExtractor {
return getUnaryOperatorReturnType((Function) selectorInterface, KeySelector.class, false, false, inType);
}
+ public static <T> TypeInformation<T> getPartitionerTypes(Partitioner<T> partitioner) {
+ return new TypeExtractor().privateCreateTypeInfo(Partitioner.class, partitioner.getClass(), 0, null, null);
+ }
+
@SuppressWarnings("unchecked")
public static <IN> TypeInformation<IN> getInputFormatTypes(InputFormat<IN, ?> inputFormatInterface) {
if(inputFormatInterface instanceof ResultTypeQueryable) {
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/ChannelSelector.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/ChannelSelector.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/ChannelSelector.java
index a62de77..c780f87 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/ChannelSelector.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/ChannelSelector.java
@@ -16,7 +16,6 @@
* limitations under the License.
*/
-
package org.apache.flink.runtime.io.network.api;
import org.apache.flink.core.io.IOReadableWritable;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java
index b39b402..c1037b5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java
@@ -25,6 +25,7 @@ import org.apache.flink.api.common.accumulators.AccumulatorHelper;
import org.apache.flink.api.common.distributions.DataDistribution;
import org.apache.flink.api.common.functions.FlatCombineFunction;
import org.apache.flink.api.common.functions.Function;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
@@ -1269,7 +1270,9 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i
throw new Exception("Incompatibe serializer-/comparator factories.");
}
final DataDistribution distribution = config.getOutputDataDistribution(i, cl);
- oe = new RecordOutputEmitter(strategy, comparator, distribution);
+ final Partitioner<?> partitioner = config.getOutputPartitioner(i, cl);
+
+ oe = new RecordOutputEmitter(strategy, comparator, partitioner, distribution);
}
writers.add(new RecordWriter<Record>(task, oe));
@@ -1292,17 +1295,17 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i
// create the OutputEmitter from output ship strategy
final ShipStrategyType strategy = config.getOutputShipStrategy(i);
final TypeComparatorFactory<T> compFactory = config.getOutputComparator(i, cl);
- final DataDistribution dataDist = config.getOutputDataDistribution(i, cl);
final ChannelSelector<SerializationDelegate<T>> oe;
if (compFactory == null) {
oe = new OutputEmitter<T>(strategy);
- } else if (dataDist == null){
- final TypeComparator<T> comparator = compFactory.createComparator();
- oe = new OutputEmitter<T>(strategy, comparator);
- } else {
+ }
+ else {
+ final DataDistribution dataDist = config.getOutputDataDistribution(i, cl);
+ final Partitioner<?> partitioner = config.getOutputPartitioner(i, cl);
+
final TypeComparator<T> comparator = compFactory.createComparator();
- oe = new OutputEmitter<T>(strategy, comparator, dataDist);
+ oe = new OutputEmitter<T>(strategy, comparator, partitioner, dataDist);
}
writers.add(new RecordWriter<SerializationDelegate<T>>(task, oe));
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/HistogramPartitionFunction.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/HistogramPartitionFunction.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/HistogramPartitionFunction.java
deleted file mode 100644
index 54bb901..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/HistogramPartitionFunction.java
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.runtime.operators.shipping;
-
-import java.util.Arrays;
-
-import org.apache.flink.api.common.operators.Order;
-import org.apache.flink.types.Record;
-
-public class HistogramPartitionFunction implements PartitionFunction {
- private final Record[] splitBorders;
- private final Order partitionOrder;
-
- public HistogramPartitionFunction(Record[] splitBorders, Order partitionOrder) {
- this.splitBorders = splitBorders;
- this.partitionOrder = partitionOrder;
- }
-
- @Override
- public void selectChannels(Record data, int numChannels, int[] channels) {
- //TODO: Check partition borders match number of channels
- int pos = Arrays.binarySearch(splitBorders, data);
-
- /*
- *
- * TODO CHECK ONLY FOR KEYS NOT FOR WHOLE RECORD
- *
- */
-
- if(pos < 0) {
- pos++;
- pos = -pos;
- }
-
- if(partitionOrder == Order.ASCENDING || partitionOrder == Order.ANY) {
- channels[0] = pos;
- } else {
- channels[0] = splitBorders.length - pos;
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/OutputEmitter.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/OutputEmitter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/OutputEmitter.java
index 4f297b0..ec92e3f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/OutputEmitter.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/OutputEmitter.java
@@ -20,6 +20,7 @@
package org.apache.flink.runtime.operators.shipping;
import org.apache.flink.api.common.distributions.DataDistribution;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.runtime.io.network.api.ChannelSelector;
import org.apache.flink.runtime.plugable.SerializationDelegate;
@@ -33,6 +34,10 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
private int nextChannelToSendTo = 0; // counter to go over channels round robin
private final TypeComparator<T> comparator; // the comparator for hashing / sorting
+
+ private final Partitioner<Object> partitioner;
+
+ private Object[] extractedKeys;
// ------------------------------------------------------------------------
// Constructors
@@ -62,7 +67,7 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
* @param comparator The comparator used to hash / compare the records.
*/
public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator) {
- this(strategy, comparator, null);
+ this(strategy, comparator, null, null);
}
/**
@@ -74,12 +79,22 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
* @param distr The distribution pattern used in the case of a range partitioning.
*/
public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator, DataDistribution distr) {
+ this(strategy, comparator, null, distr);
+ }
+
+ public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator, Partitioner<?> partitioner) {
+ this(strategy, comparator, partitioner, null);
+ }
+
+ @SuppressWarnings("unchecked")
+ public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator, Partitioner<?> partitioner, DataDistribution distr) {
if (strategy == null) {
throw new NullPointerException();
}
this.strategy = strategy;
this.comparator = comparator;
+ this.partitioner = (Partitioner<Object>) partitioner;
switch (strategy) {
case FORWARD:
@@ -87,6 +102,7 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
case PARTITION_RANGE:
case PARTITION_RANDOM:
case PARTITION_FORCED_REBALANCE:
+ case PARTITION_CUSTOM:
case BROADCAST:
break;
default:
@@ -96,6 +112,9 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
if ((strategy == ShipStrategyType.PARTITION_RANGE) && distr == null) {
throw new NullPointerException("Data distribution must not be null when the ship strategy is range partitioning.");
}
+ if (strategy == ShipStrategyType.PARTITION_CUSTOM && partitioner == null) {
+ throw new NullPointerException("Partitioner must not be null when the ship strategy is set to custom partitioning.");
+ }
}
// ------------------------------------------------------------------------
@@ -111,10 +130,12 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
return robin(numberOfChannels);
case PARTITION_HASH:
return hashPartitionDefault(record.getInstance(), numberOfChannels);
- case PARTITION_RANGE:
- return rangePartition(record.getInstance(), numberOfChannels);
case BROADCAST:
return broadcast(numberOfChannels);
+ case PARTITION_CUSTOM:
+ return customPartition(record.getInstance(), numberOfChannels);
+ case PARTITION_RANGE:
+ return rangePartition(record.getInstance(), numberOfChannels);
default:
throw new UnsupportedOperationException("Unsupported distribution strategy: " + strategy.name());
}
@@ -189,4 +210,25 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
private final int[] rangePartition(T record, int numberOfChannels) {
throw new UnsupportedOperationException();
}
+
+ private final int[] customPartition(T record, int numberOfChannels) {
+ if (channels == null) {
+ channels = new int[1];
+ extractedKeys = new Object[1];
+ }
+
+ try {
+ if (comparator.extractKeys(record, extractedKeys, 0) == 1) {
+ final Object key = extractedKeys[0];
+ channels[0] = partitioner.partition(key, numberOfChannels);
+ return channels;
+ }
+ else {
+ throw new RuntimeException("Inconsistency in the key comparator - comparator extracted more than one field.");
+ }
+ }
+ catch (Throwable t) {
+ throw new RuntimeException("Error while calling custom partitioner.", t);
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/PartitionFunction.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/PartitionFunction.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/PartitionFunction.java
deleted file mode 100644
index dadec16..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/PartitionFunction.java
+++ /dev/null
@@ -1,26 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.runtime.operators.shipping;
-
-import org.apache.flink.types.Record;
-
-public interface PartitionFunction {
- public void selectChannels(Record data, int numChannels, int[] channels);
-}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/RecordOutputEmitter.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/RecordOutputEmitter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/RecordOutputEmitter.java
index 8a375e0..9d06aad 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/RecordOutputEmitter.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/RecordOutputEmitter.java
@@ -20,6 +20,7 @@
package org.apache.flink.runtime.operators.shipping;
import org.apache.flink.api.common.distributions.DataDistribution;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.runtime.io.network.api.ChannelSelector;
import org.apache.flink.types.Key;
@@ -43,7 +44,11 @@ public class RecordOutputEmitter implements ChannelSelector<Record> {
private final DataDistribution distribution; // the data distribution to create the partition boundaries for range partitioning
+ private final Partitioner<Object> partitioner;
+
private int nextChannelToSendTo; // counter to go over channels round robin
+
+ private Object[] extractedKeys;
// ------------------------------------------------------------------------
// Constructors
@@ -66,7 +71,7 @@ public class RecordOutputEmitter implements ChannelSelector<Record> {
* @param comparator The comparator used to hash / compare the records.
*/
public RecordOutputEmitter(ShipStrategyType strategy, TypeComparator<Record> comparator) {
- this(strategy, comparator, null);
+ this(strategy, comparator, null, null);
}
/**
@@ -78,6 +83,15 @@ public class RecordOutputEmitter implements ChannelSelector<Record> {
* @param distr The distribution pattern used in the case of a range partitioning.
*/
public RecordOutputEmitter(ShipStrategyType strategy, TypeComparator<Record> comparator, DataDistribution distr) {
+ this(strategy, comparator, null, distr);
+ }
+
+ public RecordOutputEmitter(ShipStrategyType strategy, TypeComparator<Record> comparator, Partitioner<?> partitioner) {
+ this(strategy, comparator, partitioner, null);
+ }
+
+ @SuppressWarnings("unchecked")
+ public RecordOutputEmitter(ShipStrategyType strategy, TypeComparator<Record> comparator, Partitioner<?> partitioner, DataDistribution distr) {
if (strategy == null) {
throw new NullPointerException();
}
@@ -85,6 +99,7 @@ public class RecordOutputEmitter implements ChannelSelector<Record> {
this.strategy = strategy;
this.comparator = comparator;
this.distribution = distr;
+ this.partitioner = (Partitioner<Object>) partitioner;
switch (strategy) {
case FORWARD:
@@ -94,6 +109,7 @@ public class RecordOutputEmitter implements ChannelSelector<Record> {
this.channels = new int[1];
break;
case BROADCAST:
+ case PARTITION_CUSTOM:
break;
default:
throw new IllegalArgumentException("Invalid shipping strategy for OutputEmitter: " + strategy.name());
@@ -102,6 +118,9 @@ public class RecordOutputEmitter implements ChannelSelector<Record> {
if ((strategy == ShipStrategyType.PARTITION_RANGE) && distr == null) {
throw new NullPointerException("Data distribution must not be null when the ship strategy is range partitioning.");
}
+ if (strategy == ShipStrategyType.PARTITION_CUSTOM && partitioner == null) {
+ throw new NullPointerException("Partitioner must not be null when the ship strategy is set to custom partitioning.");
+ }
}
// ------------------------------------------------------------------------
@@ -113,13 +132,16 @@ public class RecordOutputEmitter implements ChannelSelector<Record> {
switch (strategy) {
case FORWARD:
case PARTITION_RANDOM:
+ case PARTITION_FORCED_REBALANCE:
return robin(numberOfChannels);
case PARTITION_HASH:
return hashPartitionDefault(record, numberOfChannels);
- case PARTITION_RANGE:
- return rangePartition(record, numberOfChannels);
+ case PARTITION_CUSTOM:
+ return customPartition(record, numberOfChannels);
case BROADCAST:
return broadcast(numberOfChannels);
+ case PARTITION_RANGE:
+ return rangePartition(record, numberOfChannels);
default:
throw new UnsupportedOperationException("Unsupported distribution strategy: " + strategy.name());
}
@@ -200,4 +222,25 @@ public class RecordOutputEmitter implements ChannelSelector<Record> {
"The number of channels to partition among is inconsistent with the partitioners state.");
}
}
+
+ private final int[] customPartition(Record record, int numberOfChannels) {
+ if (channels == null) {
+ channels = new int[1];
+ extractedKeys = new Object[1];
+ }
+
+ try {
+ if (comparator.extractKeys(record, extractedKeys, 0) == 1) {
+ final Object key = extractedKeys[0];
+ channels[0] = partitioner.partition(key, numberOfChannels);
+ return channels;
+ }
+ else {
+ throw new RuntimeException("Inconsistency in the key comparator - comparator extracted more than one field.");
+ }
+ }
+ catch (Throwable t) {
+ throw new RuntimeException("Error while calling custom partitioner.", t);
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/ShipStrategyType.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/ShipStrategyType.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/ShipStrategyType.java
index 45134a1..fb32a6e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/ShipStrategyType.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/ShipStrategyType.java
@@ -51,14 +51,19 @@ public enum ShipStrategyType {
PARTITION_RANGE(true, true),
/**
- * Partitioning the data evenly
+ * Partitioning the data evenly, forced at a specific location (cannot be pushed down by optimizer).
*/
PARTITION_FORCED_REBALANCE(true, false),
/**
* Replicating the data set to all instances.
*/
- BROADCAST(true, false);
+ BROADCAST(true, false),
+
+ /**
+ * Partitioning using a custom partitioner.
+ */
+ PARTITION_CUSTOM(true, true);
// --------------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
index 1b44a3b..89cf98a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
@@ -36,6 +36,7 @@ import org.apache.flink.api.common.aggregators.AggregatorWithName;
import org.apache.flink.api.common.aggregators.ConvergenceCriterion;
import org.apache.flink.api.common.distributions.DataDistribution;
import org.apache.flink.api.common.functions.Function;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
import org.apache.flink.api.common.typeutils.TypePairComparatorFactory;
@@ -141,6 +142,8 @@ public class TaskConfig {
private static final String OUTPUT_DATA_DISTRIBUTION_PREFIX = "out.distribution.";
+ private static final String OUTPUT_PARTITIONER = "out.partitioner.";
+
// ------------------------------------- Chaining ---------------------------------------------
private static final String CHAINING_NUM_STUBS = "chaining.num";
@@ -597,6 +600,27 @@ public class TaskConfig {
}
}
+ public void setOutputPartitioner(Partitioner<?> partitioner, int outputNum) {
+ try {
+ InstantiationUtil.writeObjectToConfig(partitioner, config, OUTPUT_PARTITIONER + outputNum);
+ }
+ catch (Throwable t) {
+ throw new RuntimeException("Could not serialize custom partitioner.", t);
+ }
+ }
+
+ public Partitioner<?> getOutputPartitioner(int outputNum, final ClassLoader cl) throws ClassNotFoundException {
+ try {
+ return (Partitioner<?>) InstantiationUtil.readObjectFromConfig(config, OUTPUT_PARTITIONER + outputNum, cl);
+ }
+ catch (ClassNotFoundException e) {
+ throw e;
+ }
+ catch (Throwable t) {
+ throw new RuntimeException("Could not deserialize custom partitioner.", t);
+ }
+ }
+
// --------------------------------------------------------------------------------------------
// Parameters to configure the memory and I/O behavior
// --------------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java b/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java
index 534ef45..3d76921 100644
--- a/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java
+++ b/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java
@@ -232,6 +232,7 @@ public class ScalaAggregateOperator<IN> extends SingleInputOperator<IN, IN, Scal
}
po.setSemanticProperties(props);
+ po.setCustomPartitioner(grouping.getCustomPartitioner());
return po;
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
index ca8e469..d1233e6 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
@@ -1009,9 +1009,75 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
getCallLocationName())
wrap(op)
}
+
+ /**
+ * Partitions a tuple DataSet on the specified key fields using a custom partitioner.
+ * This method takes the key position to partition on, and a partitioner that accepts the key
+ * type.
+ * <p>
+ * Note: This method works only on single field keys.
+ */
+ def partitionCustom[K: TypeInformation](partitioner: Partitioner[K], field: Int) : DataSet[T] = {
+ val op = new PartitionOperator[T](
+ javaSet,
+ new Keys.ExpressionKeys[T](Array[Int](field), javaSet.getType, false),
+ partitioner,
+ implicitly[TypeInformation[K]],
+ getCallLocationName())
+
+ wrap(op)
+ }
+
+ /**
+ * Partitions a POJO DataSet on the specified key fields using a custom partitioner.
+ * This method takes the key expression to partition on, and a partitioner that accepts the key
+ * type.
+ * <p>
+ * Note: This method works only on single field keys.
+ */
+ def partitionCustom[K: TypeInformation](partitioner: Partitioner[K], field: String)
+ : DataSet[T] = {
+ val op = new PartitionOperator[T](
+ javaSet,
+ new Keys.ExpressionKeys[T](Array[String](field), javaSet.getType),
+ partitioner,
+ implicitly[TypeInformation[K]],
+ getCallLocationName())
+
+ wrap(op)
+ }
+
+ /**
+ * Partitions a DataSet on the key returned by the selector, using a custom partitioner.
+ * This method takes the key selector t get the key to partition on, and a partitioner that
+ * accepts the key type.
+ * <p>
+ * Note: This method works only on single field keys, i.e. the selector cannot return tuples
+ * of fields.
+ */
+ def partitionCustom[K: TypeInformation](partitioner: Partitioner[K], fun: T => K)
+ : DataSet[T] = {
+ val keyExtractor = new KeySelector[T, K] {
+ def getKey(in: T) = fun(in)
+ }
+
+ val keyType = implicitly[TypeInformation[K]];
+
+ val op = new PartitionOperator[T](
+ javaSet,
+ new Keys.SelectorFunctionKeys[T, K](
+ keyExtractor,
+ javaSet.getType,
+ keyType),
+ partitioner,
+ keyType,
+ getCallLocationName())
+
+ wrap(op)
+ }
/**
- * Enforces a rebalancing of the DataSet, i.e., the DataSet is evenly distributed over all
+ * Enforces a re-balancing of the DataSet, i.e., the DataSet is evenly distributed over all
* parallel instances of the
* following task. This can help to improve performance in case of heavy data skew and compute
* intensive operations.
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala
index 23edc74..d87426e 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala
@@ -20,9 +20,7 @@ package org.apache.flink.api.scala
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.java.functions.FirstReducer
import org.apache.flink.api.scala.operators.ScalaAggregateOperator
-
import scala.collection.JavaConverters._
-
import org.apache.commons.lang3.Validate
import org.apache.flink.api.common.functions.{GroupReduceFunction, ReduceFunction}
import org.apache.flink.api.common.operators.Order
@@ -30,9 +28,10 @@ import org.apache.flink.api.java.aggregation.Aggregations
import org.apache.flink.api.java.operators._
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.util.Collector
-
import scala.collection.mutable
import scala.reflect.ClassTag
+import org.apache.flink.api.common.functions.Partitioner
+import com.google.common.base.Preconditions
/**
* A [[DataSet]] to which a grouping key was added. Operations work on groups of elements with the
@@ -49,6 +48,8 @@ class GroupedDataSet[T: ClassTag](
// when using a group-at-a-time reduce function.
private val groupSortKeyPositions = mutable.MutableList[Either[Int, String]]()
private val groupSortOrders = mutable.MutableList[Order]()
+
+ private var partitioner : Partitioner[_] = _
/**
* Adds a secondary sort key to this [[GroupedDataSet]]. This will only have an effect if you
@@ -113,16 +114,51 @@ class GroupedDataSet[T: ClassTag](
}
}
- grouping
+
+ if (partitioner == null) {
+ grouping
+ } else {
+ grouping.withPartitioner(partitioner)
+ }
+
} else {
- new UnsortedGrouping[T](set.javaSet, keys)
+ createUnsortedGrouping()
}
}
/** Convenience methods for creating the [[UnsortedGrouping]] */
- private def createUnsortedGrouping(): Grouping[T] = new UnsortedGrouping[T](set.javaSet, keys)
+ private def createUnsortedGrouping(): Grouping[T] = {
+ val grp = new UnsortedGrouping[T](set.javaSet, keys)
+ if (partitioner == null) {
+ grp
+ } else {
+ grp.withPartitioner(partitioner)
+ }
+ }
/**
+ * Sets a custom partitioner for the grouping.
+ */
+ def withPartitioner[K : TypeInformation](partitioner: Partitioner[K]) : GroupedDataSet[T] = {
+ Preconditions.checkNotNull(partitioner)
+ keys.validateCustomPartitioner(partitioner, implicitly[TypeInformation[K]])
+ this.partitioner = partitioner
+ this
+ }
+
+ /**
+ * Gets the custom partitioner to be used for this grouping, or null, if
+ * none was defined.
+ */
+ def getCustomPartitioner[K]() : Partitioner[K] = {
+ partitioner.asInstanceOf[Partitioner[K]]
+ }
+
+ // ----------------------------------------------------------------------------------------------
+ // Operations
+ // ----------------------------------------------------------------------------------------------
+
+ /**
* Creates a new [[DataSet]] by aggregating the specified tuple field using the given aggregation
* function. Since this is a keyed DataSet the aggregation will be performed on groups of
* tuples with the same key.
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala
index 7062c63..f5b0783 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala
@@ -21,15 +21,15 @@ import org.apache.commons.lang3.Validate
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.common.functions.{JoinFunction, RichFlatJoinFunction, FlatJoinFunction}
import org.apache.flink.api.common.typeutils.TypeSerializer
-import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint;
+import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint
import org.apache.flink.api.java.operators.JoinOperator.DefaultJoin.WrappingFlatJoinFunction
-import org.apache.flink.api.java.operators.JoinOperator.EquiJoin;
+import org.apache.flink.api.java.operators.JoinOperator.EquiJoin
import org.apache.flink.api.java.operators._
import org.apache.flink.api.scala.typeutils.{CaseClassSerializer, CaseClassTypeInfo}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.util.Collector
-
import scala.reflect.ClassTag
+import org.apache.flink.api.common.functions.Partitioner
/**
* A specific [[DataSet]] that results from a `join` operation. The result of a default join is a
@@ -66,6 +66,8 @@ class JoinDataSet[L, R](
rightKeys: Keys[R])
extends DataSet(defaultJoin) {
+ var customPartitioner : Partitioner[_] = _
+
/**
* Creates a new [[DataSet]] where the result for each pair of joined elements is the result
* of the given function.
@@ -86,8 +88,12 @@ class JoinDataSet[L, R](
implicitly[TypeInformation[O]],
defaultJoin.getJoinHint,
getCallLocationName())
-
- wrap(joinOperator)
+
+ if (customPartitioner != null) {
+ wrap(joinOperator.withPartitioner(customPartitioner))
+ } else {
+ wrap(joinOperator)
+ }
}
/**
@@ -112,7 +118,11 @@ class JoinDataSet[L, R](
defaultJoin.getJoinHint,
getCallLocationName())
- wrap(joinOperator)
+ if (customPartitioner != null) {
+ wrap(joinOperator.withPartitioner(customPartitioner))
+ } else {
+ wrap(joinOperator)
+ }
}
/**
@@ -136,7 +146,11 @@ class JoinDataSet[L, R](
defaultJoin.getJoinHint,
getCallLocationName())
- wrap(joinOperator)
+ if (customPartitioner != null) {
+ wrap(joinOperator.withPartitioner(customPartitioner))
+ } else {
+ wrap(joinOperator)
+ }
}
/**
@@ -161,7 +175,35 @@ class JoinDataSet[L, R](
defaultJoin.getJoinHint,
getCallLocationName())
- wrap(joinOperator)
+ if (customPartitioner != null) {
+ wrap(joinOperator.withPartitioner(customPartitioner))
+ } else {
+ wrap(joinOperator)
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------------
+ // Properties
+ // ----------------------------------------------------------------------------------------------
+
+ def withPartitioner[K : TypeInformation](partitioner : Partitioner[K]) : JoinDataSet[L, R] = {
+ if (partitioner != null) {
+ val typeInfo : TypeInformation[K] = implicitly[TypeInformation[K]]
+
+ leftKeys.validateCustomPartitioner(partitioner, typeInfo)
+ rightKeys.validateCustomPartitioner(partitioner, typeInfo)
+ }
+ this.customPartitioner = partitioner
+ defaultJoin.withPartitioner(partitioner)
+
+ this
+ }
+
+ /**
+ * Gets the custom partitioner used by this join, or null, if none is set.
+ */
+ def getPartitioner[K]() : Partitioner[K] = {
+ customPartitioner.asInstanceOf[Partitioner[K]]
}
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-tests/src/test/java/org/apache/flink/test/cancelling/CancellingTestBase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/cancelling/CancellingTestBase.java b/flink-tests/src/test/java/org/apache/flink/test/cancelling/CancellingTestBase.java
index 5859a4a..bc40df6 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/cancelling/CancellingTestBase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/cancelling/CancellingTestBase.java
@@ -201,6 +201,8 @@ public abstract class CancellingTestBase {
case FAILING:
case CREATED:
break;
+ case RESTARTING:
+ throw new IllegalStateException("Job restarted");
}
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-tests/src/test/java/org/apache/flink/test/iterative/StaticlyNestedIterationsITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/iterative/StaticlyNestedIterationsITCase.java b/flink-tests/src/test/java/org/apache/flink/test/iterative/StaticlyNestedIterationsITCase.java
index 19fc936..975e4aa 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/iterative/StaticlyNestedIterationsITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/iterative/StaticlyNestedIterationsITCase.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.io.DiscardingOuputFormat;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.test.util.JavaProgramTestBase;
@@ -49,7 +50,7 @@ public class StaticlyNestedIterationsITCase extends JavaProgramTestBase {
DataSet<Long> mainResult = mainIteration.closeWith(joined);
- mainResult.print();
+ mainResult.output(new DiscardingOuputFormat<Long>());
env.execute();
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/IterationWithChainingNepheleITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/IterationWithChainingNepheleITCase.java b/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/IterationWithChainingNepheleITCase.java
index 3a7cdb7..9dcdf75 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/IterationWithChainingNepheleITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/IterationWithChainingNepheleITCase.java
@@ -68,6 +68,7 @@ import org.junit.runners.Parameterized;
*
* {@link IterationWithChainingITCase}
*/
+@SuppressWarnings("deprecation")
@RunWith(Parameterized.class)
public class IterationWithChainingNepheleITCase extends RecordAPITestBase {
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/AggregateTranslationTest.scala
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/AggregateTranslationTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/AggregateTranslationTest.scala
index c4d7dc8..c9b1a3a 100644
--- a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/AggregateTranslationTest.scala
+++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/AggregateTranslationTest.scala
@@ -15,6 +15,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package org.apache.flink.api.scala.operators.translation
import org.apache.flink.api.common.Plan
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningGroupingKeySelectorTest.scala
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningGroupingKeySelectorTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningGroupingKeySelectorTest.scala
new file mode 100644
index 0000000..17ecc3f
--- /dev/null
+++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/translation/CustomPartitioningGroupingKeySelectorTest.scala
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.scala.operators.translation
+
+import org.junit.Assert._
+import org.junit.Test
+import org.apache.flink.api.scala._
+import org.apache.flink.api.common.functions.Partitioner
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType
+import org.apache.flink.compiler.plan.SingleInputPlanNode
+import org.apache.flink.test.compiler.util.CompilerTestBase
+import scala.collection.immutable.Seq
+import org.apache.flink.api.common.operators.Order
+import org.apache.flink.api.common.InvalidProgramException
+
+class CustomPartitioningGroupingKeySelectorTest extends CompilerTestBase {
+
+ @Test
+ def testCustomPartitioningKeySelectorReduce() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0,0) ).rebalance().setParallelism(4)
+
+ data
+ .groupBy( _._1 ).withPartitioner(new TestPartitionerInt())
+ .reduce( (a,b) => a )
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val keyRemovingMapper = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val reducer = keyRemovingMapper.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.FORWARD, keyRemovingMapper.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningKeySelectorGroupReduce() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0,0) ).rebalance().setParallelism(4)
+
+ data
+ .groupBy( _._1 ).withPartitioner(new TestPartitionerInt())
+ .reduceGroup( iter => Seq(iter.next()) )
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningKeySelectorGroupReduceSorted() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0,0,0) ).rebalance().setParallelism(4)
+
+ data
+ .groupBy( _._1 )
+ .withPartitioner(new TestPartitionerInt())
+ .sortGroup(1, Order.ASCENDING)
+ .reduceGroup( iter => Seq(iter.next()) )
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningKeySelectorGroupReduceSorted2() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0,0,0,0) ).rebalance().setParallelism(4)
+
+ data
+ .groupBy( _._1 ).withPartitioner(new TestPartitionerInt())
+ .sortGroup(1, Order.ASCENDING)
+ .sortGroup(2, Order.DESCENDING)
+ .reduceGroup( iter => Seq(iter.next()) )
+ .print()
+
+ val p = env.createProgramPlan()
+ val op = compileNoStats(p)
+
+ val sink = op.getDataSinks.iterator().next()
+ val reducer = sink.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+ val combiner = reducer.getInput.getSource.asInstanceOf[SingleInputPlanNode]
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput.getShipStrategy)
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput.getShipStrategy)
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningKeySelectorInvalidType() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0, 0) ).rebalance().setParallelism(4)
+
+ try {
+ data
+ .groupBy( _._1 )
+ .withPartitioner(new TestPartitionerLong())
+ fail("Should throw an exception")
+ }
+ catch {
+ case e: InvalidProgramException =>
+ }
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningKeySelectorInvalidTypeSorted() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0, 0, 0) ).rebalance().setParallelism(4)
+
+ try {
+ data
+ .groupBy( _._1 )
+ .sortGroup(1, Order.ASCENDING)
+ .withPartitioner(new TestPartitionerLong())
+ fail("Should throw an exception")
+ }
+ catch {
+ case e: InvalidProgramException =>
+ }
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ @Test
+ def testCustomPartitioningTupleRejectCompositeKey() {
+ try {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val data = env.fromElements( (0, 0, 0) ).rebalance().setParallelism(4)
+
+ try {
+ data.groupBy( v => (v._1, v._2) ).withPartitioner(new TestPartitionerInt())
+ fail("Should throw an exception")
+ }
+ catch {
+ case e: InvalidProgramException =>
+ }
+ }
+ catch {
+ case e: Exception => {
+ e.printStackTrace()
+ fail(e.getMessage)
+ }
+ }
+ }
+
+ // ----------------------------------------------------------------------------------------------
+
+ private class TestPartitionerInt extends Partitioner[Int] {
+
+ override def partition(key: Int, numPartitions: Int): Int = 0
+ }
+
+ private class TestPartitionerLong extends Partitioner[Long] {
+
+ override def partition(key: Long, numPartitions: Int): Int = 0
+ }
+}
[3/4] incubator-flink git commit: [FLINK-1237] Add support for custom
partitioners - Functions: GroupReduce, Reduce, Aggregate on UnsortedGrouping,
SortedGrouping,
Join (Java API & Scala API) - Manual partition on DataSet (Java API & S
Posted by se...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/GroupingKeySelectorTranslationTest.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/GroupingKeySelectorTranslationTest.java b/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/GroupingKeySelectorTranslationTest.java
new file mode 100644
index 0000000..8f446a7
--- /dev/null
+++ b/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/GroupingKeySelectorTranslationTest.java
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.compiler.custompartition;
+
+import static org.junit.Assert.*;
+
+import org.apache.flink.api.common.InvalidProgramException;
+import org.apache.flink.api.common.Plan;
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.common.operators.Order;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.compiler.CompilerTestBase;
+import org.apache.flink.compiler.plan.OptimizedPlan;
+import org.apache.flink.compiler.plan.SingleInputPlanNode;
+import org.apache.flink.compiler.plan.SinkPlanNode;
+import org.apache.flink.compiler.testfunctions.DummyReducer;
+import org.apache.flink.compiler.testfunctions.IdentityGroupReducer;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+import org.junit.Test;
+
+@SuppressWarnings({"serial", "unchecked"})
+public class GroupingKeySelectorTranslationTest extends CompilerTestBase {
+
+ @Test
+ public void testCustomPartitioningKeySelectorReduce() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0))
+ .rebalance().setParallelism(4);
+
+ data.groupBy(new TestKeySelector<Tuple2<Integer,Integer>>())
+ .withPartitioner(new TestPartitionerInt())
+ .reduce(new DummyReducer<Tuple2<Integer,Integer>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode keyRemovingMapper = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) keyRemovingMapper.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, keyRemovingMapper.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningKeySelectorGroupReduce() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0))
+ .rebalance().setParallelism(4);
+
+ data.groupBy(new TestKeySelector<Tuple2<Integer,Integer>>())
+ .withPartitioner(new TestPartitionerInt())
+ .reduceGroup(new IdentityGroupReducer<Tuple2<Integer,Integer>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningKeySelectorGroupReduceSorted() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0))
+ .rebalance().setParallelism(4);
+
+ data.groupBy(new TestKeySelector<Tuple3<Integer,Integer,Integer>>())
+ .withPartitioner(new TestPartitionerInt())
+ .sortGroup(1, Order.ASCENDING)
+ .reduceGroup(new IdentityGroupReducer<Tuple3<Integer,Integer,Integer>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningKeySelectorGroupReduceSorted2() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple4<Integer,Integer,Integer, Integer>> data = env.fromElements(new Tuple4<Integer,Integer,Integer,Integer>(0, 0, 0, 0))
+ .rebalance().setParallelism(4);
+
+ data
+ .groupBy(new TestKeySelector<Tuple4<Integer,Integer,Integer,Integer>>())
+ .withPartitioner(new TestPartitionerInt())
+ .sortGroup(1, Order.ASCENDING)
+ .sortGroup(2, Order.DESCENDING)
+ .reduceGroup(new IdentityGroupReducer<Tuple4<Integer,Integer,Integer,Integer>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningKeySelectorInvalidType() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0))
+ .rebalance().setParallelism(4);
+
+ try {
+ data
+ .groupBy(new TestKeySelector<Tuple2<Integer,Integer>>())
+ .withPartitioner(new TestPartitionerLong());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningKeySelectorInvalidTypeSorted() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0))
+ .rebalance().setParallelism(4);
+
+ try {
+ data
+ .groupBy(new TestKeySelector<Tuple3<Integer,Integer,Integer>>())
+ .sortGroup(1, Order.ASCENDING)
+ .withPartitioner(new TestPartitionerLong());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleRejectCompositeKey() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0))
+ .rebalance().setParallelism(4);
+
+ try {
+ data
+ .groupBy(new TestBinaryKeySelector<Tuple3<Integer,Integer,Integer>>())
+ .withPartitioner(new TestPartitionerInt());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ // --------------------------------------------------------------------------------------------
+
+ private static class TestPartitionerInt implements Partitioner<Integer> {
+ @Override
+ public int partition(Integer key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ private static class TestPartitionerLong implements Partitioner<Long> {
+ @Override
+ public int partition(Long key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ private static class TestKeySelector<T extends Tuple> implements KeySelector<T, Integer> {
+ @Override
+ public Integer getKey(T value) {
+ return value.getField(0);
+ }
+ }
+
+ private static class TestBinaryKeySelector<T extends Tuple> implements KeySelector<T, Tuple2<Integer, Integer>> {
+ @Override
+ public Tuple2<Integer, Integer> getKey(T value) {
+ return new Tuple2<Integer, Integer>(value.<Integer>getField(0), value.<Integer>getField(1));
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/GroupingPojoTranslationTest.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/GroupingPojoTranslationTest.java b/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/GroupingPojoTranslationTest.java
new file mode 100644
index 0000000..087d32d
--- /dev/null
+++ b/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/GroupingPojoTranslationTest.java
@@ -0,0 +1,257 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.compiler.custompartition;
+
+import static org.junit.Assert.*;
+
+import org.apache.flink.api.common.InvalidProgramException;
+import org.apache.flink.api.common.Plan;
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.common.operators.Order;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.compiler.CompilerTestBase;
+import org.apache.flink.compiler.plan.OptimizedPlan;
+import org.apache.flink.compiler.plan.SingleInputPlanNode;
+import org.apache.flink.compiler.plan.SinkPlanNode;
+import org.apache.flink.compiler.testfunctions.DummyReducer;
+import org.apache.flink.compiler.testfunctions.IdentityGroupReducer;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+import org.junit.Test;
+
+@SuppressWarnings("serial")
+public class GroupingPojoTranslationTest extends CompilerTestBase {
+
+ @Test
+ public void testCustomPartitioningTupleReduce() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> data = env.fromElements(new Pojo2())
+ .rebalance().setParallelism(4);
+
+ data.groupBy("a").withPartitioner(new TestPartitionerInt())
+ .reduce(new DummyReducer<Pojo2>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleGroupReduce() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> data = env.fromElements(new Pojo2())
+ .rebalance().setParallelism(4);
+
+ data.groupBy("a").withPartitioner(new TestPartitionerInt())
+ .reduceGroup(new IdentityGroupReducer<Pojo2>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleGroupReduceSorted() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo3> data = env.fromElements(new Pojo3())
+ .rebalance().setParallelism(4);
+
+ data.groupBy("a").withPartitioner(new TestPartitionerInt())
+ .sortGroup("b", Order.ASCENDING)
+ .reduceGroup(new IdentityGroupReducer<Pojo3>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleGroupReduceSorted2() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo4> data = env.fromElements(new Pojo4())
+ .rebalance().setParallelism(4);
+
+ data.groupBy("a").withPartitioner(new TestPartitionerInt())
+ .sortGroup("b", Order.ASCENDING)
+ .sortGroup("c", Order.DESCENDING)
+ .reduceGroup(new IdentityGroupReducer<Pojo4>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleInvalidType() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> data = env.fromElements(new Pojo2())
+ .rebalance().setParallelism(4);
+
+ try {
+ data.groupBy("a").withPartitioner(new TestPartitionerLong());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleInvalidTypeSorted() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo3> data = env.fromElements(new Pojo3())
+ .rebalance().setParallelism(4);
+
+ try {
+ data.groupBy("a")
+ .sortGroup("b", Order.ASCENDING)
+ .withPartitioner(new TestPartitionerLong());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleRejectCompositeKey() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> data = env.fromElements(new Pojo2())
+ .rebalance().setParallelism(4);
+
+ try {
+ data.groupBy("a", "b")
+ .withPartitioner(new TestPartitionerInt());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ // --------------------------------------------------------------------------------------------
+
+ public static class Pojo2 {
+ public int a;
+ public int b;
+
+ }
+
+ public static class Pojo3 {
+ public int a;
+ public int b;
+ public int c;
+ }
+
+ public static class Pojo4 {
+ public int a;
+ public int b;
+ public int c;
+ public int d;
+ }
+
+ private static class TestPartitionerInt implements Partitioner<Integer> {
+ @Override
+ public int partition(Integer key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ private static class TestPartitionerLong implements Partitioner<Long> {
+ @Override
+ public int partition(Long key, int numPartitions) {
+ return 0;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/GroupingTupleTranslationTest.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/GroupingTupleTranslationTest.java b/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/GroupingTupleTranslationTest.java
new file mode 100644
index 0000000..7cfabfb
--- /dev/null
+++ b/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/GroupingTupleTranslationTest.java
@@ -0,0 +1,270 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.compiler.custompartition;
+
+import static org.junit.Assert.*;
+
+import org.apache.flink.api.common.InvalidProgramException;
+import org.apache.flink.api.common.Plan;
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.common.operators.Order;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.compiler.CompilerTestBase;
+import org.apache.flink.compiler.plan.OptimizedPlan;
+import org.apache.flink.compiler.plan.SingleInputPlanNode;
+import org.apache.flink.compiler.plan.SinkPlanNode;
+import org.apache.flink.compiler.testfunctions.DummyReducer;
+import org.apache.flink.compiler.testfunctions.IdentityGroupReducer;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+import org.junit.Test;
+
+@SuppressWarnings({"serial", "unchecked"})
+public class GroupingTupleTranslationTest extends CompilerTestBase {
+
+ @Test
+ public void testCustomPartitioningTupleAgg() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0))
+ .rebalance().setParallelism(4);
+
+ data.groupBy(0).withPartitioner(new TestPartitionerInt())
+ .sum(1)
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleReduce() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0))
+ .rebalance().setParallelism(4);
+
+ data.groupBy(0).withPartitioner(new TestPartitionerInt())
+ .reduce(new DummyReducer<Tuple2<Integer,Integer>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleGroupReduce() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0))
+ .rebalance().setParallelism(4);
+
+ data.groupBy(0).withPartitioner(new TestPartitionerInt())
+ .reduceGroup(new IdentityGroupReducer<Tuple2<Integer,Integer>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleGroupReduceSorted() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0))
+ .rebalance().setParallelism(4);
+
+ data.groupBy(0).withPartitioner(new TestPartitionerInt())
+ .sortGroup(1, Order.ASCENDING)
+ .reduceGroup(new IdentityGroupReducer<Tuple3<Integer,Integer,Integer>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleGroupReduceSorted2() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple4<Integer,Integer,Integer, Integer>> data = env.fromElements(new Tuple4<Integer,Integer,Integer,Integer>(0, 0, 0, 0))
+ .rebalance().setParallelism(4);
+
+ data.groupBy(0).withPartitioner(new TestPartitionerInt())
+ .sortGroup(1, Order.ASCENDING)
+ .sortGroup(2, Order.DESCENDING)
+ .reduceGroup(new IdentityGroupReducer<Tuple4<Integer,Integer,Integer,Integer>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleInvalidType() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0))
+ .rebalance().setParallelism(4);
+
+ try {
+ data.groupBy(0).withPartitioner(new TestPartitionerLong());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleInvalidTypeSorted() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0))
+ .rebalance().setParallelism(4);
+
+ try {
+ data.groupBy(0)
+ .sortGroup(1, Order.ASCENDING)
+ .withPartitioner(new TestPartitionerLong());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleRejectCompositeKey() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0))
+ .rebalance().setParallelism(4);
+
+ try {
+ data.groupBy(0, 1)
+ .withPartitioner(new TestPartitionerInt());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ // --------------------------------------------------------------------------------------------
+
+ private static class TestPartitionerInt implements Partitioner<Integer> {
+ @Override
+ public int partition(Integer key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ private static class TestPartitionerLong implements Partitioner<Long> {
+ @Override
+ public int partition(Long key, int numPartitions) {
+ return 0;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/JoinCustomPartitioningTest.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/JoinCustomPartitioningTest.java b/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/JoinCustomPartitioningTest.java
new file mode 100644
index 0000000..0020c66
--- /dev/null
+++ b/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/JoinCustomPartitioningTest.java
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.compiler.custompartition;
+
+import static org.junit.Assert.*;
+
+import org.apache.flink.api.common.InvalidProgramException;
+import org.apache.flink.api.common.Plan;
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.compiler.CompilerTestBase;
+import org.apache.flink.compiler.plan.DualInputPlanNode;
+import org.apache.flink.compiler.plan.OptimizedPlan;
+import org.apache.flink.compiler.plan.SinkPlanNode;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+import org.junit.Test;
+
+@SuppressWarnings({"serial", "unchecked"})
+public class JoinCustomPartitioningTest extends CompilerTestBase {
+
+ @Test
+ public void testJoinWithTuples() {
+ try {
+ final Partitioner<Long> partitioner = new TestPartitionerLong();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Long, Long>> input1 = env.fromElements(new Tuple2<Long, Long>(0L, 0L));
+ DataSet<Tuple3<Long, Long, Long>> input2 = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L));
+
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST).where(1).equalTo(0).withPartitioner(partitioner)
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ DualInputPlanNode join = (DualInputPlanNode) sink.getInput().getSource();
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy());
+ assertEquals(partitioner, join.getInput1().getPartitioner());
+ assertEquals(partitioner, join.getInput2().getPartitioner());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testJoinWithTuplesWrongType() {
+ try {
+ final Partitioner<Integer> partitioner = new TestPartitionerInt();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Long, Long>> input1 = env.fromElements(new Tuple2<Long, Long>(0L, 0L));
+ DataSet<Tuple3<Long, Long, Long>> input2 = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L));
+
+ try {
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST).where(1).equalTo(0)
+ .withPartitioner(partitioner);
+
+ fail("should throw an exception");
+ }
+ catch (InvalidProgramException e) {
+ // expected
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testJoinWithPojos() {
+ try {
+ final Partitioner<Integer> partitioner = new TestPartitionerInt();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> input1 = env.fromElements(new Pojo2());
+ DataSet<Pojo3> input2 = env.fromElements(new Pojo3());
+
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST)
+ .where("b").equalTo("a").withPartitioner(partitioner)
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ DualInputPlanNode join = (DualInputPlanNode) sink.getInput().getSource();
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy());
+ assertEquals(partitioner, join.getInput1().getPartitioner());
+ assertEquals(partitioner, join.getInput2().getPartitioner());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testJoinWithPojosWrongType() {
+ try {
+ final Partitioner<Long> partitioner = new TestPartitionerLong();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> input1 = env.fromElements(new Pojo2());
+ DataSet<Pojo3> input2 = env.fromElements(new Pojo3());
+
+ try {
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST)
+ .where("a").equalTo("b")
+ .withPartitioner(partitioner);
+
+ fail("should throw an exception");
+ }
+ catch (InvalidProgramException e) {
+ // expected
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testJoinWithKeySelectors() {
+ try {
+ final Partitioner<Integer> partitioner = new TestPartitionerInt();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> input1 = env.fromElements(new Pojo2());
+ DataSet<Pojo3> input2 = env.fromElements(new Pojo3());
+
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST)
+ .where(new Pojo2KeySelector())
+ .equalTo(new Pojo3KeySelector())
+ .withPartitioner(partitioner)
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ DualInputPlanNode join = (DualInputPlanNode) sink.getInput().getSource();
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy());
+ assertEquals(partitioner, join.getInput1().getPartitioner());
+ assertEquals(partitioner, join.getInput2().getPartitioner());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testJoinWithKeySelectorsWrongType() {
+ try {
+ final Partitioner<Long> partitioner = new TestPartitionerLong();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> input1 = env.fromElements(new Pojo2());
+ DataSet<Pojo3> input2 = env.fromElements(new Pojo3());
+
+ try {
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST)
+ .where(new Pojo2KeySelector())
+ .equalTo(new Pojo3KeySelector())
+ .withPartitioner(partitioner);
+
+ fail("should throw an exception");
+ }
+ catch (InvalidProgramException e) {
+ // expected
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ // --------------------------------------------------------------------------------------------
+
+ private static class TestPartitionerInt implements Partitioner<Integer> {
+ @Override
+ public int partition(Integer key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ private static class TestPartitionerLong implements Partitioner<Long> {
+ @Override
+ public int partition(Long key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ public static class Pojo2 {
+ public int a;
+ public int b;
+ }
+
+ public static class Pojo3 {
+ public int a;
+ public int b;
+ public int c;
+ }
+
+ private static class Pojo2KeySelector implements KeySelector<Pojo2, Integer> {
+ @Override
+ public Integer getKey(Pojo2 value) {
+ return value.a;
+ }
+ }
+
+ private static class Pojo3KeySelector implements KeySelector<Pojo3, Integer> {
+ @Override
+ public Integer getKey(Pojo3 value) {
+ return value.b;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/test/java/org/apache/flink/compiler/dataproperties/GlobalPropertiesFilteringTest.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/test/java/org/apache/flink/compiler/dataproperties/GlobalPropertiesFilteringTest.java b/flink-compiler/src/test/java/org/apache/flink/compiler/dataproperties/GlobalPropertiesFilteringTest.java
new file mode 100644
index 0000000..fc8616e
--- /dev/null
+++ b/flink-compiler/src/test/java/org/apache/flink/compiler/dataproperties/GlobalPropertiesFilteringTest.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.compiler.dataproperties;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.common.operators.util.FieldList;
+import org.apache.flink.api.common.operators.util.FieldSet;
+import org.apache.flink.compiler.dag.OptimizerNode;
+import org.junit.Test;
+import org.mockito.Matchers;
+
+public class GlobalPropertiesFilteringTest {
+
+ @Test
+ public void testCustomPartitioningPreserves() {
+ try {
+ Partitioner<?> partitioner = new MockPartitioner();
+
+ GlobalProperties gp = new GlobalProperties();
+ gp.setCustomPartitioned(new FieldList(2, 3), partitioner);
+
+ OptimizerNode node = mock(OptimizerNode.class);
+ when(node.isFieldConstant(Matchers.anyInt(), Matchers.anyInt())).thenReturn(true);
+
+ GlobalProperties filtered = gp.filterByNodesConstantSet(node, 0);
+
+ assertTrue(filtered.isPartitionedOnFields(new FieldSet(2, 3)));
+ assertEquals(PartitioningProperty.CUSTOM_PARTITIONING, filtered.getPartitioning());
+ assertEquals(partitioner, filtered.getCustomPartitioner());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/test/java/org/apache/flink/compiler/dataproperties/GlobalPropertiesMatchingTest.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/test/java/org/apache/flink/compiler/dataproperties/GlobalPropertiesMatchingTest.java b/flink-compiler/src/test/java/org/apache/flink/compiler/dataproperties/GlobalPropertiesMatchingTest.java
new file mode 100644
index 0000000..fd4ad82
--- /dev/null
+++ b/flink-compiler/src/test/java/org/apache/flink/compiler/dataproperties/GlobalPropertiesMatchingTest.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.compiler.dataproperties;
+
+import static org.junit.Assert.*;
+
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.common.operators.Order;
+import org.apache.flink.api.common.operators.Ordering;
+import org.apache.flink.api.common.operators.util.FieldList;
+import org.apache.flink.api.common.operators.util.FieldSet;
+import org.junit.Test;
+
+public class GlobalPropertiesMatchingTest {
+
+ @Test
+ public void testMatchingAnyPartitioning() {
+ try {
+
+ RequestedGlobalProperties req = new RequestedGlobalProperties();
+ req.setAnyPartitioning(new FieldSet(6, 2));
+
+ // match any partitioning
+ {
+ GlobalProperties gp1 = new GlobalProperties();
+ gp1.setAnyPartitioning(new FieldList(2, 6));
+ assertTrue(req.isMetBy(gp1));
+
+ GlobalProperties gp2 = new GlobalProperties();
+ gp2.setAnyPartitioning(new FieldList(6, 2));
+ assertTrue(req.isMetBy(gp2));
+
+ GlobalProperties gp3 = new GlobalProperties();
+ gp3.setAnyPartitioning(new FieldList(6, 1));
+ assertFalse(req.isMetBy(gp3));
+
+ GlobalProperties gp4 = new GlobalProperties();
+ gp4.setAnyPartitioning(new FieldList(2));
+ assertTrue(req.isMetBy(gp4));
+ }
+
+ // match hash partitioning
+ {
+ GlobalProperties gp1 = new GlobalProperties();
+ gp1.setHashPartitioned(new FieldList(2, 6));
+ assertTrue(req.isMetBy(gp1));
+
+ GlobalProperties gp2 = new GlobalProperties();
+ gp2.setHashPartitioned(new FieldList(6, 2));
+ assertTrue(req.isMetBy(gp2));
+
+ GlobalProperties gp3 = new GlobalProperties();
+ gp3.setHashPartitioned(new FieldList(6, 1));
+ assertFalse(req.isMetBy(gp3));
+ }
+
+ // match range partitioning
+ {
+ GlobalProperties gp1 = new GlobalProperties();
+ gp1.setRangePartitioned(new Ordering(2, null, Order.DESCENDING).appendOrdering(6, null, Order.ASCENDING));
+ assertTrue(req.isMetBy(gp1));
+
+ GlobalProperties gp2 = new GlobalProperties();
+ gp2.setRangePartitioned(new Ordering(6, null, Order.DESCENDING).appendOrdering(2, null, Order.ASCENDING));
+ assertTrue(req.isMetBy(gp2));
+
+ GlobalProperties gp3 = new GlobalProperties();
+ gp3.setRangePartitioned(new Ordering(6, null, Order.DESCENDING).appendOrdering(1, null, Order.ASCENDING));
+ assertFalse(req.isMetBy(gp3));
+
+ GlobalProperties gp4 = new GlobalProperties();
+ gp4.setRangePartitioned(new Ordering(6, null, Order.DESCENDING));
+ assertTrue(req.isMetBy(gp4));
+ }
+
+ // match custom partitioning
+ {
+ GlobalProperties gp1 = new GlobalProperties();
+ gp1.setCustomPartitioned(new FieldList(2, 6), new MockPartitioner());
+ assertTrue(req.isMetBy(gp1));
+
+ GlobalProperties gp2 = new GlobalProperties();
+ gp2.setCustomPartitioned(new FieldList(6, 2), new MockPartitioner());
+ assertTrue(req.isMetBy(gp2));
+
+ GlobalProperties gp3 = new GlobalProperties();
+ gp3.setCustomPartitioned(new FieldList(6, 1), new MockPartitioner());
+ assertFalse(req.isMetBy(gp3));
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testMatchingCustomPartitioning() {
+ try {
+ final Partitioner<Long> partitioner = new MockPartitioner();
+
+ RequestedGlobalProperties req = new RequestedGlobalProperties();
+ req.setCustomPartitioned(new FieldSet(6, 2), partitioner);
+
+ // match custom partitionings
+ {
+ GlobalProperties gp1 = new GlobalProperties();
+ gp1.setCustomPartitioned(new FieldList(2, 6), partitioner);
+ assertTrue(req.isMetBy(gp1));
+
+ GlobalProperties gp2 = new GlobalProperties();
+ gp2.setCustomPartitioned(new FieldList(6, 2), partitioner);
+ assertTrue(req.isMetBy(gp2));
+
+ GlobalProperties gp3 = new GlobalProperties();
+ gp3.setCustomPartitioned(new FieldList(6, 2), new MockPartitioner());
+ assertFalse(req.isMetBy(gp3));
+ }
+
+ // cannot match other types of partitionings
+ {
+ GlobalProperties gp1 = new GlobalProperties();
+ gp1.setAnyPartitioning(new FieldList(6, 2));
+ assertFalse(req.isMetBy(gp1));
+
+ GlobalProperties gp2 = new GlobalProperties();
+ gp2.setHashPartitioned(new FieldList(6, 2));
+ assertFalse(req.isMetBy(gp2));
+
+ GlobalProperties gp3 = new GlobalProperties();
+ gp3.setRangePartitioned(new Ordering(2, null, Order.DESCENDING).appendOrdering(6, null, Order.ASCENDING));
+ assertFalse(req.isMetBy(gp3));
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ // --------------------------------------------------------------------------------------------
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/test/java/org/apache/flink/compiler/dataproperties/GlobalPropertiesPushdownTest.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/test/java/org/apache/flink/compiler/dataproperties/GlobalPropertiesPushdownTest.java b/flink-compiler/src/test/java/org/apache/flink/compiler/dataproperties/GlobalPropertiesPushdownTest.java
new file mode 100644
index 0000000..f99ebb6
--- /dev/null
+++ b/flink-compiler/src/test/java/org/apache/flink/compiler/dataproperties/GlobalPropertiesPushdownTest.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.compiler.dataproperties;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import org.apache.flink.api.common.operators.util.FieldSet;
+import org.apache.flink.compiler.dag.OptimizerNode;
+import org.junit.Test;
+import org.mockito.Matchers;
+
+public class GlobalPropertiesPushdownTest {
+
+ @Test
+ public void testAnyPartitioningPushedDown() {
+ try {
+ RequestedGlobalProperties req = new RequestedGlobalProperties();
+ req.setAnyPartitioning(new FieldSet(3, 1));
+
+ RequestedGlobalProperties preserved = req.filterByNodesConstantSet(getAllPreservingNode(), 0);
+ assertEquals(PartitioningProperty.ANY_PARTITIONING, preserved.getPartitioning());
+ assertTrue(preserved.getPartitionedFields().isValidSubset(new FieldSet(1, 3)));
+
+ RequestedGlobalProperties nonPreserved = req.filterByNodesConstantSet(getNonePreservingNode(), 0);
+ assertTrue(nonPreserved == null || nonPreserved.isTrivial());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testHashPartitioningPushedDown() {
+ try {
+ RequestedGlobalProperties req = new RequestedGlobalProperties();
+ req.setHashPartitioned(new FieldSet(3, 1));
+
+ RequestedGlobalProperties preserved = req.filterByNodesConstantSet(getAllPreservingNode(), 0);
+ assertEquals(PartitioningProperty.HASH_PARTITIONED, preserved.getPartitioning());
+ assertTrue(preserved.getPartitionedFields().isValidSubset(new FieldSet(1, 3)));
+
+ RequestedGlobalProperties nonPreserved = req.filterByNodesConstantSet(getNonePreservingNode(), 0);
+ assertTrue(nonPreserved == null || nonPreserved.isTrivial());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningNotPushedDown() {
+ try {
+ RequestedGlobalProperties req = new RequestedGlobalProperties();
+ req.setCustomPartitioned(new FieldSet(3, 1), new MockPartitioner());
+
+ RequestedGlobalProperties pushedDown = req.filterByNodesConstantSet(getAllPreservingNode(), 0);
+ assertTrue(pushedDown == null || pushedDown.isTrivial());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testForcedReblancingNotPushedDown() {
+ try {
+ RequestedGlobalProperties req = new RequestedGlobalProperties();
+ req.setForceRebalancing();
+
+ RequestedGlobalProperties pushedDown = req.filterByNodesConstantSet(getAllPreservingNode(), 0);
+ assertTrue(pushedDown == null || pushedDown.isTrivial());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ // --------------------------------------------------------------------------------------------
+
+ private static OptimizerNode getAllPreservingNode() {
+ OptimizerNode node = mock(OptimizerNode.class);
+ when(node.isFieldConstant(Matchers.anyInt(), Matchers.anyInt())).thenReturn(true);
+ return node;
+ }
+
+ private static OptimizerNode getNonePreservingNode() {
+ OptimizerNode node = mock(OptimizerNode.class);
+ when(node.isFieldConstant(Matchers.anyInt(), Matchers.anyInt())).thenReturn(false);
+ return node;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/test/java/org/apache/flink/compiler/dataproperties/MockPartitioner.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/test/java/org/apache/flink/compiler/dataproperties/MockPartitioner.java b/flink-compiler/src/test/java/org/apache/flink/compiler/dataproperties/MockPartitioner.java
new file mode 100644
index 0000000..71e4c3a
--- /dev/null
+++ b/flink-compiler/src/test/java/org/apache/flink/compiler/dataproperties/MockPartitioner.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.compiler.dataproperties;
+
+import org.apache.flink.api.common.functions.Partitioner;
+
+class MockPartitioner implements Partitioner<Long> {
+
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public int partition(Long key, int numPartitions) {
+ return 0;
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/test/java/org/apache/flink/compiler/java/DistinctAndGroupingOptimizerTest.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/test/java/org/apache/flink/compiler/java/DistinctAndGroupingOptimizerTest.java b/flink-compiler/src/test/java/org/apache/flink/compiler/java/DistinctAndGroupingOptimizerTest.java
new file mode 100644
index 0000000..45b389a
--- /dev/null
+++ b/flink-compiler/src/test/java/org/apache/flink/compiler/java/DistinctAndGroupingOptimizerTest.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.compiler.java;
+
+import static org.junit.Assert.*;
+
+import org.junit.Test;
+import org.apache.flink.api.common.Plan;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.compiler.CompilerTestBase;
+import org.apache.flink.compiler.plan.OptimizedPlan;
+import org.apache.flink.compiler.plan.SingleInputPlanNode;
+import org.apache.flink.compiler.plan.SinkPlanNode;
+import org.apache.flink.compiler.testfunctions.IdentityMapper;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+
+@SuppressWarnings("serial")
+public class DistinctAndGroupingOptimizerTest extends CompilerTestBase {
+
+ @Test
+ public void testDistinctPreservesPartitioningOfDistinctFields() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(4);
+
+ @SuppressWarnings("unchecked")
+ DataSet<Tuple2<Long, Long>> data = env.fromElements(new Tuple2<Long, Long>(0L, 0L), new Tuple2<Long, Long>(1L, 1L))
+ .map(new IdentityMapper<Tuple2<Long,Long>>()).setParallelism(4);
+
+ data.distinct(0)
+ .groupBy(0)
+ .sum(1)
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode distinctReducer = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+
+ // reducer can be forward, reuses partitioning from distinct
+ assertEquals(ShipStrategyType.FORWARD, reducer.getInput().getShipStrategy());
+
+ // distinct reducer is partitioned
+ assertEquals(ShipStrategyType.PARTITION_HASH, distinctReducer.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testDistinctDestroysPartitioningOfNonDistinctFields() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(4);
+
+ @SuppressWarnings("unchecked")
+ DataSet<Tuple2<Long, Long>> data = env.fromElements(new Tuple2<Long, Long>(0L, 0L), new Tuple2<Long, Long>(1L, 1L))
+ .map(new IdentityMapper<Tuple2<Long,Long>>()).setParallelism(4);
+
+ data.distinct(1)
+ .groupBy(0)
+ .sum(1)
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+ SingleInputPlanNode distinctReducer = (SingleInputPlanNode) combiner.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+
+ // reducer must repartition, because it works on a different field
+ assertEquals(ShipStrategyType.PARTITION_HASH, reducer.getInput().getShipStrategy());
+
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+
+ // distinct reducer is partitioned
+ assertEquals(ShipStrategyType.PARTITION_HASH, distinctReducer.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/test/java/org/apache/flink/compiler/testfunctions/DummyReducer.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/test/java/org/apache/flink/compiler/testfunctions/DummyReducer.java b/flink-compiler/src/test/java/org/apache/flink/compiler/testfunctions/DummyReducer.java
new file mode 100644
index 0000000..a536bfd
--- /dev/null
+++ b/flink-compiler/src/test/java/org/apache/flink/compiler/testfunctions/DummyReducer.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.compiler.testfunctions;
+
+import org.apache.flink.api.common.functions.RichReduceFunction;
+
+public class DummyReducer<T> extends RichReduceFunction<T> {
+
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public T reduce(T a, T b) {
+ return a;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/test/java/org/apache/flink/compiler/testfunctions/IdentityPartitionerMapper.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/test/java/org/apache/flink/compiler/testfunctions/IdentityPartitionerMapper.java b/flink-compiler/src/test/java/org/apache/flink/compiler/testfunctions/IdentityPartitionerMapper.java
new file mode 100644
index 0000000..d5c6cfe
--- /dev/null
+++ b/flink-compiler/src/test/java/org/apache/flink/compiler/testfunctions/IdentityPartitionerMapper.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.compiler.testfunctions;
+
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.util.Collector;
+
+public class IdentityPartitionerMapper<T> extends RichMapPartitionFunction<T, T> {
+
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public void mapPartition(Iterable<T> values, Collector<T> out) {
+ for (T in : values) {
+ out.collect(in);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-core/src/main/java/org/apache/flink/api/common/functions/Partitioner.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/Partitioner.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/Partitioner.java
new file mode 100644
index 0000000..f686e94
--- /dev/null
+++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/Partitioner.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.common.functions;
+
+/**
+ * Function to implement a custom partition assignment for keys.
+ *
+ * @param <K> The type of the key to be partitioned.
+ */
+public interface Partitioner<K> extends java.io.Serializable {
+
+ /**
+ * Computes the partition for the given key.
+ *
+ * @param key The key.
+ * @param numPartitions The number of partitions to partition into.
+ * @return The partition index.
+ */
+ int partition(K key, int numPartitions);
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java
index f500717..ddfd874 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java
@@ -22,6 +22,7 @@ import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.FlatCombineFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.CopyingListCollector;
import org.apache.flink.api.common.functions.util.FunctionUtils;
@@ -50,13 +51,13 @@ import java.util.List;
*/
public class GroupReduceOperatorBase<IN, OUT, FT extends GroupReduceFunction<IN, OUT>> extends SingleInputOperator<IN, OUT, FT> {
- /**
- * The ordering for the order inside a reduce group.
- */
+ /** The ordering for the order inside a reduce group. */
private Ordering groupOrder;
private boolean combinable;
+ private Partitioner<?> customPartitioner;
+
public GroupReduceOperatorBase(UserCodeWrapper<FT> udf, UnaryOperatorInformation<IN, OUT> operatorInfo, int[] keyPositions, String name) {
super(udf, operatorInfo, keyPositions, name);
@@ -82,7 +83,8 @@ public class GroupReduceOperatorBase<IN, OUT, FT extends GroupReduceFunction<IN,
super(new UserCodeClassWrapper<FT>(udf), operatorInfo, name);
}
-
+ // --------------------------------------------------------------------------------------------
+
/**
* Sets the order of the elements within a reduce group.
*
@@ -102,8 +104,6 @@ public class GroupReduceOperatorBase<IN, OUT, FT extends GroupReduceFunction<IN,
return this.groupOrder;
}
- // --------------------------------------------------------------------------------------------
-
/**
* Marks the group reduce operation as combinable. Combinable operations may pre-reduce the
* data before the actual group reduce operations. Combinable user-defined functions
@@ -132,6 +132,23 @@ public class GroupReduceOperatorBase<IN, OUT, FT extends GroupReduceFunction<IN,
return this.combinable;
}
+ public void setCustomPartitioner(Partitioner<?> customPartitioner) {
+ if (customPartitioner != null) {
+ int[] keys = getKeyColumns(0);
+ if (keys == null || keys.length == 0) {
+ throw new IllegalArgumentException("Cannot use custom partitioner for a non-grouped GroupReduce (AllGroupReduce)");
+ }
+ if (keys.length > 1) {
+ throw new IllegalArgumentException("Cannot use the key partitioner for composite keys (more than one key field)");
+ }
+ }
+ this.customPartitioner = customPartitioner;
+ }
+
+ public Partitioner<?> getCustomPartitioner() {
+ return customPartitioner;
+ }
+
// --------------------------------------------------------------------------------------------
@Override
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-core/src/main/java/org/apache/flink/api/common/operators/base/JoinOperatorBase.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/JoinOperatorBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/JoinOperatorBase.java
index ba71b01..bc0f4a0 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/JoinOperatorBase.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/JoinOperatorBase.java
@@ -19,6 +19,7 @@
package org.apache.flink.api.common.operators.base;
import org.apache.flink.api.common.functions.FlatJoinFunction;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.CopyingListCollector;
import org.apache.flink.api.common.functions.util.FunctionUtils;
@@ -98,6 +99,8 @@ public class JoinOperatorBase<IN1, IN2, OUT, FT extends FlatJoinFunction<IN1, IN
private JoinHint joinHint = JoinHint.OPTIMIZER_CHOOSES;
+ private Partitioner<?> partitioner;
+
public JoinOperatorBase(UserCodeWrapper<FT> udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) {
super(udf, operatorInfo, keyPositions1, keyPositions2, name);
@@ -123,6 +126,14 @@ public class JoinOperatorBase<IN1, IN2, OUT, FT extends FlatJoinFunction<IN1, IN
return joinHint;
}
+ public void setCustomPartitioner(Partitioner<?> partitioner) {
+ this.partitioner = partitioner;
+ }
+
+ public Partitioner<?> getCustomPartitioner() {
+ return partitioner;
+ }
+
// --------------------------------------------------------------------------------------------
@SuppressWarnings("unchecked")
@@ -143,35 +154,37 @@ public class JoinOperatorBase<IN1, IN2, OUT, FT extends FlatJoinFunction<IN1, IN
TypeComparator<IN1> leftComparator;
TypeComparator<IN2> rightComparator;
- if (leftInformation instanceof AtomicType){
+ if (leftInformation instanceof AtomicType) {
leftComparator = ((AtomicType<IN1>) leftInformation).createComparator(true);
}
- else if(leftInformation instanceof CompositeType){
+ else if (leftInformation instanceof CompositeType) {
int[] keyPositions = getKeyColumns(0);
boolean[] orders = new boolean[keyPositions.length];
Arrays.fill(orders, true);
leftComparator = ((CompositeType<IN1>) leftInformation).createComparator(keyPositions, orders, 0);
- }else{
+ }
+ else {
throw new RuntimeException("Type information for left input of type " + leftInformation.getClass()
.getCanonicalName() + " is not supported. Could not generate a comparator.");
}
- if(rightInformation instanceof AtomicType){
+ if (rightInformation instanceof AtomicType) {
rightComparator = ((AtomicType<IN2>) rightInformation).createComparator(true);
- }else if(rightInformation instanceof CompositeType){
+ }
+ else if (rightInformation instanceof CompositeType) {
int[] keyPositions = getKeyColumns(1);
boolean[] orders = new boolean[keyPositions.length];
Arrays.fill(orders, true);
rightComparator = ((CompositeType<IN2>) rightInformation).createComparator(keyPositions, orders, 0);
- }else{
+ }
+ else {
throw new RuntimeException("Type information for right input of type " + rightInformation.getClass()
.getCanonicalName() + " is not supported. Could not generate a comparator.");
}
- TypePairComparator<IN1, IN2> pairComparator = new GenericPairComparator<IN1, IN2>(leftComparator,
- rightComparator);
+ TypePairComparator<IN1, IN2> pairComparator = new GenericPairComparator<IN1, IN2>(leftComparator, rightComparator);
List<OUT> result = new ArrayList<OUT>();
Collector<OUT> collector = mutableObjectSafe ? new CopyingListCollector<OUT>(result, outInformation.createSerializer())
@@ -196,7 +209,7 @@ public class JoinOperatorBase<IN1, IN2, OUT, FT extends FlatJoinFunction<IN1, IN
if (matchingHashes != null) {
pairComparator.setReference(left);
- for (IN2 right : matchingHashes){
+ for (IN2 right : matchingHashes) {
if (pairComparator.equalToReference(right)) {
if (mutableObjectSafe) {
function.join(leftSerializer.copy(left), rightSerializer.copy(right), collector);
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-core/src/main/java/org/apache/flink/api/common/operators/base/PartitionOperatorBase.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/PartitionOperatorBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/PartitionOperatorBase.java
index af8a111..ee3b259 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/PartitionOperatorBase.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/PartitionOperatorBase.java
@@ -20,6 +20,7 @@ package org.apache.flink.api.common.operators.base;
import java.util.List;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.NoOpFunction;
import org.apache.flink.api.common.operators.SingleInputOperator;
@@ -32,8 +33,20 @@ import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
*/
public class PartitionOperatorBase<IN> extends SingleInputOperator<IN, IN, NoOpFunction> {
+ public static enum PartitionMethod {
+ REBALANCE,
+ HASH,
+ RANGE,
+ CUSTOM;
+ }
+
+ // --------------------------------------------------------------------------------------------
+
private final PartitionMethod partitionMethod;
+ private Partitioner<?> customPartitioner;
+
+
public PartitionOperatorBase(UnaryOperatorInformation<IN, IN> operatorInfo, PartitionMethod pMethod, int[] keys, String name) {
super(new UserCodeObjectWrapper<NoOpFunction>(new NoOpFunction()), operatorInfo, keys, name);
this.partitionMethod = pMethod;
@@ -44,16 +57,31 @@ public class PartitionOperatorBase<IN> extends SingleInputOperator<IN, IN, NoOpF
this.partitionMethod = pMethod;
}
+ // --------------------------------------------------------------------------------------------
+
public PartitionMethod getPartitionMethod() {
return this.partitionMethod;
}
- public static enum PartitionMethod {
- REBALANCE,
- HASH,
- RANGE;
+ public Partitioner<?> getCustomPartitioner() {
+ return customPartitioner;
+ }
+
+ public void setCustomPartitioner(Partitioner<?> customPartitioner) {
+ if (customPartitioner != null) {
+ int[] keys = getKeyColumns(0);
+ if (keys == null || keys.length == 0) {
+ throw new IllegalArgumentException("Cannot use custom partitioner for a non-grouped GroupReduce (AllGroupReduce)");
+ }
+ if (keys.length > 1) {
+ throw new IllegalArgumentException("Cannot use the key partitioner for composite keys (more than one key field)");
+ }
+ }
+ this.customPartitioner = customPartitioner;
}
+ // --------------------------------------------------------------------------------------------
+
@Override
protected List<IN> executeOnCollections(List<IN> inputData, RuntimeContext runtimeContext, boolean mutableObjectSafeMode) {
return inputData;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-core/src/main/java/org/apache/flink/api/common/operators/base/ReduceOperatorBase.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/ReduceOperatorBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/ReduceOperatorBase.java
index 30ff176..f1bf0e9 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/ReduceOperatorBase.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/ReduceOperatorBase.java
@@ -19,6 +19,7 @@
package org.apache.flink.api.common.operators.base;
import org.apache.flink.api.common.InvalidProgramException;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.FunctionUtils;
@@ -50,6 +51,9 @@ import java.util.Map;
*/
public class ReduceOperatorBase<T, FT extends ReduceFunction<T>> extends SingleInputOperator<T, T, FT> {
+ private Partitioner<?> customPartitioner;
+
+
/**
* Creates a grouped reduce data flow operator.
*
@@ -124,7 +128,26 @@ public class ReduceOperatorBase<T, FT extends ReduceFunction<T>> extends SingleI
}
// --------------------------------------------------------------------------------------------
+
+ public void setCustomPartitioner(Partitioner<?> customPartitioner) {
+ if (customPartitioner != null) {
+ int[] keys = getKeyColumns(0);
+ if (keys == null || keys.length == 0) {
+ throw new IllegalArgumentException("Cannot use custom partitioner for a non-grouped GroupReduce (AllGroupReduce)");
+ }
+ if (keys.length > 1) {
+ throw new IllegalArgumentException("Cannot use the key partitioner for composite keys (more than one key field)");
+ }
+ }
+ this.customPartitioner = customPartitioner;
+ }
+
+ public Partitioner<?> getCustomPartitioner() {
+ return customPartitioner;
+ }
+ // --------------------------------------------------------------------------------------------
+
@Override
protected List<T> executeOnCollections(List<T> inputData, RuntimeContext ctx, boolean mutableObjectSafeMode) throws Exception {
// make sure we can handle empty inputs
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-core/src/main/java/org/apache/flink/api/common/typeutils/TypeComparator.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/TypeComparator.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/TypeComparator.java
index f98f6e0..f98a05e 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/TypeComparator.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/TypeComparator.java
@@ -294,7 +294,7 @@ public abstract class TypeComparator<T> implements Serializable {
public abstract int extractKeys(Object record, Object[] target, int index);
/**
- * Get the field comparators. This is used together with {@link #extractKeys(Object)} to provide
+ * Get the field comparators. This is used together with {@link #extractKeys(Object, Object[], int)} to provide
* interoperability between different record types.
*/
@SuppressWarnings("rawtypes")
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-core/src/test/java/org/apache/flink/api/common/operators/base/JoinOperatorBaseTest.java
----------------------------------------------------------------------
diff --git a/flink-core/src/test/java/org/apache/flink/api/common/operators/base/JoinOperatorBaseTest.java b/flink-core/src/test/java/org/apache/flink/api/common/operators/base/JoinOperatorBaseTest.java
index 0ab8e72..feb2223 100644
--- a/flink-core/src/test/java/org/apache/flink/api/common/operators/base/JoinOperatorBaseTest.java
+++ b/flink-core/src/test/java/org/apache/flink/api/common/operators/base/JoinOperatorBaseTest.java
@@ -1,4 +1,4 @@
-/**
+/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-core/src/test/java/org/apache/flink/api/common/operators/base/MapOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-core/src/test/java/org/apache/flink/api/common/operators/base/MapOperatorTest.java b/flink-core/src/test/java/org/apache/flink/api/common/operators/base/MapOperatorTest.java
index 1a742b6..fd23d40 100644
--- a/flink-core/src/test/java/org/apache/flink/api/common/operators/base/MapOperatorTest.java
+++ b/flink-core/src/test/java/org/apache/flink/api/common/operators/base/MapOperatorTest.java
@@ -1,4 +1,4 @@
-/**
+/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-core/src/test/java/org/apache/flink/api/common/operators/base/PartitionMapOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-core/src/test/java/org/apache/flink/api/common/operators/base/PartitionMapOperatorTest.java b/flink-core/src/test/java/org/apache/flink/api/common/operators/base/PartitionMapOperatorTest.java
index dadd1ca..50c6b98 100644
--- a/flink-core/src/test/java/org/apache/flink/api/common/operators/base/PartitionMapOperatorTest.java
+++ b/flink-core/src/test/java/org/apache/flink/api/common/operators/base/PartitionMapOperatorTest.java
@@ -1,4 +1,4 @@
-/**
+/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-core/src/test/java/org/apache/flink/api/common/operators/util/FieldListTest.java
----------------------------------------------------------------------
diff --git a/flink-core/src/test/java/org/apache/flink/api/common/operators/util/FieldListTest.java b/flink-core/src/test/java/org/apache/flink/api/common/operators/util/FieldListTest.java
index 24783ac..39a3301 100644
--- a/flink-core/src/test/java/org/apache/flink/api/common/operators/util/FieldListTest.java
+++ b/flink-core/src/test/java/org/apache/flink/api/common/operators/util/FieldListTest.java
@@ -16,7 +16,6 @@
* limitations under the License.
*/
-
package org.apache.flink.api.common.operators.util;
import static org.junit.Assert.assertEquals;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-core/src/test/java/org/apache/flink/api/common/operators/util/FieldSetTest.java
----------------------------------------------------------------------
diff --git a/flink-core/src/test/java/org/apache/flink/api/common/operators/util/FieldSetTest.java b/flink-core/src/test/java/org/apache/flink/api/common/operators/util/FieldSetTest.java
index 7549f43..27d8bcc 100644
--- a/flink-core/src/test/java/org/apache/flink/api/common/operators/util/FieldSetTest.java
+++ b/flink-core/src/test/java/org/apache/flink/api/common/operators/util/FieldSetTest.java
@@ -16,7 +16,6 @@
* limitations under the License.
*/
-
package org.apache.flink.api.common.operators.util;
import static org.junit.Assert.assertEquals;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
index c78cc7a..6b768b7 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
@@ -25,6 +25,7 @@ import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.io.FileOutputFormat;
import org.apache.flink.api.common.io.OutputFormat;
@@ -153,8 +154,8 @@ public abstract class DataSet<T> {
- /**
- * Applies a Map-style operation to the entire partition of the data.
+ /**
+ * Applies a Map-style operation to the entire partition of the data.
* The function is called once per parallel partition of the data,
* and the entire partition is available through the given Iterator.
* The number of elements that each instance of the MapPartition function
@@ -165,12 +166,12 @@ public abstract class DataSet<T> {
* the use of {@code map()} and {@code flatMap()} is preferable.
*
* @param mapPartition The MapPartitionFunction that is called for the full DataSet.
- * @return A MapPartitionOperator that represents the transformed DataSet.
- *
- * @see MapPartitionFunction
- * @see MapPartitionOperator
- * @see DataSet
- */
+ * @return A MapPartitionOperator that represents the transformed DataSet.
+ *
+ * @see MapPartitionFunction
+ * @see MapPartitionOperator
+ * @see DataSet
+ */
public <R> MapPartitionOperator<T, R> mapPartition(MapPartitionFunction<T, R> mapPartition ){
if (mapPartition == null) {
throw new NullPointerException("MapPartition function must not be null.");
@@ -344,7 +345,7 @@ public abstract class DataSet<T> {
return new GroupReduceOperator<T, R>(this, resultType, reducer, Utils.getCallLocationName());
}
-/**
+ /**
* Applies a special case of a reduce transformation (minBy) on a non-grouped {@link DataSet}.<br/>
* The transformation consecutively calls a {@link ReduceFunction}
* until only a single element remains which is the result of the transformation.
@@ -926,12 +927,59 @@ public abstract class DataSet<T> {
}
/**
- * Enforces a rebalancing of the DataSet, i.e., the DataSet is evenly distributed over all parallel instances of the
+ * Partitions a tuple DataSet on the specified key fields using a custom partitioner.
+ * This method takes the key position to partition on, and a partitioner that accepts the key type.
+ * <p>
+ * Note: This method works only on single field keys.
+ *
+ * @param partitioner The partitioner to assign partitions to keys.
+ * @param field The field index on which the DataSet is to partitioned.
+ * @return The partitioned DataSet.
+ */
+ public <K> PartitionOperator<T> partitionCustom(Partitioner<K> partitioner, int field) {
+ return new PartitionOperator<T>(this, new Keys.ExpressionKeys<T>(new int[] {field}, getType(), false), partitioner, Utils.getCallLocationName());
+ }
+
+ /**
+ * Partitions a POJO DataSet on the specified key fields using a custom partitioner.
+ * This method takes the key expression to partition on, and a partitioner that accepts the key type.
+ * <p>
+ * Note: This method works only on single field keys.
+ *
+ * @param partitioner The partitioner to assign partitions to keys.
+ * @param field The field index on which the DataSet is to partitioned.
+ * @return The partitioned DataSet.
+ */
+ public <K> PartitionOperator<T> partitionCustom(Partitioner<K> partitioner, String field) {
+ return new PartitionOperator<T>(this, new Keys.ExpressionKeys<T>(new String[] {field}, getType()), partitioner, Utils.getCallLocationName());
+ }
+
+ /**
+ * Partitions a DataSet on the key returned by the selector, using a custom partitioner.
+ * This method takes the key selector t get the key to partition on, and a partitioner that
+ * accepts the key type.
+ * <p>
+ * Note: This method works only on single field keys, i.e. the selector cannot return tuples
+ * of fields.
+ *
+ * @param partitioner The partitioner to assign partitions to keys.
+ * @param keyExtractor The KeyExtractor with which the DataSet is hash-partitioned.
+ * @return The partitioned DataSet.
+ *
+ * @see KeySelector
+ */
+ public <K extends Comparable<K>> PartitionOperator<T> partitionCustom(Partitioner<K> partitioner, KeySelector<T, K> keyExtractor) {
+ final TypeInformation<K> keyType = TypeExtractor.getKeySelectorTypes(keyExtractor, type);
+ return new PartitionOperator<T>(this, new Keys.SelectorFunctionKeys<T, K>(keyExtractor, this.getType(), keyType), partitioner, Utils.getCallLocationName());
+ }
+
+ /**
+ * Enforces a re-balancing of the DataSet, i.e., the DataSet is evenly distributed over all parallel instances of the
* following task. This can help to improve performance in case of heavy data skew and compute intensive operations.
* <p>
* <b>Important:</b>This operation shuffles the whole DataSet over the network and can take significant amount of time.
*
- * @return The rebalanced DataSet.
+ * @return The re-balanced DataSet.
*/
public PartitionOperator<T> rebalance() {
return new PartitionOperator<T>(this, PartitionMethod.REBALANCE, Utils.getCallLocationName());
[4/4] incubator-flink git commit: [FLINK-1237] Add support for custom
partitioners - Functions: GroupReduce, Reduce, Aggregate on UnsortedGrouping,
SortedGrouping,
Join (Java API & Scala API) - Manual partition on DataSet (Java API & S
Posted by se...@apache.org.
[FLINK-1237] Add support for custom partitioners
- Functions: GroupReduce, Reduce, Aggregate on UnsortedGrouping, SortedGrouping,
Join (Java API & Scala API)
- Manual partition on DataSet (Java API & Scala API)
- Distinct operations provide semantic properties for preservation of distinctified fields
- Tests for pushown (or not pushdown) of custom partitionings and forced rebalancing
- Tests for GlobalProperties matching of partitionings
- Caching of generated requested data properties for unary operators
This closes #207
Project: http://git-wip-us.apache.org/repos/asf/incubator-flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-flink/commit/2000b45c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-flink/tree/2000b45c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-flink/diff/2000b45c
Branch: refs/heads/master
Commit: 2000b45ce3e71ed6eddecbb3f8658ebecec58230
Parents: 83d0256
Author: Stephan Ewen <se...@apache.org>
Authored: Thu Nov 13 16:26:07 2014 +0100
Committer: Stephan Ewen <se...@apache.org>
Committed: Tue Nov 18 12:19:37 2014 +0100
----------------------------------------------------------------------
.../org/apache/flink/compiler/PactCompiler.java | 8 +-
.../flink/compiler/costs/CostEstimator.java | 1 +
.../flink/compiler/dag/BinaryUnionNode.java | 1 -
.../flink/compiler/dag/CollectorMapNode.java | 6 +-
.../apache/flink/compiler/dag/FilterNode.java | 7 +-
.../apache/flink/compiler/dag/FlatMapNode.java | 8 +-
.../flink/compiler/dag/GroupReduceNode.java | 93 +++---
.../org/apache/flink/compiler/dag/JoinNode.java | 187 ++++++++++++
.../org/apache/flink/compiler/dag/MapNode.java | 11 +-
.../flink/compiler/dag/MapPartitionNode.java | 6 +-
.../flink/compiler/dag/PartitionNode.java | 25 +-
.../apache/flink/compiler/dag/ReduceNode.java | 16 +-
.../dataproperties/GlobalProperties.java | 47 ++-
.../dataproperties/PartitioningProperty.java | 12 +-
.../RequestedGlobalProperties.java | 76 ++++-
.../operators/AbstractJoinDescriptor.java | 23 +-
.../operators/GroupReduceProperties.java | 29 +-
.../GroupReduceWithCombineProperties.java | 23 +-
.../compiler/operators/ReduceProperties.java | 14 +-
.../operators/SortMergeJoinDescriptor.java | 1 -
.../org/apache/flink/compiler/plan/Channel.java | 30 +-
.../plandump/PlanJSONDumpGenerator.java | 3 +
.../plantranslate/NepheleJobGraphGenerator.java | 1 +
.../compiler/FeedbackPropertiesMatchTest.java | 6 +-
...ustomPartitioningGlobalOptimizationTest.java | 93 ++++++
.../custompartition/CustomPartitioningTest.java | 287 +++++++++++++++++++
.../GroupingKeySelectorTranslationTest.java | 268 +++++++++++++++++
.../GroupingPojoTranslationTest.java | 257 +++++++++++++++++
.../GroupingTupleTranslationTest.java | 270 +++++++++++++++++
.../JoinCustomPartitioningTest.java | 263 +++++++++++++++++
.../GlobalPropertiesFilteringTest.java | 55 ++++
.../GlobalPropertiesMatchingTest.java | 159 ++++++++++
.../GlobalPropertiesPushdownTest.java | 113 ++++++++
.../dataproperties/MockPartitioner.java | 31 ++
.../java/DistinctAndGroupingOptimizerTest.java | 112 ++++++++
.../compiler/testfunctions/DummyReducer.java | 31 ++
.../IdentityPartitionerMapper.java | 34 +++
.../flink/api/common/functions/Partitioner.java | 36 +++
.../operators/base/GroupReduceOperatorBase.java | 29 +-
.../common/operators/base/JoinOperatorBase.java | 31 +-
.../operators/base/PartitionOperatorBase.java | 36 ++-
.../operators/base/ReduceOperatorBase.java | 23 ++
.../api/common/typeutils/TypeComparator.java | 2 +-
.../operators/base/JoinOperatorBaseTest.java | 2 +-
.../common/operators/base/MapOperatorTest.java | 2 +-
.../base/PartitionMapOperatorTest.java | 2 +-
.../common/operators/util/FieldListTest.java | 1 -
.../api/common/operators/util/FieldSetTest.java | 1 -
.../java/org/apache/flink/api/java/DataSet.java | 70 ++++-
.../api/java/operators/AggregateOperator.java | 3 +-
.../api/java/operators/DistinctOperator.java | 20 +-
.../api/java/operators/GroupReduceOperator.java | 17 +-
.../flink/api/java/operators/Grouping.java | 16 +-
.../flink/api/java/operators/JoinOperator.java | 119 ++++----
.../apache/flink/api/java/operators/Keys.java | 57 +++-
.../api/java/operators/PartitionOperator.java | 94 ++++--
.../api/java/operators/ReduceOperator.java | 14 +-
.../api/java/operators/SortedGrouping.java | 19 +-
.../api/java/operators/UnsortedGrouping.java | 25 +-
.../api/java/record/io/CsvInputFormat.java | 2 +
.../api/java/record/io/CsvOutputFormat.java | 2 +-
.../java/record/io/DelimitedOutputFormat.java | 3 +-
.../java/record/operators/ReduceOperator.java | 1 +
.../flink/api/java/typeutils/TypeExtractor.java | 5 +
.../runtime/io/network/api/ChannelSelector.java | 1 -
.../runtime/operators/RegularPactTask.java | 17 +-
.../shipping/HistogramPartitionFunction.java | 58 ----
.../operators/shipping/OutputEmitter.java | 48 +++-
.../operators/shipping/PartitionFunction.java | 26 --
.../operators/shipping/RecordOutputEmitter.java | 49 +++-
.../operators/shipping/ShipStrategyType.java | 9 +-
.../runtime/operators/util/TaskConfig.java | 24 ++
.../scala/operators/ScalaAggregateOperator.java | 1 +
.../org/apache/flink/api/scala/DataSet.scala | 68 ++++-
.../apache/flink/api/scala/GroupedDataSet.scala | 48 +++-
.../apache/flink/api/scala/joinDataSet.scala | 58 +++-
.../test/cancelling/CancellingTestBase.java | 2 +
.../StaticlyNestedIterationsITCase.java | 3 +-
.../IterationWithChainingNepheleITCase.java | 1 +
.../translation/AggregateTranslationTest.scala | 1 +
...tomPartitioningGroupingKeySelectorTest.scala | 249 ++++++++++++++++
.../CustomPartitioningGroupingPojoTest.scala | 255 ++++++++++++++++
.../CustomPartitioningGroupingTupleTest.scala | 262 +++++++++++++++++
.../translation/CustomPartitioningTest.scala | 243 ++++++++++++++++
.../JoinCustomPartitioningTest.scala | 252 ++++++++++++++++
85 files changed, 4533 insertions(+), 381 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/PactCompiler.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/PactCompiler.java b/flink-compiler/src/main/java/org/apache/flink/compiler/PactCompiler.java
index 2ce2495..d1d6343 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/PactCompiler.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/PactCompiler.java
@@ -67,7 +67,7 @@ import org.apache.flink.compiler.dag.GroupReduceNode;
import org.apache.flink.compiler.dag.IterationNode;
import org.apache.flink.compiler.dag.MapNode;
import org.apache.flink.compiler.dag.MapPartitionNode;
-import org.apache.flink.compiler.dag.MatchNode;
+import org.apache.flink.compiler.dag.JoinNode;
import org.apache.flink.compiler.dag.OptimizerNode;
import org.apache.flink.compiler.dag.PactConnection;
import org.apache.flink.compiler.dag.PartitionNode;
@@ -696,7 +696,7 @@ public class PactCompiler {
n = new GroupReduceNode((GroupReduceOperatorBase<?, ?, ?>) c);
}
else if (c instanceof JoinOperatorBase) {
- n = new MatchNode((JoinOperatorBase<?, ?, ?, ?>) c);
+ n = new JoinNode((JoinOperatorBase<?, ?, ?, ?>) c);
}
else if (c instanceof CoGroupOperatorBase) {
n = new CoGroupNode((CoGroupOperatorBase<?, ?, ?, ?>) c);
@@ -883,9 +883,9 @@ public class PactCompiler {
for (PactConnection conn : solutionSetNode.getOutgoingConnections()) {
OptimizerNode successor = conn.getTarget();
- if (successor.getClass() == MatchNode.class) {
+ if (successor.getClass() == JoinNode.class) {
// find out which input to the match the solution set is
- MatchNode mn = (MatchNode) successor;
+ JoinNode mn = (JoinNode) successor;
if (mn.getFirstPredecessorNode() == solutionSetNode) {
mn.makeJoinWithSolutionSet(0);
} else if (mn.getSecondPredecessorNode() == solutionSetNode) {
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/costs/CostEstimator.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/costs/CostEstimator.java b/flink-compiler/src/main/java/org/apache/flink/compiler/costs/CostEstimator.java
index 99a9b12..b13c1be 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/costs/CostEstimator.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/costs/CostEstimator.java
@@ -105,6 +105,7 @@ public abstract class CostEstimator {
addRandomPartitioningCost(channel, costs);
break;
case PARTITION_HASH:
+ case PARTITION_CUSTOM:
addHashPartitioningCost(channel, costs);
break;
case PARTITION_RANGE:
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/dag/BinaryUnionNode.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/BinaryUnionNode.java b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/BinaryUnionNode.java
index b229a4e..9003c92 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/BinaryUnionNode.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/BinaryUnionNode.java
@@ -16,7 +16,6 @@
* limitations under the License.
*/
-
package org.apache.flink.compiler.dag;
import java.util.ArrayList;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/dag/CollectorMapNode.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/CollectorMapNode.java b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/CollectorMapNode.java
index dbf97b5..53a760e 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/CollectorMapNode.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/CollectorMapNode.java
@@ -32,9 +32,13 @@ import org.apache.flink.compiler.operators.OperatorDescriptorSingle;
*/
public class CollectorMapNode extends SingleInputNode {
+ private final List<OperatorDescriptorSingle> possibleProperties;
+
public CollectorMapNode(SingleInputOperator<?, ?, ?> operator) {
super(operator);
+
+ this.possibleProperties = Collections.<OperatorDescriptorSingle>singletonList(new CollectorMapDescriptor());
}
@Override
@@ -44,7 +48,7 @@ public class CollectorMapNode extends SingleInputNode {
@Override
protected List<OperatorDescriptorSingle> getPossibleProperties() {
- return Collections.<OperatorDescriptorSingle>singletonList(new CollectorMapDescriptor());
+ return this.possibleProperties;
}
/**
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/dag/FilterNode.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/FilterNode.java b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/FilterNode.java
index fe12fbf..df304b1 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/FilterNode.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/FilterNode.java
@@ -16,7 +16,6 @@
* limitations under the License.
*/
-
package org.apache.flink.compiler.dag;
import java.util.Collections;
@@ -32,9 +31,11 @@ import org.apache.flink.compiler.operators.OperatorDescriptorSingle;
*/
public class FilterNode extends SingleInputNode {
-
+ private final List<OperatorDescriptorSingle> possibleProperties;
+
public FilterNode(FilterOperatorBase<?, ?> operator) {
super(operator);
+ this.possibleProperties = Collections.<OperatorDescriptorSingle>singletonList(new FilterDescriptor());
}
@Override
@@ -54,7 +55,7 @@ public class FilterNode extends SingleInputNode {
@Override
protected List<OperatorDescriptorSingle> getPossibleProperties() {
- return Collections.<OperatorDescriptorSingle>singletonList(new FilterDescriptor());
+ return this.possibleProperties;
}
/**
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/dag/FlatMapNode.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/FlatMapNode.java b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/FlatMapNode.java
index c1a86b3..234b26a 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/FlatMapNode.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/FlatMapNode.java
@@ -16,7 +16,6 @@
* limitations under the License.
*/
-
package org.apache.flink.compiler.dag;
import java.util.Collections;
@@ -32,9 +31,12 @@ import org.apache.flink.compiler.operators.OperatorDescriptorSingle;
*/
public class FlatMapNode extends SingleInputNode {
-
+ private final List<OperatorDescriptorSingle> possibleProperties;
+
public FlatMapNode(FlatMapOperatorBase<?, ?, ?> operator) {
super(operator);
+
+ this.possibleProperties = Collections.<OperatorDescriptorSingle>singletonList(new FlatMapDescriptor());
}
@Override
@@ -49,7 +51,7 @@ public class FlatMapNode extends SingleInputNode {
@Override
protected List<OperatorDescriptorSingle> getPossibleProperties() {
- return Collections.<OperatorDescriptorSingle>singletonList(new FlatMapDescriptor());
+ return this.possibleProperties;
}
/**
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/dag/GroupReduceNode.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/GroupReduceNode.java b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/GroupReduceNode.java
index a6bb207..527adcc 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/GroupReduceNode.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/GroupReduceNode.java
@@ -16,12 +16,12 @@
* limitations under the License.
*/
-
package org.apache.flink.compiler.dag;
import java.util.Collections;
import java.util.List;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.operators.Ordering;
import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase;
import org.apache.flink.compiler.CompilerException;
@@ -35,59 +35,37 @@ import org.apache.flink.compiler.operators.OperatorDescriptorSingle;
import org.apache.flink.configuration.Configuration;
/**
- * The Optimizer representation of a <i>Reduce</i> contract node.
+ * The optimizer representation of a <i>GroupReduce</i> operation.
*/
public class GroupReduceNode extends SingleInputNode {
+ private final List<OperatorDescriptorSingle> possibleProperties;
+
private GroupReduceNode combinerUtilityNode;
/**
- * Creates a new ReduceNode for the given contract.
+ * Creates a new optimizer node for the given operator.
*
- * @param pactContract The reduce contract object.
+ * @param operator The reduce operation.
*/
- public GroupReduceNode(GroupReduceOperatorBase<?, ?, ?> pactContract) {
- super(pactContract);
+ public GroupReduceNode(GroupReduceOperatorBase<?, ?, ?> operator) {
+ super(operator);
if (this.keys == null) {
// case of a key-less reducer. force a parallelism of 1
setDegreeOfParallelism(1);
}
+
+ this.possibleProperties = initPossibleProperties(operator.getCustomPartitioner());
}
public GroupReduceNode(GroupReduceNode reducerToCopyForCombiner) {
super(reducerToCopyForCombiner);
- }
-
- // ------------------------------------------------------------------------
-
- /**
- * Gets the contract object for this reduce node.
- *
- * @return The contract.
- */
- @Override
- public GroupReduceOperatorBase<?, ?, ?> getPactContract() {
- return (GroupReduceOperatorBase<?, ?, ?>) super.getPactContract();
- }
-
- /**
- * Checks, whether a combiner function has been given for the function encapsulated
- * by this reduce contract.
- *
- * @return True, if a combiner has been given, false otherwise.
- */
- public boolean isCombineable() {
- return getPactContract().isCombinable();
- }
-
- @Override
- public String getName() {
- return "GroupReduce";
+
+ this.possibleProperties = Collections.emptyList();
}
- @Override
- protected List<OperatorDescriptorSingle> getPossibleProperties() {
+ private List<OperatorDescriptorSingle> initPossibleProperties(Partitioner<?> customPartitioner) {
// see if an internal hint dictates the strategy to use
final Configuration conf = getPactContract().getParameters();
final String localStrategy = conf.getString(PactCompiler.HINT_LOCAL_STRATEGY, null);
@@ -96,10 +74,11 @@ public class GroupReduceNode extends SingleInputNode {
if (localStrategy != null) {
if (PactCompiler.HINT_LOCAL_STRATEGY_SORT.equals(localStrategy)) {
useCombiner = false;
- } else if (PactCompiler.HINT_LOCAL_STRATEGY_COMBINING_SORT.equals(localStrategy)) {
+ }
+ else if (PactCompiler.HINT_LOCAL_STRATEGY_COMBINING_SORT.equals(localStrategy)) {
if (!isCombineable()) {
- PactCompiler.LOG.warn("Strategy hint for Reduce Pact '" + getPactContract().getName() +
- "' desires combinable reduce, but user function is not marked combinable.");
+ PactCompiler.LOG.warn("Strategy hint for GroupReduce '" + getPactContract().getName() +
+ "' requires combinable reduce, but user function is not marked combinable.");
}
useCombiner = true;
} else {
@@ -119,10 +98,42 @@ public class GroupReduceNode extends SingleInputNode {
}
OperatorDescriptorSingle props = useCombiner ?
- (this.keys == null ? new AllGroupWithPartialPreGroupProperties() : new GroupReduceWithCombineProperties(this.keys, groupOrder)) :
- (this.keys == null ? new AllGroupReduceProperties() : new GroupReduceProperties(this.keys, groupOrder));
+ (this.keys == null ? new AllGroupWithPartialPreGroupProperties() : new GroupReduceWithCombineProperties(this.keys, groupOrder, customPartitioner)) :
+ (this.keys == null ? new AllGroupReduceProperties() : new GroupReduceProperties(this.keys, groupOrder, customPartitioner));
+
+ return Collections.singletonList(props);
+ }
+
+ // ------------------------------------------------------------------------
+
+ /**
+ * Gets the operator represented by this optimizer node.
+ *
+ * @return The operator represented by this optimizer node.
+ */
+ @Override
+ public GroupReduceOperatorBase<?, ?, ?> getPactContract() {
+ return (GroupReduceOperatorBase<?, ?, ?>) super.getPactContract();
+ }
+
+ /**
+ * Checks, whether a combiner function has been given for the function encapsulated
+ * by this reduce contract.
+ *
+ * @return True, if a combiner has been given, false otherwise.
+ */
+ public boolean isCombineable() {
+ return getPactContract().isCombinable();
+ }
- return Collections.singletonList(props);
+ @Override
+ public String getName() {
+ return "GroupReduce";
+ }
+
+ @Override
+ protected List<OperatorDescriptorSingle> getPossibleProperties() {
+ return this.possibleProperties;
}
// --------------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/dag/JoinNode.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/JoinNode.java b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/JoinNode.java
new file mode 100644
index 0000000..19b753d
--- /dev/null
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/JoinNode.java
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.compiler.dag;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.common.operators.base.JoinOperatorBase;
+import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint;
+import org.apache.flink.compiler.CompilerException;
+import org.apache.flink.compiler.DataStatistics;
+import org.apache.flink.compiler.PactCompiler;
+import org.apache.flink.compiler.operators.AbstractJoinDescriptor;
+import org.apache.flink.compiler.operators.HashJoinBuildFirstProperties;
+import org.apache.flink.compiler.operators.HashJoinBuildSecondProperties;
+import org.apache.flink.compiler.operators.OperatorDescriptorDual;
+import org.apache.flink.compiler.operators.SortMergeJoinDescriptor;
+import org.apache.flink.configuration.Configuration;
+
+/**
+ * The Optimizer representation of a join operator.
+ */
+public class JoinNode extends TwoInputNode {
+
+ private List<OperatorDescriptorDual> dataProperties;
+
+ /**
+ * Creates a new JoinNode for the given join operator.
+ *
+ * @param joinOperatorBase The join operator object.
+ */
+ public JoinNode(JoinOperatorBase<?, ?, ?, ?> joinOperatorBase) {
+ super(joinOperatorBase);
+
+ this.dataProperties = getDataProperties(joinOperatorBase,
+ joinOperatorBase.getJoinHint(), joinOperatorBase.getCustomPartitioner());
+ }
+
+ // ------------------------------------------------------------------------
+
+ /**
+ * Gets the contract object for this match node.
+ *
+ * @return The contract.
+ */
+ @Override
+ public JoinOperatorBase<?, ?, ?, ?> getPactContract() {
+ return (JoinOperatorBase<?, ?, ?, ?>) super.getPactContract();
+ }
+
+ @Override
+ public String getName() {
+ return "Join";
+ }
+
+ @Override
+ protected List<OperatorDescriptorDual> getPossibleProperties() {
+ return this.dataProperties;
+ }
+
+ public void makeJoinWithSolutionSet(int solutionsetInputIndex) {
+ OperatorDescriptorDual op;
+ if (solutionsetInputIndex == 0) {
+ op = new HashJoinBuildFirstProperties(this.keys1, this.keys2);
+ } else if (solutionsetInputIndex == 1) {
+ op = new HashJoinBuildSecondProperties(this.keys1, this.keys2);
+ } else {
+ throw new IllegalArgumentException();
+ }
+
+ this.dataProperties = Collections.singletonList(op);
+ }
+
+ /**
+ * The default estimates build on the principle of inclusion: The smaller input key domain is included in the larger
+ * input key domain. We also assume that every key from the larger input has one join partner in the smaller input.
+ * The result cardinality is hence the larger one.
+ */
+ @Override
+ protected void computeOperatorSpecificDefaultEstimates(DataStatistics statistics) {
+ long card1 = getFirstPredecessorNode().getEstimatedNumRecords();
+ long card2 = getSecondPredecessorNode().getEstimatedNumRecords();
+ this.estimatedNumRecords = (card1 < 0 || card2 < 0) ? -1 : Math.max(card1, card2);
+
+ if (this.estimatedNumRecords >= 0) {
+ float width1 = getFirstPredecessorNode().getEstimatedAvgWidthPerOutputRecord();
+ float width2 = getSecondPredecessorNode().getEstimatedAvgWidthPerOutputRecord();
+ float width = (width1 <= 0 || width2 <= 0) ? -1 : width1 + width2;
+
+ if (width > 0) {
+ this.estimatedOutputSize = (long) (width * this.estimatedNumRecords);
+ }
+ }
+ }
+
+ private List<OperatorDescriptorDual> getDataProperties(JoinOperatorBase<?, ?, ?, ?> joinOperatorBase, JoinHint joinHint,
+ Partitioner<?> customPartitioner)
+ {
+ // see if an internal hint dictates the strategy to use
+ Configuration conf = joinOperatorBase.getParameters();
+ String localStrategy = conf.getString(PactCompiler.HINT_LOCAL_STRATEGY, null);
+
+ if (localStrategy != null) {
+ final AbstractJoinDescriptor fixedDriverStrat;
+ if (PactCompiler.HINT_LOCAL_STRATEGY_SORT_BOTH_MERGE.equals(localStrategy) ||
+ PactCompiler.HINT_LOCAL_STRATEGY_SORT_FIRST_MERGE.equals(localStrategy) ||
+ PactCompiler.HINT_LOCAL_STRATEGY_SORT_SECOND_MERGE.equals(localStrategy) ||
+ PactCompiler.HINT_LOCAL_STRATEGY_MERGE.equals(localStrategy) )
+ {
+ fixedDriverStrat = new SortMergeJoinDescriptor(this.keys1, this.keys2);
+ }
+ else if (PactCompiler.HINT_LOCAL_STRATEGY_HASH_BUILD_FIRST.equals(localStrategy)) {
+ fixedDriverStrat = new HashJoinBuildFirstProperties(this.keys1, this.keys2);
+ }
+ else if (PactCompiler.HINT_LOCAL_STRATEGY_HASH_BUILD_SECOND.equals(localStrategy)) {
+ fixedDriverStrat = new HashJoinBuildSecondProperties(this.keys1, this.keys2);
+ }
+ else {
+ throw new CompilerException("Invalid local strategy hint for match contract: " + localStrategy);
+ }
+
+ if (customPartitioner != null) {
+ fixedDriverStrat.setCustomPartitioner(customPartitioner);
+ }
+
+ ArrayList<OperatorDescriptorDual> list = new ArrayList<OperatorDescriptorDual>();
+ list.add(fixedDriverStrat);
+ return list;
+ }
+ else {
+ ArrayList<OperatorDescriptorDual> list = new ArrayList<OperatorDescriptorDual>();
+
+ joinHint = joinHint == null ? JoinHint.OPTIMIZER_CHOOSES : joinHint;
+
+ switch (joinHint) {
+ case BROADCAST_HASH_FIRST:
+ list.add(new HashJoinBuildFirstProperties(this.keys1, this.keys2, true, false, false));
+ break;
+ case BROADCAST_HASH_SECOND:
+ list.add(new HashJoinBuildSecondProperties(this.keys1, this.keys2, false, true, false));
+ break;
+ case REPARTITION_HASH_FIRST:
+ list.add(new HashJoinBuildFirstProperties(this.keys1, this.keys2, false, false, true));
+ break;
+ case REPARTITION_HASH_SECOND:
+ list.add(new HashJoinBuildSecondProperties(this.keys1, this.keys2, false, false, true));
+ break;
+ case REPARTITION_SORT_MERGE:
+ list.add(new SortMergeJoinDescriptor(this.keys1, this.keys2, false, false, true));
+ break;
+ case OPTIMIZER_CHOOSES:
+ list.add(new SortMergeJoinDescriptor(this.keys1, this.keys2));
+ list.add(new HashJoinBuildFirstProperties(this.keys1, this.keys2));
+ list.add(new HashJoinBuildSecondProperties(this.keys1, this.keys2));
+ break;
+ default:
+ throw new CompilerException("Unrecognized join hint: " + joinHint);
+ }
+
+ if (customPartitioner != null) {
+ for (OperatorDescriptorDual descr : list) {
+ ((AbstractJoinDescriptor) descr).setCustomPartitioner(customPartitioner);
+ }
+ }
+
+ return list;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/dag/MapNode.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/MapNode.java b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/MapNode.java
index e65febb..f1e26cd 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/MapNode.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/MapNode.java
@@ -16,7 +16,6 @@
* limitations under the License.
*/
-
package org.apache.flink.compiler.dag;
import java.util.Collections;
@@ -32,13 +31,17 @@ import org.apache.flink.compiler.operators.OperatorDescriptorSingle;
*/
public class MapNode extends SingleInputNode {
+ private final List<OperatorDescriptorSingle> possibleProperties;
+
/**
- * Creates a new MapNode for the given contract.
+ * Creates a new MapNode for the given operator.
*
- * @param operator The map contract object.
+ * @param operator The map operator.
*/
public MapNode(SingleInputOperator<?, ?, ?> operator) {
super(operator);
+
+ this.possibleProperties = Collections.<OperatorDescriptorSingle>singletonList(new MapDescriptor());
}
@Override
@@ -48,7 +51,7 @@ public class MapNode extends SingleInputNode {
@Override
protected List<OperatorDescriptorSingle> getPossibleProperties() {
- return Collections.<OperatorDescriptorSingle>singletonList(new MapDescriptor());
+ return this.possibleProperties;
}
/**
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/dag/MapPartitionNode.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/MapPartitionNode.java b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/MapPartitionNode.java
index a180968..e21b7fc 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/MapPartitionNode.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/MapPartitionNode.java
@@ -32,6 +32,8 @@ import org.apache.flink.compiler.operators.OperatorDescriptorSingle;
*/
public class MapPartitionNode extends SingleInputNode {
+ private final List<OperatorDescriptorSingle> possibleProperties;
+
/**
* Creates a new MapNode for the given contract.
*
@@ -39,6 +41,8 @@ public class MapPartitionNode extends SingleInputNode {
*/
public MapPartitionNode(SingleInputOperator<?, ?, ?> operator) {
super(operator);
+
+ this.possibleProperties = Collections.<OperatorDescriptorSingle>singletonList(new MapPartitionDescriptor());
}
@Override
@@ -48,7 +52,7 @@ public class MapPartitionNode extends SingleInputNode {
@Override
protected List<OperatorDescriptorSingle> getPossibleProperties() {
- return Collections.<OperatorDescriptorSingle>singletonList(new MapPartitionDescriptor());
+ return this.possibleProperties;
}
/**
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/dag/PartitionNode.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/PartitionNode.java b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/PartitionNode.java
index ccd48c5..53b5dd9 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/PartitionNode.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/PartitionNode.java
@@ -22,6 +22,7 @@ package org.apache.flink.compiler.dag;
import java.util.Collections;
import java.util.List;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.operators.base.PartitionOperatorBase;
import org.apache.flink.api.common.operators.base.PartitionOperatorBase.PartitionMethod;
import org.apache.flink.api.common.operators.util.FieldSet;
@@ -40,8 +41,14 @@ import org.apache.flink.runtime.operators.DriverStrategy;
*/
public class PartitionNode extends SingleInputNode {
+ private final List<OperatorDescriptorSingle> possibleProperties;
+
public PartitionNode(PartitionOperatorBase<?> operator) {
super(operator);
+
+ OperatorDescriptorSingle descr = new PartitionDescriptor(
+ this.getPactContract().getPartitionMethod(), this.keys, operator.getCustomPartitioner());
+ this.possibleProperties = Collections.singletonList(descr);
}
@Override
@@ -56,13 +63,14 @@ public class PartitionNode extends SingleInputNode {
@Override
protected List<OperatorDescriptorSingle> getPossibleProperties() {
- return Collections.<OperatorDescriptorSingle>singletonList(new PartitionDescriptor(this.getPactContract().getPartitionMethod(), this.keys));
+ return this.possibleProperties;
}
@Override
protected void computeOperatorSpecificDefaultEstimates(DataStatistics statistics) {
// partitioning does not change the number of records
this.estimatedNumRecords = getPredecessorNode().getEstimatedNumRecords();
+ this.estimatedOutputSize = getPredecessorNode().getEstimatedOutputSize();
}
@Override
@@ -71,15 +79,18 @@ public class PartitionNode extends SingleInputNode {
return true;
}
+ // --------------------------------------------------------------------------------------------
public static class PartitionDescriptor extends OperatorDescriptorSingle {
private final PartitionMethod pMethod;
- private final FieldSet pKeys;
+ private final Partitioner<?> customPartitioner;
- public PartitionDescriptor(PartitionMethod pMethod, FieldSet pKeys) {
+ public PartitionDescriptor(PartitionMethod pMethod, FieldSet pKeys, Partitioner<?> customPartitioner) {
+ super(pKeys);
+
this.pMethod = pMethod;
- this.pKeys = pKeys;
+ this.customPartitioner = customPartitioner;
}
@Override
@@ -98,11 +109,14 @@ public class PartitionNode extends SingleInputNode {
switch (this.pMethod) {
case HASH:
- rgps.setHashPartitioned(pKeys.toFieldList());
+ rgps.setHashPartitioned(this.keys);
break;
case REBALANCE:
rgps.setForceRebalancing();
break;
+ case CUSTOM:
+ rgps.setCustomPartitioned(this.keys, this.customPartitioner);
+ break;
case RANGE:
throw new UnsupportedOperationException("Not yet supported");
default:
@@ -130,5 +144,4 @@ public class PartitionNode extends SingleInputNode {
return lProps;
}
}
-
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/dag/ReduceNode.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/ReduceNode.java b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/ReduceNode.java
index 2abbfb9..defae04 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/dag/ReduceNode.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/dag/ReduceNode.java
@@ -33,6 +33,8 @@ import org.apache.flink.compiler.operators.ReduceProperties;
*/
public class ReduceNode extends SingleInputNode {
+ private final List<OperatorDescriptorSingle> possibleProperties;
+
private ReduceNode preReduceUtilityNode;
@@ -43,10 +45,18 @@ public class ReduceNode extends SingleInputNode {
// case of a key-less reducer. force a parallelism of 1
setDegreeOfParallelism(1);
}
+
+ OperatorDescriptorSingle props = this.keys == null ?
+ new AllReduceProperties() :
+ new ReduceProperties(this.keys, operator.getCustomPartitioner());
+
+ this.possibleProperties = Collections.singletonList(props);
}
public ReduceNode(ReduceNode reducerToCopyForCombiner) {
super(reducerToCopyForCombiner);
+
+ this.possibleProperties = Collections.emptyList();
}
// ------------------------------------------------------------------------
@@ -63,11 +73,7 @@ public class ReduceNode extends SingleInputNode {
@Override
protected List<OperatorDescriptorSingle> getPossibleProperties() {
- OperatorDescriptorSingle props = this.keys == null ?
- new AllReduceProperties() :
- new ReduceProperties(this.keys);
-
- return Collections.singletonList(props);
+ return this.possibleProperties;
}
// --------------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/dataproperties/GlobalProperties.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/dataproperties/GlobalProperties.java b/flink-compiler/src/main/java/org/apache/flink/compiler/dataproperties/GlobalProperties.java
index f3d9c2d..7dedc53 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/dataproperties/GlobalProperties.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/dataproperties/GlobalProperties.java
@@ -16,13 +16,13 @@
* limitations under the License.
*/
-
package org.apache.flink.compiler.dataproperties;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.operators.Ordering;
import org.apache.flink.api.common.operators.util.FieldList;
@@ -50,6 +50,8 @@ public class GlobalProperties implements Cloneable {
private Set<FieldSet> uniqueFieldCombinations;
+ private Partitioner<?> customPartitioner;
+
// --------------------------------------------------------------------------------------------
/**
@@ -67,6 +69,10 @@ public class GlobalProperties implements Cloneable {
* @param partitionedFields
*/
public void setHashPartitioned(FieldList partitionedFields) {
+ if (partitionedFields == null) {
+ throw new NullPointerException();
+ }
+
this.partitioning = PartitioningProperty.HASH_PARTITIONED;
this.partitioningFields = partitionedFields;
this.ordering = null;
@@ -74,12 +80,20 @@ public class GlobalProperties implements Cloneable {
public void setRangePartitioned(Ordering ordering) {
+ if (ordering == null) {
+ throw new NullPointerException();
+ }
+
this.partitioning = PartitioningProperty.RANGE_PARTITIONED;
this.ordering = ordering;
this.partitioningFields = ordering.getInvolvedIndexes();
}
public void setAnyPartitioning(FieldList partitionedFields) {
+ if (partitionedFields == null) {
+ throw new NullPointerException();
+ }
+
this.partitioning = PartitioningProperty.ANY_PARTITIONING;
this.partitioningFields = partitionedFields;
this.ordering = null;
@@ -103,7 +117,21 @@ public class GlobalProperties implements Cloneable {
this.ordering = null;
}
+ public void setCustomPartitioned(FieldList partitionedFields, Partitioner<?> partitioner) {
+ if (partitionedFields == null || partitioner == null) {
+ throw new NullPointerException();
+ }
+
+ this.partitioning = PartitioningProperty.CUSTOM_PARTITIONING;
+ this.partitioningFields = partitionedFields;
+ this.ordering = null;
+ this.customPartitioner = partitioner;
+ }
+
public void addUniqueFieldCombination(FieldSet fields) {
+ if (fields == null) {
+ return;
+ }
if (this.uniqueFieldCombinations == null) {
this.uniqueFieldCombinations = new HashSet<FieldSet>();
}
@@ -128,12 +156,16 @@ public class GlobalProperties implements Cloneable {
return this.ordering;
}
- // --------------------------------------------------------------------------------------------
-
public PartitioningProperty getPartitioning() {
return this.partitioning;
}
+ public Partitioner<?> getCustomPartitioner() {
+ return this.customPartitioner;
+ }
+
+ // --------------------------------------------------------------------------------------------
+
public boolean isPartitionedOnFields(FieldSet fields) {
if (this.partitioning.isPartitionedOnKey() && fields.isValidSubset(this.partitioningFields)) {
return true;
@@ -267,8 +299,14 @@ public class GlobalProperties implements Cloneable {
case RANGE_PARTITIONED:
channel.setShipStrategy(ShipStrategyType.PARTITION_RANGE, this.ordering.getInvolvedIndexes(), this.ordering.getFieldSortDirections());
break;
+ case FORCED_REBALANCED:
+ channel.setShipStrategy(ShipStrategyType.PARTITION_RANDOM);
+ break;
+ case CUSTOM_PARTITIONING:
+ channel.setShipStrategy(ShipStrategyType.PARTITION_CUSTOM, this.partitioningFields, this.customPartitioner);
+ break;
default:
- throw new CompilerException();
+ throw new CompilerException("Unsupported partitioning strategy");
}
}
@@ -322,6 +360,7 @@ public class GlobalProperties implements Cloneable {
newProps.partitioning = this.partitioning;
newProps.partitioningFields = this.partitioningFields;
newProps.ordering = this.ordering;
+ newProps.customPartitioner = this.customPartitioner;
newProps.uniqueFieldCombinations = this.uniqueFieldCombinations == null ? null : new HashSet<FieldSet>(this.uniqueFieldCombinations);
return newProps;
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/dataproperties/PartitioningProperty.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/dataproperties/PartitioningProperty.java b/flink-compiler/src/main/java/org/apache/flink/compiler/dataproperties/PartitioningProperty.java
index f73f491..47cd6b8 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/dataproperties/PartitioningProperty.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/dataproperties/PartitioningProperty.java
@@ -16,7 +16,6 @@
* limitations under the License.
*/
-
package org.apache.flink.compiler.dataproperties;
/**
@@ -50,9 +49,14 @@ public enum PartitioningProperty {
FULL_REPLICATION,
/**
- * Constant indicating a forced even rebalancing.
+ * Constant indicating a forced even re-balancing.
+ */
+ FORCED_REBALANCED,
+
+ /**
+ * A custom partitioning, accompanied by a {@link org.apache.flink.api.common.functions.Partitioner}.
*/
- FORCED_REBALANCED;
+ CUSTOM_PARTITIONING;
/**
* Checks, if this property represents in fact a partitioning. That is,
@@ -95,6 +99,6 @@ public enum PartitioningProperty {
* @return True, if this enum constant is a re-computable partitioning.
*/
public boolean isComputablyPartitioned() {
- return this == HASH_PARTITIONED || this == RANGE_PARTITIONED;
+ return this == HASH_PARTITIONED || this == RANGE_PARTITIONED || this == CUSTOM_PARTITIONING;
}
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/dataproperties/RequestedGlobalProperties.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/dataproperties/RequestedGlobalProperties.java b/flink-compiler/src/main/java/org/apache/flink/compiler/dataproperties/RequestedGlobalProperties.java
index dcf0afa..4e9d60a 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/dataproperties/RequestedGlobalProperties.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/dataproperties/RequestedGlobalProperties.java
@@ -16,10 +16,10 @@
* limitations under the License.
*/
-
package org.apache.flink.compiler.dataproperties;
import org.apache.flink.api.common.distributions.DataDistribution;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.operators.Ordering;
import org.apache.flink.api.common.operators.util.FieldSet;
import org.apache.flink.compiler.CompilerException;
@@ -43,7 +43,9 @@ public final class RequestedGlobalProperties implements Cloneable {
private Ordering ordering; // order of the partitioned fields, if it is an ordered (range) range partitioning
- private DataDistribution dataDistribution; // optional data distribution, for a range partitioning
+ private DataDistribution dataDistribution; // optional data distribution, for a range partitioning
+
+ private Partitioner<?> customPartitioner; // optional, partitioner for custom partitioning
// --------------------------------------------------------------------------------------------
@@ -112,6 +114,17 @@ public final class RequestedGlobalProperties implements Cloneable {
this.ordering = null;
}
+ public void setCustomPartitioned(FieldSet partitionedFields, Partitioner<?> partitioner) {
+ if (partitionedFields == null || partitioner == null) {
+ throw new NullPointerException();
+ }
+
+ this.partitioning = PartitioningProperty.CUSTOM_PARTITIONING;
+ this.partitioningFields = partitionedFields;
+ this.ordering = null;
+ this.customPartitioner = partitioner;
+ }
+
/**
* Gets the partitioning property.
*
@@ -147,6 +160,15 @@ public final class RequestedGlobalProperties implements Cloneable {
public DataDistribution getDataDistribution() {
return this.dataDistribution;
}
+
+ /**
+ * Gets the custom partitioner associated with these properties.
+ *
+ * @return The custom partitioner associated with these properties.
+ */
+ public Partitioner<?> getCustomPartitioner() {
+ return customPartitioner;
+ }
/**
* Checks, if the properties in this object are trivial, i.e. only standard values.
@@ -162,6 +184,8 @@ public final class RequestedGlobalProperties implements Cloneable {
this.partitioning = PartitioningProperty.RANDOM;
this.ordering = null;
this.partitioningFields = null;
+ this.dataDistribution = null;
+ this.customPartitioner = null;
}
/**
@@ -188,7 +212,12 @@ public final class RequestedGlobalProperties implements Cloneable {
}
}
- if (this.partitioning == PartitioningProperty.FULL_REPLICATION) {
+ // make sure that certain properties are not pushed down
+ final PartitioningProperty partitioning = this.partitioning;
+ if (partitioning == PartitioningProperty.FULL_REPLICATION ||
+ partitioning == PartitioningProperty.FORCED_REBALANCED ||
+ partitioning == PartitioningProperty.CUSTOM_PARTITIONING)
+ {
return null;
}
@@ -205,22 +234,34 @@ public final class RequestedGlobalProperties implements Cloneable {
public boolean isMetBy(GlobalProperties props) {
if (this.partitioning == PartitioningProperty.FULL_REPLICATION) {
return props.isFullyReplicated();
- } else if (props.isFullyReplicated()) {
+ }
+ else if (props.isFullyReplicated()) {
return false;
- } else if (this.partitioning == PartitioningProperty.RANDOM) {
+ }
+ else if (this.partitioning == PartitioningProperty.RANDOM) {
return true;
- } else if (this.partitioning == PartitioningProperty.ANY_PARTITIONING) {
+ }
+ else if (this.partitioning == PartitioningProperty.ANY_PARTITIONING) {
return props.isPartitionedOnFields(this.partitioningFields);
- } else if (this.partitioning == PartitioningProperty.HASH_PARTITIONED) {
+ }
+ else if (this.partitioning == PartitioningProperty.HASH_PARTITIONED) {
return props.getPartitioning() == PartitioningProperty.HASH_PARTITIONED &&
props.isPartitionedOnFields(this.partitioningFields);
- } else if (this.partitioning == PartitioningProperty.RANGE_PARTITIONED) {
+ }
+ else if (this.partitioning == PartitioningProperty.RANGE_PARTITIONED) {
return props.getPartitioning() == PartitioningProperty.RANGE_PARTITIONED &&
props.matchesOrderedPartitioning(this.ordering);
- } else if (this.partitioning == PartitioningProperty.FORCED_REBALANCED) {
+ }
+ else if (this.partitioning == PartitioningProperty.FORCED_REBALANCED) {
return props.getPartitioning() == PartitioningProperty.FORCED_REBALANCED;
- } else {
- throw new CompilerException("Bug in properties matching logic.");
+ }
+ else if (this.partitioning == PartitioningProperty.CUSTOM_PARTITIONING) {
+ return props.getPartitioning() == PartitioningProperty.CUSTOM_PARTITIONING &&
+ props.isPartitionedOnFields(this.partitioningFields) &&
+ props.getCustomPartitioner().equals(this.customPartitioner);
+ }
+ else {
+ throw new CompilerException("Properties matching logic leaves open cases.");
}
}
@@ -250,22 +291,29 @@ public final class RequestedGlobalProperties implements Cloneable {
case FULL_REPLICATION:
channel.setShipStrategy(ShipStrategyType.BROADCAST);
break;
+
case ANY_PARTITIONING:
case HASH_PARTITIONED:
channel.setShipStrategy(ShipStrategyType.PARTITION_HASH, Utils.createOrderedFromSet(this.partitioningFields));
break;
+
case RANGE_PARTITIONED:
-
- channel.setShipStrategy(ShipStrategyType.PARTITION_RANGE, this.ordering.getInvolvedIndexes(), this.ordering.getFieldSortDirections());
+ channel.setShipStrategy(ShipStrategyType.PARTITION_RANGE, this.ordering.getInvolvedIndexes(), this.ordering.getFieldSortDirections());
if(this.dataDistribution != null) {
channel.setDataDistribution(this.dataDistribution);
}
break;
+
case FORCED_REBALANCED:
channel.setShipStrategy(ShipStrategyType.PARTITION_FORCED_REBALANCE);
break;
+
+ case CUSTOM_PARTITIONING:
+ channel.setShipStrategy(ShipStrategyType.PARTITION_CUSTOM, Utils.createOrderedFromSet(this.partitioningFields), this.customPartitioner);
+ break;
+
default:
- throw new CompilerException();
+ throw new CompilerException("Invalid partitioning to create through a data exchange: " + this.partitioning.name());
}
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/operators/AbstractJoinDescriptor.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/operators/AbstractJoinDescriptor.java b/flink-compiler/src/main/java/org/apache/flink/compiler/operators/AbstractJoinDescriptor.java
index 47069e6..84af77c 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/operators/AbstractJoinDescriptor.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/operators/AbstractJoinDescriptor.java
@@ -21,6 +21,7 @@ package org.apache.flink.compiler.operators;
import java.util.ArrayList;
import java.util.List;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.operators.util.FieldList;
import org.apache.flink.compiler.dataproperties.GlobalProperties;
import org.apache.flink.compiler.dataproperties.PartitioningProperty;
@@ -35,6 +36,8 @@ public abstract class AbstractJoinDescriptor extends OperatorDescriptorDual {
private final boolean broadcastSecondAllowed;
private final boolean repartitionAllowed;
+ private Partitioner<?> customPartitioner;
+
protected AbstractJoinDescriptor(FieldList keys1, FieldList keys2) {
this(keys1, keys2, true, true, true);
}
@@ -49,16 +52,30 @@ public abstract class AbstractJoinDescriptor extends OperatorDescriptorDual {
this.repartitionAllowed = repartitionAllowed;
}
+ public void setCustomPartitioner(Partitioner<?> partitioner) {
+ customPartitioner = partitioner;
+ }
+
@Override
protected List<GlobalPropertiesPair> createPossibleGlobalProperties() {
ArrayList<GlobalPropertiesPair> pairs = new ArrayList<GlobalPropertiesPair>();
if (repartitionAllowed) {
- // partition both (hash)
+ // partition both (hash or custom)
RequestedGlobalProperties partitioned1 = new RequestedGlobalProperties();
- partitioned1.setHashPartitioned(this.keys1);
+ if (customPartitioner == null) {
+ partitioned1.setHashPartitioned(this.keys1);
+ } else {
+ partitioned1.setCustomPartitioned(this.keys1, this.customPartitioner);
+ }
+
RequestedGlobalProperties partitioned2 = new RequestedGlobalProperties();
- partitioned2.setHashPartitioned(this.keys2);
+ if (customPartitioner == null) {
+ partitioned2.setHashPartitioned(this.keys2);
+ } else {
+ partitioned2.setCustomPartitioned(this.keys2, this.customPartitioner);
+ }
+
pairs.add(new GlobalPropertiesPair(partitioned1, partitioned2));
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/operators/GroupReduceProperties.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/operators/GroupReduceProperties.java b/flink-compiler/src/main/java/org/apache/flink/compiler/operators/GroupReduceProperties.java
index bf09bcc..ab93170 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/operators/GroupReduceProperties.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/operators/GroupReduceProperties.java
@@ -21,6 +21,7 @@ package org.apache.flink.compiler.operators;
import java.util.Collections;
import java.util.List;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.operators.Ordering;
import org.apache.flink.api.common.operators.util.FieldSet;
@@ -38,12 +39,22 @@ public final class GroupReduceProperties extends OperatorDescriptorSingle {
private final Ordering ordering; // ordering that we need to use if an additional ordering is requested
+ private final Partitioner<?> customPartitioner;
+
public GroupReduceProperties(FieldSet keys) {
- this(keys, null);
+ this(keys, null, null);
+ }
+
+ public GroupReduceProperties(FieldSet keys, Ordering additionalOrderKeys) {
+ this(keys, additionalOrderKeys, null);
}
- public GroupReduceProperties(FieldSet groupKeys, Ordering additionalOrderKeys) {
+ public GroupReduceProperties(FieldSet keys, Partitioner<?> customPartitioner) {
+ this(keys, null, customPartitioner);
+ }
+
+ public GroupReduceProperties(FieldSet groupKeys, Ordering additionalOrderKeys, Partitioner<?> customPartitioner) {
super(groupKeys);
// if we have an additional ordering, construct the ordering to have primarily the grouping fields
@@ -59,9 +70,12 @@ public final class GroupReduceProperties extends OperatorDescriptorSingle {
Order order = additionalOrderKeys.getOrder(i);
this.ordering.appendOrdering(field, additionalOrderKeys.getType(i), order);
}
- } else {
+ }
+ else {
this.ordering = null;
}
+
+ this.customPartitioner = customPartitioner;
}
@Override
@@ -71,13 +85,18 @@ public final class GroupReduceProperties extends OperatorDescriptorSingle {
@Override
public SingleInputPlanNode instantiate(Channel in, SingleInputNode node) {
- return new SingleInputPlanNode(node, "Reduce("+node.getPactContract().getName()+")", in, DriverStrategy.SORTED_GROUP_REDUCE, this.keyList);
+ return new SingleInputPlanNode(node, "GroupReduce ("+node.getPactContract().getName()+")", in, DriverStrategy.SORTED_GROUP_REDUCE, this.keyList);
}
@Override
protected List<RequestedGlobalProperties> createPossibleGlobalProperties() {
RequestedGlobalProperties props = new RequestedGlobalProperties();
- props.setAnyPartitioning(this.keys);
+
+ if (customPartitioner == null) {
+ props.setAnyPartitioning(this.keys);
+ } else {
+ props.setCustomPartitioned(this.keys, this.customPartitioner);
+ }
return Collections.singletonList(props);
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/operators/GroupReduceWithCombineProperties.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/operators/GroupReduceWithCombineProperties.java b/flink-compiler/src/main/java/org/apache/flink/compiler/operators/GroupReduceWithCombineProperties.java
index 92b2297..8604951 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/operators/GroupReduceWithCombineProperties.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/operators/GroupReduceWithCombineProperties.java
@@ -21,6 +21,7 @@ package org.apache.flink.compiler.operators;
import java.util.Collections;
import java.util.List;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.operators.Ordering;
import org.apache.flink.api.common.operators.util.FieldSet;
@@ -42,12 +43,22 @@ public final class GroupReduceWithCombineProperties extends OperatorDescriptorSi
private final Ordering ordering; // ordering that we need to use if an additional ordering is requested
+ private final Partitioner<?> customPartitioner;
- public GroupReduceWithCombineProperties(FieldSet keys) {
- this(keys, null);
+
+ public GroupReduceWithCombineProperties(FieldSet groupKeys) {
+ this(groupKeys, null, null);
}
public GroupReduceWithCombineProperties(FieldSet groupKeys, Ordering additionalOrderKeys) {
+ this(groupKeys, additionalOrderKeys, null);
+ }
+
+ public GroupReduceWithCombineProperties(FieldSet groupKeys, Partitioner<?> customPartitioner) {
+ this(groupKeys, null, customPartitioner);
+ }
+
+ public GroupReduceWithCombineProperties(FieldSet groupKeys, Ordering additionalOrderKeys, Partitioner<?> customPartitioner) {
super(groupKeys);
// if we have an additional ordering, construct the ordering to have primarily the grouping fields
@@ -66,6 +77,8 @@ public final class GroupReduceWithCombineProperties extends OperatorDescriptorSi
} else {
this.ordering = null;
}
+
+ this.customPartitioner = customPartitioner;
}
@Override
@@ -111,7 +124,11 @@ public final class GroupReduceWithCombineProperties extends OperatorDescriptorSi
@Override
protected List<RequestedGlobalProperties> createPossibleGlobalProperties() {
RequestedGlobalProperties props = new RequestedGlobalProperties();
- props.setAnyPartitioning(this.keys);
+ if (customPartitioner == null) {
+ props.setAnyPartitioning(this.keys);
+ } else {
+ props.setCustomPartitioned(this.keys, this.customPartitioner);
+ }
return Collections.singletonList(props);
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/operators/ReduceProperties.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/operators/ReduceProperties.java b/flink-compiler/src/main/java/org/apache/flink/compiler/operators/ReduceProperties.java
index 9d2e86a..813af20 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/operators/ReduceProperties.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/operators/ReduceProperties.java
@@ -21,6 +21,7 @@ package org.apache.flink.compiler.operators;
import java.util.Collections;
import java.util.List;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.operators.util.FieldSet;
import org.apache.flink.compiler.costs.Costs;
import org.apache.flink.compiler.dag.ReduceNode;
@@ -38,8 +39,15 @@ import org.apache.flink.runtime.operators.util.LocalStrategy;
public final class ReduceProperties extends OperatorDescriptorSingle {
+ private final Partitioner<?> customPartitioner;
+
public ReduceProperties(FieldSet keys) {
+ this(keys, null);
+ }
+
+ public ReduceProperties(FieldSet keys, Partitioner<?> customPartitioner) {
super(keys);
+ this.customPartitioner = customPartitioner;
}
@Override
@@ -77,7 +85,11 @@ public final class ReduceProperties extends OperatorDescriptorSingle {
@Override
protected List<RequestedGlobalProperties> createPossibleGlobalProperties() {
RequestedGlobalProperties props = new RequestedGlobalProperties();
- props.setAnyPartitioning(this.keys);
+ if (customPartitioner == null) {
+ props.setAnyPartitioning(this.keys);
+ } else {
+ props.setCustomPartitioned(this.keys, this.customPartitioner);
+ }
return Collections.singletonList(props);
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/operators/SortMergeJoinDescriptor.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/operators/SortMergeJoinDescriptor.java b/flink-compiler/src/main/java/org/apache/flink/compiler/operators/SortMergeJoinDescriptor.java
index 5c6de30..cd6094e 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/operators/SortMergeJoinDescriptor.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/operators/SortMergeJoinDescriptor.java
@@ -16,7 +16,6 @@
* limitations under the License.
*/
-
package org.apache.flink.compiler.operators;
import java.util.Collections;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/plan/Channel.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/plan/Channel.java b/flink-compiler/src/main/java/org/apache/flink/compiler/plan/Channel.java
index 5fb03f5..e159481 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/plan/Channel.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/plan/Channel.java
@@ -16,10 +16,10 @@
* limitations under the License.
*/
-
package org.apache.flink.compiler.plan;
import org.apache.flink.api.common.distributions.DataDistribution;
+import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.operators.util.FieldList;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
import org.apache.flink.api.common.typeutils.TypeSerializerFactory;
@@ -36,7 +36,7 @@ import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.apache.flink.runtime.operators.util.LocalStrategy;
/**
- *
+ * A Channel is a data exchange between two operators.
*/
public class Channel implements EstimateProvider, Cloneable, DumpableConnection<PlanNode> {
@@ -72,6 +72,8 @@ public class Channel implements EstimateProvider, Cloneable, DumpableConnection<
private DataDistribution dataDistribution;
+ private Partitioner<?> partitioner;
+
private TempMode tempMode;
private double relativeTempMemory;
@@ -125,17 +127,27 @@ public class Channel implements EstimateProvider, Cloneable, DumpableConnection<
}
public void setShipStrategy(ShipStrategyType strategy) {
- setShipStrategy(strategy, null, null);
+ setShipStrategy(strategy, null, null, null);
}
public void setShipStrategy(ShipStrategyType strategy, FieldList keys) {
- setShipStrategy(strategy, keys, null);
+ setShipStrategy(strategy, keys, null, null);
}
public void setShipStrategy(ShipStrategyType strategy, FieldList keys, boolean[] sortDirection) {
+ setShipStrategy(strategy, keys, sortDirection, null);
+ }
+
+ public void setShipStrategy(ShipStrategyType strategy, FieldList keys, Partitioner<?> partitioner) {
+ setShipStrategy(strategy, keys, null, partitioner);
+ }
+
+ public void setShipStrategy(ShipStrategyType strategy, FieldList keys, boolean[] sortDirection, Partitioner<?> partitioner) {
this.shipStrategy = strategy;
this.shipKeys = keys;
this.shipSortOrder = sortDirection;
+ this.partitioner = partitioner;
+
this.globalProps = null; // reset the global properties
}
@@ -187,6 +199,10 @@ public class Channel implements EstimateProvider, Cloneable, DumpableConnection<
return this.dataDistribution;
}
+ public Partitioner<?> getPartitioner() {
+ return partitioner;
+ }
+
public TempMode getTempMode() {
return this.tempMode;
}
@@ -245,7 +261,6 @@ public class Channel implements EstimateProvider, Cloneable, DumpableConnection<
public TypeSerializerFactory<?> getSerializer() {
return serializer;
}
-
/**
* Sets the serializer for this Channel.
@@ -381,6 +396,9 @@ public class Channel implements EstimateProvider, Cloneable, DumpableConnection<
case PARTITION_FORCED_REBALANCE:
this.globalProps.setForcedRebalanced();
break;
+ case PARTITION_CUSTOM:
+ this.globalProps.setCustomPartitioned(this.shipKeys, this.partitioner);
+ break;
case NONE:
throw new CompilerException("Cannot produce GlobalProperties before ship strategy is set.");
}
@@ -411,6 +429,7 @@ public class Channel implements EstimateProvider, Cloneable, DumpableConnection<
switch (this.shipStrategy) {
case BROADCAST:
case PARTITION_HASH:
+ case PARTITION_CUSTOM:
case PARTITION_RANGE:
case PARTITION_RANDOM:
case PARTITION_FORCED_REBALANCE:
@@ -448,6 +467,7 @@ public class Channel implements EstimateProvider, Cloneable, DumpableConnection<
case PARTITION_RANGE:
case PARTITION_RANDOM:
case PARTITION_FORCED_REBALANCE:
+ case PARTITION_CUSTOM:
return;
}
throw new CompilerException("Unrecognized Ship Strategy Type: " + this.shipStrategy);
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/plandump/PlanJSONDumpGenerator.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/plandump/PlanJSONDumpGenerator.java b/flink-compiler/src/main/java/org/apache/flink/compiler/plandump/PlanJSONDumpGenerator.java
index 41dfd9b..7728948 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/plandump/PlanJSONDumpGenerator.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/plandump/PlanJSONDumpGenerator.java
@@ -336,6 +336,9 @@ public class PlanJSONDumpGenerator {
case PARTITION_FORCED_REBALANCE:
shipStrategy = "Rebalance";
break;
+ case PARTITION_CUSTOM:
+ shipStrategy = "Custom Partition";
+ break;
default:
throw new CompilerException("Unknown ship strategy '" + inConn.getShipStrategy().name()
+ "' in JSON generator.");
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/main/java/org/apache/flink/compiler/plantranslate/NepheleJobGraphGenerator.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/plantranslate/NepheleJobGraphGenerator.java b/flink-compiler/src/main/java/org/apache/flink/compiler/plantranslate/NepheleJobGraphGenerator.java
index b717924..64eca7c 100644
--- a/flink-compiler/src/main/java/org/apache/flink/compiler/plantranslate/NepheleJobGraphGenerator.java
+++ b/flink-compiler/src/main/java/org/apache/flink/compiler/plantranslate/NepheleJobGraphGenerator.java
@@ -1046,6 +1046,7 @@ public class NepheleJobGraphGenerator implements Visitor<PlanNode> {
case PARTITION_RANDOM:
case BROADCAST:
case PARTITION_HASH:
+ case PARTITION_CUSTOM:
case PARTITION_RANGE:
case PARTITION_FORCED_REBALANCE:
distributionPattern = DistributionPattern.BIPARTITE;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/test/java/org/apache/flink/compiler/FeedbackPropertiesMatchTest.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/test/java/org/apache/flink/compiler/FeedbackPropertiesMatchTest.java b/flink-compiler/src/test/java/org/apache/flink/compiler/FeedbackPropertiesMatchTest.java
index 5d45159..e3f5267 100644
--- a/flink-compiler/src/test/java/org/apache/flink/compiler/FeedbackPropertiesMatchTest.java
+++ b/flink-compiler/src/test/java/org/apache/flink/compiler/FeedbackPropertiesMatchTest.java
@@ -37,7 +37,7 @@ import org.apache.flink.api.common.operators.util.FieldSet;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.compiler.dag.DataSourceNode;
import org.apache.flink.compiler.dag.MapNode;
-import org.apache.flink.compiler.dag.MatchNode;
+import org.apache.flink.compiler.dag.JoinNode;
import org.apache.flink.compiler.dataproperties.GlobalProperties;
import org.apache.flink.compiler.dataproperties.LocalProperties;
import org.apache.flink.compiler.dataproperties.RequestedGlobalProperties;
@@ -1429,7 +1429,7 @@ public class FeedbackPropertiesMatchTest {
return new MapNode(new MapOperatorBase<String, String, MapFunction<String,String>>(new IdentityMapper<String>(), new UnaryOperatorInformation<String, String>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO), "map op"));
}
- private static final MatchNode getJoinNode() {
- return new MatchNode(new JoinOperatorBase<String, String, String, FlatJoinFunction<String, String, String>>(new DummyFlatJoinFunction<String>(), new BinaryOperatorInformation<String, String, String>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO), new int[] {1}, new int[] {2}, "join op"));
+ private static final JoinNode getJoinNode() {
+ return new JoinNode(new JoinOperatorBase<String, String, String, FlatJoinFunction<String, String, String>>(new DummyFlatJoinFunction<String>(), new BinaryOperatorInformation<String, String, String>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO), new int[] {1}, new int[] {2}, "join op"));
}
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/CustomPartitioningGlobalOptimizationTest.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/CustomPartitioningGlobalOptimizationTest.java b/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/CustomPartitioningGlobalOptimizationTest.java
new file mode 100644
index 0000000..34484d7
--- /dev/null
+++ b/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/CustomPartitioningGlobalOptimizationTest.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.compiler.custompartition;
+
+import static org.junit.Assert.*;
+
+import org.junit.Test;
+import org.apache.flink.api.common.Plan;
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.compiler.CompilerTestBase;
+import org.apache.flink.compiler.plan.DualInputPlanNode;
+import org.apache.flink.compiler.plan.OptimizedPlan;
+import org.apache.flink.compiler.plan.SingleInputPlanNode;
+import org.apache.flink.compiler.plan.SinkPlanNode;
+import org.apache.flink.compiler.testfunctions.IdentityGroupReducer;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+
+
+@SuppressWarnings({"serial", "unchecked"})
+public class CustomPartitioningGlobalOptimizationTest extends CompilerTestBase {
+
+ @Test
+ public void testJoinReduceCombination() {
+ try {
+ final Partitioner<Long> partitioner = new TestPartitionerLong();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Long, Long>> input1 = env.fromElements(new Tuple2<Long, Long>(0L, 0L));
+ DataSet<Tuple3<Long, Long, Long>> input2 = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L));
+
+ DataSet<Tuple3<Long, Long, Long>> joined = input1.join(input2)
+ .where(1).equalTo(0)
+ .projectFirst(0,1).projectSecond(2).types(Long.class, Long.class, Long.class)
+ .withPartitioner(partitioner);
+
+ joined.groupBy(1).withPartitioner(partitioner)
+ .reduceGroup(new IdentityGroupReducer<Tuple3<Long,Long,Long>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+
+ assertTrue("Reduce is not chained, property reuse does not happen",
+ reducer.getInput().getSource() instanceof DualInputPlanNode);
+
+ DualInputPlanNode join = (DualInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy());
+ assertEquals(partitioner, join.getInput1().getPartitioner());
+ assertEquals(partitioner, join.getInput2().getPartitioner());
+
+ assertEquals(ShipStrategyType.FORWARD, reducer.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ // --------------------------------------------------------------------------------------------
+
+ private static class TestPartitionerLong implements Partitioner<Long> {
+ @Override
+ public int partition(Long key, int numPartitions) {
+ return 0;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/2000b45c/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/CustomPartitioningTest.java
----------------------------------------------------------------------
diff --git a/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/CustomPartitioningTest.java b/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/CustomPartitioningTest.java
new file mode 100644
index 0000000..67505bf
--- /dev/null
+++ b/flink-compiler/src/test/java/org/apache/flink/compiler/custompartition/CustomPartitioningTest.java
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.compiler.custompartition;
+
+import static org.junit.Assert.*;
+
+import org.apache.flink.api.common.InvalidProgramException;
+import org.apache.flink.api.common.Plan;
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.compiler.CompilerTestBase;
+import org.apache.flink.compiler.plan.OptimizedPlan;
+import org.apache.flink.compiler.plan.SingleInputPlanNode;
+import org.apache.flink.compiler.plan.SinkPlanNode;
+import org.apache.flink.compiler.testfunctions.IdentityPartitionerMapper;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+import org.junit.Test;
+
+@SuppressWarnings({"serial", "unchecked"})
+public class CustomPartitioningTest extends CompilerTestBase {
+
+ @Test
+ public void testPartitionTuples() {
+ try {
+ final Partitioner<Integer> part = new TestPartitionerInt();
+ final int parallelism = 4;
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(parallelism);
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer,Integer>(0, 0))
+ .rebalance();
+
+ data
+ .partitionCustom(part, 0)
+ .mapPartition(new IdentityPartitionerMapper<Tuple2<Integer,Integer>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode mapper = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode partitioner = (SingleInputPlanNode) mapper.getInput().getSource();
+ SingleInputPlanNode balancer = (SingleInputPlanNode) partitioner.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(parallelism, sink.getDegreeOfParallelism());
+
+ assertEquals(ShipStrategyType.FORWARD, mapper.getInput().getShipStrategy());
+ assertEquals(parallelism, mapper.getDegreeOfParallelism());
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner.getInput().getShipStrategy());
+ assertEquals(part, partitioner.getInput().getPartitioner());
+ assertEquals(parallelism, partitioner.getDegreeOfParallelism());
+
+ assertEquals(ShipStrategyType.PARTITION_FORCED_REBALANCE, balancer.getInput().getShipStrategy());
+ assertEquals(parallelism, balancer.getDegreeOfParallelism());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testPartitionTuplesInvalidType() {
+ try {
+ final int parallelism = 4;
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(parallelism);
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer,Integer>(0, 0))
+ .rebalance();
+
+ try {
+ data
+ .partitionCustom(new TestPartitionerLong(), 0);
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {
+ // expected
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testPartitionPojo() {
+ try {
+ final Partitioner<Integer> part = new TestPartitionerInt();
+ final int parallelism = 4;
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(parallelism);
+
+ DataSet<Pojo> data = env.fromElements(new Pojo())
+ .rebalance();
+
+ data
+ .partitionCustom(part, "a")
+ .mapPartition(new IdentityPartitionerMapper<Pojo>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode mapper = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode partitioner = (SingleInputPlanNode) mapper.getInput().getSource();
+ SingleInputPlanNode balancer = (SingleInputPlanNode) partitioner.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(parallelism, sink.getDegreeOfParallelism());
+
+ assertEquals(ShipStrategyType.FORWARD, mapper.getInput().getShipStrategy());
+ assertEquals(parallelism, mapper.getDegreeOfParallelism());
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner.getInput().getShipStrategy());
+ assertEquals(part, partitioner.getInput().getPartitioner());
+ assertEquals(parallelism, partitioner.getDegreeOfParallelism());
+
+ assertEquals(ShipStrategyType.PARTITION_FORCED_REBALANCE, balancer.getInput().getShipStrategy());
+ assertEquals(parallelism, balancer.getDegreeOfParallelism());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testPartitionPojoInvalidType() {
+ try {
+ final int parallelism = 4;
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(parallelism);
+
+ DataSet<Pojo> data = env.fromElements(new Pojo())
+ .rebalance();
+
+ try {
+ data
+ .partitionCustom(new TestPartitionerLong(), "a");
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {
+ // expected
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testPartitionKeySelector() {
+ try {
+ final Partitioner<Integer> part = new TestPartitionerInt();
+ final int parallelism = 4;
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(parallelism);
+
+ DataSet<Pojo> data = env.fromElements(new Pojo())
+ .rebalance();
+
+ data
+ .partitionCustom(part, new TestKeySelectorInt<Pojo>())
+ .mapPartition(new IdentityPartitionerMapper<Pojo>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode mapper = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode keyRemover = (SingleInputPlanNode) mapper.getInput().getSource();
+ SingleInputPlanNode partitioner = (SingleInputPlanNode) keyRemover.getInput().getSource();
+ SingleInputPlanNode keyExtractor = (SingleInputPlanNode) partitioner.getInput().getSource();
+ SingleInputPlanNode balancer = (SingleInputPlanNode) keyExtractor.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(parallelism, sink.getDegreeOfParallelism());
+
+ assertEquals(ShipStrategyType.FORWARD, mapper.getInput().getShipStrategy());
+ assertEquals(parallelism, mapper.getDegreeOfParallelism());
+
+ assertEquals(ShipStrategyType.FORWARD, keyRemover.getInput().getShipStrategy());
+ assertEquals(parallelism, keyRemover.getDegreeOfParallelism());
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner.getInput().getShipStrategy());
+ assertEquals(part, partitioner.getInput().getPartitioner());
+ assertEquals(parallelism, partitioner.getDegreeOfParallelism());
+
+ assertEquals(ShipStrategyType.FORWARD, keyExtractor.getInput().getShipStrategy());
+ assertEquals(parallelism, keyExtractor.getDegreeOfParallelism());
+
+ assertEquals(ShipStrategyType.PARTITION_FORCED_REBALANCE, balancer.getInput().getShipStrategy());
+ assertEquals(parallelism, balancer.getDegreeOfParallelism());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testPartitionKeySelectorInvalidType() {
+ try {
+ final Partitioner<Integer> part = (Partitioner<Integer>) (Partitioner<?>) new TestPartitionerLong();
+ final int parallelism = 4;
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(parallelism);
+
+ DataSet<Pojo> data = env.fromElements(new Pojo())
+ .rebalance();
+
+ try {
+ data
+ .partitionCustom(part, new TestKeySelectorInt<Pojo>());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {
+ // expected
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ // --------------------------------------------------------------------------------------------
+
+ public static class Pojo {
+ public int a;
+ public int b;
+ }
+
+ private static class TestPartitionerInt implements Partitioner<Integer> {
+ @Override
+ public int partition(Integer key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ private static class TestPartitionerLong implements Partitioner<Long> {
+ @Override
+ public int partition(Long key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ private static class TestKeySelectorInt<T> implements KeySelector<T, Integer> {
+ @Override
+ public Integer getKey(T value) {
+ return null;
+ }
+ }
+}