You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2014/07/09 12:21:30 UTC
[03/12] git commit: [FLINK-836] Rework of the cached match driver
[FLINK-836] Rework of the cached match driver
Project: http://git-wip-us.apache.org/repos/asf/incubator-flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-flink/commit/47a68adc
Tree: http://git-wip-us.apache.org/repos/asf/incubator-flink/tree/47a68adc
Diff: http://git-wip-us.apache.org/repos/asf/incubator-flink/diff/47a68adc
Branch: refs/heads/master
Commit: 47a68adc7fa5da7743cb51fe69f81264e908e630
Parents: 99c888c
Author: Markus Holzemer <ma...@gmx.de>
Authored: Mon Jun 16 09:41:46 2014 +0200
Committer: Stephan Ewen <se...@apache.org>
Committed: Tue Jul 8 16:54:15 2014 +0200
----------------------------------------------------------------------
.../compiler/dag/OptimizerNode.java | 10 -
.../operators/HashJoinBuildFirstProperties.java | 2 +-
.../HashJoinBuildSecondProperties.java | 2 +-
.../eu/stratosphere/compiler/plan/Channel.java | 4 -
.../eu/stratosphere/compiler/plan/PlanNode.java | 4 -
.../services/memorymanager/MemoryManager.java | 8 +-
.../memorymanager/spi/DefaultMemoryManager.java | 4 +-
.../BuildFirstReOpenableHashMatchIterator.java | 10 +-
.../hash/BuildSecondHashMatchIterator.java | 6 +-
.../BuildSecondReOpenableHashMatchIterator.java | 71 +++
.../runtime/sort/AsynchronousPartialSorter.java | 13 +-
.../pact/runtime/sort/UnilateralSortMerger.java | 2 +-
.../AbstractCachedBuildSideMatchDriver.java | 131 ++---
.../pact/runtime/task/RegularPactTask.java | 10 +
.../sort/AsynchonousPartialSorterITCase.java | 6 +-
.../pact/runtime/task/CachedMatchTaskTest.java | 491 +++++++++++++++++++
.../pact/runtime/test/util/DriverTestBase.java | 33 ++
17 files changed, 683 insertions(+), 124 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/47a68adc/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/dag/OptimizerNode.java
----------------------------------------------------------------------
diff --git a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/dag/OptimizerNode.java b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/dag/OptimizerNode.java
index 6f295a5..ec9bd69 100644
--- a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/dag/OptimizerNode.java
+++ b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/dag/OptimizerNode.java
@@ -92,8 +92,6 @@ public abstract class OptimizerNode implements Visitable<OptimizerNode>, Estimat
protected boolean onDynamicPath;
- protected boolean insideIteration;
-
protected List<PlanNode> cachedPlans; // cache candidates, because the may be accessed repeatedly
protected int[][] remappedKeys;
@@ -501,14 +499,6 @@ public abstract class OptimizerNode implements Visitable<OptimizerNode>, Estimat
}
}
- public boolean isInsideIteration() {
- return insideIteration;
- }
-
- public void setInsideIteration(boolean insideIteration) {
- this.insideIteration = insideIteration;
- }
-
/**
* Checks whether this node has branching output. A node's output is branched, if it has more
* than one output connection.
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/47a68adc/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildFirstProperties.java
----------------------------------------------------------------------
diff --git a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildFirstProperties.java b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildFirstProperties.java
index 661b316..4f694c5 100644
--- a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildFirstProperties.java
+++ b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildFirstProperties.java
@@ -56,7 +56,7 @@ public class HashJoinBuildFirstProperties extends AbstractJoinDescriptor {
public DualInputPlanNode instantiate(Channel in1, Channel in2, TwoInputNode node) {
DriverStrategy strategy;
- if(!in1.isOnDynamicPath() && in1.isInsideIteration() && in2.isInsideIteration()) {
+ if(!in1.isOnDynamicPath() && in2.isOnDynamicPath()) {
strategy = DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED;
}
else {
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/47a68adc/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildSecondProperties.java
----------------------------------------------------------------------
diff --git a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildSecondProperties.java b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildSecondProperties.java
index e085588..6bea65a 100644
--- a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildSecondProperties.java
+++ b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildSecondProperties.java
@@ -53,7 +53,7 @@ public final class HashJoinBuildSecondProperties extends AbstractJoinDescriptor
public DualInputPlanNode instantiate(Channel in1, Channel in2, TwoInputNode node) {
DriverStrategy strategy;
- if(!in2.isOnDynamicPath() && in1.isInsideIteration() && in2.isInsideIteration()) {
+ if(!in2.isOnDynamicPath() && in1.isOnDynamicPath()) {
strategy = DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED;
}
else {
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/47a68adc/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/Channel.java
----------------------------------------------------------------------
diff --git a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/Channel.java b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/Channel.java
index d83da82..6f9418f 100644
--- a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/Channel.java
+++ b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/Channel.java
@@ -306,10 +306,6 @@ public class Channel implements EstimateProvider, Cloneable, DumpableConnection<
return this.source.isOnDynamicPath();
}
- public boolean isInsideIteration() {
- return this.source.isInsideIteration();
- }
-
public int getCostWeight() {
return this.source.getCostWeight();
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/47a68adc/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/PlanNode.java
----------------------------------------------------------------------
diff --git a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/PlanNode.java b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/PlanNode.java
index da24b5f..539006c 100644
--- a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/PlanNode.java
+++ b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/PlanNode.java
@@ -423,10 +423,6 @@ public abstract class PlanNode implements Visitable<PlanNode>, DumpableNode<Plan
return this.template.getCostWeight();
}
- public boolean isInsideIteration() {
- return this.template.isInsideIteration();
- }
-
// --------------------------------------------------------------------------------------------
/**
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/47a68adc/stratosphere-runtime/src/main/java/eu/stratosphere/nephele/services/memorymanager/MemoryManager.java
----------------------------------------------------------------------
diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/nephele/services/memorymanager/MemoryManager.java b/stratosphere-runtime/src/main/java/eu/stratosphere/nephele/services/memorymanager/MemoryManager.java
index 8b20c75..6c5109b 100644
--- a/stratosphere-runtime/src/main/java/eu/stratosphere/nephele/services/memorymanager/MemoryManager.java
+++ b/stratosphere-runtime/src/main/java/eu/stratosphere/nephele/services/memorymanager/MemoryManager.java
@@ -72,7 +72,8 @@ public interface MemoryManager {
/**
* Returns the total size of memory.
- * @return
+ *
+ * @return The total size of memory.
*/
long getMemorySize();
@@ -88,8 +89,9 @@ public interface MemoryManager {
/**
* Computes the memory size of the fraction per slot.
- * @param fraction
- * @return
+ *
+ * @param fraction The fraction of the memory of the task slot.
+ * @return The number of pages corresponding to the memory fraction.
*/
long computeMemorySize(double fraction);
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/47a68adc/stratosphere-runtime/src/main/java/eu/stratosphere/nephele/services/memorymanager/spi/DefaultMemoryManager.java
----------------------------------------------------------------------
diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/nephele/services/memorymanager/spi/DefaultMemoryManager.java b/stratosphere-runtime/src/main/java/eu/stratosphere/nephele/services/memorymanager/spi/DefaultMemoryManager.java
index d4a2b36..85455d2 100644
--- a/stratosphere-runtime/src/main/java/eu/stratosphere/nephele/services/memorymanager/spi/DefaultMemoryManager.java
+++ b/stratosphere-runtime/src/main/java/eu/stratosphere/nephele/services/memorymanager/spi/DefaultMemoryManager.java
@@ -394,8 +394,8 @@ public class DefaultMemoryManager implements MemoryManager {
}
private final int getRelativeNumPages(double fraction){
- if(fraction < 0){
- throw new IllegalArgumentException("The fraction of memory to allocate must not be negative.");
+ if (fraction <= 0 || fraction > 1) {
+ throw new IllegalArgumentException("The fraction of memory to allocate must within (0, 1].");
}
return (int)(this.totalNumPages * fraction / this.numberOfSlots);
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/47a68adc/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildFirstReOpenableHashMatchIterator.java
----------------------------------------------------------------------
diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildFirstReOpenableHashMatchIterator.java b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildFirstReOpenableHashMatchIterator.java
index 8c2b9ca..7898c41 100644
--- a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildFirstReOpenableHashMatchIterator.java
+++ b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildFirstReOpenableHashMatchIterator.java
@@ -37,15 +37,19 @@ public class BuildFirstReOpenableHashMatchIterator<V1, V2, O> extends BuildFirst
TypeSerializer<V1> serializer1, TypeComparator<V1> comparator1,
TypeSerializer<V2> serializer2, TypeComparator<V2> comparator2,
TypePairComparator<V2, V1> pairComparator,
- MemoryManager memManager, IOManager ioManager,
- AbstractInvokable ownerTask, double memoryFraction)
- throws MemoryAllocationException {
+ MemoryManager memManager,
+ IOManager ioManager,
+ AbstractInvokable ownerTask,
+ double memoryFraction)
+ throws MemoryAllocationException
+ {
super(firstInput, secondInput, serializer1, comparator1, serializer2,
comparator2, pairComparator, memManager, ioManager, ownerTask,
memoryFraction);
reopenHashTable = (ReOpenableMutableHashTable<V1, V2>) hashJoin;
}
+ @Override
public <BT, PT> MutableHashTable<BT, PT> getHashJoin(TypeSerializer<BT> buildSideSerializer, TypeComparator<BT> buildSideComparator,
TypeSerializer<PT> probeSideSerializer, TypeComparator<PT> probeSideComparator,
TypePairComparator<PT, BT> pairComparator,
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/47a68adc/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildSecondHashMatchIterator.java
----------------------------------------------------------------------
diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildSecondHashMatchIterator.java b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildSecondHashMatchIterator.java
index 732d256..9f3fd97 100644
--- a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildSecondHashMatchIterator.java
+++ b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildSecondHashMatchIterator.java
@@ -34,9 +34,9 @@ import eu.stratosphere.util.MutableObjectIterator;
* An implementation of the {@link eu.stratosphere.pact.runtime.task.util.JoinTaskIterator} that uses a hybrid-hash-join
* internally to match the records with equal key. The build side of the hash is the second input of the match.
*/
-public final class BuildSecondHashMatchIterator<V1, V2, O> implements JoinTaskIterator<V1, V2, O> {
+public class BuildSecondHashMatchIterator<V1, V2, O> implements JoinTaskIterator<V1, V2, O> {
- private final MutableHashTable<V2, V1> hashJoin;
+ protected final MutableHashTable<V2, V1> hashJoin;
private final V2 nextBuildSideObject;
@@ -44,7 +44,7 @@ public final class BuildSecondHashMatchIterator<V1, V2, O> implements JoinTaskIt
private final V1 probeCopy;
- private final TypeSerializer<V1> probeSideSerializer;
+ protected final TypeSerializer<V1> probeSideSerializer;
private final MemoryManager memManager;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/47a68adc/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildSecondReOpenableHashMatchIterator.java
----------------------------------------------------------------------
diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildSecondReOpenableHashMatchIterator.java b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildSecondReOpenableHashMatchIterator.java
new file mode 100644
index 0000000..597ae73
--- /dev/null
+++ b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildSecondReOpenableHashMatchIterator.java
@@ -0,0 +1,71 @@
+/***********************************************************************************************************************
+ * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ **********************************************************************************************************************/
+
+package eu.stratosphere.pact.runtime.hash;
+
+import java.io.IOException;
+import java.util.List;
+
+import eu.stratosphere.api.common.typeutils.TypeComparator;
+import eu.stratosphere.api.common.typeutils.TypePairComparator;
+import eu.stratosphere.api.common.typeutils.TypeSerializer;
+import eu.stratosphere.core.memory.MemorySegment;
+import eu.stratosphere.nephele.services.iomanager.IOManager;
+import eu.stratosphere.nephele.services.memorymanager.MemoryAllocationException;
+import eu.stratosphere.nephele.services.memorymanager.MemoryManager;
+import eu.stratosphere.nephele.template.AbstractInvokable;
+import eu.stratosphere.util.MutableObjectIterator;
+
+public class BuildSecondReOpenableHashMatchIterator<V1, V2, O> extends BuildSecondHashMatchIterator<V1, V2, O> {
+
+
+ private final ReOpenableMutableHashTable<V2, V1> reopenHashTable;
+
+ public BuildSecondReOpenableHashMatchIterator(
+ MutableObjectIterator<V1> firstInput,
+ MutableObjectIterator<V2> secondInput,
+ TypeSerializer<V1> serializer1, TypeComparator<V1> comparator1,
+ TypeSerializer<V2> serializer2, TypeComparator<V2> comparator2,
+ TypePairComparator<V1, V2> pairComparator,
+ MemoryManager memManager,
+ IOManager ioManager,
+ AbstractInvokable ownerTask,
+ double memoryFraction)
+ throws MemoryAllocationException
+ {
+ super(firstInput, secondInput, serializer1, comparator1, serializer2,
+ comparator2, pairComparator, memManager, ioManager, ownerTask, memoryFraction);
+ reopenHashTable = (ReOpenableMutableHashTable<V2, V1>) hashJoin;
+ }
+
+ @Override
+ public <BT, PT> MutableHashTable<BT, PT> getHashJoin(TypeSerializer<BT> buildSideSerializer, TypeComparator<BT> buildSideComparator,
+ TypeSerializer<PT> probeSideSerializer, TypeComparator<PT> probeSideComparator,
+ TypePairComparator<PT, BT> pairComparator,
+ MemoryManager memManager, IOManager ioManager, AbstractInvokable ownerTask, double memoryFraction)
+ throws MemoryAllocationException
+ {
+ final int numPages = memManager.computeNumberOfPages(memoryFraction);
+ final List<MemorySegment> memorySegments = memManager.allocatePages(ownerTask, numPages);
+ return new ReOpenableMutableHashTable<BT, PT>(buildSideSerializer, probeSideSerializer, buildSideComparator, probeSideComparator, pairComparator, memorySegments, ioManager);
+ }
+
+ /**
+ * Set new input for probe side
+ * @throws IOException
+ */
+ public void reopenProbe(MutableObjectIterator<V1> probeInput) throws IOException {
+ reopenHashTable.reopenProbe(probeInput);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/47a68adc/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/AsynchronousPartialSorter.java
----------------------------------------------------------------------
diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/AsynchronousPartialSorter.java b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/AsynchronousPartialSorter.java
index 35377cf..f87f1f8 100644
--- a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/AsynchronousPartialSorter.java
+++ b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/AsynchronousPartialSorter.java
@@ -34,8 +34,6 @@ import eu.stratosphere.util.MutableObjectIterator;
*/
public class AsynchronousPartialSorter<E> extends UnilateralSortMerger<E> {
- private static final int MAX_MEM_PER_PARTIAL_SORT = 64 * 1024 * 0124;
-
private BufferQueueIterator bufferIterator;
// ------------------------------------------------------------------------
@@ -62,11 +60,7 @@ public class AsynchronousPartialSorter<E> extends UnilateralSortMerger<E> {
double memoryFraction)
throws IOException, MemoryAllocationException
{
- super(memoryManager, null, input, parentTask, serializerFactory, comparator, memoryFraction,
- memoryManager.computeNumberOfPages(memoryFraction) < 2 * MIN_NUM_SORT_MEM_SEGMENTS ? 1 :
- Math.max((int) Math.ceil(((double) memoryManager.computeMemorySize(memoryFraction)) /
- MAX_MEM_PER_PARTIAL_SORT), 2),
- 2, 0.0f, true);
+ super(memoryManager, null, input, parentTask, serializerFactory, comparator, memoryFraction, 1, 2, 0.0f, true);
}
@@ -101,11 +95,6 @@ public class AsynchronousPartialSorter<E> extends UnilateralSortMerger<E> {
// ------------------------------------------------------------------------
- /**
- * This class implements an iterator over values from a {@link eu.stratosphere.pact.runtime.sort.BufferSortable}.
- * The iterator returns the values of a given
- * interval.
- */
private final class BufferQueueIterator implements MutableObjectIterator<E> {
private final CircularQueues<E> queues;
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/47a68adc/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/UnilateralSortMerger.java
----------------------------------------------------------------------
diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/UnilateralSortMerger.java b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/UnilateralSortMerger.java
index 6905b85..9109e9a 100644
--- a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/UnilateralSortMerger.java
+++ b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/UnilateralSortMerger.java
@@ -239,7 +239,7 @@ public class UnilateralSortMerger<E> implements Sorter<E> {
* @param maxNumFileHandles The maximum number of files to be merged at once.
* @param startSpillingFraction The faction of the buffers that have to be filled before the spilling thread
* actually begins spilling data to disk.
- * @param noSpilling When set to true, no memory will be allocated for writing and no spilling thread
+ * @param noSpillingMemory When set to true, no memory will be allocated for writing and no spilling thread
* will be spawned.
*
* @throws IOException Thrown, if an error occurs initializing the resources for external sorting.
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/47a68adc/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/AbstractCachedBuildSideMatchDriver.java
----------------------------------------------------------------------
diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/AbstractCachedBuildSideMatchDriver.java b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/AbstractCachedBuildSideMatchDriver.java
index 1d3c55d..1c8c427 100644
--- a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/AbstractCachedBuildSideMatchDriver.java
+++ b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/AbstractCachedBuildSideMatchDriver.java
@@ -13,26 +13,20 @@
package eu.stratosphere.pact.runtime.task;
-import java.util.List;
-
import eu.stratosphere.api.common.functions.GenericJoiner;
import eu.stratosphere.api.common.typeutils.TypeComparator;
import eu.stratosphere.api.common.typeutils.TypePairComparatorFactory;
import eu.stratosphere.api.common.typeutils.TypeSerializer;
-import eu.stratosphere.core.memory.MemorySegment;
-import eu.stratosphere.pact.runtime.hash.MutableHashTable;
+import eu.stratosphere.pact.runtime.hash.BuildFirstReOpenableHashMatchIterator;
+import eu.stratosphere.pact.runtime.hash.BuildSecondReOpenableHashMatchIterator;
+import eu.stratosphere.pact.runtime.task.util.JoinTaskIterator;
import eu.stratosphere.pact.runtime.task.util.TaskConfig;
-import eu.stratosphere.pact.runtime.util.EmptyMutableObjectIterator;
import eu.stratosphere.util.Collector;
import eu.stratosphere.util.MutableObjectIterator;
public abstract class AbstractCachedBuildSideMatchDriver<IT1, IT2, OT> extends MatchDriver<IT1, IT2, OT> implements ResettablePactDriver<GenericJoiner<IT1, IT2, OT>, OT> {
-
- /**
- * We keep it without generic parameters, because they vary depending on which input is the build side.
- */
- protected volatile MutableHashTable<?, ?> hashJoin;
+ private volatile JoinTaskIterator<IT1, IT2, OT> matchIterator;
private final int buildSideIndex;
@@ -67,23 +61,39 @@ public abstract class AbstractCachedBuildSideMatchDriver<IT1, IT2, OT> extends M
TypePairComparatorFactory<IT1, IT2> pairComparatorFactory =
this.taskContext.getTaskConfig().getPairComparatorFactory(this.taskContext.getUserCodeClassLoader());
- int numMemoryPages = this.taskContext.getMemoryManager().computeNumberOfPages(config.getRelativeMemoryDriver());
- List<MemorySegment> memSegments = this.taskContext.getMemoryManager().allocatePages(
- this.taskContext.getOwningNepheleTask(), numMemoryPages);
+ double availableMemory = config.getRelativeMemoryDriver();
if (buildSideIndex == 0 && probeSideIndex == 1) {
- MutableHashTable<IT1, IT2> hashJoin = new MutableHashTable<IT1, IT2>(serializer1, serializer2, comparator1, comparator2,
- pairComparatorFactory.createComparator21(comparator1, comparator2), memSegments, this.taskContext.getIOManager());
- this.hashJoin = hashJoin;
- hashJoin.open(input1, EmptyMutableObjectIterator.<IT2>get());
+
+ matchIterator =
+ new BuildFirstReOpenableHashMatchIterator<IT1, IT2, OT>(input1, input2,
+ serializer1, comparator1,
+ serializer2, comparator2,
+ pairComparatorFactory.createComparator21(comparator1, comparator2),
+ this.taskContext.getMemoryManager(),
+ this.taskContext.getIOManager(),
+ this.taskContext.getOwningNepheleTask(),
+ availableMemory
+ );
+
} else if (buildSideIndex == 1 && probeSideIndex == 0) {
- MutableHashTable<IT2, IT1> hashJoin = new MutableHashTable<IT2, IT1>(serializer2, serializer1, comparator2, comparator1,
- pairComparatorFactory.createComparator12(comparator1, comparator2), memSegments, this.taskContext.getIOManager());
- this.hashJoin = hashJoin;
- hashJoin.open(input2, EmptyMutableObjectIterator.<IT1>get());
+
+ matchIterator =
+ new BuildSecondReOpenableHashMatchIterator<IT1, IT2, OT>(input1, input2,
+ serializer1, comparator1,
+ serializer2, comparator2,
+ pairComparatorFactory.createComparator12(comparator1, comparator2),
+ this.taskContext.getMemoryManager(),
+ this.taskContext.getIOManager(),
+ this.taskContext.getOwningNepheleTask(),
+ availableMemory
+ );
+
} else {
throw new Exception("Error: Inconcistent setup for repeatable hash join driver.");
}
+
+ this.matchIterator.open();
}
@Override
@@ -98,63 +108,17 @@ public abstract class AbstractCachedBuildSideMatchDriver<IT1, IT2, OT> extends M
final Collector<OT> collector = this.taskContext.getOutputCollector();
if (buildSideIndex == 0) {
- final TypeSerializer<IT1> buildSideSerializer = taskContext.<IT1> getInputSerializer(0).getSerializer();
- final TypeSerializer<IT2> probeSideSerializer = taskContext.<IT2> getInputSerializer(1).getSerializer();
- IT1 buildSideRecordFirst;
- IT1 buildSideRecordOther;
- IT2 probeSideRecord;
- IT2 probeSideRecordCopy;
- final IT1 buildSideRecordFirstReuse = buildSideSerializer.createInstance();
- final IT1 buildSideRecordOtherReuse = buildSideSerializer.createInstance();
- final IT2 probeSideRecordReuse = probeSideSerializer.createInstance();
- final IT2 probeSideRecordCopyReuse = probeSideSerializer.createInstance();
+ final BuildFirstReOpenableHashMatchIterator<IT1, IT2, OT> matchIterator = (BuildFirstReOpenableHashMatchIterator<IT1, IT2, OT>) this.matchIterator;
- @SuppressWarnings("unchecked")
- final MutableHashTable<IT1, IT2> join = (MutableHashTable<IT1, IT2>) this.hashJoin;
+ while (this.running && matchIterator != null && matchIterator.callWithNextKey(matchStub, collector));
- final MutableObjectIterator<IT2> probeSideInput = taskContext.<IT2>getInput(1);
-
- while (this.running && ((probeSideRecord = probeSideInput.next(probeSideRecordReuse)) != null)) {
- final MutableHashTable.HashBucketIterator<IT1, IT2> bucket = join.getMatchesFor(probeSideRecord);
-
- if ((buildSideRecordFirst = bucket.next(buildSideRecordFirstReuse)) != null) {
- while ((buildSideRecordOther = bucket.next(buildSideRecordOtherReuse)) != null) {
- probeSideRecordCopy = probeSideSerializer.copy(probeSideRecord, probeSideRecordCopyReuse);
- matchStub.join(buildSideRecordOther, probeSideRecordCopy, collector);
- }
- matchStub.join(buildSideRecordFirst, probeSideRecord, collector);
- }
- }
} else if (buildSideIndex == 1) {
- final TypeSerializer<IT2> buildSideSerializer = taskContext.<IT2>getInputSerializer(1).getSerializer();
- final TypeSerializer<IT1> probeSideSerializer = taskContext.<IT1>getInputSerializer(0).getSerializer();
- IT2 buildSideRecordFirst;
- IT2 buildSideRecordOther;
- IT1 probeSideRecord;
- IT1 probeSideRecordCopy;
- final IT2 buildSideRecordFirstReuse = buildSideSerializer.createInstance();
- final IT2 buildSideRecordOtherReuse = buildSideSerializer.createInstance();
- final IT1 probeSideRecordReuse = probeSideSerializer.createInstance();
- final IT1 probeSideRecordCopyReuse = probeSideSerializer.createInstance();
-
- @SuppressWarnings("unchecked")
- final MutableHashTable<IT2, IT1> join = (MutableHashTable<IT2, IT1>) this.hashJoin;
+ final BuildSecondReOpenableHashMatchIterator<IT1, IT2, OT> matchIterator = (BuildSecondReOpenableHashMatchIterator<IT1, IT2, OT>) this.matchIterator;
- final MutableObjectIterator<IT1> probeSideInput = taskContext.<IT1>getInput(0);
+ while (this.running && matchIterator != null && matchIterator.callWithNextKey(matchStub, collector));
- while (this.running && ((probeSideRecord = probeSideInput.next(probeSideRecordReuse)) != null)) {
- final MutableHashTable.HashBucketIterator<IT2, IT1> bucket = join.getMatchesFor(probeSideRecord);
-
- if ((buildSideRecordFirst = bucket.next(buildSideRecordFirstReuse)) != null) {
- while ((buildSideRecordOther = bucket.next(buildSideRecordOtherReuse)) != null) {
- probeSideRecordCopy = probeSideSerializer.copy(probeSideRecord, probeSideRecordCopyReuse);
- matchStub.join(probeSideRecordCopy, buildSideRecordOther, collector);
- }
- matchStub.join(probeSideRecord, buildSideRecordFirst, collector);
- }
- }
} else {
throw new Exception();
}
@@ -164,21 +128,34 @@ public abstract class AbstractCachedBuildSideMatchDriver<IT1, IT2, OT> extends M
public void cleanup() throws Exception {}
@Override
- public void reset() throws Exception {}
+ public void reset() throws Exception {
+
+ MutableObjectIterator<IT1> input1 = this.taskContext.getInput(0);
+ MutableObjectIterator<IT2> input2 = this.taskContext.getInput(1);
+
+ if (buildSideIndex == 0 && probeSideIndex == 1) {
+ final BuildFirstReOpenableHashMatchIterator<IT1, IT2, OT> matchIterator = (BuildFirstReOpenableHashMatchIterator<IT1, IT2, OT>) this.matchIterator;
+ matchIterator.reopenProbe(input2);
+ }
+ else {
+ final BuildSecondReOpenableHashMatchIterator<IT1, IT2, OT> matchIterator = (BuildSecondReOpenableHashMatchIterator<IT1, IT2, OT>) this.matchIterator;
+ matchIterator.reopenProbe(input1);
+ }
+ }
@Override
public void teardown() {
- MutableHashTable<?, ?> ht = this.hashJoin;
- if (ht != null) {
- ht.close();
+ this.running = false;
+ if (this.matchIterator != null) {
+ this.matchIterator.close();
}
}
@Override
public void cancel() {
this.running = false;
- if (this.hashJoin != null) {
- this.hashJoin.close();
+ if (this.matchIterator != null) {
+ this.matchIterator.abort();
}
}
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/47a68adc/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/RegularPactTask.java
----------------------------------------------------------------------
diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/RegularPactTask.java b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/RegularPactTask.java
index 3140525..3e4a1fd 100644
--- a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/RegularPactTask.java
+++ b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/RegularPactTask.java
@@ -531,6 +531,16 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i
}
catch (Throwable t) {}
}
+
+ // if resettable driver invoke treardown
+ if (this.driver instanceof ResettablePactDriver) {
+ final ResettablePactDriver<?, ?> resDriver = (ResettablePactDriver<?, ?>) this.driver;
+ try {
+ resDriver.teardown();
+ } catch (Throwable t) {
+ throw new Exception("Error while shutting down an iterative operator: " + t.getMessage(), t);
+ }
+ }
RegularPactTask.cancelChainedTasks(this.chainedTasks);
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/47a68adc/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/sort/AsynchonousPartialSorterITCase.java
----------------------------------------------------------------------
diff --git a/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/sort/AsynchonousPartialSorterITCase.java b/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/sort/AsynchonousPartialSorterITCase.java
index f191075..155bf28 100644
--- a/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/sort/AsynchonousPartialSorterITCase.java
+++ b/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/sort/AsynchonousPartialSorterITCase.java
@@ -128,7 +128,7 @@ public class AsynchonousPartialSorterITCase {
// merge iterator
LOG.debug("Initializing sortmerger...");
Sorter<Record> sorter = new AsynchronousPartialSorter<Record>(this.memoryManager, source,
- this.parentTask, this.serializer, this.comparator, 1.0);
+ this.parentTask, this.serializer, this.comparator, 0.2);
runPartialSorter(sorter, NUM_RECORDS, 2);
}
@@ -151,9 +151,9 @@ public class AsynchonousPartialSorterITCase {
// merge iterator
LOG.debug("Initializing sortmerger...");
Sorter<Record> sorter = new AsynchronousPartialSorter<Record>(this.memoryManager, source,
- this.parentTask, this.serializer, this.comparator, 1.0);
+ this.parentTask, this.serializer, this.comparator, 0.15);
- runPartialSorter(sorter, NUM_RECORDS, 28);
+ runPartialSorter(sorter, NUM_RECORDS, 27);
}
catch (Exception t) {
t.printStackTrace();
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/47a68adc/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/task/CachedMatchTaskTest.java
----------------------------------------------------------------------
diff --git a/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/task/CachedMatchTaskTest.java b/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/task/CachedMatchTaskTest.java
new file mode 100644
index 0000000..ff560df
--- /dev/null
+++ b/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/task/CachedMatchTaskTest.java
@@ -0,0 +1,491 @@
+/***********************************************************************************************************************
+ * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ **********************************************************************************************************************/
+
+package eu.stratosphere.pact.runtime.task;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import eu.stratosphere.api.common.functions.GenericJoiner;
+import eu.stratosphere.api.java.record.functions.JoinFunction;
+import eu.stratosphere.api.java.typeutils.runtime.record.RecordComparator;
+import eu.stratosphere.api.java.typeutils.runtime.record.RecordPairComparatorFactory;
+import eu.stratosphere.pact.runtime.test.util.DelayingInfinitiveInputIterator;
+import eu.stratosphere.pact.runtime.test.util.DriverTestBase;
+import eu.stratosphere.pact.runtime.test.util.ExpectedTestException;
+import eu.stratosphere.pact.runtime.test.util.NirvanaOutputList;
+import eu.stratosphere.pact.runtime.test.util.TaskCancelThread;
+import eu.stratosphere.pact.runtime.test.util.UniformRecordGenerator;
+import eu.stratosphere.types.IntValue;
+import eu.stratosphere.types.Key;
+import eu.stratosphere.types.Record;
+import eu.stratosphere.util.Collector;
+
+public class CachedMatchTaskTest extends DriverTestBase<GenericJoiner<Record, Record, Record>>
+{
+ private static final long HASH_MEM = 6*1024*1024;
+
+ private static final long SORT_MEM = 3*1024*1024;
+
+ @SuppressWarnings("unchecked")
+ private final RecordComparator comparator1 = new RecordComparator(
+ new int[]{0}, (Class<? extends Key<?>>[])new Class[]{ IntValue.class });
+
+ @SuppressWarnings("unchecked")
+ private final RecordComparator comparator2 = new RecordComparator(
+ new int[]{0}, (Class<? extends Key<?>>[])new Class[]{ IntValue.class });
+
+ private final List<Record> outList = new ArrayList<Record>();
+
+
+ public CachedMatchTaskTest() {
+ super(HASH_MEM, 2, SORT_MEM);
+ }
+
+ @Test
+ public void testHash1MatchTask() {
+ int keyCnt1 = 20;
+ int valCnt1 = 1;
+
+ int keyCnt2 = 10;
+ int valCnt2 = 2;
+
+ addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
+ addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
+ addInputComparator(this.comparator1);
+ addInputComparator(this.comparator2);
+ getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
+ setOutput(this.outList);
+ getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED);
+ getTaskConfig().setRelativeMemoryDriver(1.0f);
+
+ BuildFirstCachedMatchDriver<Record, Record, Record> testTask = new BuildFirstCachedMatchDriver<Record, Record, Record>();
+
+ try {
+ testResettableDriver(testTask, MockMatchStub.class, 3);
+ } catch (Exception e) {
+ e.printStackTrace();
+ Assert.fail("Test caused an exception.");
+ }
+
+ final int expCnt = valCnt1*valCnt2*Math.min(keyCnt1, keyCnt2);
+ Assert.assertEquals("Wrong result set size.", expCnt, this.outList.size());
+ this.outList.clear();
+ }
+
+ @Test
+ public void testHash2MatchTask() {
+ int keyCnt1 = 20;
+ int valCnt1 = 1;
+
+ int keyCnt2 = 20;
+ int valCnt2 = 1;
+
+ addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
+ addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
+ addInputComparator(this.comparator1);
+ addInputComparator(this.comparator2);
+ getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
+ setOutput(this.outList);
+ getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED);
+ getTaskConfig().setRelativeMemoryDriver(1.0f);
+
+ BuildSecondCachedMatchDriver<Record, Record, Record> testTask = new BuildSecondCachedMatchDriver<Record, Record, Record>();
+
+ try {
+ testResettableDriver(testTask, MockMatchStub.class, 3);
+ } catch (Exception e) {
+ e.printStackTrace();
+ Assert.fail("Test caused an exception.");
+ }
+
+ final int expCnt = valCnt1*valCnt2*Math.min(keyCnt1, keyCnt2);
+ Assert.assertEquals("Wrong result set size.", expCnt, this.outList.size());
+ this.outList.clear();
+ }
+
+ @Test
+ public void testHash3MatchTask() {
+ int keyCnt1 = 20;
+ int valCnt1 = 1;
+
+ int keyCnt2 = 20;
+ int valCnt2 = 20;
+
+ addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
+ addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
+ addInputComparator(this.comparator1);
+ addInputComparator(this.comparator2);
+ getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
+ setOutput(this.outList);
+ getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED);
+ getTaskConfig().setRelativeMemoryDriver(1.0f);
+
+ BuildFirstCachedMatchDriver<Record, Record, Record> testTask = new BuildFirstCachedMatchDriver<Record, Record, Record>();
+
+ try {
+ testResettableDriver(testTask, MockMatchStub.class, 3);
+ } catch (Exception e) {
+ e.printStackTrace();
+ Assert.fail("Test caused an exception.");
+ }
+
+ final int expCnt = valCnt1*valCnt2*Math.min(keyCnt1, keyCnt2);
+ Assert.assertEquals("Wrong result set size.", expCnt, this.outList.size());
+ this.outList.clear();
+ }
+
+ @Test
+ public void testHash4MatchTask() {
+ int keyCnt1 = 20;
+ int valCnt1 = 20;
+
+ int keyCnt2 = 20;
+ int valCnt2 = 1;
+
+ addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
+ addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
+ addInputComparator(this.comparator1);
+ addInputComparator(this.comparator2);
+ getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
+ setOutput(this.outList);
+ getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED);
+ getTaskConfig().setRelativeMemoryDriver(1.0f);
+
+ BuildSecondCachedMatchDriver<Record, Record, Record> testTask = new BuildSecondCachedMatchDriver<Record, Record, Record>();
+
+ try {
+ testResettableDriver(testTask, MockMatchStub.class, 3);
+ } catch (Exception e) {
+ e.printStackTrace();
+ Assert.fail("Test caused an exception.");
+ }
+
+ final int expCnt = valCnt1*valCnt2*Math.min(keyCnt1, keyCnt2);
+ Assert.assertEquals("Wrong result set size.", expCnt, this.outList.size());
+ this.outList.clear();
+ }
+
+ @Test
+ public void testHash5MatchTask() {
+ int keyCnt1 = 20;
+ int valCnt1 = 20;
+
+ int keyCnt2 = 20;
+ int valCnt2 = 20;
+
+ addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
+ addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
+ addInputComparator(this.comparator1);
+ addInputComparator(this.comparator2);
+ getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
+ setOutput(this.outList);
+ getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED);
+ getTaskConfig().setRelativeMemoryDriver(1.0f);
+
+ BuildFirstCachedMatchDriver<Record, Record, Record> testTask = new BuildFirstCachedMatchDriver<Record, Record, Record>();
+
+ try {
+ testResettableDriver(testTask, MockMatchStub.class, 3);
+ } catch (Exception e) {
+ e.printStackTrace();
+ Assert.fail("Test caused an exception.");
+ }
+
+ final int expCnt = valCnt1*valCnt2*Math.min(keyCnt1, keyCnt2);
+ Assert.assertEquals("Wrong result set size.", expCnt, this.outList.size());
+ this.outList.clear();
+ }
+
+ @Test
+ public void testFailingHashFirstMatchTask() {
+ int keyCnt1 = 20;
+ int valCnt1 = 20;
+
+ int keyCnt2 = 20;
+ int valCnt2 = 20;
+
+ addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
+ addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
+ addInputComparator(this.comparator1);
+ addInputComparator(this.comparator2);
+ getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
+ setOutput(new NirvanaOutputList());
+ getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED);
+ getTaskConfig().setRelativeMemoryDriver(1.0f);
+
+ BuildFirstCachedMatchDriver<Record, Record, Record> testTask = new BuildFirstCachedMatchDriver<Record, Record, Record>();
+
+ try {
+ testResettableDriver(testTask, MockFailingMatchStub.class, 3);
+ Assert.fail("Function exception was not forwarded.");
+ } catch (ExpectedTestException etex) {
+ // good!
+ } catch (Exception e) {
+ e.printStackTrace();
+ Assert.fail("Test caused an exception.");
+ }
+ }
+
+ @Test
+ public void testFailingHashSecondMatchTask() {
+ int keyCnt1 = 20;
+ int valCnt1 = 20;
+
+ int keyCnt2 = 20;
+ int valCnt2 = 20;
+
+ addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
+ addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
+ addInputComparator(this.comparator1);
+ addInputComparator(this.comparator2);
+ getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
+ setOutput(new NirvanaOutputList());
+ getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED);
+ getTaskConfig().setRelativeMemoryDriver(1.0f);
+
+ BuildSecondCachedMatchDriver<Record, Record, Record> testTask = new BuildSecondCachedMatchDriver<Record, Record, Record>();
+
+ try {
+ testResettableDriver(testTask, MockFailingMatchStub.class, 3);
+ Assert.fail("Function exception was not forwarded.");
+ } catch (ExpectedTestException etex) {
+ // good!
+ } catch (Exception e) {
+ e.printStackTrace();
+ Assert.fail("Test caused an exception.");
+ }
+ }
+
+ @Test
+ public void testCancelHashMatchTaskWhileBuildFirst() {
+ int keyCnt = 20;
+ int valCnt = 20;
+
+ addInput(new DelayingInfinitiveInputIterator(100));
+ addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
+
+ addInputComparator(this.comparator1);
+ addInputComparator(this.comparator2);
+
+ getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
+
+ setOutput(new NirvanaOutputList());
+
+ getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED);
+ getTaskConfig().setRelativeMemoryDriver(1.0f);
+
+ final BuildFirstCachedMatchDriver<Record, Record, Record> testTask = new BuildFirstCachedMatchDriver<Record, Record, Record>();
+
+ final AtomicBoolean success = new AtomicBoolean(false);
+
+ Thread taskRunner = new Thread() {
+ @Override
+ public void run() {
+ try {
+ testDriver(testTask, MockFailingMatchStub.class);
+ success.set(true);
+ } catch (Exception ie) {
+ ie.printStackTrace();
+ }
+ }
+ };
+ taskRunner.start();
+
+ TaskCancelThread tct = new TaskCancelThread(1, taskRunner, this);
+ tct.start();
+
+ try {
+ tct.join();
+ taskRunner.join();
+ } catch(InterruptedException ie) {
+ Assert.fail("Joining threads failed");
+ }
+
+ Assert.assertTrue("Test threw an exception even though it was properly canceled.", success.get());
+ }
+
+ @Test
+ public void testHashCancelMatchTaskWhileBuildSecond() {
+ int keyCnt = 20;
+ int valCnt = 20;
+
+ addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
+ addInput(new DelayingInfinitiveInputIterator(100));
+ addInputComparator(this.comparator1);
+ addInputComparator(this.comparator2);
+ getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
+ setOutput(new NirvanaOutputList());
+ getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED);
+ getTaskConfig().setRelativeMemoryDriver(1.0f);
+
+ final BuildSecondCachedMatchDriver<Record, Record, Record> testTask = new BuildSecondCachedMatchDriver<Record, Record, Record>();
+
+ final AtomicBoolean success = new AtomicBoolean(false);
+
+ Thread taskRunner = new Thread() {
+ @Override
+ public void run() {
+ try {
+ testDriver(testTask, MockMatchStub.class);
+ success.set(true);
+ } catch (Exception ie) {
+ ie.printStackTrace();
+ }
+ }
+ };
+ taskRunner.start();
+
+ TaskCancelThread tct = new TaskCancelThread(1, taskRunner, this);
+ tct.start();
+
+ try {
+ tct.join();
+ taskRunner.join();
+ } catch(InterruptedException ie) {
+ Assert.fail("Joining threads failed");
+ }
+
+ Assert.assertTrue("Test threw an exception even though it was properly canceled.", success.get());
+ }
+
+ @Test
+ public void testHashFirstCancelMatchTaskWhileMatching() {
+ int keyCnt = 20;
+ int valCnt = 20;
+
+ addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
+ addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
+ addInputComparator(this.comparator1);
+ addInputComparator(this.comparator2);
+ getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
+ setOutput(new NirvanaOutputList());
+ getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST);
+ getTaskConfig().setRelativeMemoryDriver(1.0f);
+
+ final BuildFirstCachedMatchDriver<Record, Record, Record> testTask = new BuildFirstCachedMatchDriver<Record, Record, Record>();
+
+ final AtomicBoolean success = new AtomicBoolean(false);
+
+ Thread taskRunner = new Thread() {
+ @Override
+ public void run() {
+ try {
+ testDriver(testTask, MockMatchStub.class);
+ success.set(true);
+ } catch (Exception ie) {
+ ie.printStackTrace();
+ }
+ }
+ };
+ taskRunner.start();
+
+ TaskCancelThread tct = new TaskCancelThread(1, taskRunner, this);
+ tct.start();
+
+ try {
+ tct.join();
+ taskRunner.join();
+ } catch(InterruptedException ie) {
+ Assert.fail("Joining threads failed");
+ }
+
+ Assert.assertTrue("Test threw an exception even though it was properly canceled.", success.get());
+ }
+
+ @Test
+ public void testHashSecondCancelMatchTaskWhileMatching() {
+ int keyCnt = 20;
+ int valCnt = 20;
+
+ addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
+ addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
+ addInputComparator(this.comparator1);
+ addInputComparator(this.comparator2);
+ getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
+ setOutput(new NirvanaOutputList());
+ getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND);
+ getTaskConfig().setRelativeMemoryDriver(1.0f);
+
+ final BuildSecondCachedMatchDriver<Record, Record, Record> testTask = new BuildSecondCachedMatchDriver<Record, Record, Record>();
+
+
+ final AtomicBoolean success = new AtomicBoolean(false);
+
+ Thread taskRunner = new Thread() {
+ @Override
+ public void run() {
+ try {
+ testDriver(testTask, MockMatchStub.class);
+ success.set(true);
+ } catch (Exception ie) {
+ ie.printStackTrace();
+ }
+ }
+ };
+ taskRunner.start();
+
+ TaskCancelThread tct = new TaskCancelThread(1, taskRunner, this);
+ tct.start();
+
+ try {
+ tct.join();
+ taskRunner.join();
+ } catch(InterruptedException ie) {
+ Assert.fail("Joining threads failed");
+ }
+
+ Assert.assertTrue("Test threw an exception even though it was properly canceled.", success.get());
+ }
+
+ // =================================================================================================
+
+ public static final class MockMatchStub extends JoinFunction {
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public void join(Record record1, Record record2, Collector<Record> out) throws Exception {
+ out.collect(record1);
+ }
+ }
+
+ public static final class MockFailingMatchStub extends JoinFunction {
+ private static final long serialVersionUID = 1L;
+
+ private int cnt = 0;
+
+ @Override
+ public void join(Record record1, Record record2, Collector<Record> out) {
+ if (++this.cnt >= 10) {
+ throw new ExpectedTestException();
+ }
+
+ out.collect(record1);
+ }
+ }
+
+ public static final class MockDelayingMatchStub extends JoinFunction {
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public void join(Record record1, Record record2, Collector<Record> out) {
+ try {
+ Thread.sleep(100);
+ } catch (InterruptedException e) { }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/47a68adc/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/test/util/DriverTestBase.java
----------------------------------------------------------------------
diff --git a/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/test/util/DriverTestBase.java b/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/test/util/DriverTestBase.java
index 531382e..4156204 100644
--- a/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/test/util/DriverTestBase.java
+++ b/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/test/util/DriverTestBase.java
@@ -35,6 +35,7 @@ import eu.stratosphere.api.java.typeutils.runtime.record.RecordSerializerFactory
import eu.stratosphere.pact.runtime.sort.UnilateralSortMerger;
import eu.stratosphere.pact.runtime.task.PactDriver;
import eu.stratosphere.pact.runtime.task.PactTaskContext;
+import eu.stratosphere.pact.runtime.task.ResettablePactDriver;
import eu.stratosphere.pact.runtime.task.util.TaskConfig;
import eu.stratosphere.types.Record;
import eu.stratosphere.util.Collector;
@@ -194,16 +195,48 @@ public class DriverTestBase<S extends Function> implements PactTaskContext<S, Re
catch (Throwable t) {}
}
+ // if resettable driver invoke treardown
+ if (this.driver instanceof ResettablePactDriver) {
+ final ResettablePactDriver<?, ?> resDriver = (ResettablePactDriver<?, ?>) this.driver;
+ try {
+ resDriver.teardown();
+ } catch (Throwable t) {
+ throw new Exception("Error while shutting down an iterative operator: " + t.getMessage(), t);
+ }
+ }
+
// drop exception, if the task was canceled
if (this.running) {
throw ex;
}
+
}
finally {
driver.cleanup();
}
}
+ @SuppressWarnings({"unchecked","rawtypes"})
+ public void testResettableDriver(ResettablePactDriver driver, Class stubClass, int iterations) throws Exception {
+
+ driver.setup(this);
+
+ for(int i = 0; i < iterations; i++) {
+
+ if(i == 0) {
+ driver.initialize();
+ }
+ else {
+ driver.reset();
+ }
+
+ testDriver(driver, stubClass);
+
+ }
+
+ driver.teardown();
+ }
+
public void cancel() throws Exception {
this.running = false;
this.driver.cancel();