You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2016/06/28 08:45:46 UTC

flink git commit: [FLINK-4113] [runtime] Always copy first value in ChainedAllReduceDriver

Repository: flink
Updated Branches:
  refs/heads/master f9552d8dc -> a6feea32a


[FLINK-4113] [runtime] Always copy first value in ChainedAllReduceDriver

Guard test for ChainedAllReduceDriver

This closes #2156.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/a6feea32
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/a6feea32
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/a6feea32

Branch: refs/heads/master
Commit: a6feea32a328f3f3216960298d6c8a5d3f30a234
Parents: f9552d8
Author: Greg Hogan <co...@greghogan.com>
Authored: Thu Jun 23 12:37:37 2016 -0400
Committer: Till Rohrmann <tr...@apache.org>
Committed: Tue Jun 28 10:44:50 2016 +0200

----------------------------------------------------------------------
 .../chaining/ChainedAllReduceDriver.java        |   2 +-
 .../chaining/ChainedAllReduceDriverTest.java    | 141 +++++++++++++++++++
 2 files changed, 142 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/a6feea32/flink-runtime/src/main/java/org/apache/flink/runtime/operators/chaining/ChainedAllReduceDriver.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/chaining/ChainedAllReduceDriver.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/chaining/ChainedAllReduceDriver.java
index 1e3482f..d47c3a6 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/chaining/ChainedAllReduceDriver.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/chaining/ChainedAllReduceDriver.java
@@ -89,7 +89,7 @@ public class ChainedAllReduceDriver<IT> extends ChainedDriver<IT, IT> {
 		numRecordsIn.inc();
 		try {
 			if (base == null) {
-				base = objectReuseEnabled ? record : serializer.copy(record);
+				base = serializer.copy(record);
 			} else {
 				base = objectReuseEnabled ? reducer.reduce(base, record) : serializer.copy(reducer.reduce(base, record));
 			}

http://git-wip-us.apache.org/repos/asf/flink/blob/a6feea32/flink-runtime/src/test/java/org/apache/flink/runtime/operators/chaining/ChainedAllReduceDriverTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/chaining/ChainedAllReduceDriverTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/chaining/ChainedAllReduceDriverTest.java
new file mode 100644
index 0000000..4a037cd
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/chaining/ChainedAllReduceDriverTest.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.operators.chaining;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.operators.BatchTask;
+import org.apache.flink.runtime.operators.DriverStrategy;
+import org.apache.flink.runtime.operators.FlatMapDriver;
+import org.apache.flink.runtime.operators.FlatMapTaskTest.MockMapStub;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+import org.apache.flink.runtime.operators.testutils.TaskTestBase;
+import org.apache.flink.runtime.operators.testutils.UniformRecordGenerator;
+import org.apache.flink.runtime.operators.util.TaskConfig;
+import org.apache.flink.runtime.taskmanager.Task;
+import org.apache.flink.runtime.testutils.recordutils.RecordComparatorFactory;
+import org.apache.flink.runtime.testutils.recordutils.RecordSerializerFactory;
+import org.apache.flink.types.IntValue;
+import org.apache.flink.types.Record;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+
+import java.util.ArrayList;
+import java.util.List;
+
+@RunWith(PowerMockRunner.class)
+@PrepareForTest({Task.class, ResultPartitionWriter.class})
+public class ChainedAllReduceDriverTest extends TaskTestBase {
+
+	private static final int MEMORY_MANAGER_SIZE = 1024 * 1024 * 3;
+
+	private static final int NETWORK_BUFFER_SIZE = 1024;
+
+	private final List<Record> outList = new ArrayList<>();
+
+	@SuppressWarnings("unchecked")
+	private final RecordComparatorFactory compFact = new RecordComparatorFactory(new int[]{0}, new Class[]{IntValue.class}, new boolean[] {true});
+	private final RecordSerializerFactory serFact = RecordSerializerFactory.get();
+
+	@Test
+	public void testMapTask() {
+		final int keyCnt = 100;
+		final int valCnt = 20;
+
+		final double memoryFraction = 1.0;
+
+		try {
+			// environment
+			initEnvironment(MEMORY_MANAGER_SIZE, NETWORK_BUFFER_SIZE);
+			mockEnv.getExecutionConfig().enableObjectReuse();
+			addInput(new UniformRecordGenerator(keyCnt, valCnt, false), 0);
+			addOutput(this.outList);
+
+			// chained reduce config
+			{
+				final TaskConfig reduceConfig = new TaskConfig(new Configuration());
+
+				// input
+				reduceConfig.addInputToGroup(0);
+				reduceConfig.setInputSerializer(serFact, 0);
+
+				// output
+				reduceConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
+				reduceConfig.setOutputSerializer(serFact);
+
+				// driver
+				reduceConfig.setDriverStrategy(DriverStrategy.ALL_REDUCE);
+				reduceConfig.setDriverComparator(compFact, 0);
+				reduceConfig.setDriverComparator(compFact, 1);
+				reduceConfig.setRelativeMemoryDriver(memoryFraction);
+
+				// udf
+				reduceConfig.setStubWrapper(new UserCodeClassWrapper<>(MockReduceStub.class));
+
+				getTaskConfig().addChainedTask(ChainedAllReduceDriver.class, reduceConfig, "reduce");
+			}
+
+			// chained map+reduce
+			{
+				BatchTask<FlatMapFunction<Record, Record>, Record> testTask = new BatchTask<>();
+				registerTask(testTask, FlatMapDriver.class, MockMapStub.class);
+
+				try {
+					testTask.invoke();
+				} catch (Exception e) {
+					e.printStackTrace();
+					Assert.fail("Invoke method caused exception.");
+				}
+			}
+
+			int sumTotal = valCnt * keyCnt * (keyCnt - 1) / 2;
+
+			Assert.assertEquals(1, this.outList.size());
+			Assert.assertEquals(sumTotal, this.outList.get(0).getField(0, IntValue.class).getValue());
+		}
+		catch (Exception e) {
+			e.printStackTrace();
+			Assert.fail(e.getMessage());
+		}
+	}
+
+	private static class MockReduceStub implements ReduceFunction<Record> {
+		private static final long serialVersionUID = 1047525105526690165L;
+
+		@Override
+		public Record reduce(Record value1, Record value2) throws Exception {
+			IntValue v1 = value1.getField(0, IntValue.class);
+			IntValue v2 = value2.getField(0, IntValue.class);
+
+			// set value and force update of record; this updates and returns
+			// value1 in order to test ChainedAllReduceDriver.collect() when
+			// object reuse is enabled
+			v1.setValue(v1.getValue() + v2.getValue());
+			value1.setField(0, v1);
+			value1.updateBinaryRepresenation();
+			return value1;
+		}
+	}
+}