You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bc...@apache.org on 2016/10/13 00:38:29 UTC
[2/4] incubator-beam git commit: [BEAM-65] SplittableDoFn prototype.
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnAdapters.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnAdapters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnAdapters.java
index 3eee74a..f671a67 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnAdapters.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnAdapters.java
@@ -25,6 +25,7 @@ import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.WindowingInternals;
@@ -37,8 +38,8 @@ import org.joda.time.Instant;
/**
* Utility class containing adapters to/from {@link DoFn} and {@link OldDoFn}.
*
- * @deprecated This class will go away when we start running {@link DoFn}'s directly (using
- * {@link DoFnInvoker}) rather than via {@link OldDoFn}.
+ * @deprecated This class will go away when we start running {@link DoFn}'s directly (using {@link
+ * DoFnInvoker}) rather than via {@link OldDoFn}.
*/
@Deprecated
public class DoFnAdapters {
@@ -176,6 +177,18 @@ public class DoFnAdapters {
}
/**
+ * If the fn was created using {@link #toOldDoFn}, returns the original {@link DoFn}. Otherwise,
+ * returns {@code null}.
+ */
+ public static <InputT, OutputT> DoFn<InputT, OutputT> getDoFn(OldDoFn<InputT, OutputT> fn) {
+ if (fn instanceof SimpleDoFnAdapter) {
+ return ((SimpleDoFnAdapter<InputT, OutputT>) fn).fn;
+ } else {
+ return null;
+ }
+ }
+
+ /**
* Wraps a {@link DoFn} that doesn't require access to {@link BoundedWindow} as an {@link
* OldDoFn}.
*/
@@ -324,6 +337,11 @@ public class DoFnAdapters {
public DoFn.OutputReceiver<OutputT> outputReceiver() {
throw new UnsupportedOperationException("outputReceiver() exists only for testing");
}
+
+ @Override
+ public <RestrictionT> RestrictionTracker<RestrictionT> restrictionTracker() {
+ throw new UnsupportedOperationException("This is a non-splittable DoFn");
+ }
}
/**
@@ -412,5 +430,10 @@ public class DoFnAdapters {
public DoFn.OutputReceiver<OutputT> outputReceiver() {
throw new UnsupportedOperationException("outputReceiver() exists only for testing");
}
+
+ @Override
+ public <RestrictionT> RestrictionTracker<RestrictionT> restrictionTracker() {
+ throw new UnsupportedOperationException("This is a non-splittable DoFn");
+ }
}
}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
index 11a4cbd..302bb02 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
@@ -46,7 +46,9 @@ import org.apache.beam.sdk.util.UserCodeException;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.WindowingInternals;
import org.apache.beam.sdk.util.state.InMemoryStateInternals;
+import org.apache.beam.sdk.util.state.InMemoryTimerInternals;
import org.apache.beam.sdk.util.state.StateInternals;
+import org.apache.beam.sdk.util.state.TimerCallback;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TimestampedValue;
import org.apache.beam.sdk.values.TupleTag;
@@ -222,8 +224,11 @@ public class DoFnTester<InputT, OutputT> implements AutoCloseable {
if (state == State.UNINITIALIZED) {
initializeState();
}
- TestContext<InputT, OutputT> context = createContext(fn);
+ TestContext context = createContext(fn);
context.setupDelegateAggregators();
+ // State and timer internals are per-bundle.
+ stateInternals = InMemoryStateInternals.forKey(new Object());
+ timerInternals = new InMemoryTimerInternals();
try {
fn.startBundle(context);
} catch (UserCodeException e) {
@@ -460,6 +465,35 @@ public class DoFnTester<InputT, OutputT> implements AutoCloseable {
return extractAggregatorValue(agg.getName(), agg.getCombineFn());
}
+ private static TimerCallback collectInto(final List<TimerInternals.TimerData> firedTimers) {
+ return new TimerCallback() {
+ @Override
+ public void onTimer(TimerInternals.TimerData timer) throws Exception {
+ firedTimers.add(timer);
+ }
+ };
+ }
+
+ public List<TimerInternals.TimerData> advanceInputWatermark(Instant newWatermark) {
+ try {
+ final List<TimerInternals.TimerData> firedTimers = new ArrayList<>();
+ timerInternals.advanceInputWatermark(collectInto(firedTimers), newWatermark);
+ return firedTimers;
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public List<TimerInternals.TimerData> advanceProcessingTime(Instant newProcessingTime) {
+ try {
+ final List<TimerInternals.TimerData> firedTimers = new ArrayList<>();
+ timerInternals.advanceProcessingTime(collectInto(firedTimers), newProcessingTime);
+ return firedTimers;
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
private <AccumT, AggregateT> AggregateT extractAggregatorValue(
String name, CombineFn<?, AccumT, AggregateT> combiner) {
@SuppressWarnings("unchecked")
@@ -476,41 +510,27 @@ public class DoFnTester<InputT, OutputT> implements AutoCloseable {
return MoreObjects.firstNonNull(elems, Collections.<WindowedValue<T>>emptyList());
}
- private TestContext<InputT, OutputT> createContext(OldDoFn<InputT, OutputT> fn) {
- return new TestContext<>(fn, options, mainOutputTag, outputs, accumulators);
+ private TestContext createContext(OldDoFn<InputT, OutputT> fn) {
+ return new TestContext();
}
- private static class TestContext<InT, OutT> extends OldDoFn<InT, OutT>.Context {
- private final PipelineOptions opts;
- private final TupleTag<OutT> mainOutputTag;
- private final Map<TupleTag<?>, List<WindowedValue<?>>> outputs;
- private final Map<String, Object> accumulators;
-
- public TestContext(
- OldDoFn<InT, OutT> fn,
- PipelineOptions opts,
- TupleTag<OutT> mainOutputTag,
- Map<TupleTag<?>, List<WindowedValue<?>>> outputs,
- Map<String, Object> accumulators) {
+ private class TestContext extends OldDoFn<InputT, OutputT>.Context {
+ TestContext() {
fn.super();
- this.opts = opts;
- this.mainOutputTag = mainOutputTag;
- this.outputs = outputs;
- this.accumulators = accumulators;
}
@Override
public PipelineOptions getPipelineOptions() {
- return opts;
+ return options;
}
@Override
- public void output(OutT output) {
+ public void output(OutputT output) {
sideOutput(mainOutputTag, output);
}
@Override
- public void outputWithTimestamp(OutT output, Instant timestamp) {
+ public void outputWithTimestamp(OutputT output, Instant timestamp) {
sideOutputWithTimestamp(mainOutputTag, output, timestamp);
}
@@ -570,40 +590,27 @@ public class DoFnTester<InputT, OutputT> implements AutoCloseable {
}
}
- private TestProcessContext<InputT, OutputT> createProcessContext(
+ private TestProcessContext createProcessContext(
OldDoFn<InputT, OutputT> fn,
TimestampedValue<InputT> elem) {
WindowedValue<InputT> windowedValue = WindowedValue.timestampedValueInGlobalWindow(
elem.getValue(), elem.getTimestamp());
- return new TestProcessContext<>(fn,
- createContext(fn),
- windowedValue,
- mainOutputTag,
- sideInputs);
- }
-
- private static class TestProcessContext<InT, OutT> extends OldDoFn<InT, OutT>.ProcessContext {
- private final TestContext<InT, OutT> context;
- private final TupleTag<OutT> mainOutputTag;
- private final WindowedValue<InT> element;
- private final Map<PCollectionView<?>, Map<BoundedWindow, ?>> sideInputs;
-
- private TestProcessContext(
- OldDoFn<InT, OutT> fn,
- TestContext<InT, OutT> context,
- WindowedValue<InT> element,
- TupleTag<OutT> mainOutputTag,
- Map<PCollectionView<?>, Map<BoundedWindow, ?>> sideInputs) {
+ return new TestProcessContext(windowedValue);
+ }
+
+ private class TestProcessContext extends OldDoFn<InputT, OutputT>.ProcessContext {
+ private final TestContext context;
+ private final WindowedValue<InputT> element;
+
+ private TestProcessContext(WindowedValue<InputT> element) {
fn.super();
- this.context = context;
+ this.context = createContext(fn);
this.element = element;
- this.mainOutputTag = mainOutputTag;
- this.sideInputs = sideInputs;
}
@Override
- public InT element() {
+ public InputT element() {
return element.getValue();
}
@@ -638,10 +645,8 @@ public class DoFnTester<InputT, OutputT> implements AutoCloseable {
}
@Override
- public WindowingInternals<InT, OutT> windowingInternals() {
- return new WindowingInternals<InT, OutT>() {
- StateInternals<?> stateInternals = InMemoryStateInternals.forKey(new Object());
-
+ public WindowingInternals<InputT, OutputT> windowingInternals() {
+ return new WindowingInternals<InputT, OutputT>() {
@Override
public StateInternals<?> stateInternals() {
return stateInternals;
@@ -649,7 +654,7 @@ public class DoFnTester<InputT, OutputT> implements AutoCloseable {
@Override
public void outputWindowedValue(
- OutT output,
+ OutputT output,
Instant timestamp,
Collection<? extends BoundedWindow> windows,
PaneInfo pane) {
@@ -658,8 +663,7 @@ public class DoFnTester<InputT, OutputT> implements AutoCloseable {
@Override
public TimerInternals timerInternals() {
- throw
- new UnsupportedOperationException("Timer Internals are not supported in DoFnTester");
+ return timerInternals;
}
@Override
@@ -695,12 +699,12 @@ public class DoFnTester<InputT, OutputT> implements AutoCloseable {
}
@Override
- public void output(OutT output) {
+ public void output(OutputT output) {
sideOutput(mainOutputTag, output);
}
@Override
- public void outputWithTimestamp(OutT output, Instant timestamp) {
+ public void outputWithTimestamp(OutputT output, Instant timestamp) {
sideOutputWithTimestamp(mainOutputTag, output, timestamp);
}
@@ -774,6 +778,9 @@ public class DoFnTester<InputT, OutputT> implements AutoCloseable {
/** The outputs from the {@link DoFn} under test. */
private Map<TupleTag<?>, List<WindowedValue<?>>> outputs;
+ private InMemoryStateInternals<?> stateInternals;
+ private InMemoryTimerInternals timerInternals;
+
/** The state of processing of the {@link DoFn} under test. */
private State state = State.UNINITIALIZED;
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
index 2443d8e..fdef908 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
@@ -17,6 +17,8 @@
*/
package org.apache.beam.sdk.transforms;
+import static com.google.common.base.Preconditions.checkArgument;
+
import com.google.common.collect.ImmutableList;
import java.io.Serializable;
import java.util.Arrays;
@@ -27,6 +29,7 @@ import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.runners.PipelineRunner;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.display.DisplayData.Builder;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.StringUtils;
@@ -716,6 +719,8 @@ public class ParDo {
@Override
public PCollection<OutputT> apply(PCollection<? extends InputT> input) {
+ checkArgument(
+ !isSplittable(fn), "Splittable DoFn not supported by the current runner");
return PCollection.<OutputT>createPrimitiveOutputInternal(
input.getPipeline(),
input.getWindowingStrategy(),
@@ -925,6 +930,9 @@ public class ParDo {
@Override
public PCollectionTuple apply(PCollection<? extends InputT> input) {
+ checkArgument(
+ !isSplittable(fn), "Splittable DoFn not supported by the current runner");
+
PCollectionTuple outputs = PCollectionTuple.ofPrimitiveOutputsInternal(
input.getPipeline(),
TupleTagList.of(mainOutputTag).and(sideOutputTags.getAll()),
@@ -997,4 +1005,15 @@ public class ParDo {
.add(DisplayData.item("fn", fnClass)
.withLabel("Transform Function"));
}
+
+ private static boolean isSplittable(OldDoFn<?, ?> oldDoFn) {
+ DoFn<?, ?> fn = DoFnAdapters.getDoFn(oldDoFn);
+ if (fn == null) {
+ return false;
+ }
+ return DoFnSignatures.INSTANCE
+ .getOrParseSignature(fn.getClass())
+ .processElement()
+ .isSplittable();
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
index eb6961c..9672d53 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
@@ -17,7 +17,10 @@
*/
package org.apache.beam.sdk.transforms.reflect;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
/**
* Interface for invoking the {@code DoFn} processing methods.
@@ -43,7 +46,28 @@ public interface DoFnInvoker<InputT, OutputT> {
*
* @param c The {@link DoFn.ProcessContext} to invoke the fn with.
* @param extra Factory for producing extra parameter objects (such as window), if necessary.
+ * @return The {@link DoFn.ProcessContinuation} returned by the underlying method, or {@link
+ * DoFn.ProcessContinuation#stop()} if it returns {@code void}.
*/
- void invokeProcessElement(
+ DoFn.ProcessContinuation invokeProcessElement(
DoFn<InputT, OutputT>.ProcessContext c, DoFn.ExtraContextFactory<InputT, OutputT> extra);
+
+ /** Invoke the {@link DoFn.GetInitialRestriction} method on the bound {@link DoFn}. */
+ <RestrictionT> RestrictionT invokeGetInitialRestriction(InputT element);
+
+ /**
+ * Invoke the {@link DoFn.GetRestrictionCoder} method on the bound {@link DoFn}. Called only
+ * during pipeline construction time.
+ */
+ <RestrictionT> Coder<RestrictionT> invokeGetRestrictionCoder(CoderRegistry coderRegistry);
+
+ /** Invoke the {@link DoFn.SplitRestriction} method on the bound {@link DoFn}. */
+ <RestrictionT> void invokeSplitRestriction(
+ InputT element,
+ RestrictionT restriction,
+ DoFn.OutputReceiver<RestrictionT> restrictionReceiver);
+
+ /** Invoke the {@link DoFn.NewTracker} method on the bound {@link DoFn}. */
+ <RestrictionT, TrackerT extends RestrictionTracker<RestrictionT>> TrackerT invokeNewTracker(
+ RestrictionT restriction);
}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java
index da88587..fd057c3 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java
@@ -19,6 +19,7 @@ package org.apache.beam.sdk.transforms.reflect;
import static com.google.common.base.Preconditions.checkArgument;
+import com.google.common.reflect.TypeToken;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
@@ -26,6 +27,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumMap;
import java.util.LinkedHashMap;
+import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import net.bytebuddy.ByteBuddy;
@@ -35,10 +37,12 @@ import net.bytebuddy.description.method.MethodDescription;
import net.bytebuddy.description.modifier.FieldManifestation;
import net.bytebuddy.description.modifier.Visibility;
import net.bytebuddy.description.type.TypeDescription;
+import net.bytebuddy.description.type.TypeList;
import net.bytebuddy.dynamic.DynamicType;
import net.bytebuddy.dynamic.loading.ClassLoadingStrategy;
import net.bytebuddy.dynamic.scaffold.InstrumentedType;
import net.bytebuddy.dynamic.scaffold.subclass.ConstructorStrategy;
+import net.bytebuddy.implementation.ExceptionMethod;
import net.bytebuddy.implementation.FixedValue;
import net.bytebuddy.implementation.Implementation;
import net.bytebuddy.implementation.Implementation.Context;
@@ -48,6 +52,7 @@ import net.bytebuddy.implementation.bytecode.ByteCodeAppender;
import net.bytebuddy.implementation.bytecode.StackManipulation;
import net.bytebuddy.implementation.bytecode.Throw;
import net.bytebuddy.implementation.bytecode.assign.Assigner;
+import net.bytebuddy.implementation.bytecode.assign.TypeCasting;
import net.bytebuddy.implementation.bytecode.member.FieldAccess;
import net.bytebuddy.implementation.bytecode.member.MethodInvocation;
import net.bytebuddy.implementation.bytecode.member.MethodReturn;
@@ -57,12 +62,17 @@ import net.bytebuddy.jar.asm.MethodVisitor;
import net.bytebuddy.jar.asm.Opcodes;
import net.bytebuddy.jar.asm.Type;
import net.bytebuddy.matcher.ElementMatchers;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFn.ExtraContextFactory;
import org.apache.beam.sdk.transforms.DoFn.ProcessElement;
import org.apache.beam.sdk.transforms.DoFnAdapters;
import org.apache.beam.sdk.transforms.OldDoFn;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.util.UserCodeException;
+import org.apache.beam.sdk.values.TypeDescriptor;
/** Dynamically generates {@link DoFnInvoker} instances for invoking a {@link DoFn}. */
public class DoFnInvokers {
@@ -81,10 +91,10 @@ public class DoFnInvokers {
private DoFnInvokers() {}
/**
- * Creates a {@link DoFnInvoker} for the given {@link Object}, which should be either a
- * {@link DoFn} or an {@link OldDoFn}. The expected use would be to deserialize a user's
- * function as an {@link Object} and then pass it to this method, so there is no need to
- * statically specify what sort of object it is.
+ * Creates a {@link DoFnInvoker} for the given {@link Object}, which should be either a {@link
+ * DoFn} or an {@link OldDoFn}. The expected use would be to deserialize a user's function as an
+ * {@link Object} and then pass it to this method, so there is no need to statically specify what
+ * sort of object it is.
*
* @deprecated this is to be used only as a migration path for decoupling upgrades
*/
@@ -92,15 +102,16 @@ public class DoFnInvokers {
public DoFnInvoker<?, ?> invokerFor(Object deserializedFn) {
if (deserializedFn instanceof DoFn) {
return newByteBuddyInvoker((DoFn<?, ?>) deserializedFn);
- } else if (deserializedFn instanceof OldDoFn){
+ } else if (deserializedFn instanceof OldDoFn) {
return new OldDoFnInvoker<>((OldDoFn<?, ?>) deserializedFn);
} else {
- throw new IllegalArgumentException(String.format(
- "Cannot create a %s for %s; it should be either a %s or an %s.",
- DoFnInvoker.class.getSimpleName(),
- deserializedFn.toString(),
- DoFn.class.getSimpleName(),
- OldDoFn.class.getSimpleName()));
+ throw new IllegalArgumentException(
+ String.format(
+ "Cannot create a %s for %s; it should be either a %s or an %s.",
+ DoFnInvoker.class.getSimpleName(),
+ deserializedFn.toString(),
+ DoFn.class.getSimpleName(),
+ OldDoFn.class.getSimpleName()));
}
}
@@ -113,12 +124,13 @@ public class DoFnInvokers {
}
@Override
- public void invokeProcessElement(
+ public DoFn.ProcessContinuation invokeProcessElement(
DoFn<InputT, OutputT>.ProcessContext c, ExtraContextFactory<InputT, OutputT> extra) {
OldDoFn<InputT, OutputT>.ProcessContext oldCtx =
DoFnAdapters.adaptProcessContext(fn, c, extra);
try {
fn.processElement(oldCtx);
+ return DoFn.ProcessContinuation.stop();
} catch (Throwable exc) {
throw UserCodeException.wrap(exc);
}
@@ -161,14 +173,37 @@ public class DoFnInvokers {
throw UserCodeException.wrap(exc);
}
}
+
+ @Override
+ public <RestrictionT> RestrictionT invokeGetInitialRestriction(InputT element) {
+ throw new UnsupportedOperationException("OldDoFn is not splittable");
+ }
+
+ @Override
+ public <RestrictionT> Coder<RestrictionT> invokeGetRestrictionCoder(
+ CoderRegistry coderRegistry) {
+ throw new UnsupportedOperationException("OldDoFn is not splittable");
+ }
+
+ @Override
+ public <RestrictionT> void invokeSplitRestriction(
+ InputT element, RestrictionT restriction, DoFn.OutputReceiver<RestrictionT> receiver) {
+ throw new UnsupportedOperationException("OldDoFn is not splittable");
+ }
+
+ @Override
+ public <RestrictionT, TrackerT extends RestrictionTracker<RestrictionT>>
+ TrackerT invokeNewTracker(RestrictionT restriction) {
+ throw new UnsupportedOperationException("OldDoFn is not splittable");
+ }
}
/** @return the {@link DoFnInvoker} for the given {@link DoFn}. */
@SuppressWarnings({"unchecked", "rawtypes"})
public <InputT, OutputT> DoFnInvoker<InputT, OutputT> newByteBuddyInvoker(
DoFn<InputT, OutputT> fn) {
- return newByteBuddyInvoker(DoFnSignatures.INSTANCE.getOrParseSignature(
- (Class) fn.getClass()), fn);
+ return newByteBuddyInvoker(
+ DoFnSignatures.INSTANCE.getOrParseSignature((Class) fn.getClass()), fn);
}
/** @return the {@link DoFnInvoker} for the given {@link DoFn}. */
@@ -214,6 +249,32 @@ public class DoFnInvokers {
return constructor;
}
+ /** Default implementation of {@link DoFn.SplitRestriction}, for delegation by bytebuddy. */
+ public static class DefaultSplitRestriction {
+ /** Doesn't split the restriction. */
+ @SuppressWarnings("unused")
+ public static <InputT, RestrictionT> void invokeSplitRestriction(
+ InputT element, RestrictionT restriction, DoFn.OutputReceiver<RestrictionT> receiver) {
+ receiver.output(restriction);
+ }
+ }
+
+ /** Default implementation of {@link DoFn.GetRestrictionCoder}, for delegation by bytebuddy. */
+ public static class DefaultRestrictionCoder {
+ private final TypeToken<?> restrictionType;
+
+ DefaultRestrictionCoder(TypeToken<?> restrictionType) {
+ this.restrictionType = restrictionType;
+ }
+
+ /** Doesn't split the restriction. */
+ @SuppressWarnings({"unused", "unchecked"})
+ public <RestrictionT> Coder<RestrictionT> invokeGetRestrictionCoder(CoderRegistry registry)
+ throws CannotProvideCoderException {
+ return (Coder) registry.getCoder(TypeDescriptor.of(restrictionType.getType()));
+ }
+ }
+
/** Generates a {@link DoFnInvoker} class for the given {@link DoFnSignature}. */
private static Class<? extends DoFnInvoker<?, ?>> generateInvokerClass(DoFnSignature signature) {
Class<? extends DoFn<?, ?>> fnClass = signature.fnClass();
@@ -247,7 +308,15 @@ public class DoFnInvokers {
.method(ElementMatchers.named("invokeSetup"))
.intercept(delegateOrNoop(signature.setup()))
.method(ElementMatchers.named("invokeTeardown"))
- .intercept(delegateOrNoop(signature.teardown()));
+ .intercept(delegateOrNoop(signature.teardown()))
+ .method(ElementMatchers.named("invokeGetInitialRestriction"))
+ .intercept(delegateWithDowncastOrThrow(signature.getInitialRestriction()))
+ .method(ElementMatchers.named("invokeSplitRestriction"))
+ .intercept(splitRestrictionDelegation(signature))
+ .method(ElementMatchers.named("invokeGetRestrictionCoder"))
+ .intercept(getRestrictionCoderDelegation(signature))
+ .method(ElementMatchers.named("invokeNewTracker"))
+ .intercept(delegateWithDowncastOrThrow(signature.newTracker()));
DynamicType.Unloaded<?> unloaded = builder.make();
@@ -260,6 +329,28 @@ public class DoFnInvokers {
return res;
}
+ private static Implementation getRestrictionCoderDelegation(DoFnSignature signature) {
+ if (signature.processElement().isSplittable()) {
+ if (signature.getRestrictionCoder() == null) {
+ return MethodDelegation.to(
+ new DefaultRestrictionCoder(signature.getInitialRestriction().restrictionT()));
+ } else {
+ return new DowncastingParametersMethodDelegation(
+ signature.getRestrictionCoder().targetMethod());
+ }
+ } else {
+ return ExceptionMethod.throwing(UnsupportedOperationException.class);
+ }
+ }
+
+ private static Implementation splitRestrictionDelegation(DoFnSignature signature) {
+ if (signature.splitRestriction() == null) {
+ return MethodDelegation.to(DefaultSplitRestriction.class);
+ } else {
+ return new DowncastingParametersMethodDelegation(signature.splitRestriction().targetMethod());
+ }
+ }
+
/** Delegates to the given method if available, or does nothing. */
private static Implementation delegateOrNoop(DoFnSignature.DoFnMethod method) {
return (method == null)
@@ -267,6 +358,13 @@ public class DoFnInvokers {
: new DoFnMethodDelegation(method.targetMethod());
}
+ /** Delegates to the given method if available, or throws UnsupportedOperationException. */
+ private static Implementation delegateWithDowncastOrThrow(DoFnSignature.DoFnMethod method) {
+ return (method == null)
+ ? ExceptionMethod.throwing(UnsupportedOperationException.class)
+ : new DowncastingParametersMethodDelegation(method.targetMethod());
+ }
+
/**
* Implements a method of {@link DoFnInvoker} (the "instrumented method") by delegating to a
* "target method" of the wrapped {@link DoFn}.
@@ -374,12 +472,37 @@ public class DoFnInvokers {
}
/**
+ * Passes parameters to the delegated method by downcasting each parameter of non-primitive type
+ * to its expected type.
+ */
+ private static class DowncastingParametersMethodDelegation extends DoFnMethodDelegation {
+ DowncastingParametersMethodDelegation(Method method) {
+ super(method);
+ }
+
+ @Override
+ protected StackManipulation beforeDelegation(MethodDescription instrumentedMethod) {
+ List<StackManipulation> pushParameters = new ArrayList<>();
+ TypeList.Generic paramTypes = targetMethod.getParameters().asTypeList();
+ for (int i = 0; i < paramTypes.size(); i++) {
+ TypeDescription.Generic paramT = paramTypes.get(i);
+ pushParameters.add(MethodVariableAccess.of(paramT).loadOffset(i + 1));
+ if (!paramT.isPrimitive()) {
+ pushParameters.add(TypeCasting.to(paramT));
+ }
+ }
+ return new StackManipulation.Compound(pushParameters);
+ }
+ }
+
+ /**
* Implements the invoker's {@link DoFnInvoker#invokeProcessElement} method by delegating to the
* {@link DoFn.ProcessElement} method.
*/
private static final class ProcessElementDelegation extends DoFnMethodDelegation {
private static final Map<DoFnSignature.Parameter, MethodDescription>
EXTRA_CONTEXT_FACTORY_METHODS;
+ private static final MethodDescription PROCESS_CONTINUATION_STOP_METHOD;
static {
try {
@@ -397,11 +520,21 @@ public class DoFnInvokers {
DoFnSignature.Parameter.OUTPUT_RECEIVER,
new MethodDescription.ForLoadedMethod(
DoFn.ExtraContextFactory.class.getMethod("outputReceiver")));
+ methods.put(
+ DoFnSignature.Parameter.RESTRICTION_TRACKER,
+ new MethodDescription.ForLoadedMethod(
+ DoFn.ExtraContextFactory.class.getMethod("restrictionTracker")));
EXTRA_CONTEXT_FACTORY_METHODS = Collections.unmodifiableMap(methods);
} catch (Exception e) {
throw new RuntimeException(
"Failed to locate an ExtraContextFactory method that was expected to exist", e);
}
+ try {
+ PROCESS_CONTINUATION_STOP_METHOD =
+ new MethodDescription.ForLoadedMethod(DoFn.ProcessContinuation.class.getMethod("stop"));
+ } catch (NoSuchMethodException e) {
+ throw new RuntimeException("Failed to locate ProcessContinuation.stop()");
+ }
}
private final DoFnSignature.ProcessElementMethod signature;
@@ -427,14 +560,26 @@ public class DoFnInvokers {
parameters.add(
new StackManipulation.Compound(
pushExtraContextFactory,
- MethodInvocation.invoke(EXTRA_CONTEXT_FACTORY_METHODS.get(param))));
+ MethodInvocation.invoke(EXTRA_CONTEXT_FACTORY_METHODS.get(param)),
+ // ExtraContextFactory.restrictionTracker() returns a RestrictionTracker,
+ // but the @ProcessElement method expects a concrete subtype of it.
+ // Insert a downcast.
+ (param == DoFnSignature.Parameter.RESTRICTION_TRACKER)
+ ? TypeCasting.to(
+ new TypeDescription.ForLoadedType(signature.trackerT().getRawType()))
+ : StackManipulation.Trivial.INSTANCE));
}
return new StackManipulation.Compound(parameters);
}
@Override
protected StackManipulation afterDelegation(MethodDescription instrumentedMethod) {
- return MethodReturn.VOID;
+ if (TypeDescription.VOID.equals(targetMethod.getReturnType().asErasure())) {
+ return new StackManipulation.Compound(
+ MethodInvocation.invoke(PROCESS_CONTINUATION_STOP_METHOD), MethodReturn.REFERENCE);
+ } else {
+ return MethodReturn.returning(targetMethod.getReturnType().asErasure());
+ }
}
}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
index 756df07..632f817 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
@@ -18,11 +18,16 @@
package org.apache.beam.sdk.transforms.reflect;
import com.google.auto.value.AutoValue;
+import com.google.common.reflect.TypeToken;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.List;
import javax.annotation.Nullable;
+import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.values.PCollection;
/**
* Describes the signature of a {@link DoFn}, in particular, which features it uses, which extra
@@ -35,6 +40,9 @@ public abstract class DoFnSignature {
/** Class of the original {@link DoFn} from which this signature was produced. */
public abstract Class<? extends DoFn<?, ?>> fnClass();
+ /** Whether this {@link DoFn} does a bounded amount of work per element. */
+ public abstract PCollection.IsBounded isBoundedPerElement();
+
/** Details about this {@link DoFn}'s {@link DoFn.ProcessElement} method. */
public abstract ProcessElementMethod processElement();
@@ -54,6 +62,22 @@ public abstract class DoFnSignature {
@Nullable
public abstract LifecycleMethod teardown();
+ /** Details about this {@link DoFn}'s {@link DoFn.GetInitialRestriction} method. */
+ @Nullable
+ public abstract GetInitialRestrictionMethod getInitialRestriction();
+
+ /** Details about this {@link DoFn}'s {@link DoFn.SplitRestriction} method. */
+ @Nullable
+ public abstract SplitRestrictionMethod splitRestriction();
+
+ /** Details about this {@link DoFn}'s {@link DoFn.GetRestrictionCoder} method. */
+ @Nullable
+ public abstract GetRestrictionCoderMethod getRestrictionCoder();
+
+ /** Details about this {@link DoFn}'s {@link DoFn.NewTracker} method. */
+ @Nullable
+ public abstract NewTrackerMethod newTracker();
+
static Builder builder() {
return new AutoValue_DoFnSignature.Builder();
}
@@ -61,11 +85,16 @@ public abstract class DoFnSignature {
@AutoValue.Builder
abstract static class Builder {
abstract Builder setFnClass(Class<? extends DoFn<?, ?>> fnClass);
+ abstract Builder setIsBoundedPerElement(PCollection.IsBounded isBounded);
abstract Builder setProcessElement(ProcessElementMethod processElement);
abstract Builder setStartBundle(BundleMethod startBundle);
abstract Builder setFinishBundle(BundleMethod finishBundle);
abstract Builder setSetup(LifecycleMethod setup);
abstract Builder setTeardown(LifecycleMethod teardown);
+ abstract Builder setGetInitialRestriction(GetInitialRestrictionMethod getInitialRestriction);
+ abstract Builder setSplitRestriction(SplitRestrictionMethod splitRestriction);
+ abstract Builder setGetRestrictionCoder(GetRestrictionCoderMethod getRestrictionCoder);
+ abstract Builder setNewTracker(NewTrackerMethod newTracker);
abstract DoFnSignature build();
}
@@ -80,6 +109,7 @@ public abstract class DoFnSignature {
BOUNDED_WINDOW,
INPUT_PROVIDER,
OUTPUT_RECEIVER,
+ RESTRICTION_TRACKER
}
/** Describes a {@link DoFn.ProcessElement} method. */
@@ -92,17 +122,33 @@ public abstract class DoFnSignature {
/** Types of optional parameters of the annotated method, in the order they appear. */
public abstract List<Parameter> extraParameters();
+ /** Concrete type of the {@link RestrictionTracker} parameter, if present. */
+ @Nullable
+ abstract TypeToken<?> trackerT();
+
+ /** Whether this {@link DoFn} returns a {@link ProcessContinuation} or void. */
+ public abstract boolean hasReturnValue();
+
static ProcessElementMethod create(
Method targetMethod,
- List<Parameter> extraParameters) {
+ List<Parameter> extraParameters,
+ TypeToken<?> trackerT,
+ boolean hasReturnValue) {
return new AutoValue_DoFnSignature_ProcessElementMethod(
- targetMethod, Collections.unmodifiableList(extraParameters));
+ targetMethod, Collections.unmodifiableList(extraParameters), trackerT, hasReturnValue);
}
/** Whether this {@link DoFn} uses a Single Window. */
public boolean usesSingleWindow() {
return extraParameters().contains(Parameter.BOUNDED_WINDOW);
}
+
+ /**
+ * Whether this {@link DoFn} is <a href="https://s.apache.org/splittable-do-fn">splittable</a>.
+ */
+ public boolean isSplittable() {
+ return extraParameters().contains(Parameter.RESTRICTION_TRACKER);
+ }
}
/** Describes a {@link DoFn.StartBundle} or {@link DoFn.FinishBundle} method. */
@@ -128,4 +174,68 @@ public abstract class DoFnSignature {
return new AutoValue_DoFnSignature_LifecycleMethod(targetMethod);
}
}
+
+ /** Describes a {@link DoFn.GetInitialRestriction} method. */
+ @AutoValue
+ public abstract static class GetInitialRestrictionMethod implements DoFnMethod {
+ /** The annotated method itself. */
+ @Override
+ public abstract Method targetMethod();
+
+ /** Type of the returned restriction. */
+ abstract TypeToken<?> restrictionT();
+
+ static GetInitialRestrictionMethod create(Method targetMethod, TypeToken<?> restrictionT) {
+ return new AutoValue_DoFnSignature_GetInitialRestrictionMethod(targetMethod, restrictionT);
+ }
+ }
+
+ /** Describes a {@link DoFn.SplitRestriction} method. */
+ @AutoValue
+ public abstract static class SplitRestrictionMethod implements DoFnMethod {
+ /** The annotated method itself. */
+ @Override
+ public abstract Method targetMethod();
+
+ /** Type of the restriction taken and returned. */
+ abstract TypeToken<?> restrictionT();
+
+ static SplitRestrictionMethod create(Method targetMethod, TypeToken<?> restrictionT) {
+ return new AutoValue_DoFnSignature_SplitRestrictionMethod(targetMethod, restrictionT);
+ }
+ }
+
+ /** Describes a {@link DoFn.NewTracker} method. */
+ @AutoValue
+ public abstract static class NewTrackerMethod implements DoFnMethod {
+ /** The annotated method itself. */
+ @Override
+ public abstract Method targetMethod();
+
+ /** Type of the input restriction. */
+ abstract TypeToken<?> restrictionT();
+
+ /** Type of the returned {@link RestrictionTracker}. */
+ abstract TypeToken<?> trackerT();
+
+ static NewTrackerMethod create(
+ Method targetMethod, TypeToken<?> restrictionT, TypeToken<?> trackerT) {
+ return new AutoValue_DoFnSignature_NewTrackerMethod(targetMethod, restrictionT, trackerT);
+ }
+ }
+
+ /** Describes a {@link DoFn.GetRestrictionCoder} method. */
+ @AutoValue
+ public abstract static class GetRestrictionCoderMethod implements DoFnMethod {
+ /** The annotated method itself. */
+ @Override
+ public abstract Method targetMethod();
+
+ /** Type of the returned {@link Coder}. */
+ abstract TypeToken<?> coderT();
+
+ static GetRestrictionCoderMethod create(Method targetMethod, TypeToken<?> coderT) {
+ return new AutoValue_DoFnSignature_GetRestrictionCoderMethod(targetMethod, coderT);
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
index ad15127..524ea24 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
@@ -17,6 +17,8 @@
*/
package org.apache.beam.sdk.transforms.reflect;
+import static com.google.common.base.Preconditions.checkState;
+
import com.google.common.annotations.VisibleForTesting;
import com.google.common.reflect.TypeParameter;
import com.google.common.reflect.TypeToken;
@@ -34,9 +36,12 @@ import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
+import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.common.ReflectHelpers;
+import org.apache.beam.sdk.values.PCollection;
/**
* Parses a {@link DoFn} and computes its {@link DoFnSignature}. See {@link #getOrParseSignature}.
@@ -88,6 +93,14 @@ public class DoFnSignatures {
Method setupMethod = findAnnotatedMethod(errors, DoFn.Setup.class, fnClass, false);
Method teardownMethod = findAnnotatedMethod(errors, DoFn.Teardown.class, fnClass, false);
+ Method getInitialRestrictionMethod =
+ findAnnotatedMethod(errors, DoFn.GetInitialRestriction.class, fnClass, false);
+ Method splitRestrictionMethod =
+ findAnnotatedMethod(errors, DoFn.SplitRestriction.class, fnClass, false);
+ Method getRestrictionCoderMethod =
+ findAnnotatedMethod(errors, DoFn.GetRestrictionCoder.class, fnClass, false);
+ Method newTrackerMethod = findAnnotatedMethod(errors, DoFn.NewTracker.class, fnClass, false);
+
ErrorReporter processElementErrors =
errors.forMethod(DoFn.ProcessElement.class, processElementMethod);
DoFnSignature.ProcessElementMethod processElement =
@@ -119,7 +132,213 @@ public class DoFnSignatures {
errors.forMethod(DoFn.Teardown.class, teardownMethod), teardownMethod));
}
- return builder.build();
+ DoFnSignature.GetInitialRestrictionMethod getInitialRestriction = null;
+ ErrorReporter getInitialRestrictionErrors = null;
+ if (getInitialRestrictionMethod != null) {
+ getInitialRestrictionErrors =
+ errors.forMethod(DoFn.GetInitialRestriction.class, getInitialRestrictionMethod);
+ builder.setGetInitialRestriction(
+ getInitialRestriction =
+ analyzeGetInitialRestrictionMethod(
+ getInitialRestrictionErrors, fnToken, getInitialRestrictionMethod, inputT));
+ }
+
+ DoFnSignature.SplitRestrictionMethod splitRestriction = null;
+ if (splitRestrictionMethod != null) {
+ ErrorReporter splitRestrictionErrors =
+ errors.forMethod(DoFn.SplitRestriction.class, splitRestrictionMethod);
+ builder.setSplitRestriction(
+ splitRestriction =
+ analyzeSplitRestrictionMethod(
+ splitRestrictionErrors, fnToken, splitRestrictionMethod, inputT));
+ }
+
+ DoFnSignature.GetRestrictionCoderMethod getRestrictionCoder = null;
+ if (getRestrictionCoderMethod != null) {
+ ErrorReporter getRestrictionCoderErrors =
+ errors.forMethod(DoFn.GetRestrictionCoder.class, getRestrictionCoderMethod);
+ builder.setGetRestrictionCoder(
+ getRestrictionCoder =
+ analyzeGetRestrictionCoderMethod(
+ getRestrictionCoderErrors, fnToken, getRestrictionCoderMethod));
+ }
+
+ DoFnSignature.NewTrackerMethod newTracker = null;
+ if (newTrackerMethod != null) {
+ ErrorReporter newTrackerErrors = errors.forMethod(DoFn.NewTracker.class, newTrackerMethod);
+ builder.setNewTracker(
+ newTracker = analyzeNewTrackerMethod(newTrackerErrors, fnToken, newTrackerMethod));
+ }
+
+ builder.setIsBoundedPerElement(inferBoundedness(fnToken, processElement, errors));
+
+ DoFnSignature signature = builder.build();
+
+ // Additional validation for splittable DoFn's.
+ if (processElement.isSplittable()) {
+ verifySplittableMethods(signature, errors);
+ } else {
+ verifyUnsplittableMethods(errors, signature);
+ }
+
+ return signature;
+ }
+
+ /**
+ * Infers the boundedness of the {@link DoFn.ProcessElement} method (whether or not it performs a
+ * bounded amount of work per element) using the following criteria:
+ *
+ * <ol>
+ * <li>If the {@link DoFn} is not splittable, then it is bounded, it must not be annotated as
+ * {@link DoFn.BoundedPerElement} or {@link DoFn.UnboundedPerElement}, and {@link
+ * DoFn.ProcessElement} must return {@code void}.
+ * <li>If the {@link DoFn} (or any of its supertypes) is annotated as {@link
+ * DoFn.BoundedPerElement} or {@link DoFn.UnboundedPerElement}, use that. Only one of
+ * these must be specified.
+ * <li>If {@link DoFn.ProcessElement} returns {@link DoFn.ProcessContinuation}, assume it is
+ * unbounded. Otherwise (if it returns {@code void}), assume it is bounded.
+ * <li>If {@link DoFn.ProcessElement} returns {@code void}, but the {@link DoFn} is annotated
+ * {@link DoFn.UnboundedPerElement}, this is an error.
+ * </ol>
+ */
+ private static PCollection.IsBounded inferBoundedness(
+ TypeToken<? extends DoFn> fnToken,
+ DoFnSignature.ProcessElementMethod processElement,
+ ErrorReporter errors) {
+ PCollection.IsBounded isBounded = null;
+ for (TypeToken<?> supertype : fnToken.getTypes()) {
+ if (supertype.getRawType().isAnnotationPresent(DoFn.BoundedPerElement.class)
+ || supertype.getRawType().isAnnotationPresent(DoFn.UnboundedPerElement.class)) {
+ errors.checkArgument(
+ isBounded == null,
+ "Both @%s and @%s specified",
+ DoFn.BoundedPerElement.class.getSimpleName(),
+ DoFn.UnboundedPerElement.class.getSimpleName());
+ isBounded =
+ supertype.getRawType().isAnnotationPresent(DoFn.BoundedPerElement.class)
+ ? PCollection.IsBounded.BOUNDED
+ : PCollection.IsBounded.UNBOUNDED;
+ }
+ }
+ if (processElement.isSplittable()) {
+ if (isBounded == null) {
+ isBounded =
+ processElement.hasReturnValue()
+ ? PCollection.IsBounded.UNBOUNDED
+ : PCollection.IsBounded.BOUNDED;
+ }
+ } else {
+ errors.checkArgument(
+ isBounded == null,
+ "Non-splittable, but annotated as @"
+ + ((isBounded == PCollection.IsBounded.BOUNDED)
+ ? DoFn.BoundedPerElement.class.getSimpleName()
+ : DoFn.UnboundedPerElement.class.getSimpleName()));
+ checkState(!processElement.hasReturnValue(), "Should have been inferred splittable");
+ isBounded = PCollection.IsBounded.BOUNDED;
+ }
+ return isBounded;
+ }
+
+ /**
+ * Verifies properties related to methods of splittable {@link DoFn}:
+ *
+ * <ul>
+ * <li>Must declare the required {@link DoFn.GetInitialRestriction} and {@link DoFn.NewTracker}
+ * methods.
+ * <li>Types of restrictions and trackers must match exactly between {@link DoFn.ProcessElement},
+ * {@link DoFn.GetInitialRestriction}, {@link DoFn.NewTracker}, {@link
+ * DoFn.GetRestrictionCoder}, {@link DoFn.SplitRestriction}.
+ * </ul>
+ */
+ private static void verifySplittableMethods(DoFnSignature signature, ErrorReporter errors) {
+ DoFnSignature.ProcessElementMethod processElement = signature.processElement();
+ DoFnSignature.GetInitialRestrictionMethod getInitialRestriction =
+ signature.getInitialRestriction();
+ DoFnSignature.NewTrackerMethod newTracker = signature.newTracker();
+ DoFnSignature.GetRestrictionCoderMethod getRestrictionCoder = signature.getRestrictionCoder();
+ DoFnSignature.SplitRestrictionMethod splitRestriction = signature.splitRestriction();
+
+ ErrorReporter processElementErrors =
+ errors.forMethod(DoFn.ProcessElement.class, processElement.targetMethod());
+
+ List<String> missingRequiredMethods = new ArrayList<>();
+ if (getInitialRestriction == null) {
+ missingRequiredMethods.add("@" + DoFn.GetInitialRestriction.class.getSimpleName());
+ }
+ if (newTracker == null) {
+ missingRequiredMethods.add("@" + DoFn.NewTracker.class.getSimpleName());
+ }
+ if (!missingRequiredMethods.isEmpty()) {
+ processElementErrors.throwIllegalArgument(
+ "Splittable, but does not define the following required methods: %s",
+ missingRequiredMethods);
+ }
+
+ processElementErrors.checkArgument(
+ processElement.trackerT().equals(newTracker.trackerT()),
+ "Has tracker type %s, but @%s method %s uses tracker type %s",
+ formatType(processElement.trackerT()),
+ DoFn.NewTracker.class.getSimpleName(),
+ format(newTracker.targetMethod()),
+ formatType(newTracker.trackerT()));
+
+ ErrorReporter getInitialRestrictionErrors =
+ errors.forMethod(DoFn.GetInitialRestriction.class, getInitialRestriction.targetMethod());
+ TypeToken<?> restrictionT = getInitialRestriction.restrictionT();
+
+ getInitialRestrictionErrors.checkArgument(
+ restrictionT.equals(newTracker.restrictionT()),
+ "Uses restriction type %s, but @%s method %s uses restriction type %s",
+ formatType(restrictionT),
+ DoFn.NewTracker.class.getSimpleName(),
+ format(newTracker.targetMethod()),
+ formatType(newTracker.restrictionT()));
+
+ if (getRestrictionCoder != null) {
+ getInitialRestrictionErrors.checkArgument(
+ getRestrictionCoder.coderT().isSubtypeOf(coderTypeOf(restrictionT)),
+ "Uses restriction type %s, but @%s method %s returns %s "
+ + "which is not a subtype of %s",
+ formatType(restrictionT),
+ DoFn.GetRestrictionCoder.class.getSimpleName(),
+ format(getRestrictionCoder.targetMethod()),
+ formatType(getRestrictionCoder.coderT()),
+ formatType(coderTypeOf(restrictionT)));
+ }
+
+ if (splitRestriction != null) {
+ getInitialRestrictionErrors.checkArgument(
+ splitRestriction.restrictionT().equals(restrictionT),
+ "Uses restriction type %s, but @%s method %s uses restriction type %s",
+ formatType(restrictionT),
+ DoFn.SplitRestriction.class.getSimpleName(),
+ format(splitRestriction.targetMethod()),
+ formatType(splitRestriction.restrictionT()));
+ }
+ }
+
+ /**
+ * Verifies that a non-splittable {@link DoFn} does not declare any methods that only make sense
+ * for splittable {@link DoFn}: {@link DoFn.GetInitialRestriction}, {@link DoFn.SplitRestriction},
+ * {@link DoFn.NewTracker}, {@link DoFn.GetRestrictionCoder}.
+ */
+ private static void verifyUnsplittableMethods(ErrorReporter errors, DoFnSignature signature) {
+ List<String> forbiddenMethods = new ArrayList<>();
+ if (signature.getInitialRestriction() != null) {
+ forbiddenMethods.add("@" + DoFn.GetInitialRestriction.class.getSimpleName());
+ }
+ if (signature.splitRestriction() != null) {
+ forbiddenMethods.add("@" + DoFn.SplitRestriction.class.getSimpleName());
+ }
+ if (signature.newTracker() != null) {
+ forbiddenMethods.add("@" + DoFn.NewTracker.class.getSimpleName());
+ }
+ if (signature.getRestrictionCoder() != null) {
+ forbiddenMethods.add("@" + DoFn.GetRestrictionCoder.class.getSimpleName());
+ }
+ errors.checkArgument(
+ forbiddenMethods.isEmpty(), "Non-splittable, but defines methods: %s", forbiddenMethods);
}
/**
@@ -166,7 +385,11 @@ public class DoFnSignatures {
Method m,
TypeToken<?> inputT,
TypeToken<?> outputT) {
- errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void");
+ errors.checkArgument(
+ void.class.equals(m.getReturnType())
+ || DoFn.ProcessContinuation.class.equals(m.getReturnType()),
+ "Must return void or %s",
+ DoFn.ProcessContinuation.class.getSimpleName());
TypeToken<?> processContextToken = doFnProcessContextTypeOf(inputT, outputT);
@@ -181,6 +404,7 @@ public class DoFnSignatures {
formatType(processContextToken));
List<DoFnSignature.Parameter> extraParameters = new ArrayList<>();
+ TypeToken<?> trackerT = null;
TypeToken<?> expectedInputProviderT = inputProviderTypeOf(inputT);
TypeToken<?> expectedOutputReceiverT = outputReceiverTypeOf(outputT);
@@ -190,38 +414,62 @@ public class DoFnSignatures {
if (rawType.equals(BoundedWindow.class)) {
errors.checkArgument(
!extraParameters.contains(DoFnSignature.Parameter.BOUNDED_WINDOW),
- "Multiple BoundedWindow parameters");
+ "Multiple %s parameters",
+ BoundedWindow.class.getSimpleName());
extraParameters.add(DoFnSignature.Parameter.BOUNDED_WINDOW);
} else if (rawType.equals(DoFn.InputProvider.class)) {
errors.checkArgument(
!extraParameters.contains(DoFnSignature.Parameter.INPUT_PROVIDER),
- "Multiple InputProvider parameters");
+ "Multiple %s parameters",
+ DoFn.InputProvider.class.getSimpleName());
errors.checkArgument(
paramT.equals(expectedInputProviderT),
- "Wrong type of InputProvider parameter: %s, should be %s",
+ "Wrong type of %s parameter: %s, should be %s",
+ DoFn.InputProvider.class.getSimpleName(),
formatType(paramT),
formatType(expectedInputProviderT));
extraParameters.add(DoFnSignature.Parameter.INPUT_PROVIDER);
} else if (rawType.equals(DoFn.OutputReceiver.class)) {
errors.checkArgument(
!extraParameters.contains(DoFnSignature.Parameter.OUTPUT_RECEIVER),
- "Multiple OutputReceiver parameters");
+ "Multiple %s parameters",
+ DoFn.OutputReceiver.class.getSimpleName());
errors.checkArgument(
paramT.equals(expectedOutputReceiverT),
- "Wrong type of OutputReceiver parameter: %s, should be %s",
+ "Wrong type of %s parameter: %s, should be %s",
+ DoFn.OutputReceiver.class.getSimpleName(),
formatType(paramT),
formatType(expectedOutputReceiverT));
extraParameters.add(DoFnSignature.Parameter.OUTPUT_RECEIVER);
+ } else if (RestrictionTracker.class.isAssignableFrom(rawType)) {
+ errors.checkArgument(
+ !extraParameters.contains(DoFnSignature.Parameter.RESTRICTION_TRACKER),
+ "Multiple %s parameters",
+ RestrictionTracker.class.getSimpleName());
+ extraParameters.add(DoFnSignature.Parameter.RESTRICTION_TRACKER);
+ trackerT = paramT;
} else {
List<String> allowedParamTypes =
- Arrays.asList(formatType(new TypeToken<BoundedWindow>() {}));
+ Arrays.asList(
+ formatType(new TypeToken<BoundedWindow>() {}),
+ formatType(new TypeToken<RestrictionTracker<?>>() {}));
errors.throwIllegalArgument(
"%s is not a valid context parameter. Should be one of %s",
formatType(paramT), allowedParamTypes);
}
}
- return DoFnSignature.ProcessElementMethod.create(m, extraParameters);
+ // A splittable DoFn can not have any other extra context parameters.
+ if (extraParameters.contains(DoFnSignature.Parameter.RESTRICTION_TRACKER)) {
+ errors.checkArgument(
+ extraParameters.size() == 1,
+ "Splittable DoFn must not have any extra context arguments apart from %s, but has: %s",
+ trackerT,
+ extraParameters);
+ }
+
+ return DoFnSignature.ProcessElementMethod.create(
+ m, extraParameters, trackerT, DoFn.ProcessContinuation.class.equals(m.getReturnType()));
}
@VisibleForTesting
@@ -248,6 +496,100 @@ public class DoFnSignatures {
return DoFnSignature.LifecycleMethod.create(m);
}
+ @VisibleForTesting
+ static DoFnSignature.GetInitialRestrictionMethod analyzeGetInitialRestrictionMethod(
+ ErrorReporter errors, TypeToken<? extends DoFn> fnToken, Method m, TypeToken<?> inputT) {
+ // Method is of the form:
+ // @GetInitialRestriction
+ // RestrictionT getInitialRestriction(InputT element);
+ Type[] params = m.getGenericParameterTypes();
+ errors.checkArgument(
+ params.length == 1 && fnToken.resolveType(params[0]).equals(inputT),
+ "Must take a single argument of type %s",
+ formatType(inputT));
+ return DoFnSignature.GetInitialRestrictionMethod.create(
+ m, fnToken.resolveType(m.getGenericReturnType()));
+ }
+
+ /** Generates a type token for {@code List<T>} given {@code T}. */
+ private static <T> TypeToken<List<T>> listTypeOf(TypeToken<T> elementT) {
+ return new TypeToken<List<T>>() {}.where(new TypeParameter<T>() {}, elementT);
+ }
+
+ @VisibleForTesting
+ static DoFnSignature.SplitRestrictionMethod analyzeSplitRestrictionMethod(
+ ErrorReporter errors, TypeToken<? extends DoFn> fnToken, Method m, TypeToken<?> inputT) {
+ // Method is of the form:
+ // @SplitRestriction
+ // void splitRestriction(InputT element, RestrictionT restriction);
+ errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void");
+
+ Type[] params = m.getGenericParameterTypes();
+ errors.checkArgument(params.length == 3, "Must have exactly 3 arguments");
+ errors.checkArgument(
+ fnToken.resolveType(params[0]).equals(inputT),
+ "First argument must be the element type %s",
+ formatType(inputT));
+
+ TypeToken<?> restrictionT = fnToken.resolveType(params[1]);
+ TypeToken<?> receiverT = fnToken.resolveType(params[2]);
+ TypeToken<?> expectedReceiverT = outputReceiverTypeOf(restrictionT);
+ errors.checkArgument(
+ receiverT.equals(expectedReceiverT),
+ "Third argument must be %s, but is %s",
+ formatType(expectedReceiverT),
+ formatType(receiverT));
+
+ return DoFnSignature.SplitRestrictionMethod.create(m, restrictionT);
+ }
+
+ /** Generates a type token for {@code Coder<T>} given {@code T}. */
+ private static <T> TypeToken<Coder<T>> coderTypeOf(TypeToken<T> elementT) {
+ return new TypeToken<Coder<T>>() {}.where(new TypeParameter<T>() {}, elementT);
+ }
+
+ @VisibleForTesting
+ static DoFnSignature.GetRestrictionCoderMethod analyzeGetRestrictionCoderMethod(
+ ErrorReporter errors, TypeToken<? extends DoFn> fnToken, Method m) {
+ errors.checkArgument(m.getParameterTypes().length == 0, "Must have zero arguments");
+ TypeToken<?> resT = fnToken.resolveType(m.getGenericReturnType());
+ errors.checkArgument(
+ resT.isSubtypeOf(TypeToken.of(Coder.class)),
+ "Must return a Coder, but returns %s",
+ formatType(resT));
+ return DoFnSignature.GetRestrictionCoderMethod.create(m, resT);
+ }
+
+ /**
+ * Generates a type token for {@code RestrictionTracker<RestrictionT>} given {@code RestrictionT}.
+ */
+ private static <RestrictionT>
+ TypeToken<RestrictionTracker<RestrictionT>> restrictionTrackerTypeOf(
+ TypeToken<RestrictionT> restrictionT) {
+ return new TypeToken<RestrictionTracker<RestrictionT>>() {}.where(
+ new TypeParameter<RestrictionT>() {}, restrictionT);
+ }
+
+ @VisibleForTesting
+ static DoFnSignature.NewTrackerMethod analyzeNewTrackerMethod(
+ ErrorReporter errors, TypeToken<? extends DoFn> fnToken, Method m) {
+ // Method is of the form:
+ // @NewTracker
+ // TrackerT newTracker(RestrictionT restriction);
+ Type[] params = m.getGenericParameterTypes();
+ errors.checkArgument(params.length == 1, "Must have a single argument");
+
+ TypeToken<?> restrictionT = fnToken.resolveType(params[0]);
+ TypeToken<?> trackerT = fnToken.resolveType(m.getGenericReturnType());
+ TypeToken<?> expectedTrackerT = restrictionTrackerTypeOf(restrictionT);
+ errors.checkArgument(
+ trackerT.isSubtypeOf(expectedTrackerT),
+ "Returns %s, but must return a subtype of %s",
+ formatType(trackerT),
+ formatType(expectedTrackerT));
+ return DoFnSignature.NewTrackerMethod.create(m, restrictionT, trackerT);
+ }
+
private static Collection<Method> declaredMethodsWithAnnotation(
Class<? extends Annotation> anno, Class<?> startClass, Class<?> stopClass) {
Collection<Method> matches = new ArrayList<>();
@@ -310,7 +652,7 @@ public class DoFnSignatures {
}
private static String format(Method method) {
- return ReflectHelpers.CLASS_AND_METHOD_FORMATTER.apply(method);
+ return ReflectHelpers.METHOD_FORMATTER.apply(method);
}
private static String formatType(TypeToken<?> t) {
@@ -327,7 +669,9 @@ public class DoFnSignatures {
ErrorReporter forMethod(Class<? extends Annotation> annotation, Method method) {
return new ErrorReporter(
this,
- String.format("@%s %s", annotation, (method == null) ? "(absent)" : format(method)));
+ String.format(
+ "@%s %s",
+ annotation.getSimpleName(), (method == null) ? "(absent)" : format(method)));
}
void throwIllegalArgument(String message, Object... args) {
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java
new file mode 100644
index 0000000..6b249ee
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.transforms.splittabledofn;
+
+import org.apache.beam.sdk.transforms.DoFn;
+
+/**
+ * Manages concurrent access to the restriction and keeps track of its claimed part for a <a
+ * href="https://s.apache.org/splittable-do-fn>splittable</a> {@link DoFn}.
+ */
+public interface RestrictionTracker<RestrictionT> {
+ /**
+ * Returns a restriction accurately describing the full range of work the current {@link
+ * DoFn.ProcessElement} call will do, including already completed work.
+ */
+ RestrictionT currentRestriction();
+
+ /**
+ * Signals that the current {@link DoFn.ProcessElement} call should terminate as soon as possible.
+ * Modifies {@link #currentRestriction}. Returns a restriction representing the rest of the work:
+ * the old value of {@link #currentRestriction} is equivalent to the new value and the return
+ * value of this method combined.
+ */
+ RestrictionT checkpoint();
+
+ // TODO: Add the more general splitRemainderAfterFraction() and other methods.
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/package-info.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/package-info.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/package-info.java
new file mode 100644
index 0000000..1ceb880
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * Defines utilities related to <a href="https://s.apache.org/splittable-do-fn>splittable</a>
+ * {@link org.apache.beam.sdk.transforms.DoFn}'s.
+ */
+package org.apache.beam.sdk.transforms.splittabledofn;
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/KvCoderTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/KvCoderTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/KvCoderTest.java
index f0f7d22..436e227 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/KvCoderTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/KvCoderTest.java
@@ -17,11 +17,9 @@
*/
package org.apache.beam.sdk.coders;
-import com.google.common.collect.ImmutableMap;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
-import java.util.Map;
import org.apache.beam.sdk.testing.CoderProperties;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.values.KV;
@@ -31,40 +29,55 @@ import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
-/**
- * Test case for {@link KvCoder}.
- */
+/** Test case for {@link KvCoder}. */
@RunWith(JUnit4.class)
public class KvCoderTest {
+ private static class CoderAndData<T> {
+ Coder<T> coder;
+ List<T> data;
+ }
+
+ private static class AnyCoderAndData {
+ private CoderAndData<?> coderAndData;
+ }
- private static final Map<Coder<?>, Iterable<?>> TEST_DATA =
- new ImmutableMap.Builder<Coder<?>, Iterable<?>>()
- .put(VarIntCoder.of(),
- Arrays.asList(-1, 0, 1, 13, Integer.MAX_VALUE, Integer.MIN_VALUE))
- .put(BigEndianLongCoder.of(),
- Arrays.asList(-1L, 0L, 1L, 13L, Long.MAX_VALUE, Long.MIN_VALUE))
- .put(StringUtf8Coder.of(),
- Arrays.asList("", "hello", "goodbye", "1"))
- .put(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()),
- Arrays.asList(KV.of("", -1), KV.of("hello", 0), KV.of("goodbye", Integer.MAX_VALUE)))
- .put(ListCoder.of(VarLongCoder.of()),
- Arrays.asList(
- Arrays.asList(1L, 2L, 3L),
- Collections.emptyList()))
- .build();
+ private static <T> AnyCoderAndData coderAndData(Coder<T> coder, List<T> data) {
+ CoderAndData<T> coderAndData = new CoderAndData<>();
+ coderAndData.coder = coder;
+ coderAndData.data = data;
+ AnyCoderAndData res = new AnyCoderAndData();
+ res.coderAndData = coderAndData;
+ return res;
+ }
+
+ private static final List<AnyCoderAndData> TEST_DATA =
+ Arrays.asList(
+ coderAndData(
+ VarIntCoder.of(), Arrays.asList(-1, 0, 1, 13, Integer.MAX_VALUE, Integer.MIN_VALUE)),
+ coderAndData(
+ BigEndianLongCoder.of(),
+ Arrays.asList(-1L, 0L, 1L, 13L, Long.MAX_VALUE, Long.MIN_VALUE)),
+ coderAndData(StringUtf8Coder.of(), Arrays.asList("", "hello", "goodbye", "1")),
+ coderAndData(
+ KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()),
+ Arrays.asList(KV.of("", -1), KV.of("hello", 0), KV.of("goodbye", Integer.MAX_VALUE))),
+ coderAndData(
+ ListCoder.of(VarLongCoder.of()),
+ Arrays.asList(Arrays.asList(1L, 2L, 3L), Collections.<Long>emptyList())));
@Test
+ @SuppressWarnings("rawtypes")
public void testDecodeEncodeEqual() throws Exception {
- for (Map.Entry<Coder<?>, Iterable<?>> entry : TEST_DATA.entrySet()) {
- // The coder and corresponding values must be the same type.
- // If someone messes this up in the above test data, the test
- // will fail anyhow (unless the coder magically works on data
- // it does not understand).
- @SuppressWarnings("unchecked")
- Coder<Object> coder = (Coder<Object>) entry.getKey();
- Iterable<?> values = entry.getValue();
- for (Object value : values) {
- CoderProperties.coderDecodeEncodeEqual(coder, value);
+ for (AnyCoderAndData keyCoderAndData : TEST_DATA) {
+ Coder keyCoder = keyCoderAndData.coderAndData.coder;
+ for (Object key : keyCoderAndData.coderAndData.data) {
+ for (AnyCoderAndData valueCoderAndData : TEST_DATA) {
+ Coder valueCoder = valueCoderAndData.coderAndData.coder;
+ for (Object value : valueCoderAndData.coderAndData.data) {
+ CoderProperties.coderDecodeEncodeEqual(
+ KvCoder.of(keyCoder, valueCoder), KV.of(key, value));
+ }
+ }
}
}
}
@@ -75,37 +88,29 @@ public class KvCoderTest {
@Test
public void testEncodingId() throws Exception {
CoderProperties.coderHasEncodingId(
- KvCoder.of(VarIntCoder.of(), VarIntCoder.of()),
- EXPECTED_ENCODING_ID);
+ KvCoder.of(VarIntCoder.of(), VarIntCoder.of()), EXPECTED_ENCODING_ID);
}
- /**
- * Homogeneously typed test value for ease of use with the wire format test utility.
- */
+ /** Homogeneously typed test value for ease of use with the wire format test utility. */
private static final Coder<KV<String, Integer>> TEST_CODER =
KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of());
- private static final List<KV<String, Integer>> TEST_VALUES = Arrays.asList(
- KV.of("", -1),
- KV.of("hello", 0),
- KV.of("goodbye", Integer.MAX_VALUE));
+ private static final List<KV<String, Integer>> TEST_VALUES =
+ Arrays.asList(KV.of("", -1), KV.of("hello", 0), KV.of("goodbye", Integer.MAX_VALUE));
/**
- * Generated data to check that the wire format has not changed. To regenerate, see
- * {@link org.apache.beam.sdk.coders.PrintBase64Encodings}.
+ * Generated data to check that the wire format has not changed. To regenerate, see {@link
+ * org.apache.beam.sdk.coders.PrintBase64Encodings}.
*/
- private static final List<String> TEST_ENCODINGS = Arrays.asList(
- "AP____8P",
- "BWhlbGxvAA",
- "B2dvb2RieWX_____Bw");
+ private static final List<String> TEST_ENCODINGS =
+ Arrays.asList("AP____8P", "BWhlbGxvAA", "B2dvb2RieWX_____Bw");
@Test
public void testWireFormatEncode() throws Exception {
CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS);
}
- @Rule
- public ExpectedException thrown = ExpectedException.none();
+ @Rule public ExpectedException thrown = ExpectedException.none();
@Test
public void encodeNullThrowsCoderException() throws Exception {
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
index 7ce98bc..9c7b991 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
@@ -56,6 +56,7 @@ import org.apache.beam.sdk.transforms.ParDo.Bound;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.display.DisplayData.Builder;
import org.apache.beam.sdk.transforms.display.DisplayDataMatchers;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
@@ -1469,4 +1470,49 @@ public class ParDoTest implements Serializable {
assertThat(displayData, includesDisplayDataFrom(fn));
assertThat(displayData, hasDisplayItem("fn", fn.getClass()));
}
+
+ private abstract static class SomeTracker implements RestrictionTracker<Object> {}
+ private static class TestSplittableDoFn extends DoFn<Integer, String> {
+ @ProcessElement
+ public void processElement(ProcessContext context, SomeTracker tracker) {}
+
+ @GetInitialRestriction
+ public Object getInitialRestriction(Integer element) {
+ return null;
+ }
+
+ @NewTracker
+ public SomeTracker newTracker(Object restriction) {
+ return null;
+ }
+ }
+
+ @Test
+ public void testRejectsSplittableDoFnByDefault() {
+ // ParDo with a splittable DoFn must be overridden by the runner.
+ // Without an override, applying it directly must fail.
+ Pipeline p = TestPipeline.create();
+
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage("Splittable DoFn not supported by the current runner");
+
+ p.apply(Create.of(1, 2, 3)).apply(ParDo.of(new TestSplittableDoFn()));
+ }
+
+ @Test
+ public void testMultiRejectsSplittableDoFnByDefault() {
+ // ParDo with a splittable DoFn must be overridden by the runner.
+ // Without an override, applying it directly must fail.
+ Pipeline p = TestPipeline.create();
+
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage("Splittable DoFn not supported by the current runner");
+
+ p.apply(Create.of(1, 2, 3))
+ .apply(
+ ParDo.of(new TestSplittableDoFn())
+ .withOutputTags(
+ new TupleTag<String>("main") {},
+ TupleTagList.of(new TupleTag<String>("side1") {})));
+ }
}