You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2015/09/01 13:50:18 UTC

[2/2] flink git commit: [FLINK-2590] fixing DataSetUtils.zipWithUniqueId() and DataSetUtils.zipWithIndex()

[FLINK-2590] fixing DataSetUtils.zipWithUniqueId() and DataSetUtils.zipWithIndex()

* modified algorithm as explained in the issue
* updated method documentation

[FLINK-2590] reducing required bit shift size

* maximum bit size is changed to getNumberOfParallelSubTasks() - 1

This closes #1075.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/ab14f901
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/ab14f901
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/ab14f901

Branch: refs/heads/master
Commit: ab14f90142fd69426bb695cbdb641f0a5a0c46f7
Parents: 8c852c2
Author: Martin Junghanns <ma...@gmx.net>
Authored: Sat Aug 29 22:51:19 2015 +0200
Committer: Till Rohrmann <tr...@apache.org>
Committed: Tue Sep 1 13:49:11 2015 +0200

----------------------------------------------------------------------
 .../flink/api/java/utils/DataSetUtils.java      | 70 +++++++++++---------
 .../flink/test/util/DataSetUtilsITCase.java     | 65 ++++++++----------
 2 files changed, 68 insertions(+), 67 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/ab14f901/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java b/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java
index d268925..722fc6b 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java
@@ -18,8 +18,9 @@
 
 package org.apache.flink.api.java.utils;
 
+import com.google.common.collect.Lists;
+import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
 import org.apache.flink.api.common.functions.RichMapPartitionFunction;
-import org.apache.flink.api.java.sampling.IntermediateSampleData;
 import org.apache.flink.api.java.DataSet;
 import org.apache.flink.api.java.Utils;
 import org.apache.flink.api.java.functions.SampleInCoordinator;
@@ -27,6 +28,7 @@ import org.apache.flink.api.java.functions.SampleInPartition;
 import org.apache.flink.api.java.functions.SampleWithFraction;
 import org.apache.flink.api.java.operators.GroupReduceOperator;
 import org.apache.flink.api.java.operators.MapPartitionOperator;
+import org.apache.flink.api.java.sampling.IntermediateSampleData;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.util.Collector;
@@ -49,11 +51,11 @@ public class DataSetUtils {
 	 * @return a data set containing tuples of subtask index, number of elements mappings.
 	 */
 	private static <T> DataSet<Tuple2<Integer, Long>> countElements(DataSet<T> input) {
-		return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer,Long>>() {
+		return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
 			@Override
 			public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
 				long counter = 0;
-				for(T value: values) {
+				for (T value : values) {
 					counter++;
 				}
 
@@ -63,8 +65,8 @@ public class DataSetUtils {
 	}
 
 	/**
-	 * Method that takes a set of subtask index, total number of elements mappings
-	 * and assigns ids to all the elements from the input data set.
+	 * Method that assigns a unique {@link Long} value to all elements in the input data set. The generated values are
+	 * consecutive.
 	 *
 	 * @param input the input data set
 	 * @return a data set of tuple 2 consisting of consecutive ids and initial values.
@@ -77,28 +79,36 @@ public class DataSetUtils {
 
 			long start = 0;
 
-			// compute the offset for each partition
 			@Override
 			public void open(Configuration parameters) throws Exception {
 				super.open(parameters);
 
-				List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariable("counts");
-
-				Collections.sort(offsets, new Comparator<Tuple2<Integer, Long>>() {
-					@Override
-					public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) {
-						return compareInts(o1.f0, o2.f0);
-					}
-				});
-
-				for(int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
+				List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariableWithInitializer(
+						"counts",
+						new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
+							@Override
+							public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> data) {
+								// sort the list by task id to calculate the correct offset
+								List<Tuple2<Integer, Long>> sortedData = Lists.newArrayList(data);
+								Collections.sort(sortedData, new Comparator<Tuple2<Integer, Long>>() {
+									@Override
+									public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) {
+										return o1.f0.compareTo(o2.f0);
+									}
+								});
+								return sortedData;
+							}
+						});
+
+				// compute the offset for each partition
+				for (int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
 					start += offsets.get(i).f1;
 				}
 			}
 
 			@Override
 			public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
-				for(T value: values) {
+				for (T value: values) {
 					out.collect(new Tuple2<Long, T>(start++, value));
 				}
 			}
@@ -106,12 +116,13 @@ public class DataSetUtils {
 	}
 
 	/**
-	 * Method that assigns unique Long labels to all the elements in the input data set by making use of the
-	 * following abstractions:
+	 * Method that assigns a unique {@link Long} value to all elements in the input data set in the following way:
 	 * <ul>
-	 * 	<li> a map function generates an n-bit (n - number of parallel tasks) ID based on its own index
-	 * 	<li> with each record, a counter c is increased
-	 * 	<li> the unique label is then produced by shifting the counter c by the n-bit mapper ID
+	 *  <li> a map function is applied to the input data set
+	 *  <li> each map task holds a counter c which is increased for each record
+	 *  <li> c is shifted by n bits where n = log2(number of parallel tasks)
+	 * 	<li> to create a unique ID among all tasks, the task id is added to the counter
+	 * 	<li> for each record, the resulting counter is collected
 	 * </ul>
 	 *
 	 * @param input the input data set
@@ -121,6 +132,7 @@ public class DataSetUtils {
 
 		return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {
 
+			long maxBitSize = getBitSize(Long.MAX_VALUE);
 			long shifter = 0;
 			long start = 0;
 			long taskId = 0;
@@ -129,16 +141,16 @@ public class DataSetUtils {
 			@Override
 			public void open(Configuration parameters) throws Exception {
 				super.open(parameters);
-				shifter = log2(getRuntimeContext().getNumberOfParallelSubtasks());
+				shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1);
 				taskId = getRuntimeContext().getIndexOfThisSubtask();
 			}
 
 			@Override
 			public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
-				for(T value: values) {
-					label = start << shifter + taskId;
+				for (T value : values) {
+					label = (start << shifter) + taskId;
 
-					if(log2(start) + shifter < log2(Long.MAX_VALUE)) {
+					if (getBitSize(start) + shifter < maxBitSize) {
 						out.collect(new Tuple2<Long, T>(label, value));
 						start++;
 					} else {
@@ -241,11 +253,7 @@ public class DataSetUtils {
 	//     UTIL METHODS
 	// *************************************************************************
 
-	private static int compareInts(int x, int y) {
-		return (x < y) ? -1 : ((x == y) ? 0 : 1);
-	}
-
-	private static int log2(long value){
+	public static int getBitSize(long value){
 		if(value > Integer.MAX_VALUE) {
 			return 64 - Integer.numberOfLeadingZeros((int)(value >> 32));
 		} else {

http://git-wip-us.apache.org/repos/asf/flink/blob/ab14f901/flink-tests/src/test/java/org/apache/flink/test/util/DataSetUtilsITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/util/DataSetUtilsITCase.java b/flink-tests/src/test/java/org/apache/flink/test/util/DataSetUtilsITCase.java
index 1e5363b..a289116 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/util/DataSetUtilsITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/util/DataSetUtilsITCase.java
@@ -18,66 +18,59 @@
 
 package org.apache.flink.test.util;
 
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
 import org.apache.flink.api.java.DataSet;
 import org.apache.flink.api.java.ExecutionEnvironment;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.utils.DataSetUtils;
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Rule;
+import org.junit.Assert;
 import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Set;
+
 @RunWith(Parameterized.class)
 public class DataSetUtilsITCase extends MultipleProgramsTestBase {
 
-	private String resultPath;
-	private String expectedResult;
-
-	@Rule
-	public TemporaryFolder tempFolder = new TemporaryFolder();
-
 	public DataSetUtilsITCase(TestExecutionMode mode) {
 		super(mode);
 	}
 
-	@Before
-	public void before() throws Exception{
-		resultPath = tempFolder.newFile().toURI().toString();
-	}
-
 	@Test
 	public void testZipWithIndex() throws Exception {
 		ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
-		env.setParallelism(1);
-		DataSet<String> in = env.fromElements("A", "B", "C", "D", "E", "F");
-
-		DataSet<Tuple2<Long, String>> result = DataSetUtils.zipWithIndex(in);
-
-		result.writeAsCsv(resultPath, "\n", ",");
-		env.execute();
-
-		expectedResult = "0,A\n" + "1,B\n" + "2,C\n" + "3,D\n" + "4,E\n" + "5,F";
+		long expectedSize = 100L;
+		DataSet<Long> numbers = env.generateSequence(0, expectedSize - 1);
+
+		List<Tuple2<Long, Long>> result = Lists.newArrayList(DataSetUtils.zipWithIndex(numbers).collect());
+
+		Assert.assertEquals(expectedSize, result.size());
+		// sort result by created index
+		Collections.sort(result, new Comparator<Tuple2<Long, Long>>() {
+			@Override
+			public int compare(Tuple2<Long, Long> o1, Tuple2<Long, Long> o2) {
+				return o1.f0.compareTo(o2.f0);
+			}
+		});
+		// test if index is consecutive
+		for (int i = 0; i < expectedSize; i++) {
+			Assert.assertEquals(i, (long) result.get(i).f0);
+		}
 	}
 
 	@Test
 	public void testZipWithUniqueId() throws Exception {
 		ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
-		env.setParallelism(1);
-		DataSet<String> in = env.fromElements("A", "B", "C", "D", "E", "F");
-
-		DataSet<Tuple2<Long, String>> result = DataSetUtils.zipWithUniqueId(in);
+		long expectedSize = 100L;
+		DataSet<Long> numbers = env.generateSequence(1L, expectedSize);
 
-		result.writeAsCsv(resultPath, "\n", ",");
-		env.execute();
-
-		expectedResult = "0,A\n" + "2,B\n" + "4,C\n" + "6,D\n" + "8,E\n" + "10,F";
-	}
+		Set<Tuple2<Long, Long>> result = Sets.newHashSet(DataSetUtils.zipWithUniqueId(numbers).collect());
 
-	@After
-	public void after() throws Exception{
-		compareResultsByLinesInMemory(expectedResult, resultPath);
+		Assert.assertEquals(expectedSize, result.size());
 	}
 }