You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2015/03/20 11:06:43 UTC
[04/53] [abbrv] flink git commit: [optimizer] Rename optimizer
project to "flink-optimizer" (previously flink-compiler)
http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/BinaryCustomPartitioningCompatibilityTest.java
----------------------------------------------------------------------
diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/BinaryCustomPartitioningCompatibilityTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/BinaryCustomPartitioningCompatibilityTest.java
new file mode 100644
index 0000000..0273659
--- /dev/null
+++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/BinaryCustomPartitioningCompatibilityTest.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.optimizer.custompartition;
+
+import static org.junit.Assert.*;
+
+import org.apache.flink.api.common.Plan;
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.optimizer.CompilerTestBase;
+import org.apache.flink.optimizer.plan.DualInputPlanNode;
+import org.apache.flink.optimizer.plan.OptimizedPlan;
+import org.apache.flink.optimizer.plan.SingleInputPlanNode;
+import org.apache.flink.optimizer.plan.SinkPlanNode;
+import org.apache.flink.optimizer.plantranslate.JobGraphGenerator;
+import org.apache.flink.optimizer.testfunctions.DummyCoGroupFunction;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+import org.junit.Test;
+
+@SuppressWarnings({"serial","unchecked"})
+public class BinaryCustomPartitioningCompatibilityTest extends CompilerTestBase {
+
+ @Test
+ public void testCompatiblePartitioningJoin() {
+ try {
+ final Partitioner<Long> partitioner = new Partitioner<Long>() {
+ @Override
+ public int partition(Long key, int numPartitions) {
+ return 0;
+ }
+ };
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Long, Long>> input1 = env.fromElements(new Tuple2<Long, Long>(0L, 0L));
+ DataSet<Tuple3<Long, Long, Long>> input2 = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L));
+
+ input1.partitionCustom(partitioner, 1)
+ .join(input2.partitionCustom(partitioner, 0))
+ .where(1).equalTo(0)
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ DualInputPlanNode join = (DualInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode partitioner1 = (SingleInputPlanNode) join.getInput1().getSource();
+ SingleInputPlanNode partitioner2 = (SingleInputPlanNode) join.getInput2().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, join.getInput1().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, join.getInput2().getShipStrategy());
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner1.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner2.getInput().getShipStrategy());
+ assertEquals(partitioner, partitioner1.getInput().getPartitioner());
+ assertEquals(partitioner, partitioner2.getInput().getPartitioner());
+
+ new JobGraphGenerator().compileJobGraph(op);
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCompatiblePartitioningCoGroup() {
+ try {
+ final Partitioner<Long> partitioner = new Partitioner<Long>() {
+ @Override
+ public int partition(Long key, int numPartitions) {
+ return 0;
+ }
+ };
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Long, Long>> input1 = env.fromElements(new Tuple2<Long, Long>(0L, 0L));
+ DataSet<Tuple3<Long, Long, Long>> input2 = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L));
+
+ input1.partitionCustom(partitioner, 1)
+ .coGroup(input2.partitionCustom(partitioner, 0))
+ .where(1).equalTo(0)
+ .with(new DummyCoGroupFunction<Tuple2<Long, Long>, Tuple3<Long, Long, Long>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ DualInputPlanNode coGroup = (DualInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode partitioner1 = (SingleInputPlanNode) coGroup.getInput1().getSource();
+ SingleInputPlanNode partitioner2 = (SingleInputPlanNode) coGroup.getInput2().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, coGroup.getInput1().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, coGroup.getInput2().getShipStrategy());
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner1.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner2.getInput().getShipStrategy());
+ assertEquals(partitioner, partitioner1.getInput().getPartitioner());
+ assertEquals(partitioner, partitioner2.getInput().getPartitioner());
+
+ new JobGraphGenerator().compileJobGraph(op);
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CoGroupCustomPartitioningTest.java
----------------------------------------------------------------------
diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CoGroupCustomPartitioningTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CoGroupCustomPartitioningTest.java
new file mode 100644
index 0000000..08f7388
--- /dev/null
+++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CoGroupCustomPartitioningTest.java
@@ -0,0 +1,312 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.optimizer.custompartition;
+
+import static org.junit.Assert.*;
+
+import org.apache.flink.api.common.InvalidProgramException;
+import org.apache.flink.api.common.Plan;
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.common.operators.Order;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.optimizer.CompilerTestBase;
+import org.apache.flink.optimizer.plan.DualInputPlanNode;
+import org.apache.flink.optimizer.plan.OptimizedPlan;
+import org.apache.flink.optimizer.plan.SinkPlanNode;
+import org.apache.flink.optimizer.testfunctions.DummyCoGroupFunction;
+import org.apache.flink.optimizer.testfunctions.IdentityGroupReducer;
+import org.apache.flink.optimizer.testfunctions.IdentityMapper;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+import org.junit.Test;
+
+@SuppressWarnings({"serial", "unchecked"})
+public class CoGroupCustomPartitioningTest extends CompilerTestBase {
+
+ @Test
+ public void testCoGroupWithTuples() {
+ try {
+ final Partitioner<Long> partitioner = new TestPartitionerLong();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Long, Long>> input1 = env.fromElements(new Tuple2<Long, Long>(0L, 0L));
+ DataSet<Tuple3<Long, Long, Long>> input2 = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L));
+
+ input1
+ .coGroup(input2)
+ .where(1).equalTo(0)
+ .withPartitioner(partitioner)
+ .with(new DummyCoGroupFunction<Tuple2<Long, Long>, Tuple3<Long, Long, Long>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ DualInputPlanNode join = (DualInputPlanNode) sink.getInput().getSource();
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy());
+ assertEquals(partitioner, join.getInput1().getPartitioner());
+ assertEquals(partitioner, join.getInput2().getPartitioner());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCoGroupWithTuplesWrongType() {
+ try {
+ final Partitioner<Integer> partitioner = new TestPartitionerInt();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Long, Long>> input1 = env.fromElements(new Tuple2<Long, Long>(0L, 0L));
+ DataSet<Tuple3<Long, Long, Long>> input2 = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L));
+
+ try {
+ input1
+ .coGroup(input2)
+ .where(1).equalTo(0)
+ .withPartitioner(partitioner);
+ fail("should throw an exception");
+ }
+ catch (InvalidProgramException e) {
+ // expected
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCoGroupWithPojos() {
+ try {
+ final Partitioner<Integer> partitioner = new TestPartitionerInt();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> input1 = env.fromElements(new Pojo2());
+ DataSet<Pojo3> input2 = env.fromElements(new Pojo3());
+
+ input1
+ .coGroup(input2)
+ .where("b").equalTo("a")
+ .withPartitioner(partitioner)
+ .with(new DummyCoGroupFunction<Pojo2, Pojo3>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ DualInputPlanNode join = (DualInputPlanNode) sink.getInput().getSource();
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy());
+ assertEquals(partitioner, join.getInput1().getPartitioner());
+ assertEquals(partitioner, join.getInput2().getPartitioner());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCoGroupWithPojosWrongType() {
+ try {
+ final Partitioner<Long> partitioner = new TestPartitionerLong();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> input1 = env.fromElements(new Pojo2());
+ DataSet<Pojo3> input2 = env.fromElements(new Pojo3());
+
+ try {
+ input1
+ .coGroup(input2)
+ .where("a").equalTo("b")
+ .withPartitioner(partitioner);
+
+ fail("should throw an exception");
+ }
+ catch (InvalidProgramException e) {
+ // expected
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCoGroupWithKeySelectors() {
+ try {
+ final Partitioner<Integer> partitioner = new TestPartitionerInt();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> input1 = env.fromElements(new Pojo2());
+ DataSet<Pojo3> input2 = env.fromElements(new Pojo3());
+
+ input1
+ .coGroup(input2)
+ .where(new Pojo2KeySelector()).equalTo(new Pojo3KeySelector())
+ .withPartitioner(partitioner)
+ .with(new DummyCoGroupFunction<Pojo2, Pojo3>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ DualInputPlanNode join = (DualInputPlanNode) sink.getInput().getSource();
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy());
+ assertEquals(partitioner, join.getInput1().getPartitioner());
+ assertEquals(partitioner, join.getInput2().getPartitioner());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCoGroupWithKeySelectorsWrongType() {
+ try {
+ final Partitioner<Long> partitioner = new TestPartitionerLong();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> input1 = env.fromElements(new Pojo2());
+ DataSet<Pojo3> input2 = env.fromElements(new Pojo3());
+
+ try {
+ input1
+ .coGroup(input2)
+ .where(new Pojo2KeySelector()).equalTo(new Pojo3KeySelector())
+ .withPartitioner(partitioner);
+
+ fail("should throw an exception");
+ }
+ catch (InvalidProgramException e) {
+ // expected
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testIncompatibleHashAndCustomPartitioning() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple3<Long, Long, Long>> input = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L));
+
+ DataSet<Tuple3<Long, Long, Long>> partitioned = input
+ .partitionCustom(new Partitioner<Long>() {
+ @Override
+ public int partition(Long key, int numPartitions) { return 0; }
+ }, 0)
+ .map(new IdentityMapper<Tuple3<Long,Long,Long>>()).withForwardedFields("0", "1", "2");
+
+
+ DataSet<Tuple3<Long, Long, Long>> grouped = partitioned
+ .distinct(0, 1)
+ .groupBy(1)
+ .sortGroup(0, Order.ASCENDING)
+ .reduceGroup(new IdentityGroupReducer<Tuple3<Long,Long,Long>>()).withForwardedFields("0", "1");
+
+ grouped
+ .coGroup(partitioned).where(0).equalTo(0)
+ .with(new DummyCoGroupFunction<Tuple3<Long,Long,Long>, Tuple3<Long,Long,Long>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ DualInputPlanNode coGroup = (DualInputPlanNode) sink.getInput().getSource();
+
+ assertEquals(ShipStrategyType.PARTITION_HASH, coGroup.getInput1().getShipStrategy());
+ assertTrue(coGroup.getInput2().getShipStrategy() == ShipStrategyType.PARTITION_HASH ||
+ coGroup.getInput2().getShipStrategy() == ShipStrategyType.FORWARD);
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ // --------------------------------------------------------------------------------------------
+
+ private static class TestPartitionerInt implements Partitioner<Integer> {
+ @Override
+ public int partition(Integer key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ private static class TestPartitionerLong implements Partitioner<Long> {
+ @Override
+ public int partition(Long key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ public static class Pojo2 {
+ public int a;
+ public int b;
+ }
+
+ public static class Pojo3 {
+ public int a;
+ public int b;
+ public int c;
+ }
+
+ private static class Pojo2KeySelector implements KeySelector<Pojo2, Integer> {
+ @Override
+ public Integer getKey(Pojo2 value) {
+ return value.a;
+ }
+ }
+
+ private static class Pojo3KeySelector implements KeySelector<Pojo3, Integer> {
+ @Override
+ public Integer getKey(Pojo3 value) {
+ return value.b;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CustomPartitioningGlobalOptimizationTest.java
----------------------------------------------------------------------
diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CustomPartitioningGlobalOptimizationTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CustomPartitioningGlobalOptimizationTest.java
new file mode 100644
index 0000000..9fd676f
--- /dev/null
+++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CustomPartitioningGlobalOptimizationTest.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.optimizer.custompartition;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import org.apache.flink.api.common.Plan;
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.optimizer.CompilerTestBase;
+import org.apache.flink.optimizer.plan.DualInputPlanNode;
+import org.apache.flink.optimizer.plan.OptimizedPlan;
+import org.apache.flink.optimizer.plan.SingleInputPlanNode;
+import org.apache.flink.optimizer.plan.SinkPlanNode;
+import org.apache.flink.optimizer.testfunctions.IdentityGroupReducer;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+import org.junit.Test;
+
+@SuppressWarnings({"serial", "unchecked"})
+public class CustomPartitioningGlobalOptimizationTest extends CompilerTestBase {
+
+ @Test
+ public void testJoinReduceCombination() {
+ try {
+ final Partitioner<Long> partitioner = new TestPartitionerLong();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Long, Long>> input1 = env.fromElements(new Tuple2<Long, Long>(0L, 0L));
+ DataSet<Tuple3<Long, Long, Long>> input2 = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L));
+
+ DataSet<Tuple3<Long, Long, Long>> joined = input1.join(input2)
+ .where(1).equalTo(0)
+ .projectFirst(0, 1)
+ .<Tuple3<Long, Long, Long>>projectSecond(2)
+ .withPartitioner(partitioner);
+
+ joined.groupBy(1).withPartitioner(partitioner)
+ .reduceGroup(new IdentityGroupReducer<Tuple3<Long,Long,Long>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+
+ assertTrue("Reduce is not chained, property reuse does not happen",
+ reducer.getInput().getSource() instanceof DualInputPlanNode);
+
+ DualInputPlanNode join = (DualInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy());
+ assertEquals(partitioner, join.getInput1().getPartitioner());
+ assertEquals(partitioner, join.getInput2().getPartitioner());
+
+ assertEquals(ShipStrategyType.FORWARD, reducer.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ // --------------------------------------------------------------------------------------------
+
+ private static class TestPartitionerLong implements Partitioner<Long> {
+ @Override
+ public int partition(Long key, int numPartitions) {
+ return 0;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CustomPartitioningTest.java
----------------------------------------------------------------------
diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CustomPartitioningTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CustomPartitioningTest.java
new file mode 100644
index 0000000..d397ea2
--- /dev/null
+++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CustomPartitioningTest.java
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.optimizer.custompartition;
+
+import static org.junit.Assert.*;
+
+import org.apache.flink.api.common.InvalidProgramException;
+import org.apache.flink.api.common.Plan;
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.optimizer.CompilerTestBase;
+import org.apache.flink.optimizer.plan.OptimizedPlan;
+import org.apache.flink.optimizer.plan.SingleInputPlanNode;
+import org.apache.flink.optimizer.plan.SinkPlanNode;
+import org.apache.flink.optimizer.testfunctions.IdentityPartitionerMapper;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+import org.junit.Test;
+
+@SuppressWarnings({"serial", "unchecked"})
+public class CustomPartitioningTest extends CompilerTestBase {
+
+ @Test
+ public void testPartitionTuples() {
+ try {
+ final Partitioner<Integer> part = new TestPartitionerInt();
+ final int parallelism = 4;
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(parallelism);
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer,Integer>(0, 0))
+ .rebalance();
+
+ data
+ .partitionCustom(part, 0)
+ .mapPartition(new IdentityPartitionerMapper<Tuple2<Integer,Integer>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode mapper = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode partitioner = (SingleInputPlanNode) mapper.getInput().getSource();
+ SingleInputPlanNode balancer = (SingleInputPlanNode) partitioner.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(parallelism, sink.getParallelism());
+
+ assertEquals(ShipStrategyType.FORWARD, mapper.getInput().getShipStrategy());
+ assertEquals(parallelism, mapper.getParallelism());
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner.getInput().getShipStrategy());
+ assertEquals(part, partitioner.getInput().getPartitioner());
+ assertEquals(parallelism, partitioner.getParallelism());
+
+ assertEquals(ShipStrategyType.PARTITION_FORCED_REBALANCE, balancer.getInput().getShipStrategy());
+ assertEquals(parallelism, balancer.getParallelism());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testPartitionTuplesInvalidType() {
+ try {
+ final int parallelism = 4;
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(parallelism);
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer,Integer>(0, 0))
+ .rebalance();
+
+ try {
+ data
+ .partitionCustom(new TestPartitionerLong(), 0);
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {
+ // expected
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testPartitionPojo() {
+ try {
+ final Partitioner<Integer> part = new TestPartitionerInt();
+ final int parallelism = 4;
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(parallelism);
+
+ DataSet<Pojo> data = env.fromElements(new Pojo())
+ .rebalance();
+
+ data
+ .partitionCustom(part, "a")
+ .mapPartition(new IdentityPartitionerMapper<Pojo>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode mapper = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode partitioner = (SingleInputPlanNode) mapper.getInput().getSource();
+ SingleInputPlanNode balancer = (SingleInputPlanNode) partitioner.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(parallelism, sink.getParallelism());
+
+ assertEquals(ShipStrategyType.FORWARD, mapper.getInput().getShipStrategy());
+ assertEquals(parallelism, mapper.getParallelism());
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner.getInput().getShipStrategy());
+ assertEquals(part, partitioner.getInput().getPartitioner());
+ assertEquals(parallelism, partitioner.getParallelism());
+
+ assertEquals(ShipStrategyType.PARTITION_FORCED_REBALANCE, balancer.getInput().getShipStrategy());
+ assertEquals(parallelism, balancer.getParallelism());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testPartitionPojoInvalidType() {
+ try {
+ final int parallelism = 4;
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(parallelism);
+
+ DataSet<Pojo> data = env.fromElements(new Pojo())
+ .rebalance();
+
+ try {
+ data
+ .partitionCustom(new TestPartitionerLong(), "a");
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {
+ // expected
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testPartitionKeySelector() {
+ try {
+ final Partitioner<Integer> part = new TestPartitionerInt();
+ final int parallelism = 4;
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(parallelism);
+
+ DataSet<Pojo> data = env.fromElements(new Pojo())
+ .rebalance();
+
+ data
+ .partitionCustom(part, new TestKeySelectorInt<Pojo>())
+ .mapPartition(new IdentityPartitionerMapper<Pojo>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode mapper = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode keyRemover = (SingleInputPlanNode) mapper.getInput().getSource();
+ SingleInputPlanNode partitioner = (SingleInputPlanNode) keyRemover.getInput().getSource();
+ SingleInputPlanNode keyExtractor = (SingleInputPlanNode) partitioner.getInput().getSource();
+ SingleInputPlanNode balancer = (SingleInputPlanNode) keyExtractor.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(parallelism, sink.getParallelism());
+
+ assertEquals(ShipStrategyType.FORWARD, mapper.getInput().getShipStrategy());
+ assertEquals(parallelism, mapper.getParallelism());
+
+ assertEquals(ShipStrategyType.FORWARD, keyRemover.getInput().getShipStrategy());
+ assertEquals(parallelism, keyRemover.getParallelism());
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner.getInput().getShipStrategy());
+ assertEquals(part, partitioner.getInput().getPartitioner());
+ assertEquals(parallelism, partitioner.getParallelism());
+
+ assertEquals(ShipStrategyType.FORWARD, keyExtractor.getInput().getShipStrategy());
+ assertEquals(parallelism, keyExtractor.getParallelism());
+
+ assertEquals(ShipStrategyType.PARTITION_FORCED_REBALANCE, balancer.getInput().getShipStrategy());
+ assertEquals(parallelism, balancer.getParallelism());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testPartitionKeySelectorInvalidType() {
+ try {
+ final Partitioner<Integer> part = (Partitioner<Integer>) (Partitioner<?>) new TestPartitionerLong();
+ final int parallelism = 4;
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setDegreeOfParallelism(parallelism);
+
+ DataSet<Pojo> data = env.fromElements(new Pojo())
+ .rebalance();
+
+ try {
+ data
+ .partitionCustom(part, new TestKeySelectorInt<Pojo>());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {
+ // expected
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ // --------------------------------------------------------------------------------------------
+
+ public static class Pojo {
+ public int a;
+ public int b;
+ }
+
+ private static class TestPartitionerInt implements Partitioner<Integer> {
+ @Override
+ public int partition(Integer key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ private static class TestPartitionerLong implements Partitioner<Long> {
+ @Override
+ public int partition(Long key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ private static class TestKeySelectorInt<T> implements KeySelector<T, Integer> {
+ @Override
+ public Integer getKey(T value) {
+ return null;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingKeySelectorTranslationTest.java
----------------------------------------------------------------------
diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingKeySelectorTranslationTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingKeySelectorTranslationTest.java
new file mode 100644
index 0000000..360487b
--- /dev/null
+++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingKeySelectorTranslationTest.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.optimizer.custompartition;
+
+import static org.junit.Assert.*;
+
+import org.apache.flink.api.common.InvalidProgramException;
+import org.apache.flink.api.common.Plan;
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.common.operators.Order;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.optimizer.CompilerTestBase;
+import org.apache.flink.optimizer.plan.OptimizedPlan;
+import org.apache.flink.optimizer.plan.SingleInputPlanNode;
+import org.apache.flink.optimizer.plan.SinkPlanNode;
+import org.apache.flink.optimizer.testfunctions.DummyReducer;
+import org.apache.flink.optimizer.testfunctions.IdentityGroupReducer;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+import org.junit.Test;
+
+@SuppressWarnings({"serial", "unchecked"})
+public class GroupingKeySelectorTranslationTest extends CompilerTestBase {
+
+ @Test
+ public void testCustomPartitioningKeySelectorReduce() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0))
+ .rebalance().setParallelism(4);
+
+ data.groupBy(new TestKeySelector<Tuple2<Integer,Integer>>())
+ .withPartitioner(new TestPartitionerInt())
+ .reduce(new DummyReducer<Tuple2<Integer,Integer>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode keyRemovingMapper = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) keyRemovingMapper.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, keyRemovingMapper.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningKeySelectorGroupReduce() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0))
+ .rebalance().setParallelism(4);
+
+ data.groupBy(new TestKeySelector<Tuple2<Integer,Integer>>())
+ .withPartitioner(new TestPartitionerInt())
+ .reduceGroup(new IdentityGroupReducer<Tuple2<Integer,Integer>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningKeySelectorGroupReduceSorted() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0))
+ .rebalance().setParallelism(4);
+
+ data.groupBy(new TestKeySelector<Tuple3<Integer,Integer,Integer>>())
+ .withPartitioner(new TestPartitionerInt())
+ .sortGroup(new TestKeySelector<Tuple3<Integer, Integer, Integer>>(), Order.ASCENDING)
+ .reduceGroup(new IdentityGroupReducer<Tuple3<Integer,Integer,Integer>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningKeySelectorInvalidType() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0))
+ .rebalance().setParallelism(4);
+
+ try {
+ data
+ .groupBy(new TestKeySelector<Tuple2<Integer,Integer>>())
+ .withPartitioner(new TestPartitionerLong());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningKeySelectorInvalidTypeSorted() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0))
+ .rebalance().setParallelism(4);
+
+ try {
+ data
+ .groupBy(new TestKeySelector<Tuple3<Integer,Integer,Integer>>())
+ .sortGroup(1, Order.ASCENDING)
+ .withPartitioner(new TestPartitionerLong());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleRejectCompositeKey() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0))
+ .rebalance().setParallelism(4);
+
+ try {
+ data
+ .groupBy(new TestBinaryKeySelector<Tuple3<Integer,Integer,Integer>>())
+ .withPartitioner(new TestPartitionerInt());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ // --------------------------------------------------------------------------------------------
+
+ private static class TestPartitionerInt implements Partitioner<Integer> {
+ @Override
+ public int partition(Integer key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ private static class TestPartitionerLong implements Partitioner<Long> {
+ @Override
+ public int partition(Long key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ private static class TestKeySelector<T extends Tuple> implements KeySelector<T, Integer> {
+ @Override
+ public Integer getKey(T value) {
+ return value.getField(0);
+ }
+ }
+
+ private static class TestBinaryKeySelector<T extends Tuple> implements KeySelector<T, Tuple2<Integer, Integer>> {
+ @Override
+ public Tuple2<Integer, Integer> getKey(T value) {
+ return new Tuple2<Integer, Integer>(value.<Integer>getField(0), value.<Integer>getField(1));
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingPojoTranslationTest.java
----------------------------------------------------------------------
diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingPojoTranslationTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingPojoTranslationTest.java
new file mode 100644
index 0000000..8cd4809
--- /dev/null
+++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingPojoTranslationTest.java
@@ -0,0 +1,257 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.optimizer.custompartition;
+
+import static org.junit.Assert.*;
+
+import org.apache.flink.api.common.InvalidProgramException;
+import org.apache.flink.api.common.Plan;
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.common.operators.Order;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.optimizer.CompilerTestBase;
+import org.apache.flink.optimizer.plan.OptimizedPlan;
+import org.apache.flink.optimizer.plan.SingleInputPlanNode;
+import org.apache.flink.optimizer.plan.SinkPlanNode;
+import org.apache.flink.optimizer.testfunctions.DummyReducer;
+import org.apache.flink.optimizer.testfunctions.IdentityGroupReducer;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+import org.junit.Test;
+
+@SuppressWarnings("serial")
+public class GroupingPojoTranslationTest extends CompilerTestBase {
+
+ @Test
+ public void testCustomPartitioningTupleReduce() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> data = env.fromElements(new Pojo2())
+ .rebalance().setParallelism(4);
+
+ data.groupBy("a").withPartitioner(new TestPartitionerInt())
+ .reduce(new DummyReducer<Pojo2>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleGroupReduce() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> data = env.fromElements(new Pojo2())
+ .rebalance().setParallelism(4);
+
+ data.groupBy("a").withPartitioner(new TestPartitionerInt())
+ .reduceGroup(new IdentityGroupReducer<Pojo2>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleGroupReduceSorted() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo3> data = env.fromElements(new Pojo3())
+ .rebalance().setParallelism(4);
+
+ data.groupBy("a").withPartitioner(new TestPartitionerInt())
+ .sortGroup("b", Order.ASCENDING)
+ .reduceGroup(new IdentityGroupReducer<Pojo3>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleGroupReduceSorted2() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo4> data = env.fromElements(new Pojo4())
+ .rebalance().setParallelism(4);
+
+ data.groupBy("a").withPartitioner(new TestPartitionerInt())
+ .sortGroup("b", Order.ASCENDING)
+ .sortGroup("c", Order.DESCENDING)
+ .reduceGroup(new IdentityGroupReducer<Pojo4>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleInvalidType() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> data = env.fromElements(new Pojo2())
+ .rebalance().setParallelism(4);
+
+ try {
+ data.groupBy("a").withPartitioner(new TestPartitionerLong());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleInvalidTypeSorted() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo3> data = env.fromElements(new Pojo3())
+ .rebalance().setParallelism(4);
+
+ try {
+ data.groupBy("a")
+ .sortGroup("b", Order.ASCENDING)
+ .withPartitioner(new TestPartitionerLong());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleRejectCompositeKey() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> data = env.fromElements(new Pojo2())
+ .rebalance().setParallelism(4);
+
+ try {
+ data.groupBy("a", "b")
+ .withPartitioner(new TestPartitionerInt());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ // --------------------------------------------------------------------------------------------
+
+ public static class Pojo2 {
+ public int a;
+ public int b;
+
+ }
+
+ public static class Pojo3 {
+ public int a;
+ public int b;
+ public int c;
+ }
+
+ public static class Pojo4 {
+ public int a;
+ public int b;
+ public int c;
+ public int d;
+ }
+
+ private static class TestPartitionerInt implements Partitioner<Integer> {
+ @Override
+ public int partition(Integer key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ private static class TestPartitionerLong implements Partitioner<Long> {
+ @Override
+ public int partition(Long key, int numPartitions) {
+ return 0;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingTupleTranslationTest.java
----------------------------------------------------------------------
diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingTupleTranslationTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingTupleTranslationTest.java
new file mode 100644
index 0000000..779b8e5
--- /dev/null
+++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingTupleTranslationTest.java
@@ -0,0 +1,270 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.optimizer.custompartition;
+
+import static org.junit.Assert.*;
+
+import org.apache.flink.api.common.InvalidProgramException;
+import org.apache.flink.api.common.Plan;
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.common.operators.Order;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.optimizer.CompilerTestBase;
+import org.apache.flink.optimizer.plan.OptimizedPlan;
+import org.apache.flink.optimizer.plan.SingleInputPlanNode;
+import org.apache.flink.optimizer.plan.SinkPlanNode;
+import org.apache.flink.optimizer.testfunctions.DummyReducer;
+import org.apache.flink.optimizer.testfunctions.IdentityGroupReducer;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+import org.junit.Test;
+
+@SuppressWarnings({"serial", "unchecked"})
+public class GroupingTupleTranslationTest extends CompilerTestBase {
+
+ @Test
+ public void testCustomPartitioningTupleAgg() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0))
+ .rebalance().setParallelism(4);
+
+ data.groupBy(0).withPartitioner(new TestPartitionerInt())
+ .sum(1)
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleReduce() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0))
+ .rebalance().setParallelism(4);
+
+ data.groupBy(0).withPartitioner(new TestPartitionerInt())
+ .reduce(new DummyReducer<Tuple2<Integer,Integer>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleGroupReduce() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0))
+ .rebalance().setParallelism(4);
+
+ data.groupBy(0).withPartitioner(new TestPartitionerInt())
+ .reduceGroup(new IdentityGroupReducer<Tuple2<Integer,Integer>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleGroupReduceSorted() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0))
+ .rebalance().setParallelism(4);
+
+ data.groupBy(0).withPartitioner(new TestPartitionerInt())
+ .sortGroup(1, Order.ASCENDING)
+ .reduceGroup(new IdentityGroupReducer<Tuple3<Integer,Integer,Integer>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleGroupReduceSorted2() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple4<Integer,Integer,Integer, Integer>> data = env.fromElements(new Tuple4<Integer,Integer,Integer,Integer>(0, 0, 0, 0))
+ .rebalance().setParallelism(4);
+
+ data.groupBy(0).withPartitioner(new TestPartitionerInt())
+ .sortGroup(1, Order.ASCENDING)
+ .sortGroup(2, Order.DESCENDING)
+ .reduceGroup(new IdentityGroupReducer<Tuple4<Integer,Integer,Integer,Integer>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource();
+ SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource();
+
+ assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy());
+ assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleInvalidType() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0))
+ .rebalance().setParallelism(4);
+
+ try {
+ data.groupBy(0).withPartitioner(new TestPartitionerLong());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleInvalidTypeSorted() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0))
+ .rebalance().setParallelism(4);
+
+ try {
+ data.groupBy(0)
+ .sortGroup(1, Order.ASCENDING)
+ .withPartitioner(new TestPartitionerLong());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testCustomPartitioningTupleRejectCompositeKey() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0))
+ .rebalance().setParallelism(4);
+
+ try {
+ data.groupBy(0, 1)
+ .withPartitioner(new TestPartitionerInt());
+ fail("Should throw an exception");
+ }
+ catch (InvalidProgramException e) {}
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ // --------------------------------------------------------------------------------------------
+
+ private static class TestPartitionerInt implements Partitioner<Integer> {
+ @Override
+ public int partition(Integer key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ private static class TestPartitionerLong implements Partitioner<Long> {
+ @Override
+ public int partition(Long key, int numPartitions) {
+ return 0;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/JoinCustomPartitioningTest.java
----------------------------------------------------------------------
diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/JoinCustomPartitioningTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/JoinCustomPartitioningTest.java
new file mode 100644
index 0000000..eae40cf
--- /dev/null
+++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/JoinCustomPartitioningTest.java
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.optimizer.custompartition;
+
+import static org.junit.Assert.*;
+
+import org.apache.flink.api.common.InvalidProgramException;
+import org.apache.flink.api.common.Plan;
+import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.common.operators.Order;
+import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.optimizer.CompilerTestBase;
+import org.apache.flink.optimizer.plan.DualInputPlanNode;
+import org.apache.flink.optimizer.plan.OptimizedPlan;
+import org.apache.flink.optimizer.plan.SinkPlanNode;
+import org.apache.flink.optimizer.testfunctions.DummyFlatJoinFunction;
+import org.apache.flink.optimizer.testfunctions.IdentityGroupReducer;
+import org.apache.flink.optimizer.testfunctions.IdentityMapper;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+import org.junit.Test;
+
+@SuppressWarnings({"serial", "unchecked"})
+public class JoinCustomPartitioningTest extends CompilerTestBase {
+
+ @Test
+ public void testJoinWithTuples() {
+ try {
+ final Partitioner<Long> partitioner = new TestPartitionerLong();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Long, Long>> input1 = env.fromElements(new Tuple2<Long, Long>(0L, 0L));
+ DataSet<Tuple3<Long, Long, Long>> input2 = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L));
+
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST).where(1).equalTo(0).withPartitioner(partitioner)
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ DualInputPlanNode join = (DualInputPlanNode) sink.getInput().getSource();
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy());
+ assertEquals(partitioner, join.getInput1().getPartitioner());
+ assertEquals(partitioner, join.getInput2().getPartitioner());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testJoinWithTuplesWrongType() {
+ try {
+ final Partitioner<Integer> partitioner = new TestPartitionerInt();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple2<Long, Long>> input1 = env.fromElements(new Tuple2<Long, Long>(0L, 0L));
+ DataSet<Tuple3<Long, Long, Long>> input2 = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L));
+
+ try {
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST).where(1).equalTo(0)
+ .withPartitioner(partitioner);
+
+ fail("should throw an exception");
+ }
+ catch (InvalidProgramException e) {
+ // expected
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testJoinWithPojos() {
+ try {
+ final Partitioner<Integer> partitioner = new TestPartitionerInt();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> input1 = env.fromElements(new Pojo2());
+ DataSet<Pojo3> input2 = env.fromElements(new Pojo3());
+
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST)
+ .where("b").equalTo("a").withPartitioner(partitioner)
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ DualInputPlanNode join = (DualInputPlanNode) sink.getInput().getSource();
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy());
+ assertEquals(partitioner, join.getInput1().getPartitioner());
+ assertEquals(partitioner, join.getInput2().getPartitioner());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testJoinWithPojosWrongType() {
+ try {
+ final Partitioner<Long> partitioner = new TestPartitionerLong();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> input1 = env.fromElements(new Pojo2());
+ DataSet<Pojo3> input2 = env.fromElements(new Pojo3());
+
+ try {
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST)
+ .where("a").equalTo("b")
+ .withPartitioner(partitioner);
+
+ fail("should throw an exception");
+ }
+ catch (InvalidProgramException e) {
+ // expected
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testJoinWithKeySelectors() {
+ try {
+ final Partitioner<Integer> partitioner = new TestPartitionerInt();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> input1 = env.fromElements(new Pojo2());
+ DataSet<Pojo3> input2 = env.fromElements(new Pojo3());
+
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST)
+ .where(new Pojo2KeySelector())
+ .equalTo(new Pojo3KeySelector())
+ .withPartitioner(partitioner)
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ DualInputPlanNode join = (DualInputPlanNode) sink.getInput().getSource();
+
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy());
+ assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy());
+ assertEquals(partitioner, join.getInput1().getPartitioner());
+ assertEquals(partitioner, join.getInput2().getPartitioner());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testJoinWithKeySelectorsWrongType() {
+ try {
+ final Partitioner<Long> partitioner = new TestPartitionerLong();
+
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Pojo2> input1 = env.fromElements(new Pojo2());
+ DataSet<Pojo3> input2 = env.fromElements(new Pojo3());
+
+ try {
+ input1
+ .join(input2, JoinHint.REPARTITION_HASH_FIRST)
+ .where(new Pojo2KeySelector())
+ .equalTo(new Pojo3KeySelector())
+ .withPartitioner(partitioner);
+
+ fail("should throw an exception");
+ }
+ catch (InvalidProgramException e) {
+ // expected
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testIncompatibleHashAndCustomPartitioning() {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ DataSet<Tuple3<Long, Long, Long>> input = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L));
+
+ DataSet<Tuple3<Long, Long, Long>> partitioned = input
+ .partitionCustom(new Partitioner<Long>() {
+ @Override
+ public int partition(Long key, int numPartitions) { return 0; }
+ }, 0)
+ .map(new IdentityMapper<Tuple3<Long,Long,Long>>()).withForwardedFields("0", "1", "2");
+
+
+ DataSet<Tuple3<Long, Long, Long>> grouped = partitioned
+ .distinct(0, 1)
+ .groupBy(1)
+ .sortGroup(0, Order.ASCENDING)
+ .reduceGroup(new IdentityGroupReducer<Tuple3<Long,Long,Long>>()).withForwardedFields("0", "1");
+
+ grouped
+ .join(partitioned, JoinHint.REPARTITION_HASH_FIRST).where(0).equalTo(0)
+ .with(new DummyFlatJoinFunction<Tuple3<Long,Long,Long>>())
+ .print();
+
+ Plan p = env.createProgramPlan();
+ OptimizedPlan op = compileNoStats(p);
+
+ SinkPlanNode sink = op.getDataSinks().iterator().next();
+ DualInputPlanNode coGroup = (DualInputPlanNode) sink.getInput().getSource();
+
+ assertEquals(ShipStrategyType.PARTITION_HASH, coGroup.getInput1().getShipStrategy());
+ assertTrue(coGroup.getInput2().getShipStrategy() == ShipStrategyType.PARTITION_HASH ||
+ coGroup.getInput2().getShipStrategy() == ShipStrategyType.FORWARD);
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ // --------------------------------------------------------------------------------------------
+
+ private static class TestPartitionerInt implements Partitioner<Integer> {
+ @Override
+ public int partition(Integer key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ private static class TestPartitionerLong implements Partitioner<Long> {
+ @Override
+ public int partition(Long key, int numPartitions) {
+ return 0;
+ }
+ }
+
+ public static class Pojo2 {
+ public int a;
+ public int b;
+ }
+
+ public static class Pojo3 {
+ public int a;
+ public int b;
+ public int c;
+ }
+
+ private static class Pojo2KeySelector implements KeySelector<Pojo2, Integer> {
+ @Override
+ public Integer getKey(Pojo2 value) {
+ return value.a;
+ }
+ }
+
+ private static class Pojo3KeySelector implements KeySelector<Pojo3, Integer> {
+ @Override
+ public Integer getKey(Pojo3 value) {
+ return value.b;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/DataExchangeModeClosedBranchingTest.java
----------------------------------------------------------------------
diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/DataExchangeModeClosedBranchingTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/DataExchangeModeClosedBranchingTest.java
new file mode 100644
index 0000000..cb4bd78
--- /dev/null
+++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/DataExchangeModeClosedBranchingTest.java
@@ -0,0 +1,257 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.optimizer.dataexchange;
+
+import org.apache.flink.api.common.ExecutionMode;
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.io.DiscardingOutputFormat;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.optimizer.CompilerTestBase;
+import org.apache.flink.optimizer.plan.DualInputPlanNode;
+import org.apache.flink.optimizer.plan.OptimizedPlan;
+import org.apache.flink.optimizer.plan.SingleInputPlanNode;
+import org.apache.flink.optimizer.plan.SinkPlanNode;
+import org.apache.flink.optimizer.testfunctions.DummyCoGroupFunction;
+import org.apache.flink.optimizer.testfunctions.DummyFlatJoinFunction;
+import org.apache.flink.optimizer.testfunctions.IdentityFlatMapper;
+import org.apache.flink.optimizer.testfunctions.SelectOneReducer;
+import org.apache.flink.optimizer.testfunctions.Top1GroupReducer;
+import org.apache.flink.runtime.io.network.DataExchangeMode;
+import org.junit.Test;
+
+import java.util.Collection;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/**
+ * This test checks the correct assignment of the DataExchangeMode to
+ * connections for programs that branch, and re-join those branches.
+ *
+ * <pre>
+ * /-> (sink)
+ * /
+ * /-> (reduce) -+ /-> (flatmap) -> (sink)
+ * / \ /
+ * (source) -> (map) - (join) -+-----\
+ * \ / \
+ * \-> (filter) -+ \
+ * \ (co group) -> (sink)
+ * \ /
+ * \-> (reduce) - /
+ * </pre>
+ */
+@SuppressWarnings("serial")
+public class DataExchangeModeClosedBranchingTest extends CompilerTestBase {
+
+ @Test
+ public void testPipelinedForced() {
+ // PIPELINED_FORCED should result in pipelining all the way
+ verifyBranchingJoiningPlan(ExecutionMode.PIPELINED_FORCED,
+ DataExchangeMode.PIPELINED, DataExchangeMode.PIPELINED,
+ DataExchangeMode.PIPELINED, DataExchangeMode.PIPELINED,
+ DataExchangeMode.PIPELINED, DataExchangeMode.PIPELINED,
+ DataExchangeMode.PIPELINED, DataExchangeMode.PIPELINED,
+ DataExchangeMode.PIPELINED, DataExchangeMode.PIPELINED,
+ DataExchangeMode.PIPELINED, DataExchangeMode.PIPELINED,
+ DataExchangeMode.PIPELINED, DataExchangeMode.PIPELINED);
+ }
+
+ @Test
+ public void testPipelined() {
+ // PIPELINED should result in pipelining all the way
+ verifyBranchingJoiningPlan(ExecutionMode.PIPELINED,
+ DataExchangeMode.PIPELINED, // to map
+ DataExchangeMode.PIPELINED, // to combiner connections are pipelined
+ DataExchangeMode.BATCH, // to reduce
+ DataExchangeMode.BATCH, // to filter
+ DataExchangeMode.PIPELINED, // to sink after reduce
+ DataExchangeMode.PIPELINED, // to join (first input)
+ DataExchangeMode.BATCH, // to join (second input)
+ DataExchangeMode.PIPELINED, // combiner connections are pipelined
+ DataExchangeMode.BATCH, // to other reducer
+ DataExchangeMode.PIPELINED, // to flatMap
+ DataExchangeMode.PIPELINED, // to sink after flatMap
+ DataExchangeMode.PIPELINED, // to coGroup (first input)
+ DataExchangeMode.PIPELINED, // to coGroup (second input)
+ DataExchangeMode.PIPELINED // to sink after coGroup
+ );
+ }
+
+ @Test
+ public void testBatch() {
+ // BATCH should result in batching the shuffle all the way
+ verifyBranchingJoiningPlan(ExecutionMode.BATCH,
+ DataExchangeMode.PIPELINED, // to map
+ DataExchangeMode.PIPELINED, // to combiner connections are pipelined
+ DataExchangeMode.BATCH, // to reduce
+ DataExchangeMode.BATCH, // to filter
+ DataExchangeMode.PIPELINED, // to sink after reduce
+ DataExchangeMode.BATCH, // to join (first input)
+ DataExchangeMode.BATCH, // to join (second input)
+ DataExchangeMode.PIPELINED, // combiner connections are pipelined
+ DataExchangeMode.BATCH, // to other reducer
+ DataExchangeMode.PIPELINED, // to flatMap
+ DataExchangeMode.PIPELINED, // to sink after flatMap
+ DataExchangeMode.BATCH, // to coGroup (first input)
+ DataExchangeMode.BATCH, // to coGroup (second input)
+ DataExchangeMode.PIPELINED // to sink after coGroup
+ );
+ }
+
+ @Test
+ public void testBatchForced() {
+ // BATCH_FORCED should result in batching all the way
+ verifyBranchingJoiningPlan(ExecutionMode.BATCH_FORCED,
+ DataExchangeMode.BATCH, // to map
+ DataExchangeMode.PIPELINED, // to combiner connections are pipelined
+ DataExchangeMode.BATCH, // to reduce
+ DataExchangeMode.BATCH, // to filter
+ DataExchangeMode.BATCH, // to sink after reduce
+ DataExchangeMode.BATCH, // to join (first input)
+ DataExchangeMode.BATCH, // to join (second input)
+ DataExchangeMode.PIPELINED, // combiner connections are pipelined
+ DataExchangeMode.BATCH, // to other reducer
+ DataExchangeMode.BATCH, // to flatMap
+ DataExchangeMode.BATCH, // to sink after flatMap
+ DataExchangeMode.BATCH, // to coGroup (first input)
+ DataExchangeMode.BATCH, // to coGroup (second input)
+ DataExchangeMode.BATCH // to sink after coGroup
+ );
+ }
+
+ private void verifyBranchingJoiningPlan(ExecutionMode execMode,
+ DataExchangeMode toMap,
+ DataExchangeMode toReduceCombiner,
+ DataExchangeMode toReduce,
+ DataExchangeMode toFilter,
+ DataExchangeMode toReduceSink,
+ DataExchangeMode toJoin1,
+ DataExchangeMode toJoin2,
+ DataExchangeMode toOtherReduceCombiner,
+ DataExchangeMode toOtherReduce,
+ DataExchangeMode toFlatMap,
+ DataExchangeMode toFlatMapSink,
+ DataExchangeMode toCoGroup1,
+ DataExchangeMode toCoGroup2,
+ DataExchangeMode toCoGroupSink)
+ {
+ try {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.getConfig().setExecutionMode(execMode);
+
+ DataSet<Tuple2<Long, Long>> data = env.fromElements(33L, 44L)
+ .map(new MapFunction<Long, Tuple2<Long, Long>>() {
+ @Override
+ public Tuple2<Long, Long> map(Long value) {
+ return new Tuple2<Long, Long>(value, value);
+ }
+ });
+
+ DataSet<Tuple2<Long, Long>> reduced = data.groupBy(0).reduce(new SelectOneReducer<Tuple2<Long, Long>>());
+ reduced.output(new DiscardingOutputFormat<Tuple2<Long, Long>>()).name("reduceSink");
+
+ DataSet<Tuple2<Long, Long>> filtered = data.filter(new FilterFunction<Tuple2<Long, Long>>() {
+ @Override
+ public boolean filter(Tuple2<Long, Long> value) throws Exception {
+ return false;
+ }
+ });
+
+ DataSet<Tuple2<Long, Long>> joined = reduced.join(filtered)
+ .where(1).equalTo(1)
+ .with(new DummyFlatJoinFunction<Tuple2<Long, Long>>());
+
+ joined.flatMap(new IdentityFlatMapper<Tuple2<Long, Long>>())
+ .output(new DiscardingOutputFormat<Tuple2<Long, Long>>()).name("flatMapSink");
+
+ joined.coGroup(filtered.groupBy(1).reduceGroup(new Top1GroupReducer<Tuple2<Long, Long>>()))
+ .where(0).equalTo(0)
+ .with(new DummyCoGroupFunction<Tuple2<Long, Long>, Tuple2<Long, Long>>())
+ .output(new DiscardingOutputFormat<Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>>>()).name("cgSink");
+
+
+ OptimizedPlan optPlan = compileNoStats(env.createProgramPlan());
+
+ SinkPlanNode reduceSink = findSink(optPlan.getDataSinks(), "reduceSink");
+ SinkPlanNode flatMapSink = findSink(optPlan.getDataSinks(), "flatMapSink");
+ SinkPlanNode cgSink = findSink(optPlan.getDataSinks(), "cgSink");
+
+ DualInputPlanNode coGroupNode = (DualInputPlanNode) cgSink.getPredecessor();
+
+ DualInputPlanNode joinNode = (DualInputPlanNode) coGroupNode.getInput1().getSource();
+ SingleInputPlanNode otherReduceNode = (SingleInputPlanNode) coGroupNode.getInput2().getSource();
+ SingleInputPlanNode otherReduceCombinerNode = (SingleInputPlanNode) otherReduceNode.getPredecessor();
+
+ SingleInputPlanNode reduceNode = (SingleInputPlanNode) joinNode.getInput1().getSource();
+ SingleInputPlanNode reduceCombinerNode = (SingleInputPlanNode) reduceNode.getPredecessor();
+ assertEquals(reduceNode, reduceSink.getPredecessor());
+
+ SingleInputPlanNode filterNode = (SingleInputPlanNode) joinNode.getInput2().getSource();
+ assertEquals(filterNode, otherReduceCombinerNode.getPredecessor());
+
+ SingleInputPlanNode mapNode = (SingleInputPlanNode) filterNode.getPredecessor();
+ assertEquals(mapNode, reduceCombinerNode.getPredecessor());
+
+ SingleInputPlanNode flatMapNode = (SingleInputPlanNode) flatMapSink.getPredecessor();
+ assertEquals(joinNode, flatMapNode.getPredecessor());
+
+ // verify the data exchange modes
+
+ assertEquals(toReduceSink, reduceSink.getInput().getDataExchangeMode());
+ assertEquals(toFlatMapSink, flatMapSink.getInput().getDataExchangeMode());
+ assertEquals(toCoGroupSink, cgSink.getInput().getDataExchangeMode());
+
+ assertEquals(toCoGroup1, coGroupNode.getInput1().getDataExchangeMode());
+ assertEquals(toCoGroup2, coGroupNode.getInput2().getDataExchangeMode());
+
+ assertEquals(toJoin1, joinNode.getInput1().getDataExchangeMode());
+ assertEquals(toJoin2, joinNode.getInput2().getDataExchangeMode());
+
+ assertEquals(toOtherReduce, otherReduceNode.getInput().getDataExchangeMode());
+ assertEquals(toOtherReduceCombiner, otherReduceCombinerNode.getInput().getDataExchangeMode());
+
+ assertEquals(toFlatMap, flatMapNode.getInput().getDataExchangeMode());
+
+ assertEquals(toFilter, filterNode.getInput().getDataExchangeMode());
+ assertEquals(toReduce, reduceNode.getInput().getDataExchangeMode());
+ assertEquals(toReduceCombiner, reduceCombinerNode.getInput().getDataExchangeMode());
+
+ assertEquals(toMap, mapNode.getInput().getDataExchangeMode());
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ private SinkPlanNode findSink(Collection<SinkPlanNode> collection, String name) {
+ for (SinkPlanNode node : collection) {
+ String nodeName = node.getOptimizerNode().getOperator().getName();
+ if (nodeName != null && nodeName.equals(name)) {
+ return node;
+ }
+ }
+
+ throw new IllegalArgumentException("No node with that name was found.");
+ }
+}